Skip to content

Commit

Permalink
Handle regen if priv key missing
Browse files Browse the repository at this point in the history
  • Loading branch information
terev committed Nov 1, 2024
1 parent a3ed44a commit 1900434
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
2 changes: 1 addition & 1 deletion cmd/server/helper_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func GetOrCreateTLSCertificate(ctx context.Context, d driver.Registry, iface con
}

// no certificates configured: self-sign a new cert
priv, err := jwk.GetOrGenerateKeys(ctx, d, d.SoftwareKeyManager(), TlsKeyName, uuid.Must(uuid.NewV4()).String(), "RS256")
priv, err := jwk.GetOrGenerateKeySetPrivateKey(ctx, d, d.SoftwareKeyManager(), TlsKeyName, uuid.Must(uuid.NewV4()).String(), "RS256")
if err != nil {
d.Logger().WithError(err).Fatal("Unable to fetch or generate HTTPS TLS key pair")
return nil // in case Fatal is hooked
Expand Down
16 changes: 13 additions & 3 deletions jwk/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,27 @@ import (
var jwkGenFlightGroup singleflight.Group

func EnsureAsymmetricKeypairExists(ctx context.Context, r InternalRegistry, alg, set string) error {
_, err := GetOrGenerateKeys(ctx, r, r.KeyManager(), set, set, alg)
_, err := GetOrGenerateKeySetPrivateKey(ctx, r, r.KeyManager(), set, set, alg)
return err
}

func GetOrGenerateKeys(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (private *jose.JSONWebKey, err error) {
func GetOrGenerateKeySetPrivateKey(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (private *jose.JSONWebKey, err error) {
keySet, err := GetOrGenerateKeySet(ctx, r, m, set, kid, alg)
if err != nil {
return nil, err
}

privKey, err := FindPrivateKey(keySet)
if err == nil {
return privKey, nil
}

keySet, err = generateKeySet(ctx, r, m, set, kid, alg)
if err != nil {
return nil, err
}
return privKey, nil

return FindPrivateKey(keySet)
}

func GetOrGenerateKeySet(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (*jose.JSONWebKeySet, error) {
Expand All @@ -54,6 +60,10 @@ func GetOrGenerateKeySet(ctx context.Context, r InternalRegistry, m Manager, set
return keys, nil
}

return generateKeySet(ctx, r, m, set, kid, alg)
}

func generateKeySet(ctx context.Context, r InternalRegistry, m Manager, set, kid, alg string) (*jose.JSONWebKeySet, error) {
// Suppress duplicate key set generation jobs where the set+alg match.
keysResult, err, _ := jwkGenFlightGroup.Do(set+alg, func() (any, error) {
r.Logger().WithField("jwks", set).Warnf("JSON Web Key not found in JSON Web Key Set %s, generating new key pair...", set)
Expand Down
13 changes: 7 additions & 6 deletions jwk/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ory/x/contextx"

"github.com/ory/hydra/v2/internal"
"github.com/ory/hydra/v2/jwk"
"github.com/ory/hydra/v2/x"
"github.com/ory/x/contextx"
)

type fakeSigner struct {
Expand Down Expand Up @@ -229,7 +230,7 @@ func TestGetOrGenerateKeys(t *testing.T) {
t.Run("Test_Helper/Run_GetOrGenerateKeys_With_GetKeySetError", func(t *testing.T) {
keyManager := km(t)
keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(nil, errors.New("GetKeySetError"))
privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256")
privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256")
assert.Nil(t, privKey)
assert.EqualError(t, err, "GetKeySetError")
})
Expand All @@ -238,7 +239,7 @@ func TestGetOrGenerateKeys(t *testing.T) {
keyManager := km(t)
keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(nil, errors.Wrap(x.ErrNotFound, ""))
keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(nil, errors.New("GetKeySetError"))
privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256")
privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256")
assert.Nil(t, privKey)
assert.EqualError(t, err, "GetKeySetError")
})
Expand All @@ -247,7 +248,7 @@ func TestGetOrGenerateKeys(t *testing.T) {
keyManager := km(t)
keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil)
keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(nil, errors.New("GetKeySetError"))
privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256")
privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256")
assert.Nil(t, privKey)
assert.EqualError(t, err, "GetKeySetError")
})
Expand All @@ -256,7 +257,7 @@ func TestGetOrGenerateKeys(t *testing.T) {
keyManager := km(t)
keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil)
keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySet, nil)
privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256")
privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256")
assert.NoError(t, err)
assert.Equal(t, privKey, &keySet.Keys[0])
})
Expand All @@ -265,7 +266,7 @@ func TestGetOrGenerateKeys(t *testing.T) {
keyManager := km(t)
keyManager.EXPECT().GetKeySet(gomock.Any(), gomock.Eq(setId)).Return(keySetWithoutPrivateKey, nil)
keyManager.EXPECT().GenerateAndPersistKeySet(gomock.Any(), gomock.Eq(setId), gomock.Eq(keyId), gomock.Eq("RS256"), gomock.Eq("sig")).Return(keySetWithoutPrivateKey, nil).Times(1)
privKey, err := jwk.GetOrGenerateKeys(context.TODO(), reg, keyManager, setId, keyId, "RS256")
privKey, err := jwk.GetOrGenerateKeySetPrivateKey(context.TODO(), reg, keyManager, setId, keyId, "RS256")
assert.Nil(t, privKey)
assert.EqualError(t, err, "key not found")
})
Expand Down
3 changes: 2 additions & 1 deletion jwk/jwt_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/gofrs/uuid"

"github.com/ory/fosite"

"github.com/ory/hydra/v2/driver/config"

"github.com/pkg/errors"
Expand Down Expand Up @@ -40,7 +41,7 @@ func NewDefaultJWTSigner(c *config.DefaultProvider, r InternalRegistry, setID st
}

func (j *DefaultJWTSigner) getKeys(ctx context.Context) (private *jose.JSONWebKey, err error) {
private, err = GetOrGenerateKeys(ctx, j.r, j.r.KeyManager(), j.setID, uuid.Must(uuid.NewV4()).String(), string(jose.RS256))
private, err = GetOrGenerateKeySetPrivateKey(ctx, j.r, j.r.KeyManager(), j.setID, uuid.Must(uuid.NewV4()).String(), string(jose.RS256))
if err == nil {
return private, nil
}
Expand Down

0 comments on commit 1900434

Please sign in to comment.