Skip to content

Commit

Permalink
feat(textGeneration): implement GeminiOpenAIProvider
Browse files Browse the repository at this point in the history
Add a GeminiOpenAIProvider class with support for chat completions using the Gemini API. Implement tests to verify functionality and ensure unsupported operations throw appropriate exceptions. Include additional utility functions and config setup for Gemini integration.
  • Loading branch information
hanrw committed Jan 2, 2025
1 parent 118cd60 commit a70e9c8
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,9 @@ interface Gemini : TextGeneration {
fun default(): Gemini {
return object : Gemini, TextGeneration by getInstance() {}
}

fun create(config: GeminiConfig): Gemini {
return object : Gemini, TextGeneration by getInstance() {}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,12 @@ value class GeminiModel(val value: String) {
* Optimized for: Next generation features, speed, and multimodal generation for a diverse variety of tasks
*/
val GEMINI_2_0_FLASH = GeminiModel("gemini-2.0-flash-exp")

val availableModels = listOf(
GEMINI_1_5_PRO,
GEMINI_1_5_FLASH_8b,
GEMINI_1_5_FLASH,
GEMINI_2_0_FLASH
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@ data class GenerateContentResponse(
val candidates: List<Candidate>,
val usageMetadata: UsageMetadata,
val modelVersion: String
)
) {
companion object {
fun dummy() = GenerateContentResponse(
candidates = emptyList(),
usageMetadata = UsageMetadata(0, 0, 0),
modelVersion = ""
)
}
}

@Serializable
data class Candidate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ fun GenerateContentResponse.toOpenAIChatCompletionChunk(): OpenAIChatCompletionC
index = 0,
delta = ChatDelta(
role = OpenAIRole.Assistant,
content = candidates.first().content.parts.first().text
content = candidates.firstOrNull()?.content?.parts?.firstOrNull()?.text
),
finishReason = candidates.first().finishReason,
finishReason = candidates.firstOrNull()?.finishReason,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class AnthropicOpenAIProvider(
}

override suspend fun generate(request: ImageCreate): ListResponse<Image> {
TODO("Not yet implemented")
throw UnsupportedOperationException("Not supported")
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package com.tddworks.openai.gateway.api.internal

import com.tddworks.anthropic.api.messages.api.toAnthropicRequest
import com.tddworks.common.network.api.ktor.api.ListResponse
import com.tddworks.gemini.api.textGeneration.api.*
import com.tddworks.openai.api.chat.api.ChatCompletion
import com.tddworks.openai.api.chat.api.ChatCompletionChunk
import com.tddworks.openai.api.chat.api.ChatCompletionRequest
import com.tddworks.openai.api.chat.api.OpenAIModel
import com.tddworks.openai.api.images.api.Image
import com.tddworks.openai.api.images.api.ImageCreate
import com.tddworks.openai.api.legacy.completions.api.Completion
import com.tddworks.openai.api.legacy.completions.api.CompletionRequest
import com.tddworks.openai.gateway.api.OpenAIProvider
import com.tddworks.openai.gateway.api.OpenAIProviderConfig
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.transform
import kotlinx.serialization.ExperimentalSerializationApi

@OptIn(ExperimentalSerializationApi::class)
class GeminiOpenAIProvider(
override val id: String = "gemini",
override val name: String = "Gemini",
override val models: List<OpenAIModel> = GeminiModel.availableModels.map {
OpenAIModel(it.value)
},
override val config: OpenAIProviderConfig,
val client: Gemini
) : OpenAIProvider {

override fun supports(model: OpenAIModel): Boolean {
return models.any { it.value == model.value }
}

override suspend fun chatCompletions(request: ChatCompletionRequest): ChatCompletion {
val geminiRequest = request.toGeminiGenerateContentRequest()
return client.generateContent(geminiRequest).toOpenAIChatCompletion()
}

override fun streamChatCompletions(request: ChatCompletionRequest): Flow<ChatCompletionChunk> {
val geminiRequest = request.toGeminiGenerateContentRequest()

return client.streamGenerateContent(geminiRequest).transform {
emit(it.toOpenAIChatCompletionChunk())
}
}

override suspend fun completions(request: CompletionRequest): Completion {
throw UnsupportedOperationException("Not supported")
}

override suspend fun generate(request: ImageCreate): ListResponse<Image> {
throw UnsupportedOperationException("Not supported")
}
}

fun OpenAIProvider.Companion.gemini(
id: String = "gemini", models: List<OpenAIModel> = GeminiModel.availableModels.map {
OpenAIModel(it.value)
}, config: OpenAIProviderConfig, client: Gemini
): OpenAIProvider {
return GeminiOpenAIProvider(
id = id, models = models, config = config, client = client
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.tddworks.openai.gateway.api.internal

import com.tddworks.anthropic.api.Anthropic
import com.tddworks.gemini.api.textGeneration.api.Gemini
import com.tddworks.gemini.api.textGeneration.api.GeminiConfig
import com.tddworks.openai.gateway.api.OpenAIProviderConfig

class GeminiOpenAIProviderConfig(
override val apiKey: () -> String,
override val baseUrl: () -> String = { Gemini.BASE_URL }
) : OpenAIProviderConfig

fun OpenAIProviderConfig.toGeminiConfig() =
GeminiConfig(
apiKey = apiKey,
baseUrl = baseUrl,
)

fun OpenAIProviderConfig.Companion.gemini(
apiKey: () -> String,
baseUrl: () -> String = { Anthropic.BASE_URL },
) = GeminiOpenAIProviderConfig(apiKey, baseUrl)
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package com.tddworks.openai.gateway.api.internal

import app.cash.turbine.test
import com.tddworks.gemini.api.textGeneration.api.*
import com.tddworks.openai.api.chat.api.ChatCompletionRequest
import com.tddworks.openai.api.chat.api.OpenAIModel
import com.tddworks.openai.api.images.api.ImageCreate
import com.tddworks.openai.api.legacy.completions.api.CompletionRequest
import com.tddworks.openai.gateway.api.OpenAIProvider
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.test.runTest
import kotlinx.serialization.ExperimentalSerializationApi
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.mockito.kotlin.mock
import org.mockito.kotlin.whenever
import kotlin.test.assertFalse

@OptIn(ExperimentalSerializationApi::class)
class GeminiOpenAIProviderTest {
private lateinit var client: Gemini
private lateinit var config: GeminiOpenAIProviderConfig

private lateinit var provider: OpenAIProvider

@BeforeEach
fun setUp() {
client = mock()
config = mock()
provider = OpenAIProvider.gemini(
client = client,
config = config
)
}

@Test
fun `should throw not supported when invoke generate`() = runTest {
// given
val request = ImageCreate.create("A cute baby sea otter", OpenAIModel.DALL_E_3)

runCatching {
// when
provider.generate(request)
}.onFailure {
// then
assertEquals("Not supported", it.message)
}
}

@Test
fun `should throw not supported when invoke completions`() = runTest {
// given
val request = CompletionRequest(
prompt = "Once upon a time",
suffix = "The end",
maxTokens = 10,
temperature = 0.5
)

runCatching {
// when
provider.completions(request)
}.onFailure {
// then
assertEquals("Not supported", it.message)
}
}

@Test
fun `should fetch chat completions from OpenAI API`() = runTest {
// given
val request =
ChatCompletionRequest.dummy(OpenAIModel(GeminiModel.GEMINI_1_5_FLASH.value))
val response = GenerateContentResponse.dummy()
whenever(client.generateContent(request.toGeminiGenerateContentRequest())).thenReturn(
response
)

// when
val completions = provider.chatCompletions(request)

// then
assertEquals(response.toOpenAIChatCompletion(), completions)
}

@Test
fun `should stream chat completions for chat`() = runTest {
// given
val request =
ChatCompletionRequest.dummy(OpenAIModel(GeminiModel.GEMINI_1_5_FLASH.value))

val response = GenerateContentResponse.dummy()
whenever(client.streamGenerateContent(request.toGeminiGenerateContentRequest())).thenReturn(
flow {
emit(
response
)
})

// when
provider.streamChatCompletions(request).test {
// then
assertEquals(
response.toOpenAIChatCompletionChunk(),
awaitItem()
)
awaitComplete()
}

}

@Test
fun `should return false when model is not supported`() {
// given
val model = OpenAIModel(OpenAIModel.GPT_3_5_TURBO.value)

// when
val supported = provider.supports(model)

// then
assertFalse(supported)
}

@Test
fun `should return true when model is supported`() {
// given
val model = OpenAIModel(GeminiModel.GEMINI_1_5_FLASH.value)

// when
val supported = provider.supports(model)

// then
assertTrue(supported)
}
}

0 comments on commit a70e9c8

Please sign in to comment.