-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(textGeneration): implement GeminiOpenAIProvider
Add a GeminiOpenAIProvider class with support for chat completions using the Gemini API. Implement tests to verify functionality and ensure unsupported operations throw appropriate exceptions. Include additional utility functions and config setup for Gemini integration.
- Loading branch information
Showing
8 changed files
with
247 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
65 changes: 65 additions & 0 deletions
65
...re/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/GeminiOpenAIProvider.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
package com.tddworks.openai.gateway.api.internal | ||
|
||
import com.tddworks.anthropic.api.messages.api.toAnthropicRequest | ||
import com.tddworks.common.network.api.ktor.api.ListResponse | ||
import com.tddworks.gemini.api.textGeneration.api.* | ||
import com.tddworks.openai.api.chat.api.ChatCompletion | ||
import com.tddworks.openai.api.chat.api.ChatCompletionChunk | ||
import com.tddworks.openai.api.chat.api.ChatCompletionRequest | ||
import com.tddworks.openai.api.chat.api.OpenAIModel | ||
import com.tddworks.openai.api.images.api.Image | ||
import com.tddworks.openai.api.images.api.ImageCreate | ||
import com.tddworks.openai.api.legacy.completions.api.Completion | ||
import com.tddworks.openai.api.legacy.completions.api.CompletionRequest | ||
import com.tddworks.openai.gateway.api.OpenAIProvider | ||
import com.tddworks.openai.gateway.api.OpenAIProviderConfig | ||
import kotlinx.coroutines.flow.Flow | ||
import kotlinx.coroutines.flow.transform | ||
import kotlinx.serialization.ExperimentalSerializationApi | ||
|
||
@OptIn(ExperimentalSerializationApi::class) | ||
class GeminiOpenAIProvider( | ||
override val id: String = "gemini", | ||
override val name: String = "Gemini", | ||
override val models: List<OpenAIModel> = GeminiModel.availableModels.map { | ||
OpenAIModel(it.value) | ||
}, | ||
override val config: OpenAIProviderConfig, | ||
val client: Gemini | ||
) : OpenAIProvider { | ||
|
||
override fun supports(model: OpenAIModel): Boolean { | ||
return models.any { it.value == model.value } | ||
} | ||
|
||
override suspend fun chatCompletions(request: ChatCompletionRequest): ChatCompletion { | ||
val geminiRequest = request.toGeminiGenerateContentRequest() | ||
return client.generateContent(geminiRequest).toOpenAIChatCompletion() | ||
} | ||
|
||
override fun streamChatCompletions(request: ChatCompletionRequest): Flow<ChatCompletionChunk> { | ||
val geminiRequest = request.toGeminiGenerateContentRequest() | ||
|
||
return client.streamGenerateContent(geminiRequest).transform { | ||
emit(it.toOpenAIChatCompletionChunk()) | ||
} | ||
} | ||
|
||
override suspend fun completions(request: CompletionRequest): Completion { | ||
throw UnsupportedOperationException("Not supported") | ||
} | ||
|
||
override suspend fun generate(request: ImageCreate): ListResponse<Image> { | ||
throw UnsupportedOperationException("Not supported") | ||
} | ||
} | ||
|
||
fun OpenAIProvider.Companion.gemini( | ||
id: String = "gemini", models: List<OpenAIModel> = GeminiModel.availableModels.map { | ||
OpenAIModel(it.value) | ||
}, config: OpenAIProviderConfig, client: Gemini | ||
): OpenAIProvider { | ||
return GeminiOpenAIProvider( | ||
id = id, models = models, config = config, client = client | ||
) | ||
} |
22 changes: 22 additions & 0 deletions
22
.../commonMain/kotlin/com/tddworks/openai/gateway/api/internal/GeminiOpenAIProviderConfig.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
package com.tddworks.openai.gateway.api.internal | ||
|
||
import com.tddworks.anthropic.api.Anthropic | ||
import com.tddworks.gemini.api.textGeneration.api.Gemini | ||
import com.tddworks.gemini.api.textGeneration.api.GeminiConfig | ||
import com.tddworks.openai.gateway.api.OpenAIProviderConfig | ||
|
||
class GeminiOpenAIProviderConfig( | ||
override val apiKey: () -> String, | ||
override val baseUrl: () -> String = { Gemini.BASE_URL } | ||
) : OpenAIProviderConfig | ||
|
||
fun OpenAIProviderConfig.toGeminiConfig() = | ||
GeminiConfig( | ||
apiKey = apiKey, | ||
baseUrl = baseUrl, | ||
) | ||
|
||
fun OpenAIProviderConfig.Companion.gemini( | ||
apiKey: () -> String, | ||
baseUrl: () -> String = { Anthropic.BASE_URL }, | ||
) = GeminiOpenAIProviderConfig(apiKey, baseUrl) |
137 changes: 137 additions & 0 deletions
137
...e/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/internal/GeminiOpenAIProviderTest.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
package com.tddworks.openai.gateway.api.internal | ||
|
||
import app.cash.turbine.test | ||
import com.tddworks.gemini.api.textGeneration.api.* | ||
import com.tddworks.openai.api.chat.api.ChatCompletionRequest | ||
import com.tddworks.openai.api.chat.api.OpenAIModel | ||
import com.tddworks.openai.api.images.api.ImageCreate | ||
import com.tddworks.openai.api.legacy.completions.api.CompletionRequest | ||
import com.tddworks.openai.gateway.api.OpenAIProvider | ||
import kotlinx.coroutines.flow.flow | ||
import kotlinx.coroutines.test.runTest | ||
import kotlinx.serialization.ExperimentalSerializationApi | ||
import org.junit.jupiter.api.Assertions.assertEquals | ||
import org.junit.jupiter.api.Assertions.assertTrue | ||
import org.junit.jupiter.api.BeforeEach | ||
import org.junit.jupiter.api.Test | ||
import org.mockito.kotlin.mock | ||
import org.mockito.kotlin.whenever | ||
import kotlin.test.assertFalse | ||
|
||
@OptIn(ExperimentalSerializationApi::class) | ||
class GeminiOpenAIProviderTest { | ||
private lateinit var client: Gemini | ||
private lateinit var config: GeminiOpenAIProviderConfig | ||
|
||
private lateinit var provider: OpenAIProvider | ||
|
||
@BeforeEach | ||
fun setUp() { | ||
client = mock() | ||
config = mock() | ||
provider = OpenAIProvider.gemini( | ||
client = client, | ||
config = config | ||
) | ||
} | ||
|
||
@Test | ||
fun `should throw not supported when invoke generate`() = runTest { | ||
// given | ||
val request = ImageCreate.create("A cute baby sea otter", OpenAIModel.DALL_E_3) | ||
|
||
runCatching { | ||
// when | ||
provider.generate(request) | ||
}.onFailure { | ||
// then | ||
assertEquals("Not supported", it.message) | ||
} | ||
} | ||
|
||
@Test | ||
fun `should throw not supported when invoke completions`() = runTest { | ||
// given | ||
val request = CompletionRequest( | ||
prompt = "Once upon a time", | ||
suffix = "The end", | ||
maxTokens = 10, | ||
temperature = 0.5 | ||
) | ||
|
||
runCatching { | ||
// when | ||
provider.completions(request) | ||
}.onFailure { | ||
// then | ||
assertEquals("Not supported", it.message) | ||
} | ||
} | ||
|
||
@Test | ||
fun `should fetch chat completions from OpenAI API`() = runTest { | ||
// given | ||
val request = | ||
ChatCompletionRequest.dummy(OpenAIModel(GeminiModel.GEMINI_1_5_FLASH.value)) | ||
val response = GenerateContentResponse.dummy() | ||
whenever(client.generateContent(request.toGeminiGenerateContentRequest())).thenReturn( | ||
response | ||
) | ||
|
||
// when | ||
val completions = provider.chatCompletions(request) | ||
|
||
// then | ||
assertEquals(response.toOpenAIChatCompletion(), completions) | ||
} | ||
|
||
@Test | ||
fun `should stream chat completions for chat`() = runTest { | ||
// given | ||
val request = | ||
ChatCompletionRequest.dummy(OpenAIModel(GeminiModel.GEMINI_1_5_FLASH.value)) | ||
|
||
val response = GenerateContentResponse.dummy() | ||
whenever(client.streamGenerateContent(request.toGeminiGenerateContentRequest())).thenReturn( | ||
flow { | ||
emit( | ||
response | ||
) | ||
}) | ||
|
||
// when | ||
provider.streamChatCompletions(request).test { | ||
// then | ||
assertEquals( | ||
response.toOpenAIChatCompletionChunk(), | ||
awaitItem() | ||
) | ||
awaitComplete() | ||
} | ||
|
||
} | ||
|
||
@Test | ||
fun `should return false when model is not supported`() { | ||
// given | ||
val model = OpenAIModel(OpenAIModel.GPT_3_5_TURBO.value) | ||
|
||
// when | ||
val supported = provider.supports(model) | ||
|
||
// then | ||
assertFalse(supported) | ||
} | ||
|
||
@Test | ||
fun `should return true when model is supported`() { | ||
// given | ||
val model = OpenAIModel(GeminiModel.GEMINI_1_5_FLASH.value) | ||
|
||
// when | ||
val supported = provider.supports(model) | ||
|
||
// then | ||
assertTrue(supported) | ||
} | ||
} |