From c00840bf479452011b5ac27fe94964ea97075fc8 Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Tue, 24 Sep 2024 11:04:15 -0700 Subject: [PATCH] .Net: Implemented generic data model support for Azure CosmosDB MongoDB connector (#8967) ### Motivation and Context Related: https://github.com/microsoft/semantic-kernel/issues/6522 - Implemented `AzureCosmosDBMongoDBGenericDataModelMapper` class. - Added unit and integration tests. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --- ...mosDBMongoDBGenericDataModelMapperTests.cs | 310 ++++++++++++++++++ .../AzureAISearchGenericDataModelMapper.cs | 2 +- .../AzureCosmosDBMongoDBConstants.cs | 38 +++ ...reCosmosDBMongoDBGenericDataModelMapper.cs | 181 ++++++++++ ...mosDBMongoDBVectorStoreRecordCollection.cs | 36 +- ...eCosmosDBMongoDBVectorStoreRecordMapper.cs | 40 +-- .../QdrantGenericDataModelMapper.cs | 2 +- .../RedisHashSetGenericDataModelMapper.cs | 2 +- .../RedisJsonGenericDataModelMapper.cs | 2 +- .../AzureCosmosDBMongoDBVectorStoreFixture.cs | 5 + ...MongoDBVectorStoreRecordCollectionTests.cs | 48 +++ .../Data/VectorStoreRecordPropertyReader.cs | 26 +- 12 files changed, 638 insertions(+), 54 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBGenericDataModelMapperTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBGenericDataModelMapper.cs diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBGenericDataModelMapperTests.cs new file mode 100644 index 000000000000..e2b02c35a41f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBGenericDataModelMapperTests.cs @@ -0,0 +1,310 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; +using Microsoft.SemanticKernel.Data; +using MongoDB.Bson; +using Xunit; + +namespace SemanticKernel.Connectors.AzureCosmosDBMongoDB.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class AzureCosmosDBMongoDBGenericDataModelMapperTests +{ + private static readonly VectorStoreRecordDefinition s_vectorStoreRecordDefinition = new() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("BoolDataProp", typeof(bool)), + new VectorStoreRecordDataProperty("NullableBoolDataProp", typeof(bool?)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordDataProperty("IntDataProp", typeof(int)), + new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), + new VectorStoreRecordDataProperty("LongDataProp", typeof(long)), + new VectorStoreRecordDataProperty("NullableLongDataProp", typeof(long?)), + new VectorStoreRecordDataProperty("FloatDataProp", typeof(float)), + new VectorStoreRecordDataProperty("NullableFloatDataProp", typeof(float?)), + new VectorStoreRecordDataProperty("DoubleDataProp", typeof(double)), + new VectorStoreRecordDataProperty("NullableDoubleDataProp", typeof(double?)), + new VectorStoreRecordDataProperty("DecimalDataProp", typeof(decimal)), + new VectorStoreRecordDataProperty("NullableDecimalDataProp", typeof(decimal?)), + new VectorStoreRecordDataProperty("DateTimeDataProp", typeof(DateTime)), + new VectorStoreRecordDataProperty("NullableDateTimeDataProp", typeof(DateTime?)), + new VectorStoreRecordDataProperty("TagListDataProp", typeof(List)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), + new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?)), + new VectorStoreRecordVectorProperty("DoubleVector", typeof(ReadOnlyMemory)), + new VectorStoreRecordVectorProperty("NullableDoubleVector", typeof(ReadOnlyMemory?)), + }, + }; + + private static readonly float[] s_floatVector = [1.0f, 2.0f, 3.0f]; + private static readonly double[] s_doubleVector = [1.0f, 2.0f, 3.0f]; + private static readonly List s_taglist = ["tag1", "tag2"]; + + [Fact] + public void MapFromDataToStorageModelMapsAllSupportedTypes() + { + // Arrange + var sut = new AzureCosmosDBMongoDBGenericDataModelMapper(s_vectorStoreRecordDefinition); + var dataModel = new VectorStoreGenericDataModel("key") + { + Data = + { + ["BoolDataProp"] = true, + ["NullableBoolDataProp"] = false, + ["StringDataProp"] = "string", + ["IntDataProp"] = 1, + ["NullableIntDataProp"] = 2, + ["LongDataProp"] = 3L, + ["NullableLongDataProp"] = 4L, + ["FloatDataProp"] = 5.0f, + ["NullableFloatDataProp"] = 6.0f, + ["DoubleDataProp"] = 7.0, + ["NullableDoubleDataProp"] = 8.0, + ["DecimalDataProp"] = 9.0m, + ["NullableDecimalDataProp"] = 10.0m, + ["DateTimeDataProp"] = new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), + ["NullableDateTimeDataProp"] = new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), + ["TagListDataProp"] = s_taglist, + }, + Vectors = + { + ["FloatVector"] = new ReadOnlyMemory(s_floatVector), + ["NullableFloatVector"] = new ReadOnlyMemory(s_floatVector), + ["DoubleVector"] = new ReadOnlyMemory(s_doubleVector), + ["NullableDoubleVector"] = new ReadOnlyMemory(s_doubleVector), + }, + }; + + // Act + var storageModel = sut.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal("key", storageModel["_id"]); + Assert.Equal(true, (bool?)storageModel["BoolDataProp"]); + Assert.Equal(false, (bool?)storageModel["NullableBoolDataProp"]); + Assert.Equal("string", (string?)storageModel["StringDataProp"]); + Assert.Equal(1, (int?)storageModel["IntDataProp"]); + Assert.Equal(2, (int?)storageModel["NullableIntDataProp"]); + Assert.Equal(3L, (long?)storageModel["LongDataProp"]); + Assert.Equal(4L, (long?)storageModel["NullableLongDataProp"]); + Assert.Equal(5.0f, (float?)storageModel["FloatDataProp"].AsDouble); + Assert.Equal(6.0f, (float?)storageModel["NullableFloatDataProp"].AsNullableDouble); + Assert.Equal(7.0, (double?)storageModel["DoubleDataProp"]); + Assert.Equal(8.0, (double?)storageModel["NullableDoubleDataProp"]); + Assert.Equal(9.0m, (decimal?)storageModel["DecimalDataProp"]); + Assert.Equal(10.0m, (decimal?)storageModel["NullableDecimalDataProp"]); + Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), storageModel["DateTimeDataProp"].ToUniversalTime()); + Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), storageModel["NullableDateTimeDataProp"].ToUniversalTime()); + Assert.Equal(s_taglist, storageModel["TagListDataProp"]!.AsBsonArray.Select(x => (string)x!).ToArray()); + Assert.Equal(s_floatVector, storageModel["FloatVector"]!.AsBsonArray.Select(x => (float)x.AsDouble!).ToArray()); + Assert.Equal(s_floatVector, storageModel["NullableFloatVector"]!.AsBsonArray.Select(x => (float)x.AsNullableDouble!).ToArray()); + Assert.Equal(s_doubleVector, storageModel["DoubleVector"]!.AsBsonArray.Select(x => (double)x!).ToArray()); + Assert.Equal(s_doubleVector, storageModel["NullableDoubleVector"]!.AsBsonArray.Select(x => (double)x!).ToArray()); + } + + [Fact] + public void MapFromDataToStorageModelMapsNullValues() + { + // Arrange + VectorStoreRecordDefinition vectorStoreRecordDefinition = new() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), + new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?)), + }, + }; + + var dataModel = new VectorStoreGenericDataModel("key") + { + Data = + { + ["StringDataProp"] = null, + ["NullableIntDataProp"] = null, + }, + Vectors = + { + ["NullableFloatVector"] = null, + }, + }; + + var sut = new AzureCosmosDBMongoDBGenericDataModelMapper(vectorStoreRecordDefinition); + + // Act + var storageModel = sut.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal(BsonNull.Value, storageModel["StringDataProp"]); + Assert.Equal(BsonNull.Value, storageModel["NullableIntDataProp"]); + Assert.Empty(storageModel["NullableFloatVector"].AsBsonArray); + } + + [Fact] + public void MapFromStorageToDataModelMapsAllSupportedTypes() + { + // Arrange + var sut = new AzureCosmosDBMongoDBGenericDataModelMapper(s_vectorStoreRecordDefinition); + var storageModel = new BsonDocument + { + ["_id"] = "key", + ["BoolDataProp"] = true, + ["NullableBoolDataProp"] = false, + ["StringDataProp"] = "string", + ["IntDataProp"] = 1, + ["NullableIntDataProp"] = 2, + ["LongDataProp"] = 3L, + ["NullableLongDataProp"] = 4L, + ["FloatDataProp"] = 5.0f, + ["NullableFloatDataProp"] = 6.0f, + ["DoubleDataProp"] = 7.0, + ["NullableDoubleDataProp"] = 8.0, + ["DecimalDataProp"] = 9.0m, + ["NullableDecimalDataProp"] = 10.0m, + ["DateTimeDataProp"] = new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), + ["NullableDateTimeDataProp"] = new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), + ["TagListDataProp"] = BsonArray.Create(s_taglist), + ["FloatVector"] = BsonArray.Create(s_floatVector), + ["NullableFloatVector"] = BsonArray.Create(s_floatVector), + ["DoubleVector"] = BsonArray.Create(s_doubleVector), + ["NullableDoubleVector"] = BsonArray.Create(s_doubleVector) + }; + + // Act + var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); + + // Assert + Assert.Equal("key", dataModel.Key); + Assert.Equal(true, dataModel.Data["BoolDataProp"]); + Assert.Equal(false, dataModel.Data["NullableBoolDataProp"]); + Assert.Equal("string", dataModel.Data["StringDataProp"]); + Assert.Equal(1, dataModel.Data["IntDataProp"]); + Assert.Equal(2, dataModel.Data["NullableIntDataProp"]); + Assert.Equal(3L, dataModel.Data["LongDataProp"]); + Assert.Equal(4L, dataModel.Data["NullableLongDataProp"]); + Assert.Equal(5.0f, dataModel.Data["FloatDataProp"]); + Assert.Equal(6.0f, dataModel.Data["NullableFloatDataProp"]); + Assert.Equal(7.0, dataModel.Data["DoubleDataProp"]); + Assert.Equal(8.0, dataModel.Data["NullableDoubleDataProp"]); + Assert.Equal(9.0m, dataModel.Data["DecimalDataProp"]); + Assert.Equal(10.0m, dataModel.Data["NullableDecimalDataProp"]); + Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), dataModel.Data["DateTimeDataProp"]); + Assert.Equal(new DateTime(2021, 1, 1, 0, 0, 0).ToUniversalTime(), dataModel.Data["NullableDateTimeDataProp"]); + Assert.Equal(s_taglist, dataModel.Data["TagListDataProp"]); + Assert.Equal(s_floatVector, ((ReadOnlyMemory)dataModel.Vectors["FloatVector"]!).ToArray()); + Assert.Equal(s_floatVector, ((ReadOnlyMemory)dataModel.Vectors["NullableFloatVector"]!)!.ToArray()); + Assert.Equal(s_doubleVector, ((ReadOnlyMemory)dataModel.Vectors["DoubleVector"]!).ToArray()); + Assert.Equal(s_doubleVector, ((ReadOnlyMemory)dataModel.Vectors["NullableDoubleVector"]!)!.ToArray()); + } + + [Fact] + public void MapFromStorageToDataModelMapsNullValues() + { + // Arrange + VectorStoreRecordDefinition vectorStoreRecordDefinition = new() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordDataProperty("NullableIntDataProp", typeof(int?)), + new VectorStoreRecordVectorProperty("NullableFloatVector", typeof(ReadOnlyMemory?)), + }, + }; + + var storageModel = new BsonDocument + { + ["_id"] = "key", + ["StringDataProp"] = BsonNull.Value, + ["NullableIntDataProp"] = BsonNull.Value, + ["NullableFloatVector"] = BsonNull.Value + }; + + var sut = new AzureCosmosDBMongoDBGenericDataModelMapper(vectorStoreRecordDefinition); + + // Act + var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); + + // Assert + Assert.Equal("key", dataModel.Key); + Assert.Null(dataModel.Data["StringDataProp"]); + Assert.Null(dataModel.Data["NullableIntDataProp"]); + Assert.Null(dataModel.Vectors["NullableFloatVector"]); + } + + [Fact] + public void MapFromStorageToDataModelThrowsForMissingKey() + { + // Arrange + var sut = new AzureCosmosDBMongoDBGenericDataModelMapper(s_vectorStoreRecordDefinition); + var storageModel = new BsonDocument(); + + // Act & Assert + var exception = Assert.Throws( + () => sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true })); + } + + [Fact] + public void MapFromDataToStorageModelSkipsMissingProperties() + { + // Arrange + VectorStoreRecordDefinition vectorStoreRecordDefinition = new() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), + }, + }; + + var dataModel = new VectorStoreGenericDataModel("key"); + var sut = new AzureCosmosDBMongoDBGenericDataModelMapper(vectorStoreRecordDefinition); + + // Act + var storageModel = sut.MapFromDataToStorageModel(dataModel); + + // Assert + Assert.Equal("key", (string?)storageModel["_id"]); + Assert.False(storageModel.Contains("StringDataProp")); + Assert.False(storageModel.Contains("FloatVector")); + } + + [Fact] + public void MapFromStorageToDataModelSkipsMissingProperties() + { + // Arrange + VectorStoreRecordDefinition vectorStoreRecordDefinition = new() + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Key", typeof(string)), + new VectorStoreRecordDataProperty("StringDataProp", typeof(string)), + new VectorStoreRecordVectorProperty("FloatVector", typeof(ReadOnlyMemory)), + }, + }; + + var storageModel = new BsonDocument + { + ["_id"] = "key" + }; + + var sut = new AzureCosmosDBMongoDBGenericDataModelMapper(vectorStoreRecordDefinition); + + // Act + var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); + + // Assert + Assert.Equal("key", dataModel.Key); + Assert.False(dataModel.Data.ContainsKey("StringDataProp")); + Assert.False(dataModel.Vectors.ContainsKey("FloatVector")); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchGenericDataModelMapper.cs index 98f76f1142fe..33d995cf87e0 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchGenericDataModelMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureAISearch/AzureAISearchGenericDataModelMapper.cs @@ -11,7 +11,7 @@ namespace Microsoft.SemanticKernel.Connectors.AzureAISearch; /// -/// A mapper that maps between the generic semantic kernel data model and the model that the data is stored in in Azure AI Search. +/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Azure AI Search. /// internal class AzureAISearchGenericDataModelMapper : IVectorStoreRecordMapper, JsonObject> { diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBConstants.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBConstants.cs index ac78c98fabc2..197faf81f093 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBConstants.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBConstants.cs @@ -1,5 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; + namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; /// @@ -12,4 +15,39 @@ internal static class AzureCosmosDBMongoDBConstants /// Reserved key property name in data model. internal const string DataModelReservedKeyPropertyName = "Id"; + + /// A containing the supported key types. + internal static readonly HashSet SupportedKeyTypes = + [ + typeof(string) + ]; + + /// A containing the supported data property types. + internal static readonly HashSet SupportedDataTypes = + [ + typeof(bool), + typeof(bool?), + typeof(string), + typeof(int), + typeof(int?), + typeof(long), + typeof(long?), + typeof(float), + typeof(float?), + typeof(double), + typeof(double?), + typeof(decimal), + typeof(decimal?), + typeof(DateTime), + typeof(DateTime?), + ]; + + /// A containing the supported vector types. + internal static readonly HashSet SupportedVectorTypes = + [ + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?), + typeof(ReadOnlyMemory), + typeof(ReadOnlyMemory?) + ]; } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBGenericDataModelMapper.cs new file mode 100644 index 000000000000..e3ea3d2a12fc --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBGenericDataModelMapper.cs @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using Microsoft.SemanticKernel.Data; +using MongoDB.Bson; + +namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; + +/// +/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Azure CosmosDB MongoDB. +/// +internal sealed class AzureCosmosDBMongoDBGenericDataModelMapper : IVectorStoreRecordMapper, BsonDocument> +{ + /// A that defines the schema of the data in the database. + private readonly VectorStoreRecordDefinition _vectorStoreRecordDefinition; + + /// + /// Initializes a new instance of the class. + /// + /// A that defines the schema of the data in the database. + public AzureCosmosDBMongoDBGenericDataModelMapper(VectorStoreRecordDefinition vectorStoreRecordDefinition) + { + Verify.NotNull(vectorStoreRecordDefinition); + + this._vectorStoreRecordDefinition = vectorStoreRecordDefinition; + } + + /// + public BsonDocument MapFromDataToStorageModel(VectorStoreGenericDataModel dataModel) + { + Verify.NotNull(dataModel); + + var document = new BsonDocument(); + + // Loop through all known properties and map each from the data model to the storage model. + foreach (var property in this._vectorStoreRecordDefinition.Properties) + { + var storagePropertyName = property.StoragePropertyName ?? property.DataModelPropertyName; + + if (property is VectorStoreRecordKeyProperty keyProperty) + { + document[AzureCosmosDBMongoDBConstants.MongoReservedKeyPropertyName] = dataModel.Key; + } + else if (property is VectorStoreRecordDataProperty dataProperty) + { + if (dataModel.Data is not null && dataModel.Data.TryGetValue(dataProperty.DataModelPropertyName, out var dataValue)) + { + document[storagePropertyName] = BsonValue.Create(dataValue); + } + } + else if (property is VectorStoreRecordVectorProperty vectorProperty) + { + if (dataModel.Vectors is not null && dataModel.Vectors.TryGetValue(vectorProperty.DataModelPropertyName, out var vectorValue)) + { + document[storagePropertyName] = BsonArray.Create(GetVectorArray(vectorValue)); + } + } + } + + return document; + } + + /// + public VectorStoreGenericDataModel MapFromStorageToDataModel(BsonDocument storageModel, StorageToDataModelMapperOptions options) + { + Verify.NotNull(storageModel); + + // Create variables to store the response properties. + string? key = null; + var dataProperties = new Dictionary(); + var vectorProperties = new Dictionary(); + + // Loop through all known properties and map each from the storage model to the data model. + foreach (var property in this._vectorStoreRecordDefinition.Properties) + { + var storagePropertyName = property.StoragePropertyName ?? property.DataModelPropertyName; + + if (property is VectorStoreRecordKeyProperty keyProperty) + { + if (storageModel.TryGetValue(AzureCosmosDBMongoDBConstants.MongoReservedKeyPropertyName, out var keyValue)) + { + key = keyValue.AsString; + } + } + else if (property is VectorStoreRecordDataProperty dataProperty) + { + if (!storageModel.TryGetValue(storagePropertyName, out var dataValue)) + { + continue; + } + + dataProperties.Add(dataProperty.DataModelPropertyName, GetDataPropertyValue(property.DataModelPropertyName, property.PropertyType, dataValue)); + } + else if (property is VectorStoreRecordVectorProperty vectorProperty && options.IncludeVectors) + { + if (!storageModel.TryGetValue(storagePropertyName, out var vectorValue)) + { + continue; + } + + vectorProperties.Add(vectorProperty.DataModelPropertyName, GetVectorPropertyValue(property.DataModelPropertyName, property.PropertyType, vectorValue)); + } + } + + if (key is null) + { + throw new VectorStoreRecordMappingException("No key property was found in the record retrieved from storage."); + } + + return new VectorStoreGenericDataModel(key) { Data = dataProperties, Vectors = vectorProperties }; + } + + #region private + + private static object? GetDataPropertyValue(string propertyName, Type propertyType, BsonValue value) + { + if (value.IsBsonNull) + { + return null; + } + + return propertyType switch + { + Type t when t == typeof(bool) => value.AsBoolean, + Type t when t == typeof(bool?) => value.AsNullableBoolean, + Type t when t == typeof(string) => value.AsString, + Type t when t == typeof(int) => value.AsInt32, + Type t when t == typeof(int?) => value.AsNullableInt32, + Type t when t == typeof(long) => value.AsInt64, + Type t when t == typeof(long?) => value.AsNullableInt64, + Type t when t == typeof(float) => ((float)value.AsDouble), + Type t when t == typeof(float?) => ((float?)value.AsNullableDouble), + Type t when t == typeof(double) => value.AsDouble, + Type t when t == typeof(double?) => value.AsNullableDouble, + Type t when t == typeof(decimal) => value.AsDecimal, + Type t when t == typeof(decimal?) => value.AsNullableDecimal, + Type t when t == typeof(DateTime) => value.ToUniversalTime(), + Type t when t == typeof(DateTime?) => value.ToNullableUniversalTime(), + Type t when typeof(IEnumerable).IsAssignableFrom(t) => value.AsBsonArray.Select( + item => GetDataPropertyValue(propertyName, VectorStoreRecordPropertyReader.GetCollectionElementType(t), item)), + _ => throw new NotSupportedException($"Mapping for property {propertyName} with type {propertyType.FullName} is not supported in generic data model.") + }; + } + + private static object? GetVectorPropertyValue(string propertyName, Type propertyType, BsonValue value) + { + if (value.IsBsonNull) + { + return null; + } + + return propertyType switch + { + Type t when t == typeof(ReadOnlyMemory) || t == typeof(ReadOnlyMemory?) => + new ReadOnlyMemory(value.AsBsonArray.Select(item => (float)item.AsDouble).ToArray()), + Type t when t == typeof(ReadOnlyMemory) || t == typeof(ReadOnlyMemory?) => + new ReadOnlyMemory(value.AsBsonArray.Select(item => item.AsDouble).ToArray()), + _ => throw new NotSupportedException($"Mapping for property {propertyName} with type {propertyType.FullName} is not supported in generic data model.") + }; + } + + private static object GetVectorArray(object? vector) + { + if (vector is null) + { + return Array.Empty(); + } + + return vector switch + { + ReadOnlyMemory memoryFloat => memoryFloat.ToArray(), + ReadOnlyMemory memoryDouble => memoryDouble.ToArray(), + _ => throw new NotSupportedException($"Mapping for type {vector.GetType().FullName} is not supported in generic data model.") + }; + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs index 6b3ad5601c58..6368e08acf6b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs @@ -66,6 +66,8 @@ public AzureCosmosDBMongoDBVectorStoreRecordCollection( // Verify. Verify.NotNull(mongoDatabase); Verify.NotNullOrWhiteSpace(collectionName); + VectorStoreRecordPropertyReader.VerifyGenericDataModelKeyType(typeof(TRecord), options?.BsonDocumentCustomMapper is not null, AzureCosmosDBMongoDBConstants.SupportedKeyTypes); + VectorStoreRecordPropertyReader.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); // Assign. this._mongoDatabase = mongoDatabase; @@ -88,8 +90,7 @@ public AzureCosmosDBMongoDBVectorStoreRecordCollection( this._vectorProperties = properties.VectorProperties; this._vectorStoragePropertyNames = this._vectorProperties.Select(property => this._storagePropertyNames[property.DataModelPropertyName]).ToList(); - this._mapper = this._options.BsonDocumentCustomMapper ?? - new AzureCosmosDBMongoDBVectorStoreRecordMapper(this._vectorStoreRecordDefinition, properties.KeyProperty.DataModelPropertyName); + this._mapper = this.InitializeMapper(properties.KeyProperty); } /// @@ -144,6 +145,8 @@ public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) const string OperationName = "Find"; + var includeVectors = options?.IncludeVectors ?? false; + var record = await this.RunOperationAsync(OperationName, async () => { using var cursor = await this @@ -162,7 +165,7 @@ public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) DatabaseName, this.CollectionName, OperationName, - () => this._mapper.MapFromStorageToDataModel(record, new())); + () => this._mapper.MapFromStorageToDataModel(record, new() { IncludeVectors = includeVectors })); } /// @@ -417,17 +420,42 @@ private static Dictionary GetStoragePropertyNames( foreach (var property in allProperties) { var propertyInfo = dataModel.GetProperty(property.DataModelPropertyName); + string propertyName; if (propertyInfo != null) { var bsonElementAttribute = propertyInfo.GetCustomAttribute(); - storagePropertyNames[property.DataModelPropertyName] = bsonElementAttribute?.ElementName ?? property.DataModelPropertyName; + propertyName = bsonElementAttribute?.ElementName ?? property.DataModelPropertyName; + } + else + { + propertyName = property.DataModelPropertyName; } + + storagePropertyNames[property.DataModelPropertyName] = propertyName; } return storagePropertyNames; } + /// + /// Returns custom mapper, generic data model mapper or default record mapper. + /// + private IVectorStoreRecordMapper InitializeMapper(VectorStoreRecordKeyProperty keyProperty) + { + if (this._options.BsonDocumentCustomMapper is not null) + { + return this._options.BsonDocumentCustomMapper; + } + + if (typeof(TRecord) == typeof(VectorStoreGenericDataModel)) + { + return (new AzureCosmosDBMongoDBGenericDataModelMapper(this._vectorStoreRecordDefinition) as IVectorStoreRecordMapper)!; + } + + return new AzureCosmosDBMongoDBVectorStoreRecordMapper(this._vectorStoreRecordDefinition, keyProperty.DataModelPropertyName); + } + #endregion } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordMapper.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordMapper.cs index cae5cd1d7496..2e94ad70d7b3 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordMapper.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; using System.Reflection; using Microsoft.SemanticKernel.Data; using MongoDB.Bson; @@ -13,39 +12,6 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; internal sealed class AzureCosmosDBMongoDBVectorStoreRecordMapper : IVectorStoreRecordMapper where TRecord : class { - /// A set of types that a key on the provided model may have. - private static readonly HashSet s_supportedKeyTypes = - [ - typeof(string) - ]; - - /// A set of types that data properties on the provided model may have. - private static readonly HashSet s_supportedDataTypes = - [ - typeof(bool), - typeof(bool?), - typeof(string), - typeof(int), - typeof(int?), - typeof(long), - typeof(long?), - typeof(float), - typeof(float?), - typeof(double), - typeof(double?), - typeof(decimal), - typeof(decimal?), - ]; - - /// A set of types that vectors on the provided model may have. - private static readonly HashSet s_supportedVectorTypes = - [ - typeof(ReadOnlyMemory), - typeof(ReadOnlyMemory?), - typeof(ReadOnlyMemory), - typeof(ReadOnlyMemory?) - ]; - /// A key property info of the data model. private readonly PropertyInfo _keyProperty; @@ -61,9 +27,9 @@ public AzureCosmosDBMongoDBVectorStoreRecordMapper(VectorStoreRecordDefinition v { var (keyProperty, dataProperties, vectorProperties) = VectorStoreRecordPropertyReader.FindProperties(typeof(TRecord), vectorStoreRecordDefinition, supportsMultipleVectors: true); - VectorStoreRecordPropertyReader.VerifyPropertyTypes([keyProperty], s_supportedKeyTypes, "Key"); - VectorStoreRecordPropertyReader.VerifyPropertyTypes(dataProperties, s_supportedDataTypes, "Data", supportEnumerable: true); - VectorStoreRecordPropertyReader.VerifyPropertyTypes(vectorProperties, s_supportedVectorTypes, "Vector"); + VectorStoreRecordPropertyReader.VerifyPropertyTypes([keyProperty], AzureCosmosDBMongoDBConstants.SupportedKeyTypes, "Key"); + VectorStoreRecordPropertyReader.VerifyPropertyTypes(dataProperties, AzureCosmosDBMongoDBConstants.SupportedDataTypes, "Data", supportEnumerable: true); + VectorStoreRecordPropertyReader.VerifyPropertyTypes(vectorProperties, AzureCosmosDBMongoDBConstants.SupportedVectorTypes, "Vector"); this._keyPropertyName = keyPropertyName; this._keyProperty = keyProperty; diff --git a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantGenericDataModelMapper.cs index 881ba7ad55fa..6e76b4fb1e7b 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantGenericDataModelMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Qdrant/QdrantGenericDataModelMapper.cs @@ -11,7 +11,7 @@ namespace Microsoft.SemanticKernel.Connectors.Qdrant; /// -/// A mapper that maps between the generic semantic kernel data model and the model that the data is stored in in Qdrant. +/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Qdrant. /// internal class QdrantGenericDataModelMapper : IVectorStoreRecordMapper, PointStruct>, IVectorStoreRecordMapper, PointStruct> { diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetGenericDataModelMapper.cs index df07afa971d5..0a2d7a3b8e8e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetGenericDataModelMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisHashSetGenericDataModelMapper.cs @@ -10,7 +10,7 @@ namespace Microsoft.SemanticKernel.Connectors.Redis; /// -/// A mapper that maps between the generic semantic kernel data model and the model that the data is stored in in Redis when using hash sets. +/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Redis when using hash sets. /// internal class RedisHashSetGenericDataModelMapper : IVectorStoreRecordMapper, (string Key, HashEntry[] HashEntries)> { diff --git a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonGenericDataModelMapper.cs b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonGenericDataModelMapper.cs index f1fcc75136e9..98a470b7caff 100644 --- a/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonGenericDataModelMapper.cs +++ b/dotnet/src/Connectors/Connectors.Memory.Redis/RedisJsonGenericDataModelMapper.cs @@ -9,7 +9,7 @@ namespace Microsoft.SemanticKernel.Connectors.Redis; /// -/// A mapper that maps between the generic semantic kernel data model and the model that the data is stored in in Redis when using JSON. +/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Redis when using JSON. /// internal class RedisJsonGenericDataModelMapper : IVectorStoreRecordMapper, (string Key, JsonNode Node)> { diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreFixture.cs index 3af1a3c66b1a..6a54d6983a6d 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreFixture.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreFixture.cs @@ -53,6 +53,7 @@ public AzureCosmosDBMongoDBVectorStoreFixture() new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { StoragePropertyName = "parking_is_included" }, new VectorStoreRecordDataProperty("HotelRating", typeof(float)), new VectorStoreRecordDataProperty("Tags", typeof(List)), + new VectorStoreRecordDataProperty("Timestamp", typeof(DateTime)), new VectorStoreRecordDataProperty("Description", typeof(string)), new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, IndexKind = IndexKind.IvfFlat, DistanceFunction = DistanceFunction.CosineDistance } ] @@ -106,6 +107,10 @@ public record AzureCosmosDBMongoDBHotel() [VectorStoreRecordData] public string Description { get; set; } + /// A datetime metadata field. + [VectorStoreRecordData] + public DateTime Timestamp { get; set; } + /// A vector field. [VectorStoreRecordVector(Dimensions: 4, IndexKind: IndexKind.IvfFlat, DistanceFunction: DistanceFunction.CosineDistance)] public ReadOnlyMemory? DescriptionEmbedding { get; set; } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs index 8672d75b3530..1296e2983c01 100644 --- a/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/Memory/AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollectionTests.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; @@ -86,6 +87,7 @@ public async Task ItCanCreateCollectionUpsertAndGetAsync(bool includeVectors, bo Assert.Equal(record.ParkingIncluded, getResult.ParkingIncluded); Assert.Equal(record.Tags.ToArray(), getResult.Tags.ToArray()); Assert.Equal(record.Description, getResult.Description); + Assert.Equal(record.Timestamp.ToUniversalTime(), getResult.Timestamp.ToUniversalTime()); if (includeVectors) { @@ -324,6 +326,51 @@ public async Task UpsertWithBsonVectorStoreWithNameModelWorksCorrectlyAsync() Assert.Equal("Test Name", getResult.HotelName); } + [Fact(Skip = SkipReason)] + public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() + { + // Arrange + var options = new AzureCosmosDBMongoDBVectorStoreRecordCollectionOptions> + { + VectorStoreRecordDefinition = fixture.HotelVectorStoreRecordDefinition + }; + + var sut = new AzureCosmosDBMongoDBVectorStoreRecordCollection>(fixture.MongoDatabase, fixture.TestCollection, options); + + // Act + var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel("GenericMapper-1") + { + Data = + { + { "HotelName", "Generic Mapper Hotel" }, + { "Description", "This is a generic mapper hotel" }, + { "Tags", new string[] { "generic" } }, + { "ParkingIncluded", false }, + { "Timestamp", new DateTime(1970, 1, 18, 0, 0, 0).ToUniversalTime() }, + { "HotelRating", 3.6f } + }, + Vectors = + { + { "DescriptionEmbedding", new ReadOnlyMemory([30f, 31f, 32f, 33f]) } + } + }); + + var localGetResult = await sut.GetAsync("GenericMapper-1", new GetRecordOptions { IncludeVectors = true }); + + // Assert + Assert.NotNull(upsertResult); + Assert.Equal("GenericMapper-1", upsertResult); + + Assert.NotNull(localGetResult); + Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); + Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); + Assert.Equal(new[] { "generic" }, localGetResult.Data["Tags"]); + Assert.False((bool?)localGetResult.Data["ParkingIncluded"]); + Assert.Equal(new DateTime(1970, 1, 18, 0, 0, 0).ToUniversalTime(), localGetResult.Data["Timestamp"]); + Assert.Equal(3.6f, localGetResult.Data["HotelRating"]); + Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); + } + #region private private AzureCosmosDBMongoDBHotel CreateTestHotel(string hotelId) @@ -337,6 +384,7 @@ private AzureCosmosDBMongoDBHotel CreateTestHotel(string hotelId) ParkingIncluded = true, Tags = { "t1", "t2" }, Description = "This is a great hotel.", + Timestamp = new DateTime(2024, 09, 23, 15, 32, 33), DescriptionEmbedding = new[] { 30f, 31f, 32f, 33f }, }; } diff --git a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs index f59f760100e4..c08b2d96e472 100644 --- a/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs +++ b/dotnet/src/InternalUtilities/src/Data/VectorStoreRecordPropertyReader.cs @@ -373,15 +373,7 @@ public static void VerifyPropertyType(string propertyName, Type propertyType, Ha // Check all collection scenarios and get stored type. if (supportedEnumerableTypes.Count > 0 && typeof(IEnumerable).IsAssignableFrom(propertyType)) { - var typeToCheck = propertyType switch - { - IEnumerable => typeof(object), - var enumerableType when enumerableType.IsGenericType && enumerableType.GetGenericTypeDefinition() == typeof(IEnumerable<>) => enumerableType.GetGenericArguments()[0], - var arrayType when arrayType.IsArray => arrayType.GetElementType()!, - var interfaceType when interfaceType.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>)) is Type enumerableInterface => - enumerableInterface.GetGenericArguments()[0], - _ => propertyType - }; + var typeToCheck = GetCollectionElementType(propertyType); if (!supportedEnumerableTypes.Contains(typeToCheck)) { @@ -585,4 +577,20 @@ public static Dictionary BuildPropertyNameToStorageNameMap((Vect return storagePropertyNameMap; } + + /// + /// Returns of collection elements. + /// + public static Type GetCollectionElementType(Type collectionType) + { + return collectionType switch + { + IEnumerable => typeof(object), + var enumerableType when enumerableType.IsGenericType && enumerableType.GetGenericTypeDefinition() == typeof(IEnumerable<>) => enumerableType.GetGenericArguments()[0], + var arrayType when arrayType.IsArray => arrayType.GetElementType()!, + var interfaceType when interfaceType.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>)) is Type enumerableInterface => + enumerableInterface.GetGenericArguments()[0], + _ => collectionType + }; + } }