Skip to content

Commit

Permalink
Merge pull request #1765 from smallstep/mariano/init-provisioners
Browse files Browse the repository at this point in the history
Do not fail if a provisioner cannot be initialized
  • Loading branch information
maraino authored Jul 11, 2024
2 parents b6da1de + 343e730 commit 383d281
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 8 deletions.
5 changes: 4 additions & 1 deletion authority/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,10 @@ func (a *Authority) ReloadAdminResources(ctx context.Context) error {
provClxn := provisioner.NewCollection(provisionerConfig.Audiences)
for _, p := range provList {
if err := p.Init(provisionerConfig); err != nil {
return err
log.Printf("failed to initialize %s provisioner %q: %v\n", p.GetType(), p.GetName(), err)
p = provisioner.Uninitialized{
Interface: p, Reason: err,
}
}
if err := provClxn.Store(p); err != nil {
return err
Expand Down
9 changes: 9 additions & 0 deletions authority/authority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ func testAuthority(t *testing.T, opts ...Option) *Authority {
EnableSSHCA: &enableSSHCA,
},
},
&provisioner.JWK{
Name: "uninitialized",
Type: "JWK",
Key: clijwk,
Claims: &provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
MaxTLSDur: &provisioner.Duration{Duration: time.Minute},
},
},
}
c := &Config{
Address: "127.0.0.1:443",
Expand Down
4 changes: 4 additions & 0 deletions authority/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ func (a *Authority) getProvisionerFromToken(token string) (provisioner.Interface
if !ok {
return nil, nil, fmt.Errorf("provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
}
// If the provisioner is disabled, send an appropriate message to the client
if _, ok := p.(provisioner.Uninitialized); ok {
return nil, nil, errs.New(http.StatusUnauthorized, "provisioner %q is disabled due to an initialization error", p.GetName())
}

return p, &claims, nil
}
Expand Down
19 changes: 19 additions & 0 deletions authority/authorize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"go.step.sm/crypto/randutil"
"go.step.sm/crypto/x509util"

"github.com/google/uuid"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/provisioner"
Expand Down Expand Up @@ -304,6 +305,24 @@ func TestAuthority_authorizeToken(t *testing.T) {
code: http.StatusUnauthorized,
}
},
"fail/uninitialized": func(t *testing.T) *authorizeTest {
cl := jose.Claims{
Subject: "test.smallstep.com",
Issuer: "uninitialized",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(time.Minute)),
Audience: validAudience,
ID: uuid.NewString(),
}
raw, err := jose.Signed(sig).Claims(cl).CompactSerialize()
assert.FatalError(t, err)
return &authorizeTest{
auth: a,
token: raw,
err: errors.New(`provisioner "uninitialized" is disabled due to an initialization error`),
code: http.StatusUnauthorized,
}
},
}

for name, genTestCase := range tests {
Expand Down
25 changes: 25 additions & 0 deletions authority/provisioner/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,31 @@ type Interface interface {
AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error)
}

// Uninitialized represents a disabled provisioner. Uninitialized provisioners
// are created when the Init methods fails.
type Uninitialized struct {
Interface
Reason error
}

// MarshalJSON returns the JSON encoding of the provisioner with the disabled
// reason.
func (p Uninitialized) MarshalJSON() ([]byte, error) {
provisionerJSON, err := json.Marshal(p.Interface)
if err != nil {
return nil, err
}
reasonJSON, err := json.Marshal(struct {
State string `json:"state"`
StateReason string `json:"stateReason"`
}{"Uninitialized", p.Reason.Error()})
if err != nil {
return nil, err
}
reasonJSON[0] = ','
return append(provisionerJSON[:len(provisionerJSON)-1], reasonJSON...), nil
}

// ErrAllowTokenReuse is an error that is returned by provisioners that allows
// the reuse of tokens.
//
Expand Down
48 changes: 41 additions & 7 deletions authority/provisioner/provisioner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
"net/http"
"testing"

"golang.org/x/crypto/ssh"

"github.com/smallstep/assert"
"github.com/go-jose/go-jose/v3"
"github.com/smallstep/certificates/api/render"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ssh"
)

func TestType_String(t *testing.T) {
Expand Down Expand Up @@ -149,11 +149,11 @@ func TestDefaultIdentityFunc(t *testing.T) {
identity, err := DefaultIdentityFunc(context.Background(), tc.p, tc.email)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.Equals(t, tc.err.Error(), err.Error())
assert.Equal(t, tc.err.Error(), err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, identity.Usernames, tc.identity.Usernames)
assert.Equal(t, identity.Usernames, tc.identity.Usernames)
}
}
})
Expand Down Expand Up @@ -243,9 +243,43 @@ func TestUnimplementedMethods(t *testing.T) {
}
var sc render.StatusCodedError
if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") {
assert.Equals(t, sc.StatusCode(), http.StatusUnauthorized)
assert.Equal(t, http.StatusUnauthorized, sc.StatusCode())
}
assert.Equal(t, msg, err.Error())
})
}
}

func TestUninitialized_MarshalJSON(t *testing.T) {
p := &JWK{
Name: "bad-provisioner",
Type: "JWK",
Key: &jose.JSONWebKey{
Key: []byte("foo"),
},
}

type fields struct {
Interface Interface
Reason error
}
tests := []struct {
name string
fields fields
want []byte
assertion assert.ErrorAssertionFunc
}{
{"ok", fields{p, errors.New("bad key")}, []byte(`{"type":"JWK","name":"bad-provisioner","key":{"kty":"oct","k":"Zm9v"},"state":"Uninitialized","stateReason":"bad key"}`), assert.NoError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := Uninitialized{
Interface: tt.fields.Interface,
Reason: tt.fields.Reason,
}
assert.Equals(t, err.Error(), msg)
got, err := p.MarshalJSON()
tt.assertion(t, err)
assert.Equal(t, tt.want, got)
})
}
}

0 comments on commit 383d281

Please sign in to comment.