Skip to content

Commit

Permalink
feat(op): ID token for device authorization grant (#500)
Browse files Browse the repository at this point in the history
  • Loading branch information
muhlemmer authored Dec 18, 2023
1 parent 7bdaf9c commit b300027
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 42 deletions.
91 changes: 65 additions & 26 deletions pkg/op/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
strs "github.com/zitadel/oidc/v3/pkg/strings"
)

type DeviceAuthorizationConfig struct {
Expand Down Expand Up @@ -185,24 +186,6 @@ func NewUserCode(charSet []rune, charAmount, dashInterval int) (string, error) {
return buf.String(), nil
}

type deviceAccessTokenRequest struct {
subject string
audience []string
scopes []string
}

func (r *deviceAccessTokenRequest) GetSubject() string {
return r.subject
}

func (r *deviceAccessTokenRequest) GetAudience() []string {
return r.audience
}

func (r *deviceAccessTokenRequest) GetScopes() []string {
return r.scopes
}

func DeviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
ctx, span := tracer.Start(r.Context(), "DeviceAccessToken")
defer span.End()
Expand All @@ -229,7 +212,7 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
if err != nil {
return err
}
state, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
tokenRequest, err := CheckDeviceAuthorizationState(ctx, clientID, req.DeviceCode, exchanger)
if err != nil {
return err
}
Expand All @@ -243,11 +226,6 @@ func deviceAccessToken(w http.ResponseWriter, r *http.Request, exchanger Exchang
WithDescription("confidential client requires authentication")
}

tokenRequest := &deviceAccessTokenRequest{
subject: state.Subject,
audience: []string{clientID},
scopes: state.Scopes,
}
resp, err := CreateDeviceTokenResponse(r.Context(), tokenRequest, exchanger, client)
if err != nil {
return err
Expand All @@ -265,6 +243,50 @@ func ParseDeviceAccessTokenRequest(r *http.Request, exchanger Exchanger) (*oidc.
return req, nil
}

// DeviceAuthorizationState describes the current state of
// the device authorization flow.
// It implements the [IDTokenRequest] interface.
type DeviceAuthorizationState struct {
ClientID string
Audience []string
Scopes []string
Expires time.Time // The time after we consider the authorization request timed-out
Done bool // The user authenticated and approved the authorization request
Denied bool // The user authenticated and denied the authorization request

// The following fields are populated after Done == true
Subject string
AMR []string
AuthTime time.Time
}

func (r *DeviceAuthorizationState) GetAMR() []string {
return r.AMR
}

func (r *DeviceAuthorizationState) GetAudience() []string {
if !strs.Contains(r.Audience, r.ClientID) {
r.Audience = append(r.Audience, r.ClientID)
}
return r.Audience
}

func (r *DeviceAuthorizationState) GetAuthTime() time.Time {
return r.AuthTime
}

func (r *DeviceAuthorizationState) GetClientID() string {
return r.ClientID
}

func (r *DeviceAuthorizationState) GetScopes() []string {
return r.Scopes
}

func (r *DeviceAuthorizationState) GetSubject() string {
return r.Subject
}

func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string, exchanger Exchanger) (*DeviceAuthorizationState, error) {
storage, err := assertDeviceStorage(exchanger.Storage())
if err != nil {
Expand All @@ -291,15 +313,32 @@ func CheckDeviceAuthorizationState(ctx context.Context, clientID, deviceCode str
}

func CreateDeviceTokenResponse(ctx context.Context, tokenRequest TokenRequest, creator TokenCreator, client Client) (*oidc.AccessTokenResponse, error) {
/* TODO(v4):
Change the TokenRequest argument type to *DeviceAuthorizationState.
Breaking change that can not be done for v3.
*/
ctx, span := tracer.Start(ctx, "CreateDeviceTokenResponse")
defer span.End()

accessToken, refreshToken, validity, err := CreateAccessToken(ctx, tokenRequest, client.AccessTokenType(), creator, client, "")
if err != nil {
return nil, err
}

return &oidc.AccessTokenResponse{
response := &oidc.AccessTokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: oidc.BearerToken,
ExpiresIn: uint64(validity.Seconds()),
}, nil
}

// TODO(v4): remove type assertion
if idTokenRequest, ok := tokenRequest.(IDTokenRequest); ok && strs.Contains(tokenRequest.GetScopes(), oidc.ScopeOpenID) {
response.IDToken, err = CreateIDToken(ctx, IssuerFromContext(ctx), idTokenRequest, client.IDTokenLifetime(), accessToken, "", creator.Storage(), client)
if err != nil {
return nil, err
}
}

return response, nil
}
93 changes: 93 additions & 0 deletions pkg/op/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,96 @@ func TestCheckDeviceAuthorizationState(t *testing.T) {
})
}
}

