Skip to content

Commit

Permalink
code refactor: able to create provider from OpenAIProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
hanrw committed Aug 20, 2024
1 parent f2c0a79 commit 862a86c
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,11 @@ class DefaultOpenAIProvider(
return openAI.completions(request)
}
}

fun OpenAIProvider.Companion.openAI(
config: OpenAIConfig,
models: List<Model>,
openAI: OpenAI = OpenAI.create(config)
): OpenAIProvider {
return DefaultOpenAIProvider(config, models, openAI)
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ interface OpenAIProvider : Chat, Completions {
* @return true if the model is supported, false otherwise
*/
fun supports(model: Model): Boolean

companion object
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.tddworks.openai.gateway.api.internal

import com.tddworks.openai.api.OpenAI
import com.tddworks.openai.api.chat.api.ChatCompletion
import com.tddworks.openai.api.chat.api.ChatCompletionChunk
import com.tddworks.openai.api.chat.api.ChatCompletionRequest
Expand All @@ -19,7 +18,6 @@ import kotlinx.serialization.ExperimentalSerializationApi
@ExperimentalSerializationApi
class DefaultOpenAIGateway(
providers: List<OpenAIProvider>,
private val openAI: OpenAI,
) : OpenAIGateway {
private val availableProviders: MutableList<OpenAIProvider> =
providers.toMutableList()
Expand All @@ -40,9 +38,9 @@ class DefaultOpenAIGateway(
* @return A ChatCompletion object containing the completions for the provided request.
*/
override suspend fun chatCompletions(request: ChatCompletionRequest): ChatCompletion {
return availableProviders.firstOrNull {
it.supports(request.model)
}?.chatCompletions(request) ?: openAI.chatCompletions(request)
return availableProviders.firstOrNull { it.supports(request.model) }
?.chatCompletions(request)
?: throwNoProviderFound(request.model.value)
}

/**
Expand All @@ -55,7 +53,8 @@ class DefaultOpenAIGateway(
override fun streamChatCompletions(request: ChatCompletionRequest): Flow<ChatCompletionChunk> {
return availableProviders.firstOrNull {
it.supports(request.model)
}?.streamChatCompletions(request) ?: openAI.streamChatCompletions(request)
}?.streamChatCompletions(request)
?: throwNoProviderFound(request.model.value)
}

/**
Expand All @@ -67,6 +66,12 @@ class DefaultOpenAIGateway(
override suspend fun completions(request: CompletionRequest): Completion {
return availableProviders.firstOrNull {
it.supports(request.model)
}?.completions(request) ?: openAI.completions(request)
}?.completions(request)
?: throwNoProviderFound(request.model.value)
}


private fun throwNoProviderFound(model: String): Nothing {
throw UnsupportedOperationException("No provider found for model $model")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import com.tddworks.openai.di.openAIModules
import com.tddworks.openai.gateway.api.AnthropicOpenAIProvider
import com.tddworks.openai.gateway.api.OllamaOpenAIProvider
import com.tddworks.openai.gateway.api.OpenAIGateway
import com.tddworks.openai.gateway.api.OpenAIProvider
import com.tddworks.openai.gateway.api.internal.DefaultOpenAIGateway
import kotlinx.serialization.ExperimentalSerializationApi
import org.koin.core.context.startKoin
Expand All @@ -33,6 +34,22 @@ fun initOpenAIGateway(
)
}.koin.get<OpenAIGateway>()


@ExperimentalSerializationApi
fun createOpenAIGateway(providers: List<OpenAIProvider>) = startKoin {
modules(
commonModule(false) +
openAIGatewayModules(providers)
)
}.koin.get<OpenAIGateway>()


@OptIn(ExperimentalSerializationApi::class)
fun openAIGatewayModules(providers: List<OpenAIProvider>) = module {
single { providers }
single<OpenAIGateway> { DefaultOpenAIGateway(get()) }
}

@ExperimentalSerializationApi
fun openAIGatewayModules() = module {
single<AnthropicOpenAIProvider> { AnthropicOpenAIProvider(get()) }
Expand All @@ -45,5 +62,5 @@ fun openAIGatewayModules() = module {
)
}

single<OpenAIGateway> { DefaultOpenAIGateway(get(), get()) }
single<OpenAIGateway> { DefaultOpenAIGateway(get()) }
}

0 comments on commit 862a86c

Please sign in to comment.