From 883bd0d9068167ed45712a4725ed668127e69b43 Mon Sep 17 00:00:00 2001 From: YR Chen Date: Tue, 13 Aug 2024 15:52:54 +0800 Subject: [PATCH] Refactor with minor improvements, and add README (#5) --- .github/workflows/build.yml | 5 +- README.md | 72 +++++++++ auth.go | 160 ++++++++++++++++++ config.go | 35 ++++ logging.go | 52 ++++++ main.go | 32 ++++ sshmux.go | 313 +++++++++--------------------------- 7 files changed, 426 insertions(+), 243 deletions(-) create mode 100644 README.md create mode 100644 auth.go create mode 100644 config.go create mode 100644 logging.go create mode 100644 main.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ebfcc46..dda402c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -34,7 +34,7 @@ jobs: CGO_ENABLED: 0 - name: Create GitHub Release if: startsWith(github.ref, 'refs/tags/') - id: upload-release-asset + id: upload-release-asset uses: softprops/action-gh-release@v1 with: files: | @@ -49,7 +49,8 @@ jobs: id: meta uses: docker/metadata-action@v5 with: - images: ghcr.io/ustc-vlab/sshmux + images: | + ghcr.io/${{ github.repository_owner }}/sshmux tags: | type=ref,event=branch type=ref,event=pr diff --git a/README.md b/README.md new file mode 100644 index 0000000..14242d6 --- /dev/null +++ b/README.md @@ -0,0 +1,72 @@ +# sshmux + +`sshmux` is a new, simple implementation of SSH reverse proxy. `sshmux` was initially developed for Vlab, while we'd like to expand its usage to cover more scenarios. + +## Build, Run and Test + +`sshmux` requires a Go 1.21+ toolchain to build. You can use `go build` or `make` to get the `sshmux` binary directly in the directory. + +You can run the built binary with `./sshmux`. Note that you'll need to provide a valid configuration file as described [here](#config). + +You can perform unit tests with `go test` or `make test`. Enable verbose logging with `go test -v`. + +## Config + +`sshmux` requires a JSON configuration file to start up. By default it will look at `/etc/sshmux/config.json`, but you can also specify a custom configuration by passing `-c path/to/config.json` in the command line arguments. An [example](config.example.json) file is provided. + +The table below shows the available options for `sshmux`: + +| Key | Type | Description | Required | Example | +|-------------|------------|--------------------------------------------------------------------|----------|------------------------------------| +| `address` | `string` | TCP host and port that `sshmux` will listen on. | `true` | `"0.0.0.0:8022"` | +| `host-keys` | `[]string` | Paths to SSH host key files with which `sshmux` identifies itself. | `true` | `["/sshmux/ssh_host_ed25519_key"]` | +| `api` | `string` | HTTP address that `sshmux` shall interact with. | `true` | `"http://127.0.0.1:5000/ssh"` | +| `token` | `string` | Token used to authenticate with the API endpoint. | `true` | `"long-and-random-token"` | +| `banner` | `string` | SSH banner to send to downstream. | `false` | `"Welcome to Vlab\n"` | +| `logger` | `string` | UDP host and port that `sshmux` send log messages to. | `false` | `"127.0.0.1:5556"` | +| `proxy-protocol-allowed-cidrs` | `[]string` | CIDRs from which [PROXY protocol](https://www.haproxy.com/blog/use-the-proxy-protocol-to-preserve-a-clients-ip-address) is allowed. | `false` | `["127.0.0.22/32"]` | + +### Advanced Config + +The table below shows extra options for `sshmux`, mainly for authentication with Vlab backends: + +| Key | Type | Description | Example | +|----------------------------|------------|----------------------------------------------------------------------------|------------------------------| +| `recovery-token` | `string` | Token used to authenticate with the recovery backend. Defaults to `token`. | `"long-and-random-token"` | +| `recovery-server` | `string` | SSH host and port of the recovery server. | `"172.30.0.101:2222"` | +| `recovery-username` | `[]string` | Usernames dedicated to the recovery server. | `["recovery", "console"]` | +| `all-username-nopassword` | `bool` | If set to `true`, no users will be asked for UNIX password. | `true` | +| `username-nopassword` | `[]string` | Usernames that won't be asked for UNIX password. | `["vlab", "ubuntu", "root"]` | +| `invalid-username` | `[]string` | Usernames that are known to be invalid. | `["user"]` | +| `invalid-username-message` | `string` | Message to display when the requested username is invalid. | `"Invalid username %s."` | + +All of these options can be omitted, if the corresponding feature is not intended to be used. + +## API server + +`sshmux` requires an API server to perform authentication and authorization for a user. + +The API accepts JSON input with the following keys: + +| Key | Type | Description | +|-------------------|----------|----------------------------------------------------------------------------------------------------------| +| `auth_type` | `string` | The authentication type. Always set to `"key"` at the moment. | +| `username` | `string` | Vlab username. Omitted if the user is authenticating with public key. | +| `password` | `string` | Vlab password. Omitted if the user is authenticating with public key. | +| `public_key_type` | `string` | SSH public key type. Omitted if the user is authenticating with username and password. | +| `public_key_data` | `string` | Base64-encoded SSH public key payload. Omitted if the user is authenticating with username and password. | +| `unix_username` | `string` | UNIX username the user is requesting access to. | +| `token` | `string` | Token used to authenticate the `sshmux` instance. | + +The API responds with JSON output with the following keys: + +| Key | Type | Description | +|------------------|-----------|----------------------------------------------------------------------------------------------------------------------| +| `status` | `string` | The authentication status. Should be `"ok"` if the user is authorized. | +| `address` | `string` | TCP host and port of the downstream SSH server the user is requesting for. | +| `private_key` | `string` | SSH private key to authenticate for the downstream. | +| `cert` | `string` | The certificate associated with the SSH private key. | +| `vmid` | `integer` | ID of the requested VM. Only used for recovery access. | +| `proxy_protocol` | `integer` | PROXY protocol version to use for the downstream. Should be `1`, `2` or omitted (which disables PROXY protocol). | + +Note that if the user is not authorized, the API server should return a `status` other than `"ok"`, and other keys can be safely ommitted. diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..f0a3f23 --- /dev/null +++ b/auth.go @@ -0,0 +1,160 @@ +package main + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "slices" + + "golang.org/x/crypto/ssh" +) + +type AuthRequestPublicKey struct { + AuthType string `json:"auth_type"` + UnixUsername string `json:"unix_username"` + PublicKeyType string `json:"public_key_type"` + PublicKeyData string `json:"public_key_data"` + Token string `json:"token"` +} + +type AuthRequestPassword struct { + AuthType string `json:"auth_type"` + Username string `json:"username"` + Password string `json:"password"` + UnixUsername string `json:"unix_username"` + Token string `json:"token"` +} + +type AuthResponse struct { + Status string `json:"status"` + Address string `json:"address"` + PrivateKey string `json:"private_key"` + Cert string `json:"cert"` + Id int `json:"vmid"` + ProxyProtocol byte `json:"proxy_protocol,omitempty"` +} + +type UpstreamInformation struct { + Host string + Signer ssh.Signer + Password *string + ProxyProtocol byte +} + +type Authenticator struct { + Endpoint string + Token string + Recovery RecoveryConfig +} + +func makeAuthenticator(config Config) Authenticator { + recoveryToken := config.RecoveryToken + if recoveryToken == "" { + recoveryToken = config.Token + } + return Authenticator{ + Endpoint: config.API, + Token: config.Token, + Recovery: RecoveryConfig{ + Server: config.RecoveryServer, + Username: config.RecoveryUsername, + Token: recoveryToken, + }, + } +} + +func parsePrivateKey(key string, cert string) ssh.Signer { + if key == "" { + return nil + } + signer, err := ssh.ParsePrivateKey([]byte(key)) + if err != nil { + return nil + } + if cert == "" { + return signer + } + pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(cert)) + if err != nil { + return signer + } + certSigner, err := ssh.NewCertSigner(pk.(*ssh.Certificate), signer) + if err != nil { + return signer + } + return certSigner +} + +func (auth Authenticator) AuthUser(request any, username string) (*UpstreamInformation, error) { + payload := new(bytes.Buffer) + if err := json.NewEncoder(payload).Encode(request); err != nil { + return nil, err + } + res, err := http.Post(auth.Endpoint, "application/json", payload) + if err != nil { + return nil, err + } + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + var response AuthResponse + err = json.Unmarshal(body, &response) + if err != nil { + return nil, err + } + if response.Status != "ok" { + return nil, nil + } + + var upstream UpstreamInformation + // FIXME: Can this be handled in API server? + if slices.Contains(auth.Recovery.Username, username) { + upstream.Host = auth.Recovery.Server + password := fmt.Sprintf("%d %s", response.Id, auth.Recovery.Token) + upstream.Password = &password + } else { + upstream.Host = response.Address + } + upstream.Signer = parsePrivateKey(response.PrivateKey, response.Cert) + upstream.ProxyProtocol = response.ProxyProtocol + return &upstream, nil +} + +func (auth Authenticator) AuthUserWithPublicKey(key ssh.PublicKey, unixUsername string) (*UpstreamInformation, error) { + keyType := key.Type() + keyData := base64.StdEncoding.EncodeToString(key.Marshal()) + request := &AuthRequestPublicKey{ + AuthType: "key", + UnixUsername: unixUsername, + PublicKeyType: keyType, + PublicKeyData: keyData, + Token: auth.Token, + } + return auth.AuthUser(request, unixUsername) +} + +func (auth Authenticator) AuthUserWithUserPass(username string, password string, unixUsername string) (*UpstreamInformation, error) { + request := &AuthRequestPassword{ + AuthType: "key", + Username: username, + Password: password, + UnixUsername: unixUsername, + Token: auth.Token, + } + return auth.AuthUser(request, unixUsername) +} + +func removePublicKeyMethod(methods []string) []string { + res := make([]string, 0, len(methods)) + for _, s := range methods { + if s != "publickey" { + res = append(res, s) + } + } + return res +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..4fffdc9 --- /dev/null +++ b/config.go @@ -0,0 +1,35 @@ +package main + +type Config struct { + Address string `json:"address"` + ProxyCIDRs []string `json:"proxy-protocol-allowed-cidrs"` + HostKeys []string `json:"host-keys"` + API string `json:"api"` + Logger string `json:"logger"` + Banner string `json:"banner"` + Token string `json:"token"` + // The following should be moved into API server + RecoveryToken string `json:"recovery-token"` + RecoveryServer string `json:"recovery-server"` + RecoveryUsername []string `json:"recovery-username"` + AllUsernameNoPassword bool `json:"all-username-nopassword"` + UsernameNoPassword []string `json:"username-nopassword"` + InvalidUsername []string `json:"invalid-username"` + InvalidUsernameMessage string `json:"invalid-username-message"` +} + +type UsernamePolicyConfig struct { + InvalidUsername []string `json:"invalid-username"` + InvalidUsernameMessage string `json:"invalid-username-message"` +} + +type PasswordPolicyConfig struct { + AllUsernameNoPassword bool `json:"all-username-nopassword"` + UsernameNoPassword []string `json:"username-nopassword"` +} + +type RecoveryConfig struct { + Server string `json:"recovery-server"` + Username []string `json:"recovery-username"` + Token string `json:"token"` +} diff --git a/logging.go b/logging.go new file mode 100644 index 0000000..ed0f16b --- /dev/null +++ b/logging.go @@ -0,0 +1,52 @@ +package main + +import ( + "encoding/json" + "log" + "net" + "time" +) + +type LogMessage struct { + LoginTime int64 `json:"login_time"` + DisconnectTime int64 `json:"disconnect_time"` + ClientIp string `json:"remote_ip"` + HostIp string `json:"host_ip"` + Username string `json:"user_name"` +} + +type Logger struct { + Channel chan LogMessage +} + +func makeLogger(url string) Logger { + channel := make(chan LogMessage, 256) + go func() { + if url == "" { + for range channel { + } + return + } + conn, err := net.Dial("udp", url) + if err != nil { + log.Printf("Logger Dial failed: %s\n", err) + // Drain the channel to avoid blocking + for range channel { + } + return + } + for logMessage := range channel { + jsonMsg, err := json.Marshal(logMessage) + if err != nil { + continue + } + conn.Write(jsonMsg) + } + }() + return Logger{Channel: channel} +} + +func (l Logger) SendLog(logMessage *LogMessage) { + logMessage.DisconnectTime = time.Now().Unix() + l.Channel <- *logMessage +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..a374025 --- /dev/null +++ b/main.go @@ -0,0 +1,32 @@ +package main + +import ( + "encoding/json" + "flag" + "log" + "os" +) + +func sshmuxServer(configFile string) { + var config Config + configFileBytes, err := os.ReadFile(configFile) + if err != nil { + log.Fatal(err) + } + err = json.Unmarshal(configFileBytes, &config) + if err != nil { + log.Fatal(err) + } + sshmux, err := makeServer(config) + if err != nil { + log.Fatal(err) + } + sshmux.ListenAddr(config.Address) +} + +func main() { + var configFile string + flag.StringVar(&configFile, "c", "/etc/sshmux/config.json", "config file") + flag.Parse() + sshmuxServer(configFile) +} diff --git a/sshmux.go b/sshmux.go index 1156c53..66f06b6 100644 --- a/sshmux.go +++ b/sshmux.go @@ -1,15 +1,9 @@ package main import ( - "bytes" - "encoding/base64" - "encoding/json" - "flag" "fmt" - "io" "log" "net" - "net/http" "net/netip" "os" "slices" @@ -19,163 +13,64 @@ import ( "golang.org/x/crypto/ssh" ) -type Config struct { - Address string `json:"address"` - ProxyCIDRs []string `json:"proxy-protocol-allowed-cidrs"` - HostKeys []string `json:"host-keys"` - API string `json:"api"` - Token string `json:"token"` - RecoveryServer string `json:"recovery-server"` - RecoveryUsername []string `json:"recovery-username"` - AllUsernameNoPassword bool `json:"all-username-nopassword"` - UsernameNoPassword []string `json:"username-nopassword"` - InvalidUsername []string `json:"invalid-username"` - InvalidUsernameMessage string `json:"invalid-username-message"` - Logger string `json:"logger"` - Banner string `json:"banner"` +type Server struct { + Banner string + SSHConfig *ssh.ServerConfig + ProxyUpstreams []netip.Prefix + Authenticator Authenticator + Logger Logger + UsernamePolicy UsernamePolicyConfig + PasswordPolicy PasswordPolicyConfig } -type LogMessage struct { - LoginTime int64 `json:"login_time"` - DisconnectTime int64 `json:"disconnect_time"` - ClientIp string `json:"remote_ip"` - HostIp string `json:"host_ip"` - Username string `json:"user_name"` -} - -var configFile string -var config Config - -type AuthRequestPublicKey struct { - AuthType string `json:"auth_type"` - UnixUsername string `json:"unix_username"` - PublicKeyType string `json:"public_key_type"` - PublicKeyData string `json:"public_key_data"` - Token string `json:"token"` -} - -type AuthRequestPassword struct { - AuthType string `json:"auth_type"` - Username string `json:"username"` - Password string `json:"password"` - UnixUsername string `json:"unix_username"` - Token string `json:"token"` -} - -type AuthResponse struct { - Status string `json:"status"` - Address string `json:"address"` - PrivateKey string `json:"private_key"` - Cert string `json:"cert"` - Id int `json:"vmid"` - ProxyProtocol byte `json:"proxy_protocol,omitempty"` -} - -type UpstreamInformation struct { - Host string - Signer ssh.Signer - Password *string - ProxyProtocol byte -} - -func parsePrivateKey(key string, cert string) ssh.Signer { - if key == "" { - return nil - } - signer, err := ssh.ParsePrivateKey([]byte(key)) - if err != nil { - return nil - } - if cert == "" { - return signer - } - pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(cert)) - if err != nil { - return signer - } - certSigner, err := ssh.NewCertSigner(pk.(*ssh.Certificate), signer) - if err != nil { - return signer - } - return certSigner -} - -func authUser(request any, username string) (*UpstreamInformation, error) { - payload := new(bytes.Buffer) - if err := json.NewEncoder(payload).Encode(request); err != nil { - return nil, err - } - res, err := http.Post(config.API, "application/json", payload) - if err != nil { - return nil, err - } - defer res.Body.Close() - body, err := io.ReadAll(res.Body) - if err != nil { - return nil, err - } - var response AuthResponse - err = json.Unmarshal(body, &response) - if err != nil { - return nil, err - } - if response.Status != "ok" { - return nil, nil - } - - var upstream UpstreamInformation - if slices.Contains(config.RecoveryUsername, username) { - upstream.Host = config.RecoveryServer - password := fmt.Sprintf("%d %s", response.Id, config.Token) - upstream.Password = &password - } else { - upstream.Host = response.Address - } - upstream.Signer = parsePrivateKey(response.PrivateKey, response.Cert) - upstream.ProxyProtocol = response.ProxyProtocol - return &upstream, nil -} - -func authUserWithPublicKey(key ssh.PublicKey, unixUsername string) (*UpstreamInformation, error) { - keyType := key.Type() - keyData := base64.StdEncoding.EncodeToString(key.Marshal()) - request := &AuthRequestPublicKey{ - AuthType: "key", - UnixUsername: unixUsername, - PublicKeyType: keyType, - PublicKeyData: keyData, - Token: config.Token, +func makeServer(config Config) (*Server, error) { + sshConfig := &ssh.ServerConfig{ + ServerVersion: "SSH-2.0-taokystrong", + PublicKeyAuthAlgorithms: ssh.DefaultPubKeyAuthAlgos(), } - return authUser(request, unixUsername) -} - -func authUserWithUserPass(username string, password string, unixUsername string) (*UpstreamInformation, error) { - request := &AuthRequestPassword{ - AuthType: "key", - Username: username, - Password: password, - UnixUsername: unixUsername, - Token: config.Token, + for _, keyFile := range config.HostKeys { + bytes, err := os.ReadFile(keyFile) + if err != nil { + return nil, err + } + key, err := ssh.ParsePrivateKey(bytes) + if err != nil { + return nil, err + } + sshConfig.AddHostKey(key) } - return authUser(request, unixUsername) -} - -func removePublicKeyMethod(methods []string) []string { - res := []string{} - for _, s := range methods { - if s != "publickey" { - res = append(res, s) + proxyUpstreams := make([]netip.Prefix, 0) + for _, cidr := range config.ProxyCIDRs { + network, err := netip.ParsePrefix(cidr) + if err != nil { + return nil, err } + proxyUpstreams = append(proxyUpstreams, network) } - return res + sshmux := &Server{ + Banner: config.Banner, + SSHConfig: sshConfig, + ProxyUpstreams: proxyUpstreams, + Authenticator: makeAuthenticator(config), + Logger: makeLogger(config.Logger), + UsernamePolicy: UsernamePolicyConfig{ + InvalidUsername: config.InvalidUsername, + InvalidUsernameMessage: config.InvalidUsernameMessage, + }, + PasswordPolicy: PasswordPolicyConfig{ + AllUsernameNoPassword: config.AllUsernameNoPassword, + UsernameNoPassword: config.UsernameNoPassword, + }, + } + return sshmux, nil } -func handshake(session *ssh.PipeSession) error { +func (s *Server) Handshake(session *ssh.PipeSession) error { hasSetUser := false var user string var upstream *UpstreamInformation - if config.Banner != "" { - err := session.Downstream.SendBanner(config.Banner) + if s.Banner != "" { + err := session.Downstream.SendBanner(s.Banner) if err != nil { return err } @@ -191,16 +86,16 @@ func handshake(session *ssh.PipeSession) error { session.Downstream.SetUser(user) hasSetUser = true } - if slices.Contains(config.InvalidUsername, user) { + if slices.Contains(s.UsernamePolicy.InvalidUsername, user) { // 15: SSH_DISCONNECT_ILLEGAL_USER_NAME - msg := fmt.Sprintf(config.InvalidUsernameMessage, user) + msg := fmt.Sprintf(s.UsernamePolicy.InvalidUsernameMessage, user) session.Downstream.WriteDisconnectMsg(15, msg) return fmt.Errorf("ssh: invalid username") } if req.Method == "none" { session.Downstream.WriteAuthFailure([]string{"publickey", "keyboard-interactive"}, false) } else if req.Method == "publickey" && !req.IsPublicKeyQuery { - upstream, err = authUserWithPublicKey(*req.PublicKey, user) + upstream, err = s.Authenticator.AuthUserWithPublicKey(*req.PublicKey, user) if err != nil { return err } @@ -209,9 +104,10 @@ func handshake(session *ssh.PipeSession) error { } session.Downstream.WriteAuthFailure([]string{"publickey", "keyboard-interactive"}, false) } else if req.Method == "keyboard-interactive" { - requireUnixPassword := !config.AllUsernameNoPassword && - !slices.Contains(config.RecoveryUsername, user) && - !slices.Contains(config.UsernameNoPassword, user) + // FIXME: Can this be handled by API server? + requireUnixPassword := !s.PasswordPolicy.AllUsernameNoPassword && + !slices.Contains(s.Authenticator.Recovery.Username, user) && + !slices.Contains(s.PasswordPolicy.UsernameNoPassword, user) interactiveQuestions := []string{"Vlab username (Student ID): ", "Vlab password: "} interactiveEcho := []bool{true, false} @@ -226,7 +122,7 @@ func handshake(session *ssh.PipeSession) error { } username := answers[0] password := answers[1] - upstream, err = authUserWithUserPass(username, password, user) + upstream, err = s.Authenticator.AuthUserWithUserPass(username, password, user) if err != nil { return err } @@ -262,11 +158,11 @@ func handshake(session *ssh.PipeSession) error { return err } } - config := &ssh.ClientConfig{ + sshConfig := &ssh.ClientConfig{ User: user, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - err = session.InitUpstream(conn, upstream.Host, config) + err = session.InitUpstream(conn, upstream.Host, sshConfig) if err != nil { return err } @@ -327,46 +223,13 @@ func handshake(session *ssh.PipeSession) error { } } -func runPipeSession(session *ssh.PipeSession, logMessage *LogMessage) error { - err := handshake(session) - if err != nil { - return err - } - logMessage.Username = session.Downstream.User() - logMessage.HostIp = session.Upstream.RemoteAddr().String() - return session.RunPipe() -} - -func runLogger(ch <-chan LogMessage) { - conn, err := net.Dial("udp", config.Logger) - if err != nil { - log.Printf("Logger Dial failed: %s\n", err) - // Drain the channel to avoid blocking - for range ch { - } - } - for logMessage := range ch { - jsonMsg, err := json.Marshal(logMessage) - if err != nil { - continue - } - conn.Write(jsonMsg) - } -} - -func sendLogAndClose(logMessage *LogMessage, session *ssh.PipeSession, logCh chan<- LogMessage) { - session.Close() - logMessage.DisconnectTime = time.Now().Unix() - logCh <- *logMessage -} - -func sshmuxListenAddr(address string, sshConfig *ssh.ServerConfig, proxyUpstreams []netip.Prefix) { +func (s *Server) ListenAddr(address string) error { // set up TCP listener listener, err := net.Listen("tcp", address) if err != nil { log.Fatal(err) } - if len(proxyUpstreams) > 0 { + if len(s.ProxyUpstreams) > 0 { listener = &proxyproto.Listener{ Listener: listener, Policy: func(upstream net.Addr) (proxyproto.Policy, error) { @@ -377,7 +240,7 @@ func sshmuxListenAddr(address string, sshConfig *ssh.ServerConfig, proxyUpstream } upstreamAddr := upstreamAddrPort.Addr() // only read PROXY header from allowed CIDRs - for _, network := range proxyUpstreams { + for _, network := range s.ProxyUpstreams { if network.Contains(upstreamAddr) { return proxyproto.USE, nil } @@ -389,10 +252,6 @@ func sshmuxListenAddr(address string, sshConfig *ssh.ServerConfig, proxyUpstream } defer listener.Close() - // set up log channel - logCh := make(chan LogMessage, 256) - go runLogger(logCh) - // main handler loop for { conn, err := listener.Accept() @@ -401,7 +260,7 @@ func sshmuxListenAddr(address string, sshConfig *ssh.ServerConfig, proxyUpstream continue } go func() { - session, err := ssh.NewPipeSession(conn, sshConfig) + session, err := ssh.NewPipeSession(conn, s.SSHConfig) logMessage := LogMessage{ LoginTime: time.Now().Unix(), ClientIp: conn.RemoteAddr().String(), @@ -409,51 +268,23 @@ func sshmuxListenAddr(address string, sshConfig *ssh.ServerConfig, proxyUpstream if err != nil { return } - defer sendLogAndClose(&logMessage, session, logCh) - if err := runPipeSession(session, &logMessage); err != nil { + defer func() { + session.Close() + s.Logger.SendLog(&logMessage) + }() + if err := s.RunPipeSession(session, &logMessage); err != nil { log.Println("runPipeSession:", err) } }() } } -func sshmuxServer(configFile string) { - configFileBytes, err := os.ReadFile(configFile) - if err != nil { - log.Fatal(err) - } - err = json.Unmarshal(configFileBytes, &config) +func (s *Server) RunPipeSession(session *ssh.PipeSession, logMessage *LogMessage) error { + err := s.Handshake(session) if err != nil { - log.Fatal(err) - } - sshConfig := &ssh.ServerConfig{ - ServerVersion: "SSH-2.0-taokystrong", - PublicKeyAuthAlgorithms: ssh.DefaultPubKeyAuthAlgos(), - } - for _, keyFile := range config.HostKeys { - bytes, err := os.ReadFile(keyFile) - if err != nil { - log.Fatal(err) - } - key, err := ssh.ParsePrivateKey(bytes) - if err != nil { - log.Fatal(err) - } - sshConfig.AddHostKey(key) - } - proxyUpstreams := make([]netip.Prefix, 0) - for _, cidr := range config.ProxyCIDRs { - network, err := netip.ParsePrefix(cidr) - if err != nil { - log.Fatal(err) - } - proxyUpstreams = append(proxyUpstreams, network) + return err } - sshmuxListenAddr(config.Address, sshConfig, proxyUpstreams) -} - -func main() { - flag.StringVar(&configFile, "c", "/etc/sshmux/config.json", "config file") - flag.Parse() - sshmuxServer(configFile) + logMessage.Username = session.Downstream.User() + logMessage.HostIp = session.Upstream.RemoteAddr().String() + return session.RunPipe() }