Skip to content

Commit

Permalink
feat: auto load support model form llm api
Browse files Browse the repository at this point in the history
  • Loading branch information
daodao97 committed Dec 15, 2024
1 parent b46a86c commit f9597a2
Show file tree
Hide file tree
Showing 13 changed files with 404 additions and 239 deletions.
4 changes: 3 additions & 1 deletion lib/llm/base_llm_client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ abstract class BaseLLMClient {

final response = await chatCompletion(
CompletionRequest(
model: ProviderManager.chatProvider.currentModel,
model: ProviderManager.chatModelProvider.currentModel,
messages: [ChatMessage(role: MessageRole.user, content: content)],
tools: openaiTools,
),
Expand Down Expand Up @@ -43,4 +43,6 @@ abstract class BaseLLMClient {
}

Future<String> genTitle(List<ChatMessage> messages);

Future<List<String>> models();
}
229 changes: 132 additions & 97 deletions lib/llm/claude_client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,34 @@ class ClaudeClient extends BaseLLMClient {
final Dio _dio;

ClaudeClient({
required this.apiKey,
required this.apiKey,
String? baseUrl,
Dio? dio,
}) : baseUrl = (baseUrl == null || baseUrl.isEmpty)
? 'https://api.anthropic.com/v1/messages'
}) : baseUrl = (baseUrl == null || baseUrl.isEmpty)
? 'https://api.anthropic.com/v1'
: baseUrl,
_dio = dio ?? Dio(BaseOptions(
headers: {
'Content-Type': 'application/json',
'x-api-key': apiKey,
'anthropic-version': '2023-06-01',
},
responseType: ResponseType.stream,
));
_dio = dio ??
Dio(BaseOptions(
headers: {
'Content-Type': 'application/json',
'x-api-key': apiKey,
'anthropic-version': '2023-06-01',
},
));

