Skip to content

Commit

Permalink
chore: unit test for cohere and handle stop curl (#1856)
Browse files Browse the repository at this point in the history
* chore: unit test for cohere and handle stop curl

* fix: parse failed nlohmann::json

* fix: tojson string

* fix: return

* fix: escape json

---------

Co-authored-by: vansangpfiev <[email protected]>
  • Loading branch information
vansangpfiev and sangjanai authored Jan 15, 2025
1 parent b3df25d commit ccece16
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 11 deletions.
23 changes: 20 additions & 3 deletions engine/extensions/remote-engine/remote_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb,
status["has_error"] = true;
status["is_stream"] = true;
status["status_code"] = k400BadRequest;
context->need_stop = false;
(*context->callback)(std::move(status), std::move(check_error));
return size * nmemb;
}
Expand All @@ -58,7 +59,8 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb,
status["is_done"] = true;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = 200;
status["status_code"] = k200OK;
context->need_stop = false;
(*context->callback)(std::move(status), Json::Value());
break;
}
Expand Down Expand Up @@ -169,6 +171,15 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest(

curl_slist_free_all(headers);
curl_easy_cleanup(curl);
if (context.need_stop) {
CTL_DBG("No stop message received, need to stop");
Json::Value status;
status["is_done"] = true;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = k200OK;
(*context.callback)(std::move(status), Json::Value());
}
return response;
}

Expand Down Expand Up @@ -602,6 +613,7 @@ void RemoteEngine::HandleChatCompletion(
status["status_code"] = k500InternalServerError;
Json::Value error;
error["error"] = "Failed to parse response";
LOG_WARN << "Failed to parse response: " << response.body;
callback(std::move(status), std::move(error));
return;
}
Expand All @@ -626,15 +638,19 @@ void RemoteEngine::HandleChatCompletion(

try {
response_json["stream"] = false;
if (!response_json.isMember("model")) {
response_json["model"] = model;
}
response_str = renderer_.Render(template_str, response_json);
} catch (const std::exception& e) {
throw std::runtime_error("Template rendering error: " +
std::string(e.what()));
}
} catch (const std::exception& e) {
// Log error and potentially rethrow or handle accordingly
LOG_WARN << "Error in TransformRequest: " << e.what();
LOG_WARN << "Using original request body";
LOG_WARN << "Error: " << e.what();
LOG_WARN << "Response: " << response.body;
LOG_WARN << "Using original body";
response_str = response_json.toStyledString();
}

Expand All @@ -649,6 +665,7 @@ void RemoteEngine::HandleChatCompletion(
Json::Value error;
error["error"] = "Failed to parse response";
callback(std::move(status), std::move(error));
LOG_WARN << "Failed to parse response: " << response_str;
return;
}

