Skip to content

Commit

Permalink
fix #398
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenHodgson committed Jan 12, 2025
1 parent f741b31 commit 54f05cc
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 23 deletions.
120 changes: 112 additions & 8 deletions OpenAI-DotNet-Tests/TestFixture_04_Chat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using OpenAI.Tests.Weather;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading.Tasks;

Expand Down Expand Up @@ -75,7 +76,7 @@ public async Task Test_01_02_GetChatStreamingCompletion()
}

[Test]
public async Task Test_01_03_GetChatCompletion_Modalities()
public async Task Test_01_03_01_GetChatCompletion_Modalities()
{
Assert.IsNotNull(OpenAIClient.ChatEndpoint);

Expand Down Expand Up @@ -123,6 +124,51 @@ public async Task Test_01_03_GetChatCompletion_Modalities()
response.GetUsage();
}

[Test]
public async Task Test_01_03_01_GetChatCompletion_Modalities_Streaming()
{
Assert.IsNotNull(OpenAIClient.ChatEndpoint);
var messages = new List<Message>
{
new(Role.System, "You are a helpful assistant."),
new(Role.User, "Is a golden retriever a good family dog?"),
};
var chatRequest = new ChatRequest(messages, Model.GPT4oAudio, audioConfig: Voice.Alloy);
Assert.IsNotNull(chatRequest);
Assert.IsNotNull(chatRequest.AudioConfig);
Assert.AreEqual(Model.GPT4oAudio.Id, chatRequest.Model);
Assert.AreEqual(Voice.Alloy.Id, chatRequest.AudioConfig.Voice);
Assert.AreEqual(AudioFormat.Pcm16, chatRequest.AudioConfig.Format);
Assert.AreEqual(Modality.Text | Modality.Audio, chatRequest.Modalities);
var response = await OpenAIClient.ChatEndpoint.StreamCompletionAsync(chatRequest, Assert.IsNotNull, true);
Assert.IsNotNull(response);
Assert.IsNotNull(response.Choices);
Assert.IsNotEmpty(response.Choices);
Assert.AreEqual(1, response.Choices.Count);
Assert.IsNotNull(response.FirstChoice);
Console.WriteLine($"{response.FirstChoice.Message.Role}: {response.FirstChoice} | Finish Reason: {response.FirstChoice.FinishReason}");
Assert.IsNotEmpty(response.FirstChoice.Message.AudioOutput.Transcript);
Assert.IsNotNull(response.FirstChoice.Message.AudioOutput.Data);
Assert.IsFalse(response.FirstChoice.Message.AudioOutput.Data.IsEmpty);
response.GetUsage();
messages.Add(response.FirstChoice.Message);
messages.Add(new(Role.User, "What are some other good family dog breeds?"));
chatRequest = new ChatRequest(messages, Model.GPT4oAudio, audioConfig: Voice.Alloy);
Assert.IsNotNull(chatRequest);
Assert.IsNotNull(messages[2]);
Assert.AreEqual(Role.Assistant, messages[2].Role);
Assert.IsNotNull(messages[2].AudioOutput);
response = await OpenAIClient.ChatEndpoint.StreamCompletionAsync(chatRequest, Assert.IsNotNull, true);
Assert.IsNotNull(response);
Assert.IsNotNull(response.Choices);
Assert.IsNotEmpty(response.Choices);
Assert.AreEqual(1, response.Choices.Count);
Assert.IsNotEmpty(response.FirstChoice.Message.AudioOutput.Transcript);
Assert.IsNotNull(response.FirstChoice.Message.AudioOutput.Data);
Assert.IsFalse(response.FirstChoice.Message.AudioOutput.Data.IsEmpty);
Assert.IsFalse(string.IsNullOrWhiteSpace(response.FirstChoice));
}

[Test]
public async Task Test_01_04_JsonMode()
{
Expand All @@ -147,7 +193,7 @@ public async Task Test_01_04_JsonMode()
}

[Test]
public async Task Test_01_05_GetChatStreamingCompletionEnumerableAsync()
public async Task Test_01_05_01_GetChatStreamingCompletionEnumerableAsync()
{
Assert.IsNotNull(OpenAIClient.ChatEndpoint);
var messages = new List<Message>
Expand All @@ -159,19 +205,77 @@ public async Task Test_01_05_GetChatStreamingCompletionEnumerableAsync()
};
var cumulativeDelta = string.Empty;
var chatRequest = new ChatRequest(messages);
var didThrowException = false;

