diff --git a/.gitignore b/.gitignore index 175eeab..def3c33 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,10 @@ vendor/ # Build directory build/ +# Protobuf Generated Code +*.pb.go + # Go workspace file go.work .idea +config.yml diff --git a/.tool-versions b/.tool-versions new file mode 100644 index 0000000..36ac5f9 --- /dev/null +++ b/.tool-versions @@ -0,0 +1 @@ +protoc 25.1 diff --git a/Makefile b/Makefile index 370ae85..a7073df 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ GO_VERSION=$(shell go version | sed -e 's/go version //') .PHONY: run clean_backend -build: deps fmt +build: generate deps fmt @echo " ► Building with ${GO_VERSION}" @CGO_ENABLED=0 go build -tags=release -o $(BIN) . @echo $(BIN) @@ -31,3 +31,11 @@ run: clean: @rm -rf $(BUILD_DIR) + +generate: + @echo " ► Performing code generation" + @cd $(ROOT_DIR)/pkg/plugin && go generate + +install_protobuf: + @go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.28 + @go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.2 diff --git a/cmd/root.go b/cmd/root.go index f9d4d90..c6c70e4 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -5,11 +5,14 @@ import ( "os" "os/signal" "syscall" - "time" "github.com/Neur0toxine/sshpoke/internal/config" "github.com/Neur0toxine/sshpoke/internal/docker" "github.com/Neur0toxine/sshpoke/internal/logger" + "github.com/Neur0toxine/sshpoke/internal/model" + "github.com/Neur0toxine/sshpoke/internal/plugin" + "github.com/Neur0toxine/sshpoke/internal/server" + "github.com/go-playground/validator/v10" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -21,20 +24,24 @@ var rootCmd = &cobra.Command{ Short: "Expose your Docker services to the Internet via SSH.", Long: `sshpoke is a CLI application that listens to the docker socket and automatically exposes relevant services to the Internet.`, Run: func(cmd *cobra.Command, args []string) { + go plugin.StartAPIServer() var err error ctx, cancel := context.WithCancel(context.Background()) + server.DefaultManager = server.NewManager(ctx, config.Default.Servers, config.Default.DefaultServer) docker.Default, err = docker.New(ctx) if err != nil { logger.Sugar.Fatalf("cannot connect to docker daemon: %s", err) } for id, item := range docker.Default.Containers() { - logger.Sugar.Debugw("registering container", - "id", id, - "ip", item.IP.String(), - "port", item.Port, - "server", item.Server, - "domain", item.Domain) + err := server.DefaultManager.ProcessEvent(model.Event{ + Type: model.EventStart, + ID: id, + Container: item, + }) + if err != nil { + logger.Sugar.Errorw("cannot expose container", "id", id, "error", err) + } } events, err := docker.Default.Listen() @@ -45,7 +52,10 @@ var rootCmd = &cobra.Command{ go func() { logger.Sugar.Debug("listening for docker events...") for event := range events { - _ = event + err := server.DefaultManager.ProcessEvent(event) + if err != nil { + logger.Sugar.Errorw("cannot expose container", "id", event.ID, "error", err) + } } }() @@ -55,7 +65,7 @@ var rootCmd = &cobra.Command{ switch sig { case os.Interrupt, syscall.SIGQUIT, syscall.SIGTERM: cancel() - time.Sleep(time.Millisecond * 200) + server.DefaultManager.WaitForShutdown() logger.Sugar.Infof("received %s, exiting...", sig) os.Exit(0) default: @@ -90,14 +100,18 @@ func initConfig() { log := logger.New(os.Getenv("SSHPOKE_DEBUG") == "true").Sugar() viper.SetEnvPrefix("SSHPOKE") viper.AutomaticEnv() - if err := config.BindStructEnv(&config.DefaultConfig); err != nil { + if err := config.BindStructEnv(&config.Default); err != nil { log.Fatalf("cannot bind configuration keys: %s", err) } if err := viper.ReadInConfig(); err == nil { log.Debugf("using config file: %s", viper.ConfigFileUsed()) } - if err := viper.Unmarshal(&config.DefaultConfig); err != nil { + if err := viper.Unmarshal(&config.Default); err != nil { log.Fatalf("cannot load configuration: %s", err) } + if err := validator.New().Struct(config.Default); err != nil { + log.Fatalf("invalid configuration: %s", err) + } logger.Initialize() + logger.Sugar.Debugw("configuration loaded", "config", config.Default) } diff --git a/go.mod b/go.mod index 5109fc7..f9d5946 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,15 @@ go 1.21.4 require ( github.com/docker/docker v24.0.7+incompatible github.com/docker/go-connections v0.4.0 + github.com/go-playground/validator/v10 v10.16.0 github.com/mitchellh/mapstructure v1.5.0 github.com/spf13/cast v1.5.1 github.com/spf13/cobra v1.8.0 github.com/spf13/viper v1.17.0 go.uber.org/zap v1.26.0 + golang.org/x/crypto v0.13.0 + google.golang.org/grpc v1.58.2 + google.golang.org/protobuf v1.31.0 ) require ( @@ -18,9 +22,14 @@ require ( github.com/docker/distribution v2.8.3+incompatible // indirect github.com/docker/go-units v0.5.0 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/protobuf v1.5.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/leodido/go-urn v1.2.4 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/moby/term v0.5.0 // indirect github.com/morikuni/aec v1.0.0 // indirect @@ -41,6 +50,7 @@ require ( golang.org/x/sys v0.12.0 // indirect golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.13.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect gotest.tools/v3 v3.5.1 // indirect diff --git a/go.sum b/go.sum index 03cdfb6..c6e67c8 100644 --- a/go.sum +++ b/go.sum @@ -75,9 +75,19 @@ github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0X github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE= +github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -105,6 +115,9 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -116,6 +129,7 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= @@ -156,6 +170,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= @@ -205,6 +221,7 @@ github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= @@ -232,6 +249,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -363,6 +382,8 @@ golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.12.0 h1:/ZfYdc3zq+q02Rv9vGqTeSItdzZTSNDmfTi0mBAuidU= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -495,6 +516,8 @@ google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20201214200347-8c77b98c765d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210108203827-ffc7fda8c3d7/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210226172003-ab064af71705/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13 h1:N3bU/SQDCDyD6R528GJ/PwW9KjYcJA3dgyH+MovAkIM= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13/go.mod h1:KSqppvjFjtoCI+KGd4PELB0qLNxdJHRGqRI09mB6pQA= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -511,6 +534,8 @@ google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8= google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= +google.golang.org/grpc v1.58.2 h1:SXUpjxeVF3FKrTYQI4f4KvbGD5u2xccdYdurwowix5I= +google.golang.org/grpc v1.58.2/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -521,6 +546,10 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= diff --git a/internal/config/model.go b/internal/config/model.go index 17290dd..138b84d 100644 --- a/internal/config/model.go +++ b/internal/config/model.go @@ -8,23 +8,42 @@ import ( "github.com/docker/go-connections/tlsconfig" ) -var DefaultConfig Config +var Default Config type Config struct { - Debug bool `mapstructure:"debug"` - Docker DockerConfig `mapstructure:"docker"` + Debug bool `mapstructure:"debug"` + PluginAPIPort int `mapstructure:"plugin_api_port" validate:"gte=0,lte=65535"` + Docker DockerConfig `mapstructure:"docker"` + DefaultServer string `mapstructure:"default_server"` + Servers []Server `mapstructure:"servers"` } type DockerConfig struct { - FromEnv bool `mapstructure:"from_env"` + FromEnv *bool `mapstructure:"from_env,omitempty"` CertPath string `mapstructure:"cert_path"` TLSVerify *bool `mapstructure:"tls_verify,omitempty"` Host string `mapstructure:"host"` Version string `mapstructure:"version"` } +type DriverParams map[string]interface{} + +type DriverType string + +const ( + DriverSSH DriverType = "ssh" + DriverPlugin DriverType = "plugin" + DriverNull DriverType = "null" +) + +type Server struct { + Name string `mapstructure:"name" validate:"required"` + Driver DriverType `mapstructure:"driver"` + Params DriverParams `mapstructure:"params"` +} + func (d DockerConfig) Opts(c *client.Client) error { - if d.FromEnv { + if d.FromEnv == nil || *d.FromEnv { return client.FromEnv(c) } ops := []client.Opt{ diff --git a/internal/docker/api.go b/internal/docker/api.go index aa36ec1..027422e 100644 --- a/internal/docker/api.go +++ b/internal/docker/api.go @@ -21,7 +21,7 @@ type Docker struct { } func New(ctx context.Context) (*Docker, error) { - cli, err := client.NewClientWithOpts(config.DefaultConfig.Docker.Opts) + cli, err := client.NewClientWithOpts(config.Default.Docker.Opts) if err != nil { return nil, err } @@ -50,13 +50,13 @@ func (d *Docker) Containers() map[string]model.Container { return containers } -func (d *Docker) Listen() (chan model.ContainerEvent, error) { - cli, err := client.NewClientWithOpts(config.DefaultConfig.Docker.Opts) +func (d *Docker) Listen() (chan model.Event, error) { + cli, err := client.NewClientWithOpts(config.Default.Docker.Opts) if err != nil { return nil, err } - output := make(chan model.ContainerEvent) + output := make(chan model.Event) go func() { for { eventSource, errSource := cli.Events(d.ctx, types.EventsOptions{ @@ -68,17 +68,9 @@ func (d *Docker) Listen() (chan model.ContainerEvent, error) { if (eventType != model.EventStart && eventType != model.EventStop) || !actorEnabled(event.Actor) { continue } - if eventType == model.EventStop { - logger.Sugar.Debugw("stopping session", - "type", event.Action, "container.id", event.Actor.ID) - output <- model.ContainerEvent{ - Type: eventType, - ID: event.Actor.ID, - } - continue - } container, err := d.cli.ContainerList(d.ctx, types.ContainerListOptions{ Filters: filters.NewArgs(filters.Arg("id", event.Actor.ID)), + All: true, }) if err != nil || len(container) != 1 { logger.Sugar.Errorw("cannot get container info", @@ -89,17 +81,22 @@ func (d *Docker) Listen() (chan model.ContainerEvent, error) { if !ok { continue } - newEvent := model.ContainerEvent{ + newEvent := model.Event{ Type: eventType, ID: event.Actor.ID, Container: converted, } - logger.Sugar.Debugw("exposing container", + msg := "exposing container" + if eventType == model.EventStop { + msg = "stopping container" + } + logger.Sugar.Debugw(msg, "type", event.Action, "container.id", event.Actor.ID, "container.ip", converted.IP.String(), "container.port", converted.Port, "container.server", converted.Server, + "container.prefix", converted.Prefix, "container.domain", converted.Domain) output <- newEvent case err := <-errSource: diff --git a/internal/docker/convert.go b/internal/docker/convert.go index fe804e7..19be83f 100644 --- a/internal/docker/convert.go +++ b/internal/docker/convert.go @@ -17,6 +17,7 @@ type labelsConfig struct { Network string `mapstructure:"sshpoke.network"` Server string `mapstructure:"sshpoke.server"` Port string `mapstructure:"sshpoke.port"` + Prefix string `mapstructure:"sshpoke.prefix"` Domain string `mapstructure:"sshpoke.domain"` } @@ -80,6 +81,7 @@ func dockerContainerToInternal(container types.Container) (result model.Containe IP: ip, Port: uint16(port), Server: labels.Server, + Prefix: labels.Prefix, Domain: labels.Domain, }, true } diff --git a/internal/logger/logger.go b/internal/logger/logger.go index 47f3929..d297aaf 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -11,17 +11,18 @@ var ( ) func Initialize() { - Default = New(config.DefaultConfig.Debug) + Default = New(config.Default.Debug) Sugar = Default.Sugar() } func New(debug bool) *zap.Logger { - if debug { - logger, _ := zap.NewDevelopment() - return logger - } zapConfig := zap.NewProductionConfig() + if debug { + zapConfig = zap.NewDevelopmentConfig() + } zapConfig.Encoding = "console" + zapConfig.EncoderConfig.CallerKey = "" + zapConfig.EncoderConfig.EncodeCaller = nil logger, _ := zapConfig.Build() return logger } diff --git a/internal/model/container.go b/internal/model/event.go similarity index 82% rename from internal/model/container.go rename to internal/model/event.go index 43fedf1..9e1b4b7 100644 --- a/internal/model/container.go +++ b/internal/model/event.go @@ -21,15 +21,21 @@ func TypeFromAction(action string) EventType { } } -type ContainerEvent struct { +type Event struct { Type EventType ID string Container Container } +type EventRequest struct { + ID string + Error string +} + type Container struct { IP net.IP Port uint16 Server string + Prefix string Domain string } diff --git a/internal/plugin/server.go b/internal/plugin/server.go new file mode 100644 index 0000000..634e0c1 --- /dev/null +++ b/internal/plugin/server.go @@ -0,0 +1,81 @@ +package plugin + +import ( + "context" + "errors" + "fmt" + "net" + + "github.com/Neur0toxine/sshpoke/internal/config" + "github.com/Neur0toxine/sshpoke/internal/logger" + "github.com/Neur0toxine/sshpoke/internal/model" + "github.com/Neur0toxine/sshpoke/internal/server" + "github.com/Neur0toxine/sshpoke/internal/server/driver/plugin" + pb "github.com/Neur0toxine/sshpoke/pkg/plugin" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/types/known/emptypb" +) + +var ErrUnauthorized = errors.New("unauthorized") + +type pluginAPI struct { + pb.UnimplementedPluginServiceServer +} + +func (p *pluginAPI) Event(stream pb.PluginService_EventServer) error { + pl := p.receiverForContext(stream.Context()) + if pl == nil { + return ErrUnauthorized + } + logger.Sugar.Debugw("attached plugin event stream", "serverName", pl.Name()) + err := pl.Listen(stream.Context(), &Stream{stream: stream}) + if err != nil { + logger.Sugar.Debugw("detached plugin event stream", "serverName", pl.Name(), "error", err) + return err + } + logger.Sugar.Debugw("detached plugin event stream", "serverName", pl.Name()) + return nil +} + +func (p *pluginAPI) EventStatus(ctx context.Context, msg *pb.EventStatusMessage) (*emptypb.Empty, error) { + pl := p.receiverForContext(ctx) + if pl == nil { + return nil, ErrUnauthorized + } + pl.HandleStatus(model.EventRequest{ + ID: msg.Id, + Error: msg.Error, + }) + return &emptypb.Empty{}, nil +} + +func (p *pluginAPI) receiverForContext(ctx context.Context) *plugin.Plugin { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil + } + tokens := md.Get("token") + if len(tokens) != 1 { + return nil + } + return server.DefaultManager.PluginByToken(tokens[0]) +} + +func StartAPIServer() { + port := config.Default.PluginAPIPort + if port == 0 { + port = 3000 + } + socket, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + logger.Sugar.Errorf("cannot start plugin API server on port %d: %s", port, err) + return + } + s := grpc.NewServer() + pb.RegisterPluginServiceServer(s, &pluginAPI{}) + logger.Sugar.Debugf("starting plugin server on :%d", port) + if err := s.Serve(socket); err != nil { + logger.Sugar.Fatalf("cannot start plugin server on :%d: %s", port, err) + } +} diff --git a/internal/plugin/stream.go b/internal/plugin/stream.go new file mode 100644 index 0000000..5267b0e --- /dev/null +++ b/internal/plugin/stream.go @@ -0,0 +1,70 @@ +package plugin + +import ( + "net" + + "github.com/Neur0toxine/sshpoke/internal/model" + pb "github.com/Neur0toxine/sshpoke/pkg/plugin" +) + +type Stream struct { + stream pb.PluginService_EventServer +} + +func (s *Stream) Recv() error { + _, err := s.stream.Recv() + return err +} + +func (s *Stream) Send(event model.Event) error { + return s.stream.Send(s.eventToMessage(event)) +} + +func (s *Stream) messageToEvent(event *pb.EventMessage) model.Event { + return model.Event{ + Type: s.pbEventTypeToApp(event.Type), + ID: event.Id, + Container: model.Container{ + IP: net.ParseIP(event.Container.Ip), + Port: uint16(event.Container.Port), + Server: event.Container.Server, + Prefix: event.Container.Prefix, + Domain: event.Container.Domain, + }, + } +} + +func (s *Stream) eventToMessage(event model.Event) *pb.EventMessage { + return &pb.EventMessage{ + Type: s.appEventTypeToPB(event.Type), + Id: event.ID, + Container: &pb.Container{ + Ip: event.Container.IP.String(), + Port: uint32(event.Container.Port), + Server: event.Container.Server, + Prefix: event.Container.Prefix, + Domain: event.Container.Domain, + }, + } +} + +func (s *Stream) pbEventTypeToApp(typ pb.EventType) model.EventType { + val := model.EventType(typ.Number()) + if val > model.EventStart { + return model.EventUnknown + } + return val +} + +func (s *Stream) appEventTypeToPB(typ model.EventType) pb.EventType { + switch typ { + case 0: + return pb.EventType_EVENT_START + case 1: + return pb.EventType_EVENT_STOP + case 2: + fallthrough + default: + return pb.EventType_EVENT_UNKNOWN + } +} diff --git a/internal/server/driver/construct.go b/internal/server/driver/construct.go new file mode 100644 index 0000000..2fc23a2 --- /dev/null +++ b/internal/server/driver/construct.go @@ -0,0 +1,24 @@ +package driver + +import ( + "context" + + "github.com/Neur0toxine/sshpoke/internal/config" + "github.com/Neur0toxine/sshpoke/internal/server/driver/iface" + "github.com/Neur0toxine/sshpoke/internal/server/driver/null" + "github.com/Neur0toxine/sshpoke/internal/server/driver/plugin" + "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh" +) + +func New(ctx context.Context, name string, driver config.DriverType, params config.DriverParams) (iface.Driver, error) { + switch driver { + case config.DriverSSH: + return ssh.New(ctx, name, params) + case config.DriverPlugin: + return plugin.New(ctx, name, params) + case config.DriverNull: + fallthrough + default: + return null.New(ctx, name, params) + } +} diff --git a/internal/server/driver/iface/driver.go b/internal/server/driver/iface/driver.go new file mode 100644 index 0000000..0af39a6 --- /dev/null +++ b/internal/server/driver/iface/driver.go @@ -0,0 +1,16 @@ +package iface + +import ( + "context" + + "github.com/Neur0toxine/sshpoke/internal/config" + "github.com/Neur0toxine/sshpoke/internal/model" +) + +type DriverConstructor func(ctx context.Context, name string, params config.DriverParams) (Driver, error) + +type Driver interface { + Handle(event model.Event) error + Driver() config.DriverType + WaitForShutdown() +} diff --git a/internal/server/driver/null/driver.go b/internal/server/driver/null/driver.go new file mode 100644 index 0000000..b1d0979 --- /dev/null +++ b/internal/server/driver/null/driver.go @@ -0,0 +1,31 @@ +package null + +import ( + "context" + + "github.com/Neur0toxine/sshpoke/internal/config" + "github.com/Neur0toxine/sshpoke/internal/logger" + "github.com/Neur0toxine/sshpoke/internal/model" + "github.com/Neur0toxine/sshpoke/internal/server/driver/iface" +) + +// Null driver only logs container events to debug log. It is used when user provides invalid driver type. +// You can use it directly, but it won't do anything, so... why bother? +type Null struct { + name string +} + +func New(ctx context.Context, name string, params config.DriverParams) (iface.Driver, error) { + return &Null{name: name}, nil +} + +func (d *Null) Handle(event model.Event) error { + logger.Sugar.Debugw("handling event with null driver", "serverName", d.name, "event", event) + return nil +} + +func (d *Null) Driver() config.DriverType { + return config.DriverNull +} + +func (d *Null) WaitForShutdown() {} diff --git a/internal/server/driver/plugin/driver.go b/internal/server/driver/plugin/driver.go new file mode 100644 index 0000000..664b239 --- /dev/null +++ b/internal/server/driver/plugin/driver.go @@ -0,0 +1,111 @@ +package plugin + +import ( + "context" + "errors" + "io" + + "github.com/Neur0toxine/sshpoke/internal/config" + "github.com/Neur0toxine/sshpoke/internal/logger" + "github.com/Neur0toxine/sshpoke/internal/model" + "github.com/Neur0toxine/sshpoke/internal/server/driver/iface" + "github.com/Neur0toxine/sshpoke/internal/server/driver/util" +) + +// Plugin driver uses RPC to communicate with external plugin. +type Plugin struct { + ctx context.Context + name string + params Params + send chan model.Event +} + +type EventStream interface { + Send(event model.Event) error + Recv() error +} + +func New(ctx context.Context, name string, params config.DriverParams) (iface.Driver, error) { + drv := &Plugin{ + name: name, + ctx: ctx, + send: make(chan model.Event), + } + if err := util.UnmarshalParams(params, &drv.params); err != nil { + return nil, err + } + return drv, nil +} + +func (d *Plugin) Handle(event model.Event) error { + if d.isDone() { + return nil + } + d.send <- event + return nil +} + +func (d *Plugin) Name() string { + return d.name +} + +func (d *Plugin) Driver() config.DriverType { + return config.DriverPlugin +} + +func (d *Plugin) Token() string { + return d.params.Token +} + +func (d *Plugin) Listen(ctx context.Context, stream EventStream) error { + for { + select { + case <-ctx.Done(): + return nil + default: + } + + err := stream.Recv() + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + logger.Sugar.Errorw("error reading poll event from plugin", + "server", d.name, "error", err) + return err + } + select { + case <-ctx.Done(): + return nil + case event := <-d.send: + err := stream.Send(event) + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + logger.Sugar.Errorw("error writing event to plugin", + "server", d.name, "error", err) + return err + } + } + } +} + +func (d *Plugin) HandleStatus(event model.EventRequest) { + logger.Sugar.Errorw("plugin error", "serverName", d.name, "id", event.ID, "error", event.Error) +} + +func (d *Plugin) isDone() bool { + select { + case <-d.ctx.Done(): + close(d.send) + return true + default: + return false + } +} + +func (d *Plugin) WaitForShutdown() { + <-d.ctx.Done() + return +} diff --git a/internal/server/driver/plugin/params.go b/internal/server/driver/plugin/params.go new file mode 100644 index 0000000..122f009 --- /dev/null +++ b/internal/server/driver/plugin/params.go @@ -0,0 +1,13 @@ +package plugin + +import ( + "github.com/Neur0toxine/sshpoke/internal/server/driver/util" +) + +type Params struct { + Token string `mapstructure:"token" validate:"required"` +} + +func (p *Params) Validate() error { + return util.Validator.Struct(p) +} diff --git a/internal/server/driver/ssh/driver.go b/internal/server/driver/ssh/driver.go new file mode 100644 index 0000000..09a0eca --- /dev/null +++ b/internal/server/driver/ssh/driver.go @@ -0,0 +1,47 @@ +package ssh + +import ( + "context" + "errors" + "sync" + + "github.com/Neur0toxine/sshpoke/internal/config" + "github.com/Neur0toxine/sshpoke/internal/model" + "github.com/Neur0toxine/sshpoke/internal/server/driver/iface" + "github.com/Neur0toxine/sshpoke/internal/server/driver/util" + "github.com/Neur0toxine/sshpoke/internal/server/proto/sshtun" +) + +type SSH struct { + ctx context.Context + name string + params Params + sessions map[string]conn + wg sync.WaitGroup +} + +type conn struct { + container model.Container + tun *sshtun.Tunnel +} + +func New(ctx context.Context, name string, params config.DriverParams) (iface.Driver, error) { + drv := &SSH{ctx: ctx, name: name, sessions: make(map[string]conn)} + if err := util.UnmarshalParams(params, &drv.params); err != nil { + return nil, err + } + return drv, nil +} + +func (d *SSH) Handle(event model.Event) error { + // TODO: Implement event handling & connections management. + return errors.New(d.name + " server handler is not implemented yet") +} + +func (d *SSH) Driver() config.DriverType { + return config.DriverSSH +} + +func (d *SSH) WaitForShutdown() { + d.wg.Wait() +} diff --git a/internal/server/driver/ssh/params.go b/internal/server/driver/ssh/params.go new file mode 100644 index 0000000..429c621 --- /dev/null +++ b/internal/server/driver/ssh/params.go @@ -0,0 +1,63 @@ +package ssh + +import ( + "fmt" + + "github.com/Neur0toxine/sshpoke/internal/server/driver/util" +) + +type Params struct { + Address string `mapstructure:"address" validate:"required"` + Auth Auth `mapstructure:"auth"` + KeepAlive KeepAlive `mapstructure:"keepalive"` + Domain string `mapstructure:"domain"` + DomainProto string `mapstructure:"domain_proto"` + DomainExtractRegex string `mapstructure:"domain_extract_regex" validate:"validregexp"` + Mode DomainMode `mapstructure:"mode" validate:"required,oneof=single multi"` + Prefix bool `mapstructure:"prefix"` +} + +type AuthType string + +const ( + AuthTypePasswordless AuthType = "passwordless" + AuthTypePassword AuthType = "password" + AuthTypeKey AuthType = "key" +) + +type DomainMode string + +const ( + DomainModeSingle DomainMode = "single" + DomainModeMulti DomainMode = "multi" +) + +type Auth struct { + Type AuthType `mapstructure:"type" validate:"required,oneof=passwordless password key"` + User string `mapstructure:"user"` + Password string `mapstructure:"password"` + Directory string `mapstructure:"directory"` + Keyfile string `mapstructure:"keyfile"` +} + +func (a Auth) validate() error { + if a.Type == AuthTypePassword && a.Password == "" { + return fmt.Errorf("password must be provided for authentication type '%s'", AuthTypePassword) + } + if a.Type == AuthTypeKey && a.Directory == "" { + return fmt.Errorf("password must be provided for authentication type '%s'", AuthTypePassword) + } + return nil +} + +type KeepAlive struct { + Interval int `mapstructure:"interval" validate:"gte=0"` + MaxAttempts int `mapstructure:"max_attempts" validate:"gte=1"` +} + +func (p *Params) Validate() error { + if err := util.Validator.Struct(p); err != nil { + return err + } + return p.Auth.validate() +} diff --git a/internal/server/driver/util/params.go b/internal/server/driver/util/params.go new file mode 100644 index 0000000..24ab7e1 --- /dev/null +++ b/internal/server/driver/util/params.go @@ -0,0 +1,20 @@ +package util + +import ( + "github.com/Neur0toxine/sshpoke/internal/config" + "github.com/mitchellh/mapstructure" +) + +type ValidationAvailable interface { + Validate() error +} + +func UnmarshalParams(params config.DriverParams, target ValidationAvailable) error { + if err := mapstructure.Decode(params, target); err != nil { + return err + } + if val, canValidate := target.(ValidationAvailable); canValidate { + return val.Validate() + } + return nil +} diff --git a/internal/server/driver/util/validator.go b/internal/server/driver/util/validator.go new file mode 100644 index 0000000..2d49764 --- /dev/null +++ b/internal/server/driver/util/validator.go @@ -0,0 +1,23 @@ +package util + +import ( + "regexp" + + "github.com/go-playground/validator/v10" +) + +var Validator *validator.Validate + +func init() { + Validator = validator.New() + _ = Validator.RegisterValidation("validregexp", isValidRegExp) +} + +func isValidRegExp(fl validator.FieldLevel) bool { + expr := fl.Field().String() + if expr == "" { + return true + } + _, err := regexp.Compile(expr) + return err == nil +} diff --git a/internal/server/manager.go b/internal/server/manager.go new file mode 100644 index 0000000..d789d26 --- /dev/null +++ b/internal/server/manager.go @@ -0,0 +1,84 @@ +package server + +import ( + "context" + "errors" + "sync" + + "github.com/Neur0toxine/sshpoke/internal/config" + "github.com/Neur0toxine/sshpoke/internal/logger" + "github.com/Neur0toxine/sshpoke/internal/model" + "github.com/Neur0toxine/sshpoke/internal/server/driver" + "github.com/Neur0toxine/sshpoke/internal/server/driver/iface" + "github.com/Neur0toxine/sshpoke/internal/server/driver/plugin" +) + +type Manager struct { + rw sync.RWMutex + servers map[string]iface.Driver + defaultServer string +} + +var DefaultManager *Manager +var ( + ErrNoServer = errors.New("server is not specified") + ErrNoSuchServer = errors.New("server does not exist") +) + +func NewManager(ctx context.Context, servers []config.Server, defaultServer string) *Manager { + m := &Manager{ + servers: make(map[string]iface.Driver), + defaultServer: defaultServer, + } + for _, serverConfig := range servers { + server, err := driver.New(ctx, serverConfig.Name, serverConfig.Driver, serverConfig.Params) + if err != nil { + logger.Sugar.Errorf("cannot initialize server '%s': %s", serverConfig.Name, err) + continue + } + m.servers[serverConfig.Name] = server + } + return m +} + +func (m *Manager) ProcessEvent(event model.Event) error { + serverName := event.Container.Server + if serverName == "" { + serverName = m.defaultServer + } + if serverName == "" { + return ErrNoServer + } + defer m.rw.RUnlock() + m.rw.RLock() + srv, ok := m.servers[event.Container.Server] + if !ok { + return ErrNoSuchServer + } + return srv.Handle(event) +} + +func (m *Manager) PluginByToken(token string) *plugin.Plugin { + defer m.rw.RUnlock() + m.rw.RLock() + for _, srv := range m.servers { + if srv.Driver() != config.DriverPlugin { + continue + } + pl := srv.(*plugin.Plugin) + if pl.Token() != token { + continue + } + return pl + } + return nil +} + +func (m *Manager) WaitForShutdown() { + defer m.rw.RUnlock() + m.rw.RLock() + for _, srv := range m.servers { + srv.WaitForShutdown() + } + return +} diff --git a/internal/server/proto/sshtun/tunnel.go b/internal/server/proto/sshtun/tunnel.go new file mode 100644 index 0000000..9addddc --- /dev/null +++ b/internal/server/proto/sshtun/tunnel.go @@ -0,0 +1,249 @@ +// Copyright 2017, The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE.md file. + +package sshtun + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "golang.org/x/crypto/ssh" +) + +type TunnelMode uint8 + +func (t TunnelMode) String() string { + switch t { + case TunnelForward: + return "->" + case TunnelReverse: + return "<-" + default: + return "" + } +} + +const ( + TunnelForward TunnelMode = iota + TunnelReverse +) + +type logger interface { + Printf(string, ...interface{}) +} + +type Tunnel struct { + Auth []ssh.AuthMethod + HostKeys ssh.HostKeyCallback + Mode TunnelMode + User string + HostAddr string + BindAddr string + DialAddr string + RetryInterval time.Duration + KeepAlive KeepAliveConfig + Logger logger +} + +type KeepAliveConfig struct { + // Interval is the amount of time in seconds to wait before the + // Tunnel client will send a keep-alive message to ensure some minimum + // traffic on the SSH connection. + Interval uint + + // CountMax is the maximum number of consecutive failed responses to + // keep-alive messages the client is willing to tolerate before considering + // the SSH connection as dead. + CountMax uint +} + +func (t Tunnel) String() string { + var left, right string + switch t.Mode { + case TunnelForward: + left, right = t.BindAddr, t.DialAddr + case TunnelReverse: + left, right = t.DialAddr, t.BindAddr + } + return fmt.Sprintf("%s@%s | %s %s %s", t.User, t.HostAddr, left, t.Mode, right) +} + +func (t Tunnel) Bind(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + + for { + var once sync.Once // Only print errors once per session + func() { + // Connect to the server host via SSH. + cl, err := ssh.Dial("tcp", t.HostAddr, &ssh.ClientConfig{ + User: t.User, + Auth: t.Auth, + HostKeyCallback: t.HostKeys, + Timeout: 5 * time.Second, + }) + if err != nil { + once.Do(func() { t.Logger.Printf("(%v) SSH dial error: %v", t, err) }) + return + } + wg.Add(1) + go t.keepAliveMonitor(&once, wg, cl) + defer cl.Close() + + // Attempt to bind to the inbound socket. + var ln net.Listener + switch t.Mode { + case TunnelForward: + ln, err = net.Listen("tcp", t.BindAddr) + case TunnelReverse: + ln, err = cl.Listen("tcp", t.BindAddr) + } + if err != nil { + once.Do(func() { t.Logger.Printf("(%v) bind error: %v", t, err) }) + return + } + + // The socket is bound. Make sure we close it eventually. + bindCtx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + cl.Wait() + cancel() + }() + go func() { + <-bindCtx.Done() + once.Do(func() {}) // Suppress future errors + ln.Close() + }() + + t.Logger.Printf("(%v) binded Tunnel", t) + defer t.Logger.Printf("(%v) collapsed Tunnel", t) + + // Accept all incoming connections. + for { + cn1, err := ln.Accept() + if err != nil { + once.Do(func() { t.Logger.Printf("(%v) accept error: %v", t, err) }) + return + } + wg.Add(1) + go t.dialTunnel(bindCtx, wg, cl, cn1) + } + }() + + select { + case <-ctx.Done(): + return + case <-time.After(t.RetryInterval): + t.Logger.Printf("(%v) retrying...", t) + } + } +} + +func (t Tunnel) dialTunnel(ctx context.Context, wg *sync.WaitGroup, client *ssh.Client, cn1 net.Conn) { + defer wg.Done() + + // The inbound connection is established. Make sure we close it eventually. + connCtx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + <-connCtx.Done() + cn1.Close() + }() + + // Establish the outbound connection. + var cn2 net.Conn + var err error + switch t.Mode { + case TunnelForward: + cn2, err = client.Dial("tcp", t.DialAddr) + case TunnelReverse: + cn2, err = net.Dial("tcp", t.DialAddr) + } + if err != nil { + t.Logger.Printf("(%v) dial error: %v", t, err) + return + } + + go func() { + <-connCtx.Done() + cn2.Close() + }() + + t.Logger.Printf("(%v) connection established", t) + defer t.Logger.Printf("(%v) connection closed", t) + + // Copy bytes from one connection to the other until one side closes. + var once sync.Once + var wg2 sync.WaitGroup + wg2.Add(2) + go func() { + defer wg2.Done() + defer cancel() + if _, err := io.Copy(cn1, cn2); err != nil { + once.Do(func() { t.Logger.Printf("(%v) connection error: %v", t, err) }) + } + once.Do(func() {}) // Suppress future errors + }() + go func() { + defer wg2.Done() + defer cancel() + if _, err := io.Copy(cn2, cn1); err != nil { + once.Do(func() { t.Logger.Printf("(%v) connection error: %v", t, err) }) + } + once.Do(func() {}) // Suppress future errors + }() + wg2.Wait() +} + +// keepAliveMonitor periodically sends messages to invoke a response. +// If the server does not respond after some period of time, +// assume that the underlying net.Conn abruptly died. +func (t Tunnel) keepAliveMonitor(once *sync.Once, wg *sync.WaitGroup, client *ssh.Client) { + defer wg.Done() + if t.KeepAlive.Interval == 0 || t.KeepAlive.CountMax == 0 { + return + } + + // Detect when the SSH connection is closed. + wait := make(chan error, 1) + wg.Add(1) + go func() { + defer wg.Done() + wait <- client.Wait() + }() + + // Repeatedly check if the remote server is still alive. + var aliveCount int32 + ticker := time.NewTicker(time.Duration(t.KeepAlive.Interval) * time.Second) + defer ticker.Stop() + for { + select { + case err := <-wait: + if err != nil && err != io.EOF { + once.Do(func() { t.Logger.Printf("(%v) SSH error: %v", t, err) }) + } + return + case <-ticker.C: + if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.KeepAlive.CountMax) { + once.Do(func() { t.Logger.Printf("(%v) SSH keep-alive termination", t) }) + client.Close() + return + } + } + + wg.Add(1) + go func() { + defer wg.Done() + _, _, err := client.SendRequest("keepalive@openssh.com", true, nil) + if err == nil { + atomic.StoreInt32(&aliveCount, 0) + } + }() + } +} diff --git a/internal/server/proto/sshtun/tunnel_test.go b/internal/server/proto/sshtun/tunnel_test.go new file mode 100644 index 0000000..2561d62 --- /dev/null +++ b/internal/server/proto/sshtun/tunnel_test.go @@ -0,0 +1,509 @@ +// Copyright 2017, The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE.md file. + +package sshtun + +import ( + "bytes" + "context" + "crypto/md5" + "crypto/rsa" + "encoding/binary" + "fmt" + "io" + "io/ioutil" + "math/rand" + "net" + "reflect" + "strconv" + "sync" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +type testLogger struct { + *testing.T // Already has Fatalf method +} + +func (t testLogger) Printf(f string, x ...interface{}) { t.Logf(f, x...) } + +func TestTunnel(t *testing.T) { + rootWG := new(sync.WaitGroup) + defer rootWG.Wait() + rootCtx, cancelAll := context.WithCancel(context.Background()) + defer cancelAll() + + // Open all of the TCP sockets needed for the test. + tcpLn0 := openListener(t) // Start of the chain + tcpLn1 := openListener(t) // Mid-point of the chain + tcpLn2 := openListener(t) // End of the chain + srvLn0 := openListener(t) // Socket for SSH server in reverse Mode + srvLn1 := openListener(t) // Socket for SSH server in forward Mode + + tcpLn0.Close() // To be later binded by the reverse Tunnel + tcpLn1.Close() // To be later binded by the forward Tunnel + go closeWhenDone(rootCtx, tcpLn2) + go closeWhenDone(rootCtx, srvLn0) + go closeWhenDone(rootCtx, srvLn1) + + // Generate keys for both the servers and clients. + clientPriv0, clientPub0 := generateKeys(t) + clientPriv1, clientPub1 := generateKeys(t) + serverPriv0, serverPub0 := generateKeys(t) + serverPriv1, serverPub1 := generateKeys(t) + + // Start the SSH servers. + rootWG.Add(2) + go func() { + defer rootWG.Done() + runServer(t, rootCtx, srvLn0, serverPriv0, clientPub0, clientPub1) + }() + go func() { + defer rootWG.Done() + runServer(t, rootCtx, srvLn1, serverPriv1, clientPub0, clientPub1) + }() + + wg := new(sync.WaitGroup) + defer wg.Wait() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Create the Tunnel configurations. + tn0 := Tunnel{ + Auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv0)}, + HostKeys: ssh.FixedHostKey(serverPub0), + Mode: TunnelReverse, // Reverse Tunnel + User: "user0", + HostAddr: srvLn0.Addr().String(), + BindAddr: tcpLn0.Addr().String(), + DialAddr: tcpLn1.Addr().String(), + Logger: testLogger{t}, + } + tn1 := Tunnel{ + Auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv1)}, + HostKeys: ssh.FixedHostKey(serverPub1), + Mode: TunnelForward, // Forward Tunnel + User: "user1", + HostAddr: srvLn1.Addr().String(), + BindAddr: tcpLn1.Addr().String(), + DialAddr: tcpLn2.Addr().String(), + Logger: testLogger{t}, + } + + // Start the SSH client tunnels. + wg.Add(2) + go tn0.Bind(ctx, wg) + go tn1.Bind(ctx, wg) + + t.Log("test started") + done := make(chan bool, 10) + + // Start all the transmitters. + for i := 0; i < cap(done); i++ { + i := i + go func() { + for { + rnd := rand.New(rand.NewSource(int64(i))) + hash := md5.New() + size := uint32((1 << 10) + rnd.Intn(1<<20)) + buf4 := make([]byte, 4) + binary.LittleEndian.PutUint32(buf4, size) + + cnStart, err := net.Dial("tcp", tcpLn0.Addr().String()) + if err != nil { + time.Sleep(10 * time.Millisecond) + continue + } + defer cnStart.Close() + if _, err := cnStart.Write(buf4); err != nil { + t.Errorf("write size error: %v", err) + break + } + r := io.LimitReader(rnd, int64(size)) + w := io.MultiWriter(cnStart, hash) + if _, err := io.Copy(w, r); err != nil { + t.Errorf("copy error: %v", err) + break + } + if _, err := cnStart.Write(hash.Sum(nil)); err != nil { + t.Errorf("write hash error: %v", err) + break + } + if err := cnStart.Close(); err != nil { + t.Errorf("close error: %v", err) + break + } + break + } + }() + } + + // Start all the receivers. + for i := 0; i < cap(done); i++ { + go func() { + for { + hash := md5.New() + buf4 := make([]byte, 4) + + cnEnd, err := tcpLn2.Accept() + if err != nil { + time.Sleep(10 * time.Millisecond) + continue + } + defer cnEnd.Close() + + if _, err := io.ReadFull(cnEnd, buf4); err != nil { + t.Errorf("read size error: %v", err) + break + } + size := binary.LittleEndian.Uint32(buf4) + r := io.LimitReader(cnEnd, int64(size)) + if _, err := io.Copy(hash, r); err != nil { + t.Errorf("copy error: %v", err) + break + } + wantHash, err := ioutil.ReadAll(cnEnd) + if err != nil { + t.Errorf("read hash error: %v", err) + break + } + if err := cnEnd.Close(); err != nil { + t.Errorf("close error: %v", err) + break + } + + if gotHash := hash.Sum(nil); !bytes.Equal(gotHash, wantHash) { + t.Errorf("hash mismatch:\ngot %x\nwant %x", gotHash, wantHash) + } + break + } + done <- true + }() + } + + for i := 0; i < cap(done); i++ { + select { + case <-done: + case <-time.After(10 * time.Second): + t.Errorf("timed out: %d remaining", cap(done)-i) + return + } + } + t.Log("test complete") +} + +// generateKeys generates a random pair of SSH private and public keys. +func generateKeys(t *testing.T) (priv ssh.Signer, pub ssh.PublicKey) { + rnd := rand.New(rand.NewSource(time.Now().Unix())) + rsaKey, err := rsa.GenerateKey(rnd, 1024) + if err != nil { + t.Fatalf("unable to generate RSA key pair: %v", err) + } + priv, err = ssh.NewSignerFromKey(rsaKey) + if err != nil { + t.Fatalf("unable to generate signer: %v", err) + } + pub, err = ssh.NewPublicKey(&rsaKey.PublicKey) + if err != nil { + t.Fatalf("unable to generate public key: %v", err) + } + return priv, pub +} + +func openListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("listen error: %v", err) + } + return ln +} + +// runServer starts an SSH server capable of handling forward and reverse +// TCP tunnels. This function blocks for the entire duration that the +// server is running and can be stopped by canceling the context. +// +// The server listens on the provided Listener and will present to clients +// a certificate from serverKey and will only accept users that match +// the provided clientKeys. Only users of the name "User%d" are allowed where +// the ID number is the index for the specified client key provided. +func runServer(t *testing.T, ctx context.Context, ln net.Listener, serverKey ssh.Signer, clientKeys ...ssh.PublicKey) { + wg := new(sync.WaitGroup) + defer wg.Wait() + + // Generate SSH server configuration. + conf := ssh.ServerConfig{ + PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + var uid int + _, err := fmt.Sscanf(c.User(), "User%d", &uid) + if err != nil || uid >= len(clientKeys) || !bytes.Equal(clientKeys[uid].Marshal(), pubKey.Marshal()) { + return nil, fmt.Errorf("unknown public key for %q", c.User()) + } + return nil, nil + }, + } + conf.AddHostKey(serverKey) + + // Handle every SSH client connection. + for { + tcpCn, err := ln.Accept() + if err != nil { + if !isDone(ctx) { + t.Errorf("accept error: %v", err) + } + return + } + wg.Add(1) + go handleServerConn(t, ctx, wg, tcpCn, &conf) + } +} + +// handleServerConn handles a single SSH connection. +func handleServerConn(t *testing.T, ctx context.Context, wg *sync.WaitGroup, tcpCn net.Conn, conf *ssh.ServerConfig) { + defer wg.Done() + go closeWhenDone(ctx, tcpCn) + defer tcpCn.Close() + + sshCn, chans, reqs, err := ssh.NewServerConn(tcpCn, conf) + if err != nil { + t.Errorf("new connection error: %v", err) + return + } + go closeWhenDone(ctx, sshCn) + defer sshCn.Close() + + wg.Add(1) + go handleServerChannels(t, ctx, wg, sshCn, chans) + + wg.Add(1) + go handleServerRequests(t, ctx, wg, sshCn, reqs) + + if err := sshCn.Wait(); err != nil && err != io.EOF && !isDone(ctx) { + t.Errorf("connection error: %v", err) + } +} + +// handleServerChannels handles new channels on a SSH connection. +// The client initiates a new channel when forwarding a TCP dial. +func handleServerChannels(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, chans <-chan ssh.NewChannel) { + defer wg.Done() + for nc := range chans { + if nc.ChannelType() != "direct-tcpip" { + nc.Reject(ssh.UnknownChannelType, "not implemented") + continue + } + var args struct { + DstHost string + DstPort uint32 + SrcHost string + SrcPort uint32 + } + if !unmarshalData(nc.ExtraData(), &args) { + nc.Reject(ssh.Prohibited, "invalid request") + continue + } + + // Open a connection for both sides. + cn, err := net.Dial("tcp", net.JoinHostPort(args.DstHost, strconv.Itoa(int(args.DstPort)))) + if err != nil { + nc.Reject(ssh.ConnectionFailed, err.Error()) + continue + } + ch, reqs, err := nc.Accept() + if err != nil { + t.Errorf("accept channel error: %v", err) + cn.Close() + continue + } + go ssh.DiscardRequests(reqs) + + wg.Add(1) + go bidirCopyAndClose(t, ctx, wg, cn, ch) + } +} + +// handleServerRequests handles new requests on a SSH connection. +// The client initiates a new request for binding a local TCP socket. +func handleServerRequests(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, reqs <-chan *ssh.Request) { + defer wg.Done() + for r := range reqs { + if !r.WantReply { + continue + } + if r.Type != "tcpip-forward" { + r.Reply(false, nil) + continue + } + var args struct { + Host string + Port uint32 + } + if !unmarshalData(r.Payload, &args) { + r.Reply(false, nil) + continue + } + ln, err := net.Listen("tcp", net.JoinHostPort(args.Host, strconv.Itoa(int(args.Port)))) + if err != nil { + r.Reply(false, nil) + continue + } + + var resp struct{ Port uint32 } + _, resp.Port = splitHostPort(ln.Addr().String()) + if err := r.Reply(true, marshalData(resp)); err != nil { + t.Errorf("request reply error: %v", err) + ln.Close() + continue + } + + wg.Add(1) + go handleLocalListener(t, ctx, wg, sshCn, ln, args.Host) + + } +} + +// handleLocalListener handles every new connection on the provided socket. +// All local connections will be forwarded to the client via a new channel. +func handleLocalListener(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, ln net.Listener, host string) { + defer wg.Done() + go closeWhenDone(ctx, ln) + defer ln.Close() + + for { + // Open a connection for both sides. + cn, err := ln.Accept() + if err != nil { + if !isDone(ctx) { + t.Errorf("accept error: %v", err) + } + return + } + var args struct { + DstHost string + DstPort uint32 + SrcHost string + SrcPort uint32 + } + args.DstHost, args.DstPort = splitHostPort(cn.LocalAddr().String()) + args.SrcHost, args.SrcPort = splitHostPort(cn.RemoteAddr().String()) + args.DstHost = host // This must match on client side! + ch, reqs, err := sshCn.OpenChannel("forwarded-tcpip", marshalData(args)) + if err != nil { + t.Errorf("open channel error: %v", err) + cn.Close() + continue + } + go ssh.DiscardRequests(reqs) + + wg.Add(1) + go bidirCopyAndClose(t, ctx, wg, cn, ch) + } +} + +// bidirCopyAndClose performs a bi-directional copy on both connections +// until either side closes the connection or the context is canceled. +// This will close both connections before returning. +func bidirCopyAndClose(t *testing.T, ctx context.Context, wg *sync.WaitGroup, c1, c2 io.ReadWriteCloser) { + defer wg.Done() + go closeWhenDone(ctx, c1) + go closeWhenDone(ctx, c2) + defer c1.Close() + defer c2.Close() + + errc := make(chan error, 2) + go func() { + _, err := io.Copy(c1, c2) + errc <- err + }() + go func() { + _, err := io.Copy(c2, c1) + errc <- err + }() + if err := <-errc; err != nil && err != io.EOF && !isDone(ctx) { + t.Errorf("copy error: %v", err) + } +} + +// unmarshalData parses b into s, where s is a pointer to a struct. +// Only unexported fields of type uint32 or string are allowed. +func unmarshalData(b []byte, s interface{}) bool { + v := reflect.ValueOf(s) + if !v.IsValid() || v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { + panic("destination must be pointer to struct") + } + v = v.Elem() + for i := 0; i < v.NumField(); i++ { + switch v.Type().Field(i).Type.Kind() { + case reflect.Uint32: + if len(b) < 4 { + return false + } + v.Field(i).Set(reflect.ValueOf(binary.BigEndian.Uint32(b))) + b = b[4:] + case reflect.String: + if len(b) < 4 { + return false + } + n := binary.BigEndian.Uint32(b) + b = b[4:] + if uint64(len(b)) < uint64(n) { + return false + } + v.Field(i).Set(reflect.ValueOf(string(b[:n]))) + b = b[n:] + default: + panic("invalid field type: " + v.Type().Field(i).Type.String()) + } + } + return len(b) == 0 +} + +// marshalData serializes s into b, where s is a struct (or a pointer to one). +// Only unexported fields of type uint32 or string are allowed. +func marshalData(s interface{}) (b []byte) { + v := reflect.ValueOf(s) + if v.IsValid() && v.Kind() == reflect.Ptr { + v = v.Elem() + } + if !v.IsValid() || v.Kind() != reflect.Struct { + panic("source must be a struct") + } + var arr32 [4]byte + for i := 0; i < v.NumField(); i++ { + switch v.Type().Field(i).Type.Kind() { + case reflect.Uint32: + binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Uint())) + b = append(b, arr32[:]...) + case reflect.String: + binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Len())) + b = append(b, arr32[:]...) + b = append(b, v.Field(i).String()...) + default: + panic("invalid field type: " + v.Type().Field(i).Type.String()) + } + } + return b + +} + +func splitHostPort(s string) (string, uint32) { + host, port, _ := net.SplitHostPort(s) + p, _ := strconv.Atoi(port) + return host, uint32(p) +} + +func closeWhenDone(ctx context.Context, c io.Closer) { + <-ctx.Done() + c.Close() +} + +func isDone(ctx context.Context) bool { + select { + case <-ctx.Done(): + return true + default: + return false + } +} diff --git a/pkg/plugin/pb.go b/pkg/plugin/pb.go new file mode 100644 index 0000000..6d718f7 --- /dev/null +++ b/pkg/plugin/pb.go @@ -0,0 +1,2 @@ +//go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative pb.proto +package plugin diff --git a/pkg/plugin/pb.proto b/pkg/plugin/pb.proto new file mode 100644 index 0000000..7047695 --- /dev/null +++ b/pkg/plugin/pb.proto @@ -0,0 +1,36 @@ +syntax = "proto3"; +import "google/protobuf/empty.proto"; + +option go_package = "github.com/Neur0toxine/sshpoke/pkg/plugin"; +option java_multiple_files = true; + +service PluginService { + rpc Event (stream google.protobuf.Empty) returns (stream EventMessage); + rpc EventStatus (EventStatusMessage) returns (google.protobuf.Empty); + rpc Shutdown (stream google.protobuf.Empty) returns (google.protobuf.Empty); +} + +enum EventType { + EVENT_START = 0; + EVENT_STOP = 1; + EVENT_UNKNOWN = 2; +} + +message Container { + string ip = 1; + uint32 port = 2; + string server = 3; + string prefix = 4; + string domain = 5; +} + +message EventMessage { + EventType type = 1; + string id = 2; + Container container = 3; +} + +message EventStatusMessage { + string id = 1; + string error = 2; +}