Skip to content

Commit

Permalink
More robust endpoint URL parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
stevapple committed Aug 24, 2024
1 parent 9f2ad2e commit 928f9d6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
22 changes: 16 additions & 6 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"

"golang.org/x/crypto/ssh"
)
Expand Down Expand Up @@ -52,23 +53,28 @@ type Authenticator interface {
Auth(request AuthRequest, username string) (int, *AuthResponse, error)
}

func makeAuthenticator(auth AuthConfig) Authenticator {
func makeAuthenticator(auth AuthConfig) (Authenticator, error) {
if auth.Version == "" {
auth.Version = "v1"
}
headers := http.Header{}
for _, header := range auth.Headers {
headers.Add(header.Name, header.Value)
}
return &RESTfulAuthenticator{
Endpoint: auth.Endpoint,
auth_url, err := url.Parse(auth.Endpoint)
if err != nil {
return nil, err
}
authenticator := RESTfulAuthenticator{
Endpoint: auth_url,
Version: auth.Version,
Headers: headers,
}
return &authenticator, nil
}

type RESTfulAuthenticator struct {
Endpoint string
Endpoint *url.URL
Version string
Headers http.Header
}
Expand All @@ -77,16 +83,19 @@ func (auth *RESTfulAuthenticator) Auth(request AuthRequest, username string) (in
if auth.Version != "v1" {
return 500, nil, fmt.Errorf("unsupported API version: %s", auth.Version)
}
url := fmt.Sprintf("%s/v1/auth/%s", auth.Endpoint, username)
auth_url := auth.Endpoint.JoinPath("v1", "auth", username).String()

payload := new(bytes.Buffer)
if err := json.NewEncoder(payload).Encode(request); err != nil {
return 0, nil, err
}
req, err := http.NewRequest("POST", url, payload)

req, err := http.NewRequest("POST", auth_url, payload)
if err != nil {
return 0, nil, err
}
req.Header = auth.Headers

res, err := http.DefaultClient.Do(req)
if err != nil {
return res.StatusCode, nil, err
Expand All @@ -96,6 +105,7 @@ func (auth *RESTfulAuthenticator) Auth(request AuthRequest, username string) (in
if err != nil {
return res.StatusCode, nil, err
}

var response AuthResponse
err = json.Unmarshal(body, &response)
if err != nil {
Expand Down
6 changes: 5 additions & 1 deletion sshmux.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ func makeServer(config Config) (*Server, error) {
legacyAuthenticator := makeLegacyAuthenticator(config.Auth, config.Recovery)
authenticator = &legacyAuthenticator
} else {
authenticator = makeAuthenticator(config.Auth)
var err error
authenticator, err = makeAuthenticator(config.Auth)
if err != nil {
return nil, err
}
}
sshmux := &Server{
Address: config.Address,
Expand Down

0 comments on commit 928f9d6

Please sign in to comment.