Skip to content

Commit

Permalink
refactor(ut):
Browse files Browse the repository at this point in the history
 - add ut for HttpClient.kt for openai 403
 - remove unused files
  • Loading branch information
hanrw committed Jun 3, 2024
1 parent 3d20db2 commit c73e9a0
Show file tree
Hide file tree
Showing 18 changed files with 30 additions and 81 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.gradle
build/
.kotlin
!gradle/wrapper/gradle-wrapper.jar
!**/src/main/**/build/
!**/src/test/**/build/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ data class CreateMessageRequest(
val maxTokens: Int = 1024,
@SerialName("model")
val model: Model = Model.CLAUDE_3_HAIKU,
) : StreamMessageRequest {
val stream: Boolean? = null,
) {
companion object {
fun streamRequest(messages: List<Message>, systemPrompt: String? = null) =
CreateMessageRequest(messages, systemPrompt) as StreamMessageRequest
CreateMessageRequest(messages, systemPrompt, stream = true)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ import kotlinx.coroutines.flow.Flow
interface Messages {
suspend fun create(request: CreateMessageRequest): CreateMessageResponse

fun stream(request: StreamMessageRequest): Flow<StreamMessageResponse>
fun stream(request: CreateMessageRequest): Flow<StreamMessageResponse>
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@ import kotlinx.serialization.json.Json
class DefaultMessagesApi(
private val anthropicConfig: AnthropicConfig = AnthropicConfig(),
private val requester: HttpRequester,
private val jsonLenient: Json = JsonLenient,
) : Messages {

override fun stream(request: StreamMessageRequest): Flow<StreamMessageResponse> {
override fun stream(request: CreateMessageRequest): Flow<StreamMessageResponse> {
return requester.streamRequest<StreamMessageResponse> {
method = HttpMethod.Post
url(path = MESSAGE_API_PATH)
setBody(request.asStreamRequest(jsonLenient))
setBody(request.copy(stream = true))
contentType(ContentType.Application.Json)
accept(ContentType.Text.EventStream)
headers {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.tddworks.anthropic.api.messages.api.internal

import com.tddworks.anthropic.api.messages.api.internal.json.anthropicModule
import kotlinx.serialization.json.Json


Expand All @@ -15,9 +14,6 @@ import kotlinx.serialization.json.Json
val JsonLenient = Json {
isLenient = true
ignoreUnknownKeys = true
// https://github.com/Kotlin/kotlinx.serialization/blob/master/docs/json.md#class-discriminator-for-polymorphism
classDiscriminator = "#class"
serializersModule = anthropicModule
encodeDefaults = true
explicitNulls = false
}
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
package com.tddworks.anthropic.api.messages.api.internal.json

import com.tddworks.anthropic.api.messages.api.*
import com.tddworks.common.network.api.StreamableRequest
import kotlinx.serialization.KSerializer
import kotlinx.serialization.json.JsonContentPolymorphicSerializer
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.modules.polymorphic

val anthropicModule = SerializersModule {
polymorphic(StreamableRequest::class) {
subclass(CreateMessageRequest::class, CreateMessageRequest.serializer())
defaultDeserializer { CreateMessageRequest.serializer() }
}
}

object StreamMessageResponseSerializer :
JsonContentPolymorphicSerializer<StreamMessageResponse>(StreamMessageResponse::class) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ fun anthropicModules(
single<Messages> {
DefaultMessagesApi(
anthropicConfig = config,
jsonLenient = get(named("anthropicJson")),
requester = get(named("anthropicHttpRequester"))
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.tddworks.anthropic.api

import com.tddworks.anthropic.api.messages.api.internal.json.anthropicModule
import kotlinx.serialization.EncodeDefault
import kotlinx.serialization.json.Json

Expand Down Expand Up @@ -46,7 +45,5 @@ internal val JsonLenient = Json {
isLenient = true
ignoreUnknownKeys = true
// https://github.com/Kotlin/kotlinx.serialization/blob/master/docs/json.md#class-discriminator-for-polymorphism
classDiscriminator = "#class"
serializersModule = anthropicModule
encodeDefaults = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class CreateMessageRequestTest {
],
"systemPrompt": null,
"max_tokens": 1024,
"model": "claude-3-haiku-20240307"
"model": "claude-3-haiku-20240307",
"stream": null
}
""".trimIndent()

Expand Down

This file was deleted.

14 changes: 0 additions & 14 deletions common/src/commonMain/kotlin/com/tddworks/openai/api/ChatApi.kt

This file was deleted.

This file was deleted.

This file was deleted.

2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ android.nonTransitiveRClass=true
#publishing for kmmbridge
## Darwin Publish require from - nextVersion parameter must be a valid semver string. Current value: 0.1.4.
## So we need set version to 0.1 or 0.2 ......
LIBRARY_VERSION=0.2
LIBRARY_VERSION=0.3.1
GROUP=com.tddworks

# POM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.tddworks.openai.api.chat.api.ChatChoice
import com.tddworks.openai.api.chat.api.ChatChunk
import com.tddworks.openai.api.chat.api.ChatCompletionRequest
import com.tddworks.openai.api.chat.api.ChatDelta
import kotlinx.serialization.ExperimentalSerializationApi
import com.tddworks.openai.api.chat.api.ChatCompletion as OpenAIChatCompletion
import com.tddworks.openai.api.chat.api.ChatCompletionChunk as OpenAIChatCompletionChunk
import com.tddworks.openai.api.chat.api.ChatMessage.AssistantMessage as OpenAIAssistantMessage
Expand Down Expand Up @@ -81,7 +82,13 @@ fun StreamMessageResponse.toOpenAIChatCompletionChunk(model: String): OpenAIChat
}


fun ChatCompletionRequest.toAnthropicRequest(): CreateMessageRequest {
@OptIn(ExperimentalSerializationApi::class)
fun ChatCompletionRequest.toAnthropicStreamRequest(): CreateMessageRequest {
return toAnthropicRequest(true)
}

@OptIn(ExperimentalSerializationApi::class)
fun ChatCompletionRequest.toAnthropicRequest(stream: Boolean? = null): CreateMessageRequest {
return CreateMessageRequest(
model = Model(model.value),
messages = messages.map {
Expand All @@ -100,7 +107,8 @@ fun ChatCompletionRequest.toAnthropicRequest(): CreateMessageRequest {
}
}
)
}
},
stream = stream
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ class AnthropicOpenAIProvider(private val client: Anthropic) : OpenAIProvider {
* @return A Flow of OpenAIChatCompletionChunk objects representing the completions
*/
override fun streamChatCompletions(request: ChatCompletionRequest): Flow<OpenAIChatCompletionChunk> {
return client.stream(request.toAnthropicRequest() as StreamMessageRequest)
.filter { it !is ContentBlockStop && it !is Ping }
.transform {
return client.stream(
request.toAnthropicRequest().copy(
stream = true
)
).filter { it !is ContentBlockStop && it !is Ping }.transform {
emit(it.toOpenAIChatCompletionChunk(request.model.value))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class AnthropicOpenAIProviderTest {
index = 0,
contentBlock = ContentBlock(type = "some-type", text = "som-text")
)
whenever(client.stream(request.toAnthropicRequest() as StreamMessageRequest)).thenReturn(flow {
whenever(client.stream(request.toAnthropicStreamRequest())).thenReturn(flow {
emit(
contentBlockStart
)
Expand All @@ -95,7 +95,10 @@ class AnthropicOpenAIProviderTest {
// when
provider.streamChatCompletions(request).test {
// then
assertEquals(contentBlockStart.toOpenAIChatCompletionChunk(Model.CLAUDE_3_HAIKU.value), awaitItem())
assertEquals(
contentBlockStart.toOpenAIChatCompletionChunk(Model.CLAUDE_3_HAIKU.value),
awaitItem()
)
awaitComplete()
}

Expand Down

0 comments on commit c73e9a0

Please sign in to comment.