@override
Future<LLMResponse> chatCompletion(CompletionRequest request) async {
final messages = request.messages.map((m) => {
'role': m.role == MessageRole.user ? 'user' : 'assistant',
'content': [
{
'type': 'text',
'text': m.content ?? '',
}
],
}).toList();
final messages = request.messages
.map((m) => {
'role': m.role == MessageRole.user ? 'user' : 'assistant',
'content': [
{
'type': 'text',
'text': m.content ?? '',
}
],
})
.toList();

final body = {
'model': request.model,
Expand All @@ -69,46 +71,51 @@ class ClaudeClient extends BaseLLMClient {

try {
final response = await _dio.post(
baseUrl,
'$baseUrl/messages',
data: jsonEncode(body),
);

final buffer = StringBuffer();
await for (final chunk in response.data.stream) {
buffer.write(utf8.decode(chunk));
var json;
if (response.data is ResponseBody) {
final responseBody = response.data as ResponseBody;
final responseStr = await utf8.decodeStream(responseBody.stream);
json = jsonDecode(responseStr);
} else {
json = response.data;
}

final json = jsonDecode(buffer.toString());
final content = json['content'][0]['text'];

// Parse tool calls if present
final toolCalls = json['tool_calls']?.map<ToolCall>((t) => ToolCall(
id: t['id'],
type: t['type'],
function: FunctionCall(
name: t['function']['name'],
arguments: t['function']['arguments'],
),
))?.toList();
final toolCalls = json['tool_calls']
?.map<ToolCall>((t) => ToolCall(
id: t['id'],
type: t['type'],
function: FunctionCall(
name: t['function']['name'],
arguments: t['function']['arguments'],
),
))
?.toList();

return LLMResponse(
content: content,
toolCalls: toolCalls,
);

} catch (e) {
final tips = "Claude API call failed: $baseUrl body: $body error: $e";
Logger.root.severe(tips);
throw Exception(tips);
throw Exception(
"Claude API call failed: $baseUrl/messages body: ${jsonEncode(body)} error: $e");
}
}

@override
Stream<LLMResponse> chatStreamCompletion(CompletionRequest request) async* {
final messages = request.messages.map((m) => {
'role': m.role == MessageRole.user ? 'user' : 'assistant',
'content': m.content ?? '',
}).toList();
final messages = request.messages
.map((m) => {
'role': m.role == MessageRole.user ? 'user' : 'assistant',
'content': m.content ?? '',
})
.toList();

final body = {
'model': request.model,
Expand All @@ -127,15 +134,16 @@ class ClaudeClient extends BaseLLMClient {
}

try {
_dio.options.responseType = ResponseType.stream;
final response = await _dio.post(
baseUrl,
'$baseUrl/messages',
data: jsonEncode(body),
);

String buffer = '';
String currentContent = '';
List<ToolCall>? currentToolCalls;

await for (final chunk in response.data.stream) {
final decodedChunk = utf8.decode(chunk);
buffer += decodedChunk;
Expand All @@ -145,16 +153,17 @@ class ClaudeClient extends BaseLLMClient {
final line = buffer.substring(0, index).trim();
buffer = buffer.substring(index + 1);

if (!line.startsWith('data: ')) continue;
final jsonStr = line.substring(6).trim();
if (!line.startsWith('data:')) continue;

final jsonStr = line.substring(5).trim();
if (jsonStr.isEmpty) continue;

try {
final event = jsonDecode(jsonStr);
final eventType = event['type'];

switch (eventType) {
case 'content_block_start':
case 'content_block_delta':
final delta = event['delta'];
if (delta['type'] == 'text_delta') {
Expand Down Expand Up @@ -190,9 +199,8 @@ class ClaudeClient extends BaseLLMClient {
}
}
} catch (e) {
final error = "Claude streaming API call failed: $baseUrl body: $body error: $e";
Logger.root.severe(error);
throw Exception(error);
throw Exception(
"Claude streaming API call failed: $baseUrl/messages body: ${jsonEncode(body)} error: $e");
}
}

Expand All @@ -205,18 +213,19 @@ class ClaudeClient extends BaseLLMClient {

final prompt = ChatMessage(
role: MessageRole.user,
content: """Generate a concise title (max 20 characters) for the following conversation.
content:
"""Generate a concise title (max 20 characters) for the following conversation.
The title should summarize the main topic. Return only the title without any explanation or extra punctuation.
Conversation:
$conversationText""",
);

final response = await chatCompletion(CompletionRequest(
model: "claude-3-5-haiku-latest",
model: "claude-3-5-haiku-20241022",
messages: [prompt],
));

return response.content?.trim() ?? "New Chat";
}

Expand All @@ -226,35 +235,38 @@ $conversationText""",
Map<String, List<Map<String, dynamic>>> toolsResponse,
) async {
// Convert tools to Claude's format
final tools = toolsResponse.entries.map((entry) {
return entry.value.map((tool) {
final parameters = tool['parameters'];
if (parameters is! Map<String, dynamic>) {
return {
'name': tool['name'],
'description': tool['description'],
'input_schema': {
'type': 'object',
'properties': {},
'required': [],
},
};
}
final tools = toolsResponse.entries
.map((entry) {
return entry.value.map((tool) {
final parameters = tool['parameters'];
if (parameters is! Map<String, dynamic>) {
return {
'name': tool['name'],
'description': tool['description'],
'input_schema': {
'type': 'object',
'properties': {},
'required': [],
},
};
}

return {
'name': tool['name'],
'description': tool['description'],
'input_schema': {
'type': 'object',
'properties': parameters['properties'] ?? {},
'required': parameters['required'] ?? [],
},
};
}).toList();
}).expand((x) => x).toList();
return {
'name': tool['name'],
'description': tool['description'],
'input_schema': {
'type': 'object',
'properties': parameters['properties'] ?? {},
'required': parameters['required'] ?? [],
},
};
}).toList();
})
.expand((x) => x)
.toList();

final body = {
'model': ProviderManager.chatProvider.currentModel,
'model': ProviderManager.chatModelProvider.currentModel,
'messages': [
{
'role': 'user',
Expand All @@ -267,19 +279,21 @@ $conversationText""",

try {
final response = await _dio.post(
baseUrl,
'$baseUrl/messages',
data: jsonEncode(body),
);

final buffer = StringBuffer();
await for (final chunk in response.data.stream) {
buffer.write(utf8.decode(chunk));
var jsonData;
if (response.data is ResponseBody) {
final responseBody = response.data as ResponseBody;
final responseStr = await utf8.decodeStream(responseBody.stream);
jsonData = jsonDecode(responseStr);
} else {
jsonData = response.data;
}

final json = jsonDecode(buffer.toString());

// Check if response contains tool calls in the content array
final contentBlocks = json['content'] as List?;
final contentBlocks = jsonData['content'] as List?;
if (contentBlocks == null || contentBlocks.isEmpty) {
return {
'need_tool_call': false,
Expand All @@ -288,9 +302,9 @@ $conversationText""",
}

// Look for tool_calls in the response
final toolUseBlocks = contentBlocks.where((block) =>
block['type'] == 'tool_calls' || block['type'] == 'tool_use');
final toolUseBlocks = contentBlocks.where((block) =>
block['type'] == 'tool_calls' || block['type'] == 'tool_use');

if (toolUseBlocks.isEmpty) {
// Get text content from the first text block
final textBlock = contentBlocks.firstWhere(
Expand All @@ -304,11 +318,13 @@ $conversationText""",
}

// Extract tool calls
final toolCalls = toolUseBlocks.map((block) => {
'id': block['id'],
'name': block['name'],
'arguments': block['input'],
}).toList();
final toolCalls = toolUseBlocks
.map((block) => {
'id': block['id'],
'name': block['name'],
'arguments': block['input'],
})
.toList();

// Get any accompanying text content
final textBlock = contentBlocks.firstWhere(
Expand All @@ -321,10 +337,29 @@ $conversationText""",
'content': textBlock['text'] ?? '',
'tool_calls': toolCalls,
};

} catch (e) {
Logger.root.severe('Claude tool call check failed: $baseUrl body: $body error: $e');
throw Exception('Failed to check tool calls: $e');
throw Exception(
'Claude tool call check failed: $baseUrl/messages body: ${jsonEncode(body)} error: $e');
}
}

@override
Future<List<String>> models() async {
try {
final response = await _dio.get("$baseUrl/models");

final data = response.data;

final models = (data['data'] as List)
.map((m) => m['id'].toString())
.where((id) => id.contains('claude'))
.toList();

return models;
} catch (e, trace) {
Logger.root.severe('获取模型列表失败: $e, trace: $trace');
// 返回预定义的模型列表作为后备
return [];
}
}
}
Loading

0 comments on commit f9597a2

Please sign in to comment.