await foreach (var partialResponse in OpenAIClient.ChatEndpoint.StreamCompletionEnumerableAsync(chatRequest, true))
{
Assert.IsNotNull(partialResponse);
if (partialResponse.Usage != null) { return; }
Assert.NotNull(partialResponse.Choices);
Assert.NotZero(partialResponse.Choices.Count);
try
{
Assert.IsNotNull(partialResponse);
if (partialResponse.Usage != null) { continue; }
Assert.NotNull(partialResponse.Choices);
Assert.NotZero(partialResponse.Choices.Count);

foreach (var choice in partialResponse.Choices.Where(choice => choice.Delta?.Content != null))
if (partialResponse.FirstChoice?.Delta?.Content is not null)
{
cumulativeDelta += partialResponse.FirstChoice.Delta.Content;
}
}
catch (Exception e)
{
cumulativeDelta += choice.Delta.Content;
Console.WriteLine(e);
didThrowException = true;
}
}

Assert.IsFalse(didThrowException);
Assert.IsNotEmpty(cumulativeDelta);
Console.WriteLine(cumulativeDelta);
}

[Test]
public async Task Test_01_05_02_GetChatStreamingModalitiesEnumerableAsync()
{
Assert.IsNotNull(OpenAIClient.ChatEndpoint);

var messages = new List<Message>
{
new(Role.System, "You are a helpful assistant."),
new(Role.User, "Count from 1 to 10. Whisper please.")
};

var cumulativeDelta = string.Empty;
using var audioStream = new MemoryStream();
var chatRequest = new ChatRequest(messages, audioConfig: new AudioConfig(Voice.Nova), model: Model.GPT4oAudio);
Assert.IsNotNull(chatRequest);
Assert.IsNotNull(chatRequest.AudioConfig);
Assert.AreEqual(Model.GPT4oAudio.Id, chatRequest.Model);
Assert.AreEqual(Voice.Nova.Id, chatRequest.AudioConfig.Voice);
Assert.AreEqual(AudioFormat.Pcm16, chatRequest.AudioConfig.Format);
Assert.AreEqual(Modality.Text | Modality.Audio, chatRequest.Modalities);
var didThrowException = false;

await foreach (var partialResponse in OpenAIClient.ChatEndpoint.StreamCompletionEnumerableAsync(chatRequest, true))
{
try
{
Assert.IsNotNull(partialResponse);
if (partialResponse.Usage != null || partialResponse.Choices == null) { continue; }

if (partialResponse.FirstChoice?.Delta?.AudioOutput is not null)
{
await audioStream.WriteAsync(partialResponse.FirstChoice.Delta.AudioOutput.Data);
}
}
catch (Exception e)
{
Console.WriteLine(e);
didThrowException = true;
}
}

Assert.IsFalse(didThrowException);
Assert.IsTrue(audioStream.Length > 0);
Console.WriteLine(cumulativeDelta);
}

Expand Down
46 changes: 38 additions & 8 deletions OpenAI-DotNet/Chat/AudioOutput.cs
Original file line number Diff line number Diff line change
@@ -1,31 +1,61 @@
// Licensed under the MIT License. See LICENSE in the project root for license information.

using System;
using System.Linq;
using System.Text.Json.Serialization;

