Skip to content

Commit

Permalink
feat(gemini): support gemini TextGeneration api
Browse files Browse the repository at this point in the history
  • Loading branch information
hanrw committed Dec 31, 2024
1 parent b68a9ad commit 5426e16
Show file tree
Hide file tree
Showing 15 changed files with 517 additions and 1 deletion.
14 changes: 14 additions & 0 deletions gemini-client/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
plugins {
`maven-publish`
}

kotlin {
jvm()
sourceSets {
commonMain {
dependencies {
api(projects.geminiClient.geminiClientCore)
}
}
}
}
48 changes: 48 additions & 0 deletions gemini-client/gemini-client-core/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
plugins {
alias(libs.plugins.kotlinx.serialization)
alias(libs.plugins.kover)
`maven-publish`
}

kotlin {
jvm()
macosArm64()
macosX64()

sourceSets {
commonMain.dependencies {
// put your Multiplatform dependencies here
api(projects.common)
}

commonTest.dependencies {
implementation(libs.ktor.client.mock)
api(projects.common)
}

macosMain.dependencies {
api(libs.ktor.client.darwin)
}

jvmMain.dependencies {
api(libs.ktor.client.cio)
}

jvmTest.dependencies {
implementation(project.dependencies.platform(libs.junit.bom))
implementation(libs.bundles.jvm.test)
implementation(libs.kotlinx.coroutines.test)
implementation(libs.koin.test)
implementation(libs.koin.test.junit5)
implementation(libs.app.cash.turbine)
implementation("com.tngtech.archunit:archunit-junit5:1.1.0")
implementation("org.reflections:reflections:0.10.2")
}
}
}

