diff --git a/client.go b/client.go index 4f2af70..ad0aded 100644 --- a/client.go +++ b/client.go @@ -18,16 +18,68 @@ type Client struct { expiration time.Time token *Token m *sync.Mutex + stopChan chan struct{} } -type refreshTokenTransport struct { +type Transport struct { rt http.RoundTripper cli *Client } -func (t refreshTokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { - var err error +func (c *Client) refreshTokenIfNeeded() error { + c.m.Lock() + defer c.m.Unlock() + if time.Now().Add(time.Minute).Before(c.expiration) { + return nil + } else { + // Refresh the token if its expiration is less than a minute away + newToken, err := c.refreshToken(c.token.Refresh) + if err != nil { + return err + } + c.token = newToken + c.expiration = time.Now().Add(time.Duration(newToken.RefreshExpires-60) * time.Second) + return nil + } +} + +func (c *Client) StartTokenHandler() { + c.stopChan = make(chan struct{}) + + // Initialize the first token and start the token handler + token, err := c.newToken() + if err != nil { + panic("Failed to get initial token: " + err.Error()) + } + c.token = token + + go func() { + for { + timeToWait := time.Until(c.expiration) - time.Minute + if timeToWait < 0 { + // If the token is already expired, try to refresh immediately + timeToWait = 0 + } + + select { + case <-c.stopChan: + return + case <-time.After(timeToWait): + if err := c.refreshTokenIfNeeded(); err != nil { + // TODO(Martin): add retry logic + panic("Failed to refresh token: " + err.Error()) + } + } + } + }() +} + +func (c *Client) StopTokenHandler() { + close(c.stopChan) +} + +func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) { req.URL.Scheme = "https" req.URL.Host = baseUrl req.URL.Path = strings.Join([]string{apiPath, req.URL.Path}, "/") @@ -35,33 +87,28 @@ func (t refreshTokenTransport) RoundTrip(req *http.Request) (*http.Response, err req.Header.Add("Content-Type", "application/json") req.Header.Add("Accept", "application/json") - t.cli.m.Lock() - - if t.cli.expiration.Before(time.Now()) { - t.cli.token, err = t.cli.refreshToken(t.cli.token.Refresh) - - if err != nil { - return nil, err - } - t.cli.expiration = t.cli.expiration.Add(time.Duration(t.cli.token.RefreshExpires-60) * time.Second) + // Add the access token to the request if it exists + if t.cli.token != nil { + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.cli.token.Access)) } - t.cli.m.Unlock() - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.cli.token.Access)) return t.rt.RoundTrip(req) } +// NewClient creates a new Nordigen client that handles token refreshes and adds +// the necessary headers, host, and path to all requests. func NewClient(secretId, secretKey string) (*Client, error) { - var err error + c := &Client{c: &http.Client{Timeout: 60 * time.Second}, m: &sync.Mutex{}, + secretId: secretId, + secretKey: secretKey, + } - c := &Client{c: &http.Client{Timeout: 60 * time.Second}, m: &sync.Mutex{}} - c.token, err = c.newToken(secretId, secretKey) + // Add transport to handle headers, host and path for all requests + c.c.Transport = Transport{rt: http.DefaultTransport, cli: c} - if err != nil { - return nil, err - } - c.c.Transport = refreshTokenTransport{rt: http.DefaultTransport, cli: c} - c.expiration = time.Now().Add(time.Duration(c.token.AccessExpires-60) * time.Second) + // Start token handler + c.StartTokenHandler() + defer c.StopTokenHandler() return c, nil } diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..e452181 --- /dev/null +++ b/client_test.go @@ -0,0 +1,30 @@ +package nordigen + +import ( + "os" + "testing" + "time" +) + +// TestClientTokenRefresh should do a successful token refresh. We force this by +// setting the expiration to a time in the past and then calling any method. +// This test will only run if you have a valid secretId and secretKey in your +// environment. +func TestClientTokenRefresh(t *testing.T) { + id, id_exists := os.LookupEnv("NORDIGEN_SECRET_ID") + key, key_exists := os.LookupEnv("NORDIGEN_SECRET_KEY") + if !id_exists || !key_exists { + t.Skip("NORDIGEN_SECRET_ID and NORDIGEN_SECRET_KEY not set") + } + + c, err := NewClient(id, key) + if err != nil { + t.Fatalf("NewClient: %s", err) + } + + c.expiration = time.Now().Add(-time.Hour) + _, err = c.ListRequisitions() + if err != nil { + t.Fatalf("ListRequisitions: %s", err) + } +} diff --git a/token.go b/token.go index 49a60a1..77fb6c2 100644 --- a/token.go +++ b/token.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/json" "io" - "io/ioutil" "net/http" "net/url" "strings" @@ -26,22 +25,17 @@ const tokenPath = "token" const tokenNewPath = "new/" const tokenRefreshPath = "refresh" -func (c Client) newToken(secretId, secretKey string) (*Token, error) { +func (c Client) newToken() (*Token, error) { req := http.Request{ Method: http.MethodPost, URL: &url.URL{ - Scheme: "https", - Host: baseUrl, - Path: strings.Join([]string{apiPath, tokenPath, tokenNewPath}, "/"), + Path: strings.Join([]string{tokenPath, tokenNewPath}, "/"), }, } - req.Header = http.Header{} - req.Header.Add("Content-Type", "application/json") - req.Header.Add("Accept", "application/json") data, err := json.Marshal(Secret{ - SecretId: secretId, - AccessId: secretKey, + SecretId: c.secretId, + AccessId: c.secretKey, }) if err != nil { return nil, err @@ -52,7 +46,7 @@ func (c Client) newToken(secretId, secretKey string) (*Token, error) { if err != nil { return nil, err } - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { return nil, err @@ -89,7 +83,7 @@ func (c Client) refreshToken(refresh string) (*Token, error) { if err != nil { return nil, err } - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { return nil, err