namespace OpenAI.Chat
{
[JsonConverter(typeof(AudioOutputConverter))]
public sealed class AudioOutput
{
internal AudioOutput(string id, int expiresAtUnixSeconds, ReadOnlyMemory<byte> data, string transcript)
internal AudioOutput(string id, int? expiresAtUnixSeconds, Memory<byte> data, string transcript)
{
Id = id;
ExpiresAtUnixSeconds = expiresAtUnixSeconds;
Data = data;
this.data = data;
Transcript = transcript;
ExpiresAtUnixSeconds = expiresAtUnixSeconds;
}

public string Id { get; }
public string Id { get; private set; }

public string Transcript { get; private set; }

public int ExpiresAtUnixSeconds { get; }
private Memory<byte> data;

public DateTime ExpiresAt => DateTimeOffset.FromUnixTimeSeconds(ExpiresAtUnixSeconds).DateTime;
public ReadOnlyMemory<byte> Data => data;

public ReadOnlyMemory<byte> Data { get; }
public int? ExpiresAtUnixSeconds { get; private set; }

public string Transcript { get; }
public DateTime? ExpiresAt => ExpiresAtUnixSeconds.HasValue
? DateTimeOffset.FromUnixTimeSeconds(ExpiresAtUnixSeconds.Value).DateTime
: null;

public override string ToString() => Transcript ?? string.Empty;

internal void AppendFrom(AudioOutput other)
{
if (other == null) { return; }

if (!string.IsNullOrWhiteSpace(other.Id))
{
Id = other.Id;
}

if (other.ExpiresAtUnixSeconds.HasValue)
{
ExpiresAtUnixSeconds = other.ExpiresAtUnixSeconds;
}

if (!string.IsNullOrWhiteSpace(other.Transcript))
{
Transcript += other.Transcript;
}

if (other.Data.Length > 0)
{
data = data.ToArray().Concat(other.Data.ToArray()).ToArray();
}
}
}
}
11 changes: 7 additions & 4 deletions OpenAI-DotNet/Chat/ChatEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public async Task<ChatResponse> GetCompletionAsync(ChatRequest chatRequest, Canc
/// Created a completion for the chat message and stream the results to the <paramref name="resultHandler"/> as they come in.
/// </summary>
/// <param name="chatRequest">The chat request which contains the message content.</param>
/// <param name="resultHandler">An <see cref="Action{ChatResponse}"/> to be invoked as each new result arrives.</param>
/// <param name="resultHandler">A <see cref="Action{ChatResponse}"/> to be invoked as each new result arrives.</param>
/// <param name="streamUsage">
/// Optional, If set, an additional chunk will be streamed before the 'data: [DONE]' message.
/// The 'usage' field on this chunk shows the token usage statistics for the entire request,
Expand All @@ -82,7 +82,7 @@ public async Task<ChatResponse> StreamCompletionAsync(ChatRequest chatRequest, A
/// </summary>
/// <typeparam name="T"><see cref="JsonSchema"/> to use for structured outputs.</typeparam>
/// <param name="chatRequest">The chat request which contains the message content.</param>
/// <param name="resultHandler">An <see cref="Action{ChatResponse}"/> to be invoked as each new result arrives.</param>
/// <param name="resultHandler">A <see cref="Action{ChatResponse}"/> to be invoked as each new result arrives.</param>
/// <param name="streamUsage">
/// Optional, If set, an additional chunk will be streamed before the 'data: [DONE]' message.
/// The 'usage' field on this chunk shows the token usage statistics for the entire request,
Expand Down Expand Up @@ -196,7 +196,7 @@ public async IAsyncEnumerable<ChatResponse> StreamCompletionEnumerableAsync(Chat
await responseStream.WriteAsync("["u8.ToArray(), cancellationToken);
}

while (await reader.ReadLineAsync() is { } streamData)
while (await reader.ReadLineAsync(cancellationToken) is { } streamData)
{
cancellationToken.ThrowIfCancellationRequested();

Expand All @@ -207,7 +207,10 @@ public async IAsyncEnumerable<ChatResponse> StreamCompletionEnumerableAsync(Chat
continue;
}

if (string.IsNullOrWhiteSpace(eventData)) { continue; }
if (string.IsNullOrWhiteSpace(eventData))
{
continue;
}

if (responseStream != null)
{
Expand Down
17 changes: 16 additions & 1 deletion OpenAI-DotNet/Chat/Delta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ public sealed class Delta
[JsonPropertyName("tool_calls")]
public IReadOnlyList<ToolCall> ToolCalls { get; private set; }

/// <summary>
/// If the audio output modality is requested, this object contains data about the audio response from the model.
/// </summary>
[JsonInclude]
[JsonPropertyName("audio")]
public AudioOutput AudioOutput { get; private set; }

/// <summary>
/// Optional, The name of the author of this message.<br/>
/// May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
Expand All @@ -43,7 +50,15 @@ public sealed class Delta
[JsonPropertyName("name")]
public string Name { get; private set; }

public override string ToString() => Content ?? string.Empty;
public override string ToString()
{
if (string.IsNullOrWhiteSpace(Content))
{
return AudioOutput?.ToString() ?? string.Empty;
}

return Content ?? string.Empty;
}

public static implicit operator string(Delta delta) => delta?.ToString();
}
Expand Down
12 changes: 12 additions & 0 deletions OpenAI-DotNet/Chat/Message.cs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,18 @@ internal void AppendFrom(Delta other)
toolCalls ??= new List<ToolCall>();
toolCalls.AppendFrom(other.ToolCalls);
}

if (other is { AudioOutput: not null })
{
if (AudioOutput == null)
{
AudioOutput = other.AudioOutput;
}
else
{
AudioOutput.AppendFrom(other.AudioOutput);
}
}
}
}
}
4 changes: 2 additions & 2 deletions OpenAI-DotNet/Extensions/AudioOutputConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ internal class AudioOutputConverter : JsonConverter<AudioOutput>
public override AudioOutput Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
string id = null;
var expiresAt = 0;
int? expiresAt = null;
string b64Data = null;
string transcript = null;
ReadOnlyMemory<byte> data = null;
Memory<byte> data = null;

while (reader.Read())
{
Expand Down

0 comments on commit 54f05cc

Please sign in to comment.