From fbdd6bc1353b473e51ceea5fc7f96a627e33ec82 Mon Sep 17 00:00:00 2001
From: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com>
Date: Thu, 19 Sep 2024 14:13:23 +0100
Subject: [PATCH] .Net: Support polymorphic serialization of ChatMessageContent
class and its derivatives (#8901)
### Motivation, Context and Description
Today, when serializing chat history, the `ChatMessageContent` type or
its derivatives, like `OpenAIChatMessageContent`, are serialized as the
`ChatMessageContent` type because the type is used as the chat history
element type and the JSON serializer uses this type's public contract
for serialization.
However, when attempting serialization of instances of either the
`ChatMessageContent` or `OpenAIChatMessageContent` type that are
declared as `KernelContent` or `object` type, the serialization fails
with the `System.NotSupportedException: Runtime type
'{OpenAI}ChatMessageContent' is not supported by polymorphic type
'Microsoft.SemanticKernel.KernelContent'` exception. The reason for this
exception is that neither of these types is registered for polymorphic
serialization.
This PR registers `ChatMessageContent` type for polymorphic
serialization to allow serialization of the type instances declared as
of `KernelContent` or `object` types:
```csharp
KernelContent content = new ChatMessageContent(...);
// Now it's possible to serialize the content variable of KernelContent type that holds reference to an instance of the ChatMessageContent type as ChatMessageContent type.
var json = JsonSerializer.Serialize(content);
```
Additionally, it enables serialization of unknow in advance and external
derivatives of `ChatMessageContent` type like
`OpenAIChatMessageContent`. These types are serialized using contract of
nearest ancestor which is `ChatMessageContent` by default. To change
this behavior and register the unknown type for polymorphic
serialization use the contract model - [Configure polymorphism with the
contract
model](https://learn.microsoft.com/en-us/dotnet/standard/serialization/system-text-json/polymorphism?pivots=dotnet-8-0#configure-polymorphism-with-the-contract-model)
```csharp
KernelContent content = new UnknownChatMessageContent(...);
// The content variable will be serialized using the ChatMessageContent type contract.
var json = JsonSerializer.Serialize(content);
private class UnknownChatMessageContent : ChatMessageContent{}
```
Closes: https://github.com/microsoft/semantic-kernel/issues/7478
### Contribution Checklist
- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone :smile:
---
.../Contents/KernelContent.cs | 3 +-
.../Contents/ChatMessageContentTests.cs | 90 ++++++++++++++++++-
2 files changed, 91 insertions(+), 2 deletions(-)
diff --git a/dotnet/src/SemanticKernel.Abstractions/Contents/KernelContent.cs b/dotnet/src/SemanticKernel.Abstractions/Contents/KernelContent.cs
index 183542021705..8dbcc00eb25d 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Contents/KernelContent.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Contents/KernelContent.cs
@@ -9,13 +9,14 @@ namespace Microsoft.SemanticKernel;
///
/// Base class for all AI non-streaming results
///
-[JsonPolymorphic(TypeDiscriminatorPropertyName = "$type")]
+[JsonPolymorphic(TypeDiscriminatorPropertyName = "$type", UnknownDerivedTypeHandling = JsonUnknownDerivedTypeHandling.FallBackToNearestAncestor)]
[JsonDerivedType(typeof(TextContent), typeDiscriminator: nameof(TextContent))]
[JsonDerivedType(typeof(ImageContent), typeDiscriminator: nameof(ImageContent))]
[JsonDerivedType(typeof(FunctionCallContent), typeDiscriminator: nameof(FunctionCallContent))]
[JsonDerivedType(typeof(FunctionResultContent), typeDiscriminator: nameof(FunctionResultContent))]
[JsonDerivedType(typeof(BinaryContent), typeDiscriminator: nameof(BinaryContent))]
[JsonDerivedType(typeof(AudioContent), typeDiscriminator: nameof(AudioContent))]
+[JsonDerivedType(typeof(ChatMessageContent), typeDiscriminator: nameof(ChatMessageContent))]
#pragma warning disable SKEXP0110
[JsonDerivedType(typeof(AnnotationContent), typeDiscriminator: nameof(AnnotationContent))]
[JsonDerivedType(typeof(FileReferenceContent), typeDiscriminator: nameof(FileReferenceContent))]
diff --git a/dotnet/src/SemanticKernel.UnitTests/Contents/ChatMessageContentTests.cs b/dotnet/src/SemanticKernel.UnitTests/Contents/ChatMessageContentTests.cs
index 9fe258d8bba8..cd753a15e201 100644
--- a/dotnet/src/SemanticKernel.UnitTests/Contents/ChatMessageContentTests.cs
+++ b/dotnet/src/SemanticKernel.UnitTests/Contents/ChatMessageContentTests.cs
@@ -202,7 +202,7 @@ public void ItCanBeSerializeAndDeserialized()
new FunctionCallContent("function-name", "plugin-name", "function-id", new KernelArguments { ["parameter"] = "argument" }),
new FunctionResultContent(new FunctionCallContent("function-name", "plugin-name", "function-id"), "function-result"),
new FileReferenceContent(fileId: "file-id-1") { ModelId = "model-7", Metadata = new Dictionary() { ["metadata-key-7"] = "metadata-value-7" } },
- new AnnotationContent("quote-8") { ModelId = "model-8", FileId = "file-id-2", StartIndex = 2, EndIndex = 24, Metadata = new Dictionary() { ["metadata-key-8"] = "metadata-value-8" } }
+ new AnnotationContent("quote-8") { ModelId = "model-8", FileId = "file-id-2", StartIndex = 2, EndIndex = 24, Metadata = new Dictionary() { ["metadata-key-8"] = "metadata-value-8" } },
];
// Act
@@ -320,4 +320,92 @@ public void ItCanBeSerializeAndDeserialized()
Assert.Single(annotationContent.Metadata);
Assert.Equal("metadata-value-8", annotationContent.Metadata["metadata-key-8"]?.ToString());
}
+
+ [Fact]
+ public void ItCanBePolymorphicallySerializedAndDeserializedAsKernelContentType()
+ {
+ // Arrange
+ KernelContent sut = new ChatMessageContent(AuthorRole.User, "test-content", "test-model", metadata: new Dictionary()
+ {
+ ["test-metadata-key"] = "test-metadata-value"
+ })
+ {
+ MimeType = "test-mime-type"
+ };
+
+ // Act
+ var json = JsonSerializer.Serialize(sut);
+
+ var deserialized = JsonSerializer.Deserialize(json)!;
+
+ // Assert
+ Assert.IsType(deserialized);
+ Assert.Equal("test-content", ((ChatMessageContent)deserialized).Content);
+ Assert.Equal("test-model", deserialized.ModelId);
+ Assert.Equal("test-mime-type", deserialized.MimeType);
+ Assert.NotNull(deserialized.Metadata);
+ Assert.Single(deserialized.Metadata);
+ Assert.Equal("test-metadata-value", deserialized.Metadata["test-metadata-key"]?.ToString());
+ }
+
+ [Fact]
+ public void UnknownDerivativeCanBePolymorphicallySerializedAndDeserializedAsChatMessageContentType()
+ {
+ // Arrange
+ KernelContent sut = new UnknownExternalChatMessageContent(AuthorRole.User, "test-content")
+ {
+ MimeType = "test-mime-type",
+ };
+
+ // Act
+ var json = JsonSerializer.Serialize(sut);
+
+ var deserialized = JsonSerializer.Deserialize(json)!;
+
+ // Assert
+ Assert.IsType(deserialized);
+ Assert.Equal("test-content", ((ChatMessageContent)deserialized).Content);
+ Assert.Equal("test-mime-type", deserialized.MimeType);
+ }
+
+ [Fact]
+ public void ItCanBeSerializeAndDeserializedWithFunctionResultOfChatMessageType()
+ {
+ // Arrange
+ ChatMessageContentItemCollection items = [
+ new FunctionResultContent(new FunctionCallContent("function-name-1", "plugin-name-1", "function-id-1"), new ChatMessageContent(AuthorRole.User, "test-content-1")),
+ new FunctionResultContent(new FunctionCallContent("function-name-2", "plugin-name-2", "function-id-2"), new UnknownExternalChatMessageContent(AuthorRole.Assistant, "test-content-2")),
+ ];
+
+ // Act
+ var chatMessageJson = JsonSerializer.Serialize(new ChatMessageContent(AuthorRole.User, items: items, "message-model"));
+
+ var deserializedMessage = JsonSerializer.Deserialize(chatMessageJson)!;
+
+ // Assert
+ var functionResultContentWithResultOfChatMessageContentType = deserializedMessage.Items[0] as FunctionResultContent;
+ Assert.NotNull(functionResultContentWithResultOfChatMessageContentType);
+ Assert.Equal("function-name-1", functionResultContentWithResultOfChatMessageContentType.FunctionName);
+ Assert.Equal("function-id-1", functionResultContentWithResultOfChatMessageContentType.CallId);
+ Assert.Equal("plugin-name-1", functionResultContentWithResultOfChatMessageContentType.PluginName);
+ var chatMessageContent = Assert.IsType(functionResultContentWithResultOfChatMessageContentType.Result);
+ Assert.Equal("user", chatMessageContent.GetProperty("Role").GetProperty("Label").GetString());
+ Assert.Equal("test-content-1", chatMessageContent.GetProperty("Items")[0].GetProperty("Text").GetString());
+
+ var functionResultContentWithResultOfUnknownChatMessageContentType = deserializedMessage.Items[1] as FunctionResultContent;
+ Assert.NotNull(functionResultContentWithResultOfUnknownChatMessageContentType);
+ Assert.Equal("function-name-2", functionResultContentWithResultOfUnknownChatMessageContentType.FunctionName);
+ Assert.Equal("function-id-2", functionResultContentWithResultOfUnknownChatMessageContentType.CallId);
+ Assert.Equal("plugin-name-2", functionResultContentWithResultOfUnknownChatMessageContentType.PluginName);
+ var unknownChatMessageContent = Assert.IsType(functionResultContentWithResultOfUnknownChatMessageContentType.Result);
+ Assert.Equal("assistant", unknownChatMessageContent.GetProperty("Role").GetProperty("Label").GetString());
+ Assert.Equal("test-content-2", unknownChatMessageContent.GetProperty("Items")[0].GetProperty("Text").GetString());
+ }
+
+ private sealed class UnknownExternalChatMessageContent : ChatMessageContent
+ {
+ public UnknownExternalChatMessageContent(AuthorRole role, string? content) : base(role, content)
+ {
+ }
+ }
}