tasks {
named<Test>("jvmTest") {
useJUnitPlatform()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.tddworks.gemini.api.textGeneration.api

interface Gemini : TextGeneration {
companion object {
const val HOST = "generativelanguage.googleapis.com"
const val BASE_URL = "https://$HOST"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.tddworks.gemini.api.textGeneration.api

data class GeminiConfig(
val apiKey: () -> String = { "CONFIG_API_KEY" },
val baseUrl: () -> String = { Gemini.BASE_URL },
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.tddworks.gemini.api.textGeneration.api

import kotlinx.serialization.Serializable
import kotlin.jvm.JvmInline

/**
* https://ai.google.dev/gemini-api/docs/models/gemini
*/
@Serializable
@JvmInline
value class GeminiModel(val value: String) {
companion object {

/**
* Input(s): Audio, images, videos, and text
* Output(s): Text
* Optimized for: Complex reasoning tasks requiring more intelligence
*/
val GEMINI_1_5_PRO = GeminiModel("gemini-1.5-pro")

/**
* Input(s): Audio, images, videos, and text
* Output(s): Text
* Optimized for: High volume and lower intelligence tasks
*/
val GEMINI_1_5_FLASH_8b = GeminiModel("gemini-1.5-flash-8b")

/**
* Input(s): Audio, images, videos, and text
* Output(s): Text
* Optimized for: Fast and versatile performance across a diverse variety of tasks
*/
val GEMINI_1_5_FLASH = GeminiModel("gemini-1.5-flash")

/**
* Input(s): Audio, images, videos, and text
* Output(s): Text, images (coming soon), and audio (coming soon)
* Optimized for: Next generation features, speed, and multimodal generation for a diverse variety of tasks
*/
val GEMINI_2_0_FLASH = GeminiModel("gemini-2.0-flash-exp")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.tddworks.gemini.api.textGeneration.api

import com.tddworks.gemini.api.textGeneration.api.internal.DefaultTextGenerationApi.Companion.GEMINI_API_PATH
import kotlinx.serialization.Serializable
import kotlinx.serialization.Transient

// curl https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:streamGenerateContent?alt=sse&key=$GOOGLE_API_KEY \
// curl https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=$GOOGLE_API_KEY \

/**
* {
* "contents": [
* {"role":"user",
* "parts":[{
* "text": "Hello"}]},
* {"role": "model",
* "parts":[{
* "text": "Great to meet you. What would you like to know?"}]},
* {"role":"user",
* "parts":[{
* "text": "I have two dogs in my house. How many paws are in my house?"}]},
* ]
* }
*/
@Serializable
data class GenerateContentRequest(
val contents: List<Content>,
@Transient
val model: GeminiModel = GeminiModel.GEMINI_1_5_FLASH,
@Transient
val stream: Boolean = false,
@Transient
val apiKey: String = ""
) {
fun toRequestUrl(): String {
val endpoint = if (stream) {
"streamGenerateContent?alt=sse&key=$apiKey"
} else {
"generateContent?key=$apiKey"
}
return "$GEMINI_API_PATH/${model.value}:$endpoint"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.tddworks.gemini.api.textGeneration.api

import kotlinx.serialization.Serializable

@Serializable
data class GenerateContentResponse(
val candidates: List<Candidate>,
val usageMetadata: UsageMetadata,
val modelVersion: String
)

@Serializable
data class Candidate(
val content: Content,
val finishReason: String,
val avgLogprobs: Double
)

@Serializable
data class Content(
val parts: List<Part>,
val role: String
)

@Serializable
data class Part(
val text: String
)

@Serializable
data class UsageMetadata(
val promptTokenCount: Int,
val candidatesTokenCount: Int,
val totalTokenCount: Int
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.tddworks.gemini.api.textGeneration.api

import kotlinx.coroutines.flow.Flow

/**
* https://ai.google.dev/api/generate-content#v1beta.Candidate
*/
interface TextGeneration {
suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse

/**
* data: {"candidates": [{"content": {"parts": [{"text": " understand that AI is a constantly evolving field. New techniques and approaches are continually being developed, pushing the boundaries of what's possible. While AI can achieve impressive feats, it's important to remember that it's a tool, and its capabilities are limited by the data it's trained on and the algorithms"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 4,"totalTokenCount": 4},"modelVersion": "gemini-1.5-flash"}
*
* data: {"candidates": [{"content": {"parts": [{"text": " it uses. It doesn't possess consciousness or genuine understanding in the human sense.\n"}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 4,"candidatesTokenCount": 724,"totalTokenCount": 728},"modelVersion": "gemini-1.5-flash"}
*/
fun streamGenerateContent(request: GenerateContentRequest): Flow<GenerateContentResponse>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.tddworks.gemini.api.textGeneration.api.internal

import com.tddworks.common.network.api.ktor.api.HttpRequester
import com.tddworks.common.network.api.ktor.api.performRequest
import com.tddworks.gemini.api.textGeneration.api.GenerateContentRequest
import com.tddworks.gemini.api.textGeneration.api.GenerateContentResponse
import com.tddworks.gemini.api.textGeneration.api.TextGeneration
import io.ktor.client.request.*
import io.ktor.http.*
import kotlinx.coroutines.flow.Flow

class DefaultTextGenerationApi(
private val requester: HttpRequester
) : TextGeneration {
override suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse {
return requester.performRequest<GenerateContentResponse> {
method = HttpMethod.Post
url(path = request.toRequestUrl())
setBody(request)
contentType(ContentType.Application.Json)
}
}

override fun streamGenerateContent(request: GenerateContentRequest): Flow<GenerateContentResponse> {
TODO("Not yet implemented")
}

companion object {
const val GEMINI_API_PATH = "/v1beta/models"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.tddworks.gemini.api.textGeneration.api

import org.junit.jupiter.api.Assertions.*
import org.junit.jupiter.api.Test

class GenerateContentRequestTest {

@Test
fun `should return correct streamGenerateContent request url`() {
// Given
val generateContentRequest = GenerateContentRequest(
contents = listOf(),
stream = true,
apiKey = "some-key"
)

// When
val result = generateContentRequest.toRequestUrl()

// Then
assertEquals(
"/v1beta/models/gemini-1.5-flash:streamGenerateContent?alt=sse&key=some-key",
result
)
}


@Test
fun `should return correct generateContent request url`() {
// Given
val generateContentRequest = GenerateContentRequest(
contents = listOf(),
stream = false,
apiKey = "some-key"
)

// When
val result = generateContentRequest.toRequestUrl()

// Then
assertEquals(
"/v1beta/models/gemini-1.5-flash:generateContent?key=some-key",
result
)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.tddworks.gemini.api.textGeneration.api

import kotlinx.serialization.json.Json
import org.junit.jupiter.api.Test
import kotlin.test.assertEquals

class GenerateContentResponseTest {

@Test
fun `should deserialize GenerateContentResponse`() {
// Given
val json = """
{
"candidates": [
{
"content": {
"parts": [
{
"text": "some-text"
}
],
"role": "model"
},
"finishReason": "STOP",
"avgLogprobs": -0.24741496906413898
}
],
"usageMetadata": {
"promptTokenCount": 4,
"candidatesTokenCount": 715,
"totalTokenCount": 719
},
"modelVersion": "gemini-1.5-flash"
}
""".trimIndent()

// When
val response = Json.decodeFromString<GenerateContentResponse>(json)

// Then
assertEquals(1, response.candidates.size)
assertEquals("gemini-1.5-flash", response.modelVersion)
assertEquals(4, response.usageMetadata.promptTokenCount)
assertEquals(715, response.usageMetadata.candidatesTokenCount)
assertEquals(719, response.usageMetadata.totalTokenCount)
assertEquals("STOP", response.candidates[0].finishReason)
assertEquals(-0.24741496906413898, response.candidates[0].avgLogprobs)
assertEquals("model", response.candidates[0].content.role)
assertEquals(1, response.candidates[0].content.parts.size)
assertEquals("some-text", response.candidates[0].content.parts[0].text)
}
}
Loading

0 comments on commit 5426e16

Please sign in to comment.