From fbdd6bc1353b473e51ceea5fc7f96a627e33ec82 Mon Sep 17 00:00:00 2001 From: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> Date: Thu, 19 Sep 2024 14:13:23 +0100 Subject: [PATCH] .Net: Support polymorphic serialization of ChatMessageContent class and its derivatives (#8901) ### Motivation, Context and Description Today, when serializing chat history, the `ChatMessageContent` type or its derivatives, like `OpenAIChatMessageContent`, are serialized as the `ChatMessageContent` type because the type is used as the chat history element type and the JSON serializer uses this type's public contract for serialization. However, when attempting serialization of instances of either the `ChatMessageContent` or `OpenAIChatMessageContent` type that are declared as `KernelContent` or `object` type, the serialization fails with the `System.NotSupportedException: Runtime type '{OpenAI}ChatMessageContent' is not supported by polymorphic type 'Microsoft.SemanticKernel.KernelContent'` exception. The reason for this exception is that neither of these types is registered for polymorphic serialization. This PR registers `ChatMessageContent` type for polymorphic serialization to allow serialization of the type instances declared as of `KernelContent` or `object` types: ```csharp KernelContent content = new ChatMessageContent(...); // Now it's possible to serialize the content variable of KernelContent type that holds reference to an instance of the ChatMessageContent type as ChatMessageContent type. var json = JsonSerializer.Serialize(content); ``` Additionally, it enables serialization of unknow in advance and external derivatives of `ChatMessageContent` type like `OpenAIChatMessageContent`. These types are serialized using contract of nearest ancestor which is `ChatMessageContent` by default. To change this behavior and register the unknown type for polymorphic serialization use the contract model - [Configure polymorphism with the contract model](https://learn.microsoft.com/en-us/dotnet/standard/serialization/system-text-json/polymorphism?pivots=dotnet-8-0#configure-polymorphism-with-the-contract-model) ```csharp KernelContent content = new UnknownChatMessageContent(...); // The content variable will be serialized using the ChatMessageContent type contract. var json = JsonSerializer.Serialize(content); private class UnknownChatMessageContent : ChatMessageContent{} ``` Closes: https://github.com/microsoft/semantic-kernel/issues/7478 ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone :smile: --- .../Contents/KernelContent.cs | 3 +- .../Contents/ChatMessageContentTests.cs | 90 ++++++++++++++++++- 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/dotnet/src/SemanticKernel.Abstractions/Contents/KernelContent.cs b/dotnet/src/SemanticKernel.Abstractions/Contents/KernelContent.cs index 183542021705..8dbcc00eb25d 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Contents/KernelContent.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Contents/KernelContent.cs @@ -9,13 +9,14 @@ namespace Microsoft.SemanticKernel; /// /// Base class for all AI non-streaming results /// -[JsonPolymorphic(TypeDiscriminatorPropertyName = "$type")] +[JsonPolymorphic(TypeDiscriminatorPropertyName = "$type", UnknownDerivedTypeHandling = JsonUnknownDerivedTypeHandling.FallBackToNearestAncestor)] [JsonDerivedType(typeof(TextContent), typeDiscriminator: nameof(TextContent))] [JsonDerivedType(typeof(ImageContent), typeDiscriminator: nameof(ImageContent))] [JsonDerivedType(typeof(FunctionCallContent), typeDiscriminator: nameof(FunctionCallContent))] [JsonDerivedType(typeof(FunctionResultContent), typeDiscriminator: nameof(FunctionResultContent))] [JsonDerivedType(typeof(BinaryContent), typeDiscriminator: nameof(BinaryContent))] [JsonDerivedType(typeof(AudioContent), typeDiscriminator: nameof(AudioContent))] +[JsonDerivedType(typeof(ChatMessageContent), typeDiscriminator: nameof(ChatMessageContent))] #pragma warning disable SKEXP0110 [JsonDerivedType(typeof(AnnotationContent), typeDiscriminator: nameof(AnnotationContent))] [JsonDerivedType(typeof(FileReferenceContent), typeDiscriminator: nameof(FileReferenceContent))] diff --git a/dotnet/src/SemanticKernel.UnitTests/Contents/ChatMessageContentTests.cs b/dotnet/src/SemanticKernel.UnitTests/Contents/ChatMessageContentTests.cs index 9fe258d8bba8..cd753a15e201 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Contents/ChatMessageContentTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Contents/ChatMessageContentTests.cs @@ -202,7 +202,7 @@ public void ItCanBeSerializeAndDeserialized() new FunctionCallContent("function-name", "plugin-name", "function-id", new KernelArguments { ["parameter"] = "argument" }), new FunctionResultContent(new FunctionCallContent("function-name", "plugin-name", "function-id"), "function-result"), new FileReferenceContent(fileId: "file-id-1") { ModelId = "model-7", Metadata = new Dictionary() { ["metadata-key-7"] = "metadata-value-7" } }, - new AnnotationContent("quote-8") { ModelId = "model-8", FileId = "file-id-2", StartIndex = 2, EndIndex = 24, Metadata = new Dictionary() { ["metadata-key-8"] = "metadata-value-8" } } + new AnnotationContent("quote-8") { ModelId = "model-8", FileId = "file-id-2", StartIndex = 2, EndIndex = 24, Metadata = new Dictionary() { ["metadata-key-8"] = "metadata-value-8" } }, ]; // Act @@ -320,4 +320,92 @@ public void ItCanBeSerializeAndDeserialized() Assert.Single(annotationContent.Metadata); Assert.Equal("metadata-value-8", annotationContent.Metadata["metadata-key-8"]?.ToString()); } + + [Fact] + public void ItCanBePolymorphicallySerializedAndDeserializedAsKernelContentType() + { + // Arrange + KernelContent sut = new ChatMessageContent(AuthorRole.User, "test-content", "test-model", metadata: new Dictionary() + { + ["test-metadata-key"] = "test-metadata-value" + }) + { + MimeType = "test-mime-type" + }; + + // Act + var json = JsonSerializer.Serialize(sut); + + var deserialized = JsonSerializer.Deserialize(json)!; + + // Assert + Assert.IsType(deserialized); + Assert.Equal("test-content", ((ChatMessageContent)deserialized).Content); + Assert.Equal("test-model", deserialized.ModelId); + Assert.Equal("test-mime-type", deserialized.MimeType); + Assert.NotNull(deserialized.Metadata); + Assert.Single(deserialized.Metadata); + Assert.Equal("test-metadata-value", deserialized.Metadata["test-metadata-key"]?.ToString()); + } + + [Fact] + public void UnknownDerivativeCanBePolymorphicallySerializedAndDeserializedAsChatMessageContentType() + { + // Arrange + KernelContent sut = new UnknownExternalChatMessageContent(AuthorRole.User, "test-content") + { + MimeType = "test-mime-type", + }; + + // Act + var json = JsonSerializer.Serialize(sut); + + var deserialized = JsonSerializer.Deserialize(json)!; + + // Assert + Assert.IsType(deserialized); + Assert.Equal("test-content", ((ChatMessageContent)deserialized).Content); + Assert.Equal("test-mime-type", deserialized.MimeType); + } + + [Fact] + public void ItCanBeSerializeAndDeserializedWithFunctionResultOfChatMessageType() + { + // Arrange + ChatMessageContentItemCollection items = [ + new FunctionResultContent(new FunctionCallContent("function-name-1", "plugin-name-1", "function-id-1"), new ChatMessageContent(AuthorRole.User, "test-content-1")), + new FunctionResultContent(new FunctionCallContent("function-name-2", "plugin-name-2", "function-id-2"), new UnknownExternalChatMessageContent(AuthorRole.Assistant, "test-content-2")), + ]; + + // Act + var chatMessageJson = JsonSerializer.Serialize(new ChatMessageContent(AuthorRole.User, items: items, "message-model")); + + var deserializedMessage = JsonSerializer.Deserialize(chatMessageJson)!; + + // Assert + var functionResultContentWithResultOfChatMessageContentType = deserializedMessage.Items[0] as FunctionResultContent; + Assert.NotNull(functionResultContentWithResultOfChatMessageContentType); + Assert.Equal("function-name-1", functionResultContentWithResultOfChatMessageContentType.FunctionName); + Assert.Equal("function-id-1", functionResultContentWithResultOfChatMessageContentType.CallId); + Assert.Equal("plugin-name-1", functionResultContentWithResultOfChatMessageContentType.PluginName); + var chatMessageContent = Assert.IsType(functionResultContentWithResultOfChatMessageContentType.Result); + Assert.Equal("user", chatMessageContent.GetProperty("Role").GetProperty("Label").GetString()); + Assert.Equal("test-content-1", chatMessageContent.GetProperty("Items")[0].GetProperty("Text").GetString()); + + var functionResultContentWithResultOfUnknownChatMessageContentType = deserializedMessage.Items[1] as FunctionResultContent; + Assert.NotNull(functionResultContentWithResultOfUnknownChatMessageContentType); + Assert.Equal("function-name-2", functionResultContentWithResultOfUnknownChatMessageContentType.FunctionName); + Assert.Equal("function-id-2", functionResultContentWithResultOfUnknownChatMessageContentType.CallId); + Assert.Equal("plugin-name-2", functionResultContentWithResultOfUnknownChatMessageContentType.PluginName); + var unknownChatMessageContent = Assert.IsType(functionResultContentWithResultOfUnknownChatMessageContentType.Result); + Assert.Equal("assistant", unknownChatMessageContent.GetProperty("Role").GetProperty("Label").GetString()); + Assert.Equal("test-content-2", unknownChatMessageContent.GetProperty("Items")[0].GetProperty("Text").GetString()); + } + + private sealed class UnknownExternalChatMessageContent : ChatMessageContent + { + public UnknownExternalChatMessageContent(AuthorRole role, string? content) : base(role, content) + { + } + } }