diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/DefaultOpenAIProvider.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/DefaultOpenAIProvider.kt index f7808a5..6c5095a 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/DefaultOpenAIProvider.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/DefaultOpenAIProvider.kt @@ -36,3 +36,11 @@ class DefaultOpenAIProvider( return openAI.completions(request) } } + +fun OpenAIProvider.Companion.openAI( + config: OpenAIConfig, + models: List, + openAI: OpenAI = OpenAI.create(config) +): OpenAIProvider { + return DefaultOpenAIProvider(config, models, openAI) +} diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIProvider.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIProvider.kt index 556b30d..767c7a7 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIProvider.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIProvider.kt @@ -15,4 +15,6 @@ interface OpenAIProvider : Chat, Completions { * @return true if the model is supported, false otherwise */ fun supports(model: Model): Boolean + + companion object } \ No newline at end of file diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt index 404819b..3b98cc7 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt @@ -1,6 +1,5 @@ package com.tddworks.openai.gateway.api.internal -import com.tddworks.openai.api.OpenAI import com.tddworks.openai.api.chat.api.ChatCompletion import com.tddworks.openai.api.chat.api.ChatCompletionChunk import com.tddworks.openai.api.chat.api.ChatCompletionRequest @@ -19,7 +18,6 @@ import kotlinx.serialization.ExperimentalSerializationApi @ExperimentalSerializationApi class DefaultOpenAIGateway( providers: List, - private val openAI: OpenAI, ) : OpenAIGateway { private val availableProviders: MutableList = providers.toMutableList() @@ -40,9 +38,9 @@ class DefaultOpenAIGateway( * @return A ChatCompletion object containing the completions for the provided request. */ override suspend fun chatCompletions(request: ChatCompletionRequest): ChatCompletion { - return availableProviders.firstOrNull { - it.supports(request.model) - }?.chatCompletions(request) ?: openAI.chatCompletions(request) + return availableProviders.firstOrNull { it.supports(request.model) } + ?.chatCompletions(request) + ?: throwNoProviderFound(request.model.value) } /** @@ -55,7 +53,8 @@ class DefaultOpenAIGateway( override fun streamChatCompletions(request: ChatCompletionRequest): Flow { return availableProviders.firstOrNull { it.supports(request.model) - }?.streamChatCompletions(request) ?: openAI.streamChatCompletions(request) + }?.streamChatCompletions(request) + ?: throwNoProviderFound(request.model.value) } /** @@ -67,6 +66,12 @@ class DefaultOpenAIGateway( override suspend fun completions(request: CompletionRequest): Completion { return availableProviders.firstOrNull { it.supports(request.model) - }?.completions(request) ?: openAI.completions(request) + }?.completions(request) + ?: throwNoProviderFound(request.model.value) + } + + + private fun throwNoProviderFound(model: String): Nothing { + throw UnsupportedOperationException("No provider found for model $model") } } \ No newline at end of file diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/di/Koin.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/di/Koin.kt index 454ffa8..8f6c749 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/di/Koin.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/di/Koin.kt @@ -10,6 +10,7 @@ import com.tddworks.openai.di.openAIModules import com.tddworks.openai.gateway.api.AnthropicOpenAIProvider import com.tddworks.openai.gateway.api.OllamaOpenAIProvider import com.tddworks.openai.gateway.api.OpenAIGateway +import com.tddworks.openai.gateway.api.OpenAIProvider import com.tddworks.openai.gateway.api.internal.DefaultOpenAIGateway import kotlinx.serialization.ExperimentalSerializationApi import org.koin.core.context.startKoin @@ -33,6 +34,22 @@ fun initOpenAIGateway( ) }.koin.get() + +@ExperimentalSerializationApi +fun createOpenAIGateway(providers: List) = startKoin { + modules( + commonModule(false) + + openAIGatewayModules(providers) + ) +}.koin.get() + + +@OptIn(ExperimentalSerializationApi::class) +fun openAIGatewayModules(providers: List) = module { + single { providers } + single { DefaultOpenAIGateway(get()) } +} + @ExperimentalSerializationApi fun openAIGatewayModules() = module { single { AnthropicOpenAIProvider(get()) } @@ -45,5 +62,5 @@ fun openAIGatewayModules() = module { ) } - single { DefaultOpenAIGateway(get(), get()) } + single { DefaultOpenAIGateway(get()) } }