From e770faeeed854cfcff45a06100f9a2eef57d779a Mon Sep 17 00:00:00 2001 From: Miguel Martinez Trivino Date: Mon, 13 Mar 2023 16:35:39 +0100 Subject: [PATCH] refactor: consolidate credentials service interface (#25) --------- Signed-off-by: Miguel Martinez Trivino --- app/cli/cmd/workflow_contract_describe.go | 8 +- .../internal/action/workflow_run_describe.go | 6 +- app/controlplane/internal/biz/attestation.go | 2 +- app/controlplane/internal/biz/integration.go | 16 ++- .../internal/biz/ocirepository.go | 11 +- .../internal/biz/ocirepository_test.go | 4 +- .../biz/organization_integration_test.go | 8 +- .../internal/service/attestation.go | 8 +- internal/attestation/crafter/crafter.go | 6 +- internal/blobmanager/oci/provider.go | 7 +- internal/blobmanager/oci/provider_test.go | 2 +- internal/credentials/aws/secretmanager.go | 66 +++-------- .../credentials/aws/secretmanager_test.go | 103 ++++++------------ internal/credentials/credentials.go | 9 +- internal/credentials/credentials_test.go | 76 +++++++++++++ internal/credentials/mocks/Reader.go | 25 +---- internal/credentials/mocks/ReaderWriter.go | 71 +++--------- internal/credentials/mocks/Writer.go | 47 ++------ internal/credentials/vault/keyval.go | 49 ++------- internal/credentials/vault/keyval_test.go | 86 +++++---------- 20 files changed, 248 insertions(+), 362 deletions(-) create mode 100644 internal/credentials/credentials_test.go diff --git a/app/cli/cmd/workflow_contract_describe.go b/app/cli/cmd/workflow_contract_describe.go index 227677549..ad819fcc4 100644 --- a/app/cli/cmd/workflow_contract_describe.go +++ b/app/cli/cmd/workflow_contract_describe.go @@ -70,8 +70,8 @@ func encodeContractOutput(run *action.WorkflowContractWithVersionItem) error { switch flagOutputFormat { case formatContract: - marshaller := protojson.MarshalOptions{Indent: " "} - rawBody, err := marshaller.Marshal(run.Revision.BodyV1) + marshaler := protojson.MarshalOptions{Indent: " "} + rawBody, err := marshaler.Marshal(run.Revision.BodyV1) if err != nil { return err } @@ -86,8 +86,8 @@ func encodeContractOutput(run *action.WorkflowContractWithVersionItem) error { func contractDescribeTableOutput(contractWithVersion *action.WorkflowContractWithVersionItem) error { revision := contractWithVersion.Revision - marshaller := protojson.MarshalOptions{Indent: " "} - rawBody, err := marshaller.Marshal(revision.BodyV1) + marshaler := protojson.MarshalOptions{Indent: " "} + rawBody, err := marshaler.Marshal(revision.BodyV1) if err != nil { return err } diff --git a/app/cli/internal/action/workflow_run_describe.go b/app/cli/internal/action/workflow_run_describe.go index 82226df91..a050547ae 100644 --- a/app/cli/internal/action/workflow_run_describe.go +++ b/app/cli/internal/action/workflow_run_describe.go @@ -122,7 +122,7 @@ func (action *WorkflowRunDescribe) Run(runID string, verify bool, publicKey stri } if err := json.Unmarshal(decodedPayload, statement); err != nil { - return nil, fmt.Errorf("unmarshalling predicate: %w", err) + return nil, fmt.Errorf("un-marshaling predicate: %w", err) } var predicate *renderer.ChainloopProvenancePredicateV1 @@ -159,12 +159,12 @@ func (action *WorkflowRunDescribe) Run(runID string, verify bool, publicKey stri func extractPredicateV1(statement *in_toto.Statement) (*renderer.ChainloopProvenancePredicateV1, error) { jsonPredicate, err := json.Marshal(statement.Predicate) if err != nil { - return nil, fmt.Errorf("unmarshalling predicate: %w", err) + return nil, fmt.Errorf("un-marshaling predicate: %w", err) } predicate := &renderer.ChainloopProvenancePredicateV1{} if err := json.Unmarshal(jsonPredicate, predicate); err != nil { - return nil, fmt.Errorf("unmarshalling predicate: %w", err) + return nil, fmt.Errorf("un-marshaling predicate: %w", err) } return predicate, nil diff --git a/app/controlplane/internal/biz/attestation.go b/app/controlplane/internal/biz/attestation.go index c4c4c8eee..d9ca08b40 100644 --- a/app/controlplane/internal/biz/attestation.go +++ b/app/controlplane/internal/biz/attestation.go @@ -85,7 +85,7 @@ func doUploadToOCI(ctx context.Context, backend backend.Uploader, runID string, fileName := fmt.Sprintf("attestation-%s.json", runID) jsonContent, err := json.Marshal(envelope) if err != nil { - return "", fmt.Errorf("marshalling the envelope: %w", err) + return "", fmt.Errorf("marshaling the envelope: %w", err) } hash := sha256.New() diff --git a/app/controlplane/internal/biz/integration.go b/app/controlplane/internal/biz/integration.go index fc7ba6a20..9987d864d 100644 --- a/app/controlplane/internal/biz/integration.go +++ b/app/controlplane/internal/biz/integration.go @@ -97,8 +97,14 @@ func (uc *IntegrationUseCase) AddDependencyTrack(ctx context.Context, orgID, hos return nil, NewErrInvalidUUID(err) } + // Validate Credentials before saving them + creds := &credentials.APICreds{Host: host, Key: apiKey} + if err := creds.Validate(); err != nil { + return nil, newErrValidation(err) + } + // Create the secret in the external secrets manager - secretID, err := uc.credsRW.SaveAPICreds(ctx, orgID, &credentials.APICreds{Host: host, Key: apiKey}) + secretID, err := uc.credsRW.SaveCredentials(ctx, orgID, creds) if err != nil { return nil, fmt.Errorf("storing the credentials: %w", err) } @@ -160,7 +166,7 @@ func (uc *IntegrationUseCase) Delete(ctx context.Context, orgID, integrationID s if integration.SecretName != "" { uc.logger.Infow("msg", "deleting integration external secrets", "ID", integrationID, "secretName", integration.SecretName) - if err := uc.credsRW.DeleteCreds(ctx, integration.SecretName); err != nil { + if err := uc.credsRW.DeleteCredentials(ctx, integration.SecretName); err != nil { return fmt.Errorf("deleting the credentials: %w", err) } } @@ -274,10 +280,14 @@ func validateAttachment(ctx context.Context, integration *Integration, credsR cr // Check with the actual remote data that an upload would be possible creds := &credentials.APICreds{} - if err := credsR.ReadAPICreds(ctx, integration.SecretName, creds); err != nil { + if err := credsR.ReadCredentials(ctx, integration.SecretName, creds); err != nil { return err } + if err := creds.Validate(); err != nil { + return newErrValidation(err) + } + // Instantiate an actual uploader to see if it would work with the current configuration d, err := dependencytrack.NewSBOMUploader(c.DependencyTrack.GetDomain(), creds.Key, nil, ac.GetDependencyTrack().GetProjectId(), ac.GetDependencyTrack().GetProjectName()) diff --git a/app/controlplane/internal/biz/ocirepository.go b/app/controlplane/internal/biz/ocirepository.go index c365253ef..18a91bc61 100644 --- a/app/controlplane/internal/biz/ocirepository.go +++ b/app/controlplane/internal/biz/ocirepository.go @@ -113,8 +113,13 @@ func (uc *OCIRepositoryUseCase) CreateOrUpdate(ctx context.Context, orgID, repoU return nil, NewErrInvalidUUID(err) } - // Create the secret in the external secrets manager - secretName, err := uc.credsRW.SaveOCICreds(ctx, orgID, &credentials.OCIKeypair{Repo: repoURL, Username: username, Password: password}) + // Validate and store the secret in the external secrets manager + creds := &credentials.OCIKeypair{Repo: repoURL, Username: username, Password: password} + if err := creds.Validate(); err != nil { + return nil, newErrValidation(err) + } + + secretName, err := uc.credsRW.SaveCredentials(ctx, orgID, creds) if err != nil { return nil, fmt.Errorf("storing the credentials: %w", err) } @@ -167,7 +172,7 @@ func (uc *OCIRepositoryUseCase) Delete(ctx context.Context, id string) error { uc.logger.Infow("msg", "deleting OCI repository external secrets", "ID", id, "secretName", repo.SecretName) // Delete the secret in the external secrets manager - if err := uc.credsRW.DeleteCreds(ctx, repo.SecretName); err != nil { + if err := uc.credsRW.DeleteCredentials(ctx, repo.SecretName); err != nil { return fmt.Errorf("deleting the credentials: %w", err) } diff --git a/app/controlplane/internal/biz/ocirepository_test.go b/app/controlplane/internal/biz/ocirepository_test.go index 2f87b1cfc..9d4e5941b 100644 --- a/app/controlplane/internal/biz/ocirepository_test.go +++ b/app/controlplane/internal/biz/ocirepository_test.go @@ -85,7 +85,7 @@ func (s *ociRepositoryTestSuite) TestSaveMainRepoAlreadyExist() { r := &biz.OCIRepository{ID: s.validUUID.String()} ctx := context.Background() s.repo.On("FindMainRepo", ctx, s.validUUID).Return(r, nil) - s.credsRW.On("SaveOCICreds", ctx, s.validUUID.String(), mock.Anything).Return("secret-key", nil) + s.credsRW.On("SaveCredentials", ctx, s.validUUID.String(), mock.Anything).Return("secret-key", nil) s.repo.On("Update", ctx, &biz.OCIRepoUpdateOpts{ ID: s.validUUID, OCIRepoOpts: &biz.OCIRepoOpts{ @@ -105,7 +105,7 @@ func (s *ociRepositoryTestSuite) TestSaveMainRepoOk() { const repo, username, password = "repo", "username", "pass" s.repo.On("FindMainRepo", ctx, s.validUUID).Return(nil, nil) - s.credsRW.On("SaveOCICreds", ctx, s.validUUID.String(), mock.Anything).Return("secret-key", nil) + s.credsRW.On("SaveCredentials", ctx, s.validUUID.String(), mock.Anything).Return("secret-key", nil) newRepo := &biz.OCIRepository{} s.repo.On("Create", ctx, &biz.OCIRepoCreateOpts{ diff --git a/app/controlplane/internal/biz/organization_integration_test.go b/app/controlplane/internal/biz/organization_integration_test.go index 1f45dae1d..447f116e4 100644 --- a/app/controlplane/internal/biz/organization_integration_test.go +++ b/app/controlplane/internal/biz/organization_integration_test.go @@ -53,8 +53,8 @@ func (s *OrgIntegrationTestSuite) TestDeleteOrg() { s.T().Run("org, integrations and repositories deletion", func(t *testing.T) { // Mock calls to credentials deletion for both the integration and the OCI repository - s.mockedCredsReaderWriter.On("DeleteCreds", ctx, "stored-integration-secret").Return(nil) - s.mockedCredsReaderWriter.On("DeleteCreds", ctx, "stored-OCI-secret").Return(nil) + s.mockedCredsReaderWriter.On("DeleteCredentials", ctx, "stored-integration-secret").Return(nil) + s.mockedCredsReaderWriter.On("DeleteCredentials", ctx, "stored-OCI-secret").Return(nil) err := s.Organization.Delete(ctx, s.org.ID) assert.NoError(err) @@ -102,12 +102,12 @@ func (s *OrgIntegrationTestSuite) SetupTest() { // Dependency-track integration credentials s.mockedCredsReaderWriter.On( - "SaveAPICreds", ctx, mock.Anything, &credentials.APICreds{Host: "host", Key: "key"}, + "SaveCredentials", ctx, mock.Anything, &credentials.APICreds{Host: "host", Key: "key"}, ).Return("stored-integration-secret", nil) // OCI repository credentials s.mockedCredsReaderWriter.On( - "SaveOCICreds", ctx, mock.Anything, &credentials.OCIKeypair{Repo: "repo", Username: "username", Password: "pass"}, + "SaveCredentials", ctx, mock.Anything, &credentials.OCIKeypair{Repo: "repo", Username: "username", Password: "pass"}, ).Return("stored-OCI-secret", nil) s.TestingUseCases = testhelpers.NewTestingUseCases(t, testhelpers.WithCredsReaderWriter(s.mockedCredsReaderWriter)) diff --git a/app/controlplane/internal/service/attestation.go b/app/controlplane/internal/service/attestation.go index 187ec46d8..9820bf4f2 100644 --- a/app/controlplane/internal/service/attestation.go +++ b/app/controlplane/internal/service/attestation.go @@ -312,7 +312,7 @@ func extractPredicate(envelope *dsse.Envelope) (*renderer.ChainloopProvenancePre statement := &in_toto.Statement{} if err := json.Unmarshal(decodedPayload, statement); err != nil { - return nil, fmt.Errorf("unmarshalling predicate: %w", err) + return nil, fmt.Errorf("un-marshaling predicate: %w", err) } var predicate *renderer.ChainloopProvenancePredicateV1 @@ -370,12 +370,12 @@ func extractMaterials(in []*renderer.ChainloopProvenanceMaterial) []*cpAPI.Attes func extractPredicateV1(statement *in_toto.Statement) (*renderer.ChainloopProvenancePredicateV1, error) { jsonPredicate, err := json.Marshal(statement.Predicate) if err != nil { - return nil, fmt.Errorf("unmarshalling predicate: %w", err) + return nil, fmt.Errorf("un-marshaling predicate: %w", err) } predicate := &renderer.ChainloopProvenancePredicateV1{} if err := json.Unmarshal(jsonPredicate, predicate); err != nil { - return nil, fmt.Errorf("unmarshalling predicate: %w", err) + return nil, fmt.Errorf("un-marshaling predicate: %w", err) } return predicate, nil @@ -503,7 +503,7 @@ func doSendToDependencyTrack(ctx context.Context, credsReader credentials.Reader attachmentConfig := i.IntegrationAttachment.Config.GetDependencyTrack() creds := &credentials.APICreds{} - if err := credsReader.ReadAPICreds(ctx, i.SecretName, creds); err != nil { + if err := credsReader.ReadCredentials(ctx, i.SecretName, creds); err != nil { return err } diff --git a/internal/attestation/crafter/crafter.go b/internal/attestation/crafter/crafter.go index 78139e28d..848fdad4d 100644 --- a/internal/attestation/crafter/crafter.go +++ b/internal/attestation/crafter/crafter.go @@ -87,7 +87,7 @@ func NewCrafter(opts ...NewOpt) *Crafter { type InitOpts struct { // Control plane workflow metadata WfInfo *api.WorkflowMetadata - // already marshalled schema + // already marshaled schema SchemaV1 *schemaapi.CraftingSchema // do not record, upload or push attestation DryRun bool @@ -243,8 +243,8 @@ func initialCraftingState(schema *schemaapi.CraftingSchema, wf *api.WorkflowMeta } func persistCraftingState(craftState *api.CraftingState, stateFilePath string) error { - marshaller := protojson.MarshalOptions{Indent: " "} - raw, err := marshaller.Marshal(craftState) + marshaler := protojson.MarshalOptions{Indent: " "} + raw, err := marshaler.Marshal(craftState) if err != nil { return err } diff --git a/internal/blobmanager/oci/provider.go b/internal/blobmanager/oci/provider.go index cff3b3447..861200174 100644 --- a/internal/blobmanager/oci/provider.go +++ b/internal/blobmanager/oci/provider.go @@ -17,6 +17,7 @@ package oci import ( "context" + "fmt" backend "github.com/chainloop-dev/chainloop/internal/blobmanager" "github.com/chainloop-dev/chainloop/internal/credentials" @@ -35,10 +36,14 @@ func NewBackendProvider(cReader credentials.Reader) *BackendProvider { func (p *BackendProvider) FromCredentials(ctx context.Context, secretName string) (backend.UploaderDownloader, error) { creds := &credentials.OCIKeypair{} - if err := p.cReader.ReadOCICreds(ctx, secretName, creds); err != nil { + if err := p.cReader.ReadCredentials(ctx, secretName, creds); err != nil { return nil, err } + if err := creds.Validate(); err != nil { + return nil, fmt.Errorf("invalid credentials retrieved from storage: %w", err) + } + k, err := ociauth.NewCredentials(creds.Repo, creds.Username, creds.Password) if err != nil { return nil, err diff --git a/internal/blobmanager/oci/provider_test.go b/internal/blobmanager/oci/provider_test.go index 712a9fc48..67e54492f 100644 --- a/internal/blobmanager/oci/provider_test.go +++ b/internal/blobmanager/oci/provider_test.go @@ -32,7 +32,7 @@ func TestFromCredentials(t *testing.T) { r := mocks.NewReader(t) const repo, password, username = "repo", "password", "username" - r.On("ReadOCICreds", ctx, "secretName", mock.AnythingOfType("*credentials.OCIKeypair")).Return(nil).Run( + r.On("ReadCredentials", ctx, "secretName", mock.AnythingOfType("*credentials.OCIKeypair")).Return(nil).Run( func(args mock.Arguments) { credentials := args.Get(2).(*credentials.OCIKeypair) credentials.Repo = repo diff --git a/internal/credentials/aws/secretmanager.go b/internal/credentials/aws/secretmanager.go index b89ff414d..e6ab1e235 100644 --- a/internal/credentials/aws/secretmanager.go +++ b/internal/credentials/aws/secretmanager.go @@ -83,55 +83,15 @@ func NewManager(opts *NewManagerOpts) (*Manager, error) { }, nil } -func (m *Manager) SaveOCICreds(ctx context.Context, orgID string, creds *credentials.OCIKeypair) (string, error) { - if err := creds.Validate(); err != nil { - return "", fmt.Errorf("validating OCI keypair: %w", err) - } - - return m.save(ctx, orgID, creds) -} - -func (m *Manager) SaveAPICreds(ctx context.Context, orgID string, creds *credentials.APICreds) (string, error) { - if err := creds.Validate(); err != nil { - return "", fmt.Errorf("validating API creds: %w", err) - } - - return m.save(ctx, orgID, creds) -} - -func (m *Manager) ReadAPICreds(ctx context.Context, secretID string, creds *credentials.APICreds) error { - raw, err := m.read(ctx, secretID) - if err != nil { - return fmt.Errorf("getting the secret from AWS: %w", err) - } - - return json.Unmarshal(raw, creds) -} - -func (m *Manager) ReadOCICreds(ctx context.Context, secretID string, creds *credentials.OCIKeypair) error { - raw, err := m.read(ctx, secretID) - if err != nil { - return fmt.Errorf("getting the secret from AWS: %w", err) - } - - return json.Unmarshal(raw, creds) -} - -func (m *Manager) DeleteCreds(ctx context.Context, secretID string) error { - _, err := m.client.DeleteSecret(ctx, &secretsmanager.DeleteSecretInput{ - SecretId: aws.String(secretID), - }) - - return err -} - -func (m *Manager) save(ctx context.Context, orgID string, creds interface{}) (string, error) { +// Save Credentials, this is a generic function that can be used to save any type of credentials +// as long as they can be passed to json.Marshal +func (m *Manager) SaveCredentials(ctx context.Context, orgID string, creds any) (string, error) { secretName := strings.Join([]string{m.secretPrefix, orgID, uuid.Generate().String()}, "/") - // Store the credentials as json keypairs + // Store the credentials as json key pairs c, err := json.Marshal(creds) if err != nil { - return "", fmt.Errorf("marshalling credentials to be stored: %w", err) + return "", fmt.Errorf("marshaling credentials to be stored: %w", err) } if _, err = m.client.CreateSecret(ctx, &secretsmanager.CreateSecretInput{ @@ -143,7 +103,7 @@ func (m *Manager) save(ctx context.Context, orgID string, creds interface{}) (st return secretName, nil } -func (m *Manager) read(ctx context.Context, secretID string) ([]byte, error) { +func (m *Manager) ReadCredentials(ctx context.Context, secretID string, creds any) error { resp, err := m.client.GetSecretValue(ctx, &secretsmanager.GetSecretValueInput{ SecretId: aws.String(secretID), }) @@ -153,12 +113,20 @@ func (m *Manager) read(ctx context.Context, secretID string) ([]byte, error) { if errors.As(err, &apiErr) { switch apiErr.ErrorCode() { case (&types.ResourceNotFoundException{}).ErrorCode(): - return nil, fmt.Errorf("%w: path=%s", credentials.ErrNotFound, secretID) + return fmt.Errorf("%w: path=%s", credentials.ErrNotFound, secretID) default: - return nil, err + return err } } } - return []byte(*resp.SecretString), nil + return json.Unmarshal([]byte(*resp.SecretString), creds) +} + +func (m *Manager) DeleteCredentials(ctx context.Context, secretID string) error { + _, err := m.client.DeleteSecret(ctx, &secretsmanager.DeleteSecretInput{ + SecretId: aws.String(secretID), + }) + + return err } diff --git a/internal/credentials/aws/secretmanager_test.go b/internal/credentials/aws/secretmanager_test.go index 5b9105bcf..17e4268fa 100644 --- a/internal/credentials/aws/secretmanager_test.go +++ b/internal/credentials/aws/secretmanager_test.go @@ -17,6 +17,8 @@ package aws import ( "context" + "encoding/json" + "reflect" "testing" "github.com/aws/aws-sdk-go-v2/aws" @@ -65,48 +67,56 @@ const defaultRegion = "default-region" const defaultAccessKey = "access-key-not-a-real-key" const defaultSecretKey = "secret-key-not-a-real-key" -func (s *testSuite) TestReadWriteOCICreds() { +func (s *testSuite) TestReadWriteCredentials() { assert := assert.New(s.T()) - validCreds := &credentials.OCIKeypair{Repo: "test-repo", Username: "username", Password: "password"} - //nolint:gosec - // This is a test secret, it is not a real secret - validCredsString := "{\"Repo\":\"test-repo\",\"Username\":\"username\",\"Password\":\"password\"}" + validOCICreds := &credentials.OCIKeypair{Repo: "test-repo", Username: "username", Password: "password"} + validAPICreds := &credentials.APICreds{Host: "h", Key: "k"} testCases := []struct { name string - want *credentials.OCIKeypair + want any path string expectedError bool }{ - {"empty secret", &credentials.OCIKeypair{}, "", true}, - {"missing repo", &credentials.OCIKeypair{Username: "un", Password: "p"}, "", true}, - {"missing username", &credentials.OCIKeypair{Username: "", Password: "p", Repo: "repo"}, "", true}, - {"missing password", &credentials.OCIKeypair{Username: "u", Password: "", Repo: "repo"}, "", true}, - {"valid creds", validCreds, "", false}, - {"valid creds custom path", validCreds, "fooo", false}, + {"valid OCI creds", validOCICreds, "", false}, + {"valid OCI creds custom path", validOCICreds, "fooo", false}, + {"valid API creds", validAPICreds, "", false}, + {"valid API creds custom path", validAPICreds, "fooo", false}, } for _, tc := range testCases { s.Run(tc.name, func() { + // Re-set the manager mocked expectations + initMockedManager(s) m := s.mockedManager mc, _ := m.client.(*mclient.SecretsManagerIface) ctx := context.Background() mc.On("CreateSecret", ctx, mock.Anything).Return(nil, nil) - secretName, err := m.SaveOCICreds(ctx, orgID, tc.want) + secretName, err := m.SaveCredentials(ctx, orgID, tc.want) if tc.expectedError { assert.Error(err) return } - assert.NoError(err) + mockedResp, err := json.Marshal(tc.want) + require.NoError(s.T(), err) + // Read the keypair - got := &credentials.OCIKeypair{} mc.On("GetSecretValue", ctx, &secretsmanager.GetSecretValueInput{ SecretId: aws.String(secretName), - }).Return(&secretsmanager.GetSecretValueOutput{SecretString: aws.String(validCredsString)}, nil) + }).Return(&secretsmanager.GetSecretValueOutput{SecretString: aws.String(string(mockedResp))}, nil) + + // Choose the returning struct + var got any + switch reflect.TypeOf(tc.want).String() { + case "*credentials.APICreds": + got = &credentials.APICreds{} + case "*credentials.OCIKeypair": + got = &credentials.OCIKeypair{} + } - err = m.ReadOCICreds(ctx, secretName, got) + err = m.ReadCredentials(ctx, secretName, got) assert.NoError(err) // Compare the keypair @@ -117,7 +127,7 @@ func (s *testSuite) TestReadWriteOCICreds() { SecretId: aws.String("invalid"), }).Return(nil, &types.ResourceNotFoundException{}) - err = m.ReadOCICreds(ctx, "invalid", got) + err = m.ReadCredentials(ctx, "invalid", got) assert.Error(err) assert.ErrorIs(err, credentials.ErrNotFound) }) @@ -125,7 +135,7 @@ func (s *testSuite) TestReadWriteOCICreds() { } // // Create a new secret, delete it and check it does not exist antymore -func (s *testSuite) TestDeleteCreds() { +func (s *testSuite) TestDeleteCredentials() { assert := assert.New(s.T()) m := s.mockedManager mc, _ := m.client.(*mclient.SecretsManagerIface) @@ -136,58 +146,9 @@ func (s *testSuite) TestDeleteCreds() { SecretId: aws.String(secretName), }).Return(nil, nil) - err := m.DeleteCreds(ctx, secretName) + err := m.DeleteCredentials(ctx, secretName) assert.NoError(err) } -func (s *testSuite) TestReadWriteAPICreds() { - assert := assert.New(s.T()) - validCreds := &credentials.APICreds{Host: "http://hospath.local", Key: "api-key-not-a-secret"} - //nolint:gosec - // This is a test secret, it is not a real secret - validCredsString := "{\"Host\":\"http://hospath.local\",\"Key\":\"api-key-not-a-secret\"}" - - testCases := []struct { - name string - want *credentials.APICreds - path string - expectedError bool - }{ - {"empty secret", &credentials.APICreds{}, "", true}, - {"missing host", &credentials.APICreds{Host: "", Key: "p"}, "", true}, - {"missing key", &credentials.APICreds{Host: "host", Key: ""}, "", true}, - {"valid creds", validCreds, "", false}, - {"valid creds custom path", validCreds, "fooo", false}, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - m := s.mockedManager - mc, _ := m.client.(*mclient.SecretsManagerIface) - ctx := context.Background() - - mc.On("CreateSecret", ctx, mock.Anything).Return(nil, nil) - - secretName, err := m.SaveAPICreds(ctx, orgID, tc.want) - if tc.expectedError { - assert.Error(err) - return - } - - assert.NoError(err) - // Read the keypair - got := &credentials.APICreds{} - mc.On("GetSecretValue", ctx, &secretsmanager.GetSecretValueInput{ - SecretId: aws.String(secretName), - }).Return(&secretsmanager.GetSecretValueOutput{SecretString: aws.String(validCredsString)}, nil) - - err = m.ReadAPICreds(ctx, secretName, got) - assert.NoError(err) - - // Compare the keypair - assert.Equal(tc.want, got) - }) - } -} type testSuite struct { suite.Suite @@ -196,6 +157,10 @@ type testSuite struct { // Run before each test func (s *testSuite) SetupTest() { + initMockedManager(s) +} + +func initMockedManager(s *testSuite) { opts := &NewManagerOpts{Region: defaultRegion, AccessKey: defaultAccessKey, SecretKey: defaultSecretKey} m, err := NewManager(opts) require.NoError(s.T(), err) diff --git a/internal/credentials/credentials.go b/internal/credentials/credentials.go index 2fafd7958..f8de8d78a 100644 --- a/internal/credentials/credentials.go +++ b/internal/credentials/credentials.go @@ -34,16 +34,13 @@ type ReaderWriter interface { Writer } -// TODO: Add generics type Writer interface { - SaveAPICreds(ctx context.Context, org string, creds *APICreds) (string, error) - SaveOCICreds(ctx context.Context, org string, creds *OCIKeypair) (string, error) - DeleteCreds(ctx context.Context, credID string) error + SaveCredentials(ctx context.Context, org string, credentials any) (string, error) + DeleteCredentials(ctx context.Context, credID string) error } type Reader interface { - ReadAPICreds(ctx context.Context, secretName string, creds *APICreds) error - ReadOCICreds(ctx context.Context, secretName string, creds *OCIKeypair) error + ReadCredentials(ctx context.Context, secretName string, credentials any) error } var ErrNotFound = errors.New("credentials not found") diff --git a/internal/credentials/credentials_test.go b/internal/credentials/credentials_test.go new file mode 100644 index 000000000..587086f03 --- /dev/null +++ b/internal/credentials/credentials_test.go @@ -0,0 +1,76 @@ +// +// Copyright 2023 The Chainloop Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package credentials_test + +import ( + "testing" + + "github.com/chainloop-dev/chainloop/internal/credentials" + "github.com/stretchr/testify/assert" +) + +func TestValidateAPICreds(t *testing.T) { + assert := assert.New(t) + + testCases := []struct { + name string + input *credentials.APICreds + wantError bool + }{ + {"empty secret", &credentials.APICreds{}, true}, + {"missing host", &credentials.APICreds{Host: "", Key: "p"}, true}, + {"missing key", &credentials.APICreds{Host: "host", Key: ""}, true}, + {"valid creds", &credentials.APICreds{Host: "h", Key: "p"}, false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.input.Validate() + if tc.wantError { + assert.Error(err) + } else { + assert.NoError(err) + } + }) + } +} + +func TestValidateOCIKeyPair(t *testing.T) { + assert := assert.New(t) + + testCases := []struct { + name string + input *credentials.OCIKeypair + wantError bool + }{ + {"empty secret", &credentials.OCIKeypair{}, true}, + {"missing repo", &credentials.OCIKeypair{Username: "un", Password: "p"}, true}, + {"missing username", &credentials.OCIKeypair{Username: "", Password: "p", Repo: "repo"}, true}, + {"missing password", &credentials.OCIKeypair{Username: "u", Password: "", Repo: "repo"}, true}, + {"valid creds", &credentials.OCIKeypair{Username: "u", Password: "p", Repo: "repo"}, false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.input.Validate() + if tc.wantError { + assert.Error(err) + } else { + assert.NoError(err) + } + }) + } +} diff --git a/internal/credentials/mocks/Reader.go b/internal/credentials/mocks/Reader.go index 4ae6db4f0..d4309dcee 100644 --- a/internal/credentials/mocks/Reader.go +++ b/internal/credentials/mocks/Reader.go @@ -5,7 +5,6 @@ package mocks import ( context "context" - credentials "github.com/chainloop-dev/chainloop/internal/credentials" mock "github.com/stretchr/testify/mock" ) @@ -14,27 +13,13 @@ type Reader struct { mock.Mock } -// ReadAPICreds provides a mock function with given fields: ctx, secretName, creds -func (_m *Reader) ReadAPICreds(ctx context.Context, secretName string, creds *credentials.APICreds) error { - ret := _m.Called(ctx, secretName, creds) +// ReadCredentials provides a mock function with given fields: ctx, secretName, _a2 +func (_m *Reader) ReadCredentials(ctx context.Context, secretName string, _a2 interface{}) error { + ret := _m.Called(ctx, secretName, _a2) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, *credentials.APICreds) error); ok { - r0 = rf(ctx, secretName, creds) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// ReadOCICreds provides a mock function with given fields: ctx, secretName, creds -func (_m *Reader) ReadOCICreds(ctx context.Context, secretName string, creds *credentials.OCIKeypair) error { - ret := _m.Called(ctx, secretName, creds) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, *credentials.OCIKeypair) error); ok { - r0 = rf(ctx, secretName, creds) + if rf, ok := ret.Get(0).(func(context.Context, string, interface{}) error); ok { + r0 = rf(ctx, secretName, _a2) } else { r0 = ret.Error(0) } diff --git a/internal/credentials/mocks/ReaderWriter.go b/internal/credentials/mocks/ReaderWriter.go index 791ff2b27..b5c4d1748 100644 --- a/internal/credentials/mocks/ReaderWriter.go +++ b/internal/credentials/mocks/ReaderWriter.go @@ -5,7 +5,6 @@ package mocks import ( context "context" - credentials "github.com/chainloop-dev/chainloop/internal/credentials" mock "github.com/stretchr/testify/mock" ) @@ -14,8 +13,8 @@ type ReaderWriter struct { mock.Mock } -// DeleteCreds provides a mock function with given fields: ctx, credID -func (_m *ReaderWriter) DeleteCreds(ctx context.Context, credID string) error { +// DeleteCredentials provides a mock function with given fields: ctx, credID +func (_m *ReaderWriter) DeleteCredentials(ctx context.Context, credID string) error { ret := _m.Called(ctx, credID) var r0 error @@ -28,13 +27,13 @@ func (_m *ReaderWriter) DeleteCreds(ctx context.Context, credID string) error { return r0 } -// ReadAPICreds provides a mock function with given fields: ctx, secretName, creds -func (_m *ReaderWriter) ReadAPICreds(ctx context.Context, secretName string, creds *credentials.APICreds) error { - ret := _m.Called(ctx, secretName, creds) +// ReadCredentials provides a mock function with given fields: ctx, secretName, _a2 +func (_m *ReaderWriter) ReadCredentials(ctx context.Context, secretName string, _a2 interface{}) error { + ret := _m.Called(ctx, secretName, _a2) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, *credentials.APICreds) error); ok { - r0 = rf(ctx, secretName, creds) + if rf, ok := ret.Get(0).(func(context.Context, string, interface{}) error); ok { + r0 = rf(ctx, secretName, _a2) } else { r0 = ret.Error(0) } @@ -42,61 +41,23 @@ func (_m *ReaderWriter) ReadAPICreds(ctx context.Context, secretName string, cre return r0 } -// ReadOCICreds provides a mock function with given fields: ctx, secretName, creds -func (_m *ReaderWriter) ReadOCICreds(ctx context.Context, secretName string, creds *credentials.OCIKeypair) error { - ret := _m.Called(ctx, secretName, creds) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, *credentials.OCIKeypair) error); ok { - r0 = rf(ctx, secretName, creds) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// SaveAPICreds provides a mock function with given fields: ctx, org, creds -func (_m *ReaderWriter) SaveAPICreds(ctx context.Context, org string, creds *credentials.APICreds) (string, error) { - ret := _m.Called(ctx, org, creds) - - var r0 string - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, *credentials.APICreds) (string, error)); ok { - return rf(ctx, org, creds) - } - if rf, ok := ret.Get(0).(func(context.Context, string, *credentials.APICreds) string); ok { - r0 = rf(ctx, org, creds) - } else { - r0 = ret.Get(0).(string) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, *credentials.APICreds) error); ok { - r1 = rf(ctx, org, creds) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// SaveOCICreds provides a mock function with given fields: ctx, org, creds -func (_m *ReaderWriter) SaveOCICreds(ctx context.Context, org string, creds *credentials.OCIKeypair) (string, error) { - ret := _m.Called(ctx, org, creds) +// SaveCredentials provides a mock function with given fields: ctx, org, _a2 +func (_m *ReaderWriter) SaveCredentials(ctx context.Context, org string, _a2 interface{}) (string, error) { + ret := _m.Called(ctx, org, _a2) var r0 string var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, *credentials.OCIKeypair) (string, error)); ok { - return rf(ctx, org, creds) + if rf, ok := ret.Get(0).(func(context.Context, string, interface{}) (string, error)); ok { + return rf(ctx, org, _a2) } - if rf, ok := ret.Get(0).(func(context.Context, string, *credentials.OCIKeypair) string); ok { - r0 = rf(ctx, org, creds) + if rf, ok := ret.Get(0).(func(context.Context, string, interface{}) string); ok { + r0 = rf(ctx, org, _a2) } else { r0 = ret.Get(0).(string) } - if rf, ok := ret.Get(1).(func(context.Context, string, *credentials.OCIKeypair) error); ok { - r1 = rf(ctx, org, creds) + if rf, ok := ret.Get(1).(func(context.Context, string, interface{}) error); ok { + r1 = rf(ctx, org, _a2) } else { r1 = ret.Error(1) } diff --git a/internal/credentials/mocks/Writer.go b/internal/credentials/mocks/Writer.go index 082cefa26..8ce555c38 100644 --- a/internal/credentials/mocks/Writer.go +++ b/internal/credentials/mocks/Writer.go @@ -5,7 +5,6 @@ package mocks import ( context "context" - credentials "github.com/chainloop-dev/chainloop/internal/credentials" mock "github.com/stretchr/testify/mock" ) @@ -14,8 +13,8 @@ type Writer struct { mock.Mock } -// DeleteCreds provides a mock function with given fields: ctx, credID -func (_m *Writer) DeleteCreds(ctx context.Context, credID string) error { +// DeleteCredentials provides a mock function with given fields: ctx, credID +func (_m *Writer) DeleteCredentials(ctx context.Context, credID string) error { ret := _m.Called(ctx, credID) var r0 error @@ -28,47 +27,23 @@ func (_m *Writer) DeleteCreds(ctx context.Context, credID string) error { return r0 } -// SaveAPICreds provides a mock function with given fields: ctx, org, creds -func (_m *Writer) SaveAPICreds(ctx context.Context, org string, creds *credentials.APICreds) (string, error) { - ret := _m.Called(ctx, org, creds) +// SaveCredentials provides a mock function with given fields: ctx, org, _a2 +func (_m *Writer) SaveCredentials(ctx context.Context, org string, _a2 interface{}) (string, error) { + ret := _m.Called(ctx, org, _a2) var r0 string var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, *credentials.APICreds) (string, error)); ok { - return rf(ctx, org, creds) + if rf, ok := ret.Get(0).(func(context.Context, string, interface{}) (string, error)); ok { + return rf(ctx, org, _a2) } - if rf, ok := ret.Get(0).(func(context.Context, string, *credentials.APICreds) string); ok { - r0 = rf(ctx, org, creds) + if rf, ok := ret.Get(0).(func(context.Context, string, interface{}) string); ok { + r0 = rf(ctx, org, _a2) } else { r0 = ret.Get(0).(string) } - if rf, ok := ret.Get(1).(func(context.Context, string, *credentials.APICreds) error); ok { - r1 = rf(ctx, org, creds) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// SaveOCICreds provides a mock function with given fields: ctx, org, creds -func (_m *Writer) SaveOCICreds(ctx context.Context, org string, creds *credentials.OCIKeypair) (string, error) { - ret := _m.Called(ctx, org, creds) - - var r0 string - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, *credentials.OCIKeypair) (string, error)); ok { - return rf(ctx, org, creds) - } - if rf, ok := ret.Get(0).(func(context.Context, string, *credentials.OCIKeypair) string); ok { - r0 = rf(ctx, org, creds) - } else { - r0 = ret.Get(0).(string) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, *credentials.OCIKeypair) error); ok { - r1 = rf(ctx, org, creds) + if rf, ok := ret.Get(1).(func(context.Context, string, interface{}) error); ok { + r1 = rf(ctx, org, _a2) } else { r1 = ret.Error(1) } diff --git a/internal/credentials/vault/keyval.go b/internal/credentials/vault/keyval.go index ecc869a58..95e4f6fcd 100644 --- a/internal/credentials/vault/keyval.go +++ b/internal/credentials/vault/keyval.go @@ -103,50 +103,16 @@ func validateClient(kv *vault.KVv2, pathPrefix string) error { return nil } -func (m *Manager) SaveOCICreds(ctx context.Context, orgID string, creds *credentials.OCIKeypair) (string, error) { - if err := creds.Validate(); err != nil { - return "", fmt.Errorf("validating OCI keypair: %w", err) - } - - credsM, err := structToMap(creds) - if err != nil { - return "", fmt.Errorf("converting OCI keypair to map: %w", err) - } - - return m.save(ctx, orgID, credsM) -} - -func (m *Manager) SaveAPICreds(ctx context.Context, orgID string, creds *credentials.APICreds) (string, error) { - if err := creds.Validate(); err != nil { - return "", fmt.Errorf("validating API creds: %w", err) - } - +func (m *Manager) SaveCredentials(ctx context.Context, orgID string, creds any) (string, error) { credsM, err := structToMap(creds) if err != nil { - return "", fmt.Errorf("converting API creds to map: %w", err) + return "", fmt.Errorf("converting struct to map: %w", err) } - return m.save(ctx, orgID, credsM) -} - -func (m *Manager) ReadAPICreds(ctx context.Context, secretID string, creds *credentials.APICreds) error { - return m.read(ctx, secretID, creds) -} - -func (m *Manager) ReadOCICreds(ctx context.Context, secretID string, creds *credentials.OCIKeypair) error { - return m.read(ctx, secretID, creds) -} - -func (m *Manager) DeleteCreds(ctx context.Context, secretID string) error { - m.logger.Infow("msg", "deleting credentials", "path", secretID) - return m.client.DeleteMetadata(ctx, secretID) -} - -func (m *Manager) save(ctx context.Context, orgID string, creds map[string]interface{}) (string, error) { secretName := strings.Join([]string{m.secretPrefix, orgID, uuid.Generate().String()}, "/") m.logger.Infow("msg", "storing credentials", "path", secretName) - _, err := m.client.Put(ctx, secretName, creds) + _, err = m.client.Put(ctx, secretName, credsM) if err != nil { return "", fmt.Errorf("creating secret in Vault: %w", err) } @@ -154,7 +120,7 @@ func (m *Manager) save(ctx context.Context, orgID string, creds map[string]inter return secretName, nil } -func (m *Manager) read(ctx context.Context, secretID string, output interface{}) error { +func (m *Manager) ReadCredentials(ctx context.Context, secretID string, creds any) error { m.logger.Infow("msg", "reading credentials", "path", secretID) s, err := m.client.Get(ctx, secretID) @@ -166,13 +132,18 @@ func (m *Manager) read(ctx context.Context, secretID string, output interface{}) return fmt.Errorf("reading secret from Vault: %w", err) } - if err := mapToStruct(s.Data, output); err != nil { + if err := mapToStruct(s.Data, creds); err != nil { return fmt.Errorf("converting secret to struct: %w", err) } return nil } +func (m *Manager) DeleteCredentials(ctx context.Context, secretID string) error { + m.logger.Infow("msg", "deleting credentials", "path", secretID) + return m.client.DeleteMetadata(ctx, secretID) +} + // convert from struct to map[string]interface{} func structToMap(i interface{}) (map[string]interface{}, error) { b, err := json.Marshal(i) diff --git a/internal/credentials/vault/keyval_test.go b/internal/credentials/vault/keyval_test.go index 42d680228..bb855ad5c 100644 --- a/internal/credentials/vault/keyval_test.go +++ b/internal/credentials/vault/keyval_test.go @@ -19,6 +19,7 @@ import ( "context" "fmt" "os" + "reflect" "testing" "time" @@ -65,22 +66,23 @@ func (s *testSuite) TestNewManager() { const orgID = "test-org" -func (s *testSuite) TestReadWriteOCICreds() { +type storedSecret struct { + Foo, Bar string +} + +func (s *testSuite) TestReadWriteCredentials() { assert := assert.New(s.T()) - validCreds := &credentials.OCIKeypair{Repo: "test-repo", Username: "username", Password: "password"} + validOCICreds := &credentials.OCIKeypair{Repo: "test-repo", Username: "username", Password: "password"} testCases := []struct { name string - want *credentials.OCIKeypair + want any path string expectedWriteError bool }{ - {"empty secret", &credentials.OCIKeypair{}, "", true}, - {"missing repo", &credentials.OCIKeypair{Username: "un", Password: "p"}, "", true}, - {"missing username", &credentials.OCIKeypair{Username: "", Password: "p", Repo: "repo"}, "", true}, - {"missing password", &credentials.OCIKeypair{Username: "u", Password: "", Repo: "repo"}, "", true}, - {"valid creds", validCreds, "", false}, - {"valid creds custom path", validCreds, "fooo", false}, + {"valid creds", validOCICreds, "", false}, + {"valid creds custom path", validOCICreds, "fooo", false}, + {"random struct is compatible", &storedSecret{"bar", "baz"}, "", false}, } for _, tc := range testCases { @@ -89,16 +91,23 @@ func (s *testSuite) TestReadWriteOCICreds() { m, err := vault.NewManager(opts) require.NoError(s.T(), err) - secretName, err := m.SaveOCICreds(context.Background(), orgID, tc.want) + secretName, err := m.SaveCredentials(context.Background(), orgID, tc.want) if tc.expectedWriteError { assert.Error(err) return } assert.NoError(err) - // Read the keypair - got := &credentials.OCIKeypair{} - err = m.ReadOCICreds(context.Background(), secretName, got) + // Read the keypair choosing the returning struct + var got any + switch reflect.TypeOf(tc.want).String() { + case "*credentials.OCIKeypair": + got = &credentials.OCIKeypair{} + default: + got = &storedSecret{} + } + + err = m.ReadCredentials(context.Background(), secretName, got) assert.NoError(err) // Compare the keypair @@ -110,7 +119,7 @@ func (s *testSuite) TestReadWriteOCICreds() { opts := &vault.NewManagerOpts{AuthToken: defaultToken, Address: s.connectionString} m, err := vault.NewManager(opts) require.NoError(s.T(), err) - err = m.ReadOCICreds(context.Background(), "bogus", nil) + err = m.ReadCredentials(context.Background(), "bogus", nil) assert.ErrorIs(err, credentials.ErrNotFound) } @@ -124,67 +133,26 @@ func (s *testSuite) TestDeleteCreds() { m, err := vault.NewManager(opts) require.NoError(err) - secretName, err := m.SaveOCICreds(context.Background(), orgID, validCreds) + secretName, err := m.SaveCredentials(context.Background(), orgID, validCreds) require.NoError(err) // Read the keypair got := &credentials.OCIKeypair{} - err = m.ReadOCICreds(context.Background(), secretName, got) + err = m.ReadCredentials(context.Background(), secretName, got) assert.NoError(err) // Compare the keypair assert.Equal(validCreds, got) // Delete and check it does not exist - err = m.DeleteCreds(context.Background(), secretName) + err = m.DeleteCredentials(context.Background(), secretName) assert.NoError(err) // It does not exist got = &credentials.OCIKeypair{} - err = m.ReadOCICreds(context.Background(), secretName, got) + err = m.ReadCredentials(context.Background(), secretName, got) assert.Error(err) } -func (s *testSuite) TestReadWriteAPICreds() { - assert := assert.New(s.T()) - validCreds := &credentials.APICreds{Host: "http://hospath.local", Key: "api-key-not-a-secret"} - - testCases := []struct { - name string - want *credentials.APICreds - path string - expectedError bool - }{ - {"empty secret", &credentials.APICreds{}, "", true}, - {"missing host", &credentials.APICreds{Host: "", Key: "p"}, "", true}, - {"missing key", &credentials.APICreds{Host: "host", Key: ""}, "", true}, - {"valid creds", validCreds, "", false}, - {"valid creds custom path", validCreds, "fooo", false}, - } - - for _, tc := range testCases { - s.Run(tc.name, func() { - opts := &vault.NewManagerOpts{AuthToken: defaultToken, Address: s.connectionString, SecretPrefix: tc.path} - m, err := vault.NewManager(opts) - require.NoError(s.T(), err) - - secretName, err := m.SaveAPICreds(context.Background(), orgID, tc.want) - if tc.expectedError { - assert.Error(err) - return - } - - assert.NoError(err) - // Read the keypair - got := &credentials.APICreds{} - err = m.ReadAPICreds(context.Background(), secretName, got) - assert.NoError(err) - - // Compare the keypair - assert.Equal(tc.want, got) - }) - } -} - type testSuite struct { suite.Suite vault *vaultInstance