Compare commits
No commits in common. "99b5b8c775e3a3b6d92c27f1e655c5b6078a7a1a" and "d3029d09e7949f92254496b2100f42c33b70d77c" have entirely different histories.
99b5b8c775
...
d3029d09e7
15
Makefile
15
Makefile
@ -39,18 +39,3 @@ generate:
|
|||||||
install_protobuf:
|
install_protobuf:
|
||||||
@go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.28
|
@go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.28
|
||||||
@go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.2
|
@go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.2
|
||||||
|
|
||||||
update_sshlib:
|
|
||||||
@rm -rf cryptolib && \
|
|
||||||
git clone https://go.googlesource.com/crypto cryptolib && \
|
|
||||||
rm -rf pkg/proto/ssh && \
|
|
||||||
mv cryptolib/ssh pkg/proto/ && \
|
|
||||||
mv cryptolib/internal/poly1305 pkg/proto/ssh/internal/ && \
|
|
||||||
find pkg/proto/ssh -type f -name '*.go' -exec sed -i 's?golang.org/x/crypto/ssh?github.com/Neur0toxine/sshpoke/pkg/proto/ssh?g' {} \; && \
|
|
||||||
find pkg/proto/ssh -type f -name '*.go' -exec sed -i 's?golang.org/x/crypto/internal/poly1305?github.com/Neur0toxine/sshpoke/pkg/proto/ssh/internal/poly1305?g' {} \; && \
|
|
||||||
find pkg/proto/ssh -type f -name '*_test.go' -delete && \
|
|
||||||
patch -p0 < patch/ssh_fakehost.patch && \
|
|
||||||
rm -rf pkg/proto/ssh/test && \
|
|
||||||
rm -rf pkg/proto/ssh/testdata && \
|
|
||||||
rm -rf cryptolib
|
|
||||||
|
|
||||||
|
9
go.mod
9
go.mod
@ -5,7 +5,6 @@ go 1.21.4
|
|||||||
require (
|
require (
|
||||||
github.com/docker/docker v24.0.7+incompatible
|
github.com/docker/docker v24.0.7+incompatible
|
||||||
github.com/docker/go-connections v0.4.0
|
github.com/docker/go-connections v0.4.0
|
||||||
github.com/function61/gokit v0.0.0-20231117065306-355fe206d542
|
|
||||||
github.com/go-playground/validator/v10 v10.16.0
|
github.com/go-playground/validator/v10 v10.16.0
|
||||||
github.com/kevinburke/ssh_config v1.2.0
|
github.com/kevinburke/ssh_config v1.2.0
|
||||||
github.com/mitchellh/mapstructure v1.5.0
|
github.com/mitchellh/mapstructure v1.5.0
|
||||||
@ -14,9 +13,7 @@ require (
|
|||||||
github.com/spf13/viper v1.17.0
|
github.com/spf13/viper v1.17.0
|
||||||
go.uber.org/zap v1.26.0
|
go.uber.org/zap v1.26.0
|
||||||
golang.design/x/lockfree v0.0.1
|
golang.design/x/lockfree v0.0.1
|
||||||
golang.org/x/crypto v0.14.0
|
golang.org/x/crypto v0.13.0
|
||||||
golang.org/x/sys v0.13.0
|
|
||||||
golang.org/x/term v0.13.0
|
|
||||||
google.golang.org/grpc v1.58.2
|
google.golang.org/grpc v1.58.2
|
||||||
google.golang.org/protobuf v1.31.0
|
google.golang.org/protobuf v1.31.0
|
||||||
)
|
)
|
||||||
@ -52,8 +49,8 @@ require (
|
|||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
||||||
golang.org/x/mod v0.12.0 // indirect
|
golang.org/x/mod v0.12.0 // indirect
|
||||||
golang.org/x/net v0.17.0 // indirect
|
golang.org/x/net v0.15.0 // indirect
|
||||||
golang.org/x/sync v0.3.0 // indirect
|
golang.org/x/sys v0.12.0 // indirect
|
||||||
golang.org/x/text v0.13.0 // indirect
|
golang.org/x/text v0.13.0 // indirect
|
||||||
golang.org/x/tools v0.13.0 // indirect
|
golang.org/x/tools v0.13.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13 // indirect
|
||||||
|
18
go.sum
18
go.sum
@ -77,8 +77,6 @@ github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0X
|
|||||||
github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||||
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
||||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||||
github.com/function61/gokit v0.0.0-20231117065306-355fe206d542 h1:a9BTN+DOboRkVih0suT4zrRZ4zLGFpBtHPGNd+EQ4pI=
|
|
||||||
github.com/function61/gokit v0.0.0-20231117065306-355fe206d542/go.mod h1:sJY957+7ush4oj4ElOMhUFaFIriAFNAGYzVh2tFJNy0=
|
|
||||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
||||||
@ -257,8 +255,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
|||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
|
||||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||||
@ -328,8 +326,8 @@ golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v
|
|||||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
|
||||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
@ -387,11 +385,11 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||||||
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek=
|
golang.org/x/term v0.12.0 h1:/ZfYdc3zq+q02Rv9vGqTeSItdzZTSNDmfTi0mBAuidU=
|
||||||
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
@ -110,7 +110,7 @@ func (d *Docker) Listen() (chan dto.Event, error) {
|
|||||||
"container.ip", converted.IP.String(),
|
"container.ip", converted.IP.String(),
|
||||||
"container.port", converted.Port,
|
"container.port", converted.Port,
|
||||||
"container.server", converted.Server,
|
"container.server", converted.Server,
|
||||||
"container.remote_host", converted.RemoteHost)
|
"container.prefix", converted.Prefix)
|
||||||
output <- newEvent
|
output <- newEvent
|
||||||
case err := <-errSource:
|
case err := <-errSource:
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
|
@ -17,7 +17,7 @@ type labelsConfig struct {
|
|||||||
Network string `mapstructure:"sshpoke.network"`
|
Network string `mapstructure:"sshpoke.network"`
|
||||||
Server string `mapstructure:"sshpoke.server"`
|
Server string `mapstructure:"sshpoke.server"`
|
||||||
Port string `mapstructure:"sshpoke.port"`
|
Port string `mapstructure:"sshpoke.port"`
|
||||||
RemoteHost string `mapstructure:"sshpoke.remote_host"`
|
Prefix string `mapstructure:"sshpoke.prefix"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type boolStr string
|
type boolStr string
|
||||||
@ -82,7 +82,7 @@ func dockerContainerToInternal(container types.Container) (result dto.Container,
|
|||||||
IP: ip,
|
IP: ip,
|
||||||
Port: uint16(port),
|
Port: uint16(port),
|
||||||
Server: labels.Server,
|
Server: labels.Server,
|
||||||
RemoteHost: labels.RemoteHost,
|
Prefix: labels.Prefix,
|
||||||
}, true
|
}, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,55 +2,50 @@ package ssh
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"strconv"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/internal/config"
|
"github.com/Neur0toxine/sshpoke/internal/config"
|
||||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/base"
|
"github.com/Neur0toxine/sshpoke/internal/server/driver/base"
|
||||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/sshtun"
|
|
||||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
|
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
|
||||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/util"
|
"github.com/Neur0toxine/sshpoke/internal/server/driver/util"
|
||||||
|
"github.com/Neur0toxine/sshpoke/internal/server/proto/sshtun"
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/dto"
|
"github.com/Neur0toxine/sshpoke/pkg/dto"
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SSH struct {
|
type SSH struct {
|
||||||
base.Base
|
base.Base
|
||||||
params Params
|
params Params
|
||||||
auth []ssh.AuthMethod
|
sessions map[string]conn
|
||||||
conns map[string]conn
|
keys []ssh.Signer
|
||||||
rw sync.RWMutex
|
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
type conn struct {
|
type conn struct {
|
||||||
ctx context.Context
|
container dto.Container
|
||||||
cancel func()
|
|
||||||
tun *sshtun.Tunnel
|
tun *sshtun.Tunnel
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(ctx context.Context, name string, params config.DriverParams) (base.Driver, error) {
|
func New(ctx context.Context, name string, params config.DriverParams) (base.Driver, error) {
|
||||||
drv := &SSH{
|
drv := &SSH{
|
||||||
Base: base.New(ctx, name),
|
Base: base.New(ctx, name),
|
||||||
conns: make(map[string]conn),
|
sessions: make(map[string]conn),
|
||||||
}
|
}
|
||||||
if err := util.UnmarshalParams(params, &drv.params); err != nil {
|
if err := util.UnmarshalParams(params, &drv.params); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
drv.populateFromSSHConfig()
|
drv.populateFromSSHConfig()
|
||||||
drv.auth = drv.authenticators()
|
if err := drv.parseKeys(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return drv, nil
|
return drv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SSH) forward(val sshtun.Forward) conn {
|
|
||||||
tun := sshtun.New(d.params.Address, d.params.Auth.User, d.params.DisableRemoteHostResolve, val, d.auth, d.Log())
|
|
||||||
ctx, cancel := context.WithCancel(d.Context())
|
|
||||||
go tun.Connect(ctx, sshtun.StdoutPrinterSessionCallback(d.Log().With("ssh-output", val.Remote.String())))
|
|
||||||
return conn{ctx: ctx, cancel: cancel, tun: tun}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *SSH) populateFromSSHConfig() {
|
func (d *SSH) populateFromSSHConfig() {
|
||||||
if d.params.Auth.Directory == "" {
|
if d.params.Auth.Directory == "" {
|
||||||
return
|
return
|
||||||
@ -75,43 +70,8 @@ func (d *SSH) populateFromSSHConfig() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *SSH) Handle(event dto.Event) error {
|
func (d *SSH) Handle(event dto.Event) error {
|
||||||
defer d.rw.Unlock()
|
// TODO: Implement event handling & connections management.
|
||||||
d.rw.Lock()
|
return errors.New("server handler is not implemented yet")
|
||||||
switch event.Type {
|
|
||||||
case dto.EventStart:
|
|
||||||
conn := d.forward(sshtun.Forward{
|
|
||||||
Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))),
|
|
||||||
Remote: d.remoteEndpoint(event.Container.RemoteHost),
|
|
||||||
})
|
|
||||||
d.conns[event.Container.ID] = conn
|
|
||||||
d.wg.Add(1)
|
|
||||||
case dto.EventStop:
|
|
||||||
conn, found := d.conns[event.Container.ID]
|
|
||||||
if !found {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
conn.cancel()
|
|
||||||
delete(d.conns, event.Container.ID)
|
|
||||||
d.wg.Done()
|
|
||||||
case dto.EventShutdown:
|
|
||||||
for id, conn := range d.conns {
|
|
||||||
conn.cancel()
|
|
||||||
delete(d.conns, id)
|
|
||||||
d.wg.Done()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *SSH) remoteEndpoint(remoteHost string) sshtun.Endpoint {
|
|
||||||
port := int(d.params.ForwardPort)
|
|
||||||
if port == 0 {
|
|
||||||
port = 80
|
|
||||||
}
|
|
||||||
return sshtun.Endpoint{
|
|
||||||
Host: remoteHost,
|
|
||||||
Port: port,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SSH) Driver() config.DriverType {
|
func (d *SSH) Driver() config.DriverType {
|
||||||
@ -119,38 +79,81 @@ func (d *SSH) Driver() config.DriverType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *SSH) WaitForShutdown() {
|
func (d *SSH) WaitForShutdown() {
|
||||||
go d.Handle(dto.Event{Type: dto.EventShutdown})
|
|
||||||
d.wg.Wait()
|
d.wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SSH) authenticators() []ssh.AuthMethod {
|
func (d *SSH) parseKeys() error {
|
||||||
auth := d.authenticator()
|
if d.params.Auth.Type != types.AuthTypeKey {
|
||||||
if auth == nil {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return []ssh.AuthMethod{auth}
|
dir, err := d.params.Auth.Directory.Resolve(true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot parse keys: %s", err)
|
||||||
|
}
|
||||||
|
if d.params.Auth.Keyfile != "" {
|
||||||
|
key, err := parseKey(path.Join(dir, d.params.Auth.Keyfile))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
d.keys = []ssh.Signer{key}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot read key directory: %s", err)
|
||||||
|
}
|
||||||
|
keys := []ssh.Signer{}
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() {
|
||||||
|
d.Log().Debugf("skipping '%s' because it's a directory", entry.Name())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
info, err := entry.Info()
|
||||||
|
if err != nil {
|
||||||
|
d.Log().Debugf("skipping '%s' because stat failed: %s", entry.Name(), err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(entry.Name(), ".pub") {
|
||||||
|
d.Log().Debugf("skipping '%s' because it's probably a public key", entry.Name())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if entry.Name() == "config" {
|
||||||
|
d.Log().Debugf("skipping '%s' because it's probably a ssh-config file", entry.Name())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if entry.Name() == "known_hosts" {
|
||||||
|
d.Log().Debugf(
|
||||||
|
"skipping '%s' because it's probably a list of hosts generated by OpenSSH", entry.Name())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// this file is too small to be a private key
|
||||||
|
if info.Size() < 256 {
|
||||||
|
d.Log().Debugf("skipping '%s' because the file is smaller than 256 bytes", entry.Name())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key, err := parseKey(path.Join(dir, entry.Name()))
|
||||||
|
if err != nil {
|
||||||
|
d.Log().Debugf("skipping '%s' because it's probably not a key: %s", entry.Name(), err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
d.Log().Debugf("loading key '%s', type: %s", entry.Name(), key.PublicKey().Type())
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return errors.New("no keys in the provided directory")
|
||||||
|
}
|
||||||
|
d.keys = keys
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SSH) authenticator() ssh.AuthMethod {
|
func parseKey(keyFile string) (ssh.Signer, error) {
|
||||||
switch d.params.Auth.Type {
|
keyData, err := os.ReadFile(keyFile)
|
||||||
case types.AuthTypePasswordless:
|
|
||||||
return sshtun.AuthPassword("")
|
|
||||||
case types.AuthTypePassword:
|
|
||||||
return sshtun.AuthPassword(d.params.Auth.Password)
|
|
||||||
case types.AuthTypeKey:
|
|
||||||
if d.params.Auth.Keyfile != "" {
|
|
||||||
keyAuth, err := sshtun.AuthKeyFile(
|
|
||||||
types.SmartPath(path.Join(d.params.Auth.Directory.String(), d.params.Auth.Keyfile)))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil, err
|
||||||
}
|
}
|
||||||
return keyAuth
|
key, err := ssh.ParsePrivateKey(keyData)
|
||||||
}
|
|
||||||
dirAuth, err := sshtun.AuthKeyDir(d.params.Auth.Directory)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil, err
|
||||||
}
|
}
|
||||||
return dirAuth
|
return key, nil
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -7,15 +7,13 @@ import (
|
|||||||
|
|
||||||
type Params struct {
|
type Params struct {
|
||||||
Address string `mapstructure:"address" validate:"required"`
|
Address string `mapstructure:"address" validate:"required"`
|
||||||
DefaultHost *string `mapstructure:"default_host,omitempty"`
|
|
||||||
ForwardPort uint16 `mapstructure:"forward_port"`
|
|
||||||
Auth types.Auth `mapstructure:"auth"`
|
Auth types.Auth `mapstructure:"auth"`
|
||||||
KeepAlive types.KeepAlive `mapstructure:"keepalive"`
|
KeepAlive types.KeepAlive `mapstructure:"keepalive"`
|
||||||
Domain string `mapstructure:"domain"`
|
Domain string `mapstructure:"domain"`
|
||||||
DomainProto string `mapstructure:"domain_proto"`
|
DomainProto string `mapstructure:"domain_proto"`
|
||||||
DomainExtractRegex string `mapstructure:"domain_extract_regex" validate:"validregexp"`
|
DomainExtractRegex string `mapstructure:"domain_extract_regex" validate:"validregexp"`
|
||||||
Mode types.DomainMode `mapstructure:"mode" validate:"required,oneof=single multi"`
|
Mode types.DomainMode `mapstructure:"mode" validate:"required,oneof=single multi"`
|
||||||
DisableRemoteHostResolve bool `mapstructure:"disable_remote_host_resolve"`
|
Prefix bool `mapstructure:"prefix"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Params) Validate() error {
|
func (p *Params) Validate() error {
|
||||||
|
@ -1,91 +0,0 @@
|
|||||||
package sshtun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
|
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
func AuthKeyFile(keyFile types.SmartPath) (ssh.AuthMethod, error) {
|
|
||||||
key, err := readKey(keyFile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return ssh.PublicKeys(key), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func AuthKeyDir(keyDir types.SmartPath) (ssh.AuthMethod, error) {
|
|
||||||
keys, err := readKeys(keyDir)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return ssh.PublicKeys(keys...), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func AuthPassword(password string) ssh.AuthMethod {
|
|
||||||
return ssh.Password(password)
|
|
||||||
}
|
|
||||||
|
|
||||||
func readKeys(keyDir types.SmartPath) ([]ssh.Signer, error) {
|
|
||||||
dir, err := keyDir.Resolve(true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("cannot parse keys: %s", err)
|
|
||||||
}
|
|
||||||
entries, err := os.ReadDir(dir)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("cannot read key directory: %s", err)
|
|
||||||
}
|
|
||||||
keys := []ssh.Signer{}
|
|
||||||
for _, entry := range entries {
|
|
||||||
if entry.IsDir() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
info, err := entry.Info()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.HasSuffix(entry.Name(), ".pub") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if entry.Name() == "config" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if entry.Name() == "known_hosts" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// this file is too small to be a private key
|
|
||||||
if info.Size() < 256 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
key, err := readKey(types.SmartPath(path.Join(dir, entry.Name())))
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
keys = append(keys, key)
|
|
||||||
}
|
|
||||||
if len(keys) == 0 {
|
|
||||||
return nil, errors.New("no keys in the provided directory")
|
|
||||||
}
|
|
||||||
return keys, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func readKey(keyFile types.SmartPath) (ssh.Signer, error) {
|
|
||||||
fileName, err := keyFile.Resolve(false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
keyData, err := os.ReadFile(fileName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
key, err := ssh.ParsePrivateKey(keyData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return key, nil
|
|
||||||
}
|
|
@ -1,21 +0,0 @@
|
|||||||
package sshtun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
|
|
||||||
"github.com/kevinburke/ssh_config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func parseSSHConfig(filePath types.SmartPath) (*ssh_config.Config, error) {
|
|
||||||
fileName, err := filePath.Resolve(false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
file, err := os.ReadFile(fileName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return ssh_config.Decode(bytes.NewReader(file))
|
|
||||||
}
|
|
@ -1,37 +0,0 @@
|
|||||||
package sshtun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/errtools"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Forward struct {
|
|
||||||
// local service to be forwarded
|
|
||||||
Local Endpoint `json:"local"`
|
|
||||||
// remote forwarding port (on remote SSH server network)
|
|
||||||
Remote Endpoint `json:"remote"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func AddrToEndpoint(address string) Endpoint {
|
|
||||||
host, port, err := net.SplitHostPort(address)
|
|
||||||
if err != nil && errtools.IsPortMissingErr(err) {
|
|
||||||
return Endpoint{Host: host, Port: 22}
|
|
||||||
}
|
|
||||||
portNum, err := strconv.Atoi(port)
|
|
||||||
if err != nil {
|
|
||||||
portNum = 22
|
|
||||||
}
|
|
||||||
return Endpoint{Host: host, Port: portNum}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Endpoint struct {
|
|
||||||
Host string `json:"host"`
|
|
||||||
Port int `json:"port"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (endpoint *Endpoint) String() string {
|
|
||||||
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
|
|
||||||
}
|
|
@ -1,21 +0,0 @@
|
|||||||
package sshtun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
|
||||||
"go.uber.org/zap"
|
|
||||||
)
|
|
||||||
|
|
||||||
func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback {
|
|
||||||
return func(session *ssh.Session) {
|
|
||||||
stdout, err := session.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
scan := bufio.NewScanner(stdout)
|
|
||||||
for scan.Scan() {
|
|
||||||
log.Debug(scan.Text())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,217 +0,0 @@
|
|||||||
package sshtun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
|
||||||
"github.com/function61/gokit/app/backoff"
|
|
||||||
"github.com/function61/gokit/io/bidipipe"
|
|
||||||
"go.uber.org/zap"
|
|
||||||
)
|
|
||||||
|
|
||||||
type SessionCallback func(*ssh.Session)
|
|
||||||
|
|
||||||
var NoopSessionCallback SessionCallback = func(*ssh.Session) {}
|
|
||||||
|
|
||||||
type Tunnel struct {
|
|
||||||
user string
|
|
||||||
address Endpoint
|
|
||||||
forward Forward
|
|
||||||
authMethods []ssh.AuthMethod
|
|
||||||
log *zap.SugaredLogger
|
|
||||||
connected atomic.Bool
|
|
||||||
fakeRemoteHost bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(address, user string, fakeRemoteHost bool,
|
|
||||||
forward Forward, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Tunnel {
|
|
||||||
return &Tunnel{
|
|
||||||
address: AddrToEndpoint(address),
|
|
||||||
user: user,
|
|
||||||
fakeRemoteHost: fakeRemoteHost,
|
|
||||||
forward: forward,
|
|
||||||
authMethods: auth,
|
|
||||||
log: log.With(zap.String("sshServer", address)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) {
|
|
||||||
if c.connected.Load() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer c.connected.Store(false)
|
|
||||||
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
|
|
||||||
for {
|
|
||||||
c.connected.Store(true)
|
|
||||||
err := c.connect(ctx, sessionCb)
|
|
||||||
if err != nil {
|
|
||||||
c.log.Error("connect error:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(backoffTime())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// connect once to the SSH server. if the connection breaks, we return error and the caller
|
|
||||||
// will try to re-connect
|
|
||||||
func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error {
|
|
||||||
c.log.Debug("connecting")
|
|
||||||
sshConfig := &ssh.ClientConfig{
|
|
||||||
User: c.user,
|
|
||||||
Auth: c.authMethods,
|
|
||||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
||||||
}
|
|
||||||
|
|
||||||
var sshClient *ssh.Client
|
|
||||||
var errConnect error
|
|
||||||
|
|
||||||
sshClient, errConnect = dialSSH(ctx, c.address.String(), sshConfig)
|
|
||||||
if errConnect != nil {
|
|
||||||
return errConnect
|
|
||||||
}
|
|
||||||
|
|
||||||
defer sshClient.Close()
|
|
||||||
defer c.log.Debug("disconnecting")
|
|
||||||
|
|
||||||
c.log.Debug("connected")
|
|
||||||
listenerStopped := make(chan error)
|
|
||||||
|
|
||||||
sess, err := sshClient.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
c.log.Errorf("session error: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer sess.Close()
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
if sessionCb == nil {
|
|
||||||
sessionCb = func(*ssh.Session) {}
|
|
||||||
}
|
|
||||||
wg.Add(2)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
sessionCb(sess)
|
|
||||||
}()
|
|
||||||
reverseErr := make(chan error)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
reverseErr <- c.reverseForwardOnePort(sshClient, listenerStopped)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := <-reverseErr; err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil
|
|
||||||
case listenerFirstErr := <-listenerStopped:
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return listenerFirstErr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error
|
|
||||||
//
|
|
||||||
// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped
|
|
||||||
func (c *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped chan<- error) error {
|
|
||||||
var (
|
|
||||||
listener net.Listener
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if c.fakeRemoteHost {
|
|
||||||
listener, err = sshClient.ListenTCP(&net.TCPAddr{
|
|
||||||
IP: c.ipFromAddr(sshClient.Conn.RemoteAddr()),
|
|
||||||
Port: c.forward.Remote.Port,
|
|
||||||
}, c.forward.Remote.Host)
|
|
||||||
} else {
|
|
||||||
listener, err = sshClient.Listen("tcp", c.forward.Remote.String())
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer listener.Close()
|
|
||||||
c.log.Debugf("forwarding %s <- %s", c.forward.Local.String(), c.forward.Remote.String())
|
|
||||||
|
|
||||||
for {
|
|
||||||
client, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
listenerStopped <- fmt.Errorf("error on Accept(): %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go handleReverseForwardConn(client, c.forward, c.log)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Tunnel) ipFromAddr(addr net.Addr) net.IP {
|
|
||||||
host, _, _ := net.SplitHostPort(addr.String())
|
|
||||||
return net.ParseIP(host)
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) {
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
log.Debugf("%s connected", client.RemoteAddr())
|
|
||||||
defer log.Debug("closed")
|
|
||||||
|
|
||||||
remote, err := net.Dial("tcp", forward.Local.String())
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("dial INTO local service error: %s", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// pipe data in both directions:
|
|
||||||
// - client => remote
|
|
||||||
// - remote => client
|
|
||||||
//
|
|
||||||
// - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed
|
|
||||||
// remote endpoint.
|
|
||||||
// - the "client" and "remote" strings we give Pipe() is just for error&log messages
|
|
||||||
// - this blocks until either of the parties' socket closes (or breaks)
|
|
||||||
if err := bidipipe.Pipe(
|
|
||||||
bidipipe.WithName("client", client),
|
|
||||||
bidipipe.WithName("remote", remote),
|
|
||||||
); err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
|
|
||||||
dialer := net.Dialer{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return ssh.NewClient(clConn, newChan, reqChan), nil
|
|
||||||
}
|
|
@ -62,10 +62,6 @@ func (k SmartPath) Resolve(shouldBeDirectory bool) (result string, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k SmartPath) String() string {
|
|
||||||
return string(k)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a Auth) Validate() error {
|
func (a Auth) Validate() error {
|
||||||
if a.Type == AuthTypePassword && a.Password == "" {
|
if a.Type == AuthTypePassword && a.Password == "" {
|
||||||
return fmt.Errorf("password must be provided for authentication type '%s'", AuthTypePassword)
|
return fmt.Errorf("password must be provided for authentication type '%s'", AuthTypePassword)
|
||||||
|
249
internal/server/proto/sshtun/tunnel.go
Normal file
249
internal/server/proto/sshtun/tunnel.go
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
// Copyright 2017, The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE.md file.
|
||||||
|
|
||||||
|
package sshtun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TunnelMode uint8
|
||||||
|
|
||||||
|
func (t TunnelMode) String() string {
|
||||||
|
switch t {
|
||||||
|
case TunnelForward:
|
||||||
|
return "->"
|
||||||
|
case TunnelReverse:
|
||||||
|
return "<-"
|
||||||
|
default:
|
||||||
|
return "<?>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
TunnelForward TunnelMode = iota
|
||||||
|
TunnelReverse
|
||||||
|
)
|
||||||
|
|
||||||
|
type logger interface {
|
||||||
|
Printf(string, ...interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tunnel struct {
|
||||||
|
Auth []ssh.AuthMethod
|
||||||
|
HostKeys ssh.HostKeyCallback
|
||||||
|
Mode TunnelMode
|
||||||
|
User string
|
||||||
|
HostAddr string
|
||||||
|
BindAddr string
|
||||||
|
DialAddr string
|
||||||
|
RetryInterval time.Duration
|
||||||
|
KeepAlive KeepAliveConfig
|
||||||
|
Logger logger
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeepAliveConfig struct {
|
||||||
|
// Interval is the amount of time in seconds to wait before the
|
||||||
|
// Tunnel client will send a keep-alive message to ensure some minimum
|
||||||
|
// traffic on the SSH connection.
|
||||||
|
Interval uint
|
||||||
|
|
||||||
|
// CountMax is the maximum number of consecutive failed responses to
|
||||||
|
// keep-alive messages the client is willing to tolerate before considering
|
||||||
|
// the SSH connection as dead.
|
||||||
|
CountMax uint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tunnel) String() string {
|
||||||
|
var left, right string
|
||||||
|
switch t.Mode {
|
||||||
|
case TunnelForward:
|
||||||
|
left, right = t.BindAddr, t.DialAddr
|
||||||
|
case TunnelReverse:
|
||||||
|
left, right = t.DialAddr, t.BindAddr
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s@%s | %s %s %s", t.User, t.HostAddr, left, t.Mode, right)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tunnel) Bind(ctx context.Context, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
for {
|
||||||
|
var once sync.Once // Only print errors once per session
|
||||||
|
func() {
|
||||||
|
// Connect to the server host via SSH.
|
||||||
|
cl, err := ssh.Dial("tcp", t.HostAddr, &ssh.ClientConfig{
|
||||||
|
User: t.User,
|
||||||
|
Auth: t.Auth,
|
||||||
|
HostKeyCallback: t.HostKeys,
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
once.Do(func() { t.Logger.Printf("(%v) SSH dial error: %v", t, err) })
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go t.keepAliveMonitor(&once, wg, cl)
|
||||||
|
defer cl.Close()
|
||||||
|
|
||||||
|
// Attempt to bind to the inbound socket.
|
||||||
|
var ln net.Listener
|
||||||
|
switch t.Mode {
|
||||||
|
case TunnelForward:
|
||||||
|
ln, err = net.Listen("tcp", t.BindAddr)
|
||||||
|
case TunnelReverse:
|
||||||
|
ln, err = cl.Listen("tcp", t.BindAddr)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
once.Do(func() { t.Logger.Printf("(%v) bind error: %v", t, err) })
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// The socket is bound. Make sure we close it eventually.
|
||||||
|
bindCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
go func() {
|
||||||
|
cl.Wait()
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
<-bindCtx.Done()
|
||||||
|
once.Do(func() {}) // Suppress future errors
|
||||||
|
ln.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.Logger.Printf("(%v) bound Tunnel", t)
|
||||||
|
defer t.Logger.Printf("(%v) collapsed Tunnel", t)
|
||||||
|
|
||||||
|
// Accept all incoming connections.
|
||||||
|
for {
|
||||||
|
cn1, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
once.Do(func() { t.Logger.Printf("(%v) accept error: %v", t, err) })
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go t.dialTunnel(bindCtx, wg, cl, cn1)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(t.RetryInterval):
|
||||||
|
t.Logger.Printf("(%v) retrying...", t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tunnel) dialTunnel(ctx context.Context, wg *sync.WaitGroup, client *ssh.Client, cn1 net.Conn) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// The inbound connection is established. Make sure we close it eventually.
|
||||||
|
connCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
go func() {
|
||||||
|
<-connCtx.Done()
|
||||||
|
cn1.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Establish the outbound connection.
|
||||||
|
var cn2 net.Conn
|
||||||
|
var err error
|
||||||
|
switch t.Mode {
|
||||||
|
case TunnelForward:
|
||||||
|
cn2, err = client.Dial("tcp", t.DialAddr)
|
||||||
|
case TunnelReverse:
|
||||||
|
cn2, err = net.Dial("tcp", t.DialAddr)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Logger.Printf("(%v) dial error: %v", t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-connCtx.Done()
|
||||||
|
cn2.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.Logger.Printf("(%v) connection established", t)
|
||||||
|
defer t.Logger.Printf("(%v) connection closed", t)
|
||||||
|
|
||||||
|
// Copy bytes from one connection to the other until one side closes.
|
||||||
|
var once sync.Once
|
||||||
|
var wg2 sync.WaitGroup
|
||||||
|
wg2.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg2.Done()
|
||||||
|
defer cancel()
|
||||||
|
if _, err := io.Copy(cn1, cn2); err != nil {
|
||||||
|
once.Do(func() { t.Logger.Printf("(%v) connection error: %v", t, err) })
|
||||||
|
}
|
||||||
|
once.Do(func() {}) // Suppress future errors
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg2.Done()
|
||||||
|
defer cancel()
|
||||||
|
if _, err := io.Copy(cn2, cn1); err != nil {
|
||||||
|
once.Do(func() { t.Logger.Printf("(%v) connection error: %v", t, err) })
|
||||||
|
}
|
||||||
|
once.Do(func() {}) // Suppress future errors
|
||||||
|
}()
|
||||||
|
wg2.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// keepAliveMonitor periodically sends messages to invoke a response.
|
||||||
|
// If the server does not respond after some period of time,
|
||||||
|
// assume that the underlying net.Conn abruptly died.
|
||||||
|
func (t Tunnel) keepAliveMonitor(once *sync.Once, wg *sync.WaitGroup, client *ssh.Client) {
|
||||||
|
defer wg.Done()
|
||||||
|
if t.KeepAlive.Interval == 0 || t.KeepAlive.CountMax == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect when the SSH connection is closed.
|
||||||
|
wait := make(chan error, 1)
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
wait <- client.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Repeatedly check if the remote server is still alive.
|
||||||
|
var aliveCount int32
|
||||||
|
ticker := time.NewTicker(time.Duration(t.KeepAlive.Interval) * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case err := <-wait:
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
once.Do(func() { t.Logger.Printf("(%v) SSH error: %v", t, err) })
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.KeepAlive.CountMax) {
|
||||||
|
once.Do(func() { t.Logger.Printf("(%v) SSH keep-alive termination", t) })
|
||||||
|
client.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
|
||||||
|
if err == nil {
|
||||||
|
atomic.StoreInt32(&aliveCount, 0)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
509
internal/server/proto/sshtun/tunnel_test.go
Normal file
509
internal/server/proto/sshtun/tunnel_test.go
Normal file
@ -0,0 +1,509 @@
|
|||||||
|
// Copyright 2017, The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE.md file.
|
||||||
|
|
||||||
|
package sshtun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/md5"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testLogger struct {
|
||||||
|
*testing.T // Already has Fatalf method
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t testLogger) Printf(f string, x ...interface{}) { t.Logf(f, x...) }
|
||||||
|
|
||||||
|
func TestTunnel(t *testing.T) {
|
||||||
|
rootWG := new(sync.WaitGroup)
|
||||||
|
defer rootWG.Wait()
|
||||||
|
rootCtx, cancelAll := context.WithCancel(context.Background())
|
||||||
|
defer cancelAll()
|
||||||
|
|
||||||
|
// Open all of the TCP sockets needed for the test.
|
||||||
|
tcpLn0 := openListener(t) // Start of the chain
|
||||||
|
tcpLn1 := openListener(t) // Mid-point of the chain
|
||||||
|
tcpLn2 := openListener(t) // End of the chain
|
||||||
|
srvLn0 := openListener(t) // Socket for SSH server in reverse Mode
|
||||||
|
srvLn1 := openListener(t) // Socket for SSH server in forward Mode
|
||||||
|
|
||||||
|
tcpLn0.Close() // To be later binded by the reverse Tunnel
|
||||||
|
tcpLn1.Close() // To be later binded by the forward Tunnel
|
||||||
|
go closeWhenDone(rootCtx, tcpLn2)
|
||||||
|
go closeWhenDone(rootCtx, srvLn0)
|
||||||
|
go closeWhenDone(rootCtx, srvLn1)
|
||||||
|
|
||||||
|
// Generate keys for both the servers and clients.
|
||||||
|
clientPriv0, clientPub0 := generateKeys(t)
|
||||||
|
clientPriv1, clientPub1 := generateKeys(t)
|
||||||
|
serverPriv0, serverPub0 := generateKeys(t)
|
||||||
|
serverPriv1, serverPub1 := generateKeys(t)
|
||||||
|
|
||||||
|
// Start the SSH servers.
|
||||||
|
rootWG.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer rootWG.Done()
|
||||||
|
runServer(t, rootCtx, srvLn0, serverPriv0, clientPub0, clientPub1)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer rootWG.Done()
|
||||||
|
runServer(t, rootCtx, srvLn1, serverPriv1, clientPub0, clientPub1)
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg := new(sync.WaitGroup)
|
||||||
|
defer wg.Wait()
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Create the Tunnel configurations.
|
||||||
|
tn0 := Tunnel{
|
||||||
|
Auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv0)},
|
||||||
|
HostKeys: ssh.FixedHostKey(serverPub0),
|
||||||
|
Mode: TunnelReverse, // Reverse Tunnel
|
||||||
|
User: "user0",
|
||||||
|
HostAddr: srvLn0.Addr().String(),
|
||||||
|
BindAddr: tcpLn0.Addr().String(),
|
||||||
|
DialAddr: tcpLn1.Addr().String(),
|
||||||
|
Logger: testLogger{t},
|
||||||
|
}
|
||||||
|
tn1 := Tunnel{
|
||||||
|
Auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv1)},
|
||||||
|
HostKeys: ssh.FixedHostKey(serverPub1),
|
||||||
|
Mode: TunnelForward, // Forward Tunnel
|
||||||
|
User: "user1",
|
||||||
|
HostAddr: srvLn1.Addr().String(),
|
||||||
|
BindAddr: tcpLn1.Addr().String(),
|
||||||
|
DialAddr: tcpLn2.Addr().String(),
|
||||||
|
Logger: testLogger{t},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the SSH client tunnels.
|
||||||
|
wg.Add(2)
|
||||||
|
go tn0.Bind(ctx, wg)
|
||||||
|
go tn1.Bind(ctx, wg)
|
||||||
|
|
||||||
|
t.Log("test started")
|
||||||
|
done := make(chan bool, 10)
|
||||||
|
|
||||||
|
// Start all the transmitters.
|
||||||
|
for i := 0; i < cap(done); i++ {
|
||||||
|
i := i
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
rnd := rand.New(rand.NewSource(int64(i)))
|
||||||
|
hash := md5.New()
|
||||||
|
size := uint32((1 << 10) + rnd.Intn(1<<20))
|
||||||
|
buf4 := make([]byte, 4)
|
||||||
|
binary.LittleEndian.PutUint32(buf4, size)
|
||||||
|
|
||||||
|
cnStart, err := net.Dial("tcp", tcpLn0.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
defer cnStart.Close()
|
||||||
|
if _, err := cnStart.Write(buf4); err != nil {
|
||||||
|
t.Errorf("write size error: %v", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
r := io.LimitReader(rnd, int64(size))
|
||||||
|
w := io.MultiWriter(cnStart, hash)
|
||||||
|
if _, err := io.Copy(w, r); err != nil {
|
||||||
|
t.Errorf("copy error: %v", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if _, err := cnStart.Write(hash.Sum(nil)); err != nil {
|
||||||
|
t.Errorf("write hash error: %v", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err := cnStart.Close(); err != nil {
|
||||||
|
t.Errorf("close error: %v", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start all the receivers.
|
||||||
|
for i := 0; i < cap(done); i++ {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
hash := md5.New()
|
||||||
|
buf4 := make([]byte, 4)
|
||||||
|
|
||||||
|
cnEnd, err := tcpLn2.Accept()
|
||||||
|
if err != nil {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
defer cnEnd.Close()
|
||||||
|
|
||||||
|
if _, err := io.ReadFull(cnEnd, buf4); err != nil {
|
||||||
|
t.Errorf("read size error: %v", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
size := binary.LittleEndian.Uint32(buf4)
|
||||||
|
r := io.LimitReader(cnEnd, int64(size))
|
||||||
|
if _, err := io.Copy(hash, r); err != nil {
|
||||||
|
t.Errorf("copy error: %v", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
wantHash, err := ioutil.ReadAll(cnEnd)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("read hash error: %v", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err := cnEnd.Close(); err != nil {
|
||||||
|
t.Errorf("close error: %v", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotHash := hash.Sum(nil); !bytes.Equal(gotHash, wantHash) {
|
||||||
|
t.Errorf("hash mismatch:\ngot %x\nwant %x", gotHash, wantHash)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < cap(done); i++ {
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
t.Errorf("timed out: %d remaining", cap(done)-i)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Log("test complete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateKeys generates a random pair of SSH private and public keys.
|
||||||
|
func generateKeys(t *testing.T) (priv ssh.Signer, pub ssh.PublicKey) {
|
||||||
|
rnd := rand.New(rand.NewSource(time.Now().Unix()))
|
||||||
|
rsaKey, err := rsa.GenerateKey(rnd, 1024)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to generate RSA key pair: %v", err)
|
||||||
|
}
|
||||||
|
priv, err = ssh.NewSignerFromKey(rsaKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to generate signer: %v", err)
|
||||||
|
}
|
||||||
|
pub, err = ssh.NewPublicKey(&rsaKey.PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to generate public key: %v", err)
|
||||||
|
}
|
||||||
|
return priv, pub
|
||||||
|
}
|
||||||
|
|
||||||
|
func openListener(t *testing.T) net.Listener {
|
||||||
|
ln, err := net.Listen("tcp", ":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen error: %v", err)
|
||||||
|
}
|
||||||
|
return ln
|
||||||
|
}
|
||||||
|
|
||||||
|
// runServer starts an SSH server capable of handling forward and reverse
|
||||||
|
// TCP tunnels. This function blocks for the entire duration that the
|
||||||
|
// server is running and can be stopped by canceling the context.
|
||||||
|
//
|
||||||
|
// The server listens on the provided Listener and will present to clients
|
||||||
|
// a certificate from serverKey and will only accept users that match
|
||||||
|
// the provided clientKeys. Only users of the name "User%d" are allowed where
|
||||||
|
// the ID number is the index for the specified client key provided.
|
||||||
|
func runServer(t *testing.T, ctx context.Context, ln net.Listener, serverKey ssh.Signer, clientKeys ...ssh.PublicKey) {
|
||||||
|
wg := new(sync.WaitGroup)
|
||||||
|
defer wg.Wait()
|
||||||
|
|
||||||
|
// Generate SSH server configuration.
|
||||||
|
conf := ssh.ServerConfig{
|
||||||
|
PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
||||||
|
var uid int
|
||||||
|
_, err := fmt.Sscanf(c.User(), "User%d", &uid)
|
||||||
|
if err != nil || uid >= len(clientKeys) || !bytes.Equal(clientKeys[uid].Marshal(), pubKey.Marshal()) {
|
||||||
|
return nil, fmt.Errorf("unknown public key for %q", c.User())
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conf.AddHostKey(serverKey)
|
||||||
|
|
||||||
|
// Handle every SSH client connection.
|
||||||
|
for {
|
||||||
|
tcpCn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if !isDone(ctx) {
|
||||||
|
t.Errorf("accept error: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go handleServerConn(t, ctx, wg, tcpCn, &conf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleServerConn handles a single SSH connection.
|
||||||
|
func handleServerConn(t *testing.T, ctx context.Context, wg *sync.WaitGroup, tcpCn net.Conn, conf *ssh.ServerConfig) {
|
||||||
|
defer wg.Done()
|
||||||
|
go closeWhenDone(ctx, tcpCn)
|
||||||
|
defer tcpCn.Close()
|
||||||
|
|
||||||
|
sshCn, chans, reqs, err := ssh.NewServerConn(tcpCn, conf)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("new connection error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go closeWhenDone(ctx, sshCn)
|
||||||
|
defer sshCn.Close()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go handleServerChannels(t, ctx, wg, sshCn, chans)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go handleServerRequests(t, ctx, wg, sshCn, reqs)
|
||||||
|
|
||||||
|
if err := sshCn.Wait(); err != nil && err != io.EOF && !isDone(ctx) {
|
||||||
|
t.Errorf("connection error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleServerChannels handles new channels on a SSH connection.
|
||||||
|
// The client initiates a new channel when forwarding a TCP dial.
|
||||||
|
func handleServerChannels(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, chans <-chan ssh.NewChannel) {
|
||||||
|
defer wg.Done()
|
||||||
|
for nc := range chans {
|
||||||
|
if nc.ChannelType() != "direct-tcpip" {
|
||||||
|
nc.Reject(ssh.UnknownChannelType, "not implemented")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var args struct {
|
||||||
|
DstHost string
|
||||||
|
DstPort uint32
|
||||||
|
SrcHost string
|
||||||
|
SrcPort uint32
|
||||||
|
}
|
||||||
|
if !unmarshalData(nc.ExtraData(), &args) {
|
||||||
|
nc.Reject(ssh.Prohibited, "invalid request")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open a connection for both sides.
|
||||||
|
cn, err := net.Dial("tcp", net.JoinHostPort(args.DstHost, strconv.Itoa(int(args.DstPort))))
|
||||||
|
if err != nil {
|
||||||
|
nc.Reject(ssh.ConnectionFailed, err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ch, reqs, err := nc.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("accept channel error: %v", err)
|
||||||
|
cn.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go ssh.DiscardRequests(reqs)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go bidirCopyAndClose(t, ctx, wg, cn, ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleServerRequests handles new requests on a SSH connection.
|
||||||
|
// The client initiates a new request for binding a local TCP socket.
|
||||||
|
func handleServerRequests(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, reqs <-chan *ssh.Request) {
|
||||||
|
defer wg.Done()
|
||||||
|
for r := range reqs {
|
||||||
|
if !r.WantReply {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if r.Type != "tcpip-forward" {
|
||||||
|
r.Reply(false, nil)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var args struct {
|
||||||
|
Host string
|
||||||
|
Port uint32
|
||||||
|
}
|
||||||
|
if !unmarshalData(r.Payload, &args) {
|
||||||
|
r.Reply(false, nil)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ln, err := net.Listen("tcp", net.JoinHostPort(args.Host, strconv.Itoa(int(args.Port))))
|
||||||
|
if err != nil {
|
||||||
|
r.Reply(false, nil)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp struct{ Port uint32 }
|
||||||
|
_, resp.Port = splitHostPort(ln.Addr().String())
|
||||||
|
if err := r.Reply(true, marshalData(resp)); err != nil {
|
||||||
|
t.Errorf("request reply error: %v", err)
|
||||||
|
ln.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go handleLocalListener(t, ctx, wg, sshCn, ln, args.Host)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleLocalListener handles every new connection on the provided socket.
|
||||||
|
// All local connections will be forwarded to the client via a new channel.
|
||||||
|
func handleLocalListener(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, ln net.Listener, host string) {
|
||||||
|
defer wg.Done()
|
||||||
|
go closeWhenDone(ctx, ln)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
// Open a connection for both sides.
|
||||||
|
cn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if !isDone(ctx) {
|
||||||
|
t.Errorf("accept error: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var args struct {
|
||||||
|
DstHost string
|
||||||
|
DstPort uint32
|
||||||
|
SrcHost string
|
||||||
|
SrcPort uint32
|
||||||
|
}
|
||||||
|
args.DstHost, args.DstPort = splitHostPort(cn.LocalAddr().String())
|
||||||
|
args.SrcHost, args.SrcPort = splitHostPort(cn.RemoteAddr().String())
|
||||||
|
args.DstHost = host // This must match on client side!
|
||||||
|
ch, reqs, err := sshCn.OpenChannel("forwarded-tcpip", marshalData(args))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("open channel error: %v", err)
|
||||||
|
cn.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go ssh.DiscardRequests(reqs)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go bidirCopyAndClose(t, ctx, wg, cn, ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// bidirCopyAndClose performs a bi-directional copy on both connections
|
||||||
|
// until either side closes the connection or the context is canceled.
|
||||||
|
// This will close both connections before returning.
|
||||||
|
func bidirCopyAndClose(t *testing.T, ctx context.Context, wg *sync.WaitGroup, c1, c2 io.ReadWriteCloser) {
|
||||||
|
defer wg.Done()
|
||||||
|
go closeWhenDone(ctx, c1)
|
||||||
|
go closeWhenDone(ctx, c2)
|
||||||
|
defer c1.Close()
|
||||||
|
defer c2.Close()
|
||||||
|
|
||||||
|
errc := make(chan error, 2)
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(c1, c2)
|
||||||
|
errc <- err
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(c2, c1)
|
||||||
|
errc <- err
|
||||||
|
}()
|
||||||
|
if err := <-errc; err != nil && err != io.EOF && !isDone(ctx) {
|
||||||
|
t.Errorf("copy error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unmarshalData parses b into s, where s is a pointer to a struct.
|
||||||
|
// Only unexported fields of type uint32 or string are allowed.
|
||||||
|
func unmarshalData(b []byte, s interface{}) bool {
|
||||||
|
v := reflect.ValueOf(s)
|
||||||
|
if !v.IsValid() || v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
|
||||||
|
panic("destination must be pointer to struct")
|
||||||
|
}
|
||||||
|
v = v.Elem()
|
||||||
|
for i := 0; i < v.NumField(); i++ {
|
||||||
|
switch v.Type().Field(i).Type.Kind() {
|
||||||
|
case reflect.Uint32:
|
||||||
|
if len(b) < 4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
v.Field(i).Set(reflect.ValueOf(binary.BigEndian.Uint32(b)))
|
||||||
|
b = b[4:]
|
||||||
|
case reflect.String:
|
||||||
|
if len(b) < 4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
n := binary.BigEndian.Uint32(b)
|
||||||
|
b = b[4:]
|
||||||
|
if uint64(len(b)) < uint64(n) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
v.Field(i).Set(reflect.ValueOf(string(b[:n])))
|
||||||
|
b = b[n:]
|
||||||
|
default:
|
||||||
|
panic("invalid field type: " + v.Type().Field(i).Type.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(b) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// marshalData serializes s into b, where s is a struct (or a pointer to one).
|
||||||
|
// Only unexported fields of type uint32 or string are allowed.
|
||||||
|
func marshalData(s interface{}) (b []byte) {
|
||||||
|
v := reflect.ValueOf(s)
|
||||||
|
if v.IsValid() && v.Kind() == reflect.Ptr {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
if !v.IsValid() || v.Kind() != reflect.Struct {
|
||||||
|
panic("source must be a struct")
|
||||||
|
}
|
||||||
|
var arr32 [4]byte
|
||||||
|
for i := 0; i < v.NumField(); i++ {
|
||||||
|
switch v.Type().Field(i).Type.Kind() {
|
||||||
|
case reflect.Uint32:
|
||||||
|
binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Uint()))
|
||||||
|
b = append(b, arr32[:]...)
|
||||||
|
case reflect.String:
|
||||||
|
binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Len()))
|
||||||
|
b = append(b, arr32[:]...)
|
||||||
|
b = append(b, v.Field(i).String()...)
|
||||||
|
default:
|
||||||
|
panic("invalid field type: " + v.Type().Field(i).Type.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitHostPort(s string) (string, uint32) {
|
||||||
|
host, port, _ := net.SplitHostPort(s)
|
||||||
|
p, _ := strconv.Atoi(port)
|
||||||
|
return host, uint32(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeWhenDone(ctx context.Context, c io.Closer) {
|
||||||
|
<-ctx.Done()
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDone(ctx context.Context) bool {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
@ -1,23 +0,0 @@
|
|||||||
--- pkg/proto/ssh/tcpip.go 2023-11-18 21:39:15.394837005 +0300
|
|
||||||
+++ pkg/proto/ssh/tcpip.go 2023-11-18 21:38:25.706173351 +0300
|
|
||||||
@@ -101,14 +101,18 @@
|
|
||||||
// ListenTCP requests the remote peer open a listening socket
|
|
||||||
// on laddr. Incoming connections will be available by calling
|
|
||||||
// Accept on the returned net.Listener.
|
|
||||||
-func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
|
|
||||||
+func (c *Client) ListenTCP(laddr *net.TCPAddr, fakeHost ...string) (net.Listener, error) {
|
|
||||||
c.handleForwardsOnce.Do(c.handleForwards)
|
|
||||||
if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
|
|
||||||
return c.autoPortListenWorkaround(laddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
+ host := laddr.IP.String()
|
|
||||||
+ if len(fakeHost) > 0 {
|
|
||||||
+ host = fakeHost[0]
|
|
||||||
+ }
|
|
||||||
m := channelForwardMsg{
|
|
||||||
- laddr.IP.String(),
|
|
||||||
+ host,
|
|
||||||
uint32(laddr.Port),
|
|
||||||
}
|
|
||||||
// send message
|
|
@ -16,7 +16,7 @@ func MessageToAppEvent(event *pb.EventMessage) dto.Event {
|
|||||||
IP: net.ParseIP(event.GetContainer().GetIp()),
|
IP: net.ParseIP(event.GetContainer().GetIp()),
|
||||||
Port: uint16(event.GetContainer().GetPort()),
|
Port: uint16(event.GetContainer().GetPort()),
|
||||||
Server: event.GetContainer().GetServer(),
|
Server: event.GetContainer().GetServer(),
|
||||||
RemoteHost: event.GetContainer().GetRemoteHost(),
|
Prefix: event.GetContainer().GetPrefix(),
|
||||||
Domain: event.GetContainer().GetDomain(),
|
Domain: event.GetContainer().GetDomain(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -31,7 +31,7 @@ func AppEventToMessage(event dto.Event) *pb.EventMessage {
|
|||||||
Ip: event.Container.IP.String(),
|
Ip: event.Container.IP.String(),
|
||||||
Port: uint32(event.Container.Port),
|
Port: uint32(event.Container.Port),
|
||||||
Server: event.Container.Server,
|
Server: event.Container.Server,
|
||||||
RemoteHost: event.Container.RemoteHost,
|
Prefix: event.Container.Prefix,
|
||||||
Domain: event.Container.Domain,
|
Domain: event.Container.Domain,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -41,6 +41,6 @@ type Container struct {
|
|||||||
IP net.IP `json:"ip"`
|
IP net.IP `json:"ip"`
|
||||||
Port uint16 `json:"port"`
|
Port uint16 `json:"port"`
|
||||||
Server string `json:"-"`
|
Server string `json:"-"`
|
||||||
RemoteHost string `json:"remote_host"`
|
Prefix string `json:"prefix"`
|
||||||
Domain string `json:"domain"`
|
Domain string `json:"domain"`
|
||||||
}
|
}
|
||||||
|
@ -1,18 +0,0 @@
|
|||||||
package errtools
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func IsPortMissingErr(err error) bool {
|
|
||||||
for {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if strings.Contains(err.Error(), "missing port in address") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
err = errors.Unwrap(err)
|
|
||||||
}
|
|
||||||
}
|
|
@ -9,7 +9,6 @@ import (
|
|||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/convert"
|
"github.com/Neur0toxine/sshpoke/pkg/convert"
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/dto"
|
"github.com/Neur0toxine/sshpoke/pkg/dto"
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/errtools"
|
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/plugin/pb"
|
"github.com/Neur0toxine/sshpoke/pkg/plugin/pb"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
@ -58,12 +57,12 @@ func normalizeAddr(addr string) string {
|
|||||||
if strings.HasPrefix(addr, "grpc://") {
|
if strings.HasPrefix(addr, "grpc://") {
|
||||||
addr = addr[7:]
|
addr = addr[7:]
|
||||||
}
|
}
|
||||||
_, _, err := net.SplitHostPort(addr)
|
host, port, err := net.SplitHostPort(addr)
|
||||||
if err != nil && errtools.IsPortMissingErr(err) {
|
if err != nil && err.Error() == "missing port in address" {
|
||||||
addr = net.JoinHostPort(addr, strconv.Itoa(DefaultPort))
|
host, port, err = net.SplitHostPort(addr + ":" + strconv.Itoa(DefaultPort))
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return addr
|
return host + ":" + port
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ message Container {
|
|||||||
string ip = 3;
|
string ip = 3;
|
||||||
uint32 port = 4;
|
uint32 port = 4;
|
||||||
string server = 5;
|
string server = 5;
|
||||||
string remote_host = 6;
|
string prefix = 6;
|
||||||
string domain = 7;
|
string domain = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,854 +0,0 @@
|
|||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// Package agent implements the ssh-agent protocol, and provides both
|
|
||||||
// a client and a server. The client can talk to a standard ssh-agent
|
|
||||||
// that uses UNIX sockets, and one could implement an alternative
|
|
||||||
// ssh-agent process using the sample server.
|
|
||||||
//
|
|
||||||
// References:
|
|
||||||
//
|
|
||||||
// [PROTOCOL.agent]: https://tools.ietf.org/html/draft-miller-ssh-agent-00
|
|
||||||
package agent // import "github.com/Neur0toxine/sshpoke/pkg/proto/ssh/agent"
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/dsa"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math/big"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SignatureFlags represent additional flags that can be passed to the signature
|
|
||||||
// requests an defined in [PROTOCOL.agent] section 4.5.1.
|
|
||||||
type SignatureFlags uint32
|
|
||||||
|
|
||||||
// SignatureFlag values as defined in [PROTOCOL.agent] section 5.3.
|
|
||||||
const (
|
|
||||||
SignatureFlagReserved SignatureFlags = 1 << iota
|
|
||||||
SignatureFlagRsaSha256
|
|
||||||
SignatureFlagRsaSha512
|
|
||||||
)
|
|
||||||
|
|
||||||
// Agent represents the capabilities of an ssh-agent.
|
|
||||||
type Agent interface {
|
|
||||||
// List returns the identities known to the agent.
|
|
||||||
List() ([]*Key, error)
|
|
||||||
|
|
||||||
// Sign has the agent sign the data using a protocol 2 key as defined
|
|
||||||
// in [PROTOCOL.agent] section 2.6.2.
|
|
||||||
Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error)
|
|
||||||
|
|
||||||
// Add adds a private key to the agent.
|
|
||||||
Add(key AddedKey) error
|
|
||||||
|
|
||||||
// Remove removes all identities with the given public key.
|
|
||||||
Remove(key ssh.PublicKey) error
|
|
||||||
|
|
||||||
// RemoveAll removes all identities.
|
|
||||||
RemoveAll() error
|
|
||||||
|
|
||||||
// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
|
|
||||||
Lock(passphrase []byte) error
|
|
||||||
|
|
||||||
// Unlock undoes the effect of Lock
|
|
||||||
Unlock(passphrase []byte) error
|
|
||||||
|
|
||||||
// Signers returns signers for all the known keys.
|
|
||||||
Signers() ([]ssh.Signer, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ExtendedAgent interface {
|
|
||||||
Agent
|
|
||||||
|
|
||||||
// SignWithFlags signs like Sign, but allows for additional flags to be sent/received
|
|
||||||
SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error)
|
|
||||||
|
|
||||||
// Extension processes a custom extension request. Standard-compliant agents are not
|
|
||||||
// required to support any extensions, but this method allows agents to implement
|
|
||||||
// vendor-specific methods or add experimental features. See [PROTOCOL.agent] section 4.7.
|
|
||||||
// If agent extensions are unsupported entirely this method MUST return an
|
|
||||||
// ErrExtensionUnsupported error. Similarly, if just the specific extensionType in
|
|
||||||
// the request is unsupported by the agent then ErrExtensionUnsupported MUST be
|
|
||||||
// returned.
|
|
||||||
//
|
|
||||||
// In the case of success, since [PROTOCOL.agent] section 4.7 specifies that the contents
|
|
||||||
// of the response are unspecified (including the type of the message), the complete
|
|
||||||
// response will be returned as a []byte slice, including the "type" byte of the message.
|
|
||||||
Extension(extensionType string, contents []byte) ([]byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConstraintExtension describes an optional constraint defined by users.
|
|
||||||
type ConstraintExtension struct {
|
|
||||||
// ExtensionName consist of a UTF-8 string suffixed by the
|
|
||||||
// implementation domain following the naming scheme defined
|
|
||||||
// in Section 4.2 of RFC 4251, e.g. "foo@example.com".
|
|
||||||
ExtensionName string
|
|
||||||
// ExtensionDetails contains the actual content of the extended
|
|
||||||
// constraint.
|
|
||||||
ExtensionDetails []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddedKey describes an SSH key to be added to an Agent.
|
|
||||||
type AddedKey struct {
|
|
||||||
// PrivateKey must be a *rsa.PrivateKey, *dsa.PrivateKey,
|
|
||||||
// ed25519.PrivateKey or *ecdsa.PrivateKey, which will be inserted into the
|
|
||||||
// agent.
|
|
||||||
PrivateKey interface{}
|
|
||||||
// Certificate, if not nil, is communicated to the agent and will be
|
|
||||||
// stored with the key.
|
|
||||||
Certificate *ssh.Certificate
|
|
||||||
// Comment is an optional, free-form string.
|
|
||||||
Comment string
|
|
||||||
// LifetimeSecs, if not zero, is the number of seconds that the
|
|
||||||
// agent will store the key for.
|
|
||||||
LifetimeSecs uint32
|
|
||||||
// ConfirmBeforeUse, if true, requests that the agent confirm with the
|
|
||||||
// user before each use of this key.
|
|
||||||
ConfirmBeforeUse bool
|
|
||||||
// ConstraintExtensions are the experimental or private-use constraints
|
|
||||||
// defined by users.
|
|
||||||
ConstraintExtensions []ConstraintExtension
|
|
||||||
}
|
|
||||||
|
|
||||||
// See [PROTOCOL.agent], section 3.
|
|
||||||
const (
|
|
||||||
agentRequestV1Identities = 1
|
|
||||||
agentRemoveAllV1Identities = 9
|
|
||||||
|
|
||||||
// 3.2 Requests from client to agent for protocol 2 key operations
|
|
||||||
agentAddIdentity = 17
|
|
||||||
agentRemoveIdentity = 18
|
|
||||||
agentRemoveAllIdentities = 19
|
|
||||||
agentAddIDConstrained = 25
|
|
||||||
|
|
||||||
// 3.3 Key-type independent requests from client to agent
|
|
||||||
agentAddSmartcardKey = 20
|
|
||||||
agentRemoveSmartcardKey = 21
|
|
||||||
agentLock = 22
|
|
||||||
agentUnlock = 23
|
|
||||||
agentAddSmartcardKeyConstrained = 26
|
|
||||||
|
|
||||||
// 3.7 Key constraint identifiers
|
|
||||||
agentConstrainLifetime = 1
|
|
||||||
agentConstrainConfirm = 2
|
|
||||||
// Constraint extension identifier up to version 2 of the protocol. A
|
|
||||||
// backward incompatible change will be required if we want to add support
|
|
||||||
// for SSH_AGENT_CONSTRAIN_MAXSIGN which uses the same ID.
|
|
||||||
agentConstrainExtensionV00 = 3
|
|
||||||
// Constraint extension identifier in version 3 and later of the protocol.
|
|
||||||
agentConstrainExtension = 255
|
|
||||||
)
|
|
||||||
|
|
||||||
// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
|
|
||||||
// is a sanity check, not a limit in the spec.
|
|
||||||
const maxAgentResponseBytes = 16 << 20
|
|
||||||
|
|
||||||
// Agent messages:
|
|
||||||
// These structures mirror the wire format of the corresponding ssh agent
|
|
||||||
// messages found in [PROTOCOL.agent].
|
|
||||||
|
|
||||||
// 3.4 Generic replies from agent to client
|
|
||||||
const agentFailure = 5
|
|
||||||
|
|
||||||
type failureAgentMsg struct{}
|
|
||||||
|
|
||||||
const agentSuccess = 6
|
|
||||||
|
|
||||||
type successAgentMsg struct{}
|
|
||||||
|
|
||||||
// See [PROTOCOL.agent], section 2.5.2.
|
|
||||||
const agentRequestIdentities = 11
|
|
||||||
|
|
||||||
type requestIdentitiesAgentMsg struct{}
|
|
||||||
|
|
||||||
// See [PROTOCOL.agent], section 2.5.2.
|
|
||||||
const agentIdentitiesAnswer = 12
|
|
||||||
|
|
||||||
type identitiesAnswerAgentMsg struct {
|
|
||||||
NumKeys uint32 `sshtype:"12"`
|
|
||||||
Keys []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See [PROTOCOL.agent], section 2.6.2.
|
|
||||||
const agentSignRequest = 13
|
|
||||||
|
|
||||||
type signRequestAgentMsg struct {
|
|
||||||
KeyBlob []byte `sshtype:"13"`
|
|
||||||
Data []byte
|
|
||||||
Flags uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// See [PROTOCOL.agent], section 2.6.2.
|
|
||||||
|
|
||||||
// 3.6 Replies from agent to client for protocol 2 key operations
|
|
||||||
const agentSignResponse = 14
|
|
||||||
|
|
||||||
type signResponseAgentMsg struct {
|
|
||||||
SigBlob []byte `sshtype:"14"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type publicKey struct {
|
|
||||||
Format string
|
|
||||||
Rest []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3.7 Key constraint identifiers
|
|
||||||
type constrainLifetimeAgentMsg struct {
|
|
||||||
LifetimeSecs uint32 `sshtype:"1"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type constrainExtensionAgentMsg struct {
|
|
||||||
ExtensionName string `sshtype:"255|3"`
|
|
||||||
ExtensionDetails []byte
|
|
||||||
|
|
||||||
// Rest is a field used for parsing, not part of message
|
|
||||||
Rest []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See [PROTOCOL.agent], section 4.7
|
|
||||||
const agentExtension = 27
|
|
||||||
const agentExtensionFailure = 28
|
|
||||||
|
|
||||||
// ErrExtensionUnsupported indicates that an extension defined in
|
|
||||||
// [PROTOCOL.agent] section 4.7 is unsupported by the agent. Specifically this
|
|
||||||
// error indicates that the agent returned a standard SSH_AGENT_FAILURE message
|
|
||||||
// as the result of a SSH_AGENTC_EXTENSION request. Note that the protocol
|
|
||||||
// specification (and therefore this error) does not distinguish between a
|
|
||||||
// specific extension being unsupported and extensions being unsupported entirely.
|
|
||||||
var ErrExtensionUnsupported = errors.New("agent: extension unsupported")
|
|
||||||
|
|
||||||
type extensionAgentMsg struct {
|
|
||||||
ExtensionType string `sshtype:"27"`
|
|
||||||
// NOTE: this matches OpenSSH's PROTOCOL.agent, not the IETF draft [PROTOCOL.agent],
|
|
||||||
// so that it matches what OpenSSH actually implements in the wild.
|
|
||||||
Contents []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Key represents a protocol 2 public key as defined in
|
|
||||||
// [PROTOCOL.agent], section 2.5.2.
|
|
||||||
type Key struct {
|
|
||||||
Format string
|
|
||||||
Blob []byte
|
|
||||||
Comment string
|
|
||||||
}
|
|
||||||
|
|
||||||
func clientErr(err error) error {
|
|
||||||
return fmt.Errorf("agent: client error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns the storage form of an agent key with the format, base64
|
|
||||||
// encoded serialized key, and the comment if it is not empty.
|
|
||||||
func (k *Key) String() string {
|
|
||||||
s := string(k.Format) + " " + base64.StdEncoding.EncodeToString(k.Blob)
|
|
||||||
|
|
||||||
if k.Comment != "" {
|
|
||||||
s += " " + k.Comment
|
|
||||||
}
|
|
||||||
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
// Type returns the public key type.
|
|
||||||
func (k *Key) Type() string {
|
|
||||||
return k.Format
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshal returns key blob to satisfy the ssh.PublicKey interface.
|
|
||||||
func (k *Key) Marshal() []byte {
|
|
||||||
return k.Blob
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify satisfies the ssh.PublicKey interface.
|
|
||||||
func (k *Key) Verify(data []byte, sig *ssh.Signature) error {
|
|
||||||
pubKey, err := ssh.ParsePublicKey(k.Blob)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("agent: bad public key: %v", err)
|
|
||||||
}
|
|
||||||
return pubKey.Verify(data, sig)
|
|
||||||
}
|
|
||||||
|
|
||||||
type wireKey struct {
|
|
||||||
Format string
|
|
||||||
Rest []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseKey(in []byte) (out *Key, rest []byte, err error) {
|
|
||||||
var record struct {
|
|
||||||
Blob []byte
|
|
||||||
Comment string
|
|
||||||
Rest []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ssh.Unmarshal(in, &record); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var wk wireKey
|
|
||||||
if err := ssh.Unmarshal(record.Blob, &wk); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Key{
|
|
||||||
Format: wk.Format,
|
|
||||||
Blob: record.Blob,
|
|
||||||
Comment: record.Comment,
|
|
||||||
}, record.Rest, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// client is a client for an ssh-agent process.
|
|
||||||
type client struct {
|
|
||||||
// conn is typically a *net.UnixConn
|
|
||||||
conn io.ReadWriter
|
|
||||||
// mu is used to prevent concurrent access to the agent
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClient returns an Agent that talks to an ssh-agent process over
|
|
||||||
// the given connection.
|
|
||||||
func NewClient(rw io.ReadWriter) ExtendedAgent {
|
|
||||||
return &client{conn: rw}
|
|
||||||
}
|
|
||||||
|
|
||||||
// call sends an RPC to the agent. On success, the reply is
|
|
||||||
// unmarshaled into reply and replyType is set to the first byte of
|
|
||||||
// the reply, which contains the type of the message.
|
|
||||||
func (c *client) call(req []byte) (reply interface{}, err error) {
|
|
||||||
buf, err := c.callRaw(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
reply, err = unmarshal(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, clientErr(err)
|
|
||||||
}
|
|
||||||
return reply, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// callRaw sends an RPC to the agent. On success, the raw
|
|
||||||
// bytes of the response are returned; no unmarshalling is
|
|
||||||
// performed on the response.
|
|
||||||
func (c *client) callRaw(req []byte) (reply []byte, err error) {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
msg := make([]byte, 4+len(req))
|
|
||||||
binary.BigEndian.PutUint32(msg, uint32(len(req)))
|
|
||||||
copy(msg[4:], req)
|
|
||||||
if _, err = c.conn.Write(msg); err != nil {
|
|
||||||
return nil, clientErr(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var respSizeBuf [4]byte
|
|
||||||
if _, err = io.ReadFull(c.conn, respSizeBuf[:]); err != nil {
|
|
||||||
return nil, clientErr(err)
|
|
||||||
}
|
|
||||||
respSize := binary.BigEndian.Uint32(respSizeBuf[:])
|
|
||||||
if respSize > maxAgentResponseBytes {
|
|
||||||
return nil, clientErr(errors.New("response too large"))
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := make([]byte, respSize)
|
|
||||||
if _, err = io.ReadFull(c.conn, buf); err != nil {
|
|
||||||
return nil, clientErr(err)
|
|
||||||
}
|
|
||||||
return buf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) simpleCall(req []byte) error {
|
|
||||||
resp, err := c.call(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, ok := resp.(*successAgentMsg); ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return errors.New("agent: failure")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) RemoveAll() error {
|
|
||||||
return c.simpleCall([]byte{agentRemoveAllIdentities})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) Remove(key ssh.PublicKey) error {
|
|
||||||
req := ssh.Marshal(&agentRemoveIdentityMsg{
|
|
||||||
KeyBlob: key.Marshal(),
|
|
||||||
})
|
|
||||||
return c.simpleCall(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) Lock(passphrase []byte) error {
|
|
||||||
req := ssh.Marshal(&agentLockMsg{
|
|
||||||
Passphrase: passphrase,
|
|
||||||
})
|
|
||||||
return c.simpleCall(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) Unlock(passphrase []byte) error {
|
|
||||||
req := ssh.Marshal(&agentUnlockMsg{
|
|
||||||
Passphrase: passphrase,
|
|
||||||
})
|
|
||||||
return c.simpleCall(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
// List returns the identities known to the agent.
|
|
||||||
func (c *client) List() ([]*Key, error) {
|
|
||||||
// see [PROTOCOL.agent] section 2.5.2.
|
|
||||||
req := []byte{agentRequestIdentities}
|
|
||||||
|
|
||||||
msg, err := c.call(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case *identitiesAnswerAgentMsg:
|
|
||||||
if msg.NumKeys > maxAgentResponseBytes/8 {
|
|
||||||
return nil, errors.New("agent: too many keys in agent reply")
|
|
||||||
}
|
|
||||||
keys := make([]*Key, msg.NumKeys)
|
|
||||||
data := msg.Keys
|
|
||||||
for i := uint32(0); i < msg.NumKeys; i++ {
|
|
||||||
var key *Key
|
|
||||||
var err error
|
|
||||||
if key, data, err = parseKey(data); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
keys[i] = key
|
|
||||||
}
|
|
||||||
return keys, nil
|
|
||||||
case *failureAgentMsg:
|
|
||||||
return nil, errors.New("agent: failed to list keys")
|
|
||||||
}
|
|
||||||
panic("unreachable")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sign has the agent sign the data using a protocol 2 key as defined
|
|
||||||
// in [PROTOCOL.agent] section 2.6.2.
|
|
||||||
func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
|
|
||||||
return c.SignWithFlags(key, data, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error) {
|
|
||||||
req := ssh.Marshal(signRequestAgentMsg{
|
|
||||||
KeyBlob: key.Marshal(),
|
|
||||||
Data: data,
|
|
||||||
Flags: uint32(flags),
|
|
||||||
})
|
|
||||||
|
|
||||||
msg, err := c.call(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case *signResponseAgentMsg:
|
|
||||||
var sig ssh.Signature
|
|
||||||
if err := ssh.Unmarshal(msg.SigBlob, &sig); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &sig, nil
|
|
||||||
case *failureAgentMsg:
|
|
||||||
return nil, errors.New("agent: failed to sign challenge")
|
|
||||||
}
|
|
||||||
panic("unreachable")
|
|
||||||
}
|
|
||||||
|
|
||||||
// unmarshal parses an agent message in packet, returning the parsed
|
|
||||||
// form and the message type of packet.
|
|
||||||
func unmarshal(packet []byte) (interface{}, error) {
|
|
||||||
if len(packet) < 1 {
|
|
||||||
return nil, errors.New("agent: empty packet")
|
|
||||||
}
|
|
||||||
var msg interface{}
|
|
||||||
switch packet[0] {
|
|
||||||
case agentFailure:
|
|
||||||
return new(failureAgentMsg), nil
|
|
||||||
case agentSuccess:
|
|
||||||
return new(successAgentMsg), nil
|
|
||||||
case agentIdentitiesAnswer:
|
|
||||||
msg = new(identitiesAnswerAgentMsg)
|
|
||||||
case agentSignResponse:
|
|
||||||
msg = new(signResponseAgentMsg)
|
|
||||||
case agentV1IdentitiesAnswer:
|
|
||||||
msg = new(agentV1IdentityMsg)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("agent: unknown type tag %d", packet[0])
|
|
||||||
}
|
|
||||||
if err := ssh.Unmarshal(packet, msg); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type rsaKeyMsg struct {
|
|
||||||
Type string `sshtype:"17|25"`
|
|
||||||
N *big.Int
|
|
||||||
E *big.Int
|
|
||||||
D *big.Int
|
|
||||||
Iqmp *big.Int // IQMP = Inverse Q Mod P
|
|
||||||
P *big.Int
|
|
||||||
Q *big.Int
|
|
||||||
Comments string
|
|
||||||
Constraints []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type dsaKeyMsg struct {
|
|
||||||
Type string `sshtype:"17|25"`
|
|
||||||
P *big.Int
|
|
||||||
Q *big.Int
|
|
||||||
G *big.Int
|
|
||||||
Y *big.Int
|
|
||||||
X *big.Int
|
|
||||||
Comments string
|
|
||||||
Constraints []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ecdsaKeyMsg struct {
|
|
||||||
Type string `sshtype:"17|25"`
|
|
||||||
Curve string
|
|
||||||
KeyBytes []byte
|
|
||||||
D *big.Int
|
|
||||||
Comments string
|
|
||||||
Constraints []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ed25519KeyMsg struct {
|
|
||||||
Type string `sshtype:"17|25"`
|
|
||||||
Pub []byte
|
|
||||||
Priv []byte
|
|
||||||
Comments string
|
|
||||||
Constraints []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert adds a private key to the agent.
|
|
||||||
func (c *client) insertKey(s interface{}, comment string, constraints []byte) error {
|
|
||||||
var req []byte
|
|
||||||
switch k := s.(type) {
|
|
||||||
case *rsa.PrivateKey:
|
|
||||||
if len(k.Primes) != 2 {
|
|
||||||
return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes))
|
|
||||||
}
|
|
||||||
k.Precompute()
|
|
||||||
req = ssh.Marshal(rsaKeyMsg{
|
|
||||||
Type: ssh.KeyAlgoRSA,
|
|
||||||
N: k.N,
|
|
||||||
E: big.NewInt(int64(k.E)),
|
|
||||||
D: k.D,
|
|
||||||
Iqmp: k.Precomputed.Qinv,
|
|
||||||
P: k.Primes[0],
|
|
||||||
Q: k.Primes[1],
|
|
||||||
Comments: comment,
|
|
||||||
Constraints: constraints,
|
|
||||||
})
|
|
||||||
case *dsa.PrivateKey:
|
|
||||||
req = ssh.Marshal(dsaKeyMsg{
|
|
||||||
Type: ssh.KeyAlgoDSA,
|
|
||||||
P: k.P,
|
|
||||||
Q: k.Q,
|
|
||||||
G: k.G,
|
|
||||||
Y: k.Y,
|
|
||||||
X: k.X,
|
|
||||||
Comments: comment,
|
|
||||||
Constraints: constraints,
|
|
||||||
})
|
|
||||||
case *ecdsa.PrivateKey:
|
|
||||||
nistID := fmt.Sprintf("nistp%d", k.Params().BitSize)
|
|
||||||
req = ssh.Marshal(ecdsaKeyMsg{
|
|
||||||
Type: "ecdsa-sha2-" + nistID,
|
|
||||||
Curve: nistID,
|
|
||||||
KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y),
|
|
||||||
D: k.D,
|
|
||||||
Comments: comment,
|
|
||||||
Constraints: constraints,
|
|
||||||
})
|
|
||||||
case ed25519.PrivateKey:
|
|
||||||
req = ssh.Marshal(ed25519KeyMsg{
|
|
||||||
Type: ssh.KeyAlgoED25519,
|
|
||||||
Pub: []byte(k)[32:],
|
|
||||||
Priv: []byte(k),
|
|
||||||
Comments: comment,
|
|
||||||
Constraints: constraints,
|
|
||||||
})
|
|
||||||
// This function originally supported only *ed25519.PrivateKey, however the
|
|
||||||
// general idiom is to pass ed25519.PrivateKey by value, not by pointer.
|
|
||||||
// We still support the pointer variant for backwards compatibility.
|
|
||||||
case *ed25519.PrivateKey:
|
|
||||||
req = ssh.Marshal(ed25519KeyMsg{
|
|
||||||
Type: ssh.KeyAlgoED25519,
|
|
||||||
Pub: []byte(*k)[32:],
|
|
||||||
Priv: []byte(*k),
|
|
||||||
Comments: comment,
|
|
||||||
Constraints: constraints,
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("agent: unsupported key type %T", s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// if constraints are present then the message type needs to be changed.
|
|
||||||
if len(constraints) != 0 {
|
|
||||||
req[0] = agentAddIDConstrained
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := c.call(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, ok := resp.(*successAgentMsg); ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return errors.New("agent: failure")
|
|
||||||
}
|
|
||||||
|
|
||||||
type rsaCertMsg struct {
|
|
||||||
Type string `sshtype:"17|25"`
|
|
||||||
CertBytes []byte
|
|
||||||
D *big.Int
|
|
||||||
Iqmp *big.Int // IQMP = Inverse Q Mod P
|
|
||||||
P *big.Int
|
|
||||||
Q *big.Int
|
|
||||||
Comments string
|
|
||||||
Constraints []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type dsaCertMsg struct {
|
|
||||||
Type string `sshtype:"17|25"`
|
|
||||||
CertBytes []byte
|
|
||||||
X *big.Int
|
|
||||||
Comments string
|
|
||||||
Constraints []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ecdsaCertMsg struct {
|
|
||||||
Type string `sshtype:"17|25"`
|
|
||||||
CertBytes []byte
|
|
||||||
D *big.Int
|
|
||||||
Comments string
|
|
||||||
Constraints []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ed25519CertMsg struct {
|
|
||||||
Type string `sshtype:"17|25"`
|
|
||||||
CertBytes []byte
|
|
||||||
Pub []byte
|
|
||||||
Priv []byte
|
|
||||||
Comments string
|
|
||||||
Constraints []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add adds a private key to the agent. If a certificate is given,
|
|
||||||
// that certificate is added instead as public key.
|
|
||||||
func (c *client) Add(key AddedKey) error {
|
|
||||||
var constraints []byte
|
|
||||||
|
|
||||||
if secs := key.LifetimeSecs; secs != 0 {
|
|
||||||
constraints = append(constraints, ssh.Marshal(constrainLifetimeAgentMsg{secs})...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if key.ConfirmBeforeUse {
|
|
||||||
constraints = append(constraints, agentConstrainConfirm)
|
|
||||||
}
|
|
||||||
|
|
||||||
cert := key.Certificate
|
|
||||||
if cert == nil {
|
|
||||||
return c.insertKey(key.PrivateKey, key.Comment, constraints)
|
|
||||||
}
|
|
||||||
return c.insertCert(key.PrivateKey, cert, key.Comment, constraints)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error {
|
|
||||||
var req []byte
|
|
||||||
switch k := s.(type) {
|
|
||||||
case *rsa.PrivateKey:
|
|
||||||
if len(k.Primes) != 2 {
|
|
||||||
return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes))
|
|
||||||
}
|
|
||||||
k.Precompute()
|
|
||||||
req = ssh.Marshal(rsaCertMsg{
|
|
||||||
Type: cert.Type(),
|
|
||||||
CertBytes: cert.Marshal(),
|
|
||||||
D: k.D,
|
|
||||||
Iqmp: k.Precomputed.Qinv,
|
|
||||||
P: k.Primes[0],
|
|
||||||
Q: k.Primes[1],
|
|
||||||
Comments: comment,
|
|
||||||
Constraints: constraints,
|
|
||||||
})
|
|
||||||
case *dsa.PrivateKey:
|
|
||||||
req = ssh.Marshal(dsaCertMsg{
|
|
||||||
Type: cert.Type(),
|
|
||||||
CertBytes: cert.Marshal(),
|
|
||||||
X: k.X,
|
|
||||||
Comments: comment,
|
|
||||||
Constraints: constraints,
|
|
||||||
})
|
|
||||||
case *ecdsa.PrivateKey:
|
|
||||||
req = ssh.Marshal(ecdsaCertMsg{
|
|
||||||
Type: cert.Type(),
|
|
||||||
CertBytes: cert.Marshal(),
|
|
||||||
D: k.D,
|
|
||||||
Comments: comment,
|
|
||||||
Constraints: constraints,
|
|
||||||
})
|
|
||||||
case ed25519.PrivateKey:
|
|
||||||
req = ssh.Marshal(ed25519CertMsg{
|
|
||||||
Type: cert.Type(),
|
|
||||||
CertBytes: cert.Marshal(),
|
|
||||||
Pub: []byte(k)[32:],
|
|
||||||
Priv: []byte(k),
|
|
||||||
Comments: comment,
|
|
||||||
Constraints: constraints,
|
|
||||||
})
|
|
||||||
// This function originally supported only *ed25519.PrivateKey, however the
|
|
||||||
// general idiom is to pass ed25519.PrivateKey by value, not by pointer.
|
|
||||||
// We still support the pointer variant for backwards compatibility.
|
|
||||||
case *ed25519.PrivateKey:
|
|
||||||
req = ssh.Marshal(ed25519CertMsg{
|
|
||||||
Type: cert.Type(),
|
|
||||||
CertBytes: cert.Marshal(),
|
|
||||||
Pub: []byte(*k)[32:],
|
|
||||||
Priv: []byte(*k),
|
|
||||||
Comments: comment,
|
|
||||||
Constraints: constraints,
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("agent: unsupported key type %T", s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// if constraints are present then the message type needs to be changed.
|
|
||||||
if len(constraints) != 0 {
|
|
||||||
req[0] = agentAddIDConstrained
|
|
||||||
}
|
|
||||||
|
|
||||||
signer, err := ssh.NewSignerFromKey(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !bytes.Equal(cert.Key.Marshal(), signer.PublicKey().Marshal()) {
|
|
||||||
return errors.New("agent: signer and cert have different public key")
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := c.call(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, ok := resp.(*successAgentMsg); ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return errors.New("agent: failure")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signers provides a callback for client authentication.
|
|
||||||
func (c *client) Signers() ([]ssh.Signer, error) {
|
|
||||||
keys, err := c.List()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var result []ssh.Signer
|
|
||||||
for _, k := range keys {
|
|
||||||
result = append(result, &agentKeyringSigner{c, k})
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type agentKeyringSigner struct {
|
|
||||||
agent *client
|
|
||||||
pub ssh.PublicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *agentKeyringSigner) PublicKey() ssh.PublicKey {
|
|
||||||
return s.pub
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) {
|
|
||||||
// The agent has its own entropy source, so the rand argument is ignored.
|
|
||||||
return s.agent.Sign(s.pub, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *agentKeyringSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*ssh.Signature, error) {
|
|
||||||
if algorithm == "" || algorithm == underlyingAlgo(s.pub.Type()) {
|
|
||||||
return s.Sign(rand, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
var flags SignatureFlags
|
|
||||||
switch algorithm {
|
|
||||||
case ssh.KeyAlgoRSASHA256:
|
|
||||||
flags = SignatureFlagRsaSha256
|
|
||||||
case ssh.KeyAlgoRSASHA512:
|
|
||||||
flags = SignatureFlagRsaSha512
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("agent: unsupported algorithm %q", algorithm)
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.agent.SignWithFlags(s.pub, data, flags)
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ ssh.AlgorithmSigner = &agentKeyringSigner{}
|
|
||||||
|
|
||||||
// certKeyAlgoNames is a mapping from known certificate algorithm names to the
|
|
||||||
// corresponding public key signature algorithm.
|
|
||||||
//
|
|
||||||
// This map must be kept in sync with the one in certs.go.
|
|
||||||
var certKeyAlgoNames = map[string]string{
|
|
||||||
ssh.CertAlgoRSAv01: ssh.KeyAlgoRSA,
|
|
||||||
ssh.CertAlgoRSASHA256v01: ssh.KeyAlgoRSASHA256,
|
|
||||||
ssh.CertAlgoRSASHA512v01: ssh.KeyAlgoRSASHA512,
|
|
||||||
ssh.CertAlgoDSAv01: ssh.KeyAlgoDSA,
|
|
||||||
ssh.CertAlgoECDSA256v01: ssh.KeyAlgoECDSA256,
|
|
||||||
ssh.CertAlgoECDSA384v01: ssh.KeyAlgoECDSA384,
|
|
||||||
ssh.CertAlgoECDSA521v01: ssh.KeyAlgoECDSA521,
|
|
||||||
ssh.CertAlgoSKECDSA256v01: ssh.KeyAlgoSKECDSA256,
|
|
||||||
ssh.CertAlgoED25519v01: ssh.KeyAlgoED25519,
|
|
||||||
ssh.CertAlgoSKED25519v01: ssh.KeyAlgoSKED25519,
|
|
||||||
}
|
|
||||||
|
|
||||||
// underlyingAlgo returns the signature algorithm associated with algo (which is
|
|
||||||
// an advertised or negotiated public key or host key algorithm). These are
|
|
||||||
// usually the same, except for certificate algorithms.
|
|
||||||
func underlyingAlgo(algo string) string {
|
|
||||||
if a, ok := certKeyAlgoNames[algo]; ok {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return algo
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calls an extension method. It is up to the agent implementation as to whether or not
|
|
||||||
// any particular extension is supported and may always return an error. Because the
|
|
||||||
// type of the response is up to the implementation, this returns the bytes of the
|
|
||||||
// response and does not attempt any type of unmarshalling.
|
|
||||||
func (c *client) Extension(extensionType string, contents []byte) ([]byte, error) {
|
|
||||||
req := ssh.Marshal(extensionAgentMsg{
|
|
||||||
ExtensionType: extensionType,
|
|
||||||
Contents: contents,
|
|
||||||
})
|
|
||||||
buf, err := c.callRaw(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(buf) == 0 {
|
|
||||||
return nil, errors.New("agent: failure; empty response")
|
|
||||||
}
|
|
||||||
// [PROTOCOL.agent] section 4.7 indicates that an SSH_AGENT_FAILURE message
|
|
||||||
// represents an agent that does not support the extension
|
|
||||||
if buf[0] == agentFailure {
|
|
||||||
return nil, ErrExtensionUnsupported
|
|
||||||
}
|
|
||||||
if buf[0] == agentExtensionFailure {
|
|
||||||
return nil, errors.New("agent: generic extension failure")
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf, nil
|
|
||||||
}
|
|
@ -1,103 +0,0 @@
|
|||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RequestAgentForwarding sets up agent forwarding for the session.
|
|
||||||
// ForwardToAgent or ForwardToRemote should be called to route
|
|
||||||
// the authentication requests.
|
|
||||||
func RequestAgentForwarding(session *ssh.Session) error {
|
|
||||||
ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return errors.New("forwarding request denied")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForwardToAgent routes authentication requests to the given keyring.
|
|
||||||
func ForwardToAgent(client *ssh.Client, keyring Agent) error {
|
|
||||||
channels := client.HandleChannelOpen(channelType)
|
|
||||||
if channels == nil {
|
|
||||||
return errors.New("agent: already have handler for " + channelType)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for ch := range channels {
|
|
||||||
channel, reqs, err := ch.Accept()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go ssh.DiscardRequests(reqs)
|
|
||||||
go func() {
|
|
||||||
ServeAgent(keyring, channel)
|
|
||||||
channel.Close()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const channelType = "auth-agent@openssh.com"
|
|
||||||
|
|
||||||
// ForwardToRemote routes authentication requests to the ssh-agent
|
|
||||||
// process serving on the given unix socket.
|
|
||||||
func ForwardToRemote(client *ssh.Client, addr string) error {
|
|
||||||
channels := client.HandleChannelOpen(channelType)
|
|
||||||
if channels == nil {
|
|
||||||
return errors.New("agent: already have handler for " + channelType)
|
|
||||||
}
|
|
||||||
conn, err := net.Dial("unix", addr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for ch := range channels {
|
|
||||||
channel, reqs, err := ch.Accept()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go ssh.DiscardRequests(reqs)
|
|
||||||
go forwardUnixSocket(channel, addr)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func forwardUnixSocket(channel ssh.Channel, addr string) {
|
|
||||||
conn, err := net.Dial("unix", addr)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(2)
|
|
||||||
go func() {
|
|
||||||
io.Copy(conn, channel)
|
|
||||||
conn.(*net.UnixConn).CloseWrite()
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
io.Copy(channel, conn)
|
|
||||||
channel.CloseWrite()
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
conn.Close()
|
|
||||||
channel.Close()
|
|
||||||
}
|
|
@ -1,241 +0,0 @@
|
|||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/subtle"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
type privKey struct {
|
|
||||||
signer ssh.Signer
|
|
||||||
comment string
|
|
||||||
expire *time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type keyring struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
keys []privKey
|
|
||||||
|
|
||||||
locked bool
|
|
||||||
passphrase []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var errLocked = errors.New("agent: locked")
|
|
||||||
|
|
||||||
// NewKeyring returns an Agent that holds keys in memory. It is safe
|
|
||||||
// for concurrent use by multiple goroutines.
|
|
||||||
func NewKeyring() Agent {
|
|
||||||
return &keyring{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveAll removes all identities.
|
|
||||||
func (r *keyring) RemoveAll() error {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if r.locked {
|
|
||||||
return errLocked
|
|
||||||
}
|
|
||||||
|
|
||||||
r.keys = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeLocked does the actual key removal. The caller must already be holding the
|
|
||||||
// keyring mutex.
|
|
||||||
func (r *keyring) removeLocked(want []byte) error {
|
|
||||||
found := false
|
|
||||||
for i := 0; i < len(r.keys); {
|
|
||||||
if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) {
|
|
||||||
found = true
|
|
||||||
r.keys[i] = r.keys[len(r.keys)-1]
|
|
||||||
r.keys = r.keys[:len(r.keys)-1]
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
return errors.New("agent: key not found")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove removes all identities with the given public key.
|
|
||||||
func (r *keyring) Remove(key ssh.PublicKey) error {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if r.locked {
|
|
||||||
return errLocked
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.removeLocked(key.Marshal())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lock locks the agent. Sign and Remove will fail, and List will return an empty list.
|
|
||||||
func (r *keyring) Lock(passphrase []byte) error {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if r.locked {
|
|
||||||
return errLocked
|
|
||||||
}
|
|
||||||
|
|
||||||
r.locked = true
|
|
||||||
r.passphrase = passphrase
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unlock undoes the effect of Lock
|
|
||||||
func (r *keyring) Unlock(passphrase []byte) error {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if !r.locked {
|
|
||||||
return errors.New("agent: not locked")
|
|
||||||
}
|
|
||||||
if 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) {
|
|
||||||
return fmt.Errorf("agent: incorrect passphrase")
|
|
||||||
}
|
|
||||||
|
|
||||||
r.locked = false
|
|
||||||
r.passphrase = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// expireKeysLocked removes expired keys from the keyring. If a key was added
|
|
||||||
// with a lifetimesecs contraint and seconds >= lifetimesecs seconds have
|
|
||||||
// elapsed, it is removed. The caller *must* be holding the keyring mutex.
|
|
||||||
func (r *keyring) expireKeysLocked() {
|
|
||||||
for _, k := range r.keys {
|
|
||||||
if k.expire != nil && time.Now().After(*k.expire) {
|
|
||||||
r.removeLocked(k.signer.PublicKey().Marshal())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// List returns the identities known to the agent.
|
|
||||||
func (r *keyring) List() ([]*Key, error) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if r.locked {
|
|
||||||
// section 2.7: locked agents return empty.
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
r.expireKeysLocked()
|
|
||||||
var ids []*Key
|
|
||||||
for _, k := range r.keys {
|
|
||||||
pub := k.signer.PublicKey()
|
|
||||||
ids = append(ids, &Key{
|
|
||||||
Format: pub.Type(),
|
|
||||||
Blob: pub.Marshal(),
|
|
||||||
Comment: k.comment})
|
|
||||||
}
|
|
||||||
return ids, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert adds a private key to the keyring. If a certificate
|
|
||||||
// is given, that certificate is added as public key. Note that
|
|
||||||
// any constraints given are ignored.
|
|
||||||
func (r *keyring) Add(key AddedKey) error {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if r.locked {
|
|
||||||
return errLocked
|
|
||||||
}
|
|
||||||
signer, err := ssh.NewSignerFromKey(key.PrivateKey)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if cert := key.Certificate; cert != nil {
|
|
||||||
signer, err = ssh.NewCertSigner(cert, signer)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
p := privKey{
|
|
||||||
signer: signer,
|
|
||||||
comment: key.Comment,
|
|
||||||
}
|
|
||||||
|
|
||||||
if key.LifetimeSecs > 0 {
|
|
||||||
t := time.Now().Add(time.Duration(key.LifetimeSecs) * time.Second)
|
|
||||||
p.expire = &t
|
|
||||||
}
|
|
||||||
|
|
||||||
r.keys = append(r.keys, p)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sign returns a signature for the data.
|
|
||||||
func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
|
|
||||||
return r.SignWithFlags(key, data, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *keyring) SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if r.locked {
|
|
||||||
return nil, errLocked
|
|
||||||
}
|
|
||||||
|
|
||||||
r.expireKeysLocked()
|
|
||||||
wanted := key.Marshal()
|
|
||||||
for _, k := range r.keys {
|
|
||||||
if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) {
|
|
||||||
if flags == 0 {
|
|
||||||
return k.signer.Sign(rand.Reader, data)
|
|
||||||
} else {
|
|
||||||
if algorithmSigner, ok := k.signer.(ssh.AlgorithmSigner); !ok {
|
|
||||||
return nil, fmt.Errorf("agent: signature does not support non-default signature algorithm: %T", k.signer)
|
|
||||||
} else {
|
|
||||||
var algorithm string
|
|
||||||
switch flags {
|
|
||||||
case SignatureFlagRsaSha256:
|
|
||||||
algorithm = ssh.KeyAlgoRSASHA256
|
|
||||||
case SignatureFlagRsaSha512:
|
|
||||||
algorithm = ssh.KeyAlgoRSASHA512
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("agent: unsupported signature flags: %d", flags)
|
|
||||||
}
|
|
||||||
return algorithmSigner.SignWithAlgorithm(rand.Reader, data, algorithm)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, errors.New("not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signers returns signers for all the known keys.
|
|
||||||
func (r *keyring) Signers() ([]ssh.Signer, error) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
if r.locked {
|
|
||||||
return nil, errLocked
|
|
||||||
}
|
|
||||||
|
|
||||||
r.expireKeysLocked()
|
|
||||||
s := make([]ssh.Signer, 0, len(r.keys))
|
|
||||||
for _, k := range r.keys {
|
|
||||||
s = append(s, k.signer)
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// The keyring does not support any extensions
|
|
||||||
func (r *keyring) Extension(extensionType string, contents []byte) ([]byte, error) {
|
|
||||||
return nil, ErrExtensionUnsupported
|
|
||||||
}
|
|
@ -1,570 +0,0 @@
|
|||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package agent
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/dsa"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/ed25519"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"math/big"
|
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
// server wraps an Agent and uses it to implement the agent side of
|
|
||||||
// the SSH-agent, wire protocol.
|
|
||||||
type server struct {
|
|
||||||
agent Agent
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) processRequestBytes(reqData []byte) []byte {
|
|
||||||
rep, err := s.processRequest(reqData)
|
|
||||||
if err != nil {
|
|
||||||
if err != errLocked {
|
|
||||||
// TODO(hanwen): provide better logging interface?
|
|
||||||
log.Printf("agent %d: %v", reqData[0], err)
|
|
||||||
}
|
|
||||||
return []byte{agentFailure}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err == nil && rep == nil {
|
|
||||||
return []byte{agentSuccess}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ssh.Marshal(rep)
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalKey(k *Key) []byte {
|
|
||||||
var record struct {
|
|
||||||
Blob []byte
|
|
||||||
Comment string
|
|
||||||
}
|
|
||||||
record.Blob = k.Marshal()
|
|
||||||
record.Comment = k.Comment
|
|
||||||
|
|
||||||
return ssh.Marshal(&record)
|
|
||||||
}
|
|
||||||
|
|
||||||
// See [PROTOCOL.agent], section 2.5.1.
|
|
||||||
const agentV1IdentitiesAnswer = 2
|
|
||||||
|
|
||||||
type agentV1IdentityMsg struct {
|
|
||||||
Numkeys uint32 `sshtype:"2"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type agentRemoveIdentityMsg struct {
|
|
||||||
KeyBlob []byte `sshtype:"18"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type agentLockMsg struct {
|
|
||||||
Passphrase []byte `sshtype:"22"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type agentUnlockMsg struct {
|
|
||||||
Passphrase []byte `sshtype:"23"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) processRequest(data []byte) (interface{}, error) {
|
|
||||||
switch data[0] {
|
|
||||||
case agentRequestV1Identities:
|
|
||||||
return &agentV1IdentityMsg{0}, nil
|
|
||||||
|
|
||||||
case agentRemoveAllV1Identities:
|
|
||||||
return nil, nil
|
|
||||||
|
|
||||||
case agentRemoveIdentity:
|
|
||||||
var req agentRemoveIdentityMsg
|
|
||||||
if err := ssh.Unmarshal(data, &req); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var wk wireKey
|
|
||||||
if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob})
|
|
||||||
|
|
||||||
case agentRemoveAllIdentities:
|
|
||||||
return nil, s.agent.RemoveAll()
|
|
||||||
|
|
||||||
case agentLock:
|
|
||||||
var req agentLockMsg
|
|
||||||
if err := ssh.Unmarshal(data, &req); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, s.agent.Lock(req.Passphrase)
|
|
||||||
|
|
||||||
case agentUnlock:
|
|
||||||
var req agentUnlockMsg
|
|
||||||
if err := ssh.Unmarshal(data, &req); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return nil, s.agent.Unlock(req.Passphrase)
|
|
||||||
|
|
||||||
case agentSignRequest:
|
|
||||||
var req signRequestAgentMsg
|
|
||||||
if err := ssh.Unmarshal(data, &req); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var wk wireKey
|
|
||||||
if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
k := &Key{
|
|
||||||
Format: wk.Format,
|
|
||||||
Blob: req.KeyBlob,
|
|
||||||
}
|
|
||||||
|
|
||||||
var sig *ssh.Signature
|
|
||||||
var err error
|
|
||||||
if extendedAgent, ok := s.agent.(ExtendedAgent); ok {
|
|
||||||
sig, err = extendedAgent.SignWithFlags(k, req.Data, SignatureFlags(req.Flags))
|
|
||||||
} else {
|
|
||||||
sig, err = s.agent.Sign(k, req.Data)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil
|
|
||||||
|
|
||||||
case agentRequestIdentities:
|
|
||||||
keys, err := s.agent.List()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rep := identitiesAnswerAgentMsg{
|
|
||||||
NumKeys: uint32(len(keys)),
|
|
||||||
}
|
|
||||||
for _, k := range keys {
|
|
||||||
rep.Keys = append(rep.Keys, marshalKey(k)...)
|
|
||||||
}
|
|
||||||
return rep, nil
|
|
||||||
|
|
||||||
case agentAddIDConstrained, agentAddIdentity:
|
|
||||||
return nil, s.insertIdentity(data)
|
|
||||||
|
|
||||||
case agentExtension:
|
|
||||||
// Return a stub object where the whole contents of the response gets marshaled.
|
|
||||||
var responseStub struct {
|
|
||||||
Rest []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if extendedAgent, ok := s.agent.(ExtendedAgent); !ok {
|
|
||||||
// If this agent doesn't implement extensions, [PROTOCOL.agent] section 4.7
|
|
||||||
// requires that we return a standard SSH_AGENT_FAILURE message.
|
|
||||||
responseStub.Rest = []byte{agentFailure}
|
|
||||||
} else {
|
|
||||||
var req extensionAgentMsg
|
|
||||||
if err := ssh.Unmarshal(data, &req); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
res, err := extendedAgent.Extension(req.ExtensionType, req.Contents)
|
|
||||||
if err != nil {
|
|
||||||
// If agent extensions are unsupported, return a standard SSH_AGENT_FAILURE
|
|
||||||
// message as required by [PROTOCOL.agent] section 4.7.
|
|
||||||
if err == ErrExtensionUnsupported {
|
|
||||||
responseStub.Rest = []byte{agentFailure}
|
|
||||||
} else {
|
|
||||||
// As the result of any other error processing an extension request,
|
|
||||||
// [PROTOCOL.agent] section 4.7 requires that we return a
|
|
||||||
// SSH_AGENT_EXTENSION_FAILURE code.
|
|
||||||
responseStub.Rest = []byte{agentExtensionFailure}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if len(res) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
responseStub.Rest = res
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return responseStub, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("unknown opcode %d", data[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseConstraints(constraints []byte) (lifetimeSecs uint32, confirmBeforeUse bool, extensions []ConstraintExtension, err error) {
|
|
||||||
for len(constraints) != 0 {
|
|
||||||
switch constraints[0] {
|
|
||||||
case agentConstrainLifetime:
|
|
||||||
lifetimeSecs = binary.BigEndian.Uint32(constraints[1:5])
|
|
||||||
constraints = constraints[5:]
|
|
||||||
case agentConstrainConfirm:
|
|
||||||
confirmBeforeUse = true
|
|
||||||
constraints = constraints[1:]
|
|
||||||
case agentConstrainExtension, agentConstrainExtensionV00:
|
|
||||||
var msg constrainExtensionAgentMsg
|
|
||||||
if err = ssh.Unmarshal(constraints, &msg); err != nil {
|
|
||||||
return 0, false, nil, err
|
|
||||||
}
|
|
||||||
extensions = append(extensions, ConstraintExtension{
|
|
||||||
ExtensionName: msg.ExtensionName,
|
|
||||||
ExtensionDetails: msg.ExtensionDetails,
|
|
||||||
})
|
|
||||||
constraints = msg.Rest
|
|
||||||
default:
|
|
||||||
return 0, false, nil, fmt.Errorf("unknown constraint type: %d", constraints[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func setConstraints(key *AddedKey, constraintBytes []byte) error {
|
|
||||||
lifetimeSecs, confirmBeforeUse, constraintExtensions, err := parseConstraints(constraintBytes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
key.LifetimeSecs = lifetimeSecs
|
|
||||||
key.ConfirmBeforeUse = confirmBeforeUse
|
|
||||||
key.ConstraintExtensions = constraintExtensions
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseRSAKey(req []byte) (*AddedKey, error) {
|
|
||||||
var k rsaKeyMsg
|
|
||||||
if err := ssh.Unmarshal(req, &k); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if k.E.BitLen() > 30 {
|
|
||||||
return nil, errors.New("agent: RSA public exponent too large")
|
|
||||||
}
|
|
||||||
priv := &rsa.PrivateKey{
|
|
||||||
PublicKey: rsa.PublicKey{
|
|
||||||
E: int(k.E.Int64()),
|
|
||||||
N: k.N,
|
|
||||||
},
|
|
||||||
D: k.D,
|
|
||||||
Primes: []*big.Int{k.P, k.Q},
|
|
||||||
}
|
|
||||||
priv.Precompute()
|
|
||||||
|
|
||||||
addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
|
|
||||||
if err := setConstraints(addedKey, k.Constraints); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return addedKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseEd25519Key(req []byte) (*AddedKey, error) {
|
|
||||||
var k ed25519KeyMsg
|
|
||||||
if err := ssh.Unmarshal(req, &k); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
priv := ed25519.PrivateKey(k.Priv)
|
|
||||||
|
|
||||||
addedKey := &AddedKey{PrivateKey: &priv, Comment: k.Comments}
|
|
||||||
if err := setConstraints(addedKey, k.Constraints); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return addedKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseDSAKey(req []byte) (*AddedKey, error) {
|
|
||||||
var k dsaKeyMsg
|
|
||||||
if err := ssh.Unmarshal(req, &k); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
priv := &dsa.PrivateKey{
|
|
||||||
PublicKey: dsa.PublicKey{
|
|
||||||
Parameters: dsa.Parameters{
|
|
||||||
P: k.P,
|
|
||||||
Q: k.Q,
|
|
||||||
G: k.G,
|
|
||||||
},
|
|
||||||
Y: k.Y,
|
|
||||||
},
|
|
||||||
X: k.X,
|
|
||||||
}
|
|
||||||
|
|
||||||
addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
|
|
||||||
if err := setConstraints(addedKey, k.Constraints); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return addedKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) {
|
|
||||||
priv = &ecdsa.PrivateKey{
|
|
||||||
D: privScalar,
|
|
||||||
}
|
|
||||||
|
|
||||||
switch curveName {
|
|
||||||
case "nistp256":
|
|
||||||
priv.Curve = elliptic.P256()
|
|
||||||
case "nistp384":
|
|
||||||
priv.Curve = elliptic.P384()
|
|
||||||
case "nistp521":
|
|
||||||
priv.Curve = elliptic.P521()
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("agent: unknown curve %q", curveName)
|
|
||||||
}
|
|
||||||
|
|
||||||
priv.X, priv.Y = elliptic.Unmarshal(priv.Curve, keyBytes)
|
|
||||||
if priv.X == nil || priv.Y == nil {
|
|
||||||
return nil, errors.New("agent: point not on curve")
|
|
||||||
}
|
|
||||||
|
|
||||||
return priv, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseEd25519Cert(req []byte) (*AddedKey, error) {
|
|
||||||
var k ed25519CertMsg
|
|
||||||
if err := ssh.Unmarshal(req, &k); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
pubKey, err := ssh.ParsePublicKey(k.CertBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
priv := ed25519.PrivateKey(k.Priv)
|
|
||||||
cert, ok := pubKey.(*ssh.Certificate)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("agent: bad ED25519 certificate")
|
|
||||||
}
|
|
||||||
|
|
||||||
addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
|
|
||||||
if err := setConstraints(addedKey, k.Constraints); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return addedKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseECDSAKey(req []byte) (*AddedKey, error) {
|
|
||||||
var k ecdsaKeyMsg
|
|
||||||
if err := ssh.Unmarshal(req, &k); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
priv, err := unmarshalECDSA(k.Curve, k.KeyBytes, k.D)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
|
|
||||||
if err := setConstraints(addedKey, k.Constraints); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return addedKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseRSACert(req []byte) (*AddedKey, error) {
|
|
||||||
var k rsaCertMsg
|
|
||||||
if err := ssh.Unmarshal(req, &k); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pubKey, err := ssh.ParsePublicKey(k.CertBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, ok := pubKey.(*ssh.Certificate)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("agent: bad RSA certificate")
|
|
||||||
}
|
|
||||||
|
|
||||||
// An RSA publickey as marshaled by rsaPublicKey.Marshal() in keys.go
|
|
||||||
var rsaPub struct {
|
|
||||||
Name string
|
|
||||||
E *big.Int
|
|
||||||
N *big.Int
|
|
||||||
}
|
|
||||||
if err := ssh.Unmarshal(cert.Key.Marshal(), &rsaPub); err != nil {
|
|
||||||
return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rsaPub.E.BitLen() > 30 {
|
|
||||||
return nil, errors.New("agent: RSA public exponent too large")
|
|
||||||
}
|
|
||||||
|
|
||||||
priv := rsa.PrivateKey{
|
|
||||||
PublicKey: rsa.PublicKey{
|
|
||||||
E: int(rsaPub.E.Int64()),
|
|
||||||
N: rsaPub.N,
|
|
||||||
},
|
|
||||||
D: k.D,
|
|
||||||
Primes: []*big.Int{k.Q, k.P},
|
|
||||||
}
|
|
||||||
priv.Precompute()
|
|
||||||
|
|
||||||
addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
|
|
||||||
if err := setConstraints(addedKey, k.Constraints); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return addedKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseDSACert(req []byte) (*AddedKey, error) {
|
|
||||||
var k dsaCertMsg
|
|
||||||
if err := ssh.Unmarshal(req, &k); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
pubKey, err := ssh.ParsePublicKey(k.CertBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
cert, ok := pubKey.(*ssh.Certificate)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("agent: bad DSA certificate")
|
|
||||||
}
|
|
||||||
|
|
||||||
// A DSA publickey as marshaled by dsaPublicKey.Marshal() in keys.go
|
|
||||||
var w struct {
|
|
||||||
Name string
|
|
||||||
P, Q, G, Y *big.Int
|
|
||||||
}
|
|
||||||
if err := ssh.Unmarshal(cert.Key.Marshal(), &w); err != nil {
|
|
||||||
return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
priv := &dsa.PrivateKey{
|
|
||||||
PublicKey: dsa.PublicKey{
|
|
||||||
Parameters: dsa.Parameters{
|
|
||||||
P: w.P,
|
|
||||||
Q: w.Q,
|
|
||||||
G: w.G,
|
|
||||||
},
|
|
||||||
Y: w.Y,
|
|
||||||
},
|
|
||||||
X: k.X,
|
|
||||||
}
|
|
||||||
|
|
||||||
addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
|
|
||||||
if err := setConstraints(addedKey, k.Constraints); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return addedKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseECDSACert(req []byte) (*AddedKey, error) {
|
|
||||||
var k ecdsaCertMsg
|
|
||||||
if err := ssh.Unmarshal(req, &k); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pubKey, err := ssh.ParsePublicKey(k.CertBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
cert, ok := pubKey.(*ssh.Certificate)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("agent: bad ECDSA certificate")
|
|
||||||
}
|
|
||||||
|
|
||||||
// An ECDSA publickey as marshaled by ecdsaPublicKey.Marshal() in keys.go
|
|
||||||
var ecdsaPub struct {
|
|
||||||
Name string
|
|
||||||
ID string
|
|
||||||
Key []byte
|
|
||||||
}
|
|
||||||
if err := ssh.Unmarshal(cert.Key.Marshal(), &ecdsaPub); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
priv, err := unmarshalECDSA(ecdsaPub.ID, ecdsaPub.Key, k.D)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
|
|
||||||
if err := setConstraints(addedKey, k.Constraints); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return addedKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) insertIdentity(req []byte) error {
|
|
||||||
var record struct {
|
|
||||||
Type string `sshtype:"17|25"`
|
|
||||||
Rest []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ssh.Unmarshal(req, &record); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var addedKey *AddedKey
|
|
||||||
var err error
|
|
||||||
|
|
||||||
switch record.Type {
|
|
||||||
case ssh.KeyAlgoRSA:
|
|
||||||
addedKey, err = parseRSAKey(req)
|
|
||||||
case ssh.KeyAlgoDSA:
|
|
||||||
addedKey, err = parseDSAKey(req)
|
|
||||||
case ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521:
|
|
||||||
addedKey, err = parseECDSAKey(req)
|
|
||||||
case ssh.KeyAlgoED25519:
|
|
||||||
addedKey, err = parseEd25519Key(req)
|
|
||||||
case ssh.CertAlgoRSAv01:
|
|
||||||
addedKey, err = parseRSACert(req)
|
|
||||||
case ssh.CertAlgoDSAv01:
|
|
||||||
addedKey, err = parseDSACert(req)
|
|
||||||
case ssh.CertAlgoECDSA256v01, ssh.CertAlgoECDSA384v01, ssh.CertAlgoECDSA521v01:
|
|
||||||
addedKey, err = parseECDSACert(req)
|
|
||||||
case ssh.CertAlgoED25519v01:
|
|
||||||
addedKey, err = parseEd25519Cert(req)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("agent: not implemented: %q", record.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.agent.Add(*addedKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeAgent serves the agent protocol on the given connection. It
|
|
||||||
// returns when an I/O error occurs.
|
|
||||||
func ServeAgent(agent Agent, c io.ReadWriter) error {
|
|
||||||
s := &server{agent}
|
|
||||||
|
|
||||||
var length [4]byte
|
|
||||||
for {
|
|
||||||
if _, err := io.ReadFull(c, length[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
l := binary.BigEndian.Uint32(length[:])
|
|
||||||
if l == 0 {
|
|
||||||
return fmt.Errorf("agent: request size is 0")
|
|
||||||
}
|
|
||||||
if l > maxAgentResponseBytes {
|
|
||||||
// We also cap requests.
|
|
||||||
return fmt.Errorf("agent: request too large: %d", l)
|
|
||||||
}
|
|
||||||
|
|
||||||
req := make([]byte, l)
|
|
||||||
if _, err := io.ReadFull(c, req); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
repData := s.processRequestBytes(req)
|
|
||||||
if len(repData) > maxAgentResponseBytes {
|
|
||||||
return fmt.Errorf("agent: reply too large: %d bytes", len(repData))
|
|
||||||
}
|
|
||||||
|
|
||||||
binary.BigEndian.PutUint32(length[:], uint32(len(repData)))
|
|
||||||
if _, err := c.Write(length[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := c.Write(repData); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,97 +0,0 @@
|
|||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// buffer provides a linked list buffer for data exchange
|
|
||||||
// between producer and consumer. Theoretically the buffer is
|
|
||||||
// of unlimited capacity as it does no allocation of its own.
|
|
||||||
type buffer struct {
|
|
||||||
// protects concurrent access to head, tail and closed
|
|
||||||
*sync.Cond
|
|
||||||
|
|
||||||
head *element // the buffer that will be read first
|
|
||||||
tail *element // the buffer that will be read last
|
|
||||||
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// An element represents a single link in a linked list.
|
|
||||||
type element struct {
|
|
||||||
buf []byte
|
|
||||||
next *element
|
|
||||||
}
|
|
||||||
|
|
||||||
// newBuffer returns an empty buffer that is not closed.
|
|
||||||
func newBuffer() *buffer {
|
|
||||||
e := new(element)
|
|
||||||
b := &buffer{
|
|
||||||
Cond: newCond(),
|
|
||||||
head: e,
|
|
||||||
tail: e,
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
// write makes buf available for Read to receive.
|
|
||||||
// buf must not be modified after the call to write.
|
|
||||||
func (b *buffer) write(buf []byte) {
|
|
||||||
b.Cond.L.Lock()
|
|
||||||
e := &element{buf: buf}
|
|
||||||
b.tail.next = e
|
|
||||||
b.tail = e
|
|
||||||
b.Cond.Signal()
|
|
||||||
b.Cond.L.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// eof closes the buffer. Reads from the buffer once all
|
|
||||||
// the data has been consumed will receive io.EOF.
|
|
||||||
func (b *buffer) eof() {
|
|
||||||
b.Cond.L.Lock()
|
|
||||||
b.closed = true
|
|
||||||
b.Cond.Signal()
|
|
||||||
b.Cond.L.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read reads data from the internal buffer in buf. Reads will block
|
|
||||||
// if no data is available, or until the buffer is closed.
|
|
||||||
func (b *buffer) Read(buf []byte) (n int, err error) {
|
|
||||||
b.Cond.L.Lock()
|
|
||||||
defer b.Cond.L.Unlock()
|
|
||||||
|
|
||||||
for len(buf) > 0 {
|
|
||||||
// if there is data in b.head, copy it
|
|
||||||
if len(b.head.buf) > 0 {
|
|
||||||
r := copy(buf, b.head.buf)
|
|
||||||
buf, b.head.buf = buf[r:], b.head.buf[r:]
|
|
||||||
n += r
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// if there is a next buffer, make it the head
|
|
||||||
if len(b.head.buf) == 0 && b.head != b.tail {
|
|
||||||
b.head = b.head.next
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// if at least one byte has been copied, return
|
|
||||||
if n > 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// if nothing was read, and there is nothing outstanding
|
|
||||||
// check to see if the buffer is closed.
|
|
||||||
if b.closed {
|
|
||||||
err = io.EOF
|
|
||||||
break
|
|
||||||
}
|
|
||||||
// out of buffers, wait for producer
|
|
||||||
b.Cond.Wait()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
@ -1,611 +0,0 @@
|
|||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"sort"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Certificate algorithm names from [PROTOCOL.certkeys]. These values can appear
|
|
||||||
// in Certificate.Type, PublicKey.Type, and ClientConfig.HostKeyAlgorithms.
|
|
||||||
// Unlike key algorithm names, these are not passed to AlgorithmSigner nor
|
|
||||||
// returned by MultiAlgorithmSigner and don't appear in the Signature.Format
|
|
||||||
// field.
|
|
||||||
const (
|
|
||||||
CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com"
|
|
||||||
CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com"
|
|
||||||
CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com"
|
|
||||||
CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com"
|
|
||||||
CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com"
|
|
||||||
CertAlgoSKECDSA256v01 = "sk-ecdsa-sha2-nistp256-cert-v01@openssh.com"
|
|
||||||
CertAlgoED25519v01 = "ssh-ed25519-cert-v01@openssh.com"
|
|
||||||
CertAlgoSKED25519v01 = "sk-ssh-ed25519-cert-v01@openssh.com"
|
|
||||||
|
|
||||||
// CertAlgoRSASHA256v01 and CertAlgoRSASHA512v01 can't appear as a
|
|
||||||
// Certificate.Type (or PublicKey.Type), but only in
|
|
||||||
// ClientConfig.HostKeyAlgorithms.
|
|
||||||
CertAlgoRSASHA256v01 = "rsa-sha2-256-cert-v01@openssh.com"
|
|
||||||
CertAlgoRSASHA512v01 = "rsa-sha2-512-cert-v01@openssh.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Deprecated: use CertAlgoRSAv01.
|
|
||||||
CertSigAlgoRSAv01 = CertAlgoRSAv01
|
|
||||||
// Deprecated: use CertAlgoRSASHA256v01.
|
|
||||||
CertSigAlgoRSASHA2256v01 = CertAlgoRSASHA256v01
|
|
||||||
// Deprecated: use CertAlgoRSASHA512v01.
|
|
||||||
CertSigAlgoRSASHA2512v01 = CertAlgoRSASHA512v01
|
|
||||||
)
|
|
||||||
|
|
||||||
// Certificate types distinguish between host and user
|
|
||||||
// certificates. The values can be set in the CertType field of
|
|
||||||
// Certificate.
|
|
||||||
const (
|
|
||||||
UserCert = 1
|
|
||||||
HostCert = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
// Signature represents a cryptographic signature.
|
|
||||||
type Signature struct {
|
|
||||||
Format string
|
|
||||||
Blob []byte
|
|
||||||
Rest []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that
|
|
||||||
// a certificate does not expire.
|
|
||||||
const CertTimeInfinity = 1<<64 - 1
|
|
||||||
|
|
||||||
// An Certificate represents an OpenSSH certificate as defined in
|
|
||||||
// [PROTOCOL.certkeys]?rev=1.8. The Certificate type implements the
|
|
||||||
// PublicKey interface, so it can be unmarshaled using
|
|
||||||
// ParsePublicKey.
|
|
||||||
type Certificate struct {
|
|
||||||
Nonce []byte
|
|
||||||
Key PublicKey
|
|
||||||
Serial uint64
|
|
||||||
CertType uint32
|
|
||||||
KeyId string
|
|
||||||
ValidPrincipals []string
|
|
||||||
ValidAfter uint64
|
|
||||||
ValidBefore uint64
|
|
||||||
Permissions
|
|
||||||
Reserved []byte
|
|
||||||
SignatureKey PublicKey
|
|
||||||
Signature *Signature
|
|
||||||
}
|
|
||||||
|
|
||||||
// genericCertData holds the key-independent part of the certificate data.
|
|
||||||
// Overall, certificates contain an nonce, public key fields and
|
|
||||||
// key-independent fields.
|
|
||||||
type genericCertData struct {
|
|
||||||
Serial uint64
|
|
||||||
CertType uint32
|
|
||||||
KeyId string
|
|
||||||
ValidPrincipals []byte
|
|
||||||
ValidAfter uint64
|
|
||||||
ValidBefore uint64
|
|
||||||
CriticalOptions []byte
|
|
||||||
Extensions []byte
|
|
||||||
Reserved []byte
|
|
||||||
SignatureKey []byte
|
|
||||||
Signature []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalStringList(namelist []string) []byte {
|
|
||||||
var to []byte
|
|
||||||
for _, name := range namelist {
|
|
||||||
s := struct{ N string }{name}
|
|
||||||
to = append(to, Marshal(&s)...)
|
|
||||||
}
|
|
||||||
return to
|
|
||||||
}
|
|
||||||
|
|
||||||
type optionsTuple struct {
|
|
||||||
Key string
|
|
||||||
Value []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type optionsTupleValue struct {
|
|
||||||
Value string
|
|
||||||
}
|
|
||||||
|
|
||||||
// serialize a map of critical options or extensions
|
|
||||||
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
|
|
||||||
// we need two length prefixes for a non-empty string value
|
|
||||||
func marshalTuples(tups map[string]string) []byte {
|
|
||||||
keys := make([]string, 0, len(tups))
|
|
||||||
for key := range tups {
|
|
||||||
keys = append(keys, key)
|
|
||||||
}
|
|
||||||
sort.Strings(keys)
|
|
||||||
|
|
||||||
var ret []byte
|
|
||||||
for _, key := range keys {
|
|
||||||
s := optionsTuple{Key: key}
|
|
||||||
if value := tups[key]; len(value) > 0 {
|
|
||||||
s.Value = Marshal(&optionsTupleValue{value})
|
|
||||||
}
|
|
||||||
ret = append(ret, Marshal(&s)...)
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
|
|
||||||
// we need two length prefixes for a non-empty option value
|
|
||||||
func parseTuples(in []byte) (map[string]string, error) {
|
|
||||||
tups := map[string]string{}
|
|
||||||
var lastKey string
|
|
||||||
var haveLastKey bool
|
|
||||||
|
|
||||||
for len(in) > 0 {
|
|
||||||
var key, val, extra []byte
|
|
||||||
var ok bool
|
|
||||||
|
|
||||||
if key, in, ok = parseString(in); !ok {
|
|
||||||
return nil, errShortRead
|
|
||||||
}
|
|
||||||
keyStr := string(key)
|
|
||||||
// according to [PROTOCOL.certkeys], the names must be in
|
|
||||||
// lexical order.
|
|
||||||
if haveLastKey && keyStr <= lastKey {
|
|
||||||
return nil, fmt.Errorf("ssh: certificate options are not in lexical order")
|
|
||||||
}
|
|
||||||
lastKey, haveLastKey = keyStr, true
|
|
||||||
// the next field is a data field, which if non-empty has a string embedded
|
|
||||||
if val, in, ok = parseString(in); !ok {
|
|
||||||
return nil, errShortRead
|
|
||||||
}
|
|
||||||
if len(val) > 0 {
|
|
||||||
val, extra, ok = parseString(val)
|
|
||||||
if !ok {
|
|
||||||
return nil, errShortRead
|
|
||||||
}
|
|
||||||
if len(extra) > 0 {
|
|
||||||
return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value")
|
|
||||||
}
|
|
||||||
tups[keyStr] = string(val)
|
|
||||||
} else {
|
|
||||||
tups[keyStr] = ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tups, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseCert(in []byte, privAlgo string) (*Certificate, error) {
|
|
||||||
nonce, rest, ok := parseString(in)
|
|
||||||
if !ok {
|
|
||||||
return nil, errShortRead
|
|
||||||
}
|
|
||||||
|
|
||||||
key, rest, err := parsePubKey(rest, privAlgo)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var g genericCertData
|
|
||||||
if err := Unmarshal(rest, &g); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &Certificate{
|
|
||||||
Nonce: nonce,
|
|
||||||
Key: key,
|
|
||||||
Serial: g.Serial,
|
|
||||||
CertType: g.CertType,
|
|
||||||
KeyId: g.KeyId,
|
|
||||||
ValidAfter: g.ValidAfter,
|
|
||||||
ValidBefore: g.ValidBefore,
|
|
||||||
}
|
|
||||||
|
|
||||||
for principals := g.ValidPrincipals; len(principals) > 0; {
|
|
||||||
principal, rest, ok := parseString(principals)
|
|
||||||
if !ok {
|
|
||||||
return nil, errShortRead
|
|
||||||
}
|
|
||||||
c.ValidPrincipals = append(c.ValidPrincipals, string(principal))
|
|
||||||
principals = rest
|
|
||||||
}
|
|
||||||
|
|
||||||
c.CriticalOptions, err = parseTuples(g.CriticalOptions)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.Extensions, err = parseTuples(g.Extensions)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.Reserved = g.Reserved
|
|
||||||
k, err := ParsePublicKey(g.SignatureKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.SignatureKey = k
|
|
||||||
c.Signature, rest, ok = parseSignatureBody(g.Signature)
|
|
||||||
if !ok || len(rest) > 0 {
|
|
||||||
return nil, errors.New("ssh: signature parse error")
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type openSSHCertSigner struct {
|
|
||||||
pub *Certificate
|
|
||||||
signer Signer
|
|
||||||
}
|
|
||||||
|
|
||||||
type algorithmOpenSSHCertSigner struct {
|
|
||||||
*openSSHCertSigner
|
|
||||||
algorithmSigner AlgorithmSigner
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewCertSigner returns a Signer that signs with the given Certificate, whose
|
|
||||||
// private key is held by signer. It returns an error if the public key in cert
|
|
||||||
// doesn't match the key used by signer.
|
|
||||||
func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) {
|
|
||||||
if !bytes.Equal(cert.Key.Marshal(), signer.PublicKey().Marshal()) {
|
|
||||||
return nil, errors.New("ssh: signer and cert have different public key")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch s := signer.(type) {
|
|
||||||
case MultiAlgorithmSigner:
|
|
||||||
return &multiAlgorithmSigner{
|
|
||||||
AlgorithmSigner: &algorithmOpenSSHCertSigner{
|
|
||||||
&openSSHCertSigner{cert, signer}, s},
|
|
||||||
supportedAlgorithms: s.Algorithms(),
|
|
||||||
}, nil
|
|
||||||
case AlgorithmSigner:
|
|
||||||
return &algorithmOpenSSHCertSigner{
|
|
||||||
&openSSHCertSigner{cert, signer}, s}, nil
|
|
||||||
default:
|
|
||||||
return &openSSHCertSigner{cert, signer}, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
|
|
||||||
return s.signer.Sign(rand, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *openSSHCertSigner) PublicKey() PublicKey {
|
|
||||||
return s.pub
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *algorithmOpenSSHCertSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
|
|
||||||
return s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm)
|
|
||||||
}
|
|
||||||
|
|
||||||
const sourceAddressCriticalOption = "source-address"
|
|
||||||
|
|
||||||
// CertChecker does the work of verifying a certificate. Its methods
|
|
||||||
// can be plugged into ClientConfig.HostKeyCallback and
|
|
||||||
// ServerConfig.PublicKeyCallback. For the CertChecker to work,
|
|
||||||
// minimally, the IsAuthority callback should be set.
|
|
||||||
type CertChecker struct {
|
|
||||||
// SupportedCriticalOptions lists the CriticalOptions that the
|
|
||||||
// server application layer understands. These are only used
|
|
||||||
// for user certificates.
|
|
||||||
SupportedCriticalOptions []string
|
|
||||||
|
|
||||||
// IsUserAuthority should return true if the key is recognized as an
|
|
||||||
// authority for the given user certificate. This allows for
|
|
||||||
// certificates to be signed by other certificates. This must be set
|
|
||||||
// if this CertChecker will be checking user certificates.
|
|
||||||
IsUserAuthority func(auth PublicKey) bool
|
|
||||||
|
|
||||||
// IsHostAuthority should report whether the key is recognized as
|
|
||||||
// an authority for this host. This allows for certificates to be
|
|
||||||
// signed by other keys, and for those other keys to only be valid
|
|
||||||
// signers for particular hostnames. This must be set if this
|
|
||||||
// CertChecker will be checking host certificates.
|
|
||||||
IsHostAuthority func(auth PublicKey, address string) bool
|
|
||||||
|
|
||||||
// Clock is used for verifying time stamps. If nil, time.Now
|
|
||||||
// is used.
|
|
||||||
Clock func() time.Time
|
|
||||||
|
|
||||||
// UserKeyFallback is called when CertChecker.Authenticate encounters a
|
|
||||||
// public key that is not a certificate. It must implement validation
|
|
||||||
// of user keys or else, if nil, all such keys are rejected.
|
|
||||||
UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
|
|
||||||
|
|
||||||
// HostKeyFallback is called when CertChecker.CheckHostKey encounters a
|
|
||||||
// public key that is not a certificate. It must implement host key
|
|
||||||
// validation or else, if nil, all such keys are rejected.
|
|
||||||
HostKeyFallback HostKeyCallback
|
|
||||||
|
|
||||||
// IsRevoked is called for each certificate so that revocation checking
|
|
||||||
// can be implemented. It should return true if the given certificate
|
|
||||||
// is revoked and false otherwise. If nil, no certificates are
|
|
||||||
// considered to have been revoked.
|
|
||||||
IsRevoked func(cert *Certificate) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckHostKey checks a host key certificate. This method can be
|
|
||||||
// plugged into ClientConfig.HostKeyCallback.
|
|
||||||
func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error {
|
|
||||||
cert, ok := key.(*Certificate)
|
|
||||||
if !ok {
|
|
||||||
if c.HostKeyFallback != nil {
|
|
||||||
return c.HostKeyFallback(addr, remote, key)
|
|
||||||
}
|
|
||||||
return errors.New("ssh: non-certificate host key")
|
|
||||||
}
|
|
||||||
if cert.CertType != HostCert {
|
|
||||||
return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType)
|
|
||||||
}
|
|
||||||
if !c.IsHostAuthority(cert.SignatureKey, addr) {
|
|
||||||
return fmt.Errorf("ssh: no authorities for hostname: %v", addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
hostname, _, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pass hostname only as principal for host certificates (consistent with OpenSSH)
|
|
||||||
return c.CheckCert(hostname, cert)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authenticate checks a user certificate. Authenticate can be used as
|
|
||||||
// a value for ServerConfig.PublicKeyCallback.
|
|
||||||
func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) {
|
|
||||||
cert, ok := pubKey.(*Certificate)
|
|
||||||
if !ok {
|
|
||||||
if c.UserKeyFallback != nil {
|
|
||||||
return c.UserKeyFallback(conn, pubKey)
|
|
||||||
}
|
|
||||||
return nil, errors.New("ssh: normal key pairs not accepted")
|
|
||||||
}
|
|
||||||
|
|
||||||
if cert.CertType != UserCert {
|
|
||||||
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType)
|
|
||||||
}
|
|
||||||
if !c.IsUserAuthority(cert.SignatureKey) {
|
|
||||||
return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.CheckCert(conn.User(), cert); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &cert.Permissions, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and
|
|
||||||
// the signature of the certificate.
|
|
||||||
func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
|
|
||||||
if c.IsRevoked != nil && c.IsRevoked(cert) {
|
|
||||||
return fmt.Errorf("ssh: certificate serial %d revoked", cert.Serial)
|
|
||||||
}
|
|
||||||
|
|
||||||
for opt := range cert.CriticalOptions {
|
|
||||||
// sourceAddressCriticalOption will be enforced by
|
|
||||||
// serverAuthenticate
|
|
||||||
if opt == sourceAddressCriticalOption {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
found := false
|
|
||||||
for _, supp := range c.SupportedCriticalOptions {
|
|
||||||
if supp == opt {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(cert.ValidPrincipals) > 0 {
|
|
||||||
// By default, certs are valid for all users/hosts.
|
|
||||||
found := false
|
|
||||||
for _, p := range cert.ValidPrincipals {
|
|
||||||
if p == principal {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
clock := c.Clock
|
|
||||||
if clock == nil {
|
|
||||||
clock = time.Now
|
|
||||||
}
|
|
||||||
|
|
||||||
unixNow := clock().Unix()
|
|
||||||
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
|
|
||||||
return fmt.Errorf("ssh: cert is not yet valid")
|
|
||||||
}
|
|
||||||
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) {
|
|
||||||
return fmt.Errorf("ssh: cert has expired")
|
|
||||||
}
|
|
||||||
if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil {
|
|
||||||
return fmt.Errorf("ssh: certificate signature does not verify")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignCert signs the certificate with an authority, setting the Nonce,
|
|
||||||
// SignatureKey, and Signature fields. If the authority implements the
|
|
||||||
// MultiAlgorithmSigner interface the first algorithm in the list is used. This
|
|
||||||
// is useful if you want to sign with a specific algorithm.
|
|
||||||
func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
|
|
||||||
c.Nonce = make([]byte, 32)
|
|
||||||
if _, err := io.ReadFull(rand, c.Nonce); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.SignatureKey = authority.PublicKey()
|
|
||||||
|
|
||||||
if v, ok := authority.(MultiAlgorithmSigner); ok {
|
|
||||||
if len(v.Algorithms()) == 0 {
|
|
||||||
return errors.New("the provided authority has no signature algorithm")
|
|
||||||
}
|
|
||||||
// Use the first algorithm in the list.
|
|
||||||
sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), v.Algorithms()[0])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.Signature = sig
|
|
||||||
return nil
|
|
||||||
} else if v, ok := authority.(AlgorithmSigner); ok && v.PublicKey().Type() == KeyAlgoRSA {
|
|
||||||
// Default to KeyAlgoRSASHA512 for ssh-rsa signers.
|
|
||||||
// TODO: consider using KeyAlgoRSASHA256 as default.
|
|
||||||
sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), KeyAlgoRSASHA512)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.Signature = sig
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sig, err := authority.Sign(rand, c.bytesForSigning())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.Signature = sig
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// certKeyAlgoNames is a mapping from known certificate algorithm names to the
|
|
||||||
// corresponding public key signature algorithm.
|
|
||||||
//
|
|
||||||
// This map must be kept in sync with the one in agent/client.go.
|
|
||||||
var certKeyAlgoNames = map[string]string{
|
|
||||||
CertAlgoRSAv01: KeyAlgoRSA,
|
|
||||||
CertAlgoRSASHA256v01: KeyAlgoRSASHA256,
|
|
||||||
CertAlgoRSASHA512v01: KeyAlgoRSASHA512,
|
|
||||||
CertAlgoDSAv01: KeyAlgoDSA,
|
|
||||||
CertAlgoECDSA256v01: KeyAlgoECDSA256,
|
|
||||||
CertAlgoECDSA384v01: KeyAlgoECDSA384,
|
|
||||||
CertAlgoECDSA521v01: KeyAlgoECDSA521,
|
|
||||||
CertAlgoSKECDSA256v01: KeyAlgoSKECDSA256,
|
|
||||||
CertAlgoED25519v01: KeyAlgoED25519,
|
|
||||||
CertAlgoSKED25519v01: KeyAlgoSKED25519,
|
|
||||||
}
|
|
||||||
|
|
||||||
// underlyingAlgo returns the signature algorithm associated with algo (which is
|
|
||||||
// an advertised or negotiated public key or host key algorithm). These are
|
|
||||||
// usually the same, except for certificate algorithms.
|
|
||||||
func underlyingAlgo(algo string) string {
|
|
||||||
if a, ok := certKeyAlgoNames[algo]; ok {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return algo
|
|
||||||
}
|
|
||||||
|
|
||||||
// certificateAlgo returns the certificate algorithms that uses the provided
|
|
||||||
// underlying signature algorithm.
|
|
||||||
func certificateAlgo(algo string) (certAlgo string, ok bool) {
|
|
||||||
for certName, algoName := range certKeyAlgoNames {
|
|
||||||
if algoName == algo {
|
|
||||||
return certName, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cert *Certificate) bytesForSigning() []byte {
|
|
||||||
c2 := *cert
|
|
||||||
c2.Signature = nil
|
|
||||||
out := c2.Marshal()
|
|
||||||
// Drop trailing signature length.
|
|
||||||
return out[:len(out)-4]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshal serializes c into OpenSSH's wire format. It is part of the
|
|
||||||
// PublicKey interface.
|
|
||||||
func (c *Certificate) Marshal() []byte {
|
|
||||||
generic := genericCertData{
|
|
||||||
Serial: c.Serial,
|
|
||||||
CertType: c.CertType,
|
|
||||||
KeyId: c.KeyId,
|
|
||||||
ValidPrincipals: marshalStringList(c.ValidPrincipals),
|
|
||||||
ValidAfter: uint64(c.ValidAfter),
|
|
||||||
ValidBefore: uint64(c.ValidBefore),
|
|
||||||
CriticalOptions: marshalTuples(c.CriticalOptions),
|
|
||||||
Extensions: marshalTuples(c.Extensions),
|
|
||||||
Reserved: c.Reserved,
|
|
||||||
SignatureKey: c.SignatureKey.Marshal(),
|
|
||||||
}
|
|
||||||
if c.Signature != nil {
|
|
||||||
generic.Signature = Marshal(c.Signature)
|
|
||||||
}
|
|
||||||
genericBytes := Marshal(&generic)
|
|
||||||
keyBytes := c.Key.Marshal()
|
|
||||||
_, keyBytes, _ = parseString(keyBytes)
|
|
||||||
prefix := Marshal(&struct {
|
|
||||||
Name string
|
|
||||||
Nonce []byte
|
|
||||||
Key []byte `ssh:"rest"`
|
|
||||||
}{c.Type(), c.Nonce, keyBytes})
|
|
||||||
|
|
||||||
result := make([]byte, 0, len(prefix)+len(genericBytes))
|
|
||||||
result = append(result, prefix...)
|
|
||||||
result = append(result, genericBytes...)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// Type returns the certificate algorithm name. It is part of the PublicKey interface.
|
|
||||||
func (c *Certificate) Type() string {
|
|
||||||
certName, ok := certificateAlgo(c.Key.Type())
|
|
||||||
if !ok {
|
|
||||||
panic("unknown certificate type for key type " + c.Key.Type())
|
|
||||||
}
|
|
||||||
return certName
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify verifies a signature against the certificate's public
|
|
||||||
// key. It is part of the PublicKey interface.
|
|
||||||
func (c *Certificate) Verify(data []byte, sig *Signature) error {
|
|
||||||
return c.Key.Verify(data, sig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) {
|
|
||||||
format, in, ok := parseString(in)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
out = &Signature{
|
|
||||||
Format: string(format),
|
|
||||||
}
|
|
||||||
|
|
||||||
if out.Blob, in, ok = parseString(in); !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch out.Format {
|
|
||||||
case KeyAlgoSKECDSA256, CertAlgoSKECDSA256v01, KeyAlgoSKED25519, CertAlgoSKED25519v01:
|
|
||||||
out.Rest = in
|
|
||||||
return out, nil, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, in, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) {
|
|
||||||
sigBytes, rest, ok := parseString(in)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
out, trailing, ok := parseSignatureBody(sigBytes)
|
|
||||||
if !ok || len(trailing) > 0 {
|
|
||||||
return nil, nil, false
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
@ -1,633 +0,0 @@
|
|||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
minPacketLength = 9
|
|
||||||
// channelMaxPacket contains the maximum number of bytes that will be
|
|
||||||
// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
|
|
||||||
// the minimum.
|
|
||||||
channelMaxPacket = 1 << 15
|
|
||||||
// We follow OpenSSH here.
|
|
||||||
channelWindowSize = 64 * channelMaxPacket
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewChannel represents an incoming request to a channel. It must either be
|
|
||||||
// accepted for use by calling Accept, or rejected by calling Reject.
|
|
||||||
type NewChannel interface {
|
|
||||||
// Accept accepts the channel creation request. It returns the Channel
|
|
||||||
// and a Go channel containing SSH requests. The Go channel must be
|
|
||||||
// serviced otherwise the Channel will hang.
|
|
||||||
Accept() (Channel, <-chan *Request, error)
|
|
||||||
|
|
||||||
// Reject rejects the channel creation request. After calling
|
|
||||||
// this, no other methods on the Channel may be called.
|
|
||||||
Reject(reason RejectionReason, message string) error
|
|
||||||
|
|
||||||
// ChannelType returns the type of the channel, as supplied by the
|
|
||||||
// client.
|
|
||||||
ChannelType() string
|
|
||||||
|
|
||||||
// ExtraData returns the arbitrary payload for this channel, as supplied
|
|
||||||
// by the client. This data is specific to the channel type.
|
|
||||||
ExtraData() []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// A Channel is an ordered, reliable, flow-controlled, duplex stream
|
|
||||||
// that is multiplexed over an SSH connection.
|
|
||||||
type Channel interface {
|
|
||||||
// Read reads up to len(data) bytes from the channel.
|
|
||||||
Read(data []byte) (int, error)
|
|
||||||
|
|
||||||
// Write writes len(data) bytes to the channel.
|
|
||||||
Write(data []byte) (int, error)
|
|
||||||
|
|
||||||
// Close signals end of channel use. No data may be sent after this
|
|
||||||
// call.
|
|
||||||
Close() error
|
|
||||||
|
|
||||||
// CloseWrite signals the end of sending in-band
|
|
||||||
// data. Requests may still be sent, and the other side may
|
|
||||||
// still send data
|
|
||||||
CloseWrite() error
|
|
||||||
|
|
||||||
// SendRequest sends a channel request. If wantReply is true,
|
|
||||||
// it will wait for a reply and return the result as a
|
|
||||||
// boolean, otherwise the return value will be false. Channel
|
|
||||||
// requests are out-of-band messages so they may be sent even
|
|
||||||
// if the data stream is closed or blocked by flow control.
|
|
||||||
// If the channel is closed before a reply is returned, io.EOF
|
|
||||||
// is returned.
|
|
||||||
SendRequest(name string, wantReply bool, payload []byte) (bool, error)
|
|
||||||
|
|
||||||
// Stderr returns an io.ReadWriter that writes to this channel
|
|
||||||
// with the extended data type set to stderr. Stderr may
|
|
||||||
// safely be read and written from a different goroutine than
|
|
||||||
// Read and Write respectively.
|
|
||||||
Stderr() io.ReadWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
// Request is a request sent outside of the normal stream of
|
|
||||||
// data. Requests can either be specific to an SSH channel, or they
|
|
||||||
// can be global.
|
|
||||||
type Request struct {
|
|
||||||
Type string
|
|
||||||
WantReply bool
|
|
||||||
Payload []byte
|
|
||||||
|
|
||||||
ch *channel
|
|
||||||
mux *mux
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reply sends a response to a request. It must be called for all requests
|
|
||||||
// where WantReply is true and is a no-op otherwise. The payload argument is
|
|
||||||
// ignored for replies to channel-specific requests.
|
|
||||||
func (r *Request) Reply(ok bool, payload []byte) error {
|
|
||||||
if !r.WantReply {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.ch == nil {
|
|
||||||
return r.mux.ackRequest(ok, payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.ch.ackRequest(ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RejectionReason is an enumeration used when rejecting channel creation
|
|
||||||
// requests. See RFC 4254, section 5.1.
|
|
||||||
type RejectionReason uint32
|
|
||||||
|
|
||||||
const (
|
|
||||||
Prohibited RejectionReason = iota + 1
|
|
||||||
ConnectionFailed
|
|
||||||
UnknownChannelType
|
|
||||||
ResourceShortage
|
|
||||||
)
|
|
||||||
|
|
||||||
// String converts the rejection reason to human readable form.
|
|
||||||
func (r RejectionReason) String() string {
|
|
||||||
switch r {
|
|
||||||
case Prohibited:
|
|
||||||
return "administratively prohibited"
|
|
||||||
case ConnectionFailed:
|
|
||||||
return "connect failed"
|
|
||||||
case UnknownChannelType:
|
|
||||||
return "unknown channel type"
|
|
||||||
case ResourceShortage:
|
|
||||||
return "resource shortage"
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("unknown reason %d", int(r))
|
|
||||||
}
|
|
||||||
|
|
||||||
func min(a uint32, b int) uint32 {
|
|
||||||
if a < uint32(b) {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return uint32(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
type channelDirection uint8
|
|
||||||
|
|
||||||
const (
|
|
||||||
channelInbound channelDirection = iota
|
|
||||||
channelOutbound
|
|
||||||
)
|
|
||||||
|
|
||||||
// channel is an implementation of the Channel interface that works
|
|
||||||
// with the mux class.
|
|
||||||
type channel struct {
|
|
||||||
// R/O after creation
|
|
||||||
chanType string
|
|
||||||
extraData []byte
|
|
||||||
localId, remoteId uint32
|
|
||||||
|
|
||||||
// maxIncomingPayload and maxRemotePayload are the maximum
|
|
||||||
// payload sizes of normal and extended data packets for
|
|
||||||
// receiving and sending, respectively. The wire packet will
|
|
||||||
// be 9 or 13 bytes larger (excluding encryption overhead).
|
|
||||||
maxIncomingPayload uint32
|
|
||||||
maxRemotePayload uint32
|
|
||||||
|
|
||||||
mux *mux
|
|
||||||
|
|
||||||
// decided is set to true if an accept or reject message has been sent
|
|
||||||
// (for outbound channels) or received (for inbound channels).
|
|
||||||
decided bool
|
|
||||||
|
|
||||||
// direction contains either channelOutbound, for channels created
|
|
||||||
// locally, or channelInbound, for channels created by the peer.
|
|
||||||
direction channelDirection
|
|
||||||
|
|
||||||
// Pending internal channel messages.
|
|
||||||
msg chan interface{}
|
|
||||||
|
|
||||||
// Since requests have no ID, there can be only one request
|
|
||||||
// with WantReply=true outstanding. This lock is held by a
|
|
||||||
// goroutine that has such an outgoing request pending.
|
|
||||||
sentRequestMu sync.Mutex
|
|
||||||
|
|
||||||
incomingRequests chan *Request
|
|
||||||
|
|
||||||
sentEOF bool
|
|
||||||
|
|
||||||
// thread-safe data
|
|
||||||
remoteWin window
|
|
||||||
pending *buffer
|
|
||||||
extPending *buffer
|
|
||||||
|
|
||||||
// windowMu protects myWindow, the flow-control window.
|
|
||||||
windowMu sync.Mutex
|
|
||||||
myWindow uint32
|
|
||||||
|
|
||||||
// writeMu serializes calls to mux.conn.writePacket() and
|
|
||||||
// protects sentClose and packetPool. This mutex must be
|
|
||||||
// different from windowMu, as writePacket can block if there
|
|
||||||
// is a key exchange pending.
|
|
||||||
writeMu sync.Mutex
|
|
||||||
sentClose bool
|
|
||||||
|
|
||||||
// packetPool has a buffer for each extended channel ID to
|
|
||||||
// save allocations during writes.
|
|
||||||
packetPool map[uint32][]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// writePacket sends a packet. If the packet is a channel close, it updates
|
|
||||||
// sentClose. This method takes the lock c.writeMu.
|
|
||||||
func (ch *channel) writePacket(packet []byte) error {
|
|
||||||
ch.writeMu.Lock()
|
|
||||||
if ch.sentClose {
|
|
||||||
ch.writeMu.Unlock()
|
|
||||||
return io.EOF
|
|
||||||
}
|
|
||||||
ch.sentClose = (packet[0] == msgChannelClose)
|
|
||||||
err := ch.mux.conn.writePacket(packet)
|
|
||||||
ch.writeMu.Unlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) sendMessage(msg interface{}) error {
|
|
||||||
if debugMux {
|
|
||||||
log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
p := Marshal(msg)
|
|
||||||
binary.BigEndian.PutUint32(p[1:], ch.remoteId)
|
|
||||||
return ch.writePacket(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteExtended writes data to a specific extended stream. These streams are
|
|
||||||
// used, for example, for stderr.
|
|
||||||
func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
|
|
||||||
if ch.sentEOF {
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
// 1 byte message type, 4 bytes remoteId, 4 bytes data length
|
|
||||||
opCode := byte(msgChannelData)
|
|
||||||
headerLength := uint32(9)
|
|
||||||
if extendedCode > 0 {
|
|
||||||
headerLength += 4
|
|
||||||
opCode = msgChannelExtendedData
|
|
||||||
}
|
|
||||||
|
|
||||||
ch.writeMu.Lock()
|
|
||||||
packet := ch.packetPool[extendedCode]
|
|
||||||
// We don't remove the buffer from packetPool, so
|
|
||||||
// WriteExtended calls from different goroutines will be
|
|
||||||
// flagged as errors by the race detector.
|
|
||||||
ch.writeMu.Unlock()
|
|
||||||
|
|
||||||
for len(data) > 0 {
|
|
||||||
space := min(ch.maxRemotePayload, len(data))
|
|
||||||
if space, err = ch.remoteWin.reserve(space); err != nil {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
if want := headerLength + space; uint32(cap(packet)) < want {
|
|
||||||
packet = make([]byte, want)
|
|
||||||
} else {
|
|
||||||
packet = packet[:want]
|
|
||||||
}
|
|
||||||
|
|
||||||
todo := data[:space]
|
|
||||||
|
|
||||||
packet[0] = opCode
|
|
||||||
binary.BigEndian.PutUint32(packet[1:], ch.remoteId)
|
|
||||||
if extendedCode > 0 {
|
|
||||||
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
|
|
||||||
}
|
|
||||||
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
|
|
||||||
copy(packet[headerLength:], todo)
|
|
||||||
if err = ch.writePacket(packet); err != nil {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
n += len(todo)
|
|
||||||
data = data[len(todo):]
|
|
||||||
}
|
|
||||||
|
|
||||||
ch.writeMu.Lock()
|
|
||||||
ch.packetPool[extendedCode] = packet
|
|
||||||
ch.writeMu.Unlock()
|
|
||||||
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) handleData(packet []byte) error {
|
|
||||||
headerLen := 9
|
|
||||||
isExtendedData := packet[0] == msgChannelExtendedData
|
|
||||||
if isExtendedData {
|
|
||||||
headerLen = 13
|
|
||||||
}
|
|
||||||
if len(packet) < headerLen {
|
|
||||||
// malformed data packet
|
|
||||||
return parseError(packet[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
var extended uint32
|
|
||||||
if isExtendedData {
|
|
||||||
extended = binary.BigEndian.Uint32(packet[5:])
|
|
||||||
}
|
|
||||||
|
|
||||||
length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen])
|
|
||||||
if length == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if length > ch.maxIncomingPayload {
|
|
||||||
// TODO(hanwen): should send Disconnect?
|
|
||||||
return errors.New("ssh: incoming packet exceeds maximum payload size")
|
|
||||||
}
|
|
||||||
|
|
||||||
data := packet[headerLen:]
|
|
||||||
if length != uint32(len(data)) {
|
|
||||||
return errors.New("ssh: wrong packet length")
|
|
||||||
}
|
|
||||||
|
|
||||||
ch.windowMu.Lock()
|
|
||||||
if ch.myWindow < length {
|
|
||||||
ch.windowMu.Unlock()
|
|
||||||
// TODO(hanwen): should send Disconnect with reason?
|
|
||||||
return errors.New("ssh: remote side wrote too much")
|
|
||||||
}
|
|
||||||
ch.myWindow -= length
|
|
||||||
ch.windowMu.Unlock()
|
|
||||||
|
|
||||||
if extended == 1 {
|
|
||||||
ch.extPending.write(data)
|
|
||||||
} else if extended > 0 {
|
|
||||||
// discard other extended data.
|
|
||||||
} else {
|
|
||||||
ch.pending.write(data)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *channel) adjustWindow(n uint32) error {
|
|
||||||
c.windowMu.Lock()
|
|
||||||
// Since myWindow is managed on our side, and can never exceed
|
|
||||||
// the initial window setting, we don't worry about overflow.
|
|
||||||
c.myWindow += uint32(n)
|
|
||||||
c.windowMu.Unlock()
|
|
||||||
return c.sendMessage(windowAdjustMsg{
|
|
||||||
AdditionalBytes: uint32(n),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) {
|
|
||||||
switch extended {
|
|
||||||
case 1:
|
|
||||||
n, err = c.extPending.Read(data)
|
|
||||||
case 0:
|
|
||||||
n, err = c.pending.Read(data)
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended)
|
|
||||||
}
|
|
||||||
|
|
||||||
if n > 0 {
|
|
||||||
err = c.adjustWindow(uint32(n))
|
|
||||||
// sendWindowAdjust can return io.EOF if the remote
|
|
||||||
// peer has closed the connection, however we want to
|
|
||||||
// defer forwarding io.EOF to the caller of Read until
|
|
||||||
// the buffer has been drained.
|
|
||||||
if n > 0 && err == io.EOF {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *channel) close() {
|
|
||||||
c.pending.eof()
|
|
||||||
c.extPending.eof()
|
|
||||||
close(c.msg)
|
|
||||||
close(c.incomingRequests)
|
|
||||||
c.writeMu.Lock()
|
|
||||||
// This is not necessary for a normal channel teardown, but if
|
|
||||||
// there was another error, it is.
|
|
||||||
c.sentClose = true
|
|
||||||
c.writeMu.Unlock()
|
|
||||||
// Unblock writers.
|
|
||||||
c.remoteWin.close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// responseMessageReceived is called when a success or failure message is
|
|
||||||
// received on a channel to check that such a message is reasonable for the
|
|
||||||
// given channel.
|
|
||||||
func (ch *channel) responseMessageReceived() error {
|
|
||||||
if ch.direction == channelInbound {
|
|
||||||
return errors.New("ssh: channel response message received on inbound channel")
|
|
||||||
}
|
|
||||||
if ch.decided {
|
|
||||||
return errors.New("ssh: duplicate response received for channel")
|
|
||||||
}
|
|
||||||
ch.decided = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) handlePacket(packet []byte) error {
|
|
||||||
switch packet[0] {
|
|
||||||
case msgChannelData, msgChannelExtendedData:
|
|
||||||
return ch.handleData(packet)
|
|
||||||
case msgChannelClose:
|
|
||||||
ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId})
|
|
||||||
ch.mux.chanList.remove(ch.localId)
|
|
||||||
ch.close()
|
|
||||||
return nil
|
|
||||||
case msgChannelEOF:
|
|
||||||
// RFC 4254 is mute on how EOF affects dataExt messages but
|
|
||||||
// it is logical to signal EOF at the same time.
|
|
||||||
ch.extPending.eof()
|
|
||||||
ch.pending.eof()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
decoded, err := decode(packet)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg := decoded.(type) {
|
|
||||||
case *channelOpenFailureMsg:
|
|
||||||
if err := ch.responseMessageReceived(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ch.mux.chanList.remove(msg.PeersID)
|
|
||||||
ch.msg <- msg
|
|
||||||
case *channelOpenConfirmMsg:
|
|
||||||
if err := ch.responseMessageReceived(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
|
|
||||||
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
|
|
||||||
}
|
|
||||||
ch.remoteId = msg.MyID
|
|
||||||
ch.maxRemotePayload = msg.MaxPacketSize
|
|
||||||
ch.remoteWin.add(msg.MyWindow)
|
|
||||||
ch.msg <- msg
|
|
||||||
case *windowAdjustMsg:
|
|
||||||
if !ch.remoteWin.add(msg.AdditionalBytes) {
|
|
||||||
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
|
|
||||||
}
|
|
||||||
case *channelRequestMsg:
|
|
||||||
req := Request{
|
|
||||||
Type: msg.Request,
|
|
||||||
WantReply: msg.WantReply,
|
|
||||||
Payload: msg.RequestSpecificData,
|
|
||||||
ch: ch,
|
|
||||||
}
|
|
||||||
|
|
||||||
ch.incomingRequests <- &req
|
|
||||||
default:
|
|
||||||
ch.msg <- msg
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel {
|
|
||||||
ch := &channel{
|
|
||||||
remoteWin: window{Cond: newCond()},
|
|
||||||
myWindow: channelWindowSize,
|
|
||||||
pending: newBuffer(),
|
|
||||||
extPending: newBuffer(),
|
|
||||||
direction: direction,
|
|
||||||
incomingRequests: make(chan *Request, chanSize),
|
|
||||||
msg: make(chan interface{}, chanSize),
|
|
||||||
chanType: chanType,
|
|
||||||
extraData: extraData,
|
|
||||||
mux: m,
|
|
||||||
packetPool: make(map[uint32][]byte),
|
|
||||||
}
|
|
||||||
ch.localId = m.chanList.add(ch)
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
var errUndecided = errors.New("ssh: must Accept or Reject channel")
|
|
||||||
var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once")
|
|
||||||
|
|
||||||
type extChannel struct {
|
|
||||||
code uint32
|
|
||||||
ch *channel
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *extChannel) Write(data []byte) (n int, err error) {
|
|
||||||
return e.ch.WriteExtended(data, e.code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *extChannel) Read(data []byte) (n int, err error) {
|
|
||||||
return e.ch.ReadExtended(data, e.code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) Accept() (Channel, <-chan *Request, error) {
|
|
||||||
if ch.decided {
|
|
||||||
return nil, nil, errDecidedAlready
|
|
||||||
}
|
|
||||||
ch.maxIncomingPayload = channelMaxPacket
|
|
||||||
confirm := channelOpenConfirmMsg{
|
|
||||||
PeersID: ch.remoteId,
|
|
||||||
MyID: ch.localId,
|
|
||||||
MyWindow: ch.myWindow,
|
|
||||||
MaxPacketSize: ch.maxIncomingPayload,
|
|
||||||
}
|
|
||||||
ch.decided = true
|
|
||||||
if err := ch.sendMessage(confirm); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return ch, ch.incomingRequests, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) Reject(reason RejectionReason, message string) error {
|
|
||||||
if ch.decided {
|
|
||||||
return errDecidedAlready
|
|
||||||
}
|
|
||||||
reject := channelOpenFailureMsg{
|
|
||||||
PeersID: ch.remoteId,
|
|
||||||
Reason: reason,
|
|
||||||
Message: message,
|
|
||||||
Language: "en",
|
|
||||||
}
|
|
||||||
ch.decided = true
|
|
||||||
return ch.sendMessage(reject)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) Read(data []byte) (int, error) {
|
|
||||||
if !ch.decided {
|
|
||||||
return 0, errUndecided
|
|
||||||
}
|
|
||||||
return ch.ReadExtended(data, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) Write(data []byte) (int, error) {
|
|
||||||
if !ch.decided {
|
|
||||||
return 0, errUndecided
|
|
||||||
}
|
|
||||||
return ch.WriteExtended(data, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) CloseWrite() error {
|
|
||||||
if !ch.decided {
|
|
||||||
return errUndecided
|
|
||||||
}
|
|
||||||
ch.sentEOF = true
|
|
||||||
return ch.sendMessage(channelEOFMsg{
|
|
||||||
PeersID: ch.remoteId})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) Close() error {
|
|
||||||
if !ch.decided {
|
|
||||||
return errUndecided
|
|
||||||
}
|
|
||||||
|
|
||||||
return ch.sendMessage(channelCloseMsg{
|
|
||||||
PeersID: ch.remoteId})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extended returns an io.ReadWriter that sends and receives data on the given,
|
|
||||||
// SSH extended stream. Such streams are used, for example, for stderr.
|
|
||||||
func (ch *channel) Extended(code uint32) io.ReadWriter {
|
|
||||||
if !ch.decided {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &extChannel{code, ch}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) Stderr() io.ReadWriter {
|
|
||||||
return ch.Extended(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
|
|
||||||
if !ch.decided {
|
|
||||||
return false, errUndecided
|
|
||||||
}
|
|
||||||
|
|
||||||
if wantReply {
|
|
||||||
ch.sentRequestMu.Lock()
|
|
||||||
defer ch.sentRequestMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := channelRequestMsg{
|
|
||||||
PeersID: ch.remoteId,
|
|
||||||
Request: name,
|
|
||||||
WantReply: wantReply,
|
|
||||||
RequestSpecificData: payload,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ch.sendMessage(msg); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if wantReply {
|
|
||||||
m, ok := (<-ch.msg)
|
|
||||||
if !ok {
|
|
||||||
return false, io.EOF
|
|
||||||
}
|
|
||||||
switch m.(type) {
|
|
||||||
case *channelRequestFailureMsg:
|
|
||||||
return false, nil
|
|
||||||
case *channelRequestSuccessMsg:
|
|
||||||
return true, nil
|
|
||||||
default:
|
|
||||||
return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ackRequest either sends an ack or nack to the channel request.
|
|
||||||
func (ch *channel) ackRequest(ok bool) error {
|
|
||||||
if !ch.decided {
|
|
||||||
return errUndecided
|
|
||||||
}
|
|
||||||
|
|
||||||
var msg interface{}
|
|
||||||
if !ok {
|
|
||||||
msg = channelRequestFailureMsg{
|
|
||||||
PeersID: ch.remoteId,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msg = channelRequestSuccessMsg{
|
|
||||||
PeersID: ch.remoteId,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ch.sendMessage(msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) ChannelType() string {
|
|
||||||
return ch.chanType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ch *channel) ExtraData() []byte {
|
|
||||||
return ch.extraData
|
|
||||||
}
|
|
@ -1,789 +0,0 @@
|
|||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/aes"
|
|
||||||
"crypto/cipher"
|
|
||||||
"crypto/des"
|
|
||||||
"crypto/rc4"
|
|
||||||
"crypto/subtle"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"hash"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh/internal/poly1305"
|
|
||||||
"golang.org/x/crypto/chacha20"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
|
|
||||||
|
|
||||||
// RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations
|
|
||||||
// MUST be able to process (plus a few more kilobytes for padding and mac). The RFC
|
|
||||||
// indicates implementations SHOULD be able to handle larger packet sizes, but then
|
|
||||||
// waffles on about reasonable limits.
|
|
||||||
//
|
|
||||||
// OpenSSH caps their maxPacket at 256kB so we choose to do
|
|
||||||
// the same. maxPacket is also used to ensure that uint32
|
|
||||||
// length fields do not overflow, so it should remain well
|
|
||||||
// below 4G.
|
|
||||||
maxPacket = 256 * 1024
|
|
||||||
)
|
|
||||||
|
|
||||||
// noneCipher implements cipher.Stream and provides no encryption. It is used
|
|
||||||
// by the transport before the first key-exchange.
|
|
||||||
type noneCipher struct{}
|
|
||||||
|
|
||||||
func (c noneCipher) XORKeyStream(dst, src []byte) {
|
|
||||||
copy(dst, src)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAESCTR(key, iv []byte) (cipher.Stream, error) {
|
|
||||||
c, err := aes.NewCipher(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return cipher.NewCTR(c, iv), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRC4(key, iv []byte) (cipher.Stream, error) {
|
|
||||||
return rc4.NewCipher(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
type cipherMode struct {
|
|
||||||
keySize int
|
|
||||||
ivSize int
|
|
||||||
create func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
|
|
||||||
return func(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
|
|
||||||
stream, err := createFunc(key, iv)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var streamDump []byte
|
|
||||||
if skip > 0 {
|
|
||||||
streamDump = make([]byte, 512)
|
|
||||||
}
|
|
||||||
|
|
||||||
for remainingToDump := skip; remainingToDump > 0; {
|
|
||||||
dumpThisTime := remainingToDump
|
|
||||||
if dumpThisTime > len(streamDump) {
|
|
||||||
dumpThisTime = len(streamDump)
|
|
||||||
}
|
|
||||||
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
|
|
||||||
remainingToDump -= dumpThisTime
|
|
||||||
}
|
|
||||||
|
|
||||||
mac := macModes[algs.MAC].new(macKey)
|
|
||||||
return &streamPacketCipher{
|
|
||||||
mac: mac,
|
|
||||||
etm: macModes[algs.MAC].etm,
|
|
||||||
macResult: make([]byte, mac.Size()),
|
|
||||||
cipher: stream,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// cipherModes documents properties of supported ciphers. Ciphers not included
|
|
||||||
// are not supported and will not be negotiated, even if explicitly requested in
|
|
||||||
// ClientConfig.Crypto.Ciphers.
|
|
||||||
var cipherModes = map[string]*cipherMode{
|
|
||||||
// Ciphers from RFC 4344, which introduced many CTR-based ciphers. Algorithms
|
|
||||||
// are defined in the order specified in the RFC.
|
|
||||||
"aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)},
|
|
||||||
"aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)},
|
|
||||||
"aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)},
|
|
||||||
|
|
||||||
// Ciphers from RFC 4345, which introduces security-improved arcfour ciphers.
|
|
||||||
// They are defined in the order specified in the RFC.
|
|
||||||
"arcfour128": {16, 0, streamCipherMode(1536, newRC4)},
|
|
||||||
"arcfour256": {32, 0, streamCipherMode(1536, newRC4)},
|
|
||||||
|
|
||||||
// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
|
|
||||||
// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
|
|
||||||
// RC4) has problems with weak keys, and should be used with caution."
|
|
||||||
// RFC 4345 introduces improved versions of Arcfour.
|
|
||||||
"arcfour": {16, 0, streamCipherMode(0, newRC4)},
|
|
||||||
|
|
||||||
// AEAD ciphers
|
|
||||||
gcm128CipherID: {16, 12, newGCMCipher},
|
|
||||||
gcm256CipherID: {32, 12, newGCMCipher},
|
|
||||||
chacha20Poly1305ID: {64, 0, newChaCha20Cipher},
|
|
||||||
|
|
||||||
// CBC mode is insecure and so is not included in the default config.
|
|
||||||
// (See https://www.ieee-security.org/TC/SP2013/papers/4977a526.pdf). If absolutely
|
|
||||||
// needed, it's possible to specify a custom Config to enable it.
|
|
||||||
// You should expect that an active attacker can recover plaintext if
|
|
||||||
// you do.
|
|
||||||
aes128cbcID: {16, aes.BlockSize, newAESCBCCipher},
|
|
||||||
|
|
||||||
// 3des-cbc is insecure and is not included in the default
|
|
||||||
// config.
|
|
||||||
tripledescbcID: {24, des.BlockSize, newTripleDESCBCCipher},
|
|
||||||
}
|
|
||||||
|
|
||||||
// prefixLen is the length of the packet prefix that contains the packet length
|
|
||||||
// and number of padding bytes.
|
|
||||||
const prefixLen = 5
|
|
||||||
|
|
||||||
// streamPacketCipher is a packetCipher using a stream cipher.
|
|
||||||
type streamPacketCipher struct {
|
|
||||||
mac hash.Hash
|
|
||||||
cipher cipher.Stream
|
|
||||||
etm bool
|
|
||||||
|
|
||||||
// The following members are to avoid per-packet allocations.
|
|
||||||
prefix [prefixLen]byte
|
|
||||||
seqNumBytes [4]byte
|
|
||||||
padding [2 * packetSizeMultiple]byte
|
|
||||||
packetData []byte
|
|
||||||
macResult []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// readCipherPacket reads and decrypt a single packet from the reader argument.
|
|
||||||
func (s *streamPacketCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) {
|
|
||||||
if _, err := io.ReadFull(r, s.prefix[:]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var encryptedPaddingLength [1]byte
|
|
||||||
if s.mac != nil && s.etm {
|
|
||||||
copy(encryptedPaddingLength[:], s.prefix[4:5])
|
|
||||||
s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5])
|
|
||||||
} else {
|
|
||||||
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
length := binary.BigEndian.Uint32(s.prefix[0:4])
|
|
||||||
paddingLength := uint32(s.prefix[4])
|
|
||||||
|
|
||||||
var macSize uint32
|
|
||||||
if s.mac != nil {
|
|
||||||
s.mac.Reset()
|
|
||||||
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
|
|
||||||
s.mac.Write(s.seqNumBytes[:])
|
|
||||||
if s.etm {
|
|
||||||
s.mac.Write(s.prefix[:4])
|
|
||||||
s.mac.Write(encryptedPaddingLength[:])
|
|
||||||
} else {
|
|
||||||
s.mac.Write(s.prefix[:])
|
|
||||||
}
|
|
||||||
macSize = uint32(s.mac.Size())
|
|
||||||
}
|
|
||||||
|
|
||||||
if length <= paddingLength+1 {
|
|
||||||
return nil, errors.New("ssh: invalid packet length, packet too small")
|
|
||||||
}
|
|
||||||
|
|
||||||
if length > maxPacket {
|
|
||||||
return nil, errors.New("ssh: invalid packet length, packet too large")
|
|
||||||
}
|
|
||||||
|
|
||||||
// the maxPacket check above ensures that length-1+macSize
|
|
||||||
// does not overflow.
|
|
||||||
if uint32(cap(s.packetData)) < length-1+macSize {
|
|
||||||
s.packetData = make([]byte, length-1+macSize)
|
|
||||||
} else {
|
|
||||||
s.packetData = s.packetData[:length-1+macSize]
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(r, s.packetData); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mac := s.packetData[length-1:]
|
|
||||||
data := s.packetData[:length-1]
|
|
||||||
|
|
||||||
if s.mac != nil && s.etm {
|
|
||||||
s.mac.Write(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.cipher.XORKeyStream(data, data)
|
|
||||||
|
|
||||||
if s.mac != nil {
|
|
||||||
if !s.etm {
|
|
||||||
s.mac.Write(data)
|
|
||||||
}
|
|
||||||
s.macResult = s.mac.Sum(s.macResult[:0])
|
|
||||||
if subtle.ConstantTimeCompare(s.macResult, mac) != 1 {
|
|
||||||
return nil, errors.New("ssh: MAC failure")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.packetData[:length-paddingLength-1], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// writeCipherPacket encrypts and sends a packet of data to the writer argument
|
|
||||||
func (s *streamPacketCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
|
|
||||||
if len(packet) > maxPacket {
|
|
||||||
return errors.New("ssh: packet too large")
|
|
||||||
}
|
|
||||||
|
|
||||||
aadlen := 0
|
|
||||||
if s.mac != nil && s.etm {
|
|
||||||
// packet length is not encrypted for EtM modes
|
|
||||||
aadlen = 4
|
|
||||||
}
|
|
||||||
|
|
||||||
paddingLength := packetSizeMultiple - (prefixLen+len(packet)-aadlen)%packetSizeMultiple
|
|
||||||
if paddingLength < 4 {
|
|
||||||
paddingLength += packetSizeMultiple
|
|
||||||
}
|
|
||||||
|
|
||||||
length := len(packet) + 1 + paddingLength
|
|
||||||
binary.BigEndian.PutUint32(s.prefix[:], uint32(length))
|
|
||||||
s.prefix[4] = byte(paddingLength)
|
|
||||||
padding := s.padding[:paddingLength]
|
|
||||||
if _, err := io.ReadFull(rand, padding); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.mac != nil {
|
|
||||||
s.mac.Reset()
|
|
||||||
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
|
|
||||||
s.mac.Write(s.seqNumBytes[:])
|
|
||||||
|
|
||||||
if s.etm {
|
|
||||||
// For EtM algorithms, the packet length must stay unencrypted,
|
|
||||||
// but the following data (padding length) must be encrypted
|
|
||||||
s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5])
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mac.Write(s.prefix[:])
|
|
||||||
|
|
||||||
if !s.etm {
|
|
||||||
// For non-EtM algorithms, the algorithm is applied on unencrypted data
|
|
||||||
s.mac.Write(packet)
|
|
||||||
s.mac.Write(padding)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(s.mac != nil && s.etm) {
|
|
||||||
// For EtM algorithms, the padding length has already been encrypted
|
|
||||||
// and the packet length must remain unencrypted
|
|
||||||
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
s.cipher.XORKeyStream(packet, packet)
|
|
||||||
s.cipher.XORKeyStream(padding, padding)
|
|
||||||
|
|
||||||
if s.mac != nil && s.etm {
|
|
||||||
// For EtM algorithms, packet and padding must be encrypted
|
|
||||||
s.mac.Write(packet)
|
|
||||||
s.mac.Write(padding)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := w.Write(s.prefix[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := w.Write(packet); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := w.Write(padding); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.mac != nil {
|
|
||||||
s.macResult = s.mac.Sum(s.macResult[:0])
|
|
||||||
if _, err := w.Write(s.macResult); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type gcmCipher struct {
|
|
||||||
aead cipher.AEAD
|
|
||||||
prefix [4]byte
|
|
||||||
iv []byte
|
|
||||||
buf []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
|
|
||||||
c, err := aes.NewCipher(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
aead, err := cipher.NewGCM(c)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &gcmCipher{
|
|
||||||
aead: aead,
|
|
||||||
iv: iv,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const gcmTagSize = 16
|
|
||||||
|
|
||||||
func (c *gcmCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
|
|
||||||
// Pad out to multiple of 16 bytes. This is different from the
|
|
||||||
// stream cipher because that encrypts the length too.
|
|
||||||
padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple)
|
|
||||||
if padding < 4 {
|
|
||||||
padding += packetSizeMultiple
|
|
||||||
}
|
|
||||||
|
|
||||||
length := uint32(len(packet) + int(padding) + 1)
|
|
||||||
binary.BigEndian.PutUint32(c.prefix[:], length)
|
|
||||||
if _, err := w.Write(c.prefix[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if cap(c.buf) < int(length) {
|
|
||||||
c.buf = make([]byte, length)
|
|
||||||
} else {
|
|
||||||
c.buf = c.buf[:length]
|
|
||||||
}
|
|
||||||
|
|
||||||
c.buf[0] = padding
|
|
||||||
copy(c.buf[1:], packet)
|
|
||||||
if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:])
|
|
||||||
if _, err := w.Write(c.buf); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.incIV()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *gcmCipher) incIV() {
|
|
||||||
for i := 4 + 7; i >= 4; i-- {
|
|
||||||
c.iv[i]++
|
|
||||||
if c.iv[i] != 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *gcmCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) {
|
|
||||||
if _, err := io.ReadFull(r, c.prefix[:]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
length := binary.BigEndian.Uint32(c.prefix[:])
|
|
||||||
if length > maxPacket {
|
|
||||||
return nil, errors.New("ssh: max packet length exceeded")
|
|
||||||
}
|
|
||||||
|
|
||||||
if cap(c.buf) < int(length+gcmTagSize) {
|
|
||||||
c.buf = make([]byte, length+gcmTagSize)
|
|
||||||
} else {
|
|
||||||
c.buf = c.buf[:length+gcmTagSize]
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(r, c.buf); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.incIV()
|
|
||||||
|
|
||||||
if len(plain) == 0 {
|
|
||||||
return nil, errors.New("ssh: empty packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
padding := plain[0]
|
|
||||||
if padding < 4 {
|
|
||||||
// padding is a byte, so it automatically satisfies
|
|
||||||
// the maximum size, which is 255.
|
|
||||||
return nil, fmt.Errorf("ssh: illegal padding %d", padding)
|
|
||||||
}
|
|
||||||
|
|
||||||
if int(padding+1) >= len(plain) {
|
|
||||||
return nil, fmt.Errorf("ssh: padding %d too large", padding)
|
|
||||||
}
|
|
||||||
plain = plain[1 : length-uint32(padding)]
|
|
||||||
return plain, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1
|
|
||||||
type cbcCipher struct {
|
|
||||||
mac hash.Hash
|
|
||||||
macSize uint32
|
|
||||||
decrypter cipher.BlockMode
|
|
||||||
encrypter cipher.BlockMode
|
|
||||||
|
|
||||||
// The following members are to avoid per-packet allocations.
|
|
||||||
seqNumBytes [4]byte
|
|
||||||
packetData []byte
|
|
||||||
macResult []byte
|
|
||||||
|
|
||||||
// Amount of data we should still read to hide which
|
|
||||||
// verification error triggered.
|
|
||||||
oracleCamouflage uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
|
|
||||||
cbc := &cbcCipher{
|
|
||||||
mac: macModes[algs.MAC].new(macKey),
|
|
||||||
decrypter: cipher.NewCBCDecrypter(c, iv),
|
|
||||||
encrypter: cipher.NewCBCEncrypter(c, iv),
|
|
||||||
packetData: make([]byte, 1024),
|
|
||||||
}
|
|
||||||
if cbc.mac != nil {
|
|
||||||
cbc.macSize = uint32(cbc.mac.Size())
|
|
||||||
}
|
|
||||||
|
|
||||||
return cbc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
|
|
||||||
c, err := aes.NewCipher(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cbc, err := newCBCCipher(c, key, iv, macKey, algs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return cbc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTripleDESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
|
|
||||||
c, err := des.NewTripleDESCipher(key)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cbc, err := newCBCCipher(c, key, iv, macKey, algs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return cbc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func maxUInt32(a, b int) uint32 {
|
|
||||||
if a > b {
|
|
||||||
return uint32(a)
|
|
||||||
}
|
|
||||||
return uint32(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
cbcMinPacketSizeMultiple = 8
|
|
||||||
cbcMinPacketSize = 16
|
|
||||||
cbcMinPaddingSize = 4
|
|
||||||
)
|
|
||||||
|
|
||||||
// cbcError represents a verification error that may leak information.
|
|
||||||
type cbcError string
|
|
||||||
|
|
||||||
func (e cbcError) Error() string { return string(e) }
|
|
||||||
|
|
||||||
func (c *cbcCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) {
|
|
||||||
p, err := c.readCipherPacketLeaky(seqNum, r)
|
|
||||||
if err != nil {
|
|
||||||
if _, ok := err.(cbcError); ok {
|
|
||||||
// Verification error: read a fixed amount of
|
|
||||||
// data, to make distinguishing between
|
|
||||||
// failing MAC and failing length check more
|
|
||||||
// difficult.
|
|
||||||
io.CopyN(io.Discard, r, int64(c.oracleCamouflage))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cbcCipher) readCipherPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) {
|
|
||||||
blockSize := c.decrypter.BlockSize()
|
|
||||||
|
|
||||||
// Read the header, which will include some of the subsequent data in the
|
|
||||||
// case of block ciphers - this is copied back to the payload later.
|
|
||||||
// How many bytes of payload/padding will be read with this first read.
|
|
||||||
firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize)
|
|
||||||
firstBlock := c.packetData[:firstBlockLength]
|
|
||||||
if _, err := io.ReadFull(r, firstBlock); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength
|
|
||||||
|
|
||||||
c.decrypter.CryptBlocks(firstBlock, firstBlock)
|
|
||||||
length := binary.BigEndian.Uint32(firstBlock[:4])
|
|
||||||
if length > maxPacket {
|
|
||||||
return nil, cbcError("ssh: packet too large")
|
|
||||||
}
|
|
||||||
if length+4 < maxUInt32(cbcMinPacketSize, blockSize) {
|
|
||||||
// The minimum size of a packet is 16 (or the cipher block size, whichever
|
|
||||||
// is larger) bytes.
|
|
||||||
return nil, cbcError("ssh: packet too small")
|
|
||||||
}
|
|
||||||
// The length of the packet (including the length field but not the MAC) must
|
|
||||||
// be a multiple of the block size or 8, whichever is larger.
|
|
||||||
if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 {
|
|
||||||
return nil, cbcError("ssh: invalid packet length multiple")
|
|
||||||
}
|
|
||||||
|
|
||||||
paddingLength := uint32(firstBlock[4])
|
|
||||||
if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 {
|
|
||||||
return nil, cbcError("ssh: invalid packet length")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Positions within the c.packetData buffer:
|
|
||||||
macStart := 4 + length
|
|
||||||
paddingStart := macStart - paddingLength
|
|
||||||
|
|
||||||
// Entire packet size, starting before length, ending at end of mac.
|
|
||||||
entirePacketSize := macStart + c.macSize
|
|
||||||
|
|
||||||
// Ensure c.packetData is large enough for the entire packet data.
|
|
||||||
if uint32(cap(c.packetData)) < entirePacketSize {
|
|
||||||
// Still need to upsize and copy, but this should be rare at runtime, only
|
|
||||||
// on upsizing the packetData buffer.
|
|
||||||
c.packetData = make([]byte, entirePacketSize)
|
|
||||||
copy(c.packetData, firstBlock)
|
|
||||||
} else {
|
|
||||||
c.packetData = c.packetData[:entirePacketSize]
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := io.ReadFull(r, c.packetData[firstBlockLength:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.oracleCamouflage -= uint32(n)
|
|
||||||
|
|
||||||
remainingCrypted := c.packetData[firstBlockLength:macStart]
|
|
||||||
c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted)
|
|
||||||
|
|
||||||
mac := c.packetData[macStart:]
|
|
||||||
if c.mac != nil {
|
|
||||||
c.mac.Reset()
|
|
||||||
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum)
|
|
||||||
c.mac.Write(c.seqNumBytes[:])
|
|
||||||
c.mac.Write(c.packetData[:macStart])
|
|
||||||
c.macResult = c.mac.Sum(c.macResult[:0])
|
|
||||||
if subtle.ConstantTimeCompare(c.macResult, mac) != 1 {
|
|
||||||
return nil, cbcError("ssh: MAC failure")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.packetData[prefixLen:paddingStart], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cbcCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
|
|
||||||
effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize())
|
|
||||||
|
|
||||||
// Length of encrypted portion of the packet (header, payload, padding).
|
|
||||||
// Enforce minimum padding and packet size.
|
|
||||||
encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize)
|
|
||||||
// Enforce block size.
|
|
||||||
encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize
|
|
||||||
|
|
||||||
length := encLength - 4
|
|
||||||
paddingLength := int(length) - (1 + len(packet))
|
|
||||||
|
|
||||||
// Overall buffer contains: header, payload, padding, mac.
|
|
||||||
// Space for the MAC is reserved in the capacity but not the slice length.
|
|
||||||
bufferSize := encLength + c.macSize
|
|
||||||
if uint32(cap(c.packetData)) < bufferSize {
|
|
||||||
c.packetData = make([]byte, encLength, bufferSize)
|
|
||||||
} else {
|
|
||||||
c.packetData = c.packetData[:encLength]
|
|
||||||
}
|
|
||||||
|
|
||||||
p := c.packetData
|
|
||||||
|
|
||||||
// Packet header.
|
|
||||||
binary.BigEndian.PutUint32(p, length)
|
|
||||||
p = p[4:]
|
|
||||||
p[0] = byte(paddingLength)
|
|
||||||
|
|
||||||
// Payload.
|
|
||||||
p = p[1:]
|
|
||||||
copy(p, packet)
|
|
||||||
|
|
||||||
// Padding.
|
|
||||||
p = p[len(packet):]
|
|
||||||
if _, err := io.ReadFull(rand, p); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.mac != nil {
|
|
||||||
c.mac.Reset()
|
|
||||||
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum)
|
|
||||||
c.mac.Write(c.seqNumBytes[:])
|
|
||||||
c.mac.Write(c.packetData)
|
|
||||||
// The MAC is now appended into the capacity reserved for it earlier.
|
|
||||||
c.packetData = c.mac.Sum(c.packetData)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength])
|
|
||||||
|
|
||||||
if _, err := w.Write(c.packetData); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const chacha20Poly1305ID = "chacha20-poly1305@openssh.com"
|
|
||||||
|
|
||||||
// chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com
|
|
||||||
// AEAD, which is described here:
|
|
||||||
//
|
|
||||||
// https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00
|
|
||||||
//
|
|
||||||
// the methods here also implement padding, which RFC 4253 Section 6
|
|
||||||
// also requires of stream ciphers.
|
|
||||||
type chacha20Poly1305Cipher struct {
|
|
||||||
lengthKey [32]byte
|
|
||||||
contentKey [32]byte
|
|
||||||
buf []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
|
|
||||||
if len(key) != 64 {
|
|
||||||
panic(len(key))
|
|
||||||
}
|
|
||||||
|
|
||||||
c := &chacha20Poly1305Cipher{
|
|
||||||
buf: make([]byte, 256),
|
|
||||||
}
|
|
||||||
|
|
||||||
copy(c.contentKey[:], key[:32])
|
|
||||||
copy(c.lengthKey[:], key[32:])
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *chacha20Poly1305Cipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) {
|
|
||||||
nonce := make([]byte, 12)
|
|
||||||
binary.BigEndian.PutUint32(nonce[8:], seqNum)
|
|
||||||
s, err := chacha20.NewUnauthenticatedCipher(c.contentKey[:], nonce)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var polyKey, discardBuf [32]byte
|
|
||||||
s.XORKeyStream(polyKey[:], polyKey[:])
|
|
||||||
s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes
|
|
||||||
|
|
||||||
encryptedLength := c.buf[:4]
|
|
||||||
if _, err := io.ReadFull(r, encryptedLength); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var lenBytes [4]byte
|
|
||||||
ls, err := chacha20.NewUnauthenticatedCipher(c.lengthKey[:], nonce)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ls.XORKeyStream(lenBytes[:], encryptedLength)
|
|
||||||
|
|
||||||
length := binary.BigEndian.Uint32(lenBytes[:])
|
|
||||||
if length > maxPacket {
|
|
||||||
return nil, errors.New("ssh: invalid packet length, packet too large")
|
|
||||||
}
|
|
||||||
|
|
||||||
contentEnd := 4 + length
|
|
||||||
packetEnd := contentEnd + poly1305.TagSize
|
|
||||||
if uint32(cap(c.buf)) < packetEnd {
|
|
||||||
c.buf = make([]byte, packetEnd)
|
|
||||||
copy(c.buf[:], encryptedLength)
|
|
||||||
} else {
|
|
||||||
c.buf = c.buf[:packetEnd]
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(r, c.buf[4:packetEnd]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var mac [poly1305.TagSize]byte
|
|
||||||
copy(mac[:], c.buf[contentEnd:packetEnd])
|
|
||||||
if !poly1305.Verify(&mac, c.buf[:contentEnd], &polyKey) {
|
|
||||||
return nil, errors.New("ssh: MAC failure")
|
|
||||||
}
|
|
||||||
|
|
||||||
plain := c.buf[4:contentEnd]
|
|
||||||
s.XORKeyStream(plain, plain)
|
|
||||||
|
|
||||||
if len(plain) == 0 {
|
|
||||||
return nil, errors.New("ssh: empty packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
padding := plain[0]
|
|
||||||
if padding < 4 {
|
|
||||||
// padding is a byte, so it automatically satisfies
|
|
||||||
// the maximum size, which is 255.
|
|
||||||
return nil, fmt.Errorf("ssh: illegal padding %d", padding)
|
|
||||||
}
|
|
||||||
|
|
||||||
if int(padding)+1 >= len(plain) {
|
|
||||||
return nil, fmt.Errorf("ssh: padding %d too large", padding)
|
|
||||||
}
|
|
||||||
|
|
||||||
plain = plain[1 : len(plain)-int(padding)]
|
|
||||||
|
|
||||||
return plain, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *chacha20Poly1305Cipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, payload []byte) error {
|
|
||||||
nonce := make([]byte, 12)
|
|
||||||
binary.BigEndian.PutUint32(nonce[8:], seqNum)
|
|
||||||
s, err := chacha20.NewUnauthenticatedCipher(c.contentKey[:], nonce)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var polyKey, discardBuf [32]byte
|
|
||||||
s.XORKeyStream(polyKey[:], polyKey[:])
|
|
||||||
s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes
|
|
||||||
|
|
||||||
// There is no blocksize, so fall back to multiple of 8 byte
|
|
||||||
// padding, as described in RFC 4253, Sec 6.
|
|
||||||
const packetSizeMultiple = 8
|
|
||||||
|
|
||||||
padding := packetSizeMultiple - (1+len(payload))%packetSizeMultiple
|
|
||||||
if padding < 4 {
|
|
||||||
padding += packetSizeMultiple
|
|
||||||
}
|
|
||||||
|
|
||||||
// size (4 bytes), padding (1), payload, padding, tag.
|
|
||||||
totalLength := 4 + 1 + len(payload) + padding + poly1305.TagSize
|
|
||||||
if cap(c.buf) < totalLength {
|
|
||||||
c.buf = make([]byte, totalLength)
|
|
||||||
} else {
|
|
||||||
c.buf = c.buf[:totalLength]
|
|
||||||
}
|
|
||||||
|
|
||||||
binary.BigEndian.PutUint32(c.buf, uint32(1+len(payload)+padding))
|
|
||||||
ls, err := chacha20.NewUnauthenticatedCipher(c.lengthKey[:], nonce)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ls.XORKeyStream(c.buf, c.buf[:4])
|
|
||||||
c.buf[4] = byte(padding)
|
|
||||||
copy(c.buf[5:], payload)
|
|
||||||
packetEnd := 5 + len(payload) + padding
|
|
||||||
if _, err := io.ReadFull(rand, c.buf[5+len(payload):packetEnd]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.XORKeyStream(c.buf[4:], c.buf[4:packetEnd])
|
|
||||||
|
|
||||||
var mac [poly1305.TagSize]byte
|
|
||||||
poly1305.Sum(&mac, c.buf[:packetEnd], &polyKey)
|
|
||||||
|
|
||||||
copy(c.buf[packetEnd:], mac[:])
|
|
||||||
|
|
||||||
if _, err := w.Write(c.buf); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,282 +0,0 @@
|
|||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Client implements a traditional SSH client that supports shells,
|
|
||||||
// subprocesses, TCP port/streamlocal forwarding and tunneled dialing.
|
|
||||||
type Client struct {
|
|
||||||
Conn
|
|
||||||
|
|
||||||
handleForwardsOnce sync.Once // guards calling (*Client).handleForwards
|
|
||||||
|
|
||||||
forwards forwardList // forwarded tcpip connections from the remote side
|
|
||||||
mu sync.Mutex
|
|
||||||
channelHandlers map[string]chan NewChannel
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleChannelOpen returns a channel on which NewChannel requests
|
|
||||||
// for the given type are sent. If the type already is being handled,
|
|
||||||
// nil is returned. The channel is closed when the connection is closed.
|
|
||||||
func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
if c.channelHandlers == nil {
|
|
||||||
// The SSH channel has been closed.
|
|
||||||
c := make(chan NewChannel)
|
|
||||||
close(c)
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
ch := c.channelHandlers[channelType]
|
|
||||||
if ch != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ch = make(chan NewChannel, chanSize)
|
|
||||||
c.channelHandlers[channelType] = ch
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClient creates a Client on top of the given connection.
|
|
||||||
func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
|
|
||||||
conn := &Client{
|
|
||||||
Conn: c,
|
|
||||||
channelHandlers: make(map[string]chan NewChannel, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
go conn.handleGlobalRequests(reqs)
|
|
||||||
go conn.handleChannelOpens(chans)
|
|
||||||
go func() {
|
|
||||||
conn.Wait()
|
|
||||||
conn.forwards.closeAll()
|
|
||||||
}()
|
|
||||||
return conn
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClientConn establishes an authenticated SSH connection using c
|
|
||||||
// as the underlying transport. The Request and NewChannel channels
|
|
||||||
// must be serviced or the connection will hang.
|
|
||||||
func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) {
|
|
||||||
fullConf := *config
|
|
||||||
fullConf.SetDefaults()
|
|
||||||
if fullConf.HostKeyCallback == nil {
|
|
||||||
c.Close()
|
|
||||||
return nil, nil, nil, errors.New("ssh: must specify HostKeyCallback")
|
|
||||||
}
|
|
||||||
|
|
||||||
conn := &connection{
|
|
||||||
sshConn: sshConn{conn: c, user: fullConf.User},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := conn.clientHandshake(addr, &fullConf); err != nil {
|
|
||||||
c.Close()
|
|
||||||
return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err)
|
|
||||||
}
|
|
||||||
conn.mux = newMux(conn.transport)
|
|
||||||
return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// clientHandshake performs the client side key exchange. See RFC 4253 Section
|
|
||||||
// 7.
|
|
||||||
func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error {
|
|
||||||
if config.ClientVersion != "" {
|
|
||||||
c.clientVersion = []byte(config.ClientVersion)
|
|
||||||
} else {
|
|
||||||
c.clientVersion = []byte(packageVersion)
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.transport = newClientTransport(
|
|
||||||
newTransport(c.sshConn.conn, config.Rand, true /* is client */),
|
|
||||||
c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
|
|
||||||
if err := c.transport.waitSession(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.sessionID = c.transport.getSessionID()
|
|
||||||
return c.clientAuthenticate(config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// verifyHostKeySignature verifies the host key obtained in the key exchange.
|
|
||||||
// algo is the negotiated algorithm, and may be a certificate type.
|
|
||||||
func verifyHostKeySignature(hostKey PublicKey, algo string, result *kexResult) error {
|
|
||||||
sig, rest, ok := parseSignatureBody(result.Signature)
|
|
||||||
if len(rest) > 0 || !ok {
|
|
||||||
return errors.New("ssh: signature parse error")
|
|
||||||
}
|
|
||||||
|
|
||||||
if a := underlyingAlgo(algo); sig.Format != a {
|
|
||||||
return fmt.Errorf("ssh: invalid signature algorithm %q, expected %q", sig.Format, a)
|
|
||||||
}
|
|
||||||
|
|
||||||
return hostKey.Verify(result.H, sig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSession opens a new Session for this client. (A session is a remote
|
|
||||||
// execution of a program.)
|
|
||||||
func (c *Client) NewSession() (*Session, error) {
|
|
||||||
ch, in, err := c.OpenChannel("session", nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return newSession(ch, in)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) handleGlobalRequests(incoming <-chan *Request) {
|
|
||||||
for r := range incoming {
|
|
||||||
// This handles keepalive messages and matches
|
|
||||||
// the behaviour of OpenSSH.
|
|
||||||
r.Reply(false, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleChannelOpens channel open messages from the remote side.
|
|
||||||
func (c *Client) handleChannelOpens(in <-chan NewChannel) {
|
|
||||||
for ch := range in {
|
|
||||||
c.mu.Lock()
|
|
||||||
handler := c.channelHandlers[ch.ChannelType()]
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
if handler != nil {
|
|
||||||
handler <- ch
|
|
||||||
} else {
|
|
||||||
ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.mu.Lock()
|
|
||||||
for _, ch := range c.channelHandlers {
|
|
||||||
close(ch)
|
|
||||||
}
|
|
||||||
c.channelHandlers = nil
|
|
||||||
c.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial starts a client connection to the given SSH server. It is a
|
|
||||||
// convenience function that connects to the given network address,
|
|
||||||
// initiates the SSH handshake, and then sets up a Client. For access
|
|
||||||
// to incoming channels and requests, use net.Dial with NewClientConn
|
|
||||||
// instead.
|
|
||||||
func Dial(network, addr string, config *ClientConfig) (*Client, error) {
|
|
||||||
conn, err := net.DialTimeout(network, addr, config.Timeout)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c, chans, reqs, err := NewClientConn(conn, addr, config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return NewClient(c, chans, reqs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// HostKeyCallback is the function type used for verifying server
|
|
||||||
// keys. A HostKeyCallback must return nil if the host key is OK, or
|
|
||||||
// an error to reject it. It receives the hostname as passed to Dial
|
|
||||||
// or NewClientConn. The remote address is the RemoteAddr of the
|
|
||||||
// net.Conn underlying the SSH connection.
|
|
||||||
type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
|
|
||||||
|
|
||||||
// BannerCallback is the function type used for treat the banner sent by
|
|
||||||
// the server. A BannerCallback receives the message sent by the remote server.
|
|
||||||
type BannerCallback func(message string) error
|
|
||||||
|
|
||||||
// A ClientConfig structure is used to configure a Client. It must not be
|
|
||||||
// modified after having been passed to an SSH function.
|
|
||||||
type ClientConfig struct {
|
|
||||||
// Config contains configuration that is shared between clients and
|
|
||||||
// servers.
|
|
||||||
Config
|
|
||||||
|
|
||||||
// User contains the username to authenticate as.
|
|
||||||
User string
|
|
||||||
|
|
||||||
// Auth contains possible authentication methods to use with the
|
|
||||||
// server. Only the first instance of a particular RFC 4252 method will
|
|
||||||
// be used during authentication.
|
|
||||||
Auth []AuthMethod
|
|
||||||
|
|
||||||
// HostKeyCallback is called during the cryptographic
|
|
||||||
// handshake to validate the server's host key. The client
|
|
||||||
// configuration must supply this callback for the connection
|
|
||||||
// to succeed. The functions InsecureIgnoreHostKey or
|
|
||||||
// FixedHostKey can be used for simplistic host key checks.
|
|
||||||
HostKeyCallback HostKeyCallback
|
|
||||||
|
|
||||||
// BannerCallback is called during the SSH dance to display a custom
|
|
||||||
// server's message. The client configuration can supply this callback to
|
|
||||||
// handle it as wished. The function BannerDisplayStderr can be used for
|
|
||||||
// simplistic display on Stderr.
|
|
||||||
BannerCallback BannerCallback
|
|
||||||
|
|
||||||
// ClientVersion contains the version identification string that will
|
|
||||||
// be used for the connection. If empty, a reasonable default is used.
|
|
||||||
ClientVersion string
|
|
||||||
|
|
||||||
// HostKeyAlgorithms lists the public key algorithms that the client will
|
|
||||||
// accept from the server for host key authentication, in order of
|
|
||||||
// preference. If empty, a reasonable default is used. Any
|
|
||||||
// string returned from a PublicKey.Type method may be used, or
|
|
||||||
// any of the CertAlgo and KeyAlgo constants.
|
|
||||||
HostKeyAlgorithms []string
|
|
||||||
|
|
||||||
// Timeout is the maximum amount of time for the TCP connection to establish.
|
|
||||||
//
|
|
||||||
// A Timeout of zero means no timeout.
|
|
||||||
Timeout time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// InsecureIgnoreHostKey returns a function that can be used for
|
|
||||||
// ClientConfig.HostKeyCallback to accept any host key. It should
|
|
||||||
// not be used for production code.
|
|
||||||
func InsecureIgnoreHostKey() HostKeyCallback {
|
|
||||||
return func(hostname string, remote net.Addr, key PublicKey) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type fixedHostKey struct {
|
|
||||||
key PublicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fixedHostKey) check(hostname string, remote net.Addr, key PublicKey) error {
|
|
||||||
if f.key == nil {
|
|
||||||
return fmt.Errorf("ssh: required host key was nil")
|
|
||||||
}
|
|
||||||
if !bytes.Equal(key.Marshal(), f.key.Marshal()) {
|
|
||||||
return fmt.Errorf("ssh: host key mismatch")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FixedHostKey returns a function for use in
|
|
||||||
// ClientConfig.HostKeyCallback to accept only a specific host key.
|
|
||||||
func FixedHostKey(key PublicKey) HostKeyCallback {
|
|
||||||
hk := &fixedHostKey{key}
|
|
||||||
return hk.check
|
|
||||||
}
|
|
||||||
|
|
||||||
// BannerDisplayStderr returns a function that can be used for
|
|
||||||
// ClientConfig.BannerCallback to display banners on os.Stderr.
|
|
||||||
func BannerDisplayStderr() BannerCallback {
|
|
||||||
return func(banner string) error {
|
|
||||||
_, err := os.Stderr.WriteString(banner)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,761 +0,0 @@
|
|||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type authResult int
|
|
||||||
|
|
||||||
const (
|
|
||||||
authFailure authResult = iota
|
|
||||||
authPartialSuccess
|
|
||||||
authSuccess
|
|
||||||
)
|
|
||||||
|
|
||||||
// clientAuthenticate authenticates with the remote server. See RFC 4252.
|
|
||||||
func (c *connection) clientAuthenticate(config *ClientConfig) error {
|
|
||||||
// initiate user auth session
|
|
||||||
if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
packet, err := c.transport.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// The server may choose to send a SSH_MSG_EXT_INFO at this point (if we
|
|
||||||
// advertised willingness to receive one, which we always do) or not. See
|
|
||||||
// RFC 8308, Section 2.4.
|
|
||||||
extensions := make(map[string][]byte)
|
|
||||||
if len(packet) > 0 && packet[0] == msgExtInfo {
|
|
||||||
var extInfo extInfoMsg
|
|
||||||
if err := Unmarshal(packet, &extInfo); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
payload := extInfo.Payload
|
|
||||||
for i := uint32(0); i < extInfo.NumExtensions; i++ {
|
|
||||||
name, rest, ok := parseString(payload)
|
|
||||||
if !ok {
|
|
||||||
return parseError(msgExtInfo)
|
|
||||||
}
|
|
||||||
value, rest, ok := parseString(rest)
|
|
||||||
if !ok {
|
|
||||||
return parseError(msgExtInfo)
|
|
||||||
}
|
|
||||||
extensions[string(name)] = value
|
|
||||||
payload = rest
|
|
||||||
}
|
|
||||||
packet, err = c.transport.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var serviceAccept serviceAcceptMsg
|
|
||||||
if err := Unmarshal(packet, &serviceAccept); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// during the authentication phase the client first attempts the "none" method
|
|
||||||
// then any untried methods suggested by the server.
|
|
||||||
var tried []string
|
|
||||||
var lastMethods []string
|
|
||||||
|
|
||||||
sessionID := c.transport.getSessionID()
|
|
||||||
for auth := AuthMethod(new(noneAuth)); auth != nil; {
|
|
||||||
ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand, extensions)
|
|
||||||
if err != nil {
|
|
||||||
// We return the error later if there is no other method left to
|
|
||||||
// try.
|
|
||||||
ok = authFailure
|
|
||||||
}
|
|
||||||
if ok == authSuccess {
|
|
||||||
// success
|
|
||||||
return nil
|
|
||||||
} else if ok == authFailure {
|
|
||||||
if m := auth.method(); !contains(tried, m) {
|
|
||||||
tried = append(tried, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if methods == nil {
|
|
||||||
methods = lastMethods
|
|
||||||
}
|
|
||||||
lastMethods = methods
|
|
||||||
|
|
||||||
auth = nil
|
|
||||||
|
|
||||||
findNext:
|
|
||||||
for _, a := range config.Auth {
|
|
||||||
candidateMethod := a.method()
|
|
||||||
if contains(tried, candidateMethod) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, meth := range methods {
|
|
||||||
if meth == candidateMethod {
|
|
||||||
auth = a
|
|
||||||
break findNext
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if auth == nil && err != nil {
|
|
||||||
// We have an error and there are no other authentication methods to
|
|
||||||
// try, so we return it.
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", tried)
|
|
||||||
}
|
|
||||||
|
|
||||||
func contains(list []string, e string) bool {
|
|
||||||
for _, s := range list {
|
|
||||||
if s == e {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// An AuthMethod represents an instance of an RFC 4252 authentication method.
|
|
||||||
type AuthMethod interface {
|
|
||||||
// auth authenticates user over transport t.
|
|
||||||
// Returns true if authentication is successful.
|
|
||||||
// If authentication is not successful, a []string of alternative
|
|
||||||
// method names is returned. If the slice is nil, it will be ignored
|
|
||||||
// and the previous set of possible methods will be reused.
|
|
||||||
auth(session []byte, user string, p packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error)
|
|
||||||
|
|
||||||
// method returns the RFC 4252 method name.
|
|
||||||
method() string
|
|
||||||
}
|
|
||||||
|
|
||||||
// "none" authentication, RFC 4252 section 5.2.
|
|
||||||
type noneAuth int
|
|
||||||
|
|
||||||
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
|
|
||||||
if err := c.writePacket(Marshal(&userAuthRequestMsg{
|
|
||||||
User: user,
|
|
||||||
Service: serviceSSH,
|
|
||||||
Method: "none",
|
|
||||||
})); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return handleAuthResponse(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *noneAuth) method() string {
|
|
||||||
return "none"
|
|
||||||
}
|
|
||||||
|
|
||||||
// passwordCallback is an AuthMethod that fetches the password through
|
|
||||||
// a function call, e.g. by prompting the user.
|
|
||||||
type passwordCallback func() (password string, err error)
|
|
||||||
|
|
||||||
func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
|
|
||||||
type passwordAuthMsg struct {
|
|
||||||
User string `sshtype:"50"`
|
|
||||||
Service string
|
|
||||||
Method string
|
|
||||||
Reply bool
|
|
||||||
Password string
|
|
||||||
}
|
|
||||||
|
|
||||||
pw, err := cb()
|
|
||||||
// REVIEW NOTE: is there a need to support skipping a password attempt?
|
|
||||||
// The program may only find out that the user doesn't have a password
|
|
||||||
// when prompting.
|
|
||||||
if err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.writePacket(Marshal(&passwordAuthMsg{
|
|
||||||
User: user,
|
|
||||||
Service: serviceSSH,
|
|
||||||
Method: cb.method(),
|
|
||||||
Reply: false,
|
|
||||||
Password: pw,
|
|
||||||
})); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return handleAuthResponse(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cb passwordCallback) method() string {
|
|
||||||
return "password"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Password returns an AuthMethod using the given password.
|
|
||||||
func Password(secret string) AuthMethod {
|
|
||||||
return passwordCallback(func() (string, error) { return secret, nil })
|
|
||||||
}
|
|
||||||
|
|
||||||
// PasswordCallback returns an AuthMethod that uses a callback for
|
|
||||||
// fetching a password.
|
|
||||||
func PasswordCallback(prompt func() (secret string, err error)) AuthMethod {
|
|
||||||
return passwordCallback(prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
type publickeyAuthMsg struct {
|
|
||||||
User string `sshtype:"50"`
|
|
||||||
Service string
|
|
||||||
Method string
|
|
||||||
// HasSig indicates to the receiver packet that the auth request is signed and
|
|
||||||
// should be used for authentication of the request.
|
|
||||||
HasSig bool
|
|
||||||
Algoname string
|
|
||||||
PubKey []byte
|
|
||||||
// Sig is tagged with "rest" so Marshal will exclude it during
|
|
||||||
// validateKey
|
|
||||||
Sig []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// publicKeyCallback is an AuthMethod that uses a set of key
|
|
||||||
// pairs for authentication.
|
|
||||||
type publicKeyCallback func() ([]Signer, error)
|
|
||||||
|
|
||||||
func (cb publicKeyCallback) method() string {
|
|
||||||
return "publickey"
|
|
||||||
}
|
|
||||||
|
|
||||||
func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiAlgorithmSigner, string, error) {
|
|
||||||
var as MultiAlgorithmSigner
|
|
||||||
keyFormat := signer.PublicKey().Type()
|
|
||||||
|
|
||||||
// If the signer implements MultiAlgorithmSigner we use the algorithms it
|
|
||||||
// support, if it implements AlgorithmSigner we assume it supports all
|
|
||||||
// algorithms, otherwise only the key format one.
|
|
||||||
switch s := signer.(type) {
|
|
||||||
case MultiAlgorithmSigner:
|
|
||||||
as = s
|
|
||||||
case AlgorithmSigner:
|
|
||||||
as = &multiAlgorithmSigner{
|
|
||||||
AlgorithmSigner: s,
|
|
||||||
supportedAlgorithms: algorithmsForKeyFormat(underlyingAlgo(keyFormat)),
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
as = &multiAlgorithmSigner{
|
|
||||||
AlgorithmSigner: algorithmSignerWrapper{signer},
|
|
||||||
supportedAlgorithms: []string{underlyingAlgo(keyFormat)},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
getFallbackAlgo := func() (string, error) {
|
|
||||||
// Fallback to use if there is no "server-sig-algs" extension or a
|
|
||||||
// common algorithm cannot be found. We use the public key format if the
|
|
||||||
// MultiAlgorithmSigner supports it, otherwise we return an error.
|
|
||||||
if !contains(as.Algorithms(), underlyingAlgo(keyFormat)) {
|
|
||||||
return "", fmt.Errorf("ssh: no common public key signature algorithm, server only supports %q for key type %q, signer only supports %v",
|
|
||||||
underlyingAlgo(keyFormat), keyFormat, as.Algorithms())
|
|
||||||
}
|
|
||||||
return keyFormat, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
extPayload, ok := extensions["server-sig-algs"]
|
|
||||||
if !ok {
|
|
||||||
// If there is no "server-sig-algs" extension use the fallback
|
|
||||||
// algorithm.
|
|
||||||
algo, err := getFallbackAlgo()
|
|
||||||
return as, algo, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// The server-sig-algs extension only carries underlying signature
|
|
||||||
// algorithm, but we are trying to select a protocol-level public key
|
|
||||||
// algorithm, which might be a certificate type. Extend the list of server
|
|
||||||
// supported algorithms to include the corresponding certificate algorithms.
|
|
||||||
serverAlgos := strings.Split(string(extPayload), ",")
|
|
||||||
for _, algo := range serverAlgos {
|
|
||||||
if certAlgo, ok := certificateAlgo(algo); ok {
|
|
||||||
serverAlgos = append(serverAlgos, certAlgo)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter algorithms based on those supported by MultiAlgorithmSigner.
|
|
||||||
var keyAlgos []string
|
|
||||||
for _, algo := range algorithmsForKeyFormat(keyFormat) {
|
|
||||||
if contains(as.Algorithms(), underlyingAlgo(algo)) {
|
|
||||||
keyAlgos = append(keyAlgos, algo)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos)
|
|
||||||
if err != nil {
|
|
||||||
// If there is no overlap, return the fallback algorithm to support
|
|
||||||
// servers that fail to list all supported algorithms.
|
|
||||||
algo, err := getFallbackAlgo()
|
|
||||||
return as, algo, err
|
|
||||||
}
|
|
||||||
return as, algo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) {
|
|
||||||
// Authentication is performed by sending an enquiry to test if a key is
|
|
||||||
// acceptable to the remote. If the key is acceptable, the client will
|
|
||||||
// attempt to authenticate with the valid key. If not the client will repeat
|
|
||||||
// the process with the remaining keys.
|
|
||||||
|
|
||||||
signers, err := cb()
|
|
||||||
if err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
var methods []string
|
|
||||||
var errSigAlgo error
|
|
||||||
for _, signer := range signers {
|
|
||||||
pub := signer.PublicKey()
|
|
||||||
as, algo, err := pickSignatureAlgorithm(signer, extensions)
|
|
||||||
if err != nil && errSigAlgo == nil {
|
|
||||||
// If we cannot negotiate a signature algorithm store the first
|
|
||||||
// error so we can return it to provide a more meaningful message if
|
|
||||||
// no other signers work.
|
|
||||||
errSigAlgo = err
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ok, err := validateKey(pub, algo, user, c)
|
|
||||||
if err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
pubKey := pub.Marshal()
|
|
||||||
data := buildDataSignedForAuth(session, userAuthRequestMsg{
|
|
||||||
User: user,
|
|
||||||
Service: serviceSSH,
|
|
||||||
Method: cb.method(),
|
|
||||||
}, algo, pubKey)
|
|
||||||
sign, err := as.SignWithAlgorithm(rand, data, underlyingAlgo(algo))
|
|
||||||
if err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// manually wrap the serialized signature in a string
|
|
||||||
s := Marshal(sign)
|
|
||||||
sig := make([]byte, stringLength(len(s)))
|
|
||||||
marshalString(sig, s)
|
|
||||||
msg := publickeyAuthMsg{
|
|
||||||
User: user,
|
|
||||||
Service: serviceSSH,
|
|
||||||
Method: cb.method(),
|
|
||||||
HasSig: true,
|
|
||||||
Algoname: algo,
|
|
||||||
PubKey: pubKey,
|
|
||||||
Sig: sig,
|
|
||||||
}
|
|
||||||
p := Marshal(&msg)
|
|
||||||
if err := c.writePacket(p); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
var success authResult
|
|
||||||
success, methods, err = handleAuthResponse(c)
|
|
||||||
if err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If authentication succeeds or the list of available methods does not
|
|
||||||
// contain the "publickey" method, do not attempt to authenticate with any
|
|
||||||
// other keys. According to RFC 4252 Section 7, the latter can occur when
|
|
||||||
// additional authentication methods are required.
|
|
||||||
if success == authSuccess || !contains(methods, cb.method()) {
|
|
||||||
return success, methods, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return authFailure, methods, errSigAlgo
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateKey validates the key provided is acceptable to the server.
|
|
||||||
func validateKey(key PublicKey, algo string, user string, c packetConn) (bool, error) {
|
|
||||||
pubKey := key.Marshal()
|
|
||||||
msg := publickeyAuthMsg{
|
|
||||||
User: user,
|
|
||||||
Service: serviceSSH,
|
|
||||||
Method: "publickey",
|
|
||||||
HasSig: false,
|
|
||||||
Algoname: algo,
|
|
||||||
PubKey: pubKey,
|
|
||||||
}
|
|
||||||
if err := c.writePacket(Marshal(&msg)); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return confirmKeyAck(key, algo, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func confirmKeyAck(key PublicKey, algo string, c packetConn) (bool, error) {
|
|
||||||
pubKey := key.Marshal()
|
|
||||||
|
|
||||||
for {
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
switch packet[0] {
|
|
||||||
case msgUserAuthBanner:
|
|
||||||
if err := handleBannerResponse(c, packet); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
case msgUserAuthPubKeyOk:
|
|
||||||
var msg userAuthPubKeyOkMsg
|
|
||||||
if err := Unmarshal(packet, &msg); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if msg.Algo != algo || !bytes.Equal(msg.PubKey, pubKey) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
case msgUserAuthFailure:
|
|
||||||
return false, nil
|
|
||||||
default:
|
|
||||||
return false, unexpectedMessageError(msgUserAuthPubKeyOk, packet[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// PublicKeys returns an AuthMethod that uses the given key
|
|
||||||
// pairs.
|
|
||||||
func PublicKeys(signers ...Signer) AuthMethod {
|
|
||||||
return publicKeyCallback(func() ([]Signer, error) { return signers, nil })
|
|
||||||
}
|
|
||||||
|
|
||||||
// PublicKeysCallback returns an AuthMethod that runs the given
|
|
||||||
// function to obtain a list of key pairs.
|
|
||||||
func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod {
|
|
||||||
return publicKeyCallback(getSigners)
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleAuthResponse returns whether the preceding authentication request succeeded
|
|
||||||
// along with a list of remaining authentication methods to try next and
|
|
||||||
// an error if an unexpected response was received.
|
|
||||||
func handleAuthResponse(c packetConn) (authResult, []string, error) {
|
|
||||||
gotMsgExtInfo := false
|
|
||||||
for {
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch packet[0] {
|
|
||||||
case msgUserAuthBanner:
|
|
||||||
if err := handleBannerResponse(c, packet); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
case msgExtInfo:
|
|
||||||
// Ignore post-authentication RFC 8308 extensions, once.
|
|
||||||
if gotMsgExtInfo {
|
|
||||||
return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
|
|
||||||
}
|
|
||||||
gotMsgExtInfo = true
|
|
||||||
case msgUserAuthFailure:
|
|
||||||
var msg userAuthFailureMsg
|
|
||||||
if err := Unmarshal(packet, &msg); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
if msg.PartialSuccess {
|
|
||||||
return authPartialSuccess, msg.Methods, nil
|
|
||||||
}
|
|
||||||
return authFailure, msg.Methods, nil
|
|
||||||
case msgUserAuthSuccess:
|
|
||||||
return authSuccess, nil, nil
|
|
||||||
default:
|
|
||||||
return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleBannerResponse(c packetConn, packet []byte) error {
|
|
||||||
var msg userAuthBannerMsg
|
|
||||||
if err := Unmarshal(packet, &msg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
transport, ok := c.(*handshakeTransport)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if transport.bannerCallback != nil {
|
|
||||||
return transport.bannerCallback(msg.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// KeyboardInteractiveChallenge should print questions, optionally
|
|
||||||
// disabling echoing (e.g. for passwords), and return all the answers.
|
|
||||||
// Challenge may be called multiple times in a single session. After
|
|
||||||
// successful authentication, the server may send a challenge with no
|
|
||||||
// questions, for which the name and instruction messages should be
|
|
||||||
// printed. RFC 4256 section 3.3 details how the UI should behave for
|
|
||||||
// both CLI and GUI environments.
|
|
||||||
type KeyboardInteractiveChallenge func(name, instruction string, questions []string, echos []bool) (answers []string, err error)
|
|
||||||
|
|
||||||
// KeyboardInteractive returns an AuthMethod using a prompt/response
|
|
||||||
// sequence controlled by the server.
|
|
||||||
func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod {
|
|
||||||
return challenge
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cb KeyboardInteractiveChallenge) method() string {
|
|
||||||
return "keyboard-interactive"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
|
|
||||||
type initiateMsg struct {
|
|
||||||
User string `sshtype:"50"`
|
|
||||||
Service string
|
|
||||||
Method string
|
|
||||||
Language string
|
|
||||||
Submethods string
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.writePacket(Marshal(&initiateMsg{
|
|
||||||
User: user,
|
|
||||||
Service: serviceSSH,
|
|
||||||
Method: "keyboard-interactive",
|
|
||||||
})); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
gotMsgExtInfo := false
|
|
||||||
for {
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// like handleAuthResponse, but with less options.
|
|
||||||
switch packet[0] {
|
|
||||||
case msgUserAuthBanner:
|
|
||||||
if err := handleBannerResponse(c, packet); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
case msgExtInfo:
|
|
||||||
// Ignore post-authentication RFC 8308 extensions, once.
|
|
||||||
if gotMsgExtInfo {
|
|
||||||
return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
|
|
||||||
}
|
|
||||||
gotMsgExtInfo = true
|
|
||||||
continue
|
|
||||||
case msgUserAuthInfoRequest:
|
|
||||||
// OK
|
|
||||||
case msgUserAuthFailure:
|
|
||||||
var msg userAuthFailureMsg
|
|
||||||
if err := Unmarshal(packet, &msg); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
if msg.PartialSuccess {
|
|
||||||
return authPartialSuccess, msg.Methods, nil
|
|
||||||
}
|
|
||||||
return authFailure, msg.Methods, nil
|
|
||||||
case msgUserAuthSuccess:
|
|
||||||
return authSuccess, nil, nil
|
|
||||||
default:
|
|
||||||
return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
var msg userAuthInfoRequestMsg
|
|
||||||
if err := Unmarshal(packet, &msg); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Manually unpack the prompt/echo pairs.
|
|
||||||
rest := msg.Prompts
|
|
||||||
var prompts []string
|
|
||||||
var echos []bool
|
|
||||||
for i := 0; i < int(msg.NumPrompts); i++ {
|
|
||||||
prompt, r, ok := parseString(rest)
|
|
||||||
if !ok || len(r) == 0 {
|
|
||||||
return authFailure, nil, errors.New("ssh: prompt format error")
|
|
||||||
}
|
|
||||||
prompts = append(prompts, string(prompt))
|
|
||||||
echos = append(echos, r[0] != 0)
|
|
||||||
rest = r[1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rest) != 0 {
|
|
||||||
return authFailure, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
|
|
||||||
}
|
|
||||||
|
|
||||||
answers, err := cb(msg.Name, msg.Instruction, prompts, echos)
|
|
||||||
if err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(answers) != len(prompts) {
|
|
||||||
return authFailure, nil, fmt.Errorf("ssh: incorrect number of answers from keyboard-interactive callback %d (expected %d)", len(answers), len(prompts))
|
|
||||||
}
|
|
||||||
responseLength := 1 + 4
|
|
||||||
for _, a := range answers {
|
|
||||||
responseLength += stringLength(len(a))
|
|
||||||
}
|
|
||||||
serialized := make([]byte, responseLength)
|
|
||||||
p := serialized
|
|
||||||
p[0] = msgUserAuthInfoResponse
|
|
||||||
p = p[1:]
|
|
||||||
p = marshalUint32(p, uint32(len(answers)))
|
|
||||||
for _, a := range answers {
|
|
||||||
p = marshalString(p, []byte(a))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.writePacket(serialized); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type retryableAuthMethod struct {
|
|
||||||
authMethod AuthMethod
|
|
||||||
maxTries int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (ok authResult, methods []string, err error) {
|
|
||||||
for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ {
|
|
||||||
ok, methods, err = r.authMethod.auth(session, user, c, rand, extensions)
|
|
||||||
if ok != authFailure || err != nil { // either success, partial success or error terminate
|
|
||||||
return ok, methods, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ok, methods, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *retryableAuthMethod) method() string {
|
|
||||||
return r.authMethod.method()
|
|
||||||
}
|
|
||||||
|
|
||||||
// RetryableAuthMethod is a decorator for other auth methods enabling them to
|
|
||||||
// be retried up to maxTries before considering that AuthMethod itself failed.
|
|
||||||
// If maxTries is <= 0, will retry indefinitely
|
|
||||||
//
|
|
||||||
// This is useful for interactive clients using challenge/response type
|
|
||||||
// authentication (e.g. Keyboard-Interactive, Password, etc) where the user
|
|
||||||
// could mistype their response resulting in the server issuing a
|
|
||||||
// SSH_MSG_USERAUTH_FAILURE (rfc4252 #8 [password] and rfc4256 #3.4
|
|
||||||
// [keyboard-interactive]); Without this decorator, the non-retryable
|
|
||||||
// AuthMethod would be removed from future consideration, and never tried again
|
|
||||||
// (and so the user would never be able to retry their entry).
|
|
||||||
func RetryableAuthMethod(auth AuthMethod, maxTries int) AuthMethod {
|
|
||||||
return &retryableAuthMethod{authMethod: auth, maxTries: maxTries}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GSSAPIWithMICAuthMethod is an AuthMethod with "gssapi-with-mic" authentication.
|
|
||||||
// See RFC 4462 section 3
|
|
||||||
// gssAPIClient is implementation of the GSSAPIClient interface, see the definition of the interface for details.
|
|
||||||
// target is the server host you want to log in to.
|
|
||||||
func GSSAPIWithMICAuthMethod(gssAPIClient GSSAPIClient, target string) AuthMethod {
|
|
||||||
if gssAPIClient == nil {
|
|
||||||
panic("gss-api client must be not nil with enable gssapi-with-mic")
|
|
||||||
}
|
|
||||||
return &gssAPIWithMICCallback{gssAPIClient: gssAPIClient, target: target}
|
|
||||||
}
|
|
||||||
|
|
||||||
type gssAPIWithMICCallback struct {
|
|
||||||
gssAPIClient GSSAPIClient
|
|
||||||
target string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *gssAPIWithMICCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) {
|
|
||||||
m := &userAuthRequestMsg{
|
|
||||||
User: user,
|
|
||||||
Service: serviceSSH,
|
|
||||||
Method: g.method(),
|
|
||||||
}
|
|
||||||
// The GSS-API authentication method is initiated when the client sends an SSH_MSG_USERAUTH_REQUEST.
|
|
||||||
// See RFC 4462 section 3.2.
|
|
||||||
m.Payload = appendU32(m.Payload, 1)
|
|
||||||
m.Payload = appendString(m.Payload, string(krb5OID))
|
|
||||||
if err := c.writePacket(Marshal(m)); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
// The server responds to the SSH_MSG_USERAUTH_REQUEST with either an
|
|
||||||
// SSH_MSG_USERAUTH_FAILURE if none of the mechanisms are supported or
|
|
||||||
// with an SSH_MSG_USERAUTH_GSSAPI_RESPONSE.
|
|
||||||
// See RFC 4462 section 3.3.
|
|
||||||
// OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication,so I don't want to check
|
|
||||||
// selected mech if it is valid.
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
userAuthGSSAPIResp := &userAuthGSSAPIResponse{}
|
|
||||||
if err := Unmarshal(packet, userAuthGSSAPIResp); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
// Start the loop into the exchange token.
|
|
||||||
// See RFC 4462 section 3.4.
|
|
||||||
var token []byte
|
|
||||||
defer g.gssAPIClient.DeleteSecContext()
|
|
||||||
for {
|
|
||||||
// Initiates the establishment of a security context between the application and a remote peer.
|
|
||||||
nextToken, needContinue, err := g.gssAPIClient.InitSecContext("host@"+g.target, token, false)
|
|
||||||
if err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
if len(nextToken) > 0 {
|
|
||||||
if err := c.writePacket(Marshal(&userAuthGSSAPIToken{
|
|
||||||
Token: nextToken,
|
|
||||||
})); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !needContinue {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
packet, err = c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
switch packet[0] {
|
|
||||||
case msgUserAuthFailure:
|
|
||||||
var msg userAuthFailureMsg
|
|
||||||
if err := Unmarshal(packet, &msg); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
if msg.PartialSuccess {
|
|
||||||
return authPartialSuccess, msg.Methods, nil
|
|
||||||
}
|
|
||||||
return authFailure, msg.Methods, nil
|
|
||||||
case msgUserAuthGSSAPIError:
|
|
||||||
userAuthGSSAPIErrorResp := &userAuthGSSAPIError{}
|
|
||||||
if err := Unmarshal(packet, userAuthGSSAPIErrorResp); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
return authFailure, nil, fmt.Errorf("GSS-API Error:\n"+
|
|
||||||
"Major Status: %d\n"+
|
|
||||||
"Minor Status: %d\n"+
|
|
||||||
"Error Message: %s\n", userAuthGSSAPIErrorResp.MajorStatus, userAuthGSSAPIErrorResp.MinorStatus,
|
|
||||||
userAuthGSSAPIErrorResp.Message)
|
|
||||||
case msgUserAuthGSSAPIToken:
|
|
||||||
userAuthGSSAPITokenReq := &userAuthGSSAPIToken{}
|
|
||||||
if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
token = userAuthGSSAPITokenReq.Token
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Binding Encryption Keys.
|
|
||||||
// See RFC 4462 section 3.5.
|
|
||||||
micField := buildMIC(string(session), user, "ssh-connection", "gssapi-with-mic")
|
|
||||||
micToken, err := g.gssAPIClient.GetMIC(micField)
|
|
||||||
if err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
if err := c.writePacket(Marshal(&userAuthGSSAPIMIC{
|
|
||||||
MIC: micToken,
|
|
||||||
})); err != nil {
|
|
||||||
return authFailure, nil, err
|
|
||||||
}
|
|
||||||
return handleAuthResponse(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *gssAPIWithMICCallback) method() string {
|
|
||||||
return "gssapi-with-mic"
|
|
||||||
}
|
|
@ -1,468 +0,0 @@
|
|||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
_ "crypto/sha1"
|
|
||||||
_ "crypto/sha256"
|
|
||||||
_ "crypto/sha512"
|
|
||||||
)
|
|
||||||
|
|
||||||
// These are string constants in the SSH protocol.
|
|
||||||
const (
|
|
||||||
compressionNone = "none"
|
|
||||||
serviceUserAuth = "ssh-userauth"
|
|
||||||
serviceSSH = "ssh-connection"
|
|
||||||
)
|
|
||||||
|
|
||||||
// supportedCiphers lists ciphers we support but might not recommend.
|
|
||||||
var supportedCiphers = []string{
|
|
||||||
"aes128-ctr", "aes192-ctr", "aes256-ctr",
|
|
||||||
"aes128-gcm@openssh.com", gcm256CipherID,
|
|
||||||
chacha20Poly1305ID,
|
|
||||||
"arcfour256", "arcfour128", "arcfour",
|
|
||||||
aes128cbcID,
|
|
||||||
tripledescbcID,
|
|
||||||
}
|
|
||||||
|
|
||||||
// preferredCiphers specifies the default preference for ciphers.
|
|
||||||
var preferredCiphers = []string{
|
|
||||||
"aes128-gcm@openssh.com", gcm256CipherID,
|
|
||||||
chacha20Poly1305ID,
|
|
||||||
"aes128-ctr", "aes192-ctr", "aes256-ctr",
|
|
||||||
}
|
|
||||||
|
|
||||||
// supportedKexAlgos specifies the supported key-exchange algorithms in
|
|
||||||
// preference order.
|
|
||||||
var supportedKexAlgos = []string{
|
|
||||||
kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH,
|
|
||||||
// P384 and P521 are not constant-time yet, but since we don't
|
|
||||||
// reuse ephemeral keys, using them for ECDH should be OK.
|
|
||||||
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
|
|
||||||
kexAlgoDH14SHA256, kexAlgoDH16SHA512, kexAlgoDH14SHA1,
|
|
||||||
kexAlgoDH1SHA1,
|
|
||||||
}
|
|
||||||
|
|
||||||
// serverForbiddenKexAlgos contains key exchange algorithms, that are forbidden
|
|
||||||
// for the server half.
|
|
||||||
var serverForbiddenKexAlgos = map[string]struct{}{
|
|
||||||
kexAlgoDHGEXSHA1: {}, // server half implementation is only minimal to satisfy the automated tests
|
|
||||||
kexAlgoDHGEXSHA256: {}, // server half implementation is only minimal to satisfy the automated tests
|
|
||||||
}
|
|
||||||
|
|
||||||
// preferredKexAlgos specifies the default preference for key-exchange
|
|
||||||
// algorithms in preference order. The diffie-hellman-group16-sha512 algorithm
|
|
||||||
// is disabled by default because it is a bit slower than the others.
|
|
||||||
var preferredKexAlgos = []string{
|
|
||||||
kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH,
|
|
||||||
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
|
|
||||||
kexAlgoDH14SHA256, kexAlgoDH14SHA1,
|
|
||||||
}
|
|
||||||
|
|
||||||
// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods
|
|
||||||
// of authenticating servers) in preference order.
|
|
||||||
var supportedHostKeyAlgos = []string{
|
|
||||||
CertAlgoRSASHA256v01, CertAlgoRSASHA512v01,
|
|
||||||
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
|
|
||||||
CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01,
|
|
||||||
|
|
||||||
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
|
|
||||||
KeyAlgoRSASHA256, KeyAlgoRSASHA512,
|
|
||||||
KeyAlgoRSA, KeyAlgoDSA,
|
|
||||||
|
|
||||||
KeyAlgoED25519,
|
|
||||||
}
|
|
||||||
|
|
||||||
// supportedMACs specifies a default set of MAC algorithms in preference order.
|
|
||||||
// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
|
|
||||||
// because they have reached the end of their useful life.
|
|
||||||
var supportedMACs = []string{
|
|
||||||
"hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com", "hmac-sha2-256", "hmac-sha2-512", "hmac-sha1", "hmac-sha1-96",
|
|
||||||
}
|
|
||||||
|
|
||||||
var supportedCompressions = []string{compressionNone}
|
|
||||||
|
|
||||||
// hashFuncs keeps the mapping of supported signature algorithms to their
|
|
||||||
// respective hashes needed for signing and verification.
|
|
||||||
var hashFuncs = map[string]crypto.Hash{
|
|
||||||
KeyAlgoRSA: crypto.SHA1,
|
|
||||||
KeyAlgoRSASHA256: crypto.SHA256,
|
|
||||||
KeyAlgoRSASHA512: crypto.SHA512,
|
|
||||||
KeyAlgoDSA: crypto.SHA1,
|
|
||||||
KeyAlgoECDSA256: crypto.SHA256,
|
|
||||||
KeyAlgoECDSA384: crypto.SHA384,
|
|
||||||
KeyAlgoECDSA521: crypto.SHA512,
|
|
||||||
// KeyAlgoED25519 doesn't pre-hash.
|
|
||||||
KeyAlgoSKECDSA256: crypto.SHA256,
|
|
||||||
KeyAlgoSKED25519: crypto.SHA256,
|
|
||||||
}
|
|
||||||
|
|
||||||
// algorithmsForKeyFormat returns the supported signature algorithms for a given
|
|
||||||
// public key format (PublicKey.Type), in order of preference. See RFC 8332,
|
|
||||||
// Section 2. See also the note in sendKexInit on backwards compatibility.
|
|
||||||
func algorithmsForKeyFormat(keyFormat string) []string {
|
|
||||||
switch keyFormat {
|
|
||||||
case KeyAlgoRSA:
|
|
||||||
return []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA}
|
|
||||||
case CertAlgoRSAv01:
|
|
||||||
return []string{CertAlgoRSASHA256v01, CertAlgoRSASHA512v01, CertAlgoRSAv01}
|
|
||||||
default:
|
|
||||||
return []string{keyFormat}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// isRSA returns whether algo is a supported RSA algorithm, including certificate
|
|
||||||
// algorithms.
|
|
||||||
func isRSA(algo string) bool {
|
|
||||||
algos := algorithmsForKeyFormat(KeyAlgoRSA)
|
|
||||||
return contains(algos, underlyingAlgo(algo))
|
|
||||||
}
|
|
||||||
|
|
||||||
// supportedPubKeyAuthAlgos specifies the supported client public key
|
|
||||||
// authentication algorithms. Note that this doesn't include certificate types
|
|
||||||
// since those use the underlying algorithm. This list is sent to the client if
|
|
||||||
// it supports the server-sig-algs extension. Order is irrelevant.
|
|
||||||
var supportedPubKeyAuthAlgos = []string{
|
|
||||||
KeyAlgoED25519,
|
|
||||||
KeyAlgoSKED25519, KeyAlgoSKECDSA256,
|
|
||||||
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
|
|
||||||
KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA,
|
|
||||||
KeyAlgoDSA,
|
|
||||||
}
|
|
||||||
|
|
||||||
// unexpectedMessageError results when the SSH message that we received didn't
|
|
||||||
// match what we wanted.
|
|
||||||
func unexpectedMessageError(expected, got uint8) error {
|
|
||||||
return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseError results from a malformed SSH message.
|
|
||||||
func parseError(tag uint8) error {
|
|
||||||
return fmt.Errorf("ssh: parse error in message type %d", tag)
|
|
||||||
}
|
|
||||||
|
|
||||||
func findCommon(what string, client []string, server []string) (common string, err error) {
|
|
||||||
for _, c := range client {
|
|
||||||
for _, s := range server {
|
|
||||||
if c == s {
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server)
|
|
||||||
}
|
|
||||||
|
|
||||||
// directionAlgorithms records algorithm choices in one direction (either read or write)
|
|
||||||
type directionAlgorithms struct {
|
|
||||||
Cipher string
|
|
||||||
MAC string
|
|
||||||
Compression string
|
|
||||||
}
|
|
||||||
|
|
||||||
// rekeyBytes returns a rekeying intervals in bytes.
|
|
||||||
func (a *directionAlgorithms) rekeyBytes() int64 {
|
|
||||||
// According to RFC 4344 block ciphers should rekey after
|
|
||||||
// 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is
|
|
||||||
// 128.
|
|
||||||
switch a.Cipher {
|
|
||||||
case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcm128CipherID, gcm256CipherID, aes128cbcID:
|
|
||||||
return 16 * (1 << 32)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// For others, stick with RFC 4253 recommendation to rekey after 1 Gb of data.
|
|
||||||
return 1 << 30
|
|
||||||
}
|
|
||||||
|
|
||||||
var aeadCiphers = map[string]bool{
|
|
||||||
gcm128CipherID: true,
|
|
||||||
gcm256CipherID: true,
|
|
||||||
chacha20Poly1305ID: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
type algorithms struct {
|
|
||||||
kex string
|
|
||||||
hostKey string
|
|
||||||
w directionAlgorithms
|
|
||||||
r directionAlgorithms
|
|
||||||
}
|
|
||||||
|
|
||||||
func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
|
|
||||||
result := &algorithms{}
|
|
||||||
|
|
||||||
result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
stoc, ctos := &result.w, &result.r
|
|
||||||
if isClient {
|
|
||||||
ctos, stoc = stoc, ctos
|
|
||||||
}
|
|
||||||
|
|
||||||
ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !aeadCiphers[ctos.Cipher] {
|
|
||||||
ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !aeadCiphers[stoc.Cipher] {
|
|
||||||
stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If rekeythreshold is too small, we can't make any progress sending
|
|
||||||
// stuff.
|
|
||||||
const minRekeyThreshold uint64 = 256
|
|
||||||
|
|
||||||
// Config contains configuration data common to both ServerConfig and
|
|
||||||
// ClientConfig.
|
|
||||||
type Config struct {
|
|
||||||
// Rand provides the source of entropy for cryptographic
|
|
||||||
// primitives. If Rand is nil, the cryptographic random reader
|
|
||||||
// in package crypto/rand will be used.
|
|
||||||
Rand io.Reader
|
|
||||||
|
|
||||||
// The maximum number of bytes sent or received after which a
|
|
||||||
// new key is negotiated. It must be at least 256. If
|
|
||||||
// unspecified, a size suitable for the chosen cipher is used.
|
|
||||||
RekeyThreshold uint64
|
|
||||||
|
|
||||||
// The allowed key exchanges algorithms. If unspecified then a default set
|
|
||||||
// of algorithms is used. Unsupported values are silently ignored.
|
|
||||||
KeyExchanges []string
|
|
||||||
|
|
||||||
// The allowed cipher algorithms. If unspecified then a sensible default is
|
|
||||||
// used. Unsupported values are silently ignored.
|
|
||||||
Ciphers []string
|
|
||||||
|
|
||||||
// The allowed MAC algorithms. If unspecified then a sensible default is
|
|
||||||
// used. Unsupported values are silently ignored.
|
|
||||||
MACs []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDefaults sets sensible values for unset fields in config. This is
|
|
||||||
// exported for testing: Configs passed to SSH functions are copied and have
|
|
||||||
// default values set automatically.
|
|
||||||
func (c *Config) SetDefaults() {
|
|
||||||
if c.Rand == nil {
|
|
||||||
c.Rand = rand.Reader
|
|
||||||
}
|
|
||||||
if c.Ciphers == nil {
|
|
||||||
c.Ciphers = preferredCiphers
|
|
||||||
}
|
|
||||||
var ciphers []string
|
|
||||||
for _, c := range c.Ciphers {
|
|
||||||
if cipherModes[c] != nil {
|
|
||||||
// Ignore the cipher if we have no cipherModes definition.
|
|
||||||
ciphers = append(ciphers, c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.Ciphers = ciphers
|
|
||||||
|
|
||||||
if c.KeyExchanges == nil {
|
|
||||||
c.KeyExchanges = preferredKexAlgos
|
|
||||||
}
|
|
||||||
var kexs []string
|
|
||||||
for _, k := range c.KeyExchanges {
|
|
||||||
if kexAlgoMap[k] != nil {
|
|
||||||
// Ignore the KEX if we have no kexAlgoMap definition.
|
|
||||||
kexs = append(kexs, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.KeyExchanges = kexs
|
|
||||||
|
|
||||||
if c.MACs == nil {
|
|
||||||
c.MACs = supportedMACs
|
|
||||||
}
|
|
||||||
var macs []string
|
|
||||||
for _, m := range c.MACs {
|
|
||||||
if macModes[m] != nil {
|
|
||||||
// Ignore the MAC if we have no macModes definition.
|
|
||||||
macs = append(macs, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.MACs = macs
|
|
||||||
|
|
||||||
if c.RekeyThreshold == 0 {
|
|
||||||
// cipher specific default
|
|
||||||
} else if c.RekeyThreshold < minRekeyThreshold {
|
|
||||||
c.RekeyThreshold = minRekeyThreshold
|
|
||||||
} else if c.RekeyThreshold >= math.MaxInt64 {
|
|
||||||
// Avoid weirdness if somebody uses -1 as a threshold.
|
|
||||||
c.RekeyThreshold = math.MaxInt64
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildDataSignedForAuth returns the data that is signed in order to prove
|
|
||||||
// possession of a private key. See RFC 4252, section 7. algo is the advertised
|
|
||||||
// algorithm, and may be a certificate type.
|
|
||||||
func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo string, pubKey []byte) []byte {
|
|
||||||
data := struct {
|
|
||||||
Session []byte
|
|
||||||
Type byte
|
|
||||||
User string
|
|
||||||
Service string
|
|
||||||
Method string
|
|
||||||
Sign bool
|
|
||||||
Algo string
|
|
||||||
PubKey []byte
|
|
||||||
}{
|
|
||||||
sessionID,
|
|
||||||
msgUserAuthRequest,
|
|
||||||
req.User,
|
|
||||||
req.Service,
|
|
||||||
req.Method,
|
|
||||||
true,
|
|
||||||
algo,
|
|
||||||
pubKey,
|
|
||||||
}
|
|
||||||
return Marshal(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendU16(buf []byte, n uint16) []byte {
|
|
||||||
return append(buf, byte(n>>8), byte(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendU32(buf []byte, n uint32) []byte {
|
|
||||||
return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendU64(buf []byte, n uint64) []byte {
|
|
||||||
return append(buf,
|
|
||||||
byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32),
|
|
||||||
byte(n>>24), byte(n>>16), byte(n>>8), byte(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendInt(buf []byte, n int) []byte {
|
|
||||||
return appendU32(buf, uint32(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendString(buf []byte, s string) []byte {
|
|
||||||
buf = appendU32(buf, uint32(len(s)))
|
|
||||||
buf = append(buf, s...)
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendBool(buf []byte, b bool) []byte {
|
|
||||||
if b {
|
|
||||||
return append(buf, 1)
|
|
||||||
}
|
|
||||||
return append(buf, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// newCond is a helper to hide the fact that there is no usable zero
|
|
||||||
// value for sync.Cond.
|
|
||||||
func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) }
|
|
||||||
|
|
||||||
// window represents the buffer available to clients
|
|
||||||
// wishing to write to a channel.
|
|
||||||
type window struct {
|
|
||||||
*sync.Cond
|
|
||||||
win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
|
|
||||||
writeWaiters int
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// add adds win to the amount of window available
|
|
||||||
// for consumers.
|
|
||||||
func (w *window) add(win uint32) bool {
|
|
||||||
// a zero sized window adjust is a noop.
|
|
||||||
if win == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
w.L.Lock()
|
|
||||||
if w.win+win < win {
|
|
||||||
w.L.Unlock()
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
w.win += win
|
|
||||||
// It is unusual that multiple goroutines would be attempting to reserve
|
|
||||||
// window space, but not guaranteed. Use broadcast to notify all waiters
|
|
||||||
// that additional window is available.
|
|
||||||
w.Broadcast()
|
|
||||||
w.L.Unlock()
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// close sets the window to closed, so all reservations fail
|
|
||||||
// immediately.
|
|
||||||
func (w *window) close() {
|
|
||||||
w.L.Lock()
|
|
||||||
w.closed = true
|
|
||||||
w.Broadcast()
|
|
||||||
w.L.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// reserve reserves win from the available window capacity.
|
|
||||||
// If no capacity remains, reserve will block. reserve may
|
|
||||||
// return less than requested.
|
|
||||||
func (w *window) reserve(win uint32) (uint32, error) {
|
|
||||||
var err error
|
|
||||||
w.L.Lock()
|
|
||||||
w.writeWaiters++
|
|
||||||
w.Broadcast()
|
|
||||||
for w.win == 0 && !w.closed {
|
|
||||||
w.Wait()
|
|
||||||
}
|
|
||||||
w.writeWaiters--
|
|
||||||
if w.win < win {
|
|
||||||
win = w.win
|
|
||||||
}
|
|
||||||
w.win -= win
|
|
||||||
if w.closed {
|
|
||||||
err = io.EOF
|
|
||||||
}
|
|
||||||
w.L.Unlock()
|
|
||||||
return win, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitWriterBlocked waits until some goroutine is blocked for further
|
|
||||||
// writes. It is used in tests only.
|
|
||||||
func (w *window) waitWriterBlocked() {
|
|
||||||
w.Cond.L.Lock()
|
|
||||||
for w.writeWaiters == 0 {
|
|
||||||
w.Cond.Wait()
|
|
||||||
}
|
|
||||||
w.Cond.L.Unlock()
|
|
||||||
}
|
|
@ -1,143 +0,0 @@
|
|||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// OpenChannelError is returned if the other side rejects an
|
|
||||||
// OpenChannel request.
|
|
||||||
type OpenChannelError struct {
|
|
||||||
Reason RejectionReason
|
|
||||||
Message string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *OpenChannelError) Error() string {
|
|
||||||
return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnMetadata holds metadata for the connection.
|
|
||||||
type ConnMetadata interface {
|
|
||||||
// User returns the user ID for this connection.
|
|
||||||
User() string
|
|
||||||
|
|
||||||
// SessionID returns the session hash, also denoted by H.
|
|
||||||
SessionID() []byte
|
|
||||||
|
|
||||||
// ClientVersion returns the client's version string as hashed
|
|
||||||
// into the session ID.
|
|
||||||
ClientVersion() []byte
|
|
||||||
|
|
||||||
// ServerVersion returns the server's version string as hashed
|
|
||||||
// into the session ID.
|
|
||||||
ServerVersion() []byte
|
|
||||||
|
|
||||||
// RemoteAddr returns the remote address for this connection.
|
|
||||||
RemoteAddr() net.Addr
|
|
||||||
|
|
||||||
// LocalAddr returns the local address for this connection.
|
|
||||||
LocalAddr() net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Conn represents an SSH connection for both server and client roles.
|
|
||||||
// Conn is the basis for implementing an application layer, such
|
|
||||||
// as ClientConn, which implements the traditional shell access for
|
|
||||||
// clients.
|
|
||||||
type Conn interface {
|
|
||||||
ConnMetadata
|
|
||||||
|
|
||||||
// SendRequest sends a global request, and returns the
|
|
||||||
// reply. If wantReply is true, it returns the response status
|
|
||||||
// and payload. See also RFC 4254, section 4.
|
|
||||||
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error)
|
|
||||||
|
|
||||||
// OpenChannel tries to open an channel. If the request is
|
|
||||||
// rejected, it returns *OpenChannelError. On success it returns
|
|
||||||
// the SSH Channel and a Go channel for incoming, out-of-band
|
|
||||||
// requests. The Go channel must be serviced, or the
|
|
||||||
// connection will hang.
|
|
||||||
OpenChannel(name string, data []byte) (Channel, <-chan *Request, error)
|
|
||||||
|
|
||||||
// Close closes the underlying network connection
|
|
||||||
Close() error
|
|
||||||
|
|
||||||
// Wait blocks until the connection has shut down, and returns the
|
|
||||||
// error causing the shutdown.
|
|
||||||
Wait() error
|
|
||||||
|
|
||||||
// TODO(hanwen): consider exposing:
|
|
||||||
// RequestKeyChange
|
|
||||||
// Disconnect
|
|
||||||
}
|
|
||||||
|
|
||||||
// DiscardRequests consumes and rejects all requests from the
|
|
||||||
// passed-in channel.
|
|
||||||
func DiscardRequests(in <-chan *Request) {
|
|
||||||
for req := range in {
|
|
||||||
if req.WantReply {
|
|
||||||
req.Reply(false, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// A connection represents an incoming connection.
|
|
||||||
type connection struct {
|
|
||||||
transport *handshakeTransport
|
|
||||||
sshConn
|
|
||||||
|
|
||||||
// The connection protocol.
|
|
||||||
*mux
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *connection) Close() error {
|
|
||||||
return c.sshConn.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// sshConn provides net.Conn metadata, but disallows direct reads and
|
|
||||||
// writes.
|
|
||||||
type sshConn struct {
|
|
||||||
conn net.Conn
|
|
||||||
|
|
||||||
user string
|
|
||||||
sessionID []byte
|
|
||||||
clientVersion []byte
|
|
||||||
serverVersion []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func dup(src []byte) []byte {
|
|
||||||
dst := make([]byte, len(src))
|
|
||||||
copy(dst, src)
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshConn) User() string {
|
|
||||||
return c.user
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshConn) RemoteAddr() net.Addr {
|
|
||||||
return c.conn.RemoteAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshConn) Close() error {
|
|
||||||
return c.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshConn) LocalAddr() net.Addr {
|
|
||||||
return c.conn.LocalAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshConn) SessionID() []byte {
|
|
||||||
return dup(c.sessionID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshConn) ClientVersion() []byte {
|
|
||||||
return dup(c.clientVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshConn) ServerVersion() []byte {
|
|
||||||
return dup(c.serverVersion)
|
|
||||||
}
|
|
@ -1,23 +0,0 @@
|
|||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
/*
|
|
||||||
Package ssh implements an SSH client and server.
|
|
||||||
|
|
||||||
SSH is a transport security protocol, an authentication protocol and a
|
|
||||||
family of application protocols. The most typical application level
|
|
||||||
protocol is a remote shell and this is specifically implemented. However,
|
|
||||||
the multiplexed nature of SSH is exposed to users that wish to support
|
|
||||||
others.
|
|
||||||
|
|
||||||
References:
|
|
||||||
|
|
||||||
[PROTOCOL]: https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL?rev=HEAD
|
|
||||||
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
|
|
||||||
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
|
|
||||||
|
|
||||||
This package does not fall under the stability promise of the Go language itself,
|
|
||||||
so its API may be changed when pressing needs arise.
|
|
||||||
*/
|
|
||||||
package ssh // import "github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
|
@ -1,758 +0,0 @@
|
|||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// debugHandshake, if set, prints messages sent and received. Key
|
|
||||||
// exchange messages are printed as if DH were used, so the debug
|
|
||||||
// messages are wrong when using ECDH.
|
|
||||||
const debugHandshake = false
|
|
||||||
|
|
||||||
// chanSize sets the amount of buffering SSH connections. This is
|
|
||||||
// primarily for testing: setting chanSize=0 uncovers deadlocks more
|
|
||||||
// quickly.
|
|
||||||
const chanSize = 16
|
|
||||||
|
|
||||||
// keyingTransport is a packet based transport that supports key
|
|
||||||
// changes. It need not be thread-safe. It should pass through
|
|
||||||
// msgNewKeys in both directions.
|
|
||||||
type keyingTransport interface {
|
|
||||||
packetConn
|
|
||||||
|
|
||||||
// prepareKeyChange sets up a key change. The key change for a
|
|
||||||
// direction will be effected if a msgNewKeys message is sent
|
|
||||||
// or received.
|
|
||||||
prepareKeyChange(*algorithms, *kexResult) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// handshakeTransport implements rekeying on top of a keyingTransport
|
|
||||||
// and offers a thread-safe writePacket() interface.
|
|
||||||
type handshakeTransport struct {
|
|
||||||
conn keyingTransport
|
|
||||||
config *Config
|
|
||||||
|
|
||||||
serverVersion []byte
|
|
||||||
clientVersion []byte
|
|
||||||
|
|
||||||
// hostKeys is non-empty if we are the server. In that case,
|
|
||||||
// it contains all host keys that can be used to sign the
|
|
||||||
// connection.
|
|
||||||
hostKeys []Signer
|
|
||||||
|
|
||||||
// publicKeyAuthAlgorithms is non-empty if we are the server. In that case,
|
|
||||||
// it contains the supported client public key authentication algorithms.
|
|
||||||
publicKeyAuthAlgorithms []string
|
|
||||||
|
|
||||||
// hostKeyAlgorithms is non-empty if we are the client. In that case,
|
|
||||||
// we accept these key types from the server as host key.
|
|
||||||
hostKeyAlgorithms []string
|
|
||||||
|
|
||||||
// On read error, incoming is closed, and readError is set.
|
|
||||||
incoming chan []byte
|
|
||||||
readError error
|
|
||||||
|
|
||||||
mu sync.Mutex
|
|
||||||
writeError error
|
|
||||||
sentInitPacket []byte
|
|
||||||
sentInitMsg *kexInitMsg
|
|
||||||
pendingPackets [][]byte // Used when a key exchange is in progress.
|
|
||||||
writePacketsLeft uint32
|
|
||||||
writeBytesLeft int64
|
|
||||||
|
|
||||||
// If the read loop wants to schedule a kex, it pings this
|
|
||||||
// channel, and the write loop will send out a kex
|
|
||||||
// message.
|
|
||||||
requestKex chan struct{}
|
|
||||||
|
|
||||||
// If the other side requests or confirms a kex, its kexInit
|
|
||||||
// packet is sent here for the write loop to find it.
|
|
||||||
startKex chan *pendingKex
|
|
||||||
kexLoopDone chan struct{} // closed (with writeError non-nil) when kexLoop exits
|
|
||||||
|
|
||||||
// data for host key checking
|
|
||||||
hostKeyCallback HostKeyCallback
|
|
||||||
dialAddress string
|
|
||||||
remoteAddr net.Addr
|
|
||||||
|
|
||||||
// bannerCallback is non-empty if we are the client and it has been set in
|
|
||||||
// ClientConfig. In that case it is called during the user authentication
|
|
||||||
// dance to handle a custom server's message.
|
|
||||||
bannerCallback BannerCallback
|
|
||||||
|
|
||||||
// Algorithms agreed in the last key exchange.
|
|
||||||
algorithms *algorithms
|
|
||||||
|
|
||||||
// Counters exclusively owned by readLoop.
|
|
||||||
readPacketsLeft uint32
|
|
||||||
readBytesLeft int64
|
|
||||||
|
|
||||||
// The session ID or nil if first kex did not complete yet.
|
|
||||||
sessionID []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type pendingKex struct {
|
|
||||||
otherInit []byte
|
|
||||||
done chan error
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
|
|
||||||
t := &handshakeTransport{
|
|
||||||
conn: conn,
|
|
||||||
serverVersion: serverVersion,
|
|
||||||
clientVersion: clientVersion,
|
|
||||||
incoming: make(chan []byte, chanSize),
|
|
||||||
requestKex: make(chan struct{}, 1),
|
|
||||||
startKex: make(chan *pendingKex),
|
|
||||||
kexLoopDone: make(chan struct{}),
|
|
||||||
|
|
||||||
config: config,
|
|
||||||
}
|
|
||||||
t.resetReadThresholds()
|
|
||||||
t.resetWriteThresholds()
|
|
||||||
|
|
||||||
// We always start with a mandatory key exchange.
|
|
||||||
t.requestKex <- struct{}{}
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport {
|
|
||||||
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
|
|
||||||
t.dialAddress = dialAddr
|
|
||||||
t.remoteAddr = addr
|
|
||||||
t.hostKeyCallback = config.HostKeyCallback
|
|
||||||
t.bannerCallback = config.BannerCallback
|
|
||||||
if config.HostKeyAlgorithms != nil {
|
|
||||||
t.hostKeyAlgorithms = config.HostKeyAlgorithms
|
|
||||||
} else {
|
|
||||||
t.hostKeyAlgorithms = supportedHostKeyAlgos
|
|
||||||
}
|
|
||||||
go t.readLoop()
|
|
||||||
go t.kexLoop()
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport {
|
|
||||||
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
|
|
||||||
t.hostKeys = config.hostKeys
|
|
||||||
t.publicKeyAuthAlgorithms = config.PublicKeyAuthAlgorithms
|
|
||||||
go t.readLoop()
|
|
||||||
go t.kexLoop()
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) getSessionID() []byte {
|
|
||||||
return t.sessionID
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitSession waits for the session to be established. This should be
|
|
||||||
// the first thing to call after instantiating handshakeTransport.
|
|
||||||
func (t *handshakeTransport) waitSession() error {
|
|
||||||
p, err := t.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if p[0] != msgNewKeys {
|
|
||||||
return fmt.Errorf("ssh: first packet should be msgNewKeys")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) id() string {
|
|
||||||
if len(t.hostKeys) > 0 {
|
|
||||||
return "server"
|
|
||||||
}
|
|
||||||
return "client"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) printPacket(p []byte, write bool) {
|
|
||||||
action := "got"
|
|
||||||
if write {
|
|
||||||
action = "sent"
|
|
||||||
}
|
|
||||||
|
|
||||||
if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
|
|
||||||
log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
|
|
||||||
} else {
|
|
||||||
msg, err := decode(p)
|
|
||||||
log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) readPacket() ([]byte, error) {
|
|
||||||
p, ok := <-t.incoming
|
|
||||||
if !ok {
|
|
||||||
return nil, t.readError
|
|
||||||
}
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) readLoop() {
|
|
||||||
first := true
|
|
||||||
for {
|
|
||||||
p, err := t.readOnePacket(first)
|
|
||||||
first = false
|
|
||||||
if err != nil {
|
|
||||||
t.readError = err
|
|
||||||
close(t.incoming)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if p[0] == msgIgnore || p[0] == msgDebug {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
t.incoming <- p
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop writers too.
|
|
||||||
t.recordWriteError(t.readError)
|
|
||||||
|
|
||||||
// Unblock the writer should it wait for this.
|
|
||||||
close(t.startKex)
|
|
||||||
|
|
||||||
// Don't close t.requestKex; it's also written to from writePacket.
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) pushPacket(p []byte) error {
|
|
||||||
if debugHandshake {
|
|
||||||
t.printPacket(p, true)
|
|
||||||
}
|
|
||||||
return t.conn.writePacket(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) getWriteError() error {
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
return t.writeError
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) recordWriteError(err error) {
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
if t.writeError == nil && err != nil {
|
|
||||||
t.writeError = err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) requestKeyExchange() {
|
|
||||||
select {
|
|
||||||
case t.requestKex <- struct{}{}:
|
|
||||||
default:
|
|
||||||
// something already requested a kex, so do nothing.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) resetWriteThresholds() {
|
|
||||||
t.writePacketsLeft = packetRekeyThreshold
|
|
||||||
if t.config.RekeyThreshold > 0 {
|
|
||||||
t.writeBytesLeft = int64(t.config.RekeyThreshold)
|
|
||||||
} else if t.algorithms != nil {
|
|
||||||
t.writeBytesLeft = t.algorithms.w.rekeyBytes()
|
|
||||||
} else {
|
|
||||||
t.writeBytesLeft = 1 << 30
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) kexLoop() {
|
|
||||||
|
|
||||||
write:
|
|
||||||
for t.getWriteError() == nil {
|
|
||||||
var request *pendingKex
|
|
||||||
var sent bool
|
|
||||||
|
|
||||||
for request == nil || !sent {
|
|
||||||
var ok bool
|
|
||||||
select {
|
|
||||||
case request, ok = <-t.startKex:
|
|
||||||
if !ok {
|
|
||||||
break write
|
|
||||||
}
|
|
||||||
case <-t.requestKex:
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if !sent {
|
|
||||||
if err := t.sendKexInit(); err != nil {
|
|
||||||
t.recordWriteError(err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
sent = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := t.getWriteError(); err != nil {
|
|
||||||
if request != nil {
|
|
||||||
request.done <- err
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// We're not servicing t.requestKex, but that is OK:
|
|
||||||
// we never block on sending to t.requestKex.
|
|
||||||
|
|
||||||
// We're not servicing t.startKex, but the remote end
|
|
||||||
// has just sent us a kexInitMsg, so it can't send
|
|
||||||
// another key change request, until we close the done
|
|
||||||
// channel on the pendingKex request.
|
|
||||||
|
|
||||||
err := t.enterKeyExchange(request.otherInit)
|
|
||||||
|
|
||||||
t.mu.Lock()
|
|
||||||
t.writeError = err
|
|
||||||
t.sentInitPacket = nil
|
|
||||||
t.sentInitMsg = nil
|
|
||||||
|
|
||||||
t.resetWriteThresholds()
|
|
||||||
|
|
||||||
// we have completed the key exchange. Since the
|
|
||||||
// reader is still blocked, it is safe to clear out
|
|
||||||
// the requestKex channel. This avoids the situation
|
|
||||||
// where: 1) we consumed our own request for the
|
|
||||||
// initial kex, and 2) the kex from the remote side
|
|
||||||
// caused another send on the requestKex channel,
|
|
||||||
clear:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-t.requestKex:
|
|
||||||
//
|
|
||||||
default:
|
|
||||||
break clear
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
request.done <- t.writeError
|
|
||||||
|
|
||||||
// kex finished. Push packets that we received while
|
|
||||||
// the kex was in progress. Don't look at t.startKex
|
|
||||||
// and don't increment writtenSinceKex: if we trigger
|
|
||||||
// another kex while we are still busy with the last
|
|
||||||
// one, things will become very confusing.
|
|
||||||
for _, p := range t.pendingPackets {
|
|
||||||
t.writeError = t.pushPacket(p)
|
|
||||||
if t.writeError != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.pendingPackets = t.pendingPackets[:0]
|
|
||||||
t.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unblock reader.
|
|
||||||
t.conn.Close()
|
|
||||||
|
|
||||||
// drain startKex channel. We don't service t.requestKex
|
|
||||||
// because nobody does blocking sends there.
|
|
||||||
for request := range t.startKex {
|
|
||||||
request.done <- t.getWriteError()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark that the loop is done so that Close can return.
|
|
||||||
close(t.kexLoopDone)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The protocol uses uint32 for packet counters, so we can't let them
|
|
||||||
// reach 1<<32. We will actually read and write more packets than
|
|
||||||
// this, though: the other side may send more packets, and after we
|
|
||||||
// hit this limit on writing we will send a few more packets for the
|
|
||||||
// key exchange itself.
|
|
||||||
const packetRekeyThreshold = (1 << 31)
|
|
||||||
|
|
||||||
func (t *handshakeTransport) resetReadThresholds() {
|
|
||||||
t.readPacketsLeft = packetRekeyThreshold
|
|
||||||
if t.config.RekeyThreshold > 0 {
|
|
||||||
t.readBytesLeft = int64(t.config.RekeyThreshold)
|
|
||||||
} else if t.algorithms != nil {
|
|
||||||
t.readBytesLeft = t.algorithms.r.rekeyBytes()
|
|
||||||
} else {
|
|
||||||
t.readBytesLeft = 1 << 30
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
|
|
||||||
p, err := t.conn.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.readPacketsLeft > 0 {
|
|
||||||
t.readPacketsLeft--
|
|
||||||
} else {
|
|
||||||
t.requestKeyExchange()
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.readBytesLeft > 0 {
|
|
||||||
t.readBytesLeft -= int64(len(p))
|
|
||||||
} else {
|
|
||||||
t.requestKeyExchange()
|
|
||||||
}
|
|
||||||
|
|
||||||
if debugHandshake {
|
|
||||||
t.printPacket(p, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
if first && p[0] != msgKexInit {
|
|
||||||
return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
|
|
||||||
}
|
|
||||||
|
|
||||||
if p[0] != msgKexInit {
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
firstKex := t.sessionID == nil
|
|
||||||
|
|
||||||
kex := pendingKex{
|
|
||||||
done: make(chan error, 1),
|
|
||||||
otherInit: p,
|
|
||||||
}
|
|
||||||
t.startKex <- &kex
|
|
||||||
err = <-kex.done
|
|
||||||
|
|
||||||
if debugHandshake {
|
|
||||||
log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
t.resetReadThresholds()
|
|
||||||
|
|
||||||
// By default, a key exchange is hidden from higher layers by
|
|
||||||
// translating it into msgIgnore.
|
|
||||||
successPacket := []byte{msgIgnore}
|
|
||||||
if firstKex {
|
|
||||||
// sendKexInit() for the first kex waits for
|
|
||||||
// msgNewKeys so the authentication process is
|
|
||||||
// guaranteed to happen over an encrypted transport.
|
|
||||||
successPacket = []byte{msgNewKeys}
|
|
||||||
}
|
|
||||||
|
|
||||||
return successPacket, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendKexInit sends a key change message.
|
|
||||||
func (t *handshakeTransport) sendKexInit() error {
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
if t.sentInitMsg != nil {
|
|
||||||
// kexInits may be sent either in response to the other side,
|
|
||||||
// or because our side wants to initiate a key change, so we
|
|
||||||
// may have already sent a kexInit. In that case, don't send a
|
|
||||||
// second kexInit.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := &kexInitMsg{
|
|
||||||
KexAlgos: t.config.KeyExchanges,
|
|
||||||
CiphersClientServer: t.config.Ciphers,
|
|
||||||
CiphersServerClient: t.config.Ciphers,
|
|
||||||
MACsClientServer: t.config.MACs,
|
|
||||||
MACsServerClient: t.config.MACs,
|
|
||||||
CompressionClientServer: supportedCompressions,
|
|
||||||
CompressionServerClient: supportedCompressions,
|
|
||||||
}
|
|
||||||
io.ReadFull(rand.Reader, msg.Cookie[:])
|
|
||||||
|
|
||||||
isServer := len(t.hostKeys) > 0
|
|
||||||
if isServer {
|
|
||||||
for _, k := range t.hostKeys {
|
|
||||||
// If k is a MultiAlgorithmSigner, we restrict the signature
|
|
||||||
// algorithms. If k is a AlgorithmSigner, presume it supports all
|
|
||||||
// signature algorithms associated with the key format. If k is not
|
|
||||||
// an AlgorithmSigner, we can only assume it only supports the
|
|
||||||
// algorithms that matches the key format. (This means that Sign
|
|
||||||
// can't pick a different default).
|
|
||||||
keyFormat := k.PublicKey().Type()
|
|
||||||
|
|
||||||
switch s := k.(type) {
|
|
||||||
case MultiAlgorithmSigner:
|
|
||||||
for _, algo := range algorithmsForKeyFormat(keyFormat) {
|
|
||||||
if contains(s.Algorithms(), underlyingAlgo(algo)) {
|
|
||||||
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case AlgorithmSigner:
|
|
||||||
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algorithmsForKeyFormat(keyFormat)...)
|
|
||||||
default:
|
|
||||||
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
|
|
||||||
|
|
||||||
// As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what
|
|
||||||
// algorithms the server supports for public key authentication. See RFC
|
|
||||||
// 8308, Section 2.1.
|
|
||||||
if firstKeyExchange := t.sessionID == nil; firstKeyExchange {
|
|
||||||
msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+1)
|
|
||||||
msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
|
|
||||||
msg.KexAlgos = append(msg.KexAlgos, "ext-info-c")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
packet := Marshal(msg)
|
|
||||||
|
|
||||||
// writePacket destroys the contents, so save a copy.
|
|
||||||
packetCopy := make([]byte, len(packet))
|
|
||||||
copy(packetCopy, packet)
|
|
||||||
|
|
||||||
if err := t.pushPacket(packetCopy); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
t.sentInitMsg = msg
|
|
||||||
t.sentInitPacket = packet
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) writePacket(p []byte) error {
|
|
||||||
switch p[0] {
|
|
||||||
case msgKexInit:
|
|
||||||
return errors.New("ssh: only handshakeTransport can send kexInit")
|
|
||||||
case msgNewKeys:
|
|
||||||
return errors.New("ssh: only handshakeTransport can send newKeys")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
if t.writeError != nil {
|
|
||||||
return t.writeError
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.sentInitMsg != nil {
|
|
||||||
// Copy the packet so the writer can reuse the buffer.
|
|
||||||
cp := make([]byte, len(p))
|
|
||||||
copy(cp, p)
|
|
||||||
t.pendingPackets = append(t.pendingPackets, cp)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.writeBytesLeft > 0 {
|
|
||||||
t.writeBytesLeft -= int64(len(p))
|
|
||||||
} else {
|
|
||||||
t.requestKeyExchange()
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.writePacketsLeft > 0 {
|
|
||||||
t.writePacketsLeft--
|
|
||||||
} else {
|
|
||||||
t.requestKeyExchange()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := t.pushPacket(p); err != nil {
|
|
||||||
t.writeError = err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) Close() error {
|
|
||||||
// Close the connection. This should cause the readLoop goroutine to wake up
|
|
||||||
// and close t.startKex, which will shut down kexLoop if running.
|
|
||||||
err := t.conn.Close()
|
|
||||||
|
|
||||||
// Wait for the kexLoop goroutine to complete.
|
|
||||||
// At that point we know that the readLoop goroutine is complete too,
|
|
||||||
// because kexLoop itself waits for readLoop to close the startKex channel.
|
|
||||||
<-t.kexLoopDone
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
|
|
||||||
if debugHandshake {
|
|
||||||
log.Printf("%s entered key exchange", t.id())
|
|
||||||
}
|
|
||||||
|
|
||||||
otherInit := &kexInitMsg{}
|
|
||||||
if err := Unmarshal(otherInitPacket, otherInit); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
magics := handshakeMagics{
|
|
||||||
clientVersion: t.clientVersion,
|
|
||||||
serverVersion: t.serverVersion,
|
|
||||||
clientKexInit: otherInitPacket,
|
|
||||||
serverKexInit: t.sentInitPacket,
|
|
||||||
}
|
|
||||||
|
|
||||||
clientInit := otherInit
|
|
||||||
serverInit := t.sentInitMsg
|
|
||||||
isClient := len(t.hostKeys) == 0
|
|
||||||
if isClient {
|
|
||||||
clientInit, serverInit = serverInit, clientInit
|
|
||||||
|
|
||||||
magics.clientKexInit = t.sentInitPacket
|
|
||||||
magics.serverKexInit = otherInitPacket
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We don't send FirstKexFollows, but we handle receiving it.
|
|
||||||
//
|
|
||||||
// RFC 4253 section 7 defines the kex and the agreement method for
|
|
||||||
// first_kex_packet_follows. It states that the guessed packet
|
|
||||||
// should be ignored if the "kex algorithm and/or the host
|
|
||||||
// key algorithm is guessed wrong (server and client have
|
|
||||||
// different preferred algorithm), or if any of the other
|
|
||||||
// algorithms cannot be agreed upon". The other algorithms have
|
|
||||||
// already been checked above so the kex algorithm and host key
|
|
||||||
// algorithm are checked here.
|
|
||||||
if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) {
|
|
||||||
// other side sent a kex message for the wrong algorithm,
|
|
||||||
// which we have to ignore.
|
|
||||||
if _, err := t.conn.readPacket(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
kex, ok := kexAlgoMap[t.algorithms.kex]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
|
|
||||||
}
|
|
||||||
|
|
||||||
var result *kexResult
|
|
||||||
if len(t.hostKeys) > 0 {
|
|
||||||
result, err = t.server(kex, &magics)
|
|
||||||
} else {
|
|
||||||
result, err = t.client(kex, &magics)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
firstKeyExchange := t.sessionID == nil
|
|
||||||
if firstKeyExchange {
|
|
||||||
t.sessionID = result.H
|
|
||||||
}
|
|
||||||
result.SessionID = t.sessionID
|
|
||||||
|
|
||||||
if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// On the server side, after the first SSH_MSG_NEWKEYS, send a SSH_MSG_EXT_INFO
|
|
||||||
// message with the server-sig-algs extension if the client supports it. See
|
|
||||||
// RFC 8308, Sections 2.4 and 3.1, and [PROTOCOL], Section 1.9.
|
|
||||||
if !isClient && firstKeyExchange && contains(clientInit.KexAlgos, "ext-info-c") {
|
|
||||||
supportedPubKeyAuthAlgosList := strings.Join(t.publicKeyAuthAlgorithms, ",")
|
|
||||||
extInfo := &extInfoMsg{
|
|
||||||
NumExtensions: 2,
|
|
||||||
Payload: make([]byte, 0, 4+15+4+len(supportedPubKeyAuthAlgosList)+4+16+4+1),
|
|
||||||
}
|
|
||||||
extInfo.Payload = appendInt(extInfo.Payload, len("server-sig-algs"))
|
|
||||||
extInfo.Payload = append(extInfo.Payload, "server-sig-algs"...)
|
|
||||||
extInfo.Payload = appendInt(extInfo.Payload, len(supportedPubKeyAuthAlgosList))
|
|
||||||
extInfo.Payload = append(extInfo.Payload, supportedPubKeyAuthAlgosList...)
|
|
||||||
extInfo.Payload = appendInt(extInfo.Payload, len("ping@openssh.com"))
|
|
||||||
extInfo.Payload = append(extInfo.Payload, "ping@openssh.com"...)
|
|
||||||
extInfo.Payload = appendInt(extInfo.Payload, 1)
|
|
||||||
extInfo.Payload = append(extInfo.Payload, "0"...)
|
|
||||||
if err := t.conn.writePacket(Marshal(extInfo)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if packet, err := t.conn.readPacket(); err != nil {
|
|
||||||
return err
|
|
||||||
} else if packet[0] != msgNewKeys {
|
|
||||||
return unexpectedMessageError(msgNewKeys, packet[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// algorithmSignerWrapper is an AlgorithmSigner that only supports the default
|
|
||||||
// key format algorithm.
|
|
||||||
//
|
|
||||||
// This is technically a violation of the AlgorithmSigner interface, but it
|
|
||||||
// should be unreachable given where we use this. Anyway, at least it returns an
|
|
||||||
// error instead of panicing or producing an incorrect signature.
|
|
||||||
type algorithmSignerWrapper struct {
|
|
||||||
Signer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a algorithmSignerWrapper) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
|
|
||||||
if algorithm != underlyingAlgo(a.PublicKey().Type()) {
|
|
||||||
return nil, errors.New("ssh: internal error: algorithmSignerWrapper invoked with non-default algorithm")
|
|
||||||
}
|
|
||||||
return a.Sign(rand, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner {
|
|
||||||
for _, k := range hostKeys {
|
|
||||||
if s, ok := k.(MultiAlgorithmSigner); ok {
|
|
||||||
if !contains(s.Algorithms(), underlyingAlgo(algo)) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if algo == k.PublicKey().Type() {
|
|
||||||
return algorithmSignerWrapper{k}
|
|
||||||
}
|
|
||||||
|
|
||||||
k, ok := k.(AlgorithmSigner)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, a := range algorithmsForKeyFormat(k.PublicKey().Type()) {
|
|
||||||
if algo == a {
|
|
||||||
return k
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) server(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
|
|
||||||
hostKey := pickHostKey(t.hostKeys, t.algorithms.hostKey)
|
|
||||||
if hostKey == nil {
|
|
||||||
return nil, errors.New("ssh: internal error: negotiated unsupported signature type")
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.hostKey)
|
|
||||||
return r, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *handshakeTransport) client(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
|
|
||||||
result, err := kex.Client(t.conn, t.config.Rand, magics)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hostKey, err := ParsePublicKey(result.HostKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := verifyHostKeySignature(hostKey, t.algorithms.hostKey, result); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
@ -1,93 +0,0 @@
|
|||||||
// Copyright 2014 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// Package bcrypt_pbkdf implements bcrypt_pbkdf(3) from OpenBSD.
|
|
||||||
//
|
|
||||||
// See https://flak.tedunangst.com/post/bcrypt-pbkdf and
|
|
||||||
// https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/lib/libutil/bcrypt_pbkdf.c.
|
|
||||||
package bcrypt_pbkdf
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/sha512"
|
|
||||||
"errors"
|
|
||||||
"golang.org/x/crypto/blowfish"
|
|
||||||
)
|
|
||||||
|
|
||||||
const blockSize = 32
|
|
||||||
|
|
||||||
// Key derives a key from the password, salt and rounds count, returning a
|
|
||||||
// []byte of length keyLen that can be used as cryptographic key.
|
|
||||||
func Key(password, salt []byte, rounds, keyLen int) ([]byte, error) {
|
|
||||||
if rounds < 1 {
|
|
||||||
return nil, errors.New("bcrypt_pbkdf: number of rounds is too small")
|
|
||||||
}
|
|
||||||
if len(password) == 0 {
|
|
||||||
return nil, errors.New("bcrypt_pbkdf: empty password")
|
|
||||||
}
|
|
||||||
if len(salt) == 0 || len(salt) > 1<<20 {
|
|
||||||
return nil, errors.New("bcrypt_pbkdf: bad salt length")
|
|
||||||
}
|
|
||||||
if keyLen > 1024 {
|
|
||||||
return nil, errors.New("bcrypt_pbkdf: keyLen is too large")
|
|
||||||
}
|
|
||||||
|
|
||||||
numBlocks := (keyLen + blockSize - 1) / blockSize
|
|
||||||
key := make([]byte, numBlocks*blockSize)
|
|
||||||
|
|
||||||
h := sha512.New()
|
|
||||||
h.Write(password)
|
|
||||||
shapass := h.Sum(nil)
|
|
||||||
|
|
||||||
shasalt := make([]byte, 0, sha512.Size)
|
|
||||||
cnt, tmp := make([]byte, 4), make([]byte, blockSize)
|
|
||||||
for block := 1; block <= numBlocks; block++ {
|
|
||||||
h.Reset()
|
|
||||||
h.Write(salt)
|
|
||||||
cnt[0] = byte(block >> 24)
|
|
||||||
cnt[1] = byte(block >> 16)
|
|
||||||
cnt[2] = byte(block >> 8)
|
|
||||||
cnt[3] = byte(block)
|
|
||||||
h.Write(cnt)
|
|
||||||
bcryptHash(tmp, shapass, h.Sum(shasalt))
|
|
||||||
|
|
||||||
out := make([]byte, blockSize)
|
|
||||||
copy(out, tmp)
|
|
||||||
for i := 2; i <= rounds; i++ {
|
|
||||||
h.Reset()
|
|
||||||
h.Write(tmp)
|
|
||||||
bcryptHash(tmp, shapass, h.Sum(shasalt))
|
|
||||||
for j := 0; j < len(out); j++ {
|
|
||||||
out[j] ^= tmp[j]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, v := range out {
|
|
||||||
key[i*numBlocks+(block-1)] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return key[:keyLen], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var magic = []byte("OxychromaticBlowfishSwatDynamite")
|
|
||||||
|
|
||||||
func bcryptHash(out, shapass, shasalt []byte) {
|
|
||||||
c, err := blowfish.NewSaltedCipher(shapass, shasalt)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
for i := 0; i < 64; i++ {
|
|
||||||
blowfish.ExpandKey(shasalt, c)
|
|
||||||
blowfish.ExpandKey(shapass, c)
|
|
||||||
}
|
|
||||||
copy(out, magic)
|
|
||||||
for i := 0; i < 32; i += 8 {
|
|
||||||
for j := 0; j < 64; j++ {
|
|
||||||
c.Encrypt(out[i:i+8], out[i:i+8])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Swap bytes due to different endianness.
|
|
||||||
for i := 0; i < 32; i += 4 {
|
|
||||||
out[i+3], out[i+2], out[i+1], out[i] = out[i], out[i+1], out[i+2], out[i+3]
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,39 +0,0 @@
|
|||||||
// Copyright 2019 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
//go:build !go1.13
|
|
||||||
|
|
||||||
package poly1305
|
|
||||||
|
|
||||||
// Generic fallbacks for the math/bits intrinsics, copied from
|
|
||||||
// src/math/bits/bits.go. They were added in Go 1.12, but Add64 and Sum64 had
|
|
||||||
// variable time fallbacks until Go 1.13.
|
|
||||||
|
|
||||||
func bitsAdd64(x, y, carry uint64) (sum, carryOut uint64) {
|
|
||||||
sum = x + y + carry
|
|
||||||
carryOut = ((x & y) | ((x | y) &^ sum)) >> 63
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func bitsSub64(x, y, borrow uint64) (diff, borrowOut uint64) {
|
|
||||||
diff = x - y - borrow
|
|
||||||
borrowOut = ((^x & y) | (^(x ^ y) & diff)) >> 63
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func bitsMul64(x, y uint64) (hi, lo uint64) {
|
|
||||||
const mask32 = 1<<32 - 1
|
|
||||||
x0 := x & mask32
|
|
||||||
x1 := x >> 32
|
|
||||||
y0 := y & mask32
|
|
||||||
y1 := y >> 32
|
|
||||||
w0 := x0 * y0
|
|
||||||
t := x1*y0 + w0>>32
|
|
||||||
w1 := t & mask32
|
|
||||||
w2 := t >> 32
|
|
||||||
w1 += x0 * y1
|
|
||||||
hi = x1*y1 + w2 + w1>>32
|
|
||||||
lo = x * y
|
|
||||||
return
|
|
||||||
}
|
|
@ -1,21 +0,0 @@
|
|||||||
// Copyright 2019 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
//go:build go1.13
|
|
||||||
|
|
||||||
package poly1305
|
|
||||||
|
|
||||||
import "math/bits"
|
|
||||||
|
|
||||||
func bitsAdd64(x, y, carry uint64) (sum, carryOut uint64) {
|
|
||||||
return bits.Add64(x, y, carry)
|
|
||||||
}
|
|
||||||
|
|
||||||
func bitsSub64(x, y, borrow uint64) (diff, borrowOut uint64) {
|
|
||||||
return bits.Sub64(x, y, borrow)
|
|
||||||
}
|
|
||||||
|
|
||||||
func bitsMul64(x, y uint64) (hi, lo uint64) {
|
|
||||||
return bits.Mul64(x, y)
|
|
||||||
}
|
|
@ -1,9 +0,0 @@
|
|||||||
// Copyright 2018 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
//go:build (!amd64 && !ppc64le && !s390x) || !gc || purego
|
|
||||||
|
|
||||||
package poly1305
|
|
||||||
|
|
||||||
type mac struct{ macGeneric }
|
|
@ -1,99 +0,0 @@
|
|||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// Package poly1305 implements Poly1305 one-time message authentication code as
|
|
||||||
// specified in https://cr.yp.to/mac/poly1305-20050329.pdf.
|
|
||||||
//
|
|
||||||
// Poly1305 is a fast, one-time authentication function. It is infeasible for an
|
|
||||||
// attacker to generate an authenticator for a message without the key. However, a
|
|
||||||
// key must only be used for a single message. Authenticating two different
|
|
||||||
// messages with the same key allows an attacker to forge authenticators for other
|
|
||||||
// messages with the same key.
|
|
||||||
//
|
|
||||||
// Poly1305 was originally coupled with AES in order to make Poly1305-AES. AES was
|
|
||||||
// used with a fixed key in order to generate one-time keys from an nonce.
|
|
||||||
// However, in this package AES isn't used and the one-time key is specified
|
|
||||||
// directly.
|
|
||||||
package poly1305
|
|
||||||
|
|
||||||
import "crypto/subtle"
|
|
||||||
|
|
||||||
// TagSize is the size, in bytes, of a poly1305 authenticator.
|
|
||||||
const TagSize = 16
|
|
||||||
|
|
||||||
// Sum generates an authenticator for msg using a one-time key and puts the
|
|
||||||
// 16-byte result into out. Authenticating two different messages with the same
|
|
||||||
// key allows an attacker to forge messages at will.
|
|
||||||
func Sum(out *[16]byte, m []byte, key *[32]byte) {
|
|
||||||
h := New(key)
|
|
||||||
h.Write(m)
|
|
||||||
h.Sum(out[:0])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify returns true if mac is a valid authenticator for m with the given key.
|
|
||||||
func Verify(mac *[16]byte, m []byte, key *[32]byte) bool {
|
|
||||||
var tmp [16]byte
|
|
||||||
Sum(&tmp, m, key)
|
|
||||||
return subtle.ConstantTimeCompare(tmp[:], mac[:]) == 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// New returns a new MAC computing an authentication
|
|
||||||
// tag of all data written to it with the given key.
|
|
||||||
// This allows writing the message progressively instead
|
|
||||||
// of passing it as a single slice. Common users should use
|
|
||||||
// the Sum function instead.
|
|
||||||
//
|
|
||||||
// The key must be unique for each message, as authenticating
|
|
||||||
// two different messages with the same key allows an attacker
|
|
||||||
// to forge messages at will.
|
|
||||||
func New(key *[32]byte) *MAC {
|
|
||||||
m := &MAC{}
|
|
||||||
initialize(key, &m.macState)
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
// MAC is an io.Writer computing an authentication tag
|
|
||||||
// of the data written to it.
|
|
||||||
//
|
|
||||||
// MAC cannot be used like common hash.Hash implementations,
|
|
||||||
// because using a poly1305 key twice breaks its security.
|
|
||||||
// Therefore writing data to a running MAC after calling
|
|
||||||
// Sum or Verify causes it to panic.
|
|
||||||
type MAC struct {
|
|
||||||
mac // platform-dependent implementation
|
|
||||||
|
|
||||||
finalized bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Size returns the number of bytes Sum will return.
|
|
||||||
func (h *MAC) Size() int { return TagSize }
|
|
||||||
|
|
||||||
// Write adds more data to the running message authentication code.
|
|
||||||
// It never returns an error.
|
|
||||||
//
|
|
||||||
// It must not be called after the first call of Sum or Verify.
|
|
||||||
func (h *MAC) Write(p []byte) (n int, err error) {
|
|
||||||
if h.finalized {
|
|
||||||
panic("poly1305: write to MAC after Sum or Verify")
|
|
||||||
}
|
|
||||||
return h.mac.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sum computes the authenticator of all data written to the
|
|
||||||
// message authentication code.
|
|
||||||
func (h *MAC) Sum(b []byte) []byte {
|
|
||||||
var mac [TagSize]byte
|
|
||||||
h.mac.Sum(&mac)
|
|
||||||
h.finalized = true
|
|
||||||
return append(b, mac[:]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify returns whether the authenticator of all data written to
|
|
||||||
// the message authentication code matches the expected value.
|
|
||||||
func (h *MAC) Verify(expected []byte) bool {
|
|
||||||
var mac [TagSize]byte
|
|
||||||
h.mac.Sum(&mac)
|
|
||||||
h.finalized = true
|
|
||||||
return subtle.ConstantTimeCompare(expected, mac[:]) == 1
|
|
||||||
}
|
|
@ -1,47 +0,0 @@
|
|||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
//go:build gc && !purego
|
|
||||||
|
|
||||||
package poly1305
|
|
||||||
|
|
||||||
//go:noescape
|
|
||||||
func update(state *macState, msg []byte)
|
|
||||||
|
|
||||||
// mac is a wrapper for macGeneric that redirects calls that would have gone to
|
|
||||||
// updateGeneric to update.
|
|
||||||
//
|
|
||||||
// Its Write and Sum methods are otherwise identical to the macGeneric ones, but
|
|
||||||
// using function pointers would carry a major performance cost.
|
|
||||||
type mac struct{ macGeneric }
|
|
||||||
|
|
||||||
func (h *mac) Write(p []byte) (int, error) {
|
|
||||||
nn := len(p)
|
|
||||||
if h.offset > 0 {
|
|
||||||
n := copy(h.buffer[h.offset:], p)
|
|
||||||
if h.offset+n < TagSize {
|
|
||||||
h.offset += n
|
|
||||||
return nn, nil
|
|
||||||
}
|
|
||||||
p = p[n:]
|
|
||||||
h.offset = 0
|
|
||||||
update(&h.macState, h.buffer[:])
|
|
||||||
}
|
|
||||||
if n := len(p) - (len(p) % TagSize); n > 0 {
|
|
||||||
update(&h.macState, p[:n])
|
|
||||||
p = p[n:]
|
|
||||||
}
|
|
||||||
if len(p) > 0 {
|
|
||||||
h.offset += copy(h.buffer[h.offset:], p)
|
|
||||||
}
|
|
||||||
return nn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *mac) Sum(out *[16]byte) {
|
|
||||||
state := h.macState
|
|
||||||
if h.offset > 0 {
|
|
||||||
update(&state, h.buffer[:h.offset])
|
|
||||||
}
|
|
||||||
finalize(out, &state.h, &state.s)
|
|
||||||
}
|
|
@ -1,108 +0,0 @@
|
|||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
//go:build gc && !purego
|
|
||||||
|
|
||||||
#include "textflag.h"
|
|
||||||
|
|
||||||
#define POLY1305_ADD(msg, h0, h1, h2) \
|
|
||||||
ADDQ 0(msg), h0; \
|
|
||||||
ADCQ 8(msg), h1; \
|
|
||||||
ADCQ $1, h2; \
|
|
||||||
LEAQ 16(msg), msg
|
|
||||||
|
|
||||||
#define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3) \
|
|
||||||
MOVQ r0, AX; \
|
|
||||||
MULQ h0; \
|
|
||||||
MOVQ AX, t0; \
|
|
||||||
MOVQ DX, t1; \
|
|
||||||
MOVQ r0, AX; \
|
|
||||||
MULQ h1; \
|
|
||||||
ADDQ AX, t1; \
|
|
||||||
ADCQ $0, DX; \
|
|
||||||
MOVQ r0, t2; \
|
|
||||||
IMULQ h2, t2; \
|
|
||||||
ADDQ DX, t2; \
|
|
||||||
\
|
|
||||||
MOVQ r1, AX; \
|
|
||||||
MULQ h0; \
|
|
||||||
ADDQ AX, t1; \
|
|
||||||
ADCQ $0, DX; \
|
|
||||||
MOVQ DX, h0; \
|
|
||||||
MOVQ r1, t3; \
|
|
||||||
IMULQ h2, t3; \
|
|
||||||
MOVQ r1, AX; \
|
|
||||||
MULQ h1; \
|
|
||||||
ADDQ AX, t2; \
|
|
||||||
ADCQ DX, t3; \
|
|
||||||
ADDQ h0, t2; \
|
|
||||||
ADCQ $0, t3; \
|
|
||||||
\
|
|
||||||
MOVQ t0, h0; \
|
|
||||||
MOVQ t1, h1; \
|
|
||||||
MOVQ t2, h2; \
|
|
||||||
ANDQ $3, h2; \
|
|
||||||
MOVQ t2, t0; \
|
|
||||||
ANDQ $0xFFFFFFFFFFFFFFFC, t0; \
|
|
||||||
ADDQ t0, h0; \
|
|
||||||
ADCQ t3, h1; \
|
|
||||||
ADCQ $0, h2; \
|
|
||||||
SHRQ $2, t3, t2; \
|
|
||||||
SHRQ $2, t3; \
|
|
||||||
ADDQ t2, h0; \
|
|
||||||
ADCQ t3, h1; \
|
|
||||||
ADCQ $0, h2
|
|
||||||
|
|
||||||
// func update(state *[7]uint64, msg []byte)
|
|
||||||
TEXT ·update(SB), $0-32
|
|
||||||
MOVQ state+0(FP), DI
|
|
||||||
MOVQ msg_base+8(FP), SI
|
|
||||||
MOVQ msg_len+16(FP), R15
|
|
||||||
|
|
||||||
MOVQ 0(DI), R8 // h0
|
|
||||||
MOVQ 8(DI), R9 // h1
|
|
||||||
MOVQ 16(DI), R10 // h2
|
|
||||||
MOVQ 24(DI), R11 // r0
|
|
||||||
MOVQ 32(DI), R12 // r1
|
|
||||||
|
|
||||||
CMPQ R15, $16
|
|
||||||
JB bytes_between_0_and_15
|
|
||||||
|
|
||||||
loop:
|
|
||||||
POLY1305_ADD(SI, R8, R9, R10)
|
|
||||||
|
|
||||||
multiply:
|
|
||||||
POLY1305_MUL(R8, R9, R10, R11, R12, BX, CX, R13, R14)
|
|
||||||
SUBQ $16, R15
|
|
||||||
CMPQ R15, $16
|
|
||||||
JAE loop
|
|
||||||
|
|
||||||
bytes_between_0_and_15:
|
|
||||||
TESTQ R15, R15
|
|
||||||
JZ done
|
|
||||||
MOVQ $1, BX
|
|
||||||
XORQ CX, CX
|
|
||||||
XORQ R13, R13
|
|
||||||
ADDQ R15, SI
|
|
||||||
|
|
||||||
flush_buffer:
|
|
||||||
SHLQ $8, BX, CX
|
|
||||||
SHLQ $8, BX
|
|
||||||
MOVB -1(SI), R13
|
|
||||||
XORQ R13, BX
|
|
||||||
DECQ SI
|
|
||||||
DECQ R15
|
|
||||||
JNZ flush_buffer
|
|
||||||
|
|
||||||
ADDQ BX, R8
|
|
||||||
ADCQ CX, R9
|
|
||||||
ADCQ $0, R10
|
|
||||||
MOVQ $16, R15
|
|
||||||
JMP multiply
|
|
||||||
|
|
||||||
done:
|
|
||||||
MOVQ R8, 0(DI)
|
|
||||||
MOVQ R9, 8(DI)
|
|
||||||
MOVQ R10, 16(DI)
|
|
||||||
RET
|
|
@ -1,309 +0,0 @@
|
|||||||
// Copyright 2018 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// This file provides the generic implementation of Sum and MAC. Other files
|
|
||||||
// might provide optimized assembly implementations of some of this code.
|
|
||||||
|
|
||||||
package poly1305
|
|
||||||
|
|
||||||
import "encoding/binary"
|
|
||||||
|
|
||||||
// Poly1305 [RFC 7539] is a relatively simple algorithm: the authentication tag
|
|
||||||
// for a 64 bytes message is approximately
|
|
||||||
//
|
|
||||||
// s + m[0:16] * r⁴ + m[16:32] * r³ + m[32:48] * r² + m[48:64] * r mod 2¹³⁰ - 5
|
|
||||||
//
|
|
||||||
// for some secret r and s. It can be computed sequentially like
|
|
||||||
//
|
|
||||||
// for len(msg) > 0:
|
|
||||||
// h += read(msg, 16)
|
|
||||||
// h *= r
|
|
||||||
// h %= 2¹³⁰ - 5
|
|
||||||
// return h + s
|
|
||||||
//
|
|
||||||
// All the complexity is about doing performant constant-time math on numbers
|
|
||||||
// larger than any available numeric type.
|
|
||||||
|
|
||||||
func sumGeneric(out *[TagSize]byte, msg []byte, key *[32]byte) {
|
|
||||||
h := newMACGeneric(key)
|
|
||||||
h.Write(msg)
|
|
||||||
h.Sum(out)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMACGeneric(key *[32]byte) macGeneric {
|
|
||||||
m := macGeneric{}
|
|
||||||
initialize(key, &m.macState)
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
// macState holds numbers in saturated 64-bit little-endian limbs. That is,
|
|
||||||
// the value of [x0, x1, x2] is x[0] + x[1] * 2⁶⁴ + x[2] * 2¹²⁸.
|
|
||||||
type macState struct {
|
|
||||||
// h is the main accumulator. It is to be interpreted modulo 2¹³⁰ - 5, but
|
|
||||||
// can grow larger during and after rounds. It must, however, remain below
|
|
||||||
// 2 * (2¹³⁰ - 5).
|
|
||||||
h [3]uint64
|
|
||||||
// r and s are the private key components.
|
|
||||||
r [2]uint64
|
|
||||||
s [2]uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
type macGeneric struct {
|
|
||||||
macState
|
|
||||||
|
|
||||||
buffer [TagSize]byte
|
|
||||||
offset int
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write splits the incoming message into TagSize chunks, and passes them to
|
|
||||||
// update. It buffers incomplete chunks.
|
|
||||||
func (h *macGeneric) Write(p []byte) (int, error) {
|
|
||||||
nn := len(p)
|
|
||||||
if h.offset > 0 {
|
|
||||||
n := copy(h.buffer[h.offset:], p)
|
|
||||||
if h.offset+n < TagSize {
|
|
||||||
h.offset += n
|
|
||||||
return nn, nil
|
|
||||||
}
|
|
||||||
p = p[n:]
|
|
||||||
h.offset = 0
|
|
||||||
updateGeneric(&h.macState, h.buffer[:])
|
|
||||||
}
|
|
||||||
if n := len(p) - (len(p) % TagSize); n > 0 {
|
|
||||||
updateGeneric(&h.macState, p[:n])
|
|
||||||
p = p[n:]
|
|
||||||
}
|
|
||||||
if len(p) > 0 {
|
|
||||||
h.offset += copy(h.buffer[h.offset:], p)
|
|
||||||
}
|
|
||||||
return nn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sum flushes the last incomplete chunk from the buffer, if any, and generates
|
|
||||||
// the MAC output. It does not modify its state, in order to allow for multiple
|
|
||||||
// calls to Sum, even if no Write is allowed after Sum.
|
|
||||||
func (h *macGeneric) Sum(out *[TagSize]byte) {
|
|
||||||
state := h.macState
|
|
||||||
if h.offset > 0 {
|
|
||||||
updateGeneric(&state, h.buffer[:h.offset])
|
|
||||||
}
|
|
||||||
finalize(out, &state.h, &state.s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// [rMask0, rMask1] is the specified Poly1305 clamping mask in little-endian. It
|
|
||||||
// clears some bits of the secret coefficient to make it possible to implement
|
|
||||||
// multiplication more efficiently.
|
|
||||||
const (
|
|
||||||
rMask0 = 0x0FFFFFFC0FFFFFFF
|
|
||||||
rMask1 = 0x0FFFFFFC0FFFFFFC
|
|
||||||
)
|
|
||||||
|
|
||||||
// initialize loads the 256-bit key into the two 128-bit secret values r and s.
|
|
||||||
func initialize(key *[32]byte, m *macState) {
|
|
||||||
m.r[0] = binary.LittleEndian.Uint64(key[0:8]) & rMask0
|
|
||||||
m.r[1] = binary.LittleEndian.Uint64(key[8:16]) & rMask1
|
|
||||||
m.s[0] = binary.LittleEndian.Uint64(key[16:24])
|
|
||||||
m.s[1] = binary.LittleEndian.Uint64(key[24:32])
|
|
||||||
}
|
|
||||||
|
|
||||||
// uint128 holds a 128-bit number as two 64-bit limbs, for use with the
|
|
||||||
// bits.Mul64 and bits.Add64 intrinsics.
|
|
||||||
type uint128 struct {
|
|
||||||
lo, hi uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
func mul64(a, b uint64) uint128 {
|
|
||||||
hi, lo := bitsMul64(a, b)
|
|
||||||
return uint128{lo, hi}
|
|
||||||
}
|
|
||||||
|
|
||||||
func add128(a, b uint128) uint128 {
|
|
||||||
lo, c := bitsAdd64(a.lo, b.lo, 0)
|
|
||||||
hi, c := bitsAdd64(a.hi, b.hi, c)
|
|
||||||
if c != 0 {
|
|
||||||
panic("poly1305: unexpected overflow")
|
|
||||||
}
|
|
||||||
return uint128{lo, hi}
|
|
||||||
}
|
|
||||||
|
|
||||||
func shiftRightBy2(a uint128) uint128 {
|
|
||||||
a.lo = a.lo>>2 | (a.hi&3)<<62
|
|
||||||
a.hi = a.hi >> 2
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateGeneric absorbs msg into the state.h accumulator. For each chunk m of
|
|
||||||
// 128 bits of message, it computes
|
|
||||||
//
|
|
||||||
// h₊ = (h + m) * r mod 2¹³⁰ - 5
|
|
||||||
//
|
|
||||||
// If the msg length is not a multiple of TagSize, it assumes the last
|
|
||||||
// incomplete chunk is the final one.
|
|
||||||
func updateGeneric(state *macState, msg []byte) {
|
|
||||||
h0, h1, h2 := state.h[0], state.h[1], state.h[2]
|
|
||||||
r0, r1 := state.r[0], state.r[1]
|
|
||||||
|
|
||||||
for len(msg) > 0 {
|
|
||||||
var c uint64
|
|
||||||
|
|
||||||
// For the first step, h + m, we use a chain of bits.Add64 intrinsics.
|
|
||||||
// The resulting value of h might exceed 2¹³⁰ - 5, but will be partially
|
|
||||||
// reduced at the end of the multiplication below.
|
|
||||||
//
|
|
||||||
// The spec requires us to set a bit just above the message size, not to
|
|
||||||
// hide leading zeroes. For full chunks, that's 1 << 128, so we can just
|
|
||||||
// add 1 to the most significant (2¹²⁸) limb, h2.
|
|
||||||
if len(msg) >= TagSize {
|
|
||||||
h0, c = bitsAdd64(h0, binary.LittleEndian.Uint64(msg[0:8]), 0)
|
|
||||||
h1, c = bitsAdd64(h1, binary.LittleEndian.Uint64(msg[8:16]), c)
|
|
||||||
h2 += c + 1
|
|
||||||
|
|
||||||
msg = msg[TagSize:]
|
|
||||||
} else {
|
|
||||||
var buf [TagSize]byte
|
|
||||||
copy(buf[:], msg)
|
|
||||||
buf[len(msg)] = 1
|
|
||||||
|
|
||||||
h0, c = bitsAdd64(h0, binary.LittleEndian.Uint64(buf[0:8]), 0)
|
|
||||||
h1, c = bitsAdd64(h1, binary.LittleEndian.Uint64(buf[8:16]), c)
|
|
||||||
h2 += c
|
|
||||||
|
|
||||||
msg = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multiplication of big number limbs is similar to elementary school
|
|
||||||
// columnar multiplication. Instead of digits, there are 64-bit limbs.
|
|
||||||
//
|
|
||||||
// We are multiplying a 3 limbs number, h, by a 2 limbs number, r.
|
|
||||||
//
|
|
||||||
// h2 h1 h0 x
|
|
||||||
// r1 r0 =
|
|
||||||
// ----------------
|
|
||||||
// h2r0 h1r0 h0r0 <-- individual 128-bit products
|
|
||||||
// + h2r1 h1r1 h0r1
|
|
||||||
// ------------------------
|
|
||||||
// m3 m2 m1 m0 <-- result in 128-bit overlapping limbs
|
|
||||||
// ------------------------
|
|
||||||
// m3.hi m2.hi m1.hi m0.hi <-- carry propagation
|
|
||||||
// + m3.lo m2.lo m1.lo m0.lo
|
|
||||||
// -------------------------------
|
|
||||||
// t4 t3 t2 t1 t0 <-- final result in 64-bit limbs
|
|
||||||
//
|
|
||||||
// The main difference from pen-and-paper multiplication is that we do
|
|
||||||
// carry propagation in a separate step, as if we wrote two digit sums
|
|
||||||
// at first (the 128-bit limbs), and then carried the tens all at once.
|
|
||||||
|
|
||||||
h0r0 := mul64(h0, r0)
|
|
||||||
h1r0 := mul64(h1, r0)
|
|
||||||
h2r0 := mul64(h2, r0)
|
|
||||||
h0r1 := mul64(h0, r1)
|
|
||||||
h1r1 := mul64(h1, r1)
|
|
||||||
h2r1 := mul64(h2, r1)
|
|
||||||
|
|
||||||
// Since h2 is known to be at most 7 (5 + 1 + 1), and r0 and r1 have their
|
|
||||||
// top 4 bits cleared by rMask{0,1}, we know that their product is not going
|
|
||||||
// to overflow 64 bits, so we can ignore the high part of the products.
|
|
||||||
//
|
|
||||||
// This also means that the product doesn't have a fifth limb (t4).
|
|
||||||
if h2r0.hi != 0 {
|
|
||||||
panic("poly1305: unexpected overflow")
|
|
||||||
}
|
|
||||||
if h2r1.hi != 0 {
|
|
||||||
panic("poly1305: unexpected overflow")
|
|
||||||
}
|
|
||||||
|
|
||||||
m0 := h0r0
|
|
||||||
m1 := add128(h1r0, h0r1) // These two additions don't overflow thanks again
|
|
||||||
m2 := add128(h2r0, h1r1) // to the 4 masked bits at the top of r0 and r1.
|
|
||||||
m3 := h2r1
|
|
||||||
|
|
||||||
t0 := m0.lo
|
|
||||||
t1, c := bitsAdd64(m1.lo, m0.hi, 0)
|
|
||||||
t2, c := bitsAdd64(m2.lo, m1.hi, c)
|
|
||||||
t3, _ := bitsAdd64(m3.lo, m2.hi, c)
|
|
||||||
|
|
||||||
// Now we have the result as 4 64-bit limbs, and we need to reduce it
|
|
||||||
// modulo 2¹³⁰ - 5. The special shape of this Crandall prime lets us do
|
|
||||||
// a cheap partial reduction according to the reduction identity
|
|
||||||
//
|
|
||||||
// c * 2¹³⁰ + n = c * 5 + n mod 2¹³⁰ - 5
|
|
||||||
//
|
|
||||||
// because 2¹³⁰ = 5 mod 2¹³⁰ - 5. Partial reduction since the result is
|
|
||||||
// likely to be larger than 2¹³⁰ - 5, but still small enough to fit the
|
|
||||||
// assumptions we make about h in the rest of the code.
|
|
||||||
//
|
|
||||||
// See also https://speakerdeck.com/gtank/engineering-prime-numbers?slide=23
|
|
||||||
|
|
||||||
// We split the final result at the 2¹³⁰ mark into h and cc, the carry.
|
|
||||||
// Note that the carry bits are effectively shifted left by 2, in other
|
|
||||||
// words, cc = c * 4 for the c in the reduction identity.
|
|
||||||
h0, h1, h2 = t0, t1, t2&maskLow2Bits
|
|
||||||
cc := uint128{t2 & maskNotLow2Bits, t3}
|
|
||||||
|
|
||||||
// To add c * 5 to h, we first add cc = c * 4, and then add (cc >> 2) = c.
|
|
||||||
|
|
||||||
h0, c = bitsAdd64(h0, cc.lo, 0)
|
|
||||||
h1, c = bitsAdd64(h1, cc.hi, c)
|
|
||||||
h2 += c
|
|
||||||
|
|
||||||
cc = shiftRightBy2(cc)
|
|
||||||
|
|
||||||
h0, c = bitsAdd64(h0, cc.lo, 0)
|
|
||||||
h1, c = bitsAdd64(h1, cc.hi, c)
|
|
||||||
h2 += c
|
|
||||||
|
|
||||||
// h2 is at most 3 + 1 + 1 = 5, making the whole of h at most
|
|
||||||
//
|
|
||||||
// 5 * 2¹²⁸ + (2¹²⁸ - 1) = 6 * 2¹²⁸ - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
state.h[0], state.h[1], state.h[2] = h0, h1, h2
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
maskLow2Bits uint64 = 0x0000000000000003
|
|
||||||
maskNotLow2Bits uint64 = ^maskLow2Bits
|
|
||||||
)
|
|
||||||
|
|
||||||
// select64 returns x if v == 1 and y if v == 0, in constant time.
|
|
||||||
func select64(v, x, y uint64) uint64 { return ^(v-1)&x | (v-1)&y }
|
|
||||||
|
|
||||||
// [p0, p1, p2] is 2¹³⁰ - 5 in little endian order.
|
|
||||||
const (
|
|
||||||
p0 = 0xFFFFFFFFFFFFFFFB
|
|
||||||
p1 = 0xFFFFFFFFFFFFFFFF
|
|
||||||
p2 = 0x0000000000000003
|
|
||||||
)
|
|
||||||
|
|
||||||
// finalize completes the modular reduction of h and computes
|
|
||||||
//
|
|
||||||
// out = h + s mod 2¹²⁸
|
|
||||||
func finalize(out *[TagSize]byte, h *[3]uint64, s *[2]uint64) {
|
|
||||||
h0, h1, h2 := h[0], h[1], h[2]
|
|
||||||
|
|
||||||
// After the partial reduction in updateGeneric, h might be more than
|
|
||||||
// 2¹³⁰ - 5, but will be less than 2 * (2¹³⁰ - 5). To complete the reduction
|
|
||||||
// in constant time, we compute t = h - (2¹³⁰ - 5), and select h as the
|
|
||||||
// result if the subtraction underflows, and t otherwise.
|
|
||||||
|
|
||||||
hMinusP0, b := bitsSub64(h0, p0, 0)
|
|
||||||
hMinusP1, b := bitsSub64(h1, p1, b)
|
|
||||||
_, b = bitsSub64(h2, p2, b)
|
|
||||||
|
|
||||||
// h = h if h < p else h - p
|
|
||||||
h0 = select64(b, h0, hMinusP0)
|
|
||||||
h1 = select64(b, h1, hMinusP1)
|
|
||||||
|
|
||||||
// Finally, we compute the last Poly1305 step
|
|
||||||
//
|
|
||||||
// tag = h + s mod 2¹²⁸
|
|
||||||
//
|
|
||||||
// by just doing a wide addition with the 128 low bits of h and discarding
|
|
||||||
// the overflow.
|
|
||||||
h0, c := bitsAdd64(h0, s[0], 0)
|
|
||||||
h1, _ = bitsAdd64(h1, s[1], c)
|
|
||||||
|
|
||||||
binary.LittleEndian.PutUint64(out[0:8], h0)
|
|
||||||
binary.LittleEndian.PutUint64(out[8:16], h1)
|
|
||||||
}
|
|
@ -1,47 +0,0 @@
|
|||||||
// Copyright 2019 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
//go:build gc && !purego
|
|
||||||
|
|
||||||
package poly1305
|
|
||||||
|
|
||||||
//go:noescape
|
|
||||||
func update(state *macState, msg []byte)
|
|
||||||
|
|
||||||
// mac is a wrapper for macGeneric that redirects calls that would have gone to
|
|
||||||
// updateGeneric to update.
|
|
||||||
//
|
|
||||||
// Its Write and Sum methods are otherwise identical to the macGeneric ones, but
|
|
||||||
// using function pointers would carry a major performance cost.
|
|
||||||
type mac struct{ macGeneric }
|
|
||||||
|
|
||||||
func (h *mac) Write(p []byte) (int, error) {
|
|
||||||
nn := len(p)
|
|
||||||
if h.offset > 0 {
|
|
||||||
n := copy(h.buffer[h.offset:], p)
|
|
||||||
if h.offset+n < TagSize {
|
|
||||||
h.offset += n
|
|
||||||
return nn, nil
|
|
||||||
}
|
|
||||||
p = p[n:]
|
|
||||||
h.offset = 0
|
|
||||||
update(&h.macState, h.buffer[:])
|
|
||||||
}
|
|
||||||
if n := len(p) - (len(p) % TagSize); n > 0 {
|
|
||||||
update(&h.macState, p[:n])
|
|
||||||
p = p[n:]
|
|
||||||
}
|
|
||||||
if len(p) > 0 {
|
|
||||||
h.offset += copy(h.buffer[h.offset:], p)
|
|
||||||
}
|
|
||||||
return nn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *mac) Sum(out *[16]byte) {
|
|
||||||
state := h.macState
|
|
||||||
if h.offset > 0 {
|
|
||||||
update(&state, h.buffer[:h.offset])
|
|
||||||
}
|
|
||||||
finalize(out, &state.h, &state.s)
|
|
||||||
}
|
|
@ -1,181 +0,0 @@
|
|||||||
// Copyright 2019 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
//go:build gc && !purego
|
|
||||||
|
|
||||||
#include "textflag.h"
|
|
||||||
|
|
||||||
// This was ported from the amd64 implementation.
|
|
||||||
|
|
||||||
#define POLY1305_ADD(msg, h0, h1, h2, t0, t1, t2) \
|
|
||||||
MOVD (msg), t0; \
|
|
||||||
MOVD 8(msg), t1; \
|
|
||||||
MOVD $1, t2; \
|
|
||||||
ADDC t0, h0, h0; \
|
|
||||||
ADDE t1, h1, h1; \
|
|
||||||
ADDE t2, h2; \
|
|
||||||
ADD $16, msg
|
|
||||||
|
|
||||||
#define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3, t4, t5) \
|
|
||||||
MULLD r0, h0, t0; \
|
|
||||||
MULLD r0, h1, t4; \
|
|
||||||
MULHDU r0, h0, t1; \
|
|
||||||
MULHDU r0, h1, t5; \
|
|
||||||
ADDC t4, t1, t1; \
|
|
||||||
MULLD r0, h2, t2; \
|
|
||||||
ADDZE t5; \
|
|
||||||
MULHDU r1, h0, t4; \
|
|
||||||
MULLD r1, h0, h0; \
|
|
||||||
ADD t5, t2, t2; \
|
|
||||||
ADDC h0, t1, t1; \
|
|
||||||
MULLD h2, r1, t3; \
|
|
||||||
ADDZE t4, h0; \
|
|
||||||
MULHDU r1, h1, t5; \
|
|
||||||
MULLD r1, h1, t4; \
|
|
||||||
ADDC t4, t2, t2; \
|
|
||||||
ADDE t5, t3, t3; \
|
|
||||||
ADDC h0, t2, t2; \
|
|
||||||
MOVD $-4, t4; \
|
|
||||||
MOVD t0, h0; \
|
|
||||||
MOVD t1, h1; \
|
|
||||||
ADDZE t3; \
|
|
||||||
ANDCC $3, t2, h2; \
|
|
||||||
AND t2, t4, t0; \
|
|
||||||
ADDC t0, h0, h0; \
|
|
||||||
ADDE t3, h1, h1; \
|
|
||||||
SLD $62, t3, t4; \
|
|
||||||
SRD $2, t2; \
|
|
||||||
ADDZE h2; \
|
|
||||||
OR t4, t2, t2; \
|
|
||||||
SRD $2, t3; \
|
|
||||||
ADDC t2, h0, h0; \
|
|
||||||
ADDE t3, h1, h1; \
|
|
||||||
ADDZE h2
|
|
||||||
|
|
||||||
DATA ·poly1305Mask<>+0x00(SB)/8, $0x0FFFFFFC0FFFFFFF
|
|
||||||
DATA ·poly1305Mask<>+0x08(SB)/8, $0x0FFFFFFC0FFFFFFC
|
|
||||||
GLOBL ·poly1305Mask<>(SB), RODATA, $16
|
|
||||||
|
|
||||||
// func update(state *[7]uint64, msg []byte)
|
|
||||||
TEXT ·update(SB), $0-32
|
|
||||||
MOVD state+0(FP), R3
|
|
||||||
MOVD msg_base+8(FP), R4
|
|
||||||
MOVD msg_len+16(FP), R5
|
|
||||||
|
|
||||||
MOVD 0(R3), R8 // h0
|
|
||||||
MOVD 8(R3), R9 // h1
|
|
||||||
MOVD 16(R3), R10 // h2
|
|
||||||
MOVD 24(R3), R11 // r0
|
|
||||||
MOVD 32(R3), R12 // r1
|
|
||||||
|
|
||||||
CMP R5, $16
|
|
||||||
BLT bytes_between_0_and_15
|
|
||||||
|
|
||||||
loop:
|
|
||||||
POLY1305_ADD(R4, R8, R9, R10, R20, R21, R22)
|
|
||||||
|
|
||||||
multiply:
|
|
||||||
POLY1305_MUL(R8, R9, R10, R11, R12, R16, R17, R18, R14, R20, R21)
|
|
||||||
ADD $-16, R5
|
|
||||||
CMP R5, $16
|
|
||||||
BGE loop
|
|
||||||
|
|
||||||
bytes_between_0_and_15:
|
|
||||||
CMP R5, $0
|
|
||||||
BEQ done
|
|
||||||
MOVD $0, R16 // h0
|
|
||||||
MOVD $0, R17 // h1
|
|
||||||
|
|
||||||
flush_buffer:
|
|
||||||
CMP R5, $8
|
|
||||||
BLE just1
|
|
||||||
|
|
||||||
MOVD $8, R21
|
|
||||||
SUB R21, R5, R21
|
|
||||||
|
|
||||||
// Greater than 8 -- load the rightmost remaining bytes in msg
|
|
||||||
// and put into R17 (h1)
|
|
||||||
MOVD (R4)(R21), R17
|
|
||||||
MOVD $16, R22
|
|
||||||
|
|
||||||
// Find the offset to those bytes
|
|
||||||
SUB R5, R22, R22
|
|
||||||
SLD $3, R22
|
|
||||||
|
|
||||||
// Shift to get only the bytes in msg
|
|
||||||
SRD R22, R17, R17
|
|
||||||
|
|
||||||
// Put 1 at high end
|
|
||||||
MOVD $1, R23
|
|
||||||
SLD $3, R21
|
|
||||||
SLD R21, R23, R23
|
|
||||||
OR R23, R17, R17
|
|
||||||
|
|
||||||
// Remainder is 8
|
|
||||||
MOVD $8, R5
|
|
||||||
|
|
||||||
just1:
|
|
||||||
CMP R5, $8
|
|
||||||
BLT less8
|
|
||||||
|
|
||||||
// Exactly 8
|
|
||||||
MOVD (R4), R16
|
|
||||||
|
|
||||||
CMP R17, $0
|
|
||||||
|
|
||||||
// Check if we've already set R17; if not
|
|
||||||
// set 1 to indicate end of msg.
|
|
||||||
BNE carry
|
|
||||||
MOVD $1, R17
|
|
||||||
BR carry
|
|
||||||
|
|
||||||
less8:
|
|
||||||
MOVD $0, R16 // h0
|
|
||||||
MOVD $0, R22 // shift count
|
|
||||||
CMP R5, $4
|
|
||||||
BLT less4
|
|
||||||
MOVWZ (R4), R16
|
|
||||||
ADD $4, R4
|
|
||||||
ADD $-4, R5
|
|
||||||
MOVD $32, R22
|
|
||||||
|
|
||||||
less4:
|
|
||||||
CMP R5, $2
|
|
||||||
BLT less2
|
|
||||||
MOVHZ (R4), R21
|
|
||||||
SLD R22, R21, R21
|
|
||||||
OR R16, R21, R16
|
|
||||||
ADD $16, R22
|
|
||||||
ADD $-2, R5
|
|
||||||
ADD $2, R4
|
|
||||||
|
|
||||||
less2:
|
|
||||||
CMP R5, $0
|
|
||||||
BEQ insert1
|
|
||||||
MOVBZ (R4), R21
|
|
||||||
SLD R22, R21, R21
|
|
||||||
OR R16, R21, R16
|
|
||||||
ADD $8, R22
|
|
||||||
|
|
||||||
insert1:
|
|
||||||
// Insert 1 at end of msg
|
|
||||||
MOVD $1, R21
|
|
||||||
SLD R22, R21, R21
|
|
||||||
OR R16, R21, R16
|
|
||||||
|
|
||||||
carry:
|
|
||||||
// Add new values to h0, h1, h2
|
|
||||||
ADDC R16, R8
|
|
||||||
ADDE R17, R9
|
|
||||||
ADDZE R10, R10
|
|
||||||
MOVD $16, R5
|
|
||||||
ADD R5, R4
|
|
||||||
BR multiply
|
|
||||||
|
|
||||||
done:
|
|
||||||
// Save h0, h1, h2 in state
|
|
||||||
MOVD R8, 0(R3)
|
|
||||||
MOVD R9, 8(R3)
|
|
||||||
MOVD R10, 16(R3)
|
|
||||||
RET
|
|
@ -1,76 +0,0 @@
|
|||||||
// Copyright 2018 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
//go:build gc && !purego
|
|
||||||
|
|
||||||
package poly1305
|
|
||||||
|
|
||||||
import (
|
|
||||||
"golang.org/x/sys/cpu"
|
|
||||||
)
|
|
||||||
|
|
||||||
// updateVX is an assembly implementation of Poly1305 that uses vector
|
|
||||||
// instructions. It must only be called if the vector facility (vx) is
|
|
||||||
// available.
|
|
||||||
//
|
|
||||||
//go:noescape
|
|
||||||
func updateVX(state *macState, msg []byte)
|
|
||||||
|
|
||||||
// mac is a replacement for macGeneric that uses a larger buffer and redirects
|
|
||||||
// calls that would have gone to updateGeneric to updateVX if the vector
|
|
||||||
// facility is installed.
|
|
||||||
//
|
|
||||||
// A larger buffer is required for good performance because the vector
|
|
||||||
// implementation has a higher fixed cost per call than the generic
|
|
||||||
// implementation.
|
|
||||||
type mac struct {
|
|
||||||
macState
|
|
||||||
|
|
||||||
buffer [16 * TagSize]byte // size must be a multiple of block size (16)
|
|
||||||
offset int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *mac) Write(p []byte) (int, error) {
|
|
||||||
nn := len(p)
|
|
||||||
if h.offset > 0 {
|
|
||||||
n := copy(h.buffer[h.offset:], p)
|
|
||||||
if h.offset+n < len(h.buffer) {
|
|
||||||
h.offset += n
|
|
||||||
return nn, nil
|
|
||||||
}
|
|
||||||
p = p[n:]
|
|
||||||
h.offset = 0
|
|
||||||
if cpu.S390X.HasVX {
|
|
||||||
updateVX(&h.macState, h.buffer[:])
|
|
||||||
} else {
|
|
||||||
updateGeneric(&h.macState, h.buffer[:])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tail := len(p) % len(h.buffer) // number of bytes to copy into buffer
|
|
||||||
body := len(p) - tail // number of bytes to process now
|
|
||||||
if body > 0 {
|
|
||||||
if cpu.S390X.HasVX {
|
|
||||||
updateVX(&h.macState, p[:body])
|
|
||||||
} else {
|
|
||||||
updateGeneric(&h.macState, p[:body])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.offset = copy(h.buffer[:], p[body:]) // copy tail bytes - can be 0
|
|
||||||
return nn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *mac) Sum(out *[TagSize]byte) {
|
|
||||||
state := h.macState
|
|
||||||
remainder := h.buffer[:h.offset]
|
|
||||||
|
|
||||||
// Use the generic implementation if we have 2 or fewer blocks left
|
|
||||||
// to sum. The vector implementation has a higher startup time.
|
|
||||||
if cpu.S390X.HasVX && len(remainder) > 2*TagSize {
|
|
||||||
updateVX(&state, remainder)
|
|
||||||
} else if len(remainder) > 0 {
|
|
||||||
updateGeneric(&state, remainder)
|
|
||||||
}
|
|
||||||
finalize(out, &state.h, &state.s)
|
|
||||||
}
|
|
@ -1,503 +0,0 @@
|
|||||||
// Copyright 2018 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
//go:build gc && !purego
|
|
||||||
|
|
||||||
#include "textflag.h"
|
|
||||||
|
|
||||||
// This implementation of Poly1305 uses the vector facility (vx)
|
|
||||||
// to process up to 2 blocks (32 bytes) per iteration using an
|
|
||||||
// algorithm based on the one described in:
|
|
||||||
//
|
|
||||||
// NEON crypto, Daniel J. Bernstein & Peter Schwabe
|
|
||||||
// https://cryptojedi.org/papers/neoncrypto-20120320.pdf
|
|
||||||
//
|
|
||||||
// This algorithm uses 5 26-bit limbs to represent a 130-bit
|
|
||||||
// value. These limbs are, for the most part, zero extended and
|
|
||||||
// placed into 64-bit vector register elements. Each vector
|
|
||||||
// register is 128-bits wide and so holds 2 of these elements.
|
|
||||||
// Using 26-bit limbs allows us plenty of headroom to accommodate
|
|
||||||
// accumulations before and after multiplication without
|
|
||||||
// overflowing either 32-bits (before multiplication) or 64-bits
|
|
||||||
// (after multiplication).
|
|
||||||
//
|
|
||||||
// In order to parallelise the operations required to calculate
|
|
||||||
// the sum we use two separate accumulators and then sum those
|
|
||||||
// in an extra final step. For compatibility with the generic
|
|
||||||
// implementation we perform this summation at the end of every
|
|
||||||
// updateVX call.
|
|
||||||
//
|
|
||||||
// To use two accumulators we must multiply the message blocks
|
|
||||||
// by r² rather than r. Only the final message block should be
|
|
||||||
// multiplied by r.
|
|
||||||
//
|
|
||||||
// Example:
|
|
||||||
//
|
|
||||||
// We want to calculate the sum (h) for a 64 byte message (m):
|
|
||||||
//
|
|
||||||
// h = m[0:16]r⁴ + m[16:32]r³ + m[32:48]r² + m[48:64]r
|
|
||||||
//
|
|
||||||
// To do this we split the calculation into the even indices
|
|
||||||
// and odd indices of the message. These form our SIMD 'lanes':
|
|
||||||
//
|
|
||||||
// h = m[ 0:16]r⁴ + m[32:48]r² + <- lane 0
|
|
||||||
// m[16:32]r³ + m[48:64]r <- lane 1
|
|
||||||
//
|
|
||||||
// To calculate this iteratively we refactor so that both lanes
|
|
||||||
// are written in terms of r² and r:
|
|
||||||
//
|
|
||||||
// h = (m[ 0:16]r² + m[32:48])r² + <- lane 0
|
|
||||||
// (m[16:32]r² + m[48:64])r <- lane 1
|
|
||||||
// ^ ^
|
|
||||||
// | coefficients for second iteration
|
|
||||||
// coefficients for first iteration
|
|
||||||
//
|
|
||||||
// So in this case we would have two iterations. In the first
|
|
||||||
// both lanes are multiplied by r². In the second only the
|
|
||||||
// first lane is multiplied by r² and the second lane is
|
|
||||||
// instead multiplied by r. This gives use the odd and even
|
|
||||||
// powers of r that we need from the original equation.
|
|
||||||
//
|
|
||||||
// Notation:
|
|
||||||
//
|
|
||||||
// h - accumulator
|
|
||||||
// r - key
|
|
||||||
// m - message
|
|
||||||
//
|
|
||||||
// [a, b] - SIMD register holding two 64-bit values
|
|
||||||
// [a, b, c, d] - SIMD register holding four 32-bit values
|
|
||||||
// xᵢ[n] - limb n of variable x with bit width i
|
|
||||||
//
|
|
||||||
// Limbs are expressed in little endian order, so for 26-bit
|
|
||||||
// limbs x₂₆[4] will be the most significant limb and x₂₆[0]
|
|
||||||
// will be the least significant limb.
|
|
||||||
|
|
||||||
// masking constants
|
|
||||||
#define MOD24 V0 // [0x0000000000ffffff, 0x0000000000ffffff] - mask low 24-bits
|
|
||||||
#define MOD26 V1 // [0x0000000003ffffff, 0x0000000003ffffff] - mask low 26-bits
|
|
||||||
|
|
||||||
// expansion constants (see EXPAND macro)
|
|
||||||
#define EX0 V2
|
|
||||||
#define EX1 V3
|
|
||||||
#define EX2 V4
|
|
||||||
|
|
||||||
// key (r², r or 1 depending on context)
|
|
||||||
#define R_0 V5
|
|
||||||
#define R_1 V6
|
|
||||||
#define R_2 V7
|
|
||||||
#define R_3 V8
|
|
||||||
#define R_4 V9
|
|
||||||
|
|
||||||
// precalculated coefficients (5r², 5r or 0 depending on context)
|
|
||||||
#define R5_1 V10
|
|
||||||
#define R5_2 V11
|
|
||||||
#define R5_3 V12
|
|
||||||
#define R5_4 V13
|
|
||||||
|
|
||||||
// message block (m)
|
|
||||||
#define M_0 V14
|
|
||||||
#define M_1 V15
|
|
||||||
#define M_2 V16
|
|
||||||
#define M_3 V17
|
|
||||||
#define M_4 V18
|
|
||||||
|
|
||||||
// accumulator (h)
|
|
||||||
#define H_0 V19
|
|
||||||
#define H_1 V20
|
|
||||||
#define H_2 V21
|
|
||||||
#define H_3 V22
|
|
||||||
#define H_4 V23
|
|
||||||
|
|
||||||
// temporary registers (for short-lived values)
|
|
||||||
#define T_0 V24
|
|
||||||
#define T_1 V25
|
|
||||||
#define T_2 V26
|
|
||||||
#define T_3 V27
|
|
||||||
#define T_4 V28
|
|
||||||
|
|
||||||
GLOBL ·constants<>(SB), RODATA, $0x30
|
|
||||||
// EX0
|
|
||||||
DATA ·constants<>+0x00(SB)/8, $0x0006050403020100
|
|
||||||
DATA ·constants<>+0x08(SB)/8, $0x1016151413121110
|
|
||||||
// EX1
|
|
||||||
DATA ·constants<>+0x10(SB)/8, $0x060c0b0a09080706
|
|
||||||
DATA ·constants<>+0x18(SB)/8, $0x161c1b1a19181716
|
|
||||||
// EX2
|
|
||||||
DATA ·constants<>+0x20(SB)/8, $0x0d0d0d0d0d0f0e0d
|
|
||||||
DATA ·constants<>+0x28(SB)/8, $0x1d1d1d1d1d1f1e1d
|
|
||||||
|
|
||||||
// MULTIPLY multiplies each lane of f and g, partially reduced
|
|
||||||
// modulo 2¹³⁰ - 5. The result, h, consists of partial products
|
|
||||||
// in each lane that need to be reduced further to produce the
|
|
||||||
// final result.
|
|
||||||
//
|
|
||||||
// h₁₃₀ = (f₁₃₀g₁₃₀) % 2¹³⁰ + (5f₁₃₀g₁₃₀) / 2¹³⁰
|
|
||||||
//
|
|
||||||
// Note that the multiplication by 5 of the high bits is
|
|
||||||
// achieved by precalculating the multiplication of four of the
|
|
||||||
// g coefficients by 5. These are g51-g54.
|
|
||||||
#define MULTIPLY(f0, f1, f2, f3, f4, g0, g1, g2, g3, g4, g51, g52, g53, g54, h0, h1, h2, h3, h4) \
|
|
||||||
VMLOF f0, g0, h0 \
|
|
||||||
VMLOF f0, g3, h3 \
|
|
||||||
VMLOF f0, g1, h1 \
|
|
||||||
VMLOF f0, g4, h4 \
|
|
||||||
VMLOF f0, g2, h2 \
|
|
||||||
VMLOF f1, g54, T_0 \
|
|
||||||
VMLOF f1, g2, T_3 \
|
|
||||||
VMLOF f1, g0, T_1 \
|
|
||||||
VMLOF f1, g3, T_4 \
|
|
||||||
VMLOF f1, g1, T_2 \
|
|
||||||
VMALOF f2, g53, h0, h0 \
|
|
||||||
VMALOF f2, g1, h3, h3 \
|
|
||||||
VMALOF f2, g54, h1, h1 \
|
|
||||||
VMALOF f2, g2, h4, h4 \
|
|
||||||
VMALOF f2, g0, h2, h2 \
|
|
||||||
VMALOF f3, g52, T_0, T_0 \
|
|
||||||
VMALOF f3, g0, T_3, T_3 \
|
|
||||||
VMALOF f3, g53, T_1, T_1 \
|
|
||||||
VMALOF f3, g1, T_4, T_4 \
|
|
||||||
VMALOF f3, g54, T_2, T_2 \
|
|
||||||
VMALOF f4, g51, h0, h0 \
|
|
||||||
VMALOF f4, g54, h3, h3 \
|
|
||||||
VMALOF f4, g52, h1, h1 \
|
|
||||||
VMALOF f4, g0, h4, h4 \
|
|
||||||
VMALOF f4, g53, h2, h2 \
|
|
||||||
VAG T_0, h0, h0 \
|
|
||||||
VAG T_3, h3, h3 \
|
|
||||||
VAG T_1, h1, h1 \
|
|
||||||
VAG T_4, h4, h4 \
|
|
||||||
VAG T_2, h2, h2
|
|
||||||
|
|
||||||
// REDUCE performs the following carry operations in four
|
|
||||||
// stages, as specified in Bernstein & Schwabe:
|
|
||||||
//
|
|
||||||
// 1: h₂₆[0]->h₂₆[1] h₂₆[3]->h₂₆[4]
|
|
||||||
// 2: h₂₆[1]->h₂₆[2] h₂₆[4]->h₂₆[0]
|
|
||||||
// 3: h₂₆[0]->h₂₆[1] h₂₆[2]->h₂₆[3]
|
|
||||||
// 4: h₂₆[3]->h₂₆[4]
|
|
||||||
//
|
|
||||||
// The result is that all of the limbs are limited to 26-bits
|
|
||||||
// except for h₂₆[1] and h₂₆[4] which are limited to 27-bits.
|
|
||||||
//
|
|
||||||
// Note that although each limb is aligned at 26-bit intervals
|
|
||||||
// they may contain values that exceed 2²⁶ - 1, hence the need
|
|
||||||
// to carry the excess bits in each limb.
|
|
||||||
#define REDUCE(h0, h1, h2, h3, h4) \
|
|
||||||
VESRLG $26, h0, T_0 \
|
|
||||||
VESRLG $26, h3, T_1 \
|
|
||||||
VN MOD26, h0, h0 \
|
|
||||||
VN MOD26, h3, h3 \
|
|
||||||
VAG T_0, h1, h1 \
|
|
||||||
VAG T_1, h4, h4 \
|
|
||||||
VESRLG $26, h1, T_2 \
|
|
||||||
VESRLG $26, h4, T_3 \
|
|
||||||
VN MOD26, h1, h1 \
|
|
||||||
VN MOD26, h4, h4 \
|
|
||||||
VESLG $2, T_3, T_4 \
|
|
||||||
VAG T_3, T_4, T_4 \
|
|
||||||
VAG T_2, h2, h2 \
|
|
||||||
VAG T_4, h0, h0 \
|
|
||||||
VESRLG $26, h2, T_0 \
|
|
||||||
VESRLG $26, h0, T_1 \
|
|
||||||
VN MOD26, h2, h2 \
|
|
||||||
VN MOD26, h0, h0 \
|
|
||||||
VAG T_0, h3, h3 \
|
|
||||||
VAG T_1, h1, h1 \
|
|
||||||
VESRLG $26, h3, T_2 \
|
|
||||||
VN MOD26, h3, h3 \
|
|
||||||
VAG T_2, h4, h4
|
|
||||||
|
|
||||||
// EXPAND splits the 128-bit little-endian values in0 and in1
|
|
||||||
// into 26-bit big-endian limbs and places the results into
|
|
||||||
// the first and second lane of d₂₆[0:4] respectively.
|
|
||||||
//
|
|
||||||
// The EX0, EX1 and EX2 constants are arrays of byte indices
|
|
||||||
// for permutation. The permutation both reverses the bytes
|
|
||||||
// in the input and ensures the bytes are copied into the
|
|
||||||
// destination limb ready to be shifted into their final
|
|
||||||
// position.
|
|
||||||
#define EXPAND(in0, in1, d0, d1, d2, d3, d4) \
|
|
||||||
VPERM in0, in1, EX0, d0 \
|
|
||||||
VPERM in0, in1, EX1, d2 \
|
|
||||||
VPERM in0, in1, EX2, d4 \
|
|
||||||
VESRLG $26, d0, d1 \
|
|
||||||
VESRLG $30, d2, d3 \
|
|
||||||
VESRLG $4, d2, d2 \
|
|
||||||
VN MOD26, d0, d0 \ // [in0₂₆[0], in1₂₆[0]]
|
|
||||||
VN MOD26, d3, d3 \ // [in0₂₆[3], in1₂₆[3]]
|
|
||||||
VN MOD26, d1, d1 \ // [in0₂₆[1], in1₂₆[1]]
|
|
||||||
VN MOD24, d4, d4 \ // [in0₂₆[4], in1₂₆[4]]
|
|
||||||
VN MOD26, d2, d2 // [in0₂₆[2], in1₂₆[2]]
|
|
||||||
|
|
||||||
// func updateVX(state *macState, msg []byte)
|
|
||||||
TEXT ·updateVX(SB), NOSPLIT, $0
|
|
||||||
MOVD state+0(FP), R1
|
|
||||||
LMG msg+8(FP), R2, R3 // R2=msg_base, R3=msg_len
|
|
||||||
|
|
||||||
// load EX0, EX1 and EX2
|
|
||||||
MOVD $·constants<>(SB), R5
|
|
||||||
VLM (R5), EX0, EX2
|
|
||||||
|
|
||||||
// generate masks
|
|
||||||
VGMG $(64-24), $63, MOD24 // [0x00ffffff, 0x00ffffff]
|
|
||||||
VGMG $(64-26), $63, MOD26 // [0x03ffffff, 0x03ffffff]
|
|
||||||
|
|
||||||
// load h (accumulator) and r (key) from state
|
|
||||||
VZERO T_1 // [0, 0]
|
|
||||||
VL 0(R1), T_0 // [h₆₄[0], h₆₄[1]]
|
|
||||||
VLEG $0, 16(R1), T_1 // [h₆₄[2], 0]
|
|
||||||
VL 24(R1), T_2 // [r₆₄[0], r₆₄[1]]
|
|
||||||
VPDI $0, T_0, T_2, T_3 // [h₆₄[0], r₆₄[0]]
|
|
||||||
VPDI $5, T_0, T_2, T_4 // [h₆₄[1], r₆₄[1]]
|
|
||||||
|
|
||||||
// unpack h and r into 26-bit limbs
|
|
||||||
// note: h₆₄[2] may have the low 3 bits set, so h₂₆[4] is a 27-bit value
|
|
||||||
VN MOD26, T_3, H_0 // [h₂₆[0], r₂₆[0]]
|
|
||||||
VZERO H_1 // [0, 0]
|
|
||||||
VZERO H_3 // [0, 0]
|
|
||||||
VGMG $(64-12-14), $(63-12), T_0 // [0x03fff000, 0x03fff000] - 26-bit mask with low 12 bits masked out
|
|
||||||
VESLG $24, T_1, T_1 // [h₆₄[2]<<24, 0]
|
|
||||||
VERIMG $-26&63, T_3, MOD26, H_1 // [h₂₆[1], r₂₆[1]]
|
|
||||||
VESRLG $+52&63, T_3, H_2 // [h₂₆[2], r₂₆[2]] - low 12 bits only
|
|
||||||
VERIMG $-14&63, T_4, MOD26, H_3 // [h₂₆[1], r₂₆[1]]
|
|
||||||
VESRLG $40, T_4, H_4 // [h₂₆[4], r₂₆[4]] - low 24 bits only
|
|
||||||
VERIMG $+12&63, T_4, T_0, H_2 // [h₂₆[2], r₂₆[2]] - complete
|
|
||||||
VO T_1, H_4, H_4 // [h₂₆[4], r₂₆[4]] - complete
|
|
||||||
|
|
||||||
// replicate r across all 4 vector elements
|
|
||||||
VREPF $3, H_0, R_0 // [r₂₆[0], r₂₆[0], r₂₆[0], r₂₆[0]]
|
|
||||||
VREPF $3, H_1, R_1 // [r₂₆[1], r₂₆[1], r₂₆[1], r₂₆[1]]
|
|
||||||
VREPF $3, H_2, R_2 // [r₂₆[2], r₂₆[2], r₂₆[2], r₂₆[2]]
|
|
||||||
VREPF $3, H_3, R_3 // [r₂₆[3], r₂₆[3], r₂₆[3], r₂₆[3]]
|
|
||||||
VREPF $3, H_4, R_4 // [r₂₆[4], r₂₆[4], r₂₆[4], r₂₆[4]]
|
|
||||||
|
|
||||||
// zero out lane 1 of h
|
|
||||||
VLEIG $1, $0, H_0 // [h₂₆[0], 0]
|
|
||||||
VLEIG $1, $0, H_1 // [h₂₆[1], 0]
|
|
||||||
VLEIG $1, $0, H_2 // [h₂₆[2], 0]
|
|
||||||
VLEIG $1, $0, H_3 // [h₂₆[3], 0]
|
|
||||||
VLEIG $1, $0, H_4 // [h₂₆[4], 0]
|
|
||||||
|
|
||||||
// calculate 5r (ignore least significant limb)
|
|
||||||
VREPIF $5, T_0
|
|
||||||
VMLF T_0, R_1, R5_1 // [5r₂₆[1], 5r₂₆[1], 5r₂₆[1], 5r₂₆[1]]
|
|
||||||
VMLF T_0, R_2, R5_2 // [5r₂₆[2], 5r₂₆[2], 5r₂₆[2], 5r₂₆[2]]
|
|
||||||
VMLF T_0, R_3, R5_3 // [5r₂₆[3], 5r₂₆[3], 5r₂₆[3], 5r₂₆[3]]
|
|
||||||
VMLF T_0, R_4, R5_4 // [5r₂₆[4], 5r₂₆[4], 5r₂₆[4], 5r₂₆[4]]
|
|
||||||
|
|
||||||
// skip r² calculation if we are only calculating one block
|
|
||||||
CMPBLE R3, $16, skip
|
|
||||||
|
|
||||||
// calculate r²
|
|
||||||
MULTIPLY(R_0, R_1, R_2, R_3, R_4, R_0, R_1, R_2, R_3, R_4, R5_1, R5_2, R5_3, R5_4, M_0, M_1, M_2, M_3, M_4)
|
|
||||||
REDUCE(M_0, M_1, M_2, M_3, M_4)
|
|
||||||
VGBM $0x0f0f, T_0
|
|
||||||
VERIMG $0, M_0, T_0, R_0 // [r₂₆[0], r²₂₆[0], r₂₆[0], r²₂₆[0]]
|
|
||||||
VERIMG $0, M_1, T_0, R_1 // [r₂₆[1], r²₂₆[1], r₂₆[1], r²₂₆[1]]
|
|
||||||
VERIMG $0, M_2, T_0, R_2 // [r₂₆[2], r²₂₆[2], r₂₆[2], r²₂₆[2]]
|
|
||||||
VERIMG $0, M_3, T_0, R_3 // [r₂₆[3], r²₂₆[3], r₂₆[3], r²₂₆[3]]
|
|
||||||
VERIMG $0, M_4, T_0, R_4 // [r₂₆[4], r²₂₆[4], r₂₆[4], r²₂₆[4]]
|
|
||||||
|
|
||||||
// calculate 5r² (ignore least significant limb)
|
|
||||||
VREPIF $5, T_0
|
|
||||||
VMLF T_0, R_1, R5_1 // [5r₂₆[1], 5r²₂₆[1], 5r₂₆[1], 5r²₂₆[1]]
|
|
||||||
VMLF T_0, R_2, R5_2 // [5r₂₆[2], 5r²₂₆[2], 5r₂₆[2], 5r²₂₆[2]]
|
|
||||||
VMLF T_0, R_3, R5_3 // [5r₂₆[3], 5r²₂₆[3], 5r₂₆[3], 5r²₂₆[3]]
|
|
||||||
VMLF T_0, R_4, R5_4 // [5r₂₆[4], 5r²₂₆[4], 5r₂₆[4], 5r²₂₆[4]]
|
|
||||||
|
|
||||||
loop:
|
|
||||||
CMPBLE R3, $32, b2 // 2 or fewer blocks remaining, need to change key coefficients
|
|
||||||
|
|
||||||
// load next 2 blocks from message
|
|
||||||
VLM (R2), T_0, T_1
|
|
||||||
|
|
||||||
// update message slice
|
|
||||||
SUB $32, R3
|
|
||||||
MOVD $32(R2), R2
|
|
||||||
|
|
||||||
// unpack message blocks into 26-bit big-endian limbs
|
|
||||||
EXPAND(T_0, T_1, M_0, M_1, M_2, M_3, M_4)
|
|
||||||
|
|
||||||
// add 2¹²⁸ to each message block value
|
|
||||||
VLEIB $4, $1, M_4
|
|
||||||
VLEIB $12, $1, M_4
|
|
||||||
|
|
||||||
multiply:
|
|
||||||
// accumulate the incoming message
|
|
||||||
VAG H_0, M_0, M_0
|
|
||||||
VAG H_3, M_3, M_3
|
|
||||||
VAG H_1, M_1, M_1
|
|
||||||
VAG H_4, M_4, M_4
|
|
||||||
VAG H_2, M_2, M_2
|
|
||||||
|
|
||||||
// multiply the accumulator by the key coefficient
|
|
||||||
MULTIPLY(M_0, M_1, M_2, M_3, M_4, R_0, R_1, R_2, R_3, R_4, R5_1, R5_2, R5_3, R5_4, H_0, H_1, H_2, H_3, H_4)
|
|
||||||
|
|
||||||
// carry and partially reduce the partial products
|
|
||||||
REDUCE(H_0, H_1, H_2, H_3, H_4)
|
|
||||||
|
|
||||||
CMPBNE R3, $0, loop
|
|
||||||
|
|
||||||
finish:
|
|
||||||
// sum lane 0 and lane 1 and put the result in lane 1
|
|
||||||
VZERO T_0
|
|
||||||
VSUMQG H_0, T_0, H_0
|
|
||||||
VSUMQG H_3, T_0, H_3
|
|
||||||
VSUMQG H_1, T_0, H_1
|
|
||||||
VSUMQG H_4, T_0, H_4
|
|
||||||
VSUMQG H_2, T_0, H_2
|
|
||||||
|
|
||||||
// reduce again after summation
|
|
||||||
// TODO(mundaym): there might be a more efficient way to do this
|
|
||||||
// now that we only have 1 active lane. For example, we could
|
|
||||||
// simultaneously pack the values as we reduce them.
|
|
||||||
REDUCE(H_0, H_1, H_2, H_3, H_4)
|
|
||||||
|
|
||||||
// carry h[1] through to h[4] so that only h[4] can exceed 2²⁶ - 1
|
|
||||||
// TODO(mundaym): in testing this final carry was unnecessary.
|
|
||||||
// Needs a proof before it can be removed though.
|
|
||||||
VESRLG $26, H_1, T_1
|
|
||||||
VN MOD26, H_1, H_1
|
|
||||||
VAQ T_1, H_2, H_2
|
|
||||||
VESRLG $26, H_2, T_2
|
|
||||||
VN MOD26, H_2, H_2
|
|
||||||
VAQ T_2, H_3, H_3
|
|
||||||
VESRLG $26, H_3, T_3
|
|
||||||
VN MOD26, H_3, H_3
|
|
||||||
VAQ T_3, H_4, H_4
|
|
||||||
|
|
||||||
// h is now < 2(2¹³⁰ - 5)
|
|
||||||
// Pack each lane in h₂₆[0:4] into h₁₂₈[0:1].
|
|
||||||
VESLG $26, H_1, H_1
|
|
||||||
VESLG $26, H_3, H_3
|
|
||||||
VO H_0, H_1, H_0
|
|
||||||
VO H_2, H_3, H_2
|
|
||||||
VESLG $4, H_2, H_2
|
|
||||||
VLEIB $7, $48, H_1
|
|
||||||
VSLB H_1, H_2, H_2
|
|
||||||
VO H_0, H_2, H_0
|
|
||||||
VLEIB $7, $104, H_1
|
|
||||||
VSLB H_1, H_4, H_3
|
|
||||||
VO H_3, H_0, H_0
|
|
||||||
VLEIB $7, $24, H_1
|
|
||||||
VSRLB H_1, H_4, H_1
|
|
||||||
|
|
||||||
// update state
|
|
||||||
VSTEG $1, H_0, 0(R1)
|
|
||||||
VSTEG $0, H_0, 8(R1)
|
|
||||||
VSTEG $1, H_1, 16(R1)
|
|
||||||
RET
|
|
||||||
|
|
||||||
b2: // 2 or fewer blocks remaining
|
|
||||||
CMPBLE R3, $16, b1
|
|
||||||
|
|
||||||
// Load the 2 remaining blocks (17-32 bytes remaining).
|
|
||||||
MOVD $-17(R3), R0 // index of final byte to load modulo 16
|
|
||||||
VL (R2), T_0 // load full 16 byte block
|
|
||||||
VLL R0, 16(R2), T_1 // load final (possibly partial) block and pad with zeros to 16 bytes
|
|
||||||
|
|
||||||
// The Poly1305 algorithm requires that a 1 bit be appended to
|
|
||||||
// each message block. If the final block is less than 16 bytes
|
|
||||||
// long then it is easiest to insert the 1 before the message
|
|
||||||
// block is split into 26-bit limbs. If, on the other hand, the
|
|
||||||
// final message block is 16 bytes long then we append the 1 bit
|
|
||||||
// after expansion as normal.
|
|
||||||
MOVBZ $1, R0
|
|
||||||
MOVD $-16(R3), R3 // index of byte in last block to insert 1 at (could be 16)
|
|
||||||
CMPBEQ R3, $16, 2(PC) // skip the insertion if the final block is 16 bytes long
|
|
||||||
VLVGB R3, R0, T_1 // insert 1 into the byte at index R3
|
|
||||||
|
|
||||||
// Split both blocks into 26-bit limbs in the appropriate lanes.
|
|
||||||
EXPAND(T_0, T_1, M_0, M_1, M_2, M_3, M_4)
|
|
||||||
|
|
||||||
// Append a 1 byte to the end of the second to last block.
|
|
||||||
VLEIB $4, $1, M_4
|
|
||||||
|
|
||||||
// Append a 1 byte to the end of the last block only if it is a
|
|
||||||
// full 16 byte block.
|
|
||||||
CMPBNE R3, $16, 2(PC)
|
|
||||||
VLEIB $12, $1, M_4
|
|
||||||
|
|
||||||
// Finally, set up the coefficients for the final multiplication.
|
|
||||||
// We have previously saved r and 5r in the 32-bit even indexes
|
|
||||||
// of the R_[0-4] and R5_[1-4] coefficient registers.
|
|
||||||
//
|
|
||||||
// We want lane 0 to be multiplied by r² so that can be kept the
|
|
||||||
// same. We want lane 1 to be multiplied by r so we need to move
|
|
||||||
// the saved r value into the 32-bit odd index in lane 1 by
|
|
||||||
// rotating the 64-bit lane by 32.
|
|
||||||
VGBM $0x00ff, T_0 // [0, 0xffffffffffffffff] - mask lane 1 only
|
|
||||||
VERIMG $32, R_0, T_0, R_0 // [_, r²₂₆[0], _, r₂₆[0]]
|
|
||||||
VERIMG $32, R_1, T_0, R_1 // [_, r²₂₆[1], _, r₂₆[1]]
|
|
||||||
VERIMG $32, R_2, T_0, R_2 // [_, r²₂₆[2], _, r₂₆[2]]
|
|
||||||
VERIMG $32, R_3, T_0, R_3 // [_, r²₂₆[3], _, r₂₆[3]]
|
|
||||||
VERIMG $32, R_4, T_0, R_4 // [_, r²₂₆[4], _, r₂₆[4]]
|
|
||||||
VERIMG $32, R5_1, T_0, R5_1 // [_, 5r²₂₆[1], _, 5r₂₆[1]]
|
|
||||||
VERIMG $32, R5_2, T_0, R5_2 // [_, 5r²₂₆[2], _, 5r₂₆[2]]
|
|
||||||
VERIMG $32, R5_3, T_0, R5_3 // [_, 5r²₂₆[3], _, 5r₂₆[3]]
|
|
||||||
VERIMG $32, R5_4, T_0, R5_4 // [_, 5r²₂₆[4], _, 5r₂₆[4]]
|
|
||||||
|
|
||||||
MOVD $0, R3
|
|
||||||
BR multiply
|
|
||||||
|
|
||||||
skip:
|
|
||||||
CMPBEQ R3, $0, finish
|
|
||||||
|
|
||||||
b1: // 1 block remaining
|
|
||||||
|
|
||||||
// Load the final block (1-16 bytes). This will be placed into
|
|
||||||
// lane 0.
|
|
||||||
MOVD $-1(R3), R0
|
|
||||||
VLL R0, (R2), T_0 // pad to 16 bytes with zeros
|
|
||||||
|
|
||||||
// The Poly1305 algorithm requires that a 1 bit be appended to
|
|
||||||
// each message block. If the final block is less than 16 bytes
|
|
||||||
// long then it is easiest to insert the 1 before the message
|
|
||||||
// block is split into 26-bit limbs. If, on the other hand, the
|
|
||||||
// final message block is 16 bytes long then we append the 1 bit
|
|
||||||
// after expansion as normal.
|
|
||||||
MOVBZ $1, R0
|
|
||||||
CMPBEQ R3, $16, 2(PC)
|
|
||||||
VLVGB R3, R0, T_0
|
|
||||||
|
|
||||||
// Set the message block in lane 1 to the value 0 so that it
|
|
||||||
// can be accumulated without affecting the final result.
|
|
||||||
VZERO T_1
|
|
||||||
|
|
||||||
// Split the final message block into 26-bit limbs in lane 0.
|
|
||||||
// Lane 1 will be contain 0.
|
|
||||||
EXPAND(T_0, T_1, M_0, M_1, M_2, M_3, M_4)
|
|
||||||
|
|
||||||
// Append a 1 byte to the end of the last block only if it is a
|
|
||||||
// full 16 byte block.
|
|
||||||
CMPBNE R3, $16, 2(PC)
|
|
||||||
VLEIB $4, $1, M_4
|
|
||||||
|
|
||||||
// We have previously saved r and 5r in the 32-bit even indexes
|
|
||||||
// of the R_[0-4] and R5_[1-4] coefficient registers.
|
|
||||||
//
|
|
||||||
// We want lane 0 to be multiplied by r so we need to move the
|
|
||||||
// saved r value into the 32-bit odd index in lane 0. We want
|
|
||||||
// lane 1 to be set to the value 1. This makes multiplication
|
|
||||||
// a no-op. We do this by setting lane 1 in every register to 0
|
|
||||||
// and then just setting the 32-bit index 3 in R_0 to 1.
|
|
||||||
VZERO T_0
|
|
||||||
MOVD $0, R0
|
|
||||||
MOVD $0x10111213, R12
|
|
||||||
VLVGP R12, R0, T_1 // [_, 0x10111213, _, 0x00000000]
|
|
||||||
VPERM T_0, R_0, T_1, R_0 // [_, r₂₆[0], _, 0]
|
|
||||||
VPERM T_0, R_1, T_1, R_1 // [_, r₂₆[1], _, 0]
|
|
||||||
VPERM T_0, R_2, T_1, R_2 // [_, r₂₆[2], _, 0]
|
|
||||||
VPERM T_0, R_3, T_1, R_3 // [_, r₂₆[3], _, 0]
|
|
||||||
VPERM T_0, R_4, T_1, R_4 // [_, r₂₆[4], _, 0]
|
|
||||||
VPERM T_0, R5_1, T_1, R5_1 // [_, 5r₂₆[1], _, 0]
|
|
||||||
VPERM T_0, R5_2, T_1, R5_2 // [_, 5r₂₆[2], _, 0]
|
|
||||||
VPERM T_0, R5_3, T_1, R5_3 // [_, 5r₂₆[3], _, 0]
|
|
||||||
VPERM T_0, R5_4, T_1, R5_4 // [_, 5r₂₆[4], _, 0]
|
|
||||||
|
|
||||||
// Set the value of lane 1 to be 1.
|
|
||||||
VLEIF $3, $1, R_0 // [_, r₂₆[0], _, 1]
|
|
||||||
|
|
||||||
MOVD $0, R3
|
|
||||||
BR multiply
|
|
@ -1,786 +0,0 @@
|
|||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto"
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/subtle"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math/big"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/curve25519"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1"
|
|
||||||
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
|
|
||||||
kexAlgoDH14SHA256 = "diffie-hellman-group14-sha256"
|
|
||||||
kexAlgoDH16SHA512 = "diffie-hellman-group16-sha512"
|
|
||||||
kexAlgoECDH256 = "ecdh-sha2-nistp256"
|
|
||||||
kexAlgoECDH384 = "ecdh-sha2-nistp384"
|
|
||||||
kexAlgoECDH521 = "ecdh-sha2-nistp521"
|
|
||||||
kexAlgoCurve25519SHA256LibSSH = "curve25519-sha256@libssh.org"
|
|
||||||
kexAlgoCurve25519SHA256 = "curve25519-sha256"
|
|
||||||
|
|
||||||
// For the following kex only the client half contains a production
|
|
||||||
// ready implementation. The server half only consists of a minimal
|
|
||||||
// implementation to satisfy the automated tests.
|
|
||||||
kexAlgoDHGEXSHA1 = "diffie-hellman-group-exchange-sha1"
|
|
||||||
kexAlgoDHGEXSHA256 = "diffie-hellman-group-exchange-sha256"
|
|
||||||
)
|
|
||||||
|
|
||||||
// kexResult captures the outcome of a key exchange.
|
|
||||||
type kexResult struct {
|
|
||||||
// Session hash. See also RFC 4253, section 8.
|
|
||||||
H []byte
|
|
||||||
|
|
||||||
// Shared secret. See also RFC 4253, section 8.
|
|
||||||
K []byte
|
|
||||||
|
|
||||||
// Host key as hashed into H.
|
|
||||||
HostKey []byte
|
|
||||||
|
|
||||||
// Signature of H.
|
|
||||||
Signature []byte
|
|
||||||
|
|
||||||
// A cryptographic hash function that matches the security
|
|
||||||
// level of the key exchange algorithm. It is used for
|
|
||||||
// calculating H, and for deriving keys from H and K.
|
|
||||||
Hash crypto.Hash
|
|
||||||
|
|
||||||
// The session ID, which is the first H computed. This is used
|
|
||||||
// to derive key material inside the transport.
|
|
||||||
SessionID []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// handshakeMagics contains data that is always included in the
|
|
||||||
// session hash.
|
|
||||||
type handshakeMagics struct {
|
|
||||||
clientVersion, serverVersion []byte
|
|
||||||
clientKexInit, serverKexInit []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *handshakeMagics) write(w io.Writer) {
|
|
||||||
writeString(w, m.clientVersion)
|
|
||||||
writeString(w, m.serverVersion)
|
|
||||||
writeString(w, m.clientKexInit)
|
|
||||||
writeString(w, m.serverKexInit)
|
|
||||||
}
|
|
||||||
|
|
||||||
// kexAlgorithm abstracts different key exchange algorithms.
|
|
||||||
type kexAlgorithm interface {
|
|
||||||
// Server runs server-side key agreement, signing the result
|
|
||||||
// with a hostkey. algo is the negotiated algorithm, and may
|
|
||||||
// be a certificate type.
|
|
||||||
Server(p packetConn, rand io.Reader, magics *handshakeMagics, s AlgorithmSigner, algo string) (*kexResult, error)
|
|
||||||
|
|
||||||
// Client runs the client-side key agreement. Caller is
|
|
||||||
// responsible for verifying the host key signature.
|
|
||||||
Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
|
|
||||||
type dhGroup struct {
|
|
||||||
g, p, pMinus1 *big.Int
|
|
||||||
hashFunc crypto.Hash
|
|
||||||
}
|
|
||||||
|
|
||||||
func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
|
|
||||||
if theirPublic.Cmp(bigOne) <= 0 || theirPublic.Cmp(group.pMinus1) >= 0 {
|
|
||||||
return nil, errors.New("ssh: DH parameter out of bounds")
|
|
||||||
}
|
|
||||||
return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
|
|
||||||
var x *big.Int
|
|
||||||
for {
|
|
||||||
var err error
|
|
||||||
if x, err = rand.Int(randSource, group.pMinus1); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if x.Sign() > 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
X := new(big.Int).Exp(group.g, x, group.p)
|
|
||||||
kexDHInit := kexDHInitMsg{
|
|
||||||
X: X,
|
|
||||||
}
|
|
||||||
if err := c.writePacket(Marshal(&kexDHInit)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var kexDHReply kexDHReplyMsg
|
|
||||||
if err = Unmarshal(packet, &kexDHReply); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ki, err := group.diffieHellman(kexDHReply.Y, x)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
h := group.hashFunc.New()
|
|
||||||
magics.write(h)
|
|
||||||
writeString(h, kexDHReply.HostKey)
|
|
||||||
writeInt(h, X)
|
|
||||||
writeInt(h, kexDHReply.Y)
|
|
||||||
K := make([]byte, intLength(ki))
|
|
||||||
marshalInt(K, ki)
|
|
||||||
h.Write(K)
|
|
||||||
|
|
||||||
return &kexResult{
|
|
||||||
H: h.Sum(nil),
|
|
||||||
K: K,
|
|
||||||
HostKey: kexDHReply.HostKey,
|
|
||||||
Signature: kexDHReply.Signature,
|
|
||||||
Hash: group.hashFunc,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var kexDHInit kexDHInitMsg
|
|
||||||
if err = Unmarshal(packet, &kexDHInit); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var y *big.Int
|
|
||||||
for {
|
|
||||||
if y, err = rand.Int(randSource, group.pMinus1); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if y.Sign() > 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Y := new(big.Int).Exp(group.g, y, group.p)
|
|
||||||
ki, err := group.diffieHellman(kexDHInit.X, y)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hostKeyBytes := priv.PublicKey().Marshal()
|
|
||||||
|
|
||||||
h := group.hashFunc.New()
|
|
||||||
magics.write(h)
|
|
||||||
writeString(h, hostKeyBytes)
|
|
||||||
writeInt(h, kexDHInit.X)
|
|
||||||
writeInt(h, Y)
|
|
||||||
|
|
||||||
K := make([]byte, intLength(ki))
|
|
||||||
marshalInt(K, ki)
|
|
||||||
h.Write(K)
|
|
||||||
|
|
||||||
H := h.Sum(nil)
|
|
||||||
|
|
||||||
// H is already a hash, but the hostkey signing will apply its
|
|
||||||
// own key-specific hash algorithm.
|
|
||||||
sig, err := signAndMarshal(priv, randSource, H, algo)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
kexDHReply := kexDHReplyMsg{
|
|
||||||
HostKey: hostKeyBytes,
|
|
||||||
Y: Y,
|
|
||||||
Signature: sig,
|
|
||||||
}
|
|
||||||
packet = Marshal(&kexDHReply)
|
|
||||||
|
|
||||||
err = c.writePacket(packet)
|
|
||||||
return &kexResult{
|
|
||||||
H: H,
|
|
||||||
K: K,
|
|
||||||
HostKey: hostKeyBytes,
|
|
||||||
Signature: sig,
|
|
||||||
Hash: group.hashFunc,
|
|
||||||
}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ecdh performs Elliptic Curve Diffie-Hellman key exchange as
|
|
||||||
// described in RFC 5656, section 4.
|
|
||||||
type ecdh struct {
|
|
||||||
curve elliptic.Curve
|
|
||||||
}
|
|
||||||
|
|
||||||
func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
|
|
||||||
ephKey, err := ecdsa.GenerateKey(kex.curve, rand)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
kexInit := kexECDHInitMsg{
|
|
||||||
ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y),
|
|
||||||
}
|
|
||||||
|
|
||||||
serialized := Marshal(&kexInit)
|
|
||||||
if err := c.writePacket(serialized); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var reply kexECDHReplyMsg
|
|
||||||
if err = Unmarshal(packet, &reply); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// generate shared secret
|
|
||||||
secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes())
|
|
||||||
|
|
||||||
h := ecHash(kex.curve).New()
|
|
||||||
magics.write(h)
|
|
||||||
writeString(h, reply.HostKey)
|
|
||||||
writeString(h, kexInit.ClientPubKey)
|
|
||||||
writeString(h, reply.EphemeralPubKey)
|
|
||||||
K := make([]byte, intLength(secret))
|
|
||||||
marshalInt(K, secret)
|
|
||||||
h.Write(K)
|
|
||||||
|
|
||||||
return &kexResult{
|
|
||||||
H: h.Sum(nil),
|
|
||||||
K: K,
|
|
||||||
HostKey: reply.HostKey,
|
|
||||||
Signature: reply.Signature,
|
|
||||||
Hash: ecHash(kex.curve),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// unmarshalECKey parses and checks an EC key.
|
|
||||||
func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) {
|
|
||||||
x, y = elliptic.Unmarshal(curve, pubkey)
|
|
||||||
if x == nil {
|
|
||||||
return nil, nil, errors.New("ssh: elliptic.Unmarshal failure")
|
|
||||||
}
|
|
||||||
if !validateECPublicKey(curve, x, y) {
|
|
||||||
return nil, nil, errors.New("ssh: public key not on curve")
|
|
||||||
}
|
|
||||||
return x, y, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateECPublicKey checks that the point is a valid public key for
|
|
||||||
// the given curve. See [SEC1], 3.2.2
|
|
||||||
func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool {
|
|
||||||
if x.Sign() == 0 && y.Sign() == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if x.Cmp(curve.Params().P) >= 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if y.Cmp(curve.Params().P) >= 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !curve.IsOnCurve(x, y) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// We don't check if N * PubKey == 0, since
|
|
||||||
//
|
|
||||||
// - the NIST curves have cofactor = 1, so this is implicit.
|
|
||||||
// (We don't foresee an implementation that supports non NIST
|
|
||||||
// curves)
|
|
||||||
//
|
|
||||||
// - for ephemeral keys, we don't need to worry about small
|
|
||||||
// subgroup attacks.
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var kexECDHInit kexECDHInitMsg
|
|
||||||
if err = Unmarshal(packet, &kexECDHInit); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We could cache this key across multiple users/multiple
|
|
||||||
// connection attempts, but the benefit is small. OpenSSH
|
|
||||||
// generates a new key for each incoming connection.
|
|
||||||
ephKey, err := ecdsa.GenerateKey(kex.curve, rand)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hostKeyBytes := priv.PublicKey().Marshal()
|
|
||||||
|
|
||||||
serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y)
|
|
||||||
|
|
||||||
// generate shared secret
|
|
||||||
secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes())
|
|
||||||
|
|
||||||
h := ecHash(kex.curve).New()
|
|
||||||
magics.write(h)
|
|
||||||
writeString(h, hostKeyBytes)
|
|
||||||
writeString(h, kexECDHInit.ClientPubKey)
|
|
||||||
writeString(h, serializedEphKey)
|
|
||||||
|
|
||||||
K := make([]byte, intLength(secret))
|
|
||||||
marshalInt(K, secret)
|
|
||||||
h.Write(K)
|
|
||||||
|
|
||||||
H := h.Sum(nil)
|
|
||||||
|
|
||||||
// H is already a hash, but the hostkey signing will apply its
|
|
||||||
// own key-specific hash algorithm.
|
|
||||||
sig, err := signAndMarshal(priv, rand, H, algo)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
reply := kexECDHReplyMsg{
|
|
||||||
EphemeralPubKey: serializedEphKey,
|
|
||||||
HostKey: hostKeyBytes,
|
|
||||||
Signature: sig,
|
|
||||||
}
|
|
||||||
|
|
||||||
serialized := Marshal(&reply)
|
|
||||||
if err := c.writePacket(serialized); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &kexResult{
|
|
||||||
H: H,
|
|
||||||
K: K,
|
|
||||||
HostKey: reply.HostKey,
|
|
||||||
Signature: sig,
|
|
||||||
Hash: ecHash(kex.curve),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ecHash returns the hash to match the given elliptic curve, see RFC
|
|
||||||
// 5656, section 6.2.1
|
|
||||||
func ecHash(curve elliptic.Curve) crypto.Hash {
|
|
||||||
bitSize := curve.Params().BitSize
|
|
||||||
switch {
|
|
||||||
case bitSize <= 256:
|
|
||||||
return crypto.SHA256
|
|
||||||
case bitSize <= 384:
|
|
||||||
return crypto.SHA384
|
|
||||||
}
|
|
||||||
return crypto.SHA512
|
|
||||||
}
|
|
||||||
|
|
||||||
var kexAlgoMap = map[string]kexAlgorithm{}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
// This is the group called diffie-hellman-group1-sha1 in
|
|
||||||
// RFC 4253 and Oakley Group 2 in RFC 2409.
|
|
||||||
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16)
|
|
||||||
kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{
|
|
||||||
g: new(big.Int).SetInt64(2),
|
|
||||||
p: p,
|
|
||||||
pMinus1: new(big.Int).Sub(p, bigOne),
|
|
||||||
hashFunc: crypto.SHA1,
|
|
||||||
}
|
|
||||||
|
|
||||||
// This are the groups called diffie-hellman-group14-sha1 and
|
|
||||||
// diffie-hellman-group14-sha256 in RFC 4253 and RFC 8268,
|
|
||||||
// and Oakley Group 14 in RFC 3526.
|
|
||||||
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
|
|
||||||
group14 := &dhGroup{
|
|
||||||
g: new(big.Int).SetInt64(2),
|
|
||||||
p: p,
|
|
||||||
pMinus1: new(big.Int).Sub(p, bigOne),
|
|
||||||
}
|
|
||||||
|
|
||||||
kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{
|
|
||||||
g: group14.g, p: group14.p, pMinus1: group14.pMinus1,
|
|
||||||
hashFunc: crypto.SHA1,
|
|
||||||
}
|
|
||||||
kexAlgoMap[kexAlgoDH14SHA256] = &dhGroup{
|
|
||||||
g: group14.g, p: group14.p, pMinus1: group14.pMinus1,
|
|
||||||
hashFunc: crypto.SHA256,
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is the group called diffie-hellman-group16-sha512 in RFC
|
|
||||||
// 8268 and Oakley Group 16 in RFC 3526.
|
|
||||||
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF", 16)
|
|
||||||
|
|
||||||
kexAlgoMap[kexAlgoDH16SHA512] = &dhGroup{
|
|
||||||
g: new(big.Int).SetInt64(2),
|
|
||||||
p: p,
|
|
||||||
pMinus1: new(big.Int).Sub(p, bigOne),
|
|
||||||
hashFunc: crypto.SHA512,
|
|
||||||
}
|
|
||||||
|
|
||||||
kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()}
|
|
||||||
kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()}
|
|
||||||
kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()}
|
|
||||||
kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{}
|
|
||||||
kexAlgoMap[kexAlgoCurve25519SHA256LibSSH] = &curve25519sha256{}
|
|
||||||
kexAlgoMap[kexAlgoDHGEXSHA1] = &dhGEXSHA{hashFunc: crypto.SHA1}
|
|
||||||
kexAlgoMap[kexAlgoDHGEXSHA256] = &dhGEXSHA{hashFunc: crypto.SHA256}
|
|
||||||
}
|
|
||||||
|
|
||||||
// curve25519sha256 implements the curve25519-sha256 (formerly known as
|
|
||||||
// curve25519-sha256@libssh.org) key exchange method, as described in RFC 8731.
|
|
||||||
type curve25519sha256 struct{}
|
|
||||||
|
|
||||||
type curve25519KeyPair struct {
|
|
||||||
priv [32]byte
|
|
||||||
pub [32]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (kp *curve25519KeyPair) generate(rand io.Reader) error {
|
|
||||||
if _, err := io.ReadFull(rand, kp.priv[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
curve25519.ScalarBaseMult(&kp.pub, &kp.priv)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// curve25519Zeros is just an array of 32 zero bytes so that we have something
|
|
||||||
// convenient to compare against in order to reject curve25519 points with the
|
|
||||||
// wrong order.
|
|
||||||
var curve25519Zeros [32]byte
|
|
||||||
|
|
||||||
func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
|
|
||||||
var kp curve25519KeyPair
|
|
||||||
if err := kp.generate(rand); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var reply kexECDHReplyMsg
|
|
||||||
if err = Unmarshal(packet, &reply); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(reply.EphemeralPubKey) != 32 {
|
|
||||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong length")
|
|
||||||
}
|
|
||||||
|
|
||||||
var servPub, secret [32]byte
|
|
||||||
copy(servPub[:], reply.EphemeralPubKey)
|
|
||||||
curve25519.ScalarMult(&secret, &kp.priv, &servPub)
|
|
||||||
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 {
|
|
||||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong order")
|
|
||||||
}
|
|
||||||
|
|
||||||
h := crypto.SHA256.New()
|
|
||||||
magics.write(h)
|
|
||||||
writeString(h, reply.HostKey)
|
|
||||||
writeString(h, kp.pub[:])
|
|
||||||
writeString(h, reply.EphemeralPubKey)
|
|
||||||
|
|
||||||
ki := new(big.Int).SetBytes(secret[:])
|
|
||||||
K := make([]byte, intLength(ki))
|
|
||||||
marshalInt(K, ki)
|
|
||||||
h.Write(K)
|
|
||||||
|
|
||||||
return &kexResult{
|
|
||||||
H: h.Sum(nil),
|
|
||||||
K: K,
|
|
||||||
HostKey: reply.HostKey,
|
|
||||||
Signature: reply.Signature,
|
|
||||||
Hash: crypto.SHA256,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var kexInit kexECDHInitMsg
|
|
||||||
if err = Unmarshal(packet, &kexInit); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(kexInit.ClientPubKey) != 32 {
|
|
||||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong length")
|
|
||||||
}
|
|
||||||
|
|
||||||
var kp curve25519KeyPair
|
|
||||||
if err := kp.generate(rand); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var clientPub, secret [32]byte
|
|
||||||
copy(clientPub[:], kexInit.ClientPubKey)
|
|
||||||
curve25519.ScalarMult(&secret, &kp.priv, &clientPub)
|
|
||||||
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 {
|
|
||||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong order")
|
|
||||||
}
|
|
||||||
|
|
||||||
hostKeyBytes := priv.PublicKey().Marshal()
|
|
||||||
|
|
||||||
h := crypto.SHA256.New()
|
|
||||||
magics.write(h)
|
|
||||||
writeString(h, hostKeyBytes)
|
|
||||||
writeString(h, kexInit.ClientPubKey)
|
|
||||||
writeString(h, kp.pub[:])
|
|
||||||
|
|
||||||
ki := new(big.Int).SetBytes(secret[:])
|
|
||||||
K := make([]byte, intLength(ki))
|
|
||||||
marshalInt(K, ki)
|
|
||||||
h.Write(K)
|
|
||||||
|
|
||||||
H := h.Sum(nil)
|
|
||||||
|
|
||||||
sig, err := signAndMarshal(priv, rand, H, algo)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
reply := kexECDHReplyMsg{
|
|
||||||
EphemeralPubKey: kp.pub[:],
|
|
||||||
HostKey: hostKeyBytes,
|
|
||||||
Signature: sig,
|
|
||||||
}
|
|
||||||
if err := c.writePacket(Marshal(&reply)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &kexResult{
|
|
||||||
H: H,
|
|
||||||
K: K,
|
|
||||||
HostKey: hostKeyBytes,
|
|
||||||
Signature: sig,
|
|
||||||
Hash: crypto.SHA256,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// dhGEXSHA implements the diffie-hellman-group-exchange-sha1 and
|
|
||||||
// diffie-hellman-group-exchange-sha256 key agreement protocols,
|
|
||||||
// as described in RFC 4419
|
|
||||||
type dhGEXSHA struct {
|
|
||||||
hashFunc crypto.Hash
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
dhGroupExchangeMinimumBits = 2048
|
|
||||||
dhGroupExchangePreferredBits = 2048
|
|
||||||
dhGroupExchangeMaximumBits = 8192
|
|
||||||
)
|
|
||||||
|
|
||||||
func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
|
|
||||||
// Send GexRequest
|
|
||||||
kexDHGexRequest := kexDHGexRequestMsg{
|
|
||||||
MinBits: dhGroupExchangeMinimumBits,
|
|
||||||
PreferedBits: dhGroupExchangePreferredBits,
|
|
||||||
MaxBits: dhGroupExchangeMaximumBits,
|
|
||||||
}
|
|
||||||
if err := c.writePacket(Marshal(&kexDHGexRequest)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Receive GexGroup
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var msg kexDHGexGroupMsg
|
|
||||||
if err = Unmarshal(packet, &msg); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// reject if p's bit length < dhGroupExchangeMinimumBits or > dhGroupExchangeMaximumBits
|
|
||||||
if msg.P.BitLen() < dhGroupExchangeMinimumBits || msg.P.BitLen() > dhGroupExchangeMaximumBits {
|
|
||||||
return nil, fmt.Errorf("ssh: server-generated gex p is out of range (%d bits)", msg.P.BitLen())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if g is safe by verifying that 1 < g < p-1
|
|
||||||
pMinusOne := new(big.Int).Sub(msg.P, bigOne)
|
|
||||||
if msg.G.Cmp(bigOne) <= 0 || msg.G.Cmp(pMinusOne) >= 0 {
|
|
||||||
return nil, fmt.Errorf("ssh: server provided gex g is not safe")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send GexInit
|
|
||||||
pHalf := new(big.Int).Rsh(msg.P, 1)
|
|
||||||
x, err := rand.Int(randSource, pHalf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
X := new(big.Int).Exp(msg.G, x, msg.P)
|
|
||||||
kexDHGexInit := kexDHGexInitMsg{
|
|
||||||
X: X,
|
|
||||||
}
|
|
||||||
if err := c.writePacket(Marshal(&kexDHGexInit)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Receive GexReply
|
|
||||||
packet, err = c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var kexDHGexReply kexDHGexReplyMsg
|
|
||||||
if err = Unmarshal(packet, &kexDHGexReply); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if kexDHGexReply.Y.Cmp(bigOne) <= 0 || kexDHGexReply.Y.Cmp(pMinusOne) >= 0 {
|
|
||||||
return nil, errors.New("ssh: DH parameter out of bounds")
|
|
||||||
}
|
|
||||||
kInt := new(big.Int).Exp(kexDHGexReply.Y, x, msg.P)
|
|
||||||
|
|
||||||
// Check if k is safe by verifying that k > 1 and k < p - 1
|
|
||||||
if kInt.Cmp(bigOne) <= 0 || kInt.Cmp(pMinusOne) >= 0 {
|
|
||||||
return nil, fmt.Errorf("ssh: derived k is not safe")
|
|
||||||
}
|
|
||||||
|
|
||||||
h := gex.hashFunc.New()
|
|
||||||
magics.write(h)
|
|
||||||
writeString(h, kexDHGexReply.HostKey)
|
|
||||||
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits))
|
|
||||||
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits))
|
|
||||||
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits))
|
|
||||||
writeInt(h, msg.P)
|
|
||||||
writeInt(h, msg.G)
|
|
||||||
writeInt(h, X)
|
|
||||||
writeInt(h, kexDHGexReply.Y)
|
|
||||||
K := make([]byte, intLength(kInt))
|
|
||||||
marshalInt(K, kInt)
|
|
||||||
h.Write(K)
|
|
||||||
|
|
||||||
return &kexResult{
|
|
||||||
H: h.Sum(nil),
|
|
||||||
K: K,
|
|
||||||
HostKey: kexDHGexReply.HostKey,
|
|
||||||
Signature: kexDHGexReply.Signature,
|
|
||||||
Hash: gex.hashFunc,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Server half implementation of the Diffie Hellman Key Exchange with SHA1 and SHA256.
|
|
||||||
//
|
|
||||||
// This is a minimal implementation to satisfy the automated tests.
|
|
||||||
func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
|
|
||||||
// Receive GexRequest
|
|
||||||
packet, err := c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var kexDHGexRequest kexDHGexRequestMsg
|
|
||||||
if err = Unmarshal(packet, &kexDHGexRequest); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send GexGroup
|
|
||||||
// This is the group called diffie-hellman-group14-sha1 in RFC
|
|
||||||
// 4253 and Oakley Group 14 in RFC 3526.
|
|
||||||
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
|
|
||||||
g := big.NewInt(2)
|
|
||||||
|
|
||||||
msg := &kexDHGexGroupMsg{
|
|
||||||
P: p,
|
|
||||||
G: g,
|
|
||||||
}
|
|
||||||
if err := c.writePacket(Marshal(msg)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Receive GexInit
|
|
||||||
packet, err = c.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var kexDHGexInit kexDHGexInitMsg
|
|
||||||
if err = Unmarshal(packet, &kexDHGexInit); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
pHalf := new(big.Int).Rsh(p, 1)
|
|
||||||
|
|
||||||
y, err := rand.Int(randSource, pHalf)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
Y := new(big.Int).Exp(g, y, p)
|
|
||||||
|
|
||||||
pMinusOne := new(big.Int).Sub(p, bigOne)
|
|
||||||
if kexDHGexInit.X.Cmp(bigOne) <= 0 || kexDHGexInit.X.Cmp(pMinusOne) >= 0 {
|
|
||||||
return nil, errors.New("ssh: DH parameter out of bounds")
|
|
||||||
}
|
|
||||||
kInt := new(big.Int).Exp(kexDHGexInit.X, y, p)
|
|
||||||
|
|
||||||
hostKeyBytes := priv.PublicKey().Marshal()
|
|
||||||
|
|
||||||
h := gex.hashFunc.New()
|
|
||||||
magics.write(h)
|
|
||||||
writeString(h, hostKeyBytes)
|
|
||||||
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits))
|
|
||||||
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits))
|
|
||||||
binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits))
|
|
||||||
writeInt(h, p)
|
|
||||||
writeInt(h, g)
|
|
||||||
writeInt(h, kexDHGexInit.X)
|
|
||||||
writeInt(h, Y)
|
|
||||||
|
|
||||||
K := make([]byte, intLength(kInt))
|
|
||||||
marshalInt(K, kInt)
|
|
||||||
h.Write(K)
|
|
||||||
|
|
||||||
H := h.Sum(nil)
|
|
||||||
|
|
||||||
// H is already a hash, but the hostkey signing will apply its
|
|
||||||
// own key-specific hash algorithm.
|
|
||||||
sig, err := signAndMarshal(priv, randSource, H, algo)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
kexDHGexReply := kexDHGexReplyMsg{
|
|
||||||
HostKey: hostKeyBytes,
|
|
||||||
Y: Y,
|
|
||||||
Signature: sig,
|
|
||||||
}
|
|
||||||
packet = Marshal(&kexDHGexReply)
|
|
||||||
|
|
||||||
err = c.writePacket(packet)
|
|
||||||
|
|
||||||
return &kexResult{
|
|
||||||
H: H,
|
|
||||||
K: K,
|
|
||||||
HostKey: hostKeyBytes,
|
|
||||||
Signature: sig,
|
|
||||||
Hash: gex.hashFunc,
|
|
||||||
}, err
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
@ -1,540 +0,0 @@
|
|||||||
// Copyright 2017 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// Package knownhosts implements a parser for the OpenSSH known_hosts
|
|
||||||
// host key database, and provides utility functions for writing
|
|
||||||
// OpenSSH compliant known_hosts files.
|
|
||||||
package knownhosts
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha1"
|
|
||||||
"encoding/base64"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
// See the sshd manpage
|
|
||||||
// (http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT) for
|
|
||||||
// background.
|
|
||||||
|
|
||||||
type addr struct{ host, port string }
|
|
||||||
|
|
||||||
func (a *addr) String() string {
|
|
||||||
h := a.host
|
|
||||||
if strings.Contains(h, ":") {
|
|
||||||
h = "[" + h + "]"
|
|
||||||
}
|
|
||||||
return h + ":" + a.port
|
|
||||||
}
|
|
||||||
|
|
||||||
type matcher interface {
|
|
||||||
match(addr) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type hostPattern struct {
|
|
||||||
negate bool
|
|
||||||
addr addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *hostPattern) String() string {
|
|
||||||
n := ""
|
|
||||||
if p.negate {
|
|
||||||
n = "!"
|
|
||||||
}
|
|
||||||
|
|
||||||
return n + p.addr.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
type hostPatterns []hostPattern
|
|
||||||
|
|
||||||
func (ps hostPatterns) match(a addr) bool {
|
|
||||||
matched := false
|
|
||||||
for _, p := range ps {
|
|
||||||
if !p.match(a) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if p.negate {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
matched = true
|
|
||||||
}
|
|
||||||
return matched
|
|
||||||
}
|
|
||||||
|
|
||||||
// See
|
|
||||||
// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/addrmatch.c
|
|
||||||
// The matching of * has no regard for separators, unlike filesystem globs
|
|
||||||
func wildcardMatch(pat []byte, str []byte) bool {
|
|
||||||
for {
|
|
||||||
if len(pat) == 0 {
|
|
||||||
return len(str) == 0
|
|
||||||
}
|
|
||||||
if len(str) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if pat[0] == '*' {
|
|
||||||
if len(pat) == 1 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
for j := range str {
|
|
||||||
if wildcardMatch(pat[1:], str[j:]) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if pat[0] == '?' || pat[0] == str[0] {
|
|
||||||
pat = pat[1:]
|
|
||||||
str = str[1:]
|
|
||||||
} else {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *hostPattern) match(a addr) bool {
|
|
||||||
return wildcardMatch([]byte(p.addr.host), []byte(a.host)) && p.addr.port == a.port
|
|
||||||
}
|
|
||||||
|
|
||||||
type keyDBLine struct {
|
|
||||||
cert bool
|
|
||||||
matcher matcher
|
|
||||||
knownKey KnownKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func serialize(k ssh.PublicKey) string {
|
|
||||||
return k.Type() + " " + base64.StdEncoding.EncodeToString(k.Marshal())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *keyDBLine) match(a addr) bool {
|
|
||||||
return l.matcher.match(a)
|
|
||||||
}
|
|
||||||
|
|
||||||
type hostKeyDB struct {
|
|
||||||
// Serialized version of revoked keys
|
|
||||||
revoked map[string]*KnownKey
|
|
||||||
lines []keyDBLine
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHostKeyDB() *hostKeyDB {
|
|
||||||
db := &hostKeyDB{
|
|
||||||
revoked: make(map[string]*KnownKey),
|
|
||||||
}
|
|
||||||
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
|
|
||||||
func keyEq(a, b ssh.PublicKey) bool {
|
|
||||||
return bytes.Equal(a.Marshal(), b.Marshal())
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsHostAuthority can be used as a callback in ssh.CertChecker
|
|
||||||
func (db *hostKeyDB) IsHostAuthority(remote ssh.PublicKey, address string) bool {
|
|
||||||
h, p, err := net.SplitHostPort(address)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
a := addr{host: h, port: p}
|
|
||||||
|
|
||||||
for _, l := range db.lines {
|
|
||||||
if l.cert && keyEq(l.knownKey.Key, remote) && l.match(a) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsRevoked can be used as a callback in ssh.CertChecker
|
|
||||||
func (db *hostKeyDB) IsRevoked(key *ssh.Certificate) bool {
|
|
||||||
_, ok := db.revoked[string(key.Marshal())]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
const markerCert = "@cert-authority"
|
|
||||||
const markerRevoked = "@revoked"
|
|
||||||
|
|
||||||
func nextWord(line []byte) (string, []byte) {
|
|
||||||
i := bytes.IndexAny(line, "\t ")
|
|
||||||
if i == -1 {
|
|
||||||
return string(line), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(line[:i]), bytes.TrimSpace(line[i:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseLine(line []byte) (marker, host string, key ssh.PublicKey, err error) {
|
|
||||||
if w, next := nextWord(line); w == markerCert || w == markerRevoked {
|
|
||||||
marker = w
|
|
||||||
line = next
|
|
||||||
}
|
|
||||||
|
|
||||||
host, line = nextWord(line)
|
|
||||||
if len(line) == 0 {
|
|
||||||
return "", "", nil, errors.New("knownhosts: missing host pattern")
|
|
||||||
}
|
|
||||||
|
|
||||||
// ignore the keytype as it's in the key blob anyway.
|
|
||||||
_, line = nextWord(line)
|
|
||||||
if len(line) == 0 {
|
|
||||||
return "", "", nil, errors.New("knownhosts: missing key type pattern")
|
|
||||||
}
|
|
||||||
|
|
||||||
keyBlob, _ := nextWord(line)
|
|
||||||
|
|
||||||
keyBytes, err := base64.StdEncoding.DecodeString(keyBlob)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", nil, err
|
|
||||||
}
|
|
||||||
key, err = ssh.ParsePublicKey(keyBytes)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return marker, host, key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *hostKeyDB) parseLine(line []byte, filename string, linenum int) error {
|
|
||||||
marker, pattern, key, err := parseLine(line)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if marker == markerRevoked {
|
|
||||||
db.revoked[string(key.Marshal())] = &KnownKey{
|
|
||||||
Key: key,
|
|
||||||
Filename: filename,
|
|
||||||
Line: linenum,
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := keyDBLine{
|
|
||||||
cert: marker == markerCert,
|
|
||||||
knownKey: KnownKey{
|
|
||||||
Filename: filename,
|
|
||||||
Line: linenum,
|
|
||||||
Key: key,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if pattern[0] == '|' {
|
|
||||||
entry.matcher, err = newHashedHost(pattern)
|
|
||||||
} else {
|
|
||||||
entry.matcher, err = newHostnameMatcher(pattern)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
db.lines = append(db.lines, entry)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHostnameMatcher(pattern string) (matcher, error) {
|
|
||||||
var hps hostPatterns
|
|
||||||
for _, p := range strings.Split(pattern, ",") {
|
|
||||||
if len(p) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var a addr
|
|
||||||
var negate bool
|
|
||||||
if p[0] == '!' {
|
|
||||||
negate = true
|
|
||||||
p = p[1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(p) == 0 {
|
|
||||||
return nil, errors.New("knownhosts: negation without following hostname")
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if p[0] == '[' {
|
|
||||||
a.host, a.port, err = net.SplitHostPort(p)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
a.host, a.port, err = net.SplitHostPort(p)
|
|
||||||
if err != nil {
|
|
||||||
a.host = p
|
|
||||||
a.port = "22"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
hps = append(hps, hostPattern{
|
|
||||||
negate: negate,
|
|
||||||
addr: a,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return hps, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// KnownKey represents a key declared in a known_hosts file.
|
|
||||||
type KnownKey struct {
|
|
||||||
Key ssh.PublicKey
|
|
||||||
Filename string
|
|
||||||
Line int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *KnownKey) String() string {
|
|
||||||
return fmt.Sprintf("%s:%d: %s", k.Filename, k.Line, serialize(k.Key))
|
|
||||||
}
|
|
||||||
|
|
||||||
// KeyError is returned if we did not find the key in the host key
|
|
||||||
// database, or there was a mismatch. Typically, in batch
|
|
||||||
// applications, this should be interpreted as failure. Interactive
|
|
||||||
// applications can offer an interactive prompt to the user.
|
|
||||||
type KeyError struct {
|
|
||||||
// Want holds the accepted host keys. For each key algorithm,
|
|
||||||
// there can be one hostkey. If Want is empty, the host is
|
|
||||||
// unknown. If Want is non-empty, there was a mismatch, which
|
|
||||||
// can signify a MITM attack.
|
|
||||||
Want []KnownKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *KeyError) Error() string {
|
|
||||||
if len(u.Want) == 0 {
|
|
||||||
return "knownhosts: key is unknown"
|
|
||||||
}
|
|
||||||
return "knownhosts: key mismatch"
|
|
||||||
}
|
|
||||||
|
|
||||||
// RevokedError is returned if we found a key that was revoked.
|
|
||||||
type RevokedError struct {
|
|
||||||
Revoked KnownKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RevokedError) Error() string {
|
|
||||||
return "knownhosts: key is revoked"
|
|
||||||
}
|
|
||||||
|
|
||||||
// check checks a key against the host database. This should not be
|
|
||||||
// used for verifying certificates.
|
|
||||||
func (db *hostKeyDB) check(address string, remote net.Addr, remoteKey ssh.PublicKey) error {
|
|
||||||
if revoked := db.revoked[string(remoteKey.Marshal())]; revoked != nil {
|
|
||||||
return &RevokedError{Revoked: *revoked}
|
|
||||||
}
|
|
||||||
|
|
||||||
host, port, err := net.SplitHostPort(remote.String())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", remote, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
hostToCheck := addr{host, port}
|
|
||||||
if address != "" {
|
|
||||||
// Give preference to the hostname if available.
|
|
||||||
host, port, err := net.SplitHostPort(address)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", address, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
hostToCheck = addr{host, port}
|
|
||||||
}
|
|
||||||
|
|
||||||
return db.checkAddr(hostToCheck, remoteKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkAddr checks if we can find the given public key for the
|
|
||||||
// given address. If we only find an entry for the IP address,
|
|
||||||
// or only the hostname, then this still succeeds.
|
|
||||||
func (db *hostKeyDB) checkAddr(a addr, remoteKey ssh.PublicKey) error {
|
|
||||||
// TODO(hanwen): are these the right semantics? What if there
|
|
||||||
// is just a key for the IP address, but not for the
|
|
||||||
// hostname?
|
|
||||||
|
|
||||||
// Algorithm => key.
|
|
||||||
knownKeys := map[string]KnownKey{}
|
|
||||||
for _, l := range db.lines {
|
|
||||||
if l.match(a) {
|
|
||||||
typ := l.knownKey.Key.Type()
|
|
||||||
if _, ok := knownKeys[typ]; !ok {
|
|
||||||
knownKeys[typ] = l.knownKey
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
keyErr := &KeyError{}
|
|
||||||
for _, v := range knownKeys {
|
|
||||||
keyErr.Want = append(keyErr.Want, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unknown remote host.
|
|
||||||
if len(knownKeys) == 0 {
|
|
||||||
return keyErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the remote host starts using a different, unknown key type, we
|
|
||||||
// also interpret that as a mismatch.
|
|
||||||
if known, ok := knownKeys[remoteKey.Type()]; !ok || !keyEq(known.Key, remoteKey) {
|
|
||||||
return keyErr
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// The Read function parses file contents.
|
|
||||||
func (db *hostKeyDB) Read(r io.Reader, filename string) error {
|
|
||||||
scanner := bufio.NewScanner(r)
|
|
||||||
|
|
||||||
lineNum := 0
|
|
||||||
for scanner.Scan() {
|
|
||||||
lineNum++
|
|
||||||
line := scanner.Bytes()
|
|
||||||
line = bytes.TrimSpace(line)
|
|
||||||
if len(line) == 0 || line[0] == '#' {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := db.parseLine(line, filename, lineNum); err != nil {
|
|
||||||
return fmt.Errorf("knownhosts: %s:%d: %v", filename, lineNum, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return scanner.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a host key callback from the given OpenSSH host key
|
|
||||||
// files. The returned callback is for use in
|
|
||||||
// ssh.ClientConfig.HostKeyCallback. By preference, the key check
|
|
||||||
// operates on the hostname if available, i.e. if a server changes its
|
|
||||||
// IP address, the host key check will still succeed, even though a
|
|
||||||
// record of the new IP address is not available.
|
|
||||||
func New(files ...string) (ssh.HostKeyCallback, error) {
|
|
||||||
db := newHostKeyDB()
|
|
||||||
for _, fn := range files {
|
|
||||||
f, err := os.Open(fn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
if err := db.Read(f, fn); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var certChecker ssh.CertChecker
|
|
||||||
certChecker.IsHostAuthority = db.IsHostAuthority
|
|
||||||
certChecker.IsRevoked = db.IsRevoked
|
|
||||||
certChecker.HostKeyFallback = db.check
|
|
||||||
|
|
||||||
return certChecker.CheckHostKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize normalizes an address into the form used in known_hosts
|
|
||||||
func Normalize(address string) string {
|
|
||||||
host, port, err := net.SplitHostPort(address)
|
|
||||||
if err != nil {
|
|
||||||
host = address
|
|
||||||
port = "22"
|
|
||||||
}
|
|
||||||
entry := host
|
|
||||||
if port != "22" {
|
|
||||||
entry = "[" + entry + "]:" + port
|
|
||||||
} else if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
|
|
||||||
entry = "[" + entry + "]"
|
|
||||||
}
|
|
||||||
return entry
|
|
||||||
}
|
|
||||||
|
|
||||||
// Line returns a line to add append to the known_hosts files.
|
|
||||||
func Line(addresses []string, key ssh.PublicKey) string {
|
|
||||||
var trimmed []string
|
|
||||||
for _, a := range addresses {
|
|
||||||
trimmed = append(trimmed, Normalize(a))
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.Join(trimmed, ",") + " " + serialize(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// HashHostname hashes the given hostname. The hostname is not
|
|
||||||
// normalized before hashing.
|
|
||||||
func HashHostname(hostname string) string {
|
|
||||||
// TODO(hanwen): check if we can safely normalize this always.
|
|
||||||
salt := make([]byte, sha1.Size)
|
|
||||||
|
|
||||||
_, err := rand.Read(salt)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("crypto/rand failure %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
hash := hashHost(hostname, salt)
|
|
||||||
return encodeHash(sha1HashType, salt, hash)
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeHash(encoded string) (hashType string, salt, hash []byte, err error) {
|
|
||||||
if len(encoded) == 0 || encoded[0] != '|' {
|
|
||||||
err = errors.New("knownhosts: hashed host must start with '|'")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
components := strings.Split(encoded, "|")
|
|
||||||
if len(components) != 4 {
|
|
||||||
err = fmt.Errorf("knownhosts: got %d components, want 3", len(components))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hashType = components[1]
|
|
||||||
if salt, err = base64.StdEncoding.DecodeString(components[2]); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if hash, err = base64.StdEncoding.DecodeString(components[3]); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeHash(typ string, salt []byte, hash []byte) string {
|
|
||||||
return strings.Join([]string{"",
|
|
||||||
typ,
|
|
||||||
base64.StdEncoding.EncodeToString(salt),
|
|
||||||
base64.StdEncoding.EncodeToString(hash),
|
|
||||||
}, "|")
|
|
||||||
}
|
|
||||||
|
|
||||||
// See https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
|
|
||||||
func hashHost(hostname string, salt []byte) []byte {
|
|
||||||
mac := hmac.New(sha1.New, salt)
|
|
||||||
mac.Write([]byte(hostname))
|
|
||||||
return mac.Sum(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
type hashedHost struct {
|
|
||||||
salt []byte
|
|
||||||
hash []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
const sha1HashType = "1"
|
|
||||||
|
|
||||||
func newHashedHost(encoded string) (*hashedHost, error) {
|
|
||||||
typ, salt, hash, err := decodeHash(encoded)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// The type field seems for future algorithm agility, but it's
|
|
||||||
// actually hardcoded in openssh currently, see
|
|
||||||
// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
|
|
||||||
if typ != sha1HashType {
|
|
||||||
return nil, fmt.Errorf("knownhosts: got hash type %s, must be '1'", typ)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &hashedHost{salt: salt, hash: hash}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hashedHost) match(a addr) bool {
|
|
||||||
return bytes.Equal(hashHost(Normalize(a.String()), h.salt), h.hash)
|
|
||||||
}
|
|
@ -1,68 +0,0 @@
|
|||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
// Message authentication support
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/sha1"
|
|
||||||
"crypto/sha256"
|
|
||||||
"crypto/sha512"
|
|
||||||
"hash"
|
|
||||||
)
|
|
||||||
|
|
||||||
type macMode struct {
|
|
||||||
keySize int
|
|
||||||
etm bool
|
|
||||||
new func(key []byte) hash.Hash
|
|
||||||
}
|
|
||||||
|
|
||||||
// truncatingMAC wraps around a hash.Hash and truncates the output digest to
|
|
||||||
// a given size.
|
|
||||||
type truncatingMAC struct {
|
|
||||||
length int
|
|
||||||
hmac hash.Hash
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t truncatingMAC) Write(data []byte) (int, error) {
|
|
||||||
return t.hmac.Write(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t truncatingMAC) Sum(in []byte) []byte {
|
|
||||||
out := t.hmac.Sum(in)
|
|
||||||
return out[:len(in)+t.length]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t truncatingMAC) Reset() {
|
|
||||||
t.hmac.Reset()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t truncatingMAC) Size() int {
|
|
||||||
return t.length
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() }
|
|
||||||
|
|
||||||
var macModes = map[string]*macMode{
|
|
||||||
"hmac-sha2-512-etm@openssh.com": {64, true, func(key []byte) hash.Hash {
|
|
||||||
return hmac.New(sha512.New, key)
|
|
||||||
}},
|
|
||||||
"hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash {
|
|
||||||
return hmac.New(sha256.New, key)
|
|
||||||
}},
|
|
||||||
"hmac-sha2-512": {64, false, func(key []byte) hash.Hash {
|
|
||||||
return hmac.New(sha512.New, key)
|
|
||||||
}},
|
|
||||||
"hmac-sha2-256": {32, false, func(key []byte) hash.Hash {
|
|
||||||
return hmac.New(sha256.New, key)
|
|
||||||
}},
|
|
||||||
"hmac-sha1": {20, false, func(key []byte) hash.Hash {
|
|
||||||
return hmac.New(sha1.New, key)
|
|
||||||
}},
|
|
||||||
"hmac-sha1-96": {20, false, func(key []byte) hash.Hash {
|
|
||||||
return truncatingMAC{12, hmac.New(sha1.New, key)}
|
|
||||||
}},
|
|
||||||
}
|
|
@ -1,891 +0,0 @@
|
|||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math/big"
|
|
||||||
"reflect"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// These are SSH message type numbers. They are scattered around several
|
|
||||||
// documents but many were taken from [SSH-PARAMETERS].
|
|
||||||
const (
|
|
||||||
msgIgnore = 2
|
|
||||||
msgUnimplemented = 3
|
|
||||||
msgDebug = 4
|
|
||||||
msgNewKeys = 21
|
|
||||||
)
|
|
||||||
|
|
||||||
// SSH messages:
|
|
||||||
//
|
|
||||||
// These structures mirror the wire format of the corresponding SSH messages.
|
|
||||||
// They are marshaled using reflection with the marshal and unmarshal functions
|
|
||||||
// in this file. The only wrinkle is that a final member of type []byte with a
|
|
||||||
// ssh tag of "rest" receives the remainder of a packet when unmarshaling.
|
|
||||||
|
|
||||||
// See RFC 4253, section 11.1.
|
|
||||||
const msgDisconnect = 1
|
|
||||||
|
|
||||||
// disconnectMsg is the message that signals a disconnect. It is also
|
|
||||||
// the error type returned from mux.Wait()
|
|
||||||
type disconnectMsg struct {
|
|
||||||
Reason uint32 `sshtype:"1"`
|
|
||||||
Message string
|
|
||||||
Language string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *disconnectMsg) Error() string {
|
|
||||||
return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4253, section 7.1.
|
|
||||||
const msgKexInit = 20
|
|
||||||
|
|
||||||
type kexInitMsg struct {
|
|
||||||
Cookie [16]byte `sshtype:"20"`
|
|
||||||
KexAlgos []string
|
|
||||||
ServerHostKeyAlgos []string
|
|
||||||
CiphersClientServer []string
|
|
||||||
CiphersServerClient []string
|
|
||||||
MACsClientServer []string
|
|
||||||
MACsServerClient []string
|
|
||||||
CompressionClientServer []string
|
|
||||||
CompressionServerClient []string
|
|
||||||
LanguagesClientServer []string
|
|
||||||
LanguagesServerClient []string
|
|
||||||
FirstKexFollows bool
|
|
||||||
Reserved uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4253, section 8.
|
|
||||||
|
|
||||||
// Diffie-Hellman
|
|
||||||
const msgKexDHInit = 30
|
|
||||||
|
|
||||||
type kexDHInitMsg struct {
|
|
||||||
X *big.Int `sshtype:"30"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgKexECDHInit = 30
|
|
||||||
|
|
||||||
type kexECDHInitMsg struct {
|
|
||||||
ClientPubKey []byte `sshtype:"30"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgKexECDHReply = 31
|
|
||||||
|
|
||||||
type kexECDHReplyMsg struct {
|
|
||||||
HostKey []byte `sshtype:"31"`
|
|
||||||
EphemeralPubKey []byte
|
|
||||||
Signature []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgKexDHReply = 31
|
|
||||||
|
|
||||||
type kexDHReplyMsg struct {
|
|
||||||
HostKey []byte `sshtype:"31"`
|
|
||||||
Y *big.Int
|
|
||||||
Signature []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4419, section 5.
|
|
||||||
const msgKexDHGexGroup = 31
|
|
||||||
|
|
||||||
type kexDHGexGroupMsg struct {
|
|
||||||
P *big.Int `sshtype:"31"`
|
|
||||||
G *big.Int
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgKexDHGexInit = 32
|
|
||||||
|
|
||||||
type kexDHGexInitMsg struct {
|
|
||||||
X *big.Int `sshtype:"32"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgKexDHGexReply = 33
|
|
||||||
|
|
||||||
type kexDHGexReplyMsg struct {
|
|
||||||
HostKey []byte `sshtype:"33"`
|
|
||||||
Y *big.Int
|
|
||||||
Signature []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgKexDHGexRequest = 34
|
|
||||||
|
|
||||||
type kexDHGexRequestMsg struct {
|
|
||||||
MinBits uint32 `sshtype:"34"`
|
|
||||||
PreferedBits uint32
|
|
||||||
MaxBits uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4253, section 10.
|
|
||||||
const msgServiceRequest = 5
|
|
||||||
|
|
||||||
type serviceRequestMsg struct {
|
|
||||||
Service string `sshtype:"5"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4253, section 10.
|
|
||||||
const msgServiceAccept = 6
|
|
||||||
|
|
||||||
type serviceAcceptMsg struct {
|
|
||||||
Service string `sshtype:"6"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 8308, section 2.3
|
|
||||||
const msgExtInfo = 7
|
|
||||||
|
|
||||||
type extInfoMsg struct {
|
|
||||||
NumExtensions uint32 `sshtype:"7"`
|
|
||||||
Payload []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4252, section 5.
|
|
||||||
const msgUserAuthRequest = 50
|
|
||||||
|
|
||||||
type userAuthRequestMsg struct {
|
|
||||||
User string `sshtype:"50"`
|
|
||||||
Service string
|
|
||||||
Method string
|
|
||||||
Payload []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Used for debug printouts of packets.
|
|
||||||
type userAuthSuccessMsg struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4252, section 5.1
|
|
||||||
const msgUserAuthFailure = 51
|
|
||||||
|
|
||||||
type userAuthFailureMsg struct {
|
|
||||||
Methods []string `sshtype:"51"`
|
|
||||||
PartialSuccess bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4252, section 5.1
|
|
||||||
const msgUserAuthSuccess = 52
|
|
||||||
|
|
||||||
// See RFC 4252, section 5.4
|
|
||||||
const msgUserAuthBanner = 53
|
|
||||||
|
|
||||||
type userAuthBannerMsg struct {
|
|
||||||
Message string `sshtype:"53"`
|
|
||||||
// unused, but required to allow message parsing
|
|
||||||
Language string
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4256, section 3.2
|
|
||||||
const msgUserAuthInfoRequest = 60
|
|
||||||
const msgUserAuthInfoResponse = 61
|
|
||||||
|
|
||||||
type userAuthInfoRequestMsg struct {
|
|
||||||
Name string `sshtype:"60"`
|
|
||||||
Instruction string
|
|
||||||
Language string
|
|
||||||
NumPrompts uint32
|
|
||||||
Prompts []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.1.
|
|
||||||
const msgChannelOpen = 90
|
|
||||||
|
|
||||||
type channelOpenMsg struct {
|
|
||||||
ChanType string `sshtype:"90"`
|
|
||||||
PeersID uint32
|
|
||||||
PeersWindow uint32
|
|
||||||
MaxPacketSize uint32
|
|
||||||
TypeSpecificData []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgChannelExtendedData = 95
|
|
||||||
const msgChannelData = 94
|
|
||||||
|
|
||||||
// Used for debug print outs of packets.
|
|
||||||
type channelDataMsg struct {
|
|
||||||
PeersID uint32 `sshtype:"94"`
|
|
||||||
Length uint32
|
|
||||||
Rest []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.1.
|
|
||||||
const msgChannelOpenConfirm = 91
|
|
||||||
|
|
||||||
type channelOpenConfirmMsg struct {
|
|
||||||
PeersID uint32 `sshtype:"91"`
|
|
||||||
MyID uint32
|
|
||||||
MyWindow uint32
|
|
||||||
MaxPacketSize uint32
|
|
||||||
TypeSpecificData []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.1.
|
|
||||||
const msgChannelOpenFailure = 92
|
|
||||||
|
|
||||||
type channelOpenFailureMsg struct {
|
|
||||||
PeersID uint32 `sshtype:"92"`
|
|
||||||
Reason RejectionReason
|
|
||||||
Message string
|
|
||||||
Language string
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgChannelRequest = 98
|
|
||||||
|
|
||||||
type channelRequestMsg struct {
|
|
||||||
PeersID uint32 `sshtype:"98"`
|
|
||||||
Request string
|
|
||||||
WantReply bool
|
|
||||||
RequestSpecificData []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.4.
|
|
||||||
const msgChannelSuccess = 99
|
|
||||||
|
|
||||||
type channelRequestSuccessMsg struct {
|
|
||||||
PeersID uint32 `sshtype:"99"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.4.
|
|
||||||
const msgChannelFailure = 100
|
|
||||||
|
|
||||||
type channelRequestFailureMsg struct {
|
|
||||||
PeersID uint32 `sshtype:"100"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.3
|
|
||||||
const msgChannelClose = 97
|
|
||||||
|
|
||||||
type channelCloseMsg struct {
|
|
||||||
PeersID uint32 `sshtype:"97"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.3
|
|
||||||
const msgChannelEOF = 96
|
|
||||||
|
|
||||||
type channelEOFMsg struct {
|
|
||||||
PeersID uint32 `sshtype:"96"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 4
|
|
||||||
const msgGlobalRequest = 80
|
|
||||||
|
|
||||||
type globalRequestMsg struct {
|
|
||||||
Type string `sshtype:"80"`
|
|
||||||
WantReply bool
|
|
||||||
Data []byte `ssh:"rest"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 4
|
|
||||||
const msgRequestSuccess = 81
|
|
||||||
|
|
||||||
type globalRequestSuccessMsg struct {
|
|
||||||
Data []byte `ssh:"rest" sshtype:"81"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 4
|
|
||||||
const msgRequestFailure = 82
|
|
||||||
|
|
||||||
type globalRequestFailureMsg struct {
|
|
||||||
Data []byte `ssh:"rest" sshtype:"82"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 5.2
|
|
||||||
const msgChannelWindowAdjust = 93
|
|
||||||
|
|
||||||
type windowAdjustMsg struct {
|
|
||||||
PeersID uint32 `sshtype:"93"`
|
|
||||||
AdditionalBytes uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4252, section 7
|
|
||||||
const msgUserAuthPubKeyOk = 60
|
|
||||||
|
|
||||||
type userAuthPubKeyOkMsg struct {
|
|
||||||
Algo string `sshtype:"60"`
|
|
||||||
PubKey []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4462, section 3
|
|
||||||
const msgUserAuthGSSAPIResponse = 60
|
|
||||||
|
|
||||||
type userAuthGSSAPIResponse struct {
|
|
||||||
SupportMech []byte `sshtype:"60"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgUserAuthGSSAPIToken = 61
|
|
||||||
|
|
||||||
type userAuthGSSAPIToken struct {
|
|
||||||
Token []byte `sshtype:"61"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const msgUserAuthGSSAPIMIC = 66
|
|
||||||
|
|
||||||
type userAuthGSSAPIMIC struct {
|
|
||||||
MIC []byte `sshtype:"66"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4462, section 3.9
|
|
||||||
const msgUserAuthGSSAPIErrTok = 64
|
|
||||||
|
|
||||||
type userAuthGSSAPIErrTok struct {
|
|
||||||
ErrorToken []byte `sshtype:"64"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4462, section 3.8
|
|
||||||
const msgUserAuthGSSAPIError = 65
|
|
||||||
|
|
||||||
type userAuthGSSAPIError struct {
|
|
||||||
MajorStatus uint32 `sshtype:"65"`
|
|
||||||
MinorStatus uint32
|
|
||||||
Message string
|
|
||||||
LanguageTag string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9
|
|
||||||
const msgPing = 192
|
|
||||||
|
|
||||||
type pingMsg struct {
|
|
||||||
Data string `sshtype:"192"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9
|
|
||||||
const msgPong = 193
|
|
||||||
|
|
||||||
type pongMsg struct {
|
|
||||||
Data string `sshtype:"193"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// typeTags returns the possible type bytes for the given reflect.Type, which
|
|
||||||
// should be a struct. The possible values are separated by a '|' character.
|
|
||||||
func typeTags(structType reflect.Type) (tags []byte) {
|
|
||||||
tagStr := structType.Field(0).Tag.Get("sshtype")
|
|
||||||
|
|
||||||
for _, tag := range strings.Split(tagStr, "|") {
|
|
||||||
i, err := strconv.Atoi(tag)
|
|
||||||
if err == nil {
|
|
||||||
tags = append(tags, byte(i))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return tags
|
|
||||||
}
|
|
||||||
|
|
||||||
func fieldError(t reflect.Type, field int, problem string) error {
|
|
||||||
if problem != "" {
|
|
||||||
problem = ": " + problem
|
|
||||||
}
|
|
||||||
return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem)
|
|
||||||
}
|
|
||||||
|
|
||||||
var errShortRead = errors.New("ssh: short read")
|
|
||||||
|
|
||||||
// Unmarshal parses data in SSH wire format into a structure. The out
|
|
||||||
// argument should be a pointer to struct. If the first member of the
|
|
||||||
// struct has the "sshtype" tag set to a '|'-separated set of numbers
|
|
||||||
// in decimal, the packet must start with one of those numbers. In
|
|
||||||
// case of error, Unmarshal returns a ParseError or
|
|
||||||
// UnexpectedMessageError.
|
|
||||||
func Unmarshal(data []byte, out interface{}) error {
|
|
||||||
v := reflect.ValueOf(out).Elem()
|
|
||||||
structType := v.Type()
|
|
||||||
expectedTypes := typeTags(structType)
|
|
||||||
|
|
||||||
var expectedType byte
|
|
||||||
if len(expectedTypes) > 0 {
|
|
||||||
expectedType = expectedTypes[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(data) == 0 {
|
|
||||||
return parseError(expectedType)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(expectedTypes) > 0 {
|
|
||||||
goodType := false
|
|
||||||
for _, e := range expectedTypes {
|
|
||||||
if e > 0 && data[0] == e {
|
|
||||||
goodType = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !goodType {
|
|
||||||
return fmt.Errorf("ssh: unexpected message type %d (expected one of %v)", data[0], expectedTypes)
|
|
||||||
}
|
|
||||||
data = data[1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
var ok bool
|
|
||||||
for i := 0; i < v.NumField(); i++ {
|
|
||||||
field := v.Field(i)
|
|
||||||
t := field.Type()
|
|
||||||
switch t.Kind() {
|
|
||||||
case reflect.Bool:
|
|
||||||
if len(data) < 1 {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
field.SetBool(data[0] != 0)
|
|
||||||
data = data[1:]
|
|
||||||
case reflect.Array:
|
|
||||||
if t.Elem().Kind() != reflect.Uint8 {
|
|
||||||
return fieldError(structType, i, "array of unsupported type")
|
|
||||||
}
|
|
||||||
if len(data) < t.Len() {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
for j, n := 0, t.Len(); j < n; j++ {
|
|
||||||
field.Index(j).Set(reflect.ValueOf(data[j]))
|
|
||||||
}
|
|
||||||
data = data[t.Len():]
|
|
||||||
case reflect.Uint64:
|
|
||||||
var u64 uint64
|
|
||||||
if u64, data, ok = parseUint64(data); !ok {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
field.SetUint(u64)
|
|
||||||
case reflect.Uint32:
|
|
||||||
var u32 uint32
|
|
||||||
if u32, data, ok = parseUint32(data); !ok {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
field.SetUint(uint64(u32))
|
|
||||||
case reflect.Uint8:
|
|
||||||
if len(data) < 1 {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
field.SetUint(uint64(data[0]))
|
|
||||||
data = data[1:]
|
|
||||||
case reflect.String:
|
|
||||||
var s []byte
|
|
||||||
if s, data, ok = parseString(data); !ok {
|
|
||||||
return fieldError(structType, i, "")
|
|
||||||
}
|
|
||||||
field.SetString(string(s))
|
|
||||||
case reflect.Slice:
|
|
||||||
switch t.Elem().Kind() {
|
|
||||||
case reflect.Uint8:
|
|
||||||
if structType.Field(i).Tag.Get("ssh") == "rest" {
|
|
||||||
field.Set(reflect.ValueOf(data))
|
|
||||||
data = nil
|
|
||||||
} else {
|
|
||||||
var s []byte
|
|
||||||
if s, data, ok = parseString(data); !ok {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
field.Set(reflect.ValueOf(s))
|
|
||||||
}
|
|
||||||
case reflect.String:
|
|
||||||
var nl []string
|
|
||||||
if nl, data, ok = parseNameList(data); !ok {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
field.Set(reflect.ValueOf(nl))
|
|
||||||
default:
|
|
||||||
return fieldError(structType, i, "slice of unsupported type")
|
|
||||||
}
|
|
||||||
case reflect.Ptr:
|
|
||||||
if t == bigIntType {
|
|
||||||
var n *big.Int
|
|
||||||
if n, data, ok = parseInt(data); !ok {
|
|
||||||
return errShortRead
|
|
||||||
}
|
|
||||||
field.Set(reflect.ValueOf(n))
|
|
||||||
} else {
|
|
||||||
return fieldError(structType, i, "pointer to unsupported type")
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fieldError(structType, i, fmt.Sprintf("unsupported type: %v", t))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(data) != 0 {
|
|
||||||
return parseError(expectedType)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshal serializes the message in msg to SSH wire format. The msg
|
|
||||||
// argument should be a struct or pointer to struct. If the first
|
|
||||||
// member has the "sshtype" tag set to a number in decimal, that
|
|
||||||
// number is prepended to the result. If the last of member has the
|
|
||||||
// "ssh" tag set to "rest", its contents are appended to the output.
|
|
||||||
func Marshal(msg interface{}) []byte {
|
|
||||||
out := make([]byte, 0, 64)
|
|
||||||
return marshalStruct(out, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalStruct(out []byte, msg interface{}) []byte {
|
|
||||||
v := reflect.Indirect(reflect.ValueOf(msg))
|
|
||||||
msgTypes := typeTags(v.Type())
|
|
||||||
if len(msgTypes) > 0 {
|
|
||||||
out = append(out, msgTypes[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, n := 0, v.NumField(); i < n; i++ {
|
|
||||||
field := v.Field(i)
|
|
||||||
switch t := field.Type(); t.Kind() {
|
|
||||||
case reflect.Bool:
|
|
||||||
var v uint8
|
|
||||||
if field.Bool() {
|
|
||||||
v = 1
|
|
||||||
}
|
|
||||||
out = append(out, v)
|
|
||||||
case reflect.Array:
|
|
||||||
if t.Elem().Kind() != reflect.Uint8 {
|
|
||||||
panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface()))
|
|
||||||
}
|
|
||||||
for j, l := 0, t.Len(); j < l; j++ {
|
|
||||||
out = append(out, uint8(field.Index(j).Uint()))
|
|
||||||
}
|
|
||||||
case reflect.Uint32:
|
|
||||||
out = appendU32(out, uint32(field.Uint()))
|
|
||||||
case reflect.Uint64:
|
|
||||||
out = appendU64(out, uint64(field.Uint()))
|
|
||||||
case reflect.Uint8:
|
|
||||||
out = append(out, uint8(field.Uint()))
|
|
||||||
case reflect.String:
|
|
||||||
s := field.String()
|
|
||||||
out = appendInt(out, len(s))
|
|
||||||
out = append(out, s...)
|
|
||||||
case reflect.Slice:
|
|
||||||
switch t.Elem().Kind() {
|
|
||||||
case reflect.Uint8:
|
|
||||||
if v.Type().Field(i).Tag.Get("ssh") != "rest" {
|
|
||||||
out = appendInt(out, field.Len())
|
|
||||||
}
|
|
||||||
out = append(out, field.Bytes()...)
|
|
||||||
case reflect.String:
|
|
||||||
offset := len(out)
|
|
||||||
out = appendU32(out, 0)
|
|
||||||
if n := field.Len(); n > 0 {
|
|
||||||
for j := 0; j < n; j++ {
|
|
||||||
f := field.Index(j)
|
|
||||||
if j != 0 {
|
|
||||||
out = append(out, ',')
|
|
||||||
}
|
|
||||||
out = append(out, f.String()...)
|
|
||||||
}
|
|
||||||
// overwrite length value
|
|
||||||
binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4))
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface()))
|
|
||||||
}
|
|
||||||
case reflect.Ptr:
|
|
||||||
if t == bigIntType {
|
|
||||||
var n *big.Int
|
|
||||||
nValue := reflect.ValueOf(&n)
|
|
||||||
nValue.Elem().Set(field)
|
|
||||||
needed := intLength(n)
|
|
||||||
oldLength := len(out)
|
|
||||||
|
|
||||||
if cap(out)-len(out) < needed {
|
|
||||||
newOut := make([]byte, len(out), 2*(len(out)+needed))
|
|
||||||
copy(newOut, out)
|
|
||||||
out = newOut
|
|
||||||
}
|
|
||||||
out = out[:oldLength+needed]
|
|
||||||
marshalInt(out[oldLength:], n)
|
|
||||||
} else {
|
|
||||||
panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
var bigOne = big.NewInt(1)
|
|
||||||
|
|
||||||
func parseString(in []byte) (out, rest []byte, ok bool) {
|
|
||||||
if len(in) < 4 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
length := binary.BigEndian.Uint32(in)
|
|
||||||
in = in[4:]
|
|
||||||
if uint32(len(in)) < length {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
out = in[:length]
|
|
||||||
rest = in[length:]
|
|
||||||
ok = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
comma = []byte{','}
|
|
||||||
emptyNameList = []string{}
|
|
||||||
)
|
|
||||||
|
|
||||||
func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
|
|
||||||
contents, rest, ok := parseString(in)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(contents) == 0 {
|
|
||||||
out = emptyNameList
|
|
||||||
return
|
|
||||||
}
|
|
||||||
parts := bytes.Split(contents, comma)
|
|
||||||
out = make([]string, len(parts))
|
|
||||||
for i, part := range parts {
|
|
||||||
out[i] = string(part)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) {
|
|
||||||
contents, rest, ok := parseString(in)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
out = new(big.Int)
|
|
||||||
|
|
||||||
if len(contents) > 0 && contents[0]&0x80 == 0x80 {
|
|
||||||
// This is a negative number
|
|
||||||
notBytes := make([]byte, len(contents))
|
|
||||||
for i := range notBytes {
|
|
||||||
notBytes[i] = ^contents[i]
|
|
||||||
}
|
|
||||||
out.SetBytes(notBytes)
|
|
||||||
out.Add(out, bigOne)
|
|
||||||
out.Neg(out)
|
|
||||||
} else {
|
|
||||||
// Positive number
|
|
||||||
out.SetBytes(contents)
|
|
||||||
}
|
|
||||||
ok = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseUint32(in []byte) (uint32, []byte, bool) {
|
|
||||||
if len(in) < 4 {
|
|
||||||
return 0, nil, false
|
|
||||||
}
|
|
||||||
return binary.BigEndian.Uint32(in), in[4:], true
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseUint64(in []byte) (uint64, []byte, bool) {
|
|
||||||
if len(in) < 8 {
|
|
||||||
return 0, nil, false
|
|
||||||
}
|
|
||||||
return binary.BigEndian.Uint64(in), in[8:], true
|
|
||||||
}
|
|
||||||
|
|
||||||
func intLength(n *big.Int) int {
|
|
||||||
length := 4 /* length bytes */
|
|
||||||
if n.Sign() < 0 {
|
|
||||||
nMinus1 := new(big.Int).Neg(n)
|
|
||||||
nMinus1.Sub(nMinus1, bigOne)
|
|
||||||
bitLen := nMinus1.BitLen()
|
|
||||||
if bitLen%8 == 0 {
|
|
||||||
// The number will need 0xff padding
|
|
||||||
length++
|
|
||||||
}
|
|
||||||
length += (bitLen + 7) / 8
|
|
||||||
} else if n.Sign() == 0 {
|
|
||||||
// A zero is the zero length string
|
|
||||||
} else {
|
|
||||||
bitLen := n.BitLen()
|
|
||||||
if bitLen%8 == 0 {
|
|
||||||
// The number will need 0x00 padding
|
|
||||||
length++
|
|
||||||
}
|
|
||||||
length += (bitLen + 7) / 8
|
|
||||||
}
|
|
||||||
|
|
||||||
return length
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalUint32(to []byte, n uint32) []byte {
|
|
||||||
binary.BigEndian.PutUint32(to, n)
|
|
||||||
return to[4:]
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalUint64(to []byte, n uint64) []byte {
|
|
||||||
binary.BigEndian.PutUint64(to, n)
|
|
||||||
return to[8:]
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalInt(to []byte, n *big.Int) []byte {
|
|
||||||
lengthBytes := to
|
|
||||||
to = to[4:]
|
|
||||||
length := 0
|
|
||||||
|
|
||||||
if n.Sign() < 0 {
|
|
||||||
// A negative number has to be converted to two's-complement
|
|
||||||
// form. So we'll subtract 1 and invert. If the
|
|
||||||
// most-significant-bit isn't set then we'll need to pad the
|
|
||||||
// beginning with 0xff in order to keep the number negative.
|
|
||||||
nMinus1 := new(big.Int).Neg(n)
|
|
||||||
nMinus1.Sub(nMinus1, bigOne)
|
|
||||||
bytes := nMinus1.Bytes()
|
|
||||||
for i := range bytes {
|
|
||||||
bytes[i] ^= 0xff
|
|
||||||
}
|
|
||||||
if len(bytes) == 0 || bytes[0]&0x80 == 0 {
|
|
||||||
to[0] = 0xff
|
|
||||||
to = to[1:]
|
|
||||||
length++
|
|
||||||
}
|
|
||||||
nBytes := copy(to, bytes)
|
|
||||||
to = to[nBytes:]
|
|
||||||
length += nBytes
|
|
||||||
} else if n.Sign() == 0 {
|
|
||||||
// A zero is the zero length string
|
|
||||||
} else {
|
|
||||||
bytes := n.Bytes()
|
|
||||||
if len(bytes) > 0 && bytes[0]&0x80 != 0 {
|
|
||||||
// We'll have to pad this with a 0x00 in order to
|
|
||||||
// stop it looking like a negative number.
|
|
||||||
to[0] = 0
|
|
||||||
to = to[1:]
|
|
||||||
length++
|
|
||||||
}
|
|
||||||
nBytes := copy(to, bytes)
|
|
||||||
to = to[nBytes:]
|
|
||||||
length += nBytes
|
|
||||||
}
|
|
||||||
|
|
||||||
lengthBytes[0] = byte(length >> 24)
|
|
||||||
lengthBytes[1] = byte(length >> 16)
|
|
||||||
lengthBytes[2] = byte(length >> 8)
|
|
||||||
lengthBytes[3] = byte(length)
|
|
||||||
return to
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeInt(w io.Writer, n *big.Int) {
|
|
||||||
length := intLength(n)
|
|
||||||
buf := make([]byte, length)
|
|
||||||
marshalInt(buf, n)
|
|
||||||
w.Write(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeString(w io.Writer, s []byte) {
|
|
||||||
var lengthBytes [4]byte
|
|
||||||
lengthBytes[0] = byte(len(s) >> 24)
|
|
||||||
lengthBytes[1] = byte(len(s) >> 16)
|
|
||||||
lengthBytes[2] = byte(len(s) >> 8)
|
|
||||||
lengthBytes[3] = byte(len(s))
|
|
||||||
w.Write(lengthBytes[:])
|
|
||||||
w.Write(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func stringLength(n int) int {
|
|
||||||
return 4 + n
|
|
||||||
}
|
|
||||||
|
|
||||||
func marshalString(to []byte, s []byte) []byte {
|
|
||||||
to[0] = byte(len(s) >> 24)
|
|
||||||
to[1] = byte(len(s) >> 16)
|
|
||||||
to[2] = byte(len(s) >> 8)
|
|
||||||
to[3] = byte(len(s))
|
|
||||||
to = to[4:]
|
|
||||||
copy(to, s)
|
|
||||||
return to[len(s):]
|
|
||||||
}
|
|
||||||
|
|
||||||
var bigIntType = reflect.TypeOf((*big.Int)(nil))
|
|
||||||
|
|
||||||
// Decode a packet into its corresponding message.
|
|
||||||
func decode(packet []byte) (interface{}, error) {
|
|
||||||
var msg interface{}
|
|
||||||
switch packet[0] {
|
|
||||||
case msgDisconnect:
|
|
||||||
msg = new(disconnectMsg)
|
|
||||||
case msgServiceRequest:
|
|
||||||
msg = new(serviceRequestMsg)
|
|
||||||
case msgServiceAccept:
|
|
||||||
msg = new(serviceAcceptMsg)
|
|
||||||
case msgExtInfo:
|
|
||||||
msg = new(extInfoMsg)
|
|
||||||
case msgKexInit:
|
|
||||||
msg = new(kexInitMsg)
|
|
||||||
case msgKexDHInit:
|
|
||||||
msg = new(kexDHInitMsg)
|
|
||||||
case msgKexDHReply:
|
|
||||||
msg = new(kexDHReplyMsg)
|
|
||||||
case msgUserAuthRequest:
|
|
||||||
msg = new(userAuthRequestMsg)
|
|
||||||
case msgUserAuthSuccess:
|
|
||||||
return new(userAuthSuccessMsg), nil
|
|
||||||
case msgUserAuthFailure:
|
|
||||||
msg = new(userAuthFailureMsg)
|
|
||||||
case msgUserAuthPubKeyOk:
|
|
||||||
msg = new(userAuthPubKeyOkMsg)
|
|
||||||
case msgGlobalRequest:
|
|
||||||
msg = new(globalRequestMsg)
|
|
||||||
case msgRequestSuccess:
|
|
||||||
msg = new(globalRequestSuccessMsg)
|
|
||||||
case msgRequestFailure:
|
|
||||||
msg = new(globalRequestFailureMsg)
|
|
||||||
case msgChannelOpen:
|
|
||||||
msg = new(channelOpenMsg)
|
|
||||||
case msgChannelData:
|
|
||||||
msg = new(channelDataMsg)
|
|
||||||
case msgChannelOpenConfirm:
|
|
||||||
msg = new(channelOpenConfirmMsg)
|
|
||||||
case msgChannelOpenFailure:
|
|
||||||
msg = new(channelOpenFailureMsg)
|
|
||||||
case msgChannelWindowAdjust:
|
|
||||||
msg = new(windowAdjustMsg)
|
|
||||||
case msgChannelEOF:
|
|
||||||
msg = new(channelEOFMsg)
|
|
||||||
case msgChannelClose:
|
|
||||||
msg = new(channelCloseMsg)
|
|
||||||
case msgChannelRequest:
|
|
||||||
msg = new(channelRequestMsg)
|
|
||||||
case msgChannelSuccess:
|
|
||||||
msg = new(channelRequestSuccessMsg)
|
|
||||||
case msgChannelFailure:
|
|
||||||
msg = new(channelRequestFailureMsg)
|
|
||||||
case msgUserAuthGSSAPIToken:
|
|
||||||
msg = new(userAuthGSSAPIToken)
|
|
||||||
case msgUserAuthGSSAPIMIC:
|
|
||||||
msg = new(userAuthGSSAPIMIC)
|
|
||||||
case msgUserAuthGSSAPIErrTok:
|
|
||||||
msg = new(userAuthGSSAPIErrTok)
|
|
||||||
case msgUserAuthGSSAPIError:
|
|
||||||
msg = new(userAuthGSSAPIError)
|
|
||||||
default:
|
|
||||||
return nil, unexpectedMessageError(0, packet[0])
|
|
||||||
}
|
|
||||||
if err := Unmarshal(packet, msg); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var packetTypeNames = map[byte]string{
|
|
||||||
msgDisconnect: "disconnectMsg",
|
|
||||||
msgServiceRequest: "serviceRequestMsg",
|
|
||||||
msgServiceAccept: "serviceAcceptMsg",
|
|
||||||
msgExtInfo: "extInfoMsg",
|
|
||||||
msgKexInit: "kexInitMsg",
|
|
||||||
msgKexDHInit: "kexDHInitMsg",
|
|
||||||
msgKexDHReply: "kexDHReplyMsg",
|
|
||||||
msgUserAuthRequest: "userAuthRequestMsg",
|
|
||||||
msgUserAuthSuccess: "userAuthSuccessMsg",
|
|
||||||
msgUserAuthFailure: "userAuthFailureMsg",
|
|
||||||
msgUserAuthPubKeyOk: "userAuthPubKeyOkMsg",
|
|
||||||
msgGlobalRequest: "globalRequestMsg",
|
|
||||||
msgRequestSuccess: "globalRequestSuccessMsg",
|
|
||||||
msgRequestFailure: "globalRequestFailureMsg",
|
|
||||||
msgChannelOpen: "channelOpenMsg",
|
|
||||||
msgChannelData: "channelDataMsg",
|
|
||||||
msgChannelOpenConfirm: "channelOpenConfirmMsg",
|
|
||||||
msgChannelOpenFailure: "channelOpenFailureMsg",
|
|
||||||
msgChannelWindowAdjust: "windowAdjustMsg",
|
|
||||||
msgChannelEOF: "channelEOFMsg",
|
|
||||||
msgChannelClose: "channelCloseMsg",
|
|
||||||
msgChannelRequest: "channelRequestMsg",
|
|
||||||
msgChannelSuccess: "channelRequestSuccessMsg",
|
|
||||||
msgChannelFailure: "channelRequestFailureMsg",
|
|
||||||
}
|
|
@ -1,357 +0,0 @@
|
|||||||
// Copyright 2013 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
// debugMux, if set, causes messages in the connection protocol to be
|
|
||||||
// logged.
|
|
||||||
const debugMux = false
|
|
||||||
|
|
||||||
// chanList is a thread safe channel list.
|
|
||||||
type chanList struct {
|
|
||||||
// protects concurrent access to chans
|
|
||||||
sync.Mutex
|
|
||||||
|
|
||||||
// chans are indexed by the local id of the channel, which the
|
|
||||||
// other side should send in the PeersId field.
|
|
||||||
chans []*channel
|
|
||||||
|
|
||||||
// This is a debugging aid: it offsets all IDs by this
|
|
||||||
// amount. This helps distinguish otherwise identical
|
|
||||||
// server/client muxes
|
|
||||||
offset uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assigns a channel ID to the given channel.
|
|
||||||
func (c *chanList) add(ch *channel) uint32 {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
for i := range c.chans {
|
|
||||||
if c.chans[i] == nil {
|
|
||||||
c.chans[i] = ch
|
|
||||||
return uint32(i) + c.offset
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.chans = append(c.chans, ch)
|
|
||||||
return uint32(len(c.chans)-1) + c.offset
|
|
||||||
}
|
|
||||||
|
|
||||||
// getChan returns the channel for the given ID.
|
|
||||||
func (c *chanList) getChan(id uint32) *channel {
|
|
||||||
id -= c.offset
|
|
||||||
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
if id < uint32(len(c.chans)) {
|
|
||||||
return c.chans[id]
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *chanList) remove(id uint32) {
|
|
||||||
id -= c.offset
|
|
||||||
c.Lock()
|
|
||||||
if id < uint32(len(c.chans)) {
|
|
||||||
c.chans[id] = nil
|
|
||||||
}
|
|
||||||
c.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// dropAll forgets all channels it knows, returning them in a slice.
|
|
||||||
func (c *chanList) dropAll() []*channel {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
var r []*channel
|
|
||||||
|
|
||||||
for _, ch := range c.chans {
|
|
||||||
if ch == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
r = append(r, ch)
|
|
||||||
}
|
|
||||||
c.chans = nil
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// mux represents the state for the SSH connection protocol, which
|
|
||||||
// multiplexes many channels onto a single packet transport.
|
|
||||||
type mux struct {
|
|
||||||
conn packetConn
|
|
||||||
chanList chanList
|
|
||||||
|
|
||||||
incomingChannels chan NewChannel
|
|
||||||
|
|
||||||
globalSentMu sync.Mutex
|
|
||||||
globalResponses chan interface{}
|
|
||||||
incomingRequests chan *Request
|
|
||||||
|
|
||||||
errCond *sync.Cond
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// When debugging, each new chanList instantiation has a different
|
|
||||||
// offset.
|
|
||||||
var globalOff uint32
|
|
||||||
|
|
||||||
func (m *mux) Wait() error {
|
|
||||||
m.errCond.L.Lock()
|
|
||||||
defer m.errCond.L.Unlock()
|
|
||||||
for m.err == nil {
|
|
||||||
m.errCond.Wait()
|
|
||||||
}
|
|
||||||
return m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
// newMux returns a mux that runs over the given connection.
|
|
||||||
func newMux(p packetConn) *mux {
|
|
||||||
m := &mux{
|
|
||||||
conn: p,
|
|
||||||
incomingChannels: make(chan NewChannel, chanSize),
|
|
||||||
globalResponses: make(chan interface{}, 1),
|
|
||||||
incomingRequests: make(chan *Request, chanSize),
|
|
||||||
errCond: newCond(),
|
|
||||||
}
|
|
||||||
if debugMux {
|
|
||||||
m.chanList.offset = atomic.AddUint32(&globalOff, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
go m.loop()
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) sendMessage(msg interface{}) error {
|
|
||||||
p := Marshal(msg)
|
|
||||||
if debugMux {
|
|
||||||
log.Printf("send global(%d): %#v", m.chanList.offset, msg)
|
|
||||||
}
|
|
||||||
return m.conn.writePacket(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
|
|
||||||
if wantReply {
|
|
||||||
m.globalSentMu.Lock()
|
|
||||||
defer m.globalSentMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.sendMessage(globalRequestMsg{
|
|
||||||
Type: name,
|
|
||||||
WantReply: wantReply,
|
|
||||||
Data: payload,
|
|
||||||
}); err != nil {
|
|
||||||
return false, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !wantReply {
|
|
||||||
return false, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
msg, ok := <-m.globalResponses
|
|
||||||
if !ok {
|
|
||||||
return false, nil, io.EOF
|
|
||||||
}
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case *globalRequestFailureMsg:
|
|
||||||
return false, msg.Data, nil
|
|
||||||
case *globalRequestSuccessMsg:
|
|
||||||
return true, msg.Data, nil
|
|
||||||
default:
|
|
||||||
return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ackRequest must be called after processing a global request that
|
|
||||||
// has WantReply set.
|
|
||||||
func (m *mux) ackRequest(ok bool, data []byte) error {
|
|
||||||
if ok {
|
|
||||||
return m.sendMessage(globalRequestSuccessMsg{Data: data})
|
|
||||||
}
|
|
||||||
return m.sendMessage(globalRequestFailureMsg{Data: data})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) Close() error {
|
|
||||||
return m.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// loop runs the connection machine. It will process packets until an
|
|
||||||
// error is encountered. To synchronize on loop exit, use mux.Wait.
|
|
||||||
func (m *mux) loop() {
|
|
||||||
var err error
|
|
||||||
for err == nil {
|
|
||||||
err = m.onePacket()
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ch := range m.chanList.dropAll() {
|
|
||||||
ch.close()
|
|
||||||
}
|
|
||||||
|
|
||||||
close(m.incomingChannels)
|
|
||||||
close(m.incomingRequests)
|
|
||||||
close(m.globalResponses)
|
|
||||||
|
|
||||||
m.conn.Close()
|
|
||||||
|
|
||||||
m.errCond.L.Lock()
|
|
||||||
m.err = err
|
|
||||||
m.errCond.Broadcast()
|
|
||||||
m.errCond.L.Unlock()
|
|
||||||
|
|
||||||
if debugMux {
|
|
||||||
log.Println("loop exit", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// onePacket reads and processes one packet.
|
|
||||||
func (m *mux) onePacket() error {
|
|
||||||
packet, err := m.conn.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if debugMux {
|
|
||||||
if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
|
|
||||||
log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
|
|
||||||
} else {
|
|
||||||
p, _ := decode(packet)
|
|
||||||
log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch packet[0] {
|
|
||||||
case msgChannelOpen:
|
|
||||||
return m.handleChannelOpen(packet)
|
|
||||||
case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
|
|
||||||
return m.handleGlobalPacket(packet)
|
|
||||||
case msgPing:
|
|
||||||
var msg pingMsg
|
|
||||||
if err := Unmarshal(packet, &msg); err != nil {
|
|
||||||
return fmt.Errorf("failed to unmarshal ping@openssh.com message: %w", err)
|
|
||||||
}
|
|
||||||
return m.sendMessage(pongMsg(msg))
|
|
||||||
}
|
|
||||||
|
|
||||||
// assume a channel packet.
|
|
||||||
if len(packet) < 5 {
|
|
||||||
return parseError(packet[0])
|
|
||||||
}
|
|
||||||
id := binary.BigEndian.Uint32(packet[1:])
|
|
||||||
ch := m.chanList.getChan(id)
|
|
||||||
if ch == nil {
|
|
||||||
return m.handleUnknownChannelPacket(id, packet)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ch.handlePacket(packet)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) handleGlobalPacket(packet []byte) error {
|
|
||||||
msg, err := decode(packet)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case *globalRequestMsg:
|
|
||||||
m.incomingRequests <- &Request{
|
|
||||||
Type: msg.Type,
|
|
||||||
WantReply: msg.WantReply,
|
|
||||||
Payload: msg.Data,
|
|
||||||
mux: m,
|
|
||||||
}
|
|
||||||
case *globalRequestSuccessMsg, *globalRequestFailureMsg:
|
|
||||||
m.globalResponses <- msg
|
|
||||||
default:
|
|
||||||
panic(fmt.Sprintf("not a global message %#v", msg))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleChannelOpen schedules a channel to be Accept()ed.
|
|
||||||
func (m *mux) handleChannelOpen(packet []byte) error {
|
|
||||||
var msg channelOpenMsg
|
|
||||||
if err := Unmarshal(packet, &msg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
|
|
||||||
failMsg := channelOpenFailureMsg{
|
|
||||||
PeersID: msg.PeersID,
|
|
||||||
Reason: ConnectionFailed,
|
|
||||||
Message: "invalid request",
|
|
||||||
Language: "en_US.UTF-8",
|
|
||||||
}
|
|
||||||
return m.sendMessage(failMsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
|
|
||||||
c.remoteId = msg.PeersID
|
|
||||||
c.maxRemotePayload = msg.MaxPacketSize
|
|
||||||
c.remoteWin.add(msg.PeersWindow)
|
|
||||||
m.incomingChannels <- c
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
|
|
||||||
ch, err := m.openChannel(chanType, extra)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return ch, ch.incomingRequests, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
|
|
||||||
ch := m.newChannel(chanType, channelOutbound, extra)
|
|
||||||
|
|
||||||
ch.maxIncomingPayload = channelMaxPacket
|
|
||||||
|
|
||||||
open := channelOpenMsg{
|
|
||||||
ChanType: chanType,
|
|
||||||
PeersWindow: ch.myWindow,
|
|
||||||
MaxPacketSize: ch.maxIncomingPayload,
|
|
||||||
TypeSpecificData: extra,
|
|
||||||
PeersID: ch.localId,
|
|
||||||
}
|
|
||||||
if err := m.sendMessage(open); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg := (<-ch.msg).(type) {
|
|
||||||
case *channelOpenConfirmMsg:
|
|
||||||
return ch, nil
|
|
||||||
case *channelOpenFailureMsg:
|
|
||||||
return nil, &OpenChannelError{msg.Reason, msg.Message}
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error {
|
|
||||||
msg, err := decode(packet)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
// RFC 4254 section 5.4 says unrecognized channel requests should
|
|
||||||
// receive a failure response.
|
|
||||||
case *channelRequestMsg:
|
|
||||||
if msg.WantReply {
|
|
||||||
return m.sendMessage(channelRequestFailureMsg{
|
|
||||||
PeersID: msg.PeersID,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("ssh: invalid channel %d", id)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,800 +0,0 @@
|
|||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// The Permissions type holds fine-grained permissions that are
|
|
||||||
// specific to a user or a specific authentication method for a user.
|
|
||||||
// The Permissions value for a successful authentication attempt is
|
|
||||||
// available in ServerConn, so it can be used to pass information from
|
|
||||||
// the user-authentication phase to the application layer.
|
|
||||||
type Permissions struct {
|
|
||||||
// CriticalOptions indicate restrictions to the default
|
|
||||||
// permissions, and are typically used in conjunction with
|
|
||||||
// user certificates. The standard for SSH certificates
|
|
||||||
// defines "force-command" (only allow the given command to
|
|
||||||
// execute) and "source-address" (only allow connections from
|
|
||||||
// the given address). The SSH package currently only enforces
|
|
||||||
// the "source-address" critical option. It is up to server
|
|
||||||
// implementations to enforce other critical options, such as
|
|
||||||
// "force-command", by checking them after the SSH handshake
|
|
||||||
// is successful. In general, SSH servers should reject
|
|
||||||
// connections that specify critical options that are unknown
|
|
||||||
// or not supported.
|
|
||||||
CriticalOptions map[string]string
|
|
||||||
|
|
||||||
// Extensions are extra functionality that the server may
|
|
||||||
// offer on authenticated connections. Lack of support for an
|
|
||||||
// extension does not preclude authenticating a user. Common
|
|
||||||
// extensions are "permit-agent-forwarding",
|
|
||||||
// "permit-X11-forwarding". The Go SSH library currently does
|
|
||||||
// not act on any extension, and it is up to server
|
|
||||||
// implementations to honor them. Extensions can be used to
|
|
||||||
// pass data from the authentication callbacks to the server
|
|
||||||
// application layer.
|
|
||||||
Extensions map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
type GSSAPIWithMICConfig struct {
|
|
||||||
// AllowLogin, must be set, is called when gssapi-with-mic
|
|
||||||
// authentication is selected (RFC 4462 section 3). The srcName is from the
|
|
||||||
// results of the GSS-API authentication. The format is username@DOMAIN.
|
|
||||||
// GSSAPI just guarantees to the server who the user is, but not if they can log in, and with what permissions.
|
|
||||||
// This callback is called after the user identity is established with GSSAPI to decide if the user can login with
|
|
||||||
// which permissions. If the user is allowed to login, it should return a nil error.
|
|
||||||
AllowLogin func(conn ConnMetadata, srcName string) (*Permissions, error)
|
|
||||||
|
|
||||||
// Server must be set. It's the implementation
|
|
||||||
// of the GSSAPIServer interface. See GSSAPIServer interface for details.
|
|
||||||
Server GSSAPIServer
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServerConfig holds server specific configuration data.
|
|
||||||
type ServerConfig struct {
|
|
||||||
// Config contains configuration shared between client and server.
|
|
||||||
Config
|
|
||||||
|
|
||||||
// PublicKeyAuthAlgorithms specifies the supported client public key
|
|
||||||
// authentication algorithms. Note that this should not include certificate
|
|
||||||
// types since those use the underlying algorithm. This list is sent to the
|
|
||||||
// client if it supports the server-sig-algs extension. Order is irrelevant.
|
|
||||||
// If unspecified then a default set of algorithms is used.
|
|
||||||
PublicKeyAuthAlgorithms []string
|
|
||||||
|
|
||||||
hostKeys []Signer
|
|
||||||
|
|
||||||
// NoClientAuth is true if clients are allowed to connect without
|
|
||||||
// authenticating.
|
|
||||||
// To determine NoClientAuth at runtime, set NoClientAuth to true
|
|
||||||
// and the optional NoClientAuthCallback to a non-nil value.
|
|
||||||
NoClientAuth bool
|
|
||||||
|
|
||||||
// NoClientAuthCallback, if non-nil, is called when a user
|
|
||||||
// attempts to authenticate with auth method "none".
|
|
||||||
// NoClientAuth must also be set to true for this be used, or
|
|
||||||
// this func is unused.
|
|
||||||
NoClientAuthCallback func(ConnMetadata) (*Permissions, error)
|
|
||||||
|
|
||||||
// MaxAuthTries specifies the maximum number of authentication attempts
|
|
||||||
// permitted per connection. If set to a negative number, the number of
|
|
||||||
// attempts are unlimited. If set to zero, the number of attempts are limited
|
|
||||||
// to 6.
|
|
||||||
MaxAuthTries int
|
|
||||||
|
|
||||||
// PasswordCallback, if non-nil, is called when a user
|
|
||||||
// attempts to authenticate using a password.
|
|
||||||
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
|
|
||||||
|
|
||||||
// PublicKeyCallback, if non-nil, is called when a client
|
|
||||||
// offers a public key for authentication. It must return a nil error
|
|
||||||
// if the given public key can be used to authenticate the
|
|
||||||
// given user. For example, see CertChecker.Authenticate. A
|
|
||||||
// call to this function does not guarantee that the key
|
|
||||||
// offered is in fact used to authenticate. To record any data
|
|
||||||
// depending on the public key, store it inside a
|
|
||||||
// Permissions.Extensions entry.
|
|
||||||
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
|
|
||||||
|
|
||||||
// KeyboardInteractiveCallback, if non-nil, is called when
|
|
||||||
// keyboard-interactive authentication is selected (RFC
|
|
||||||
// 4256). The client object's Challenge function should be
|
|
||||||
// used to query the user. The callback may offer multiple
|
|
||||||
// Challenge rounds. To avoid information leaks, the client
|
|
||||||
// should be presented a challenge even if the user is
|
|
||||||
// unknown.
|
|
||||||
KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)
|
|
||||||
|
|
||||||
// AuthLogCallback, if non-nil, is called to log all authentication
|
|
||||||
// attempts.
|
|
||||||
AuthLogCallback func(conn ConnMetadata, method string, err error)
|
|
||||||
|
|
||||||
// ServerVersion is the version identification string to announce in
|
|
||||||
// the public handshake.
|
|
||||||
// If empty, a reasonable default is used.
|
|
||||||
// Note that RFC 4253 section 4.2 requires that this string start with
|
|
||||||
// "SSH-2.0-".
|
|
||||||
ServerVersion string
|
|
||||||
|
|
||||||
// BannerCallback, if present, is called and the return string is sent to
|
|
||||||
// the client after key exchange completed but before authentication.
|
|
||||||
BannerCallback func(conn ConnMetadata) string
|
|
||||||
|
|
||||||
// GSSAPIWithMICConfig includes gssapi server and callback, which if both non-nil, is used
|
|
||||||
// when gssapi-with-mic authentication is selected (RFC 4462 section 3).
|
|
||||||
GSSAPIWithMICConfig *GSSAPIWithMICConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddHostKey adds a private key as a host key. If an existing host
|
|
||||||
// key exists with the same public key format, it is replaced. Each server
|
|
||||||
// config must have at least one host key.
|
|
||||||
func (s *ServerConfig) AddHostKey(key Signer) {
|
|
||||||
for i, k := range s.hostKeys {
|
|
||||||
if k.PublicKey().Type() == key.PublicKey().Type() {
|
|
||||||
s.hostKeys[i] = key
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.hostKeys = append(s.hostKeys, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// cachedPubKey contains the results of querying whether a public key is
|
|
||||||
// acceptable for a user.
|
|
||||||
type cachedPubKey struct {
|
|
||||||
user string
|
|
||||||
pubKeyData []byte
|
|
||||||
result error
|
|
||||||
perms *Permissions
|
|
||||||
}
|
|
||||||
|
|
||||||
const maxCachedPubKeys = 16
|
|
||||||
|
|
||||||
// pubKeyCache caches tests for public keys. Since SSH clients
|
|
||||||
// will query whether a public key is acceptable before attempting to
|
|
||||||
// authenticate with it, we end up with duplicate queries for public
|
|
||||||
// key validity. The cache only applies to a single ServerConn.
|
|
||||||
type pubKeyCache struct {
|
|
||||||
keys []cachedPubKey
|
|
||||||
}
|
|
||||||
|
|
||||||
// get returns the result for a given user/algo/key tuple.
|
|
||||||
func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) {
|
|
||||||
for _, k := range c.keys {
|
|
||||||
if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) {
|
|
||||||
return k, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cachedPubKey{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// add adds the given tuple to the cache.
|
|
||||||
func (c *pubKeyCache) add(candidate cachedPubKey) {
|
|
||||||
if len(c.keys) < maxCachedPubKeys {
|
|
||||||
c.keys = append(c.keys, candidate)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServerConn is an authenticated SSH connection, as seen from the
|
|
||||||
// server
|
|
||||||
type ServerConn struct {
|
|
||||||
Conn
|
|
||||||
|
|
||||||
// If the succeeding authentication callback returned a
|
|
||||||
// non-nil Permissions pointer, it is stored here.
|
|
||||||
Permissions *Permissions
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewServerConn starts a new SSH server with c as the underlying
|
|
||||||
// transport. It starts with a handshake and, if the handshake is
|
|
||||||
// unsuccessful, it closes the connection and returns an error. The
|
|
||||||
// Request and NewChannel channels must be serviced, or the connection
|
|
||||||
// will hang.
|
|
||||||
//
|
|
||||||
// The returned error may be of type *ServerAuthError for
|
|
||||||
// authentication errors.
|
|
||||||
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
|
|
||||||
fullConf := *config
|
|
||||||
fullConf.SetDefaults()
|
|
||||||
if fullConf.MaxAuthTries == 0 {
|
|
||||||
fullConf.MaxAuthTries = 6
|
|
||||||
}
|
|
||||||
if len(fullConf.PublicKeyAuthAlgorithms) == 0 {
|
|
||||||
fullConf.PublicKeyAuthAlgorithms = supportedPubKeyAuthAlgos
|
|
||||||
} else {
|
|
||||||
for _, algo := range fullConf.PublicKeyAuthAlgorithms {
|
|
||||||
if !contains(supportedPubKeyAuthAlgos, algo) {
|
|
||||||
return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Check if the config contains any unsupported key exchanges
|
|
||||||
for _, kex := range fullConf.KeyExchanges {
|
|
||||||
if _, ok := serverForbiddenKexAlgos[kex]; ok {
|
|
||||||
return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s := &connection{
|
|
||||||
sshConn: sshConn{conn: c},
|
|
||||||
}
|
|
||||||
perms, err := s.serverHandshake(&fullConf)
|
|
||||||
if err != nil {
|
|
||||||
c.Close()
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// signAndMarshal signs the data with the appropriate algorithm,
|
|
||||||
// and serializes the result in SSH wire format. algo is the negotiate
|
|
||||||
// algorithm and may be a certificate type.
|
|
||||||
func signAndMarshal(k AlgorithmSigner, rand io.Reader, data []byte, algo string) ([]byte, error) {
|
|
||||||
sig, err := k.SignWithAlgorithm(rand, data, underlyingAlgo(algo))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return Marshal(sig), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handshake performs key exchange and user authentication.
|
|
||||||
func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) {
|
|
||||||
if len(config.hostKeys) == 0 {
|
|
||||||
return nil, errors.New("ssh: server has no host keys")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !config.NoClientAuth && config.PasswordCallback == nil && config.PublicKeyCallback == nil &&
|
|
||||||
config.KeyboardInteractiveCallback == nil && (config.GSSAPIWithMICConfig == nil ||
|
|
||||||
config.GSSAPIWithMICConfig.AllowLogin == nil || config.GSSAPIWithMICConfig.Server == nil) {
|
|
||||||
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.ServerVersion != "" {
|
|
||||||
s.serverVersion = []byte(config.ServerVersion)
|
|
||||||
} else {
|
|
||||||
s.serverVersion = []byte(packageVersion)
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
|
|
||||||
s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
|
|
||||||
|
|
||||||
if err := s.transport.waitSession(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We just did the key change, so the session ID is established.
|
|
||||||
s.sessionID = s.transport.getSessionID()
|
|
||||||
|
|
||||||
var packet []byte
|
|
||||||
if packet, err = s.transport.readPacket(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var serviceRequest serviceRequestMsg
|
|
||||||
if err = Unmarshal(packet, &serviceRequest); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if serviceRequest.Service != serviceUserAuth {
|
|
||||||
return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
|
|
||||||
}
|
|
||||||
serviceAccept := serviceAcceptMsg{
|
|
||||||
Service: serviceUserAuth,
|
|
||||||
}
|
|
||||||
if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
perms, err := s.serverAuthenticate(config)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
s.mux = newMux(s.transport)
|
|
||||||
return perms, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
|
|
||||||
if addr == nil {
|
|
||||||
return errors.New("ssh: no address known for client, but source-address match required")
|
|
||||||
}
|
|
||||||
|
|
||||||
tcpAddr, ok := addr.(*net.TCPAddr)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, sourceAddr := range strings.Split(sourceAddrs, ",") {
|
|
||||||
if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil {
|
|
||||||
if allowedIP.Equal(tcpAddr.IP) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
_, ipNet, err := net.ParseCIDR(sourceAddr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ipNet.Contains(tcpAddr.IP) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, token []byte, s *connection,
|
|
||||||
sessionID []byte, userAuthReq userAuthRequestMsg) (authErr error, perms *Permissions, err error) {
|
|
||||||
gssAPIServer := gssapiConfig.Server
|
|
||||||
defer gssAPIServer.DeleteSecContext()
|
|
||||||
var srcName string
|
|
||||||
for {
|
|
||||||
var (
|
|
||||||
outToken []byte
|
|
||||||
needContinue bool
|
|
||||||
)
|
|
||||||
outToken, srcName, needContinue, err = gssAPIServer.AcceptSecContext(token)
|
|
||||||
if err != nil {
|
|
||||||
return err, nil, nil
|
|
||||||
}
|
|
||||||
if len(outToken) != 0 {
|
|
||||||
if err := s.transport.writePacket(Marshal(&userAuthGSSAPIToken{
|
|
||||||
Token: outToken,
|
|
||||||
})); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !needContinue {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
packet, err := s.transport.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
userAuthGSSAPITokenReq := &userAuthGSSAPIToken{}
|
|
||||||
if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
token = userAuthGSSAPITokenReq.Token
|
|
||||||
}
|
|
||||||
packet, err := s.transport.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
userAuthGSSAPIMICReq := &userAuthGSSAPIMIC{}
|
|
||||||
if err := Unmarshal(packet, userAuthGSSAPIMICReq); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
mic := buildMIC(string(sessionID), userAuthReq.User, userAuthReq.Service, userAuthReq.Method)
|
|
||||||
if err := gssAPIServer.VerifyMIC(mic, userAuthGSSAPIMICReq.MIC); err != nil {
|
|
||||||
return err, nil, nil
|
|
||||||
}
|
|
||||||
perms, authErr = gssapiConfig.AllowLogin(s, srcName)
|
|
||||||
return authErr, perms, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isAlgoCompatible checks if the signature format is compatible with the
|
|
||||||
// selected algorithm taking into account edge cases that occur with old
|
|
||||||
// clients.
|
|
||||||
func isAlgoCompatible(algo, sigFormat string) bool {
|
|
||||||
// Compatibility for old clients.
|
|
||||||
//
|
|
||||||
// For certificate authentication with OpenSSH 7.2-7.7 signature format can
|
|
||||||
// be rsa-sha2-256 or rsa-sha2-512 for the algorithm
|
|
||||||
// ssh-rsa-cert-v01@openssh.com.
|
|
||||||
//
|
|
||||||
// With gpg-agent < 2.2.6 the algorithm can be rsa-sha2-256 or rsa-sha2-512
|
|
||||||
// for signature format ssh-rsa.
|
|
||||||
if isRSA(algo) && isRSA(sigFormat) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// Standard case: the underlying algorithm must match the signature format.
|
|
||||||
return underlyingAlgo(algo) == sigFormat
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServerAuthError represents server authentication errors and is
|
|
||||||
// sometimes returned by NewServerConn. It appends any authentication
|
|
||||||
// errors that may occur, and is returned if all of the authentication
|
|
||||||
// methods provided by the user failed to authenticate.
|
|
||||||
type ServerAuthError struct {
|
|
||||||
// Errors contains authentication errors returned by the authentication
|
|
||||||
// callback methods. The first entry is typically ErrNoAuth.
|
|
||||||
Errors []error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l ServerAuthError) Error() string {
|
|
||||||
var errs []string
|
|
||||||
for _, err := range l.Errors {
|
|
||||||
errs = append(errs, err.Error())
|
|
||||||
}
|
|
||||||
return "[" + strings.Join(errs, ", ") + "]"
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrNoAuth is the error value returned if no
|
|
||||||
// authentication method has been passed yet. This happens as a normal
|
|
||||||
// part of the authentication loop, since the client first tries
|
|
||||||
// 'none' authentication to discover available methods.
|
|
||||||
// It is returned in ServerAuthError.Errors from NewServerConn.
|
|
||||||
var ErrNoAuth = errors.New("ssh: no auth passed yet")
|
|
||||||
|
|
||||||
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
|
|
||||||
sessionID := s.transport.getSessionID()
|
|
||||||
var cache pubKeyCache
|
|
||||||
var perms *Permissions
|
|
||||||
|
|
||||||
authFailures := 0
|
|
||||||
var authErrs []error
|
|
||||||
var displayedBanner bool
|
|
||||||
|
|
||||||
userAuthLoop:
|
|
||||||
for {
|
|
||||||
if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 {
|
|
||||||
discMsg := &disconnectMsg{
|
|
||||||
Reason: 2,
|
|
||||||
Message: "too many authentication failures",
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, discMsg
|
|
||||||
}
|
|
||||||
|
|
||||||
var userAuthReq userAuthRequestMsg
|
|
||||||
if packet, err := s.transport.readPacket(); err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
return nil, &ServerAuthError{Errors: authErrs}
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
} else if err = Unmarshal(packet, &userAuthReq); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if userAuthReq.Service != serviceSSH {
|
|
||||||
return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.user = userAuthReq.User
|
|
||||||
|
|
||||||
if !displayedBanner && config.BannerCallback != nil {
|
|
||||||
displayedBanner = true
|
|
||||||
msg := config.BannerCallback(s)
|
|
||||||
if msg != "" {
|
|
||||||
bannerMsg := &userAuthBannerMsg{
|
|
||||||
Message: msg,
|
|
||||||
}
|
|
||||||
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
perms = nil
|
|
||||||
authErr := ErrNoAuth
|
|
||||||
|
|
||||||
switch userAuthReq.Method {
|
|
||||||
case "none":
|
|
||||||
if config.NoClientAuth {
|
|
||||||
if config.NoClientAuthCallback != nil {
|
|
||||||
perms, authErr = config.NoClientAuthCallback(s)
|
|
||||||
} else {
|
|
||||||
authErr = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// allow initial attempt of 'none' without penalty
|
|
||||||
if authFailures == 0 {
|
|
||||||
authFailures--
|
|
||||||
}
|
|
||||||
case "password":
|
|
||||||
if config.PasswordCallback == nil {
|
|
||||||
authErr = errors.New("ssh: password auth not configured")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
payload := userAuthReq.Payload
|
|
||||||
if len(payload) < 1 || payload[0] != 0 {
|
|
||||||
return nil, parseError(msgUserAuthRequest)
|
|
||||||
}
|
|
||||||
payload = payload[1:]
|
|
||||||
password, payload, ok := parseString(payload)
|
|
||||||
if !ok || len(payload) > 0 {
|
|
||||||
return nil, parseError(msgUserAuthRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
perms, authErr = config.PasswordCallback(s, password)
|
|
||||||
case "keyboard-interactive":
|
|
||||||
if config.KeyboardInteractiveCallback == nil {
|
|
||||||
authErr = errors.New("ssh: keyboard-interactive auth not configured")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
prompter := &sshClientKeyboardInteractive{s}
|
|
||||||
perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge)
|
|
||||||
case "publickey":
|
|
||||||
if config.PublicKeyCallback == nil {
|
|
||||||
authErr = errors.New("ssh: publickey auth not configured")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
payload := userAuthReq.Payload
|
|
||||||
if len(payload) < 1 {
|
|
||||||
return nil, parseError(msgUserAuthRequest)
|
|
||||||
}
|
|
||||||
isQuery := payload[0] == 0
|
|
||||||
payload = payload[1:]
|
|
||||||
algoBytes, payload, ok := parseString(payload)
|
|
||||||
if !ok {
|
|
||||||
return nil, parseError(msgUserAuthRequest)
|
|
||||||
}
|
|
||||||
algo := string(algoBytes)
|
|
||||||
if !contains(config.PublicKeyAuthAlgorithms, underlyingAlgo(algo)) {
|
|
||||||
authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
pubKeyData, payload, ok := parseString(payload)
|
|
||||||
if !ok {
|
|
||||||
return nil, parseError(msgUserAuthRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
pubKey, err := ParsePublicKey(pubKeyData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
candidate, ok := cache.get(s.user, pubKeyData)
|
|
||||||
if !ok {
|
|
||||||
candidate.user = s.user
|
|
||||||
candidate.pubKeyData = pubKeyData
|
|
||||||
candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey)
|
|
||||||
if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
|
|
||||||
candidate.result = checkSourceAddress(
|
|
||||||
s.RemoteAddr(),
|
|
||||||
candidate.perms.CriticalOptions[sourceAddressCriticalOption])
|
|
||||||
}
|
|
||||||
cache.add(candidate)
|
|
||||||
}
|
|
||||||
|
|
||||||
if isQuery {
|
|
||||||
// The client can query if the given public key
|
|
||||||
// would be okay.
|
|
||||||
|
|
||||||
if len(payload) > 0 {
|
|
||||||
return nil, parseError(msgUserAuthRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
if candidate.result == nil {
|
|
||||||
okMsg := userAuthPubKeyOkMsg{
|
|
||||||
Algo: algo,
|
|
||||||
PubKey: pubKeyData,
|
|
||||||
}
|
|
||||||
if err = s.transport.writePacket(Marshal(&okMsg)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
continue userAuthLoop
|
|
||||||
}
|
|
||||||
authErr = candidate.result
|
|
||||||
} else {
|
|
||||||
sig, payload, ok := parseSignature(payload)
|
|
||||||
if !ok || len(payload) > 0 {
|
|
||||||
return nil, parseError(msgUserAuthRequest)
|
|
||||||
}
|
|
||||||
// Ensure the declared public key algo is compatible with the
|
|
||||||
// decoded one. This check will ensure we don't accept e.g.
|
|
||||||
// ssh-rsa-cert-v01@openssh.com algorithm with ssh-rsa public
|
|
||||||
// key type. The algorithm and public key type must be
|
|
||||||
// consistent: both must be certificate algorithms, or neither.
|
|
||||||
if !contains(algorithmsForKeyFormat(pubKey.Type()), algo) {
|
|
||||||
authErr = fmt.Errorf("ssh: public key type %q not compatible with selected algorithm %q",
|
|
||||||
pubKey.Type(), algo)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
// Ensure the public key algo and signature algo
|
|
||||||
// are supported. Compare the private key
|
|
||||||
// algorithm name that corresponds to algo with
|
|
||||||
// sig.Format. This is usually the same, but
|
|
||||||
// for certs, the names differ.
|
|
||||||
if !contains(config.PublicKeyAuthAlgorithms, sig.Format) {
|
|
||||||
authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if !isAlgoCompatible(algo, sig.Format) {
|
|
||||||
authErr = fmt.Errorf("ssh: signature %q not compatible with selected algorithm %q", sig.Format, algo)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
signedData := buildDataSignedForAuth(sessionID, userAuthReq, algo, pubKeyData)
|
|
||||||
|
|
||||||
if err := pubKey.Verify(signedData, sig); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
authErr = candidate.result
|
|
||||||
perms = candidate.perms
|
|
||||||
}
|
|
||||||
case "gssapi-with-mic":
|
|
||||||
if config.GSSAPIWithMICConfig == nil {
|
|
||||||
authErr = errors.New("ssh: gssapi-with-mic auth not configured")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
gssapiConfig := config.GSSAPIWithMICConfig
|
|
||||||
userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload)
|
|
||||||
if err != nil {
|
|
||||||
return nil, parseError(msgUserAuthRequest)
|
|
||||||
}
|
|
||||||
// OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication.
|
|
||||||
if userAuthRequestGSSAPI.N == 0 {
|
|
||||||
authErr = fmt.Errorf("ssh: Mechanism negotiation is not supported")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
var i uint32
|
|
||||||
present := false
|
|
||||||
for i = 0; i < userAuthRequestGSSAPI.N; i++ {
|
|
||||||
if userAuthRequestGSSAPI.OIDS[i].Equal(krb5Mesh) {
|
|
||||||
present = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !present {
|
|
||||||
authErr = fmt.Errorf("ssh: GSSAPI authentication must use the Kerberos V5 mechanism")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
// Initial server response, see RFC 4462 section 3.3.
|
|
||||||
if err := s.transport.writePacket(Marshal(&userAuthGSSAPIResponse{
|
|
||||||
SupportMech: krb5OID,
|
|
||||||
})); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// Exchange token, see RFC 4462 section 3.4.
|
|
||||||
packet, err := s.transport.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
userAuthGSSAPITokenReq := &userAuthGSSAPIToken{}
|
|
||||||
if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
authErr, perms, err = gssExchangeToken(gssapiConfig, userAuthGSSAPITokenReq.Token, s, sessionID,
|
|
||||||
userAuthReq)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method)
|
|
||||||
}
|
|
||||||
|
|
||||||
authErrs = append(authErrs, authErr)
|
|
||||||
|
|
||||||
if config.AuthLogCallback != nil {
|
|
||||||
config.AuthLogCallback(s, userAuthReq.Method, authErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if authErr == nil {
|
|
||||||
break userAuthLoop
|
|
||||||
}
|
|
||||||
|
|
||||||
authFailures++
|
|
||||||
if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
|
|
||||||
// If we have hit the max attempts, don't bother sending the
|
|
||||||
// final SSH_MSG_USERAUTH_FAILURE message, since there are
|
|
||||||
// no more authentication methods which can be attempted,
|
|
||||||
// and this message may cause the client to re-attempt
|
|
||||||
// authentication while we send the disconnect message.
|
|
||||||
// Continue, and trigger the disconnect at the start of
|
|
||||||
// the loop.
|
|
||||||
//
|
|
||||||
// The SSH specification is somewhat confusing about this,
|
|
||||||
// RFC 4252 Section 5.1 requires each authentication failure
|
|
||||||
// be responded to with a respective SSH_MSG_USERAUTH_FAILURE
|
|
||||||
// message, but Section 4 says the server should disconnect
|
|
||||||
// after some number of attempts, but it isn't explicit which
|
|
||||||
// message should take precedence (i.e. should there be a failure
|
|
||||||
// message than a disconnect message, or if we are going to
|
|
||||||
// disconnect, should we only send that message.)
|
|
||||||
//
|
|
||||||
// Either way, OpenSSH disconnects immediately after the last
|
|
||||||
// failed authnetication attempt, and given they are typically
|
|
||||||
// considered the golden implementation it seems reasonable
|
|
||||||
// to match that behavior.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var failureMsg userAuthFailureMsg
|
|
||||||
if config.PasswordCallback != nil {
|
|
||||||
failureMsg.Methods = append(failureMsg.Methods, "password")
|
|
||||||
}
|
|
||||||
if config.PublicKeyCallback != nil {
|
|
||||||
failureMsg.Methods = append(failureMsg.Methods, "publickey")
|
|
||||||
}
|
|
||||||
if config.KeyboardInteractiveCallback != nil {
|
|
||||||
failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
|
|
||||||
}
|
|
||||||
if config.GSSAPIWithMICConfig != nil && config.GSSAPIWithMICConfig.Server != nil &&
|
|
||||||
config.GSSAPIWithMICConfig.AllowLogin != nil {
|
|
||||||
failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(failureMsg.Methods) == 0 {
|
|
||||||
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return perms, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by
|
|
||||||
// asking the client on the other side of a ServerConn.
|
|
||||||
type sshClientKeyboardInteractive struct {
|
|
||||||
*connection
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *sshClientKeyboardInteractive) Challenge(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
|
|
||||||
if len(questions) != len(echos) {
|
|
||||||
return nil, errors.New("ssh: echos and questions must have equal length")
|
|
||||||
}
|
|
||||||
|
|
||||||
var prompts []byte
|
|
||||||
for i := range questions {
|
|
||||||
prompts = appendString(prompts, questions[i])
|
|
||||||
prompts = appendBool(prompts, echos[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{
|
|
||||||
Name: name,
|
|
||||||
Instruction: instruction,
|
|
||||||
NumPrompts: uint32(len(questions)),
|
|
||||||
Prompts: prompts,
|
|
||||||
})); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
packet, err := c.transport.readPacket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if packet[0] != msgUserAuthInfoResponse {
|
|
||||||
return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0])
|
|
||||||
}
|
|
||||||
packet = packet[1:]
|
|
||||||
|
|
||||||
n, packet, ok := parseUint32(packet)
|
|
||||||
if !ok || int(n) != len(questions) {
|
|
||||||
return nil, parseError(msgUserAuthInfoResponse)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := uint32(0); i < n; i++ {
|
|
||||||
ans, rest, ok := parseString(packet)
|
|
||||||
if !ok {
|
|
||||||
return nil, parseError(msgUserAuthInfoResponse)
|
|
||||||
}
|
|
||||||
|
|
||||||
answers = append(answers, string(ans))
|
|
||||||
packet = rest
|
|
||||||
}
|
|
||||||
if len(packet) != 0 {
|
|
||||||
return nil, errors.New("ssh: junk at end of message")
|
|
||||||
}
|
|
||||||
|
|
||||||
return answers, nil
|
|
||||||
}
|
|
@ -1,647 +0,0 @@
|
|||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
// Session implements an interactive session described in
|
|
||||||
// "RFC 4254, section 6".
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Signal string
|
|
||||||
|
|
||||||
// POSIX signals as listed in RFC 4254 Section 6.10.
|
|
||||||
const (
|
|
||||||
SIGABRT Signal = "ABRT"
|
|
||||||
SIGALRM Signal = "ALRM"
|
|
||||||
SIGFPE Signal = "FPE"
|
|
||||||
SIGHUP Signal = "HUP"
|
|
||||||
SIGILL Signal = "ILL"
|
|
||||||
SIGINT Signal = "INT"
|
|
||||||
SIGKILL Signal = "KILL"
|
|
||||||
SIGPIPE Signal = "PIPE"
|
|
||||||
SIGQUIT Signal = "QUIT"
|
|
||||||
SIGSEGV Signal = "SEGV"
|
|
||||||
SIGTERM Signal = "TERM"
|
|
||||||
SIGUSR1 Signal = "USR1"
|
|
||||||
SIGUSR2 Signal = "USR2"
|
|
||||||
)
|
|
||||||
|
|
||||||
var signals = map[Signal]int{
|
|
||||||
SIGABRT: 6,
|
|
||||||
SIGALRM: 14,
|
|
||||||
SIGFPE: 8,
|
|
||||||
SIGHUP: 1,
|
|
||||||
SIGILL: 4,
|
|
||||||
SIGINT: 2,
|
|
||||||
SIGKILL: 9,
|
|
||||||
SIGPIPE: 13,
|
|
||||||
SIGQUIT: 3,
|
|
||||||
SIGSEGV: 11,
|
|
||||||
SIGTERM: 15,
|
|
||||||
}
|
|
||||||
|
|
||||||
type TerminalModes map[uint8]uint32
|
|
||||||
|
|
||||||
// POSIX terminal mode flags as listed in RFC 4254 Section 8.
|
|
||||||
const (
|
|
||||||
tty_OP_END = 0
|
|
||||||
VINTR = 1
|
|
||||||
VQUIT = 2
|
|
||||||
VERASE = 3
|
|
||||||
VKILL = 4
|
|
||||||
VEOF = 5
|
|
||||||
VEOL = 6
|
|
||||||
VEOL2 = 7
|
|
||||||
VSTART = 8
|
|
||||||
VSTOP = 9
|
|
||||||
VSUSP = 10
|
|
||||||
VDSUSP = 11
|
|
||||||
VREPRINT = 12
|
|
||||||
VWERASE = 13
|
|
||||||
VLNEXT = 14
|
|
||||||
VFLUSH = 15
|
|
||||||
VSWTCH = 16
|
|
||||||
VSTATUS = 17
|
|
||||||
VDISCARD = 18
|
|
||||||
IGNPAR = 30
|
|
||||||
PARMRK = 31
|
|
||||||
INPCK = 32
|
|
||||||
ISTRIP = 33
|
|
||||||
INLCR = 34
|
|
||||||
IGNCR = 35
|
|
||||||
ICRNL = 36
|
|
||||||
IUCLC = 37
|
|
||||||
IXON = 38
|
|
||||||
IXANY = 39
|
|
||||||
IXOFF = 40
|
|
||||||
IMAXBEL = 41
|
|
||||||
IUTF8 = 42 // RFC 8160
|
|
||||||
ISIG = 50
|
|
||||||
ICANON = 51
|
|
||||||
XCASE = 52
|
|
||||||
ECHO = 53
|
|
||||||
ECHOE = 54
|
|
||||||
ECHOK = 55
|
|
||||||
ECHONL = 56
|
|
||||||
NOFLSH = 57
|
|
||||||
TOSTOP = 58
|
|
||||||
IEXTEN = 59
|
|
||||||
ECHOCTL = 60
|
|
||||||
ECHOKE = 61
|
|
||||||
PENDIN = 62
|
|
||||||
OPOST = 70
|
|
||||||
OLCUC = 71
|
|
||||||
ONLCR = 72
|
|
||||||
OCRNL = 73
|
|
||||||
ONOCR = 74
|
|
||||||
ONLRET = 75
|
|
||||||
CS7 = 90
|
|
||||||
CS8 = 91
|
|
||||||
PARENB = 92
|
|
||||||
PARODD = 93
|
|
||||||
TTY_OP_ISPEED = 128
|
|
||||||
TTY_OP_OSPEED = 129
|
|
||||||
)
|
|
||||||
|
|
||||||
// A Session represents a connection to a remote command or shell.
|
|
||||||
type Session struct {
|
|
||||||
// Stdin specifies the remote process's standard input.
|
|
||||||
// If Stdin is nil, the remote process reads from an empty
|
|
||||||
// bytes.Buffer.
|
|
||||||
Stdin io.Reader
|
|
||||||
|
|
||||||
// Stdout and Stderr specify the remote process's standard
|
|
||||||
// output and error.
|
|
||||||
//
|
|
||||||
// If either is nil, Run connects the corresponding file
|
|
||||||
// descriptor to an instance of io.Discard. There is a
|
|
||||||
// fixed amount of buffering that is shared for the two streams.
|
|
||||||
// If either blocks it may eventually cause the remote
|
|
||||||
// command to block.
|
|
||||||
Stdout io.Writer
|
|
||||||
Stderr io.Writer
|
|
||||||
|
|
||||||
ch Channel // the channel backing this session
|
|
||||||
started bool // true once Start, Run or Shell is invoked.
|
|
||||||
copyFuncs []func() error
|
|
||||||
errors chan error // one send per copyFunc
|
|
||||||
|
|
||||||
// true if pipe method is active
|
|
||||||
stdinpipe, stdoutpipe, stderrpipe bool
|
|
||||||
|
|
||||||
// stdinPipeWriter is non-nil if StdinPipe has not been called
|
|
||||||
// and Stdin was specified by the user; it is the write end of
|
|
||||||
// a pipe connecting Session.Stdin to the stdin channel.
|
|
||||||
stdinPipeWriter io.WriteCloser
|
|
||||||
|
|
||||||
exitStatus chan error
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendRequest sends an out-of-band channel request on the SSH channel
|
|
||||||
// underlying the session.
|
|
||||||
func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
|
|
||||||
return s.ch.SendRequest(name, wantReply, payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) Close() error {
|
|
||||||
return s.ch.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 Section 6.4.
|
|
||||||
type setenvRequest struct {
|
|
||||||
Name string
|
|
||||||
Value string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setenv sets an environment variable that will be applied to any
|
|
||||||
// command executed by Shell or Run.
|
|
||||||
func (s *Session) Setenv(name, value string) error {
|
|
||||||
msg := setenvRequest{
|
|
||||||
Name: name,
|
|
||||||
Value: value,
|
|
||||||
}
|
|
||||||
ok, err := s.ch.SendRequest("env", true, Marshal(&msg))
|
|
||||||
if err == nil && !ok {
|
|
||||||
err = errors.New("ssh: setenv failed")
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 Section 6.2.
|
|
||||||
type ptyRequestMsg struct {
|
|
||||||
Term string
|
|
||||||
Columns uint32
|
|
||||||
Rows uint32
|
|
||||||
Width uint32
|
|
||||||
Height uint32
|
|
||||||
Modelist string
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestPty requests the association of a pty with the session on the remote host.
|
|
||||||
func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error {
|
|
||||||
var tm []byte
|
|
||||||
for k, v := range termmodes {
|
|
||||||
kv := struct {
|
|
||||||
Key byte
|
|
||||||
Val uint32
|
|
||||||
}{k, v}
|
|
||||||
|
|
||||||
tm = append(tm, Marshal(&kv)...)
|
|
||||||
}
|
|
||||||
tm = append(tm, tty_OP_END)
|
|
||||||
req := ptyRequestMsg{
|
|
||||||
Term: term,
|
|
||||||
Columns: uint32(w),
|
|
||||||
Rows: uint32(h),
|
|
||||||
Width: uint32(w * 8),
|
|
||||||
Height: uint32(h * 8),
|
|
||||||
Modelist: string(tm),
|
|
||||||
}
|
|
||||||
ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req))
|
|
||||||
if err == nil && !ok {
|
|
||||||
err = errors.New("ssh: pty-req failed")
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 Section 6.5.
|
|
||||||
type subsystemRequestMsg struct {
|
|
||||||
Subsystem string
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestSubsystem requests the association of a subsystem with the session on the remote host.
|
|
||||||
// A subsystem is a predefined command that runs in the background when the ssh session is initiated
|
|
||||||
func (s *Session) RequestSubsystem(subsystem string) error {
|
|
||||||
msg := subsystemRequestMsg{
|
|
||||||
Subsystem: subsystem,
|
|
||||||
}
|
|
||||||
ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg))
|
|
||||||
if err == nil && !ok {
|
|
||||||
err = errors.New("ssh: subsystem request failed")
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 Section 6.7.
|
|
||||||
type ptyWindowChangeMsg struct {
|
|
||||||
Columns uint32
|
|
||||||
Rows uint32
|
|
||||||
Width uint32
|
|
||||||
Height uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// WindowChange informs the remote host about a terminal window dimension change to h rows and w columns.
|
|
||||||
func (s *Session) WindowChange(h, w int) error {
|
|
||||||
req := ptyWindowChangeMsg{
|
|
||||||
Columns: uint32(w),
|
|
||||||
Rows: uint32(h),
|
|
||||||
Width: uint32(w * 8),
|
|
||||||
Height: uint32(h * 8),
|
|
||||||
}
|
|
||||||
_, err := s.ch.SendRequest("window-change", false, Marshal(&req))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 Section 6.9.
|
|
||||||
type signalMsg struct {
|
|
||||||
Signal string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signal sends the given signal to the remote process.
|
|
||||||
// sig is one of the SIG* constants.
|
|
||||||
func (s *Session) Signal(sig Signal) error {
|
|
||||||
msg := signalMsg{
|
|
||||||
Signal: string(sig),
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := s.ch.SendRequest("signal", false, Marshal(&msg))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 Section 6.5.
|
|
||||||
type execMsg struct {
|
|
||||||
Command string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start runs cmd on the remote host. Typically, the remote
|
|
||||||
// server passes cmd to the shell for interpretation.
|
|
||||||
// A Session only accepts one call to Run, Start or Shell.
|
|
||||||
func (s *Session) Start(cmd string) error {
|
|
||||||
if s.started {
|
|
||||||
return errors.New("ssh: session already started")
|
|
||||||
}
|
|
||||||
req := execMsg{
|
|
||||||
Command: cmd,
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := s.ch.SendRequest("exec", true, Marshal(&req))
|
|
||||||
if err == nil && !ok {
|
|
||||||
err = fmt.Errorf("ssh: command %v failed", cmd)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.start()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run runs cmd on the remote host. Typically, the remote
|
|
||||||
// server passes cmd to the shell for interpretation.
|
|
||||||
// A Session only accepts one call to Run, Start, Shell, Output,
|
|
||||||
// or CombinedOutput.
|
|
||||||
//
|
|
||||||
// The returned error is nil if the command runs, has no problems
|
|
||||||
// copying stdin, stdout, and stderr, and exits with a zero exit
|
|
||||||
// status.
|
|
||||||
//
|
|
||||||
// If the remote server does not send an exit status, an error of type
|
|
||||||
// *ExitMissingError is returned. If the command completes
|
|
||||||
// unsuccessfully or is interrupted by a signal, the error is of type
|
|
||||||
// *ExitError. Other error types may be returned for I/O problems.
|
|
||||||
func (s *Session) Run(cmd string) error {
|
|
||||||
err := s.Start(cmd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output runs cmd on the remote host and returns its standard output.
|
|
||||||
func (s *Session) Output(cmd string) ([]byte, error) {
|
|
||||||
if s.Stdout != nil {
|
|
||||||
return nil, errors.New("ssh: Stdout already set")
|
|
||||||
}
|
|
||||||
var b bytes.Buffer
|
|
||||||
s.Stdout = &b
|
|
||||||
err := s.Run(cmd)
|
|
||||||
return b.Bytes(), err
|
|
||||||
}
|
|
||||||
|
|
||||||
type singleWriter struct {
|
|
||||||
b bytes.Buffer
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *singleWriter) Write(p []byte) (int, error) {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
return w.b.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CombinedOutput runs cmd on the remote host and returns its combined
|
|
||||||
// standard output and standard error.
|
|
||||||
func (s *Session) CombinedOutput(cmd string) ([]byte, error) {
|
|
||||||
if s.Stdout != nil {
|
|
||||||
return nil, errors.New("ssh: Stdout already set")
|
|
||||||
}
|
|
||||||
if s.Stderr != nil {
|
|
||||||
return nil, errors.New("ssh: Stderr already set")
|
|
||||||
}
|
|
||||||
var b singleWriter
|
|
||||||
s.Stdout = &b
|
|
||||||
s.Stderr = &b
|
|
||||||
err := s.Run(cmd)
|
|
||||||
return b.b.Bytes(), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shell starts a login shell on the remote host. A Session only
|
|
||||||
// accepts one call to Run, Start, Shell, Output, or CombinedOutput.
|
|
||||||
func (s *Session) Shell() error {
|
|
||||||
if s.started {
|
|
||||||
return errors.New("ssh: session already started")
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := s.ch.SendRequest("shell", true, nil)
|
|
||||||
if err == nil && !ok {
|
|
||||||
return errors.New("ssh: could not start shell")
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.start()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) start() error {
|
|
||||||
s.started = true
|
|
||||||
|
|
||||||
type F func(*Session)
|
|
||||||
for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} {
|
|
||||||
setupFd(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.errors = make(chan error, len(s.copyFuncs))
|
|
||||||
for _, fn := range s.copyFuncs {
|
|
||||||
go func(fn func() error) {
|
|
||||||
s.errors <- fn()
|
|
||||||
}(fn)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait waits for the remote command to exit.
|
|
||||||
//
|
|
||||||
// The returned error is nil if the command runs, has no problems
|
|
||||||
// copying stdin, stdout, and stderr, and exits with a zero exit
|
|
||||||
// status.
|
|
||||||
//
|
|
||||||
// If the remote server does not send an exit status, an error of type
|
|
||||||
// *ExitMissingError is returned. If the command completes
|
|
||||||
// unsuccessfully or is interrupted by a signal, the error is of type
|
|
||||||
// *ExitError. Other error types may be returned for I/O problems.
|
|
||||||
func (s *Session) Wait() error {
|
|
||||||
if !s.started {
|
|
||||||
return errors.New("ssh: session not started")
|
|
||||||
}
|
|
||||||
waitErr := <-s.exitStatus
|
|
||||||
|
|
||||||
if s.stdinPipeWriter != nil {
|
|
||||||
s.stdinPipeWriter.Close()
|
|
||||||
}
|
|
||||||
var copyError error
|
|
||||||
for range s.copyFuncs {
|
|
||||||
if err := <-s.errors; err != nil && copyError == nil {
|
|
||||||
copyError = err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if waitErr != nil {
|
|
||||||
return waitErr
|
|
||||||
}
|
|
||||||
return copyError
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) wait(reqs <-chan *Request) error {
|
|
||||||
wm := Waitmsg{status: -1}
|
|
||||||
// Wait for msg channel to be closed before returning.
|
|
||||||
for msg := range reqs {
|
|
||||||
switch msg.Type {
|
|
||||||
case "exit-status":
|
|
||||||
wm.status = int(binary.BigEndian.Uint32(msg.Payload))
|
|
||||||
case "exit-signal":
|
|
||||||
var sigval struct {
|
|
||||||
Signal string
|
|
||||||
CoreDumped bool
|
|
||||||
Error string
|
|
||||||
Lang string
|
|
||||||
}
|
|
||||||
if err := Unmarshal(msg.Payload, &sigval); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Must sanitize strings?
|
|
||||||
wm.signal = sigval.Signal
|
|
||||||
wm.msg = sigval.Error
|
|
||||||
wm.lang = sigval.Lang
|
|
||||||
default:
|
|
||||||
// This handles keepalives and matches
|
|
||||||
// OpenSSH's behaviour.
|
|
||||||
if msg.WantReply {
|
|
||||||
msg.Reply(false, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if wm.status == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if wm.status == -1 {
|
|
||||||
// exit-status was never sent from server
|
|
||||||
if wm.signal == "" {
|
|
||||||
// signal was not sent either. RFC 4254
|
|
||||||
// section 6.10 recommends against this
|
|
||||||
// behavior, but it is allowed, so we let
|
|
||||||
// clients handle it.
|
|
||||||
return &ExitMissingError{}
|
|
||||||
}
|
|
||||||
wm.status = 128
|
|
||||||
if _, ok := signals[Signal(wm.signal)]; ok {
|
|
||||||
wm.status += signals[Signal(wm.signal)]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ExitError{wm}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExitMissingError is returned if a session is torn down cleanly, but
|
|
||||||
// the server sends no confirmation of the exit status.
|
|
||||||
type ExitMissingError struct{}
|
|
||||||
|
|
||||||
func (e *ExitMissingError) Error() string {
|
|
||||||
return "wait: remote command exited without exit status or exit signal"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) stdin() {
|
|
||||||
if s.stdinpipe {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var stdin io.Reader
|
|
||||||
if s.Stdin == nil {
|
|
||||||
stdin = new(bytes.Buffer)
|
|
||||||
} else {
|
|
||||||
r, w := io.Pipe()
|
|
||||||
go func() {
|
|
||||||
_, err := io.Copy(w, s.Stdin)
|
|
||||||
w.CloseWithError(err)
|
|
||||||
}()
|
|
||||||
stdin, s.stdinPipeWriter = r, w
|
|
||||||
}
|
|
||||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
|
||||||
_, err := io.Copy(s.ch, stdin)
|
|
||||||
if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF {
|
|
||||||
err = err1
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) stdout() {
|
|
||||||
if s.stdoutpipe {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.Stdout == nil {
|
|
||||||
s.Stdout = io.Discard
|
|
||||||
}
|
|
||||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
|
||||||
_, err := io.Copy(s.Stdout, s.ch)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) stderr() {
|
|
||||||
if s.stderrpipe {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.Stderr == nil {
|
|
||||||
s.Stderr = io.Discard
|
|
||||||
}
|
|
||||||
s.copyFuncs = append(s.copyFuncs, func() error {
|
|
||||||
_, err := io.Copy(s.Stderr, s.ch.Stderr())
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// sessionStdin reroutes Close to CloseWrite.
|
|
||||||
type sessionStdin struct {
|
|
||||||
io.Writer
|
|
||||||
ch Channel
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sessionStdin) Close() error {
|
|
||||||
return s.ch.CloseWrite()
|
|
||||||
}
|
|
||||||
|
|
||||||
// StdinPipe returns a pipe that will be connected to the
|
|
||||||
// remote command's standard input when the command starts.
|
|
||||||
func (s *Session) StdinPipe() (io.WriteCloser, error) {
|
|
||||||
if s.Stdin != nil {
|
|
||||||
return nil, errors.New("ssh: Stdin already set")
|
|
||||||
}
|
|
||||||
if s.started {
|
|
||||||
return nil, errors.New("ssh: StdinPipe after process started")
|
|
||||||
}
|
|
||||||
s.stdinpipe = true
|
|
||||||
return &sessionStdin{s.ch, s.ch}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// StdoutPipe returns a pipe that will be connected to the
|
|
||||||
// remote command's standard output when the command starts.
|
|
||||||
// There is a fixed amount of buffering that is shared between
|
|
||||||
// stdout and stderr streams. If the StdoutPipe reader is
|
|
||||||
// not serviced fast enough it may eventually cause the
|
|
||||||
// remote command to block.
|
|
||||||
func (s *Session) StdoutPipe() (io.Reader, error) {
|
|
||||||
if s.Stdout != nil {
|
|
||||||
return nil, errors.New("ssh: Stdout already set")
|
|
||||||
}
|
|
||||||
if s.started {
|
|
||||||
return nil, errors.New("ssh: StdoutPipe after process started")
|
|
||||||
}
|
|
||||||
s.stdoutpipe = true
|
|
||||||
return s.ch, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// StderrPipe returns a pipe that will be connected to the
|
|
||||||
// remote command's standard error when the command starts.
|
|
||||||
// There is a fixed amount of buffering that is shared between
|
|
||||||
// stdout and stderr streams. If the StderrPipe reader is
|
|
||||||
// not serviced fast enough it may eventually cause the
|
|
||||||
// remote command to block.
|
|
||||||
func (s *Session) StderrPipe() (io.Reader, error) {
|
|
||||||
if s.Stderr != nil {
|
|
||||||
return nil, errors.New("ssh: Stderr already set")
|
|
||||||
}
|
|
||||||
if s.started {
|
|
||||||
return nil, errors.New("ssh: StderrPipe after process started")
|
|
||||||
}
|
|
||||||
s.stderrpipe = true
|
|
||||||
return s.ch.Stderr(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// newSession returns a new interactive session on the remote host.
|
|
||||||
func newSession(ch Channel, reqs <-chan *Request) (*Session, error) {
|
|
||||||
s := &Session{
|
|
||||||
ch: ch,
|
|
||||||
}
|
|
||||||
s.exitStatus = make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
s.exitStatus <- s.wait(reqs)
|
|
||||||
}()
|
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// An ExitError reports unsuccessful completion of a remote command.
|
|
||||||
type ExitError struct {
|
|
||||||
Waitmsg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *ExitError) Error() string {
|
|
||||||
return e.Waitmsg.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Waitmsg stores the information about an exited remote command
|
|
||||||
// as reported by Wait.
|
|
||||||
type Waitmsg struct {
|
|
||||||
status int
|
|
||||||
signal string
|
|
||||||
msg string
|
|
||||||
lang string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExitStatus returns the exit status of the remote command.
|
|
||||||
func (w Waitmsg) ExitStatus() int {
|
|
||||||
return w.status
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signal returns the exit signal of the remote command if
|
|
||||||
// it was terminated violently.
|
|
||||||
func (w Waitmsg) Signal() string {
|
|
||||||
return w.signal
|
|
||||||
}
|
|
||||||
|
|
||||||
// Msg returns the exit message given by the remote command
|
|
||||||
func (w Waitmsg) Msg() string {
|
|
||||||
return w.msg
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lang returns the language tag. See RFC 3066
|
|
||||||
func (w Waitmsg) Lang() string {
|
|
||||||
return w.lang
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w Waitmsg) String() string {
|
|
||||||
str := fmt.Sprintf("Process exited with status %v", w.status)
|
|
||||||
if w.signal != "" {
|
|
||||||
str += fmt.Sprintf(" from signal %v", w.signal)
|
|
||||||
}
|
|
||||||
if w.msg != "" {
|
|
||||||
str += fmt.Sprintf(". Reason was: %v", w.msg)
|
|
||||||
}
|
|
||||||
return str
|
|
||||||
}
|
|
@ -1,139 +0,0 @@
|
|||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/asn1"
|
|
||||||
"errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
var krb5OID []byte
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
krb5OID, _ = asn1.Marshal(krb5Mesh)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GSSAPIClient provides the API to plug-in GSSAPI authentication for client logins.
|
|
||||||
type GSSAPIClient interface {
|
|
||||||
// InitSecContext initiates the establishment of a security context for GSS-API between the
|
|
||||||
// ssh client and ssh server. Initially the token parameter should be specified as nil.
|
|
||||||
// The routine may return a outputToken which should be transferred to
|
|
||||||
// the ssh server, where the ssh server will present it to
|
|
||||||
// AcceptSecContext. If no token need be sent, InitSecContext will indicate this by setting
|
|
||||||
// needContinue to false. To complete the context
|
|
||||||
// establishment, one or more reply tokens may be required from the ssh
|
|
||||||
// server;if so, InitSecContext will return a needContinue which is true.
|
|
||||||
// In this case, InitSecContext should be called again when the
|
|
||||||
// reply token is received from the ssh server, passing the reply
|
|
||||||
// token to InitSecContext via the token parameters.
|
|
||||||
// See RFC 2743 section 2.2.1 and RFC 4462 section 3.4.
|
|
||||||
InitSecContext(target string, token []byte, isGSSDelegCreds bool) (outputToken []byte, needContinue bool, err error)
|
|
||||||
// GetMIC generates a cryptographic MIC for the SSH2 message, and places
|
|
||||||
// the MIC in a token for transfer to the ssh server.
|
|
||||||
// The contents of the MIC field are obtained by calling GSS_GetMIC()
|
|
||||||
// over the following, using the GSS-API context that was just
|
|
||||||
// established:
|
|
||||||
// string session identifier
|
|
||||||
// byte SSH_MSG_USERAUTH_REQUEST
|
|
||||||
// string user name
|
|
||||||
// string service
|
|
||||||
// string "gssapi-with-mic"
|
|
||||||
// See RFC 2743 section 2.3.1 and RFC 4462 3.5.
|
|
||||||
GetMIC(micFiled []byte) ([]byte, error)
|
|
||||||
// Whenever possible, it should be possible for
|
|
||||||
// DeleteSecContext() calls to be successfully processed even
|
|
||||||
// if other calls cannot succeed, thereby enabling context-related
|
|
||||||
// resources to be released.
|
|
||||||
// In addition to deleting established security contexts,
|
|
||||||
// gss_delete_sec_context must also be able to delete "half-built"
|
|
||||||
// security contexts resulting from an incomplete sequence of
|
|
||||||
// InitSecContext()/AcceptSecContext() calls.
|
|
||||||
// See RFC 2743 section 2.2.3.
|
|
||||||
DeleteSecContext() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// GSSAPIServer provides the API to plug in GSSAPI authentication for server logins.
|
|
||||||
type GSSAPIServer interface {
|
|
||||||
// AcceptSecContext allows a remotely initiated security context between the application
|
|
||||||
// and a remote peer to be established by the ssh client. The routine may return a
|
|
||||||
// outputToken which should be transferred to the ssh client,
|
|
||||||
// where the ssh client will present it to InitSecContext.
|
|
||||||
// If no token need be sent, AcceptSecContext will indicate this
|
|
||||||
// by setting the needContinue to false. To
|
|
||||||
// complete the context establishment, one or more reply tokens may be
|
|
||||||
// required from the ssh client. if so, AcceptSecContext
|
|
||||||
// will return a needContinue which is true, in which case it
|
|
||||||
// should be called again when the reply token is received from the ssh
|
|
||||||
// client, passing the token to AcceptSecContext via the
|
|
||||||
// token parameters.
|
|
||||||
// The srcName return value is the authenticated username.
|
|
||||||
// See RFC 2743 section 2.2.2 and RFC 4462 section 3.4.
|
|
||||||
AcceptSecContext(token []byte) (outputToken []byte, srcName string, needContinue bool, err error)
|
|
||||||
// VerifyMIC verifies that a cryptographic MIC, contained in the token parameter,
|
|
||||||
// fits the supplied message is received from the ssh client.
|
|
||||||
// See RFC 2743 section 2.3.2.
|
|
||||||
VerifyMIC(micField []byte, micToken []byte) error
|
|
||||||
// Whenever possible, it should be possible for
|
|
||||||
// DeleteSecContext() calls to be successfully processed even
|
|
||||||
// if other calls cannot succeed, thereby enabling context-related
|
|
||||||
// resources to be released.
|
|
||||||
// In addition to deleting established security contexts,
|
|
||||||
// gss_delete_sec_context must also be able to delete "half-built"
|
|
||||||
// security contexts resulting from an incomplete sequence of
|
|
||||||
// InitSecContext()/AcceptSecContext() calls.
|
|
||||||
// See RFC 2743 section 2.2.3.
|
|
||||||
DeleteSecContext() error
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
// OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication,
|
|
||||||
// so we also support the krb5 mechanism only.
|
|
||||||
// See RFC 1964 section 1.
|
|
||||||
krb5Mesh = asn1.ObjectIdentifier{1, 2, 840, 113554, 1, 2, 2}
|
|
||||||
)
|
|
||||||
|
|
||||||
// The GSS-API authentication method is initiated when the client sends an SSH_MSG_USERAUTH_REQUEST
|
|
||||||
// See RFC 4462 section 3.2.
|
|
||||||
type userAuthRequestGSSAPI struct {
|
|
||||||
N uint32
|
|
||||||
OIDS []asn1.ObjectIdentifier
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseGSSAPIPayload(payload []byte) (*userAuthRequestGSSAPI, error) {
|
|
||||||
n, rest, ok := parseUint32(payload)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("parse uint32 failed")
|
|
||||||
}
|
|
||||||
s := &userAuthRequestGSSAPI{
|
|
||||||
N: n,
|
|
||||||
OIDS: make([]asn1.ObjectIdentifier, n),
|
|
||||||
}
|
|
||||||
for i := 0; i < int(n); i++ {
|
|
||||||
var (
|
|
||||||
desiredMech []byte
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
desiredMech, rest, ok = parseString(rest)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("parse string failed")
|
|
||||||
}
|
|
||||||
if rest, err = asn1.Unmarshal(desiredMech, &s.OIDS[i]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4462 section 3.6.
|
|
||||||
func buildMIC(sessionID string, username string, service string, authMethod string) []byte {
|
|
||||||
out := make([]byte, 0, 0)
|
|
||||||
out = appendString(out, sessionID)
|
|
||||||
out = append(out, msgUserAuthRequest)
|
|
||||||
out = appendString(out, username)
|
|
||||||
out = appendString(out, service)
|
|
||||||
out = appendString(out, authMethod)
|
|
||||||
return out
|
|
||||||
}
|
|
@ -1,116 +0,0 @@
|
|||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message
|
|
||||||
// with "direct-streamlocal@openssh.com" string.
|
|
||||||
//
|
|
||||||
// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding
|
|
||||||
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235
|
|
||||||
type streamLocalChannelOpenDirectMsg struct {
|
|
||||||
socketPath string
|
|
||||||
reserved0 string
|
|
||||||
reserved1 uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message
|
|
||||||
// with "forwarded-streamlocal@openssh.com" string.
|
|
||||||
type forwardedStreamLocalPayload struct {
|
|
||||||
SocketPath string
|
|
||||||
Reserved0 string
|
|
||||||
}
|
|
||||||
|
|
||||||
// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message
|
|
||||||
// with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string.
|
|
||||||
type streamLocalChannelForwardMsg struct {
|
|
||||||
socketPath string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListenUnix is similar to ListenTCP but uses a Unix domain socket.
|
|
||||||
func (c *Client) ListenUnix(socketPath string) (net.Listener, error) {
|
|
||||||
c.handleForwardsOnce.Do(c.handleForwards)
|
|
||||||
m := streamLocalChannelForwardMsg{
|
|
||||||
socketPath,
|
|
||||||
}
|
|
||||||
// send message
|
|
||||||
ok, _, err := c.SendRequest("streamlocal-forward@openssh.com", true, Marshal(&m))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer")
|
|
||||||
}
|
|
||||||
ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"})
|
|
||||||
|
|
||||||
return &unixListener{socketPath, c, ch}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) dialStreamLocal(socketPath string) (Channel, error) {
|
|
||||||
msg := streamLocalChannelOpenDirectMsg{
|
|
||||||
socketPath: socketPath,
|
|
||||||
}
|
|
||||||
ch, in, err := c.OpenChannel("direct-streamlocal@openssh.com", Marshal(&msg))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
go DiscardRequests(in)
|
|
||||||
return ch, err
|
|
||||||
}
|
|
||||||
|
|
||||||
type unixListener struct {
|
|
||||||
socketPath string
|
|
||||||
|
|
||||||
conn *Client
|
|
||||||
in <-chan forward
|
|
||||||
}
|
|
||||||
|
|
||||||
// Accept waits for and returns the next connection to the listener.
|
|
||||||
func (l *unixListener) Accept() (net.Conn, error) {
|
|
||||||
s, ok := <-l.in
|
|
||||||
if !ok {
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
ch, incoming, err := s.newCh.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
go DiscardRequests(incoming)
|
|
||||||
|
|
||||||
return &chanConn{
|
|
||||||
Channel: ch,
|
|
||||||
laddr: &net.UnixAddr{
|
|
||||||
Name: l.socketPath,
|
|
||||||
Net: "unix",
|
|
||||||
},
|
|
||||||
raddr: &net.UnixAddr{
|
|
||||||
Name: "@",
|
|
||||||
Net: "unix",
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the listener.
|
|
||||||
func (l *unixListener) Close() error {
|
|
||||||
// this also closes the listener.
|
|
||||||
l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"})
|
|
||||||
m := streamLocalChannelForwardMsg{
|
|
||||||
l.socketPath,
|
|
||||||
}
|
|
||||||
ok, _, err := l.conn.SendRequest("cancel-streamlocal-forward@openssh.com", true, Marshal(&m))
|
|
||||||
if err == nil && !ok {
|
|
||||||
err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed")
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Addr returns the listener's network address.
|
|
||||||
func (l *unixListener) Addr() net.Addr {
|
|
||||||
return &net.UnixAddr{
|
|
||||||
Name: l.socketPath,
|
|
||||||
Net: "unix",
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,478 +0,0 @@
|
|||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Listen requests the remote peer open a listening socket on
|
|
||||||
// addr. Incoming connections will be available by calling Accept on
|
|
||||||
// the returned net.Listener. The listener must be serviced, or the
|
|
||||||
// SSH connection may hang.
|
|
||||||
// N must be "tcp", "tcp4", "tcp6", or "unix".
|
|
||||||
func (c *Client) Listen(n, addr string) (net.Listener, error) {
|
|
||||||
switch n {
|
|
||||||
case "tcp", "tcp4", "tcp6":
|
|
||||||
laddr, err := net.ResolveTCPAddr(n, addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return c.ListenTCP(laddr)
|
|
||||||
case "unix":
|
|
||||||
return c.ListenUnix(addr)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Automatic port allocation is broken with OpenSSH before 6.0. See
|
|
||||||
// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017. In
|
|
||||||
// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0,
|
|
||||||
// rather than the actual port number. This means you can never open
|
|
||||||
// two different listeners with auto allocated ports. We work around
|
|
||||||
// this by trying explicit ports until we succeed.
|
|
||||||
|
|
||||||
const openSSHPrefix = "OpenSSH_"
|
|
||||||
|
|
||||||
var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano()))
|
|
||||||
|
|
||||||
// isBrokenOpenSSHVersion returns true if the given version string
|
|
||||||
// specifies a version of OpenSSH that is known to have a bug in port
|
|
||||||
// forwarding.
|
|
||||||
func isBrokenOpenSSHVersion(versionStr string) bool {
|
|
||||||
i := strings.Index(versionStr, openSSHPrefix)
|
|
||||||
if i < 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
i += len(openSSHPrefix)
|
|
||||||
j := i
|
|
||||||
for ; j < len(versionStr); j++ {
|
|
||||||
if versionStr[j] < '0' || versionStr[j] > '9' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
version, _ := strconv.Atoi(versionStr[i:j])
|
|
||||||
return version < 6
|
|
||||||
}
|
|
||||||
|
|
||||||
// autoPortListenWorkaround simulates automatic port allocation by
|
|
||||||
// trying random ports repeatedly.
|
|
||||||
func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) {
|
|
||||||
var sshListener net.Listener
|
|
||||||
var err error
|
|
||||||
const tries = 10
|
|
||||||
for i := 0; i < tries; i++ {
|
|
||||||
addr := *laddr
|
|
||||||
addr.Port = 1024 + portRandomizer.Intn(60000)
|
|
||||||
sshListener, err = c.ListenTCP(&addr)
|
|
||||||
if err == nil {
|
|
||||||
laddr.Port = addr.Port
|
|
||||||
return sshListener, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 7.1
|
|
||||||
type channelForwardMsg struct {
|
|
||||||
addr string
|
|
||||||
rport uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleForwards starts goroutines handling forwarded connections.
|
|
||||||
// It's called on first use by (*Client).ListenTCP to not launch
|
|
||||||
// goroutines until needed.
|
|
||||||
func (c *Client) handleForwards() {
|
|
||||||
go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip"))
|
|
||||||
go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListenTCP requests the remote peer open a listening socket
|
|
||||||
// on laddr. Incoming connections will be available by calling
|
|
||||||
// Accept on the returned net.Listener.
|
|
||||||
func (c *Client) ListenTCP(laddr *net.TCPAddr, fakeHost ...string) (net.Listener, error) {
|
|
||||||
c.handleForwardsOnce.Do(c.handleForwards)
|
|
||||||
if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
|
|
||||||
return c.autoPortListenWorkaround(laddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
host := laddr.IP.String()
|
|
||||||
if len(fakeHost) > 0 {
|
|
||||||
host = fakeHost[0]
|
|
||||||
}
|
|
||||||
m := channelForwardMsg{
|
|
||||||
host,
|
|
||||||
uint32(laddr.Port),
|
|
||||||
}
|
|
||||||
// send message
|
|
||||||
ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("ssh: tcpip-forward request denied by peer")
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the original port was 0, then the remote side will
|
|
||||||
// supply a real port number in the response.
|
|
||||||
if laddr.Port == 0 {
|
|
||||||
var p struct {
|
|
||||||
Port uint32
|
|
||||||
}
|
|
||||||
if err := Unmarshal(resp, &p); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
laddr.Port = int(p.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register this forward, using the port number we obtained.
|
|
||||||
ch := c.forwards.add(laddr)
|
|
||||||
|
|
||||||
return &tcpListener{laddr, c, ch}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// forwardList stores a mapping between remote
|
|
||||||
// forward requests and the tcpListeners.
|
|
||||||
type forwardList struct {
|
|
||||||
sync.Mutex
|
|
||||||
entries []forwardEntry
|
|
||||||
}
|
|
||||||
|
|
||||||
// forwardEntry represents an established mapping of a laddr on a
|
|
||||||
// remote ssh server to a channel connected to a tcpListener.
|
|
||||||
type forwardEntry struct {
|
|
||||||
laddr net.Addr
|
|
||||||
c chan forward
|
|
||||||
}
|
|
||||||
|
|
||||||
// forward represents an incoming forwarded tcpip connection. The
|
|
||||||
// arguments to add/remove/lookup should be address as specified in
|
|
||||||
// the original forward-request.
|
|
||||||
type forward struct {
|
|
||||||
newCh NewChannel // the ssh client channel underlying this forward
|
|
||||||
raddr net.Addr // the raddr of the incoming connection
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *forwardList) add(addr net.Addr) chan forward {
|
|
||||||
l.Lock()
|
|
||||||
defer l.Unlock()
|
|
||||||
f := forwardEntry{
|
|
||||||
laddr: addr,
|
|
||||||
c: make(chan forward, 1),
|
|
||||||
}
|
|
||||||
l.entries = append(l.entries, f)
|
|
||||||
return f.c
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 4254, section 7.2
|
|
||||||
type forwardedTCPPayload struct {
|
|
||||||
Addr string
|
|
||||||
Port uint32
|
|
||||||
OriginAddr string
|
|
||||||
OriginPort uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
|
|
||||||
func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
|
|
||||||
if port == 0 || port > 65535 {
|
|
||||||
return nil, fmt.Errorf("ssh: port number out of range: %d", port)
|
|
||||||
}
|
|
||||||
ip := net.ParseIP(string(addr))
|
|
||||||
if ip == nil {
|
|
||||||
return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr)
|
|
||||||
}
|
|
||||||
return &net.TCPAddr{IP: ip, Port: int(port)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *forwardList) handleChannels(in <-chan NewChannel) {
|
|
||||||
for ch := range in {
|
|
||||||
var (
|
|
||||||
laddr net.Addr
|
|
||||||
raddr net.Addr
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
switch channelType := ch.ChannelType(); channelType {
|
|
||||||
case "forwarded-tcpip":
|
|
||||||
var payload forwardedTCPPayload
|
|
||||||
if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
|
|
||||||
ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 section 7.2 specifies that incoming
|
|
||||||
// addresses should list the address, in string
|
|
||||||
// format. It is implied that this should be an IP
|
|
||||||
// address, as it would be impossible to connect to it
|
|
||||||
// otherwise.
|
|
||||||
laddr, err = parseTCPAddr(payload.Addr, payload.Port)
|
|
||||||
if err != nil {
|
|
||||||
ch.Reject(ConnectionFailed, err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
|
|
||||||
if err != nil {
|
|
||||||
ch.Reject(ConnectionFailed, err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
case "forwarded-streamlocal@openssh.com":
|
|
||||||
var payload forwardedStreamLocalPayload
|
|
||||||
if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
|
|
||||||
ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
laddr = &net.UnixAddr{
|
|
||||||
Name: payload.SocketPath,
|
|
||||||
Net: "unix",
|
|
||||||
}
|
|
||||||
raddr = &net.UnixAddr{
|
|
||||||
Name: "@",
|
|
||||||
Net: "unix",
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
|
|
||||||
}
|
|
||||||
if ok := l.forward(laddr, raddr, ch); !ok {
|
|
||||||
// Section 7.2, implementations MUST reject spurious incoming
|
|
||||||
// connections.
|
|
||||||
ch.Reject(Prohibited, "no forward for address")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove removes the forward entry, and the channel feeding its
|
|
||||||
// listener.
|
|
||||||
func (l *forwardList) remove(addr net.Addr) {
|
|
||||||
l.Lock()
|
|
||||||
defer l.Unlock()
|
|
||||||
for i, f := range l.entries {
|
|
||||||
if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
|
|
||||||
l.entries = append(l.entries[:i], l.entries[i+1:]...)
|
|
||||||
close(f.c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// closeAll closes and clears all forwards.
|
|
||||||
func (l *forwardList) closeAll() {
|
|
||||||
l.Lock()
|
|
||||||
defer l.Unlock()
|
|
||||||
for _, f := range l.entries {
|
|
||||||
close(f.c)
|
|
||||||
}
|
|
||||||
l.entries = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
|
|
||||||
l.Lock()
|
|
||||||
defer l.Unlock()
|
|
||||||
for _, f := range l.entries {
|
|
||||||
if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
|
|
||||||
f.c <- forward{newCh: ch, raddr: raddr}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
type tcpListener struct {
|
|
||||||
laddr *net.TCPAddr
|
|
||||||
|
|
||||||
conn *Client
|
|
||||||
in <-chan forward
|
|
||||||
}
|
|
||||||
|
|
||||||
// Accept waits for and returns the next connection to the listener.
|
|
||||||
func (l *tcpListener) Accept() (net.Conn, error) {
|
|
||||||
s, ok := <-l.in
|
|
||||||
if !ok {
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
ch, incoming, err := s.newCh.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
go DiscardRequests(incoming)
|
|
||||||
|
|
||||||
return &chanConn{
|
|
||||||
Channel: ch,
|
|
||||||
laddr: l.laddr,
|
|
||||||
raddr: s.raddr,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the listener.
|
|
||||||
func (l *tcpListener) Close() error {
|
|
||||||
m := channelForwardMsg{
|
|
||||||
l.laddr.IP.String(),
|
|
||||||
uint32(l.laddr.Port),
|
|
||||||
}
|
|
||||||
|
|
||||||
// this also closes the listener.
|
|
||||||
l.conn.forwards.remove(l.laddr)
|
|
||||||
ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
|
|
||||||
if err == nil && !ok {
|
|
||||||
err = errors.New("ssh: cancel-tcpip-forward failed")
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Addr returns the listener's network address.
|
|
||||||
func (l *tcpListener) Addr() net.Addr {
|
|
||||||
return l.laddr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial initiates a connection to the addr from the remote host.
|
|
||||||
// The resulting connection has a zero LocalAddr() and RemoteAddr().
|
|
||||||
func (c *Client) Dial(n, addr string) (net.Conn, error) {
|
|
||||||
var ch Channel
|
|
||||||
switch n {
|
|
||||||
case "tcp", "tcp4", "tcp6":
|
|
||||||
// Parse the address into host and numeric port.
|
|
||||||
host, portString, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
port, err := strconv.ParseUint(portString, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// Use a zero address for local and remote address.
|
|
||||||
zeroAddr := &net.TCPAddr{
|
|
||||||
IP: net.IPv4zero,
|
|
||||||
Port: 0,
|
|
||||||
}
|
|
||||||
return &chanConn{
|
|
||||||
Channel: ch,
|
|
||||||
laddr: zeroAddr,
|
|
||||||
raddr: zeroAddr,
|
|
||||||
}, nil
|
|
||||||
case "unix":
|
|
||||||
var err error
|
|
||||||
ch, err = c.dialStreamLocal(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &chanConn{
|
|
||||||
Channel: ch,
|
|
||||||
laddr: &net.UnixAddr{
|
|
||||||
Name: "@",
|
|
||||||
Net: "unix",
|
|
||||||
},
|
|
||||||
raddr: &net.UnixAddr{
|
|
||||||
Name: addr,
|
|
||||||
Net: "unix",
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialTCP connects to the remote address raddr on the network net,
|
|
||||||
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used
|
|
||||||
// as the local address for the connection.
|
|
||||||
func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) {
|
|
||||||
if laddr == nil {
|
|
||||||
laddr = &net.TCPAddr{
|
|
||||||
IP: net.IPv4zero,
|
|
||||||
Port: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &chanConn{
|
|
||||||
Channel: ch,
|
|
||||||
laddr: laddr,
|
|
||||||
raddr: raddr,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RFC 4254 7.2
|
|
||||||
type channelOpenDirectMsg struct {
|
|
||||||
raddr string
|
|
||||||
rport uint32
|
|
||||||
laddr string
|
|
||||||
lport uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
|
|
||||||
msg := channelOpenDirectMsg{
|
|
||||||
raddr: raddr,
|
|
||||||
rport: uint32(rport),
|
|
||||||
laddr: laddr,
|
|
||||||
lport: uint32(lport),
|
|
||||||
}
|
|
||||||
ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
go DiscardRequests(in)
|
|
||||||
return ch, err
|
|
||||||
}
|
|
||||||
|
|
||||||
type tcpChan struct {
|
|
||||||
Channel // the backing channel
|
|
||||||
}
|
|
||||||
|
|
||||||
// chanConn fulfills the net.Conn interface without
|
|
||||||
// the tcpChan having to hold laddr or raddr directly.
|
|
||||||
type chanConn struct {
|
|
||||||
Channel
|
|
||||||
laddr, raddr net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// LocalAddr returns the local network address.
|
|
||||||
func (t *chanConn) LocalAddr() net.Addr {
|
|
||||||
return t.laddr
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoteAddr returns the remote network address.
|
|
||||||
func (t *chanConn) RemoteAddr() net.Addr {
|
|
||||||
return t.raddr
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDeadline sets the read and write deadlines associated
|
|
||||||
// with the connection.
|
|
||||||
func (t *chanConn) SetDeadline(deadline time.Time) error {
|
|
||||||
if err := t.SetReadDeadline(deadline); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return t.SetWriteDeadline(deadline)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetReadDeadline sets the read deadline.
|
|
||||||
// A zero value for t means Read will not time out.
|
|
||||||
// After the deadline, the error from Read will implement net.Error
|
|
||||||
// with Timeout() == true.
|
|
||||||
func (t *chanConn) SetReadDeadline(deadline time.Time) error {
|
|
||||||
// for compatibility with previous version,
|
|
||||||
// the error message contains "tcpChan"
|
|
||||||
return errors.New("ssh: tcpChan: deadline not supported")
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetWriteDeadline exists to satisfy the net.Conn interface
|
|
||||||
// but is not implemented by this type. It always returns an error.
|
|
||||||
func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
|
|
||||||
return errors.New("ssh: tcpChan: deadline not supported")
|
|
||||||
}
|
|
@ -1,76 +0,0 @@
|
|||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// Package terminal provides support functions for dealing with terminals, as
|
|
||||||
// commonly found on UNIX systems.
|
|
||||||
//
|
|
||||||
// Deprecated: this package moved to golang.org/x/term.
|
|
||||||
package terminal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"golang.org/x/term"
|
|
||||||
)
|
|
||||||
|
|
||||||
// EscapeCodes contains escape sequences that can be written to the terminal in
|
|
||||||
// order to achieve different styles of text.
|
|
||||||
type EscapeCodes = term.EscapeCodes
|
|
||||||
|
|
||||||
// Terminal contains the state for running a VT100 terminal that is capable of
|
|
||||||
// reading lines of input.
|
|
||||||
type Terminal = term.Terminal
|
|
||||||
|
|
||||||
// NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is
|
|
||||||
// a local terminal, that terminal must first have been put into raw mode.
|
|
||||||
// prompt is a string that is written at the start of each input line (i.e.
|
|
||||||
// "> ").
|
|
||||||
func NewTerminal(c io.ReadWriter, prompt string) *Terminal {
|
|
||||||
return term.NewTerminal(c, prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrPasteIndicator may be returned from ReadLine as the error, in addition
|
|
||||||
// to valid line data. It indicates that bracketed paste mode is enabled and
|
|
||||||
// that the returned line consists only of pasted data. Programs may wish to
|
|
||||||
// interpret pasted data more literally than typed data.
|
|
||||||
var ErrPasteIndicator = term.ErrPasteIndicator
|
|
||||||
|
|
||||||
// State contains the state of a terminal.
|
|
||||||
type State = term.State
|
|
||||||
|
|
||||||
// IsTerminal returns whether the given file descriptor is a terminal.
|
|
||||||
func IsTerminal(fd int) bool {
|
|
||||||
return term.IsTerminal(fd)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadPassword reads a line of input from a terminal without local echo. This
|
|
||||||
// is commonly used for inputting passwords and other sensitive data. The slice
|
|
||||||
// returned does not include the \n.
|
|
||||||
func ReadPassword(fd int) ([]byte, error) {
|
|
||||||
return term.ReadPassword(fd)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MakeRaw puts the terminal connected to the given file descriptor into raw
|
|
||||||
// mode and returns the previous state of the terminal so that it can be
|
|
||||||
// restored.
|
|
||||||
func MakeRaw(fd int) (*State, error) {
|
|
||||||
return term.MakeRaw(fd)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Restore restores the terminal connected to the given file descriptor to a
|
|
||||||
// previous state.
|
|
||||||
func Restore(fd int, oldState *State) error {
|
|
||||||
return term.Restore(fd, oldState)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetState returns the current state of a terminal which may be useful to
|
|
||||||
// restore the terminal after a signal.
|
|
||||||
func GetState(fd int) (*State, error) {
|
|
||||||
return term.GetState(fd)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSize returns the dimensions of the given terminal.
|
|
||||||
func GetSize(fd int) (width, height int, err error) {
|
|
||||||
return term.GetSize(fd)
|
|
||||||
}
|
|
@ -1,358 +0,0 @@
|
|||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
)
|
|
||||||
|
|
||||||
// debugTransport if set, will print packet types as they go over the
|
|
||||||
// wire. No message decoding is done, to minimize the impact on timing.
|
|
||||||
const debugTransport = false
|
|
||||||
|
|
||||||
const (
|
|
||||||
gcm128CipherID = "aes128-gcm@openssh.com"
|
|
||||||
gcm256CipherID = "aes256-gcm@openssh.com"
|
|
||||||
aes128cbcID = "aes128-cbc"
|
|
||||||
tripledescbcID = "3des-cbc"
|
|
||||||
)
|
|
||||||
|
|
||||||
// packetConn represents a transport that implements packet based
|
|
||||||
// operations.
|
|
||||||
type packetConn interface {
|
|
||||||
// Encrypt and send a packet of data to the remote peer.
|
|
||||||
writePacket(packet []byte) error
|
|
||||||
|
|
||||||
// Read a packet from the connection. The read is blocking,
|
|
||||||
// i.e. if error is nil, then the returned byte slice is
|
|
||||||
// always non-empty.
|
|
||||||
readPacket() ([]byte, error)
|
|
||||||
|
|
||||||
// Close closes the write-side of the connection.
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// transport is the keyingTransport that implements the SSH packet
|
|
||||||
// protocol.
|
|
||||||
type transport struct {
|
|
||||||
reader connectionState
|
|
||||||
writer connectionState
|
|
||||||
|
|
||||||
bufReader *bufio.Reader
|
|
||||||
bufWriter *bufio.Writer
|
|
||||||
rand io.Reader
|
|
||||||
isClient bool
|
|
||||||
io.Closer
|
|
||||||
}
|
|
||||||
|
|
||||||
// packetCipher represents a combination of SSH encryption/MAC
|
|
||||||
// protocol. A single instance should be used for one direction only.
|
|
||||||
type packetCipher interface {
|
|
||||||
// writeCipherPacket encrypts the packet and writes it to w. The
|
|
||||||
// contents of the packet are generally scrambled.
|
|
||||||
writeCipherPacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error
|
|
||||||
|
|
||||||
// readCipherPacket reads and decrypts a packet of data. The
|
|
||||||
// returned packet may be overwritten by future calls of
|
|
||||||
// readPacket.
|
|
||||||
readCipherPacket(seqnum uint32, r io.Reader) ([]byte, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// connectionState represents one side (read or write) of the
|
|
||||||
// connection. This is necessary because each direction has its own
|
|
||||||
// keys, and can even have its own algorithms
|
|
||||||
type connectionState struct {
|
|
||||||
packetCipher
|
|
||||||
seqNum uint32
|
|
||||||
dir direction
|
|
||||||
pendingKeyChange chan packetCipher
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepareKeyChange sets up key material for a keychange. The key changes in
|
|
||||||
// both directions are triggered by reading and writing a msgNewKey packet
|
|
||||||
// respectively.
|
|
||||||
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
|
|
||||||
ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
t.reader.pendingKeyChange <- ciph
|
|
||||||
|
|
||||||
ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
t.writer.pendingKeyChange <- ciph
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *transport) printPacket(p []byte, write bool) {
|
|
||||||
if len(p) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
who := "server"
|
|
||||||
if t.isClient {
|
|
||||||
who = "client"
|
|
||||||
}
|
|
||||||
what := "read"
|
|
||||||
if write {
|
|
||||||
what = "write"
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println(what, who, p[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read and decrypt next packet.
|
|
||||||
func (t *transport) readPacket() (p []byte, err error) {
|
|
||||||
for {
|
|
||||||
p, err = t.reader.readPacket(t.bufReader)
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if len(p) == 0 || (p[0] != msgIgnore && p[0] != msgDebug) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if debugTransport {
|
|
||||||
t.printPacket(p, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
return p, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
|
|
||||||
packet, err := s.packetCipher.readCipherPacket(s.seqNum, r)
|
|
||||||
s.seqNum++
|
|
||||||
if err == nil && len(packet) == 0 {
|
|
||||||
err = errors.New("ssh: zero length packet")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(packet) > 0 {
|
|
||||||
switch packet[0] {
|
|
||||||
case msgNewKeys:
|
|
||||||
select {
|
|
||||||
case cipher := <-s.pendingKeyChange:
|
|
||||||
s.packetCipher = cipher
|
|
||||||
default:
|
|
||||||
return nil, errors.New("ssh: got bogus newkeys message")
|
|
||||||
}
|
|
||||||
|
|
||||||
case msgDisconnect:
|
|
||||||
// Transform a disconnect message into an
|
|
||||||
// error. Since this is lowest level at which
|
|
||||||
// we interpret message types, doing it here
|
|
||||||
// ensures that we don't have to handle it
|
|
||||||
// elsewhere.
|
|
||||||
var msg disconnectMsg
|
|
||||||
if err := Unmarshal(packet, &msg); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return nil, &msg
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// The packet may point to an internal buffer, so copy the
|
|
||||||
// packet out here.
|
|
||||||
fresh := make([]byte, len(packet))
|
|
||||||
copy(fresh, packet)
|
|
||||||
|
|
||||||
return fresh, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *transport) writePacket(packet []byte) error {
|
|
||||||
if debugTransport {
|
|
||||||
t.printPacket(packet, true)
|
|
||||||
}
|
|
||||||
return t.writer.writePacket(t.bufWriter, t.rand, packet)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error {
|
|
||||||
changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
|
|
||||||
|
|
||||||
err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err = w.Flush(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.seqNum++
|
|
||||||
if changeKeys {
|
|
||||||
select {
|
|
||||||
case cipher := <-s.pendingKeyChange:
|
|
||||||
s.packetCipher = cipher
|
|
||||||
default:
|
|
||||||
panic("ssh: no key material for msgNewKeys")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport {
|
|
||||||
t := &transport{
|
|
||||||
bufReader: bufio.NewReader(rwc),
|
|
||||||
bufWriter: bufio.NewWriter(rwc),
|
|
||||||
rand: rand,
|
|
||||||
reader: connectionState{
|
|
||||||
packetCipher: &streamPacketCipher{cipher: noneCipher{}},
|
|
||||||
pendingKeyChange: make(chan packetCipher, 1),
|
|
||||||
},
|
|
||||||
writer: connectionState{
|
|
||||||
packetCipher: &streamPacketCipher{cipher: noneCipher{}},
|
|
||||||
pendingKeyChange: make(chan packetCipher, 1),
|
|
||||||
},
|
|
||||||
Closer: rwc,
|
|
||||||
}
|
|
||||||
t.isClient = isClient
|
|
||||||
|
|
||||||
if isClient {
|
|
||||||
t.reader.dir = serverKeys
|
|
||||||
t.writer.dir = clientKeys
|
|
||||||
} else {
|
|
||||||
t.reader.dir = clientKeys
|
|
||||||
t.writer.dir = serverKeys
|
|
||||||
}
|
|
||||||
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
type direction struct {
|
|
||||||
ivTag []byte
|
|
||||||
keyTag []byte
|
|
||||||
macKeyTag []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
|
|
||||||
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
|
|
||||||
)
|
|
||||||
|
|
||||||
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
|
|
||||||
// described in RFC 4253, section 6.4. direction should either be serverKeys
|
|
||||||
// (to setup server->client keys) or clientKeys (for client->server keys).
|
|
||||||
func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
|
|
||||||
cipherMode := cipherModes[algs.Cipher]
|
|
||||||
|
|
||||||
iv := make([]byte, cipherMode.ivSize)
|
|
||||||
key := make([]byte, cipherMode.keySize)
|
|
||||||
|
|
||||||
generateKeyMaterial(iv, d.ivTag, kex)
|
|
||||||
generateKeyMaterial(key, d.keyTag, kex)
|
|
||||||
|
|
||||||
var macKey []byte
|
|
||||||
if !aeadCiphers[algs.Cipher] {
|
|
||||||
macMode := macModes[algs.MAC]
|
|
||||||
macKey = make([]byte, macMode.keySize)
|
|
||||||
generateKeyMaterial(macKey, d.macKeyTag, kex)
|
|
||||||
}
|
|
||||||
|
|
||||||
return cipherModes[algs.Cipher].create(key, iv, macKey, algs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateKeyMaterial fills out with key material generated from tag, K, H
|
|
||||||
// and sessionId, as specified in RFC 4253, section 7.2.
|
|
||||||
func generateKeyMaterial(out, tag []byte, r *kexResult) {
|
|
||||||
var digestsSoFar []byte
|
|
||||||
|
|
||||||
h := r.Hash.New()
|
|
||||||
for len(out) > 0 {
|
|
||||||
h.Reset()
|
|
||||||
h.Write(r.K)
|
|
||||||
h.Write(r.H)
|
|
||||||
|
|
||||||
if len(digestsSoFar) == 0 {
|
|
||||||
h.Write(tag)
|
|
||||||
h.Write(r.SessionID)
|
|
||||||
} else {
|
|
||||||
h.Write(digestsSoFar)
|
|
||||||
}
|
|
||||||
|
|
||||||
digest := h.Sum(nil)
|
|
||||||
n := copy(out, digest)
|
|
||||||
out = out[n:]
|
|
||||||
if len(out) > 0 {
|
|
||||||
digestsSoFar = append(digestsSoFar, digest...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const packageVersion = "SSH-2.0-Go"
|
|
||||||
|
|
||||||
// Sends and receives a version line. The versionLine string should
|
|
||||||
// be US ASCII, start with "SSH-2.0-", and should not include a
|
|
||||||
// newline. exchangeVersions returns the other side's version line.
|
|
||||||
func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) {
|
|
||||||
// Contrary to the RFC, we do not ignore lines that don't
|
|
||||||
// start with "SSH-2.0-" to make the library usable with
|
|
||||||
// nonconforming servers.
|
|
||||||
for _, c := range versionLine {
|
|
||||||
// The spec disallows non US-ASCII chars, and
|
|
||||||
// specifically forbids null chars.
|
|
||||||
if c < 32 {
|
|
||||||
return nil, errors.New("ssh: junk character in version line")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
them, err = readVersion(rw)
|
|
||||||
return them, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// maxVersionStringBytes is the maximum number of bytes that we'll
|
|
||||||
// accept as a version string. RFC 4253 section 4.2 limits this at 255
|
|
||||||
// chars
|
|
||||||
const maxVersionStringBytes = 255
|
|
||||||
|
|
||||||
// Read version string as specified by RFC 4253, section 4.2.
|
|
||||||
func readVersion(r io.Reader) ([]byte, error) {
|
|
||||||
versionString := make([]byte, 0, 64)
|
|
||||||
var ok bool
|
|
||||||
var buf [1]byte
|
|
||||||
|
|
||||||
for length := 0; length < maxVersionStringBytes; length++ {
|
|
||||||
_, err := io.ReadFull(r, buf[:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// The RFC says that the version should be terminated with \r\n
|
|
||||||
// but several SSH servers actually only send a \n.
|
|
||||||
if buf[0] == '\n' {
|
|
||||||
if !bytes.HasPrefix(versionString, []byte("SSH-")) {
|
|
||||||
// RFC 4253 says we need to ignore all version string lines
|
|
||||||
// except the one containing the SSH version (provided that
|
|
||||||
// all the lines do not exceed 255 bytes in total).
|
|
||||||
versionString = versionString[:0]
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ok = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// non ASCII chars are disallowed, but we are lenient,
|
|
||||||
// since Go doesn't use null-terminated strings.
|
|
||||||
|
|
||||||
// The RFC allows a comment after a space, however,
|
|
||||||
// all of it (version and comments) goes into the
|
|
||||||
// session hash.
|
|
||||||
versionString = append(versionString, buf[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("ssh: overflow reading version string")
|
|
||||||
}
|
|
||||||
|
|
||||||
// There might be a '\r' on the end which we should remove.
|
|
||||||
if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' {
|
|
||||||
versionString = versionString[:len(versionString)-1]
|
|
||||||
}
|
|
||||||
return versionString, nil
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user