Expand Down
1 change: 1 addition & 0 deletions engine/extensions/remote-engine/remote_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct StreamContext {
std::string model;
extensions::TemplateRenderer& renderer;
std::string stream_template;
bool need_stop = true;
};
struct CurlResponse {
std::string body;
Expand Down
13 changes: 7 additions & 6 deletions engine/extensions/template_renderer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#include <regex>
#include <stdexcept>
#include "utils/logging_utils.h"
#include "utils/string_utils.h"
namespace extensions {

TemplateRenderer::TemplateRenderer() {
// Configure Inja environment
env_.set_trim_blocks(true);
Expand All @@ -21,7 +23,8 @@ TemplateRenderer::TemplateRenderer() {
const auto& value = *args[0];

if (value.is_string()) {
return nlohmann::json(std::string("\"") + value.get<std::string>() +
return nlohmann::json(std::string("\"") +
string_utils::EscapeJson(value.get<std::string>()) +
"\"");
}
return value;
Expand All @@ -46,16 +49,14 @@ std::string TemplateRenderer::Render(const std::string& tmpl,
std::string result = env_.render(tmpl, template_data);

// Clean up any potential double quotes in JSON strings
result = std::regex_replace(result, std::regex("\\\"\\\""), "\"");
// result = std::regex_replace(result, std::regex("\\\"\\\""), "\"");

LOG_DEBUG << "Result: " << result;

// Validate JSON
auto parsed = nlohmann::json::parse(result);

return result;
} catch (const std::exception& e) {
LOG_ERROR << "Template rendering failed: " << e.what();
LOG_ERROR << "Data: " << data.toStyledString();
LOG_ERROR << "Template: " << tmpl;
throw std::runtime_error(std::string("Template rendering failed: ") +
e.what());
Expand Down Expand Up @@ -133,4 +134,4 @@ std::string TemplateRenderer::RenderFile(const std::string& template_path,
e.what());
}
}
} // namespace remote_engine
} // namespace extensions
238 changes: 236 additions & 2 deletions engine/test/components/test_remote_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ TEST_F(RemoteEngineTest, OpenAiToAnthropicRequest) {
"messages": [
{% for message in input_request.messages %}
{% if not loop.is_first %}
{"role": "{{ message.role }}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %}
{"role": "{{ message.role }}", "content": {{ tojson(message.content) }} } {% if not loop.is_last %},{% endif %}
{% endif %}
{% endfor %}
]
{% else %}
"messages": [
{% for message in input_request.messages %}
{"role": " {{ message.role}}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %}
{"role": " {{ message.role}}", "content": {{ tojson(message.content) }} } {% if not loop.is_last %},{% endif %}
{% endfor %}
]
{% endif %}
Expand Down Expand Up @@ -181,6 +181,240 @@ TEST_F(RemoteEngineTest, AnthropicResponse) {
EXPECT_TRUE(res_json["choices"][0]["message"]["content"].isNull());
}

TEST_F(RemoteEngineTest, CohereRequest) {
std::string tpl =
R"({
{% for key, value in input_request %}
{% if key == "messages" %}
{% if input_request.messages.0.role == "system" %}
"preamble": {{ tojson(input_request.messages.0.content) }},
{% if length(input_request.messages) > 2 %}
"chatHistory": [
{% for message in input_request.messages %}
{% if not loop.is_first and not loop.is_last %}
{"role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": {{ tojson(message.content) }} } {% if loop.index < length(input_request.messages) - 2 %},{% endif %}
{% endif %}
{% endfor %}
],
{% endif %}
"message": {{ tojson(last(input_request.messages).content) }}
{% else %}
{% if length(input_request.messages) > 2 %}
"chatHistory": [
{% for message in input_request.messages %}
{% if not loop.is_last %}
{ "role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": {{ tojson(message.content) }} } {% if loop.index < length(input_request.messages) - 2 %},{% endif %}
{% endif %}
{% endfor %}
],
{% endif %}
"message": {{ tojson(last(input_request.messages).content) }}
{% endif %}
{% if not loop.is_last %},{% endif %}
{% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %}
"{{ key }}": {{ tojson(value) }}
{% if not loop.is_last %},{% endif %}
{% endif %}
{% endfor %} })";
{
std::string message_with_system = R"({
"engine" : "cohere",
"max_tokens" : 1024,
"messages": [
{"role": "system", "content": "You are a seasoned data scientist at a Fortune 500 company."},
{"role": "user", "content": "Hello, world"},
{"role": "assistant", "content": "The man who is widely credited with discovering gravity is Sir Isaac Newton"},
{"role": "user", "content": "How are you today?"}
],
"model": "command-r-plus-08-2024",
"stream" : true
})";

auto data = json_helper::ParseJsonString(message_with_system);

extensions::TemplateRenderer rdr;
auto res = rdr.Render(tpl, data);

auto res_json = json_helper::ParseJsonString(res);
EXPECT_EQ(data["model"].asString(), res_json["model"].asString());
EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt());
for (auto const& msg : data["messages"]) {
if (msg["role"].asString() == "system") {
EXPECT_EQ(msg["content"].asString(), res_json["preamble"].asString());
}
}
EXPECT_EQ(res_json["message"].asString(), "How are you today?");
}

{
std::string message_without_system = R"({
"messages": [
{"role": "user", "content": "Hello, \"the\" \n\nworld"}
],
"model": "command-r-plus-08-2024",
"max_tokens": 1024,
})";

auto data = json_helper::ParseJsonString(message_without_system);

extensions::TemplateRenderer rdr;
auto res = rdr.Render(tpl, data);

auto res_json = json_helper::ParseJsonString(res);
EXPECT_EQ(data["model"].asString(), res_json["model"].asString());
EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt());
EXPECT_EQ(data["messages"][0]["content"].asString(),
res_json["message"].asString());
}
}

