diff --git a/clients/ui/bff/README.md b/clients/ui/bff/README.md index b0e0a3e57..026453969 100644 --- a/clients/ui/bff/README.md +++ b/clients/ui/bff/README.md @@ -68,6 +68,7 @@ make docker-build | GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions | GetAllModelVersionsForRegisteredModelHandler | Get all ModelVersion entities by RegisteredModel ID | | POST /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions | CreateModelVersionForRegisteredModelHandler | Create a ModelVersion entity for a specific RegisteredModel | | GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts | GetAllModelArtifactsByModelVersionHandler | Get all ModelArtifact entities by ModelVersion ID | +| POST /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts | CreateModelArtifactByModelVersion | Create a ModelArtifact entity for a specific ModelVersion | ### Sample local calls ``` @@ -189,4 +190,22 @@ curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/regi ``` # GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts curl -i http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts +``` +``` +# POST /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts +curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts" \ + -H "Content-Type: application/json" \ + -d '{ "data": { + "customProperties": { + "my-label9": { + "metadataType": "MetadataStringValue", + "string_value": "val" + } + }, + "description": "New description", + "externalId": "9927", + "name": "ModelArtifact One", + "state": "LIVE", + "artifactType": "TYPE_ONE" +}}' ``` \ No newline at end of file diff --git a/clients/ui/bff/api/app.go b/clients/ui/bff/api/app.go index 3d5fd0fe9..36b0f6b97 100644 --- a/clients/ui/bff/api/app.go +++ b/clients/ui/bff/api/app.go @@ -18,6 +18,7 @@ const ( ModelRegistryId = "model_registry_id" RegisteredModelId = "registered_model_id" ModelVersionId = "model_version_id" + ModelArtifactId = "model_artifact_id" HealthCheckPath = PathPrefix + "/healthcheck" ModelRegistryListPath = PathPrefix + "/model_registry" ModelRegistryPath = ModelRegistryListPath + "/:" + ModelRegistryId @@ -27,6 +28,8 @@ const ( ModelVersionListPath = ModelRegistryPath + "/model_versions" ModelVersionPath = ModelVersionListPath + "/:" + ModelVersionId ModelVersionArtifactListPath = ModelVersionPath + "/artifacts" + ModelArtifactListPath = ModelRegistryPath + "/model_artifacts" + ModelArtifactPath = ModelArtifactListPath + "/:" + ModelArtifactId ) type App struct { @@ -91,6 +94,7 @@ func (app *App) Routes() http.Handler { router.POST(ModelVersionListPath, app.AttachRESTClient(app.CreateModelVersionHandler)) router.PATCH(ModelVersionPath, app.AttachRESTClient(app.UpdateModelVersionHandler)) router.GET(ModelVersionArtifactListPath, app.AttachRESTClient(app.GetAllModelArtifactsByModelVersionHandler)) + router.POST(ModelVersionArtifactListPath, app.AttachRESTClient(app.CreateModelArtifactByModelVersionHandler)) // Kubernetes client routes router.GET(ModelRegistryListPath, app.ModelRegistryHandler) diff --git a/clients/ui/bff/api/model_versions_handler.go b/clients/ui/bff/api/model_versions_handler.go index 33dc001db..a68c25760 100644 --- a/clients/ui/bff/api/model_versions_handler.go +++ b/clients/ui/bff/api/model_versions_handler.go @@ -14,6 +14,7 @@ import ( type ModelVersionEnvelope Envelope[*openapi.ModelVersion, None] type ModelVersionListEnvelope Envelope[*openapi.ModelVersionList, None] type ModelArtifactListEnvelope Envelope[*openapi.ModelArtifactList, None] +type ModelArtifactEnvelope Envelope[*openapi.ModelArtifact, None] func (app *App) GetModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) @@ -172,3 +173,60 @@ func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter, app.serverErrorResponse(w, r, err) } } + +func (app *App) CreateModelArtifactByModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) + if !ok { + app.serverErrorResponse(w, r, errors.New("REST client not found")) + return + } + + var envelope ModelArtifactEnvelope + if err := json.NewDecoder(r.Body).Decode(&envelope); err != nil { + app.serverErrorResponse(w, r, fmt.Errorf("error decoding JSON:: %v", err.Error())) + return + } + + data := *envelope.Data + + if err := validation.ValidateModelArtifact(data); err != nil { + app.badRequestResponse(w, r, fmt.Errorf("validation error:: %v", err.Error())) + return + } + + jsonData, err := json.Marshal(data) + if err != nil { + app.serverErrorResponse(w, r, fmt.Errorf("error marshaling ModelVersion to JSON: %w", err)) + return + } + + createdArtifact, err := app.modelRegistryClient.CreateModelArtifactByModelVersion(client, ps.ByName(ModelVersionId), jsonData) + if err != nil { + var httpErr *integrations.HTTPError + if errors.As(err, &httpErr) { + app.errorResponse(w, r, httpErr) + } else { + app.serverErrorResponse(w, r, err) + } + return + } + + if createdArtifact == nil { + app.serverErrorResponse(w, r, fmt.Errorf("created ModelArtifact is nil")) + return + } + + response := ModelArtifactEnvelope{ + Data: createdArtifact, + } + + w.Header().Set("Location", ParseURLTemplate(ModelArtifactPath, map[string]string{ + ModelRegistryId: ps.ByName(ModelRegistryId), + ModelArtifactId: createdArtifact.GetId(), + })) + err = app.WriteJSON(w, http.StatusCreated, response, nil) + if err != nil { + app.serverErrorResponse(w, r, fmt.Errorf("error writing JSON")) + return + } +} diff --git a/clients/ui/bff/api/model_versions_handler_test.go b/clients/ui/bff/api/model_versions_handler_test.go index 3c426392e..27050e819 100644 --- a/clients/ui/bff/api/model_versions_handler_test.go +++ b/clients/ui/bff/api/model_versions_handler_test.go @@ -59,3 +59,21 @@ func TestGetAllModelArtifactsByModelVersionHandler(t *testing.T) { assert.Equal(t, expected.Data.NextPageToken, actual.Data.NextPageToken) assert.Equal(t, len(expected.Data.Items), len(actual.Data.Items)) } + +func TestCreateModelArtifactByModelVersionHandler(t *testing.T) { + data := mocks.GetModelArtifactMocks()[0] + expected := ModelArtifactEnvelope{Data: &data} + + artifact := openapi.ModelArtifact{ + Name: openapi.PtrString("Artifact One"), + ArtifactType: "ARTIFACT_TYPE_ONE", + } + body := ModelArtifactEnvelope{Data: &artifact} + + actual, rs, err := setupApiTest[ModelArtifactEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", body) + assert.NoError(t, err) + + assert.Equal(t, http.StatusCreated, rs.StatusCode) + assert.Equal(t, expected.Data.GetArtifactType(), actual.Data.GetArtifactType()) + assert.Equal(t, rs.Header.Get("Location"), "/api/v1/model_registry/model-registry/model_artifacts/1") +} diff --git a/clients/ui/bff/data/model_version.go b/clients/ui/bff/data/model_version.go index f93318049..41057d0a3 100644 --- a/clients/ui/bff/data/model_version.go +++ b/clients/ui/bff/data/model_version.go @@ -17,6 +17,7 @@ type ModelVersionInterface interface { CreateModelVersion(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.ModelVersion, error) UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) + CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error) } type ModelVersion struct { @@ -58,7 +59,7 @@ func (v ModelVersion) CreateModelVersion(client integrations.HTTPClientInterface return &model, nil } -func (m ModelVersion) UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) { +func (v ModelVersion) UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) { path, err := url.JoinPath(modelVersionPath, id) if err != nil { @@ -78,7 +79,7 @@ func (m ModelVersion) UpdateModelVersion(client integrations.HTTPClientInterface return &model, nil } -func (m ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) { +func (v ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) { path, err := url.JoinPath(modelVersionPath, id, artifactsByModelVersionPath) if err != nil { @@ -97,3 +98,22 @@ func (m ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPCl return &model, nil } + +func (v ModelVersion) CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error) { + path, err := url.JoinPath(modelVersionPath, id, artifactsByModelVersionPath) + if err != nil { + return nil, err + } + + responseData, err := client.POST(path, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("error posting model artifact: %w", err) + } + + var model openapi.ModelArtifact + if err := json.Unmarshal(responseData, &model); err != nil { + return nil, fmt.Errorf("error decoding response data: %w", err) + } + + return &model, nil +} diff --git a/clients/ui/bff/data/model_version_test.go b/clients/ui/bff/data/model_version_test.go index 25236c77a..a17f0d1ff 100644 --- a/clients/ui/bff/data/model_version_test.go +++ b/clients/ui/bff/data/model_version_test.go @@ -114,3 +114,29 @@ func TestGetModelArtifactsByModelVersion(t *testing.T) { assert.Equal(t, expected.PageSize, actual.PageSize) assert.Equal(t, len(expected.Items), len(actual.Items)) } + +func TestCreateModelArtifactByModelVersion(t *testing.T) { + gofakeit.Seed(0) + + expected := mocks.GenerateMockModelArtifact() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + modelVersion := ModelVersion{} + + path, err := url.JoinPath(modelVersionPath, "1", artifactsByModelVersionPath) + assert.NoError(t, err) + + mockClient := new(mocks.MockHTTPClient) + mockClient.On(http.MethodPost, path, mock.Anything).Return(mockData, nil) + + jsonInnput, err := json.Marshal(expected) + assert.NoError(t, err) + + actual, err := modelVersion.CreateModelArtifactByModelVersion(mockClient, "1", jsonInnput) + assert.NoError(t, err) + assert.NotNil(t, actual) + assert.Equal(t, expected.Name, actual.Name) + assert.Equal(t, expected.ArtifactType, actual.ArtifactType) +} diff --git a/clients/ui/bff/internals/mocks/model_registry_client_mock.go b/clients/ui/bff/internals/mocks/model_registry_client_mock.go index a1819cac8..375f5dbfd 100644 --- a/clients/ui/bff/internals/mocks/model_registry_client_mock.go +++ b/clients/ui/bff/internals/mocks/model_registry_client_mock.go @@ -64,3 +64,7 @@ func (m *ModelRegistryClientMock) GetModelArtifactsByModelVersion(client integra mockData := GetModelArtifactListMock() return &mockData, nil } +func (m *ModelRegistryClientMock) CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error) { + mockData := GetModelArtifactMocks()[0] + return &mockData, nil +} diff --git a/clients/ui/bff/validation/validation.go b/clients/ui/bff/validation/validation.go index 131ae1607..2c988bfc5 100644 --- a/clients/ui/bff/validation/validation.go +++ b/clients/ui/bff/validation/validation.go @@ -20,3 +20,11 @@ func ValidateModelVersion(input openapi.ModelVersion) error { // Add more field validations as required return nil } + +func ValidateModelArtifact(input openapi.ModelArtifact) error { + if input.GetName() == "" { + return errors.New("name cannot be empty") + } + // Add more field validations as required + return nil +} diff --git a/clients/ui/bff/validation/validation_test.go b/clients/ui/bff/validation/validation_test.go index 4ee3c9e1b..b79f4bba5 100644 --- a/clients/ui/bff/validation/validation_test.go +++ b/clients/ui/bff/validation/validation_test.go @@ -38,3 +38,20 @@ func TestValidateModelVersion(t *testing.T) { validateTestSpecs(t, specs, ValidateModelVersion) } + +func TestValidateModel(t *testing.T) { + specs := []testSpec[openapi.ModelArtifact]{ + { + name: "Empty name", + input: openapi.ModelArtifact{Name: openapi.PtrString("")}, + wantErr: true, + }, + { + name: "Valid name", + input: openapi.ModelArtifact{Name: openapi.PtrString("ValidName")}, + wantErr: false, + }, + } + + validateTestSpecs(t, specs, ValidateModelArtifact) +}