func TestCreateDeviceTokenResponse(t *testing.T) {
tests := []struct {
name string
tokenRequest op.TokenRequest
wantAccessToken bool
wantRefreshToken bool
wantIDToken bool
wantErr bool
}{
{
name: "access token",
tokenRequest: &op.DeviceAuthorizationState{
ClientID: "client1",
Subject: "id1",
AMR: []string{"password"},
AuthTime: time.Now(),
},
wantAccessToken: true,
},
{
name: "access and refresh tokens",
tokenRequest: &op.DeviceAuthorizationState{
ClientID: "client1",
Subject: "id1",
AMR: []string{"password"},
AuthTime: time.Now(),
Scopes: []string{oidc.ScopeOfflineAccess},
},
wantAccessToken: true,
wantRefreshToken: true,
},
{
name: "access and id token",
tokenRequest: &op.DeviceAuthorizationState{
ClientID: "client1",
Subject: "id1",
AMR: []string{"password"},
AuthTime: time.Now(),
Scopes: []string{oidc.ScopeOpenID},
},
wantAccessToken: true,
wantIDToken: true,
},
{
name: "access, refresh and id token",
tokenRequest: &op.DeviceAuthorizationState{
ClientID: "client1",
Subject: "id1",
AMR: []string{"password"},
AuthTime: time.Now(),
Scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID},
},
wantAccessToken: true,
wantRefreshToken: true,
wantIDToken: true,
},
{
name: "id token creation error",
tokenRequest: &op.DeviceAuthorizationState{
ClientID: "client1",
Subject: "foobar",
AMR: []string{"password"},
AuthTime: time.Now(),
Scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, err := testProvider.Storage().GetClientByClientID(context.Background(), "native")
require.NoError(t, err)

got, err := op.CreateDeviceTokenResponse(context.Background(), tt.tokenRequest, testProvider, client)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.InDelta(t, 300, got.ExpiresIn, 2)
if tt.wantAccessToken {
assert.NotEmpty(t, got.AccessToken, "access token")
}
if tt.wantRefreshToken {
assert.NotEmpty(t, got.RefreshToken, "refresh token")
}
if tt.wantIDToken {
assert.NotEmpty(t, got.IDToken, "id token")
}
})
}
}
9 changes: 2 additions & 7 deletions pkg/op/server_legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,23 +291,18 @@ func (s *LegacyServer) ClientCredentialsExchange(ctx context.Context, r *ClientR
}

func (s *LegacyServer) DeviceToken(ctx context.Context, r *ClientRequest[oidc.DeviceAccessTokenRequest]) (*Response, error) {
if !s.provider.GrantTypeClientCredentialsSupported() {
if !s.provider.GrantTypeDeviceCodeSupported() {
return nil, unimplementedGrantError(oidc.GrantTypeDeviceCode)
}
// use a limited context timeout shorter as the default
// poll interval of 5 seconds.
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()

state, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider)
tokenRequest, err := CheckDeviceAuthorizationState(ctx, r.Client.GetID(), r.Data.DeviceCode, s.provider)
if err != nil {
return nil, err
}
tokenRequest := &deviceAccessTokenRequest{
subject: state.Subject,
audience: []string{r.Client.GetID()},
scopes: state.Scopes,
}
resp, err := CreateDeviceTokenResponse(ctx, tokenRequest, s.provider, r.Client)
if err != nil {
return nil, err
Expand Down
9 changes: 0 additions & 9 deletions pkg/op/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,6 @@ type EndSessionRequest struct {

var ErrDuplicateUserCode = errors.New("user code already exists")

type DeviceAuthorizationState struct {
ClientID string
Scopes []string
Expires time.Time
Done bool
Subject string
Denied bool
}

type DeviceAuthorizationStorage interface {
// StoreDeviceAuthorizationRequest stores a new device authorization request in the database.
// User code will be used by the user to complete the login flow and must be unique.
Expand Down
2 changes: 2 additions & 0 deletions pkg/op/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ func needsRefreshToken(tokenRequest TokenRequest, client AccessTokenClient) bool
return req.GetRequestedTokenType() == oidc.RefreshTokenType
case RefreshTokenRequest:
return true
case *DeviceAuthorizationState:
return strings.Contains(req.GetScopes(), oidc.ScopeOfflineAccess) && ValidateGrantType(client, oidc.GrantTypeRefreshToken)
default:
return false
}
Expand Down

0 comments on commit b300027

Please sign in to comment.