TEST_F(RemoteEngineTest, CohereResponse) {
std::string tpl = R"(
{% if input_request.stream %}
{"object": "chat.completion.chunk",
"model": "{{ input_request.model }}",
"choices": [{"index": 0, "delta": { {% if input_request.event_type == "text-generation" %} "role": "assistant", "content": {{ tojson(input_request.text) }} {% else %} "role": "assistant", "content": null {% endif %} },
{% if input_request.event_type == "stream-end" %} "finish_reason": "{{ input_request.finish_reason }}" {% else %} "finish_reason": null {% endif %} }]
}
{% else %}
{"id": "{{ input_request.generation_id }}",
"created": null,
"object": "chat.completion",
"model": "{{ input_request.model }}",
"choices": [{ "index": 0, "message": { "role": "assistant", "content": {% if not input_request.text %} null {% else %} {{ tojson(input_request.text) }} {% endif %}, "refusal": null }, "logprobs": null, "finish_reason": "{{ input_request.finish_reason }}" } ], "usage": { "prompt_tokens": {{ input_request.meta.tokens.input_tokens }}, "completion_tokens": {{ input_request.meta.tokens.output_tokens }}, "total_tokens": {{ input_request.meta.tokens.input_tokens + input_request.meta.tokens.output_tokens }}, "prompt_tokens_details": { "cached_tokens": 0 }, "completion_tokens_details": { "reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0 } }, "system_fingerprint": "fp_6b68a8204b"} {% endif %})";
std::string message = R"({
"event_type": "text-generation",
"text": " help"
})";
auto data = json_helper::ParseJsonString(message);
data["stream"] = true;
data["model"] = "cohere";
extensions::TemplateRenderer rdr;
auto res = rdr.Render(tpl, data);
auto res_json = json_helper::ParseJsonString(res);
EXPECT_EQ(res_json["choices"][0]["delta"]["content"].asString(), " help");

message = R"(
{
"event_type": "stream-end",
"response": {
"text": "Hello! How can I help you today?",
"generation_id": "29f14a5a-11de-4cae-9800-25e4747408ea",
"chat_history": [
{
"role": "USER",
"message": "hello world!"
},
{
"role": "CHATBOT",
"message": "Hello! How can I help you today?"
}
],
"finish_reason": "COMPLETE",
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"input_tokens": 3,
"output_tokens": 9
},
"tokens": {
"input_tokens": 69,
"output_tokens": 9
}
}
},
"finish_reason": "COMPLETE"
})";
data = json_helper::ParseJsonString(message);
data["stream"] = true;
data["model"] = "cohere";
res = rdr.Render(tpl, data);
res_json = json_helper::ParseJsonString(res);
EXPECT_TRUE(res_json["choices"][0]["delta"]["content"].isNull());

// non-stream
message = R"(
{
"text": "Isaac \t\tNewton was 'born' on 25 \"December\" 1642 (Old Style) \n\nor 4 January 1643 (New Style).",
"generation_id": "0385c7cf-4247-43a3-a450-b25b547a31e1",
"citations": [
{
"start": 25,
"end": 41,
"text": "25 December 1642",
"document_ids": [
"web-search_0"
]
}
],
"search_queries": [
{
"text": "Isaac Newton birth year",
"generation_id": "9a497980-c3e2-4460-b81c-ef44d293f95d"
}
],
"search_results": [
{
"connector": {
"id": "web-search"
},
"document_ids": [
"web-search_0"
],
"search_query": {
"text": "Isaac Newton birth year",
"generation_id": "9a497980-c3e2-4460-b81c-ef44d293f95d"
}
}
],
"finish_reason": "COMPLETE",
"chat_history": [
{
"role": "USER",
"message": "Who discovered gravity?"
},
{
"role": "CHATBOT",
"message": "The man who is widely credited with discovering gravity is Sir Isaac Newton"
},
{
"role": "USER",
"message": "What year was he born?"
},
{
"role": "CHATBOT",
"message": "Isaac Newton was born on 25 December 1642 (Old Style) or 4 January 1643 (New Style)."
}
],
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"input_tokens": 31738,
"output_tokens": 35
},
"tokens": {
"input_tokens": 32465,
"output_tokens": 205
}
}
}
)";

data = json_helper::ParseJsonString(message);
data["stream"] = false;
data["model"] = "cohere";
res = rdr.Render(tpl, data);
res_json = json_helper::ParseJsonString(res);
EXPECT_EQ(
res_json["choices"][0]["message"]["content"].asString(),
"Isaac \t\tNewton was 'born' on 25 \"December\" 1642 (Old Style) \n\nor 4 "
"January 1643 (New Style).");
}

TEST_F(RemoteEngineTest, HeaderTemplate) {
{
std::string header_template =
Expand Down
Loading

0 comments on commit ccece16

Please sign in to comment.