From 9386056086cbf722fe776616a023f301c7c66bd0 Mon Sep 17 00:00:00 2001 From: slam Date: Fri, 23 Aug 2024 20:34:44 +0800 Subject: [PATCH] feat(gateway): to support AzureAIProvider - refactor createHttpClient --- .../com/tddworks/anthropic/api/Anthropic.kt | 10 +- .../kotlin/com/tddworks/anthropic/di/Koin.kt | 10 +- .../network/api/ktor/internal/HttpClient.kt | 160 +++++++++--------- .../api/ktor/internal/HttpClientTest.kt | 52 ++++-- .../kotlin/com/tddworks/ollama/api/Ollama.kt | 31 +--- .../com/tddworks/ollama/api/OllamaConfig.kt | 4 +- .../tddworks/ollama/api/internal/OllamaApi.kt | 8 - .../kotlin/com/tddworks/ollama/di/Koin.kt | 11 +- .../tddworks/ollama/api/OllamaConfigTest.kt | 14 +- .../com/tddworks/ollama/api/OllamaTest.kt | 16 +- .../chat/internal/DefaultOllamaChatITest.kt | 8 +- .../com/tddworks/ollama/api/DarwinOllama.kt | 4 +- .../kotlin/com/tddworks/openai/api/OpenAI.kt | 20 ++- .../api/chat/internal/DefaultChatApi.kt | 5 +- .../kotlin/com/tddworks/openai/di/Koin.kt | 9 +- .../kotlin/com/tddworks/azure/api/AzureAI.kt | 126 ++++++++++++++ .../api/internal/DefaultOpenAIGateway.kt | 11 ++ .../api/internal/OllamaOpenAIProvider.kt | 12 +- .../internal/OllamaOpenAIProviderConfig.kt | 12 +- .../azure/api/AzureAIProviderConfigTest.kt | 23 +++ .../api/internal/DefaultOpenAIGatewayTest.kt | 48 +++++- .../openai/gateway/api/DarwinOpenAIGateway.kt | 21 +-- 22 files changed, 375 insertions(+), 240 deletions(-) create mode 100644 openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/azure/api/AzureAI.kt create mode 100644 openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/azure/api/AzureAIProviderConfigTest.kt diff --git a/anthropic-client/anthropic-client-core/src/commonMain/kotlin/com/tddworks/anthropic/api/Anthropic.kt b/anthropic-client/anthropic-client-core/src/commonMain/kotlin/com/tddworks/anthropic/api/Anthropic.kt index ddaea65..1deb363 100644 --- a/anthropic-client/anthropic-client-core/src/commonMain/kotlin/com/tddworks/anthropic/api/Anthropic.kt +++ b/anthropic-client/anthropic-client-core/src/commonMain/kotlin/com/tddworks/anthropic/api/Anthropic.kt @@ -5,10 +5,7 @@ import com.tddworks.anthropic.api.messages.api.Messages import com.tddworks.anthropic.api.messages.api.internal.DefaultMessagesApi import com.tddworks.anthropic.api.messages.api.internal.JsonLenient import com.tddworks.common.network.api.ktor.api.HttpRequester -import com.tddworks.common.network.api.ktor.internal.createHttpClient -import com.tddworks.common.network.api.ktor.internal.default -import com.tddworks.di.createJson -import com.tddworks.di.getInstance +import com.tddworks.common.network.api.ktor.internal.* /** * Interface for interacting with the Anthropic API. @@ -30,9 +27,10 @@ interface Anthropic : Messages { val requester = HttpRequester.default( createHttpClient( - host = anthropicConfig.baseUrl, + connectionConfig = UrlBasedConnectionConfig(anthropicConfig.baseUrl), + authConfig = AuthConfig(anthropicConfig.apiKey), // get from commonModule - json = JsonLenient, + features = ClientFeatures(json = JsonLenient) ) ) val messages = DefaultMessagesApi( diff --git a/anthropic-client/anthropic-client-core/src/commonMain/kotlin/com/tddworks/anthropic/di/Koin.kt b/anthropic-client/anthropic-client-core/src/commonMain/kotlin/com/tddworks/anthropic/di/Koin.kt index 49c06b3..ad6b355 100644 --- a/anthropic-client/anthropic-client-core/src/commonMain/kotlin/com/tddworks/anthropic/di/Koin.kt +++ b/anthropic-client/anthropic-client-core/src/commonMain/kotlin/com/tddworks/anthropic/di/Koin.kt @@ -6,12 +6,10 @@ import com.tddworks.anthropic.api.messages.api.Messages import com.tddworks.anthropic.api.messages.api.internal.DefaultMessagesApi import com.tddworks.anthropic.api.messages.api.internal.JsonLenient import com.tddworks.common.network.api.ktor.api.HttpRequester -import com.tddworks.common.network.api.ktor.internal.createHttpClient -import com.tddworks.common.network.api.ktor.internal.default +import com.tddworks.common.network.api.ktor.internal.* import com.tddworks.di.commonModule import kotlinx.serialization.json.Json import org.koin.core.context.startKoin -import org.koin.core.module.Module import org.koin.core.qualifier.named import org.koin.dsl.KoinAppDeclaration import org.koin.dsl.module @@ -38,8 +36,10 @@ fun anthropicModules( single(named("anthropicHttpRequester")) { HttpRequester.default( createHttpClient( - host = config.baseUrl, - json = get(named("anthropicJson")), + connectionConfig = UrlBasedConnectionConfig(config.baseUrl), + authConfig = AuthConfig(config.apiKey), + // get from commonModule + features = ClientFeatures(json = get(named("anthropicJson"))) ) ) } diff --git a/common/src/commonMain/kotlin/com/tddworks/common/network/api/ktor/internal/HttpClient.kt b/common/src/commonMain/kotlin/com/tddworks/common/network/api/ktor/internal/HttpClient.kt index 6500084..400c343 100644 --- a/common/src/commonMain/kotlin/com/tddworks/common/network/api/ktor/internal/HttpClient.kt +++ b/common/src/commonMain/kotlin/com/tddworks/common/network/api/ktor/internal/HttpClient.kt @@ -3,122 +3,114 @@ package com.tddworks.common.network.api.ktor.internal import io.ktor.client.* import io.ktor.client.engine.* import io.ktor.client.plugins.* -import io.ktor.client.plugins.auth.* -import io.ktor.client.plugins.auth.providers.* import io.ktor.client.plugins.contentnegotiation.* import io.ktor.client.plugins.logging.* import io.ktor.client.request.* import io.ktor.http.* import io.ktor.serialization.kotlinx.* +import io.ktor.util.* import kotlinx.serialization.json.Json -import kotlin.time.Duration.Companion.seconds internal expect fun httpClientEngine(): HttpClientEngine -/** - * Creates a new [HttpClient] with [OkHttp] engine and [ContentNegotiation] plugin. - * - * @param protocol the protocol to use - default is HTTPS - * @param host the base URL of the API - * @param port the port to use - default is 443 - * @param authToken the authentication token - * @return a new [HttpClient] instance - */ + +interface ConnectionConfig { + fun setupUrl(builder: DefaultRequest.DefaultRequestBuilder) { + builder.setupUrl(this) + } +} + +data class UrlBasedConnectionConfig( + val baseUrl: () -> String = { "" } +) : ConnectionConfig + +data class HostPortConnectionConfig( + val protocol: () -> String? = { null }, + val host: () -> String = { "" }, + val port: () -> Int? = { null }, +) : ConnectionConfig + +data class AuthConfig( + val authToken: (() -> String)? = null +) + +data class ClientFeatures( + val json: Json = Json, + val queryParams: Map = emptyMap(), + val expectSuccess: Boolean = true +) + fun createHttpClient( - protocol: () -> String? = { null }, - host: () -> String, - port: () -> Int? = { null }, - authToken: (() -> String)? = null, - json: Json = Json, - httpClientEngine: HttpClientEngine = httpClientEngine(), + connectionConfig: ConnectionConfig = UrlBasedConnectionConfig(), + authConfig: AuthConfig = AuthConfig(), + features: ClientFeatures = ClientFeatures(), + httpClientEngine: HttpClientEngine = httpClientEngine() ): HttpClient { - return HttpClient(httpClientEngine) { -// enable proxy in the future -// engine { -// proxy = ProxyBuilder.http(url) -// } + return HttpClient(httpClientEngine) { install(ContentNegotiation) { - register(ContentType.Application.Json, KotlinxSerializationConverter(json)) + register( + ContentType.Application.Json, + KotlinxSerializationConverter(features.json) + ) } - /** - * Support configurable in the future - * Install the Logging module. - * @param logging the logging instance to use - * @return Unit - */ install(Logging) { - /** - * DEFAULT - default - LoggerFactory.getLogger - * SIMPLE - Logger using println. - * Empty - Empty Logger for test purpose. - */ logger = Logger.DEFAULT - /** - * ALL - log all - * HEADERS - log headers - * INFO - log info - * NONE - none - */ level = LogLevel.INFO } - /** - * Install the Auth module. but can't update on the fly - * @param auth the auth instance to use - * @return Unit - */ -// authToken?.let { -// install(Auth) { -// bearer { -// loadTokens { -// BearerTokens(accessToken = authToken(), refreshToken = "") -// } -// } -// } -// } - - /** - * Installs an [HttpRequestRetry] with default maxRetries of 3, - * retryIf checks for rate limit error with status code 429, - * and exponential delay with base 5.0 and max delay of 1 minute. - * - * @param retry [HttpRequestRetry] instance to install - */ install(HttpRequestRetry) { maxRetries = 3 - // retry on rate limit error. - retryIf { _, response -> response.status.value.let { it == 429 } } - exponentialDelay(base = 5.0, maxDelayMs = 10.seconds.inWholeMilliseconds) + retryIf { _, response -> response.status.value == 429 } + exponentialDelay(base = 5.0, maxDelayMs = 60_000) } + defaultRequest { - url { - this.protocol = protocol()?.let { URLProtocol.createOrDefault(it) } - ?: URLProtocol.HTTPS - this.host = host() - port()?.let { this.port = it } - } + connectionConfig.setupUrl(this) + commonSettings(features.queryParams, authConfig.authToken) + } + + expectSuccess = features.expectSuccess + } +} - authToken?.let { - header(HttpHeaders.Authorization, "Bearer ${it()}") +private fun DefaultRequest.DefaultRequestBuilder.setupUrl(connectionConfig: ConnectionConfig) { + when (connectionConfig) { + is HostPortConnectionConfig -> { + url { + protocol = + connectionConfig.protocol()?.let { URLProtocol.createOrDefault(it) } + ?: URLProtocol.HTTPS + host = connectionConfig.host() + connectionConfig.port()?.let { port = it } } + } - header(HttpHeaders.ContentType, ContentType.Application.Json) - contentType(ContentType.Application.Json) + is UrlBasedConnectionConfig -> { + connectionConfig.baseUrl().let { url.takeFrom(it) } } + } +} - /** - * If set to true, the client will throw an exception if the response from the server is not successful. The definition of successful can vary depending on the HTTP status code. For example, a successful response for a GET request would typically be a status code of 200, while a successful response for a POST request could be a status code of 201. - * - * By setting expectSuccess = true, the developer is indicating that they want to handle non-successful responses explicitly and can throw or handle the exceptions themselves. - * - * If expectSuccess is set to false, the HttpClient will not throw exceptions for non-successful responses and the developer is responsible for parsing and handling any errors or unexpected responses. - */ - expectSuccess = true +private fun DefaultRequest.DefaultRequestBuilder.commonSettings( + queryParams: Map, + authToken: (() -> String)? +) { + queryParams.forEach { (key, value) -> + url.parameters.appendIfNameAbsent( + key, + value + ) + } + authToken?.let { + header(HttpHeaders.Authorization, "Bearer ${it()}") } + + header(HttpHeaders.ContentType, ContentType.Application.Json) + contentType(ContentType.Application.Json) } diff --git a/common/src/jvmTest/kotlin/com/tddworks/common/network/api/ktor/internal/HttpClientTest.kt b/common/src/jvmTest/kotlin/com/tddworks/common/network/api/ktor/internal/HttpClientTest.kt index c7b7962..2d2b254 100644 --- a/common/src/jvmTest/kotlin/com/tddworks/common/network/api/ktor/internal/HttpClientTest.kt +++ b/common/src/jvmTest/kotlin/com/tddworks/common/network/api/ktor/internal/HttpClientTest.kt @@ -6,8 +6,7 @@ import io.ktor.client.engine.okhttp.* import io.ktor.client.request.* import io.ktor.http.* import io.ktor.utils.io.* -import kotlinx.coroutines.runBlocking -import kotlinx.serialization.json.Json +import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test @@ -15,25 +14,44 @@ import org.junit.jupiter.api.Test class HttpClientTest { @Test - fun `should return correct json response with default settings`() { - runBlocking { - val mockEngine = MockEngine { request -> - respond( - content = ByteReadChannel("""{"ip":"127.0.0.1"}"""), - status = HttpStatusCode.OK, - headers = headersOf(HttpHeaders.ContentType, "application/json") - ) - } - val apiClient = createHttpClient( - host = { "some-host" }, - httpClientEngine = mockEngine + fun `should return correct response with host and port based config`() = runTest { + val mockEngine = MockEngine { _ -> + respond( + content = ByteReadChannel("""{"ip":"127.0.0.1"}"""), + status = HttpStatusCode.OK, + headers = headersOf(HttpHeaders.ContentType, "application/json") ) - - val body = apiClient.get("https://some-host:443").body() - assertEquals("""{"ip":"127.0.0.1"}""", body) } + val apiClient = createHttpClient( + connectionConfig = HostPortConnectionConfig( + protocol = { "https" }, + host = { "some-host" }, + port = { 443 } + ), + httpClientEngine = mockEngine + ) + + val body = apiClient.get("https://some-host").body() + assertEquals("""{"ip":"127.0.0.1"}""", body) } + @Test + fun `should return correct response with url based config`() = runTest { + val mockEngine = MockEngine { _ -> + respond( + content = ByteReadChannel("""{"ip":"127.0.0.1"}"""), + status = HttpStatusCode.OK, + headers = headersOf(HttpHeaders.ContentType, "application/json") + ) + } + val apiClient = createHttpClient( + connectionConfig = UrlBasedConnectionConfig { "https://some-host" }, + httpClientEngine = mockEngine + ) + + val body = apiClient.get("https://some-host").body() + assertEquals("""{"ip":"127.0.0.1"}""", body) + } @Test fun `should return OkHttp engine`() { diff --git a/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/Ollama.kt b/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/Ollama.kt index fda63b1..e4bc6c2 100644 --- a/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/Ollama.kt +++ b/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/Ollama.kt @@ -1,8 +1,7 @@ package com.tddworks.ollama.api import com.tddworks.common.network.api.ktor.api.HttpRequester -import com.tddworks.common.network.api.ktor.internal.createHttpClient -import com.tddworks.common.network.api.ktor.internal.default +import com.tddworks.common.network.api.ktor.internal.* import com.tddworks.ollama.api.chat.OllamaChat import com.tddworks.ollama.api.chat.internal.DefaultOllamaChatApi import com.tddworks.ollama.api.generate.OllamaGenerate @@ -16,18 +15,16 @@ import com.tddworks.ollama.api.json.JsonLenient interface Ollama : OllamaChat, OllamaGenerate { companion object { - const val BASE_URL = "localhost" - const val PORT = 11434 - const val PROTOCOL = "http" + const val BASE_URL = "http://localhost:11434" fun create(ollamaConfig: OllamaConfig): Ollama { - val requester = HttpRequester.default( createHttpClient( - host = ollamaConfig.baseUrl, - port = ollamaConfig.port, - protocol = ollamaConfig.protocol, - json = JsonLenient, + connectionConfig = UrlBasedConnectionConfig( + baseUrl = ollamaConfig.baseUrl, + ), + // get from commonModule + features = ClientFeatures(json = JsonLenient) ) ) val ollamaChat = DefaultOllamaChatApi(requester = requester) @@ -47,18 +44,4 @@ interface Ollama : OllamaChat, OllamaGenerate { * @return a string representing the base URL */ fun baseUrl(): String - - /** - * This function returns the port as an integer. - * - * @return an integer representing the port - */ - fun port(): Int - - /** - * This function returns the protocol as a string. - * - * @return a string representing the protocol - */ - fun protocol(): String } \ No newline at end of file diff --git a/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/OllamaConfig.kt b/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/OllamaConfig.kt index 5176d65..f3aa348 100644 --- a/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/OllamaConfig.kt +++ b/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/OllamaConfig.kt @@ -3,7 +3,5 @@ package com.tddworks.ollama.api import org.koin.core.component.KoinComponent data class OllamaConfig( - val baseUrl: () -> String = { Ollama.BASE_URL }, - val protocol: () -> String = { Ollama.PROTOCOL }, - val port: () -> Int = { Ollama.PORT }, + val baseUrl: () -> String = { Ollama.BASE_URL } ) : KoinComponent \ No newline at end of file diff --git a/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/internal/OllamaApi.kt b/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/internal/OllamaApi.kt index a85707e..a2154de 100644 --- a/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/internal/OllamaApi.kt +++ b/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/internal/OllamaApi.kt @@ -15,14 +15,6 @@ class OllamaApi( override fun baseUrl(): String { return config.baseUrl() } - - override fun port(): Int { - return config.port() - } - - override fun protocol(): String { - return config.protocol() - } } fun Ollama.Companion.create( diff --git a/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/di/Koin.kt b/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/di/Koin.kt index 235da37..ab647d4 100644 --- a/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/di/Koin.kt +++ b/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/di/Koin.kt @@ -1,8 +1,7 @@ package com.tddworks.ollama.di import com.tddworks.common.network.api.ktor.api.HttpRequester -import com.tddworks.common.network.api.ktor.internal.createHttpClient -import com.tddworks.common.network.api.ktor.internal.default +import com.tddworks.common.network.api.ktor.internal.* import com.tddworks.di.commonModule import com.tddworks.ollama.api.Ollama import com.tddworks.ollama.api.OllamaConfig @@ -37,10 +36,10 @@ fun ollamaModules( single(named("ollamaHttpRequester")) { HttpRequester.default( createHttpClient( - protocol = config.protocol, - host = config.baseUrl, - port = config.port, - json = get(named("ollamaJson")), + connectionConfig = UrlBasedConnectionConfig( + baseUrl = config.baseUrl, + ), + features = ClientFeatures(json = get(named("ollamaJson"))) ) ) } diff --git a/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/OllamaConfigTest.kt b/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/OllamaConfigTest.kt index 7ed0719..0a74a0b 100644 --- a/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/OllamaConfigTest.kt +++ b/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/OllamaConfigTest.kt @@ -9,26 +9,16 @@ class OllamaConfigTest { @Test fun `should return overridden settings`() { val target = OllamaConfig( - baseUrl = { "some-url" }, - port = { 8080 }, - protocol = { "https" } + baseUrl = { "some-url" } ) assertEquals("some-url", target.baseUrl()) - - assertEquals(8080, target.port()) - - assertEquals("https", target.protocol()) } @Test fun `should return default settings`() { val target = OllamaConfig() - assertEquals("localhost", target.baseUrl()) - - assertEquals(11434, target.port()) - - assertEquals("http", target.protocol()) + assertEquals("http://localhost:11434", target.baseUrl()) } } \ No newline at end of file diff --git a/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/OllamaTest.kt b/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/OllamaTest.kt index f72f840..388860d 100644 --- a/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/OllamaTest.kt +++ b/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/OllamaTest.kt @@ -16,9 +16,7 @@ class OllamaTestTest : AutoCloseKoinTest() { koinApplication { initOllama( config = OllamaConfig( - baseUrl = { "127.0.0.1" }, - port = { 8080 }, - protocol = { "https" } + baseUrl = { "http://127.0.0.1:8080" }, ) ) }.checkModules() @@ -28,22 +26,14 @@ class OllamaTestTest : AutoCloseKoinTest() { fun `should return overridden settings`() { val target = getInstance() - assertEquals("127.0.0.1", target.baseUrl()) - - assertEquals(8080, target.port()) - - assertEquals("https", target.protocol()) + assertEquals("http://127.0.0.1:8080", target.baseUrl()) } @Test fun `should return default settings`() { val target = Ollama.create(ollamaConfig = OllamaConfig()) - assertEquals("localhost", target.baseUrl()) - - assertEquals(11434, target.port()) - - assertEquals("http", target.protocol()) + assertEquals("http://localhost:11434", target.baseUrl()) } } \ No newline at end of file diff --git a/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/internal/DefaultOllamaChatITest.kt b/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/internal/DefaultOllamaChatITest.kt index 976f122..9e13d51 100644 --- a/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/internal/DefaultOllamaChatITest.kt +++ b/ollama-client/ollama-client-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/internal/DefaultOllamaChatITest.kt @@ -20,11 +20,7 @@ class DefaultOllamaChatITest : AutoCloseKoinTest() { @BeforeEach fun setUp() { initOllama( - config = OllamaConfig( - protocol = { "http" }, - baseUrl = { "localhost" }, - port = { 11434 } - ) + config = OllamaConfig(baseUrl = { "http://localhost:11434" }) ) } @@ -70,8 +66,6 @@ class DefaultOllamaChatITest : AutoCloseKoinTest() { ) ) - println("create response: $r") - assertNotNull(r.message?.content) } } \ No newline at end of file diff --git a/ollama-client/ollama-client-darwin/src/appleMain/kotlin/com/tddworks/ollama/api/DarwinOllama.kt b/ollama-client/ollama-client-darwin/src/appleMain/kotlin/com/tddworks/ollama/api/DarwinOllama.kt index 17433d7..9081de1 100644 --- a/ollama-client/ollama-client-darwin/src/appleMain/kotlin/com/tddworks/ollama/api/DarwinOllama.kt +++ b/ollama-client/ollama-client-darwin/src/appleMain/kotlin/com/tddworks/ollama/api/DarwinOllama.kt @@ -22,9 +22,7 @@ object DarwinOllama { * @return an Ollama instance created with the provided configuration */ fun ollama( - port: () -> Int = { Ollama.PORT }, - protocol: () -> String = { Ollama.PROTOCOL }, baseUrl: () -> String = { Ollama.BASE_URL }, ): Ollama = - initOllama(OllamaConfig(baseUrl = baseUrl, port = port, protocol = protocol)) + initOllama(OllamaConfig(baseUrl = baseUrl)) } \ No newline at end of file diff --git a/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/api/OpenAI.kt b/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/api/OpenAI.kt index 641fd97..b1993e7 100644 --- a/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/api/OpenAI.kt +++ b/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/api/OpenAI.kt @@ -1,8 +1,7 @@ package com.tddworks.openai.api import com.tddworks.common.network.api.ktor.api.HttpRequester -import com.tddworks.common.network.api.ktor.internal.createHttpClient -import com.tddworks.common.network.api.ktor.internal.default +import com.tddworks.common.network.api.ktor.internal.* import com.tddworks.di.createJson import com.tddworks.di.getInstance import com.tddworks.openai.api.chat.api.Chat @@ -19,14 +18,21 @@ interface OpenAI : Chat, Images, Completions { fun create(config: OpenAIConfig): OpenAI { val requester = HttpRequester.default( createHttpClient( - host = config.baseUrl, - authToken = config.apiKey, - // get from commonModule - json = createJson(), + connectionConfig = UrlBasedConnectionConfig(config.baseUrl), + authConfig = AuthConfig(config.apiKey), + features = ClientFeatures(json = createJson()) ) ) + return create(requester) + } + + fun create( + requester: HttpRequester, + chatCompletionPath: String = Chat.CHAT_COMPLETIONS_PATH + ): OpenAI { val chatApi = DefaultChatApi( - requester = requester + requester = requester, + chatCompletionPath = chatCompletionPath ) val imagesApi = DefaultImagesApi( diff --git a/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/api/chat/internal/DefaultChatApi.kt b/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/api/chat/internal/DefaultChatApi.kt index b3fce83..ac2e955 100644 --- a/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/api/chat/internal/DefaultChatApi.kt +++ b/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/api/chat/internal/DefaultChatApi.kt @@ -21,11 +21,12 @@ import kotlinx.serialization.ExperimentalSerializationApi @OptIn(ExperimentalSerializationApi::class) class DefaultChatApi( private val requester: HttpRequester, + private val chatCompletionPath: String = CHAT_COMPLETIONS_PATH ) : Chat { override suspend fun chatCompletions(request: ChatCompletionRequest): ChatCompletion { return requester.performRequest { method = HttpMethod.Post - url(path = CHAT_COMPLETIONS_PATH) + url(path = chatCompletionPath) setBody(request) contentType(ContentType.Application.Json) } @@ -34,7 +35,7 @@ class DefaultChatApi( override fun streamChatCompletions(request: ChatCompletionRequest): Flow { return requester.streamRequest { method = HttpMethod.Post - url(path = CHAT_COMPLETIONS_PATH) + url(path = chatCompletionPath) setBody(request.copy(stream = true)) contentType(ContentType.Application.Json) accept(ContentType.Text.EventStream) diff --git a/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/di/Koin.kt b/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/di/Koin.kt index 65b9c2b..48de387 100644 --- a/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/di/Koin.kt +++ b/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/di/Koin.kt @@ -1,8 +1,7 @@ package com.tddworks.openai.di import com.tddworks.common.network.api.ktor.api.HttpRequester -import com.tddworks.common.network.api.ktor.internal.createHttpClient -import com.tddworks.common.network.api.ktor.internal.default +import com.tddworks.common.network.api.ktor.internal.* import com.tddworks.di.commonModule import com.tddworks.openai.api.OpenAI import com.tddworks.openai.api.OpenAIApi @@ -37,10 +36,10 @@ fun openAIModules( single(named("openAIHttpRequester")) { HttpRequester.default( createHttpClient( - host = config.baseUrl, - authToken = config.apiKey, + connectionConfig = UrlBasedConnectionConfig(config.baseUrl), + authConfig = AuthConfig(config.apiKey), // get from commonModule - json = get(), + features = ClientFeatures(json = get()) ) ) } diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/azure/api/AzureAI.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/azure/api/AzureAI.kt new file mode 100644 index 0000000..19db639 --- /dev/null +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/azure/api/AzureAI.kt @@ -0,0 +1,126 @@ +package com.tddworks.azure.api + +import com.tddworks.common.network.api.ktor.api.HttpRequester +import com.tddworks.common.network.api.ktor.api.performRequest +import com.tddworks.common.network.api.ktor.api.streamRequest +import com.tddworks.common.network.api.ktor.internal.ClientFeatures +import com.tddworks.common.network.api.ktor.internal.UrlBasedConnectionConfig +import com.tddworks.common.network.api.ktor.internal.createHttpClient +import com.tddworks.common.network.api.ktor.internal.default +import com.tddworks.di.createJson +import com.tddworks.openai.api.OpenAI +import com.tddworks.openai.api.chat.api.Chat +import com.tddworks.openai.api.chat.api.Chat.Companion.CHAT_COMPLETIONS_PATH +import com.tddworks.openai.api.chat.api.ChatCompletion +import com.tddworks.openai.api.chat.api.ChatCompletionChunk +import com.tddworks.openai.api.chat.api.ChatCompletionRequest +import com.tddworks.openai.api.images.api.Images +import com.tddworks.openai.api.images.internal.DefaultImagesApi +import com.tddworks.openai.api.legacy.completions.api.Completions +import com.tddworks.openai.api.legacy.completions.api.internal.DefaultCompletionsApi +import com.tddworks.openai.gateway.api.OpenAIProviderConfig +import io.ktor.client.request.* +import io.ktor.http.* +import kotlinx.coroutines.flow.Flow +import kotlinx.serialization.ExperimentalSerializationApi + +data class AzureAIProviderConfig( + override val apiKey: () -> String, + override val baseUrl: () -> String, + val deploymentId: () -> String, + val apiVersion: () -> String, +) : OpenAIProviderConfig + +fun OpenAIProviderConfig.Companion.azure( + apiKey: () -> String, + baseUrl: () -> String, + deploymentId: () -> String, + apiVersion: () -> String +) = AzureAIProviderConfig( + apiKey = apiKey, + baseUrl = baseUrl, + deploymentId = deploymentId, + apiVersion = apiVersion +) + +/** + * Authentication + * Azure OpenAI provides two methods for authentication. You can use either API Keys or Microsoft Entra ID. + * + * API Key authentication: For this type of authentication, all API requests must include the API Key in the api-key HTTP header. The Quickstart provides guidance for how to make calls with this type of authentication. + * + * Microsoft Entra ID authentication: You can authenticate an API call using a Microsoft Entra token. Authentication tokens are included in a request as the Authorization header. The token provided must be preceded by Bearer, for example Bearer YOUR_AUTH_TOKEN. You can read our how-to guide on authenticating with Microsoft Entra ID. + */ +fun OpenAI.Companion.azure(config: AzureAIProviderConfig): OpenAI { + val requester = HttpRequester.default( + createHttpClient( + // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference + // POST https://YOUR_RESOURCE_NAME.openai.azure.com/openai/deployments/YOUR_DEPLOYMENT_NAME/completions?api-version=2024-06-01 + connectionConfig = UrlBasedConnectionConfig { "${config.baseUrl()}/openai/deployments/${config.deploymentId()}/" }, + features = ClientFeatures( + json = createJson(), + queryParams = mapOf("api-version" to config.apiVersion()) + ) + ) + ) + return azure( + config = config, + requester = requester, + chatCompletionPath = "chat/completions" + ) +} + +fun azure( + config: AzureAIProviderConfig, + requester: HttpRequester, + chatCompletionPath: String +): OpenAI { + val chatApi = AzureChatApi( + config = config, + requester = requester, + chatCompletionPath = chatCompletionPath + ) + + val imagesApi = DefaultImagesApi( + requester = requester + ) + + val completionsApi = DefaultCompletionsApi( + requester = requester + ) + + return object : OpenAI, Chat by chatApi, Images by imagesApi, + Completions by completionsApi {} +} + +@OptIn(ExperimentalSerializationApi::class) +class AzureChatApi( + private val config: AzureAIProviderConfig, + private val requester: HttpRequester, + private val chatCompletionPath: String = CHAT_COMPLETIONS_PATH +) : Chat { + override suspend fun chatCompletions(request: ChatCompletionRequest): ChatCompletion { + return requester.performRequest { + method = HttpMethod.Post + url(path = chatCompletionPath) + setBody(request) + contentType(ContentType.Application.Json) + } + } + + override fun streamChatCompletions(request: ChatCompletionRequest): Flow { + return requester.streamRequest { + method = HttpMethod.Post + url(path = chatCompletionPath) + setBody(request.copy(stream = true)) + contentType(ContentType.Application.Json) + accept(ContentType.Text.EventStream) + headers { + append("api-key", config.apiKey()) + append(HttpHeaders.CacheControl, "no-cache") + append(HttpHeaders.Connection, "keep-alive") + } + } + } + +} \ No newline at end of file diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt index ef34937..59001b3 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt @@ -1,5 +1,8 @@ package com.tddworks.openai.gateway.api.internal +import com.tddworks.azure.api.AzureAIProviderConfig +import com.tddworks.azure.api.azure +import com.tddworks.openai.api.OpenAI import com.tddworks.openai.api.chat.api.ChatCompletion import com.tddworks.openai.api.chat.api.ChatCompletionChunk import com.tddworks.openai.api.chat.api.ChatCompletionRequest @@ -55,6 +58,14 @@ class DefaultOpenAIGateway( models = models ) + is AzureAIProviderConfig -> DefaultOpenAIProvider( + id = id, + name = name, + config = config, + models = models, + openAI = OpenAI.azure(config) + ) + else -> throw IllegalArgumentException("Unsupported config type") } diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/OllamaOpenAIProvider.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/OllamaOpenAIProvider.kt index d3d0dde..d5e08b7 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/OllamaOpenAIProvider.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/OllamaOpenAIProvider.kt @@ -24,11 +24,7 @@ class OllamaOpenAIProvider( OpenAIModel(it.value) }, private val client: Ollama = Ollama.create( - ollamaConfig = OllamaConfig( - baseUrl = config.baseUrl, - port = config.port, - protocol = config.protocol - ) + ollamaConfig = OllamaConfig(baseUrl = config.baseUrl) ) ) : OpenAIProvider { /** @@ -79,11 +75,7 @@ fun OpenAIProvider.Companion.ollama( OpenAIModel(it.value) }, client: Ollama = Ollama.create( - ollamaConfig = OllamaConfig( - baseUrl = config.baseUrl, - port = config.port, - protocol = config.protocol - ) + ollamaConfig = OllamaConfig(baseUrl = config.baseUrl) ) ): OpenAIProvider { return OllamaOpenAIProvider( diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/OllamaOpenAIProviderConfig.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/OllamaOpenAIProviderConfig.kt index e15bfd4..6762bff 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/OllamaOpenAIProviderConfig.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/OllamaOpenAIProviderConfig.kt @@ -4,18 +4,14 @@ import com.tddworks.ollama.api.OllamaConfig import com.tddworks.openai.gateway.api.OpenAIProviderConfig data class OllamaOpenAIProviderConfig( - val port: () -> Int = { 11434 }, - val protocol: () -> String = { "http" }, - override val baseUrl: () -> String = { "localhost" }, + override val baseUrl: () -> String = { "http//:localhost:11434" }, override val apiKey: () -> String = { "ollama-ignore-this" } ) : OpenAIProviderConfig fun OllamaOpenAIProviderConfig.toOllamaConfig() = - OllamaConfig(baseUrl = baseUrl, protocol = protocol, port = port) + OllamaConfig(baseUrl = baseUrl) fun OpenAIProviderConfig.Companion.ollama( apiKey: () -> String = { "ollama-ignore-this" }, - baseUrl: () -> String = { "localhost" }, - protocol: () -> String = { "http" }, - port: () -> Int = { 11434 } -) = OllamaOpenAIProviderConfig(port, protocol, baseUrl, apiKey) \ No newline at end of file + baseUrl: () -> String = { "http//:localhost:11434" }, +) = OllamaOpenAIProviderConfig(baseUrl, apiKey) \ No newline at end of file diff --git a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/azure/api/AzureAIProviderConfigTest.kt b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/azure/api/AzureAIProviderConfigTest.kt new file mode 100644 index 0000000..55fa96e --- /dev/null +++ b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/azure/api/AzureAIProviderConfigTest.kt @@ -0,0 +1,23 @@ +package com.tddworks.azure.api + +import com.tddworks.openai.gateway.api.OpenAIProviderConfig +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.Test + +class AzureAIProviderConfigTest { + + @Test + fun `should create AzureAIProviderConfig`() { + val config = OpenAIProviderConfig.azure( + apiKey = { "api-key" }, + baseUrl = { "YOUR_RESOURCE_NAME.openai.azure.com" }, + deploymentId = { "deployment-id" }, + apiVersion = { "api-version" } + ) + + assertEquals("api-key", config.apiKey()) + assertEquals("YOUR_RESOURCE_NAME.openai.azure.com", config.baseUrl()) + assertEquals("deployment-id", config.deploymentId()) + assertEquals("api-version", config.apiVersion()) + } +} \ No newline at end of file diff --git a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGatewayTest.kt b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGatewayTest.kt index 65369e9..1bf2115 100644 --- a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGatewayTest.kt +++ b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGatewayTest.kt @@ -1,6 +1,7 @@ package com.tddworks.openai.gateway.api.internal import app.cash.turbine.test +import com.tddworks.azure.api.AzureAIProviderConfig import com.tddworks.ollama.api.OllamaModel import com.tddworks.openai.api.chat.api.ChatCompletion import com.tddworks.openai.api.chat.api.ChatCompletionChunk @@ -20,6 +21,7 @@ import com.tddworks.anthropic.api.AnthropicModel as AnthropicModel @OptIn(ExperimentalSerializationApi::class) class DefaultOpenAIGatewayTest { + private val anthropic = mock { on(it.id).thenReturn("anthropic") on(it.supports(OpenAIModel(AnthropicModel.CLAUDE_3_HAIKU.value))).thenReturn(true) @@ -39,16 +41,48 @@ class DefaultOpenAIGatewayTest { on(it.name).thenReturn("Default") } + private val azure = mock { + on(it.id).thenReturn("azure") + on(it.supports(OpenAIModel(OpenAIModel.GPT_3_5_TURBO.value))).thenReturn(false) + on(it.name).thenReturn("azure") + } + private val providers: List = listOf( default, anthropic, - ollama + ollama, + azure ) private val openAIGateway = DefaultOpenAIGateway( providers, ) + @Test + fun `should update azure provider`() { + // Given + val id = "default" + val name = "new Default" + val config = AzureAIProviderConfig( + apiKey = { "new api key" }, + baseUrl = { "new endpoint" }, + deploymentId = { "new deployment id" }, + apiVersion = { "new api version" } + ) + + val models = listOf(OpenAIModel(OpenAIModel.GPT_3_5_TURBO.value)) + + // When + openAIGateway.updateProvider(id, name, config, models) + + // Then + assertEquals(4, openAIGateway.getProviders().size) + val openAIProvider = openAIGateway.getProviders().first { it.id == id } + assertEquals(name, openAIProvider.name) + assertEquals(config, openAIProvider.config) + assertEquals(models, openAIProvider.models) + } + @Test fun `should update openai provider`() { // Given @@ -65,7 +99,7 @@ class DefaultOpenAIGatewayTest { openAIGateway.updateProvider(id, name, config, models) // Then - assertEquals(3, openAIGateway.getProviders().size) + assertEquals(4, openAIGateway.getProviders().size) val openAIProvider = openAIGateway.getProviders().first { it.id == id } assertEquals(name, openAIProvider.name) assertEquals(config, openAIProvider.config) @@ -85,7 +119,7 @@ class DefaultOpenAIGatewayTest { openAIGateway.updateProvider(id, name, config, models) // Then - assertEquals(3, openAIGateway.getProviders().size) + assertEquals(4, openAIGateway.getProviders().size) val openAIProvider = openAIGateway.getProviders().first { it.id == id } assertEquals(name, openAIProvider.name) assertEquals(config, openAIProvider.config) @@ -107,7 +141,7 @@ class DefaultOpenAIGatewayTest { openAIGateway.updateProvider(id, name, config, models) // Then - assertEquals(3, openAIGateway.getProviders().size) + assertEquals(4, openAIGateway.getProviders().size) val openAIProvider = openAIGateway.getProviders().first { it.id == id } assertEquals(name, openAIProvider.name) assertEquals(config, openAIProvider.config) @@ -118,8 +152,8 @@ class DefaultOpenAIGatewayTest { fun `should able to remove provider`() { openAIGateway.removeProvider(anthropic.name) // Then - assertEquals(2, openAIGateway.getProviders().size) - assertEquals(ollama, openAIGateway.getProviders().last()) + assertEquals(3, openAIGateway.getProviders().size) + assertEquals(azure, openAIGateway.getProviders().last()) } @Test @@ -135,7 +169,7 @@ class DefaultOpenAIGatewayTest { } // Then - assertEquals(4, gateway.getProviders().size) + assertEquals(5, gateway.getProviders().size) assertEquals(provider, gateway.getProviders().last()) } diff --git a/openai-gateway/openai-gateway-darwin/src/appleMain/kotlin/com/tddworks/openai/gateway/api/DarwinOpenAIGateway.kt b/openai-gateway/openai-gateway-darwin/src/appleMain/kotlin/com/tddworks/openai/gateway/api/DarwinOpenAIGateway.kt index 5b7c955..262765a 100644 --- a/openai-gateway/openai-gateway-darwin/src/appleMain/kotlin/com/tddworks/openai/gateway/api/DarwinOpenAIGateway.kt +++ b/openai-gateway/openai-gateway-darwin/src/appleMain/kotlin/com/tddworks/openai/gateway/api/DarwinOpenAIGateway.kt @@ -1,5 +1,7 @@ package com.tddworks.openai.gateway.api +import com.tddworks.anthropic.api.Anthropic +import com.tddworks.ollama.api.Ollama import com.tddworks.openai.api.OpenAI import com.tddworks.openai.gateway.api.internal.anthropic import com.tddworks.openai.gateway.api.internal.default @@ -28,26 +30,19 @@ object DarwinOpenAIGateway { fun openAIGateway( openAIBaseUrl: () -> String = { OpenAI.BASE_URL }, openAIKey: () -> String = { "CONFIGURE_ME" }, - anthropicBaseUrl: () -> String = { "api.anthropic.com" }, + anthropicBaseUrl: () -> String = { Anthropic.BASE_URL }, anthropicKey: () -> String = { "CONFIGURE_ME" }, - anthropicVersion: () -> String = { "2023-06-01" }, - ollamaBaseUrl: () -> String = { "localhost" }, - ollamaPort: () -> Int = { 8080 }, - ollamaProtocol: () -> String = { "http" }, + anthropicVersion: () -> String = { Anthropic.ANTHROPIC_VERSION }, + ollamaBaseUrl: () -> String = { Ollama.BASE_URL }, ) = initOpenAIGateway( openAIConfig = OpenAIProviderConfig.default( - baseUrl = openAIBaseUrl, - apiKey = openAIKey - ), - anthropicConfig = OpenAIProviderConfig.anthropic( + baseUrl = openAIBaseUrl, apiKey = openAIKey + ), anthropicConfig = OpenAIProviderConfig.anthropic( baseUrl = anthropicBaseUrl, apiKey = anthropicKey, anthropicVersion = anthropicVersion - ), - ollamaConfig = OpenAIProviderConfig.ollama( + ), ollamaConfig = OpenAIProviderConfig.ollama( baseUrl = ollamaBaseUrl, - port = ollamaPort, - protocol = ollamaProtocol ) ) }