From 928f9d621c074aeba916a943f011195e93e588a7 Mon Sep 17 00:00:00 2001 From: YR Chen Date: Sat, 24 Aug 2024 21:29:03 +0800 Subject: [PATCH] More robust endpoint URL parsing --- auth.go | 22 ++++++++++++++++------ sshmux.go | 6 +++++- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/auth.go b/auth.go index 4f0ebaf..03c638b 100644 --- a/auth.go +++ b/auth.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "net/url" "golang.org/x/crypto/ssh" ) @@ -52,7 +53,7 @@ type Authenticator interface { Auth(request AuthRequest, username string) (int, *AuthResponse, error) } -func makeAuthenticator(auth AuthConfig) Authenticator { +func makeAuthenticator(auth AuthConfig) (Authenticator, error) { if auth.Version == "" { auth.Version = "v1" } @@ -60,15 +61,20 @@ func makeAuthenticator(auth AuthConfig) Authenticator { for _, header := range auth.Headers { headers.Add(header.Name, header.Value) } - return &RESTfulAuthenticator{ - Endpoint: auth.Endpoint, + auth_url, err := url.Parse(auth.Endpoint) + if err != nil { + return nil, err + } + authenticator := RESTfulAuthenticator{ + Endpoint: auth_url, Version: auth.Version, Headers: headers, } + return &authenticator, nil } type RESTfulAuthenticator struct { - Endpoint string + Endpoint *url.URL Version string Headers http.Header } @@ -77,16 +83,19 @@ func (auth *RESTfulAuthenticator) Auth(request AuthRequest, username string) (in if auth.Version != "v1" { return 500, nil, fmt.Errorf("unsupported API version: %s", auth.Version) } - url := fmt.Sprintf("%s/v1/auth/%s", auth.Endpoint, username) + auth_url := auth.Endpoint.JoinPath("v1", "auth", username).String() + payload := new(bytes.Buffer) if err := json.NewEncoder(payload).Encode(request); err != nil { return 0, nil, err } - req, err := http.NewRequest("POST", url, payload) + + req, err := http.NewRequest("POST", auth_url, payload) if err != nil { return 0, nil, err } req.Header = auth.Headers + res, err := http.DefaultClient.Do(req) if err != nil { return res.StatusCode, nil, err @@ -96,6 +105,7 @@ func (auth *RESTfulAuthenticator) Auth(request AuthRequest, username string) (in if err != nil { return res.StatusCode, nil, err } + var response AuthResponse err = json.Unmarshal(body, &response) if err != nil { diff --git a/sshmux.go b/sshmux.go index 9d3b325..3a94a38 100644 --- a/sshmux.go +++ b/sshmux.go @@ -106,7 +106,11 @@ func makeServer(config Config) (*Server, error) { legacyAuthenticator := makeLegacyAuthenticator(config.Auth, config.Recovery) authenticator = &legacyAuthenticator } else { - authenticator = makeAuthenticator(config.Auth) + var err error + authenticator, err = makeAuthenticator(config.Auth) + if err != nil { + return nil, err + } } sshmux := &Server{ Address: config.Address,