From 192c4da5476fa211f51a1164c5dd8ba6ddd7d6de Mon Sep 17 00:00:00 2001 From: James Date: Thu, 14 Nov 2024 12:00:05 +0700 Subject: [PATCH 01/44] chore: cleanup httplib --- engine/CMakeLists.txt | 2 - engine/cli/CMakeLists.txt | 2 - engine/cli/command_line_parser.cc | 19 +- engine/cli/commands/chat_completion_cmd.cc | 189 +++++++------ engine/cli/commands/chat_completion_cmd.h | 5 - engine/cli/commands/cortex_upd_cmd.cc | 249 ++++++++---------- engine/cli/commands/cortex_upd_cmd.h | 14 +- engine/cli/commands/engine_list_cmd.cc | 2 +- engine/cli/commands/hardware_activate_cmd.cc | 51 ++-- engine/cli/commands/hardware_activate_cmd.h | 4 +- engine/cli/commands/hardware_list_cmd.cc | 13 +- engine/cli/commands/model_del_cmd.cc | 26 +- engine/cli/commands/model_get_cmd.cc | 30 +-- engine/cli/commands/model_import_cmd.cc | 32 ++- engine/cli/commands/model_import_cmd.h | 3 +- engine/cli/commands/model_list_cmd.cc | 66 +++-- engine/cli/commands/model_pull_cmd.cc | 154 +++++------ engine/cli/commands/model_pull_cmd.h | 10 +- engine/cli/commands/model_start_cmd.cc | 52 ++-- engine/cli/commands/model_start_cmd.h | 8 +- engine/cli/commands/model_status_cmd.cc | 28 +- engine/cli/commands/model_status_cmd.h | 8 +- engine/cli/commands/model_stop_cmd.cc | 33 ++- engine/cli/commands/model_stop_cmd.h | 7 - engine/cli/commands/model_upd_cmd.cc | 35 ++- engine/cli/commands/model_upd_cmd.h | 10 +- engine/cli/commands/ps_cmd.cc | 20 +- engine/cli/commands/run_cmd.cc | 12 +- engine/cli/commands/run_cmd.h | 6 +- engine/cli/commands/server_start_cmd.cc | 2 - engine/cli/commands/server_start_cmd.h | 26 +- engine/cli/commands/server_stop_cmd.cc | 23 +- engine/cli/main.cc | 2 + engine/controllers/server.h | 5 - engine/e2e-test/main.py | 3 +- engine/e2e-test/test_api_docker.py | 53 ++-- engine/e2e-test/test_api_engine_uninstall.py | 13 +- ..._start.py => test_api_model_start_stop.py} | 31 ++- engine/e2e-test/test_api_model_stop.py | 38 --- engine/e2e-test/test_cli_model_delete.py | 16 +- engine/e2e-test/test_runner.py | 37 +++ engine/main.cc | 2 + engine/services/model_service.cc | 14 +- engine/test/components/CMakeLists.txt | 2 - engine/utils/config_yaml_utils.h | 2 +- engine/utils/cortex_utils.h | 2 - .../utils/cpuid/detail/init_linux_gcc_arm.h | 2 +- engine/utils/curl_utils.h | 2 - engine/vcpkg.json | 4 - 49 files changed, 648 insertions(+), 721 deletions(-) rename engine/e2e-test/{test_api_model_start.py => test_api_model_start_stop.py} (74%) delete mode 100644 engine/e2e-test/test_api_model_stop.py diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index b53eb7fdf..5ffabf23c 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -71,7 +71,6 @@ add_subdirectory(cli) find_package(jsoncpp CONFIG REQUIRED) find_package(Drogon CONFIG REQUIRED) find_package(yaml-cpp CONFIG REQUIRED) -find_package(httplib CONFIG REQUIRED) find_package(unofficial-minizip CONFIG REQUIRED) find_package(LibArchive REQUIRED) find_package(CURL REQUIRED) @@ -147,7 +146,6 @@ add_executable(${TARGET_NAME} main.cc target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) -target_link_libraries(${TARGET_NAME} PRIVATE httplib::httplib) target_link_libraries(${TARGET_NAME} PRIVATE unofficial::minizip::minizip) target_link_libraries(${TARGET_NAME} PRIVATE LibArchive::LibArchive) target_link_libraries(${TARGET_NAME} PRIVATE CURL::libcurl) diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index c69e7e150..db2bed828 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -61,7 +61,6 @@ add_compile_definitions(CORTEX_CONFIG_FILE_PATH="${CORTEX_CONFIG_FILE_PATH}") find_package(jsoncpp CONFIG REQUIRED) find_package(yaml-cpp CONFIG REQUIRED) -find_package(httplib CONFIG REQUIRED) find_package(CLI11 CONFIG REQUIRED) find_package(unofficial-minizip CONFIG REQUIRED) find_package(LibArchive REQUIRED) @@ -87,7 +86,6 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/download_progress.cc ) -target_link_libraries(${TARGET_NAME} PRIVATE httplib::httplib) target_link_libraries(${TARGET_NAME} PRIVATE CLI11::CLI11) target_link_libraries(${TARGET_NAME} PRIVATE unofficial::minizip::minizip) target_link_libraries(${TARGET_NAME} PRIVATE LibArchive::LibArchive) diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index bb41dbe8b..e1b2f5feb 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -149,9 +149,9 @@ void CommandLineParser::SetupCommonCommands() { return; } try { - commands::ModelPullCmd(download_service_) - .Exec(cml_data_.config.apiServerHost, - std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id); + commands::ModelPullCmd().Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), + cml_data_.model_id); } catch (const std::exception& e) { CLI_LOG(e.what()); } @@ -214,10 +214,9 @@ void CommandLineParser::SetupModelCommands() { CLI_LOG(model_start_cmd->help()); return; }; - commands::ModelStartCmd(model_service_) - .Exec(cml_data_.config.apiServerHost, - std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id, - hw_activate_opts_); + commands::ModelStartCmd().Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), + cml_data_.model_id, hw_activate_opts_); }); auto stop_model_cmd = @@ -234,9 +233,9 @@ void CommandLineParser::SetupModelCommands() { CLI_LOG(stop_model_cmd->help()); return; }; - commands::ModelStopCmd(model_service_) - .Exec(cml_data_.config.apiServerHost, - std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id); + commands::ModelStopCmd().Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), + cml_data_.model_id); }); auto list_models_cmd = diff --git a/engine/cli/commands/chat_completion_cmd.cc b/engine/cli/commands/chat_completion_cmd.cc index f81040bac..0067b1c08 100644 --- a/engine/cli/commands/chat_completion_cmd.cc +++ b/engine/cli/commands/chat_completion_cmd.cc @@ -1,8 +1,8 @@ #include "chat_completion_cmd.h" +#include #include "config/yaml_config.h" #include "cortex_upd_cmd.h" #include "database/models.h" -#include "httplib.h" #include "model_status_cmd.h" #include "server_start_cmd.h" #include "utils/engine_constants.h" @@ -16,29 +16,42 @@ constexpr const auto kMinDataChunkSize = 6u; constexpr const char* kUser = "user"; constexpr const char* kAssistant = "assistant"; -} // namespace +struct StreamingCallback { + std::string* ai_chat; + bool is_done; -struct ChunkParser { - std::string content; - bool is_done = false; + StreamingCallback() : ai_chat(nullptr), is_done(false) {} +}; - ChunkParser(const char* data, size_t data_length) { - if (data && data_length > kMinDataChunkSize) { - std::string s(data + kMinDataChunkSize, data_length - kMinDataChunkSize); - if (s.find("[DONE]") != std::string::npos) { - is_done = true; - } else { - try { - content = - json_helper::ParseJsonString(s)["choices"][0]["delta"]["content"] - .asString(); - } catch (const std::exception& e) { - CTL_WRN("JSON parse error: " << e.what()); - } +size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) { + auto* callback = static_cast(userdata); + size_t data_length = size * nmemb; + + if (ptr && data_length > kMinDataChunkSize) { + std::string chunk(ptr + kMinDataChunkSize, data_length - kMinDataChunkSize); + if (chunk.find("[DONE]") != std::string::npos) { + callback->is_done = true; + std::cout << std::endl; + return data_length; + } + + try { + std::string content = + json_helper::ParseJsonString(chunk)["choices"][0]["delta"]["content"] + .asString(); + std::cout << content << std::flush; + if (callback->ai_chat) { + *callback->ai_chat += content; } + } catch (const std::exception& e) { + CTL_WRN("JSON parse error: " << e.what()); } } -}; + + return data_length; +} + +} // namespace void ChatCompletionCmd::Exec(const std::string& host, int port, const std::string& model_handle, std::string msg) { @@ -68,95 +81,101 @@ void ChatCompletionCmd::Exec(const std::string& host, int port, const std::string& model_handle, const config::ModelConfig& mc, std::string msg) { auto address = host + ":" + std::to_string(port); + // Check if server is started - { - if (!commands::IsServerAlive(host, port)) { - CLI_LOG("Server is not started yet, please run `" - << commands::GetCortexBinary() << " start` to start server!"); - return; - } + if (!commands::IsServerAlive(host, port)) { + CLI_LOG("Server is not started yet, please run `" + << commands::GetCortexBinary() << " start` to start server!"); + return; } // Only check if llamacpp engine if ((mc.engine.find(kLlamaEngine) != std::string::npos || mc.engine.find(kLlamaRepo) != std::string::npos) && - !commands::ModelStatusCmd(model_service_) - .IsLoaded(host, port, model_handle)) { + !commands::ModelStatusCmd().IsLoaded(host, port, model_handle)) { CLI_LOG("Model is not loaded yet!"); return; } + auto curl = curl_easy_init(); + if (!curl) { + CLI_LOG("Failed to initialize CURL"); + return; + } + + std::string url = "http://" + address + "/v1/chat/completions"; + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, "Content-Type: application/json"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + // Interactive mode or not bool interactive = msg.empty(); - // Some instruction for user here if (interactive) { - std::cout << "Inorder to exit, type `exit()`" << std::endl; + std::cout << "In order to exit, type `exit()`" << std::endl; } - // Model is loaded, start to chat - { - do { - std::string user_input = std::move(msg); - if (user_input.empty()) { - std::cout << "> "; - if (!std::getline(std::cin, user_input)) { - break; - } - } - string_utils::Trim(user_input); - if (user_input == kExitChat) { + do { + std::string user_input = std::move(msg); + if (user_input.empty()) { + std::cout << "> "; + if (!std::getline(std::cin, user_input)) { break; } + } + + string_utils::Trim(user_input); + if (user_input == kExitChat) { + break; + } + + if (!user_input.empty()) { + // Prepare JSON payload + Json::Value new_data; + new_data["role"] = kUser; + new_data["content"] = user_input; + histories_.push_back(std::move(new_data)); + + Json::Value json_data = mc.ToJson(); + json_data["engine"] = mc.engine; + + Json::Value msgs_array(Json::arrayValue); + for (const auto& m : histories_) { + msgs_array.append(m); + } + + json_data["messages"] = msgs_array; + json_data["model"] = model_handle; + json_data["stream"] = true; - if (!user_input.empty()) { - httplib::Client cli(address); - Json::Value json_data = mc.ToJson(); - Json::Value new_data; - new_data["role"] = kUser; - new_data["content"] = user_input; - histories_.push_back(std::move(new_data)); - json_data["engine"] = mc.engine; - Json::Value msgs_array(Json::arrayValue); - for (const auto& m : histories_) { - msgs_array.append(m); - } - json_data["messages"] = msgs_array; - json_data["model"] = model_handle; - //TODO: support non-stream - json_data["stream"] = true; - auto data_str = json_data.toStyledString(); - // std::cout << data_str << std::endl; - cli.set_read_timeout(std::chrono::seconds(60)); - // std::cout << "> "; - httplib::Request req; - req.headers = httplib::Headers(); - req.set_header("Content-Type", "application/json"); - req.method = "POST"; - req.path = "/v1/chat/completions"; - req.body = data_str; - std::string ai_chat; - req.content_receiver = [&](const char* data, size_t data_length, - uint64_t offset, uint64_t total_length) { - ChunkParser cp(data, data_length); - if (cp.is_done) { - std::cout << std::endl; - return false; - } - std::cout << cp.content << std::flush; - ai_chat += cp.content; - return true; - }; - cli.send(req); + std::string json_payload = json_data.toStyledString(); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_payload.c_str()); + + std::string ai_chat; + StreamingCallback callback; + callback.ai_chat = &ai_chat; + + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &callback); + + CURLcode res = curl_easy_perform(curl); + + if (res != CURLE_OK) { + CLI_LOG("CURL request failed: " << curl_easy_strerror(res)); + } else { Json::Value ai_res; ai_res["role"] = kAssistant; ai_res["content"] = ai_chat; histories_.push_back(std::move(ai_res)); } - // std::cout << "ok Done" << std::endl; - } while (interactive); - } -} + } + } while (interactive); -}; // namespace commands + curl_slist_free_all(headers); + curl_easy_cleanup(curl); +} +} // namespace commands diff --git a/engine/cli/commands/chat_completion_cmd.h b/engine/cli/commands/chat_completion_cmd.h index 746c25eb3..a784b4604 100644 --- a/engine/cli/commands/chat_completion_cmd.h +++ b/engine/cli/commands/chat_completion_cmd.h @@ -3,14 +3,10 @@ #include #include #include "config/model_config.h" -#include "services/model_service.h" namespace commands { class ChatCompletionCmd { public: - explicit ChatCompletionCmd(const ModelService& model_service) - : model_service_{model_service} {}; - void Exec(const std::string& host, int port, const std::string& model_handle, std::string msg); void Exec(const std::string& host, int port, const std::string& model_handle, @@ -18,6 +14,5 @@ class ChatCompletionCmd { private: std::vector histories_; - ModelService model_service_; }; } // namespace commands diff --git a/engine/cli/commands/cortex_upd_cmd.cc b/engine/cli/commands/cortex_upd_cmd.cc index fcb45fc5c..5d7b4bf4c 100644 --- a/engine/cli/commands/cortex_upd_cmd.cc +++ b/engine/cli/commands/cortex_upd_cmd.cc @@ -1,9 +1,9 @@ #include "cortex_upd_cmd.h" -#include "httplib.h" +#include "cli/commands/server_start_cmd.h" #include "server_stop_cmd.h" #include "utils/archive_utils.h" +#include "utils/curl_utils.h" #include "utils/file_manager_utils.h" -#include "utils/json_helper.h" #include "utils/logging_utils.h" #include "utils/scope_exit.h" #include "utils/system_info_utils.h" @@ -151,69 +151,62 @@ std::optional CheckNewUpdate( return config.latestRelease; } - auto host_name = GetHostName(); - auto release_path = GetReleasePath(); - CTL_INF("Engine release path: " << host_name << release_path); + auto url = url_parser::Url{ + .protocol = "https", + .host = GetHostName(), + .pathParams = GetReleasePath(), + }; - httplib::Client cli(host_name); - if (timeout.has_value()) { - cli.set_connection_timeout(*timeout); - cli.set_read_timeout(*timeout); + CTL_INF("Engine release path: " << url.ToFullPath()); + + auto res = curl_utils::SimpleGetJson(url.ToFullPath()); + if (res.has_error()) { + CTL_INF("HTTP error: " << res.error()); + return std::nullopt; } - if (auto res = cli.Get(release_path)) { - if (res->status == httplib::StatusCode::OK_200) { - try { - auto get_latest = [](const Json::Value& data) -> std::string { - if (data.empty()) { - return ""; - } - if (CORTEX_VARIANT == file_manager_utils::kBetaVariant) { - for (const auto& d : data) { - if (auto tag = d["tag_name"].asString(); - tag.find(kBetaComp) != std::string::npos) { - return tag; - } - } - return data[0]["tag_name"].asString(); - } else { - return data["tag_name"].asString(); + try { + auto get_latest = [](const Json::Value& data) -> std::string { + if (data.empty()) { + return ""; + } + + if (CORTEX_VARIANT == file_manager_utils::kBetaVariant) { + for (const auto& d : data) { + if (auto tag = d["tag_name"].asString(); + tag.find(kBetaComp) != std::string::npos) { + return tag; } - return ""; - }; - - auto json_res = json_helper::ParseJsonString(res->body); - std::string latest_version = get_latest(json_res); - if (latest_version.empty()) { - CTL_WRN("Release not found!"); - return std::nullopt; - } - std::string current_version = CORTEX_CPP_VERSION; - CTL_INF("Got the latest release, update to the config file: " - << latest_version) - config.latestRelease = latest_version; - auto result = - config_yaml_utils::CortexConfigMgr::GetInstance().DumpYamlConfig( - config, file_manager_utils::GetConfigurationPath().string()); - if (result.has_error()) { - CTL_ERR("Error update " - << file_manager_utils::GetConfigurationPath().string() - << result.error()); } - if (current_version != latest_version) { - return latest_version; - } - } catch (const std::exception& e) { - CTL_INF("JSON parse error: " << e.what()); - return std::nullopt; + return data[0]["tag_name"].asString(); + } else { + return data["tag_name"].asString(); } - } else { - CTL_INF("HTTP error: " << res->status); + return ""; + }; + + auto latest_version = get_latest(res.value()); + if (latest_version.empty()) { + CTL_WRN("Release not found!"); return std::nullopt; } - } else { - auto err = res.error(); - CTL_INF("HTTP error: " << httplib::to_string(err)); + std::string current_version = CORTEX_CPP_VERSION; + CTL_INF( + "Got the latest release, update to the config file: " << latest_version) + config.latestRelease = latest_version; + auto result = + config_yaml_utils::CortexConfigMgr::GetInstance().DumpYamlConfig( + config, file_manager_utils::GetConfigurationPath().string()); + if (result.has_error()) { + CTL_ERR("Error update " + << file_manager_utils::GetConfigurationPath().string() + << result.error()); + } + if (current_version != latest_version) { + return latest_version; + } + } catch (const std::exception& e) { + CTL_INF("JSON parse error: " << e.what()); return std::nullopt; } return std::nullopt; @@ -230,9 +223,9 @@ void CortexUpdCmd::Exec(const std::string& v, bool force) { { auto config = file_manager_utils::GetCortexConfig(); - httplib::Client cli(config.apiServerHost + ":" + config.apiServerPort); - auto res = cli.Get("/healthz"); - if (res) { + auto server_running = commands::IsServerAlive( + config.apiServerHost, std::stoi(config.apiServerPort)); + if (server_running) { CLI_LOG("Server is running. Stopping server before updating!"); commands::ServerStopCmd ssc(config.apiServerHost, std::stoi(config.apiServerPort)); @@ -270,38 +263,32 @@ bool CortexUpdCmd::GetStable(const std::string& v) { auto system_info = GetSystemInfoWithUniversal(); CTL_INF("OS: " << system_info->os << ", Arch: " << system_info->arch); - // Download file - auto github_host = GetHostName(); - auto release_path = GetReleasePath(); - CTL_INF("Engine release path: " << github_host << release_path); + auto url_obj = url_parser::Url{ + .protocol = "https", + .host = GetHostName(), + .pathParams = GetReleasePath(), + }; + CTL_INF("Engine release path: " << url_obj.ToFullPath()); - httplib::Client cli(github_host); - if (auto res = cli.Get(release_path)) { - if (res->status == httplib::StatusCode::OK_200) { - try { - auto json_data = json_helper::ParseJsonString(res->body); - if (json_data.empty()) { - CLI_LOG("Version not found: " << v); - return false; - } + auto res = curl_utils::SimpleGetJson(url_obj.ToFullPath()); + if (res.has_error()) { + CLI_LOG_ERROR("HTTP error: " << res.error()); + return false; + } - if (downloaded_exe_path = HandleGithubRelease( - json_data["assets"], - {system_info->os + "-" + system_info->arch}); - !downloaded_exe_path) { - return false; - } - } catch (const std::exception& e) { - CLI_LOG_ERROR("JSON parse error: " << e.what()); - return false; - } - } else { - CLI_LOG_ERROR("HTTP error: " << res->status); + try { + if (res.value().empty()) { + CLI_LOG("Version not found: " << v); return false; } - } else { - auto err = res.error(); - CLI_LOG_ERROR("HTTP error: " << httplib::to_string(err)); + + if (downloaded_exe_path = HandleGithubRelease( + res.value()["assets"], {system_info->os + "-" + system_info->arch}); + !downloaded_exe_path) { + return false; + } + } catch (const std::exception& e) { + CLI_LOG_ERROR("JSON parse error: " << e.what()); return false; } @@ -330,50 +317,42 @@ bool CortexUpdCmd::GetBeta(const std::string& v) { auto system_info = GetSystemInfoWithUniversal(); CTL_INF("OS: " << system_info->os << ", Arch: " << system_info->arch); - // Download file - auto github_host = GetHostName(); - auto release_path = GetReleasePath(); - CTL_INF("Engine release path: " << github_host << release_path); + auto url_obj = url_parser::Url{ + .protocol = "https", + .host = GetHostName(), + .pathParams = GetReleasePath(), + }; + CTL_INF("Engine release path: " << url_obj.ToFullPath()); + auto res = curl_utils::SimpleGetJson(url_obj.ToFullPath()); + if (res.has_error()) { + CLI_LOG_ERROR("HTTP error: " << res.error()); + return false; + } - httplib::Client cli(github_host); - if (auto res = cli.Get(release_path)) { - if (res->status == httplib::StatusCode::OK_200) { - try { - auto json_res = json_helper::ParseJsonString(res->body); - - Json::Value json_data; - for (const auto& jr : json_res) { - // Get the latest beta or match version - if (auto tag = jr["tag_name"].asString(); - (v.empty() && tag.find(kBetaComp) != std::string::npos) || - (tag == v)) { - json_data = jr; - break; - } - } + try { + Json::Value json_data; + for (const auto& jr : res.value()) { + // Get the latest beta or match version + if (auto tag = jr["tag_name"].asString(); + (v.empty() && tag.find(kBetaComp) != std::string::npos) || + (tag == v)) { + json_data = jr; + break; + } + } - if (json_data.empty()) { - CLI_LOG("Version not found: " << v); - return false; - } + if (json_data.empty()) { + CLI_LOG("Version not found: " << v); + return false; + } - if (downloaded_exe_path = HandleGithubRelease( - json_data["assets"], - {system_info->os + "-" + system_info->arch}); - !downloaded_exe_path) { - return false; - } - } catch (const std::exception& e) { - CLI_LOG_ERROR("JSON parse error: " << e.what()); - return false; - } - } else { - CLI_LOG_ERROR("HTTP error: " << res->status); + if (downloaded_exe_path = HandleGithubRelease( + json_data["assets"], {system_info->os + "-" + system_info->arch}); + !downloaded_exe_path) { return false; } - } else { - auto err = res.error(); - CLI_LOG_ERROR("HTTP error: " << httplib::to_string(err)); + } catch (const std::exception& e) { + CLI_LOG_ERROR("JSON parse error: " << e.what()); return false; } @@ -430,13 +409,15 @@ std::optional CortexUpdCmd::HandleGithubRelease( CLI_LOG_ERROR("Failed to create directories: " << e.what()); return std::nullopt; } - auto download_task{DownloadTask{.id = "cortex", - .type = DownloadType::Cortex, - .items = {DownloadItem{ - .id = "cortex", - .downloadUrl = download_url, - .localPath = local_path, - }}}}; + auto download_task{DownloadTask{ + .id = "cortex", + .type = DownloadType::Cortex, + .items = {DownloadItem{ + .id = "cortex", + .downloadUrl = download_url, + .localPath = local_path, + }}, + }}; auto result = download_service_->AddDownloadTask( download_task, [](const DownloadTask& finishedTask) { diff --git a/engine/cli/commands/cortex_upd_cmd.h b/engine/cli/commands/cortex_upd_cmd.h index 9c500a999..01793992f 100644 --- a/engine/cli/commands/cortex_upd_cmd.h +++ b/engine/cli/commands/cortex_upd_cmd.h @@ -1,5 +1,7 @@ #pragma once + #include +#include #include "services/download_service.h" #if !defined(_WIN32) #include @@ -67,19 +69,19 @@ inline std::string GetCortexServerBinary() { inline std::string GetHostName() { if (CORTEX_VARIANT == file_manager_utils::kNightlyVariant) { - return "https://delta.jan.ai"; + return "delta.jan.ai"; } else { - return "https://api.github.com"; + return "api.github.com"; } } -inline std::string GetReleasePath() { +inline std::vector GetReleasePath() { if (CORTEX_VARIANT == file_manager_utils::kNightlyVariant) { - return "/cortex/latest/version.json"; + return {"cortex", "latest", "version.json"}; } else if (CORTEX_VARIANT == file_manager_utils::kBetaVariant) { - return "/repos/janhq/cortex.cpp/releases"; + return {"repos", "janhq", "cortex.cpp", "releases"}; } else { - return "/repos/janhq/cortex.cpp/releases/latest"; + return {"repos", "janhq", "cortex.cpp", "releases", "latest"}; } } diff --git a/engine/cli/commands/engine_list_cmd.cc b/engine/cli/commands/engine_list_cmd.cc index 3a2b527c9..b010e8687 100644 --- a/engine/cli/commands/engine_list_cmd.cc +++ b/engine/cli/commands/engine_list_cmd.cc @@ -1,8 +1,8 @@ #include "engine_list_cmd.h" #include #include +#include "common/engine_servicei.h" #include "server_start_cmd.h" -#include "services/engine_service.h" #include "utils/curl_utils.h" #include "utils/logging_utils.h" #include "utils/url_parser.h" diff --git a/engine/cli/commands/hardware_activate_cmd.cc b/engine/cli/commands/hardware_activate_cmd.cc index a0f34e4b7..77d600233 100644 --- a/engine/cli/commands/hardware_activate_cmd.cc +++ b/engine/cli/commands/hardware_activate_cmd.cc @@ -36,7 +36,6 @@ bool HardwareActivateCmd::Exec( } } - // TODO(sang) should use curl but it does not work (?) Json::Value body; Json::Value gpus_json = Json::arrayValue; std::vector gpus; @@ -51,36 +50,30 @@ bool HardwareActivateCmd::Exec( body["gpus"] = gpus_json; auto data_str = body.toStyledString(); - httplib::Client cli(host + ":" + std::to_string(port)); + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "hardware", "activate"}, + }; - auto res = cli.Post("/v1/hardware/activate", httplib::Headers(), - data_str.data(), data_str.size(), "application/json"); - if (res) { - if (res->status == httplib::StatusCode::OK_200) { - auto root = json_helper::ParseJsonString(res->body); - if (!root["warning"].isNull()) { - CLI_LOG(root["warning"].asString()); - } - if(body["gpus"].empty()) { - CLI_LOG("Deactivated all GPUs!"); - } else { - std::string gpus_str; - for(auto i: gpus) { - gpus_str += " " + std::to_string(i); - } - CLI_LOG("Activated GPUs:" << gpus_str); - } - return true; - } else { - auto root = json_helper::ParseJsonString(res->body); - CLI_LOG(root["message"].asString()); - return false; - } - } else { - auto err = res.error(); - CTL_ERR("HTTP error: " << httplib::to_string(err)); + auto res = curl_utils::SimplePostJson(url.ToFullPath(), data_str); + if (res.has_error()) { + auto root = json_helper::ParseJsonString(res.error()); + CLI_LOG(root["message"].asString()); return false; } + if (!res.value()["warning"].isNull()) { + CLI_LOG(res.value()["warning"].asString()); + } + if (body["gpus"].empty()) { + CLI_LOG("Deactivated all GPUs!"); + } else { + std::string gpus_str; + for (auto i : gpus) { + gpus_str += " " + std::to_string(i); + } + CLI_LOG("Activated GPUs:" << gpus_str); + } return true; } -} // namespace commands \ No newline at end of file +} // namespace commands diff --git a/engine/cli/commands/hardware_activate_cmd.h b/engine/cli/commands/hardware_activate_cmd.h index eb5b68cc3..82676ca99 100644 --- a/engine/cli/commands/hardware_activate_cmd.h +++ b/engine/cli/commands/hardware_activate_cmd.h @@ -1,7 +1,7 @@ #pragma once + #include #include -#include "common/hardware_config.h" namespace commands { class HardwareActivateCmd { @@ -9,4 +9,4 @@ class HardwareActivateCmd { bool Exec(const std::string& host, int port, const std::unordered_map& options); }; -} // namespace commands \ No newline at end of file +} // namespace commands diff --git a/engine/cli/commands/hardware_list_cmd.cc b/engine/cli/commands/hardware_list_cmd.cc index a800b0e24..6d6cccbc3 100644 --- a/engine/cli/commands/hardware_list_cmd.cc +++ b/engine/cli/commands/hardware_list_cmd.cc @@ -1,21 +1,12 @@ #include "hardware_list_cmd.h" - #include #include #include - #include -#include "httplib.h" #include "server_start_cmd.h" +#include "services/hardware_service.h" #include "utils/curl_utils.h" -#include "utils/hardware/cpu_info.h" -#include "utils/hardware/gpu_info.h" -#include "utils/hardware/os_info.h" -#include "utils/hardware/power_info.h" -#include "utils/hardware/ram_info.h" -#include "utils/hardware/storage_info.h" #include "utils/logging_utils.h" -#include "utils/string_utils.h" // clang-format off #include // clang-format on @@ -186,4 +177,4 @@ bool HardwareListCmd::Exec(const std::string& host, int port, return true; } -} // namespace commands \ No newline at end of file +} // namespace commands diff --git a/engine/cli/commands/model_del_cmd.cc b/engine/cli/commands/model_del_cmd.cc index d78fcc921..2f46aa52a 100644 --- a/engine/cli/commands/model_del_cmd.cc +++ b/engine/cli/commands/model_del_cmd.cc @@ -1,7 +1,8 @@ #include "model_del_cmd.h" -#include "httplib.h" #include "server_start_cmd.h" +#include "utils/curl_utils.h" #include "utils/logging_utils.h" +#include "utils/url_parser.h" namespace commands { @@ -16,18 +17,17 @@ void ModelDelCmd::Exec(const std::string& host, int port, } } - // Call API to delete model - httplib::Client cli(host + ":" + std::to_string(port)); - auto res = cli.Delete("/v1/models/" + model_handle); - if (res) { - if (res->status == httplib::StatusCode::OK_200) { - CLI_LOG("Model " + model_handle + " deleted successfully"); - } else { - CTL_ERR("Model failed to delete with status code: " << res->status); - } - } else { - auto err = res.error(); - CTL_ERR("HTTP error: " << httplib::to_string(err)); + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models", model_handle}, + }; + + auto res = curl_utils::SimpleDeleteJson(url.ToFullPath()); + if (res.has_error()) { + CLI_LOG("Failed to delete model: " << res.error()); + return; } + CLI_LOG("Model " + model_handle + " deleted successfully"); } } // namespace commands diff --git a/engine/cli/commands/model_get_cmd.cc b/engine/cli/commands/model_get_cmd.cc index 2c7c294e3..c4a400136 100644 --- a/engine/cli/commands/model_get_cmd.cc +++ b/engine/cli/commands/model_get_cmd.cc @@ -1,8 +1,9 @@ #include "model_get_cmd.h" -#include "httplib.h" #include "server_start_cmd.h" +#include "utils/curl_utils.h" #include "utils/json_helper.h" #include "utils/logging_utils.h" +#include "utils/url_parser.h" namespace commands { @@ -17,20 +18,19 @@ void ModelGetCmd::Exec(const std::string& host, int port, } } - // Call API to delete model - httplib::Client cli(host + ":" + std::to_string(port)); - auto res = cli.Get("/v1/models/" + model_handle); - if (res) { - if (res->status == httplib::StatusCode::OK_200) { - CLI_LOG(res->body); - } else { - auto root = json_helper::ParseJsonString(res->body); - CLI_LOG(root["message"].asString()); - } - } else { - auto err = res.error(); - CTL_ERR("HTTP error: " << httplib::to_string(err)); + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models", model_handle}, + }; + + auto res = curl_utils::SimpleGetJson(url.ToFullPath()); + if (res.has_error()) { + auto root = json_helper::ParseJsonString(res.error()); + CLI_LOG(root["message"].asString()); + return; } -} + CLI_LOG(res.value().toStyledString()); +} } // namespace commands diff --git a/engine/cli/commands/model_import_cmd.cc b/engine/cli/commands/model_import_cmd.cc index f8cf6a810..fbc01be7d 100644 --- a/engine/cli/commands/model_import_cmd.cc +++ b/engine/cli/commands/model_import_cmd.cc @@ -1,8 +1,10 @@ #include "model_import_cmd.h" #include -#include "httplib.h" #include "server_start_cmd.h" +#include "utils/curl_utils.h" +#include "utils/json_helper.h" #include "utils/logging_utils.h" +#include "utils/url_parser.h" namespace commands { @@ -18,23 +20,25 @@ void ModelImportCmd::Exec(const std::string& host, int port, } } - httplib::Client cli(host + ":" + std::to_string(port)); + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models", "import"}, + }; + Json::Value json_data; json_data["model"] = model_handle; json_data["modelPath"] = model_path; auto data_str = json_data.toStyledString(); - auto res = cli.Post("/v1/models/import", httplib::Headers(), data_str.data(), - data_str.size(), "application/json"); - if (res) { - if (res->status == httplib::StatusCode::OK_200) { - CLI_LOG("Successfully import model from '" + model_path + - "' for modeID '" + model_handle + "'."); - } else { - CTL_ERR("Model failed to import model with status code: " << res->status); - } - } else { - auto err = res.error(); - CTL_ERR("HTTP error: " << httplib::to_string(err)); + + auto res = curl_utils::SimplePostJson(url.ToFullPath(), data_str); + if (res.has_error()) { + auto root = json_helper::ParseJsonString(res.error()); + CLI_LOG(root["message"].asString()); + return; } + + CLI_LOG("Successfully import model from '" + model_path + "' for modelID '" + + model_handle + "'."); } } // namespace commands diff --git a/engine/cli/commands/model_import_cmd.h b/engine/cli/commands/model_import_cmd.h index 141351909..685e8d5fc 100644 --- a/engine/cli/commands/model_import_cmd.h +++ b/engine/cli/commands/model_import_cmd.h @@ -1,6 +1,7 @@ #pragma once #include + namespace commands { class ModelImportCmd { @@ -8,4 +9,4 @@ class ModelImportCmd { void Exec(const std::string& host, int port, const std::string& model_handle, const std::string& model_path); }; -} // namespace commands \ No newline at end of file +} // namespace commands diff --git a/engine/cli/commands/model_list_cmd.cc b/engine/cli/commands/model_list_cmd.cc index 41fe61d1c..c63ed0012 100644 --- a/engine/cli/commands/model_list_cmd.cc +++ b/engine/cli/commands/model_list_cmd.cc @@ -2,12 +2,13 @@ #include #include #include - #include -#include "httplib.h" #include "server_start_cmd.h" +#include "utils/curl_utils.h" +#include "utils/json_helper.h" #include "utils/logging_utils.h" #include "utils/string_utils.h" +#include "utils/url_parser.h" // clang-format off #include // clang-format on @@ -44,43 +45,40 @@ void ModelListCmd::Exec(const std::string& host, int port, int count = 0; // Iterate through directory - httplib::Client cli(host + ":" + std::to_string(port)); - auto res = cli.Get("/v1/models"); - if (res) { - if (res->status == httplib::StatusCode::OK_200) { - Json::Value body; - Json::Reader reader; - reader.parse(res->body, body); - if (!body["data"].isNull()) { - for (auto const& v : body["data"]) { - auto model_id = v["model"].asString(); - if (!filter.empty() && - !string_utils::StringContainsIgnoreCase(model_id, filter)) { - continue; - } + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models"}, + }; - count += 1; + auto res = curl_utils::SimpleGetJson(url.ToFullPath()); + if (res.has_error()) { + auto root = json_helper::ParseJsonString(res.error()); + CLI_LOG(root["message"].asString()); + return; + } + + if (!res.value()["data"].isNull()) { + for (auto const& v : res.value()["data"]) { + auto model_id = v["model"].asString(); + if (!filter.empty() && + !string_utils::StringContainsIgnoreCase(model_id, filter)) { + continue; + } - std::vector row = {std::to_string(count), - v["model"].asString()}; - if (display_engine) { - row.push_back(v["engine"].asString()); - } - if (display_version) { - row.push_back(v["version"].asString()); - } + count += 1; - table.add_row({row.begin(), row.end()}); - } + std::vector row = {std::to_string(count), + v["model"].asString()}; + if (display_engine) { + row.push_back(v["engine"].asString()); } - } else { - CTL_ERR("Failed to get model list with status code: " << res->status); - return; + if (display_version) { + row.push_back(v["version"].asString()); + } + + table.add_row({row.begin(), row.end()}); } - } else { - auto err = res.error(); - CTL_ERR("HTTP error: " << httplib::to_string(err)); - return; } std::cout << table << std::endl; diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 376943fd1..75c0ce1a0 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -1,10 +1,13 @@ #include "model_pull_cmd.h" +#include #include "server_start_cmd.h" #include "utils/cli_selection_utils.h" +#include "utils/curl_utils.h" #include "utils/download_progress.h" #include "utils/json_helper.h" #include "utils/logging_utils.h" #include "utils/scope_exit.h" +#include "utils/url_parser.h" #if defined(_WIN32) #include #endif @@ -33,65 +36,57 @@ std::optional ModelPullCmd::Exec(const std::string& host, int port, } } - // Get model info from Server - httplib::Client cli(host + ":" + std::to_string(port)); - cli.set_read_timeout(std::chrono::seconds(60)); + auto model_info_url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"models", "pull", "info"}, + }; Json::Value j_data; j_data["model"] = input; auto d_str = j_data.toStyledString(); - auto res = cli.Post("/models/pull/info", httplib::Headers(), d_str.data(), - d_str.size(), "application/json"); - - if (res) { - if (res->status == httplib::StatusCode::OK_200) { - // CLI_LOG(res->body); - auto root = json_helper::ParseJsonString(res->body); - auto id = root["id"].asString(); - bool is_cortexso = root["modelSource"].asString() == "cortexso"; - auto default_branch = root["defaultBranch"].asString(); - std::vector downloaded; - for (auto const& v : root["downloadedModels"]) { - downloaded.push_back(v.asString()); - } - std::vector avails; - for (auto const& v : root["availableModels"]) { - avails.push_back(v.asString()); - } - auto download_url = root["downloadUrl"].asString(); - - if (downloaded.empty() && avails.empty()) { - model_id = id; - model = download_url; - } else { - if (is_cortexso) { - auto selection = cli_selection_utils::PrintModelSelection( - downloaded, avails, - default_branch.empty() - ? std::nullopt - : std::optional(default_branch)); - - if (!selection.has_value()) { - CLI_LOG("Invalid selection"); - return std::nullopt; - } - model_id = selection.value(); - model = model_id; - } else { - auto selection = cli_selection_utils::PrintSelection(avails); - CLI_LOG("Selected: " << selection.value()); - model_id = id + ":" + selection.value(); - model = download_url + selection.value(); - } + auto res = curl_utils::SimplePostJson(model_info_url.ToFullPath(), d_str); + + if (res.has_error()) { + auto root = json_helper::ParseJsonString(res.error()); + CLI_LOG(root["message"].asString()); + return std::nullopt; + } + + auto id = res.value()["id"].asString(); + bool is_cortexso = res.value()["modelSource"].asString() == "cortexso"; + auto default_branch = res.value()["defaultBranch"].asString(); + std::vector downloaded; + for (auto const& v : res.value()["downloadedModels"]) { + downloaded.push_back(v.asString()); + } + std::vector avails; + for (auto const& v : res.value()["availableModels"]) { + avails.push_back(v.asString()); + } + auto download_url = res.value()["downloadUrl"].asString(); + + if (downloaded.empty() && avails.empty()) { + model_id = id; + model = download_url; + } else { + if (is_cortexso) { + auto selection = cli_selection_utils::PrintModelSelection( + downloaded, avails, + default_branch.empty() ? std::nullopt + : std::optional(default_branch)); + + if (!selection.has_value()) { + CLI_LOG("Invalid selection"); + return std::nullopt; } + model_id = selection.value(); + model = model_id; } else { - auto root = json_helper::ParseJsonString(res->body); - CLI_LOG(root["message"].asString()); - return std::nullopt; + auto selection = cli_selection_utils::PrintSelection(avails); + CLI_LOG("Selected: " << selection.value()); + model_id = id + ":" + selection.value(); + model = download_url + selection.value(); } - } else { - auto err = res.error(); - CTL_ERR("HTTP error: " << httplib::to_string(err)); - return std::nullopt; } CTL_INF("model: " << model << ", model_id: " << model_id); @@ -99,19 +94,18 @@ std::optional ModelPullCmd::Exec(const std::string& host, int port, Json::Value json_data; json_data["model"] = model; auto data_str = json_data.toStyledString(); - cli.set_read_timeout(std::chrono::seconds(60)); - res = cli.Post("/v1/models/pull", httplib::Headers(), data_str.data(), - data_str.size(), "application/json"); - - if (res) { - if (res->status != httplib::StatusCode::OK_200) { - auto root = json_helper::ParseJsonString(res->body); - CLI_LOG(root["message"].asString()); - return std::nullopt; - } - } else { - auto err = res.error(); - CTL_ERR("HTTP error: " << httplib::to_string(err)); + + auto pull_url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models", "pull"}, + }; + + auto pull_result = + curl_utils::SimplePostJson(pull_url.ToFullPath(), data_str); + if (pull_result.has_error()) { + auto root = json_helper::ParseJsonString(pull_result.error()); + CLI_LOG(root["message"].asString()); return std::nullopt; } @@ -154,23 +148,19 @@ bool ModelPullCmd::AbortModelPull(const std::string& host, int port, Json::Value json_data; json_data["taskId"] = task_id; auto data_str = json_data.toStyledString(); - httplib::Client cli(host + ":" + std::to_string(port)); - cli.set_read_timeout(std::chrono::seconds(60)); - auto res = cli.Delete("/v1/models/pull", httplib::Headers(), data_str.data(), - data_str.size(), "application/json"); - if (res) { - if (res->status == httplib::StatusCode::OK_200) { - CTL_INF("Abort model pull successfully: " << task_id); - return true; - } else { - auto root = json_helper::ParseJsonString(res->body); - CLI_LOG(root["message"].asString()); - return false; - } - } else { - auto err = res.error(); - CTL_ERR("HTTP error: " << httplib::to_string(err)); + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models", "pull"}, + }; + auto res = curl_utils::SimpleDeleteJson(url.ToFullPath(), data_str); + + if (res.has_error()) { + auto root = json_helper::ParseJsonString(res.error()); + CLI_LOG(root["message"].asString()); return false; } + CTL_INF("Abort model pull successfully: " << task_id); + return true; } }; // namespace commands diff --git a/engine/cli/commands/model_pull_cmd.h b/engine/cli/commands/model_pull_cmd.h index d05759dbc..022da9c84 100644 --- a/engine/cli/commands/model_pull_cmd.h +++ b/engine/cli/commands/model_pull_cmd.h @@ -1,23 +1,17 @@ #pragma once -#include "services/model_service.h" +#include +#include namespace commands { class ModelPullCmd { public: - explicit ModelPullCmd(std::shared_ptr download_service) - : model_service_{ModelService(download_service)} {}; - explicit ModelPullCmd(const ModelService& model_service) - : model_service_{model_service} {}; std::optional Exec(const std::string& host, int port, const std::string& input); private: bool AbortModelPull(const std::string& host, int port, const std::string& task_id); - - private: - ModelService model_service_; }; } // namespace commands diff --git a/engine/cli/commands/model_start_cmd.cc b/engine/cli/commands/model_start_cmd.cc index 9b2f9d4b3..ea6b81e5a 100644 --- a/engine/cli/commands/model_start_cmd.cc +++ b/engine/cli/commands/model_start_cmd.cc @@ -1,7 +1,6 @@ #include "model_start_cmd.h" #include "cortex_upd_cmd.h" #include "hardware_activate_cmd.h" -#include "httplib.h" #include "run_cmd.h" #include "server_start_cmd.h" #include "utils/cli_selection_utils.h" @@ -14,7 +13,7 @@ bool ModelStartCmd::Exec( const std::unordered_map& options, bool print_success_log) { std::optional model_id = - SelectLocalModel(host, port, model_service_, model_handle); + SelectLocalModel(host, port, model_handle); if (!model_id.has_value()) { return false; @@ -46,41 +45,34 @@ bool ModelStartCmd::Exec( while (count--) { std::this_thread::sleep_for(std::chrono::milliseconds(500)); if (commands::IsServerAlive(host, port)) - break; + break; } } - // Call API to start model - httplib::Client cli(host + ":" + std::to_string(port)); + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models", "start"}, + }; + Json::Value json_data; json_data["model"] = model_id.value(); auto data_str = json_data.toStyledString(); - cli.set_read_timeout(std::chrono::seconds(60)); - auto res = cli.Post("/v1/models/start", httplib::Headers(), data_str.data(), - data_str.size(), "application/json"); - if (res) { - if (res->status == httplib::StatusCode::OK_200) { - if (print_success_log) { - CLI_LOG(model_id.value() - << " model started successfully. Use `" - << commands::GetCortexBinary() << " run " << *model_id - << "` for interactive chat shell"); - } - auto root = json_helper::ParseJsonString(res->body); - if (!root["warning"].isNull()) { - CLI_LOG(root["warning"].asString()); - } - return true; - } else { - auto root = json_helper::ParseJsonString(res->body); - CLI_LOG(root["message"].asString()); - return false; - } - } else { - auto err = res.error(); - CLI_LOG("HTTP error: " << httplib::to_string(err)); + auto res = curl_utils::SimplePostJson(url.ToFullPath(), data_str); + if (res.has_error()) { + auto root = json_helper::ParseJsonString(res.error()); + CLI_LOG(root["message"].asString()); return false; } -} + if (print_success_log) { + CLI_LOG(model_id.value() << " model started successfully. Use `" + << commands::GetCortexBinary() << " run " + << *model_id << "` for interactive chat shell"); + } + if (!res.value()["warning"].isNull()) { + CLI_LOG(res.value()["warning"].asString()); + } + return true; +} }; // namespace commands diff --git a/engine/cli/commands/model_start_cmd.h b/engine/cli/commands/model_start_cmd.h index 652d37994..519db0f0d 100644 --- a/engine/cli/commands/model_start_cmd.h +++ b/engine/cli/commands/model_start_cmd.h @@ -1,20 +1,14 @@ #pragma once + #include #include -#include "services/model_service.h" namespace commands { class ModelStartCmd { public: - explicit ModelStartCmd(const ModelService& model_service) - : model_service_{model_service} {}; - bool Exec(const std::string& host, int port, const std::string& model_handle, const std::unordered_map& options, bool print_success_log = true); - - private: - ModelService model_service_; }; } // namespace commands diff --git a/engine/cli/commands/model_status_cmd.cc b/engine/cli/commands/model_status_cmd.cc index 6677fe0af..cd9f3034d 100644 --- a/engine/cli/commands/model_status_cmd.cc +++ b/engine/cli/commands/model_status_cmd.cc @@ -1,7 +1,9 @@ #include "model_status_cmd.h" -#include "httplib.h" #include "server_start_cmd.h" +#include "utils/curl_utils.h" +#include "utils/json_helper.h" #include "utils/logging_utils.h" +#include "utils/url_parser.h" namespace commands { bool ModelStatusCmd::IsLoaded(const std::string& host, int port, @@ -14,22 +16,20 @@ bool ModelStatusCmd::IsLoaded(const std::string& host, int port, return false; } } + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models", "status", model_handle}, + }; - // Call API to check model status - httplib::Client cli(host + ":" + std::to_string(port)); - auto res = cli.Get("/v1/models/status/" + model_handle); - if (res) { - if (res->status == httplib::StatusCode::OK_200) { - CTL_INF(res->body); - } else { - CTL_WRN("Failed to get model status with code: " << res->status); - return false; - } - } else { - auto err = res.error(); - CTL_WRN("HTTP error: " << httplib::to_string(err)); + auto res = curl_utils::SimpleGetJson(url.ToFullPath()); + if (res.has_error()) { + auto root = json_helper::ParseJsonString(res.error()); + CLI_LOG(root["message"].asString()); return false; } + + CTL_INF(res.value().toStyledString()); return true; } } // namespace commands diff --git a/engine/cli/commands/model_status_cmd.h b/engine/cli/commands/model_status_cmd.h index 3bf1cb115..de094e748 100644 --- a/engine/cli/commands/model_status_cmd.h +++ b/engine/cli/commands/model_status_cmd.h @@ -1,18 +1,12 @@ #pragma once + #include -#include "services/model_service.h" namespace commands { class ModelStatusCmd { public: - explicit ModelStatusCmd(const ModelService& model_service) - : model_service_{model_service} {}; - bool IsLoaded(const std::string& host, int port, const std::string& model_handle); - - private: - ModelService model_service_; }; } // namespace commands diff --git a/engine/cli/commands/model_stop_cmd.cc b/engine/cli/commands/model_stop_cmd.cc index 9a14b0876..291977dc7 100644 --- a/engine/cli/commands/model_stop_cmd.cc +++ b/engine/cli/commands/model_stop_cmd.cc @@ -1,30 +1,29 @@ #include "model_stop_cmd.h" -#include "httplib.h" +#include +#include "utils/curl_utils.h" #include "utils/logging_utils.h" +#include "utils/url_parser.h" namespace commands { void ModelStopCmd::Exec(const std::string& host, int port, const std::string& model_handle) { - // Call API to stop model - httplib::Client cli(host + ":" + std::to_string(port)); + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models", "stop"}, + }; + Json::Value json_data; json_data["model"] = model_handle; auto data_str = json_data.toStyledString(); - auto res = cli.Post("/v1/models/stop", httplib::Headers(), data_str.data(), - data_str.size(), "application/json"); - if (res) { - if (res->status == httplib::StatusCode::OK_200) { - CLI_LOG("Model unloaded!"); - } else { - auto root = json_helper::ParseJsonString(res->body); - CLI_LOG(root["message"].asString()); - return; - } - } else { - auto err = res.error(); - CLI_LOG("HTTP error: " << httplib::to_string(err)); + auto res = curl_utils::SimplePostJson(url.ToFullPath(), data_str); + + if (res.has_error()) { + CLI_LOG_ERROR("Failed to stop model: " << res.error()); + return; } -} + CLI_LOG("Model stopped!"); +} }; // namespace commands diff --git a/engine/cli/commands/model_stop_cmd.h b/engine/cli/commands/model_stop_cmd.h index f341e70d2..f437f5000 100644 --- a/engine/cli/commands/model_stop_cmd.h +++ b/engine/cli/commands/model_stop_cmd.h @@ -1,18 +1,11 @@ #pragma once #include -#include "services/model_service.h" namespace commands { class ModelStopCmd { public: - explicit ModelStopCmd(const ModelService& model_service) - : model_service_{model_service} {}; - void Exec(const std::string& host, int port, const std::string& model_handle); - - private: - ModelService model_service_; }; } // namespace commands diff --git a/engine/cli/commands/model_upd_cmd.cc b/engine/cli/commands/model_upd_cmd.cc index af37efd5f..6534d1fbd 100644 --- a/engine/cli/commands/model_upd_cmd.cc +++ b/engine/cli/commands/model_upd_cmd.cc @@ -1,9 +1,9 @@ #include "model_upd_cmd.h" -#include "httplib.h" - #include "server_start_cmd.h" -#include "utils/file_manager_utils.h" +#include "utils/curl_utils.h" +#include "utils/json_helper.h" #include "utils/logging_utils.h" +#include "utils/url_parser.h" namespace commands { @@ -22,7 +22,12 @@ void ModelUpdCmd::Exec( } } - httplib::Client cli(host + ":" + std::to_string(port)); + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models", model_handle_}, + }; + Json::Value json_data; for (const auto& [key, value] : options) { if (!value.empty()) { @@ -30,21 +35,15 @@ void ModelUpdCmd::Exec( } } auto data_str = json_data.toStyledString(); - auto res = cli.Patch("/v1/models/" + model_handle_, httplib::Headers(), - data_str.data(), data_str.size(), "application/json"); - if (res) { - if (res->status == httplib::StatusCode::OK_200) { - CLI_LOG("Successfully updated model ID '" + model_handle_ + "'!"); - return; - } else { - CTL_ERR("Model failed to update with status code: " << res->status); - return; - } - } else { - auto err = res.error(); - CTL_ERR("HTTP error: " << httplib::to_string(err)); + auto res = curl_utils::SimplePatchJson(url.ToFullPath(), data_str); + if (res.has_error()) { + auto root = json_helper::ParseJsonString(res.error()); + CLI_LOG(root["message"].asString()); return; } + + CLI_LOG("Successfully updated model ID '" + model_handle_ + "'!"); + return; } void ModelUpdCmd::UpdateConfig(Json::Value& data, const std::string& key, @@ -335,4 +334,4 @@ void ModelUpdCmd::UpdateBooleanField(const std::string& key, bool boolValue = (value == "true" || value == "1"); setter(boolValue); } -} // namespace commands \ No newline at end of file +} // namespace commands diff --git a/engine/cli/commands/model_upd_cmd.h b/engine/cli/commands/model_upd_cmd.h index f2eaa8675..0a78c3eae 100644 --- a/engine/cli/commands/model_upd_cmd.h +++ b/engine/cli/commands/model_upd_cmd.h @@ -1,11 +1,11 @@ #pragma once -#include -#include + +#include +#include #include #include #include -#include -#include "json/json.h" + namespace commands { class ModelUpdCmd { public: @@ -28,4 +28,4 @@ class ModelUpdCmd { private: std::string model_handle_; }; -} // namespace commands \ No newline at end of file +} // namespace commands diff --git a/engine/cli/commands/ps_cmd.cc b/engine/cli/commands/ps_cmd.cc index ca891dab4..c692ffc00 100644 --- a/engine/cli/commands/ps_cmd.cc +++ b/engine/cli/commands/ps_cmd.cc @@ -1,29 +1,29 @@ #include "ps_cmd.h" -#include #include #include -#include "utils/engine_constants.h" +#include "utils/curl_utils.h" #include "utils/format_utils.h" -#include "utils/json_helper.h" #include "utils/logging_utils.h" #include "utils/string_utils.h" +#include "utils/url_parser.h" namespace commands { void PsCmd::Exec(const std::string& host, int port) { - auto host_and_port{host + ":" + std::to_string(port)}; - httplib::Client cli(host_and_port); - - auto res = cli.Get("/inferences/server/models"); - if (!res || res->status != httplib::StatusCode::OK_200) { + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"inferences", "server", "models"}, + }; + auto res = curl_utils::SimpleGetJson(url.ToFullPath()); + if (res.has_error()) { CLI_LOG("No models loaded!"); return; } - auto data = json_helper::ParseJsonString(res->body)["data"]; std::vector model_status_list; try { - for (const auto& item : data) { + for (const auto& item : res.value()["data"]) { ModelLoadedStatus model_status; // TODO(sang) hardcode for now model_status.engine = kLlamaEngine; diff --git a/engine/cli/commands/run_cmd.cc b/engine/cli/commands/run_cmd.cc index 279128552..1b71f1af7 100644 --- a/engine/cli/commands/run_cmd.cc +++ b/engine/cli/commands/run_cmd.cc @@ -14,7 +14,6 @@ namespace commands { std::optional SelectLocalModel(std::string host, int port, - ModelService& model_service, const std::string& model_handle) { std::optional model_id = model_handle; cortex::db::Models modellist_handler; @@ -45,7 +44,7 @@ std::optional SelectLocalModel(std::string host, int port, } else { auto related_models_ids = modellist_handler.FindRelatedModel(model_handle); if (related_models_ids.has_error() || related_models_ids.value().empty()) { - auto result = ModelPullCmd(model_service).Exec(host, port, model_handle); + auto result = ModelPullCmd().Exec(host, port, model_handle); if (!result) { CLI_LOG("Model " << model_handle << " not found!"); return std::nullopt; @@ -70,7 +69,7 @@ std::optional SelectLocalModel(std::string host, int port, void RunCmd::Exec(bool run_detach, const std::unordered_map& options) { std::optional model_id = - SelectLocalModel(host_, port_, model_service_, model_handle_); + SelectLocalModel(host_, port_, model_handle_); if (!model_id.has_value()) { return; } @@ -127,10 +126,9 @@ void RunCmd::Exec(bool run_detach, { if ((mc.engine.find(kLlamaRepo) == std::string::npos && mc.engine.find(kLlamaEngine) == std::string::npos) || - !commands::ModelStatusCmd(model_service_) - .IsLoaded(host_, port_, *model_id)) { + !commands::ModelStatusCmd().IsLoaded(host_, port_, *model_id)) { - auto res = commands::ModelStartCmd(model_service_) + auto res = commands::ModelStartCmd() .Exec(host_, port_, *model_id, options, false /*print_success_log*/); if (!res) { @@ -146,7 +144,7 @@ void RunCmd::Exec(bool run_detach, << commands::GetCortexBinary() << " run " << *model_id << "` for interactive chat shell"); } else { - ChatCompletionCmd(model_service_).Exec(host_, port_, *model_id, mc, ""); + ChatCompletionCmd().Exec(host_, port_, *model_id, mc, ""); } } } catch (const std::exception& e) { diff --git a/engine/cli/commands/run_cmd.h b/engine/cli/commands/run_cmd.h index 6e524c6b1..c0f6a4eb2 100644 --- a/engine/cli/commands/run_cmd.h +++ b/engine/cli/commands/run_cmd.h @@ -3,12 +3,10 @@ #include #include #include "services/engine_service.h" -#include "services/model_service.h" namespace commands { std::optional SelectLocalModel(std::string host, int port, - ModelService& model_service, const std::string& model_handle); class RunCmd { @@ -19,8 +17,7 @@ class RunCmd { port_{port}, model_handle_{std::move(model_handle)}, download_service_(download_service), - engine_service_{EngineService(download_service)}, - model_service_{ModelService(download_service)} {}; + engine_service_{EngineService(download_service)} {}; void Exec(bool chat_flag, const std::unordered_map& options); @@ -31,7 +28,6 @@ class RunCmd { std::string model_handle_; std::shared_ptr download_service_; - ModelService model_service_; EngineService engine_service_; }; } // namespace commands diff --git a/engine/cli/commands/server_start_cmd.cc b/engine/cli/commands/server_start_cmd.cc index 5ba972463..6f36515f1 100644 --- a/engine/cli/commands/server_start_cmd.cc +++ b/engine/cli/commands/server_start_cmd.cc @@ -25,8 +25,6 @@ bool TryConnectToServer(const std::string& host, int port) { } } // namespace -ServerStartCmd::ServerStartCmd() {} - bool ServerStartCmd::Exec(const std::string& host, int port, const std::optional& log_level) { std::string log_level_; diff --git a/engine/cli/commands/server_start_cmd.h b/engine/cli/commands/server_start_cmd.h index 780123172..f3880532e 100644 --- a/engine/cli/commands/server_start_cmd.h +++ b/engine/cli/commands/server_start_cmd.h @@ -1,22 +1,30 @@ #pragma once -#include -#include "httplib.h" #include +#include +#include "utils/curl_utils.h" +#include "utils/logging_utils.h" +#include "utils/url_parser.h" + namespace commands { inline bool IsServerAlive(const std::string& host, int port) { - httplib::Client cli(host + ":" + std::to_string(port)); - auto res = cli.Get("/healthz"); - if (res && res->status == httplib::StatusCode::OK_200) { - return true; + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"healthz"}, + }; + auto res = curl_utils::SimpleGet(url.ToFullPath()); + if (res.has_error()) { + CTL_WRN("Server is not alive: " << res.error()); + return false; } - return false; + return true; } class ServerStartCmd { public: - ServerStartCmd(); - bool Exec(const std::string& host, int port, const std::optional& log_level = std::nullopt); + bool Exec(const std::string& host, int port, + const std::optional& log_level = std::nullopt); }; } // namespace commands diff --git a/engine/cli/commands/server_stop_cmd.cc b/engine/cli/commands/server_stop_cmd.cc index e55446923..303022174 100644 --- a/engine/cli/commands/server_stop_cmd.cc +++ b/engine/cli/commands/server_stop_cmd.cc @@ -1,20 +1,25 @@ #include "server_stop_cmd.h" -#include "httplib.h" +#include "utils/curl_utils.h" #include "utils/logging_utils.h" +#include "utils/url_parser.h" namespace commands { ServerStopCmd::ServerStopCmd(std::string host, int port) : host_(std::move(host)), port_(port) {} void ServerStopCmd::Exec() { - httplib::Client cli(host_ + ":" + std::to_string(port_)); - auto res = cli.Delete("/processManager/destroy"); - if (res) { - CLI_LOG("Server stopped!"); - } else { - auto err = res.error(); - CLI_LOG_ERROR("HTTP error: " << httplib::to_string(err)); + auto url = url_parser::Url{ + .protocol = "http", + .host = host_ + ":" + std::to_string(port_), + .pathParams = {"processManager", "destroy"}, + }; + + auto res = curl_utils::SimpleDeleteJson(url.ToFullPath()); + if (res.has_error()) { + CLI_LOG_ERROR("Failed to stop server: " << res.error()); + return; } -} + CLI_LOG("Server stopped!"); +} }; // namespace commands diff --git a/engine/cli/main.cc b/engine/cli/main.cc index a03c5adf0..52fc5591f 100644 --- a/engine/cli/main.cc +++ b/engine/cli/main.cc @@ -88,6 +88,8 @@ int main(int argc, char* argv[]) { return 1; } + curl_global_init(CURL_GLOBAL_DEFAULT); + bool should_install_server = false; bool verbose = false; for (int i = 0; i < argc; i++) { diff --git a/engine/controllers/server.h b/engine/controllers/server.h index 2889e7ed1..5d6b8ded4 100644 --- a/engine/controllers/server.h +++ b/engine/controllers/server.h @@ -9,11 +9,6 @@ #include -#ifndef NDEBUG -// crash the server in debug mode, otherwise send an http 500 error -#define CPPHTTPLIB_NO_EXCEPTIONS 1 -#endif - #include #include #include "common/base.h" diff --git a/engine/e2e-test/main.py b/engine/e2e-test/main.py index add2354f3..9ef2970f9 100644 --- a/engine/e2e-test/main.py +++ b/engine/e2e-test/main.py @@ -8,8 +8,7 @@ ### models, keeps in order, note that we only uninstall engine after finishing all models test from test_api_model_pull_direct_url import TestApiModelPullDirectUrl -from test_api_model_start import TestApiModelStart -from test_api_model_stop import TestApiModelStop +from test_api_model_start_stop import TestApiModelStartStop from test_api_model_get import TestApiModelGet from test_api_model_list import TestApiModelList from test_api_model_update import TestApiModelUpdate diff --git a/engine/e2e-test/test_api_docker.py b/engine/e2e-test/test_api_docker.py index 2f06e6edb..6856e05f4 100644 --- a/engine/e2e-test/test_api_docker.py +++ b/engine/e2e-test/test_api_docker.py @@ -1,18 +1,14 @@ import pytest import requests -import os - -from pathlib import Path -from test_runner import ( - wait_for_websocket_download_success_event -) +from test_runner import wait_for_websocket_download_success_event repo_branches = ["tinyllama:1b-gguf"] + class TestCortexsoModels: @pytest.fixture(autouse=True) - def setup_and_teardown(self, request): + def setup_and_teardown(self): yield @pytest.mark.parametrize("model_url", repo_branches) @@ -20,20 +16,20 @@ def setup_and_teardown(self, request): async def test_models_on_cortexso_hub(self, model_url): print("Pull model from cortexso hub") # Pull model from cortexso hub - json_body = { - "model": model_url - } + json_body = {"model": model_url} response = requests.post("http://localhost:3928/v1/models/pull", json=json_body) assert response.status_code == 200, f"Failed to pull model: {model_url}" - + await wait_for_websocket_download_success_event(timeout=None) - + print("Check if the model was pulled successfully") # Check if the model was pulled successfully get_model_response = requests.get( f"http://127.0.0.1:3928/v1/models/{model_url}" ) - assert get_model_response.status_code == 200, f"Failed to fetch model: {model_url}" + assert ( + get_model_response.status_code == 200 + ), f"Failed to fetch model: {model_url}" assert ( get_model_response.json()["model"] == model_url ), f"Unexpected model name for: {model_url}" @@ -47,7 +43,10 @@ async def test_models_on_cortexso_hub(self, model_url): print("Start the model") # Start the model - response = requests.post("http://localhost:3928/v1/models/start", json=json_body) + response = requests.post( + "http://localhost:3928/v1/models/start", json=json_body + ) + print(response.json()) assert response.status_code == 200, f"status_code: {response.status_code}" print("Send an inference request") @@ -55,26 +54,24 @@ async def test_models_on_cortexso_hub(self, model_url): inference_json_body = { "frequency_penalty": 0.2, "max_tokens": 4096, - "messages": [ - { - "content": "", - "role": "user" - } - ], + "messages": [{"content": "", "role": "user"}], "model": model_url, "presence_penalty": 0.6, - "stop": [ - "End" - ], + "stop": ["End"], "stream": False, "temperature": 0.8, - "top_p": 0.95 - } - response = requests.post("http://localhost:3928/v1/chat/completions", json=inference_json_body, headers={"Content-Type": "application/json"}) - assert response.status_code == 200, f"status_code: {response.status_code} response: {response.json()}" + "top_p": 0.95, + } + response = requests.post( + "http://localhost:3928/v1/chat/completions", + json=inference_json_body, + headers={"Content-Type": "application/json"}, + ) + assert ( + response.status_code == 200 + ), f"status_code: {response.status_code} response: {response.json()}" print("Stop the model") # Stop the model response = requests.post("http://localhost:3928/v1/models/stop", json=json_body) assert response.status_code == 200, f"status_code: {response.status_code}" - diff --git a/engine/e2e-test/test_api_engine_uninstall.py b/engine/e2e-test/test_api_engine_uninstall.py index 2a491d07a..1951e5c3a 100644 --- a/engine/e2e-test/test_api_engine_uninstall.py +++ b/engine/e2e-test/test_api_engine_uninstall.py @@ -1,9 +1,10 @@ -import pytest import time + +import pytest import requests from test_runner import ( run, - start_server, + start_server_if_needed, stop_server, wait_for_websocket_download_success_event, ) @@ -14,22 +15,20 @@ class TestApiEngineUninstall: @pytest.fixture(autouse=True) def setup_and_teardown(self): # Setup - success = start_server() - if not success: - raise Exception("Failed to start server") + start_server_if_needed() yield # Teardown stop_server() - + @pytest.mark.asyncio async def test_engines_uninstall_llamacpp_should_be_successful(self): response = requests.post("http://localhost:3928/v1/engines/llama-cpp/install") assert response.status_code == 200 await wait_for_websocket_download_success_event(timeout=None) time.sleep(30) - + response = requests.delete("http://localhost:3928/v1/engines/llama-cpp/install") assert response.status_code == 200 diff --git a/engine/e2e-test/test_api_model_start.py b/engine/e2e-test/test_api_model_start_stop.py similarity index 74% rename from engine/e2e-test/test_api_model_start.py rename to engine/e2e-test/test_api_model_start_stop.py index b3e33d113..78c20e8da 100644 --- a/engine/e2e-test/test_api_model_start.py +++ b/engine/e2e-test/test_api_model_start_stop.py @@ -1,26 +1,28 @@ -import pytest import time + +import pytest import requests -from test_runner import run, start_server, stop_server from test_runner import ( - wait_for_websocket_download_success_event + run, + start_server_if_needed, + stop_server, + wait_for_websocket_download_success_event, ) -class TestApiModelStart: + + +class TestApiModelStartStop: @pytest.fixture(autouse=True) def setup_and_teardown(self): # Setup - stop_server() - success = start_server() - if not success: - raise Exception("Failed to start server") + start_server_if_needed() run("Delete model", ["models", "delete", "tinyllama:gguf"]) yield # Teardown stop_server() - + @pytest.mark.asyncio async def test_models_start_should_be_successful(self): response = requests.post("http://localhost:3928/v1/engines/llama-cpp/install") @@ -28,16 +30,17 @@ async def test_models_start_should_be_successful(self): await wait_for_websocket_download_success_event(timeout=None) # TODO(sang) need to fix for cuda download time.sleep(30) - - json_body = { - "model": "tinyllama:gguf" - } + + json_body = {"model": "tinyllama:gguf"} response = requests.post("http://localhost:3928/v1/models/pull", json=json_body) assert response.status_code == 200, f"Failed to pull model: tinyllama:gguf" await wait_for_websocket_download_success_event(timeout=None) - + json_body = {"model": "tinyllama:gguf"} response = requests.post( "http://localhost:3928/v1/models/start", json=json_body ) assert response.status_code == 200, f"status_code: {response.status_code}" + + response = requests.post("http://localhost:3928/v1/models/stop", json=json_body) + assert response.status_code == 200, f"status_code: {response.status_code}" diff --git a/engine/e2e-test/test_api_model_stop.py b/engine/e2e-test/test_api_model_stop.py deleted file mode 100644 index 4fc7a55e2..000000000 --- a/engine/e2e-test/test_api_model_stop.py +++ /dev/null @@ -1,38 +0,0 @@ -import pytest -import time -import requests -from test_runner import run, start_server, stop_server -from test_runner import ( - wait_for_websocket_download_success_event -) - -class TestApiModelStop: - - @pytest.fixture(autouse=True) - def setup_and_teardown(self): - # Setup - stop_server() - success = start_server() - if not success: - raise Exception("Failed to start server") - - yield - - run("Uninstall engine", ["engines", "uninstall", "llama-cpp"]) - # Teardown - stop_server() - - @pytest.mark.asyncio - async def test_models_stop_should_be_successful(self): - response = requests.post("http://localhost:3928/v1/engines/llama-cpp/install") - assert response.status_code == 200 - await wait_for_websocket_download_success_event(timeout=None) - time.sleep(30) - - json_body = {"model": "tinyllama:gguf"} - response = requests.post( - "http://localhost:3928/v1/models/start", json=json_body - ) - assert response.status_code == 200, f"status_code: {response.status_code}" - response = requests.post("http://localhost:3928/v1/models/stop", json=json_body) - assert response.status_code == 200, f"status_code: {response.status_code}" diff --git a/engine/e2e-test/test_cli_model_delete.py b/engine/e2e-test/test_cli_model_delete.py index d0ba43ec1..06cc3a4c3 100644 --- a/engine/e2e-test/test_cli_model_delete.py +++ b/engine/e2e-test/test_cli_model_delete.py @@ -1,11 +1,13 @@ import pytest import requests -from test_runner import popen, run -from test_runner import start_server, stop_server from test_runner import ( - wait_for_websocket_download_success_event + run, + start_server, + stop_server, + wait_for_websocket_download_success_event, ) + class TestCliModelDelete: @pytest.fixture(autouse=True) @@ -22,15 +24,13 @@ def setup_and_teardown(self): run("Delete model", ["models", "delete", "tinyllama:gguf"]) stop_server() - @pytest.mark.asyncio + @pytest.mark.asyncio async def test_models_delete_should_be_successful(self): - json_body = { - "model": "tinyllama:gguf" - } + json_body = {"model": "tinyllama:gguf"} response = requests.post("http://localhost:3928/v1/models/pull", json=json_body) assert response.status_code == 200, f"Failed to pull model: tinyllama:gguf" await wait_for_websocket_download_success_event(timeout=None) - + exit_code, output, error = run( "Delete model", ["models", "delete", "tinyllama:gguf"] ) diff --git a/engine/e2e-test/test_runner.py b/engine/e2e-test/test_runner.py index 843e669b4..dfc515df7 100644 --- a/engine/e2e-test/test_runner.py +++ b/engine/e2e-test/test_runner.py @@ -7,6 +7,7 @@ import threading import time import requests +from requests.exceptions import RequestException from typing import List import websockets @@ -72,6 +73,42 @@ def start_server() -> bool: return start_server_nix() +def start_server_if_needed(): + """ + Start the server if it is not already running. + Sending a healthz request to the server to check if it is running. + """ + try: + response = requests.get( + 'http://localhost:3928/healthz', + timeout=5 + ) + if response.status_code == 200: + print("Server is already running") + except RequestException as e: + print("Server is not running. Starting the server...") + start_server() + + +def pull_model_if_needed(model_id: str = "tinyllama:gguf"): + """ + Pull the model if it is not already pulled. + """ + should_pull = False + try: + response = requests.get("http://localhost:3928/models/" + model_id, + timeout=5 + ) + if response.status_code != 200: + should_pull = True + + except RequestException as e: + print("Http error occurred: " + e) + + if should_pull: + run("Pull model", ["pull", model_id], timeout=10*60) + + def start_server_nix() -> bool: executable = getExecutablePath() process = subprocess.Popen( diff --git a/engine/main.cc b/engine/main.cc index b39c4c6e2..7c37e27fe 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -219,6 +219,8 @@ int main(int argc, char* argv[]) { return 1; } + curl_global_init(CURL_GLOBAL_DEFAULT); + // avoid printing logs to terminal is_server = true; diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 1ec1a68cf..a37cea12c 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -1,4 +1,6 @@ #include "model_service.h" +#include +#include #include #include #include @@ -7,7 +9,6 @@ #include "config/yaml_config.h" #include "database/models.h" #include "hardware_service.h" -#include "httplib.h" #include "utils/cli_selection_utils.h" #include "utils/cortex_utils.h" #include "utils/engine_constants.h" @@ -79,7 +80,8 @@ cpp::result GetDownloadTask( url_parser::Url url = { .protocol = "https", .host = kHuggingFaceHost, - .pathParams = {"api", "models", "cortexso", modelId, "tree", branch}}; + .pathParams = {"api", "models", "cortexso", modelId, "tree", branch}, + }; auto result = curl_utils::SimpleGetJson(url.ToFullPath()); if (result.has_error()) { @@ -812,9 +814,9 @@ cpp::result ModelService::StartModel( inference_svc_->LoadModel(std::make_shared(json_data)); auto status = std::get<0>(ir)["status_code"].asInt(); auto data = std::get<1>(ir); - if (status == httplib::StatusCode::OK_200) { + if (status == drogon::k200OK) { return StartModelResult{.success = true, .warning = warning}; - } else if (status == httplib::StatusCode::Conflict_409) { + } else if (status == drogon::k409Conflict) { CTL_INF("Model '" + model_handle + "' is already loaded"); return StartModelResult{.success = true, .warning = warning}; } else { @@ -859,7 +861,7 @@ cpp::result ModelService::StopModel( auto ir = inference_svc_->UnloadModel(engine_name, model_handle); auto status = std::get<0>(ir)["status_code"].asInt(); auto data = std::get<1>(ir); - if (status == httplib::StatusCode::OK_200) { + if (status == drogon::k200OK) { if (bypass_check) { bypass_stop_check_set_.erase(model_handle); } @@ -901,7 +903,7 @@ cpp::result ModelService::GetModelStatus( inference_svc_->GetModelStatus(std::make_shared(root)); auto status = std::get<0>(ir)["status_code"].asInt(); auto data = std::get<1>(ir); - if (status == httplib::StatusCode::OK_200) { + if (status == drogon::k200OK) { return true; } else { CTL_ERR("Model failed to get model status with status code: " << status); diff --git a/engine/test/components/CMakeLists.txt b/engine/test/components/CMakeLists.txt index b92770a65..4a15b7c8b 100644 --- a/engine/test/components/CMakeLists.txt +++ b/engine/test/components/CMakeLists.txt @@ -17,7 +17,6 @@ add_executable(${PROJECT_NAME} find_package(Drogon CONFIG REQUIRED) find_package(GTest CONFIG REQUIRED) find_package(yaml-cpp CONFIG REQUIRED) -find_package(httplib CONFIG REQUIRED) find_package(unofficial-minizip CONFIG REQUIRED) find_package(LibArchive REQUIRED) find_package(CURL REQUIRED) @@ -26,7 +25,6 @@ find_package(SQLiteCpp REQUIRED) target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gtest_main yaml-cpp::yaml-cpp ${CMAKE_THREAD_LIBS_INIT}) -target_link_libraries(${PROJECT_NAME} PRIVATE httplib::httplib) target_link_libraries(${PROJECT_NAME} PRIVATE unofficial::minizip::minizip) target_link_libraries(${PROJECT_NAME} PRIVATE LibArchive::LibArchive) target_link_libraries(${PROJECT_NAME} PRIVATE CURL::libcurl) diff --git a/engine/utils/config_yaml_utils.h b/engine/utils/config_yaml_utils.h index 3176339a0..73c990996 100644 --- a/engine/utils/config_yaml_utils.h +++ b/engine/utils/config_yaml_utils.h @@ -3,8 +3,8 @@ #include #include #include -#include #include +#include #include "utils/logging_utils.h" #include "utils/result.hpp" #include "yaml-cpp/yaml.h" diff --git a/engine/utils/cortex_utils.h b/engine/utils/cortex_utils.h index 895217250..4d0a956a9 100644 --- a/engine/utils/cortex_utils.h +++ b/engine/utils/cortex_utils.h @@ -12,7 +12,6 @@ #include #include #include -#include #if defined(__linux__) #include #include @@ -120,5 +119,4 @@ inline std::string GetCurrentPath() { #endif } #endif - } // namespace cortex_utils diff --git a/engine/utils/cpuid/detail/init_linux_gcc_arm.h b/engine/utils/cpuid/detail/init_linux_gcc_arm.h index f10d360fd..cfd4059a5 100644 --- a/engine/utils/cpuid/detail/init_linux_gcc_arm.h +++ b/engine/utils/cpuid/detail/init_linux_gcc_arm.h @@ -21,7 +21,7 @@ void init_cpuinfo(CpuInfo::Impl& info) { // The Advanced SIMD (NEON) instruction set is required on AArch64 // (64-bit ARM). Note that /proc/cpuinfo will display "asimd" instead of // "neon" in the Features list on a 64-bit ARM CPU. - info.m_has_neon = true; + info.has_neon = true; #else // Runtime detection of NEON is necessary on 32-bit ARM CPUs // diff --git a/engine/utils/curl_utils.h b/engine/utils/curl_utils.h index 7bfbec44c..c56808b56 100644 --- a/engine/utils/curl_utils.h +++ b/engine/utils/curl_utils.h @@ -73,8 +73,6 @@ inline std::optional> GetHeaders( inline cpp::result SimpleGet(const std::string& url, const int timeout = -1) { - // Initialize libcurl - curl_global_init(CURL_GLOBAL_DEFAULT); auto curl = curl_easy_init(); if (!curl) { diff --git a/engine/vcpkg.json b/engine/vcpkg.json index 09ddb3368..36fa322a3 100644 --- a/engine/vcpkg.json +++ b/engine/vcpkg.json @@ -3,10 +3,6 @@ "curl", "gtest", "cli11", - { - "name": "cpp-httplib", - "features": ["openssl"] - }, "drogon", "jsoncpp", "minizip", From 208c3f51b4d69b29b96138467333c919ae47d614 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 Date: Fri, 29 Nov 2024 12:15:09 +0700 Subject: [PATCH 02/44] chore: update structured output documentation base on new cortex.llamacpp engine --- docs/docs/guides/structured-outputs.md | 264 ++++++++++--------------- 1 file changed, 102 insertions(+), 162 deletions(-) diff --git a/docs/docs/guides/structured-outputs.md b/docs/docs/guides/structured-outputs.md index f683d7c4b..1fe3f789b 100644 --- a/docs/docs/guides/structured-outputs.md +++ b/docs/docs/guides/structured-outputs.md @@ -5,17 +5,68 @@ title: Structured Outputs Structured outputs, or response formats, are a feature designed to generate responses in a defined JSON schema, enabling more predictable and machine-readable outputs. This is essential for applications where data consistency and format adherence are crucial, such as automated data processing, structured data generation, and integrations with other systems. -In recent developments, systems like OpenAI's models have excelled at producing these structured outputs. However, while open-source models like Llama 3.1 and Mistral Nemo offer powerful capabilities, they currently struggle to produce reliably structured JSON outputs required for advanced use cases. This often stems from the models not being specifically trained on tasks demanding strict schema adherence. +In recent developments, systems like OpenAI's models have excelled at producing these structured outputs. However, while open-source models like Llama 3.1 and Mistral Nemo offer powerful capabilities, they currently struggle to produce reliably structured JSON outputs required for advanced use cases. This guide explores the concept of structured outputs using these models, highlights the challenges faced in achieving consistent output formatting, and provides strategies for improving output accuracy, particularly when using models that don't inherently support this feature as robustly as GPT models. By understanding these nuances, users can make informed decisions when choosing models for tasks requiring structured outputs, ensuring that the tools they select align with their project's formatting requirements and expected accuracy. -The Structured Outputs/Response Format feature in [OpenAI](https://platform.openai.com/docs/guides/structured-outputs) is fundamentally a prompt engineering challenge. While its goal is to use system prompts to generate JSON output matching a specific schema, popular open-source models like Llama 3.1 and Mistral Nemo struggle to consistently generate exact JSON output that matches the requirements. An easy way to directly guild the model to reponse in json format in system message: +The Structured Outputs/Response Format feature in [OpenAI](https://platform.openai.com/docs/guides/structured-outputs) is fundamentally a prompt engineering challenge. While its goal is to use system prompts to generate JSON output matching a specific schema, popular open-source models like Llama 3.1 and Mistral Nemo struggle to consistently generate exact JSON output that matches the requirements. An easy way to directly guild the model to reponse in json format in system message, you just need to pass the pydantic model to `response_format`: ``` +from pydantic import BaseModel +from openai import OpenAI +import json +ENDPOINT = "http://localhost:39281/v1" +MODEL = "llama3.1:8b-gguf-q4-km" + +client = OpenAI( + base_url=ENDPOINT, + api_key="not-needed" +) + + +class CalendarEvent(BaseModel): + name: str + date: str + participants: list[str] + + +completion = client.beta.chat.completions.parse( + model=MODEL, + messages=[ + {"role": "system", "content": "Extract the event information."}, + {"role": "user", "content": "Alice and Bob are going to a science fair on Friday."}, + ], + response_format=CalendarEvent, + stop=["<|eot_id|>"] +) + +event = completion.choices[0].message.parsed + +print(json.dumps(event.dict(), indent=4)) +``` + +The output of the model like this + +``` +{ + "name": "science fair", + "date": "Friday", + "participants": [ + "Alice", + "Bob" + ] +} +``` + +With more complex json format, llama3.1 still struggle to response correct answer: + +``` + from openai import OpenAI from pydantic import BaseModel +import json ENDPOINT = "http://localhost:39281/v1" MODEL = "llama3.1:8b-gguf-q4-km" client = OpenAI( @@ -39,203 +90,92 @@ completion_payload = { ] } -response = client.chat.completions.create( + +class Step(BaseModel): + explanation: str + output: str + + +class MathReasoning(BaseModel): + steps: list[Step] + final_answer: str + + +response = client.beta.chat.completions.parse( top_p=0.9, temperature=0.6, model=MODEL, - messages=completion_payload["messages"] + messages=completion_payload["messages"], + stop=["<|eot_id|>"], + response_format=MathReasoning ) -print(response) +math_reasoning = response.choices[0].message.parsed +print(json.dumps(math_reasoning.dict(), indent=4)) ``` -The output of the model like this +The output of model looks like this ``` - -ChatCompletion( - id='OZI0q8hghjYQY7NXlLId', - choices=[ - Choice( - finish_reason=None, - index=0, - logprobs=None, - message=ChatCompletionMessage( - content='''Here's how you can solve it: - { "steps": [ { - "explanation": "First, we need to isolate the variable x. To do this, subtract 7 from both sides of the equation.", + "explanation": "To isolate the variable x, we need to get rid of the constant term on the left-hand side. We can do this by subtracting 7 from both sides of the equation.", "output": "8x + 7 - 7 = -23 - 7" }, { - "explanation": "This simplifies to 8x = -30", + "explanation": "Simplifying the left-hand side, we get:", "output": "8x = -30" }, { - "explanation": "Next, divide both sides of the equation by 8 to solve for x.", - "output": "(8x) / 8 = -30 / 8" + "explanation": "Now, to solve for x, we need to isolate it by dividing both sides of the equation by 8.", + "output": "8x / 8 = -30 / 8" }, { - "explanation": "This simplifies to x = -3.75", + "explanation": "Simplifying the right-hand side, we get:", "output": "x = -3.75" } ], - "final_output": "-3.75" -}''', - refusal=None, - role='assistant', - audio=None, - function_call=None, - tool_calls=None - ) - ) - ], - created=1730645716, - model='_', - object='chat.completion', - service_tier=None, - system_fingerprint='_', - usage=CompletionUsage( - completion_tokens=190, - prompt_tokens=78, - total_tokens=268, - completion_tokens_details=None, - prompt_tokens_details=None - ) -) + "final_answer": "There is no final answer yet, let's break it down step by step." +} ``` -From the output, you can easily parse the response to get correct json format as you guild the model in the system prompt. +Even if the model can generate correct format but the information doesn't 100% accurate, the `final_answer` should be `-3.75` instead of `There is no final answer yet, let's break it down step by step.`. -Howerver, open source model like llama3.1 or mistral nemo still truggling on mimic newest OpenAI API on response format. For example, consider this request created using the OpenAI library with very simple request like [OpenAI](https://platform.openai.com/docs/guides/structured-outputs#chain-of-thought): +Another usecase for structured output with json response, you can provide the `response_format={"type" : "json_object"}`, the model will be force to generate json output. ``` -from openai import OpenAI -ENDPOINT = "http://localhost:39281/v1" -MODEL = "llama3.1:8b-gguf-q4-km" -client = OpenAI( - base_url=ENDPOINT, - api_key="not-needed" -) - -class Step(BaseModel): - explanation: str - output: str - - -class MathReasoning(BaseModel): - steps: List[Step] - final_answer: str - - -completion_payload = { - "messages": [ - {"role": "system", "content": f"You are a helpful math tutor. Guide the user through the solution step by step.\n"}, - {"role": "user", "content": "how can I solve 8x + 7 = -23"} - ] -} - -response = client.beta.chat.completions.parse( - top_p=0.9, - temperature=0.6, +json_format = {"song_name":"release date"} +completion = client.chat.completions.create( model=MODEL, - messages= completion_payload["messages"], - response_format=MathReasoning + messages=[ + {"role": "system", "content": f"You are a helpful assistant, you must reponse with this format: '{json_format}'"}, + {"role": "user", "content": "List 10 songs for me"} + ], + response_format={"type": "json_object"}, + stop=["<|eot_id|>"] ) -``` - -The response format parsed by OpenAI before sending to the server is quite complex for the `MathReasoning` schema. Unlike GPT models, Llama 3.1 and Mistral Nemo cannot reliably generate responses that can be parsed as shown in the [OpenAI tutorial](https://platform.openai.com/docs/guides/structured-outputs/example-response). This may be due to these models not being trained on similar structured output tasks. -``` -"response_format" : - { - "json_schema" : - { - "name" : "MathReasoning", - "schema" : - { - "$defs" : - { - "Step" : - { - "additionalProperties" : false, - "properties" : - { - "explanation" : - { - "title" : "Explanation", - "type" : "string" - }, - "output" : - { - "title" : "Output", - "type" : "string" - } - }, - "required" : - [ - "explanation", - "output" - ], - "title" : "Step", - "type" : "object" - } - }, - "additionalProperties" : false, - "properties" : - { - "final_answer" : - { - "title" : "Final Answer", - "type" : "string" - }, - "steps" : - { - "items" : - { - "$ref" : "#/$defs/Step" - }, - "title" : "Steps", - "type" : "array" - } - }, - "required" : - [ - "steps", - "final_answer" - ], - "title" : "MathReasoning", - "type" : "object" - }, - "strict" : true - }, - "type" : "json_schema" - } +print(json.dumps(json.loads(completion.choices[0].message.content), indent=4)) ``` -The response for this request by `mistral-nemo` and `llama3.1` can not be used to parse result like in the [original tutorial by openAI](https://platform.openai.com/docs/guides/structured-outputs/example-response). Maybe `llama3.1` and `mistral-nemo` didn't train with this kind of data, so it fails to handle this case. +The output will looks like this: ``` -Response: { - "choices" : - [ - { - "finish_reason" : null, - "index" : 0, - "message" : - { - "content" : "Here's a step-by-step guide to solving the equation 8x + 7 = -23:\n\n```json\n{\n \"name\": \"MathReasoning\",\n \"schema\": {\n \"$defs\": {\n \"Step\": {\n \"additionalProperties\": false,\n \"properties\": {\n \"explanation\": {\"title\": \"Explanation\", \"type\": \"string\"},\n \"output\": {\"title\": \"Output\", \"type\": \"string\"}\n },\n \"required\": [\"explanation\", \"output\"],\n \"title\": \"Step\",\n \"type\": \"object\"\n }\n },\n \"additionalProperties\": false,\n \"properties\": {\n \"final_answer\": {\"title\": \"Final Answer\", \"type\": \"string\"},\n \"steps\": {\n \"items\": {\"$ref\": \"#/$defs/Step\"},\n \"title\": \"Steps\",\n \"type\": \"array\"\n }\n },\n \"required\": [\"steps\", \"final_answer\"],\n \"title\": \"MathReasoning\",\n \"type\": \"object\"\n },\n \"strict\": true\n}\n```\n\n1. **Subtract 7 from both sides** to isolate the term with x:\n\n - Explanation: To get rid of the +7 on the left side, we add -7 to both sides of the equation.\n - Output: `8x + 7 - 7 = -23 - 7`\n\n This simplifies to:\n ```\n 8x = -30\n ```\n\n2. **Divide both sides by 8** to solve for x:\n\n - Explanation: To get rid of the 8 on the left side, we multiply both sides of the equation by the reciprocal of 8, which is 1/8.\n - Output: `8x / 8 = -30 / 8`\n\n This simplifies to:\n ```\n x = -3.75\n ```\n\nSo, the final answer is:\n\n- Final Answer: `x = -3.75`", - "role" : "assistant" - } - } - ], +{ + "Happy": "2013", + "Uptown Funk": "2014", + "Shut Up and Dance": "2014", + "Can't Stop the Feeling!": "2016", + "We Found Love": "2011", + "All About That Bass": "2014", + "Radioactive": "2012", + "SexyBack": "2006", + "Crazy": "2007", + "Viva la Vida": "2008" +} ``` - - - ## Limitations of Open-Source Models for Structured Outputs While the concept of structured outputs is compelling, particularly for applications requiring machine-readable data, it's important to understand that not all models support this capability equally. Open-source models such as Llama 3.1 and Mistral Nemo face notable challenges in generating outputs that adhere strictly to defined JSON schemas. Here are the key limitations: From 1f5631df834dfae6bdd0bc8fcf7b7d3f174a4184 Mon Sep 17 00:00:00 2001 From: James Date: Fri, 29 Nov 2024 14:30:27 +0700 Subject: [PATCH 03/44] fix: adding mutex for loaded engine map --- engine/services/engine_service.cc | 135 +++++++++++++++++------------- engine/services/engine_service.h | 4 +- 2 files changed, 81 insertions(+), 58 deletions(-) diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 4eebff669..c52e32ef0 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -4,7 +4,6 @@ #include #include "algorithm" #include "utils/archive_utils.h" -#include "utils/cortex_utils.h" #include "utils/engine_constants.h" #include "utils/engine_matcher_utils.h" #include "utils/file_manager_utils.h" @@ -631,13 +630,15 @@ EngineService::GetInstalledEngineVariants(const std::string& engine) const { return variants; } -bool EngineService::IsEngineLoaded(const std::string& engine) const { +bool EngineService::IsEngineLoaded(const std::string& engine) { + std::lock_guard lock(engines_mutex_); auto ne = NormalizeEngine(engine); return engines_.find(ne) != engines_.end(); } cpp::result EngineService::GetLoadedEngine( const std::string& engine_name) { + std::lock_guard lock(engines_mutex_); auto ne = NormalizeEngine(engine_name); if (engines_.find(ne) == engines_.end()) { return cpp::fail("Engine " + engine_name + " is not loaded yet!"); @@ -708,19 +709,19 @@ cpp::result EngineService::LoadEngine( auto add_dll = [this](const std::string& e_type, const std::filesystem::path& p) { if (auto cookie = AddDllDirectory(p.c_str()); cookie != 0) { - CTL_DBG("Added dll directory: " << p); + CTL_DBG("Added dll directory: " << p.string()); engines_[e_type].cookie = cookie; } else { - CTL_WRN("Could not add dll directory: " << p); + CTL_WRN("Could not add dll directory: " << p.string()); } auto cuda_path = file_manager_utils::GetCudaToolkitPath(e_type); if (auto cuda_cookie = AddDllDirectory(cuda_path.c_str()); cuda_cookie != 0) { - CTL_DBG("Added cuda dll directory: " << p); + CTL_DBG("Added cuda dll directory: " << p.string()); engines_[e_type].cuda_cookie = cuda_cookie; } else { - CTL_WRN("Could not add cuda dll directory: " << p); + CTL_WRN("Could not add cuda dll directory: " << p.string()); } }; @@ -732,16 +733,20 @@ cpp::result EngineService::LoadEngine( should_use_dll_search_path) { if (IsEngineLoaded(kLlamaRepo) && ne == kTrtLlmRepo && should_use_dll_search_path) { - // Remove llamacpp dll directory - if (!RemoveDllDirectory(engines_[kLlamaRepo].cookie)) { - CTL_WRN("Could not remove dll directory: " << kLlamaRepo); - } else { - CTL_DBG("Removed dll directory: " << kLlamaRepo); - } - if (!RemoveDllDirectory(engines_[kLlamaRepo].cuda_cookie)) { - CTL_WRN("Could not remove cuda dll directory: " << kLlamaRepo); - } else { - CTL_DBG("Removed cuda dll directory: " << kLlamaRepo); + + { + std::lock_guard lock(engines_mutex_); + // Remove llamacpp dll directory + if (!RemoveDllDirectory(engines_[kLlamaRepo].cookie)) { + CTL_WRN("Could not remove dll directory: " << kLlamaRepo); + } else { + CTL_DBG("Removed dll directory: " << kLlamaRepo); + } + if (!RemoveDllDirectory(engines_[kLlamaRepo].cuda_cookie)) { + CTL_WRN("Could not remove cuda dll directory: " << kLlamaRepo); + } else { + CTL_DBG("Removed cuda dll directory: " << kLlamaRepo); + } } add_dll(ne, engine_dir_path); @@ -752,8 +757,11 @@ cpp::result EngineService::LoadEngine( } } #endif - engines_[ne].dl = - std::make_unique(engine_dir_path.string(), "engine"); + { + std::lock_guard lock(engines_mutex_); + engines_[ne].dl = std::make_unique( + engine_dir_path.string(), "engine"); + } #if defined(__linux__) const char* name = "LD_LIBRARY_PATH"; auto data = getenv(name); @@ -774,65 +782,78 @@ cpp::result EngineService::LoadEngine( } catch (const cortex_cpp::dylib::load_error& e) { CTL_ERR("Could not load engine: " << e.what()); - engines_.erase(ne); + { + std::lock_guard lock(engines_mutex_); + engines_.erase(ne); + } return cpp::fail("Could not load engine " + ne + ": " + e.what()); } - auto func = engines_[ne].dl->get_function("get_engine"); - engines_[ne].engine = func(); + { + std::lock_guard lock(engines_mutex_); + auto func = engines_[ne].dl->get_function("get_engine"); + engines_[ne].engine = func(); - auto& en = std::get(engines_[ne].engine); - if (ne == kLlamaRepo) { //fix for llamacpp engine first - auto config = file_manager_utils::GetCortexConfig(); - if (en->IsSupported("SetFileLogger")) { - en->SetFileLogger(config.maxLogLines, - (std::filesystem::path(config.logFolderPath) / - std::filesystem::path(config.logLlamaCppPath)) - .string()); - } else { - CTL_WRN("Method SetFileLogger is not supported yet"); - } - if (en->IsSupported("SetLogLevel")) { - en->SetLogLevel(logging_utils_helper::global_log_level); - } else { - CTL_WRN("Method SetLogLevel is not supported yet"); + auto& en = std::get(engines_[ne].engine); + if (ne == kLlamaRepo) { //fix for llamacpp engine first + auto config = file_manager_utils::GetCortexConfig(); + if (en->IsSupported("SetFileLogger")) { + en->SetFileLogger(config.maxLogLines, + (std::filesystem::path(config.logFolderPath) / + std::filesystem::path(config.logLlamaCppPath)) + .string()); + } else { + CTL_WRN("Method SetFileLogger is not supported yet"); + } + if (en->IsSupported("SetLogLevel")) { + en->SetLogLevel(logging_utils_helper::global_log_level); + } else { + CTL_WRN("Method SetLogLevel is not supported yet"); + } } + CTL_DBG("loaded engine: " << ne); } - CTL_DBG("Loaded engine: " << ne); return {}; } cpp::result EngineService::UnloadEngine( const std::string& engine) { auto ne = NormalizeEngine(engine); - if (!IsEngineLoaded(ne)) { - return cpp::fail("Engine " + ne + " is not loaded yet!"); - } - EngineI* e = std::get(engines_[ne].engine); - delete e; + { + std::lock_guard lock(engines_mutex_); + if (!IsEngineLoaded(ne)) { + return cpp::fail("Engine " + ne + " is not loaded yet!"); + } + EngineI* e = std::get(engines_[ne].engine); + delete e; + #if defined(_WIN32) - if (!RemoveDllDirectory(engines_[ne].cookie)) { - CTL_WRN("Could not remove dll directory: " << ne); - } else { - CTL_DBG("Removed dll directory: " << ne); - } - if (!RemoveDllDirectory(engines_[ne].cuda_cookie)) { - CTL_WRN("Could not remove cuda dll directory: " << ne); - } else { - CTL_DBG("Removed cuda dll directory: " << ne); - } + if (!RemoveDllDirectory(engines_[ne].cookie)) { + CTL_WRN("Could not remove dll directory: " << ne); + } else { + CTL_DBG("Removed dll directory: " << ne); + } + if (!RemoveDllDirectory(engines_[ne].cuda_cookie)) { + CTL_WRN("Could not remove cuda dll directory: " << ne); + } else { + CTL_DBG("Removed cuda dll directory: " << ne); + } #endif - engines_.erase(ne); + engines_.erase(ne); + } CTL_DBG("Unloaded engine " + ne); return {}; } std::vector EngineService::GetLoadedEngines() { - std::vector loaded_engines; - for (const auto& [key, value] : engines_) { - loaded_engines.push_back(value.engine); + { + std::lock_guard lock(engines_mutex_); + std::vector loaded_engines; + for (const auto& [key, value] : engines_) { + loaded_engines.push_back(value.engine); + } + return loaded_engines; } - return loaded_engines; } cpp::result diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index dee8a530b..a18a276cd 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -51,6 +52,7 @@ class EngineService : public EngineServiceI { #endif }; + std::mutex engines_mutex_; std::unordered_map engines_{}; public: @@ -99,7 +101,7 @@ class EngineService : public EngineServiceI { cpp::result, std::string> GetInstalledEngineVariants(const std::string& engine) const; - bool IsEngineLoaded(const std::string& engine) const; + bool IsEngineLoaded(const std::string& engine); cpp::result GetLoadedEngine( const std::string& engine_name); From 67e35543eeb06f6469682af38b2d77329dcc303e Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 2 Dec 2024 07:49:16 +0700 Subject: [PATCH 04/44] feat: model compatibility API (#1715) * feat: model estimation * fix: cleanup and improve * chore: cleanup * chore: API docs * chore: CLI docs * fix: correct calculation * fix: handle macOS * chore: typo --------- Co-authored-by: vansangpfiev --- docs/docs/cli/models/index.mdx | 7 +- docs/static/openapi/cortex.json | 53 ++ engine/cli/command_line_parser.cc | 13 +- engine/cli/command_line_parser.h | 2 + engine/cli/commands/model_list_cmd.cc | 31 +- engine/cli/commands/model_list_cmd.h | 3 +- engine/controllers/models.cc | 4 + engine/main.cc | 1 + engine/services/model_service.cc | 252 ++++---- engine/services/model_service.h | 9 + engine/utils/hardware/gguf/ggml.h | 235 ++++++++ engine/utils/hardware/gguf/gguf_file.h | 537 ++++++++++++++++++ .../utils/hardware/gguf/gguf_file_estimate.h | 183 ++++++ 13 files changed, 1226 insertions(+), 104 deletions(-) create mode 100644 engine/utils/hardware/gguf/ggml.h create mode 100644 engine/utils/hardware/gguf/gguf_file.h create mode 100644 engine/utils/hardware/gguf/gguf_file_estimate.h diff --git a/docs/docs/cli/models/index.mdx b/docs/docs/cli/models/index.mdx index 5b29069a6..b75bf9d49 100644 --- a/docs/docs/cli/models/index.mdx +++ b/docs/docs/cli/models/index.mdx @@ -120,8 +120,11 @@ For example, it returns the following:w | Option | Description | Required | Default value | Example | |---------------------------|----------------------------------------------------|----------|---------------|----------------------| -| `-h`, `--help` | Display help for command. | No | - | `-h` | - +| `-h`, `--help` | Display help for command. | No | - | `-h` | +| `-e`, `--engine` | Display engines. | No | - | `--engine` | +| `-v`, `--version` | Display version for model. | No | - | `--version` | +| `--cpu_mode` | Display CPU mode. | No | - | `--cpu_mode` | +| `--gpu_mode` | Display GPU mode. | No | - | `--gpu_mode` | ## `cortex models start` :::info diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index da31ab64b..78430294f 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -3940,6 +3940,55 @@ }, "required": ["description", "name", "productName", "status"] }, + "CpuModeDto": { + "type": "object", + "properties": { + "ram": { + "type": "number", + "example": 1024 + } + } + }, + "GpuModeDto": { + "type": "object", + "properties": { + "ram": { + "type": "number", + "example": 1024 + }, + "vram": { + "type": "number", + "example": 1024 + }, + "ngl": { + "type": "number", + "example": 30 + }, + "context_length": { + "type": "number", + "example": 4096 + }, + "recommend_ngl": { + "type": "number", + "example": 33 + } + } + }, + "RecommendDto": { + "type": "object", + "properties": { + "cpu_mode": { + "type": "object", + "$ref": "#/components/schemas/CpuModeDto" + }, + "gpu_mode": { + "type": "array", + "items": { + "$ref": "#/components/schemas/GPUDto" + } + } + } + }, "ModelDto": { "type": "object", "properties": { @@ -4064,6 +4113,10 @@ "type": "string", "description": "The engine to use.", "example": "llamacpp" + }, + "recommendation": { + "type": "object", + "$ref": "#/components/schemas/RecommendDto" } }, "required": ["id"] diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index e1b2f5feb..34c6b9069 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -245,14 +245,19 @@ void CommandLineParser::SetupModelCommands() { "Display engine"); list_models_cmd->add_flag("-v,--version", cml_data_.display_version, "Display version"); + list_models_cmd->add_flag("--cpu_mode", cml_data_.display_cpu_mode, + "Display cpu mode"); + list_models_cmd->add_flag("--gpu_mode", cml_data_.display_gpu_mode, + "Display gpu mode"); list_models_cmd->group(kSubcommands); list_models_cmd->callback([this]() { if (std::exchange(executed_, true)) return; - commands::ModelListCmd().Exec(cml_data_.config.apiServerHost, - std::stoi(cml_data_.config.apiServerPort), - cml_data_.filter, cml_data_.display_engine, - cml_data_.display_version); + commands::ModelListCmd().Exec( + cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), cml_data_.filter, + cml_data_.display_engine, cml_data_.display_version, + cml_data_.display_cpu_mode, cml_data_.display_gpu_mode); }); auto get_models_cmd = diff --git a/engine/cli/command_line_parser.h b/engine/cli/command_line_parser.h index bce83222a..f7ca3f507 100644 --- a/engine/cli/command_line_parser.h +++ b/engine/cli/command_line_parser.h @@ -64,6 +64,8 @@ class CommandLineParser { // for model list bool display_engine = false; bool display_version = false; + bool display_cpu_mode = false; + bool display_gpu_mode = false; std::string filter = ""; std::string log_level = "INFO"; diff --git a/engine/cli/commands/model_list_cmd.cc b/engine/cli/commands/model_list_cmd.cc index c63ed0012..7990563f3 100644 --- a/engine/cli/commands/model_list_cmd.cc +++ b/engine/cli/commands/model_list_cmd.cc @@ -20,7 +20,8 @@ using Row_t = void ModelListCmd::Exec(const std::string& host, int port, const std::string& filter, bool display_engine, - bool display_version) { + bool display_version, bool display_cpu_mode, + bool display_gpu_mode) { // Start server if server is not started yet if (!commands::IsServerAlive(host, port)) { CLI_LOG("Starting server ..."); @@ -39,6 +40,12 @@ void ModelListCmd::Exec(const std::string& host, int port, column_headers.push_back("Version"); } + if (display_cpu_mode) { + column_headers.push_back("CPU Mode"); + } + if (display_gpu_mode) { + column_headers.push_back("GPU Mode"); + } Row_t header{column_headers.begin(), column_headers.end()}; table.add_row(header); table.format().font_color(Color::green); @@ -77,6 +84,28 @@ void ModelListCmd::Exec(const std::string& host, int port, row.push_back(v["version"].asString()); } + if (auto& r = v["recommendation"]; !r.isNull()) { + if (display_cpu_mode) { + if (!r["cpu_mode"].isNull()) { + row.push_back("RAM: " + r["cpu_mode"]["ram"].asString() + " MiB"); + } + } + + if (display_gpu_mode) { + if (!r["gpu_mode"].isNull()) { + std::string s; + s += "ngl: " + r["gpu_mode"][0]["ngl"].asString() + " - "; + s += "context: " + r["gpu_mode"][0]["context_length"].asString() + + " - "; + s += "RAM: " + r["gpu_mode"][0]["ram"].asString() + " MiB - "; + s += "VRAM: " + r["gpu_mode"][0]["vram"].asString() + " MiB - "; + s += "recommended ngl: " + + r["gpu_mode"][0]["recommend_ngl"].asString(); + row.push_back(s); + } + } + } + table.add_row({row.begin(), row.end()}); } } diff --git a/engine/cli/commands/model_list_cmd.h b/engine/cli/commands/model_list_cmd.h index 2e7c446e7..791c1ecf6 100644 --- a/engine/cli/commands/model_list_cmd.h +++ b/engine/cli/commands/model_list_cmd.h @@ -7,6 +7,7 @@ namespace commands { class ModelListCmd { public: void Exec(const std::string& host, int port, const std::string& filter, - bool display_engine = false, bool display_version = false); + bool display_engine = false, bool display_version = false, + bool display_cpu_mode = false, bool display_gpu_mode = false); }; } // namespace commands diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index c51bb3b77..af8061269 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -171,6 +171,10 @@ void Models::ListModel( Json::Value obj = model_config.ToJson(); obj["id"] = model_entry.model; obj["model"] = model_entry.model; + auto es = model_service_->GetEstimation(model_entry.model); + if (es.has_value()) { + obj["recommendation"] = hardware::ToJson(es.value()); + } data.append(std::move(obj)); yaml_handler.Reset(); } catch (const std::exception& e) { diff --git a/engine/main.cc b/engine/main.cc index 7c37e27fe..61571907f 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -20,6 +20,7 @@ #include "utils/event_processor.h" #include "utils/file_logger.h" #include "utils/file_manager_utils.h" +#include "utils/hardware/gguf/gguf_file_estimate.h" #include "utils/logging_utils.h" #include "utils/system_info_utils.h" #include "utils/widechar_conv.h" diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index a37cea12c..cc1f99bdc 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -339,6 +339,53 @@ cpp::result ModelService::HandleDownloadUrlAsync( return download_service_->AddTask(downloadTask, on_finished); } +cpp::result ModelService::GetEstimation( + const std::string& model_handle, const std::string& kv_cache, int n_batch, + int n_ubatch) { + namespace fs = std::filesystem; + namespace fmu = file_manager_utils; + cortex::db::Models modellist_handler; + config::YamlHandler yaml_handler; + + try { + auto model_entry = modellist_handler.GetModelInfo(model_handle); + if (model_entry.has_error()) { + CTL_WRN("Error: " + model_entry.error()); + return cpp::fail(model_entry.error()); + } + auto file_path = fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .parent_path() / + "model.gguf"; + yaml_handler.ModelConfigFromFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + auto mc = yaml_handler.GetModelConfig(); + services::HardwareService hw_svc; + auto hw_info = hw_svc.GetHardwareInfo(); + auto free_vram_MiB = 0u; + for (const auto& gpu : hw_info.gpus) { + free_vram_MiB += gpu.free_vram; + } + +#if defined(__APPLE__) && defined(__MACH__) + free_vram_MiB = hw_info.ram.available_MiB; +#endif + + return hardware::EstimateLLaMACppRun(file_path.string(), + {.ngl = mc.ngl, + .ctx_len = mc.ctx_len, + .n_batch = n_batch, + .n_ubatch = n_ubatch, + .kv_cache_type = kv_cache, + .free_vram_MiB = free_vram_MiB}); + } catch (const std::exception& e) { + return cpp::fail("Fail to get model status with ID '" + model_handle + + "': " + e.what()); + } +} + cpp::result ModelService::HandleUrl( const std::string& url) { auto url_obj = url_parser::FromUrlString(url); @@ -713,100 +760,11 @@ cpp::result ModelService::StartModel( #undef ASSIGN_IF_PRESENT CTL_INF(json_data.toStyledString()); - // TODO(sang) move this into another function - // Calculate ram/vram needed to load model - services::HardwareService hw_svc; - auto hw_info = hw_svc.GetHardwareInfo(); - assert(!!engine_svc_); - auto default_engine = engine_svc_->GetDefaultEngineVariant(kLlamaEngine); - bool is_cuda = false; - if (default_engine.has_error()) { - CTL_INF("Could not get default engine"); - } else { - auto& de = default_engine.value(); - is_cuda = de.variant.find("cuda") != std::string::npos; - CTL_INF("is_cuda: " << is_cuda); - } - - std::optional warning; - if (is_cuda && !system_info_utils::IsNvidiaSmiAvailable()) { - CTL_INF( - "Running cuda variant but nvidia-driver is not installed yet, " - "fallback to CPU mode"); - auto res = engine_svc_->GetInstalledEngineVariants(kLlamaEngine); - if (res.has_error()) { - CTL_WRN("Could not get engine variants"); - return cpp::fail("Nvidia-driver is not installed!"); - } else { - auto& es = res.value(); - std::sort( - es.begin(), es.end(), - [](const EngineVariantResponse& e1, - const EngineVariantResponse& e2) { return e1.name > e2.name; }); - for (auto& e : es) { - CTL_INF(e.name << " " << e.version << " " << e.engine); - // Select the first CPU candidate - if (e.name.find("cuda") == std::string::npos) { - auto r = engine_svc_->SetDefaultEngineVariant(kLlamaEngine, - e.version, e.name); - if (r.has_error()) { - CTL_WRN("Could not set default engine variant"); - return cpp::fail("Nvidia-driver is not installed!"); - } else { - CTL_INF("Change default engine to: " << e.name); - auto rl = engine_svc_->LoadEngine(kLlamaEngine); - if (rl.has_error()) { - return cpp::fail("Nvidia-driver is not installed!"); - } else { - CTL_INF("Engine started"); - is_cuda = false; - warning = "Nvidia-driver is not installed, use CPU variant: " + - e.version + "-" + e.name; - break; - } - } - } - } - // If we reach here, means that no CPU variant to fallback - if (!warning) { - return cpp::fail( - "Nvidia-driver is not installed, no available CPU version to " - "fallback"); - } - } - } - // If in GPU acceleration mode: - // We use all visible GPUs, so only need to sum all free vram - auto free_vram_MiB = 0u; - for (const auto& gpu : hw_info.gpus) { - free_vram_MiB += gpu.free_vram; - } - - auto free_ram_MiB = hw_info.ram.available_MiB; - - auto const& mp = json_data["model_path"].asString(); - auto ngl = json_data["ngl"].asInt(); - // Bypass for now - auto vram_needed_MiB = 0u; - auto ram_needed_MiB = 0u; - - if (vram_needed_MiB > free_vram_MiB && is_cuda) { - CTL_WRN("Not enough VRAM - " << "required: " << vram_needed_MiB - << ", available: " << free_vram_MiB); - - return cpp::fail( - "Not enough VRAM - required: " + std::to_string(vram_needed_MiB) + - " MiB, available: " + std::to_string(free_vram_MiB) + - " MiB - Should adjust ngl to " + - std::to_string(free_vram_MiB / (vram_needed_MiB / ngl) - 1)); - } - - if (ram_needed_MiB > free_ram_MiB) { - CTL_WRN("Not enough RAM - " << "required: " << ram_needed_MiB - << ", available: " << free_ram_MiB); - return cpp::fail( - "Not enough RAM - required: " + std::to_string(ram_needed_MiB) + - " MiB,, available: " + std::to_string(free_ram_MiB) + " MiB"); + auto may_fallback_res = MayFallbackToCpu(json_data["model_path"].asString(), + json_data["ngl"].asInt(), + json_data["ctx_len"].asInt()); + if (may_fallback_res.has_error()) { + return cpp::fail(may_fallback_res.error()); } assert(!!inference_svc_); @@ -814,11 +772,14 @@ cpp::result ModelService::StartModel( inference_svc_->LoadModel(std::make_shared(json_data)); auto status = std::get<0>(ir)["status_code"].asInt(); auto data = std::get<1>(ir); + if (status == drogon::k200OK) { - return StartModelResult{.success = true, .warning = warning}; + return StartModelResult{.success = true, + .warning = may_fallback_res.value()}; } else if (status == drogon::k409Conflict) { CTL_INF("Model '" + model_handle + "' is already loaded"); - return StartModelResult{.success = true, .warning = warning}; + return StartModelResult{ + .success = true, .warning = may_fallback_res.value_or(std::nullopt)}; } else { // only report to user the error CTL_ERR("Model failed to start with status code: " << status); @@ -1047,3 +1008,102 @@ cpp::result ModelService::AbortDownloadModel( const std::string& task_id) { return download_service_->StopTask(task_id); } + +cpp::result, std::string> +ModelService::MayFallbackToCpu(const std::string& model_path, int ngl, + int ctx_len, int n_batch, int n_ubatch, + const std::string& kv_cache_type) { + services::HardwareService hw_svc; + auto hw_info = hw_svc.GetHardwareInfo(); + assert(!!engine_svc_); + auto default_engine = engine_svc_->GetDefaultEngineVariant(kLlamaEngine); + bool is_cuda = false; + if (default_engine.has_error()) { + CTL_INF("Could not get default engine"); + } else { + auto& de = default_engine.value(); + is_cuda = de.variant.find("cuda") != std::string::npos; + CTL_INF("is_cuda: " << is_cuda); + } + + std::optional warning; + if (is_cuda && !system_info_utils::IsNvidiaSmiAvailable()) { + CTL_INF( + "Running cuda variant but nvidia-driver is not installed yet, " + "fallback to CPU mode"); + auto res = engine_svc_->GetInstalledEngineVariants(kLlamaEngine); + if (res.has_error()) { + CTL_WRN("Could not get engine variants"); + return cpp::fail("Nvidia-driver is not installed!"); + } else { + auto& es = res.value(); + std::sort( + es.begin(), es.end(), + [](const EngineVariantResponse& e1, const EngineVariantResponse& e2) { + return e1.name > e2.name; + }); + for (auto& e : es) { + CTL_INF(e.name << " " << e.version << " " << e.engine); + // Select the first CPU candidate + if (e.name.find("cuda") == std::string::npos) { + auto r = engine_svc_->SetDefaultEngineVariant(kLlamaEngine, e.version, + e.name); + if (r.has_error()) { + CTL_WRN("Could not set default engine variant"); + return cpp::fail("Nvidia-driver is not installed!"); + } else { + CTL_INF("Change default engine to: " << e.name); + auto rl = engine_svc_->LoadEngine(kLlamaEngine); + if (rl.has_error()) { + return cpp::fail("Nvidia-driver is not installed!"); + } else { + CTL_INF("Engine started"); + is_cuda = false; + warning = "Nvidia-driver is not installed, use CPU variant: " + + e.version + "-" + e.name; + break; + } + } + } + } + // If we reach here, means that no CPU variant to fallback + if (!warning) { + return cpp::fail( + "Nvidia-driver is not installed, no available CPU version to " + "fallback"); + } + } + } + // If in GPU acceleration mode: + // We use all visible GPUs, so only need to sum all free vram + auto free_vram_MiB = 0u; + for (const auto& gpu : hw_info.gpus) { + free_vram_MiB += gpu.free_vram; + } + + auto free_ram_MiB = hw_info.ram.available_MiB; + +#if defined(__APPLE__) && defined(__MACH__) + free_vram_MiB = free_ram_MiB; +#endif + + hardware::RunConfig rc = {.ngl = ngl, + .ctx_len = ctx_len, + .n_batch = n_batch, + .n_ubatch = n_ubatch, + .kv_cache_type = kv_cache_type, + .free_vram_MiB = free_vram_MiB}; + auto es = hardware::EstimateLLaMACppRun(model_path, rc); + + if (es.gpu_mode.vram_MiB > free_vram_MiB && is_cuda) { + CTL_WRN("Not enough VRAM - " << "required: " << es.gpu_mode.vram_MiB + << ", available: " << free_vram_MiB); + } + + if (es.cpu_mode.ram_MiB > free_ram_MiB) { + CTL_WRN("Not enough RAM - " << "required: " << es.cpu_mode.ram_MiB + << ", available: " << free_ram_MiB); + } + + return warning; +} diff --git a/engine/services/model_service.h b/engine/services/model_service.h index be450fb0b..7235d5a0a 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -7,6 +7,7 @@ #include "config/model_config.h" #include "services/download_service.h" #include "services/inference_service.h" +#include "utils/hardware/gguf/gguf_file_estimate.h" struct ModelPullInfo { std::string id; @@ -96,6 +97,10 @@ class ModelService { bool HasModel(const std::string& id) const; + cpp::result GetEstimation( + const std::string& model_handle, const std::string& kv_cache = "f16", + int n_batch = 2048, int n_ubatch = 2048); + private: /** * Handle downloading model which have following pattern: author/model_name @@ -111,6 +116,10 @@ class ModelService { cpp::result HandleCortexsoModel( const std::string& modelName); + cpp::result, std::string> MayFallbackToCpu( + const std::string& model_path, int ngl, int ctx_len, int n_batch = 2048, + int n_ubatch = 2048, const std::string& kv_cache_type = "f16"); + std::shared_ptr download_service_; std::shared_ptr inference_svc_; std::unordered_set bypass_stop_check_set_; diff --git a/engine/utils/hardware/gguf/ggml.h b/engine/utils/hardware/gguf/ggml.h new file mode 100644 index 000000000..e898fc796 --- /dev/null +++ b/engine/utils/hardware/gguf/ggml.h @@ -0,0 +1,235 @@ +#pragma once +#include +#include +#include +#include +#include "utils/result.hpp" + +namespace hardware { +enum GGMLType { + GGML_TYPE_F32 = 0, + GGML_TYPE_F16 = 1, + GGML_TYPE_Q4_0 = 2, + GGML_TYPE_Q4_1 = 3, + // GGML_TYPE_Q4_2 = 4, support has been removed + // GGML_TYPE_Q4_3 = 5, support has been removed + GGML_TYPE_Q5_0 = 6, + GGML_TYPE_Q5_1 = 7, + GGML_TYPE_Q8_0 = 8, + GGML_TYPE_Q8_1 = 9, + GGML_TYPE_Q2_K = 10, + GGML_TYPE_Q3_K = 11, + GGML_TYPE_Q4_K = 12, + GGML_TYPE_Q5_K = 13, + GGML_TYPE_Q6_K = 14, + GGML_TYPE_Q8_K = 15, + GGML_TYPE_IQ2_XXS = 16, + GGML_TYPE_IQ2_XS = 17, + GGML_TYPE_IQ3_XXS = 18, + GGML_TYPE_IQ1_S = 19, + GGML_TYPE_IQ4_NL = 20, + GGML_TYPE_IQ3_S = 21, + GGML_TYPE_IQ2_S = 22, + GGML_TYPE_IQ4_XS = 23, + GGML_TYPE_I8 = 24, + GGML_TYPE_I16 = 25, + GGML_TYPE_I32 = 26, + GGML_TYPE_I64 = 27, + GGML_TYPE_F64 = 28, + GGML_TYPE_IQ1_M = 29, + GGML_TYPE_BF16 = 30, + GGML_TYPE_Q4_0_4_4 = 31, + GGML_TYPE_Q4_0_4_8 = 32, + GGML_TYPE_Q4_0_8_8 = 33, + GGML_TYPE_TQ1_0 = 34, + GGML_TYPE_TQ2_0 = 35, + GGML_TYPE_COUNT, +}; + +inline float GetQuantBit(GGMLType gt) { + switch (gt) { + case GGML_TYPE_I32: + case GGML_TYPE_F32: + return 32.0; + case GGML_TYPE_I16: + case GGML_TYPE_BF16: + case GGML_TYPE_F16: + return 16.0; + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + return 2.31; + case GGML_TYPE_Q2_K: + return 2.5625; + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_Q3_K: + return 3.4375; + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + case GGML_TYPE_Q4_0_8_8: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + return 4.5; + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + return 5.5; + case GGML_TYPE_Q6_K: + return 6.5625; + case GGML_TYPE_I8: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q8_K: + return 8.0; + + case GGML_TYPE_I64: + case GGML_TYPE_F64: + return 64.0; + + default: + return 8.0; + } +} + +inline std::string to_string(GGMLType t) { + switch (t) { + case GGML_TYPE_F32: + return "F32"; + case GGML_TYPE_F16: + return "F16"; + case GGML_TYPE_Q4_0: + return "Q4_0"; + case GGML_TYPE_Q4_1: + return "Q4_1"; + case GGML_TYPE_Q5_0: + return "Q5_0"; + case GGML_TYPE_Q5_1: + return "Q5_1"; + case GGML_TYPE_Q8_0: + return "Q8_0"; + case GGML_TYPE_Q8_1: + return "Q8_1"; + case GGML_TYPE_Q2_K: + return "Q2_K"; + case GGML_TYPE_Q3_K: + return "Q3_K"; + case GGML_TYPE_Q4_K: + return "Q4_K"; + case GGML_TYPE_Q5_K: + return "Q5_K"; + case GGML_TYPE_Q6_K: + return "Q6_K"; + case GGML_TYPE_Q8_K: + return "Q8_K"; + case GGML_TYPE_IQ2_XXS: + return "IQ2_XXS"; + case GGML_TYPE_IQ2_XS: + return "IQ2_XS"; + case GGML_TYPE_IQ3_XXS: + return "IQ3_XXS"; + case GGML_TYPE_IQ1_S: + return "IQ1_S"; + case GGML_TYPE_IQ4_NL: + return "IQ4_NL"; + case GGML_TYPE_IQ3_S: + return "IQ3_S"; + case GGML_TYPE_IQ2_S: + return "IQ2_S"; + case GGML_TYPE_IQ4_XS: + return "IQ4_XS"; + case GGML_TYPE_I8: + return "I8"; + case GGML_TYPE_I16: + return "I16"; + case GGML_TYPE_I32: + return "I32"; + case GGML_TYPE_I64: + return "I64"; + case GGML_TYPE_F64: + return "F64"; + case GGML_TYPE_IQ1_M: + return "IQ1_M"; + case GGML_TYPE_BF16: + return "BF16"; + case GGML_TYPE_Q4_0_4_4: + return "Q4_0_4_4"; + case GGML_TYPE_Q4_0_4_8: + return "Q4_0_4_8"; + case GGML_TYPE_Q4_0_8_8: + return "Q4_0_8_8"; + case GGML_TYPE_TQ1_0: + return "TQ1_0"; + case GGML_TYPE_TQ2_0: + return "TQ2_0"; + default: + return "Invalid"; + } +} + +struct GGMLTypeTrait { + uint64_t block_size; + uint64_t type_size; + bool is_quantized; +}; + +const std::unordered_map kGGMLTypeTraits = { + {GGML_TYPE_F32, {.block_size = 1, .type_size = 4}}, + {GGML_TYPE_F16, {.block_size = 1, .type_size = 2}}, + {GGML_TYPE_Q4_0, {.block_size = 32, .type_size = 18, .is_quantized = true}}, + {GGML_TYPE_Q4_1, {.block_size = 32, .type_size = 20, .is_quantized = true}}, + {GGML_TYPE_Q5_0, {.block_size = 32, .type_size = 22, .is_quantized = true}}, + {GGML_TYPE_Q5_1, {.block_size = 32, .type_size = 24, .is_quantized = true}}, + {GGML_TYPE_Q8_0, {.block_size = 32, .type_size = 34, .is_quantized = true}}, + {GGML_TYPE_Q8_1, {.block_size = 32, .type_size = 36, .is_quantized = true}}, + {GGML_TYPE_Q2_K, + {.block_size = 256, .type_size = 84, .is_quantized = true}}, + {GGML_TYPE_Q3_K, + {.block_size = 256, .type_size = 110, .is_quantized = true}}, + {GGML_TYPE_Q4_K, + {.block_size = 256, .type_size = 144, .is_quantized = true}}, + {GGML_TYPE_Q5_K, + {.block_size = 256, .type_size = 176, .is_quantized = true}}, + {GGML_TYPE_Q6_K, + {.block_size = 256, .type_size = 210, .is_quantized = true}}, + {GGML_TYPE_Q8_K, + {.block_size = 256, .type_size = 292, .is_quantized = true}}, + {GGML_TYPE_IQ2_XXS, + {.block_size = 256, .type_size = 66, .is_quantized = true}}, + {GGML_TYPE_IQ2_XS, + {.block_size = 256, .type_size = 74, .is_quantized = true}}, + {GGML_TYPE_IQ3_XXS, + {.block_size = 256, .type_size = 98, .is_quantized = true}}, + {GGML_TYPE_IQ1_S, + {.block_size = 256, .type_size = 50, .is_quantized = true}}, + {GGML_TYPE_IQ4_NL, + {.block_size = 32, .type_size = 18, .is_quantized = true}}, + {GGML_TYPE_IQ3_S, + {.block_size = 256, .type_size = 110, .is_quantized = true}}, + {GGML_TYPE_IQ2_S, + {.block_size = 256, .type_size = 82, .is_quantized = true}}, + {GGML_TYPE_IQ4_XS, + {.block_size = 256, .type_size = 136, .is_quantized = true}}, + {GGML_TYPE_I8, {.block_size = 1, .type_size = 1}}, + {GGML_TYPE_I16, {.block_size = 1, .type_size = 2}}, + {GGML_TYPE_I32, {.block_size = 1, .type_size = 4}}, + {GGML_TYPE_I64, {.block_size = 1, .type_size = 8}}, + {GGML_TYPE_F64, {.block_size = 1, .type_size = 8}}, + {GGML_TYPE_IQ1_M, + {.block_size = 256, .type_size = 56, .is_quantized = true}}, + {GGML_TYPE_BF16, {.block_size = 1, .type_size = 2}}, + {GGML_TYPE_Q4_0_4_4, + {.block_size = 32, .type_size = 18, .is_quantized = true}}, + {GGML_TYPE_Q4_0_4_8, + {.block_size = 32, .type_size = 18, .is_quantized = true}}, + {GGML_TYPE_Q4_0_8_8, + {.block_size = 32, .type_size = 18, .is_quantized = true}}, + {GGML_TYPE_TQ1_0, + {.block_size = 256, .type_size = 54, .is_quantized = true}}, + {GGML_TYPE_TQ2_0, + {.block_size = 256, .type_size = 66, .is_quantized = true}}, +}; +} // namespace hardware \ No newline at end of file diff --git a/engine/utils/hardware/gguf/gguf_file.h b/engine/utils/hardware/gguf/gguf_file.h new file mode 100644 index 000000000..1263debf2 --- /dev/null +++ b/engine/utils/hardware/gguf/gguf_file.h @@ -0,0 +1,537 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#include +#include +#else +#include // For memory-mapped file +#include // For file descriptors +#endif + +#include "ggml.h" +#include "utils/string_utils.h" + +// #define GGUF_LOG(msg) \ +// do { \ +// std::cout << __FILE__ << "(@" << __LINE__ << "): " << msg << '\n'; \ +// } while (false) + +#define GGUF_LOG(msg) +namespace hardware { +#undef min +#undef max + +using GGUFMagic = uint32_t; +constexpr const GGUFMagic kGGUFMagicGGML = 0x67676d6c; +constexpr const GGUFMagic kGGUFMagicGGMF = 0x67676d66; +constexpr const GGUFMagic kGGUFMagicGGJT = 0x67676a74; +constexpr const GGUFMagic kGGUFMagicGGUFLe = 0x46554747; // GGUF +constexpr const GGUFMagic kGGUFMagicGGUFBe = 0x47475546; // GGUF + +using GGUFVersion = uint32_t; +constexpr const GGUFVersion kGGUFVersionV1 = 1; +constexpr const GGUFVersion kGGUFVersionV2 = 2; +constexpr const GGUFVersion kGGUFVersionV3 = 3; + +enum GGUFMetadataValueType : uint32_t { + GGUFMetadataValueTypeUint8 = 0, + GGUFMetadataValueTypeInt8, + GGUFMetadataValueTypeUint16, + GGUFMetadataValueTypeInt16, + GGUFMetadataValueTypeUint32, + GGUFMetadataValueTypeInt32, + GGUFMetadataValueTypeFloat32, + GGUFMetadataValueTypeBool, + GGUFMetadataValueTypeString, + GGUFMetadataValueTypeArray, + GGUFMetadataValueTypeUint64, + GGUFMetadataValueTypeInt64, + GGUFMetadataValueTypeFloat64, + _GGUFMetadataValueTypeCount // Unknown +}; + +struct GGUFMetadataKV { + // Key is the key of the metadata key-value pair, + // which is no larger than 64 bytes long. + std::string key; // Using std::string for dynamic string handling + + // ValueType is the type of the metadata value. + GGUFMetadataValueType value_type; // Enum to represent value types + + // Value is the value of the metadata key-value pair. + std::any value; +}; + +struct GGUFMetadataKVArrayValue { + /* Basic */ + + // type is the type of the array item. + GGUFMetadataValueType type; // Enum to represent value types + + // Len is the length of the array. + uint64_t len; // Using uint64_t for length + + // Array holds all array items. + std::vector arr; + /* Appendix */ + + // start_offset is the offset in bytes of the GGUFMetadataKVArrayValue in the GGUFFile file. + int64_t start_offset; // Using int64_t for offset + + // Size is the size of the array in bytes. + int64_t size; // Using int64_t for size +}; + +inline std::string to_string(GGUFMetadataValueType vt, const std::any& v) { + switch (vt) { + case GGUFMetadataValueTypeUint8: + return std::to_string(std::any_cast(v)); + case GGUFMetadataValueTypeInt8: + return std::to_string(std::any_cast(v)); + case GGUFMetadataValueTypeUint16: + return std::to_string(std::any_cast(v)); + case GGUFMetadataValueTypeInt16: + return std::to_string(std::any_cast(v)); + case GGUFMetadataValueTypeUint32: + return std::to_string(std::any_cast(v)); + case GGUFMetadataValueTypeInt32: + return std::to_string(std::any_cast(v)); + case GGUFMetadataValueTypeFloat32: + return std::to_string(std::any_cast(v)); + case GGUFMetadataValueTypeBool: + return std::to_string(std::any_cast(v)); + case GGUFMetadataValueTypeString: + return std::any_cast(v); + case GGUFMetadataValueTypeUint64: + return std::to_string(std::any_cast(v)); + case GGUFMetadataValueTypeInt64: + return std::to_string(std::any_cast(v)); + case GGUFMetadataValueTypeFloat64: + return std::to_string(std::any_cast(v)); + default: + break; + } + return "array"; +} +inline std::string to_string(const GGUFMetadataKVArrayValue& arr_v) { + std::string res; + auto num = std::min(size_t(5), arr_v.arr.size()); + for (size_t i = 0; i < num; i++) { + res += to_string(arr_v.type, arr_v.arr[i]) + " "; + } + return res; +} + +inline std::string to_string(const GGUFMetadataKV& kv) { + switch (kv.value_type) { + case GGUFMetadataValueTypeUint8: + return std::to_string(std::any_cast(kv.value)); + case GGUFMetadataValueTypeInt8: + return std::to_string(std::any_cast(kv.value)); + case GGUFMetadataValueTypeUint16: + return std::to_string(std::any_cast(kv.value)); + case GGUFMetadataValueTypeInt16: + return std::to_string(std::any_cast(kv.value)); + case GGUFMetadataValueTypeUint32: + return std::to_string(std::any_cast(kv.value)); + case GGUFMetadataValueTypeInt32: + return std::to_string(std::any_cast(kv.value)); + case GGUFMetadataValueTypeFloat32: + return std::to_string(std::any_cast(kv.value)); + case GGUFMetadataValueTypeBool: + return std::to_string(std::any_cast(kv.value)); + case GGUFMetadataValueTypeString: + return std::any_cast(kv.value); + case GGUFMetadataValueTypeUint64: + return std::to_string(std::any_cast(kv.value)); + case GGUFMetadataValueTypeInt64: + return std::to_string(std::any_cast(kv.value)); + case GGUFMetadataValueTypeFloat64: + return std::to_string(std::any_cast(kv.value)); + case GGUFMetadataValueTypeArray: + return to_string(std::any_cast(kv.value)); + default: + break; + } + return "Invalid type "; +} + + + +struct GGUFTensorInfo { + /* Basic */ + std::string name; + + // NDimensions is the number of dimensions of the tensor. + uint32_t n_dimensions; + // Dimensions is the dimensions of the tensor, + // the length is NDimensions. + std::vector dimensions; + // type is the type of the tensor. + GGMLType type; + // Offset is the offset in bytes of the tensor's data in this file. + // + // The offset is relative to tensor data, not to the start of the file. + uint64_t offset; + + /* Appendix */ + + // StartOffset is the offset in bytes of the GGUFTensorInfo in the GGUFFile file. + // + // The offset is the start of the file. + int64_t start_offset; +}; + +struct GGUFHelper { + uint8_t* data; + uint8_t* d_close; + uint64_t file_size; + + bool OpenAndMMap(const std::string& file_path) { +#ifdef _WIN32 + HANDLE file_handle = INVALID_HANDLE_VALUE; + HANDLE file_mapping = nullptr; + file_handle = + CreateFileA(file_path.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, + OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr); + if (file_handle == INVALID_HANDLE_VALUE) { + std::cout << "Failed to open file" << std::endl; + return false; + } + // Get the file size + LARGE_INTEGER file_size_struct; + if (!GetFileSizeEx(file_handle, &file_size_struct)) { + CloseHandle(file_handle); + std::cout << "Failed to open file" << std::endl; + return false; + } + file_size = static_cast(file_size_struct.QuadPart); + + // Create a file mapping object + file_mapping = + CreateFileMappingA(file_handle, nullptr, PAGE_READONLY, 0, 0, nullptr); + if (file_mapping == nullptr) { + CloseHandle(file_handle); + std::cout << "Failed to create file mapping" << std::endl; + return false; + } + + // Map the file into memory + data = static_cast( + MapViewOfFile(file_mapping, FILE_MAP_READ, 0, 0, file_size)); + if (data == nullptr) { + CloseHandle(file_mapping); + CloseHandle(file_handle); + std::cout << "Failed to map file" << std::endl; + return false; + } + + // Close the file handle, as it is no longer needed after mapping + CloseHandle(file_handle); + d_close = data; +#else + file_size = std::filesystem::file_size(file_path); + + int fd = open(file_path.c_str(), O_RDONLY); + // Memory-map the file + data = static_cast( + mmap(nullptr, file_size, PROT_READ, MAP_PRIVATE, fd, 0)); + if (data == MAP_FAILED) { + perror("Error mapping file"); + close(fd); + return false; + } + + close(fd); + d_close = data; +#endif + return true; + } + + ~GGUFHelper() { Close(); } + + void Close() { +#ifdef _WIN32 + if (d_close != nullptr) { + UnmapViewOfFile(d_close); + d_close = nullptr; + } +#else + if (d_close != nullptr && d_close != MAP_FAILED) { + munmap(d_close, file_size); + d_close = nullptr; + } +#endif + } + + template + T Read() { + static_assert(std::is_floating_point::value || + std::is_integral::value || std::is_same::value); + T res = *reinterpret_cast(data); + data += sizeof(T); + return res; + } + + std::string ReadString() { + auto l = Read(); + std::string res(reinterpret_cast(data), l); + auto r = res; + data += l; + return r; + } + + GGUFMetadataKVArrayValue ReadArray() { + GGUFMetadataKVArrayValue v; + v.start_offset = (data - d_close); + v.type = static_cast(Read()); + auto arr_length = Read(); + for (uint64_t i = 0; i < arr_length; ++i) { + switch (v.type) { + case GGUFMetadataValueTypeUint8: + v.arr.push_back(Read()); + break; + case GGUFMetadataValueTypeInt8: + v.arr.push_back(Read()); + break; + case GGUFMetadataValueTypeUint16: + v.arr.push_back(Read()); + break; + case GGUFMetadataValueTypeInt16: + v.arr.push_back(Read()); + break; + case GGUFMetadataValueTypeUint32: + v.arr.push_back(Read()); + break; + case GGUFMetadataValueTypeInt32: + v.arr.push_back(Read()); + break; + case GGUFMetadataValueTypeFloat32: + v.arr.push_back(Read()); + break; + case GGUFMetadataValueTypeBool: + v.arr.push_back(Read()); + break; + case GGUFMetadataValueTypeString: + v.arr.push_back(ReadString()); + break; + case GGUFMetadataValueTypeUint64: + v.arr.push_back(Read()); + break; + case GGUFMetadataValueTypeInt64: + v.arr.push_back(Read()); + break; + case GGUFMetadataValueTypeFloat64: + v.arr.push_back(Read()); + break; + default: + std::cout << "Invalid type: " << std::to_string(v.type); + } + } + v.size = data - v.start_offset - d_close - 4 - 8; + return v; + } + + std::any ReadValue(GGUFMetadataValueType vt) { + switch (vt) { + case GGUFMetadataValueTypeUint8: + return Read(); + case GGUFMetadataValueTypeInt8: + return Read(); + case GGUFMetadataValueTypeUint16: + return Read(); + case GGUFMetadataValueTypeInt16: + return Read(); + case GGUFMetadataValueTypeUint32: + return Read(); + case GGUFMetadataValueTypeInt32: + return Read(); + case GGUFMetadataValueTypeFloat32: + return Read(); + case GGUFMetadataValueTypeBool: + return Read(); + case GGUFMetadataValueTypeString: + return ReadString(); + case GGUFMetadataValueTypeArray: + return ReadArray(); + case GGUFMetadataValueTypeUint64: + return Read(); + case GGUFMetadataValueTypeInt64: + return Read(); + case GGUFMetadataValueTypeFloat64: + return Read(); + default: + std::cout << "Invalid type: " << vt; + return {}; + } + } + + GGUFMetadataKV ReadMetadataKV() { + GGUFMetadataKV kv; + kv.key = ReadString(); + auto vt = Read(); + kv.value_type = GGUFMetadataValueType(vt); + kv.value = ReadValue(kv.value_type); + return kv; + } + + std::shared_ptr ReadTensorInfo() { + auto ti = std::make_shared(); + ti->start_offset = data - d_close; + ti->name = ReadString(); + ti->n_dimensions = Read(); + ti->dimensions.resize(ti->n_dimensions); + for (size_t i = 0; i < ti->n_dimensions; i++) { + ti->dimensions[i] = Read(); + } + auto v = Read(); + ti->type = GGMLType(v); + ti->offset = Read(); + return ti; + } +}; + +constexpr const auto ErrGGUFFileInvalidFormat = "invalid GGUF format"; + +struct GGUFHeader { + // Magic is a magic number that announces that this is a GGUF file. + GGUFMagic magic; + // Version is a version of the GGUF file format. + GGUFVersion version; + // TensorCount is the number of tensors in the file. + uint64_t tensor_count; + // MetadataKVCount is the number of key-value pairs in the metadata. + uint64_t metadata_kv_count; + // MetadataKV are the key-value pairs in the metadata, + std::vector metadata_kv; + + std::pair Get(const std::string& name) { + for (auto const& kv : metadata_kv) { + if (kv.key == name) { + return std::pair(kv, true); + } + } + return std::pair(GGUFMetadataKV{}, false); + } +}; + +struct GGUFFile { + // header is the header of the GGUF file. + GGUFHeader header; + // tensor_infos are the tensor infos of the GGUF file, + // the size of TensorInfos is equal to `Header.TensorCount`. + std::vector> tensor_infos; + + // padding is the padding size of the GGUF file, + // which is used to split Header and TensorInfos from tensor data. + int64_t padding; + // split_paddings holds the padding size slice of the GGUF file splits, + // each item represents splitting Header and TensorInfos from tensor data. + // + // The length of split_paddings is the number of split files. + std::vector split_paddings; + // tensor_data_start_offset is the offset in bytes of the tensor data in this file. + // + // The offset is the start of the file. + int64_t tensor_data_start_offset; + // split_tensor_data_start_offsets holds the offset slice in bytes of the tensor data of the GGUF file splits, + // each item represents the offset of the tensor data in the split file. + // + // The length of split_tensor_data_start_offsets is the number of split files. + std::vector split_tensor_data_start_offsets; + + /* Appendix */ + + // size is the size of the GGUF file, + // if the file is split, the size is the sum of all split files. + uint64_t size; + // split_sizes holds the size slice of the GGUF file splits, + // each item represents the size of the split file. + // + // The length of split_sizes is the number of split files. + std::vector split_sizes; + // model_size is the size of the model when loading. + uint64_t model_size; + // split_model_sizes holds the size slice of the model, + // each item represents a size when loading of the split file. + // + // The length of split_model_sizes is the number of split files. + std::vector split_model_sizes; + + // model_parameters is the number of the model parameters. + uint64_t model_parameters; + // model_bits_per_weight is the bits per weight of the model, + // which describes how many bits are used to store a weight, + // higher is better. + double model_bits_per_weight; +}; + +inline GGUFFile ParseGgufFile(const std::string& path) { + GGUFFile gf; + GGUFHelper h; + h.OpenAndMMap(path); + + GGUFMagic magic = h.Read(); + // GGUF_LOG("magic: " << magic); + gf.header.magic = magic; + GGUFVersion version = h.Read(); + auto tensor_count = h.Read(); + // GGUF_LOG("tensor_count: " << tensor_count); + gf.header.tensor_count += tensor_count; + + auto metadata_kv_count = h.Read(); + gf.header.metadata_kv_count += metadata_kv_count; + // GGUF_LOG("metadata_kv_count: " << metadata_kv_count); + + // metadata kv + { + std::vector kvs; + kvs.resize(metadata_kv_count); + for (size_t i = 0; i < metadata_kv_count; i++) { + kvs[i] = h.ReadMetadataKV(); + GGUF_LOG("i: " << i << " " << kvs[i].value_type << " " << kvs[i].key + << ": " << to_string(kvs[i])); + } + for (auto const& kv : kvs) { + if (kv.key == "split.no") { + gf.header.metadata_kv_count--; + continue; + } + gf.header.metadata_kv.push_back(kv); + } + } + + { + std::vector> tis; + tis.resize(tensor_count); + for (size_t i = 0; i < tensor_count; i++) { + tis[i] = h.ReadTensorInfo(); + // auto tto_string = [](const std::vector& ds) -> std::string { + // std::string res = "["; + // for (auto d : ds) + // res += std::to_string(d) + " "; + // return res + "]"; + // }; + // auto ds = tto_string(tis[i]->dimensions); + // GGUF_LOG("i: " << i << " name: " << tis[i]->name + // << " type: " << to_string(tis[i]->type) << " dimensions: " + // << std::to_string(tis[i]->n_dimensions) << " " << ds); + } + gf.tensor_infos = tis; + } + return gf; +} +} // namespace hardware \ No newline at end of file diff --git a/engine/utils/hardware/gguf/gguf_file_estimate.h b/engine/utils/hardware/gguf/gguf_file_estimate.h new file mode 100644 index 000000000..fde0b0ac0 --- /dev/null +++ b/engine/utils/hardware/gguf/gguf_file_estimate.h @@ -0,0 +1,183 @@ +#pragma once +#include +#include +#include "gguf_file.h" +#include "json/json.h" + +namespace hardware { +inline uint64_t BytesToMiB(uint64_t b) { + return (double)b / 1024 / 1024; +}; +struct RunConfig { + int ngl; + int ctx_len; + int n_batch; + int n_ubatch; + std::string kv_cache_type; + int64_t free_vram_MiB; +}; + +struct CpuMode { + int64_t ram_MiB; +}; + +struct GpuMode { + int64_t ram_MiB; + int64_t vram_MiB; + int ngl; + int ctx_len; + int recommend_ngl; +}; + +struct Estimation { + CpuMode cpu_mode; + GpuMode gpu_mode; +}; + +inline Json::Value ToJson(const Estimation& es) { + Json::Value res; + Json::Value cpu; + cpu["ram"] = es.cpu_mode.ram_MiB; + Json::Value gpus(Json::arrayValue); + Json::Value gpu; + gpu["ram"] = es.gpu_mode.ram_MiB; + gpu["vram"] = es.gpu_mode.vram_MiB; + gpu["ngl"] = es.gpu_mode.ngl; + gpu["context_length"] = es.gpu_mode.ctx_len; + gpu["recommend_ngl"] = es.gpu_mode.recommend_ngl; + gpus.append(gpu); + res["cpu_mode"] = cpu; + res["gpu_mode"] = gpus; + return res; +} + +inline float GetQuantBit(const std::string& kv_cache_t) { + if (kv_cache_t == "f16") { + return 16.0; + } else if (kv_cache_t == "q8_0") { + return 8.0; + } else if (kv_cache_t == "q4_0") { + return 4.5; + } + return 16.0; +} + +inline Estimation EstimateLLaMACppRun(const std::string& file_path, + const RunConfig& rc) { + Estimation res; + // token_embeddings_size = n_vocab * embedding_length * 2 * quant_bit/16 bytes + //RAM = token_embeddings_size + ((total_ngl-ngl) >=1 ? Output_layer_size + (total_ngl - ngl - 1 ) / (total_ngl-1) * (total_file_size - token_embeddings_size - Output_layer_size) : 0 ) (bytes) + + // VRAM = total_file_size - RAM (bytes) + auto gf = ParseGgufFile(file_path); + int32_t embedding_length = 0; + int64_t n_vocab = 0; + int32_t num_block = 0; + int32_t total_ngl = 0; + auto file_size = std::filesystem::file_size(file_path); + for (auto const& kv : gf.header.metadata_kv) { + if (kv.key.find("embedding_length") != std::string::npos) { + embedding_length = std::any_cast(kv.value); + } else if (kv.key == "tokenizer.ggml.tokens") { + n_vocab = std::any_cast(kv.value).arr.size(); + } else if (kv.key.find("block_count") != std::string::npos) { + num_block = std::any_cast(kv.value); + total_ngl = num_block + 1; + } + } + + // std::cout << n_vocab << std::endl; + + // token_embeddings_size = n_vocab * embedding_length * 2 * quant_bit_in/16 bytes + int32_t quant_bit_in = 0; + int32_t quant_bit_out = 0; + + for (auto const& ti : gf.tensor_infos) { + if (ti->name == "output.weight") { + quant_bit_out = GetQuantBit(ti->type); + // std::cout << ti->type << std::endl; + } else if (ti->name == "token_embd.weight") { + quant_bit_in = GetQuantBit(ti->type); + // std::cout << ti->type << std::endl; + } + } + // output.weight + // token_embd.weight + // std::cout << "embedding_length: " << embedding_length << std::endl; + // std::cout << "n_vocab: " << n_vocab << std::endl; + // std::cout << "file_size: " << file_size << std::endl; + // Model weight + int64_t token_embeddings_size = + n_vocab * embedding_length * 2 * quant_bit_in / 16; + int64_t output_layer_size = + n_vocab * embedding_length * 2 * quant_bit_out / 16; + // RAM = token_embeddings_size + ((total_ngl-ngl) >=1 ? output_layer_size + (total_ngl - ngl - 1 ) / (total_ngl-1) * (total_file_size - token_embeddings_size - output_layer_size) : 0 ) (bytes) + int64_t offload = 0; + if (total_ngl >= rc.ngl + 1) { + offload = output_layer_size + + (double)(total_ngl - rc.ngl - 1) / (total_ngl - 1) * + (file_size - token_embeddings_size - output_layer_size); + } + + int64_t ram_usage = token_embeddings_size + offload; + int64_t vram_usage = file_size - ram_usage; + // std::cout << "token_embeddings_size: " << BytesToMiB(token_embeddings_size) + // << std::endl; + // std::cout << "output_layer_size: " << BytesToMiB(output_layer_size) + // << std::endl; + // std::cout << "ram_usage: " << BytesToMiB(ram_usage) << std::endl; + // std::cout << "vram_usage: " << BytesToMiB(vram_usage) << std::endl; + + // KV cache + // kv_cache_size = ctx_len/8192 * hidden_dim/4096 * quant_bit/16 * num_block/33 * 1 (GB) + auto hidden_dim = embedding_length; + int kv_quant_bit = + GetQuantBit(rc.kv_cache_type); // f16, 8 bits for q8_0, 4.5 bits for q4_0 + int64_t kv_cache_size = (double)(1024 * 1024 * 1024) * rc.ctx_len / 8192 * + hidden_dim / 4096 * kv_quant_bit / 16 * num_block / + 33; //(bytes) + + // std::cout << "kv_cache_size: " << BytesToMiB(kv_cache_size) << std::endl; + + // VRAM = (min(n_batch, n_ubatch))/ 512 * 266 (MiB) + int64_t preprocessing_buffer_size = + (double)std::min(rc.n_batch, rc.n_ubatch) / 512 * 266 * 1024 * 1024 * + n_vocab / 128256 /*llama3 n_vocab*/; //(bytes) + if (total_ngl != rc.ngl) { + preprocessing_buffer_size += output_layer_size; + } + // std::cout << "preprocessing_buffer_size: " + // << BytesToMiB(preprocessing_buffer_size) << std::endl; + + // CPU mode + { + // Model weight + int64_t model_weight = file_size; + // KV cache + // Buffer + res.cpu_mode.ram_MiB = + BytesToMiB(model_weight + kv_cache_size + preprocessing_buffer_size); + } + // GPU mode + { + res.gpu_mode.ctx_len = rc.ctx_len; + res.gpu_mode.ngl = rc.ngl; + res.gpu_mode.ram_MiB = BytesToMiB(ram_usage); + // We also need to reserve extra 100 MiB -200 MiB of Ram for some small buffers during processing + constexpr const int64_t kDeltaVramMiB = 200; + res.gpu_mode.vram_MiB = + kDeltaVramMiB + + BytesToMiB(vram_usage + kv_cache_size + preprocessing_buffer_size); + if (rc.free_vram_MiB > res.gpu_mode.vram_MiB) { + res.gpu_mode.recommend_ngl = total_ngl; + } else { + res.gpu_mode.recommend_ngl = + (double)rc.free_vram_MiB / res.gpu_mode.vram_MiB * rc.ngl; + } +#if defined(__APPLE__) && defined(__MACH__) + res.cpu_mode.ram_MiB = res.gpu_mode.vram_MiB + res.gpu_mode.ram_MiB; +#endif + } + return res; +} +} // namespace hardware \ No newline at end of file From 73bbd741350820d2e23946f5f43bec399b15f066 Mon Sep 17 00:00:00 2001 From: James Date: Mon, 2 Dec 2024 11:20:39 +0700 Subject: [PATCH 05/44] fix: floating point for models endpoint --- engine/controllers/models.cc | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index af8061269..2760663d0 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -12,6 +12,14 @@ #include "utils/logging_utils.h" #include "utils/string_utils.h" +namespace { +std::string ToJsonStringWithPrecision(Json::Value& input, int precision = 2) { + Json::StreamWriterBuilder wbuilder; + wbuilder.settings_["precision"] = 2; + return Json::writeString(wbuilder, input); +} +} // namespace + void Models::PullModel(const HttpRequestPtr& req, std::function&& callback) { if (!http_util::HasFieldInReq(req, callback, "model")) { @@ -182,9 +190,11 @@ void Models::ListModel( << model_entry.path_to_model_yaml << ", error: " << e.what(); } } + ret["data"] = data; ret["result"] = "OK"; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + auto ret_str = ToJsonStringWithPrecision(ret); + auto resp = cortex_utils::CreateCortexHttpTextAsJsonResponse(ret_str); resp->setStatusCode(k200OK); callback(resp); } else { From 3bdf8fa1cf0cc83063dee14e1d5ef4e827ae1bb4 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 2 Dec 2024 12:58:21 +0700 Subject: [PATCH 06/44] fix: get driver version and cuda version at a single command (#1754) Co-authored-by: vansangpfiev --- engine/cli/commands/engine_install_cmd.cc | 6 ++- engine/cli/commands/engine_install_cmd.h | 3 +- engine/cli/commands/engine_update_cmd.cc | 3 +- engine/cli/commands/server_start_cmd.cc | 2 +- engine/services/engine_service.h | 3 +- engine/services/hardware_service.cc | 4 +- engine/utils/hardware/gpu_info.h | 3 +- engine/utils/system_info_utils.h | 54 +++++++++-------------- 8 files changed, 36 insertions(+), 42 deletions(-) diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index 477e38ee2..21cd9f042 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -37,7 +37,8 @@ bool EngineInstallCmd::Exec(const std::string& engine, dp.Connect(host_, port_); // engine can be small, so need to start ws first auto dp_res = std::async(std::launch::deferred, [&dp] { - bool need_cuda_download = !system_info_utils::GetCudaVersion().empty(); + bool need_cuda_download = + !system_info_utils::GetDriverAndCudaVersion().second.empty(); if (need_cuda_download) { return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit}); } else { @@ -149,7 +150,8 @@ bool EngineInstallCmd::Exec(const std::string& engine, dp.Connect(host_, port_); // engine can be small, so need to start ws first auto dp_res = std::async(std::launch::deferred, [&dp] { - bool need_cuda_download = !system_info_utils::GetCudaVersion().empty(); + bool need_cuda_download = + !system_info_utils::GetDriverAndCudaVersion().second.empty(); if (need_cuda_download) { return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit}); } else { diff --git a/engine/cli/commands/engine_install_cmd.h b/engine/cli/commands/engine_install_cmd.h index deb9197e1..d50776dc4 100644 --- a/engine/cli/commands/engine_install_cmd.h +++ b/engine/cli/commands/engine_install_cmd.h @@ -14,7 +14,8 @@ class EngineInstallCmd { port_(port), show_menu_(show_menu), hw_inf_{.sys_inf = system_info_utils::GetSystemInfo(), - .cuda_driver_version = system_info_utils::GetCudaVersion()} {}; + .cuda_driver_version = + system_info_utils::GetDriverAndCudaVersion().second} {}; bool Exec(const std::string& engine, const std::string& version = "latest", const std::string& src = ""); diff --git a/engine/cli/commands/engine_update_cmd.cc b/engine/cli/commands/engine_update_cmd.cc index 9717ddb15..a86106ed2 100644 --- a/engine/cli/commands/engine_update_cmd.cc +++ b/engine/cli/commands/engine_update_cmd.cc @@ -25,7 +25,8 @@ bool EngineUpdateCmd::Exec(const std::string& host, int port, dp.Connect(host, port); // engine can be small, so need to start ws first auto dp_res = std::async(std::launch::deferred, [&dp] { - bool need_cuda_download = !system_info_utils::GetCudaVersion().empty(); + bool need_cuda_download = + !system_info_utils::GetDriverAndCudaVersion().second.empty(); if (need_cuda_download) { return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit}); } else { diff --git a/engine/cli/commands/server_start_cmd.cc b/engine/cli/commands/server_start_cmd.cc index 6f36515f1..cfed72c24 100644 --- a/engine/cli/commands/server_start_cmd.cc +++ b/engine/cli/commands/server_start_cmd.cc @@ -8,7 +8,7 @@ namespace commands { namespace { bool TryConnectToServer(const std::string& host, int port) { - constexpr const auto kMaxRetry = 3u; + constexpr const auto kMaxRetry = 4u; auto count = 0u; // Check if server is started while (true) { diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index a18a276cd..47d7c272f 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -62,7 +62,8 @@ class EngineService : public EngineServiceI { explicit EngineService(std::shared_ptr download_service) : download_service_{download_service}, hw_inf_{.sys_inf = system_info_utils::GetSystemInfo(), - .cuda_driver_version = system_info_utils::GetCudaVersion()} {} + .cuda_driver_version = + system_info_utils::GetDriverAndCudaVersion().second} {} std::vector GetEngineInfoList() const; diff --git a/engine/services/hardware_service.cc b/engine/services/hardware_service.cc index 16ae234b4..a6ceb556f 100644 --- a/engine/services/hardware_service.cc +++ b/engine/services/hardware_service.cc @@ -16,7 +16,7 @@ namespace services { namespace { bool TryConnectToServer(const std::string& host, int port) { - constexpr const auto kMaxRetry = 3u; + constexpr const auto kMaxRetry = 4u; auto count = 0u; // Check if server is started while (true) { @@ -292,7 +292,7 @@ void HardwareService::UpdateHardwareInfos() { } #if defined(_WIN32) || defined(_WIN64) || defined(__linux__) - if (system_info_utils::IsNvidiaSmiAvailable()) { + if (!gpus.empty()) { const char* value = std::getenv("CUDA_VISIBLE_DEVICES"); if (value) { LOG_INFO << "CUDA_VISIBLE_DEVICES: " << value; diff --git a/engine/utils/hardware/gpu_info.h b/engine/utils/hardware/gpu_info.h index bbd4a49d6..1e10589a9 100644 --- a/engine/utils/hardware/gpu_info.h +++ b/engine/utils/hardware/gpu_info.h @@ -11,12 +11,11 @@ inline std::vector GetGPUInfo() { // Only support for nvidia for now // auto gpus = hwinfo::getAllGPUs(); auto nvidia_gpus = system_info_utils::GetGpuInfoList(); - auto cuda_version = system_info_utils::GetCudaVersion(); for (auto& n : nvidia_gpus) { res.emplace_back( GPU{.id = n.id, .name = n.name, - .version = cuda_version, + .version = nvidia_gpus[0].cuda_driver_version.value_or("unknown"), .add_info = NvidiaAddInfo{ .driver_version = n.driver_version.value_or("unknown"), diff --git a/engine/utils/system_info_utils.h b/engine/utils/system_info_utils.h index 013069699..f2fab10cb 100644 --- a/engine/utils/system_info_utils.h +++ b/engine/utils/system_info_utils.h @@ -19,7 +19,8 @@ constexpr static auto kUnsupported{"Unsupported"}; constexpr static auto kCudaVersionRegex{R"(CUDA Version:\s*([\d\.]+))"}; constexpr static auto kDriverVersionRegex{R"(Driver Version:\s*(\d+\.\d+))"}; constexpr static auto kGpuQueryCommand{ - "nvidia-smi --query-gpu=index,memory.total,memory.free,name,compute_cap,uuid " + "nvidia-smi " + "--query-gpu=index,memory.total,memory.free,name,compute_cap,uuid " "--format=csv,noheader,nounits"}; constexpr static auto kGpuInfoRegex{ R"((\d+),\s*(\d+),\s*(\d+),\s*([^,]+),\s*([\d\.]+),\s*([^\n,]+))"}; @@ -100,53 +101,42 @@ inline bool IsNvidiaSmiAvailable() { #endif } -inline std::string GetDriverVersion() { +inline std::pair GetDriverAndCudaVersion() { if (!IsNvidiaSmiAvailable()) { CTL_INF("nvidia-smi is not available!"); - return ""; + return {}; } try { + std::string driver_version; + std::string cuda_version; CommandExecutor cmd("nvidia-smi"); auto output = cmd.execute(); const std::regex driver_version_reg(kDriverVersionRegex); - std::smatch match; + std::smatch driver_match; - if (std::regex_search(output, match, driver_version_reg)) { - LOG_INFO << "Gpu Driver Version: " << match[1].str(); - return match[1].str(); + if (std::regex_search(output, driver_match, driver_version_reg)) { + LOG_INFO << "Gpu Driver Version: " << driver_match[1].str(); + driver_version = driver_match[1].str(); } else { LOG_ERROR << "Gpu Driver not found!"; - return ""; + return {}; } - } catch (const std::exception& e) { - LOG_ERROR << "Error: " << e.what(); - return ""; - } -} - -inline std::string GetCudaVersion() { - if (!IsNvidiaSmiAvailable()) { - CTL_INF("nvidia-smi is not available!"); - return ""; - } - try { - CommandExecutor cmd("nvidia-smi"); - auto output = cmd.execute(); const std::regex cuda_version_reg(kCudaVersionRegex); - std::smatch match; + std::smatch cuda_match; - if (std::regex_search(output, match, cuda_version_reg)) { - LOG_INFO << "CUDA Version: " << match[1].str(); - return match[1].str(); + if (std::regex_search(output, cuda_match, cuda_version_reg)) { + LOG_INFO << "CUDA Version: " << cuda_match[1].str(); + cuda_version = cuda_match[1].str(); } else { LOG_ERROR << "CUDA Version not found!"; - return ""; + return {}; } + return std::pair(driver_version, cuda_version); } catch (const std::exception& e) { LOG_ERROR << "Error: " << e.what(); - return ""; + return {}; } } @@ -227,9 +217,9 @@ inline std::vector GetGpuInfoList() { if (!IsNvidiaSmiAvailable()) return gpuInfoList; try { - // TODO: improve by parsing both in one command execution - auto driver_version = GetDriverVersion(); - auto cuda_version = GetCudaVersion(); + auto [driver_version, cuda_version] = GetDriverAndCudaVersion(); + if (driver_version.empty() || cuda_version.empty()) + return gpuInfoList; CommandExecutor cmd(kGpuQueryCommand); auto output = cmd.execute(); @@ -249,7 +239,7 @@ inline std::vector GetGpuInfoList() { driver_version, // driver_version cuda_version, // cuda_driver_version match[5].str(), // compute_cap - match[6].str() // uuid + match[6].str() // uuid }; gpuInfoList.push_back(gpuInfo); search_start = match.suffix().first; From 9224be00bbef9a89443c6a7f876f320746c0d4d1 Mon Sep 17 00:00:00 2001 From: Gabrielle Ong Date: Mon, 2 Dec 2024 18:42:46 +0800 Subject: [PATCH 07/44] docs: fix tools should be array not object --- docs/static/openapi/cortex.json | 39 ++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 78430294f..206ee381d 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -2224,8 +2224,18 @@ "default": [], "type": "array", "items": { - "type": "object" - } + "type": "array", + "properties": { + "type": { + "type": "string", + "enum": ["function"] + }, + "function": { + "$ref": "#/components/schemas/Function" + } + }, + "required": ["type", "function"] + }, }, "metadata": { "type": "object", @@ -2286,7 +2296,7 @@ "nullable": true }, "tools": { - "type": "object" + "type": "array" }, "metadata": { "type": "object", @@ -2869,17 +2879,20 @@ } }, "tools": { - "type": "object", - "properties": { - "type": { - "type": "string", - "enum": ["function"] + "type": "array", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": ["function"] + }, + "function": { + "$ref": "#/components/schemas/Function" + } }, - "function": { - "$ref": "#/components/schemas/Function" - } - }, - "required": ["type", "function"] + "required": ["type", "function"] + } }, "tool_choice": { "anyOf": [ From 5eda21268931939c0c4c031ceec286ec293b9240 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 3 Dec 2024 18:09:09 +0700 Subject: [PATCH 08/44] chore: refactor utils (#1760) * chore: config_yaml_utils * chore: file_manager_utils * chore: curl_utils * chore: system_info_utils * chore: clean e2e tests * fix: build macos * fix: docker e2e tests * fix: e2e docker * fix: e2e tests --------- Co-authored-by: vansangpfiev --- docker/entrypoint.sh | 2 +- engine/CMakeLists.txt | 3 +- engine/cli/CMakeLists.txt | 4 + engine/cli/commands/engine_list_cmd.cc | 1 + engine/cli/commands/ps_cmd.cc | 1 + engine/cli/commands/server_start_cmd.cc | 5 + engine/cli/main.cc | 2 +- engine/e2e-test/main.py | 16 +- ...engine_uninstall.py => test_api_engine.py} | 44 ++- engine/e2e-test/test_api_engine_get.py | 22 -- engine/e2e-test/test_api_engine_install.py | 36 -- engine/e2e-test/test_api_engine_list.py | 2 +- ...l_pull_direct_url.py => test_api_model.py} | 89 ++++- engine/e2e-test/test_api_model_delete.py | 22 -- engine/e2e-test/test_api_model_get.py | 22 -- engine/e2e-test/test_api_model_list.py | 22 -- engine/e2e-test/test_api_model_start_stop.py | 46 --- engine/e2e-test/test_api_model_update.py | 23 -- ..._cli_model_delete.py => test_cli_model.py} | 22 +- .../test_cli_model_pull_direct_url.py | 32 -- engine/test/components/CMakeLists.txt | 4 + engine/utils/config_yaml_utils.cc | 177 +++++++++ engine/utils/config_yaml_utils.h | 172 +------- engine/utils/curl_utils.cc | 321 +++++++++++++++ engine/utils/curl_utils.h | 326 +--------------- engine/utils/file_manager_utils.cc | 367 ++++++++++++++++++ engine/utils/file_manager_utils.h | 366 +---------------- engine/utils/huggingface_utils.h | 1 + engine/utils/system_info_utils.cc | 141 +++++++ engine/utils/system_info_utils.h | 139 +------ 30 files changed, 1198 insertions(+), 1232 deletions(-) rename engine/e2e-test/{test_api_engine_uninstall.py => test_api_engine.py} (60%) delete mode 100644 engine/e2e-test/test_api_engine_get.py delete mode 100644 engine/e2e-test/test_api_engine_install.py rename engine/e2e-test/{test_api_model_pull_direct_url.py => test_api_model.py} (53%) delete mode 100644 engine/e2e-test/test_api_model_delete.py delete mode 100644 engine/e2e-test/test_api_model_get.py delete mode 100644 engine/e2e-test/test_api_model_list.py delete mode 100644 engine/e2e-test/test_api_model_start_stop.py delete mode 100644 engine/e2e-test/test_api_model_update.py rename engine/e2e-test/{test_cli_model_delete.py => test_cli_model.py} (58%) delete mode 100644 engine/e2e-test/test_cli_model_pull_direct_url.py create mode 100644 engine/utils/config_yaml_utils.cc create mode 100644 engine/utils/curl_utils.cc create mode 100644 engine/utils/file_manager_utils.cc create mode 100644 engine/utils/system_info_utils.cc diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 6f0493ec2..99bdd0009 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -7,10 +7,10 @@ echo "enableCors: true" >> /root/.cortexrc # Install the engine cortex engines install llama-cpp -s /opt/cortex.llamacpp -cortex engines list # Start the cortex server cortex start +cortex engines list # Keep the container running by tailing the log files tail -f /root/cortexcpp/logs/cortex.log & diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 5ffabf23c..06e778b7e 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -176,10 +176,11 @@ aux_source_directory(cortex-common CORTEX_COMMON) aux_source_directory(config CONFIG_SRC) aux_source_directory(database DB_SRC) aux_source_directory(migrations MIGR_SRC) +aux_source_directory(utils UTILS_SRC) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ) -target_sources(${TARGET_NAME} PRIVATE ${CONFIG_SRC} ${CTL_SRC} ${COMMON_SRC} ${SERVICES_SRC} ${DB_SRC} ${MIGR_SRC}) +target_sources(${TARGET_NAME} PRIVATE ${UTILS_SRC} ${CONFIG_SRC} ${CTL_SRC} ${COMMON_SRC} ${SERVICES_SRC} ${DB_SRC} ${MIGR_SRC}) set_target_properties(${TARGET_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_BINARY_DIR} diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index db2bed828..42d00ebd5 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -84,6 +84,10 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/hardware_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/download_progress.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../utils/config_yaml_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../utils/file_manager_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../utils/curl_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../utils/system_info_utils.cc ) target_link_libraries(${TARGET_NAME} PRIVATE CLI11::CLI11) diff --git a/engine/cli/commands/engine_list_cmd.cc b/engine/cli/commands/engine_list_cmd.cc index b010e8687..35584dcd2 100644 --- a/engine/cli/commands/engine_list_cmd.cc +++ b/engine/cli/commands/engine_list_cmd.cc @@ -4,6 +4,7 @@ #include "common/engine_servicei.h" #include "server_start_cmd.h" #include "utils/curl_utils.h" +#include "utils/engine_constants.h" #include "utils/logging_utils.h" #include "utils/url_parser.h" // clang-format off diff --git a/engine/cli/commands/ps_cmd.cc b/engine/cli/commands/ps_cmd.cc index c692ffc00..24ef497c6 100644 --- a/engine/cli/commands/ps_cmd.cc +++ b/engine/cli/commands/ps_cmd.cc @@ -2,6 +2,7 @@ #include #include #include "utils/curl_utils.h" +#include "utils/engine_constants.h" #include "utils/format_utils.h" #include "utils/logging_utils.h" #include "utils/string_utils.h" diff --git a/engine/cli/commands/server_start_cmd.cc b/engine/cli/commands/server_start_cmd.cc index cfed72c24..ba4f7bd82 100644 --- a/engine/cli/commands/server_start_cmd.cc +++ b/engine/cli/commands/server_start_cmd.cc @@ -1,6 +1,7 @@ #include "server_start_cmd.h" #include "commands/cortex_upd_cmd.h" #include "utils/cortex_utils.h" +#include "utils/engine_constants.h" #include "utils/file_manager_utils.h" #include "utils/widechar_conv.h" @@ -27,6 +28,10 @@ bool TryConnectToServer(const std::string& host, int port) { bool ServerStartCmd::Exec(const std::string& host, int port, const std::optional& log_level) { + if (IsServerAlive(host, port)) { + CLI_LOG("The server has already started"); + return true; + } std::string log_level_; if (!log_level.has_value()) { log_level_ = "INFO"; diff --git a/engine/cli/main.cc b/engine/cli/main.cc index 52fc5591f..8ed4beb61 100644 --- a/engine/cli/main.cc +++ b/engine/cli/main.cc @@ -148,7 +148,6 @@ int main(int argc, char* argv[]) { if (should_check_for_latest_llamacpp_version) { std::thread t1([]() { - auto config = file_manager_utils::GetCortexConfig(); // TODO: namh current we only check for llamacpp. Need to add support for other engine auto get_latest_version = []() -> cpp::result { try { @@ -176,6 +175,7 @@ int main(int argc, char* argv[]) { auto now = std::chrono::system_clock::now(); CTL_DBG("latest llama.cpp version: " << res.value()); + auto config = file_manager_utils::GetCortexConfig(); config.checkedForLlamacppUpdateAt = std::chrono::duration_cast( now.time_since_epoch()) diff --git a/engine/e2e-test/main.py b/engine/e2e-test/main.py index 9ef2970f9..e874ab3a0 100644 --- a/engine/e2e-test/main.py +++ b/engine/e2e-test/main.py @@ -3,26 +3,16 @@ ### e2e tests are expensive, have to keep engines tests in order from test_api_engine_list import TestApiEngineList -from test_api_engine_install import TestApiEngineInstall -from test_api_engine_get import TestApiEngineGet - -### models, keeps in order, note that we only uninstall engine after finishing all models test -from test_api_model_pull_direct_url import TestApiModelPullDirectUrl -from test_api_model_start_stop import TestApiModelStartStop -from test_api_model_get import TestApiModelGet -from test_api_model_list import TestApiModelList -from test_api_model_update import TestApiModelUpdate -from test_api_model_delete import TestApiModelDelete +from test_api_engine import TestApiEngine +from test_api_model import TestApiModel from test_api_model_import import TestApiModelImport -from test_api_engine_uninstall import TestApiEngineUninstall ### from test_cli_engine_get import TestCliEngineGet from test_cli_engine_install import TestCliEngineInstall from test_cli_engine_list import TestCliEngineList from test_cli_engine_uninstall import TestCliEngineUninstall -from test_cli_model_delete import TestCliModelDelete -from test_cli_model_pull_direct_url import TestCliModelPullDirectUrl +from test_cli_model import TestCliModel from test_cli_server_start import TestCliServerStart from test_cortex_update import TestCortexUpdate from test_create_log_folder import TestCreateLogFolder diff --git a/engine/e2e-test/test_api_engine_uninstall.py b/engine/e2e-test/test_api_engine.py similarity index 60% rename from engine/e2e-test/test_api_engine_uninstall.py rename to engine/e2e-test/test_api_engine.py index 1951e5c3a..57b47b879 100644 --- a/engine/e2e-test/test_api_engine_uninstall.py +++ b/engine/e2e-test/test_api_engine.py @@ -1,29 +1,49 @@ -import time - import pytest import requests +import time from test_runner import ( - run, - start_server_if_needed, + start_server, stop_server, wait_for_websocket_download_success_event, ) - -class TestApiEngineUninstall: +class TestApiEngine: @pytest.fixture(autouse=True) def setup_and_teardown(self): # Setup - start_server_if_needed() + success = start_server() + if not success: + raise Exception("Failed to start server") yield # Teardown stop_server() + + # engines get + def test_engines_get_llamacpp_should_be_successful(self): + response = requests.get("http://localhost:3928/engines/llama-cpp") + assert response.status_code == 200 + + # engines install + def test_engines_install_llamacpp_specific_version_and_variant(self): + data = {"version": "v0.1.35-27.10.24", "variant": "linux-amd64-avx-cuda-11-7"} + response = requests.post( + "http://localhost:3928/v1/engines/llama-cpp/install", json=data + ) + assert response.status_code == 200 + def test_engines_install_llamacpp_specific_version_and_null_variant(self): + data = {"version": "v0.1.35-27.10.24"} + response = requests.post( + "http://localhost:3928/v1/engines/llama-cpp/install", json=data + ) + assert response.status_code == 200 + + # engines uninstall @pytest.mark.asyncio - async def test_engines_uninstall_llamacpp_should_be_successful(self): + async def test_engines_install_uninstall_llamacpp_should_be_successful(self): response = requests.post("http://localhost:3928/v1/engines/llama-cpp/install") assert response.status_code == 200 await wait_for_websocket_download_success_event(timeout=None) @@ -33,7 +53,7 @@ async def test_engines_uninstall_llamacpp_should_be_successful(self): assert response.status_code == 200 @pytest.mark.asyncio - async def test_engines_uninstall_llamacpp_with_only_version_should_be_failed(self): + async def test_engines_install_uninstall_llamacpp_with_only_version_should_be_failed(self): # install first data = {"variant": "mac-arm64"} install_response = requests.post( @@ -50,7 +70,7 @@ async def test_engines_uninstall_llamacpp_with_only_version_should_be_failed(sel assert response.json()["message"] == "No variant provided" @pytest.mark.asyncio - async def test_engines_uninstall_llamacpp_with_variant_should_be_successful(self): + async def test_engines_install_uninstall_llamacpp_with_variant_should_be_successful(self): # install first data = {"variant": "mac-arm64"} install_response = requests.post( @@ -62,7 +82,7 @@ async def test_engines_uninstall_llamacpp_with_variant_should_be_successful(self response = requests.delete("http://127.0.0.1:3928/v1/engines/llama-cpp/install") assert response.status_code == 200 - def test_engines_uninstall_llamacpp_with_specific_variant_and_version_should_be_successful( + def test_engines_install_uninstall_llamacpp_with_specific_variant_and_version_should_be_successful( self, ): data = {"variant": "mac-arm64", "version": "v0.1.35"} @@ -76,3 +96,5 @@ def test_engines_uninstall_llamacpp_with_specific_variant_and_version_should_be_ "http://localhost:3928/v1/engines/llama-cpp/install", json=data ) assert response.status_code == 200 + + \ No newline at end of file diff --git a/engine/e2e-test/test_api_engine_get.py b/engine/e2e-test/test_api_engine_get.py deleted file mode 100644 index baa9c8037..000000000 --- a/engine/e2e-test/test_api_engine_get.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest -import requests -from test_runner import start_server, stop_server - - -class TestApiEngineGet: - - @pytest.fixture(autouse=True) - def setup_and_teardown(self): - # Setup - success = start_server() - if not success: - raise Exception("Failed to start server") - - yield - - # Teardown - stop_server() - - def test_engines_get_llamacpp_should_be_successful(self): - response = requests.get("http://localhost:3928/engines/llama-cpp") - assert response.status_code == 200 diff --git a/engine/e2e-test/test_api_engine_install.py b/engine/e2e-test/test_api_engine_install.py deleted file mode 100644 index aabe0138d..000000000 --- a/engine/e2e-test/test_api_engine_install.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest -import requests -from test_runner import start_server, stop_server - - -class TestApiEngineInstall: - - @pytest.fixture(autouse=True) - def setup_and_teardown(self): - # Setup - success = start_server() - if not success: - raise Exception("Failed to start server") - - yield - - # Teardown - stop_server() - - def test_engines_install_llamacpp_should_be_successful(self): - response = requests.post("http://localhost:3928/v1/engines/llama-cpp/install") - assert response.status_code == 200 - - def test_engines_install_llamacpp_specific_version_and_variant(self): - data = {"version": "v0.1.35-27.10.24", "variant": "linux-amd64-avx-cuda-11-7"} - response = requests.post( - "http://localhost:3928/v1/engines/llama-cpp/install", json=data - ) - assert response.status_code == 200 - - def test_engines_install_llamacpp_specific_version_and_null_variant(self): - data = {"version": "v0.1.35-27.10.24"} - response = requests.post( - "http://localhost:3928/v1/engines/llama-cpp/install", json=data - ) - assert response.status_code == 200 diff --git a/engine/e2e-test/test_api_engine_list.py b/engine/e2e-test/test_api_engine_list.py index 71b9ea8b4..f149f1450 100644 --- a/engine/e2e-test/test_api_engine_list.py +++ b/engine/e2e-test/test_api_engine_list.py @@ -22,4 +22,4 @@ def setup_and_teardown(self): def test_engines_list_api_run_successfully(self): response = requests.get("http://localhost:3928/engines") - assert response.status_code == 200 + assert response.status_code == 200 \ No newline at end of file diff --git a/engine/e2e-test/test_api_model_pull_direct_url.py b/engine/e2e-test/test_api_model.py similarity index 53% rename from engine/e2e-test/test_api_model_pull_direct_url.py rename to engine/e2e-test/test_api_model.py index 604f216f8..c2723d2ca 100644 --- a/engine/e2e-test/test_api_model_pull_direct_url.py +++ b/engine/e2e-test/test_api_model.py @@ -1,5 +1,6 @@ import pytest import requests +import time from test_runner import ( run, start_server, @@ -7,27 +8,22 @@ wait_for_websocket_download_success_event, ) - -class TestApiModelPullDirectUrl: +class TestApiModel: @pytest.fixture(autouse=True) def setup_and_teardown(self): - # Setup - stop_server() + # Setup success = start_server() if not success: raise Exception("Failed to start server") # Delete model if exists - run( - "Delete model", - [ - "models", - "delete", - "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf", - ], - ) yield # Teardown + stop_server() + + # Pull with direct url + @pytest.mark.asyncio + async def test_model_pull_with_direct_url_should_be_success(self): run( "Delete model", [ @@ -36,10 +32,7 @@ def setup_and_teardown(self): "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf", ], ) - stop_server() - - @pytest.mark.asyncio - async def test_model_pull_with_direct_url_should_be_success(self): + myobj = { "model": "https://huggingface.co/afrideva/zephyr-smol_llama-100m-sft-full-GGUF/blob/main/zephyr-smol_llama-100m-sft-full.q2_k.gguf" } @@ -54,6 +47,15 @@ async def test_model_pull_with_direct_url_should_be_success(self): get_model_response.json()["model"] == "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf" ) + + run( + "Delete model", + [ + "models", + "delete", + "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf", + ], + ) @pytest.mark.asyncio async def test_model_pull_with_direct_url_should_have_desired_name(self): @@ -73,3 +75,58 @@ async def test_model_pull_with_direct_url_should_have_desired_name(self): get_model_response.json()["name"] == "smol_llama_100m" ) + + run( + "Delete model", + [ + "models", + "delete", + "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf", + ], + ) + + async def test_models_start_stop_should_be_successful(self): + print("Install engine") + response = requests.post("http://localhost:3928/v1/engines/llama-cpp/install") + assert response.status_code == 200 + await wait_for_websocket_download_success_event(timeout=None) + # TODO(sang) need to fix for cuda download + time.sleep(30) + + print("Pull model") + json_body = {"model": "tinyllama:gguf"} + response = requests.post("http://localhost:3928/v1/models/pull", json=json_body) + assert response.status_code == 200, f"Failed to pull model: tinyllama:gguf" + await wait_for_websocket_download_success_event(timeout=None) + + # get API + print("Get model") + response = requests.get("http://localhost:3928/v1/models/tinyllama:gguf") + assert response.status_code == 200 + + # list API + print("List model") + response = requests.get("http://localhost:3928/v1/models") + assert response.status_code == 200 + + print("Start model") + json_body = {"model": "tinyllama:gguf"} + response = requests.post( + "http://localhost:3928/v1/models/start", json=json_body + ) + assert response.status_code == 200, f"status_code: {response.status_code}" + + print("Stop model") + response = requests.post("http://localhost:3928/v1/models/stop", json=json_body) + assert response.status_code == 200, f"status_code: {response.status_code}" + + # update API + print("Update model") + body_json = {'model': 'tinyllama:gguf'} + response = requests.patch("http://localhost:3928/v1/models/tinyllama:gguf", json = body_json) + assert response.status_code == 200 + + # delete API + print("Delete model") + response = requests.delete("http://localhost:3928/v1/models/tinyllama:gguf") + assert response.status_code == 200 \ No newline at end of file diff --git a/engine/e2e-test/test_api_model_delete.py b/engine/e2e-test/test_api_model_delete.py deleted file mode 100644 index 455032a9b..000000000 --- a/engine/e2e-test/test_api_model_delete.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest -import requests -from test_runner import start_server, stop_server - - -class TestApiModelDelete: - - @pytest.fixture(autouse=True) - def setup_and_teardown(self): - # Setup - success = start_server() - if not success: - raise Exception("Failed to start server") - - yield - - # Teardown - stop_server() - - def test_models_delete_should_be_successful(self): - response = requests.delete("http://localhost:3928/v1/models/tinyllama:gguf") - assert response.status_code == 200 diff --git a/engine/e2e-test/test_api_model_get.py b/engine/e2e-test/test_api_model_get.py deleted file mode 100644 index dd58ca2a4..000000000 --- a/engine/e2e-test/test_api_model_get.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest -import requests -from test_runner import popen, run -from test_runner import start_server, stop_server - - -class TestApiModelGet: - - @pytest.fixture(autouse=True) - def setup_and_teardown(self): - # Setup - success = start_server() - if not success: - raise Exception("Failed to start server") - - yield - - stop_server() - - def test_models_get_should_be_successful(self): - response = requests.get("http://localhost:3928/v1/models/tinyllama:gguf") - assert response.status_code == 200 diff --git a/engine/e2e-test/test_api_model_list.py b/engine/e2e-test/test_api_model_list.py deleted file mode 100644 index 5e2a4b901..000000000 --- a/engine/e2e-test/test_api_model_list.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest -import requests -from test_runner import start_server, stop_server - - -class TestApiModelList: - - @pytest.fixture(autouse=True) - def setup_and_teardown(self): - # Setup - success = start_server() - if not success: - raise Exception("Failed to start server") - - yield - - # Teardown - stop_server() - - def test_models_list_should_be_successful(self): - response = requests.get("http://localhost:3928/v1/models") - assert response.status_code == 200 diff --git a/engine/e2e-test/test_api_model_start_stop.py b/engine/e2e-test/test_api_model_start_stop.py deleted file mode 100644 index 78c20e8da..000000000 --- a/engine/e2e-test/test_api_model_start_stop.py +++ /dev/null @@ -1,46 +0,0 @@ -import time - -import pytest -import requests -from test_runner import ( - run, - start_server_if_needed, - stop_server, - wait_for_websocket_download_success_event, -) - - -class TestApiModelStartStop: - - @pytest.fixture(autouse=True) - def setup_and_teardown(self): - # Setup - start_server_if_needed() - run("Delete model", ["models", "delete", "tinyllama:gguf"]) - - yield - - # Teardown - stop_server() - - @pytest.mark.asyncio - async def test_models_start_should_be_successful(self): - response = requests.post("http://localhost:3928/v1/engines/llama-cpp/install") - assert response.status_code == 200 - await wait_for_websocket_download_success_event(timeout=None) - # TODO(sang) need to fix for cuda download - time.sleep(30) - - json_body = {"model": "tinyllama:gguf"} - response = requests.post("http://localhost:3928/v1/models/pull", json=json_body) - assert response.status_code == 200, f"Failed to pull model: tinyllama:gguf" - await wait_for_websocket_download_success_event(timeout=None) - - json_body = {"model": "tinyllama:gguf"} - response = requests.post( - "http://localhost:3928/v1/models/start", json=json_body - ) - assert response.status_code == 200, f"status_code: {response.status_code}" - - response = requests.post("http://localhost:3928/v1/models/stop", json=json_body) - assert response.status_code == 200, f"status_code: {response.status_code}" diff --git a/engine/e2e-test/test_api_model_update.py b/engine/e2e-test/test_api_model_update.py deleted file mode 100644 index f862c8907..000000000 --- a/engine/e2e-test/test_api_model_update.py +++ /dev/null @@ -1,23 +0,0 @@ -import pytest -import requests -from test_runner import popen, run -from test_runner import start_server, stop_server - - -class TestApiModelUpdate: - - @pytest.fixture(autouse=True) - def setup_and_teardown(self): - # Setup - success = start_server() - if not success: - raise Exception("Failed to start server") - - yield - - stop_server() - - def test_models_update_should_be_successful(self): - body_json = {'model': 'tinyllama:gguf'} - response = requests.patch("http://localhost:3928/v1/models/tinyllama:gguf", json = body_json) - assert response.status_code == 200 diff --git a/engine/e2e-test/test_cli_model_delete.py b/engine/e2e-test/test_cli_model.py similarity index 58% rename from engine/e2e-test/test_cli_model_delete.py rename to engine/e2e-test/test_cli_model.py index 06cc3a4c3..f6aad4ae9 100644 --- a/engine/e2e-test/test_cli_model_delete.py +++ b/engine/e2e-test/test_cli_model.py @@ -1,5 +1,7 @@ import pytest import requests +import os +from pathlib import Path from test_runner import ( run, start_server, @@ -7,8 +9,7 @@ wait_for_websocket_download_success_event, ) - -class TestCliModelDelete: +class TestCliModel: @pytest.fixture(autouse=True) def setup_and_teardown(self): @@ -23,7 +24,20 @@ def setup_and_teardown(self): # Clean up run("Delete model", ["models", "delete", "tinyllama:gguf"]) stop_server() - + + def test_model_pull_with_direct_url_should_be_success(self): + exit_code, output, error = run( + "Pull model", + [ + "pull", + "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/blob/main/tinyllama-1.1b-chat-v0.3.Q2_K.gguf", + ], + timeout=None, capture=False + ) + root = Path.home() + assert os.path.exists(root / "cortexcpp" / "models" / "huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf") + assert exit_code == 0, f"Model pull failed with error: {error}" + @pytest.mark.asyncio async def test_models_delete_should_be_successful(self): json_body = {"model": "tinyllama:gguf"} @@ -35,4 +49,4 @@ async def test_models_delete_should_be_successful(self): "Delete model", ["models", "delete", "tinyllama:gguf"] ) assert "Model tinyllama:gguf deleted successfully" in output - assert exit_code == 0, f"Model does not exist: {error}" + assert exit_code == 0, f"Model does not exist: {error}" \ No newline at end of file diff --git a/engine/e2e-test/test_cli_model_pull_direct_url.py b/engine/e2e-test/test_cli_model_pull_direct_url.py deleted file mode 100644 index b10d1593d..000000000 --- a/engine/e2e-test/test_cli_model_pull_direct_url.py +++ /dev/null @@ -1,32 +0,0 @@ -from test_runner import run -import os -from pathlib import Path - -class TestCliModelPullDirectUrl: - - def setup_and_teardown(self): - # Setup - success = start_server() - if not success: - raise Exception("Failed to start server") - - yield - - # Teardown - stop_server() - - def test_model_pull_with_direct_url_should_be_success(self): - exit_code, output, error = run( - "Pull model", - [ - "pull", - "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/blob/main/tinyllama-1.1b-chat-v0.3.Q2_K.gguf", - ], - timeout=None, capture=False - ) - root = Path.home() - assert os.path.exists(root / "cortexcpp" / "models" / "huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf") - assert exit_code == 0, f"Model pull failed with error: {error}" - # TODO: verify that the model has been pull successfully - # TODO: skip this test. since download model is taking too long - diff --git a/engine/test/components/CMakeLists.txt b/engine/test/components/CMakeLists.txt index 4a15b7c8b..58c5d83d6 100644 --- a/engine/test/components/CMakeLists.txt +++ b/engine/test/components/CMakeLists.txt @@ -12,6 +12,10 @@ add_executable(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/../../services/config_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../services/download_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../database/models.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/config_yaml_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/file_manager_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/curl_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/system_info_utils.cc ) find_package(Drogon CONFIG REQUIRED) diff --git a/engine/utils/config_yaml_utils.cc b/engine/utils/config_yaml_utils.cc new file mode 100644 index 000000000..4d6f47ebe --- /dev/null +++ b/engine/utils/config_yaml_utils.cc @@ -0,0 +1,177 @@ +#include "config_yaml_utils.h" + +namespace config_yaml_utils { +cpp::result CortexConfigMgr::DumpYamlConfig( + const CortexConfig& config, const std::string& path) { + std::lock_guard l(mtx_); + std::filesystem::path config_file_path{path}; + + try { + std::ofstream out_file(config_file_path); + if (!out_file) { + throw std::runtime_error("Failed to open output file."); + } + // Workaround to save file as utf8 BOM + const unsigned char utf8_bom[] = {0xEF, 0xBB, 0xBF}; + out_file.write(reinterpret_cast(utf8_bom), sizeof(utf8_bom)); + YAML::Node node; + node["logFolderPath"] = config.logFolderPath; + node["logLlamaCppPath"] = config.logLlamaCppPath; + node["logTensorrtLLMPath"] = config.logTensorrtLLMPath; + node["logOnnxPath"] = config.logOnnxPath; + node["dataFolderPath"] = config.dataFolderPath; + node["maxLogLines"] = config.maxLogLines; + node["apiServerHost"] = config.apiServerHost; + node["apiServerPort"] = config.apiServerPort; + node["checkedForUpdateAt"] = config.checkedForUpdateAt; + node["checkedForLlamacppUpdateAt"] = config.checkedForLlamacppUpdateAt; + node["latestRelease"] = config.latestRelease; + node["latestLlamacppRelease"] = config.latestLlamacppRelease; + node["huggingFaceToken"] = config.huggingFaceToken; + node["gitHubUserAgent"] = config.gitHubUserAgent; + node["gitHubToken"] = config.gitHubToken; + node["llamacppVariant"] = config.llamacppVariant; + node["llamacppVersion"] = config.llamacppVersion; + node["enableCors"] = config.enableCors; + node["allowedOrigins"] = config.allowedOrigins; + node["proxyUrl"] = config.proxyUrl; + node["verifyProxySsl"] = config.verifyProxySsl; + node["verifyProxyHostSsl"] = config.verifyProxyHostSsl; + node["proxyUsername"] = config.proxyUsername; + node["proxyPassword"] = config.proxyPassword; + node["noProxy"] = config.noProxy; + node["verifyPeerSsl"] = config.verifyPeerSsl; + node["verifyHostSsl"] = config.verifyHostSsl; + + out_file << node; + out_file.close(); + return {}; + } catch (const std::exception& e) { + CTL_ERR("Error writing to file: " << e.what()); + return cpp::fail("Error writing to file: " + std::string(e.what())); + } +} + +CortexConfig CortexConfigMgr::FromYaml(const std::string& path, + const CortexConfig& default_cfg) { + std::unique_lock l(mtx_); + std::filesystem::path config_file_path{path}; + if (!std::filesystem::exists(config_file_path)) { + throw std::runtime_error("File not found: " + path); + } + + try { + auto node = YAML::LoadFile(config_file_path.string()); + bool should_update_config = + (!node["logFolderPath"] || !node["dataFolderPath"] || + !node["maxLogLines"] || !node["apiServerHost"] || + !node["apiServerPort"] || !node["checkedForUpdateAt"] || + !node["checkedForLlamacppUpdateAt"] || !node["latestRelease"] || + !node["latestLlamacppRelease"] || !node["logLlamaCppPath"] || + !node["logOnnxPath"] || !node["logTensorrtLLMPath"] || + !node["huggingFaceToken"] || !node["gitHubUserAgent"] || + !node["gitHubToken"] || !node["llamacppVariant"] || + !node["llamacppVersion"] || !node["enableCors"] || + !node["allowedOrigins"] || !node["proxyUrl"] || + !node["proxyUsername"] || !node["proxyPassword"] || + !node["verifyPeerSsl"] || !node["verifyHostSsl"] || + !node["verifyProxySsl"] || !node["verifyProxyHostSsl"] || + !node["noProxy"]); + + CortexConfig config = { + .logFolderPath = node["logFolderPath"] + ? node["logFolderPath"].as() + : default_cfg.logFolderPath, + .logLlamaCppPath = node["logLlamaCppPath"] + ? node["logLlamaCppPath"].as() + : default_cfg.logLlamaCppPath, + .logTensorrtLLMPath = node["logTensorrtLLMPath"] + ? node["logTensorrtLLMPath"].as() + : default_cfg.logTensorrtLLMPath, + .logOnnxPath = node["logOnnxPath"] + ? node["logOnnxPath"].as() + : default_cfg.logOnnxPath, + .dataFolderPath = node["dataFolderPath"] + ? node["dataFolderPath"].as() + : default_cfg.dataFolderPath, + .maxLogLines = node["maxLogLines"] ? node["maxLogLines"].as() + : default_cfg.maxLogLines, + .apiServerHost = node["apiServerHost"] + ? node["apiServerHost"].as() + : default_cfg.apiServerHost, + .apiServerPort = node["apiServerPort"] + ? node["apiServerPort"].as() + : default_cfg.apiServerPort, + .checkedForUpdateAt = node["checkedForUpdateAt"] + ? node["checkedForUpdateAt"].as() + : default_cfg.checkedForUpdateAt, + .checkedForLlamacppUpdateAt = + node["checkedForLlamacppUpdateAt"] + ? node["checkedForLlamacppUpdateAt"].as() + : default_cfg.checkedForLlamacppUpdateAt, + .latestRelease = node["latestRelease"] + ? node["latestRelease"].as() + : default_cfg.latestRelease, + .latestLlamacppRelease = + node["latestLlamacppRelease"] + ? node["latestLlamacppRelease"].as() + : default_cfg.latestLlamacppRelease, + .huggingFaceToken = node["huggingFaceToken"] + ? node["huggingFaceToken"].as() + : default_cfg.huggingFaceToken, + .gitHubUserAgent = node["gitHubUserAgent"] + ? node["gitHubUserAgent"].as() + : default_cfg.gitHubUserAgent, + .gitHubToken = node["gitHubToken"] + ? node["gitHubToken"].as() + : default_cfg.gitHubToken, + .llamacppVariant = node["llamacppVariant"] + ? node["llamacppVariant"].as() + : default_cfg.llamacppVariant, + .llamacppVersion = node["llamacppVersion"] + ? node["llamacppVersion"].as() + : default_cfg.llamacppVersion, + .enableCors = node["enableCors"] ? node["enableCors"].as() + : default_cfg.enableCors, + .allowedOrigins = + node["allowedOrigins"] + ? node["allowedOrigins"].as>() + : default_cfg.allowedOrigins, + .proxyUrl = node["proxyUrl"] ? node["proxyUrl"].as() + : default_cfg.proxyUrl, + .verifyProxySsl = node["verifyProxySsl"] + ? node["verifyProxySsl"].as() + : default_cfg.verifyProxySsl, + .verifyProxyHostSsl = node["verifyProxyHostSsl"] + ? node["verifyProxyHostSsl"].as() + : default_cfg.verifyProxyHostSsl, + .proxyUsername = node["proxyUsername"] + ? node["proxyUsername"].as() + : default_cfg.proxyUsername, + .proxyPassword = node["proxyPassword"] + ? node["proxyPassword"].as() + : default_cfg.proxyPassword, + .noProxy = node["noProxy"] ? node["noProxy"].as() + : default_cfg.noProxy, + .verifyPeerSsl = node["verifyPeerSsl"] + ? node["verifyPeerSsl"].as() + : default_cfg.verifyPeerSsl, + .verifyHostSsl = node["verifyHostSsl"] + ? node["verifyHostSsl"].as() + : default_cfg.verifyHostSsl, + }; + if (should_update_config) { + l.unlock(); + auto result = DumpYamlConfig(config, path); + if (result.has_error()) { + CTL_ERR("Failed to update config file: " << result.error()); + } + } + return config; + } catch (const YAML::BadFile& e) { + CTL_ERR("Failed to read file: " << e.what()); + throw; + } +} + +} // namespace config_yaml_utils \ No newline at end of file diff --git a/engine/utils/config_yaml_utils.h b/engine/utils/config_yaml_utils.h index 73c990996..aa1b4027e 100644 --- a/engine/utils/config_yaml_utils.h +++ b/engine/utils/config_yaml_utils.h @@ -77,178 +77,10 @@ class CortexConfigMgr { } cpp::result DumpYamlConfig(const CortexConfig& config, - const std::string& path) { - std::lock_guard l(mtx_); - std::filesystem::path config_file_path{path}; - - try { - std::ofstream out_file(config_file_path); - if (!out_file) { - throw std::runtime_error("Failed to open output file."); - } - // Workaround to save file as utf8 BOM - const unsigned char utf8_bom[] = {0xEF, 0xBB, 0xBF}; - out_file.write(reinterpret_cast(utf8_bom), sizeof(utf8_bom)); - YAML::Node node; - node["logFolderPath"] = config.logFolderPath; - node["logLlamaCppPath"] = config.logLlamaCppPath; - node["logTensorrtLLMPath"] = config.logTensorrtLLMPath; - node["logOnnxPath"] = config.logOnnxPath; - node["dataFolderPath"] = config.dataFolderPath; - node["maxLogLines"] = config.maxLogLines; - node["apiServerHost"] = config.apiServerHost; - node["apiServerPort"] = config.apiServerPort; - node["checkedForUpdateAt"] = config.checkedForUpdateAt; - node["checkedForLlamacppUpdateAt"] = config.checkedForLlamacppUpdateAt; - node["latestRelease"] = config.latestRelease; - node["latestLlamacppRelease"] = config.latestLlamacppRelease; - node["huggingFaceToken"] = config.huggingFaceToken; - node["gitHubUserAgent"] = config.gitHubUserAgent; - node["gitHubToken"] = config.gitHubToken; - node["llamacppVariant"] = config.llamacppVariant; - node["llamacppVersion"] = config.llamacppVersion; - node["enableCors"] = config.enableCors; - node["allowedOrigins"] = config.allowedOrigins; - node["proxyUrl"] = config.proxyUrl; - node["verifyProxySsl"] = config.verifyProxySsl; - node["verifyProxyHostSsl"] = config.verifyProxyHostSsl; - node["proxyUsername"] = config.proxyUsername; - node["proxyPassword"] = config.proxyPassword; - node["noProxy"] = config.noProxy; - node["verifyPeerSsl"] = config.verifyPeerSsl; - node["verifyHostSsl"] = config.verifyHostSsl; - - out_file << node; - out_file.close(); - return {}; - } catch (const std::exception& e) { - CTL_ERR("Error writing to file: " << e.what()); - return cpp::fail("Error writing to file: " + std::string(e.what())); - } - } + const std::string& path); CortexConfig FromYaml(const std::string& path, - const CortexConfig& default_cfg) { - std::unique_lock l(mtx_); - std::filesystem::path config_file_path{path}; - if (!std::filesystem::exists(config_file_path)) { - throw std::runtime_error("File not found: " + path); - } - - try { - auto node = YAML::LoadFile(config_file_path.string()); - bool should_update_config = - (!node["logFolderPath"] || !node["dataFolderPath"] || - !node["maxLogLines"] || !node["apiServerHost"] || - !node["apiServerPort"] || !node["checkedForUpdateAt"] || - !node["checkedForLlamacppUpdateAt"] || !node["latestRelease"] || - !node["latestLlamacppRelease"] || !node["logLlamaCppPath"] || - !node["logOnnxPath"] || !node["logTensorrtLLMPath"] || - !node["huggingFaceToken"] || !node["gitHubUserAgent"] || - !node["gitHubToken"] || !node["llamacppVariant"] || - !node["llamacppVersion"] || !node["enableCors"] || - !node["allowedOrigins"] || !node["proxyUrl"] || - !node["proxyUsername"] || !node["proxyPassword"] || - !node["verifyPeerSsl"] || !node["verifyHostSsl"] || - !node["verifyProxySsl"] || !node["verifyProxyHostSsl"] || - !node["noProxy"]); - - CortexConfig config = { - .logFolderPath = node["logFolderPath"] - ? node["logFolderPath"].as() - : default_cfg.logFolderPath, - .logLlamaCppPath = node["logLlamaCppPath"] - ? node["logLlamaCppPath"].as() - : default_cfg.logLlamaCppPath, - .logTensorrtLLMPath = - node["logTensorrtLLMPath"] - ? node["logTensorrtLLMPath"].as() - : default_cfg.logTensorrtLLMPath, - .logOnnxPath = node["logOnnxPath"] - ? node["logOnnxPath"].as() - : default_cfg.logOnnxPath, - .dataFolderPath = node["dataFolderPath"] - ? node["dataFolderPath"].as() - : default_cfg.dataFolderPath, - .maxLogLines = node["maxLogLines"] ? node["maxLogLines"].as() - : default_cfg.maxLogLines, - .apiServerHost = node["apiServerHost"] - ? node["apiServerHost"].as() - : default_cfg.apiServerHost, - .apiServerPort = node["apiServerPort"] - ? node["apiServerPort"].as() - : default_cfg.apiServerPort, - .checkedForUpdateAt = node["checkedForUpdateAt"] - ? node["checkedForUpdateAt"].as() - : default_cfg.checkedForUpdateAt, - .checkedForLlamacppUpdateAt = - node["checkedForLlamacppUpdateAt"] - ? node["checkedForLlamacppUpdateAt"].as() - : default_cfg.checkedForLlamacppUpdateAt, - .latestRelease = node["latestRelease"] - ? node["latestRelease"].as() - : default_cfg.latestRelease, - .latestLlamacppRelease = - node["latestLlamacppRelease"] - ? node["latestLlamacppRelease"].as() - : default_cfg.latestLlamacppRelease, - .huggingFaceToken = node["huggingFaceToken"] - ? node["huggingFaceToken"].as() - : default_cfg.huggingFaceToken, - .gitHubUserAgent = node["gitHubUserAgent"] - ? node["gitHubUserAgent"].as() - : default_cfg.gitHubUserAgent, - .gitHubToken = node["gitHubToken"] - ? node["gitHubToken"].as() - : default_cfg.gitHubToken, - .llamacppVariant = node["llamacppVariant"] - ? node["llamacppVariant"].as() - : default_cfg.llamacppVariant, - .llamacppVersion = node["llamacppVersion"] - ? node["llamacppVersion"].as() - : default_cfg.llamacppVersion, - .enableCors = node["enableCors"] ? node["enableCors"].as() - : default_cfg.enableCors, - .allowedOrigins = - node["allowedOrigins"] - ? node["allowedOrigins"].as>() - : default_cfg.allowedOrigins, - .proxyUrl = node["proxyUrl"] ? node["proxyUrl"].as() - : default_cfg.proxyUrl, - .verifyProxySsl = node["verifyProxySsl"] - ? node["verifyProxySsl"].as() - : default_cfg.verifyProxySsl, - .verifyProxyHostSsl = node["verifyProxyHostSsl"] - ? node["verifyProxyHostSsl"].as() - : default_cfg.verifyProxyHostSsl, - .proxyUsername = node["proxyUsername"] - ? node["proxyUsername"].as() - : default_cfg.proxyUsername, - .proxyPassword = node["proxyPassword"] - ? node["proxyPassword"].as() - : default_cfg.proxyPassword, - .noProxy = node["noProxy"] ? node["noProxy"].as() - : default_cfg.noProxy, - .verifyPeerSsl = node["verifyPeerSsl"] - ? node["verifyPeerSsl"].as() - : default_cfg.verifyPeerSsl, - .verifyHostSsl = node["verifyHostSsl"] - ? node["verifyHostSsl"].as() - : default_cfg.verifyHostSsl, - }; - if (should_update_config) { - l.unlock(); - auto result = DumpYamlConfig(config, path); - if (result.has_error()) { - CTL_ERR("Failed to update config file: " << result.error()); - } - } - return config; - } catch (const YAML::BadFile& e) { - CTL_ERR("Failed to read file: " << e.what()); - throw; - } - } + const CortexConfig& default_cfg); }; } // namespace config_yaml_utils diff --git a/engine/utils/curl_utils.cc b/engine/utils/curl_utils.cc new file mode 100644 index 000000000..71f263a6a --- /dev/null +++ b/engine/utils/curl_utils.cc @@ -0,0 +1,321 @@ +#include "curl_utils.h" + +#include "utils/engine_constants.h" +#include "utils/file_manager_utils.h" +#include "utils/logging_utils.h" + +#include "utils/string_utils.h" +#include "utils/url_parser.h" + +namespace curl_utils { +namespace { +size_t WriteCallback(void* contents, size_t size, size_t nmemb, + std::string* output) { + size_t totalSize = size * nmemb; + output->append((char*)contents, totalSize); + return totalSize; +} + +void SetUpProxy(CURL* handle, const std::string& url) { + auto config = file_manager_utils::GetCortexConfig(); + if (!config.proxyUrl.empty()) { + auto proxy_url = config.proxyUrl; + auto verify_proxy_ssl = config.verifyProxySsl; + auto verify_proxy_host_ssl = config.verifyProxyHostSsl; + + auto verify_ssl = config.verifyPeerSsl; + auto verify_host_ssl = config.verifyHostSsl; + + auto proxy_username = config.proxyUsername; + auto proxy_password = config.proxyPassword; + auto no_proxy = config.noProxy; + + CTL_INF("=== Proxy configuration ==="); + CTL_INF("Request url: " << url); + CTL_INF("Proxy url: " << proxy_url); + CTL_INF("Verify proxy ssl: " << verify_proxy_ssl); + CTL_INF("Verify proxy host ssl: " << verify_proxy_host_ssl); + CTL_INF("Verify ssl: " << verify_ssl); + CTL_INF("Verify host ssl: " << verify_host_ssl); + CTL_INF("No proxy: " << no_proxy); + + curl_easy_setopt(handle, CURLOPT_PROXY, proxy_url.c_str()); + if (string_utils::StartsWith(proxy_url, "https")) { + curl_easy_setopt(handle, CURLOPT_PROXYTYPE, CURLPROXY_HTTPS); + } + curl_easy_setopt(handle, CURLOPT_SSL_VERIFYPEER, verify_ssl ? 1L : 0L); + curl_easy_setopt(handle, CURLOPT_SSL_VERIFYHOST, verify_host_ssl ? 2L : 0L); + + curl_easy_setopt(handle, CURLOPT_PROXY_SSL_VERIFYPEER, + verify_proxy_ssl ? 1L : 0L); + curl_easy_setopt(handle, CURLOPT_PROXY_SSL_VERIFYHOST, + verify_proxy_host_ssl ? 2L : 0L); + + auto proxy_auth = proxy_username + ":" + proxy_password; + curl_easy_setopt(handle, CURLOPT_PROXYUSERPWD, proxy_auth.c_str()); + + curl_easy_setopt(handle, CURLOPT_NOPROXY, no_proxy.c_str()); + } +} +} // namespace + +std::optional> GetHeaders( + const std::string& url) { + auto url_obj = url_parser::FromUrlString(url); + if (url_obj.has_error()) { + return std::nullopt; + } + + if (url_obj->host == kHuggingFaceHost) { + std::unordered_map headers{}; + headers["Content-Type"] = "application/json"; + auto const& token = file_manager_utils::GetCortexConfig().huggingFaceToken; + if (!token.empty()) { + headers["Authorization"] = "Bearer " + token; + + // for debug purpose + auto min_token_size = 6; + if (token.size() < min_token_size) { + CTL_WRN("Hugging Face token is too short"); + } else { + CTL_INF("Using authentication with Hugging Face token: " + + token.substr(token.size() - min_token_size)); + } + } + + return headers; + } + + if (url_obj->host == kGitHubHost) { + std::unordered_map headers{}; + headers["Accept"] = "application/vnd.github.v3+json"; + // github API requires user-agent https://docs.github.com/en/rest/using-the-rest-api/getting-started-with-the-rest-api?apiVersion=2022-11-28#user-agent + auto user_agent = file_manager_utils::GetCortexConfig().gitHubUserAgent; + auto gh_token = file_manager_utils::GetCortexConfig().gitHubToken; + headers["User-Agent"] = + user_agent.empty() ? kDefaultGHUserAgent : user_agent; + if (!gh_token.empty()) { + headers["Authorization"] = "Bearer " + gh_token; + + // for debug purpose + auto min_token_size = 6; + if (gh_token.size() < min_token_size) { + CTL_WRN("Github token is too short"); + } else { + CTL_INF("Using authentication with Github token: " + + gh_token.substr(gh_token.size() - min_token_size)); + } + } + return headers; + } + + return std::nullopt; +} + +cpp::result SimpleGet(const std::string& url, + const int timeout) { + auto curl = curl_easy_init(); + + if (!curl) { + return cpp::fail("Failed to init CURL"); + } + + auto headers = GetHeaders(url); + curl_slist* curl_headers = nullptr; + if (headers.has_value()) { + for (const auto& [key, value] : headers.value()) { + auto header = key + ": " + value; + curl_headers = curl_slist_append(curl_headers, header.c_str()); + } + + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, curl_headers); + } + + std::string readBuffer; + + SetUpProxy(curl, url); + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + if (timeout > 0) { + curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeout); + } + + // Perform the request + auto res = curl_easy_perform(curl); + + curl_slist_free_all(curl_headers); + curl_easy_cleanup(curl); + if (res != CURLE_OK) { + return cpp::fail("CURL request failed: " + + static_cast(curl_easy_strerror(res))); + } + auto http_code = 0; + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); + if (http_code >= 400) { + CTL_ERR("HTTP request failed with status code: " + + std::to_string(http_code)); + return cpp::fail(readBuffer); + } + + return readBuffer; +} + +cpp::result SimpleRequest( + const std::string& url, const RequestType& request_type, + const std::string& body) { + auto curl = curl_easy_init(); + + if (!curl) { + return cpp::fail("Failed to init CURL"); + } + + auto headers = GetHeaders(url); + curl_slist* curl_headers = nullptr; + curl_headers = + curl_slist_append(curl_headers, "Content-Type: application/json"); + curl_headers = curl_slist_append(curl_headers, "Expect:"); + + if (headers.has_value()) { + for (const auto& [key, value] : headers.value()) { + auto header = key + ": " + value; + curl_headers = curl_slist_append(curl_headers, header.c_str()); + } + } + std::string readBuffer; + + SetUpProxy(curl, url); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, curl_headers); + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + if (request_type == RequestType::PATCH) { + curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "PATCH"); + } else if (request_type == RequestType::POST) { + curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "POST"); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + } else if (request_type == RequestType::DEL) { + curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "DELETE"); + } + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, body.length()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); + + // Perform the request + auto res = curl_easy_perform(curl); + + auto http_code = 0L; + curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); + + // Clean up + curl_slist_free_all(curl_headers); + curl_easy_cleanup(curl); + + if (res != CURLE_OK) { + CTL_ERR("CURL request failed: " + std::string(curl_easy_strerror(res))); + return cpp::fail("CURL request failed: " + + static_cast(curl_easy_strerror(res))); + } + + if (http_code >= 400) { + CTL_ERR("HTTP request failed with status code: " + + std::to_string(http_code)); + return cpp::fail(readBuffer); + } + + return readBuffer; +} + +cpp::result ReadRemoteYaml(const std::string& url) { + auto result = SimpleGet(url); + if (result.has_error()) { + CTL_ERR("Failed to get Yaml from " + url + ": " + result.error()); + return cpp::fail(result.error()); + } + + try { + return YAML::Load(result.value()); + } catch (const std::exception& e) { + return cpp::fail("YAML from " + url + + " parsing error: " + std::string(e.what())); + } +} + +cpp::result SimpleGetJson(const std::string& url, + const int timeout) { + auto result = SimpleGet(url, timeout); + if (result.has_error()) { + CTL_ERR("Failed to get JSON from " + url + ": " + result.error()); + return cpp::fail(result.error()); + } + + Json::Value root; + Json::Reader reader; + if (!reader.parse(result.value(), root)) { + return cpp::fail("JSON from " + url + + " parsing error: " + reader.getFormattedErrorMessages()); + } + + return root; +} + +cpp::result SimplePostJson(const std::string& url, + const std::string& body) { + auto result = SimpleRequest(url, RequestType::POST, body); + if (result.has_error()) { + CTL_INF("url: " + url); + CTL_INF("body: " + body); + CTL_ERR("Failed to get JSON from " + url + ": " + result.error()); + return cpp::fail(result.error()); + } + + CTL_INF("Response: " + result.value()); + Json::Value root; + Json::Reader reader; + if (!reader.parse(result.value(), root)) { + return cpp::fail("JSON from " + url + + " parsing error: " + reader.getFormattedErrorMessages()); + } + + return root; +} + +cpp::result SimpleDeleteJson( + const std::string& url, const std::string& body) { + auto result = SimpleRequest(url, RequestType::DEL, body); + if (result.has_error()) { + CTL_ERR("Failed to get JSON from " + url + ": " + result.error()); + return cpp::fail(result.error()); + } + + CTL_INF("Response: " + result.value()); + Json::Value root; + Json::Reader reader; + if (!reader.parse(result.value(), root)) { + return cpp::fail("JSON from " + url + + " parsing error: " + reader.getFormattedErrorMessages()); + } + + return root; +} + +cpp::result SimplePatchJson(const std::string& url, + const std::string& body) { + auto result = SimpleRequest(url, RequestType::PATCH, body); + if (result.has_error()) { + CTL_ERR("Failed to get JSON from " + url + ": " + result.error()); + return cpp::fail(result.error()); + } + + CTL_INF("Response: " + result.value()); + Json::Value root; + Json::Reader reader; + if (!reader.parse(result.value(), root)) { + return cpp::fail("JSON from " + url + + " parsing error: " + reader.getFormattedErrorMessages()); + } + + return root; +} +} // namespace curl_utils \ No newline at end of file diff --git a/engine/utils/curl_utils.h b/engine/utils/curl_utils.h index c56808b56..64b5fc339 100644 --- a/engine/utils/curl_utils.h +++ b/engine/utils/curl_utils.h @@ -5,335 +5,43 @@ #include #include #include +#include #include -#include "utils/engine_constants.h" -#include "utils/file_manager_utils.h" -#include "utils/logging_utils.h" +#include + #include "utils/result.hpp" -#include "utils/string_utils.h" -#include "utils/url_parser.h" enum class RequestType { GET, PATCH, POST, DEL }; namespace curl_utils { -namespace { -size_t WriteCallback(void* contents, size_t size, size_t nmemb, - std::string* output) { - size_t totalSize = size * nmemb; - output->append((char*)contents, totalSize); - return totalSize; -} - -void SetUpProxy(CURL* handle, const std::string& url) { - auto config = file_manager_utils::GetCortexConfig(); - if (!config.proxyUrl.empty()) { - auto proxy_url = config.proxyUrl; - auto verify_proxy_ssl = config.verifyProxySsl; - auto verify_proxy_host_ssl = config.verifyProxyHostSsl; - - auto verify_ssl = config.verifyPeerSsl; - auto verify_host_ssl = config.verifyHostSsl; - - auto proxy_username = config.proxyUsername; - auto proxy_password = config.proxyPassword; - auto no_proxy = config.noProxy; - - CTL_INF("=== Proxy configuration ==="); - CTL_INF("Request url: " << url); - CTL_INF("Proxy url: " << proxy_url); - CTL_INF("Verify proxy ssl: " << verify_proxy_ssl); - CTL_INF("Verify proxy host ssl: " << verify_proxy_host_ssl); - CTL_INF("Verify ssl: " << verify_ssl); - CTL_INF("Verify host ssl: " << verify_host_ssl); - CTL_INF("No proxy: " << no_proxy); - - curl_easy_setopt(handle, CURLOPT_PROXY, proxy_url.c_str()); - if (string_utils::StartsWith(proxy_url, "https")) { - curl_easy_setopt(handle, CURLOPT_PROXYTYPE, CURLPROXY_HTTPS); - } - curl_easy_setopt(handle, CURLOPT_SSL_VERIFYPEER, verify_ssl ? 1L : 0L); - curl_easy_setopt(handle, CURLOPT_SSL_VERIFYHOST, verify_host_ssl ? 2L : 0L); - - curl_easy_setopt(handle, CURLOPT_PROXY_SSL_VERIFYPEER, - verify_proxy_ssl ? 1L : 0L); - curl_easy_setopt(handle, CURLOPT_PROXY_SSL_VERIFYHOST, - verify_proxy_host_ssl ? 2L : 0L); - - auto proxy_auth = proxy_username + ":" + proxy_password; - curl_easy_setopt(handle, CURLOPT_PROXYUSERPWD, proxy_auth.c_str()); - - curl_easy_setopt(handle, CURLOPT_NOPROXY, no_proxy.c_str()); - } -} -} // namespace - -inline std::optional> GetHeaders( +std::optional> GetHeaders( const std::string& url); -inline cpp::result SimpleGet(const std::string& url, - const int timeout = -1) { - auto curl = curl_easy_init(); - - if (!curl) { - return cpp::fail("Failed to init CURL"); - } +cpp::result SimpleGet(const std::string& url, + const int timeout = -1); - auto headers = GetHeaders(url); - curl_slist* curl_headers = nullptr; - if (headers.has_value()) { - for (const auto& [key, value] : headers.value()) { - auto header = key + ": " + value; - curl_headers = curl_slist_append(curl_headers, header.c_str()); - } - - curl_easy_setopt(curl, CURLOPT_HTTPHEADER, curl_headers); - } - - std::string readBuffer; - - SetUpProxy(curl, url); - curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); - if (timeout > 0) { - curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeout); - } - - // Perform the request - auto res = curl_easy_perform(curl); - - curl_slist_free_all(curl_headers); - curl_easy_cleanup(curl); - if (res != CURLE_OK) { - return cpp::fail("CURL request failed: " + - static_cast(curl_easy_strerror(res))); - } - auto http_code = 0; - curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); - if (http_code >= 400) { - CTL_ERR("HTTP request failed with status code: " + - std::to_string(http_code)); - return cpp::fail(readBuffer); - } - - return readBuffer; -} - -inline cpp::result SimpleRequest( +cpp::result SimpleRequest( const std::string& url, const RequestType& request_type, - const std::string& body = "") { - auto curl = curl_easy_init(); - - if (!curl) { - return cpp::fail("Failed to init CURL"); - } - - auto headers = GetHeaders(url); - curl_slist* curl_headers = nullptr; - curl_headers = - curl_slist_append(curl_headers, "Content-Type: application/json"); - curl_headers = curl_slist_append(curl_headers, "Expect:"); - - if (headers.has_value()) { - for (const auto& [key, value] : headers.value()) { - auto header = key + ": " + value; - curl_headers = curl_slist_append(curl_headers, header.c_str()); - } - } - std::string readBuffer; - - SetUpProxy(curl, url); - curl_easy_setopt(curl, CURLOPT_HTTPHEADER, curl_headers); - curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); - if (request_type == RequestType::PATCH) { - curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "PATCH"); - } else if (request_type == RequestType::POST) { - curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "POST"); - curl_easy_setopt(curl, CURLOPT_POST, 1L); - } else if (request_type == RequestType::DEL) { - curl_easy_setopt(curl, CURLOPT_CUSTOMREQUEST, "DELETE"); - } - curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, &readBuffer); + const std::string& body = ""); - curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, body.length()); - curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); - - // Perform the request - auto res = curl_easy_perform(curl); - - auto http_code = 0L; - curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code); - - // Clean up - curl_slist_free_all(curl_headers); - curl_easy_cleanup(curl); - - if (res != CURLE_OK) { - CTL_ERR("CURL request failed: " + std::string(curl_easy_strerror(res))); - return cpp::fail("CURL request failed: " + - static_cast(curl_easy_strerror(res))); - } - - if (http_code >= 400) { - CTL_ERR("HTTP request failed with status code: " + - std::to_string(http_code)); - return cpp::fail(readBuffer); - } - - return readBuffer; -} - -inline cpp::result ReadRemoteYaml( - const std::string& url) { - auto result = SimpleGet(url); - if (result.has_error()) { - CTL_ERR("Failed to get Yaml from " + url + ": " + result.error()); - return cpp::fail(result.error()); - } - - try { - return YAML::Load(result.value()); - } catch (const std::exception& e) { - return cpp::fail("YAML from " + url + - " parsing error: " + std::string(e.what())); - } -} +cpp::result ReadRemoteYaml(const std::string& url); /** * SimpleGetJson is a helper function that sends a GET request to the given URL * * [timeout] is an optional parameter that specifies the timeout for the request. In second. */ -inline cpp::result SimpleGetJson( - const std::string& url, const int timeout = -1) { - auto result = SimpleGet(url, timeout); - if (result.has_error()) { - CTL_ERR("Failed to get JSON from " + url + ": " + result.error()); - return cpp::fail(result.error()); - } - - Json::Value root; - Json::Reader reader; - if (!reader.parse(result.value(), root)) { - return cpp::fail("JSON from " + url + - " parsing error: " + reader.getFormattedErrorMessages()); - } - - return root; -} - -inline cpp::result SimplePostJson( - const std::string& url, const std::string& body = "") { - auto result = SimpleRequest(url, RequestType::POST, body); - if (result.has_error()) { - CTL_INF("url: " + url); - CTL_INF("body: " + body); - CTL_ERR("Failed to get JSON from " + url + ": " + result.error()); - return cpp::fail(result.error()); - } - - CTL_INF("Response: " + result.value()); - Json::Value root; - Json::Reader reader; - if (!reader.parse(result.value(), root)) { - return cpp::fail("JSON from " + url + - " parsing error: " + reader.getFormattedErrorMessages()); - } - - return root; -} - -inline cpp::result SimpleDeleteJson( - const std::string& url, const std::string& body = "") { - auto result = SimpleRequest(url, RequestType::DEL, body); - if (result.has_error()) { - CTL_ERR("Failed to get JSON from " + url + ": " + result.error()); - return cpp::fail(result.error()); - } - - CTL_INF("Response: " + result.value()); - Json::Value root; - Json::Reader reader; - if (!reader.parse(result.value(), root)) { - return cpp::fail("JSON from " + url + - " parsing error: " + reader.getFormattedErrorMessages()); - } - - return root; -} - -inline cpp::result SimplePatchJson( - const std::string& url, const std::string& body = "") { - auto result = SimpleRequest(url, RequestType::PATCH, body); - if (result.has_error()) { - CTL_ERR("Failed to get JSON from " + url + ": " + result.error()); - return cpp::fail(result.error()); - } - - CTL_INF("Response: " + result.value()); - Json::Value root; - Json::Reader reader; - if (!reader.parse(result.value(), root)) { - return cpp::fail("JSON from " + url + - " parsing error: " + reader.getFormattedErrorMessages()); - } - - return root; -} - -inline std::optional> GetHeaders( - const std::string& url) { - auto url_obj = url_parser::FromUrlString(url); - if (url_obj.has_error()) { - return std::nullopt; - } - - if (url_obj->host == kHuggingFaceHost) { - std::unordered_map headers{}; - headers["Content-Type"] = "application/json"; - auto const& token = file_manager_utils::GetCortexConfig().huggingFaceToken; - if (!token.empty()) { - headers["Authorization"] = "Bearer " + token; - - // for debug purpose - auto min_token_size = 6; - if (token.size() < min_token_size) { - CTL_WRN("Hugging Face token is too short"); - } else { - CTL_INF("Using authentication with Hugging Face token: " + - token.substr(token.size() - min_token_size)); - } - } +cpp::result SimpleGetJson(const std::string& url, + const int timeout = -1); - return headers; - } +cpp::result SimplePostJson( + const std::string& url, const std::string& body = ""); - if (url_obj->host == kGitHubHost) { - std::unordered_map headers{}; - headers["Accept"] = "application/vnd.github.v3+json"; - // github API requires user-agent https://docs.github.com/en/rest/using-the-rest-api/getting-started-with-the-rest-api?apiVersion=2022-11-28#user-agent - auto user_agent = file_manager_utils::GetCortexConfig().gitHubUserAgent; - auto gh_token = file_manager_utils::GetCortexConfig().gitHubToken; - headers["User-Agent"] = - user_agent.empty() ? kDefaultGHUserAgent : user_agent; - if (!gh_token.empty()) { - headers["Authorization"] = "Bearer " + gh_token; +cpp::result SimpleDeleteJson( + const std::string& url, const std::string& body = ""); - // for debug purpose - auto min_token_size = 6; - if (gh_token.size() < min_token_size) { - CTL_WRN("Github token is too short"); - } else { - CTL_INF("Using authentication with Github token: " + - gh_token.substr(gh_token.size() - min_token_size)); - } - } - return headers; - } +cpp::result SimplePatchJson( + const std::string& url, const std::string& body = ""); - return std::nullopt; -} } // namespace curl_utils diff --git a/engine/utils/file_manager_utils.cc b/engine/utils/file_manager_utils.cc new file mode 100644 index 000000000..9650dd973 --- /dev/null +++ b/engine/utils/file_manager_utils.cc @@ -0,0 +1,367 @@ +#include "file_manager_utils.h" + +#include "logging_utils.h" + +#include "utils/engine_constants.h" +#include "utils/result.hpp" +#include "utils/widechar_conv.h" + +#if defined(__APPLE__) && defined(__MACH__) +#include +#elif defined(__linux__) +#include +#elif defined(_WIN32) +#include +#include +#include +#endif + +namespace file_manager_utils { +std::filesystem::path GetExecutableFolderContainerPath() { +#if defined(__APPLE__) && defined(__MACH__) + char buffer[1024]; + uint32_t size = sizeof(buffer); + + if (_NSGetExecutablePath(buffer, &size) == 0) { + // CTL_DBG("Executable path: " << buffer); + return std::filesystem::path{buffer}.parent_path(); + } else { + CTL_ERR("Failed to get executable path"); + return std::filesystem::current_path(); + } +#elif defined(__linux__) + char buffer[1024]; + ssize_t len = readlink("/proc/self/exe", buffer, sizeof(buffer) - 1); + if (len != -1) { + buffer[len] = '\0'; + // CTL_DBG("Executable path: " << buffer); + return std::filesystem::path{buffer}.parent_path(); + } else { + CTL_ERR("Failed to get executable path"); + return std::filesystem::current_path(); + } +#elif defined(_WIN32) + wchar_t buffer[MAX_PATH]; + GetModuleFileNameW(NULL, buffer, MAX_PATH); + // CTL_DBG("Executable path: " << buffer); + return std::filesystem::path{buffer}.parent_path(); +#else + LOG_ERROR << "Unsupported platform!"; + return std::filesystem::current_path(); +#endif +} + +std::filesystem::path GetHomeDirectoryPath() { +#ifdef _WIN32 + const wchar_t* homeDir = _wgetenv(L"USERPROFILE"); + if (!homeDir) { + // Fallback if USERPROFILE is not set + const wchar_t* homeDrive = _wgetenv(L"HOMEDRIVE"); + const wchar_t* homePath = _wgetenv(L"HOMEPATH"); + if (homeDrive && homePath) { + return std::filesystem::path(homeDrive) / std::filesystem::path(homePath); + } else { + throw std::runtime_error("Cannot determine the home directory"); + } + } +#else + const char* homeDir = std::getenv("HOME"); + if (!homeDir) { + throw std::runtime_error("Cannot determine the home directory"); + } +#endif + return std::filesystem::path(homeDir); +} + +std::filesystem::path GetConfigurationPath() { +#ifndef CORTEX_CONFIG_FILE_PATH +#define CORTEX_CONFIG_FILE_PATH kDefaultConfigurationPath +#endif + +#ifndef CORTEX_VARIANT +#define CORTEX_VARIANT kProdVariant +#endif + std::string config_file_path; + if (cortex_config_file_path.empty()) { + config_file_path = CORTEX_CONFIG_FILE_PATH; + } else { + config_file_path = cortex_config_file_path; + } + + if (config_file_path != kDefaultConfigurationPath) { +// CTL_INF("Config file path: " + config_file_path); +#if defined(_WIN32) + return std::filesystem::u8path(config_file_path); +#else + return std::filesystem::path(config_file_path); +#endif + } + + std::string variant{CORTEX_VARIANT}; + std::string env_postfix{""}; + if (variant == kBetaVariant) { + env_postfix.append("-").append(kBetaVariant); + } else if (variant == kNightlyVariant) { + env_postfix.append("-").append(kNightlyVariant); + } + + std::string config_file_name{kCortexConfigurationFileName}; + config_file_name.append(env_postfix); + // CTL_INF("Config file name: " + config_file_name); + + auto home_path = GetHomeDirectoryPath(); + auto configuration_path = home_path / config_file_name; + return configuration_path; +} + +std::string GetDefaultDataFolderName() { +#ifndef CORTEX_VARIANT +#define CORTEX_VARIANT "prod" +#endif + std::string default_data_folder_name{kCortexFolderName}; + std::string variant{CORTEX_VARIANT}; + std::string env_postfix{""}; + if (variant == kBetaVariant) { + env_postfix.append("-").append(kBetaVariant); + } else if (variant == kNightlyVariant) { + env_postfix.append("-").append(kNightlyVariant); + } + default_data_folder_name.append(env_postfix); + return default_data_folder_name; +} + +cpp::result UpdateCortexConfig( + const config_yaml_utils::CortexConfig& config) { + auto config_path = GetConfigurationPath(); + if (!std::filesystem::exists(config_path)) { + CTL_ERR("Config file not found: " << config_path.string()); + return cpp::fail("Config file not found: " + config_path.string()); + } + + return cyu::CortexConfigMgr::GetInstance().DumpYamlConfig( + config, config_path.string()); +} + +config_yaml_utils::CortexConfig GetDefaultConfig() { + auto config_path = GetConfigurationPath(); + auto default_data_folder_name = GetDefaultDataFolderName(); + auto default_data_folder_path = + cortex_data_folder_path.empty() + ? file_manager_utils::GetHomeDirectoryPath() / + default_data_folder_name + : std::filesystem::path(cortex_data_folder_path); + + return config_yaml_utils::CortexConfig{ +#if defined(_WIN32) + .logFolderPath = + cortex::wc::WstringToUtf8(default_data_folder_path.wstring()), +#else + .logFolderPath = default_data_folder_path.string(), +#endif + .logLlamaCppPath = kLogsLlamacppBaseName, + .logTensorrtLLMPath = kLogsTensorrtllmBaseName, + .logOnnxPath = kLogsOnnxBaseName, +#if defined(_WIN32) + .dataFolderPath = + cortex::wc::WstringToUtf8(default_data_folder_path.wstring()), +#else + .dataFolderPath = default_data_folder_path.string(), +#endif + .maxLogLines = config_yaml_utils::kDefaultMaxLines, + .apiServerHost = config_yaml_utils::kDefaultHost, + .apiServerPort = config_yaml_utils::kDefaultPort, + .checkedForUpdateAt = config_yaml_utils::kDefaultCheckedForUpdateAt, + .checkedForLlamacppUpdateAt = + config_yaml_utils::kDefaultCheckedForLlamacppUpdateAt, + .latestRelease = config_yaml_utils::kDefaultLatestRelease, + .latestLlamacppRelease = config_yaml_utils::kDefaultLatestLlamacppRelease, + .enableCors = config_yaml_utils::kDefaultCorsEnabled, + .allowedOrigins = config_yaml_utils::kDefaultEnabledOrigins, + .proxyUrl = "", + .verifyProxySsl = true, + .verifyProxyHostSsl = true, + .proxyUsername = "", + .proxyPassword = "", + .noProxy = config_yaml_utils::kDefaultNoProxy, + .verifyPeerSsl = true, + .verifyHostSsl = true, + }; +} + +cpp::result CreateConfigFileIfNotExist() { + auto config_path = GetConfigurationPath(); + if (std::filesystem::exists(config_path)) { + // already exists, no need to create + return {}; + } + + CLI_LOG("Config file not found. Creating one at " + config_path.string()); + auto config = GetDefaultConfig(); + CLI_LOG("Default data folder path: " + config.dataFolderPath); + return cyu::CortexConfigMgr::GetInstance().DumpYamlConfig( + config, config_path.string()); +} + +config_yaml_utils::CortexConfig GetCortexConfig() { + auto config_path = GetConfigurationPath(); + + auto default_cfg = GetDefaultConfig(); + return config_yaml_utils::CortexConfigMgr::GetInstance().FromYaml( + config_path.string(), default_cfg); +} + +std::filesystem::path GetCortexDataPath() { + auto result = CreateConfigFileIfNotExist(); + if (result.has_error()) { + CTL_ERR("Error creating config file: " << result.error()); + return std::filesystem::path{}; + } + + auto config = GetCortexConfig(); + std::filesystem::path data_folder_path; + if (!config.dataFolderPath.empty()) { +#if defined(_WIN32) + data_folder_path = std::filesystem::u8path(config.dataFolderPath); +#else + data_folder_path = std::filesystem::path(config.dataFolderPath); +#endif + } else { + auto home_path = GetHomeDirectoryPath(); + data_folder_path = home_path / kCortexFolderName; + } + + if (!std::filesystem::exists(data_folder_path)) { + CLI_LOG("Cortex home folder not found. Create one: " + + data_folder_path.string()); + std::filesystem::create_directory(data_folder_path); + } + return data_folder_path; +} + +std::filesystem::path GetCortexLogPath() { + // TODO: We will need to support user to move the data folder to other place. + // TODO: get the variant of cortex. As discussed, we will have: prod, beta, nightly + + // currently we will store cortex data at ~/cortexcpp + auto config = GetCortexConfig(); + std::filesystem::path log_folder_path; + if (!config.logFolderPath.empty()) { + log_folder_path = std::filesystem::path(config.logFolderPath); + } else { + auto home_path = GetHomeDirectoryPath(); + log_folder_path = home_path / kCortexFolderName; + } + + if (!std::filesystem::exists(log_folder_path)) { + CTL_INF("Cortex log folder not found. Create one: " + + log_folder_path.string()); + std::filesystem::create_directory(log_folder_path); + } + return log_folder_path; +} + +void CreateDirectoryRecursively(const std::string& path) { + // Create the directories if they don't exist + if (std::filesystem::create_directories(path)) { + CTL_INF(path + " successfully created!"); + } else { + CTL_INF(path + " already exist!"); + } +} + +std::filesystem::path GetModelsContainerPath() { + auto result = CreateConfigFileIfNotExist(); + if (result.has_error()) { + CTL_ERR("Error creating config file: " << result.error()); + } + auto cortex_path = GetCortexDataPath(); + auto models_container_path = cortex_path / "models"; + + if (!std::filesystem::exists(models_container_path)) { + CTL_INF("Model container folder not found. Create one: " + << models_container_path.string()); + std::filesystem::create_directories(models_container_path); + } + + return models_container_path; +} + +std::filesystem::path GetCudaToolkitPath(const std::string& engine) { + auto engine_path = getenv("ENGINE_PATH") + ? std::filesystem::path(getenv("ENGINE_PATH")) + : GetCortexDataPath(); + + auto cuda_path = engine_path / "engines" / engine / "deps"; + if (!std::filesystem::exists(cuda_path)) { + std::filesystem::create_directories(cuda_path); + } + + return cuda_path; +} + +std::filesystem::path GetEnginesContainerPath() { + auto cortex_path = getenv("ENGINE_PATH") + ? std::filesystem::path(getenv("ENGINE_PATH")) + : GetCortexDataPath(); + auto engines_container_path = cortex_path / "engines"; + + if (!std::filesystem::exists(engines_container_path)) { + CTL_INF("Engine container folder not found. Create one: " + << engines_container_path.string()); + std::filesystem::create_directory(engines_container_path); + } + + return engines_container_path; +} + +std::filesystem::path GetContainerFolderPath(const std::string_view type) { + std::filesystem::path container_folder_path; + + if (type == "Model") { + container_folder_path = GetModelsContainerPath(); + } else if (type == "Engine") { + container_folder_path = GetEnginesContainerPath(); + } else if (type == "CudaToolkit") { + container_folder_path = + std::filesystem::temp_directory_path() / "cuda-dependencies"; + } else if (type == "Cortex") { + container_folder_path = std::filesystem::temp_directory_path() / "cortex"; + } else { + container_folder_path = std::filesystem::temp_directory_path() / "misc"; + } + + if (!std::filesystem::exists(container_folder_path)) { + CTL_INF("Creating folder: " << container_folder_path.string() << "\n"); + std::filesystem::create_directories(container_folder_path); + } + + return container_folder_path; +} + +std::string DownloadTypeToString(DownloadType type) { + switch (type) { + case DownloadType::Model: + return "Model"; + case DownloadType::Engine: + return "Engine"; + case DownloadType::Miscellaneous: + return "Misc"; + case DownloadType::CudaToolkit: + return "CudaToolkit"; + case DownloadType::Cortex: + return "Cortex"; + default: + return "UNKNOWN"; + } +} + +std::filesystem::path ToRelativeCortexDataPath( + const std::filesystem::path& path) { + return Subtract(path, GetCortexDataPath()); +} + +std::filesystem::path ToAbsoluteCortexDataPath( + const std::filesystem::path& path) { + return GetAbsolutePath(GetCortexDataPath(), path); +} +} // namespace file_manager_utils \ No newline at end of file diff --git a/engine/utils/file_manager_utils.h b/engine/utils/file_manager_utils.h index 72310385c..a7a1b09c2 100644 --- a/engine/utils/file_manager_utils.h +++ b/engine/utils/file_manager_utils.h @@ -3,21 +3,7 @@ #include #include #include "common/download_task.h" -#include "logging_utils.h" #include "utils/config_yaml_utils.h" -#include "utils/engine_constants.h" -#include "utils/result.hpp" -#include "utils/widechar_conv.h" - -#if defined(__APPLE__) && defined(__MACH__) -#include -#elif defined(__linux__) -#include -#elif defined(_WIN32) -#include -#include -#include -#endif namespace file_manager_utils { namespace cyu = config_yaml_utils; @@ -34,344 +20,38 @@ inline std::string cortex_config_file_path; inline std::string cortex_data_folder_path; -inline std::filesystem::path GetExecutableFolderContainerPath() { -#if defined(__APPLE__) && defined(__MACH__) - char buffer[1024]; - uint32_t size = sizeof(buffer); - - if (_NSGetExecutablePath(buffer, &size) == 0) { - // CTL_DBG("Executable path: " << buffer); - return std::filesystem::path{buffer}.parent_path(); - } else { - CTL_ERR("Failed to get executable path"); - return std::filesystem::current_path(); - } -#elif defined(__linux__) - char buffer[1024]; - ssize_t len = readlink("/proc/self/exe", buffer, sizeof(buffer) - 1); - if (len != -1) { - buffer[len] = '\0'; - // CTL_DBG("Executable path: " << buffer); - return std::filesystem::path{buffer}.parent_path(); - } else { - CTL_ERR("Failed to get executable path"); - return std::filesystem::current_path(); - } -#elif defined(_WIN32) - wchar_t buffer[MAX_PATH]; - GetModuleFileNameW(NULL, buffer, MAX_PATH); - // CTL_DBG("Executable path: " << buffer); - return std::filesystem::path{buffer}.parent_path(); -#else - LOG_ERROR << "Unsupported platform!"; - return std::filesystem::current_path(); -#endif -} - -inline std::filesystem::path GetHomeDirectoryPath() { -#ifdef _WIN32 - const wchar_t* homeDir = _wgetenv(L"USERPROFILE"); - if (!homeDir) { - // Fallback if USERPROFILE is not set - const wchar_t* homeDrive = _wgetenv(L"HOMEDRIVE"); - const wchar_t* homePath = _wgetenv(L"HOMEPATH"); - if (homeDrive && homePath) { - return std::filesystem::path(homeDrive) / std::filesystem::path(homePath); - } else { - throw std::runtime_error("Cannot determine the home directory"); - } - } -#else - const char* homeDir = std::getenv("HOME"); - if (!homeDir) { - throw std::runtime_error("Cannot determine the home directory"); - } -#endif - return std::filesystem::path(homeDir); -} - -inline std::filesystem::path GetConfigurationPath() { -#ifndef CORTEX_CONFIG_FILE_PATH -#define CORTEX_CONFIG_FILE_PATH kDefaultConfigurationPath -#endif - -#ifndef CORTEX_VARIANT -#define CORTEX_VARIANT kProdVariant -#endif - std::string config_file_path; - if (cortex_config_file_path.empty()) { - config_file_path = CORTEX_CONFIG_FILE_PATH; - } else { - config_file_path = cortex_config_file_path; - } - - if (config_file_path != kDefaultConfigurationPath) { -// CTL_INF("Config file path: " + config_file_path); -#if defined(_WIN32) - return std::filesystem::u8path(config_file_path); -#else - return std::filesystem::path(config_file_path); -#endif - } - - std::string variant{CORTEX_VARIANT}; - std::string env_postfix{""}; - if (variant == kBetaVariant) { - env_postfix.append("-").append(kBetaVariant); - } else if (variant == kNightlyVariant) { - env_postfix.append("-").append(kNightlyVariant); - } - - std::string config_file_name{kCortexConfigurationFileName}; - config_file_name.append(env_postfix); - // CTL_INF("Config file name: " + config_file_name); - - auto home_path = GetHomeDirectoryPath(); - auto configuration_path = home_path / config_file_name; - return configuration_path; -} - -inline std::string GetDefaultDataFolderName() { -#ifndef CORTEX_VARIANT -#define CORTEX_VARIANT "prod" -#endif - std::string default_data_folder_name{kCortexFolderName}; - std::string variant{CORTEX_VARIANT}; - std::string env_postfix{""}; - if (variant == kBetaVariant) { - env_postfix.append("-").append(kBetaVariant); - } else if (variant == kNightlyVariant) { - env_postfix.append("-").append(kNightlyVariant); - } - default_data_folder_name.append(env_postfix); - return default_data_folder_name; -} - -inline cpp::result UpdateCortexConfig( - const config_yaml_utils::CortexConfig& config) { - auto config_path = GetConfigurationPath(); - if (!std::filesystem::exists(config_path)) { - CTL_ERR("Config file not found: " << config_path.string()); - return cpp::fail("Config file not found: " + config_path.string()); - } - - return cyu::CortexConfigMgr::GetInstance().DumpYamlConfig( - config, config_path.string()); -} - -inline config_yaml_utils::CortexConfig GetDefaultConfig() { - auto config_path = GetConfigurationPath(); - auto default_data_folder_name = GetDefaultDataFolderName(); - auto default_data_folder_path = - cortex_data_folder_path.empty() - ? file_manager_utils::GetHomeDirectoryPath() / - default_data_folder_name - : std::filesystem::path(cortex_data_folder_path); - - return config_yaml_utils::CortexConfig{ -#if defined(_WIN32) - .logFolderPath = - cortex::wc::WstringToUtf8(default_data_folder_path.wstring()), -#else - .logFolderPath = default_data_folder_path.string(), -#endif - .logLlamaCppPath = kLogsLlamacppBaseName, - .logTensorrtLLMPath = kLogsTensorrtllmBaseName, - .logOnnxPath = kLogsOnnxBaseName, -#if defined(_WIN32) - .dataFolderPath = - cortex::wc::WstringToUtf8(default_data_folder_path.wstring()), -#else - .dataFolderPath = default_data_folder_path.string(), -#endif - .maxLogLines = config_yaml_utils::kDefaultMaxLines, - .apiServerHost = config_yaml_utils::kDefaultHost, - .apiServerPort = config_yaml_utils::kDefaultPort, - .checkedForUpdateAt = config_yaml_utils::kDefaultCheckedForUpdateAt, - .checkedForLlamacppUpdateAt = - config_yaml_utils::kDefaultCheckedForLlamacppUpdateAt, - .latestRelease = config_yaml_utils::kDefaultLatestRelease, - .latestLlamacppRelease = config_yaml_utils::kDefaultLatestLlamacppRelease, - .enableCors = config_yaml_utils::kDefaultCorsEnabled, - .allowedOrigins = config_yaml_utils::kDefaultEnabledOrigins, - .proxyUrl = "", - .verifyProxySsl = true, - .verifyProxyHostSsl = true, - .proxyUsername = "", - .proxyPassword = "", - .noProxy = config_yaml_utils::kDefaultNoProxy, - .verifyPeerSsl = true, - .verifyHostSsl = true, - }; -} - -inline cpp::result CreateConfigFileIfNotExist() { - auto config_path = GetConfigurationPath(); - if (std::filesystem::exists(config_path)) { - // already exists, no need to create - return {}; - } - - CLI_LOG("Config file not found. Creating one at " + config_path.string()); - auto config = GetDefaultConfig(); - CLI_LOG("Default data folder path: " + config.dataFolderPath); - return cyu::CortexConfigMgr::GetInstance().DumpYamlConfig( - config, config_path.string()); -} - -inline config_yaml_utils::CortexConfig GetCortexConfig() { - auto config_path = GetConfigurationPath(); - - auto default_cfg = GetDefaultConfig(); - return config_yaml_utils::CortexConfigMgr::GetInstance().FromYaml( - config_path.string(), default_cfg); -} - -inline std::filesystem::path GetCortexDataPath() { - auto result = CreateConfigFileIfNotExist(); - if (result.has_error()) { - CTL_ERR("Error creating config file: " << result.error()); - return std::filesystem::path{}; - } - - auto config = GetCortexConfig(); - std::filesystem::path data_folder_path; - if (!config.dataFolderPath.empty()) { -#if defined(_WIN32) - data_folder_path = std::filesystem::u8path(config.dataFolderPath); -#else - data_folder_path = std::filesystem::path(config.dataFolderPath); -#endif - } else { - auto home_path = GetHomeDirectoryPath(); - data_folder_path = home_path / kCortexFolderName; - } - - if (!std::filesystem::exists(data_folder_path)) { - CLI_LOG("Cortex home folder not found. Create one: " + - data_folder_path.string()); - std::filesystem::create_directory(data_folder_path); - } - return data_folder_path; -} - -inline std::filesystem::path GetCortexLogPath() { - // TODO: We will need to support user to move the data folder to other place. - // TODO: get the variant of cortex. As discussed, we will have: prod, beta, nightly - - // currently we will store cortex data at ~/cortexcpp - auto config = GetCortexConfig(); - std::filesystem::path log_folder_path; - if (!config.logFolderPath.empty()) { - log_folder_path = std::filesystem::path(config.logFolderPath); - } else { - auto home_path = GetHomeDirectoryPath(); - log_folder_path = home_path / kCortexFolderName; - } - - if (!std::filesystem::exists(log_folder_path)) { - CTL_INF("Cortex log folder not found. Create one: " + - log_folder_path.string()); - std::filesystem::create_directory(log_folder_path); - } - return log_folder_path; -} +std::filesystem::path GetExecutableFolderContainerPath(); -inline void CreateDirectoryRecursively(const std::string& path) { - // Create the directories if they don't exist - if (std::filesystem::create_directories(path)) { - CTL_INF(path + " successfully created!"); - } else { - CTL_INF(path + " already exist!"); - } -} +std::filesystem::path GetHomeDirectoryPath(); -inline std::filesystem::path GetModelsContainerPath() { - auto result = CreateConfigFileIfNotExist(); - if (result.has_error()) { - CTL_ERR("Error creating config file: " << result.error()); - } - auto cortex_path = GetCortexDataPath(); - auto models_container_path = cortex_path / "models"; +std::filesystem::path GetConfigurationPath(); - if (!std::filesystem::exists(models_container_path)) { - CTL_INF("Model container folder not found. Create one: " - << models_container_path.string()); - std::filesystem::create_directories(models_container_path); - } +std::string GetDefaultDataFolderName(); - return models_container_path; -} +cpp::result UpdateCortexConfig( + const config_yaml_utils::CortexConfig& config); -inline std::filesystem::path GetCudaToolkitPath(const std::string& engine) { - auto engine_path = getenv("ENGINE_PATH") - ? std::filesystem::path(getenv("ENGINE_PATH")) - : GetCortexDataPath(); +config_yaml_utils::CortexConfig GetDefaultConfig(); - auto cuda_path = engine_path / "engines" / engine / "deps"; - if (!std::filesystem::exists(cuda_path)) { - std::filesystem::create_directories(cuda_path); - } +cpp::result CreateConfigFileIfNotExist(); - return cuda_path; -} +config_yaml_utils::CortexConfig GetCortexConfig(); -inline std::filesystem::path GetEnginesContainerPath() { - auto cortex_path = getenv("ENGINE_PATH") - ? std::filesystem::path(getenv("ENGINE_PATH")) - : GetCortexDataPath(); - auto engines_container_path = cortex_path / "engines"; +std::filesystem::path GetCortexDataPath(); - if (!std::filesystem::exists(engines_container_path)) { - CTL_INF("Engine container folder not found. Create one: " - << engines_container_path.string()); - std::filesystem::create_directory(engines_container_path); - } +std::filesystem::path GetCortexLogPath(); - return engines_container_path; -} +void CreateDirectoryRecursively(const std::string& path); -inline std::filesystem::path GetContainerFolderPath( - const std::string_view type) { - std::filesystem::path container_folder_path; +std::filesystem::path GetModelsContainerPath(); - if (type == "Model") { - container_folder_path = GetModelsContainerPath(); - } else if (type == "Engine") { - container_folder_path = GetEnginesContainerPath(); - } else if (type == "CudaToolkit") { - container_folder_path = - std::filesystem::temp_directory_path() / "cuda-dependencies"; - } else if (type == "Cortex") { - container_folder_path = std::filesystem::temp_directory_path() / "cortex"; - } else { - container_folder_path = std::filesystem::temp_directory_path() / "misc"; - } +std::filesystem::path GetCudaToolkitPath(const std::string& engine); - if (!std::filesystem::exists(container_folder_path)) { - CTL_INF("Creating folder: " << container_folder_path.string() << "\n"); - std::filesystem::create_directories(container_folder_path); - } +std::filesystem::path GetEnginesContainerPath(); - return container_folder_path; -} +std::filesystem::path GetContainerFolderPath(const std::string_view type); -inline std::string DownloadTypeToString(DownloadType type) { - switch (type) { - case DownloadType::Model: - return "Model"; - case DownloadType::Engine: - return "Engine"; - case DownloadType::Miscellaneous: - return "Misc"; - case DownloadType::CudaToolkit: - return "CudaToolkit"; - case DownloadType::Cortex: - return "Cortex"; - default: - return "UNKNOWN"; - } -} +std::string DownloadTypeToString(DownloadType type); inline std::filesystem::path GetAbsolutePath(const std::filesystem::path& base, const std::filesystem::path& r) { @@ -399,14 +79,10 @@ inline std::filesystem::path Subtract(const std::filesystem::path& path, } } -inline std::filesystem::path ToRelativeCortexDataPath( - const std::filesystem::path& path) { - return Subtract(path, GetCortexDataPath()); -} +std::filesystem::path ToRelativeCortexDataPath( + const std::filesystem::path& path); -inline std::filesystem::path ToAbsoluteCortexDataPath( - const std::filesystem::path& path) { - return GetAbsolutePath(GetCortexDataPath(), path); -} +std::filesystem::path ToAbsoluteCortexDataPath( + const std::filesystem::path& path); } // namespace file_manager_utils diff --git a/engine/utils/huggingface_utils.h b/engine/utils/huggingface_utils.h index 99df2aa77..f2895c363 100644 --- a/engine/utils/huggingface_utils.h +++ b/engine/utils/huggingface_utils.h @@ -5,6 +5,7 @@ #include #include #include "utils/curl_utils.h" +#include "utils/engine_constants.h" #include "utils/json_parser_utils.h" #include "utils/result.hpp" #include "utils/url_parser.h" diff --git a/engine/utils/system_info_utils.cc b/engine/utils/system_info_utils.cc new file mode 100644 index 000000000..e80bce035 --- /dev/null +++ b/engine/utils/system_info_utils.cc @@ -0,0 +1,141 @@ +#include "system_info_utils.h" +#include "utils/logging_utils.h" + +namespace system_info_utils { +std::pair GetDriverAndCudaVersion() { + if (!IsNvidiaSmiAvailable()) { + CTL_INF("nvidia-smi is not available!"); + return {}; + } + try { + std::string driver_version; + std::string cuda_version; + CommandExecutor cmd("nvidia-smi"); + auto output = cmd.execute(); + + const std::regex driver_version_reg(kDriverVersionRegex); + std::smatch driver_match; + + if (std::regex_search(output, driver_match, driver_version_reg)) { + LOG_INFO << "Gpu Driver Version: " << driver_match[1].str(); + driver_version = driver_match[1].str(); + } else { + LOG_ERROR << "Gpu Driver not found!"; + return {}; + } + + const std::regex cuda_version_reg(kCudaVersionRegex); + std::smatch cuda_match; + + if (std::regex_search(output, cuda_match, cuda_version_reg)) { + LOG_INFO << "CUDA Version: " << cuda_match[1].str(); + cuda_version = cuda_match[1].str(); + } else { + LOG_ERROR << "CUDA Version not found!"; + return {}; + } + return std::pair(driver_version, cuda_version); + } catch (const std::exception& e) { + LOG_ERROR << "Error: " << e.what(); + return {}; + } +} + +std::vector GetGpuInfoListVulkan() { + std::vector gpuInfoList; + + try { + // NOTE: current ly we don't have logic to download vulkaninfoSDK +#ifdef _WIN32 + CommandExecutor cmd("vulkaninfoSDK.exe --summary"); +#else + CommandExecutor cmd("vulkaninfoSDK --summary"); +#endif + auto output = cmd.execute(); + + // Regular expression patterns to match each field + std::regex gpu_block_reg(R"(GPU(\d+):)"); + std::regex field_pattern(R"(\s*(\w+)\s*=\s*(.*))"); + + std::sregex_iterator iter(output.begin(), output.end(), gpu_block_reg); + std::sregex_iterator end; + + while (iter != end) { + GpuInfo gpuInfo; + + // Extract GPU ID from the GPU block pattern (e.g., GPU0 -> id = "0") + gpuInfo.id = (*iter)[1].str(); + + auto gpu_start_pos = iter->position(0) + iter->length(0); + auto gpu_end_pos = std::next(iter) != end ? std::next(iter)->position(0) + : std::string::npos; + std::string gpu_block = + output.substr(gpu_start_pos, gpu_end_pos - gpu_start_pos); + + std::sregex_iterator field_iter(gpu_block.begin(), gpu_block.end(), + field_pattern); + + while (field_iter != end) { + std::string key = (*field_iter)[1].str(); + std::string value = (*field_iter)[2].str(); + + if (key == "deviceName") + gpuInfo.name = value; + else if (key == "apiVersion") + gpuInfo.compute_cap = value; + + gpuInfo.vram_total = ""; // not available + gpuInfo.arch = GetGpuArch(gpuInfo.name); + + ++field_iter; + } + + gpuInfoList.push_back(gpuInfo); + ++iter; + } + } catch (const std::exception& e) { + LOG_ERROR << "Error: " << e.what(); + } + + return gpuInfoList; +} + +std::vector GetGpuInfoList() { + std::vector gpuInfoList; + if (!IsNvidiaSmiAvailable()) + return gpuInfoList; + try { + auto [driver_version, cuda_version] = GetDriverAndCudaVersion(); + if (driver_version.empty() || cuda_version.empty()) + return gpuInfoList; + + CommandExecutor cmd(kGpuQueryCommand); + auto output = cmd.execute(); + + const std::regex gpu_info_reg(kGpuInfoRegex); + std::smatch match; + std::string::const_iterator search_start(output.cbegin()); + + while ( + std::regex_search(search_start, output.cend(), match, gpu_info_reg)) { + GpuInfo gpuInfo = { + match[1].str(), // id + match[2].str(), // vram_total + match[3].str(), // vram_free + match[4].str(), // name + GetGpuArch(match[4].str()), // arch + driver_version, // driver_version + cuda_version, // cuda_driver_version + match[5].str(), // compute_cap + match[6].str() // uuid + }; + gpuInfoList.push_back(gpuInfo); + search_start = match.suffix().first; + } + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + } + + return gpuInfoList; +} +} // namespace system_info_utils \ No newline at end of file diff --git a/engine/utils/system_info_utils.h b/engine/utils/system_info_utils.h index f2fab10cb..0907884be 100644 --- a/engine/utils/system_info_utils.h +++ b/engine/utils/system_info_utils.h @@ -8,7 +8,7 @@ #include #include "utils/command_executor.h" #include "utils/engine_constants.h" -#include "utils/logging_utils.h" + #ifdef _WIN32 #include #endif @@ -101,44 +101,7 @@ inline bool IsNvidiaSmiAvailable() { #endif } -inline std::pair GetDriverAndCudaVersion() { - if (!IsNvidiaSmiAvailable()) { - CTL_INF("nvidia-smi is not available!"); - return {}; - } - try { - std::string driver_version; - std::string cuda_version; - CommandExecutor cmd("nvidia-smi"); - auto output = cmd.execute(); - - const std::regex driver_version_reg(kDriverVersionRegex); - std::smatch driver_match; - - if (std::regex_search(output, driver_match, driver_version_reg)) { - LOG_INFO << "Gpu Driver Version: " << driver_match[1].str(); - driver_version = driver_match[1].str(); - } else { - LOG_ERROR << "Gpu Driver not found!"; - return {}; - } - - const std::regex cuda_version_reg(kCudaVersionRegex); - std::smatch cuda_match; - - if (std::regex_search(output, cuda_match, cuda_version_reg)) { - LOG_INFO << "CUDA Version: " << cuda_match[1].str(); - cuda_version = cuda_match[1].str(); - } else { - LOG_ERROR << "CUDA Version not found!"; - return {}; - } - return std::pair(driver_version, cuda_version); - } catch (const std::exception& e) { - LOG_ERROR << "Error: " << e.what(); - return {}; - } -} +std::pair GetDriverAndCudaVersion(); struct GpuInfo { std::string id; @@ -153,101 +116,7 @@ struct GpuInfo { std::string uuid; }; -inline std::vector GetGpuInfoListVulkan() { - std::vector gpuInfoList; - - try { - // NOTE: current ly we don't have logic to download vulkaninfoSDK -#ifdef _WIN32 - CommandExecutor cmd("vulkaninfoSDK.exe --summary"); -#else - CommandExecutor cmd("vulkaninfoSDK --summary"); -#endif - auto output = cmd.execute(); - - // Regular expression patterns to match each field - std::regex gpu_block_reg(R"(GPU(\d+):)"); - std::regex field_pattern(R"(\s*(\w+)\s*=\s*(.*))"); - - std::sregex_iterator iter(output.begin(), output.end(), gpu_block_reg); - std::sregex_iterator end; - - while (iter != end) { - GpuInfo gpuInfo; - - // Extract GPU ID from the GPU block pattern (e.g., GPU0 -> id = "0") - gpuInfo.id = (*iter)[1].str(); - - auto gpu_start_pos = iter->position(0) + iter->length(0); - auto gpu_end_pos = std::next(iter) != end ? std::next(iter)->position(0) - : std::string::npos; - std::string gpu_block = - output.substr(gpu_start_pos, gpu_end_pos - gpu_start_pos); +std::vector GetGpuInfoListVulkan(); - std::sregex_iterator field_iter(gpu_block.begin(), gpu_block.end(), - field_pattern); - - while (field_iter != end) { - std::string key = (*field_iter)[1].str(); - std::string value = (*field_iter)[2].str(); - - if (key == "deviceName") - gpuInfo.name = value; - else if (key == "apiVersion") - gpuInfo.compute_cap = value; - - gpuInfo.vram_total = ""; // not available - gpuInfo.arch = GetGpuArch(gpuInfo.name); - - ++field_iter; - } - - gpuInfoList.push_back(gpuInfo); - ++iter; - } - } catch (const std::exception& e) { - LOG_ERROR << "Error: " << e.what(); - } - - return gpuInfoList; -} - -inline std::vector GetGpuInfoList() { - std::vector gpuInfoList; - if (!IsNvidiaSmiAvailable()) - return gpuInfoList; - try { - auto [driver_version, cuda_version] = GetDriverAndCudaVersion(); - if (driver_version.empty() || cuda_version.empty()) - return gpuInfoList; - - CommandExecutor cmd(kGpuQueryCommand); - auto output = cmd.execute(); - - const std::regex gpu_info_reg(kGpuInfoRegex); - std::smatch match; - std::string::const_iterator search_start(output.cbegin()); - - while ( - std::regex_search(search_start, output.cend(), match, gpu_info_reg)) { - GpuInfo gpuInfo = { - match[1].str(), // id - match[2].str(), // vram_total - match[3].str(), // vram_free - match[4].str(), // name - GetGpuArch(match[4].str()), // arch - driver_version, // driver_version - cuda_version, // cuda_driver_version - match[5].str(), // compute_cap - match[6].str() // uuid - }; - gpuInfoList.push_back(gpuInfo); - search_start = match.suffix().first; - } - } catch (const std::exception& e) { - std::cerr << "Error: " << e.what() << std::endl; - } - - return gpuInfoList; -} +std::vector GetGpuInfoList(); } // namespace system_info_utils From 7d6199dd1ad92ab9ad0d815899ab569c633a0b4e Mon Sep 17 00:00:00 2001 From: James Date: Sat, 23 Nov 2024 13:24:13 +0700 Subject: [PATCH 09/44] feat: add messages api --- engine/CMakeLists.txt | 3 +- engine/cli/commands/engine_install_cmd.cc | 1 - .../messages/delete_message_response.h | 19 + engine/common/json_serializable.h | 11 + engine/common/message.h | 213 ++++++ engine/common/message_attachment.h | 50 ++ engine/common/message_attachment_factory.h | 48 ++ engine/common/message_content.h | 23 + engine/common/message_content_factory.h | 77 ++ engine/common/message_content_image_file.h | 69 ++ engine/common/message_content_image_url.h | 71 ++ engine/common/message_content_refusal.h | 46 ++ engine/common/message_content_text.h | 242 ++++++ engine/common/message_incomplete_detail.h | 32 + engine/common/message_role.h | 30 + engine/common/message_status.h | 34 + engine/common/repository/message_repository.h | 27 + engine/common/variant_map.h | 62 ++ engine/controllers/messages.cc | 298 ++++++++ engine/controllers/messages.h | 60 ++ engine/main.cc | 9 + engine/repositories/message_fs_repository.cc | 226 ++++++ engine/repositories/message_fs_repository.h | 39 + engine/services/hardware_service.cc | 3 +- engine/services/message_service.cc | 105 +++ engine/services/message_service.h | 39 + engine/utils/file_manager_utils.cc | 5 + engine/utils/file_manager_utils.h | 2 + engine/utils/ulid/ulid.hh | 16 + engine/utils/ulid/ulid_struct.hh | 710 ++++++++++++++++++ engine/utils/ulid/ulid_uint128.hh | 561 ++++++++++++++ 31 files changed, 3127 insertions(+), 4 deletions(-) create mode 100644 engine/common/api-dto/messages/delete_message_response.h create mode 100644 engine/common/json_serializable.h create mode 100644 engine/common/message.h create mode 100644 engine/common/message_attachment.h create mode 100644 engine/common/message_attachment_factory.h create mode 100644 engine/common/message_content.h create mode 100644 engine/common/message_content_factory.h create mode 100644 engine/common/message_content_image_file.h create mode 100644 engine/common/message_content_image_url.h create mode 100644 engine/common/message_content_refusal.h create mode 100644 engine/common/message_content_text.h create mode 100644 engine/common/message_incomplete_detail.h create mode 100644 engine/common/message_role.h create mode 100644 engine/common/message_status.h create mode 100644 engine/common/repository/message_repository.h create mode 100644 engine/common/variant_map.h create mode 100644 engine/controllers/messages.cc create mode 100644 engine/controllers/messages.h create mode 100644 engine/repositories/message_fs_repository.cc create mode 100644 engine/repositories/message_fs_repository.h create mode 100644 engine/services/message_service.cc create mode 100644 engine/services/message_service.h create mode 100644 engine/utils/ulid/ulid.hh create mode 100644 engine/utils/ulid/ulid_struct.hh create mode 100644 engine/utils/ulid/ulid_uint128.hh diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 06e778b7e..eae09d439 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -169,6 +169,7 @@ else() endif() aux_source_directory(controllers CTL_SRC) +aux_source_directory(repositories REPO_SRC) aux_source_directory(services SERVICES_SRC) aux_source_directory(common COMMON_SRC) aux_source_directory(models MODEL_SRC) @@ -180,7 +181,7 @@ aux_source_directory(utils UTILS_SRC) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ) -target_sources(${TARGET_NAME} PRIVATE ${UTILS_SRC} ${CONFIG_SRC} ${CTL_SRC} ${COMMON_SRC} ${SERVICES_SRC} ${DB_SRC} ${MIGR_SRC}) +target_sources(${TARGET_NAME} PRIVATE ${UTILS_SRC} ${CONFIG_SRC} ${CTL_SRC} ${COMMON_SRC} ${SERVICES_SRC} ${DB_SRC} ${MIGR_SRC} ${REPO_SRC}) set_target_properties(${TARGET_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_BINARY_DIR} diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index 21cd9f042..491ab0937 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -179,7 +179,6 @@ bool EngineInstallCmd::Exec(const std::string& engine, auto response = curl_utils::SimplePostJson(install_url.ToFullPath(), body.toStyledString()); if (response.has_error()) { - // TODO: namh refactor later Json::Value root; Json::Reader reader; if (!reader.parse(response.error(), root)) { diff --git a/engine/common/api-dto/messages/delete_message_response.h b/engine/common/api-dto/messages/delete_message_response.h new file mode 100644 index 000000000..79447c93a --- /dev/null +++ b/engine/common/api-dto/messages/delete_message_response.h @@ -0,0 +1,19 @@ +#pragma once + +#include "common/json_serializable.h" + +namespace api_response { +struct DeleteMessageResponse : JsonSerializable { + std::string id; + std::string object; + bool deleted; + + cpp::result ToJson() override { + Json::Value json; + json["id"] = id; + json["object"] = object; + json["deleted"] = deleted; + return json; + } +}; +} // namespace api_response diff --git a/engine/common/json_serializable.h b/engine/common/json_serializable.h new file mode 100644 index 000000000..4afec92c5 --- /dev/null +++ b/engine/common/json_serializable.h @@ -0,0 +1,11 @@ +#pragma once + +#include +#include "utils/result.hpp" + +struct JsonSerializable { + + virtual cpp::result ToJson() = 0; + + virtual ~JsonSerializable() = default; +}; diff --git a/engine/common/message.h b/engine/common/message.h new file mode 100644 index 000000000..e5685f3bb --- /dev/null +++ b/engine/common/message.h @@ -0,0 +1,213 @@ +#pragma once + +#include +#include +#include +#include +#include +#include "common/message_attachment.h" +#include "common/message_attachment_factory.h" +#include "common/message_content.h" +#include "common/message_content_factory.h" +#include "common/message_incomplete_detail.h" +#include "common/message_role.h" +#include "common/message_status.h" +#include "common/variant_map.h" +#include "json_serializable.h" +#include "utils/logging_utils.h" +#include "utils/result.hpp" + +namespace ThreadMessage { + +// Represents a message within a thread. +struct Message : JsonSerializable { + Message() = default; + + Message(Message&&) = default; + + Message& operator=(Message&&) = default; + + Message(const Message&) = delete; + + Message& operator=(const Message&) = delete; + + // The identifier, which can be referenced in API endpoints. + std::string id; + + // The object type, which is always thread.message. + std::string object = "thread.message"; + + // The Unix timestamp (in seconds) for when the message was created. + uint32_t created_at; + + // The thread ID that this message belongs to. + std::string thread_id; + + // The status of the message, which can be either in_progress, incomplete, or completed. + Status status; + + // On an incomplete message, details about why the message is incomplete. + std::optional incomplete_details; + + // The Unix timestamp (in seconds) for when the message was completed. + std::optional completed_at; + + // The Unix timestamp (in seconds) for when the message was marked as incomplete. + std::optional incomplete_at; + + Role role; + + // The content of the message in array of text and/or images. + std::vector> content; + + // If applicable, the ID of the assistant that authored this message. + std::optional assistant_id; + + // The ID of the run associated with the creation of this message. Value is null when messages are created manually using the create message or create thread endpoints. + std::optional run_id; + + // A list of files attached to the message, and the tools they were added to. + std::optional> attachments; + + // Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. + Cortex::VariantMap metadata; + + static cpp::result FromJsonString( + std::string&& json_str) { + Json::Value root; + Json::Reader reader; + if (!reader.parse(json_str, root)) { + return cpp::fail("Failed to parse JSON: " + + reader.getFormattedErrorMessages()); + } + + Message message; + + try { + message.id = std::move(root["id"].asString()); + message.object = + std::move(root.get("object", "thread.message").asString()); + message.created_at = root["created_at"].asUInt(); + if (message.created_at == 0 && root["created"].asUInt64() != 0) { + message.created_at = root["created"].asUInt64() / 1000; + } + message.thread_id = std::move(root["thread_id"].asString()); + message.status = StatusFromString(std::move(root["status"].asString())); + + message.incomplete_details = + IncompleteDetail::FromJson(std::move(root["incomplete_details"])) + .value(); + message.completed_at = root["completed_at"].asUInt(); + message.incomplete_at = root["incomplete_at"].asUInt(); + message.role = RoleFromString(std::move(root["role"].asString())); + message.content = ParseContents(std::move(root["content"])).value(); + + message.assistant_id = std::move(root["assistant_id"].asString()); + message.run_id = std::move(root["run_id"].asString()); + message.attachments = + ParseAttachments(std::move(root["attachments"])).value(); + + if (root["metadata"].isObject() && !root["metadata"].empty()) { + auto res = Cortex::ConvertJsonValueToMap(root["metadata"]); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + message.metadata = res.value(); + } + } + + return message; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJsonString failed: ") + e.what()); + } + } + + cpp::result ToSingleLineJsonString() { + auto json_result = ToJson(); + if (json_result.has_error()) { + return cpp::fail(json_result.error()); + } + + Json::FastWriter writer; + try { + return writer.write(json_result.value()); + } catch (const std::exception& e) { + return cpp::fail(std::string("Failed to write JSON: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value json; + + json["id"] = id; + json["object"] = object; + json["created_at"] = created_at; + json["thread_id"] = thread_id; + json["status"] = StatusToString(status); + + if (incomplete_details.has_value()) { + if (auto it = incomplete_details->ToJson(); it.has_value()) { + json["incomplete_details"] = it.value(); + } else { + CTL_WRN("Failed to convert incomplete_details to json: " + + it.error()); + } + } + if (completed_at.has_value() && completed_at.value() != 0) { + json["completed_at"] = *completed_at; + } + if (incomplete_at.has_value() && incomplete_at.value() != 0) { + json["incomplete_at"] = *incomplete_at; + } + + json["role"] = RoleToString(role); + + Json::Value content_json_arr{Json::arrayValue}; + for (auto& child_content : content) { + if (auto it = child_content->ToJson(); it.has_value()) { + content_json_arr.append(it.value()); + } else { + CTL_WRN("Failed to convert content to json: " + it.error()); + } + } + json["content"] = content_json_arr; + if (assistant_id.has_value() && !assistant_id->empty()) { + json["assistant_id"] = *assistant_id; + } + if (run_id.has_value() && !run_id->empty()) { + json["run_id"] = *run_id; + } + if (attachments.has_value()) { + Json::Value attachments_json_arr{Json::arrayValue}; + for (auto& attachment : *attachments) { + if (auto it = attachment.ToJson(); it.has_value()) { + attachments_json_arr.append(it.value()); + } else { + CTL_WRN("Failed to convert attachment to json: " + it.error()); + } + } + json["attachments"] = attachments_json_arr; + } + + Json::Value metadata_json{Json::objectValue}; + for (const auto& [key, value] : metadata) { + if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else { + metadata_json[key] = std::get(value); + } + } + json["metadata"] = metadata_json; + + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +}; // namespace ThreadMessage diff --git a/engine/common/message_attachment.h b/engine/common/message_attachment.h new file mode 100644 index 000000000..ea809990e --- /dev/null +++ b/engine/common/message_attachment.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include "common/json_serializable.h" + +namespace ThreadMessage { + +// The tools to add this file to. +struct Tool { + std::string type; + + Tool(const std::string& type) : type{type} {} +}; + +// The type of tool being defined: code_interpreter +struct CodeInterpreter : Tool { + CodeInterpreter() : Tool{"code_interpreter"} {} +}; + +// The type of tool being defined: file_search +struct FileSearch : Tool { + FileSearch() : Tool{"file_search"} {} +}; + +// A list of files attached to the message, and the tools they were added to. +struct Attachment : JsonSerializable { + + // The ID of the file to attach to the message. + std::string file_id; + + std::vector tools; + + cpp::result ToJson() override { + try { + Json::Value json; + json["file_id"] = file_id; + Json::Value tools_json_arr{Json::arrayValue}; + for (auto& tool : tools) { + Json::Value tool_json; + tool_json["type"] = tool.type; + tools_json_arr.append(tool_json); + } + json["tools"] = tools_json_arr; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +}; // namespace ThreadMessage diff --git a/engine/common/message_attachment_factory.h b/engine/common/message_attachment_factory.h new file mode 100644 index 000000000..d9f1b8d2e --- /dev/null +++ b/engine/common/message_attachment_factory.h @@ -0,0 +1,48 @@ +#include +#include "common/message_attachment.h" +#include "utils/result.hpp" + +namespace ThreadMessage { +inline cpp::result ParseAttachment( + Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + + Attachment attachment; + attachment.file_id = json["file_id"].asString(); + + std::vector tools{}; + if (json["tools"].isArray()) { + for (auto& tool_json : json["tools"]) { + Tool tool{tool_json["type"].asString()}; + tools.push_back(tool); + } + } + attachment.tools = tools; + + return attachment; +} + +inline cpp::result>, std::string> +ParseAttachments(Json::Value&& json) { + if (json.empty()) { + // still count as success + return std::nullopt; + } + if (!json.isArray()) { + return cpp::fail("Json is not an array"); + } + + std::vector attachments; + for (auto& attachment_json : json) { + auto attachment = ParseAttachment(std::move(attachment_json)); + if (attachment.has_error()) { + return cpp::fail(attachment.error()); + } + attachments.push_back(attachment.value()); + } + + return attachments; +} +}; // namespace ThreadMessage diff --git a/engine/common/message_content.h b/engine/common/message_content.h new file mode 100644 index 000000000..6e76b01a8 --- /dev/null +++ b/engine/common/message_content.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include "common/json_serializable.h" + +namespace ThreadMessage { + +struct Content : JsonSerializable { + std::string type; + + Content(const std::string& type) : type{type} {} + + Content(const Content&) = delete; + + Content& operator=(const Content&) = delete; + + Content(Content&&) noexcept = default; + + Content& operator=(Content&&) noexcept = default; + + virtual ~Content() = default; +}; +}; // namespace ThreadMessage diff --git a/engine/common/message_content_factory.h b/engine/common/message_content_factory.h new file mode 100644 index 000000000..854f6efd8 --- /dev/null +++ b/engine/common/message_content_factory.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include "common/message_content_image_file.h" +#include "common/message_content_image_url.h" +#include "common/message_content_refusal.h" +#include "common/message_content_text.h" +#include "utils/logging_utils.h" +#include "utils/result.hpp" + +namespace ThreadMessage { +inline cpp::result, std::string> ParseContent( + Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + + try { + auto type = json["type"].asString(); + + if (type == "image_file") { + auto result = ImageFileContent::FromJson(std::move(json)); + if (result.has_error()) { + return cpp::fail(result.error()); + } + return std::make_unique(std::move(result.value())); + } else if (type == "image_url") { + auto result = ImageUrlContent::FromJson(std::move(json)); + if (result.has_error()) { + return cpp::fail(result.error()); + } + return std::make_unique(std::move(result.value())); + } else if (type == "text") { + auto result = TextContent::FromJson(std::move(json)); + if (result.has_error()) { + return cpp::fail(result.error()); + } + return std::make_unique(std::move(result.value())); + } else if (type == "refusal") { + auto result = Refusal::FromJson(std::move(json)); + if (result.has_error()) { + return cpp::fail(result.error()); + } + return std::make_unique(std::move(result.value())); + } else { + return cpp::fail("Unknown content type: " + type); + } + + return cpp::fail("Unknown content type: " + type); + } catch (const std::exception& e) { + return cpp::fail(std::string("ParseContent failed: ") + e.what()); + } +} + +inline cpp::result>, std::string> +ParseContents(Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + if (!json.isArray()) { + return cpp::fail("Json is not an array"); + } + + std::vector> contents; + Json::Value mutable_json = std::move(json); + + for (auto& content_json : mutable_json) { + auto content = ParseContent(std::move(content_json)); + if (content.has_error()) { + CTL_WRN(content.error()); + continue; + } + contents.push_back(std::move(content.value())); + } + return contents; +} +} // namespace ThreadMessage diff --git a/engine/common/message_content_image_file.h b/engine/common/message_content_image_file.h new file mode 100644 index 000000000..1807dec1e --- /dev/null +++ b/engine/common/message_content_image_file.h @@ -0,0 +1,69 @@ +#pragma once + +#include "common/message_content.h" + +namespace ThreadMessage { +struct ImageFile { + // The File ID of the image in the message content. Set purpose="vision" when uploading the File if you need to later display the file content. + std::string file_id; + + // Specifies the detail level of the image if specified by the user. low uses fewer tokens, you can opt in to high resolution using high. + std::string detail; + + ImageFile() = default; + + ImageFile(ImageFile&&) noexcept = default; + + ImageFile& operator=(ImageFile&&) noexcept = default; + + ImageFile(const ImageFile&) = delete; + + ImageFile& operator=(const ImageFile&) = delete; +}; + +// References an image File in the content of a message. +struct ImageFileContent : Content { + + ImageFileContent() : Content("image_file") {} + + ImageFileContent(ImageFileContent&&) noexcept = default; + + ImageFileContent& operator=(ImageFileContent&&) noexcept = default; + + ImageFileContent(const ImageFileContent&) = delete; + + ImageFileContent& operator=(const ImageFileContent&) = delete; + + ImageFile image_file; + + static cpp::result FromJson( + Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + + try { + ImageFileContent content; + ImageFile image_file; + image_file.detail = std::move(json["image_file"]["detail"].asString()); + image_file.file_id = std::move(json["image_file"]["file_id"].asString()); + content.image_file = std::move(image_file); + return content; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value json; + json["type"] = type; + json["image_file"]["file_id"] = image_file.file_id; + json["image_file"]["detail"] = image_file.detail; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +} // namespace ThreadMessage diff --git a/engine/common/message_content_image_url.h b/engine/common/message_content_image_url.h new file mode 100644 index 000000000..eae6a7aa6 --- /dev/null +++ b/engine/common/message_content_image_url.h @@ -0,0 +1,71 @@ +#pragma once + +#include "common/message_content.h" + +namespace ThreadMessage { + +struct ImageUrl { + // The external URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp. + std::string url; + + // Specifies the detail level of the image. low uses fewer tokens, you can opt in to high resolution using high. Default value is auto + std::string detail; + + ImageUrl() = default; + + ImageUrl(ImageUrl&&) noexcept = default; + + ImageUrl& operator=(ImageUrl&&) noexcept = default; + + ImageUrl(const ImageUrl&) = delete; + + ImageUrl& operator=(const ImageUrl&) = delete; +}; + +// References an image URL in the content of a message. +struct ImageUrlContent : Content { + + // The type of the content part. + ImageUrlContent(const std::string& type) : Content(type) {} + + ImageUrlContent(ImageUrlContent&&) noexcept = default; + + ImageUrlContent& operator=(ImageUrlContent&&) noexcept = default; + + ImageUrlContent(const ImageUrlContent&) = delete; + + ImageUrlContent& operator=(const ImageUrlContent&) = delete; + + ImageUrl image_url; + + static cpp::result FromJson( + Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + + try { + ImageUrlContent content{"image_url"}; + ImageUrl image_url; + image_url.url = std::move(json["image_url"]["url"].asString()); + image_url.detail = std::move(json["image_url"]["detail"].asString()); + content.image_url = std::move(image_url); + return content; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value json; + json["type"] = type; + json["image_url"]["url"] = image_url.url; + json["image_url"]["detail"] = image_url.detail; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +} // namespace ThreadMessage diff --git a/engine/common/message_content_refusal.h b/engine/common/message_content_refusal.h new file mode 100644 index 000000000..8353c3a85 --- /dev/null +++ b/engine/common/message_content_refusal.h @@ -0,0 +1,46 @@ +#pragma once + +#include "common/message_content.h" + +namespace ThreadMessage { +// The refusal content generated by the assistant. +struct Refusal : Content { + + // Always refusal. + Refusal(const std::string& refusal) : Content("refusal"), refusal{refusal} {} + + Refusal(Refusal&&) noexcept = default; + + Refusal& operator=(Refusal&&) noexcept = default; + + Refusal(const Refusal&) = delete; + + Refusal& operator=(const Refusal&) = delete; + + std::string refusal; + + static cpp::result FromJson(Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + + try { + Refusal content{std::move(json["refusal"].asString())}; + return content; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value json; + json["type"] = type; + json["refusal"] = refusal; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +} // namespace ThreadMessage diff --git a/engine/common/message_content_text.h b/engine/common/message_content_text.h new file mode 100644 index 000000000..124d4a878 --- /dev/null +++ b/engine/common/message_content_text.h @@ -0,0 +1,242 @@ +#pragma once + +#include "common/message_content.h" +#include "utils/logging_utils.h" + +namespace ThreadMessage { + +struct Annotation : JsonSerializable { + std::string type; + + // The text in the message content that needs to be replaced. + std::string text; + + uint32_t start_index; + + uint32_t end_index; + + Annotation(const std::string& type, const std::string& text, + uint32_t start_index, uint32_t end_index) + : type{type}, + text{text}, + start_index{start_index}, + end_index{end_index} {} + + virtual ~Annotation() = default; +}; + +// A citation within the message that points to a specific quote from a specific File associated with the assistant or the message. Generated when the assistant uses the "file_search" tool to search files. +struct FileCitationWrapper : Annotation { + + // Always file_citation. + FileCitationWrapper(const std::string& text, uint32_t start_index, + uint32_t end_index) + : Annotation("file_citation", text, start_index, end_index) {} + + FileCitationWrapper(FileCitationWrapper&&) noexcept = default; + + FileCitationWrapper& operator=(FileCitationWrapper&&) noexcept = default; + + FileCitationWrapper(const FileCitationWrapper&) = delete; + + FileCitationWrapper& operator=(const FileCitationWrapper&) = delete; + + struct FileCitation { + // The ID of the specific File the citation is from. + std::string file_id; + + FileCitation() = default; + + FileCitation(FileCitation&&) noexcept = default; + + FileCitation& operator=(FileCitation&&) noexcept = default; + + FileCitation(const FileCitation&) = delete; + + FileCitation& operator=(const FileCitation&) = delete; + }; + + FileCitation file_citation; + + cpp::result ToJson() override { + try { + Json::Value json; + json["text"] = text; + json["type"] = type; + json["file_citation"]["file_id"] = file_citation.file_id; + json["start_index"] = start_index; + json["end_index"] = end_index; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; + +// A URL for the file that's generated when the assistant used the code_interpreter tool to generate a file. +struct FilePathWrapper : Annotation { + // Always file_path. + FilePathWrapper(const std::string& text, uint32_t start_index, + uint32_t end_index) + : Annotation("file_path", text, start_index, end_index) {} + + FilePathWrapper(FilePathWrapper&&) noexcept = default; + + FilePathWrapper& operator=(FilePathWrapper&&) noexcept = default; + + FilePathWrapper(const FilePathWrapper&) = delete; + + FilePathWrapper& operator=(const FilePathWrapper&) = delete; + + struct FilePath { + // The ID of the file that was generated. + std::string file_id; + + FilePath() = default; + + FilePath(FilePath&&) noexcept = default; + + FilePath& operator=(FilePath&&) noexcept = default; + + FilePath(const FilePath&) = delete; + + FilePath& operator=(const FilePath&) = delete; + }; + + FilePath file_path; + + cpp::result ToJson() override { + try { + Json::Value json; + json["text"] = text; + json["type"] = type; + json["file_path"]["file_id"] = file_path.file_id; + json["start_index"] = start_index; + json["end_index"] = end_index; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; + +struct Text : JsonSerializable { + // The data that makes up the text. + + Text() = default; + + Text(Text&&) noexcept = default; + + Text& operator=(Text&&) noexcept = default; + + Text(const Text&) = delete; + + Text& operator=(const Text&) = delete; + + std::string value; + + std::vector> annotations; + + static cpp::result FromJson(Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + + try { + Text text; + text.value = json["value"].asString(); + + // Parse annotations array + if (json.isMember("annotations") && json["annotations"].isArray()) { + for (const auto& annotation_json : json["annotations"]) { + std::string type = std::move(annotation_json["type"].asString()); + std::string annotation_text = + std::move(annotation_json["text"].asString()); + uint32_t start_index = annotation_json["start_index"].asUInt(); + uint32_t end_index = annotation_json["end_index"].asUInt(); + + if (type == "file_citation") { + auto citation = std::make_unique( + annotation_text, start_index, end_index); + citation->file_citation.file_id = std::move( + annotation_json["file_citation"]["file_id"].asString()); + text.annotations.push_back(std::move(citation)); + } else if (type == "file_path") { + auto file_path = std::make_unique( + annotation_text, start_index, end_index); + file_path->file_path.file_id = + std::move(annotation_json["file_path"]["file_id"].asString()); + text.annotations.push_back(std::move(file_path)); + } else { + CTL_WRN("Unknown annotation type: " + type); + } + } + } + + return text; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value json; + json["value"] = value; + Json::Value annotations_json_arr{Json::arrayValue}; + for (auto& annotation : annotations) { + if (auto it = annotation->ToJson(); it.has_value()) { + annotations_json_arr.append(it.value()); + } else { + CTL_WRN("Failed to convert annotation to json: " + it.error()); + } + } + json["annotations"] = annotations_json_arr; + return json; + } catch (const std::exception e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + }; +}; + +// The text content that is part of a message. +struct TextContent : Content { + // Always text. + TextContent() : Content("text") {} + + TextContent(TextContent&&) noexcept = default; + + TextContent& operator=(TextContent&&) noexcept = default; + + TextContent(const TextContent&) = delete; + + TextContent& operator=(const TextContent&) = delete; + + Text text; + + static cpp::result FromJson(Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Json string is empty"); + } + + try { + TextContent content; + content.text = std::move(Text::FromJson(std::move(json["text"])).value()); + return content; + } catch (const std::exception& e) { + return cpp::fail(std::string("FromJson failed: ") + e.what()); + } + } + + cpp::result ToJson() override { + try { + Json::Value json; + json["type"] = type; + json["text"] = text.ToJson().value(); + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +} // namespace ThreadMessage diff --git a/engine/common/message_incomplete_detail.h b/engine/common/message_incomplete_detail.h new file mode 100644 index 000000000..25e9c1169 --- /dev/null +++ b/engine/common/message_incomplete_detail.h @@ -0,0 +1,32 @@ +#pragma once + +#include "common/json_serializable.h" + +namespace ThreadMessage { + +// On an incomplete message, details about why the message is incomplete. +struct IncompleteDetail : JsonSerializable { + // The reason the message is incomplete. + std::string reason; + + static cpp::result, std::string> FromJson( + Json::Value&& json) { + if (json.empty()) { + return std::nullopt; + } + IncompleteDetail incomplete_detail; + incomplete_detail.reason = json["reason"].asString(); + return incomplete_detail; + } + + cpp::result ToJson() override { + try { + Json::Value json; + json["reason"] = reason; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +} // namespace ThreadMessage diff --git a/engine/common/message_role.h b/engine/common/message_role.h new file mode 100644 index 000000000..9d428eddc --- /dev/null +++ b/engine/common/message_role.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include "utils/string_utils.h" + +namespace ThreadMessage { +// The entity that produced the message. One of user or assistant. +enum class Role { USER, ASSISTANT }; + +inline std::string RoleToString(Role role) { + switch (role) { + case Role::USER: + return "user"; + case Role::ASSISTANT: + return "assistant"; + default: + throw new std::invalid_argument("Invalid role: " + + std::to_string((int)role)); + } +} + +inline Role RoleFromString(const std::string& input) { + if (string_utils::EqualsIgnoreCase(input, "user")) { + return Role::USER; + } else { + // for backward compatible with jan. Before, jan was mark text with `ready` + return Role::ASSISTANT; + } +} +}; // namespace ThreadMessage diff --git a/engine/common/message_status.h b/engine/common/message_status.h new file mode 100644 index 000000000..e8844ee13 --- /dev/null +++ b/engine/common/message_status.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include "utils/string_utils.h" + +namespace ThreadMessage { +// The status of the message, which can be either in_progress, incomplete, or completed. +enum class Status { IN_PROGRESS, INCOMPLETE, COMPLETED }; + +// Convert a Status enum to a string. +inline std::string StatusToString(Status status) { + switch (status) { + case Status::IN_PROGRESS: + return "in_progress"; + case Status::INCOMPLETE: + return "incomplete"; + // default as completed for backward compatible with jan + default: + return "completed"; + } +} + +// Convert a string to a Status enum. +inline Status StatusFromString(const std::string& input) { + if (string_utils::EqualsIgnoreCase(input, "in_progress")) { + return Status::IN_PROGRESS; + } else if (string_utils::EqualsIgnoreCase(input, "incomplete")) { + return Status::INCOMPLETE; + } else { + // for backward compatible with jan. Before, jan was mark text with `ready` + return Status::COMPLETED; + } +} +}; // namespace ThreadMessage diff --git a/engine/common/repository/message_repository.h b/engine/common/repository/message_repository.h new file mode 100644 index 000000000..cffc73675 --- /dev/null +++ b/engine/common/repository/message_repository.h @@ -0,0 +1,27 @@ +#pragma once + +#include "common/message.h" +#include "utils/result.hpp" + +class MessageRepository { + public: + virtual cpp::result CreateMessage( + ThreadMessage::Message& message) = 0; + + virtual cpp::result, std::string> + ListMessages(const std::string& thread_id, uint8_t limit = 20, + const std::string& order = "desc", const std::string& after = "", + const std::string& before = "", + const std::string& run_id = "") const = 0; + + virtual cpp::result RetrieveMessage( + const std::string& thread_id, const std::string& message_id) const = 0; + + virtual cpp::result ModifyMessage( + ThreadMessage::Message& message) = 0; + + virtual cpp::result DeleteMessage( + const std::string& thread_id, const std::string& message_id) = 0; + + virtual ~MessageRepository() = default; +}; diff --git a/engine/common/variant_map.h b/engine/common/variant_map.h new file mode 100644 index 000000000..c8da77317 --- /dev/null +++ b/engine/common/variant_map.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include +#include +#include +#include "utils/result.hpp" + +namespace Cortex { + +using ValueVariant = std::variant; +using VariantMap = std::unordered_map; + +inline cpp::result ConvertJsonValueToMap( + const Json::Value& json) { + VariantMap result; + + if (!json.isObject()) { + return cpp::fail("Input json is not an object"); + } + + for (const auto& key : json.getMemberNames()) { + const Json::Value& value = json[key]; + + switch (value.type()) { + case Json::nullValue: + // Skip null values + break; + + case Json::stringValue: + result.emplace(key, value.asString()); + break; + + case Json::booleanValue: + result.emplace(key, value.asBool()); + break; + + case Json::uintValue: + case Json::intValue: + // Handle both signed and unsigned integers + if (value.isUInt64()) { + result.emplace(key, value.asUInt64()); + } else { + // Convert to double if the integer is negative or too large + result.emplace(key, value.asDouble()); + } + break; + + case Json::realValue: + result.emplace(key, value.asDouble()); + break; + + case Json::arrayValue: + case Json::objectValue: + // currently does not handle complex type + break; + } + } + + return result; +} +}; // namespace Cortex diff --git a/engine/controllers/messages.cc b/engine/controllers/messages.cc new file mode 100644 index 000000000..55d9f6370 --- /dev/null +++ b/engine/controllers/messages.cc @@ -0,0 +1,298 @@ +#include "messages.h" +#include "common/api-dto/messages/delete_message_response.h" +#include "common/message_content.h" +#include "common/message_role.h" +#include "common/variant_map.h" +#include "utils/cortex_utils.h" +#include "utils/logging_utils.h" +#include "utils/string_utils.h" + +void Messages::ListMessages( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, std::optional limit, + std::optional order, std::optional after, + std::optional before, + std::optional run_id) const { + auto res = message_service_->ListMessages( + thread_id, limit.value_or(20), order.value_or("desc"), after.value_or(""), + before.value_or(""), run_id.value_or("")); + + Json::Value root; + if (res.has_error()) { + root["message"] = res.error(); + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k400BadRequest); + callback(response); + return; + } + Json::Value msg_arr(Json::arrayValue); + for (auto& msg : res.value()) { + if (auto it = msg.ToJson(); it.has_value()) { + msg_arr.append(it.value()); + } else { + CTL_WRN("Failed to convert message to json: " + it.error()); + } + } + + root["object"] = "list"; + root["data"] = msg_arr; + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k200OK); + callback(response); +} + +void Messages::CreateMessage( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // role + auto role_str = json_body->get("role", "").asString(); + if (role_str.empty()) { + Json::Value ret; + ret["message"] = "Role is required"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + if (role_str != "user" && role_str != "assistant") { + Json::Value ret; + ret["message"] = "Role must be either 'user' or 'assistant'"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + ThreadMessage::Role role = role_str == "user" + ? ThreadMessage::Role::USER + : ThreadMessage::Role::ASSISTANT; + + std::variant>> + content; + + if (json_body->get("content", "").isArray()) { + auto result = ThreadMessage::ParseContents(json_body->get("content", "")); + if (result.has_error()) { + Json::Value ret; + ret["message"] = "Failed to parse content array: " + result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + if (result.value().empty()) { + Json::Value ret; + ret["message"] = "Content array cannot be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + content = std::move(result.value()); + } else if (json_body->get("content", "").isString()) { + auto content_str = json_body->get("content", "").asString(); + string_utils::Trim(content_str); + if (content_str.empty()) { + Json::Value ret; + ret["message"] = "Content can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // success get content as string + content = content_str; + } else { + Json::Value ret; + ret["message"] = "Content must be either a string or an array"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // attachments + std::optional> attachments = + std::nullopt; + if (json_body->get("attachments", "").isArray()) { + attachments = ThreadMessage::ParseAttachments( + std::move(json_body->get("attachments", ""))) + .value(); + } + + std::optional metadata = std::nullopt; + if (json_body->get("metadata", "").isObject()) { + auto res = Cortex::ConvertJsonValueToMap(json_body->get("metadata", "")); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + metadata = res.value(); + } + } + + auto res = message_service_->CreateMessage( + thread_id, role, std::move(content), attachments, metadata); + if (res.has_error()) { + Json::Value ret; + ret["message"] = "Content must be either a string or an array"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto message_to_json = res->ToJson(); + if (message_to_json.has_error()) { + CTL_ERR("Failed to convert message to json: " + message_to_json.error()); + Json::Value ret; + ret["message"] = message_to_json.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Messages::RetrieveMessage( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, const std::string& message_id) const { + auto res = message_service_->RetrieveMessage(thread_id, message_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto message_to_json = res->ToJson(); + if (message_to_json.has_error()) { + CTL_ERR("Failed to convert message to json: " + message_to_json.error()); + Json::Value ret; + ret["message"] = message_to_json.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Messages::ModifyMessage( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, const std::string& message_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + std::optional metadata = std::nullopt; + if (auto it = json_body->get("metadata", ""); it) { + if (it.empty()) { + Json::Value ret; + ret["message"] = "Metadata can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + auto convert_res = Cortex::ConvertJsonValueToMap(it); + if (convert_res.has_error()) { + Json::Value ret; + ret["message"] = + "Failed to convert metadata to map: " + convert_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + metadata = convert_res.value(); + } + + if (!metadata.has_value()) { + Json::Value ret; + ret["message"] = "Metadata is mandatory"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto res = + message_service_->ModifyMessage(thread_id, message_id, metadata.value()); + if (res.has_error()) { + Json::Value ret; + ret["message"] = "Failed to modify message: " + res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto message_to_json = res->ToJson(); + if (message_to_json.has_error()) { + CTL_ERR("Failed to convert message to json: " + message_to_json.error()); + Json::Value ret; + ret["message"] = message_to_json.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Messages::DeleteMessage( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, const std::string& message_id) { + auto res = message_service_->DeleteMessage(thread_id, message_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = "Failed to delete message: " + res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + api_response::DeleteMessageResponse response; + response.id = message_id; + response.object = "thread.message.deleted"; + response.deleted = true; + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(response.ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} diff --git a/engine/controllers/messages.h b/engine/controllers/messages.h new file mode 100644 index 000000000..340317eb8 --- /dev/null +++ b/engine/controllers/messages.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include +#include "services/message_service.h" + +using namespace drogon; + +class Messages : public drogon::HttpController { + public: + METHOD_LIST_BEGIN + ADD_METHOD_TO(Messages::CreateMessage, "/v1/threads/{1}/messages", Options, + Post); + + ADD_METHOD_TO(Messages::ListMessages, + "/v1/threads/{thread_id}/" + "messages?limit={limit}&order={order}&after={after}&before={" + "before}&run_id={run_id}", + Get); + + ADD_METHOD_TO(Messages::RetrieveMessage, "/v1/threads/{1}/messages/{2}", Get); + ADD_METHOD_TO(Messages::ModifyMessage, "/v1/threads/{1}/messages/{2}", + Options, Post); + ADD_METHOD_TO(Messages::DeleteMessage, "/v1/threads/{1}/messages/{2}", + Options, Delete); + METHOD_LIST_END + + Messages(std::shared_ptr msg_srv) + : message_service_{msg_srv} {} + + void CreateMessage(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id); + + void ListMessages(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, std::optional limit, + std::optional order, + std::optional after, + std::optional before, + std::optional run_id) const; + + void RetrieveMessage(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, + const std::string& message_id) const; + + void ModifyMessage(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, + const std::string& message_id); + + void DeleteMessage(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id, + const std::string& message_id); + + private: + std::shared_ptr message_service_; +}; diff --git a/engine/main.cc b/engine/main.cc index 61571907f..d076c02bd 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -1,18 +1,22 @@ #include #include #include +#include "common/repository/message_repository.h" #include "controllers/configs.h" #include "controllers/engines.h" #include "controllers/events.h" #include "controllers/hardware.h" +#include "controllers/messages.h" #include "controllers/models.h" #include "controllers/process_manager.h" #include "controllers/server.h" #include "cortex-common/cortexpythoni.h" #include "database/database.h" #include "migrations/migration_manager.h" +#include "repositories/message_fs_repository.h" #include "services/config_service.h" #include "services/file_watcher_service.h" +#include "services/message_service.h" #include "services/model_service.h" #include "utils/archive_utils.h" #include "utils/cortex_utils.h" @@ -116,6 +120,9 @@ void RunServer(std::optional port, bool ignore_cout) { auto event_queue_ptr = std::make_shared(); cortex::event::EventProcessor event_processor(event_queue_ptr); + std::shared_ptr msg_repo = + std::make_shared(); + auto message_srv = std::make_shared(msg_repo); auto model_dir_path = file_manager_utils::GetModelsContainerPath(); auto config_service = std::make_shared(); auto download_service = @@ -131,6 +138,7 @@ void RunServer(std::optional port, bool ignore_cout) { file_watcher_srv->start(); // initialize custom controllers + auto message_ctl = std::make_shared(message_srv); auto engine_ctl = std::make_shared(engine_service); auto model_ctl = std::make_shared(model_service, engine_service); auto event_ctl = std::make_shared(event_queue_ptr); @@ -140,6 +148,7 @@ void RunServer(std::optional port, bool ignore_cout) { std::make_shared(inference_svc, engine_service); auto config_ctl = std::make_shared(config_service); + drogon::app().registerController(message_ctl); drogon::app().registerController(engine_ctl); drogon::app().registerController(model_ctl); drogon::app().registerController(event_ctl); diff --git a/engine/repositories/message_fs_repository.cc b/engine/repositories/message_fs_repository.cc new file mode 100644 index 000000000..60cc0b5bf --- /dev/null +++ b/engine/repositories/message_fs_repository.cc @@ -0,0 +1,226 @@ +#include "message_fs_repository.h" +#include "utils/file_manager_utils.h" +#include "utils/result.hpp" + +namespace { +constexpr static const std::string_view kMessageFile = "messages.jsonl"; + +inline cpp::result GetMessageFileAbsPath( + const std::string& thread_id) { + auto path = + file_manager_utils::GetThreadsContainerPath() / thread_id / kMessageFile; + if (!std::filesystem::exists(path)) { + return cpp::fail("Message file not exist at path: " + path.string()); + } + return path; +} +} // namespace + +cpp::result MessageFsRepository::CreateMessage( + ThreadMessage::Message& message) { + CTL_INF("CreateMessage for thread " + message.thread_id); + auto path = GetMessageFileAbsPath(message.thread_id); + if (path.has_error()) { + return cpp::fail(path.error()); + } + + std::ofstream file(path->string(), std::ios::app); + if (!file) { + return cpp::fail("Failed to open file for writing: " + path->string()); + } + + auto mutex = GrabMutex(message.thread_id); + std::shared_lock lock(*mutex); + + auto json_str = message.ToSingleLineJsonString(); + if (json_str.has_error()) { + return cpp::fail(json_str.error()); + } + file << json_str.value(); + + file.flush(); + if (file.fail()) { + return cpp::fail("Failed to write to file: " + path->string()); + } + file.close(); + if (file.fail()) { + return cpp::fail("Failed to close file after writing: " + path->string()); + } + + return {}; +} + +cpp::result, std::string> +MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, + const std::string& order, + const std::string& after, + const std::string& before, + const std::string& run_id) const { + CTL_INF("Listing messages for thread " + thread_id); + auto path = GetMessageFileAbsPath(thread_id); + if (path.has_error()) { + return cpp::fail(path.error()); + } + + auto mutex = GrabMutex(thread_id); + std::shared_lock lock(*mutex); + + return ReadMessageFromFile(thread_id); +} + +cpp::result +MessageFsRepository::RetrieveMessage(const std::string& thread_id, + const std::string& message_id) const { + auto path = GetMessageFileAbsPath(thread_id); + if (path.has_error()) { + return cpp::fail(path.error()); + } + + auto mutex = GrabMutex(thread_id); + std::unique_lock lock(*mutex); + + auto messages = ReadMessageFromFile(thread_id); + if (messages.has_error()) { + return cpp::fail(messages.error()); + } + + for (auto& msg : messages.value()) { + if (msg.id == message_id) { + return std::move(msg); + } + } + + return cpp::fail("Message not found"); +} + +cpp::result MessageFsRepository::ModifyMessage( + ThreadMessage::Message& message) { + auto path = GetMessageFileAbsPath(message.thread_id); + if (path.has_error()) { + return cpp::fail(path.error()); + } + + auto mutex = GrabMutex(message.thread_id); + std::unique_lock lock(*mutex); + + auto messages = ReadMessageFromFile(message.thread_id); + if (messages.has_error()) { + return cpp::fail(messages.error()); + } + + std::ofstream file(path.value().string(), std::ios::trunc); + if (!file) { + return cpp::fail("Failed to open file for writing: " + + path.value().string()); + } + + bool found = false; + for (auto& msg : messages.value()) { + if (msg.id == message.id) { + file << message.ToSingleLineJsonString().value(); + found = true; + } else { + file << msg.ToSingleLineJsonString().value(); + } + } + + file.flush(); + if (file.fail()) { + return cpp::fail("Failed to write to file: " + path->string()); + } + file.close(); + if (file.fail()) { + return cpp::fail("Failed to close file after writing: " + path->string()); + } + + if (!found) { + return cpp::fail("Message not found"); + } + return {}; +} + +cpp::result MessageFsRepository::DeleteMessage( + const std::string& thread_id, const std::string& message_id) { + auto path = GetMessageFileAbsPath(thread_id); + if (path.has_error()) { + return cpp::fail(path.error()); + } + + auto mutex = GrabMutex(thread_id); + std::unique_lock lock(*mutex); + auto messages = ReadMessageFromFile(thread_id); + if (messages.has_error()) { + return cpp::fail(messages.error()); + } + + std::ofstream file(path.value().string(), std::ios::trunc); + if (!file) { + return cpp::fail("Failed to open file for writing: " + + path.value().string()); + } + + bool found = false; + for (auto& msg : messages.value()) { + if (msg.id != message_id) { + file << msg.ToSingleLineJsonString().value(); + } else { + found = true; + } + } + + file.flush(); + if (file.fail()) { + return cpp::fail("Failed to write to file: " + path->string()); + } + file.close(); + if (file.fail()) { + return cpp::fail("Failed to close file after writing: " + path->string()); + } + + if (!found) { + return cpp::fail("Message not found"); + } + + return {}; +} + +cpp::result, std::string> +MessageFsRepository::ReadMessageFromFile(const std::string& thread_id) const { + LOG_TRACE << "Reading messages from file for thread " << thread_id; + auto path = GetMessageFileAbsPath(thread_id); + if (path.has_error()) { + return cpp::fail(path.error()); + } + + std::ifstream file(path.value()); + if (!file) { + return cpp::fail("Failed to open file: " + path->string()); + } + + std::vector messages; + std::string line; + while (std::getline(file, line)) { + if (line.empty()) + continue; + auto msg_parse_result = + ThreadMessage::Message::FromJsonString(std::move(line)); + if (msg_parse_result.has_error()) { + CTL_WRN("Failed to parse message: " + msg_parse_result.error()); + continue; + } + + messages.push_back(std::move(msg_parse_result.value())); + } + + return messages; +} + +std::shared_mutex* MessageFsRepository::GrabMutex( + const std::string& thread_id) const { + std::lock_guard lock(mutex_map_mutex_); + auto& thread_mutex = thread_mutexes_[thread_id]; + if (!thread_mutex) { + thread_mutex = std::make_unique(); + } + return thread_mutex.get(); +} diff --git a/engine/repositories/message_fs_repository.h b/engine/repositories/message_fs_repository.h new file mode 100644 index 000000000..d8bcd02a7 --- /dev/null +++ b/engine/repositories/message_fs_repository.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include "common/repository/message_repository.h" + +class MessageFsRepository : public MessageRepository { + public: + cpp::result CreateMessage( + ThreadMessage::Message& message) override; + + cpp::result, std::string> ListMessages( + const std::string& thread_id, uint8_t limit = 20, + const std::string& order = "desc", const std::string& after = "", + const std::string& before = "", + const std::string& run_id = "") const override; + + cpp::result RetrieveMessage( + const std::string& thread_id, + const std::string& message_id) const override; + + cpp::result ModifyMessage( + ThreadMessage::Message& message) override; + + cpp::result DeleteMessage( + const std::string& thread_id, const std::string& message_id) override; + + ~MessageFsRepository() = default; + + private: + cpp::result, std::string> + ReadMessageFromFile(const std::string& thread_id) const; + + std::shared_mutex* GrabMutex(const std::string& thread_id) const; + + mutable std::unordered_map> + thread_mutexes_; + mutable std::mutex mutex_map_mutex_; +}; diff --git a/engine/services/hardware_service.cc b/engine/services/hardware_service.cc index a6ceb556f..681ca7578 100644 --- a/engine/services/hardware_service.cc +++ b/engine/services/hardware_service.cc @@ -8,7 +8,6 @@ #endif #include "cli/commands/cortex_upd_cmd.h" #include "database/hardware.h" -#include "services/engine_service.h" #include "utils/cortex_utils.h" #include "utils/widechar_conv.h" @@ -326,4 +325,4 @@ bool HardwareService::IsValidConfig( } return true; } -} // namespace services \ No newline at end of file +} // namespace services diff --git a/engine/services/message_service.cc b/engine/services/message_service.cc new file mode 100644 index 000000000..31ae38420 --- /dev/null +++ b/engine/services/message_service.cc @@ -0,0 +1,105 @@ +#include "services/message_service.h" +#include "utils/logging_utils.h" +#include "utils/result.hpp" +#include "utils/ulid/ulid.hh" + +cpp::result MessageService::CreateMessage( + const std::string& thread_id, const ThreadMessage::Role& role, + std::variant>>&& + content, + std::optional> attachments, + std::optional metadata) { + LOG_TRACE << "CreateMessage for thread " << thread_id; + auto now = std::chrono::system_clock::now(); + auto seconds_since_epoch = + std::chrono::duration_cast(now.time_since_epoch()) + .count(); + std::vector> content_list{}; + // if content is string + if (std::holds_alternative(content)) { + auto text_content = std::make_unique(); + text_content->text.value = std::get(content); + content_list.push_back(std::move(text_content)); + } else { + content_list = std::move( + std::get>>( + content)); + } + + ulid::ULID ulid = ulid::Create(seconds_since_epoch, []() { return 4; }); + std::string str = ulid::Marshal(ulid); + LOG_TRACE << "Generated message ID: " << str; + + ThreadMessage::Message msg; + msg.id = str; + msg.object = "thread.message"; + msg.created_at = 0; + msg.thread_id = thread_id; + msg.status = ThreadMessage::Status::COMPLETED; + msg.completed_at = seconds_since_epoch; + msg.incomplete_at = std::nullopt; + msg.incomplete_details = std::nullopt; + msg.role = role; + msg.content = std::move(content_list); + msg.assistant_id = std::nullopt; + msg.run_id = std::nullopt; + msg.attachments = attachments; + msg.metadata = metadata.value_or(Cortex::VariantMap{}); + auto res = message_repository_->CreateMessage(msg); + if (res.has_error()) { + return cpp::fail("Failed to create message: " + res.error()); + } else { + return msg; + } +} + +cpp::result, std::string> +MessageService::ListMessages(const std::string& thread_id, uint8_t limit, + const std::string& order, const std::string& after, + const std::string& before, + const std::string& run_id) const { + CTL_INF("ListMessages for thread " + thread_id); + return message_repository_->ListMessages(thread_id); +} + +cpp::result +MessageService::RetrieveMessage(const std::string& thread_id, + const std::string& message_id) const { + CTL_INF("RetrieveMessage for thread " + thread_id); + return message_repository_->RetrieveMessage(thread_id, message_id); +} + +cpp::result MessageService::ModifyMessage( + const std::string& thread_id, const std::string& message_id, + std::optional metadata) { + LOG_TRACE << "ModifyMessage for thread " << thread_id << ", message " + << message_id; + auto msg = RetrieveMessage(thread_id, message_id); + if (msg.has_error()) { + return cpp::fail("Failed to retrieve message: " + msg.error()); + } + + msg->metadata = metadata.value(); + auto ptr = &msg.value(); + + auto res = message_repository_->ModifyMessage(msg.value()); + if (res.has_error()) { + CTL_ERR("Failed to modify message: " + res.error()); + return cpp::fail("Failed to modify message: " + res.error()); + } else { + return RetrieveMessage(thread_id, message_id); + } +} + +cpp::result MessageService::DeleteMessage( + const std::string& thread_id, const std::string& message_id) { + LOG_TRACE << "DeleteMessage for thread " + thread_id; + auto res = message_repository_->DeleteMessage(thread_id, message_id); + if (res.has_error()) { + LOG_ERROR << "Failed to delete message: " + res.error(); + return cpp::fail("Failed to delete message: " + res.error()); + } else { + return message_id; + } +} diff --git a/engine/services/message_service.h b/engine/services/message_service.h new file mode 100644 index 000000000..e62970b54 --- /dev/null +++ b/engine/services/message_service.h @@ -0,0 +1,39 @@ +#pragma once + +#include "common/repository/message_repository.h" +#include "common/variant_map.h" +#include "utils/result.hpp" + +class MessageService { + public: + explicit MessageService(std::shared_ptr message_repository) + : message_repository_{message_repository} {} + + cpp::result CreateMessage( + const std::string& thread_id, const ThreadMessage::Role& role, + std::variant>>&& + content, + std::optional> attachments, + std::optional metadata); + + cpp::result, std::string> ListMessages( + const std::string& thread_id, uint8_t limit = 20, + const std::string& order = "desc", const std::string& after = "", + const std::string& before = "", const std::string& run_id = "") const; + + cpp::result RetrieveMessage( + const std::string& thread_id, const std::string& message_id) const; + + cpp::result ModifyMessage( + const std::string& thread_id, const std::string& message_id, + std::optional>> + metadata); + + cpp::result DeleteMessage( + const std::string& thread_id, const std::string& message_id); + + private: + std::shared_ptr message_repository_; +}; diff --git a/engine/utils/file_manager_utils.cc b/engine/utils/file_manager_utils.cc index 9650dd973..11128a275 100644 --- a/engine/utils/file_manager_utils.cc +++ b/engine/utils/file_manager_utils.cc @@ -299,6 +299,11 @@ std::filesystem::path GetCudaToolkitPath(const std::string& engine) { return cuda_path; } +std::filesystem::path GetThreadsContainerPath() { + auto cortex_path = GetCortexDataPath(); + return cortex_path / "threads"; +} + std::filesystem::path GetEnginesContainerPath() { auto cortex_path = getenv("ENGINE_PATH") ? std::filesystem::path(getenv("ENGINE_PATH")) diff --git a/engine/utils/file_manager_utils.h b/engine/utils/file_manager_utils.h index a7a1b09c2..91102d002 100644 --- a/engine/utils/file_manager_utils.h +++ b/engine/utils/file_manager_utils.h @@ -49,6 +49,8 @@ std::filesystem::path GetCudaToolkitPath(const std::string& engine); std::filesystem::path GetEnginesContainerPath(); +std::filesystem::path GetThreadsContainerPath(); + std::filesystem::path GetContainerFolderPath(const std::string_view type); std::string DownloadTypeToString(DownloadType type); diff --git a/engine/utils/ulid/ulid.hh b/engine/utils/ulid/ulid.hh new file mode 100644 index 000000000..22b6f19b5 --- /dev/null +++ b/engine/utils/ulid/ulid.hh @@ -0,0 +1,16 @@ +#ifndef ULID_HH +#define ULID_HH + +// https://github.com/suyash/ulid +// http://stackoverflow.com/a/23981011 +#ifdef __SIZEOF_INT128__ +#define ULIDUINT128 +#endif + +#ifdef ULIDUINT128 +#include "ulid_uint128.hh" +#else +#include "ulid_struct.hh" +#endif // ULIDUINT128 + +#endif // ULID_HH diff --git a/engine/utils/ulid/ulid_struct.hh b/engine/utils/ulid/ulid_struct.hh new file mode 100644 index 000000000..ad0da59ec --- /dev/null +++ b/engine/utils/ulid/ulid_struct.hh @@ -0,0 +1,710 @@ +#ifndef ULID_STRUCT_HH +#define ULID_STRUCT_HH + +#include +#include +#include +#include +#include +#include + +#if _MSC_VER > 0 +typedef uint32_t rand_t; +#else +typedef uint8_t rand_t; +#endif + +namespace ulid { + +/** + * ULID is a 16 byte Universally Unique Lexicographically Sortable Identifier + * */ +struct ULID { + uint8_t data[16]; + + ULID() { + // for (int i = 0 ; i < 16 ; i++) { + // data[i] = 0; + // } + + // unrolled loop + data[0] = 0; + data[1] = 0; + data[2] = 0; + data[3] = 0; + data[4] = 0; + data[5] = 0; + data[6] = 0; + data[7] = 0; + data[8] = 0; + data[9] = 0; + data[10] = 0; + data[11] = 0; + data[12] = 0; + data[13] = 0; + data[14] = 0; + data[15] = 0; + } + + ULID(uint64_t val) { + // for (int i = 0 ; i < 16 ; i++) { + // data[15 - i] = static_cast(val); + // val >>= 8; + // } + + // unrolled loop + data[15] = static_cast(val); + + val >>= 8; + data[14] = static_cast(val); + + val >>= 8; + data[13] = static_cast(val); + + val >>= 8; + data[12] = static_cast(val); + + val >>= 8; + data[11] = static_cast(val); + + val >>= 8; + data[10] = static_cast(val); + + val >>= 8; + data[9] = static_cast(val); + + val >>= 8; + data[8] = static_cast(val); + + data[7] = 0; + data[6] = 0; + data[5] = 0; + data[4] = 0; + data[3] = 0; + data[2] = 0; + data[1] = 0; + data[0] = 0; + } + + ULID(const ULID& other) { + // for (int i = 0 ; i < 16 ; i++) { + // data[i] = other.data[i]; + // } + + // unrolled loop + data[0] = other.data[0]; + data[1] = other.data[1]; + data[2] = other.data[2]; + data[3] = other.data[3]; + data[4] = other.data[4]; + data[5] = other.data[5]; + data[6] = other.data[6]; + data[7] = other.data[7]; + data[8] = other.data[8]; + data[9] = other.data[9]; + data[10] = other.data[10]; + data[11] = other.data[11]; + data[12] = other.data[12]; + data[13] = other.data[13]; + data[14] = other.data[14]; + data[15] = other.data[15]; + } + + ULID& operator=(const ULID& other) { + // for (int i = 0 ; i < 16 ; i++) { + // data[i] = other.data[i]; + // } + + // unrolled loop + data[0] = other.data[0]; + data[1] = other.data[1]; + data[2] = other.data[2]; + data[3] = other.data[3]; + data[4] = other.data[4]; + data[5] = other.data[5]; + data[6] = other.data[6]; + data[7] = other.data[7]; + data[8] = other.data[8]; + data[9] = other.data[9]; + data[10] = other.data[10]; + data[11] = other.data[11]; + data[12] = other.data[12]; + data[13] = other.data[13]; + data[14] = other.data[14]; + data[15] = other.data[15]; + + return *this; + } + + ULID(ULID&& other) { + // for (int i = 0 ; i < 16 ; i++) { + // data[i] = other.data[i]; + // other.data[i] = 0; + // } + + // unrolled loop + data[0] = other.data[0]; + other.data[0] = 0; + + data[1] = other.data[1]; + other.data[1] = 0; + + data[2] = other.data[2]; + other.data[2] = 0; + + data[3] = other.data[3]; + other.data[3] = 0; + + data[4] = other.data[4]; + other.data[4] = 0; + + data[5] = other.data[5]; + other.data[5] = 0; + + data[6] = other.data[6]; + other.data[6] = 0; + + data[7] = other.data[7]; + other.data[7] = 0; + + data[8] = other.data[8]; + other.data[8] = 0; + + data[9] = other.data[9]; + other.data[9] = 0; + + data[10] = other.data[10]; + other.data[10] = 0; + + data[11] = other.data[11]; + other.data[11] = 0; + + data[12] = other.data[12]; + other.data[12] = 0; + + data[13] = other.data[13]; + other.data[13] = 0; + + data[14] = other.data[14]; + other.data[14] = 0; + + data[15] = other.data[15]; + other.data[15] = 0; + } + + ULID& operator=(ULID&& other) { + // for (int i = 0 ; i < 16 ; i++) { + // data[i] = other.data[i]; + // other.data[i] = 0; + // } + + // unrolled loop + data[0] = other.data[0]; + other.data[0] = 0; + + data[1] = other.data[1]; + other.data[1] = 0; + + data[2] = other.data[2]; + other.data[2] = 0; + + data[3] = other.data[3]; + other.data[3] = 0; + + data[4] = other.data[4]; + other.data[4] = 0; + + data[5] = other.data[5]; + other.data[5] = 0; + + data[6] = other.data[6]; + other.data[6] = 0; + + data[7] = other.data[7]; + other.data[7] = 0; + + data[8] = other.data[8]; + other.data[8] = 0; + + data[9] = other.data[9]; + other.data[9] = 0; + + data[10] = other.data[10]; + other.data[10] = 0; + + data[11] = other.data[11]; + other.data[11] = 0; + + data[12] = other.data[12]; + other.data[12] = 0; + + data[13] = other.data[13]; + other.data[13] = 0; + + data[14] = other.data[14]; + other.data[14] = 0; + + data[15] = other.data[15]; + other.data[15] = 0; + + return *this; + } +}; + +/** + * EncodeTime will encode the first 6 bytes of a uint8_t array to the passed + * timestamp + * */ +inline void EncodeTime(time_t timestamp, ULID& ulid) { + ulid.data[0] = static_cast(timestamp >> 40); + ulid.data[1] = static_cast(timestamp >> 32); + ulid.data[2] = static_cast(timestamp >> 24); + ulid.data[3] = static_cast(timestamp >> 16); + ulid.data[4] = static_cast(timestamp >> 8); + ulid.data[5] = static_cast(timestamp); +} + +/** + * EncodeTimeNow will encode a ULID using the time obtained using std::time(nullptr) + * */ +inline void EncodeTimeNow(ULID& ulid) { + EncodeTime(std::time(nullptr), ulid); +} + +/** + * EncodeTimeSystemClockNow will encode a ULID using the time obtained using + * std::chrono::system_clock::now() by taking the timestamp in milliseconds. + * */ +inline void EncodeTimeSystemClockNow(ULID& ulid) { + auto now = std::chrono::system_clock::now(); + auto ms = std::chrono::duration_cast( + now.time_since_epoch()); + EncodeTime(ms.count(), ulid); +} + +/** + * EncodeEntropy will encode the last 10 bytes of the passed uint8_t array with + * the values generated using the passed random number generator. + * */ +inline void EncodeEntropy(const std::function& rng, ULID& ulid) { + ulid.data[6] = rng(); + ulid.data[7] = rng(); + ulid.data[8] = rng(); + ulid.data[9] = rng(); + ulid.data[10] = rng(); + ulid.data[11] = rng(); + ulid.data[12] = rng(); + ulid.data[13] = rng(); + ulid.data[14] = rng(); + ulid.data[15] = rng(); +} + +/** + * EncodeEntropyRand will encode a ulid using std::rand + * + * std::rand returns values in [0, RAND_MAX] + * */ +inline void EncodeEntropyRand(ULID& ulid) { + ulid.data[6] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[7] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[8] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[9] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[10] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[11] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[12] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[13] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[14] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; + ulid.data[15] = (uint8_t)(std::rand() * 255ull) / RAND_MAX; +} + +static std::uniform_int_distribution Distribution_0_255(0, 255); + +/** + * EncodeEntropyMt19937 will encode a ulid using std::mt19937 + * + * It also creates a std::uniform_int_distribution to generate values in [0, 255] + * */ +inline void EncodeEntropyMt19937(std::mt19937& generator, ULID& ulid) { + ulid.data[6] = Distribution_0_255(generator); + ulid.data[7] = Distribution_0_255(generator); + ulid.data[8] = Distribution_0_255(generator); + ulid.data[9] = Distribution_0_255(generator); + ulid.data[10] = Distribution_0_255(generator); + ulid.data[11] = Distribution_0_255(generator); + ulid.data[12] = Distribution_0_255(generator); + ulid.data[13] = Distribution_0_255(generator); + ulid.data[14] = Distribution_0_255(generator); + ulid.data[15] = Distribution_0_255(generator); +} + +/** + * Encode will create an encoded ULID with a timestamp and a generator. + * */ +inline void Encode(time_t timestamp, const std::function& rng, + ULID& ulid) { + EncodeTime(timestamp, ulid); + EncodeEntropy(rng, ulid); +} + +/** + * EncodeNowRand = EncodeTimeNow + EncodeEntropyRand. + * */ +inline void EncodeNowRand(ULID& ulid) { + EncodeTimeNow(ulid); + EncodeEntropyRand(ulid); +} + +/** + * Create will create a ULID with a timestamp and a generator. + * */ +inline ULID Create(time_t timestamp, const std::function& rng) { + ULID ulid; + Encode(timestamp, rng, ulid); + return ulid; +} + +/** + * CreateNowRand:EncodeNowRand = Create:Encode. + * */ +inline ULID CreateNowRand() { + ULID ulid; + EncodeNowRand(ulid); + return ulid; +} + +/** + * Crockford's Base32 + * */ +static const char Encoding[33] = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"; + +/** + * MarshalTo will marshal a ULID to the passed character array. + * + * Implementation taken directly from oklog/ulid + * (https://sourcegraph.com/github.com/oklog/ulid@0774f81f6e44af5ce5e91c8d7d76cf710e889ebb/-/blob/ulid.go#L162-190) + * + * timestamp:
+ * dst[0]: first 3 bits of data[0]
+ * dst[1]: last 5 bits of data[0]
+ * dst[2]: first 5 bits of data[1]
+ * dst[3]: last 3 bits of data[1] + first 2 bits of data[2]
+ * dst[4]: bits 3-7 of data[2]
+ * dst[5]: last bit of data[2] + first 4 bits of data[3]
+ * dst[6]: last 4 bits of data[3] + first bit of data[4]
+ * dst[7]: bits 2-6 of data[4]
+ * dst[8]: last 2 bits of data[4] + first 3 bits of data[5]
+ * dst[9]: last 5 bits of data[5]
+ * + * entropy: + * follows similarly, except now all components are set to 5 bits. + * */ +inline void MarshalTo(const ULID& ulid, char dst[26]) { + // 10 byte timestamp + dst[0] = Encoding[(ulid.data[0] & 224) >> 5]; + dst[1] = Encoding[ulid.data[0] & 31]; + dst[2] = Encoding[(ulid.data[1] & 248) >> 3]; + dst[3] = Encoding[((ulid.data[1] & 7) << 2) | ((ulid.data[2] & 192) >> 6)]; + dst[4] = Encoding[(ulid.data[2] & 62) >> 1]; + dst[5] = Encoding[((ulid.data[2] & 1) << 4) | ((ulid.data[3] & 240) >> 4)]; + dst[6] = Encoding[((ulid.data[3] & 15) << 1) | ((ulid.data[4] & 128) >> 7)]; + dst[7] = Encoding[(ulid.data[4] & 124) >> 2]; + dst[8] = Encoding[((ulid.data[4] & 3) << 3) | ((ulid.data[5] & 224) >> 5)]; + dst[9] = Encoding[ulid.data[5] & 31]; + + // 16 bytes of entropy + dst[10] = Encoding[(ulid.data[6] & 248) >> 3]; + dst[11] = Encoding[((ulid.data[6] & 7) << 2) | ((ulid.data[7] & 192) >> 6)]; + dst[12] = Encoding[(ulid.data[7] & 62) >> 1]; + dst[13] = Encoding[((ulid.data[7] & 1) << 4) | ((ulid.data[8] & 240) >> 4)]; + dst[14] = Encoding[((ulid.data[8] & 15) << 1) | ((ulid.data[9] & 128) >> 7)]; + dst[15] = Encoding[(ulid.data[9] & 124) >> 2]; + dst[16] = Encoding[((ulid.data[9] & 3) << 3) | ((ulid.data[10] & 224) >> 5)]; + dst[17] = Encoding[ulid.data[10] & 31]; + dst[18] = Encoding[(ulid.data[11] & 248) >> 3]; + dst[19] = Encoding[((ulid.data[11] & 7) << 2) | ((ulid.data[12] & 192) >> 6)]; + dst[20] = Encoding[(ulid.data[12] & 62) >> 1]; + dst[21] = Encoding[((ulid.data[12] & 1) << 4) | ((ulid.data[13] & 240) >> 4)]; + dst[22] = + Encoding[((ulid.data[13] & 15) << 1) | ((ulid.data[14] & 128) >> 7)]; + dst[23] = Encoding[(ulid.data[14] & 124) >> 2]; + dst[24] = Encoding[((ulid.data[14] & 3) << 3) | ((ulid.data[15] & 224) >> 5)]; + dst[25] = Encoding[ulid.data[15] & 31]; +} + +/** + * Marshal will marshal a ULID to a std::string. + * */ +inline std::string Marshal(const ULID& ulid) { + char data[27]; + data[26] = '\0'; + MarshalTo(ulid, data); + return std::string(data); +} + +/** + * MarshalBinaryTo will Marshal a ULID to the passed byte array + * */ +inline void MarshalBinaryTo(const ULID& ulid, uint8_t dst[16]) { + // timestamp + dst[0] = ulid.data[0]; + dst[1] = ulid.data[1]; + dst[2] = ulid.data[2]; + dst[3] = ulid.data[3]; + dst[4] = ulid.data[4]; + dst[5] = ulid.data[5]; + + // entropy + dst[6] = ulid.data[6]; + dst[7] = ulid.data[7]; + dst[8] = ulid.data[8]; + dst[9] = ulid.data[9]; + dst[10] = ulid.data[10]; + dst[11] = ulid.data[11]; + dst[12] = ulid.data[12]; + dst[13] = ulid.data[13]; + dst[14] = ulid.data[14]; + dst[15] = ulid.data[15]; +} + +/** + * MarshalBinary will Marshal a ULID to a byte vector. + * */ +inline std::vector MarshalBinary(const ULID& ulid) { + std::vector dst(16); + MarshalBinaryTo(ulid, dst.data()); + return dst; +} + +/** + * dec storesdecimal encodings for characters. + * 0xFF indicates invalid character. + * 48-57 are digits. + * 65-90 are capital alphabets. + * */ +static const uint8_t dec[256] = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, + /* 0 1 2 3 4 5 6 7 */ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + /* 8 9 */ + 0x08, 0x09, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + /* 10(A) 11(B) 12(C) 13(D) 14(E) 15(F) 16(G) */ + 0xFF, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, + /*17(H) 18(J) 19(K) 20(M) 21(N) */ + 0x11, 0xFF, 0x12, 0x13, 0xFF, 0x14, 0x15, 0xFF, + /*22(P)23(Q)24(R) 25(S) 26(T) 27(V) 28(W) */ + 0x16, 0x17, 0x18, 0x19, 0x1A, 0xFF, 0x1B, 0x1C, + /*29(X)30(Y)31(Z) */ + 0x1D, 0x1E, 0x1F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; + +/** + * UnmarshalFrom will unmarshal a ULID from the passed character array. + * */ +inline void UnmarshalFrom(const char str[26], ULID& ulid) { + // timestamp + ulid.data[0] = (dec[int(str[0])] << 5) | dec[int(str[1])]; + ulid.data[1] = (dec[int(str[2])] << 3) | (dec[int(str[3])] >> 2); + ulid.data[2] = (dec[int(str[3])] << 6) | (dec[int(str[4])] << 1) | + (dec[int(str[5])] >> 4); + ulid.data[3] = (dec[int(str[5])] << 4) | (dec[int(str[6])] >> 1); + ulid.data[4] = (dec[int(str[6])] << 7) | (dec[int(str[7])] << 2) | + (dec[int(str[8])] >> 3); + ulid.data[5] = (dec[int(str[8])] << 5) | dec[int(str[9])]; + + // entropy + ulid.data[6] = (dec[int(str[10])] << 3) | (dec[int(str[11])] >> 2); + ulid.data[7] = (dec[int(str[11])] << 6) | (dec[int(str[12])] << 1) | + (dec[int(str[13])] >> 4); + ulid.data[8] = (dec[int(str[13])] << 4) | (dec[int(str[14])] >> 1); + ulid.data[9] = (dec[int(str[14])] << 7) | (dec[int(str[15])] << 2) | + (dec[int(str[16])] >> 3); + ulid.data[10] = (dec[int(str[16])] << 5) | dec[int(str[17])]; + ulid.data[11] = (dec[int(str[18])] << 3) | (dec[int(str[19])] >> 2); + ulid.data[12] = (dec[int(str[19])] << 6) | (dec[int(str[20])] << 1) | + (dec[int(str[21])] >> 4); + ulid.data[13] = (dec[int(str[21])] << 4) | (dec[int(str[22])] >> 1); + ulid.data[14] = (dec[int(str[22])] << 7) | (dec[int(str[23])] << 2) | + (dec[int(str[24])] >> 3); + ulid.data[15] = (dec[int(str[24])] << 5) | dec[int(str[25])]; +} + +/** + * Unmarshal will create a new ULID by unmarshaling the passed string. + * */ +inline ULID Unmarshal(const std::string& str) { + ULID ulid; + UnmarshalFrom(str.c_str(), ulid); + return ulid; +} + +/** + * UnmarshalBinaryFrom will unmarshal a ULID from the passed byte array. + * */ +inline void UnmarshalBinaryFrom(const uint8_t b[16], ULID& ulid) { + // timestamp + ulid.data[0] = b[0]; + ulid.data[1] = b[1]; + ulid.data[2] = b[2]; + ulid.data[3] = b[3]; + ulid.data[4] = b[4]; + ulid.data[5] = b[5]; + + // entropy + ulid.data[6] = b[6]; + ulid.data[7] = b[7]; + ulid.data[8] = b[8]; + ulid.data[9] = b[9]; + ulid.data[10] = b[10]; + ulid.data[11] = b[11]; + ulid.data[12] = b[12]; + ulid.data[13] = b[13]; + ulid.data[14] = b[14]; + ulid.data[15] = b[15]; +} + +/** + * Unmarshal will create a new ULID by unmarshaling the passed byte vector. + * */ +inline ULID UnmarshalBinary(const std::vector& b) { + ULID ulid; + UnmarshalBinaryFrom(b.data(), ulid); + return ulid; +} + +/** + * CompareULIDs will compare two ULIDs. + * returns: + * -1 if ulid1 is Lexicographically before ulid2 + * 1 if ulid1 is Lexicographically after ulid2 + * 0 if ulid1 is same as ulid2 + * */ +inline int CompareULIDs(const ULID& ulid1, const ULID& ulid2) { + // for (int i = 0 ; i < 16 ; i++) { + // if (ulid1.data[i] != ulid2.data[i]) { + // return (ulid1.data[i] < ulid2.data[i]) * -2 + 1; + // } + // } + + // unrolled loop + + if (ulid1.data[0] != ulid2.data[0]) { + return (ulid1.data[0] < ulid2.data[0]) * -2 + 1; + } + + if (ulid1.data[1] != ulid2.data[1]) { + return (ulid1.data[1] < ulid2.data[1]) * -2 + 1; + } + + if (ulid1.data[2] != ulid2.data[2]) { + return (ulid1.data[2] < ulid2.data[2]) * -2 + 1; + } + + if (ulid1.data[3] != ulid2.data[3]) { + return (ulid1.data[3] < ulid2.data[3]) * -2 + 1; + } + + if (ulid1.data[4] != ulid2.data[4]) { + return (ulid1.data[4] < ulid2.data[4]) * -2 + 1; + } + + if (ulid1.data[5] != ulid2.data[5]) { + return (ulid1.data[5] < ulid2.data[5]) * -2 + 1; + } + + if (ulid1.data[6] != ulid2.data[6]) { + return (ulid1.data[6] < ulid2.data[6]) * -2 + 1; + } + + if (ulid1.data[7] != ulid2.data[7]) { + return (ulid1.data[7] < ulid2.data[7]) * -2 + 1; + } + + if (ulid1.data[8] != ulid2.data[8]) { + return (ulid1.data[8] < ulid2.data[8]) * -2 + 1; + } + + if (ulid1.data[9] != ulid2.data[9]) { + return (ulid1.data[9] < ulid2.data[9]) * -2 + 1; + } + + if (ulid1.data[10] != ulid2.data[10]) { + return (ulid1.data[10] < ulid2.data[10]) * -2 + 1; + } + + if (ulid1.data[11] != ulid2.data[11]) { + return (ulid1.data[11] < ulid2.data[11]) * -2 + 1; + } + + if (ulid1.data[12] != ulid2.data[12]) { + return (ulid1.data[12] < ulid2.data[12]) * -2 + 1; + } + + if (ulid1.data[13] != ulid2.data[13]) { + return (ulid1.data[13] < ulid2.data[13]) * -2 + 1; + } + + if (ulid1.data[14] != ulid2.data[14]) { + return (ulid1.data[14] < ulid2.data[14]) * -2 + 1; + } + + if (ulid1.data[15] != ulid2.data[15]) { + return (ulid1.data[15] < ulid2.data[15]) * -2 + 1; + } + + return 0; +} + +/** + * Time will extract the timestamp used to generate a ULID + * */ +inline time_t Time(const ULID& ulid) { + time_t ans = 0; + + ans |= ulid.data[0]; + + ans <<= 8; + ans |= ulid.data[1]; + + ans <<= 8; + ans |= ulid.data[2]; + + ans <<= 8; + ans |= ulid.data[3]; + + ans <<= 8; + ans |= ulid.data[4]; + + ans <<= 8; + ans |= ulid.data[5]; + + return ans; +} + +}; // namespace ulid + +#endif // ULID_STRUCT_HH diff --git a/engine/utils/ulid/ulid_uint128.hh b/engine/utils/ulid/ulid_uint128.hh new file mode 100644 index 000000000..b3f200141 --- /dev/null +++ b/engine/utils/ulid/ulid_uint128.hh @@ -0,0 +1,561 @@ +#ifndef ULID_UINT128_HH +#define ULID_UINT128_HH + +#include +#include +#include +#include +#include +#include + +#if _MSC_VER > 0 +typedef uint32_t rand_t; +#else +typedef uint8_t rand_t; +#endif + +namespace ulid { + +/** + * ULID is a 16 byte Universally Unique Lexicographically Sortable Identifier + * */ +typedef __uint128_t ULID; + +/** + * EncodeTime will encode the first 6 bytes of a uint8_t array to the passed + * timestamp + * */ +inline void EncodeTime(time_t timestamp, ULID& ulid) { + ULID t = static_cast(timestamp >> 40); + + t <<= 8; + t |= static_cast(timestamp >> 32); + + t <<= 8; + t |= static_cast(timestamp >> 24); + + t <<= 8; + t |= static_cast(timestamp >> 16); + + t <<= 8; + t |= static_cast(timestamp >> 8); + + t <<= 8; + t |= static_cast(timestamp); + + t <<= 80; + + ULID mask = 1; + mask <<= 80; + mask--; + + ulid = t | (ulid & mask); +} + +/** + * EncodeTimeNow will encode a ULID using the time obtained using std::time(nullptr) + * */ +inline void EncodeTimeNow(ULID& ulid) { + EncodeTime(std::time(nullptr), ulid); +} + +/** + * EncodeTimeSystemClockNow will encode a ULID using the time obtained using + * std::chrono::system_clock::now() by taking the timestamp in milliseconds. + * */ +inline void EncodeTimeSystemClockNow(ULID& ulid) { + auto now = std::chrono::system_clock::now(); + auto ms = std::chrono::duration_cast( + now.time_since_epoch()); + EncodeTime(ms.count(), ulid); +} + +/** + * EncodeEntropy will encode the last 10 bytes of the passed uint8_t array with + * the values generated using the passed random number generator. + * */ +inline void EncodeEntropy(const std::function& rng, ULID& ulid) { + ulid = (ulid >> 80) << 80; + + ULID e = rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + e <<= 8; + e |= rng(); + + ulid |= e; +} + +/** + * EncodeEntropyRand will encode a ulid using std::rand + * + * std::rand returns values in [0, RAND_MAX] + * */ +inline void EncodeEntropyRand(ULID& ulid) { + ulid = (ulid >> 80) << 80; + + ULID e = (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + e <<= 8; + e |= (std::rand() * 255ull) / RAND_MAX; + + ulid |= e; +} + +static std::uniform_int_distribution Distribution_0_255(0, 255); + +/** + * EncodeEntropyMt19937 will encode a ulid using std::mt19937 + * + * It also creates a std::uniform_int_distribution to generate values in [0, 255] + * */ +inline void EncodeEntropyMt19937(std::mt19937& generator, ULID& ulid) { + ulid = (ulid >> 80) << 80; + + ULID e = Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + e <<= 8; + e |= Distribution_0_255(generator); + + ulid |= e; +} + +/** + * Encode will create an encoded ULID with a timestamp and a generator. + * */ +inline void Encode(time_t timestamp, const std::function& rng, + ULID& ulid) { + EncodeTime(timestamp, ulid); + EncodeEntropy(rng, ulid); +} + +/** + * EncodeNowRand = EncodeTimeNow + EncodeEntropyRand. + * */ +inline void EncodeNowRand(ULID& ulid) { + EncodeTimeNow(ulid); + EncodeEntropyRand(ulid); +} + +/** + * Create will create a ULID with a timestamp and a generator. + * */ +inline ULID Create(time_t timestamp, const std::function& rng) { + ULID ulid = 0; + Encode(timestamp, rng, ulid); + return ulid; +} + +/** + * CreateNowRand:EncodeNowRand = Create:Encode. + * */ +inline ULID CreateNowRand() { + ULID ulid = 0; + EncodeNowRand(ulid); + return ulid; +} + +/** + * Crockford's Base32 + * */ +static const char Encoding[33] = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"; + +/** + * MarshalTo will marshal a ULID to the passed character array. + * + * Implementation taken directly from oklog/ulid + * (https://sourcegraph.com/github.com/oklog/ulid@0774f81f6e44af5ce5e91c8d7d76cf710e889ebb/-/blob/ulid.go#L162-190) + * + * timestamp: + * dst[0]: first 3 bits of data[0] + * dst[1]: last 5 bits of data[0] + * dst[2]: first 5 bits of data[1] + * dst[3]: last 3 bits of data[1] + first 2 bits of data[2] + * dst[4]: bits 3-7 of data[2] + * dst[5]: last bit of data[2] + first 4 bits of data[3] + * dst[6]: last 4 bits of data[3] + first bit of data[4] + * dst[7]: bits 2-6 of data[4] + * dst[8]: last 2 bits of data[4] + first 3 bits of data[5] + * dst[9]: last 5 bits of data[5] + * + * entropy: + * follows similarly, except now all components are set to 5 bits. + * */ +inline void MarshalTo(const ULID& ulid, char dst[26]) { + // 10 byte timestamp + dst[0] = Encoding[(static_cast(ulid >> 120) & 224) >> 5]; + dst[1] = Encoding[static_cast(ulid >> 120) & 31]; + dst[2] = Encoding[(static_cast(ulid >> 112) & 248) >> 3]; + dst[3] = Encoding[((static_cast(ulid >> 112) & 7) << 2) | + ((static_cast(ulid >> 104) & 192) >> 6)]; + dst[4] = Encoding[(static_cast(ulid >> 104) & 62) >> 1]; + dst[5] = Encoding[((static_cast(ulid >> 104) & 1) << 4) | + ((static_cast(ulid >> 96) & 240) >> 4)]; + dst[6] = Encoding[((static_cast(ulid >> 96) & 15) << 1) | + ((static_cast(ulid >> 88) & 128) >> 7)]; + dst[7] = Encoding[(static_cast(ulid >> 88) & 124) >> 2]; + dst[8] = Encoding[((static_cast(ulid >> 88) & 3) << 3) | + ((static_cast(ulid >> 80) & 224) >> 5)]; + dst[9] = Encoding[static_cast(ulid >> 80) & 31]; + + // 16 bytes of entropy + dst[10] = Encoding[(static_cast(ulid >> 72) & 248) >> 3]; + dst[11] = Encoding[((static_cast(ulid >> 72) & 7) << 2) | + ((static_cast(ulid >> 64) & 192) >> 6)]; + dst[12] = Encoding[(static_cast(ulid >> 64) & 62) >> 1]; + dst[13] = Encoding[((static_cast(ulid >> 64) & 1) << 4) | + ((static_cast(ulid >> 56) & 240) >> 4)]; + dst[14] = Encoding[((static_cast(ulid >> 56) & 15) << 1) | + ((static_cast(ulid >> 48) & 128) >> 7)]; + dst[15] = Encoding[(static_cast(ulid >> 48) & 124) >> 2]; + dst[16] = Encoding[((static_cast(ulid >> 48) & 3) << 3) | + ((static_cast(ulid >> 40) & 224) >> 5)]; + dst[17] = Encoding[static_cast(ulid >> 40) & 31]; + dst[18] = Encoding[(static_cast(ulid >> 32) & 248) >> 3]; + dst[19] = Encoding[((static_cast(ulid >> 32) & 7) << 2) | + ((static_cast(ulid >> 24) & 192) >> 6)]; + dst[20] = Encoding[(static_cast(ulid >> 24) & 62) >> 1]; + dst[21] = Encoding[((static_cast(ulid >> 24) & 1) << 4) | + ((static_cast(ulid >> 16) & 240) >> 4)]; + dst[22] = Encoding[((static_cast(ulid >> 16) & 15) << 1) | + ((static_cast(ulid >> 8) & 128) >> 7)]; + dst[23] = Encoding[(static_cast(ulid >> 8) & 124) >> 2]; + dst[24] = Encoding[((static_cast(ulid >> 8) & 3) << 3) | + (((static_cast(ulid)) & 224) >> 5)]; + dst[25] = Encoding[(static_cast(ulid)) & 31]; +} + +/** + * Marshal will marshal a ULID to a std::string. + * */ +inline std::string Marshal(const ULID& ulid) { + char data[27]; + data[26] = '\0'; + MarshalTo(ulid, data); + return std::string(data); +} + +/** + * MarshalBinaryTo will Marshal a ULID to the passed byte array + * */ +inline void MarshalBinaryTo(const ULID& ulid, uint8_t dst[16]) { + // timestamp + dst[0] = static_cast(ulid >> 120); + dst[1] = static_cast(ulid >> 112); + dst[2] = static_cast(ulid >> 104); + dst[3] = static_cast(ulid >> 96); + dst[4] = static_cast(ulid >> 88); + dst[5] = static_cast(ulid >> 80); + + // entropy + dst[6] = static_cast(ulid >> 72); + dst[7] = static_cast(ulid >> 64); + dst[8] = static_cast(ulid >> 56); + dst[9] = static_cast(ulid >> 48); + dst[10] = static_cast(ulid >> 40); + dst[11] = static_cast(ulid >> 32); + dst[12] = static_cast(ulid >> 24); + dst[13] = static_cast(ulid >> 16); + dst[14] = static_cast(ulid >> 8); + dst[15] = static_cast(ulid); +} + +/** + * MarshalBinary will Marshal a ULID to a byte vector. + * */ +inline std::vector MarshalBinary(const ULID& ulid) { + std::vector dst(16); + MarshalBinaryTo(ulid, dst.data()); + return dst; +} + +/** + * dec storesdecimal encodings for characters. + * 0xFF indicates invalid character. + * 48-57 are digits. + * 65-90 are capital alphabets. + * */ +static const uint8_t dec[256] = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, + /* 0 1 2 3 4 5 6 7 */ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + /* 8 9 */ + 0x08, 0x09, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + /* 10(A) 11(B) 12(C) 13(D) 14(E) 15(F) 16(G) */ + 0xFF, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, + /*17(H) 18(J) 19(K) 20(M) 21(N) */ + 0x11, 0xFF, 0x12, 0x13, 0xFF, 0x14, 0x15, 0xFF, + /*22(P)23(Q)24(R) 25(S) 26(T) 27(V) 28(W) */ + 0x16, 0x17, 0x18, 0x19, 0x1A, 0xFF, 0x1B, 0x1C, + /*29(X)30(Y)31(Z) */ + 0x1D, 0x1E, 0x1F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}; + +/** + * UnmarshalFrom will unmarshal a ULID from the passed character array. + * */ +inline void UnmarshalFrom(const char str[26], ULID& ulid) { + // timestamp + ulid = (dec[int(str[0])] << 5) | dec[int(str[1])]; + + ulid <<= 8; + ulid |= (dec[int(str[2])] << 3) | (dec[int(str[3])] >> 2); + + ulid <<= 8; + ulid |= (dec[int(str[3])] << 6) | (dec[int(str[4])] << 1) | + (dec[int(str[5])] >> 4); + + ulid <<= 8; + ulid |= (dec[int(str[5])] << 4) | (dec[int(str[6])] >> 1); + + ulid <<= 8; + ulid |= (dec[int(str[6])] << 7) | (dec[int(str[7])] << 2) | + (dec[int(str[8])] >> 3); + + ulid <<= 8; + ulid |= (dec[int(str[8])] << 5) | dec[int(str[9])]; + + // entropy + ulid <<= 8; + ulid |= (dec[int(str[10])] << 3) | (dec[int(str[11])] >> 2); + + ulid <<= 8; + ulid |= (dec[int(str[11])] << 6) | (dec[int(str[12])] << 1) | + (dec[int(str[13])] >> 4); + + ulid <<= 8; + ulid |= (dec[int(str[13])] << 4) | (dec[int(str[14])] >> 1); + + ulid <<= 8; + ulid |= (dec[int(str[14])] << 7) | (dec[int(str[15])] << 2) | + (dec[int(str[16])] >> 3); + + ulid <<= 8; + ulid |= (dec[int(str[16])] << 5) | dec[int(str[17])]; + + ulid <<= 8; + ulid |= (dec[int(str[18])] << 3) | (dec[int(str[19])] >> 2); + + ulid <<= 8; + ulid |= (dec[int(str[19])] << 6) | (dec[int(str[20])] << 1) | + (dec[int(str[21])] >> 4); + + ulid <<= 8; + ulid |= (dec[int(str[21])] << 4) | (dec[int(str[22])] >> 1); + + ulid <<= 8; + ulid |= (dec[int(str[22])] << 7) | (dec[int(str[23])] << 2) | + (dec[int(str[24])] >> 3); + + ulid <<= 8; + ulid |= (dec[int(str[24])] << 5) | dec[int(str[25])]; +} + +/** + * Unmarshal will create a new ULID by unmarshaling the passed string. + * */ +inline ULID Unmarshal(const std::string& str) { + ULID ulid; + UnmarshalFrom(str.c_str(), ulid); + return ulid; +} + +/** + * UnmarshalBinaryFrom will unmarshal a ULID from the passed byte array. + * */ +inline void UnmarshalBinaryFrom(const uint8_t b[16], ULID& ulid) { + // timestamp + ulid = b[0]; + + ulid <<= 8; + ulid |= b[1]; + + ulid <<= 8; + ulid |= b[2]; + + ulid <<= 8; + ulid |= b[3]; + + ulid <<= 8; + ulid |= b[4]; + + ulid <<= 8; + ulid |= b[5]; + + // entropy + ulid <<= 8; + ulid |= b[6]; + + ulid <<= 8; + ulid |= b[7]; + + ulid <<= 8; + ulid |= b[8]; + + ulid <<= 8; + ulid |= b[9]; + + ulid <<= 8; + ulid |= b[10]; + + ulid <<= 8; + ulid |= b[11]; + + ulid <<= 8; + ulid |= b[12]; + + ulid <<= 8; + ulid |= b[13]; + + ulid <<= 8; + ulid |= b[14]; + + ulid <<= 8; + ulid |= b[15]; +} + +/** + * Unmarshal will create a new ULID by unmarshaling the passed byte vector. + * */ +inline ULID UnmarshalBinary(const std::vector& b) { + ULID ulid; + UnmarshalBinaryFrom(b.data(), ulid); + return ulid; +} + +/** + * CompareULIDs will compare two ULIDs. + * returns: + * -1 if ulid1 is Lexicographically before ulid2 + * 1 if ulid1 is Lexicographically after ulid2 + * 0 if ulid1 is same as ulid2 + * */ +inline int CompareULIDs(const ULID& ulid1, const ULID& ulid2) { + return -2 * (ulid1 < ulid2) - 1 * (ulid1 == ulid2) + 1; +} + +/** + * Time will extract the timestamp used to generate a ULID + * */ +inline time_t Time(const ULID& ulid) { + time_t ans = 0; + + ans |= static_cast(ulid >> 120); + + ans <<= 8; + ans |= static_cast(ulid >> 112); + + ans <<= 8; + ans |= static_cast(ulid >> 104); + + ans <<= 8; + ans |= static_cast(ulid >> 96); + + ans <<= 8; + ans |= static_cast(ulid >> 88); + + ans <<= 8; + ans |= static_cast(ulid >> 80); + + return ans; +} + +}; // namespace ulid + +#endif // ULID_UINT128_HH From 2b74824ae2e417895954fb8829450185fd0a8ece Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Wed, 4 Dec 2024 13:54:22 +0700 Subject: [PATCH 10/44] fix: support ctx_len for model start cli (#1766) * fix: support ctx_len for model start cli * chore: docs * fix: guard max ctx_len --- docs/docs/cli/models/index.mdx | 7 ++++--- docs/docs/cli/models/start.md | 1 + docs/docs/cli/run.mdx | 5 +++-- engine/cli/command_line_parser.cc | 18 +++++++++++------- engine/cli/command_line_parser.h | 2 +- engine/cli/commands/model_start_cmd.cc | 19 +++++++++++++++++-- engine/cli/commands/model_start_cmd.h | 4 ++++ engine/services/model_service.cc | 10 ++++++++-- 8 files changed, 49 insertions(+), 17 deletions(-) diff --git a/docs/docs/cli/models/index.mdx b/docs/docs/cli/models/index.mdx index b75bf9d49..dff452788 100644 --- a/docs/docs/cli/models/index.mdx +++ b/docs/docs/cli/models/index.mdx @@ -159,9 +159,10 @@ This command uses a `model_id` from the model that you have downloaded or availa | Option | Description | Required | Default value | Example | |---------------------------|---------------------------------------------------------------------------|----------|----------------------------------------------|------------------------| -| `model_id` | The identifier of the model you want to start. | Yes | `Prompt to select from the available models` | `mistral` | -| `--gpus` | List of GPUs to use. | No | - | `[0,1]` | -| `-h`, `--help` | Display help information for the command. | No | - | `-h` | +| `model_id` | The identifier of the model you want to start. | Yes | `Prompt to select from the available models` | `mistral` | +| `--gpus` | List of GPUs to use. | No | - | `[0,1]` | +| `--ctx_len` | Maximum context length for inference. | No | `min(8192, max_model_context_length)` | `1024` | +| `-h`, `--help` | Display help information for the command. | No | - | `-h` | ## `cortex models stop` :::info diff --git a/docs/docs/cli/models/start.md b/docs/docs/cli/models/start.md index 77addd0b4..3880cd477 100644 --- a/docs/docs/cli/models/start.md +++ b/docs/docs/cli/models/start.md @@ -33,6 +33,7 @@ cortex models start [model_id]:[engine] [options] |---------------------------|----------------------------------------------------------|----------|----------------------------------------------|-------------------| | `model_id` | The identifier of the model you want to start. | No | `Prompt to select from the available models` | `mistral` | | `--gpus` | List of GPUs to use. | No | - | `[0,1]` | +| `--ctx_len` | Maximum context length for inference. | No | `min(8192, max_model_context_length)` | `1024` | | `-h`, `--help` | Display help information for the command. | No | - | `-h` | diff --git a/docs/docs/cli/run.mdx b/docs/docs/cli/run.mdx index bbce017f1..57c8358a2 100644 --- a/docs/docs/cli/run.mdx +++ b/docs/docs/cli/run.mdx @@ -36,7 +36,8 @@ You can use the `--verbose` flag to display more detailed output of the internal | Option | Description | Required | Default value | Example | |-----------------------------|-----------------------------------------------------------------------------|----------|----------------------------------------------|------------------------| -| `model_id` | The identifier of the model you want to chat with. | Yes | - | `mistral` | -| `--gpus` | List of GPUs to use. | No | - | `[0,1]` | +| `model_id` | The identifier of the model you want to chat with. | Yes | - | `mistral` | +| `--gpus` | List of GPUs to use. | No | - | `[0,1]` | +| `--ctx_len` | Maximum context length for inference. | No | `min(8192, max_model_context_length)` | `1024` | | `-h`, `--help` | Display help information for the command. | No | - | `-h` | diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index 34c6b9069..9d5d83ffc 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -163,8 +163,10 @@ void CommandLineParser::SetupCommonCommands() { run_cmd->usage("Usage:\n" + commands::GetCortexBinary() + " run [options] [model_id]"); run_cmd->add_option("model_id", cml_data_.model_id, ""); - run_cmd->add_option("--gpus", hw_activate_opts_["gpus"], + run_cmd->add_option("--gpus", run_settings_["gpus"], "List of GPU to activate, for example [0, 1]"); + run_cmd->add_option("--ctx_len", run_settings_["ctx_len"], + "Maximum context length for inference"); run_cmd->add_flag("-d,--detach", cml_data_.run_detach, "Detached mode"); run_cmd->callback([this, run_cmd] { if (std::exchange(executed_, true)) @@ -172,7 +174,7 @@ void CommandLineParser::SetupCommonCommands() { commands::RunCmd rc(cml_data_.config.apiServerHost, std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id, download_service_); - rc.Exec(cml_data_.run_detach, hw_activate_opts_); + rc.Exec(cml_data_.run_detach, run_settings_); }); } @@ -203,8 +205,10 @@ void CommandLineParser::SetupModelCommands() { model_start_cmd->usage("Usage:\n" + commands::GetCortexBinary() + " models start [model_id]"); model_start_cmd->add_option("model_id", cml_data_.model_id, ""); - model_start_cmd->add_option("--gpus", hw_activate_opts_["gpus"], + model_start_cmd->add_option("--gpus", run_settings_["gpus"], "List of GPU to activate, for example [0, 1]"); + model_start_cmd->add_option("--ctx_len", run_settings_["ctx_len"], + "Maximum context length for inference"); model_start_cmd->group(kSubcommands); model_start_cmd->callback([this, model_start_cmd]() { if (std::exchange(executed_, true)) @@ -216,7 +220,7 @@ void CommandLineParser::SetupModelCommands() { }; commands::ModelStartCmd().Exec(cml_data_.config.apiServerHost, std::stoi(cml_data_.config.apiServerPort), - cml_data_.model_id, hw_activate_opts_); + cml_data_.model_id, run_settings_); }); auto stop_model_cmd = @@ -562,7 +566,7 @@ void CommandLineParser::SetupHardwareCommands() { hw_activate_cmd->usage("Usage:\n" + commands::GetCortexBinary() + " hardware activate --gpus [list_gpu]"); hw_activate_cmd->group(kSubcommands); - hw_activate_cmd->add_option("--gpus", hw_activate_opts_["gpus"], + hw_activate_cmd->add_option("--gpus", run_settings_["gpus"], "List of GPU to activate, for example [0, 1]"); hw_activate_cmd->callback([this, hw_activate_cmd]() { if (std::exchange(executed_, true)) @@ -572,14 +576,14 @@ void CommandLineParser::SetupHardwareCommands() { return; } - if (hw_activate_opts_["gpus"].empty()) { + if (run_settings_["gpus"].empty()) { CLI_LOG("[list_gpu] is required\n"); CLI_LOG(hw_activate_cmd->help()); return; } commands::HardwareActivateCmd().Exec( cml_data_.config.apiServerHost, - std::stoi(cml_data_.config.apiServerPort), hw_activate_opts_); + std::stoi(cml_data_.config.apiServerPort), run_settings_); }); } diff --git a/engine/cli/command_line_parser.h b/engine/cli/command_line_parser.h index f7ca3f507..aec10dcb4 100644 --- a/engine/cli/command_line_parser.h +++ b/engine/cli/command_line_parser.h @@ -79,5 +79,5 @@ class CommandLineParser { std::unordered_map config_update_opts_; bool executed_ = false; commands::HarwareOptions hw_opts_; - std::unordered_map hw_activate_opts_; + std::unordered_map run_settings_; }; diff --git a/engine/cli/commands/model_start_cmd.cc b/engine/cli/commands/model_start_cmd.cc index ea6b81e5a..12aec944d 100644 --- a/engine/cli/commands/model_start_cmd.cc +++ b/engine/cli/commands/model_start_cmd.cc @@ -30,8 +30,8 @@ bool ModelStartCmd::Exec( // bool should_activate_hw = false; - for (auto const& [_, v] : options) { - if (!v.empty()) { + for (auto const& [k, v] : options) { + if (k == "gpus" && !v.empty()) { should_activate_hw = true; break; } @@ -57,6 +57,9 @@ bool ModelStartCmd::Exec( Json::Value json_data; json_data["model"] = model_id.value(); + for (auto const& [k, v] : options) { + UpdateConfig(json_data, k, v); + } auto data_str = json_data.toStyledString(); auto res = curl_utils::SimplePostJson(url.ToFullPath(), data_str); if (res.has_error()) { @@ -75,4 +78,16 @@ bool ModelStartCmd::Exec( } return true; } + +bool ModelStartCmd::UpdateConfig(Json::Value& data, const std::string& key, + const std::string& value) { + if (key == "ctx_len" && !value.empty()) { + try { + data["ctx_len"] = std::stoi(value); + } catch (const std::exception& e) { + CLI_LOG("Failed to parse numeric value for " << key << ": " << e.what()); + } + } + return true; +} }; // namespace commands diff --git a/engine/cli/commands/model_start_cmd.h b/engine/cli/commands/model_start_cmd.h index 519db0f0d..124ef463d 100644 --- a/engine/cli/commands/model_start_cmd.h +++ b/engine/cli/commands/model_start_cmd.h @@ -2,6 +2,7 @@ #include #include +#include "json/json.h" namespace commands { @@ -10,5 +11,8 @@ class ModelStartCmd { bool Exec(const std::string& host, int port, const std::string& model_handle, const std::unordered_map& options, bool print_success_log = true); + private: + bool UpdateConfig(Json::Value& data, const std::string& key, + const std::string& value); }; } // namespace commands diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index cc1f99bdc..3cfff5cb2 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -702,6 +702,8 @@ cpp::result ModelService::StartModel( config::YamlHandler yaml_handler; try { + constexpr const int kDefautlContextLength = 8192; + int max_model_context_length = kDefautlContextLength; Json::Value json_data; // Currently we don't support download vision models, so we need to bypass check if (!params_override.bypass_model_check()) { @@ -732,6 +734,8 @@ cpp::result ModelService::StartModel( json_data["system_prompt"] = mc.system_template; json_data["user_prompt"] = mc.user_template; json_data["ai_prompt"] = mc.ai_template; + json_data["ctx_len"] = std::min(kDefautlContextLength, mc.ctx_len); + max_model_context_length = mc.ctx_len; } else { bypass_stop_check_set_.insert(model_handle); } @@ -753,12 +757,14 @@ cpp::result ModelService::StartModel( ASSIGN_IF_PRESENT(json_data, params_override, cache_enabled); ASSIGN_IF_PRESENT(json_data, params_override, ngl); ASSIGN_IF_PRESENT(json_data, params_override, n_parallel); - ASSIGN_IF_PRESENT(json_data, params_override, ctx_len); ASSIGN_IF_PRESENT(json_data, params_override, cache_type); ASSIGN_IF_PRESENT(json_data, params_override, mmproj); ASSIGN_IF_PRESENT(json_data, params_override, model_path); #undef ASSIGN_IF_PRESENT - + if (params_override.ctx_len) { + json_data["ctx_len"] = + std::min(params_override.ctx_len.value(), max_model_context_length); + } CTL_INF(json_data.toStyledString()); auto may_fallback_res = MayFallbackToCpu(json_data["model_path"].asString(), json_data["ngl"].asInt(), From 79f7679d6660dd8a24514a711006d4420c830ef4 Mon Sep 17 00:00:00 2001 From: nguyenhoangthuan99 <35255081+nguyenhoangthuan99@users.noreply.github.com> Date: Thu, 5 Dec 2024 08:39:06 +0700 Subject: [PATCH 11/44] feat: remote engine (#1666) * Init remote engine * Fix: CI build windows * Fix: CI build windows * Fix: CI build windows * Fix: CI build windows * feat: new db schema for model and template for engine * Add remote model * Add Get, List, Update support for remote models * change model_id to model in remote engine * fix: mac compatibility * chore: some refactors before making big changes * feat: db ops for engines * chore: small refactor before more changes * Update engine * refine db schema, composite key for engines * add entry definition for engine at db layer * complete add, get engine operations * engine managements * Integrate with remote engine to run remote model * error handling and response transform * Support for stream request * chore: fix conflicts * feat: anthropic * feat: support anthropic * feat: support anthropic * chore: rename * chore: cleanup and fix unit tests * fix: issue with db * chore: refactor remote engine * fix: e2e tests * fix: e2e tests * chore: API docs * fix: use different interface for remote engine --------- Co-authored-by: Luke Nguyen Co-authored-by: vansangpfiev Co-authored-by: vansangpfiev --- docs/static/openapi/cortex.json | 232 +++++- engine/CMakeLists.txt | 8 +- engine/cli/CMakeLists.txt | 7 +- engine/common/engine_servicei.h | 6 +- engine/config/model_config.h | 181 +++++ engine/controllers/engines.cc | 96 ++- engine/controllers/models.cc | 204 ++++- engine/controllers/models.h | 11 + engine/cortex-common/EngineI.h | 2 + engine/cortex-common/remote_enginei.h | 37 + engine/database/engines.cc | 173 +++++ engine/database/engines.h | 88 +++ engine/database/models.cc | 74 +- engine/database/models.h | 19 +- .../remote-engine/anthropic_engine.cc | 62 ++ .../remote-engine/anthropic_engine.h | 13 + .../extensions/remote-engine/openai_engine.cc | 54 ++ .../extensions/remote-engine/openai_engine.h | 14 + .../extensions/remote-engine/remote_engine.cc | 712 ++++++++++++++++++ .../extensions/remote-engine/remote_engine.h | 102 +++ .../remote-engine/template_renderer.cc | 136 ++++ .../remote-engine/template_renderer.h | 40 + engine/migrations/db_helper.h | 26 + engine/migrations/migration_helper.cc | 1 - engine/migrations/migration_manager.cc | 14 + engine/migrations/schema_version.h | 2 +- engine/migrations/v1/migration.h | 165 ++++ engine/services/engine_service.cc | 198 ++++- engine/services/engine_service.h | 53 +- engine/services/inference_service.cc | 92 ++- engine/services/inference_service.h | 2 +- engine/services/model_service.cc | 55 +- engine/test/components/test_models_db.cc | 42 +- engine/utils/engine_constants.h | 2 + engine/utils/logging_utils.h | 2 + engine/utils/remote_models_utils.h | 132 ++++ engine/utils/result.hpp | 1 - engine/vcpkg.json | 1 + 38 files changed, 2934 insertions(+), 125 deletions(-) create mode 100644 engine/cortex-common/remote_enginei.h create mode 100644 engine/database/engines.cc create mode 100644 engine/database/engines.h create mode 100644 engine/extensions/remote-engine/anthropic_engine.cc create mode 100644 engine/extensions/remote-engine/anthropic_engine.h create mode 100644 engine/extensions/remote-engine/openai_engine.cc create mode 100644 engine/extensions/remote-engine/openai_engine.h create mode 100644 engine/extensions/remote-engine/remote_engine.cc create mode 100644 engine/extensions/remote-engine/remote_engine.h create mode 100644 engine/extensions/remote-engine/template_renderer.cc create mode 100644 engine/extensions/remote-engine/template_renderer.h create mode 100644 engine/migrations/db_helper.h create mode 100644 engine/migrations/v1/migration.h create mode 100644 engine/utils/remote_models_utils.h diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 206ee381d..9cdd5c7b4 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -512,6 +512,73 @@ } } }, + "/v1/models/add": { + "post": { + "operationId": "ModelsController_addModel", + "summary": "Add a remote model", + "description": "Add a new remote model configuration to the system.", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AddModelRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string" + }, + "model": { + "type": "object", + "properties": { + "model": { + "type": "string" + }, + "engine": { + "type": "string" + }, + "version": { + "type": "string" + } + } + } + } + }, + "example": { + "message": "Model added successfully!", + "model": { + "model": "claude-3-5-sonnet-20241022", + "engine": "anthropic", + "version": "2023-06-01" + } + } + } + } + }, + "400": { + "description": "Bad request", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/SimpleErrorResponse" + } + } + } + } + }, + "tags": ["Pulling Models"] + } + }, "/v1/models": { "get": { "operationId": "ModelsController_findAll", @@ -1417,7 +1484,7 @@ "required": true, "schema": { "type": "string", - "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm", "openai", "anthropic"], "default": "llama-cpp" }, "description": "The type of engine" @@ -1439,6 +1506,31 @@ "type": "string", "description": "The variant of the engine to install (optional)", "example": "mac-arm64" + }, + "type": { + "type": "string", + "description": "The type of connection, remote or local", + "example": "remote" + }, + "url": { + "type": "string", + "description": "The URL for the API endpoint for remote engine", + "example": "https://api.openai.com" + }, + "api_key": { + "type": "string", + "description": "The API key for authentication for remote engine", + "example": "" + }, + "metadata": { + "type": "object", + "properties": { + "get_models_url": { + "type": "string", + "description": "The URL to get models", + "example": "https://api.openai.com/v1/models" + } + } } } } @@ -1475,7 +1567,7 @@ "required": true, "schema": { "type": "string", - "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm", "openai", "anthropic"], "default": "llama-cpp" }, "description": "The type of engine" @@ -1690,7 +1782,7 @@ "required": true, "schema": { "type": "string", - "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm"], + "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm", "openai", "anthropic"], "default": "llama-cpp" }, "description": "The name of the engine to update" @@ -3636,6 +3728,109 @@ } } }, + "AddModelRequest": { + "type": "object", + "required": ["model", "engine", "version", "inference_params", "TransformReq", "TransformResp", "metadata"], + "properties": { + "model": { + "type": "string", + "description": "The identifier of the model." + }, + "api_key_template": { + "type": "string", + "description": "Template for the API key header." + }, + "engine": { + "type": "string", + "description": "The engine used for the model." + }, + "version": { + "type": "string", + "description": "The version of the model." + }, + "inference_params": { + "type": "object", + "properties": { + "temperature": { + "type": "number" + }, + "top_p": { + "type": "number" + }, + "frequency_penalty": { + "type": "number" + }, + "presence_penalty": { + "type": "number" + }, + "max_tokens": { + "type": "integer" + }, + "stream": { + "type": "boolean" + } + } + }, + "TransformReq": { + "type": "object", + "properties": { + "get_models": { + "type": "object" + }, + "chat_completions": { + "type": "object", + "properties": { + "url": { + "type": "string" + }, + "template": { + "type": "string" + } + } + }, + "embeddings": { + "type": "object" + } + } + }, + "TransformResp": { + "type": "object", + "properties": { + "chat_completions": { + "type": "object", + "properties": { + "template": { + "type": "string" + } + } + }, + "embeddings": { + "type": "object" + } + } + }, + "metadata": { + "type": "object", + "properties": { + "author": { + "type": "string" + }, + "description": { + "type": "string" + }, + "end_point": { + "type": "string" + }, + "logo": { + "type": "string" + }, + "api_key_url": { + "type": "string" + } + } + } + } + }, "CreateModelDto": { "type": "object", "properties": { @@ -4305,6 +4500,37 @@ "type": "integer", "description": "Number of GPU layers.", "example": 33 + }, + "api_key_template": { + "type": "string", + "description": "Template for the API key header." + }, + "version": { + "type": "string", + "description": "The version of the model." + }, + "inference_params": { + "type": "object", + "properties": { + "temperature": { + "type": "number" + }, + "top_p": { + "type": "number" + }, + "frequency_penalty": { + "type": "number" + }, + "presence_penalty": { + "type": "number" + }, + "max_tokens": { + "type": "integer" + }, + "stream": { + "type": "boolean" + } + } } } }, diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index eae09d439..7cac3421c 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -142,6 +142,10 @@ file(APPEND "${CMAKE_CURRENT_BINARY_DIR}/cortex_openapi.h" add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/cpuid/cpu_info.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/file_logger.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/remote_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/openai_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/anthropic_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/template_renderer.cc ) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) @@ -171,17 +175,17 @@ endif() aux_source_directory(controllers CTL_SRC) aux_source_directory(repositories REPO_SRC) aux_source_directory(services SERVICES_SRC) -aux_source_directory(common COMMON_SRC) aux_source_directory(models MODEL_SRC) aux_source_directory(cortex-common CORTEX_COMMON) aux_source_directory(config CONFIG_SRC) aux_source_directory(database DB_SRC) +aux_source_directory(extensions EX_SRC) aux_source_directory(migrations MIGR_SRC) aux_source_directory(utils UTILS_SRC) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} ) -target_sources(${TARGET_NAME} PRIVATE ${UTILS_SRC} ${CONFIG_SRC} ${CTL_SRC} ${COMMON_SRC} ${SERVICES_SRC} ${DB_SRC} ${MIGR_SRC} ${REPO_SRC}) +target_sources(${TARGET_NAME} PRIVATE ${UTILS_SRC} ${CONFIG_SRC} ${CTL_SRC} ${COMMON_SRC} ${SERVICES_SRC} ${DB_SRC} ${EX_SRC} ${MIGR_SRC} ${REPO_SRC}) set_target_properties(${TARGET_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_BINARY_DIR} diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index 42d00ebd5..51382dc13 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -82,6 +82,10 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/hardware_service.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/remote_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/openai_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/anthropic_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/template_renderer.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/download_progress.cc ${CMAKE_CURRENT_SOURCE_DIR}/../utils/config_yaml_utils.cc @@ -121,11 +125,12 @@ aux_source_directory(../cortex-common CORTEX_COMMON) aux_source_directory(../config CONFIG_SRC) aux_source_directory(commands COMMANDS_SRC) aux_source_directory(../database DB_SRC) +aux_source_directory(../extensions EX_SRC) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/.. ) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) -target_sources(${TARGET_NAME} PRIVATE ${COMMANDS_SRC} ${CONFIG_SRC} ${COMMON_SRC} ${DB_SRC}) +target_sources(${TARGET_NAME} PRIVATE ${COMMANDS_SRC} ${CONFIG_SRC} ${COMMON_SRC} ${DB_SRC} ${EX_SRC}) set_target_properties(${TARGET_NAME} PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_BINARY_DIR} diff --git a/engine/common/engine_servicei.h b/engine/common/engine_servicei.h index bd4f099ab..85fa87d76 100644 --- a/engine/common/engine_servicei.h +++ b/engine/common/engine_servicei.h @@ -3,8 +3,8 @@ #include #include #include +#include "database/engines.h" #include "utils/result.hpp" - // TODO: namh think of the other name struct DefaultEngineVariant { std::string engine; @@ -54,4 +54,8 @@ class EngineServiceI { virtual cpp::result UnloadEngine( const std::string& engine_name) = 0; + virtual cpp::result + GetEngineByNameAndVariant( + const std::string& engine_name, + const std::optional variant = std::nullopt) = 0; }; diff --git a/engine/config/model_config.h b/engine/config/model_config.h index 7d4076ee5..701547873 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -1,13 +1,194 @@ #pragma once #include +#include +#include +#include +#include #include #include +#include #include #include #include "utils/format_utils.h" +#include "utils/remote_models_utils.h" +#include "yaml-cpp/yaml.h" namespace config { + +namespace { +const std::string kOpenAITransformReqTemplate = + R"({ {% set first = true %} {% for key, value in input_request %} {% if key == \"messages\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"presence_penalty\" or key == \"metadata\" or key == \"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; +const std::string kOpenAITransformRespTemplate = + R"({ {%- set first = true -%} {%- for key, value in input_request -%} {%- if key == \"id\" or key == \"choices\" or key == \"created\" or key == \"model\" or key == \"service_tier\" or key == \"system_fingerprint\" or key == \"object\" or key == \"usage\" -%} {%- if not first -%},{%- endif -%} \"{{ key }}\": {{ tojson(value) }} {%- set first = false -%} {%- endif -%} {%- endfor -%} })"; +const std::string kAnthropicTransformReqTemplate = + R"({ {% set first = true %} {% for key, value in input_request %} {% if key == \"system\" or key == \"messages\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"presence_penalty\" or key == \"metadata\" or key == \"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; +const std::string kAnthropicTransformRespTemplate = R"({ + "id": "{{ input_request.id }}", + "created": null, + "object": "chat.completion", + "model": "{{ input_request.model }}", + "choices": [ + { + "index": 0, + "message": { + "role": "{{ input_request.role }}", + "content": "{% if input_request.content and input_request.content.0.type == "text" %} {{input_request.content.0.text}} {% endif %}", + "refusal": null + }, + "logprobs": null, + "finish_reason": "{{ input_request.stop_reason }}" + } + ], + "usage": { + "prompt_tokens": {{ input_request.usage.input_tokens }}, + "completion_tokens": {{ input_request.usage.output_tokens }}, + "total_tokens": {{ input_request.usage.input_tokens + input_request.usage.output_tokens }}, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "system_fingerprint": "fp_6b68a8204b" + })"; +} // namespace + +struct RemoteModelConfig { + std::string model; + std::string api_key_template; + std::string engine; + std::string version; + std::size_t created; + std::string object = "model"; + std::string owned_by = ""; + Json::Value inference_params; + Json::Value TransformReq; + Json::Value TransformResp; + Json::Value metadata; + void LoadFromJson(const Json::Value& json) { + if (!json.isObject()) { + throw std::runtime_error("Input JSON must be an object"); + } + + // Load basic string fields + model = json.get("model", model).asString(); + api_key_template = + json.get("api_key_template", api_key_template).asString(); + engine = json.get("engine", engine).asString(); + version = json.get("version", version).asString(); + created = + json.get("created", static_cast(created)).asUInt64(); + object = json.get("object", object).asString(); + owned_by = json.get("owned_by", owned_by).asString(); + + // Load JSON object fields directly + inference_params = json.get("inference_params", inference_params); + TransformReq = json.get("TransformReq", TransformReq); + // Use default template if it is empty, currently we only support 2 remote engines + auto is_anthropic = [](const std::string& model) { + return model.find("claude") != std::string::npos; + }; + if (TransformReq["chat_completions"]["template"].isNull()) { + if (is_anthropic(model)) { + TransformReq["chat_completions"]["template"] = + kAnthropicTransformReqTemplate; + } else { + TransformReq["chat_completions"]["template"] = + kOpenAITransformReqTemplate; + } + } + TransformResp = json.get("TransformResp", TransformResp); + if (TransformResp["chat_completions"]["template"].isNull()) { + if (is_anthropic(model)) { + TransformResp["chat_completions"]["template"] = + kAnthropicTransformRespTemplate; + } else { + TransformResp["chat_completions"]["template"] = + kOpenAITransformRespTemplate; + } + } + metadata = json.get("metadata", metadata); + } + + Json::Value ToJson() const { + Json::Value json; + + // Add basic string fields + json["model"] = model; + json["api_key_template"] = api_key_template; + json["engine"] = engine; + json["version"] = version; + json["created"] = static_cast(created); + json["object"] = object; + json["owned_by"] = owned_by; + + // Add JSON object fields directly + json["inference_params"] = inference_params; + json["TransformReq"] = TransformReq; + json["TransformResp"] = TransformResp; + json["metadata"] = metadata; + + return json; + }; + + void SaveToYamlFile(const std::string& filepath) const { + YAML::Node root; + + // Convert basic fields + root["model"] = model; + root["api_key_template"] = api_key_template; + root["engine"] = engine; + root["version"] = version; + root["object"] = object; + root["owned_by"] = owned_by; + root["created"] = std::time(nullptr); + + // Convert Json::Value to YAML::Node using utility function + root["inference_params"] = + remote_models_utils::jsonToYaml(inference_params); + root["TransformReq"] = remote_models_utils::jsonToYaml(TransformReq); + root["TransformResp"] = remote_models_utils::jsonToYaml(TransformResp); + root["metadata"] = remote_models_utils::jsonToYaml(metadata); + + // Save to file + std::ofstream fout(filepath); + if (!fout.is_open()) { + throw std::runtime_error("Failed to open file for writing: " + filepath); + } + fout << root; + } + + void LoadFromYamlFile(const std::string& filepath) { + YAML::Node root; + try { + root = YAML::LoadFile(filepath); + } catch (const YAML::Exception& e) { + throw std::runtime_error("Failed to parse YAML file: " + + std::string(e.what())); + } + + // Load basic fields + model = root["model"].as(""); + api_key_template = root["api_key_template"].as(""); + engine = root["engine"].as(""); + version = root["version"] ? root["version"].as() : ""; + created = root["created"] ? root["created"].as() : 0; + object = root["object"] ? root["object"].as() : "model"; + owned_by = root["owned_by"] ? root["owned_by"].as() : ""; + + // Load complex fields using utility function + inference_params = + remote_models_utils::yamlToJson(root["inference_params"]); + TransformReq = remote_models_utils::yamlToJson(root["TransformReq"]); + TransformResp = remote_models_utils::yamlToJson(root["TransformResp"]); + metadata = remote_models_utils::yamlToJson(root["metadata"]); + } +}; + struct ModelConfig { std::string name; std::string model; diff --git a/engine/controllers/engines.cc b/engine/controllers/engines.cc index 9e110bd66..3d3c0c037 100644 --- a/engine/controllers/engines.cc +++ b/engine/controllers/engines.cc @@ -3,9 +3,9 @@ #include "utils/archive_utils.h" #include "utils/cortex_utils.h" #include "utils/engine_constants.h" +#include "utils/http_util.h" #include "utils/logging_utils.h" #include "utils/string_utils.h" - namespace { // Need to change this after we rename repositories std::string NormalizeEngine(const std::string& engine) { @@ -38,6 +38,18 @@ void Engines::ListEngine( } ret[engine] = variants; } + // Add remote engine + auto remote_engines = engine_service_->GetEngines(); + if (remote_engines.has_value()) { + for (auto engine : remote_engines.value()) { + if (engine.type == "remote") { + auto engine_json = engine.ToJson(); + Json::Value list_engine(Json::arrayValue); + list_engine.append(engine_json); + ret[engine.engine_name] = list_engine; + } + } + } auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); resp->setStatusCode(k200OK); @@ -162,6 +174,86 @@ void Engines::InstallEngine( norm_version = version; } + if ((req->getJsonObject()) && + (*(req->getJsonObject())).get("type", "").asString() == "remote") { + auto type = (*(req->getJsonObject())).get("type", "").asString(); + auto api_key = (*(req->getJsonObject())).get("api_key", "").asString(); + auto url = (*(req->getJsonObject())).get("url", "").asString(); + auto variant = norm_variant.value_or("all-platforms"); + auto status = (*(req->getJsonObject())).get("status", "Default").asString(); + std::string metadata; + if ((*(req->getJsonObject())).isMember("metadata") && + (*(req->getJsonObject()))["metadata"].isObject()) { + metadata = (*(req->getJsonObject())) + .get("metadata", Json::Value(Json::objectValue)) + .toStyledString(); + } else if ((*(req->getJsonObject())).isMember("metadata") && + !(*(req->getJsonObject()))["metadata"].isObject()) { + Json::Value res; + res["message"] = "metadata must be object"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto get_models_url = (*(req->getJsonObject())) + .get("metadata", Json::Value(Json::objectValue)) + .get("get_models_url", "") + .asString(); + + if (engine.empty() || type.empty() || url.empty()) { + Json::Value res; + res["message"] = "Engine name, type, url are required"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + auto exist_engine = engine_service_->GetEngineByNameAndVariant(engine); + // only allow 1 variant 1 version of a remote engine name + if (exist_engine.has_value()) { + Json::Value res; + if (get_models_url.empty()) { + res["warning"] = + "'get_models_url' not found in metadata, You'll not able to search " + "remote models with this engine"; + } + res["message"] = "Engine '" + engine + "' already exists"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto result = engine_service_->UpsertEngine( + engine, type, api_key, url, norm_version, variant, status, metadata); + if (result.has_error()) { + Json::Value res; + if (get_models_url.empty()) { + res["warning"] = + "'get_models_url' not found in metadata, You'll not able to search " + "remote models with this engine"; + } + res["message"] = result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + Json::Value res; + if (get_models_url.empty()) { + res["warning"] = + "'get_models_url' not found in metadata, You'll not able to search " + "remote models with this engine"; + } + res["message"] = "Remote Engine install successfully!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); + resp->setStatusCode(k200OK); + callback(resp); + } + return; + } + auto result = engine_service_->InstallEngineAsync(engine, norm_version, norm_variant); if (result.has_error()) { @@ -169,12 +261,14 @@ void Engines::InstallEngine( res["message"] = result.error(); auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); resp->setStatusCode(k400BadRequest); + CTL_INF("Error: " << result.error()); callback(resp); } else { Json::Value res; res["message"] = "Engine starts installing!"; auto resp = cortex_utils::CreateCortexHttpJsonResponse(res); resp->setStatusCode(k200OK); + CTL_INF("Engine starts installing!"); callback(resp); } } diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 2760663d0..de14886da 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -7,6 +7,7 @@ #include "models.h" #include "trantor/utils/Logger.h" #include "utils/cortex_utils.h" +#include "utils/engine_constants.h" #include "utils/file_manager_utils.h" #include "utils/http_util.h" #include "utils/logging_utils.h" @@ -176,15 +177,29 @@ void Models::ListModel( fs::path(model_entry.path_to_model_yaml)) .string()); auto model_config = yaml_handler.GetModelConfig(); - Json::Value obj = model_config.ToJson(); - obj["id"] = model_entry.model; - obj["model"] = model_entry.model; - auto es = model_service_->GetEstimation(model_entry.model); - if (es.has_value()) { - obj["recommendation"] = hardware::ToJson(es.value()); + + if (!remote_engine::IsRemoteEngine(model_config.engine)) { + Json::Value obj = model_config.ToJson(); + obj["id"] = model_entry.model; + obj["model"] = model_entry.model; + obj["model"] = model_entry.model; + auto es = model_service_->GetEstimation(model_entry.model); + if (es.has_value()) { + obj["recommendation"] = hardware::ToJson(es.value()); + } + data.append(std::move(obj)); + yaml_handler.Reset(); + } else { + config::RemoteModelConfig remote_model_config; + remote_model_config.LoadFromYamlFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.path_to_model_yaml)) + .string()); + Json::Value obj = remote_model_config.ToJson(); + obj["id"] = model_entry.model; + obj["model"] = model_entry.model; + data.append(std::move(obj)); } - data.append(std::move(obj)); - yaml_handler.Reset(); } catch (const std::exception& e) { LOG_ERROR << "Failed to load yaml file for model: " << model_entry.path_to_model_yaml << ", error: " << e.what(); @@ -232,16 +247,34 @@ void Models::GetModel(const HttpRequestPtr& req, callback(resp); return; } + yaml_handler.ModelConfigFromFile( fmu::ToAbsoluteCortexDataPath( fs::path(model_entry.value().path_to_model_yaml)) .string()); auto model_config = yaml_handler.GetModelConfig(); + if (model_config.engine == kOnnxEngine || + model_config.engine == kLlamaEngine || + model_config.engine == kTrtLlmEngine) { + auto ret = model_config.ToJsonString(); + auto resp = cortex_utils::CreateCortexHttpTextAsJsonResponse(ret); + resp->setStatusCode(drogon::k200OK); + callback(resp); + } else { + config::RemoteModelConfig remote_model_config; + remote_model_config.LoadFromYamlFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + ret = remote_model_config.ToJson(); + ret["id"] = remote_model_config.model; + ret["object"] = "model"; + ret["result"] = "OK"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + } - auto ret = model_config.ToJsonString(); - auto resp = cortex_utils::CreateCortexHttpTextAsJsonResponse(ret); - resp->setStatusCode(drogon::k200OK); - callback(resp); } catch (const std::exception& e) { std::string message = "Fail to get model information with ID '" + model_id + "': " + e.what(); @@ -289,11 +322,23 @@ void Models::UpdateModel(const HttpRequestPtr& req, fs::path(model_entry.value().path_to_model_yaml)); yaml_handler.ModelConfigFromFile(yaml_fp.string()); config::ModelConfig model_config = yaml_handler.GetModelConfig(); - model_config.FromJson(json_body); - yaml_handler.UpdateModelConfig(model_config); - yaml_handler.WriteYamlFile(yaml_fp.string()); - std::string message = "Successfully update model ID '" + model_id + - "': " + json_body.toStyledString(); + std::string message; + if (model_config.engine == kOnnxEngine || + model_config.engine == kLlamaEngine || + model_config.engine == kTrtLlmEngine) { + model_config.FromJson(json_body); + yaml_handler.UpdateModelConfig(model_config); + yaml_handler.WriteYamlFile(yaml_fp.string()); + message = "Successfully update model ID '" + model_id + + "': " + json_body.toStyledString(); + } else { + config::RemoteModelConfig remote_model_config; + remote_model_config.LoadFromYamlFile(yaml_fp.string()); + remote_model_config.LoadFromJson(json_body); + remote_model_config.SaveToYamlFile(yaml_fp.string()); + message = "Successfully update model ID '" + model_id + + "': " + json_body.toStyledString(); + } LOG_INFO << message; Json::Value ret; ret["result"] = "Updated successfully!"; @@ -344,8 +389,10 @@ void Models::ImportModel( // Use relative path for model_yaml_path. In case of import, we use absolute path for model auto yaml_rel_path = fmu::ToRelativeCortexDataPath(fs::path(model_yaml_path)); - cortex::db::ModelEntry model_entry{modelHandle, "local", "imported", - yaml_rel_path.string(), modelHandle}; + cortex::db::ModelEntry model_entry{ + modelHandle, "", "", yaml_rel_path.string(), + modelHandle, "local", "imported", cortex::db::ModelStatus::Downloaded, + ""}; std::filesystem::create_directories( std::filesystem::path(model_yaml_path).parent_path()); @@ -558,3 +605,122 @@ void Models::GetModelStatus( callback(resp); } } + +void Models::GetRemoteModels( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& engine_id) { + if (!remote_engine::IsRemoteEngine(engine_id)) { + Json::Value ret; + ret["message"] = "Not a remote engine: " + engine_id; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + + auto result = engine_service_->GetRemoteModels(engine_id); + + if (result.has_error()) { + Json::Value ret; + ret["message"] = result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + } else { + auto resp = cortex_utils::CreateCortexHttpJsonResponse(result.value()); + resp->setStatusCode(k200OK); + callback(resp); + } +} + +void Models::AddRemoteModel( + const HttpRequestPtr& req, + std::function&& callback) const { + namespace fs = std::filesystem; + namespace fmu = file_manager_utils; + if (!http_util::HasFieldInReq(req, callback, "model") || + !http_util::HasFieldInReq(req, callback, "engine")) { + return; + } + + auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); + auto engine_name = (*(req->getJsonObject())).get("engine", "").asString(); + /* To do: uncomment when remote engine is ready + + auto engine_validate = engine_service_->IsEngineReady(engine_name); + if (engine_validate.has_error()) { + Json::Value ret; + ret["message"] = engine_validate.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + if (!engine_validate.value()) { + Json::Value ret; + ret["message"] = "Engine is not ready! Please install first!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + */ + config::RemoteModelConfig model_config; + model_config.LoadFromJson(*(req->getJsonObject())); + cortex::db::Models modellist_utils_obj; + std::string model_yaml_path = (file_manager_utils::GetModelsContainerPath() / + std::filesystem::path("remote") / + std::filesystem::path(model_handle + ".yml")) + .string(); + try { + // Use relative path for model_yaml_path. In case of import, we use absolute path for model + auto yaml_rel_path = + fmu::ToRelativeCortexDataPath(fs::path(model_yaml_path)); + // TODO: remove hardcode "openai" when engine is finish + cortex::db::ModelEntry model_entry{ + model_handle, "", "", yaml_rel_path.string(), + model_handle, "remote", "imported", cortex::db::ModelStatus::Remote, + "openai"}; + std::filesystem::create_directories( + std::filesystem::path(model_yaml_path).parent_path()); + if (modellist_utils_obj.AddModelEntry(model_entry).value()) { + model_config.SaveToYamlFile(model_yaml_path); + std::string success_message = "Model is imported successfully!"; + LOG_INFO << success_message; + Json::Value ret; + ret["result"] = "OK"; + ret["modelHandle"] = model_handle; + ret["message"] = success_message; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + + } else { + std::string error_message = "Fail to import model, model_id '" + + model_handle + "' already exists!"; + LOG_ERROR << error_message; + Json::Value ret; + ret["result"] = "Import failed!"; + ret["modelHandle"] = model_handle; + ret["message"] = error_message; + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } + } catch (const std::exception& e) { + std::string error_message = + "Error while adding Remote model with model_id '" + model_handle + + "': " + e.what(); + LOG_ERROR << error_message; + Json::Value ret; + ret["result"] = "Add failed!"; + ret["modelHandle"] = model_handle; + ret["message"] = error_message; + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } +} \ No newline at end of file diff --git a/engine/controllers/models.h b/engine/controllers/models.h index da6caf024..b2b288adc 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -21,6 +21,8 @@ class Models : public drogon::HttpController { METHOD_ADD(Models::StartModel, "/start", Options, Post); METHOD_ADD(Models::StopModel, "/stop", Options, Post); METHOD_ADD(Models::GetModelStatus, "/status/{1}", Get); + METHOD_ADD(Models::AddRemoteModel, "/add", Options, Post); + METHOD_ADD(Models::GetRemoteModels, "/remote/{1}", Get); ADD_METHOD_TO(Models::PullModel, "/v1/models/pull", Options, Post); ADD_METHOD_TO(Models::AbortPullModel, "/v1/models/pull", Options, Delete); @@ -32,6 +34,8 @@ class Models : public drogon::HttpController { ADD_METHOD_TO(Models::StartModel, "/v1/models/start", Options, Post); ADD_METHOD_TO(Models::StopModel, "/v1/models/stop", Options, Post); ADD_METHOD_TO(Models::GetModelStatus, "/v1/models/status/{1}", Get); + ADD_METHOD_TO(Models::AddRemoteModel, "/v1/models/add", Options, Post); + ADD_METHOD_TO(Models::GetRemoteModels, "/v1/models/remote/{1}", Get); METHOD_LIST_END explicit Models(std::shared_ptr model_service, @@ -56,6 +60,9 @@ class Models : public drogon::HttpController { void ImportModel( const HttpRequestPtr& req, std::function&& callback) const; + void AddRemoteModel( + const HttpRequestPtr& req, + std::function&& callback) const; void DeleteModel(const HttpRequestPtr& req, std::function&& callback, const std::string& model_id); @@ -73,6 +80,10 @@ class Models : public drogon::HttpController { std::function&& callback, const std::string& model_id); + void GetRemoteModels(const HttpRequestPtr& req, + std::function&& callback, + const std::string& engine_id); + private: std::shared_ptr model_service_; std::shared_ptr engine_service_; diff --git a/engine/cortex-common/EngineI.h b/engine/cortex-common/EngineI.h index 95ce605de..51e19c124 100644 --- a/engine/cortex-common/EngineI.h +++ b/engine/cortex-common/EngineI.h @@ -37,4 +37,6 @@ class EngineI { virtual bool SetFileLogger(int max_log_lines, const std::string& log_path) = 0; virtual void SetLogLevel(trantor::Logger::LogLevel logLevel) = 0; + + virtual Json::Value GetRemoteModels() = 0; }; diff --git a/engine/cortex-common/remote_enginei.h b/engine/cortex-common/remote_enginei.h new file mode 100644 index 000000000..81ffbf5cd --- /dev/null +++ b/engine/cortex-common/remote_enginei.h @@ -0,0 +1,37 @@ +#pragma once + +#pragma once + +#include +#include + +#include "json/value.h" +#include "trantor/utils/Logger.h" +class RemoteEngineI { + public: + virtual ~RemoteEngineI() {} + + virtual void HandleChatCompletion( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void HandleEmbedding( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void LoadModel( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void UnloadModel( + std::shared_ptr json_body, + std::function&& callback) = 0; + virtual void GetModelStatus( + std::shared_ptr json_body, + std::function&& callback) = 0; + + // Get list of running models + virtual void GetModels( + std::shared_ptr jsonBody, + std::function&& callback) = 0; + + // Get available remote models + virtual Json::Value GetRemoteModels() = 0; +}; diff --git a/engine/database/engines.cc b/engine/database/engines.cc new file mode 100644 index 000000000..a4d13ef79 --- /dev/null +++ b/engine/database/engines.cc @@ -0,0 +1,173 @@ +#include "engines.h" +#include +#include "database.h" + +namespace cortex::db { + +void CreateTable(SQLite::Database& db) {} + +Engines::Engines() : db_(cortex::db::Database::GetInstance().db()) { + CreateTable(db_); +} + +Engines::Engines(SQLite::Database& db) : db_(db) { + CreateTable(db_); +} + +Engines::~Engines() {} + +std::optional Engines::UpsertEngine( + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata) { + try { + SQLite::Statement query( + db_, + "INSERT INTO engines (engine_name, type, api_key, url, version, " + "variant, status, metadata) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?) " + "ON CONFLICT(engine_name, variant) DO UPDATE SET " + "type = excluded.type, " + "api_key = excluded.api_key, " + "url = excluded.url, " + "version = excluded.version, " + "status = excluded.status, " + "metadata = excluded.metadata, " + "date_updated = CURRENT_TIMESTAMP " + "RETURNING id, engine_name, type, api_key, url, version, variant, " + "status, metadata, date_created, date_updated;"); + + query.bind(1, engine_name); + query.bind(2, type); + query.bind(3, api_key); + query.bind(4, url); + query.bind(5, version); + query.bind(6, variant); + query.bind(7, status); + query.bind(8, metadata); + + if (query.executeStep()) { + return EngineEntry{ + query.getColumn(0).getInt(), query.getColumn(1).getString(), + query.getColumn(2).getString(), query.getColumn(3).getString(), + query.getColumn(4).getString(), query.getColumn(5).getString(), + query.getColumn(6).getString(), query.getColumn(7).getString(), + query.getColumn(8).getString(), query.getColumn(9).getString(), + query.getColumn(10).getString()}; + } else { + return std::nullopt; + } + } catch (const std::exception& e) { + return std::nullopt; + } +} + +std::optional> Engines::GetEngines() const { + try { + SQLite::Statement query( + db_, + "SELECT id, engine_name, type, api_key, url, version, variant, status, " + "metadata, date_created, date_updated " + "FROM engines " + "WHERE status = 'Default' " + "ORDER BY date_updated DESC"); + + std::vector engines; + while (query.executeStep()) { + engines.push_back(EngineEntry{ + query.getColumn(0).getInt(), query.getColumn(1).getString(), + query.getColumn(2).getString(), query.getColumn(3).getString(), + query.getColumn(4).getString(), query.getColumn(5).getString(), + query.getColumn(6).getString(), query.getColumn(7).getString(), + query.getColumn(8).getString(), query.getColumn(9).getString(), + query.getColumn(10).getString()}); + } + + return engines; + } catch (const std::exception& e) { + return std::nullopt; + } +} + +std::optional Engines::GetEngineById(int id) const { + try { + SQLite::Statement query( + db_, + "SELECT id, engine_name, type, api_key, url, version, variant, status, " + "metadata, date_created, date_updated " + "FROM engines " + "WHERE id = ? AND status = 'Default' " + "ORDER BY date_updated DESC LIMIT 1"); + + query.bind(1, id); + + if (query.executeStep()) { + return EngineEntry{ + query.getColumn(0).getInt(), query.getColumn(1).getString(), + query.getColumn(2).getString(), query.getColumn(3).getString(), + query.getColumn(4).getString(), query.getColumn(5).getString(), + query.getColumn(6).getString(), query.getColumn(7).getString(), + query.getColumn(8).getString(), query.getColumn(9).getString(), + query.getColumn(10).getString()}; + } else { + return std::nullopt; + } + } catch (const std::exception& e) { + return std::nullopt; + } +} + +std::optional Engines::GetEngineByNameAndVariant( + const std::string& engine_name, + const std::optional variant) const { + try { + std::string queryStr = + "SELECT id, engine_name, type, api_key, url, version, variant, status, " + "metadata, date_created, date_updated " + "FROM engines " + "WHERE engine_name = ? AND status = 'Default' "; + + if (variant) { + queryStr += "AND variant = ? "; + } + + queryStr += "ORDER BY date_updated DESC LIMIT 1"; + + SQLite::Statement query(db_, queryStr); + + query.bind(1, engine_name); + + if (variant) { + query.bind(2, variant.value()); + } + + if (query.executeStep()) { + return EngineEntry{ + query.getColumn(0).getInt(), query.getColumn(1).getString(), + query.getColumn(2).getString(), query.getColumn(3).getString(), + query.getColumn(4).getString(), query.getColumn(5).getString(), + query.getColumn(6).getString(), query.getColumn(7).getString(), + query.getColumn(8).getString(), query.getColumn(9).getString(), + query.getColumn(10).getString()}; + } else { + return std::nullopt; + } + } catch (const std::exception& e) { + return std::nullopt; + } +} + +std::optional Engines::DeleteEngineById(int id) { + try { + SQLite::Statement query(db_, "DELETE FROM engines WHERE id = ?"); + + query.bind(1, id); + query.exec(); + return std::nullopt; + } catch (const std::exception& e) { + return std::string("Failed to delete engine: ") + e.what(); + } +} + +} // namespace cortex::db \ No newline at end of file diff --git a/engine/database/engines.h b/engine/database/engines.h new file mode 100644 index 000000000..7429d0fa2 --- /dev/null +++ b/engine/database/engines.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace cortex::db { + +struct EngineEntry { + int id; + std::string engine_name; + std::string type; + std::string api_key; + std::string url; + std::string version; + std::string variant; + std::string status; + std::string metadata; + std::string date_created; + std::string date_updated; + Json::Value ToJson() const { + Json::Value root; + Json::Reader reader; + + // Convert basic fields + root["id"] = id; + root["engine_name"] = engine_name; + root["type"] = type; + root["api_key"] = api_key; + root["url"] = url; + root["version"] = version; + root["variant"] = variant; + root["status"] = status; + root["date_created"] = date_created; + root["date_updated"] = date_updated; + + // Parse metadata string into JSON object + Json::Value metadataJson; + if (!metadata.empty()) { + bool success = reader.parse(metadata, metadataJson, + false); // false = don't collect comments + if (success) { + root["metadata"] = metadataJson; + } else { + root["metadata"] = Json::Value::null; + } + } else { + root["metadata"] = Json::Value(Json::objectValue); // empty object + } + + return root; + } +}; + +class Engines { + private: + SQLite::Database& db_; + + bool IsUnique(const std::vector& entries, + const std::string& model_id, + const std::string& model_alias) const; + + std::optional> LoadModelListNoLock() const; + + public: + Engines(); + Engines(SQLite::Database& db); + ~Engines(); + + std::optional UpsertEngine( + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata); + + std::optional> GetEngines() const; + std::optional GetEngineById(int id) const; + std::optional GetEngineByNameAndVariant( + const std::string& engine_name, + const std::optional variant = std::nullopt) const; + + std::optional DeleteEngineById(int id); +}; + +} // namespace cortex::db \ No newline at end of file diff --git a/engine/database/models.cc b/engine/database/models.cc index 3e81fbab2..fb2128396 100644 --- a/engine/database/models.cc +++ b/engine/database/models.cc @@ -9,9 +9,32 @@ namespace cortex::db { Models::Models() : db_(cortex::db::Database::GetInstance().db()) {} +Models::~Models() {} + +std::string Models::StatusToString(ModelStatus status) const { + switch (status) { + case ModelStatus::Remote: + return "remote"; + case ModelStatus::Downloaded: + return "downloaded"; + case ModelStatus::Undownloaded: + return "undownloaded"; + } + return "unknown"; +} + Models::Models(SQLite::Database& db) : db_(db) {} -Models::~Models() {} +ModelStatus Models::StringToStatus(const std::string& status_str) const { + if (status_str == "remote") { + return ModelStatus::Remote; + } else if (status_str == "downloaded" || status_str.empty()) { + return ModelStatus::Downloaded; + } else if (status_str == "undownloaded") { + return ModelStatus::Undownloaded; + } + throw std::invalid_argument("Invalid status string"); +} cpp::result, std::string> Models::LoadModelList() const { @@ -41,7 +64,8 @@ cpp::result, std::string> Models::LoadModelListNoLock() std::vector entries; SQLite::Statement query(db_, "SELECT model_id, author_repo_id, branch_name, " - "path_to_model_yaml, model_alias FROM models"); + "path_to_model_yaml, model_alias, model_format, " + "model_source, status, engine FROM models"); while (query.executeStep()) { ModelEntry entry; @@ -50,6 +74,10 @@ cpp::result, std::string> Models::LoadModelListNoLock() entry.branch_name = query.getColumn(2).getString(); entry.path_to_model_yaml = query.getColumn(3).getString(); entry.model_alias = query.getColumn(4).getString(); + entry.model_format = query.getColumn(5).getString(); + entry.model_source = query.getColumn(6).getString(); + entry.status = StringToStatus(query.getColumn(7).getString()); + entry.engine = query.getColumn(8).getString(); entries.push_back(entry); } return entries; @@ -124,7 +152,8 @@ cpp::result Models::GetModelInfo( try { SQLite::Statement query(db_, "SELECT model_id, author_repo_id, branch_name, " - "path_to_model_yaml, model_alias FROM models " + "path_to_model_yaml, model_alias, model_format, " + "model_source, status, engine FROM models " "WHERE model_id = ? OR model_alias = ?"); query.bind(1, identifier); @@ -136,6 +165,10 @@ cpp::result Models::GetModelInfo( entry.branch_name = query.getColumn(2).getString(); entry.path_to_model_yaml = query.getColumn(3).getString(); entry.model_alias = query.getColumn(4).getString(); + entry.model_format = query.getColumn(5).getString(); + entry.model_source = query.getColumn(6).getString(); + entry.status = StringToStatus(query.getColumn(7).getString()); + entry.engine = query.getColumn(8).getString(); return entry; } else { return cpp::fail("Model not found: " + identifier); @@ -151,6 +184,10 @@ void Models::PrintModelInfo(const ModelEntry& entry) const { LOG_INFO << "Branch Name: " << entry.branch_name; LOG_INFO << "Path to model.yaml: " << entry.path_to_model_yaml; LOG_INFO << "Model Alias: " << entry.model_alias; + LOG_INFO << "Model Format: " << entry.model_format; + LOG_INFO << "Model Source: " << entry.model_source; + LOG_INFO << "Status: " << StatusToString(entry.status); + LOG_INFO << "Engine: " << entry.engine; } cpp::result Models::AddModelEntry(ModelEntry new_entry, @@ -171,14 +208,18 @@ cpp::result Models::AddModelEntry(ModelEntry new_entry, SQLite::Statement insert( db_, - "INSERT INTO models (model_id, author_repo_id, " - "branch_name, path_to_model_yaml, model_alias) VALUES (?, ?, " - "?, ?, ?)"); + "INSERT INTO models (model_id, author_repo_id, branch_name, " + "path_to_model_yaml, model_alias, model_format, model_source, " + "status, engine) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"); insert.bind(1, new_entry.model); insert.bind(2, new_entry.author_repo_id); insert.bind(3, new_entry.branch_name); insert.bind(4, new_entry.path_to_model_yaml); insert.bind(5, new_entry.model_alias); + insert.bind(6, new_entry.model_format); + insert.bind(7, new_entry.model_source); + insert.bind(8, StatusToString(new_entry.status)); + insert.bind(9, new_entry.engine); insert.exec(); return true; @@ -196,16 +237,20 @@ cpp::result Models::UpdateModelEntry( return cpp::fail("Model not found: " + identifier); } try { - SQLite::Statement upd(db_, - "UPDATE models " - "SET author_repo_id = ?, branch_name = ?, " - "path_to_model_yaml = ? " - "WHERE model_id = ? OR model_alias = ?"); + SQLite::Statement upd( + db_, + "UPDATE models SET author_repo_id = ?, branch_name = ?, " + "path_to_model_yaml = ?, model_format = ?, model_source = ?, status = " + "?, engine = ? WHERE model_id = ? OR model_alias = ?"); upd.bind(1, updated_entry.author_repo_id); upd.bind(2, updated_entry.branch_name); upd.bind(3, updated_entry.path_to_model_yaml); - upd.bind(4, identifier); - upd.bind(5, identifier); + upd.bind(4, updated_entry.model_format); + upd.bind(5, updated_entry.model_source); + upd.bind(6, StatusToString(updated_entry.status)); + upd.bind(7, updated_entry.engine); + upd.bind(8, identifier); + upd.bind(9, identifier); return upd.exec() == 1; } catch (const std::exception& e) { return cpp::fail(e.what()); @@ -293,4 +338,5 @@ bool Models::HasModel(const std::string& identifier) const { return false; } } -} // namespace cortex::db + +} // namespace cortex::db \ No newline at end of file diff --git a/engine/database/models.h b/engine/database/models.h index 197996ab8..dd6e2a5a1 100644 --- a/engine/database/models.h +++ b/engine/database/models.h @@ -7,12 +7,23 @@ #include "utils/result.hpp" namespace cortex::db { + +enum class ModelStatus { + Remote, + Downloaded, + Undownloaded +}; + struct ModelEntry { - std::string model; + std::string model; std::string author_repo_id; std::string branch_name; std::string path_to_model_yaml; std::string model_alias; + std::string model_format; + std::string model_source; + ModelStatus status; + std::string engine; }; class Models { @@ -26,6 +37,9 @@ class Models { cpp::result, std::string> LoadModelListNoLock() const; + std::string StatusToString(ModelStatus status) const; + ModelStatus StringToStatus(const std::string& status_str) const; + public: cpp::result, std::string> LoadModelList() const; Models(); @@ -49,4 +63,5 @@ class Models { const std::string& identifier) const; bool HasModel(const std::string& identifier) const; }; -} // namespace cortex::db + +} // namespace cortex::db \ No newline at end of file diff --git a/engine/extensions/remote-engine/anthropic_engine.cc b/engine/extensions/remote-engine/anthropic_engine.cc new file mode 100644 index 000000000..847cba566 --- /dev/null +++ b/engine/extensions/remote-engine/anthropic_engine.cc @@ -0,0 +1,62 @@ +#include "anthropic_engine.h" +#include +#include +#include "utils/logging_utils.h" + +namespace remote_engine { +namespace { +constexpr const std::array kAnthropicModels = { + "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", + "claude-3-opus-20240229", "claude-3-sonnet-20240229", + "claude-3-haiku-20240307"}; +} +void AnthropicEngine::GetModels( + std::shared_ptr json_body, + std::function&& callback) { + Json::Value json_resp; + Json::Value model_array(Json::arrayValue); + { + std::shared_lock l(models_mtx_); + for (const auto& [m, _] : models_) { + Json::Value val; + val["id"] = m; + val["engine"] = "anthropic"; + val["start_time"] = "_"; + val["model_size"] = "_"; + val["vram"] = "_"; + val["ram"] = "_"; + val["object"] = "model"; + model_array.append(val); + } + } + + json_resp["object"] = "list"; + json_resp["data"] = model_array; + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = 200; + callback(std::move(status), std::move(json_resp)); + CTL_INF("Running models responded"); +} + +Json::Value AnthropicEngine::GetRemoteModels() { + Json::Value json_resp; + Json::Value model_array(Json::arrayValue); + for (const auto& m : kAnthropicModels) { + Json::Value val; + val["id"] = std::string(m); + val["engine"] = "anthropic"; + val["created"] = "_"; + val["object"] = "model"; + model_array.append(val); + } + + json_resp["object"] = "list"; + json_resp["data"] = model_array; + CTL_INF("Remote models responded"); + return json_resp; +} +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/anthropic_engine.h b/engine/extensions/remote-engine/anthropic_engine.h new file mode 100644 index 000000000..bcd3dfaf7 --- /dev/null +++ b/engine/extensions/remote-engine/anthropic_engine.h @@ -0,0 +1,13 @@ +#pragma once +#include "remote_engine.h" + +namespace remote_engine { + class AnthropicEngine: public RemoteEngine { +public: + void GetModels( + std::shared_ptr json_body, + std::function&& callback) override; + + Json::Value GetRemoteModels() override; + }; +} \ No newline at end of file diff --git a/engine/extensions/remote-engine/openai_engine.cc b/engine/extensions/remote-engine/openai_engine.cc new file mode 100644 index 000000000..7c7d70385 --- /dev/null +++ b/engine/extensions/remote-engine/openai_engine.cc @@ -0,0 +1,54 @@ +#include "openai_engine.h" +#include "utils/logging_utils.h" + +namespace remote_engine { + +void OpenAiEngine::GetModels( + std::shared_ptr json_body, + std::function&& callback) { + Json::Value json_resp; + Json::Value model_array(Json::arrayValue); + { + std::shared_lock l(models_mtx_); + for (const auto& [m, _] : models_) { + Json::Value val; + val["id"] = m; + val["engine"] = "openai"; + val["start_time"] = "_"; + val["model_size"] = "_"; + val["vram"] = "_"; + val["ram"] = "_"; + val["object"] = "model"; + model_array.append(val); + } + } + + json_resp["object"] = "list"; + json_resp["data"] = model_array; + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = 200; + callback(std::move(status), std::move(json_resp)); + CTL_INF("Running models responded"); +} + +Json::Value OpenAiEngine::GetRemoteModels() { + auto response = MakeGetModelsRequest(); + if (response.error) { + Json::Value error; + error["error"] = response.error_message; + return error; + } + Json::Value response_json; + Json::Reader reader; + if (!reader.parse(response.body, response_json)) { + Json::Value error; + error["error"] = "Failed to parse response"; + return error; + } + return response_json; +} +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/openai_engine.h b/engine/extensions/remote-engine/openai_engine.h new file mode 100644 index 000000000..61dc68f0c --- /dev/null +++ b/engine/extensions/remote-engine/openai_engine.h @@ -0,0 +1,14 @@ +#pragma once + +#include "remote_engine.h" + +namespace remote_engine { +class OpenAiEngine : public RemoteEngine { + public: + void GetModels( + std::shared_ptr json_body, + std::function&& callback) override; + + Json::Value GetRemoteModels() override; +}; +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc new file mode 100644 index 000000000..04effb457 --- /dev/null +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -0,0 +1,712 @@ +#include "remote_engine.h" +#include +#include +#include +#include +#include "utils/json_helper.h" +#include "utils/logging_utils.h" +namespace remote_engine { +namespace { +constexpr const int k200OK = 200; +constexpr const int k400BadRequest = 400; +constexpr const int k409Conflict = 409; +constexpr const int k500InternalServerError = 500; +constexpr const int kFileLoggerOption = 0; +bool is_anthropic(const std::string& model) { + return model.find("claude") != std::string::npos; +} + +struct AnthropicChunk { + std::string type; + std::string id; + int index; + std::string msg; + std::string model; + std::string stop_reason; + bool should_ignore = false; + + AnthropicChunk(const std::string& str) { + if (str.size() > 6) { + std::string s = str.substr(6); + try { + auto root = json_helper::ParseJsonString(s); + type = root["type"].asString(); + if (type == "message_start") { + id = root["message"]["id"].asString(); + model = root["message"]["model"].asString(); + } else if (type == "content_block_delta") { + index = root["index"].asInt(); + if (root["delta"]["type"].asString() == "text_delta") { + msg = root["delta"]["text"].asString(); + } + } else if (type == "message_delta") { + stop_reason = root["delta"]["stop_reason"].asString(); + } else { + // ignore other messages + should_ignore = true; + } + } catch (const std::exception& e) { + should_ignore = true; + CTL_WRN("JSON parse error: " << e.what()); + } + } else { + should_ignore = true; + } + } + + std::string ToOpenAiFormatString() { + Json::Value root; + root["id"] = id; + root["object"] = "chat.completion.chunk"; + root["created"] = Json::Value(); + root["model"] = model; + root["system_fingerprint"] = "fp_e76890f0c3"; + Json::Value choices(Json::arrayValue); + Json::Value choice; + Json::Value content; + choice["index"] = 0; + content["content"] = msg; + if (type == "message_start") { + content["role"] = "assistant"; + content["refusal"] = Json::Value(); + } + choice["delta"] = content; + choice["finish_reason"] = stop_reason.empty() ? Json::Value() : stop_reason; + choices.append(choice); + root["choices"] = choices; + return "data: " + json_helper::DumpJsonString(root); + } +}; + +} // namespace + +size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, + void* userdata) { + auto* context = static_cast(userdata); + std::string chunk(ptr, size * nmemb); + + context->buffer += chunk; + + // Process complete lines + size_t pos; + while ((pos = context->buffer.find('\n')) != std::string::npos) { + std::string line = context->buffer.substr(0, pos); + context->buffer = context->buffer.substr(pos + 1); + CTL_TRC(line); + + // Skip empty lines + if (line.empty() || line == "\r" || + line.find("event:") != std::string::npos) + continue; + + // Remove "data: " prefix if present + // if (line.substr(0, 6) == "data: ") + // { + // line = line.substr(6); + // } + + // Skip [DONE] message + // std::cout << line << std::endl; + if (line == "data: [DONE]" || + line.find("message_stop") != std::string::npos) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = 200; + (*context->callback)(std::move(status), Json::Value()); + break; + } + + // Parse the JSON + Json::Value chunk_json; + if (is_anthropic(context->model)) { + AnthropicChunk ac(line); + if (ac.should_ignore) + continue; + ac.model = context->model; + if (ac.type == "message_start") { + context->id = ac.id; + } else { + ac.id = context->id; + } + chunk_json["data"] = ac.ToOpenAiFormatString() + "\n\n"; + } else { + chunk_json["data"] = line + "\n\n"; + } + Json::Reader reader; + + Json::Value status; + status["is_done"] = false; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = 200; + (*context->callback)(std::move(status), std::move(chunk_json)); + } + + return size * nmemb; +} + +CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( + const ModelConfig& config, const std::string& body, + const std::function& callback) { + + CURL* curl = curl_easy_init(); + CurlResponse response; + + if (!curl) { + response.error = true; + response.error_message = "Failed to initialize CURL"; + return response; + } + + std::string full_url = + config.transform_req["chat_completions"]["url"].as(); + + struct curl_slist* headers = nullptr; + if (!config.api_key.empty()) { + headers = curl_slist_append(headers, api_key_template_.c_str()); + } + + if (is_anthropic(config.model)) { + std::string v = "anthropic-version: " + config.version; + headers = curl_slist_append(headers, v.c_str()); + } + + headers = curl_slist_append(headers, "Content-Type: application/json"); + headers = curl_slist_append(headers, "Accept: text/event-stream"); + headers = curl_slist_append(headers, "Cache-Control: no-cache"); + headers = curl_slist_append(headers, "Connection: keep-alive"); + + StreamContext context{ + std::make_shared>( + callback), + "", "", config.model}; + + curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, StreamWriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &context); + curl_easy_setopt(curl, CURLOPT_TRANSFER_ENCODING, 1L); + + CURLcode res = curl_easy_perform(curl); + + if (res != CURLE_OK) { + response.error = true; + response.error_message = curl_easy_strerror(res); + + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = true; + status["status_code"] = 500; + + Json::Value error; + error["error"] = response.error_message; + callback(std::move(status), std::move(error)); + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + return response; +} + +std::string ReplaceApiKeyPlaceholder(const std::string& templateStr, + const std::string& apiKey) { + const std::string placeholder = "{{api_key}}"; + std::string result = templateStr; + size_t pos = result.find(placeholder); + + if (pos != std::string::npos) { + result.replace(pos, placeholder.length(), apiKey); + } + + return result; +} + +static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, + std::string* data) { + data->append(ptr, size * nmemb); + return size * nmemb; +} + +RemoteEngine::RemoteEngine() { + curl_global_init(CURL_GLOBAL_ALL); +} + +RemoteEngine::~RemoteEngine() { + curl_global_cleanup(); +} + +RemoteEngine::ModelConfig* RemoteEngine::GetModelConfig( + const std::string& model) { + std::shared_lock lock(models_mtx_); + auto it = models_.find(model); + if (it != models_.end()) { + return &it->second; + } + return nullptr; +} + +CurlResponse RemoteEngine::MakeGetModelsRequest() { + CURL* curl = curl_easy_init(); + CurlResponse response; + + if (!curl) { + response.error = true; + response.error_message = "Failed to initialize CURL"; + return response; + } + + std::string full_url = metadata_["get_models_url"].asString(); + + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, api_key_template_.c_str()); + headers = curl_slist_append(headers, "Content-Type: application/json"); + + curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + std::string response_string; + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_string); + + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + response.error = true; + response.error_message = curl_easy_strerror(res); + } else { + response.body = response_string; + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + return response; +} + +CurlResponse RemoteEngine::MakeChatCompletionRequest( + const ModelConfig& config, const std::string& body, + const std::string& method) { + CURL* curl = curl_easy_init(); + CurlResponse response; + + if (!curl) { + response.error = true; + response.error_message = "Failed to initialize CURL"; + return response; + } + std::string full_url = + config.transform_req["chat_completions"]["url"].as(); + + struct curl_slist* headers = nullptr; + if (!config.api_key.empty()) { + headers = curl_slist_append(headers, api_key_template_.c_str()); + } + + if (is_anthropic(config.model)) { + std::string v = "anthropic-version: " + config.version; + headers = curl_slist_append(headers, v.c_str()); + } + headers = curl_slist_append(headers, "Content-Type: application/json"); + + curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + if (method == "POST") { + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); + } + + std::string response_string; + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_string); + + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + response.error = true; + response.error_message = curl_easy_strerror(res); + } else { + response.body = response_string; + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + return response; +} + +bool RemoteEngine::LoadModelConfig(const std::string& model, + const std::string& yaml_path, + const std::string& api_key) { + try { + YAML::Node config = YAML::LoadFile(yaml_path); + + ModelConfig model_config; + model_config.model = model; + if (is_anthropic(model)) { + if (!config["version"]) { + CTL_ERR("Missing version for model: " << model); + return false; + } + model_config.version = config["version"].as(); + } + + // Required fields + if (!config["api_key_template"]) { + LOG_ERROR << "Missing required fields in config for model " << model; + return false; + } + + model_config.api_key = api_key; + // model_config.url = ; + // Optional fields + if (config["api_key_template"]) { + api_key_template_ = ReplaceApiKeyPlaceholder( + config["api_key_template"].as(), api_key); + } + if (config["TransformReq"]) { + model_config.transform_req = config["TransformReq"]; + } else { + LOG_WARN << "Missing TransformReq in config for model " << model; + } + if (config["TransformResp"]) { + model_config.transform_resp = config["TransformResp"]; + } else { + LOG_WARN << "Missing TransformResp in config for model " << model; + } + + model_config.is_loaded = true; + + // Thread-safe update of models map + { + std::unique_lock lock(models_mtx_); + models_[model] = std::move(model_config); + } + CTL_DBG("LoadModelConfig successfully: " << model << ", " << yaml_path); + + return true; + } catch (const YAML::Exception& e) { + LOG_ERROR << "Failed to load config for model " << model << ": " + << e.what(); + return false; + } +} + +void RemoteEngine::GetModels( + std::shared_ptr json_body, + std::function&& callback) { + CTL_WRN("Not implemented yet!"); +} + +void RemoteEngine::LoadModel( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model") || !json_body->isMember("model_path") || + !json_body->isMember("api_key")) { + Json::Value error; + error["error"] = "Missing required fields: model or model_path"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); + return; + } + + const std::string& model = (*json_body)["model"].asString(); + const std::string& model_path = (*json_body)["model_path"].asString(); + const std::string& api_key = (*json_body)["api_key"].asString(); + + if (!LoadModelConfig(model, model_path, api_key)) { + Json::Value error; + error["error"] = "Failed to load model configuration"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + callback(std::move(status), std::move(error)); + return; + } + if (json_body->isMember("metadata")) { + metadata_ = (*json_body)["metadata"]; + } + + Json::Value response; + response["status"] = "Model loaded successfully"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + callback(std::move(status), std::move(response)); + CTL_INF("Model loaded successfully: " << model); +} + +void RemoteEngine::UnloadModel( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model")) { + Json::Value error; + error["error"] = "Missing required field: model"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + callback(std::move(status), std::move(error)); + return; + } + + const std::string& model = (*json_body)["model"].asString(); + + { + std::unique_lock lock(models_mtx_); + models_.erase(model); + } + + Json::Value response; + response["status"] = "Model unloaded successfully"; + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + callback(std::move(status), std::move(response)); +} + +void RemoteEngine::HandleChatCompletion( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model")) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = "Missing required fields: model"; + callback(std::move(status), std::move(error)); + return; + } + + const std::string& model = (*json_body)["model"].asString(); + auto* model_config = GetModelConfig(model); + + if (!model_config || !model_config->is_loaded) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = "Model not found or not loaded: " + model; + callback(std::move(status), std::move(error)); + return; + } + bool is_stream = + json_body->isMember("stream") && (*json_body)["stream"].asBool(); + Json::FastWriter writer; + // Transform request + std::string result; + try { + // Check if required YAML nodes exist + if (!model_config->transform_req["chat_completions"]) { + throw std::runtime_error( + "Missing 'chat_completions' node in transform_req"); + } + if (!model_config->transform_req["chat_completions"]["template"]) { + throw std::runtime_error("Missing 'template' node in chat_completions"); + } + + // Validate JSON body + if (!json_body || json_body->isNull()) { + throw std::runtime_error("Invalid or null JSON body"); + } + + // Get template string with error check + std::string template_str; + try { + template_str = model_config->transform_req["chat_completions"]["template"] + .as(); + } catch (const YAML::BadConversion& e) { + throw std::runtime_error("Failed to convert template node to string: " + + std::string(e.what())); + } + + // Parse system for anthropic + if (is_anthropic(model)) { + bool has_system = false; + Json::Value msgs(Json::arrayValue); + for (auto& kv : (*json_body)["messages"]) { + if (kv["role"].asString() == "system") { + (*json_body)["system"] = kv["content"].asString(); + has_system = true; + } else { + msgs.append(kv); + } + } + if (has_system) { + (*json_body)["messages"] = msgs; + } + } + + // Render with error handling + try { + result = renderer_.Render(template_str, *json_body); + } catch (const std::exception& e) { + throw std::runtime_error("Template rendering error: " + + std::string(e.what())); + } + } catch (const std::exception& e) { + // Log error and potentially rethrow or handle accordingly + LOG_WARN << "Error in TransformRequest: " << e.what(); + LOG_WARN << "Using original request body"; + result = (*json_body).toStyledString(); + } + + if (is_stream) { + MakeStreamingChatCompletionRequest(*model_config, result, callback); + } else { + + auto response = MakeChatCompletionRequest(*model_config, result); + + if (response.error) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = response.error_message; + callback(std::move(status), std::move(error)); + return; + } + + Json::Value response_json; + Json::Reader reader; + if (!reader.parse(response.body, response_json)) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + Json::Value error; + error["error"] = "Failed to parse response"; + callback(std::move(status), std::move(error)); + return; + } + + // Transform Response + std::string response_str; + try { + // Check if required YAML nodes exist + if (!model_config->transform_resp["chat_completions"]) { + throw std::runtime_error( + "Missing 'chat_completions' node in transform_resp"); + } + if (!model_config->transform_resp["chat_completions"]["template"]) { + throw std::runtime_error("Missing 'template' node in chat_completions"); + } + + // Validate JSON body + if (!response_json || response_json.isNull()) { + throw std::runtime_error("Invalid or null JSON body"); + } + + // Get template string with error check + std::string template_str; + try { + template_str = + model_config->transform_resp["chat_completions"]["template"] + .as(); + } catch (const YAML::BadConversion& e) { + throw std::runtime_error("Failed to convert template node to string: " + + std::string(e.what())); + } + + // Render with error handling + try { + response_str = renderer_.Render(template_str, response_json); + } catch (const std::exception& e) { + throw std::runtime_error("Template rendering error: " + + std::string(e.what())); + } + } catch (const std::exception& e) { + // Log error and potentially rethrow or handle accordingly + LOG_WARN << "Error in TransformRequest: " << e.what(); + LOG_WARN << "Using original request body"; + response_str = response_json.toStyledString(); + } + + Json::Reader reader_final; + Json::Value response_json_final; + if (!reader_final.parse(response_str, response_json_final)) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + Json::Value error; + error["error"] = "Failed to parse response"; + callback(std::move(status), std::move(error)); + return; + } + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + + callback(std::move(status), std::move(response_json_final)); + } +} + +void RemoteEngine::GetModelStatus( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model")) { + Json::Value error; + error["error"] = "Missing required field: model"; + callback(Json::Value(), std::move(error)); + return; + } + + const std::string& model = (*json_body)["model"].asString(); + auto* model_config = GetModelConfig(model); + + if (!model_config) { + Json::Value error; + error["error"] = "Model not found: " + model; + callback(Json::Value(), std::move(error)); + return; + } + + Json::Value response; + response["model"] = model; + response["model_loaded"] = model_config->is_loaded; + response["model_data"] = model_config->url; + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + callback(std::move(status), std::move(response)); +} + +// Implement remaining virtual functions +void RemoteEngine::HandleEmbedding( + std::shared_ptr, + std::function&& callback) { + callback(Json::Value(), Json::Value()); +} + +Json::Value RemoteEngine::GetRemoteModels() { + CTL_WRN("Not implemented yet!"); + return {}; +} + +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h new file mode 100644 index 000000000..8ce6fa652 --- /dev/null +++ b/engine/extensions/remote-engine/remote_engine.h @@ -0,0 +1,102 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "cortex-common/remote_enginei.h" +#include "extensions/remote-engine/template_renderer.h" +#include "utils/engine_constants.h" +#include "utils/file_logger.h" +// Helper for CURL response + +namespace remote_engine { +inline bool IsRemoteEngine(std::string_view e) { + return e == kAnthropicEngine || e == kOpenAiEngine; +} + +struct StreamContext { + std::shared_ptr> callback; + std::string buffer; + // Cache value for Anthropic + std::string id; + std::string model; +}; +struct CurlResponse { + std::string body; + bool error{false}; + std::string error_message; +}; + +class RemoteEngine : public RemoteEngineI { + protected: + // Model configuration + struct ModelConfig { + std::string model; + std::string version; + std::string api_key; + std::string url; + YAML::Node transform_req; + YAML::Node transform_resp; + bool is_loaded{false}; + }; + + // Thread-safe model config storage + mutable std::shared_mutex models_mtx_; + std::unordered_map models_; + TemplateRenderer renderer_; + Json::Value metadata_; + std::string api_key_template_; + std::unique_ptr async_file_logger_; + + // Helper functions + CurlResponse MakeChatCompletionRequest(const ModelConfig& config, + const std::string& body, + const std::string& method = "POST"); + CurlResponse MakeStreamingChatCompletionRequest( + const ModelConfig& config, const std::string& body, + const std::function& callback); + CurlResponse MakeGetModelsRequest(); + + // Internal model management + bool LoadModelConfig(const std::string& model, const std::string& yaml_path, + const std::string& api_key); + ModelConfig* GetModelConfig(const std::string& model); + + public: + RemoteEngine(); + virtual ~RemoteEngine(); + + // Main interface implementations + void GetModels( + std::shared_ptr json_body, + std::function&& callback) override; + + void HandleChatCompletion( + std::shared_ptr json_body, + std::function&& callback) override; + + void LoadModel( + std::shared_ptr json_body, + std::function&& callback) override; + + void UnloadModel( + std::shared_ptr json_body, + std::function&& callback) override; + + void GetModelStatus( + std::shared_ptr json_body, + std::function&& callback) override; + + // Other required virtual functions + void HandleEmbedding( + std::shared_ptr json_body, + std::function&& callback) override; + + Json::Value GetRemoteModels() override; +}; + +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/template_renderer.cc b/engine/extensions/remote-engine/template_renderer.cc new file mode 100644 index 000000000..15514d17c --- /dev/null +++ b/engine/extensions/remote-engine/template_renderer.cc @@ -0,0 +1,136 @@ +#if defined(_WIN32) || defined(_WIN64) +#define NOMINMAX +#undef min +#undef max +#endif +#include "template_renderer.h" +#include +#include +#include "utils/logging_utils.h" +namespace remote_engine { +TemplateRenderer::TemplateRenderer() { + // Configure Inja environment + env_.set_trim_blocks(true); + env_.set_lstrip_blocks(true); + + // Add tojson function for all value types + env_.add_callback("tojson", 1, [](inja::Arguments& args) { + if (args.empty()) { + return nlohmann::json(nullptr); + } + const auto& value = *args[0]; + + if (value.is_string()) { + return nlohmann::json(std::string("\"") + value.get() + + "\""); + } + return value; + }); +} + +std::string TemplateRenderer::Render(const std::string& tmpl, + const Json::Value& data) { + try { + // Convert Json::Value to nlohmann::json + auto json_data = ConvertJsonValue(data); + + // Create the input data structure expected by the template + nlohmann::json template_data; + template_data["input_request"] = json_data; + + // Debug output + LOG_DEBUG << "Template: " << tmpl; + LOG_DEBUG << "Data: " << template_data.dump(2); + + // Render template + std::string result = env_.render(tmpl, template_data); + + // Clean up any potential double quotes in JSON strings + result = std::regex_replace(result, std::regex("\\\"\\\""), "\""); + + LOG_DEBUG << "Result: " << result; + + // Validate JSON + auto parsed = nlohmann::json::parse(result); + + return result; + } catch (const std::exception& e) { + LOG_ERROR << "Template rendering failed: " << e.what(); + LOG_ERROR << "Template: " << tmpl; + throw std::runtime_error(std::string("Template rendering failed: ") + + e.what()); + } +} + +nlohmann::json TemplateRenderer::ConvertJsonValue(const Json::Value& input) { + if (input.isNull()) { + return nullptr; + } else if (input.isBool()) { + return input.asBool(); + } else if (input.isInt()) { + return input.asInt(); + } else if (input.isUInt()) { + return input.asUInt(); + } else if (input.isDouble()) { + return input.asDouble(); + } else if (input.isString()) { + return input.asString(); + } else if (input.isArray()) { + nlohmann::json arr = nlohmann::json::array(); + for (const auto& element : input) { + arr.push_back(ConvertJsonValue(element)); + } + return arr; + } else if (input.isObject()) { + nlohmann::json obj = nlohmann::json::object(); + for (const auto& key : input.getMemberNames()) { + obj[key] = ConvertJsonValue(input[key]); + } + return obj; + } + return nullptr; +} + +Json::Value TemplateRenderer::ConvertNlohmannJson(const nlohmann::json& input) { + if (input.is_null()) { + return Json::Value(); + } else if (input.is_boolean()) { + return Json::Value(input.get()); + } else if (input.is_number_integer()) { + return Json::Value(input.get()); + } else if (input.is_number_unsigned()) { + return Json::Value(input.get()); + } else if (input.is_number_float()) { + return Json::Value(input.get()); + } else if (input.is_string()) { + return Json::Value(input.get()); + } else if (input.is_array()) { + Json::Value arr(Json::arrayValue); + for (const auto& element : input) { + arr.append(ConvertNlohmannJson(element)); + } + return arr; + } else if (input.is_object()) { + Json::Value obj(Json::objectValue); + for (auto it = input.begin(); it != input.end(); ++it) { + obj[it.key()] = ConvertNlohmannJson(it.value()); + } + return obj; + } + return Json::Value(); +} + +std::string TemplateRenderer::RenderFile(const std::string& template_path, + const Json::Value& data) { + try { + // Convert Json::Value to nlohmann::json + auto json_data = ConvertJsonValue(data); + + // Load and render template + return env_.render_file(template_path, json_data); + } catch (const std::exception& e) { + throw std::runtime_error(std::string("Template file rendering failed: ") + + e.what()); + } +} +} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/template_renderer.h b/engine/extensions/remote-engine/template_renderer.h new file mode 100644 index 000000000..f59e7cc93 --- /dev/null +++ b/engine/extensions/remote-engine/template_renderer.h @@ -0,0 +1,40 @@ +#pragma once + +#include + +#include +#include "json/json.h" +#include "trantor/utils/Logger.h" +// clang-format off +#if defined(_WIN32) || defined(_WIN64) +#define NOMINMAX +#undef min +#undef max +#endif +#include +#include +// clang-format on +namespace remote_engine { +class TemplateRenderer { + public: + TemplateRenderer(); + ~TemplateRenderer() = default; + + // Convert Json::Value to nlohmann::json + static nlohmann::json ConvertJsonValue(const Json::Value& input); + + // Convert nlohmann::json to Json::Value + static Json::Value ConvertNlohmannJson(const nlohmann::json& input); + + // Render template with data + std::string Render(const std::string& tmpl, const Json::Value& data); + + // Load template from file and render + std::string RenderFile(const std::string& template_path, + const Json::Value& data); + + private: + inja::Environment env_; +}; + +} // namespace remote_engine \ No newline at end of file diff --git a/engine/migrations/db_helper.h b/engine/migrations/db_helper.h new file mode 100644 index 000000000..0990426bf --- /dev/null +++ b/engine/migrations/db_helper.h @@ -0,0 +1,26 @@ +#pragma once +#include + +namespace cortex::mgr { +#include +#include +#include +#include + +inline bool ColumnExists(SQLite::Database& db, const std::string& table_name, const std::string& column_name) { + try { + SQLite::Statement query(db, "SELECT " + column_name + " FROM " + table_name + " LIMIT 0"); + return true; + } catch (std::exception&) { + return false; + } +} + +inline void AddColumnIfNotExists(SQLite::Database& db, const std::string& table_name, + const std::string& column_name, const std::string& column_type) { + if (!ColumnExists(db, table_name, column_name)) { + std::string sql = "ALTER TABLE " + table_name + " ADD COLUMN " + column_name + " " + column_type; + db.exec(sql); + } +} +} \ No newline at end of file diff --git a/engine/migrations/migration_helper.cc b/engine/migrations/migration_helper.cc index 42cc8d453..b02435cd2 100644 --- a/engine/migrations/migration_helper.cc +++ b/engine/migrations/migration_helper.cc @@ -7,7 +7,6 @@ cpp::result MigrationHelper::BackupDatabase( try { SQLite::Database src_db(src_db_path, SQLite::OPEN_READONLY); sqlite3* backup_db; - if (sqlite3_open(backup_db_path.c_str(), &backup_db) != SQLITE_OK) { throw std::runtime_error("Failed to open backup database"); } diff --git a/engine/migrations/migration_manager.cc b/engine/migrations/migration_manager.cc index 2c2b6ddfd..0e2e41e4e 100644 --- a/engine/migrations/migration_manager.cc +++ b/engine/migrations/migration_manager.cc @@ -5,6 +5,8 @@ #include "utils/file_manager_utils.h" #include "utils/scope_exit.h" #include "utils/widechar_conv.h" +#include "v0/migration.h" +#include "v1/migration.h" namespace cortex::migr { @@ -140,6 +142,9 @@ cpp::result MigrationManager::DoUpFolderStructure( case 0: return v0::MigrateFolderStructureUp(); break; + case 1: + return v1::MigrateFolderStructureUp(); + break; default: return true; @@ -151,6 +156,9 @@ cpp::result MigrationManager::DoDownFolderStructure( case 0: return v0::MigrateFolderStructureDown(); break; + case 1: + return v1::MigrateFolderStructureDown(); + break; default: return true; @@ -184,6 +192,9 @@ cpp::result MigrationManager::DoUpDB(int version) { case 0: return v0::MigrateDBUp(db_); break; + case 1: + return v1::MigrateDBUp(db_); + break; default: return true; @@ -195,6 +206,9 @@ cpp::result MigrationManager::DoDownDB(int version) { case 0: return v0::MigrateDBDown(db_); break; + case 1: + return v1::MigrateDBDown(db_); + break; default: return true; diff --git a/engine/migrations/schema_version.h b/engine/migrations/schema_version.h index 7cfccf27a..1e64110e3 100644 --- a/engine/migrations/schema_version.h +++ b/engine/migrations/schema_version.h @@ -1,4 +1,4 @@ #pragma once //Track the current schema version -#define SCHEMA_VERSION 0 \ No newline at end of file +#define SCHEMA_VERSION 1 \ No newline at end of file diff --git a/engine/migrations/v1/migration.h b/engine/migrations/v1/migration.h new file mode 100644 index 000000000..f9a8038e3 --- /dev/null +++ b/engine/migrations/v1/migration.h @@ -0,0 +1,165 @@ +#pragma once +#include +#include +#include +#include "migrations/db_helper.h" +#include "utils/file_manager_utils.h" +#include "utils/logging_utils.h" +#include "utils/result.hpp" + +namespace cortex::migr::v1 { +// Data folder +namespace fmu = file_manager_utils; + +// cortexcpp +// |__ models +// | |__ cortex.so +// | |__ tinyllama +// | |__ gguf +// |__ engines +// | |__ cortex.llamacpp +// | |__ deps +// | |__ windows-amd64-avx +// |__ logs +// +inline cpp::result MigrateFolderStructureUp() { + if (!std::filesystem::exists(fmu::GetCortexDataPath() / "models")) { + std::filesystem::create_directory(fmu::GetCortexDataPath() / "models"); + } + + if (!std::filesystem::exists(fmu::GetCortexDataPath() / "engines")) { + std::filesystem::create_directory(fmu::GetCortexDataPath() / "engines"); + } + + if (!std::filesystem::exists(fmu::GetCortexDataPath() / "logs")) { + std::filesystem::create_directory(fmu::GetCortexDataPath() / "logs"); + } + + return true; +} + +inline cpp::result MigrateFolderStructureDown() { + // CTL_INF("Folder structure already up to date!"); + return true; +} + +// Database +inline cpp::result MigrateDBUp(SQLite::Database& db) { + try { + db.exec( + "CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER PRIMARY " + "KEY);"); + + // models + { + // Check if the table exists + SQLite::Statement query(db, + "SELECT name FROM sqlite_master WHERE " + "type='table' AND name='models'"); + auto table_exists = query.executeStep(); + + if (table_exists) { + // Alter existing table + cortex::mgr::AddColumnIfNotExists(db, "models", "model_format", "TEXT"); + cortex::mgr::AddColumnIfNotExists(db, "models", "model_source", "TEXT"); + cortex::mgr::AddColumnIfNotExists(db, "models", "status", "TEXT"); + cortex::mgr::AddColumnIfNotExists(db, "models", "engine", "TEXT"); + } else { + // Create new table + db.exec( + "CREATE TABLE models (" + "model_id TEXT PRIMARY KEY," + "author_repo_id TEXT," + "branch_name TEXT," + "path_to_model_yaml TEXT," + "model_alias TEXT," + "model_format TEXT," + "model_source TEXT," + "status TEXT," + "engine TEXT" + ")"); + } + } + + db.exec( + "CREATE TABLE IF NOT EXISTS hardware (" + "uuid TEXT PRIMARY KEY, " + "type TEXT NOT NULL, " + "hardware_id INTEGER NOT NULL, " + "software_id INTEGER NOT NULL, " + "activated INTEGER NOT NULL CHECK (activated IN (0, 1)));"); + + // engines + db.exec( + "CREATE TABLE IF NOT EXISTS engines (" + "id INTEGER PRIMARY KEY AUTOINCREMENT," + "engine_name TEXT," + "type TEXT," + "api_key TEXT," + "url TEXT," + "version TEXT," + "variant TEXT," + "status TEXT," + "metadata TEXT," + "date_created TEXT DEFAULT CURRENT_TIMESTAMP," + "date_updated TEXT DEFAULT CURRENT_TIMESTAMP," + "UNIQUE(engine_name, variant));"); + + // CTL_INF("Database migration up completed successfully."); + return true; + } catch (const std::exception& e) { + CTL_WRN("Migration up failed: " << e.what()); + return cpp::fail(e.what()); + } +}; + +inline cpp::result MigrateDBDown(SQLite::Database& db) { + try { + // models + { + SQLite::Statement query(db, + "SELECT name FROM sqlite_master WHERE " + "type='table' AND name='models'"); + auto table_exists = query.executeStep(); + if (table_exists) { + // Create a new table with the old schema + db.exec( + "CREATE TABLE models_old (" + "model_id TEXT PRIMARY KEY," + "author_repo_id TEXT," + "branch_name TEXT," + "path_to_model_yaml TEXT," + "model_alias TEXT" + ")"); + + // Copy data from the current table to the new table + db.exec( + "INSERT INTO models_old (model_id, author_repo_id, branch_name, " + "path_to_model_yaml, model_alias) " + "SELECT model_id, author_repo_id, branch_name, path_to_model_yaml, " + "model_alias FROM models"); + + // Drop the current table + db.exec("DROP TABLE models"); + + // Rename the new table to the original name + db.exec("ALTER TABLE models_old RENAME TO models"); + } + } + + // hardware + { + // Do nothing + } + + // engines + db.exec("DROP TABLE IF EXISTS engines;"); + // CTL_INF("Migration down completed successfully."); + return true; + } catch (const std::exception& e) { + CTL_WRN("Migration down failed: " << e.what()); + return cpp::fail(e.what()); + } +} + +}; // namespace cortex::migr::v1 diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index c52e32ef0..c91fd0dd0 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -2,7 +2,11 @@ #include #include #include +#include #include "algorithm" +#include "database/engines.h" +#include "extensions/remote-engine/anthropic_engine.h" +#include "extensions/remote-engine/openai_engine.h" #include "utils/archive_utils.h" #include "utils/engine_constants.h" #include "utils/engine_matcher_utils.h" @@ -13,7 +17,6 @@ #include "utils/semantic_version_utils.h" #include "utils/system_info_utils.h" #include "utils/url_parser.h" - namespace { std::string GetSuitableCudaVersion(const std::string& engine, const std::string& cuda_driver_version) { @@ -179,6 +182,18 @@ cpp::result EngineService::UninstallEngineVariant( const std::string& engine, const std::optional version, const std::optional variant) { auto ne = NormalizeEngine(engine); + // TODO: handle uninstall remote engine + // only delete a remote engine if no model are using it + auto exist_engine = GetEngineByNameAndVariant(engine); + if (exist_engine.has_value() && exist_engine.value().type == "remote") { + auto result = DeleteEngine(exist_engine.value().id); + if (!result.empty()) { // This mean no error when delete model + CTL_ERR("Failed to delete engine: " << result); + return cpp::fail(result); + } + return cpp::result(true); + } + if (IsEngineLoaded(ne)) { CTL_INF("Engine " << ne << " is already loaded, unloading it"); auto unload_res = UnloadEngine(ne); @@ -226,21 +241,19 @@ cpp::result EngineService::UninstallEngineVariant( cpp::result EngineService::DownloadEngine( const std::string& engine, const std::string& version, const std::optional variant_name) { + auto normalized_version = version == "latest" ? "latest" : string_utils::RemoveSubstring(version, "v"); - auto res = GetEngineVariants(engine, version); if (res.has_error()) { return cpp::fail("Failed to fetch engine releases: " + res.error()); } - if (res.value().empty()) { return cpp::fail("No release found for " + version); } std::optional selected_variant = std::nullopt; - if (variant_name.has_value()) { auto latest_version_semantic = normalized_version == "latest" ? res.value()[0].version @@ -269,9 +282,10 @@ cpp::result EngineService::DownloadEngine( } } - if (selected_variant == std::nullopt) { + if (!selected_variant) { return cpp::fail("Failed to find a suitable variant for " + engine); } + if (IsEngineLoaded(engine)) { CTL_INF("Engine " << engine << " is already loaded, unloading it"); auto unload_res = UnloadEngine(engine); @@ -282,17 +296,17 @@ cpp::result EngineService::DownloadEngine( CTL_INF("Engine " << engine << " unloaded successfully"); } } - auto normalize_version = "v" + selected_variant->version; + auto normalize_version = "v" + selected_variant->version; auto variant_folder_name = engine_matcher_utils::GetVariantFromNameAndVersion( selected_variant->name, engine, selected_variant->version); - auto variant_folder_path = file_manager_utils::GetEnginesContainerPath() / engine / variant_folder_name.value() / normalize_version; - auto variant_path = variant_folder_path / selected_variant->name; + std::filesystem::create_directories(variant_folder_path); + CTL_INF("variant_folder_path: " + variant_folder_path.string()); auto on_finished = [this, engine, selected_variant, variant_folder_path, normalize_version](const DownloadTask& finishedTask) { @@ -301,14 +315,15 @@ cpp::result EngineService::DownloadEngine( CTL_INF("Version: " + normalize_version); auto extract_path = finishedTask.items[0].localPath.parent_path(); - archive_utils::ExtractArchive(finishedTask.items[0].localPath.string(), extract_path.string(), true); auto variant = engine_matcher_utils::GetVariantFromNameAndVersion( selected_variant->name, engine, normalize_version); + CTL_INF("Extracted variant: " + variant.value()); // set as default + auto res = SetDefaultEngineVariant(engine, normalize_version, variant.value()); if (res.has_error()) { @@ -316,10 +331,21 @@ cpp::result EngineService::DownloadEngine( } else { CTL_INF("Set default engine variant: " << res.value().variant); } - - // remove other engines - auto engine_directories = file_manager_utils::GetEnginesContainerPath() / - engine / selected_variant->name; + auto create_res = + EngineService::UpsertEngine(engine, // engine_name + "local", // todo - luke + "", // todo - luke + "", // todo - luke + normalize_version, variant.value(), + "Default", // todo - luke + "" // todo - luke + ); + + if (create_res.has_value()) { + CTL_ERR("Failed to create engine entry: " << create_res->engine_name); + } else { + CTL_INF("Engine entry created successfully"); + } for (const auto& entry : std::filesystem::directory_iterator( variant_folder_path.parent_path())) { @@ -333,7 +359,6 @@ cpp::result EngineService::DownloadEngine( } } - // remove the downloaded file try { std::filesystem::remove(finishedTask.items[0].localPath); } catch (const std::exception& e) { @@ -342,18 +367,18 @@ cpp::result EngineService::DownloadEngine( CTL_INF("Finished!"); }; - auto downloadTask{ + auto downloadTask = DownloadTask{.id = engine, .type = DownloadType::Engine, .items = {DownloadItem{ .id = engine, .downloadUrl = selected_variant->browser_download_url, .localPath = variant_path, - }}}}; + }}}; auto add_task_result = download_service_->AddTask(downloadTask, on_finished); - if (res.has_error()) { - return cpp::fail(res.error()); + if (add_task_result.has_error()) { + return cpp::fail(add_task_result.error()); } return {}; } @@ -656,6 +681,25 @@ cpp::result EngineService::LoadEngine( return {}; } + // Check for remote engine + if (remote_engine::IsRemoteEngine(engine_name)) { + auto exist_engine = GetEngineByNameAndVariant(engine_name); + if (exist_engine.has_error()) { + return cpp::fail("Remote engine '" + engine_name + "' is not installed"); + } + + if (engine_name == kOpenAiEngine) { + engines_[engine_name].engine = new remote_engine::OpenAiEngine(); + } else { + engines_[engine_name].engine = new remote_engine::AnthropicEngine(); + } + + CTL_INF("Loaded engine: " << engine_name); + return {}; + } + + // End hard code + CTL_INF("Loading engine: " << ne); auto selected_engine_variant = GetDefaultEngineVariant(ne); @@ -824,8 +868,11 @@ cpp::result EngineService::UnloadEngine( if (!IsEngineLoaded(ne)) { return cpp::fail("Engine " + ne + " is not loaded yet!"); } - EngineI* e = std::get(engines_[ne].engine); - delete e; + if (std::holds_alternative(engines_[ne].engine)) { + delete std::get(engines_[ne].engine); + } else { + delete std::get(engines_[ne].engine); + } #if defined(_WIN32) if (!RemoveDllDirectory(engines_[ne].cookie)) { @@ -867,9 +914,20 @@ EngineService::GetLatestEngineVersion(const std::string& engine) const { } cpp::result EngineService::IsEngineReady( - const std::string& engine) const { + const std::string& engine) { auto ne = NormalizeEngine(engine); + // Check for remote engine + if (remote_engine::IsRemoteEngine(engine)) { + auto exist_engine = GetEngineByNameAndVariant(engine); + if (exist_engine.has_error()) { + return cpp::fail("Remote engine '" + engine + "' is not installed"); + } + return true; + } + + // End hard code + auto os = hw_inf_.sys_inf->os; if (os == kMacOs && (ne == kOnnxRepo || ne == kTrtLlmRepo)) { return cpp::fail("Engine " + engine + " is not supported on macOS"); @@ -955,3 +1013,101 @@ cpp::result EngineService::UpdateEngine( .from = default_variant->version, .to = latest_version->tag_name}; } + +cpp::result, std::string> +EngineService::GetEngines() { + cortex::db::Engines engines; + auto get_res = engines.GetEngines(); + + if (!get_res.has_value()) { + return cpp::fail("Failed to get engine entries"); + } + + return get_res.value(); +} + +cpp::result EngineService::GetEngineById( + int id) { + cortex::db::Engines engines; + auto get_res = engines.GetEngineById(id); + + if (!get_res.has_value()) { + return cpp::fail("Engine with ID " + std::to_string(id) + " not found"); + } + + return get_res.value(); +} + +cpp::result +EngineService::GetEngineByNameAndVariant( + const std::string& engine_name, const std::optional variant) { + + cortex::db::Engines engines; + auto get_res = engines.GetEngineByNameAndVariant(engine_name, variant); + + if (!get_res.has_value()) { + if (variant.has_value()) { + return cpp::fail("Variant " + variant.value() + " not found for engine " + + engine_name); + } else { + return cpp::fail("Engine " + engine_name + " not found"); + } + } + + return get_res.value(); +} + +cpp::result EngineService::UpsertEngine( + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata) { + cortex::db::Engines engines; + auto upsert_res = engines.UpsertEngine(engine_name, type, api_key, url, + version, variant, status, metadata); + if (upsert_res.has_value()) { + return upsert_res.value(); + } else { + return cpp::fail("Failed to upsert engine entry"); + } +} + +std::string EngineService::DeleteEngine(int id) { + cortex::db::Engines engines; + auto delete_res = engines.DeleteEngineById(id); + if (delete_res.has_value()) { + return delete_res.value(); + } else { + return ""; + } +} + +cpp::result EngineService::GetRemoteModels( + const std::string& engine_name) { + if (auto r = IsEngineReady(engine_name); r.has_error()) { + return cpp::fail(r.error()); + } + + if (!IsEngineLoaded(engine_name)) { + auto exist_engine = GetEngineByNameAndVariant(engine_name); + if (exist_engine.has_error()) { + return cpp::fail("Remote engine '" + engine_name + "' is not installed"); + } + + if (engine_name == kOpenAiEngine) { + engines_[engine_name].engine = new remote_engine::OpenAiEngine(); + } else { + engines_[engine_name].engine = new remote_engine::AnthropicEngine(); + } + + CTL_INF("Loaded engine: " << engine_name); + } + + auto& e = std::get(engines_[engine_name].engine); + auto res = e->GetRemoteModels(); + if (!res["error"].isNull()) { + return cpp::fail(res["error"].asString()); + } else { + return res; + } +} \ No newline at end of file diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 47d7c272f..8c8bfbbe6 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -2,12 +2,18 @@ #include #include +#include #include #include +#include #include + #include "common/engine_servicei.h" #include "cortex-common/EngineI.h" #include "cortex-common/cortexpythoni.h" +#include "cortex-common/remote_enginei.h" +#include "database/engines.h" +#include "extensions/remote-engine/remote_engine.h" #include "services/download_service.h" #include "utils/cpuid/cpu_info.h" #include "utils/dylib.h" @@ -32,11 +38,7 @@ struct EngineUpdateResult { } }; -namespace system_info_utils { -struct SystemInfo; -} - -using EngineV = std::variant; +using EngineV = std::variant; class EngineService : public EngineServiceI { private: @@ -54,6 +56,14 @@ class EngineService : public EngineServiceI { std::mutex engines_mutex_; std::unordered_map engines_{}; + std::shared_ptr download_service_; + + struct HardwareInfo { + std::unique_ptr sys_inf; + cortex::cpuid::CpuInfo cpu_inf; + std::string cuda_driver_version; + }; + HardwareInfo hw_inf_; public: const std::vector kSupportEngines = { @@ -70,7 +80,7 @@ class EngineService : public EngineServiceI { /** * Check if an engines is ready (have at least one variant installed) */ - cpp::result IsEngineReady(const std::string& engine) const; + cpp::result IsEngineReady(const std::string& engine); /** * Handling install engine variant. @@ -110,7 +120,6 @@ class EngineService : public EngineServiceI { std::vector GetLoadedEngines(); cpp::result LoadEngine(const std::string& engine_name); - cpp::result UnloadEngine(const std::string& engine_name); cpp::result @@ -123,6 +132,25 @@ class EngineService : public EngineServiceI { cpp::result UpdateEngine( const std::string& engine); + cpp::result, std::string> GetEngines(); + + cpp::result GetEngineById(int id); + + cpp::result GetEngineByNameAndVariant( + const std::string& engine_name, + const std::optional variant = std::nullopt); + + cpp::result UpsertEngine( + const std::string& engine_name, const std::string& type, + const std::string& api_key, const std::string& url, + const std::string& version, const std::string& variant, + const std::string& status, const std::string& metadata); + + std::string DeleteEngine(int id); + + cpp::result GetRemoteModels( + const std::string& engine_name); + private: cpp::result DownloadEngine( const std::string& engine, const std::string& version = "latest", @@ -137,13 +165,4 @@ class EngineService : public EngineServiceI { cpp::result IsEngineVariantReady( const std::string& engine, const std::string& version, const std::string& variant); - - std::shared_ptr download_service_; - - struct HardwareInfo { - std::unique_ptr sys_inf; - cortex::cpuid::CpuInfo cpu_inf; - std::string cuda_driver_version; - }; - HardwareInfo hw_inf_; -}; +}; \ No newline at end of file diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index 46309823d..ace7e675f 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -24,14 +24,26 @@ cpp::result InferenceService::HandleChatCompletion( return cpp::fail(std::make_pair(stt, res)); } - auto engine = std::get(engine_result.value()); - engine->HandleChatCompletion( - json_body, [q, tool_choice](Json::Value status, Json::Value res) { - if (!tool_choice.isNull()) { - res["tool_choice"] = tool_choice; - } - q->push(std::make_pair(status, res)); - }); + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->HandleChatCompletion( + json_body, [q, tool_choice](Json::Value status, Json::Value res) { + if (!tool_choice.isNull()) { + res["tool_choice"] = tool_choice; + } + q->push(std::make_pair(status, res)); + }); + } else { + std::get(engine_result.value()) + ->HandleChatCompletion( + json_body, [q, tool_choice](Json::Value status, Json::Value res) { + if (!tool_choice.isNull()) { + res["tool_choice"] = tool_choice; + } + q->push(std::make_pair(status, res)); + }); + } + return {}; } @@ -53,10 +65,18 @@ cpp::result InferenceService::HandleEmbedding( LOG_WARN << "Engine is not loaded yet"; return cpp::fail(std::make_pair(stt, res)); } - auto engine = std::get(engine_result.value()); - engine->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) { - q->push(std::make_pair(status, res)); - }); + + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) { + q->push(std::make_pair(status, res)); + }); + } else { + std::get(engine_result.value()) + ->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) { + q->push(std::make_pair(status, res)); + }); + } return {}; } @@ -83,11 +103,20 @@ InferResult InferenceService::LoadModel( // might need mutex here auto engine_result = engine_service_->GetLoadedEngine(engine_type); - auto engine = std::get(engine_result.value()); - engine->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) { - stt = status; - r = res; - }); + + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } else { + std::get(engine_result.value()) + ->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } return std::make_pair(stt, r); } @@ -110,12 +139,22 @@ InferResult InferenceService::UnloadModel(const std::string& engine_name, json_body["model"] = model_id; LOG_TRACE << "Start unload model"; - auto engine = std::get(engine_result.value()); - engine->UnloadModel(std::make_shared(json_body), + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->UnloadModel(std::make_shared(json_body), + [&r, &stt](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } else { + std::get(engine_result.value()) + ->UnloadModel(std::make_shared(json_body), [&r, &stt](Json::Value status, Json::Value res) { stt = status; r = res; }); + } + return std::make_pair(stt, r); } @@ -141,12 +180,23 @@ InferResult InferenceService::GetModelStatus( } LOG_TRACE << "Start to get model status"; - auto engine = std::get(engine_result.value()); - engine->GetModelStatus(json_body, + + if (std::holds_alternative(engine_result.value())) { + std::get(engine_result.value()) + ->GetModelStatus(json_body, + [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }); + } else { + std::get(engine_result.value()) + ->GetModelStatus(json_body, [&stt, &r](Json::Value status, Json::Value res) { stt = status; r = res; }); + } + return std::make_pair(stt, r); } diff --git a/engine/services/inference_service.h b/engine/services/inference_service.h index 7c09156ff..94097132a 100644 --- a/engine/services/inference_service.h +++ b/engine/services/inference_service.h @@ -5,7 +5,7 @@ #include #include "services/engine_service.h" #include "utils/result.hpp" - +#include "extensions/remote-engine/remote_engine.h" namespace services { // Status and result using InferResult = std::pair; diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 3cfff5cb2..d81a9b649 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -64,11 +64,13 @@ void ParseGguf(const DownloadItem& ggufDownloadItem, auto author_id = author.has_value() ? author.value() : "cortexso"; cortex::db::Models modellist_utils_obj; - cortex::db::ModelEntry model_entry{.model = ggufDownloadItem.id, - .author_repo_id = author_id, - .branch_name = branch, - .path_to_model_yaml = rel.string(), - .model_alias = ggufDownloadItem.id}; + cortex::db::ModelEntry model_entry{ + .model = ggufDownloadItem.id, + .author_repo_id = author_id, + .branch_name = branch, + .path_to_model_yaml = rel.string(), + .model_alias = ggufDownloadItem.id, + .status = cortex::db::ModelStatus::Downloaded}; auto result = modellist_utils_obj.AddModelEntry(model_entry, true); if (result.has_error()) { CTL_WRN("Error adding model to modellist: " + result.error()); @@ -718,6 +720,49 @@ cpp::result ModelService::StartModel( .string()); auto mc = yaml_handler.GetModelConfig(); + // Running remote model + if (remote_engine::IsRemoteEngine(mc.engine)) { + + config::RemoteModelConfig remote_mc; + remote_mc.LoadFromYamlFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + auto remote_engine_entry = + engine_svc_->GetEngineByNameAndVariant(mc.engine); + if (remote_engine_entry.has_error()) { + CTL_WRN("Remote engine error: " + model_entry.error()); + return cpp::fail(remote_engine_entry.error()); + } + auto remote_engine_json = remote_engine_entry.value().ToJson(); + json_data = remote_mc.ToJson(); + + json_data["api_key"] = std::move(remote_engine_json["api_key"]); + json_data["model_path"] = + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string(); + json_data["metadata"] = std::move(remote_engine_json["metadata"]); + + auto ir = + inference_svc_->LoadModel(std::make_shared(json_data)); + auto status = std::get<0>(ir)["status_code"].asInt(); + auto data = std::get<1>(ir); + if (status == drogon::k200OK) { + return StartModelResult{.success = true, .warning = ""}; + } else if (status == drogon::k409Conflict) { + CTL_INF("Model '" + model_handle + "' is already loaded"); + return StartModelResult{.success = true, .warning = ""}; + } else { + // only report to user the error + CTL_ERR("Model failed to start with status code: " << status); + return cpp::fail("Model failed to start: " + + data["message"].asString()); + } + } + + // end hard code + json_data = mc.ToJson(); if (mc.files.size() > 0) { #if defined(_WIN32) diff --git a/engine/test/components/test_models_db.cc b/engine/test/components/test_models_db.cc index 8c3ebbe00..ab0ea9f70 100644 --- a/engine/test/components/test_models_db.cc +++ b/engine/test/components/test_models_db.cc @@ -6,6 +6,7 @@ namespace cortex::db { namespace { constexpr const auto kTestDb = "./test.db"; } + class ModelsTestSuite : public ::testing::Test { public: ModelsTestSuite() @@ -14,12 +15,17 @@ class ModelsTestSuite : public ::testing::Test { void SetUp() { try { db_.exec( - "CREATE TABLE IF NOT EXISTS models (" + "CREATE TABLE models (" "model_id TEXT PRIMARY KEY," "author_repo_id TEXT," "branch_name TEXT," "path_to_model_yaml TEXT," - "model_alias TEXT);"); + "model_alias TEXT," + "model_format TEXT," + "model_source TEXT," + "status TEXT," + "engine TEXT" + ")"); } catch (const std::exception& e) {} } @@ -33,20 +39,27 @@ class ModelsTestSuite : public ::testing::Test { SQLite::Database db_; cortex::db::Models model_list_; - const cortex::db::ModelEntry kTestModel{"test_model_id", "test_author", - "main", "/path/to/model.yaml", - "test_alias"}; + const cortex::db::ModelEntry kTestModel{ + "test_model_id", "test_author", + "main", "/path/to/model.yaml", + "test_alias", "test_format", + "test_source", cortex::db::ModelStatus::Downloaded, + "test_engine"}; }; TEST_F(ModelsTestSuite, TestAddModelEntry) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); auto retrieved_model = model_list_.GetModelInfo(kTestModel.model); - EXPECT_TRUE(retrieved_model); + EXPECT_TRUE(retrieved_model.has_value()); EXPECT_EQ(retrieved_model.value().model, kTestModel.model); EXPECT_EQ(retrieved_model.value().author_repo_id, kTestModel.author_repo_id); + EXPECT_EQ(retrieved_model.value().model_format, kTestModel.model_format); + EXPECT_EQ(retrieved_model.value().model_source, kTestModel.model_source); + EXPECT_EQ(retrieved_model.value().status, kTestModel.status); + EXPECT_EQ(retrieved_model.value().engine, kTestModel.engine); - // // Clean up + // Clean up EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); } @@ -54,7 +67,7 @@ TEST_F(ModelsTestSuite, TestGetModelInfo) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); auto model_by_id = model_list_.GetModelInfo(kTestModel.model); - EXPECT_TRUE(model_by_id); + EXPECT_TRUE(model_by_id.has_value()); EXPECT_EQ(model_by_id.value().model, kTestModel.model); auto model_by_alias = model_list_.GetModelInfo("test_alias"); @@ -71,14 +84,14 @@ TEST_F(ModelsTestSuite, TestUpdateModelEntry) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); cortex::db::ModelEntry updated_model = kTestModel; + updated_model.status = cortex::db::ModelStatus::Downloaded; EXPECT_TRUE( model_list_.UpdateModelEntry(kTestModel.model, updated_model).value()); auto retrieved_model = model_list_.GetModelInfo(kTestModel.model); - EXPECT_TRUE(retrieved_model); - EXPECT_TRUE( - model_list_.UpdateModelEntry(kTestModel.model, updated_model).value()); + EXPECT_TRUE(retrieved_model.has_value()); + EXPECT_EQ(retrieved_model.value().status, updated_model.status); // Clean up EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); @@ -117,7 +130,7 @@ TEST_F(ModelsTestSuite, TestPersistence) { // Create a new ModelListUtils instance to test if it loads from file cortex::db::Models new_model_list(db_); auto retrieved_model = new_model_list.GetModelInfo(kTestModel.model); - EXPECT_TRUE(retrieved_model); + EXPECT_TRUE(retrieved_model.has_value()); EXPECT_EQ(retrieved_model.value().model, kTestModel.model); EXPECT_EQ(retrieved_model.value().author_repo_id, kTestModel.author_repo_id); EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); @@ -136,7 +149,7 @@ TEST_F(ModelsTestSuite, TestUpdateModelAlias) { EXPECT_TRUE( model_list_.UpdateModelAlias(kTestModel.model, kNewTestAlias).value()); auto updated_model = model_list_.GetModelInfo(kNewTestAlias); - EXPECT_TRUE(updated_model); + EXPECT_TRUE(updated_model.has_value()); EXPECT_EQ(updated_model.value().model_alias, kNewTestAlias); EXPECT_EQ(updated_model.value().model, kTestModel.model); @@ -174,4 +187,5 @@ TEST_F(ModelsTestSuite, TestHasModel) { // Clean up EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); } -} // namespace cortex::db + +} // namespace cortex::db \ No newline at end of file diff --git a/engine/utils/engine_constants.h b/engine/utils/engine_constants.h index 5dab49936..020109fd8 100644 --- a/engine/utils/engine_constants.h +++ b/engine/utils/engine_constants.h @@ -3,6 +3,8 @@ constexpr const auto kOnnxEngine = "onnxruntime"; constexpr const auto kLlamaEngine = "llama-cpp"; constexpr const auto kTrtLlmEngine = "tensorrt-llm"; +constexpr const auto kOpenAiEngine = "openai"; +constexpr const auto kAnthropicEngine = "anthropic"; constexpr const auto kOnnxRepo = "cortex.onnx"; constexpr const auto kLlamaRepo = "cortex.llamacpp"; diff --git a/engine/utils/logging_utils.h b/engine/utils/logging_utils.h index d2c04a7e8..7d4cf35f1 100644 --- a/engine/utils/logging_utils.h +++ b/engine/utils/logging_utils.h @@ -9,6 +9,8 @@ inline bool log_verbose = false; inline bool is_server = false; // Only use trantor log +#define CTL_TRC(msg) LOG_TRACE << msg; + #define CTL_DBG(msg) LOG_DEBUG << msg; #define CTL_INF(msg) LOG_INFO << msg; diff --git a/engine/utils/remote_models_utils.h b/engine/utils/remote_models_utils.h new file mode 100644 index 000000000..7b7906f2c --- /dev/null +++ b/engine/utils/remote_models_utils.h @@ -0,0 +1,132 @@ +#pragma once + +#include +#include +#include + +namespace remote_models_utils { +constexpr char chat_completion_request_template[] = + "{ {% set first = true %} {% for key, value in input_request %} {% if key " + "== \"messages\" or key == \"model\" or key == \"temperature\" or key == " + "\"store\" or key == \"max_tokens\" or key == \"stream\" or key == " + "\"presence_penalty\" or key == \"metadata\" or key == " + "\"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or " + "key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" " + "or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key " + "== \"response_format\" or key == \"service_tier\" or key == \"seed\" or " + "key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key " + "== \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% " + "endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% " + "endif %} {% endfor %} }"; + +constexpr char chat_completion_response_template[] = + "{ {% set first = true %} {% for key, value in input_request %} {% if key " + "== \"messages\" or key == \"model\" or key == \"temperature\" or key == " + "\"store\" or key == \"max_tokens\" or key == \"stream\" or key == " + "\"presence_penalty\" or key == \"metadata\" or key == " + "\"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or " + "key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" " + "or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key " + "== \"response_format\" or key == \"service_tier\" or key == \"seed\" or " + "key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key " + "== \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% " + "endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% " + "endif %} {% endfor %} }"; + +constexpr char chat_completion_url[] = + "https://api.openai.com/v1/chat/completions"; + +inline Json::Value yamlToJson(const YAML::Node& node) { + Json::Value result; + + switch (node.Type()) { + case YAML::NodeType::Null: + return Json::Value(); + case YAML::NodeType::Scalar: { + // For scalar types, we'll first try to parse as string + std::string str_val = node.as(); + + // Try to parse as boolean + if (str_val == "true" || str_val == "True" || str_val == "TRUE") + return Json::Value(true); + if (str_val == "false" || str_val == "False" || str_val == "FALSE") + return Json::Value(false); + + // Try to parse as number + try { + // Check if it's an integer + size_t pos; + long long int_val = std::stoll(str_val, &pos); + if (pos == str_val.length()) { + return Json::Value(static_cast(int_val)); + } + + // Check if it's a float + double float_val = std::stod(str_val, &pos); + if (pos == str_val.length()) { + return Json::Value(float_val); + } + } catch (...) { + // If parsing as number fails, use as string + } + + // Default to string if no other type matches + return Json::Value(str_val); + } + case YAML::NodeType::Sequence: { + result = Json::Value(Json::arrayValue); + for (const auto& elem : node) { + result.append(yamlToJson(elem)); + } + return result; + } + case YAML::NodeType::Map: { + result = Json::Value(Json::objectValue); + for (const auto& it : node) { + std::string key = it.first.as(); + result[key] = yamlToJson(it.second); + } + return result; + } + default: + return Json::Value(); + } +} + +inline YAML::Node jsonToYaml(const Json::Value& json) { + YAML::Node result; + + switch (json.type()) { + case Json::nullValue: + result = YAML::Node(YAML::NodeType::Null); + break; + case Json::intValue: + result = json.asInt64(); + break; + case Json::uintValue: + result = json.asUInt64(); + break; + case Json::realValue: + result = json.asDouble(); + break; + case Json::stringValue: + result = json.asString(); + break; + case Json::booleanValue: + result = json.asBool(); + break; + case Json::arrayValue: + result = YAML::Node(YAML::NodeType::Sequence); + for (const auto& elem : json) + result.push_back(jsonToYaml(elem)); + break; + case Json::objectValue: + result = YAML::Node(YAML::NodeType::Map); + for (const auto& key : json.getMemberNames()) + result[key] = jsonToYaml(json[key]); + break; + } + return result; +} + +} // namespace remote_models_utils \ No newline at end of file diff --git a/engine/utils/result.hpp b/engine/utils/result.hpp index 96243f72e..7f7356b84 100644 --- a/engine/utils/result.hpp +++ b/engine/utils/result.hpp @@ -34,7 +34,6 @@ #include // std::size_t #include // std::enable_if, std::is_constructible, etc -#include // placement-new #include // std::address_of #include // std::reference_wrapper, std::invoke #include // std::in_place_t, std::forward diff --git a/engine/vcpkg.json b/engine/vcpkg.json index 36fa322a3..962d06ffd 100644 --- a/engine/vcpkg.json +++ b/engine/vcpkg.json @@ -13,6 +13,7 @@ "sqlitecpp", "trantor", "indicators", + "inja", "lfreist-hwinfo" ] } From 4d2d23615d7e424c61b9087d381ffdefa3794f83 Mon Sep 17 00:00:00 2001 From: NamH Date: Thu, 5 Dec 2024 10:53:59 +0700 Subject: [PATCH 12/44] Merge pull request #1767 from janhq/j/add-thread feat: add thread --- ...e_response.h => delete_success_response.h} | 2 +- engine/common/message.h | 14 +- engine/common/message_attachment.h | 4 +- engine/common/message_attachment_factory.h | 6 +- engine/common/message_content.h | 4 +- engine/common/message_content_factory.h | 4 +- engine/common/message_content_image_file.h | 4 +- engine/common/message_content_image_url.h | 4 +- engine/common/message_content_refusal.h | 4 +- engine/common/message_content_text.h | 4 +- engine/common/message_incomplete_detail.h | 4 +- engine/common/message_role.h | 4 +- engine/common/message_status.h | 4 +- engine/common/repository/message_repository.h | 19 +- engine/common/repository/thread_repository.h | 25 ++ engine/common/thread.h | 142 +++++++++++ engine/common/thread_tool_resources.h | 50 ++++ engine/controllers/messages.cc | 22 +- engine/controllers/threads.cc | 220 ++++++++++++++++++ engine/controllers/threads.h | 57 +++++ engine/main.cc | 41 ++-- engine/repositories/message_fs_repository.cc | 143 +++++++----- engine/repositories/message_fs_repository.h | 46 +++- engine/repositories/thread_fs_repository.cc | 166 +++++++++++++ engine/repositories/thread_fs_repository.h | 62 +++++ engine/services/message_service.cc | 62 +++-- engine/services/message_service.h | 23 +- engine/services/thread_service.cc | 83 +++++++ engine/services/thread_service.h | 35 +++ 29 files changed, 1070 insertions(+), 188 deletions(-) rename engine/common/api-dto/{messages/delete_message_response.h => delete_success_response.h} (87%) create mode 100644 engine/common/repository/thread_repository.h create mode 100644 engine/common/thread.h create mode 100644 engine/common/thread_tool_resources.h create mode 100644 engine/controllers/threads.cc create mode 100644 engine/controllers/threads.h create mode 100644 engine/repositories/thread_fs_repository.cc create mode 100644 engine/repositories/thread_fs_repository.h create mode 100644 engine/services/thread_service.cc create mode 100644 engine/services/thread_service.h diff --git a/engine/common/api-dto/messages/delete_message_response.h b/engine/common/api-dto/delete_success_response.h similarity index 87% rename from engine/common/api-dto/messages/delete_message_response.h rename to engine/common/api-dto/delete_success_response.h index 79447c93a..ebb8f36f0 100644 --- a/engine/common/api-dto/messages/delete_message_response.h +++ b/engine/common/api-dto/delete_success_response.h @@ -3,7 +3,7 @@ #include "common/json_serializable.h" namespace api_response { -struct DeleteMessageResponse : JsonSerializable { +struct DeleteSuccessResponse : JsonSerializable { std::string id; std::string object; bool deleted; diff --git a/engine/common/message.h b/engine/common/message.h index e5685f3bb..909a843ee 100644 --- a/engine/common/message.h +++ b/engine/common/message.h @@ -17,19 +17,17 @@ #include "utils/logging_utils.h" #include "utils/result.hpp" -namespace ThreadMessage { +namespace OpenAi { // Represents a message within a thread. struct Message : JsonSerializable { Message() = default; - - Message(Message&&) = default; - - Message& operator=(Message&&) = default; - + // Delete copy operations Message(const Message&) = delete; - Message& operator=(const Message&) = delete; + // Allow move operations + Message(Message&&) = default; + Message& operator=(Message&&) = default; // The identifier, which can be referenced in API endpoints. std::string id; @@ -210,4 +208,4 @@ struct Message : JsonSerializable { } } }; -}; // namespace ThreadMessage +}; // namespace OpenAi diff --git a/engine/common/message_attachment.h b/engine/common/message_attachment.h index ea809990e..767ec9bea 100644 --- a/engine/common/message_attachment.h +++ b/engine/common/message_attachment.h @@ -3,7 +3,7 @@ #include #include "common/json_serializable.h" -namespace ThreadMessage { +namespace OpenAi { // The tools to add this file to. struct Tool { @@ -47,4 +47,4 @@ struct Attachment : JsonSerializable { } } }; -}; // namespace ThreadMessage +}; // namespace OpenAi diff --git a/engine/common/message_attachment_factory.h b/engine/common/message_attachment_factory.h index d9f1b8d2e..ce4eef60b 100644 --- a/engine/common/message_attachment_factory.h +++ b/engine/common/message_attachment_factory.h @@ -1,8 +1,10 @@ +#pragma once + #include #include "common/message_attachment.h" #include "utils/result.hpp" -namespace ThreadMessage { +namespace OpenAi { inline cpp::result ParseAttachment( Json::Value&& json) { if (json.empty()) { @@ -45,4 +47,4 @@ ParseAttachments(Json::Value&& json) { return attachments; } -}; // namespace ThreadMessage +}; // namespace OpenAi diff --git a/engine/common/message_content.h b/engine/common/message_content.h index 6e76b01a8..a86dc58ed 100644 --- a/engine/common/message_content.h +++ b/engine/common/message_content.h @@ -3,7 +3,7 @@ #include #include "common/json_serializable.h" -namespace ThreadMessage { +namespace OpenAi { struct Content : JsonSerializable { std::string type; @@ -20,4 +20,4 @@ struct Content : JsonSerializable { virtual ~Content() = default; }; -}; // namespace ThreadMessage +}; // namespace OpenAi diff --git a/engine/common/message_content_factory.h b/engine/common/message_content_factory.h index 854f6efd8..6f8fcb4fe 100644 --- a/engine/common/message_content_factory.h +++ b/engine/common/message_content_factory.h @@ -8,7 +8,7 @@ #include "utils/logging_utils.h" #include "utils/result.hpp" -namespace ThreadMessage { +namespace OpenAi { inline cpp::result, std::string> ParseContent( Json::Value&& json) { if (json.empty()) { @@ -74,4 +74,4 @@ ParseContents(Json::Value&& json) { } return contents; } -} // namespace ThreadMessage +} // namespace OpenAi diff --git a/engine/common/message_content_image_file.h b/engine/common/message_content_image_file.h index 1807dec1e..c3ec57853 100644 --- a/engine/common/message_content_image_file.h +++ b/engine/common/message_content_image_file.h @@ -2,7 +2,7 @@ #include "common/message_content.h" -namespace ThreadMessage { +namespace OpenAi { struct ImageFile { // The File ID of the image in the message content. Set purpose="vision" when uploading the File if you need to later display the file content. std::string file_id; @@ -66,4 +66,4 @@ struct ImageFileContent : Content { } } }; -} // namespace ThreadMessage +} // namespace OpenAi diff --git a/engine/common/message_content_image_url.h b/engine/common/message_content_image_url.h index eae6a7aa6..b86544e38 100644 --- a/engine/common/message_content_image_url.h +++ b/engine/common/message_content_image_url.h @@ -2,7 +2,7 @@ #include "common/message_content.h" -namespace ThreadMessage { +namespace OpenAi { struct ImageUrl { // The external URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp. @@ -68,4 +68,4 @@ struct ImageUrlContent : Content { } } }; -} // namespace ThreadMessage +} // namespace OpenAi diff --git a/engine/common/message_content_refusal.h b/engine/common/message_content_refusal.h index 8353c3a85..c2537ccbf 100644 --- a/engine/common/message_content_refusal.h +++ b/engine/common/message_content_refusal.h @@ -2,7 +2,7 @@ #include "common/message_content.h" -namespace ThreadMessage { +namespace OpenAi { // The refusal content generated by the assistant. struct Refusal : Content { @@ -43,4 +43,4 @@ struct Refusal : Content { } } }; -} // namespace ThreadMessage +} // namespace OpenAi diff --git a/engine/common/message_content_text.h b/engine/common/message_content_text.h index 124d4a878..ea6aab1ab 100644 --- a/engine/common/message_content_text.h +++ b/engine/common/message_content_text.h @@ -3,7 +3,7 @@ #include "common/message_content.h" #include "utils/logging_utils.h" -namespace ThreadMessage { +namespace OpenAi { struct Annotation : JsonSerializable { std::string type; @@ -239,4 +239,4 @@ struct TextContent : Content { } } }; -} // namespace ThreadMessage +} // namespace OpenAi diff --git a/engine/common/message_incomplete_detail.h b/engine/common/message_incomplete_detail.h index 25e9c1169..98e6ff56b 100644 --- a/engine/common/message_incomplete_detail.h +++ b/engine/common/message_incomplete_detail.h @@ -2,7 +2,7 @@ #include "common/json_serializable.h" -namespace ThreadMessage { +namespace OpenAi { // On an incomplete message, details about why the message is incomplete. struct IncompleteDetail : JsonSerializable { @@ -29,4 +29,4 @@ struct IncompleteDetail : JsonSerializable { } } }; -} // namespace ThreadMessage +} // namespace OpenAi diff --git a/engine/common/message_role.h b/engine/common/message_role.h index 9d428eddc..504e2e5f6 100644 --- a/engine/common/message_role.h +++ b/engine/common/message_role.h @@ -3,7 +3,7 @@ #include #include "utils/string_utils.h" -namespace ThreadMessage { +namespace OpenAi { // The entity that produced the message. One of user or assistant. enum class Role { USER, ASSISTANT }; @@ -27,4 +27,4 @@ inline Role RoleFromString(const std::string& input) { return Role::ASSISTANT; } } -}; // namespace ThreadMessage +}; // namespace OpenAi diff --git a/engine/common/message_status.h b/engine/common/message_status.h index e8844ee13..453617363 100644 --- a/engine/common/message_status.h +++ b/engine/common/message_status.h @@ -3,7 +3,7 @@ #include #include "utils/string_utils.h" -namespace ThreadMessage { +namespace OpenAi { // The status of the message, which can be either in_progress, incomplete, or completed. enum class Status { IN_PROGRESS, INCOMPLETE, COMPLETED }; @@ -31,4 +31,4 @@ inline Status StatusFromString(const std::string& input) { return Status::COMPLETED; } } -}; // namespace ThreadMessage +}; // namespace OpenAi diff --git a/engine/common/repository/message_repository.h b/engine/common/repository/message_repository.h index cffc73675..a8a971fd8 100644 --- a/engine/common/repository/message_repository.h +++ b/engine/common/repository/message_repository.h @@ -6,22 +6,25 @@ class MessageRepository { public: virtual cpp::result CreateMessage( - ThreadMessage::Message& message) = 0; + OpenAi::Message& message) = 0; - virtual cpp::result, std::string> - ListMessages(const std::string& thread_id, uint8_t limit = 20, - const std::string& order = "desc", const std::string& after = "", - const std::string& before = "", - const std::string& run_id = "") const = 0; + virtual cpp::result, std::string> ListMessages( + const std::string& thread_id, uint8_t limit, const std::string& order, + const std::string& after, const std::string& before, + const std::string& run_id) const = 0; - virtual cpp::result RetrieveMessage( + virtual cpp::result RetrieveMessage( const std::string& thread_id, const std::string& message_id) const = 0; virtual cpp::result ModifyMessage( - ThreadMessage::Message& message) = 0; + OpenAi::Message& message) = 0; virtual cpp::result DeleteMessage( const std::string& thread_id, const std::string& message_id) = 0; + virtual cpp::result InitializeMessages( + const std::string& thread_id, + std::optional> messages) = 0; + virtual ~MessageRepository() = default; }; diff --git a/engine/common/repository/thread_repository.h b/engine/common/repository/thread_repository.h new file mode 100644 index 000000000..c7bb9e7cf --- /dev/null +++ b/engine/common/repository/thread_repository.h @@ -0,0 +1,25 @@ +#pragma once + +#include "common/thread.h" +#include "utils/result.hpp" + +class ThreadRepository { + public: + virtual cpp::result CreateThread( + OpenAi::Thread& thread) = 0; + + virtual cpp::result, std::string> ListThreads( + uint8_t limit, const std::string& order, const std::string&, + const std::string& before) const = 0; + + virtual cpp::result RetrieveThread( + const std::string& thread_id) const = 0; + + virtual cpp::result ModifyThread( + OpenAi::Thread& thread) = 0; + + virtual cpp::result DeleteThread( + const std::string& thread_id) = 0; + + virtual ~ThreadRepository() = default; +}; diff --git a/engine/common/thread.h b/engine/common/thread.h new file mode 100644 index 000000000..20672ff72 --- /dev/null +++ b/engine/common/thread.h @@ -0,0 +1,142 @@ +#pragma once + +#include +#include +#include +#include "common/thread_tool_resources.h" +#include "common/variant_map.h" +#include "json_serializable.h" +#include "utils/logging_utils.h" + +namespace OpenAi { + +/** + * Represents a thread that contains messages. + */ +struct Thread : JsonSerializable { + /** + * The identifier, which can be referenced in API endpoints. + */ + std::string id; + + /** + * The object type, which is always thread. + */ + std::string object = "thread"; + + /** + * The Unix timestamp (in seconds) for when the thread was created. + */ + uint64_t created_at; + + /** + * A set of resources that are made available to the assistant's + * tools in this thread. The resources are specific to the type + * of tool. For example, the code_interpreter tool requires a list of + * file IDs, while the file_search tool requires a list of vector store IDs. + */ + std::unique_ptr tool_resources; + + /** + * Set of 16 key-value pairs that can be attached to an object. + * This can be useful for storing additional information about the object + * in a structured format. + * + * Keys can be a maximum of 64 characters long and values can be a maximum + * of 512 characters long. + */ + Cortex::VariantMap metadata; + + static cpp::result FromJson(const Json::Value& json) { + Thread thread; + + thread.id = json["id"].asString(); + thread.object = "thread"; + thread.created_at = json["created_at"].asUInt(); + if (thread.created_at == 0 && json["created"].asUInt64() != 0) { + thread.created_at = json["created"].asUInt64() / 1000; + } + + if (json.isMember("tool_resources") && !json["tool_resources"].isNull()) { + const auto& tool_json = json["tool_resources"]; + + if (tool_json.isMember("code_interpreter")) { + auto code_interpreter = std::make_unique(); + const auto& file_ids = tool_json["code_interpreter"]["file_ids"]; + if (file_ids.isArray()) { + for (const auto& file_id : file_ids) { + code_interpreter->file_ids.push_back(file_id.asString()); + } + } + thread.tool_resources = std::move(code_interpreter); + } else if (tool_json.isMember("file_search")) { + auto file_search = std::make_unique(); + const auto& store_ids = tool_json["file_search"]["vector_store_ids"]; + if (store_ids.isArray()) { + for (const auto& store_id : store_ids) { + file_search->vector_store_ids.push_back(store_id.asString()); + } + } + thread.tool_resources = std::move(file_search); + } + } + + if (json["metadata"].isObject() && !json["metadata"].empty()) { + auto res = Cortex::ConvertJsonValueToMap(json["metadata"]); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + thread.metadata = res.value(); + } + } + + return thread; + } + + cpp::result ToJson() override { + try { + Json::Value json; + + json["id"] = id; + json["object"] = object; + json["created_at"] = created_at; + + if (tool_resources) { + auto tool_result = tool_resources->ToJson(); + if (tool_result.has_error()) { + return cpp::fail("Failed to serialize tool_resources: " + + tool_result.error()); + } + + Json::Value tool_json; + if (auto code_interpreter = + dynamic_cast(tool_resources.get())) { + tool_json["code_interpreter"] = tool_result.value(); + } else if (auto file_search = + dynamic_cast(tool_resources.get())) { + tool_json["file_search"] = tool_result.value(); + } + json["tool_resources"] = tool_json; + } + + Json::Value metadata_json{Json::objectValue}; + for (const auto& [key, value] : metadata) { + if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else if (std::holds_alternative(value)) { + metadata_json[key] = std::get(value); + } else { + metadata_json[key] = std::get(value); + } + } + json["metadata"] = metadata_json; + + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +} // namespace OpenAi diff --git a/engine/common/thread_tool_resources.h b/engine/common/thread_tool_resources.h new file mode 100644 index 000000000..3c22a4480 --- /dev/null +++ b/engine/common/thread_tool_resources.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include "common/json_serializable.h" + +namespace OpenAi { + +struct ThreadToolResources : JsonSerializable { + ~ThreadToolResources() = default; + + virtual cpp::result ToJson() override = 0; +}; + +struct ThreadCodeInterpreter : ThreadToolResources { + std::vector file_ids; + + cpp::result ToJson() override { + try { + Json::Value json; + Json::Value file_ids_json{Json::arrayValue}; + for (auto& file_id : file_ids) { + file_ids_json.append(file_id); + } + json["file_ids"] = file_ids_json; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; + +struct ThreadFileSearch : ThreadToolResources { + std::vector vector_store_ids; + + cpp::result ToJson() override { + try { + Json::Value json; + Json::Value vector_store_ids_json{Json::arrayValue}; + for (auto& vector_store_id : vector_store_ids) { + vector_store_ids_json.append(vector_store_id); + } + json["vector_store_ids"] = vector_store_ids_json; + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } +}; +} // namespace OpenAi diff --git a/engine/controllers/messages.cc b/engine/controllers/messages.cc index 55d9f6370..ef82b3412 100644 --- a/engine/controllers/messages.cc +++ b/engine/controllers/messages.cc @@ -1,5 +1,5 @@ #include "messages.h" -#include "common/api-dto/messages/delete_message_response.h" +#include "common/api-dto/delete_success_response.h" #include "common/message_content.h" #include "common/message_role.h" #include "common/variant_map.h" @@ -75,16 +75,13 @@ void Messages::CreateMessage( return; } - ThreadMessage::Role role = role_str == "user" - ? ThreadMessage::Role::USER - : ThreadMessage::Role::ASSISTANT; + auto role = role_str == "user" ? OpenAi::Role::USER : OpenAi::Role::ASSISTANT; - std::variant>> + std::variant>> content; if (json_body->get("content", "").isArray()) { - auto result = ThreadMessage::ParseContents(json_body->get("content", "")); + auto result = OpenAi::ParseContents(json_body->get("content", "")); if (result.has_error()) { Json::Value ret; ret["message"] = "Failed to parse content array: " + result.error(); @@ -128,12 +125,11 @@ void Messages::CreateMessage( } // attachments - std::optional> attachments = - std::nullopt; + std::optional> attachments = std::nullopt; if (json_body->get("attachments", "").isArray()) { - attachments = ThreadMessage::ParseAttachments( - std::move(json_body->get("attachments", ""))) - .value(); + attachments = + OpenAi::ParseAttachments(std::move(json_body->get("attachments", ""))) + .value(); } std::optional metadata = std::nullopt; @@ -287,7 +283,7 @@ void Messages::DeleteMessage( return; } - api_response::DeleteMessageResponse response; + api_response::DeleteSuccessResponse response; response.id = message_id; response.object = "thread.message.deleted"; response.deleted = true; diff --git a/engine/controllers/threads.cc b/engine/controllers/threads.cc new file mode 100644 index 000000000..a11c1071b --- /dev/null +++ b/engine/controllers/threads.cc @@ -0,0 +1,220 @@ +#include "threads.h" +#include "common/api-dto/delete_success_response.h" +#include "common/variant_map.h" +#include "utils/cortex_utils.h" +#include "utils/logging_utils.h" + +void Threads::ListThreads( + const HttpRequestPtr& req, + std::function&& callback, + std::optional limit, std::optional order, + std::optional after, std::optional before) const { + CTL_INF("ListThreads"); + auto res = + thread_service_->ListThreads(limit.value_or(20), order.value_or("desc"), + after.value_or(""), before.value_or("")); + + if (res.has_error()) { + Json::Value root; + root["message"] = res.error(); + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k400BadRequest); + callback(response); + return; + } + Json::Value msg_arr(Json::arrayValue); + for (auto& msg : res.value()) { + if (auto it = msg.ToJson(); it.has_value()) { + msg_arr.append(it.value()); + } else { + CTL_WRN("Failed to convert message to json: " + it.error()); + } + } + + Json::Value root; + root["object"] = "list"; + root["data"] = msg_arr; + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k200OK); + callback(response); +} + +void Threads::CreateThread( + const HttpRequestPtr& req, + std::function&& callback) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // TODO: namh handle tool_resources + // TODO: namh handle messages + + std::optional metadata = std::nullopt; + if (json_body->get("metadata", "").isObject()) { + auto res = Cortex::ConvertJsonValueToMap(json_body->get("metadata", "")); + if (res.has_error()) { + CTL_WRN("Failed to convert metadata to map: " + res.error()); + } else { + metadata = res.value(); + } + } + + auto res = thread_service_->CreateThread(nullptr, metadata); + + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto init_msg_res = + message_service_->InitializeMessages(res->id, std::nullopt); + + if (res.has_error()) { + CTL_ERR("Failed to convert message to json: " + res.error()); + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Threads::RetrieveThread( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id) const { + auto res = thread_service_->RetrieveThread(thread_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto thread_to_json = res->ToJson(); + if (thread_to_json.has_error()) { + CTL_ERR("Failed to convert message to json: " + thread_to_json.error()); + Json::Value ret; + ret["message"] = thread_to_json.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Threads::ModifyThread( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + std::optional metadata = std::nullopt; + if (auto it = json_body->get("metadata", ""); it) { + if (it.empty()) { + Json::Value ret; + ret["message"] = "Metadata can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + auto convert_res = Cortex::ConvertJsonValueToMap(it); + if (convert_res.has_error()) { + Json::Value ret; + ret["message"] = + "Failed to convert metadata to map: " + convert_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + metadata = convert_res.value(); + } + + if (!metadata.has_value()) { + Json::Value ret; + ret["message"] = "Metadata is mandatory"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // TODO: namh handle tools + auto res = + thread_service_->ModifyThread(thread_id, nullptr, metadata.value()); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto message_to_json = res->ToJson(); + if (message_to_json.has_error()) { + CTL_ERR("Failed to convert message to json: " + message_to_json.error()); + Json::Value ret; + ret["message"] = message_to_json.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Threads::DeleteThread( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id) { + auto res = thread_service_->DeleteThread(thread_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + api_response::DeleteSuccessResponse response; + response.id = thread_id; + response.object = "thread.deleted"; + response.deleted = true; + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(response.ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} diff --git a/engine/controllers/threads.h b/engine/controllers/threads.h new file mode 100644 index 000000000..92c509525 --- /dev/null +++ b/engine/controllers/threads.h @@ -0,0 +1,57 @@ +#pragma once + +#include +#include +#include "services/message_service.h" +#include "services/thread_service.h" + +using namespace drogon; + +class Threads : public drogon::HttpController { + public: + METHOD_LIST_BEGIN + ADD_METHOD_TO(Threads::CreateThread, "/v1/threads", Options, Post); + + ADD_METHOD_TO(Threads::ListThreads, + "/v1/" + "threads?limit={limit}&order={order}&after={after}&before={" + "before}", + Get); + + ADD_METHOD_TO(Threads::RetrieveThread, "/v1/threads/{thread_id}", Get); + ADD_METHOD_TO(Threads::ModifyThread, "/v1/threads/{thread_id}", Options, + Post); + ADD_METHOD_TO(Threads::DeleteThread, "/v1/threads/{thread_id}", Options, + Delete); + METHOD_LIST_END + + explicit Threads(std::shared_ptr thread_srv, + std::shared_ptr msg_srv) + : thread_service_{thread_srv}, message_service_{msg_srv} {} + + void CreateThread(const HttpRequestPtr& req, + std::function&& callback); + + void ListThreads(const HttpRequestPtr& req, + std::function&& callback, + std::optional limit, + std::optional order, + std::optional after, + std::optional before) const; + + void RetrieveThread(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id) const; + + void ModifyThread(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id); + + void DeleteThread(const HttpRequestPtr& req, + std::function&& callback, + const std::string& thread_id); + + private: + std::shared_ptr thread_service_; + std::shared_ptr message_service_; +}; diff --git a/engine/main.cc b/engine/main.cc index d076c02bd..0177a2143 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -1,7 +1,6 @@ #include #include #include -#include "common/repository/message_repository.h" #include "controllers/configs.h" #include "controllers/engines.h" #include "controllers/events.h" @@ -10,24 +9,23 @@ #include "controllers/models.h" #include "controllers/process_manager.h" #include "controllers/server.h" -#include "cortex-common/cortexpythoni.h" +#include "controllers/threads.h" #include "database/database.h" #include "migrations/migration_manager.h" #include "repositories/message_fs_repository.h" +#include "repositories/thread_fs_repository.h" #include "services/config_service.h" #include "services/file_watcher_service.h" #include "services/message_service.h" #include "services/model_service.h" +#include "services/thread_service.h" #include "utils/archive_utils.h" #include "utils/cortex_utils.h" -#include "utils/dylib.h" #include "utils/event_processor.h" #include "utils/file_logger.h" #include "utils/file_manager_utils.h" -#include "utils/hardware/gguf/gguf_file_estimate.h" #include "utils/logging_utils.h" #include "utils/system_info_utils.h" -#include "utils/widechar_conv.h" #if defined(__APPLE__) && defined(__MACH__) #include // for dirname() @@ -40,6 +38,7 @@ #include // for readlink() #elif defined(_WIN32) #include +#include "utils/widechar_conv.h" #undef max #else #error "Unsupported platform!" @@ -120,9 +119,14 @@ void RunServer(std::optional port, bool ignore_cout) { auto event_queue_ptr = std::make_shared(); cortex::event::EventProcessor event_processor(event_queue_ptr); - std::shared_ptr msg_repo = - std::make_shared(); + auto msg_repo = std::make_shared( + file_manager_utils::GetCortexDataPath()); + auto thread_repo = std::make_shared( + file_manager_utils::GetCortexDataPath()); + + auto thread_srv = std::make_shared(thread_repo); auto message_srv = std::make_shared(msg_repo); + auto model_dir_path = file_manager_utils::GetModelsContainerPath(); auto config_service = std::make_shared(); auto download_service = @@ -138,6 +142,7 @@ void RunServer(std::optional port, bool ignore_cout) { file_watcher_srv->start(); // initialize custom controllers + auto thread_ctl = std::make_shared(thread_srv, message_srv); auto message_ctl = std::make_shared(message_srv); auto engine_ctl = std::make_shared(engine_service); auto model_ctl = std::make_shared(model_service, engine_service); @@ -148,6 +153,7 @@ void RunServer(std::optional port, bool ignore_cout) { std::make_shared(inference_svc, engine_service); auto config_ctl = std::make_shared(config_service); + drogon::app().registerController(thread_ctl); drogon::app().registerController(message_ctl); drogon::app().registerController(engine_ctl); drogon::app().registerController(model_ctl); @@ -318,27 +324,6 @@ int main(int argc, char* argv[]) { } } - // // Check if this process is for python execution - // if (argc > 1) { - // if (strcmp(argv[1], "--run_python_file") == 0) { - // std::string py_home_path = (argc > 3) ? argv[3] : ""; - // std::unique_ptr dl; - // try { - // std::string abs_path = - // cortex_utils::GetCurrentPath() + kPythonRuntimeLibPath; - // dl = std::make_unique(abs_path, "engine"); - // } catch (const cortex_cpp::dylib::load_error& e) { - // LOG_ERROR << "Could not load engine: " << e.what(); - // return 1; - // } - - // auto func = dl->get_function("get_engine"); - // auto e = func(); - // e->ExecutePythonFile(argv[0], argv[2], py_home_path); - // return 0; - // } - // } - RunServer(server_port, ignore_cout_log); return 0; } diff --git a/engine/repositories/message_fs_repository.cc b/engine/repositories/message_fs_repository.cc index 60cc0b5bf..e576a7695 100644 --- a/engine/repositories/message_fs_repository.cc +++ b/engine/repositories/message_fs_repository.cc @@ -1,32 +1,22 @@ #include "message_fs_repository.h" -#include "utils/file_manager_utils.h" +#include +#include #include "utils/result.hpp" -namespace { -constexpr static const std::string_view kMessageFile = "messages.jsonl"; - -inline cpp::result GetMessageFileAbsPath( - const std::string& thread_id) { - auto path = - file_manager_utils::GetThreadsContainerPath() / thread_id / kMessageFile; - if (!std::filesystem::exists(path)) { - return cpp::fail("Message file not exist at path: " + path.string()); - } - return path; +std::filesystem::path MessageFsRepository::GetMessagePath( + const std::string& thread_id) const { + return data_folder_path_ / kThreadContainerFolderName / thread_id / + kMessageFile; } -} // namespace cpp::result MessageFsRepository::CreateMessage( - ThreadMessage::Message& message) { + OpenAi::Message& message) { CTL_INF("CreateMessage for thread " + message.thread_id); - auto path = GetMessageFileAbsPath(message.thread_id); - if (path.has_error()) { - return cpp::fail(path.error()); - } + auto path = GetMessagePath(message.thread_id); - std::ofstream file(path->string(), std::ios::app); + std::ofstream file(path, std::ios::app); if (!file) { - return cpp::fail("Failed to open file for writing: " + path->string()); + return cpp::fail("Failed to open file for writing: " + path.string()); } auto mutex = GrabMutex(message.thread_id); @@ -40,27 +30,24 @@ cpp::result MessageFsRepository::CreateMessage( file.flush(); if (file.fail()) { - return cpp::fail("Failed to write to file: " + path->string()); + return cpp::fail("Failed to write to file: " + path.string()); } file.close(); if (file.fail()) { - return cpp::fail("Failed to close file after writing: " + path->string()); + return cpp::fail("Failed to close file after writing: " + path.string()); } return {}; } -cpp::result, std::string> +cpp::result, std::string> MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, const std::string& order, const std::string& after, const std::string& before, const std::string& run_id) const { CTL_INF("Listing messages for thread " + thread_id); - auto path = GetMessageFileAbsPath(thread_id); - if (path.has_error()) { - return cpp::fail(path.error()); - } + auto path = GetMessagePath(thread_id); auto mutex = GrabMutex(thread_id); std::shared_lock lock(*mutex); @@ -68,13 +55,9 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, return ReadMessageFromFile(thread_id); } -cpp::result -MessageFsRepository::RetrieveMessage(const std::string& thread_id, - const std::string& message_id) const { - auto path = GetMessageFileAbsPath(thread_id); - if (path.has_error()) { - return cpp::fail(path.error()); - } +cpp::result MessageFsRepository::RetrieveMessage( + const std::string& thread_id, const std::string& message_id) const { + auto path = GetMessagePath(thread_id); auto mutex = GrabMutex(thread_id); std::unique_lock lock(*mutex); @@ -94,11 +77,8 @@ MessageFsRepository::RetrieveMessage(const std::string& thread_id, } cpp::result MessageFsRepository::ModifyMessage( - ThreadMessage::Message& message) { - auto path = GetMessageFileAbsPath(message.thread_id); - if (path.has_error()) { - return cpp::fail(path.error()); - } + OpenAi::Message& message) { + auto path = GetMessagePath(message.thread_id); auto mutex = GrabMutex(message.thread_id); std::unique_lock lock(*mutex); @@ -108,10 +88,9 @@ cpp::result MessageFsRepository::ModifyMessage( return cpp::fail(messages.error()); } - std::ofstream file(path.value().string(), std::ios::trunc); + std::ofstream file(path, std::ios::trunc); if (!file) { - return cpp::fail("Failed to open file for writing: " + - path.value().string()); + return cpp::fail("Failed to open file for writing: " + path.string()); } bool found = false; @@ -126,11 +105,11 @@ cpp::result MessageFsRepository::ModifyMessage( file.flush(); if (file.fail()) { - return cpp::fail("Failed to write to file: " + path->string()); + return cpp::fail("Failed to write to file: " + path.string()); } file.close(); if (file.fail()) { - return cpp::fail("Failed to close file after writing: " + path->string()); + return cpp::fail("Failed to close file after writing: " + path.string()); } if (!found) { @@ -141,10 +120,7 @@ cpp::result MessageFsRepository::ModifyMessage( cpp::result MessageFsRepository::DeleteMessage( const std::string& thread_id, const std::string& message_id) { - auto path = GetMessageFileAbsPath(thread_id); - if (path.has_error()) { - return cpp::fail(path.error()); - } + auto path = GetMessagePath(thread_id); auto mutex = GrabMutex(thread_id); std::unique_lock lock(*mutex); @@ -153,10 +129,9 @@ cpp::result MessageFsRepository::DeleteMessage( return cpp::fail(messages.error()); } - std::ofstream file(path.value().string(), std::ios::trunc); + std::ofstream file(path, std::ios::trunc); if (!file) { - return cpp::fail("Failed to open file for writing: " + - path.value().string()); + return cpp::fail("Failed to open file for writing: " + path.string()); } bool found = false; @@ -170,11 +145,11 @@ cpp::result MessageFsRepository::DeleteMessage( file.flush(); if (file.fail()) { - return cpp::fail("Failed to write to file: " + path->string()); + return cpp::fail("Failed to write to file: " + path.string()); } file.close(); if (file.fail()) { - return cpp::fail("Failed to close file after writing: " + path->string()); + return cpp::fail("Failed to close file after writing: " + path.string()); } if (!found) { @@ -184,26 +159,22 @@ cpp::result MessageFsRepository::DeleteMessage( return {}; } -cpp::result, std::string> +cpp::result, std::string> MessageFsRepository::ReadMessageFromFile(const std::string& thread_id) const { LOG_TRACE << "Reading messages from file for thread " << thread_id; - auto path = GetMessageFileAbsPath(thread_id); - if (path.has_error()) { - return cpp::fail(path.error()); - } + auto path = GetMessagePath(thread_id); - std::ifstream file(path.value()); + std::ifstream file(path); if (!file) { - return cpp::fail("Failed to open file: " + path->string()); + return cpp::fail("Failed to open file: " + path.string()); } - std::vector messages; + std::vector messages; std::string line; while (std::getline(file, line)) { if (line.empty()) continue; - auto msg_parse_result = - ThreadMessage::Message::FromJsonString(std::move(line)); + auto msg_parse_result = OpenAi::Message::FromJsonString(std::move(line)); if (msg_parse_result.has_error()) { CTL_WRN("Failed to parse message: " + msg_parse_result.error()); continue; @@ -224,3 +195,49 @@ std::shared_mutex* MessageFsRepository::GrabMutex( } return thread_mutex.get(); } + +cpp::result MessageFsRepository::InitializeMessages( + const std::string& thread_id, + std::optional> messages) { + CTL_INF("Initializing messages for thread " + thread_id); + + auto path = GetMessagePath(thread_id); + + if (!std::filesystem::exists(path.parent_path())) { + return cpp::fail( + "Failed to initialize messages, thread is not created yet! Path does " + "not exist: " + + path.parent_path().string()); + } + + auto mutex = GrabMutex(thread_id); + std::unique_lock lock(*mutex); + + std::ofstream file(path, std::ios::trunc); + if (!file) { + return cpp::fail("Failed to create message file: " + path.string()); + } + + if (messages.has_value()) { + for (auto& message : messages.value()) { + auto json_str = message.ToSingleLineJsonString(); + if (json_str.has_error()) { + CTL_WRN("Failed to serialize message: " + json_str.error()); + continue; + } + file << json_str.value(); + } + } + + file.flush(); + if (file.fail()) { + return cpp::fail("Failed to write to file: " + path.string()); + } + + file.close(); + if (file.fail()) { + return cpp::fail("Failed to close file after writing: " + path.string()); + } + + return {}; +} diff --git a/engine/repositories/message_fs_repository.h b/engine/repositories/message_fs_repository.h index d8bcd02a7..2146778bf 100644 --- a/engine/repositories/message_fs_repository.h +++ b/engine/repositories/message_fs_repository.h @@ -1,39 +1,63 @@ #pragma once +#include #include #include #include "common/repository/message_repository.h" class MessageFsRepository : public MessageRepository { + constexpr static auto kMessageFile = "messages.jsonl"; + constexpr static auto kThreadContainerFolderName = "threads"; + public: cpp::result CreateMessage( - ThreadMessage::Message& message) override; + OpenAi::Message& message) override; - cpp::result, std::string> ListMessages( - const std::string& thread_id, uint8_t limit = 20, - const std::string& order = "desc", const std::string& after = "", - const std::string& before = "", - const std::string& run_id = "") const override; + cpp::result, std::string> ListMessages( + const std::string& thread_id, uint8_t limit, const std::string& order, + const std::string& after, const std::string& before, + const std::string& run_id) const override; - cpp::result RetrieveMessage( + cpp::result RetrieveMessage( const std::string& thread_id, const std::string& message_id) const override; cpp::result ModifyMessage( - ThreadMessage::Message& message) override; + OpenAi::Message& message) override; cpp::result DeleteMessage( const std::string& thread_id, const std::string& message_id) override; + cpp::result InitializeMessages( + const std::string& thread_id, + std::optional> messages) override; + + explicit MessageFsRepository(std::filesystem::path data_folder_path) + : data_folder_path_{data_folder_path} { + CTL_INF("Constructing MessageFsRepository.."); + auto thread_container_path = data_folder_path_ / kThreadContainerFolderName; + + if (!std::filesystem::exists(thread_container_path)) { + std::filesystem::create_directories(thread_container_path); + } + } + ~MessageFsRepository() = default; private: - cpp::result, std::string> - ReadMessageFromFile(const std::string& thread_id) const; + cpp::result, std::string> ReadMessageFromFile( + const std::string& thread_id) const; + + /** + * The path to the data folder. + */ + std::filesystem::path data_folder_path_; + + std::filesystem::path GetMessagePath(const std::string& thread_id) const; std::shared_mutex* GrabMutex(const std::string& thread_id) const; + mutable std::mutex mutex_map_mutex_; mutable std::unordered_map> thread_mutexes_; - mutable std::mutex mutex_map_mutex_; }; diff --git a/engine/repositories/thread_fs_repository.cc b/engine/repositories/thread_fs_repository.cc new file mode 100644 index 000000000..64dad6ea5 --- /dev/null +++ b/engine/repositories/thread_fs_repository.cc @@ -0,0 +1,166 @@ +#include "thread_fs_repository.h" +#include +#include + +cpp::result, std::string> +ThreadFsRepository::ListThreads(uint8_t limit, const std::string& order, + const std::string& after, + const std::string& before) const { + CTL_INF("ListThreads: limit=" + std::to_string(limit) + ", order=" + order + + ", after=" + after + ", before=" + before); + std::vector threads; + + try { + auto thread_container_path = data_folder_path_ / kThreadContainerFolderName; + for (const auto& entry : + std::filesystem::directory_iterator(thread_container_path)) { + if (!entry.is_directory()) + continue; + + if (!std::filesystem::exists(entry.path() / kThreadFileName)) + continue; + + auto current_thread_id = entry.path().filename().string(); + CTL_INF("ListThreads: Found thread: " + current_thread_id); + std::shared_lock thread_lock(GrabThreadMutex(current_thread_id)); + + auto thread_result = LoadThread(current_thread_id); + if (thread_result.has_value()) { + threads.push_back(std::move(thread_result.value())); + } + + thread_lock.unlock(); + } + + return threads; + } catch (const std::exception& e) { + return cpp::fail(std::string("Failed to list threads: ") + e.what()); + } +} + +std::shared_mutex& ThreadFsRepository::GrabThreadMutex( + const std::string& thread_id) const { + std::shared_lock map_lock(map_mutex_); + auto it = thread_mutexes_.find(thread_id); + if (it != thread_mutexes_.end()) { + return *it->second; + } + + map_lock.unlock(); + std::unique_lock map_write_lock(map_mutex_); + return *thread_mutexes_ + .try_emplace(thread_id, std::make_unique()) + .first->second; +} + +std::filesystem::path ThreadFsRepository::GetThreadPath( + const std::string& thread_id) const { + return data_folder_path_ / kThreadContainerFolderName / thread_id; +} + +cpp::result ThreadFsRepository::LoadThread( + const std::string& thread_id) const { + auto path = GetThreadPath(thread_id) / kThreadFileName; + if (!std::filesystem::exists(path)) { + return cpp::fail("Path does not exist: " + path.string()); + } + + try { + std::ifstream file(path); + if (!file.is_open()) { + return cpp::fail("Failed to open file: " + path.string()); + } + + Json::Value root; + Json::CharReaderBuilder builder; + JSONCPP_STRING errs; + + if (!parseFromStream(builder, file, &root, &errs)) { + return cpp::fail("Failed to parse JSON: " + errs); + } + + return OpenAi::Thread::FromJson(root); + } catch (const std::exception& e) { + return cpp::fail("Failed to load thread: " + std::string(e.what())); + } +} + +cpp::result ThreadFsRepository::CreateThread( + OpenAi::Thread& thread) { + CTL_INF("CreateThread: " + thread.id); + std::unique_lock lock(GrabThreadMutex(thread.id)); + auto thread_path = GetThreadPath(thread.id); + + if (std::filesystem::exists(thread_path)) { + return cpp::fail("Thread exists: " + thread.id); + } + + std::filesystem::create_directories(thread_path); + auto thread_file_path = thread_path / kThreadFileName; + std::ofstream thread_file(thread_file_path); + thread_file.close(); + + return SaveThread(thread); +} + +cpp::result ThreadFsRepository::SaveThread( + OpenAi::Thread& thread) { + auto path = GetThreadPath(thread.id) / kThreadFileName; + if (!std::filesystem::exists(path)) { + return cpp::fail("Path does not exist: " + path.string()); + } + + std::ofstream file(path); + try { + if (!file) { + return cpp::fail("Failed to open file: " + path.string()); + } + file << thread.ToJson()->toStyledString(); + file.flush(); + file.close(); + return {}; + } catch (const std::exception& e) { + file.close(); + return cpp::fail("Failed to save thread: " + std::string(e.what())); + } +} + +cpp::result ThreadFsRepository::RetrieveThread( + const std::string& thread_id) const { + std::shared_lock lock(GrabThreadMutex(thread_id)); + return LoadThread(thread_id); +} + +cpp::result ThreadFsRepository::ModifyThread( + OpenAi::Thread& thread) { + std::unique_lock lock(GrabThreadMutex(thread.id)); + auto thread_path = GetThreadPath(thread.id); + + if (!std::filesystem::exists(thread_path)) { + return cpp::fail("Thread doesn't exist: " + thread.id); + } + + return SaveThread(thread); +} + +cpp::result ThreadFsRepository::DeleteThread( + const std::string& thread_id) { + CTL_INF("DeleteThread: " + thread_id); + + { + std::unique_lock thread_lock(GrabThreadMutex(thread_id)); + auto path = GetThreadPath(thread_id); + if (!std::filesystem::exists(path)) { + return cpp::fail("Thread doesn't exist: " + thread_id); + } + try { + std::filesystem::remove_all(path); + } catch (const std::exception& e) { + return cpp::fail(std::string("Failed to delete thread: ") + e.what()); + } + } + + std::unique_lock map_lock(map_mutex_); + thread_mutexes_.erase(thread_id); + return {}; +} diff --git a/engine/repositories/thread_fs_repository.h b/engine/repositories/thread_fs_repository.h new file mode 100644 index 000000000..d834b8e44 --- /dev/null +++ b/engine/repositories/thread_fs_repository.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include +#include +#include "common/repository/thread_repository.h" +#include "common/thread.h" +#include "utils/logging_utils.h" + +class ThreadFsRepository : public ThreadRepository { + private: + constexpr static auto kThreadFileName = "thread.json"; + constexpr static auto kThreadContainerFolderName = "threads"; + + mutable std::shared_mutex map_mutex_; + mutable std::unordered_map> + thread_mutexes_; + + /** + * The path to the data folder. + */ + std::filesystem::path data_folder_path_; + + std::shared_mutex& GrabThreadMutex(const std::string& thread_id) const; + + std::filesystem::path GetThreadPath(const std::string& thread_id) const; + + /** + * Read the thread file and parse to Thread from the file system. + */ + cpp::result LoadThread( + const std::string& thread_id) const; + + cpp::result SaveThread(OpenAi::Thread& thread); + + public: + explicit ThreadFsRepository(const std::filesystem::path& data_folder_path) + : data_folder_path_{data_folder_path} { + CTL_INF("Constructing ThreadFsRepository.."); + auto thread_container_path = data_folder_path_ / kThreadContainerFolderName; + + if (!std::filesystem::exists(thread_container_path)) { + std::filesystem::create_directories(thread_container_path); + } + } + + cpp::result CreateThread(OpenAi::Thread& thread) override; + + cpp::result, std::string> ListThreads( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const override; + + cpp::result RetrieveThread( + const std::string& thread_id) const override; + + cpp::result ModifyThread(OpenAi::Thread& thread) override; + + cpp::result DeleteThread( + const std::string& thread_id) override; + + ~ThreadFsRepository() = default; +}; diff --git a/engine/services/message_service.cc b/engine/services/message_service.cc index 31ae38420..dfad74236 100644 --- a/engine/services/message_service.cc +++ b/engine/services/message_service.cc @@ -3,40 +3,39 @@ #include "utils/result.hpp" #include "utils/ulid/ulid.hh" -cpp::result MessageService::CreateMessage( - const std::string& thread_id, const ThreadMessage::Role& role, - std::variant>>&& +cpp::result MessageService::CreateMessage( + const std::string& thread_id, const OpenAi::Role& role, + std::variant>>&& content, - std::optional> attachments, + std::optional> attachments, std::optional metadata) { LOG_TRACE << "CreateMessage for thread " << thread_id; - auto now = std::chrono::system_clock::now(); + auto seconds_since_epoch = - std::chrono::duration_cast(now.time_since_epoch()) + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) .count(); - std::vector> content_list{}; + std::vector> content_list{}; + // if content is string if (std::holds_alternative(content)) { - auto text_content = std::make_unique(); + auto text_content = std::make_unique(); text_content->text.value = std::get(content); content_list.push_back(std::move(text_content)); } else { content_list = std::move( - std::get>>( - content)); + std::get>>(content)); } - ulid::ULID ulid = ulid::Create(seconds_since_epoch, []() { return 4; }); - std::string str = ulid::Marshal(ulid); - LOG_TRACE << "Generated message ID: " << str; + auto ulid = ulid::CreateNowRand(); + auto msg_id = ulid::Marshal(ulid); - ThreadMessage::Message msg; - msg.id = str; + OpenAi::Message msg; + msg.id = msg_id; msg.object = "thread.message"; msg.created_at = 0; msg.thread_id = thread_id; - msg.status = ThreadMessage::Status::COMPLETED; + msg.status = OpenAi::Status::COMPLETED; msg.completed_at = seconds_since_epoch; msg.incomplete_at = std::nullopt; msg.incomplete_details = std::nullopt; @@ -54,23 +53,23 @@ cpp::result MessageService::CreateMessage( } } -cpp::result, std::string> +cpp::result, std::string> MessageService::ListMessages(const std::string& thread_id, uint8_t limit, const std::string& order, const std::string& after, const std::string& before, const std::string& run_id) const { CTL_INF("ListMessages for thread " + thread_id); - return message_repository_->ListMessages(thread_id); + return message_repository_->ListMessages(thread_id, limit, order, after, + before, run_id); } -cpp::result -MessageService::RetrieveMessage(const std::string& thread_id, - const std::string& message_id) const { +cpp::result MessageService::RetrieveMessage( + const std::string& thread_id, const std::string& message_id) const { CTL_INF("RetrieveMessage for thread " + thread_id); return message_repository_->RetrieveMessage(thread_id, message_id); } -cpp::result MessageService::ModifyMessage( +cpp::result MessageService::ModifyMessage( const std::string& thread_id, const std::string& message_id, std::optional metadata) { LOG_TRACE << "ModifyMessage for thread " << thread_id << ", message " @@ -103,3 +102,20 @@ cpp::result MessageService::DeleteMessage( return message_id; } } + +cpp::result MessageService::InitializeMessages( + const std::string& thread_id, + std::optional> messages) { + CTL_INF("InitializeMessages for thread " + thread_id); + + if (messages.has_value()) { + CTL_INF("Prepopulated messages length: " + + std::to_string(messages->size())); + } else { + + CTL_INF("Prepopulated with empty messages"); + } + + return message_repository_->InitializeMessages(thread_id, + std::move(messages)); +} diff --git a/engine/services/message_service.h b/engine/services/message_service.h index e62970b54..6c4880f32 100644 --- a/engine/services/message_service.h +++ b/engine/services/message_service.h @@ -9,27 +9,28 @@ class MessageService { explicit MessageService(std::shared_ptr message_repository) : message_repository_{message_repository} {} - cpp::result CreateMessage( - const std::string& thread_id, const ThreadMessage::Role& role, - std::variant>>&& + cpp::result CreateMessage( + const std::string& thread_id, const OpenAi::Role& role, + std::variant>>&& content, - std::optional> attachments, + std::optional> attachments, std::optional metadata); - cpp::result, std::string> ListMessages( + cpp::result InitializeMessages( + const std::string& thread_id, + std::optional> messages); + + cpp::result, std::string> ListMessages( const std::string& thread_id, uint8_t limit = 20, const std::string& order = "desc", const std::string& after = "", const std::string& before = "", const std::string& run_id = "") const; - cpp::result RetrieveMessage( + cpp::result RetrieveMessage( const std::string& thread_id, const std::string& message_id) const; - cpp::result ModifyMessage( + cpp::result ModifyMessage( const std::string& thread_id, const std::string& message_id, - std::optional>> - metadata); + std::optional metadata); cpp::result DeleteMessage( const std::string& thread_id, const std::string& message_id); diff --git a/engine/services/thread_service.cc b/engine/services/thread_service.cc new file mode 100644 index 000000000..25784c2ee --- /dev/null +++ b/engine/services/thread_service.cc @@ -0,0 +1,83 @@ +#include "thread_service.h" +#include "utils/logging_utils.h" +#include "utils/ulid/ulid.hh" + +cpp::result ThreadService::CreateThread( + std::unique_ptr tool_resources, + std::optional metadata) { + LOG_TRACE << "CreateThread"; + + auto seconds_since_epoch = + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + + auto ulid = ulid::CreateNowRand(); + auto thread_id = ulid::Marshal(ulid); + + OpenAi::Thread thread; + thread.id = thread_id; + thread.object = "thread"; + thread.created_at = seconds_since_epoch; + + if (tool_resources) { + thread.tool_resources = std::move(tool_resources); + } + thread.metadata = metadata.value_or(Cortex::VariantMap{}); + + if (auto res = thread_repository_->CreateThread(thread); res.has_error()) { + return cpp::fail("Failed to create message: " + res.error()); + } + + return thread; +} + +cpp::result, std::string> +ThreadService::ListThreads(uint8_t limit, const std::string& order, + const std::string& after, + const std::string& before) const { + CTL_INF("ListThreads"); + return thread_repository_->ListThreads(limit, order, after, before); +} + +cpp::result ThreadService::RetrieveThread( + const std::string& thread_id) const { + CTL_INF("RetrieveThread: " + thread_id); + return thread_repository_->RetrieveThread(thread_id); +} + +cpp::result ThreadService::ModifyThread( + const std::string& thread_id, + std::unique_ptr tool_resources, + std::optional metadata) { + LOG_TRACE << "ModifyThread " << thread_id; + auto retrieve_res = RetrieveThread(thread_id); + if (retrieve_res.has_error()) { + return cpp::fail("Failed to retrieve thread: " + retrieve_res.error()); + } + + if (tool_resources) { + retrieve_res->tool_resources = std::move(tool_resources); + } + retrieve_res->metadata = std::move(metadata.value()); + + auto res = thread_repository_->ModifyThread(retrieve_res.value()); + if (res.has_error()) { + CTL_ERR("Failed to modify thread: " + res.error()); + return cpp::fail("Failed to modify thread: " + res.error()); + } else { + return RetrieveThread(thread_id); + } +} + +cpp::result ThreadService::DeleteThread( + const std::string& thread_id) { + LOG_TRACE << "DeleteThread: " + thread_id; + auto res = thread_repository_->DeleteThread(thread_id); + if (res.has_error()) { + LOG_ERROR << "Failed to delete thread: " + res.error(); + return cpp::fail("Failed to delete thread: " + res.error()); + } else { + return thread_id; + } +} diff --git a/engine/services/thread_service.h b/engine/services/thread_service.h new file mode 100644 index 000000000..966b0ab01 --- /dev/null +++ b/engine/services/thread_service.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include "common/repository/thread_repository.h" +#include "common/thread_tool_resources.h" +#include "common/variant_map.h" +#include "utils/result.hpp" + +class ThreadService { + public: + explicit ThreadService(std::shared_ptr thread_repository) + : thread_repository_{thread_repository} {} + + cpp::result CreateThread( + std::unique_ptr tool_resources, + std::optional metadata); + + cpp::result, std::string> ListThreads( + uint8_t limit, const std::string& order, const std::string& after, + const std::string& before) const; + + cpp::result RetrieveThread( + const std::string& thread_id) const; + + cpp::result ModifyThread( + const std::string& thread_id, + std::unique_ptr tool_resources, + std::optional metadata); + + cpp::result DeleteThread( + const std::string& thread_id); + + private: + std::shared_ptr thread_repository_; +}; From a49054c32682d72de57b346fe478b5c3751667ba Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 5 Dec 2024 15:19:19 +0700 Subject: [PATCH 13/44] fix: deadlock when unload engine (#1769) * fix: deadlock when unload engine * fix: add lock --- engine/services/engine_service.cc | 75 +++++++++++++------------------ engine/services/engine_service.h | 4 +- 2 files changed, 33 insertions(+), 46 deletions(-) diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index c91fd0dd0..fe5317c7d 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -656,7 +656,6 @@ EngineService::GetInstalledEngineVariants(const std::string& engine) const { } bool EngineService::IsEngineLoaded(const std::string& engine) { - std::lock_guard lock(engines_mutex_); auto ne = NormalizeEngine(engine); return engines_.find(ne) != engines_.end(); } @@ -675,7 +674,7 @@ cpp::result EngineService::GetLoadedEngine( cpp::result EngineService::LoadEngine( const std::string& engine_name) { auto ne = NormalizeEngine(engine_name); - + std::lock_guard lock(engines_mutex_); if (IsEngineLoaded(ne)) { CTL_INF("Engine " << ne << " is already loaded"); return {}; @@ -779,7 +778,7 @@ cpp::result EngineService::LoadEngine( should_use_dll_search_path) { { - std::lock_guard lock(engines_mutex_); + // Remove llamacpp dll directory if (!RemoveDllDirectory(engines_[kLlamaRepo].cookie)) { CTL_WRN("Could not remove dll directory: " << kLlamaRepo); @@ -801,11 +800,8 @@ cpp::result EngineService::LoadEngine( } } #endif - { - std::lock_guard lock(engines_mutex_); - engines_[ne].dl = std::make_unique( - engine_dir_path.string(), "engine"); - } + engines_[ne].dl = + std::make_unique(engine_dir_path.string(), "engine"); #if defined(__linux__) const char* name = "LD_LIBRARY_PATH"; auto data = getenv(name); @@ -826,45 +822,39 @@ cpp::result EngineService::LoadEngine( } catch (const cortex_cpp::dylib::load_error& e) { CTL_ERR("Could not load engine: " << e.what()); - { - std::lock_guard lock(engines_mutex_); - engines_.erase(ne); - } + engines_.erase(ne); return cpp::fail("Could not load engine " + ne + ": " + e.what()); } - { - std::lock_guard lock(engines_mutex_); - auto func = engines_[ne].dl->get_function("get_engine"); - engines_[ne].engine = func(); - - auto& en = std::get(engines_[ne].engine); - if (ne == kLlamaRepo) { //fix for llamacpp engine first - auto config = file_manager_utils::GetCortexConfig(); - if (en->IsSupported("SetFileLogger")) { - en->SetFileLogger(config.maxLogLines, - (std::filesystem::path(config.logFolderPath) / - std::filesystem::path(config.logLlamaCppPath)) - .string()); - } else { - CTL_WRN("Method SetFileLogger is not supported yet"); - } - if (en->IsSupported("SetLogLevel")) { - en->SetLogLevel(logging_utils_helper::global_log_level); - } else { - CTL_WRN("Method SetLogLevel is not supported yet"); - } + auto func = engines_[ne].dl->get_function("get_engine"); + engines_[ne].engine = func(); + + auto& en = std::get(engines_[ne].engine); + if (ne == kLlamaRepo) { //fix for llamacpp engine first + auto config = file_manager_utils::GetCortexConfig(); + if (en->IsSupported("SetFileLogger")) { + en->SetFileLogger(config.maxLogLines, + (std::filesystem::path(config.logFolderPath) / + std::filesystem::path(config.logLlamaCppPath)) + .string()); + } else { + CTL_WRN("Method SetFileLogger is not supported yet"); + } + if (en->IsSupported("SetLogLevel")) { + en->SetLogLevel(logging_utils_helper::global_log_level); + } else { + CTL_WRN("Method SetLogLevel is not supported yet"); } - CTL_DBG("loaded engine: " << ne); } + CTL_DBG("loaded engine: " << ne); return {}; } cpp::result EngineService::UnloadEngine( const std::string& engine) { auto ne = NormalizeEngine(engine); + std::lock_guard lock(engines_mutex_); { - std::lock_guard lock(engines_mutex_); if (!IsEngineLoaded(ne)) { return cpp::fail("Engine " + ne + " is not loaded yet!"); } @@ -893,14 +883,12 @@ cpp::result EngineService::UnloadEngine( } std::vector EngineService::GetLoadedEngines() { - { - std::lock_guard lock(engines_mutex_); - std::vector loaded_engines; - for (const auto& [key, value] : engines_) { - loaded_engines.push_back(value.engine); - } - return loaded_engines; + std::lock_guard lock(engines_mutex_); + std::vector loaded_engines; + for (const auto& [key, value] : engines_) { + loaded_engines.push_back(value.engine); } + return loaded_engines; } cpp::result @@ -1084,6 +1072,7 @@ std::string EngineService::DeleteEngine(int id) { cpp::result EngineService::GetRemoteModels( const std::string& engine_name) { + std::lock_guard lock(engines_mutex_); if (auto r = IsEngineReady(engine_name); r.has_error()) { return cpp::fail(r.error()); } @@ -1093,7 +1082,6 @@ cpp::result EngineService::GetRemoteModels( if (exist_engine.has_error()) { return cpp::fail("Remote engine '" + engine_name + "' is not installed"); } - if (engine_name == kOpenAiEngine) { engines_[engine_name].engine = new remote_engine::OpenAiEngine(); } else { @@ -1102,7 +1090,6 @@ cpp::result EngineService::GetRemoteModels( CTL_INF("Loaded engine: " << engine_name); } - auto& e = std::get(engines_[engine_name].engine); auto res = e->GetRemoteModels(); if (!res["error"].isNull()) { diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 8c8bfbbe6..ab274825d 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -112,8 +112,6 @@ class EngineService : public EngineServiceI { cpp::result, std::string> GetInstalledEngineVariants(const std::string& engine) const; - bool IsEngineLoaded(const std::string& engine); - cpp::result GetLoadedEngine( const std::string& engine_name); @@ -152,6 +150,8 @@ class EngineService : public EngineServiceI { const std::string& engine_name); private: + bool IsEngineLoaded(const std::string& engine); + cpp::result DownloadEngine( const std::string& engine, const std::string& version = "latest", const std::optional variant_name = std::nullopt); From 61c3ee1b6a75bd16137eaffeee2470818a33019f Mon Sep 17 00:00:00 2001 From: NamH Date: Fri, 6 Dec 2024 08:59:32 +0700 Subject: [PATCH 14/44] feat: add assistants (#1770) * feat: add assistants * add pagination messages * allow edit content of message --- engine/common/assistant.h | 157 +++++++++++++++++++ engine/common/assistant_tool.h | 91 +++++++++++ engine/common/thread.h | 23 +++ engine/config/model_config.h | 3 - engine/controllers/assistants.cc | 144 +++++++++++++++++ engine/controllers/assistants.h | 39 +++++ engine/controllers/messages.cc | 79 ++++++++-- engine/controllers/messages.h | 3 +- engine/controllers/threads.cc | 8 +- engine/controllers/threads.h | 2 +- engine/main.cc | 5 + engine/repositories/message_fs_repository.cc | 57 ++++++- engine/repositories/thread_fs_repository.cc | 124 ++++++++++++++- engine/repositories/thread_fs_repository.h | 29 +++- engine/services/assistant_service.cc | 28 ++++ engine/services/assistant_service.h | 24 +++ engine/services/message_service.cc | 24 ++- engine/services/message_service.h | 11 +- 18 files changed, 813 insertions(+), 38 deletions(-) create mode 100644 engine/common/assistant.h create mode 100644 engine/common/assistant_tool.h create mode 100644 engine/controllers/assistants.cc create mode 100644 engine/controllers/assistants.h create mode 100644 engine/services/assistant_service.cc create mode 100644 engine/services/assistant_service.h diff --git a/engine/common/assistant.h b/engine/common/assistant.h new file mode 100644 index 000000000..e49147e9e --- /dev/null +++ b/engine/common/assistant.h @@ -0,0 +1,157 @@ +#pragma once + +#include +#include "common/assistant_tool.h" +#include "common/thread_tool_resources.h" +#include "common/variant_map.h" +#include "utils/result.hpp" + +namespace OpenAi { +// Deprecated. After jan's migration, we should remove this struct +struct JanAssistant : JsonSerializable { + std::string id; + + std::string name; + + std::string object = "assistant"; + + uint32_t created_at; + + Json::Value tools; + + Json::Value model; + + std::string instructions; + + ~JanAssistant() = default; + + cpp::result ToJson() override { + try { + Json::Value json; + + json["id"] = id; + json["name"] = name; + json["object"] = object; + json["created_at"] = created_at; + + json["tools"] = tools; + json["model"] = model; + json["instructions"] = instructions; + + return json; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } + + static cpp::result FromJson(Json::Value&& json) { + if (json.empty()) { + return cpp::fail("Empty JSON"); + } + + JanAssistant assistant; + if (json.isMember("assistant_id")) { + assistant.id = json["assistant_id"].asString(); + } else { + assistant.id = json["id"].asString(); + } + + if (json.isMember("assistant_name")) { + assistant.name = json["assistant_name"].asString(); + } else { + assistant.name = json["name"].asString(); + } + assistant.object = "assistant"; + assistant.created_at = 0; // Jan does not have this + if (json.isMember("tools")) { + assistant.tools = json["tools"]; + } + if (json.isMember("model")) { + assistant.model = json["model"]; + } + assistant.instructions = json["instructions"].asString(); + + return assistant; + } +}; + +struct Assistant { + /** + * The identifier, which can be referenced in API endpoints. + */ + std::string id; + + /** + * The object type, which is always assistant. + */ + std::string object = "assistant"; + + /** + * The Unix timestamp (in seconds) for when the assistant was created. + */ + uint64_t created_at; + + /** + * The name of the assistant. The maximum length is 256 characters. + */ + std::optional name; + + /** + * The description of the assistant. The maximum length is 512 characters. + */ + std::optional description; + + /** + * ID of the model to use. You can use the List models API to see all of + * your available models, or see our Model overview for descriptions of them. + */ + std::string model; + + /** + * The system instructions that the assistant uses. The maximum length is + * 256,000 characters. + */ + std::optional instructions; + + /** + * A list of tool enabled on the assistant. There can be a maximum of 128 + * tools per assistant. Tools can be of types code_interpreter, file_search, + * or function. + */ + std::vector> tools; + + /** + * A set of resources that are used by the assistant's tools. The resources + * are specific to the type of tool. For example, the code_interpreter tool + * requires a list of file IDs, while the file_search tool requires a list + * of vector store IDs. + */ + std::optional> + tool_resources; + + /** + * Set of 16 key-value pairs that can be attached to an object. This can be + * useful for storing additional information about the object in a structured + * format. Keys can be a maximum of 64 characters long and values can be a + * maximum of 512 characters long. + */ + Cortex::VariantMap metadata; + + /** + * What sampling temperature to use, between 0 and 2. Higher values like + * 0.8 will make the output more random, while lower values like 0.2 will + * make it more focused and deterministic. + */ + std::optional temperature; + + /** + * An alternative to sampling with temperature, called nucleus sampling, + * where the model considers the results of the tokens with top_p + * probability mass. So 0.1 means only the tokens comprising the top 10% + * probability mass are considered. + * + * We generally recommend altering this or temperature but not both. + */ + std::optional top_p; +}; +} // namespace OpenAi diff --git a/engine/common/assistant_tool.h b/engine/common/assistant_tool.h new file mode 100644 index 000000000..622721708 --- /dev/null +++ b/engine/common/assistant_tool.h @@ -0,0 +1,91 @@ +#pragma once + +#include +#include + +namespace OpenAi { +struct AssistantTool { + std::string type; + + AssistantTool(const std::string& type) : type{type} {} + + virtual ~AssistantTool() = default; +}; + +struct AssistantCodeInterpreterTool : public AssistantTool { + AssistantCodeInterpreterTool() : AssistantTool{"code_interpreter"} {} + + ~AssistantCodeInterpreterTool() = default; +}; + +struct AssistantFileSearchTool : public AssistantTool { + AssistantFileSearchTool() : AssistantTool("file_search") {} + + ~AssistantFileSearchTool() = default; + + /** + * The ranking options for the file search. If not specified, + * the file search tool will use the auto ranker and a score_threshold of 0. + * + * See the file search tool documentation for more information. + */ + struct RankingOption { + /** + * The ranker to use for the file search. If not specified will use the auto ranker. + */ + std::string ranker; + + /** + * The score threshold for the file search. All values must be a + * floating point number between 0 and 1. + */ + float score_threshold; + }; + + /** + * Overrides for the file search tool. + */ + struct FileSearch { + /** + * The maximum number of results the file search tool should output. + * The default is 20 for gpt-4* models and 5 for gpt-3.5-turbo. + * This number should be between 1 and 50 inclusive. + * + * Note that the file search tool may output fewer than max_num_results results. + * See the file search tool documentation for more information. + */ + int max_num_result; + }; +}; + +struct AssistantFunctionTool : public AssistantTool { + AssistantFunctionTool() : AssistantTool("function") {} + + ~AssistantFunctionTool() = default; + + struct Function { + /** + * A description of what the function does, used by the model to choose + * when and how to call the function. + */ + std::string description; + + /** + * The name of the function to be called. Must be a-z, A-Z, 0-9, or contain + * underscores and dashes, with a maximum length of 64. + */ + std::string name; + + // TODO: namh handle parameters + + /** + * Whether to enable strict schema adherence when generating the function call. + * If set to true, the model will follow the exact schema defined in the parameters + * field. Only a subset of JSON Schema is supported when strict is true. + * + * Learn more about Structured Outputs in the function calling guide. + */ + std::optional strict; + }; +}; +} // namespace OpenAi diff --git a/engine/common/thread.h b/engine/common/thread.h index 20672ff72..60f408635 100644 --- a/engine/common/thread.h +++ b/engine/common/thread.h @@ -3,6 +3,7 @@ #include #include #include +#include "common/assistant.h" #include "common/thread_tool_resources.h" #include "common/variant_map.h" #include "json_serializable.h" @@ -47,6 +48,9 @@ struct Thread : JsonSerializable { */ Cortex::VariantMap metadata; + // For supporting Jan + std::optional> assistants; + static cpp::result FromJson(const Json::Value& json) { Thread thread; @@ -90,6 +94,25 @@ struct Thread : JsonSerializable { } } + if (json.isMember("title") && !json["title"].isNull()) { + thread.metadata["title"] = json["title"].asString(); + } + + if (json.isMember("assistants") && json["assistants"].isArray()) { + std::vector assistants; + for (Json::ArrayIndex i = 0; i < json["assistants"].size(); ++i) { + Json::Value assistant_json = json["assistants"][i]; + auto assistant_result = + JanAssistant::FromJson(std::move(assistant_json)); + if (assistant_result.has_error()) { + return cpp::fail("Failed to parse assistant: " + + assistant_result.error()); + } + assistants.push_back(std::move(assistant_result.value())); + } + thread.assistants = std::move(assistants); + } + return thread; } diff --git a/engine/config/model_config.h b/engine/config/model_config.h index 701547873..84e175d54 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -1,10 +1,8 @@ #pragma once #include -#include #include #include -#include #include #include #include @@ -12,7 +10,6 @@ #include #include "utils/format_utils.h" #include "utils/remote_models_utils.h" -#include "yaml-cpp/yaml.h" namespace config { diff --git a/engine/controllers/assistants.cc b/engine/controllers/assistants.cc new file mode 100644 index 000000000..405d7ed3c --- /dev/null +++ b/engine/controllers/assistants.cc @@ -0,0 +1,144 @@ +#include "assistants.h" +#include "utils/cortex_utils.h" +#include "utils/logging_utils.h" + +void Assistants::RetrieveAssistant( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) const { + CTL_INF("RetrieveAssistant: " + assistant_id); + auto res = assistant_service_->RetrieveAssistant(assistant_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto to_json_res = res->ToJson(); + if (to_json_res.has_error()) { + CTL_ERR("Failed to convert assistant to json: " + to_json_res.error()); + Json::Value ret; + ret["message"] = to_json_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + // TODO: namh need to use the text response because it contains model config + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } + } +} + +void Assistants::CreateAssistant( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // Parse assistant from request body + auto assistant_result = OpenAi::JanAssistant::FromJson(std::move(*json_body)); + if (assistant_result.has_error()) { + Json::Value ret; + ret["message"] = "Failed to parse assistant: " + assistant_result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // Call assistant service to create + auto create_result = assistant_service_->CreateAssistant( + assistant_id, assistant_result.value()); + if (create_result.has_error()) { + Json::Value ret; + ret["message"] = create_result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // Convert result to JSON and send response + auto to_json_result = create_result->ToJson(); + if (to_json_result.has_error()) { + CTL_ERR("Failed to convert assistant to json: " + to_json_result.error()); + Json::Value ret; + ret["message"] = to_json_result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(to_json_result.value()); + resp->setStatusCode(k201Created); + callback(resp); +} + +void Assistants::ModifyAssistant( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) { + auto json_body = req->getJsonObject(); + if (json_body == nullptr) { + Json::Value ret; + ret["message"] = "Request body can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // Parse assistant from request body + auto assistant_result = OpenAi::JanAssistant::FromJson(std::move(*json_body)); + if (assistant_result.has_error()) { + Json::Value ret; + ret["message"] = "Failed to parse assistant: " + assistant_result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // Call assistant service to create + auto modify_result = assistant_service_->ModifyAssistant( + assistant_id, assistant_result.value()); + if (modify_result.has_error()) { + Json::Value ret; + ret["message"] = modify_result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + // Convert result to JSON and send response + auto to_json_result = modify_result->ToJson(); + if (to_json_result.has_error()) { + CTL_ERR("Failed to convert assistant to json: " + to_json_result.error()); + Json::Value ret; + ret["message"] = to_json_result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(to_json_result.value()); + resp->setStatusCode(k200OK); + callback(resp); +} diff --git a/engine/controllers/assistants.h b/engine/controllers/assistants.h new file mode 100644 index 000000000..94ddd14b1 --- /dev/null +++ b/engine/controllers/assistants.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include "services/assistant_service.h" + +using namespace drogon; + +class Assistants : public drogon::HttpController { + public: + METHOD_LIST_BEGIN + ADD_METHOD_TO(Assistants::RetrieveAssistant, "/v1/assistants/{assistant_id}", + Get); + + ADD_METHOD_TO(Assistants::CreateAssistant, "/v1/assistants/{assistant_id}", + Options, Post); + + ADD_METHOD_TO(Assistants::ModifyAssistant, "/v1/assistants/{assistant_id}", + Options, Patch); + METHOD_LIST_END + + explicit Assistants(std::shared_ptr assistant_srv) + : assistant_service_{assistant_srv} {}; + + void RetrieveAssistant(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id) const; + + void CreateAssistant(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id); + + void ModifyAssistant(const HttpRequestPtr& req, + std::function&& callback, + const std::string& assistant_id); + + private: + std::shared_ptr assistant_service_; +}; diff --git a/engine/controllers/messages.cc b/engine/controllers/messages.cc index ef82b3412..27307803a 100644 --- a/engine/controllers/messages.cc +++ b/engine/controllers/messages.cc @@ -10,13 +10,13 @@ void Messages::ListMessages( const HttpRequestPtr& req, std::function&& callback, - const std::string& thread_id, std::optional limit, + const std::string& thread_id, std::optional limit, std::optional order, std::optional after, std::optional before, std::optional run_id) const { auto res = message_service_->ListMessages( - thread_id, limit.value_or(20), order.value_or("desc"), after.value_or(""), - before.value_or(""), run_id.value_or("")); + thread_id, std::stoi(limit.value_or("20")), order.value_or("desc"), + after.value_or(""), before.value_or(""), run_id.value_or("")); Json::Value root; if (res.has_error()) { @@ -212,39 +212,88 @@ void Messages::ModifyMessage( } std::optional metadata = std::nullopt; - if (auto it = json_body->get("metadata", ""); it) { - if (it.empty()) { + if (json_body->isMember("metadata")) { + if (auto it = json_body->get("metadata", ""); it) { + if (it.empty()) { + Json::Value ret; + ret["message"] = "Metadata can't be empty"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + auto convert_res = Cortex::ConvertJsonValueToMap(it); + if (convert_res.has_error()) { + Json::Value ret; + ret["message"] = + "Failed to convert metadata to map: " + convert_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + metadata = convert_res.value(); + } + } + + std::optional< + std::variant>>> + content = std::nullopt; + + if (json_body->get("content", "").isArray()) { + auto result = OpenAi::ParseContents(json_body->get("content", "")); + if (result.has_error()) { + Json::Value ret; + ret["message"] = "Failed to parse content array: " + result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + if (result.value().empty()) { Json::Value ret; - ret["message"] = "Metadata can't be empty"; + ret["message"] = "Content array cannot be empty"; auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); resp->setStatusCode(k400BadRequest); callback(resp); return; } - auto convert_res = Cortex::ConvertJsonValueToMap(it); - if (convert_res.has_error()) { + + content = std::move(result.value()); + } else if (json_body->get("content", "").isString()) { + auto content_str = json_body->get("content", "").asString(); + string_utils::Trim(content_str); + if (content_str.empty()) { Json::Value ret; - ret["message"] = - "Failed to convert metadata to map: " + convert_res.error(); + ret["message"] = "Content can't be empty"; auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); resp->setStatusCode(k400BadRequest); callback(resp); return; } - metadata = convert_res.value(); + + content = content_str; + } else if (!json_body->get("content", "").empty()) { + Json::Value ret; + ret["message"] = "Content must be either a string or an array"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; } - if (!metadata.has_value()) { + if (!metadata.has_value() && !content.has_value()) { Json::Value ret; - ret["message"] = "Metadata is mandatory"; + ret["message"] = "Nothing to update"; auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); resp->setStatusCode(k400BadRequest); callback(resp); return; } - auto res = - message_service_->ModifyMessage(thread_id, message_id, metadata.value()); + auto res = message_service_->ModifyMessage(thread_id, message_id, metadata, + std::move(content)); if (res.has_error()) { Json::Value ret; ret["message"] = "Failed to modify message: " + res.error(); diff --git a/engine/controllers/messages.h b/engine/controllers/messages.h index 340317eb8..045d8a207 100644 --- a/engine/controllers/messages.h +++ b/engine/controllers/messages.h @@ -34,7 +34,8 @@ class Messages : public drogon::HttpController { void ListMessages(const HttpRequestPtr& req, std::function&& callback, - const std::string& thread_id, std::optional limit, + const std::string& thread_id, + std::optional limit, std::optional order, std::optional after, std::optional before, diff --git a/engine/controllers/threads.cc b/engine/controllers/threads.cc index a11c1071b..1cd3aaeef 100644 --- a/engine/controllers/threads.cc +++ b/engine/controllers/threads.cc @@ -7,12 +7,12 @@ void Threads::ListThreads( const HttpRequestPtr& req, std::function&& callback, - std::optional limit, std::optional order, + std::optional limit, std::optional order, std::optional after, std::optional before) const { CTL_INF("ListThreads"); - auto res = - thread_service_->ListThreads(limit.value_or(20), order.value_or("desc"), - after.value_or(""), before.value_or("")); + auto res = thread_service_->ListThreads( + std::stoi(limit.value_or("20")), order.value_or("desc"), + after.value_or(""), before.value_or("")); if (res.has_error()) { Json::Value root; diff --git a/engine/controllers/threads.h b/engine/controllers/threads.h index 92c509525..f26e35785 100644 --- a/engine/controllers/threads.h +++ b/engine/controllers/threads.h @@ -34,7 +34,7 @@ class Threads : public drogon::HttpController { void ListThreads(const HttpRequestPtr& req, std::function&& callback, - std::optional limit, + std::optional limit, std::optional order, std::optional after, std::optional before) const; diff --git a/engine/main.cc b/engine/main.cc index 0177a2143..894e9d146 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -1,6 +1,7 @@ #include #include #include +#include "controllers/assistants.h" #include "controllers/configs.h" #include "controllers/engines.h" #include "controllers/events.h" @@ -14,6 +15,7 @@ #include "migrations/migration_manager.h" #include "repositories/message_fs_repository.h" #include "repositories/thread_fs_repository.h" +#include "services/assistant_service.h" #include "services/config_service.h" #include "services/file_watcher_service.h" #include "services/message_service.h" @@ -124,6 +126,7 @@ void RunServer(std::optional port, bool ignore_cout) { auto thread_repo = std::make_shared( file_manager_utils::GetCortexDataPath()); + auto assistant_srv = std::make_shared(thread_repo); auto thread_srv = std::make_shared(thread_repo); auto message_srv = std::make_shared(msg_repo); @@ -142,6 +145,7 @@ void RunServer(std::optional port, bool ignore_cout) { file_watcher_srv->start(); // initialize custom controllers + auto assistant_ctl = std::make_shared(assistant_srv); auto thread_ctl = std::make_shared(thread_srv, message_srv); auto message_ctl = std::make_shared(message_srv); auto engine_ctl = std::make_shared(engine_service); @@ -153,6 +157,7 @@ void RunServer(std::optional port, bool ignore_cout) { std::make_shared(inference_svc, engine_service); auto config_ctl = std::make_shared(config_service); + drogon::app().registerController(assistant_ctl); drogon::app().registerController(thread_ctl); drogon::app().registerController(message_ctl); drogon::app().registerController(engine_ctl); diff --git a/engine/repositories/message_fs_repository.cc b/engine/repositories/message_fs_repository.cc index e576a7695..388409390 100644 --- a/engine/repositories/message_fs_repository.cc +++ b/engine/repositories/message_fs_repository.cc @@ -1,4 +1,5 @@ #include "message_fs_repository.h" +#include #include #include #include "utils/result.hpp" @@ -52,7 +53,61 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, auto mutex = GrabMutex(thread_id); std::shared_lock lock(*mutex); - return ReadMessageFromFile(thread_id); + auto read_result = ReadMessageFromFile(thread_id); + if (read_result.has_error()) { + return cpp::fail(read_result.error()); + } + + std::vector messages = std::move(read_result.value()); + + if (!run_id.empty()) { + messages.erase(std::remove_if(messages.begin(), messages.end(), + [&run_id](const OpenAi::Message& msg) { + return msg.run_id != run_id; + }), + messages.end()); + } + + std::sort(messages.begin(), messages.end(), + [&order](const OpenAi::Message& a, const OpenAi::Message& b) { + if (order == "desc") { + return a.created_at > b.created_at; + } + return a.created_at < b.created_at; + }); + + auto start_it = messages.begin(); + auto end_it = messages.end(); + + if (!after.empty()) { + start_it = std::find_if( + messages.begin(), messages.end(), + [&after](const OpenAi::Message& msg) { return msg.id == after; }); + if (start_it != messages.end()) { + ++start_it; // Start from the message after the 'after' message + } else { + start_it = messages.begin(); + } + } + + if (!before.empty()) { + end_it = std::find_if( + messages.begin(), messages.end(), + [&before](const OpenAi::Message& msg) { return msg.id == before; }); + } + + std::vector result; + size_t distance = std::distance(start_it, end_it); + size_t limit_size = static_cast(limit); + CTL_INF("Distance: " + std::to_string(distance) + + ", limit_size: " + std::to_string(limit_size)); + result.reserve(distance < limit_size ? distance : limit_size); + + for (auto it = start_it; it != end_it && result.size() < limit_size; ++it) { + result.push_back(std::move(*it)); + } + + return result; } cpp::result MessageFsRepository::RetrieveMessage( diff --git a/engine/repositories/thread_fs_repository.cc b/engine/repositories/thread_fs_repository.cc index 64dad6ea5..6b75db8e4 100644 --- a/engine/repositories/thread_fs_repository.cc +++ b/engine/repositories/thread_fs_repository.cc @@ -1,37 +1,67 @@ #include "thread_fs_repository.h" #include #include +#include "common/assistant.h" +#include "utils/result.hpp" cpp::result, std::string> ThreadFsRepository::ListThreads(uint8_t limit, const std::string& order, const std::string& after, const std::string& before) const { - CTL_INF("ListThreads: limit=" + std::to_string(limit) + ", order=" + order + - ", after=" + after + ", before=" + before); std::vector threads; try { auto thread_container_path = data_folder_path_ / kThreadContainerFolderName; + std::vector all_threads; + + // First load all valid threads for (const auto& entry : std::filesystem::directory_iterator(thread_container_path)) { if (!entry.is_directory()) continue; - if (!std::filesystem::exists(entry.path() / kThreadFileName)) + auto thread_file = entry.path() / kThreadFileName; + if (!std::filesystem::exists(thread_file)) continue; auto current_thread_id = entry.path().filename().string(); - CTL_INF("ListThreads: Found thread: " + current_thread_id); - std::shared_lock thread_lock(GrabThreadMutex(current_thread_id)); + // Apply pagination filters + if (!after.empty() && current_thread_id <= after) + continue; + if (!before.empty() && current_thread_id >= before) + continue; + + std::shared_lock thread_lock(GrabThreadMutex(current_thread_id)); auto thread_result = LoadThread(current_thread_id); + if (thread_result.has_value()) { - threads.push_back(std::move(thread_result.value())); + all_threads.push_back(std::move(thread_result.value())); } thread_lock.unlock(); } + // Sort threads based on order parameter using created_at + if (order == "desc") { + std::sort(all_threads.begin(), all_threads.end(), + [](const OpenAi::Thread& a, const OpenAi::Thread& b) { + return a.created_at > b.created_at; // Descending order + }); + } else { + std::sort(all_threads.begin(), all_threads.end(), + [](const OpenAi::Thread& a, const OpenAi::Thread& b) { + return a.created_at < b.created_at; // Ascending order + }); + } + + // Apply limit + size_t thread_count = + std::min(static_cast(limit), all_threads.size()); + for (size_t i = 0; i < thread_count; i++) { + threads.push_back(std::move(all_threads[i])); + } + return threads; } catch (const std::exception& e) { return cpp::fail(std::string("Failed to list threads: ") + e.what()); @@ -164,3 +194,85 @@ cpp::result ThreadFsRepository::DeleteThread( thread_mutexes_.erase(thread_id); return {}; } + +cpp::result +ThreadFsRepository::LoadAssistant(const std::string& thread_id) const { + auto path = GetThreadPath(thread_id) / kThreadFileName; + if (!std::filesystem::exists(path)) { + return cpp::fail("Path does not exist: " + path.string()); + } + + std::shared_lock thread_lock(GrabThreadMutex(thread_id)); + try { + std::ifstream file(path); + if (!file.is_open()) { + return cpp::fail("Failed to open file: " + path.string()); + } + + Json::Value root; + Json::CharReaderBuilder builder; + JSONCPP_STRING errs; + + if (!parseFromStream(builder, file, &root, &errs)) { + return cpp::fail("Failed to parse JSON: " + errs); + } + + Json::Value assistants = root["assistants"]; + if (!assistants.isArray()) { + return cpp::fail("Assistants field is not an array"); + } + + if (assistants.empty()) { + return cpp::fail("Assistant not found in thread: " + thread_id); + } + + return OpenAi::JanAssistant::FromJson(std::move(assistants[0])); + } catch (const std::exception& e) { + return cpp::fail("Failed to load assistant: " + std::string(e.what())); + } +} + +cpp::result +ThreadFsRepository::ModifyAssistant(const std::string& thread_id, + const OpenAi::JanAssistant& assistant) { + std::unique_lock lock(GrabThreadMutex(thread_id)); + + // Load the existing thread + auto thread_result = LoadThread(thread_id); + if (!thread_result.has_value()) { + return cpp::fail("Failed to load thread: " + thread_result.error()); + } + + auto& thread = thread_result.value(); + if (thread.ToJson() + ->get("assistants", Json::Value(Json::arrayValue)) + .empty()) { + return cpp::fail("No assistants found in thread: " + thread_id); + } + + thread.assistants = {assistant}; + + auto save_result = SaveThread(thread); + if (!save_result.has_value()) { + return cpp::fail("Failed to save thread: " + save_result.error()); + } + + return assistant; +} + +cpp::result ThreadFsRepository::CreateAssistant( + const std::string& thread_id, const OpenAi::JanAssistant& assistant) { + std::unique_lock lock(GrabThreadMutex(thread_id)); + + // Load the existing thread + auto thread_result = LoadThread(thread_id); + if (!thread_result.has_value()) { + return cpp::fail("Failed to load thread: " + thread_result.error()); + } + + auto& thread = thread_result.value(); + thread.assistants = {assistant}; + + // Save the modified thread + return SaveThread(thread); +} diff --git a/engine/repositories/thread_fs_repository.h b/engine/repositories/thread_fs_repository.h index d834b8e44..b6f6032fa 100644 --- a/engine/repositories/thread_fs_repository.h +++ b/engine/repositories/thread_fs_repository.h @@ -3,11 +3,26 @@ #include #include #include +#include "common/assistant.h" #include "common/repository/thread_repository.h" #include "common/thread.h" #include "utils/logging_utils.h" -class ThreadFsRepository : public ThreadRepository { +// this interface is for backward supporting Jan +class AssistantBackwardCompatibleSupport { + public: + virtual cpp::result LoadAssistant( + const std::string& thread_id) const = 0; + + virtual cpp::result ModifyAssistant( + const std::string& thread_id, const OpenAi::JanAssistant& assistant) = 0; + + virtual cpp::result CreateAssistant( + const std::string& thread_id, const OpenAi::JanAssistant& assistant) = 0; +}; + +class ThreadFsRepository : public ThreadRepository, + public AssistantBackwardCompatibleSupport { private: constexpr static auto kThreadFileName = "thread.json"; constexpr static auto kThreadContainerFolderName = "threads"; @@ -58,5 +73,17 @@ class ThreadFsRepository : public ThreadRepository { cpp::result DeleteThread( const std::string& thread_id) override; + // for supporting Jan + cpp::result LoadAssistant( + const std::string& thread_id) const override; + + cpp::result ModifyAssistant( + const std::string& thread_id, + const OpenAi::JanAssistant& assistant) override; + + cpp::result CreateAssistant( + const std::string& thread_id, + const OpenAi::JanAssistant& assistant) override; + ~ThreadFsRepository() = default; }; diff --git a/engine/services/assistant_service.cc b/engine/services/assistant_service.cc new file mode 100644 index 000000000..e769bf23f --- /dev/null +++ b/engine/services/assistant_service.cc @@ -0,0 +1,28 @@ +#include "assistant_service.h" +#include "utils/logging_utils.h" + +cpp::result +AssistantService::CreateAssistant(const std::string& thread_id, + const OpenAi::JanAssistant& assistant) { + CTL_INF("CreateAssistant: " + thread_id); + auto res = thread_repository_->CreateAssistant(thread_id, assistant); + + if (res.has_error()) { + return cpp::fail(res.error()); + } + + return assistant; +} + +cpp::result +AssistantService::RetrieveAssistant(const std::string& assistant_id) const { + CTL_INF("RetrieveAssistant: " + assistant_id); + return thread_repository_->LoadAssistant(assistant_id); +} + +cpp::result +AssistantService::ModifyAssistant(const std::string& thread_id, + const OpenAi::JanAssistant& assistant) { + CTL_INF("RetrieveAssistant: " + thread_id); + return thread_repository_->ModifyAssistant(thread_id, assistant); +} diff --git a/engine/services/assistant_service.h b/engine/services/assistant_service.h new file mode 100644 index 000000000..e7f7414d1 --- /dev/null +++ b/engine/services/assistant_service.h @@ -0,0 +1,24 @@ +#pragma once + +#include "common/assistant.h" +#include "repositories/thread_fs_repository.h" +#include "utils/result.hpp" + +class AssistantService { + public: + explicit AssistantService( + std::shared_ptr thread_repository) + : thread_repository_{thread_repository} {} + + cpp::result CreateAssistant( + const std::string& thread_id, const OpenAi::JanAssistant& assistant); + + cpp::result RetrieveAssistant( + const std::string& thread_id) const; + + cpp::result ModifyAssistant( + const std::string& thread_id, const OpenAi::JanAssistant& assistant); + + private: + std::shared_ptr thread_repository_; +}; diff --git a/engine/services/message_service.cc b/engine/services/message_service.cc index dfad74236..ddc9e096b 100644 --- a/engine/services/message_service.cc +++ b/engine/services/message_service.cc @@ -71,7 +71,10 @@ cpp::result MessageService::RetrieveMessage( cpp::result MessageService::ModifyMessage( const std::string& thread_id, const std::string& message_id, - std::optional metadata) { + std::optional metadata, + std::optional>>> + content) { LOG_TRACE << "ModifyMessage for thread " << thread_id << ", message " << message_id; auto msg = RetrieveMessage(thread_id, message_id); @@ -79,7 +82,24 @@ cpp::result MessageService::ModifyMessage( return cpp::fail("Failed to retrieve message: " + msg.error()); } - msg->metadata = metadata.value(); + if (metadata.has_value()) { + msg->metadata = metadata.value(); + } + if (content.has_value()) { + std::vector> content_list{}; + + // If content is string + if (std::holds_alternative(*content)) { + auto text_content = std::make_unique(); + text_content->text.value = std::get(*content); + content_list.push_back(std::move(text_content)); + } else { + content_list = std::move( + std::get>>(*content)); + } + + msg->content = std::move(content_list); + } auto ptr = &msg.value(); auto res = message_repository_->ModifyMessage(msg.value()); diff --git a/engine/services/message_service.h b/engine/services/message_service.h index 6c4880f32..456cdb3a3 100644 --- a/engine/services/message_service.h +++ b/engine/services/message_service.h @@ -21,16 +21,19 @@ class MessageService { std::optional> messages); cpp::result, std::string> ListMessages( - const std::string& thread_id, uint8_t limit = 20, - const std::string& order = "desc", const std::string& after = "", - const std::string& before = "", const std::string& run_id = "") const; + const std::string& thread_id, uint8_t limit, const std::string& order, + const std::string& after, const std::string& before, + const std::string& run_id) const; cpp::result RetrieveMessage( const std::string& thread_id, const std::string& message_id) const; cpp::result ModifyMessage( const std::string& thread_id, const std::string& message_id, - std::optional metadata); + std::optional metadata, + std::optional>>> + content); cpp::result DeleteMessage( const std::string& thread_id, const std::string& message_id); From a6d9da3a9a0e5a132b16836b9016577f46be48ae Mon Sep 17 00:00:00 2001 From: hiento09 <136591877+hiento09@users.noreply.github.com> Date: Fri, 6 Dec 2024 12:57:38 +0700 Subject: [PATCH 15/44] fix: cortex.cpp nightly test with cortex.llamacpp (#1771) Co-authored-by: Hien To --- .github/workflows/cortex-cpp-quality-gate.yml | 4 ++-- engine/e2e-test/cortex-llamacpp-e2e-nightly.py | 17 +++-------------- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/.github/workflows/cortex-cpp-quality-gate.yml b/.github/workflows/cortex-cpp-quality-gate.yml index e9fd8664b..316160ce5 100644 --- a/.github/workflows/cortex-cpp-quality-gate.yml +++ b/.github/workflows/cortex-cpp-quality-gate.yml @@ -124,7 +124,7 @@ jobs: cat ~/.cortexrc - name: Run e2e tests - if: runner.os != 'Windows' && github.event.pull_request.draft == false + if: github.event_name != 'schedule' && runner.os != 'Windows' && github.event.pull_request.draft == false run: | cd engine cp build/cortex build/cortex-nightly @@ -138,7 +138,7 @@ jobs: GITHUB_TOKEN: ${{ secrets.PAT_SERVICE_ACCOUNT }} - name: Run e2e tests - if: runner.os == 'Windows' && github.event.pull_request.draft == false + if: github.event_name != 'schedule' && runner.os == 'Windows' && github.event.pull_request.draft == false run: | cd engine cp build/cortex.exe build/cortex-nightly.exe diff --git a/engine/e2e-test/cortex-llamacpp-e2e-nightly.py b/engine/e2e-test/cortex-llamacpp-e2e-nightly.py index 9be34519a..0511277f3 100644 --- a/engine/e2e-test/cortex-llamacpp-e2e-nightly.py +++ b/engine/e2e-test/cortex-llamacpp-e2e-nightly.py @@ -4,30 +4,19 @@ ### e2e tests are expensive, have to keep engines tests in order from test_api_engine_list import TestApiEngineList from test_api_engine_install_nightly import TestApiEngineInstall -from test_api_engine_get import TestApiEngineGet - -### models, keeps in order, note that we only uninstall engine after finishing all models test -from test_api_model_pull_direct_url import TestApiModelPullDirectUrl -from test_api_model_start import TestApiModelStart -from test_api_model_stop import TestApiModelStop -from test_api_model_get import TestApiModelGet -from test_api_model_list import TestApiModelList -from test_api_model_update import TestApiModelUpdate -from test_api_model_delete import TestApiModelDelete +from test_api_model import TestApiModel from test_api_model_import import TestApiModelImport -from test_api_engine_uninstall import TestApiEngineUninstall ### from test_cli_engine_get import TestCliEngineGet from test_cli_engine_install_nightly import TestCliEngineInstall from test_cli_engine_list import TestCliEngineList -from test_cli_model_delete import TestCliModelDelete -from test_cli_model_pull_direct_url import TestCliModelPullDirectUrl +from test_cli_engine_uninstall import TestCliEngineUninstall +from test_cli_model import TestCliModel from test_cli_server_start import TestCliServerStart from test_cortex_update import TestCortexUpdate from test_create_log_folder import TestCreateLogFolder from test_cli_model_import import TestCliModelImport -from test_cli_engine_uninstall import TestCliEngineUninstall if __name__ == "__main__": sys.exit(pytest.main([__file__, "-v"])) From 97e56360ed2128eacc035b7bff34a583ce057a21 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Fri, 6 Dec 2024 15:14:52 +0700 Subject: [PATCH 16/44] chore: add more checks and logs when load file (#1772) --- engine/controllers/models.cc | 4 ++-- engine/services/model_service.cc | 17 ++++++++-------- engine/services/model_service.h | 2 +- engine/utils/hardware/gguf/gguf_file.h | 20 ++++++++++--------- .../utils/hardware/gguf/gguf_file_estimate.h | 10 ++++++---- 5 files changed, 29 insertions(+), 24 deletions(-) diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index de14886da..3f91da848 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -184,8 +184,8 @@ void Models::ListModel( obj["model"] = model_entry.model; obj["model"] = model_entry.model; auto es = model_service_->GetEstimation(model_entry.model); - if (es.has_value()) { - obj["recommendation"] = hardware::ToJson(es.value()); + if (es.has_value() && !!es.value()) { + obj["recommendation"] = hardware::ToJson(*(es.value())); } data.append(std::move(obj)); yaml_handler.Reset(); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index d81a9b649..7f79ddaf7 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -341,9 +341,10 @@ cpp::result ModelService::HandleDownloadUrlAsync( return download_service_->AddTask(downloadTask, on_finished); } -cpp::result ModelService::GetEstimation( - const std::string& model_handle, const std::string& kv_cache, int n_batch, - int n_ubatch) { +cpp::result, std::string> +ModelService::GetEstimation(const std::string& model_handle, + const std::string& kv_cache, int n_batch, + int n_ubatch) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; cortex::db::Models modellist_handler; @@ -918,7 +919,7 @@ cpp::result ModelService::GetModelStatus( if (status == drogon::k200OK) { return true; } else { - CTL_ERR("Model failed to get model status with status code: " << status); + CTL_WRN("Model failed to get model status with status code: " << status); return cpp::fail("Model failed to get model status: " + data["message"].asString()); } @@ -1146,13 +1147,13 @@ ModelService::MayFallbackToCpu(const std::string& model_path, int ngl, .free_vram_MiB = free_vram_MiB}; auto es = hardware::EstimateLLaMACppRun(model_path, rc); - if (es.gpu_mode.vram_MiB > free_vram_MiB && is_cuda) { - CTL_WRN("Not enough VRAM - " << "required: " << es.gpu_mode.vram_MiB + if (!!es && (*es).gpu_mode.vram_MiB > free_vram_MiB && is_cuda) { + CTL_WRN("Not enough VRAM - " << "required: " << (*es).gpu_mode.vram_MiB << ", available: " << free_vram_MiB); } - if (es.cpu_mode.ram_MiB > free_ram_MiB) { - CTL_WRN("Not enough RAM - " << "required: " << es.cpu_mode.ram_MiB + if (!!es && (*es).cpu_mode.ram_MiB > free_ram_MiB) { + CTL_WRN("Not enough RAM - " << "required: " << (*es).cpu_mode.ram_MiB << ", available: " << free_ram_MiB); } diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 7235d5a0a..e2638fd1f 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -97,7 +97,7 @@ class ModelService { bool HasModel(const std::string& id) const; - cpp::result GetEstimation( + cpp::result, std::string> GetEstimation( const std::string& model_handle, const std::string& kv_cache = "f16", int n_batch = 2048, int n_ubatch = 2048); diff --git a/engine/utils/hardware/gguf/gguf_file.h b/engine/utils/hardware/gguf/gguf_file.h index 1263debf2..361668242 100644 --- a/engine/utils/hardware/gguf/gguf_file.h +++ b/engine/utils/hardware/gguf/gguf_file.h @@ -11,6 +11,7 @@ #include #include #include +#include #ifdef _WIN32 #include @@ -23,13 +24,14 @@ #include "ggml.h" #include "utils/string_utils.h" +#include "utils/logging_utils.h" // #define GGUF_LOG(msg) \ // do { \ // std::cout << __FILE__ << "(@" << __LINE__ << "): " << msg << '\n'; \ // } while (false) -#define GGUF_LOG(msg) +#define GGUF_LOG(msg) namespace hardware { #undef min #undef max @@ -169,8 +171,6 @@ inline std::string to_string(const GGUFMetadataKV& kv) { return "Invalid type "; } - - struct GGUFTensorInfo { /* Basic */ std::string name; @@ -208,14 +208,14 @@ struct GGUFHelper { CreateFileA(file_path.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr); if (file_handle == INVALID_HANDLE_VALUE) { - std::cout << "Failed to open file" << std::endl; + CTL_INF("Failed to open file: " << file_path); return false; } // Get the file size LARGE_INTEGER file_size_struct; if (!GetFileSizeEx(file_handle, &file_size_struct)) { CloseHandle(file_handle); - std::cout << "Failed to open file" << std::endl; + CTL_INF("Failed to get file size: " << file_path); return false; } file_size = static_cast(file_size_struct.QuadPart); @@ -225,7 +225,7 @@ struct GGUFHelper { CreateFileMappingA(file_handle, nullptr, PAGE_READONLY, 0, 0, nullptr); if (file_mapping == nullptr) { CloseHandle(file_handle); - std::cout << "Failed to create file mapping" << std::endl; + CTL_INF("Failed to create file mapping: " << file_path); return false; } @@ -235,7 +235,7 @@ struct GGUFHelper { if (data == nullptr) { CloseHandle(file_mapping); CloseHandle(file_handle); - std::cout << "Failed to map file" << std::endl; + CTL_INF("Failed to map file:: " << file_path); return false; } @@ -479,10 +479,12 @@ struct GGUFFile { double model_bits_per_weight; }; -inline GGUFFile ParseGgufFile(const std::string& path) { +inline std::optional ParseGgufFile(const std::string& path) { GGUFFile gf; GGUFHelper h; - h.OpenAndMMap(path); + if(!h.OpenAndMMap(path)) { + return std::nullopt; + } GGUFMagic magic = h.Read(); // GGUF_LOG("magic: " << magic); diff --git a/engine/utils/hardware/gguf/gguf_file_estimate.h b/engine/utils/hardware/gguf/gguf_file_estimate.h index fde0b0ac0..12a7e72e1 100644 --- a/engine/utils/hardware/gguf/gguf_file_estimate.h +++ b/engine/utils/hardware/gguf/gguf_file_estimate.h @@ -62,20 +62,22 @@ inline float GetQuantBit(const std::string& kv_cache_t) { return 16.0; } -inline Estimation EstimateLLaMACppRun(const std::string& file_path, - const RunConfig& rc) { +inline std::optional EstimateLLaMACppRun( + const std::string& file_path, const RunConfig& rc) { Estimation res; // token_embeddings_size = n_vocab * embedding_length * 2 * quant_bit/16 bytes //RAM = token_embeddings_size + ((total_ngl-ngl) >=1 ? Output_layer_size + (total_ngl - ngl - 1 ) / (total_ngl-1) * (total_file_size - token_embeddings_size - Output_layer_size) : 0 ) (bytes) // VRAM = total_file_size - RAM (bytes) auto gf = ParseGgufFile(file_path); + if (!gf) + return std::nullopt; int32_t embedding_length = 0; int64_t n_vocab = 0; int32_t num_block = 0; int32_t total_ngl = 0; auto file_size = std::filesystem::file_size(file_path); - for (auto const& kv : gf.header.metadata_kv) { + for (auto const& kv : (*gf).header.metadata_kv) { if (kv.key.find("embedding_length") != std::string::npos) { embedding_length = std::any_cast(kv.value); } else if (kv.key == "tokenizer.ggml.tokens") { @@ -92,7 +94,7 @@ inline Estimation EstimateLLaMACppRun(const std::string& file_path, int32_t quant_bit_in = 0; int32_t quant_bit_out = 0; - for (auto const& ti : gf.tensor_infos) { + for (auto const& ti : (*gf).tensor_infos) { if (ti->name == "output.weight") { quant_bit_out = GetQuantBit(ti->type); // std::cout << ti->type << std::endl; From 4700f8d212c5596250fdd835e701d7b1e219a636 Mon Sep 17 00:00:00 2001 From: NamH Date: Fri, 6 Dec 2024 15:50:58 +0700 Subject: [PATCH 17/44] fix: create assistant (#1773) * fix: create assistant * fix ci --- engine/common/thread.h | 13 +++++++++++++ engine/controllers/hardware.cc | 6 ++---- engine/controllers/threads.cc | 4 +++- engine/database/hardware.cc | 9 ++++----- engine/database/models.cc | 3 ++- engine/test/components/test_cortex_config.cc | 4 ++++ engine/test/components/test_cortex_upd_cmd.cc | 3 ++- .../test_file_manager_config_yaml_utils.cc | 1 + engine/utils/config_yaml_utils.cc | 7 ++++++- engine/utils/config_yaml_utils.h | 6 +----- 10 files changed, 38 insertions(+), 18 deletions(-) diff --git a/engine/common/thread.h b/engine/common/thread.h index 60f408635..480c0ba78 100644 --- a/engine/common/thread.h +++ b/engine/common/thread.h @@ -156,6 +156,19 @@ struct Thread : JsonSerializable { } json["metadata"] = metadata_json; + if (assistants.has_value()) { + Json::Value assistants_json(Json::arrayValue); + for (auto& assistant : assistants.value()) { + auto assistant_result = assistant.ToJson(); + if (assistant_result.has_error()) { + return cpp::fail("Failed to serialize assistant: " + + assistant_result.error()); + } + assistants_json.append(assistant_result.value()); + } + json["assistants"] = assistants_json; + } + return json; } catch (const std::exception& e) { return cpp::fail(std::string("ToJson failed: ") + e.what()); diff --git a/engine/controllers/hardware.cc b/engine/controllers/hardware.cc index 4f5cc2879..39a109750 100644 --- a/engine/controllers/hardware.cc +++ b/engine/controllers/hardware.cc @@ -1,8 +1,6 @@ #include "hardware.h" -#include "common/hardware_config.h" #include "utils/cortex_utils.h" -#include "utils/file_manager_utils.h" -#include "utils/scope_exit.h" +#include "utils/logging_utils.h" void Hardware::GetHardwareInfo( const HttpRequestPtr& req, @@ -73,4 +71,4 @@ void Hardware::Activate( callback(resp); app().quit(); #endif -} \ No newline at end of file +} diff --git a/engine/controllers/threads.cc b/engine/controllers/threads.cc index 1cd3aaeef..e130dad88 100644 --- a/engine/controllers/threads.cc +++ b/engine/controllers/threads.cc @@ -25,6 +25,7 @@ void Threads::ListThreads( Json::Value msg_arr(Json::arrayValue); for (auto& msg : res.value()) { if (auto it = msg.ToJson(); it.has_value()) { + it->removeMember("assistants"); msg_arr.append(it.value()); } else { CTL_WRN("Failed to convert message to json: " + it.error()); @@ -114,8 +115,9 @@ void Threads::RetrieveThread( resp->setStatusCode(k400BadRequest); callback(resp); } else { + thread_to_json->removeMember("assistants"); auto resp = - cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + cortex_utils::CreateCortexHttpJsonResponse(thread_to_json.value()); resp->setStatusCode(k200OK); callback(resp); } diff --git a/engine/database/hardware.cc b/engine/database/hardware.cc index ee68749d5..ff2eb853a 100644 --- a/engine/database/hardware.cc +++ b/engine/database/hardware.cc @@ -1,14 +1,13 @@ #include "hardware.h" #include "database.h" +#include "utils/logging_utils.h" #include "utils/scope_exit.h" namespace cortex::db { -Hardwares::Hardwares() : db_(cortex::db::Database::GetInstance().db()) { -} +Hardwares::Hardwares() : db_(cortex::db::Database::GetInstance().db()) {} -Hardwares::Hardwares(SQLite::Database& db) : db_(db) { -} +Hardwares::Hardwares(SQLite::Database& db) : db_(db) {} Hardwares::~Hardwares() {} @@ -94,4 +93,4 @@ cpp::result Hardwares::DeleteHardwareEntry( return cpp::fail(e.what()); } } -} // namespace cortex::db \ No newline at end of file +} // namespace cortex::db diff --git a/engine/database/models.cc b/engine/database/models.cc index fb2128396..8c8be9eaf 100644 --- a/engine/database/models.cc +++ b/engine/database/models.cc @@ -2,6 +2,7 @@ #include #include #include "database.h" +#include "utils/logging_utils.h" #include "utils/result.hpp" #include "utils/scope_exit.h" @@ -339,4 +340,4 @@ bool Models::HasModel(const std::string& identifier) const { } } -} // namespace cortex::db \ No newline at end of file +} // namespace cortex::db diff --git a/engine/test/components/test_cortex_config.cc b/engine/test/components/test_cortex_config.cc index 04f3ddf33..f4bb7c1dc 100644 --- a/engine/test/components/test_cortex_config.cc +++ b/engine/test/components/test_cortex_config.cc @@ -1,3 +1,7 @@ +#include +#include +#include +#include #include "gtest/gtest.h" #include "utils/config_yaml_utils.h" diff --git a/engine/test/components/test_cortex_upd_cmd.cc b/engine/test/components/test_cortex_upd_cmd.cc index 772889fbd..06eff4a98 100644 --- a/engine/test/components/test_cortex_upd_cmd.cc +++ b/engine/test/components/test_cortex_upd_cmd.cc @@ -1,4 +1,5 @@ -#include "cli/commands/cortex_upd_cmd.h" +#include +#include #include "gtest/gtest.h" namespace { diff --git a/engine/test/components/test_file_manager_config_yaml_utils.cc b/engine/test/components/test_file_manager_config_yaml_utils.cc index f2c8c4075..ccbc92ec8 100644 --- a/engine/test/components/test_file_manager_config_yaml_utils.cc +++ b/engine/test/components/test_file_manager_config_yaml_utils.cc @@ -1,6 +1,7 @@ #include #include #include +#include #include "utils/config_yaml_utils.h" #include "utils/file_manager_utils.h" diff --git a/engine/utils/config_yaml_utils.cc b/engine/utils/config_yaml_utils.cc index 4d6f47ebe..af671d9e6 100644 --- a/engine/utils/config_yaml_utils.cc +++ b/engine/utils/config_yaml_utils.cc @@ -1,4 +1,9 @@ #include "config_yaml_utils.h" +#include +#include +#include +#include "utils/logging_utils.h" +#include "yaml-cpp/yaml.h" namespace config_yaml_utils { cpp::result CortexConfigMgr::DumpYamlConfig( @@ -174,4 +179,4 @@ CortexConfig CortexConfigMgr::FromYaml(const std::string& path, } } -} // namespace config_yaml_utils \ No newline at end of file +} // namespace config_yaml_utils diff --git a/engine/utils/config_yaml_utils.h b/engine/utils/config_yaml_utils.h index aa1b4027e..ffb3a31fa 100644 --- a/engine/utils/config_yaml_utils.h +++ b/engine/utils/config_yaml_utils.h @@ -1,13 +1,9 @@ #pragma once -#include -#include -#include #include #include -#include "utils/logging_utils.h" +#include #include "utils/result.hpp" -#include "yaml-cpp/yaml.h" namespace config_yaml_utils { From e4c6a6ff0229155f1b880c77c3b03b510ee1b2ca Mon Sep 17 00:00:00 2001 From: NamH Date: Sun, 8 Dec 2024 23:07:47 +0700 Subject: [PATCH 18/44] fix: message created at wrong value (#1774) --- engine/services/message_service.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/engine/services/message_service.cc b/engine/services/message_service.cc index ddc9e096b..5b871f447 100644 --- a/engine/services/message_service.cc +++ b/engine/services/message_service.cc @@ -11,7 +11,7 @@ cpp::result MessageService::CreateMessage( std::optional metadata) { LOG_TRACE << "CreateMessage for thread " << thread_id; - auto seconds_since_epoch = + uint32_t seconds_since_epoch = std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) .count(); @@ -33,7 +33,7 @@ cpp::result MessageService::CreateMessage( OpenAi::Message msg; msg.id = msg_id; msg.object = "thread.message"; - msg.created_at = 0; + msg.created_at = seconds_since_epoch; msg.thread_id = thread_id; msg.status = OpenAi::Status::COMPLETED; msg.completed_at = seconds_since_epoch; From 9694ec8c607dad57b75298f6361f8dfed3d00a67 Mon Sep 17 00:00:00 2001 From: NamH Date: Mon, 9 Dec 2024 09:32:53 +0700 Subject: [PATCH 19/44] feat: add ssl cert configuration (#1776) --- engine/main.cc | 18 ++++++++++++++++++ engine/utils/config_yaml_utils.cc | 10 ++++++++-- engine/utils/config_yaml_utils.h | 2 ++ engine/utils/file_manager_utils.cc | 4 +++- 4 files changed, 31 insertions(+), 3 deletions(-) diff --git a/engine/main.cc b/engine/main.cc index 894e9d146..93aa3b8e7 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -219,6 +219,24 @@ void RunServer(std::optional port, bool ignore_cout) { resp->addHeader("Access-Control-Allow-Methods", "*"); }); + // ssl + auto ssl_cert_path = config.sslCertPath; + auto ssl_key_path = config.sslKeyPath; + + if (!ssl_cert_path.empty() && !ssl_key_path.empty()) { + CTL_INF("SSL cert path: " << ssl_cert_path); + CTL_INF("SSL key path: " << ssl_key_path); + + if (!std::filesystem::exists(ssl_cert_path) || + !std::filesystem::exists(ssl_key_path)) { + CTL_ERR("SSL cert or key file not exist at specified path! Ignore.."); + return; + } + + drogon::app().setSSLFiles(ssl_cert_path, ssl_key_path); + drogon::app().addListener(config.apiServerHost, 443, true); + } + drogon::app().run(); if (hw_service->ShouldRestart()) { CTL_INF("Restart to update hardware configuration"); diff --git a/engine/utils/config_yaml_utils.cc b/engine/utils/config_yaml_utils.cc index af671d9e6..ed6437256 100644 --- a/engine/utils/config_yaml_utils.cc +++ b/engine/utils/config_yaml_utils.cc @@ -47,6 +47,8 @@ cpp::result CortexConfigMgr::DumpYamlConfig( node["noProxy"] = config.noProxy; node["verifyPeerSsl"] = config.verifyPeerSsl; node["verifyHostSsl"] = config.verifyHostSsl; + node["sslCertPath"] = config.sslCertPath; + node["sslKeyPath"] = config.sslKeyPath; out_file << node; out_file.close(); @@ -81,7 +83,7 @@ CortexConfig CortexConfigMgr::FromYaml(const std::string& path, !node["proxyUsername"] || !node["proxyPassword"] || !node["verifyPeerSsl"] || !node["verifyHostSsl"] || !node["verifyProxySsl"] || !node["verifyProxyHostSsl"] || - !node["noProxy"]); + !node["sslCertPath"] || !node["sslKeyPath"] || !node["noProxy"]); CortexConfig config = { .logFolderPath = node["logFolderPath"] @@ -164,6 +166,11 @@ CortexConfig CortexConfigMgr::FromYaml(const std::string& path, .verifyHostSsl = node["verifyHostSsl"] ? node["verifyHostSsl"].as() : default_cfg.verifyHostSsl, + .sslCertPath = node["sslCertPath"] + ? node["sslCertPath"].as() + : default_cfg.sslCertPath, + .sslKeyPath = node["sslKeyPath"] ? node["sslKeyPath"].as() + : default_cfg.sslKeyPath, }; if (should_update_config) { l.unlock(); @@ -178,5 +185,4 @@ CortexConfig CortexConfigMgr::FromYaml(const std::string& path, throw; } } - } // namespace config_yaml_utils diff --git a/engine/utils/config_yaml_utils.h b/engine/utils/config_yaml_utils.h index ffb3a31fa..d36cc48e0 100644 --- a/engine/utils/config_yaml_utils.h +++ b/engine/utils/config_yaml_utils.h @@ -55,6 +55,8 @@ struct CortexConfig { bool verifyPeerSsl; bool verifyHostSsl; + std::string sslCertPath; + std::string sslKeyPath; }; class CortexConfigMgr { diff --git a/engine/utils/file_manager_utils.cc b/engine/utils/file_manager_utils.cc index 11128a275..ca3d0c07b 100644 --- a/engine/utils/file_manager_utils.cc +++ b/engine/utils/file_manager_utils.cc @@ -185,6 +185,8 @@ config_yaml_utils::CortexConfig GetDefaultConfig() { .noProxy = config_yaml_utils::kDefaultNoProxy, .verifyPeerSsl = true, .verifyHostSsl = true, + .sslCertPath = "", + .sslKeyPath = "", }; } @@ -369,4 +371,4 @@ std::filesystem::path ToAbsoluteCortexDataPath( const std::filesystem::path& path) { return GetAbsolutePath(GetCortexDataPath(), path); } -} // namespace file_manager_utils \ No newline at end of file +} // namespace file_manager_utils From 0b5b9aa298b7792a7e29a1b07d3941db71f244a1 Mon Sep 17 00:00:00 2001 From: NamH Date: Mon, 9 Dec 2024 16:50:52 +0700 Subject: [PATCH 20/44] fix: sort messages by its ulid instead of created_at (#1778) --- engine/repositories/message_fs_repository.cc | 73 +++++++++++--------- 1 file changed, 42 insertions(+), 31 deletions(-) diff --git a/engine/repositories/message_fs_repository.cc b/engine/repositories/message_fs_repository.cc index 388409390..422242e3a 100644 --- a/engine/repositories/message_fs_repository.cc +++ b/engine/repositories/message_fs_repository.cc @@ -48,7 +48,14 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, const std::string& before, const std::string& run_id) const { CTL_INF("Listing messages for thread " + thread_id); - auto path = GetMessagePath(thread_id); + + // Early validation + if (limit == 0) { + return std::vector(); + } + if (!after.empty() && !before.empty() && after >= before) { + return cpp::fail("Invalid range: 'after' must be less than 'before'"); + } auto mutex = GrabMutex(thread_id); std::shared_lock lock(*mutex); @@ -60,6 +67,11 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, std::vector messages = std::move(read_result.value()); + if (messages.empty()) { + return messages; + } + + // Filter by run_id if (!run_id.empty()) { messages.erase(std::remove_if(messages.begin(), messages.end(), [&run_id](const OpenAi::Message& msg) { @@ -68,52 +80,52 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, messages.end()); } - std::sort(messages.begin(), messages.end(), - [&order](const OpenAi::Message& a, const OpenAi::Message& b) { - if (order == "desc") { - return a.created_at > b.created_at; - } - return a.created_at < b.created_at; - }); + const bool is_descending = (order == "desc"); + std::sort( + messages.begin(), messages.end(), + [is_descending](const OpenAi::Message& a, const OpenAi::Message& b) { + return is_descending ? (a.id > b.id) : (a.id < b.id); + }); auto start_it = messages.begin(); auto end_it = messages.end(); if (!after.empty()) { - start_it = std::find_if( - messages.begin(), messages.end(), - [&after](const OpenAi::Message& msg) { return msg.id == after; }); - if (start_it != messages.end()) { - ++start_it; // Start from the message after the 'after' message - } else { - start_it = messages.begin(); + start_it = std::lower_bound( + messages.begin(), messages.end(), after, + [is_descending](const OpenAi::Message& msg, const std::string& value) { + return is_descending ? (msg.id > value) : (msg.id < value); + }); + + if (start_it != messages.end() && start_it->id == after) { + ++start_it; } } if (!before.empty()) { - end_it = std::find_if( - messages.begin(), messages.end(), - [&before](const OpenAi::Message& msg) { return msg.id == before; }); + end_it = std::upper_bound( + start_it, messages.end(), before, + [is_descending](const std::string& value, const OpenAi::Message& msg) { + return is_descending ? (value > msg.id) : (value < msg.id); + }); } - std::vector result; - size_t distance = std::distance(start_it, end_it); - size_t limit_size = static_cast(limit); - CTL_INF("Distance: " + std::to_string(distance) + - ", limit_size: " + std::to_string(limit_size)); - result.reserve(distance < limit_size ? distance : limit_size); + const size_t available_messages = std::distance(start_it, end_it); + const size_t result_size = + std::min(static_cast(limit), available_messages); - for (auto it = start_it; it != end_it && result.size() < limit_size; ++it) { - result.push_back(std::move(*it)); - } + CTL_INF("Available messages: " + std::to_string(available_messages) + + ", result size: " + std::to_string(result_size)); + + std::vector result; + result.reserve(result_size); + std::move(start_it, start_it + result_size, std::back_inserter(result)); return result; } cpp::result MessageFsRepository::RetrieveMessage( const std::string& thread_id, const std::string& message_id) const { - auto path = GetMessagePath(thread_id); - auto mutex = GrabMutex(thread_id); std::unique_lock lock(*mutex); @@ -133,8 +145,6 @@ cpp::result MessageFsRepository::RetrieveMessage( cpp::result MessageFsRepository::ModifyMessage( OpenAi::Message& message) { - auto path = GetMessagePath(message.thread_id); - auto mutex = GrabMutex(message.thread_id); std::unique_lock lock(*mutex); @@ -143,6 +153,7 @@ cpp::result MessageFsRepository::ModifyMessage( return cpp::fail(messages.error()); } + auto path = GetMessagePath(message.thread_id); std::ofstream file(path, std::ios::trunc); if (!file) { return cpp::fail("Failed to open file for writing: " + path.string()); From 630073233fd71193f0f8fc39f881b997552bbcbe Mon Sep 17 00:00:00 2001 From: NamH Date: Mon, 9 Dec 2024 23:08:36 +0700 Subject: [PATCH 21/44] chore: add backward compatible for thread (#1782) --- engine/common/thread.h | 15 +++++++++++++++ engine/controllers/threads.cc | 10 ++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/engine/common/thread.h b/engine/common/thread.h index 480c0ba78..2bd5d866b 100644 --- a/engine/common/thread.h +++ b/engine/common/thread.h @@ -124,6 +124,21 @@ struct Thread : JsonSerializable { json["object"] = object; json["created_at"] = created_at; + // Deprecated: This is for backward compatibility. Please remove it later. (2-3 releases) to be sure + try { + auto it = metadata.find("title"); + if (it == metadata.end()) { + json["title"] = ""; + } else { + json["title"] = std::get(metadata["title"]); + } + + } catch (const std::bad_variant_access& ex) { + // std::cerr << "Error: value is not a string" << std::endl; + CTL_WRN("Error: value of title is not a string: " << ex.what()); + } + // End deprecated + if (tool_resources) { auto tool_result = tool_resources->ToJson(); if (tool_result.has_error()) { diff --git a/engine/controllers/threads.cc b/engine/controllers/threads.cc index e130dad88..81e14ce5a 100644 --- a/engine/controllers/threads.cc +++ b/engine/controllers/threads.cc @@ -26,6 +26,7 @@ void Threads::ListThreads( for (auto& msg : res.value()) { if (auto it = msg.ToJson(); it.has_value()) { it->removeMember("assistants"); + it->removeMember("title"); msg_arr.append(it.value()); } else { CTL_WRN("Failed to convert message to json: " + it.error()); @@ -86,8 +87,10 @@ void Threads::CreateThread( resp->setStatusCode(k400BadRequest); callback(resp); } else { - auto resp = - cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + auto json_res = res->ToJson(); + json_res->removeMember("title"); + json_res->removeMember("assistants"); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(json_res.value()); resp->setStatusCode(k200OK); callback(resp); } @@ -116,6 +119,7 @@ void Threads::RetrieveThread( callback(resp); } else { thread_to_json->removeMember("assistants"); + thread_to_json->removeMember("title"); auto resp = cortex_utils::CreateCortexHttpJsonResponse(thread_to_json.value()); resp->setStatusCode(k200OK); @@ -189,6 +193,8 @@ void Threads::ModifyThread( resp->setStatusCode(k400BadRequest); callback(resp); } else { + res->ToJson()->removeMember("title"); + res->ToJson()->removeMember("assistants"); auto resp = cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); resp->setStatusCode(k200OK); From 0fa83b2ea6faf21a6e29a82cd1df2da2ef16cf31 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 10 Dec 2024 09:00:28 +0700 Subject: [PATCH 22/44] feat: prioritize gpus (#1768) * feat: prioritize GPUs * fix: migrate db * fix: add priority * fix: db * fix: more * fix: migration --------- Co-authored-by: vansangpfiev --- engine/controllers/hardware.cc | 2 +- engine/database/hardware.cc | 35 +++-- engine/database/hardware.h | 19 +-- engine/migrations/db_helper.h | 35 +++-- engine/migrations/migration_manager.cc | 15 +- engine/migrations/schema_version.h | 3 +- engine/migrations/v2/migration.h | 210 +++++++++++++++++++++++++ engine/services/hardware_service.cc | 77 ++++++--- 8 files changed, 331 insertions(+), 65 deletions(-) create mode 100644 engine/migrations/v2/migration.h diff --git a/engine/controllers/hardware.cc b/engine/controllers/hardware.cc index 39a109750..8b7884710 100644 --- a/engine/controllers/hardware.cc +++ b/engine/controllers/hardware.cc @@ -38,7 +38,7 @@ void Hardware::Activate( ahc.gpus.push_back(g.asInt()); } } - std::sort(ahc.gpus.begin(), ahc.gpus.end()); + if (!hw_svc_->IsValidConfig(ahc)) { Json::Value ret; ret["message"] = "Invalid GPU index provided."; diff --git a/engine/database/hardware.cc b/engine/database/hardware.cc index ff2eb853a..2ee1db968 100644 --- a/engine/database/hardware.cc +++ b/engine/database/hardware.cc @@ -5,14 +5,15 @@ namespace cortex::db { -Hardwares::Hardwares() : db_(cortex::db::Database::GetInstance().db()) {} +Hardware::Hardware() : db_(cortex::db::Database::GetInstance().db()) {} -Hardwares::Hardwares(SQLite::Database& db) : db_(db) {} +Hardware::Hardware(SQLite::Database& db) : db_(db) {} -Hardwares::~Hardwares() {} + +Hardware::~Hardware() {} cpp::result, std::string> -Hardwares::LoadHardwareList() const { +Hardware::LoadHardwareList() const { try { db_.exec("BEGIN TRANSACTION;"); cortex::utils::ScopeExit se([this] { db_.exec("COMMIT;"); }); @@ -20,7 +21,7 @@ Hardwares::LoadHardwareList() const { SQLite::Statement query( db_, "SELECT uuid, type, " - "hardware_id, software_id, activated FROM hardware"); + "hardware_id, software_id, activated, priority FROM hardware"); while (query.executeStep()) { HardwareEntry entry; @@ -29,6 +30,7 @@ Hardwares::LoadHardwareList() const { entry.hardware_id = query.getColumn(2).getInt(); entry.software_id = query.getColumn(3).getInt(); entry.activated = query.getColumn(4).getInt(); + entry.priority = query.getColumn(5).getInt(); entries.push_back(entry); } return entries; @@ -37,19 +39,20 @@ Hardwares::LoadHardwareList() const { return cpp::fail(e.what()); } } -cpp::result Hardwares::AddHardwareEntry( +cpp::result Hardware::AddHardwareEntry( const HardwareEntry& new_entry) { try { SQLite::Statement insert( db_, "INSERT INTO hardware (uuid, type, " - "hardware_id, software_id, activated) VALUES (?, ?, " - "?, ?, ?)"); + "hardware_id, software_id, activated, priority) VALUES (?, ?, " + "?, ?, ?, ?)"); insert.bind(1, new_entry.uuid); insert.bind(2, new_entry.type); insert.bind(3, new_entry.hardware_id); insert.bind(4, new_entry.software_id); insert.bind(5, new_entry.activated); + insert.bind(6, new_entry.priority); insert.exec(); CTL_INF("Inserted: " << new_entry.ToJsonString()); return true; @@ -58,17 +61,19 @@ cpp::result Hardwares::AddHardwareEntry( return cpp::fail(e.what()); } } -cpp::result Hardwares::UpdateHardwareEntry( +cpp::result Hardware::UpdateHardwareEntry( const std::string& id, const HardwareEntry& updated_entry) { try { - SQLite::Statement upd(db_, - "UPDATE hardware " - "SET hardware_id = ?, software_id = ?, activated = ? " - "WHERE uuid = ?"); + SQLite::Statement upd( + db_, + "UPDATE hardware " + "SET hardware_id = ?, software_id = ?, activated = ?, priority = ? " + "WHERE uuid = ?"); upd.bind(1, updated_entry.hardware_id); upd.bind(2, updated_entry.software_id); upd.bind(3, updated_entry.activated); - upd.bind(4, id); + upd.bind(4, updated_entry.priority); + upd.bind(5, id); if (upd.exec() == 1) { CTL_INF("Updated: " << updated_entry.ToJsonString()); return true; @@ -79,7 +84,7 @@ cpp::result Hardwares::UpdateHardwareEntry( } } -cpp::result Hardwares::DeleteHardwareEntry( +cpp::result Hardware::DeleteHardwareEntry( const std::string& id) { try { SQLite::Statement del(db_, "DELETE from hardware WHERE uuid = ?"); diff --git a/engine/database/hardware.h b/engine/database/hardware.h index 0966d58a3..04d0bbda1 100644 --- a/engine/database/hardware.h +++ b/engine/database/hardware.h @@ -4,8 +4,8 @@ #include #include #include -#include "utils/result.hpp" #include "utils/json_helper.h" +#include "utils/result.hpp" namespace cortex::db { struct HardwareEntry { @@ -14,6 +14,7 @@ struct HardwareEntry { int hardware_id; int software_id; bool activated; + int priority; std::string ToJsonString() const { Json::Value root; root["uuid"] = uuid; @@ -21,26 +22,26 @@ struct HardwareEntry { root["hardware_id"] = hardware_id; root["software_id"] = software_id; root["activated"] = activated; + root["priority"] = priority; return json_helper::DumpJsonString(root); } }; -class Hardwares { +class Hardware { private: SQLite::Database& db_; - public: - Hardwares(); - Hardwares(SQLite::Database& db); - ~Hardwares(); + Hardware(); + Hardware(SQLite::Database& db); + ~Hardware(); cpp::result, std::string> LoadHardwareList() const; - cpp::result AddHardwareEntry(const HardwareEntry& new_entry); + cpp::result AddHardwareEntry( + const HardwareEntry& new_entry); cpp::result UpdateHardwareEntry( const std::string& id, const HardwareEntry& updated_entry); - cpp::result DeleteHardwareEntry( - const std::string& id); + cpp::result DeleteHardwareEntry(const std::string& id); }; } // namespace cortex::db \ No newline at end of file diff --git a/engine/migrations/db_helper.h b/engine/migrations/db_helper.h index 0990426bf..867e871ff 100644 --- a/engine/migrations/db_helper.h +++ b/engine/migrations/db_helper.h @@ -4,23 +4,28 @@ namespace cortex::mgr { #include #include -#include #include +#include -inline bool ColumnExists(SQLite::Database& db, const std::string& table_name, const std::string& column_name) { - try { - SQLite::Statement query(db, "SELECT " + column_name + " FROM " + table_name + " LIMIT 0"); - return true; - } catch (std::exception&) { - return false; - } +inline bool ColumnExists(SQLite::Database& db, const std::string& table_name, + const std::string& column_name) { + try { + SQLite::Statement query( + db, "SELECT " + column_name + " FROM " + table_name + " LIMIT 0"); + return true; + } catch (std::exception&) { + return false; + } } -inline void AddColumnIfNotExists(SQLite::Database& db, const std::string& table_name, - const std::string& column_name, const std::string& column_type) { - if (!ColumnExists(db, table_name, column_name)) { - std::string sql = "ALTER TABLE " + table_name + " ADD COLUMN " + column_name + " " + column_type; - db.exec(sql); - } +inline void AddColumnIfNotExists(SQLite::Database& db, + const std::string& table_name, + const std::string& column_name, + const std::string& column_type) { + if (!ColumnExists(db, table_name, column_name)) { + std::string sql = "ALTER TABLE " + table_name + " ADD COLUMN " + + column_name + " " + column_type; + db.exec(sql); + } } -} \ No newline at end of file +} // namespace cortex::mgr diff --git a/engine/migrations/migration_manager.cc b/engine/migrations/migration_manager.cc index 0e2e41e4e..6936f45a0 100644 --- a/engine/migrations/migration_manager.cc +++ b/engine/migrations/migration_manager.cc @@ -7,7 +7,7 @@ #include "utils/widechar_conv.h" #include "v0/migration.h" #include "v1/migration.h" - +#include "v2/migration.h" namespace cortex::migr { namespace { @@ -141,9 +141,11 @@ cpp::result MigrationManager::DoUpFolderStructure( switch (version) { case 0: return v0::MigrateFolderStructureUp(); - break; case 1: return v1::MigrateFolderStructureUp(); + case 2: + return v2::MigrateFolderStructureUp(); + break; default: @@ -155,9 +157,10 @@ cpp::result MigrationManager::DoDownFolderStructure( switch (version) { case 0: return v0::MigrateFolderStructureDown(); - break; case 1: return v1::MigrateFolderStructureDown(); + case 2: + return v2::MigrateFolderStructureDown(); break; default: @@ -191,9 +194,10 @@ cpp::result MigrationManager::DoUpDB(int version) { switch (version) { case 0: return v0::MigrateDBUp(db_); - break; case 1: return v1::MigrateDBUp(db_); + case 2: + return v2::MigrateDBUp(db_); break; default: @@ -205,9 +209,10 @@ cpp::result MigrationManager::DoDownDB(int version) { switch (version) { case 0: return v0::MigrateDBDown(db_); - break; case 1: return v1::MigrateDBDown(db_); + case 2: + return v2::MigrateDBDown(db_); break; default: diff --git a/engine/migrations/schema_version.h b/engine/migrations/schema_version.h index 1e64110e3..5739040d0 100644 --- a/engine/migrations/schema_version.h +++ b/engine/migrations/schema_version.h @@ -1,4 +1,5 @@ #pragma once //Track the current schema version -#define SCHEMA_VERSION 1 \ No newline at end of file +#define SCHEMA_VERSION 2 + diff --git a/engine/migrations/v2/migration.h b/engine/migrations/v2/migration.h new file mode 100644 index 000000000..54b79f666 --- /dev/null +++ b/engine/migrations/v2/migration.h @@ -0,0 +1,210 @@ +#pragma once +#include +#include +#include +#include "migrations/db_helper.h" +#include "utils/file_manager_utils.h" +#include "utils/logging_utils.h" +#include "utils/result.hpp" + +namespace cortex::migr::v2 { +// Data folder +namespace fmu = file_manager_utils; + +// cortexcpp +// |__ models +// | |__ cortex.so +// | |__ tinyllama +// | |__ gguf +// |__ engines +// | |__ cortex.llamacpp +// | |__ deps +// | |__ windows-amd64-avx +// |__ logs +// +inline cpp::result MigrateFolderStructureUp() { + if (!std::filesystem::exists(fmu::GetCortexDataPath() / "models")) { + std::filesystem::create_directory(fmu::GetCortexDataPath() / "models"); + } + + if (!std::filesystem::exists(fmu::GetCortexDataPath() / "engines")) { + std::filesystem::create_directory(fmu::GetCortexDataPath() / "engines"); + } + + if (!std::filesystem::exists(fmu::GetCortexDataPath() / "logs")) { + std::filesystem::create_directory(fmu::GetCortexDataPath() / "logs"); + } + + return true; +} + +inline cpp::result MigrateFolderStructureDown() { + // CTL_INF("Folder structure already up to date!"); + return true; +} + +// Database +inline cpp::result MigrateDBUp(SQLite::Database& db) { + try { + db.exec( + "CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER PRIMARY " + "KEY);"); + + // models + { + // Check if the table exists + SQLite::Statement query(db, + "SELECT name FROM sqlite_master WHERE " + "type='table' AND name='models'"); + auto table_exists = query.executeStep(); + + if (table_exists) { + // Alter existing table + cortex::mgr::AddColumnIfNotExists(db, "models", "metadata", "TEXT"); + } else { + // Create new table + db.exec( + "CREATE TABLE models (" + "model_id TEXT PRIMARY KEY," + "author_repo_id TEXT," + "branch_name TEXT," + "path_to_model_yaml TEXT," + "model_alias TEXT," + "model_format TEXT," + "model_source TEXT," + "status TEXT," + "engine TEXT," + "metadata TEXT" + ")"); + } + } + + // Check if the table exists + SQLite::Statement hw_query(db, + "SELECT name FROM sqlite_master WHERE " + "type='table' AND name='hardware'"); + auto hw_table_exists = hw_query.executeStep(); + + if (hw_table_exists) { + // Alter existing table + cortex::mgr::AddColumnIfNotExists(db, "hardware", "priority", "INTEGER"); + } else { + db.exec( + "CREATE TABLE IF NOT EXISTS hardware (" + "uuid TEXT PRIMARY KEY, " + "type TEXT NOT NULL, " + "hardware_id INTEGER NOT NULL, " + "software_id INTEGER NOT NULL, " + "activated INTEGER NOT NULL CHECK (activated IN (0, 1)), " + "priority INTEGER); "); + } + + // engines + db.exec( + "CREATE TABLE IF NOT EXISTS engines (" + "id INTEGER PRIMARY KEY AUTOINCREMENT," + "engine_name TEXT," + "type TEXT," + "api_key TEXT," + "url TEXT," + "version TEXT," + "variant TEXT," + "status TEXT," + "metadata TEXT," + "date_created TEXT DEFAULT CURRENT_TIMESTAMP," + "date_updated TEXT DEFAULT CURRENT_TIMESTAMP," + "UNIQUE(engine_name, variant));"); + + // CTL_INF("Database migration up completed successfully."); + return true; + } catch (const std::exception& e) { + CTL_WRN("Migration up failed: " << e.what()); + return cpp::fail(e.what()); + } +}; + +inline cpp::result MigrateDBDown(SQLite::Database& db) { + try { + // models + { + SQLite::Statement query(db, + "SELECT name FROM sqlite_master WHERE " + "type='table' AND name='models'"); + auto table_exists = query.executeStep(); + if (table_exists) { + // Create a new table with the old schema + db.exec( + "CREATE TABLE models_old (" + "model_id TEXT PRIMARY KEY," + "author_repo_id TEXT," + "branch_name TEXT," + "path_to_model_yaml TEXT," + "model_alias TEXT," + "model_format TEXT," + "model_source TEXT," + "status TEXT," + "engine TEXT" + ")"); + + // Copy data from the current table to the new table + db.exec( + "INSERT INTO models_old (model_id, author_repo_id, branch_name, " + "path_to_model_yaml, model_alias, model_format, model_source, " + "status, engine) " + "SELECT model_id, author_repo_id, branch_name, path_to_model_yaml, " + "model_alias, model_format, model_source, status, engine FROM " + "models"); + + // Drop the current table + db.exec("DROP TABLE models"); + + // Rename the new table to the original name + db.exec("ALTER TABLE models_old RENAME TO models"); + } + } + + // hardware + { + SQLite::Statement query(db, + "SELECT name FROM sqlite_master WHERE " + "type='table' AND name='hardware'"); + auto table_exists = query.executeStep(); + if (table_exists) { + // Create a new table with the old schema + db.exec( + "CREATE TABLE hardware_old (" + "uuid TEXT PRIMARY KEY, " + "type TEXT NOT NULL, " + "hardware_id INTEGER NOT NULL, " + "software_id INTEGER NOT NULL, " + "activated INTEGER NOT NULL CHECK (activated IN (0, 1))" + ")"); + + // Copy data from the current table to the new table + db.exec( + "INSERT INTO hardware_old (uuid, type, hardware_id, " + "software_id, activated) " + "SELECT uuid, type, hardware_id, software_id, " + "activated FROM hardware"); + + // Drop the current table + db.exec("DROP TABLE hardware"); + + // Rename the new table to the original name + db.exec("ALTER TABLE hardware_old RENAME TO hardware"); + } + } + + // engines + { + // do nothing + } + // CTL_INF("Migration down completed successfully."); + return true; + } catch (const std::exception& e) { + CTL_WRN("Migration down failed: " << e.what()); + return cpp::fail(e.what()); + } +} + +}; // namespace cortex::migr::v2 \ No newline at end of file diff --git a/engine/services/hardware_service.cc b/engine/services/hardware_service.cc index 681ca7578..25be78873 100644 --- a/engine/services/hardware_service.cc +++ b/engine/services/hardware_service.cc @@ -34,7 +34,7 @@ bool TryConnectToServer(const std::string& host, int port) { HardwareInfo HardwareService::GetHardwareInfo() { // append active state - cortex::db::Hardwares hw_db; + cortex::db::Hardware hw_db; auto gpus = cortex::hw::GetGPUInfo(); auto res = hw_db.LoadHardwareList(); if (res.has_value()) { @@ -191,31 +191,61 @@ bool HardwareService::Restart(const std::string& host, int port) { return true; } +// GPU identifiers are given as integer indices or as UUID strings. GPU UUID strings +// should follow the same format as given by nvidia-smi, such as GPU-8932f937-d72c-4106-c12f-20bd9faed9f6. +// However, for convenience, abbreviated forms are allowed; simply specify enough digits +// from the beginning of the GPU UUID to uniquely identify that GPU in the target system. +// For example, CUDA_VISIBLE_DEVICES=GPU-8932f937 may be a valid way to refer to the above GPU UUID, +// assuming no other GPU in the system shares this prefix. Only the devices whose index +// is present in the sequence are visible to CUDA applications and they are enumerated +// in the order of the sequence. If one of the indices is invalid, only the devices whose +// index precedes the invalid index are visible to CUDA applications. For example, setting +// CUDA_VISIBLE_DEVICES to 2,1 causes device 0 to be invisible and device 2 to be enumerated +// before device 1. Setting CUDA_VISIBLE_DEVICES to 0,2,-1,1 causes devices 0 and 2 to be +// visible and device 1 to be invisible. MIG format starts with MIG keyword and GPU UUID +// should follow the same format as given by nvidia-smi. +// For example, MIG-GPU-8932f937-d72c-4106-c12f-20bd9faed9f6/1/2. +// Only single MIG instance enumeration is supported. bool HardwareService::SetActivateHardwareConfig( const cortex::hw::ActivateHardwareConfig& ahc) { // Note: need to map software_id and hardware_id // Update to db - cortex::db::Hardwares hw_db; + cortex::db::Hardware hw_db; + // copy all gpu information to new vector + auto ahc_gpus = ahc.gpus; auto activate = [&ahc](int software_id) { return std::count(ahc.gpus.begin(), ahc.gpus.end(), software_id) > 0; }; + auto priority = [&ahc](int software_id) -> int { + for (size_t i = 0; i < ahc.gpus.size(); i++) { + if (ahc.gpus[i] == software_id) + return i; + break; + } + return INT_MAX; + }; + auto res = hw_db.LoadHardwareList(); if (res.has_value()) { bool need_update = false; - std::vector activated_ids; + std::vector> activated_ids; // Check if need to update for (auto const& e : res.value()) { if (e.activated) { - activated_ids.push_back(e.software_id); + activated_ids.push_back(std::pair(e.software_id, e.priority)); } } std::sort(activated_ids.begin(), activated_ids.end()); - if (ahc.gpus.size() != activated_ids.size()) { + std::sort(ahc_gpus.begin(), ahc_gpus.end()); + if (ahc_gpus.size() != activated_ids.size()) { need_update = true; } else { - for (size_t i = 0; i < ahc.gpus.size(); i++) { - if (ahc.gpus[i] != activated_ids[i]) + for (size_t i = 0; i < ahc_gpus.size(); i++) { + // if activated id or priority changes + if (ahc_gpus[i] != activated_ids[i].first || + i != activated_ids[i].second) need_update = true; + break; } } @@ -227,6 +257,7 @@ bool HardwareService::SetActivateHardwareConfig( // Need to update, proceed for (auto& e : res.value()) { e.activated = activate(e.software_id); + e.priority = priority(e.software_id); auto res = hw_db.UpdateHardwareEntry(e.uuid, e); if (res.has_error()) { CTL_WRN(res.error()); @@ -240,14 +271,14 @@ bool HardwareService::SetActivateHardwareConfig( void HardwareService::UpdateHardwareInfos() { using HwEntry = cortex::db::HardwareEntry; auto gpus = cortex::hw::GetGPUInfo(); - cortex::db::Hardwares hw_db; + cortex::db::Hardware hw_db; auto b = hw_db.LoadHardwareList(); - std::vector activated_gpu_bf; + std::vector> activated_gpu_bf; std::string debug_b; for (auto const& he : b.value()) { if (he.type == "gpu" && he.activated) { debug_b += std::to_string(he.software_id) + " "; - activated_gpu_bf.push_back(he.software_id); + activated_gpu_bf.push_back(std::pair(he.software_id, he.priority)); } } CTL_INF("Activated GPUs before: " << debug_b); @@ -258,7 +289,8 @@ void HardwareService::UpdateHardwareInfos() { .type = "gpu", .hardware_id = std::stoi(gpu.id), .software_id = std::stoi(gpu.id), - .activated = true}); + .activated = true, + .priority = INT_MAX}); if (res.has_error()) { CTL_WRN(res.error()); } @@ -266,24 +298,26 @@ void HardwareService::UpdateHardwareInfos() { auto a = hw_db.LoadHardwareList(); std::vector a_gpu; - std::vector activated_gpu_af; + std::vector> activated_gpu_af; std::string debug_a; for (auto const& he : a.value()) { if (he.type == "gpu" && he.activated) { debug_a += std::to_string(he.software_id) + " "; - activated_gpu_af.push_back(he.software_id); + activated_gpu_af.push_back(std::pair(he.software_id, he.priority)); } } CTL_INF("Activated GPUs after: " << debug_a); // if hardware list changes, need to restart - std::sort(activated_gpu_bf.begin(), activated_gpu_bf.end()); - std::sort(activated_gpu_af.begin(), activated_gpu_af.end()); + std::sort(activated_gpu_bf.begin(), activated_gpu_bf.end(), + [](auto& p1, auto& p2) { return p1.second < p2.second; }); + std::sort(activated_gpu_af.begin(), activated_gpu_af.end(), + [](auto& p1, auto& p2) { return p1.second < p2.second; }); bool need_restart = false; if (activated_gpu_bf.size() != activated_gpu_af.size()) { need_restart = true; } else { for (size_t i = 0; i < activated_gpu_bf.size(); i++) { - if (activated_gpu_bf[i] != activated_gpu_af[i]) { + if (activated_gpu_bf[i].first != activated_gpu_af[i].first) { need_restart = true; break; } @@ -291,7 +325,8 @@ void HardwareService::UpdateHardwareInfos() { } #if defined(_WIN32) || defined(_WIN64) || defined(__linux__) - if (!gpus.empty()) { + bool has_deactivated_gpu = a.value().size() != activated_gpu_af.size(); + if (!gpus.empty() && has_deactivated_gpu) { const char* value = std::getenv("CUDA_VISIBLE_DEVICES"); if (value) { LOG_INFO << "CUDA_VISIBLE_DEVICES: " << value; @@ -303,7 +338,11 @@ void HardwareService::UpdateHardwareInfos() { if (need_restart) { CTL_INF("Need restart"); - ahc_ = {.gpus = activated_gpu_af}; + std::vector gpus; + for (auto const& p : activated_gpu_af) { + gpus.push_back(p.first); + } + ahc_ = {.gpus = gpus}; } } @@ -311,7 +350,7 @@ bool HardwareService::IsValidConfig( const cortex::hw::ActivateHardwareConfig& ahc) { if (ahc.gpus.empty()) return true; - cortex::db::Hardwares hw_db; + cortex::db::Hardware hw_db; auto is_valid = [&ahc](int software_id) { return std::count(ahc.gpus.begin(), ahc.gpus.end(), software_id) > 0; }; From 43e740da5a07d1fdf240929f81541e1898df3f67 Mon Sep 17 00:00:00 2001 From: NamH Date: Tue, 10 Dec 2024 09:55:41 +0700 Subject: [PATCH 23/44] Update Engine interface (#1759) * chore: add document * feat: update engine interface --- docs/docs/engines/engine-extension.mdx | 235 ++++++++++++++++------ engine/cli/commands/server_start_cmd.cc | 22 +-- engine/controllers/engines.cc | 5 +- engine/cortex-common/EngineI.h | 30 +++ engine/services/engine_service.cc | 246 +++++++++++------------- engine/services/engine_service.h | 12 +- engine/services/hardware_service.cc | 2 +- engine/utils/config_yaml_utils.cc | 1 + engine/utils/config_yaml_utils.h | 5 +- engine/utils/file_manager_utils.cc | 1 + 10 files changed, 341 insertions(+), 218 deletions(-) diff --git a/docs/docs/engines/engine-extension.mdx b/docs/docs/engines/engine-extension.mdx index 8a62cd813..6bb966f60 100644 --- a/docs/docs/engines/engine-extension.mdx +++ b/docs/docs/engines/engine-extension.mdx @@ -1,89 +1,210 @@ --- -title: Building Engine Extensions +title: Adding a Third-Party Engine to Cortex description: Cortex supports Engine Extensions to integrate both :ocal inference engines, and Remote APIs. --- -:::info -🚧 Cortex is currently under development, and this page is a stub for future development. -::: - - +We welcome suggestions and contributions to improve this integration process. Please feel free to submit issues or pull requests through our repository. diff --git a/engine/cli/commands/server_start_cmd.cc b/engine/cli/commands/server_start_cmd.cc index ba4f7bd82..3d52f3d25 100644 --- a/engine/cli/commands/server_start_cmd.cc +++ b/engine/cli/commands/server_start_cmd.cc @@ -1,9 +1,12 @@ #include "server_start_cmd.h" #include "commands/cortex_upd_cmd.h" +#include "services/engine_service.h" #include "utils/cortex_utils.h" -#include "utils/engine_constants.h" #include "utils/file_manager_utils.h" + +#if defined(_WIN32) || defined(_WIN64) #include "utils/widechar_conv.h" +#endif namespace commands { @@ -108,22 +111,9 @@ bool ServerStartCmd::Exec(const std::string& host, int port, std::cerr << "Could not start server: " << std::endl; return false; } else if (pid == 0) { - // No need to configure LD_LIBRARY_PATH for macOS -#if !defined(__APPLE__) || !defined(__MACH__) - const char* name = "LD_LIBRARY_PATH"; - auto data = getenv(name); - std::string v; - if (auto g = getenv(name); g) { - v += g; - } - CTL_INF("LD_LIBRARY_PATH: " << v); - auto llamacpp_path = file_manager_utils::GetCudaToolkitPath(kLlamaRepo); - auto trt_path = file_manager_utils::GetCudaToolkitPath(kTrtLlmRepo); + // Some engines requires to add lib search path before process being created + EngineService().RegisterEngineLibPath(); - auto new_v = trt_path.string() + ":" + llamacpp_path.string() + ":" + v; - setenv(name, new_v.c_str(), true); - CTL_INF("LD_LIBRARY_PATH: " << getenv(name)); -#endif std::string p = cortex_utils::GetCurrentPath() + "/" + exe; execl(p.c_str(), exe.c_str(), "--start-server", "--config_file_path", get_config_file_path().c_str(), "--data_folder_path", diff --git a/engine/controllers/engines.cc b/engine/controllers/engines.cc index 3d3c0c037..1d0223d9a 100644 --- a/engine/controllers/engines.cc +++ b/engine/controllers/engines.cc @@ -23,10 +23,9 @@ std::string NormalizeEngine(const std::string& engine) { void Engines::ListEngine( const HttpRequestPtr& req, std::function&& callback) const { - std::vector supported_engines{kLlamaEngine, kOnnxEngine, - kTrtLlmEngine}; Json::Value ret; - for (const auto& engine : supported_engines) { + auto engine_names = engine_service_->GetSupportedEngineNames().value(); + for (const auto& engine : engine_names) { auto installed_engines = engine_service_->GetInstalledEngineVariants(engine); if (installed_engines.has_error()) { diff --git a/engine/cortex-common/EngineI.h b/engine/cortex-common/EngineI.h index 51e19c124..11866a708 100644 --- a/engine/cortex-common/EngineI.h +++ b/engine/cortex-common/EngineI.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -7,8 +8,37 @@ #include "trantor/utils/Logger.h" class EngineI { public: + struct RegisterLibraryOption { + std::vector paths; + }; + + struct EngineLoadOption { + // engine + std::filesystem::path engine_path; + std::filesystem::path cuda_path; + bool custom_engine_path; + + // logging + std::filesystem::path log_path; + int max_log_lines; + trantor::Logger::LogLevel log_level; + }; + + struct EngineUnloadOption { + bool unload_dll; + }; + virtual ~EngineI() {} + /** + * Being called before starting process to register dependencies search paths. + */ + virtual void RegisterLibraryPath(RegisterLibraryOption opts) = 0; + + virtual void Load(EngineLoadOption opts) = 0; + + virtual void Unload(EngineUnloadOption opts) = 0; + // cortex.llamacpp interface virtual void HandleChatCompletion( std::shared_ptr json_body, diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index fe5317c7d..4f2122f6b 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -2,6 +2,7 @@ #include #include #include +#include #include #include "algorithm" #include "database/engines.h" @@ -17,6 +18,7 @@ #include "utils/semantic_version_utils.h" #include "utils/system_info_utils.h" #include "utils/url_parser.h" + namespace { std::string GetSuitableCudaVersion(const std::string& engine, const std::string& cuda_driver_version) { @@ -701,6 +703,87 @@ cpp::result EngineService::LoadEngine( CTL_INF("Loading engine: " << ne); + auto engine_dir_path_res = GetEngineDirPath(ne); + if (engine_dir_path_res.has_error()) { + return cpp::fail(engine_dir_path_res.error()); + } + auto engine_dir_path = engine_dir_path_res.value().first; + auto custom_engine_path = engine_dir_path_res.value().second; + + try { + auto dylib = + std::make_unique(engine_dir_path.string(), "engine"); + + auto config = file_manager_utils::GetCortexConfig(); + + auto log_path = + std::filesystem::path(config.logFolderPath) / + std::filesystem::path( + config.logLlamaCppPath); // for now seems like we use same log path + + // init + auto func = dylib->get_function("get_engine"); + auto engine_obj = func(); + auto load_opts = EngineI::EngineLoadOption{ + .engine_path = engine_dir_path, + .cuda_path = file_manager_utils::GetCudaToolkitPath(ne), + .custom_engine_path = custom_engine_path, + .log_path = log_path, + .max_log_lines = config.maxLogLines, + .log_level = logging_utils_helper::global_log_level, + }; + engine_obj->Load(load_opts); + + engines_[ne].engine = engine_obj; + engines_[ne].dl = std::move(dylib); + + CTL_DBG("Engine loaded: " << ne); + return {}; + } catch (const cortex_cpp::dylib::load_error& e) { + CTL_ERR("Could not load engine: " << e.what()); + engines_.erase(ne); + return cpp::fail("Could not load engine " + ne + ": " + e.what()); + } +} + +void EngineService::RegisterEngineLibPath() { + auto engine_names = GetSupportedEngineNames().value(); + for (const auto& engine : engine_names) { + auto ne = NormalizeEngine(engine); + try { + auto engine_dir_path_res = GetEngineDirPath(engine); + if (engine_dir_path_res.has_error()) { + CTL_ERR( + "Could not get engine dir path: " << engine_dir_path_res.error()); + continue; + } + auto engine_dir_path = engine_dir_path_res.value().first; + auto custom_engine_path = engine_dir_path_res.value().second; + + auto dylib = std::make_unique(engine_dir_path.string(), + "engine"); + + auto cuda_path = file_manager_utils::GetCudaToolkitPath(ne); + // init + auto func = dylib->get_function("get_engine"); + auto engine = func(); + std::vector paths{}; + auto register_opts = EngineI::RegisterLibraryOption{ + .paths = paths, + }; + engine->RegisterLibraryPath(register_opts); + delete engine; + CTL_DBG("Register lib path for: " << engine); + } catch (const std::exception& e) { + CTL_WRN("Failed to registering engine lib path: " << e.what()); + } + } +} + +cpp::result, std::string> +EngineService::GetEngineDirPath(const std::string& engine_name) { + auto ne = NormalizeEngine(engine_name); + auto selected_engine_variant = GetDefaultEngineVariant(ne); if (selected_engine_variant.has_error()) { @@ -715,6 +798,7 @@ cpp::result EngineService::LoadEngine( auto user_defined_engine_path = getenv("ENGINE_PATH"); #endif + auto custom_engine_path = user_defined_engine_path != nullptr; CTL_DBG("user defined engine path: " << user_defined_engine_path); const std::filesystem::path engine_dir_path = [&] { if (user_defined_engine_path != nullptr) { @@ -728,157 +812,38 @@ cpp::result EngineService::LoadEngine( } }(); - CTL_DBG("Engine path: " << engine_dir_path.string()); - if (!std::filesystem::exists(engine_dir_path)) { CTL_ERR("Directory " + engine_dir_path.string() + " is not exist!"); return cpp::fail("Directory " + engine_dir_path.string() + " is not exist!"); } - CTL_INF("Engine path: " << engine_dir_path.string()); - - try { -#if defined(_WIN32) - // TODO(?) If we only allow to load an engine at a time, the logic is simpler. - // We would like to support running multiple engines at the same time. Therefore, - // the adding/removing dll directory logic is quite complicated: - // 1. If llamacpp is loaded and new requested engine is tensorrt-llm: - // Unload the llamacpp dll directory then load the tensorrt-llm - // 2. If tensorrt-llm is loaded and new requested engine is llamacpp: - // Do nothing, llamacpp can re-use tensorrt-llm dependencies (need to be tested careful) - // 3. Add dll directory if met other conditions - - auto add_dll = [this](const std::string& e_type, - const std::filesystem::path& p) { - if (auto cookie = AddDllDirectory(p.c_str()); cookie != 0) { - CTL_DBG("Added dll directory: " << p.string()); - engines_[e_type].cookie = cookie; - } else { - CTL_WRN("Could not add dll directory: " << p.string()); - } - - auto cuda_path = file_manager_utils::GetCudaToolkitPath(e_type); - if (auto cuda_cookie = AddDllDirectory(cuda_path.c_str()); - cuda_cookie != 0) { - CTL_DBG("Added cuda dll directory: " << p.string()); - engines_[e_type].cuda_cookie = cuda_cookie; - } else { - CTL_WRN("Could not add cuda dll directory: " << p.string()); - } - }; - -#if defined(_WIN32) - if (bool should_use_dll_search_path = !(_wgetenv(L"ENGINE_PATH")); -#else - if (bool should_use_dll_search_path = !(getenv("ENGINE_PATH")); -#endif - should_use_dll_search_path) { - if (IsEngineLoaded(kLlamaRepo) && ne == kTrtLlmRepo && - should_use_dll_search_path) { - - { - - // Remove llamacpp dll directory - if (!RemoveDllDirectory(engines_[kLlamaRepo].cookie)) { - CTL_WRN("Could not remove dll directory: " << kLlamaRepo); - } else { - CTL_DBG("Removed dll directory: " << kLlamaRepo); - } - if (!RemoveDllDirectory(engines_[kLlamaRepo].cuda_cookie)) { - CTL_WRN("Could not remove cuda dll directory: " << kLlamaRepo); - } else { - CTL_DBG("Removed cuda dll directory: " << kLlamaRepo); - } - } - - add_dll(ne, engine_dir_path); - } else if (IsEngineLoaded(kTrtLlmRepo) && ne == kLlamaRepo) { - // Do nothing - } else { - add_dll(ne, engine_dir_path); - } - } -#endif - engines_[ne].dl = - std::make_unique(engine_dir_path.string(), "engine"); -#if defined(__linux__) - const char* name = "LD_LIBRARY_PATH"; - auto data = getenv(name); - std::string v; - if (auto g = getenv(name); g) { - v += g; - } - CTL_INF("LD_LIBRARY_PATH: " << v); - auto llamacpp_path = file_manager_utils::GetCudaToolkitPath(kLlamaRepo); - CTL_INF("llamacpp_path: " << llamacpp_path); - // tensorrt is not supported for now - // auto trt_path = file_manager_utils::GetCudaToolkitPath(kTrtLlmRepo); - - auto new_v = llamacpp_path.string() + ":" + v; - setenv(name, new_v.c_str(), true); - CTL_INF("LD_LIBRARY_PATH: " << getenv(name)); -#endif - - } catch (const cortex_cpp::dylib::load_error& e) { - CTL_ERR("Could not load engine: " << e.what()); - engines_.erase(ne); - return cpp::fail("Could not load engine " + ne + ": " + e.what()); - } - - auto func = engines_[ne].dl->get_function("get_engine"); - engines_[ne].engine = func(); - - auto& en = std::get(engines_[ne].engine); - if (ne == kLlamaRepo) { //fix for llamacpp engine first - auto config = file_manager_utils::GetCortexConfig(); - if (en->IsSupported("SetFileLogger")) { - en->SetFileLogger(config.maxLogLines, - (std::filesystem::path(config.logFolderPath) / - std::filesystem::path(config.logLlamaCppPath)) - .string()); - } else { - CTL_WRN("Method SetFileLogger is not supported yet"); - } - if (en->IsSupported("SetLogLevel")) { - en->SetLogLevel(logging_utils_helper::global_log_level); - } else { - CTL_WRN("Method SetLogLevel is not supported yet"); - } - } - CTL_DBG("loaded engine: " << ne); - return {}; + CTL_INF("Engine path: " << engine_dir_path.string() + << ", custom_engine_path: " << custom_engine_path); + return std::make_pair(engine_dir_path, custom_engine_path); } cpp::result EngineService::UnloadEngine( const std::string& engine) { auto ne = NormalizeEngine(engine); std::lock_guard lock(engines_mutex_); - { - if (!IsEngineLoaded(ne)) { - return cpp::fail("Engine " + ne + " is not loaded yet!"); - } - if (std::holds_alternative(engines_[ne].engine)) { - delete std::get(engines_[ne].engine); - } else { - delete std::get(engines_[ne].engine); - } - -#if defined(_WIN32) - if (!RemoveDllDirectory(engines_[ne].cookie)) { - CTL_WRN("Could not remove dll directory: " << ne); - } else { - CTL_DBG("Removed dll directory: " << ne); - } - if (!RemoveDllDirectory(engines_[ne].cuda_cookie)) { - CTL_WRN("Could not remove cuda dll directory: " << ne); - } else { - CTL_DBG("Removed cuda dll directory: " << ne); - } -#endif + if (!IsEngineLoaded(ne)) { + return cpp::fail("Engine " + ne + " is not loaded yet!"); + } + if (std::holds_alternative(engines_[ne].engine)) { + LOG_INFO << "Unloading engine " << ne; + auto* e = std::get(engines_[ne].engine); + auto unload_opts = EngineI::EngineUnloadOption{ + .unload_dll = true, + }; + e->Unload(unload_opts); + delete e; engines_.erase(ne); + } else { + delete std::get(engines_[ne].engine); } - CTL_DBG("Unloaded engine " + ne); + + CTL_DBG("Engine unloaded: " + ne); return {}; } @@ -1097,4 +1062,9 @@ cpp::result EngineService::GetRemoteModels( } else { return res; } -} \ No newline at end of file +} + +cpp::result, std::string> +EngineService::GetSupportedEngineNames() { + return file_manager_utils::GetCortexConfig().supportedEngines; +} diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index ab274825d..8299655f2 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -13,7 +13,6 @@ #include "cortex-common/cortexpythoni.h" #include "cortex-common/remote_enginei.h" #include "database/engines.h" -#include "extensions/remote-engine/remote_engine.h" #include "services/download_service.h" #include "utils/cpuid/cpu_info.h" #include "utils/dylib.h" @@ -75,6 +74,9 @@ class EngineService : public EngineServiceI { .cuda_driver_version = system_info_utils::GetDriverAndCudaVersion().second} {} + // just for initialize supported engines + EngineService() {}; + std::vector GetEngineInfoList() const; /** @@ -148,6 +150,9 @@ class EngineService : public EngineServiceI { cpp::result GetRemoteModels( const std::string& engine_name); + cpp::result, std::string> GetSupportedEngineNames(); + + void RegisterEngineLibPath(); private: bool IsEngineLoaded(const std::string& engine); @@ -162,7 +167,10 @@ class EngineService : public EngineServiceI { std::string GetMatchedVariant(const std::string& engine, const std::vector& variants); + cpp::result, std::string> + GetEngineDirPath(const std::string& engine_name); + cpp::result IsEngineVariantReady( const std::string& engine, const std::string& version, const std::string& variant); -}; \ No newline at end of file +}; diff --git a/engine/services/hardware_service.cc b/engine/services/hardware_service.cc index 25be78873..97ddacb97 100644 --- a/engine/services/hardware_service.cc +++ b/engine/services/hardware_service.cc @@ -5,11 +5,11 @@ #if defined(_WIN32) || defined(_WIN64) #include #include +#include "utils/widechar_conv.h" #endif #include "cli/commands/cortex_upd_cmd.h" #include "database/hardware.h" #include "utils/cortex_utils.h" -#include "utils/widechar_conv.h" namespace services { diff --git a/engine/utils/config_yaml_utils.cc b/engine/utils/config_yaml_utils.cc index ed6437256..c7a696df4 100644 --- a/engine/utils/config_yaml_utils.cc +++ b/engine/utils/config_yaml_utils.cc @@ -49,6 +49,7 @@ cpp::result CortexConfigMgr::DumpYamlConfig( node["verifyHostSsl"] = config.verifyHostSsl; node["sslCertPath"] = config.sslCertPath; node["sslKeyPath"] = config.sslKeyPath; + node["supportedEngines"] = config.supportedEngines; out_file << node; out_file.close(); diff --git a/engine/utils/config_yaml_utils.h b/engine/utils/config_yaml_utils.h index d36cc48e0..f9925ea86 100644 --- a/engine/utils/config_yaml_utils.h +++ b/engine/utils/config_yaml_utils.h @@ -3,6 +3,7 @@ #include #include #include +#include "utils/engine_constants.h" #include "utils/result.hpp" namespace config_yaml_utils { @@ -18,6 +19,8 @@ constexpr const auto kDefaultCorsEnabled = true; const std::vector kDefaultEnabledOrigins{ "http://localhost:39281", "http://127.0.0.1:39281", "http://0.0.0.0:39281"}; constexpr const auto kDefaultNoProxy = "example.com,::1,localhost,127.0.0.1"; +const std::vector kDefaultSupportedEngines{ + kLlamaEngine, kOnnxEngine, kTrtLlmEngine}; struct CortexConfig { std::string logFolderPath; @@ -57,6 +60,7 @@ struct CortexConfig { bool verifyHostSsl; std::string sslCertPath; std::string sslKeyPath; + std::vector supportedEngines; }; class CortexConfigMgr { @@ -80,5 +84,4 @@ class CortexConfigMgr { CortexConfig FromYaml(const std::string& path, const CortexConfig& default_cfg); }; - } // namespace config_yaml_utils diff --git a/engine/utils/file_manager_utils.cc b/engine/utils/file_manager_utils.cc index ca3d0c07b..338abadac 100644 --- a/engine/utils/file_manager_utils.cc +++ b/engine/utils/file_manager_utils.cc @@ -187,6 +187,7 @@ config_yaml_utils::CortexConfig GetDefaultConfig() { .verifyHostSsl = true, .sslCertPath = "", .sslKeyPath = "", + .supportedEngines = config_yaml_utils::kDefaultSupportedEngines, }; } From 4a839b4d14f8c51d1e95598ea552ecc8bdfd0394 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 10 Dec 2024 19:43:53 +0700 Subject: [PATCH 24/44] fix: stop inflight chat completion (#1765) * fix: stop inflight chat completion * chore: bypass docker e2e test * fix: comments --------- Co-authored-by: vansangpfiev --- engine/controllers/server.cc | 22 ++++- engine/controllers/server.h | 4 +- engine/cortex-common/EngineI.h | 3 +- engine/e2e-test/test_api_docker.py | 67 +++++++-------- engine/services/inference_service.cc | 119 ++++++++++++++------------- engine/services/inference_service.h | 5 +- 6 files changed, 124 insertions(+), 96 deletions(-) diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index 4bec96f76..a9920e8aa 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -3,6 +3,7 @@ #include "trantor/utils/Logger.h" #include "utils/cortex_utils.h" #include "utils/function_calling/common.h" +#include "utils/http_util.h" using namespace inferences; @@ -27,6 +28,15 @@ void server::ChatCompletion( LOG_DEBUG << "Start chat completion"; auto json_body = req->getJsonObject(); bool is_stream = (*json_body).get("stream", false).asBool(); + auto model_id = (*json_body).get("model", "invalid_model").asString(); + auto engine_type = [this, &json_body]() -> std::string { + if (!inference_svc_->HasFieldInReq(json_body, "engine")) { + return kLlamaRepo; + } else { + return (*(json_body)).get("engine", kLlamaRepo).asString(); + } + }(); + LOG_DEBUG << "request body: " << json_body->toStyledString(); auto q = std::make_shared(); auto ir = inference_svc_->HandleChatCompletion(q, json_body); @@ -40,7 +50,7 @@ void server::ChatCompletion( } LOG_DEBUG << "Wait to chat completion responses"; if (is_stream) { - ProcessStreamRes(std::move(callback), q); + ProcessStreamRes(std::move(callback), q, engine_type, model_id); } else { ProcessNonStreamRes(std::move(callback), *q); } @@ -121,12 +131,16 @@ void server::LoadModel(const HttpRequestPtr& req, } void server::ProcessStreamRes(std::function cb, - std::shared_ptr q) { + std::shared_ptr q, + const std::string& engine_type, + const std::string& model_id) { auto err_or_done = std::make_shared(false); - auto chunked_content_provider = - [q, err_or_done](char* buf, std::size_t buf_size) -> std::size_t { + auto chunked_content_provider = [this, q, err_or_done, engine_type, model_id]( + char* buf, + std::size_t buf_size) -> std::size_t { if (buf == nullptr) { LOG_TRACE << "Buf is null"; + inference_svc_->StopInferencing(engine_type, model_id); return 0; } diff --git a/engine/controllers/server.h b/engine/controllers/server.h index 5d6b8ded4..22ea86c30 100644 --- a/engine/controllers/server.h +++ b/engine/controllers/server.h @@ -72,7 +72,9 @@ class server : public drogon::HttpController, private: void ProcessStreamRes(std::function cb, - std::shared_ptr q); + std::shared_ptr q, + const std::string& engine_type, + const std::string& model_id); void ProcessNonStreamRes(std::function cb, services::SyncQueue& q); diff --git a/engine/cortex-common/EngineI.h b/engine/cortex-common/EngineI.h index 11866a708..b456cb109 100644 --- a/engine/cortex-common/EngineI.h +++ b/engine/cortex-common/EngineI.h @@ -68,5 +68,6 @@ class EngineI { const std::string& log_path) = 0; virtual void SetLogLevel(trantor::Logger::LogLevel logLevel) = 0; - virtual Json::Value GetRemoteModels() = 0; + // Stop inflight chat completion in stream mode + virtual void StopInferencing(const std::string& model_id) = 0; }; diff --git a/engine/e2e-test/test_api_docker.py b/engine/e2e-test/test_api_docker.py index 6856e05f4..b46b1f782 100644 --- a/engine/e2e-test/test_api_docker.py +++ b/engine/e2e-test/test_api_docker.py @@ -40,38 +40,39 @@ async def test_models_on_cortexso_hub(self, model_url): assert response.status_code == 200 models = [i["id"] for i in response.json()["data"]] assert model_url in models, f"Model not found in list: {model_url}" + + # TODO(sang) bypass for now. Re-enable when we publish new stable version for llama-cpp engine + # print("Start the model") + # # Start the model + # response = requests.post( + # "http://localhost:3928/v1/models/start", json=json_body + # ) + # print(response.json()) + # assert response.status_code == 200, f"status_code: {response.status_code}" - print("Start the model") - # Start the model - response = requests.post( - "http://localhost:3928/v1/models/start", json=json_body - ) - print(response.json()) - assert response.status_code == 200, f"status_code: {response.status_code}" - - print("Send an inference request") - # Send an inference request - inference_json_body = { - "frequency_penalty": 0.2, - "max_tokens": 4096, - "messages": [{"content": "", "role": "user"}], - "model": model_url, - "presence_penalty": 0.6, - "stop": ["End"], - "stream": False, - "temperature": 0.8, - "top_p": 0.95, - } - response = requests.post( - "http://localhost:3928/v1/chat/completions", - json=inference_json_body, - headers={"Content-Type": "application/json"}, - ) - assert ( - response.status_code == 200 - ), f"status_code: {response.status_code} response: {response.json()}" + # print("Send an inference request") + # # Send an inference request + # inference_json_body = { + # "frequency_penalty": 0.2, + # "max_tokens": 4096, + # "messages": [{"content": "", "role": "user"}], + # "model": model_url, + # "presence_penalty": 0.6, + # "stop": ["End"], + # "stream": False, + # "temperature": 0.8, + # "top_p": 0.95, + # } + # response = requests.post( + # "http://localhost:3928/v1/chat/completions", + # json=inference_json_body, + # headers={"Content-Type": "application/json"}, + # ) + # assert ( + # response.status_code == 200 + # ), f"status_code: {response.status_code} response: {response.json()}" - print("Stop the model") - # Stop the model - response = requests.post("http://localhost:3928/v1/models/stop", json=json_body) - assert response.status_code == 200, f"status_code: {response.status_code}" + # print("Stop the model") + # # Stop the model + # response = requests.post("http://localhost:3928/v1/models/stop", json=json_body) + # assert response.status_code == 200, f"status_code: {response.status_code}" diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index ace7e675f..91cb277dc 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -24,24 +24,18 @@ cpp::result InferenceService::HandleChatCompletion( return cpp::fail(std::make_pair(stt, res)); } + auto cb = [q, tool_choice](Json::Value status, Json::Value res) { + if (!tool_choice.isNull()) { + res["tool_choice"] = tool_choice; + } + q->push(std::make_pair(status, res)); + }; if (std::holds_alternative(engine_result.value())) { std::get(engine_result.value()) - ->HandleChatCompletion( - json_body, [q, tool_choice](Json::Value status, Json::Value res) { - if (!tool_choice.isNull()) { - res["tool_choice"] = tool_choice; - } - q->push(std::make_pair(status, res)); - }); + ->HandleChatCompletion(json_body, std::move(cb)); } else { std::get(engine_result.value()) - ->HandleChatCompletion( - json_body, [q, tool_choice](Json::Value status, Json::Value res) { - if (!tool_choice.isNull()) { - res["tool_choice"] = tool_choice; - } - q->push(std::make_pair(status, res)); - }); + ->HandleChatCompletion(json_body, std::move(cb)); } return {}; @@ -66,16 +60,15 @@ cpp::result InferenceService::HandleEmbedding( return cpp::fail(std::make_pair(stt, res)); } + auto cb = [q](Json::Value status, Json::Value res) { + q->push(std::make_pair(status, res)); + }; if (std::holds_alternative(engine_result.value())) { std::get(engine_result.value()) - ->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) { - q->push(std::make_pair(status, res)); - }); + ->HandleEmbedding(json_body, std::move(cb)); } else { std::get(engine_result.value()) - ->HandleEmbedding(json_body, [q](Json::Value status, Json::Value res) { - q->push(std::make_pair(status, res)); - }); + ->HandleEmbedding(json_body, std::move(cb)); } return {}; } @@ -104,18 +97,16 @@ InferResult InferenceService::LoadModel( // might need mutex here auto engine_result = engine_service_->GetLoadedEngine(engine_type); + auto cb = [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }; if (std::holds_alternative(engine_result.value())) { std::get(engine_result.value()) - ->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) { - stt = status; - r = res; - }); + ->LoadModel(json_body, std::move(cb)); } else { std::get(engine_result.value()) - ->LoadModel(json_body, [&stt, &r](Json::Value status, Json::Value res) { - stt = status; - r = res; - }); + ->LoadModel(json_body, std::move(cb)); } return std::make_pair(stt, r); } @@ -139,20 +130,16 @@ InferResult InferenceService::UnloadModel(const std::string& engine_name, json_body["model"] = model_id; LOG_TRACE << "Start unload model"; + auto cb = [&r, &stt](Json::Value status, Json::Value res) { + stt = status; + r = res; + }; if (std::holds_alternative(engine_result.value())) { std::get(engine_result.value()) - ->UnloadModel(std::make_shared(json_body), - [&r, &stt](Json::Value status, Json::Value res) { - stt = status; - r = res; - }); + ->UnloadModel(std::make_shared(json_body), std::move(cb)); } else { std::get(engine_result.value()) - ->UnloadModel(std::make_shared(json_body), - [&r, &stt](Json::Value status, Json::Value res) { - stt = status; - r = res; - }); + ->UnloadModel(std::make_shared(json_body), std::move(cb)); } return std::make_pair(stt, r); @@ -181,20 +168,16 @@ InferResult InferenceService::GetModelStatus( LOG_TRACE << "Start to get model status"; + auto cb = [&stt, &r](Json::Value status, Json::Value res) { + stt = status; + r = res; + }; if (std::holds_alternative(engine_result.value())) { std::get(engine_result.value()) - ->GetModelStatus(json_body, - [&stt, &r](Json::Value status, Json::Value res) { - stt = status; - r = res; - }); + ->GetModelStatus(json_body, std::move(cb)); } else { std::get(engine_result.value()) - ->GetModelStatus(json_body, - [&stt, &r](Json::Value status, Json::Value res) { - stt = status; - r = res; - }); + ->GetModelStatus(json_body, std::move(cb)); } return std::make_pair(stt, r); @@ -214,15 +197,20 @@ InferResult InferenceService::GetModels( LOG_TRACE << "Start to get models"; Json::Value resp_data(Json::arrayValue); + auto cb = [&resp_data](Json::Value status, Json::Value res) { + for (auto r : res["data"]) { + resp_data.append(r); + } + }; for (const auto& loaded_engine : loaded_engines) { - auto e = std::get(loaded_engine); - if (e->IsSupported("GetModels")) { - e->GetModels(json_body, - [&resp_data](Json::Value status, Json::Value res) { - for (auto r : res["data"]) { - resp_data.append(r); - } - }); + if (std::holds_alternative(loaded_engine)) { + auto e = std::get(loaded_engine); + if (e->IsSupported("GetModels")) { + e->GetModels(json_body, std::move(cb)); + } + } else { + std::get(loaded_engine) + ->GetModels(json_body, std::move(cb)); } } @@ -283,6 +271,25 @@ InferResult InferenceService::FineTuning( return std::make_pair(stt, r); } +bool InferenceService::StopInferencing(const std::string& engine_name, + const std::string& model_id) { + CTL_DBG("Stop inferencing"); + auto engine_result = engine_service_->GetLoadedEngine(engine_name); + if (engine_result.has_error()) { + LOG_WARN << "Engine is not loaded yet"; + return false; + } + + if (std::holds_alternative(engine_result.value())) { + auto engine = std::get(engine_result.value()); + if (engine->IsSupported("StopInferencing")) { + engine->StopInferencing(model_id); + CTL_INF("Stopped inferencing"); + } + } + return true; +} + bool InferenceService::HasFieldInReq(std::shared_ptr json_body, const std::string& field) { if (!json_body || (*json_body)[field].isNull()) { diff --git a/engine/services/inference_service.h b/engine/services/inference_service.h index 94097132a..b417fa14a 100644 --- a/engine/services/inference_service.h +++ b/engine/services/inference_service.h @@ -52,10 +52,13 @@ class InferenceService { InferResult FineTuning(std::shared_ptr json_body); - private: + bool StopInferencing(const std::string& engine_name, + const std::string& model_id); + bool HasFieldInReq(std::shared_ptr json_body, const std::string& field); + private: std::shared_ptr engine_service_; }; } // namespace services From 2ee1e814da6d6b708c601036e6a893750bfa8e28 Mon Sep 17 00:00:00 2001 From: hiento09 <136591877+hiento09@users.noreply.github.com> Date: Wed, 11 Dec 2024 15:51:30 +0700 Subject: [PATCH 25/44] feat: macos 12 arm64 (#1791) Co-authored-by: Hien To --- .github/workflows/cortex-cpp-quality-gate.yml | 2 +- .github/workflows/template-build-macos.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cortex-cpp-quality-gate.yml b/.github/workflows/cortex-cpp-quality-gate.yml index 316160ce5..8a76e4669 100644 --- a/.github/workflows/cortex-cpp-quality-gate.yml +++ b/.github/workflows/cortex-cpp-quality-gate.yml @@ -34,7 +34,7 @@ jobs: ccache-dir: "" - os: "mac" name: "arm64" - runs-on: "macos-silicon" + runs-on: "macos-selfhosted-12-arm64" cmake-flags: "-DCORTEX_CPP_VERSION=${{github.event.pull_request.head.sha}} -DCMAKE_BUILD_TEST=ON -DMAC_ARM64=ON -DCMAKE_TOOLCHAIN_FILE=vcpkg/scripts/buildsystems/vcpkg.cmake" build-deps-cmake-flags: "" ccache-dir: "" diff --git a/.github/workflows/template-build-macos.yml b/.github/workflows/template-build-macos.yml index 371468dfb..ae10fb675 100644 --- a/.github/workflows/template-build-macos.yml +++ b/.github/workflows/template-build-macos.yml @@ -82,7 +82,7 @@ jobs: matrix: include: - arch: 'arm64' - runs-on: 'macos-silicon' + runs-on: 'macos-selfhosted-12-arm64' extra-cmake-flags: "-DMAC_ARM64=ON" - arch: 'amd64' From 8dde05cc6963f0d9abf24b8f6c55eb9b4de52d3e Mon Sep 17 00:00:00 2001 From: NamH Date: Thu, 12 Dec 2024 09:20:40 +0700 Subject: [PATCH 26/44] feat: add files api (#1781) * feat: add files api * add backward support * add db support * fix link issue on windows --- engine/common/file.h | 71 ++++++ engine/common/message.h | 69 +++++- engine/common/repository/file_repository.h | 29 +++ engine/controllers/files.cc | 269 +++++++++++++++++++++ engine/controllers/files.h | 62 +++++ engine/database/file.cc | 96 ++++++++ engine/database/file.h | 31 +++ engine/database/models.h | 10 +- engine/main.cc | 23 +- engine/migrations/migration_manager.cc | 17 +- engine/migrations/migration_manager.h | 4 +- engine/migrations/schema_version.h | 3 +- engine/migrations/v3/migration.h | 73 ++++++ engine/repositories/file_fs_repository.cc | 169 +++++++++++++ engine/repositories/file_fs_repository.h | 50 ++++ engine/services/file_service.cc | 55 +++++ engine/services/file_service.h | 40 +++ 17 files changed, 1046 insertions(+), 25 deletions(-) create mode 100644 engine/common/file.h create mode 100644 engine/common/repository/file_repository.h create mode 100644 engine/controllers/files.cc create mode 100644 engine/controllers/files.h create mode 100644 engine/database/file.cc create mode 100644 engine/database/file.h create mode 100644 engine/migrations/v3/migration.h create mode 100644 engine/repositories/file_fs_repository.cc create mode 100644 engine/repositories/file_fs_repository.h create mode 100644 engine/services/file_service.cc create mode 100644 engine/services/file_service.h diff --git a/engine/common/file.h b/engine/common/file.h new file mode 100644 index 000000000..3096023c5 --- /dev/null +++ b/engine/common/file.h @@ -0,0 +1,71 @@ +#pragma once + +#include +#include "common/json_serializable.h" + +namespace OpenAi { +/** + * The File object represents a document that has been uploaded to OpenAI. + */ +struct File : public JsonSerializable { + /** + * The file identifier, which can be referenced in the API endpoints. + */ + std::string id; + + /** + * The object type, which is always file. + */ + std::string object = "file"; + + /** + * The size of the file, in bytes. + */ + uint64_t bytes; + + /** + * The Unix timestamp (in seconds) for when the file was created. + */ + uint32_t created_at; + + /** + * The name of the file. + */ + std::string filename; + + /** + * The intended purpose of the file. Supported values are assistants, + * assistants_output, batch, batch_output, fine-tune, fine-tune-results + * and vision. + */ + std::string purpose; + + ~File() = default; + + static cpp::result FromJson(const Json::Value& json) { + File file; + + file.id = std::move(json["id"].asString()); + file.object = "file"; + file.bytes = json["bytes"].asUInt64(); + file.created_at = json["created_at"].asUInt(); + file.filename = std::move(json["filename"].asString()); + file.purpose = std::move(json["purpose"].asString()); + + return file; + } + + cpp::result ToJson() { + Json::Value root; + + root["id"] = id; + root["object"] = object; + root["bytes"] = bytes; + root["created_at"] = created_at; + root["filename"] = filename; + root["purpose"] = purpose; + + return root; + } +}; +} // namespace OpenAi diff --git a/engine/common/message.h b/engine/common/message.h index 909a843ee..3bff6f048 100644 --- a/engine/common/message.h +++ b/engine/common/message.h @@ -19,6 +19,20 @@ namespace OpenAi { +inline std::string ExtractFileId(const std::string& path) { + // Handle both forward and backward slashes + auto last_slash = path.find_last_of("/\\"); + if (last_slash == std::string::npos) + return ""; + + auto filename = path.substr(last_slash + 1); + auto dot_pos = filename.find('.'); + if (dot_pos == std::string::npos) + return ""; + + return filename.substr(0, dot_pos); +} + // Represents a message within a thread. struct Message : JsonSerializable { Message() = default; @@ -70,6 +84,12 @@ struct Message : JsonSerializable { // Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maximum of 512 characters long. Cortex::VariantMap metadata; + // deprecated. remove in the future + std::optional attach_filename; + std::optional size; + std::optional rel_path; + // end deprecated + static cpp::result FromJsonString( std::string&& json_str) { Json::Value root; @@ -98,7 +118,6 @@ struct Message : JsonSerializable { message.completed_at = root["completed_at"].asUInt(); message.incomplete_at = root["incomplete_at"].asUInt(); message.role = RoleFromString(std::move(root["role"].asString())); - message.content = ParseContents(std::move(root["content"])).value(); message.assistant_id = std::move(root["assistant_id"].asString()); message.run_id = std::move(root["run_id"].asString()); @@ -114,6 +133,54 @@ struct Message : JsonSerializable { } } + if (root.isMember("content")) { + if (root["content"].isArray() && !root["content"].empty()) { + if (root["content"][0]["type"].asString() == "text") { + message.content = ParseContents(std::move(root["content"])).value(); + } else { + // deprecated, for supporting jan and should be removed in the future + // check if annotations is empty + if (!root["content"][0]["text"]["annotations"].empty()) { + // parse attachment + Json::Value attachments_json_array{Json::arrayValue}; + Json::Value attachment; + attachment["file_id"] = ExtractFileId( + root["content"][0]["text"]["annotations"][0].asString()); + + Json::Value tools_json_array{Json::arrayValue}; + Json::Value tool; + tool["type"] = "file_search"; + tools_json_array.append(tool); + + attachment["tools"] = tools_json_array; + attachment["file_id"] = attachments_json_array.append(attachment); + + message.attachments = + ParseAttachments(std::move(attachments_json_array)).value(); + + message.attach_filename = + root["content"][0]["text"]["name"].asString(); + message.size = root["content"][0]["text"]["size"].asUInt64(); + message.rel_path = + root["content"][0]["text"]["annotations"][0].asString(); + } + + // parse content + Json::Value contents_json_array{Json::arrayValue}; + Json::Value content; + Json::Value content_text; + Json::Value empty_annotations{Json::arrayValue}; + content["type"] = "text"; + content_text["value"] = root["content"][0]["text"]["value"]; + content_text["annotations"] = empty_annotations; + content["text"] = content_text; + contents_json_array.append(content); + message.content = + ParseContents(std::move(contents_json_array)).value(); + } + } + } + return message; } catch (const std::exception& e) { return cpp::fail(std::string("FromJsonString failed: ") + e.what()); diff --git a/engine/common/repository/file_repository.h b/engine/common/repository/file_repository.h new file mode 100644 index 000000000..f574b76d0 --- /dev/null +++ b/engine/common/repository/file_repository.h @@ -0,0 +1,29 @@ +#pragma once + +#include "common/file.h" +#include "utils/result.hpp" + +class FileRepository { + public: + virtual cpp::result StoreFile(OpenAi::File& file_metadata, + const char* content, + uint64_t length) = 0; + + virtual cpp::result, std::string> ListFiles( + const std::string& purpose, uint8_t limit, const std::string& order, + const std::string& after) const = 0; + + virtual cpp::result RetrieveFile( + const std::string file_id) const = 0; + + virtual cpp::result, size_t>, std::string> + RetrieveFileContent(const std::string& file_id) const = 0; + + virtual cpp::result, size_t>, std::string> + RetrieveFileContentByPath(const std::string& path) const = 0; + + virtual cpp::result DeleteFileLocal( + const std::string& file_id) = 0; + + virtual ~FileRepository() = default; +}; diff --git a/engine/controllers/files.cc b/engine/controllers/files.cc new file mode 100644 index 000000000..e0cd502f4 --- /dev/null +++ b/engine/controllers/files.cc @@ -0,0 +1,269 @@ +#include "files.h" +#include "common/api-dto/delete_success_response.h" +#include "utils/cortex_utils.h" +#include "utils/logging_utils.h" + +void Files::UploadFile(const HttpRequestPtr& req, + std::function&& callback) { + MultiPartParser parser; + if (parser.parse(req) != 0 || parser.getFiles().size() != 1) { + Json::Value root; + root["message"] = "Must only be one file"; + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k400BadRequest); + callback(response); + return; + } + + auto params = parser.getParameters(); + if (params.find("purpose") == params.end()) { + Json::Value root; + root["message"] = "purpose is mandatory"; + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k400BadRequest); + callback(response); + return; + } + + auto purpose = params["purpose"]; + if (std::find(file_service_->kSupportedPurposes.begin(), + file_service_->kSupportedPurposes.end(), + purpose) == file_service_->kSupportedPurposes.end()) { + Json::Value root; + root["message"] = + "purpose is not supported. Purpose can only one of these types: " + "assistants, vision, batch or fine-tune"; + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k400BadRequest); + callback(response); + return; + } + + const auto& file = parser.getFiles()[0]; + auto result = + file_service_->UploadFile(file.getFileName(), purpose, + file.fileContent().data(), file.fileLength()); + + if (result.has_error()) { + Json::Value ret; + ret["message"] = result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(result->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + } +} + +void Files::ListFiles(const HttpRequestPtr& req, + std::function&& callback, + std::optional purpose, + std::optional limit, + std::optional order, + std::optional after) const { + auto res = file_service_->ListFiles( + purpose.value_or(""), std::stoi(limit.value_or("20")), + order.value_or("desc"), after.value_or("")); + if (res.has_error()) { + Json::Value root; + root["message"] = res.error(); + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k400BadRequest); + callback(response); + return; + } + + Json::Value msg_arr(Json::arrayValue); + for (auto& msg : res.value()) { + if (auto it = msg.ToJson(); it.has_value()) { + msg_arr.append(it.value()); + } else { + CTL_WRN("Failed to convert message to json: " + it.error()); + } + } + + Json::Value root; + root["object"] = "list"; + root["data"] = msg_arr; + auto response = cortex_utils::CreateCortexHttpJsonResponse(root); + response->setStatusCode(k200OK); + callback(response); +} + +void Files::RetrieveFile(const HttpRequestPtr& req, + std::function&& callback, + const std::string& file_id, + std::optional thread_id) const { + // this code part is for backward compatible. remove it later on + if (thread_id.has_value()) { + auto msg_res = + message_service_->RetrieveMessage(thread_id.value(), file_id); + if (msg_res.has_error()) { + Json::Value ret; + ret["message"] = msg_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + if (msg_res->attachments->empty()) { + auto res = file_service_->RetrieveFile(file_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); + return; + } else { + if (!msg_res->attach_filename.has_value() || !msg_res->size.has_value()) { + Json::Value ret; + ret["message"] = "File not found or had been removed!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k404NotFound); + callback(resp); + return; + } + + Json::Value ret; + ret["object"] = "file"; + ret["created_at"] = msg_res->created_at; + ret["filename"] = msg_res->attach_filename.value(); + ret["bytes"] = msg_res->size.value(); + ret["id"] = msg_res->id; + ret["purpose"] = "assistants"; + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + return; + } + } + + auto res = file_service_->RetrieveFile(file_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} + +void Files::DeleteFile(const HttpRequestPtr& req, + std::function&& callback, + const std::string& file_id) { + auto res = file_service_->DeleteFileLocal(file_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + api_response::DeleteSuccessResponse response; + response.id = file_id; + response.object = "file"; + response.deleted = true; + auto resp = + cortex_utils::CreateCortexHttpJsonResponse(response.ToJson().value()); + resp->setStatusCode(k200OK); + callback(resp); +} + +void Files::RetrieveFileContent( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& file_id, std::optional thread_id) { + if (thread_id.has_value()) { + auto msg_res = + message_service_->RetrieveMessage(thread_id.value(), file_id); + if (msg_res.has_error()) { + Json::Value ret; + ret["message"] = msg_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + if (msg_res->attachments->empty()) { + auto res = file_service_->RetrieveFileContent(file_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto [buffer, size] = std::move(res.value()); + auto resp = HttpResponse::newHttpResponse(); + resp->setBody(std::string(buffer.get(), size)); + resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM); + callback(resp); + } else { + if (!msg_res->rel_path.has_value()) { + Json::Value ret; + ret["message"] = "File not found or had been removed"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto content_res = + file_service_->RetrieveFileContentByPath(msg_res->rel_path.value()); + + if (content_res.has_error()) { + Json::Value ret; + ret["message"] = content_res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto [buffer, size] = std::move(content_res.value()); + auto resp = HttpResponse::newHttpResponse(); + resp->setBody(std::string(buffer.get(), size)); + resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM); + callback(resp); + } + } + + auto res = file_service_->RetrieveFileContent(file_id); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto [buffer, size] = std::move(res.value()); + auto resp = HttpResponse::newHttpResponse(); + resp->setBody(std::string(buffer.get(), size)); + resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM); + callback(resp); +} diff --git a/engine/controllers/files.h b/engine/controllers/files.h new file mode 100644 index 000000000..efd7f6d93 --- /dev/null +++ b/engine/controllers/files.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include +#include +#include "services/file_service.h" +#include "services/message_service.h" + +using namespace drogon; + +class Files : public drogon::HttpController { + public: + METHOD_LIST_BEGIN + ADD_METHOD_TO(Files::UploadFile, "/v1/files", Options, Post); + + ADD_METHOD_TO(Files::RetrieveFile, "/v1/files/{file_id}?thread={thread_id}", + Get); + + ADD_METHOD_TO( + Files::ListFiles, + "/v1/files?purpose={purpose}&limit={limit}&order={order}&after={after}", + Get); + + ADD_METHOD_TO(Files::DeleteFile, "/v1/files/{file_id}", Options, Delete); + + ADD_METHOD_TO(Files::RetrieveFileContent, + "/v1/files/{file_id}/content?thread={thread_id}", Get); + + METHOD_LIST_END + + explicit Files(std::shared_ptr file_service, + std::shared_ptr msg_service) + : file_service_{file_service}, message_service_{msg_service} {} + + void UploadFile(const HttpRequestPtr& req, + std::function&& callback); + + void ListFiles(const HttpRequestPtr& req, + std::function&& callback, + std::optional purpose, + std::optional limit, + std::optional order, + std::optional after) const; + + void RetrieveFile(const HttpRequestPtr& req, + std::function&& callback, + const std::string& file_id, + std::optional thread_id) const; + + void DeleteFile(const HttpRequestPtr& req, + std::function&& callback, + const std::string& file_id); + + void RetrieveFileContent( + const HttpRequestPtr& req, + std::function&& callback, + const std::string& file_id, std::optional thread_id); + + private: + std::shared_ptr file_service_; + std::shared_ptr message_service_; +}; diff --git a/engine/database/file.cc b/engine/database/file.cc new file mode 100644 index 000000000..3f9a37b98 --- /dev/null +++ b/engine/database/file.cc @@ -0,0 +1,96 @@ +#include "file.h" +#include "utils/logging_utils.h" +#include "utils/scope_exit.h" + +namespace cortex::db { + +cpp::result, std::string> File::GetFileList() const { + try { + db_.exec("BEGIN TRANSACTION;"); + cortex::utils::ScopeExit se([this] { db_.exec("COMMIT;"); }); + std::vector entries; + SQLite::Statement query(db_, + "SELECT id, object, " + "purpose, filename, created_at, bytes FROM files"); + + while (query.executeStep()) { + OpenAi::File entry; + entry.id = query.getColumn(0).getString(); + entry.object = query.getColumn(1).getString(); + entry.purpose = query.getColumn(2).getString(); + entry.filename = query.getColumn(3).getString(); + entry.created_at = query.getColumn(4).getInt(); + entry.bytes = query.getColumn(5).getInt(); + entries.push_back(entry); + } + return entries; + } catch (const std::exception& e) { + CTL_WRN(e.what()); + return cpp::fail(e.what()); + } +} + +cpp::result File::GetFileById( + const std::string& file_id) const { + try { + SQLite::Statement query(db_, + "SELECT id, object, " + "purpose, filename, created_at, bytes FROM files " + "WHERE id = ?"); + + query.bind(1, file_id); + if (query.executeStep()) { + OpenAi::File entry; + entry.id = query.getColumn(0).getString(); + entry.object = query.getColumn(1).getString(); + entry.purpose = query.getColumn(2).getString(); + entry.filename = query.getColumn(3).getString(); + entry.created_at = query.getColumn(4).getInt(); + entry.bytes = query.getColumn(5).getInt64(); + return entry; + } else { + return cpp::fail("File not found: " + file_id); + } + } catch (const std::exception& e) { + return cpp::fail(e.what()); + } +} + +cpp::result File::AddFileEntry(OpenAi::File& file) { + try { + SQLite::Statement insert( + db_, + "INSERT INTO files (id, object, " + "purpose, filename, created_at, bytes) VALUES (?, ?, " + "?, ?, ?, ?)"); + insert.bind(1, file.id); + insert.bind(2, file.object); + insert.bind(3, file.purpose); + insert.bind(4, file.filename); + insert.bind(5, std::to_string(file.created_at)); + insert.bind(6, std::to_string(file.bytes)); + insert.exec(); + + CTL_INF("Inserted: " << file.ToJson()->toStyledString()); + return {}; + } catch (const std::exception& e) { + CTL_WRN(e.what()); + return cpp::fail(e.what()); + } +} + +cpp::result File::DeleteFileEntry( + const std::string& file_id) { + try { + SQLite::Statement del(db_, "DELETE from files WHERE id = ?"); + del.bind(1, file_id); + if (del.exec() == 1) { + CTL_INF("Deleted: " << file_id); + return {}; + } + return {}; + } catch (const std::exception& e) { + return cpp::fail(e.what()); + } +} +} // namespace cortex::db diff --git a/engine/database/file.h b/engine/database/file.h new file mode 100644 index 000000000..be976ecce --- /dev/null +++ b/engine/database/file.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include +#include +#include "common/file.h" +#include "database.h" +#include "utils/result.hpp" + +namespace cortex::db { +class File { + SQLite::Database& db_; + + public: + File(SQLite::Database& db) : db_{db} {}; + + File() : db_(cortex::db::Database::GetInstance().db()) {} + + ~File() {} + + cpp::result, std::string> GetFileList() const; + + cpp::result GetFileById( + const std::string& file_id) const; + + cpp::result AddFileEntry(OpenAi::File& file); + + cpp::result DeleteFileEntry(const std::string& file_id); +}; +} // namespace cortex::db diff --git a/engine/database/models.h b/engine/database/models.h index dd6e2a5a1..5c855cf1b 100644 --- a/engine/database/models.h +++ b/engine/database/models.h @@ -8,14 +8,10 @@ namespace cortex::db { -enum class ModelStatus { - Remote, - Downloaded, - Undownloaded -}; +enum class ModelStatus { Remote, Downloaded, Undownloaded }; struct ModelEntry { - std::string model; + std::string model; std::string author_repo_id; std::string branch_name; std::string path_to_model_yaml; @@ -64,4 +60,4 @@ class Models { bool HasModel(const std::string& identifier) const; }; -} // namespace cortex::db \ No newline at end of file +} // namespace cortex::db diff --git a/engine/main.cc b/engine/main.cc index 93aa3b8e7..5222ac5c2 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -5,6 +5,7 @@ #include "controllers/configs.h" #include "controllers/engines.h" #include "controllers/events.h" +#include "controllers/files.h" #include "controllers/hardware.h" #include "controllers/messages.h" #include "controllers/models.h" @@ -13,6 +14,7 @@ #include "controllers/threads.h" #include "database/database.h" #include "migrations/migration_manager.h" +#include "repositories/file_fs_repository.h" #include "repositories/message_fs_repository.h" #include "repositories/thread_fs_repository.h" #include "services/assistant_service.h" @@ -121,11 +123,13 @@ void RunServer(std::optional port, bool ignore_cout) { auto event_queue_ptr = std::make_shared(); cortex::event::EventProcessor event_processor(event_queue_ptr); - auto msg_repo = std::make_shared( - file_manager_utils::GetCortexDataPath()); - auto thread_repo = std::make_shared( - file_manager_utils::GetCortexDataPath()); + auto data_folder_path = file_manager_utils::GetCortexDataPath(); + auto file_repo = std::make_shared(data_folder_path); + auto msg_repo = std::make_shared(data_folder_path); + auto thread_repo = std::make_shared(data_folder_path); + + auto file_srv = std::make_shared(file_repo); auto assistant_srv = std::make_shared(thread_repo); auto thread_srv = std::make_shared(thread_repo); auto message_srv = std::make_shared(msg_repo); @@ -145,6 +149,7 @@ void RunServer(std::optional port, bool ignore_cout) { file_watcher_srv->start(); // initialize custom controllers + auto file_ctl = std::make_shared(file_srv, message_srv); auto assistant_ctl = std::make_shared(assistant_srv); auto thread_ctl = std::make_shared(thread_srv, message_srv); auto message_ctl = std::make_shared(message_srv); @@ -157,6 +162,7 @@ void RunServer(std::optional port, bool ignore_cout) { std::make_shared(inference_svc, engine_service); auto config_ctl = std::make_shared(config_service); + drogon::app().registerController(file_ctl); drogon::app().registerController(assistant_ctl); drogon::app().registerController(thread_ctl); drogon::app().registerController(message_ctl); @@ -168,9 +174,6 @@ void RunServer(std::optional port, bool ignore_cout) { drogon::app().registerController(hw_ctl); drogon::app().registerController(config_ctl); - auto upload_path = std::filesystem::temp_directory_path() / "cortex-uploads"; - drogon::app().setUploadPath(upload_path.string()); - LOG_INFO << "Server started, listening at: " << config.apiServerHost << ":" << config.apiServerPort; LOG_INFO << "Please load your model"; @@ -185,6 +188,12 @@ void RunServer(std::optional port, bool ignore_cout) { LOG_INFO << "Number of thread is:" << drogon::app().getThreadNum(); drogon::app().disableSigtermHandling(); + // file upload + drogon::app() + .enableCompressedRequest(true) + .setClientMaxBodySize(256 * 1024 * 1024) // Max 256MiB body size + .setClientMaxMemoryBodySize(1024 * 1024); // 1MiB before writing to disk + // CORS drogon::app().registerPostHandlingAdvice( [config_service](const drogon::HttpRequestPtr& req, diff --git a/engine/migrations/migration_manager.cc b/engine/migrations/migration_manager.cc index 6936f45a0..26197115d 100644 --- a/engine/migrations/migration_manager.cc +++ b/engine/migrations/migration_manager.cc @@ -8,6 +8,8 @@ #include "v0/migration.h" #include "v1/migration.h" #include "v2/migration.h" +#include "v3/migration.h" + namespace cortex::migr { namespace { @@ -145,8 +147,8 @@ cpp::result MigrationManager::DoUpFolderStructure( return v1::MigrateFolderStructureUp(); case 2: return v2::MigrateFolderStructureUp(); - - break; + case 3: + return v3::MigrateFolderStructureUp(); default: return true; @@ -161,7 +163,8 @@ cpp::result MigrationManager::DoDownFolderStructure( return v1::MigrateFolderStructureDown(); case 2: return v2::MigrateFolderStructureDown(); - break; + case 3: + return v3::MigrateFolderStructureDown(); default: return true; @@ -198,7 +201,8 @@ cpp::result MigrationManager::DoUpDB(int version) { return v1::MigrateDBUp(db_); case 2: return v2::MigrateDBUp(db_); - break; + case 3: + return v3::MigrateDBUp(db_); default: return true; @@ -213,7 +217,8 @@ cpp::result MigrationManager::DoDownDB(int version) { return v1::MigrateDBDown(db_); case 2: return v2::MigrateDBDown(db_); - break; + case 3: + return v3::MigrateDBDown(db_); default: return true; @@ -247,4 +252,4 @@ cpp::result MigrationManager::UpdateSchemaVersion( return cpp::fail(e.what()); } } -} // namespace cortex::migr \ No newline at end of file +} // namespace cortex::migr diff --git a/engine/migrations/migration_manager.h b/engine/migrations/migration_manager.h index b05a76c26..05fc42693 100644 --- a/engine/migrations/migration_manager.h +++ b/engine/migrations/migration_manager.h @@ -1,6 +1,6 @@ #pragma once + #include "migration_helper.h" -#include "v0/migration.h" namespace cortex::migr { class MigrationManager { @@ -28,4 +28,4 @@ class MigrationManager { MigrationHelper mgr_helper_; SQLite::Database& db_; }; -} // namespace cortex::migr \ No newline at end of file +} // namespace cortex::migr diff --git a/engine/migrations/schema_version.h b/engine/migrations/schema_version.h index 5739040d0..619f3054d 100644 --- a/engine/migrations/schema_version.h +++ b/engine/migrations/schema_version.h @@ -1,5 +1,4 @@ #pragma once //Track the current schema version -#define SCHEMA_VERSION 2 - +#define SCHEMA_VERSION 3 diff --git a/engine/migrations/v3/migration.h b/engine/migrations/v3/migration.h new file mode 100644 index 000000000..3bed802fb --- /dev/null +++ b/engine/migrations/v3/migration.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include +#include "utils/logging_utils.h" +#include "utils/result.hpp" + +namespace cortex::migr::v3 { +inline cpp::result MigrateFolderStructureUp() { + return true; +} + +inline cpp::result MigrateFolderStructureDown() { + // CTL_INF("Folder structure already up to date!"); + return true; +} + +// Database +inline cpp::result MigrateDBUp(SQLite::Database& db) { + try { + db.exec( + "CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER PRIMARY " + "KEY);"); + + // files + { + // Check if the table exists + SQLite::Statement query(db, + "SELECT name FROM sqlite_master WHERE " + "type='table' AND name='files'"); + auto table_exists = query.executeStep(); + + if (!table_exists) { + // Create new table + db.exec( + "CREATE TABLE files (" + "id TEXT PRIMARY KEY," + "object TEXT," + "purpose TEXT," + "filename TEXT," + "created_at INTEGER," + "bytes INTEGER" + ")"); + } + } + + return true; + } catch (const std::exception& e) { + CTL_WRN("Migration up failed: " << e.what()); + return cpp::fail(e.what()); + } +}; + +inline cpp::result MigrateDBDown(SQLite::Database& db) { + try { + // hardware + { + SQLite::Statement query(db, + "SELECT name FROM sqlite_master WHERE " + "type='table' AND name='hardware'"); + auto table_exists = query.executeStep(); + if (table_exists) { + db.exec("DROP TABLE files"); + } + } + + return true; + } catch (const std::exception& e) { + CTL_WRN("Migration down failed: " << e.what()); + return cpp::fail(e.what()); + } +} +}; // namespace cortex::migr::v3 diff --git a/engine/repositories/file_fs_repository.cc b/engine/repositories/file_fs_repository.cc new file mode 100644 index 000000000..b9ab4fec6 --- /dev/null +++ b/engine/repositories/file_fs_repository.cc @@ -0,0 +1,169 @@ +#include "file_fs_repository.h" +#include +#include +#include +#include "database/file.h" +#include "utils/logging_utils.h" +#include "utils/result.hpp" + +std::filesystem::path FileFsRepository::GetFilePath() const { + return data_folder_path_ / kFileContainerFolderName; +} + +cpp::result FileFsRepository::StoreFile( + OpenAi::File& file_metadata, const char* content, uint64_t length) { + auto file_container_path = GetFilePath(); + if (!std::filesystem::exists(file_container_path)) { + std::filesystem::create_directories(file_container_path); + } + + cortex::db::File db; + auto file_full_path = file_container_path / file_metadata.filename; + if (std::filesystem::exists(file_full_path)) { + return cpp::fail("File already exists: " + file_full_path.string()); + } + + try { + std::ofstream file(file_full_path, std::ios::binary); + if (!file) { + return cpp::fail("Failed to open file for writing: " + + file_full_path.string()); + } + + file.write(content, length); + file.flush(); + file.close(); + + auto result = db.AddFileEntry(file_metadata); + if (result.has_error()) { + std::filesystem::remove(file_full_path); + return cpp::fail(result.error()); + } + + return {}; + } catch (const std::exception& e) { + CTL_ERR("Failed to store file: " << e.what()); + return cpp::fail("Failed to write file: " + file_full_path.string() + + ", error: " + e.what()); + } +} + +cpp::result, std::string> FileFsRepository::ListFiles( + const std::string& purpose, uint8_t limit, const std::string& order, + const std::string& after) const { + cortex::db::File db; + auto res = db.GetFileList(); + if (res.has_error()) { + return cpp::fail(res.error()); + } + auto files = res.value(); + + if (order == "desc") { + std::sort(files.begin(), files.end(), + [](const OpenAi::File& a, const OpenAi::File& b) { + return a.id > b.id; + }); + } else { + std::sort(files.begin(), files.end(), + [](const OpenAi::File& a, const OpenAi::File& b) { + return a.id < b.id; + }); + } + + if (limit > 0 && files.size() > limit) { + files.resize(limit); + } + + return files; +} + +cpp::result FileFsRepository::RetrieveFile( + const std::string file_id) const { + CTL_INF("Retrieving file: " + file_id); + + auto file_container_path = GetFilePath(); + cortex::db::File db; + auto res = db.GetFileById(file_id); + if (res.has_error()) { + return cpp::fail(res.error()); + } + + return res.value(); +} + +cpp::result, size_t>, std::string> +FileFsRepository::RetrieveFileContent(const std::string& file_id) const { + auto file_container_path = GetFilePath(); + auto file_metadata = RetrieveFile(file_id); + if (file_metadata.has_error()) { + return cpp::fail(file_metadata.error()); + } + auto file_path = file_container_path / file_metadata->filename; + if (!std::filesystem::exists(file_path)) { + return cpp::fail("File content not found: " + file_path.string()); + } + size_t size = std::filesystem::file_size(file_path); + auto buffer = std::make_unique(size); + std::ifstream file(file_path, std::ios::binary); + if (!file.read(buffer.get(), size)) { + return cpp::fail("Failed to read file: " + file_path.string()); + } + + return std::make_pair(std::move(buffer), size); +} + +cpp::result, size_t>, std::string> +FileFsRepository::RetrieveFileContentByPath(const std::string& path) const { + auto file_path = data_folder_path_ / path; + if (!std::filesystem::exists(file_path)) { + return cpp::fail("File not found: " + path); + } + + try { + size_t size = std::filesystem::file_size(file_path); + auto buffer = std::make_unique(size); + + std::ifstream file(file_path, std::ios::binary); + if (!file.read(buffer.get(), size)) { + return cpp::fail("Failed to read file: " + file_path.string()); + } + + return std::make_pair(std::move(buffer), size); + } catch (const std::exception& e) { + CTL_ERR("Failed to retrieve file content: " << e.what()); + return cpp::fail("Failed to retrieve file content"); + } +} + +cpp::result FileFsRepository::DeleteFileLocal( + const std::string& file_id) { + CTL_INF("Deleting file: " + file_id); + auto file_container_path = GetFilePath(); + cortex::db::File db; + auto file_metadata = db.GetFileById(file_id); + if (file_metadata.has_error()) { + return cpp::fail(file_metadata.error()); + } + + auto file_path = file_container_path / file_metadata->filename; + + auto res = db.DeleteFileEntry(file_id); + if (res.has_error()) { + CTL_ERR("Failed to delete file entry: " << res.error()); + return cpp::fail(res.error()); + } + + if (!std::filesystem::exists(file_path)) { + CTL_INF("File not found: " + file_path.string()); + return {}; + } + + try { + std::filesystem::remove_all(file_path); + return {}; + } catch (const std::exception& e) { + CTL_ERR("Failed to delete file: " << e.what()); + return cpp::fail("Failed to delete file: " + file_container_path.string() + + ", error: " + e.what()); + } +} diff --git a/engine/repositories/file_fs_repository.h b/engine/repositories/file_fs_repository.h new file mode 100644 index 000000000..974e81fa4 --- /dev/null +++ b/engine/repositories/file_fs_repository.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include "common/repository/file_repository.h" +#include "utils/logging_utils.h" + +class FileFsRepository : public FileRepository { + public: + constexpr static auto kFileContainerFolderName = "files"; + + cpp::result StoreFile(OpenAi::File& file_metadata, + const char* content, + uint64_t length) override; + + cpp::result, std::string> ListFiles( + const std::string& purpose, uint8_t limit, const std::string& order, + const std::string& after) const override; + + cpp::result RetrieveFile( + const std::string file_id) const override; + + cpp::result, size_t>, std::string> + RetrieveFileContent(const std::string& file_id) const override; + + cpp::result, size_t>, std::string> + RetrieveFileContentByPath(const std::string& path) const override; + + cpp::result DeleteFileLocal( + const std::string& file_id) override; + + explicit FileFsRepository(std::filesystem::path data_folder_path) + : data_folder_path_{data_folder_path} { + CTL_INF("Constructing FileFsRepository.."); + auto file_container_path = data_folder_path_ / kFileContainerFolderName; + + if (!std::filesystem::exists(file_container_path)) { + std::filesystem::create_directories(file_container_path); + } + } + + ~FileFsRepository() = default; + + private: + std::filesystem::path GetFilePath() const; + + /** + * The path to the data folder. + */ + std::filesystem::path data_folder_path_; +}; diff --git a/engine/services/file_service.cc b/engine/services/file_service.cc new file mode 100644 index 000000000..f2514fbfb --- /dev/null +++ b/engine/services/file_service.cc @@ -0,0 +1,55 @@ +#include "file_service.h" +#include +#include "utils/ulid/ulid.hh" + +cpp::result FileService::UploadFile( + const std::string& filename, const std::string& purpose, + const char* content, uint64_t content_length) { + + auto seconds_since_epoch = + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + + auto file_id{"file-" + ulid::Marshal(ulid::CreateNowRand())}; + OpenAi::File file; + file.id = file_id; + file.object = "file"; + file.bytes = content_length; + file.created_at = seconds_since_epoch; + file.filename = filename; + file.purpose = purpose; + + auto res = file_repository_->StoreFile(file, content, content_length); + if (res.has_error()) { + return cpp::fail(res.error()); + } + + return file; +} + +cpp::result, std::string> FileService::ListFiles( + const std::string& purpose, uint8_t limit, const std::string& order, + const std::string& after) const { + return file_repository_->ListFiles(purpose, limit, order, after); +} + +cpp::result FileService::RetrieveFile( + const std::string& file_id) const { + return file_repository_->RetrieveFile(file_id); +} + +cpp::result FileService::DeleteFileLocal( + const std::string& file_id) { + return file_repository_->DeleteFileLocal(file_id); +} + +cpp::result, size_t>, std::string> +FileService::RetrieveFileContent(const std::string& file_id) const { + return file_repository_->RetrieveFileContent(file_id); +} + +cpp::result, size_t>, std::string> +FileService::RetrieveFileContentByPath(const std::string& path) const { + return file_repository_->RetrieveFileContentByPath(path); +} diff --git a/engine/services/file_service.h b/engine/services/file_service.h new file mode 100644 index 000000000..397feda20 --- /dev/null +++ b/engine/services/file_service.h @@ -0,0 +1,40 @@ +#pragma once + +#include "common/file.h" +#include "common/repository/file_repository.h" +#include "utils/result.hpp" + +class FileService { + public: + const std::vector kSupportedPurposes{"assistants", "vision", + "batch", "fine-tune"}; + + cpp::result UploadFile(const std::string& filename, + const std::string& purpose, + const char* content, + uint64_t content_length); + + cpp::result, std::string> ListFiles( + const std::string& purpose, uint8_t limit, const std::string& order, + const std::string& after) const; + + cpp::result RetrieveFile( + const std::string& file_id) const; + + cpp::result DeleteFileLocal(const std::string& file_id); + + cpp::result, size_t>, std::string> + RetrieveFileContent(const std::string& file_id) const; + + /** + * For getting file content by **relative** path. + */ + cpp::result, size_t>, std::string> + RetrieveFileContentByPath(const std::string& path) const; + + explicit FileService(std::shared_ptr file_repository) + : file_repository_{file_repository} {} + + private: + std::shared_ptr file_repository_; +}; From f473b0b2d78074d4ebb2e61540de470b62740ea1 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 12 Dec 2024 11:29:11 +0700 Subject: [PATCH 27/44] feat: model sources (#1777) * feat: prioritize GPUs * fix: migrate db * fix: add priority * fix: db * fix: more * feat: model sources * feat: support delete API * feat: cli: support models sources add * feat: cli: model source delete * feat: cli: add model source list * feat: sync cortex.db * chore: cleanup * feat: add metadata for model * fix: migration * chore: unit tests: cleanup * fix: add metadata * fix: pull model * chore: unit tests: update * chore: add e2e tests for models sources * chore: add API docs * chore: rename --------- Co-authored-by: vansangpfiev --- docs/static/openapi/cortex.json | 99 ++++ engine/cli/command_line_parser.cc | 76 ++- engine/cli/command_line_parser.h | 2 + engine/cli/commands/model_list_cmd.cc | 78 +-- engine/cli/commands/model_list_cmd.h | 3 +- engine/cli/commands/model_source_add_cmd.cc | 38 ++ engine/cli/commands/model_source_add_cmd.h | 12 + engine/cli/commands/model_source_del_cmd.cc | 39 ++ engine/cli/commands/model_source_del_cmd.h | 12 + engine/cli/commands/model_source_list_cmd.cc | 56 +++ engine/cli/commands/model_source_list_cmd.h | 11 + engine/controllers/models.cc | 98 +++- engine/controllers/models.h | 25 +- engine/database/models.cc | 222 ++++----- engine/database/models.h | 22 +- engine/e2e-test/test_api_model.py | 15 +- engine/main.cc | 5 +- engine/services/model_service.cc | 107 ++-- engine/services/model_source_service.cc | 493 +++++++++++++++++++ engine/services/model_source_service.h | 53 ++ engine/test/components/test_models_db.cc | 70 +-- engine/utils/huggingface_utils.h | 2 + engine/utils/json_parser_utils.h | 2 +- 23 files changed, 1269 insertions(+), 271 deletions(-) create mode 100644 engine/cli/commands/model_source_add_cmd.cc create mode 100644 engine/cli/commands/model_source_add_cmd.h create mode 100644 engine/cli/commands/model_source_del_cmd.cc create mode 100644 engine/cli/commands/model_source_del_cmd.h create mode 100644 engine/cli/commands/model_source_list_cmd.cc create mode 100644 engine/cli/commands/model_source_list_cmd.h create mode 100644 engine/services/model_source_service.cc create mode 100644 engine/services/model_source_service.h diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 9cdd5c7b4..2ff239ce2 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -807,6 +807,105 @@ "tags": ["Pulling Models"] } }, + "/v1/models/sources": { + "post": { + "summary": "Add a model source", + "description": "User can add a Huggingface Organization or Repository", + "requestBody": { + "required": false, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "source": { + "type": "string", + "description": "The url of model source to add", + "example": "https://huggingface.co/cortexso/tinyllama" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Successful installation", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "example": "Added model source" + } + } + } + } + } + } + }, + "tags": ["Pulling Models"] + }, + "delete": { + "summary": "Remove a model source", + "description": "User can remove a Huggingface Organization or Repository", + "requestBody": { + "required": false, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "source": { + "type": "string", + "description": "The url of model source to remove", + "example": "https://huggingface.co/cortexso/tinyllama" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "Successful uninstallation", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "Removed model source successfully!", + "example": "Removed model source successfully!" + } + } + } + } + } + }, + "400": { + "description": "Bad request", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "error": { + "type": "string", + "description": "Error message describing the issue with the request" + } + } + } + } + } + } + }, + "tags": ["Pulling Models"] + } + }, "/v1/threads": { "post": { "operationId": "ThreadsController_create", diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index 9d5d83ffc..624ccd3dd 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -20,6 +20,9 @@ #include "commands/model_import_cmd.h" #include "commands/model_list_cmd.h" #include "commands/model_pull_cmd.h" +#include "commands/model_source_add_cmd.h" +#include "commands/model_source_del_cmd.h" +#include "commands/model_source_list_cmd.h" #include "commands/model_start_cmd.h" #include "commands/model_stop_cmd.h" #include "commands/model_upd_cmd.h" @@ -253,6 +256,8 @@ void CommandLineParser::SetupModelCommands() { "Display cpu mode"); list_models_cmd->add_flag("--gpu_mode", cml_data_.display_gpu_mode, "Display gpu mode"); + list_models_cmd->add_flag("--available", cml_data_.display_available_model, + "Display available models to download"); list_models_cmd->group(kSubcommands); list_models_cmd->callback([this]() { if (std::exchange(executed_, true)) @@ -261,7 +266,8 @@ void CommandLineParser::SetupModelCommands() { cml_data_.config.apiServerHost, std::stoi(cml_data_.config.apiServerPort), cml_data_.filter, cml_data_.display_engine, cml_data_.display_version, - cml_data_.display_cpu_mode, cml_data_.display_gpu_mode); + cml_data_.display_cpu_mode, cml_data_.display_gpu_mode, + cml_data_.display_available_model); }); auto get_models_cmd = @@ -329,6 +335,74 @@ void CommandLineParser::SetupModelCommands() { std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id, cml_data_.model_path); }); + + auto model_source_cmd = models_cmd->add_subcommand( + "sources", "Subcommands for managing model sources"); + model_source_cmd->usage("Usage:\n" + commands::GetCortexBinary() + + " models sources [options] [subcommand]"); + model_source_cmd->group(kSubcommands); + + model_source_cmd->callback([this, model_source_cmd] { + if (std::exchange(executed_, true)) + return; + if (model_source_cmd->get_subcommands().empty()) { + CLI_LOG(model_source_cmd->help()); + } + }); + + auto model_src_add_cmd = + model_source_cmd->add_subcommand("add", "Add a model source"); + model_src_add_cmd->usage("Usage:\n" + commands::GetCortexBinary() + + " models sources add [model_source]"); + model_src_add_cmd->group(kSubcommands); + model_src_add_cmd->add_option("source", cml_data_.model_src, ""); + model_src_add_cmd->callback([&]() { + if (std::exchange(executed_, true)) + return; + if (cml_data_.model_src.empty()) { + CLI_LOG("[model_source] is required\n"); + CLI_LOG(model_src_add_cmd->help()); + return; + }; + + commands::ModelSourceAddCmd().Exec( + cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), cml_data_.model_src); + }); + + auto model_src_del_cmd = + model_source_cmd->add_subcommand("remove", "Remove a model source"); + model_src_del_cmd->usage("Usage:\n" + commands::GetCortexBinary() + + " models sources remove [model_source]"); + model_src_del_cmd->group(kSubcommands); + model_src_del_cmd->add_option("source", cml_data_.model_src, ""); + model_src_del_cmd->callback([&]() { + if (std::exchange(executed_, true)) + return; + if (cml_data_.model_src.empty()) { + CLI_LOG("[model_source] is required\n"); + CLI_LOG(model_src_del_cmd->help()); + return; + }; + + commands::ModelSourceDelCmd().Exec( + cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), cml_data_.model_src); + }); + + auto model_src_list_cmd = + model_source_cmd->add_subcommand("list", "List all model sources"); + model_src_list_cmd->usage("Usage:\n" + commands::GetCortexBinary() + + " models sources list"); + model_src_list_cmd->group(kSubcommands); + model_src_list_cmd->callback([&]() { + if (std::exchange(executed_, true)) + return; + + commands::ModelSourceListCmd().Exec( + cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort)); + }); } void CommandLineParser::SetupConfigsCommands() { diff --git a/engine/cli/command_line_parser.h b/engine/cli/command_line_parser.h index aec10dcb4..896c026d0 100644 --- a/engine/cli/command_line_parser.h +++ b/engine/cli/command_line_parser.h @@ -66,6 +66,7 @@ class CommandLineParser { bool display_version = false; bool display_cpu_mode = false; bool display_gpu_mode = false; + bool display_available_model = false; std::string filter = ""; std::string log_level = "INFO"; @@ -74,6 +75,7 @@ class CommandLineParser { int port; config_yaml_utils::CortexConfig config; std::unordered_map model_update_options; + std::string model_src; }; CmlData cml_data_; std::unordered_map config_update_opts_; diff --git a/engine/cli/commands/model_list_cmd.cc b/engine/cli/commands/model_list_cmd.cc index 7990563f3..96ff2885d 100644 --- a/engine/cli/commands/model_list_cmd.cc +++ b/engine/cli/commands/model_list_cmd.cc @@ -21,7 +21,7 @@ using Row_t = void ModelListCmd::Exec(const std::string& host, int port, const std::string& filter, bool display_engine, bool display_version, bool display_cpu_mode, - bool display_gpu_mode) { + bool display_gpu_mode, bool available) { // Start server if server is not started yet if (!commands::IsServerAlive(host, port)) { CLI_LOG("Starting server ..."); @@ -73,40 +73,62 @@ void ModelListCmd::Exec(const std::string& host, int port, continue; } - count += 1; + if (available) { + if (v["status"].asString() != "downloadable") { + continue; + } - std::vector row = {std::to_string(count), - v["model"].asString()}; - if (display_engine) { - row.push_back(v["engine"].asString()); - } - if (display_version) { - row.push_back(v["version"].asString()); - } + count += 1; - if (auto& r = v["recommendation"]; !r.isNull()) { - if (display_cpu_mode) { - if (!r["cpu_mode"].isNull()) { - row.push_back("RAM: " + r["cpu_mode"]["ram"].asString() + " MiB"); - } + std::vector row = {std::to_string(count), + v["model"].asString()}; + if (display_engine) { + row.push_back(v["engine"].asString()); + } + if (display_version) { + row.push_back(v["version"].asString()); + } + table.add_row({row.begin(), row.end()}); + } else { + if (v["status"].asString() == "downloadable") { + continue; + } + + count += 1; + + std::vector row = {std::to_string(count), + v["model"].asString()}; + if (display_engine) { + row.push_back(v["engine"].asString()); + } + if (display_version) { + row.push_back(v["version"].asString()); } - if (display_gpu_mode) { - if (!r["gpu_mode"].isNull()) { - std::string s; - s += "ngl: " + r["gpu_mode"][0]["ngl"].asString() + " - "; - s += "context: " + r["gpu_mode"][0]["context_length"].asString() + - " - "; - s += "RAM: " + r["gpu_mode"][0]["ram"].asString() + " MiB - "; - s += "VRAM: " + r["gpu_mode"][0]["vram"].asString() + " MiB - "; - s += "recommended ngl: " + - r["gpu_mode"][0]["recommend_ngl"].asString(); - row.push_back(s); + if (auto& r = v["recommendation"]; !r.isNull()) { + if (display_cpu_mode) { + if (!r["cpu_mode"].isNull()) { + row.push_back("RAM: " + r["cpu_mode"]["ram"].asString() + " MiB"); + } + } + + if (display_gpu_mode) { + if (!r["gpu_mode"].isNull()) { + std::string s; + s += "ngl: " + r["gpu_mode"][0]["ngl"].asString() + " - "; + s += "context: " + r["gpu_mode"][0]["context_length"].asString() + + " - "; + s += "RAM: " + r["gpu_mode"][0]["ram"].asString() + " MiB - "; + s += "VRAM: " + r["gpu_mode"][0]["vram"].asString() + " MiB - "; + s += "recommended ngl: " + + r["gpu_mode"][0]["recommend_ngl"].asString(); + row.push_back(s); + } } } - } - table.add_row({row.begin(), row.end()}); + table.add_row({row.begin(), row.end()}); + } } } diff --git a/engine/cli/commands/model_list_cmd.h b/engine/cli/commands/model_list_cmd.h index 791c1ecf6..85dd76de9 100644 --- a/engine/cli/commands/model_list_cmd.h +++ b/engine/cli/commands/model_list_cmd.h @@ -8,6 +8,7 @@ class ModelListCmd { public: void Exec(const std::string& host, int port, const std::string& filter, bool display_engine = false, bool display_version = false, - bool display_cpu_mode = false, bool display_gpu_mode = false); + bool display_cpu_mode = false, bool display_gpu_mode = false, + bool available = false); }; } // namespace commands diff --git a/engine/cli/commands/model_source_add_cmd.cc b/engine/cli/commands/model_source_add_cmd.cc new file mode 100644 index 000000000..2fadbe8ec --- /dev/null +++ b/engine/cli/commands/model_source_add_cmd.cc @@ -0,0 +1,38 @@ +#include "model_source_add_cmd.h" +#include "server_start_cmd.h" +#include "utils/json_helper.h" +#include "utils/logging_utils.h" +namespace commands { +bool ModelSourceAddCmd::Exec(const std::string& host, int port, const std::string& model_source) { + // Start server if server is not started yet + if (!commands::IsServerAlive(host, port)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host, port)) { + return false; + } + } + + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models", "sources"}, + }; + + Json::Value json_data; + json_data["source"] = model_source; + + auto data_str = json_data.toStyledString(); + auto res = curl_utils::SimplePostJson(url.ToFullPath(), data_str); + if (res.has_error()) { + auto root = json_helper::ParseJsonString(res.error()); + CLI_LOG(root["message"].asString()); + return false; + } + + CLI_LOG("Added model source: " << model_source); + return true; +} + + +}; // namespace commands diff --git a/engine/cli/commands/model_source_add_cmd.h b/engine/cli/commands/model_source_add_cmd.h new file mode 100644 index 000000000..6d3bcc6c0 --- /dev/null +++ b/engine/cli/commands/model_source_add_cmd.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +namespace commands { + +class ModelSourceAddCmd { + public: + bool Exec(const std::string& host, int port, const std::string& model_source); +}; +} // namespace commands diff --git a/engine/cli/commands/model_source_del_cmd.cc b/engine/cli/commands/model_source_del_cmd.cc new file mode 100644 index 000000000..c3c1694e7 --- /dev/null +++ b/engine/cli/commands/model_source_del_cmd.cc @@ -0,0 +1,39 @@ +#include "model_source_del_cmd.h" +#include "server_start_cmd.h" +#include "utils/json_helper.h" +#include "utils/logging_utils.h" + +namespace commands { +bool ModelSourceDelCmd::Exec(const std::string& host, int port, const std::string& model_source) { + // Start server if server is not started yet + if (!commands::IsServerAlive(host, port)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host, port)) { + return false; + } + } + + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models", "sources"}, + }; + + Json::Value json_data; + json_data["source"] = model_source; + + auto data_str = json_data.toStyledString(); + auto res = curl_utils::SimpleDeleteJson(url.ToFullPath(), data_str); + if (res.has_error()) { + auto root = json_helper::ParseJsonString(res.error()); + CLI_LOG(root["message"].asString()); + return false; + } + + CLI_LOG("Removed model source: " << model_source); + return true; +} + + +}; // namespace commands diff --git a/engine/cli/commands/model_source_del_cmd.h b/engine/cli/commands/model_source_del_cmd.h new file mode 100644 index 000000000..5015a609a --- /dev/null +++ b/engine/cli/commands/model_source_del_cmd.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +namespace commands { + +class ModelSourceDelCmd { + public: + bool Exec(const std::string& host, int port, const std::string& model_source); +}; +} // namespace commands diff --git a/engine/cli/commands/model_source_list_cmd.cc b/engine/cli/commands/model_source_list_cmd.cc new file mode 100644 index 000000000..ae69c5aef --- /dev/null +++ b/engine/cli/commands/model_source_list_cmd.cc @@ -0,0 +1,56 @@ +#include "model_source_list_cmd.h" +#include +#include +#include +#include +#include "server_start_cmd.h" +#include "utils/curl_utils.h" +#include "utils/json_helper.h" +#include "utils/logging_utils.h" +#include "utils/string_utils.h" +#include "utils/url_parser.h" +// clang-format off +#include +// clang-format on + +namespace commands { + +bool ModelSourceListCmd::Exec(const std::string& host, int port) { + // Start server if server is not started yet + if (!commands::IsServerAlive(host, port)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host, port)) { + return false; + } + } + + tabulate::Table table; + table.add_row({"#", "Model Source"}); + + auto url = url_parser::Url{ + .protocol = "http", + .host = host + ":" + std::to_string(port), + .pathParams = {"v1", "models", "sources"}, + }; + auto result = curl_utils::SimpleGetJson(url.ToFullPath()); + if (result.has_error()) { + CTL_ERR(result.error()); + return false; + } + table.format().font_color(tabulate::Color::green); + int count = 0; + + if (!result.value()["data"].isNull()) { + for (auto const& v : result.value()["data"]) { + auto model_source = v.asString(); + count += 1; + std::vector row = {std::to_string(count), model_source}; + table.add_row({row.begin(), row.end()}); + } + } + + std::cout << table << std::endl; + return true; +} +}; // namespace commands diff --git a/engine/cli/commands/model_source_list_cmd.h b/engine/cli/commands/model_source_list_cmd.h new file mode 100644 index 000000000..99116f592 --- /dev/null +++ b/engine/cli/commands/model_source_list_cmd.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace commands { + +class ModelSourceListCmd { + public: + bool Exec(const std::string& host, int port); +}; +} // namespace commands diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 3f91da848..affa45d52 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -172,6 +172,28 @@ void Models::ListModel( if (list_entry) { for (const auto& model_entry : list_entry.value()) { try { + if (model_entry.status == cortex::db::ModelStatus::Downloadable) { + Json::Value obj; + obj["id"] = model_entry.model; + obj["model"] = model_entry.model; + auto status_to_string = [](cortex::db::ModelStatus status) { + switch (status) { + case cortex::db::ModelStatus::Remote: + return "remote"; + case cortex::db::ModelStatus::Downloaded: + return "downloaded"; + case cortex::db::ModelStatus::Downloadable: + return "downloadable"; + } + return "unknown"; + }; + obj["modelSource"] = model_entry.model_source; + obj["status"] = status_to_string(model_entry.status); + obj["engine"] = model_entry.engine; + obj["metadata"] = model_entry.metadata; + data.append(std::move(obj)); + continue; + } yaml_handler.ModelConfigFromFile( fmu::ToAbsoluteCortexDataPath( fs::path(model_entry.path_to_model_yaml)) @@ -182,7 +204,7 @@ void Models::ListModel( Json::Value obj = model_config.ToJson(); obj["id"] = model_entry.model; obj["model"] = model_entry.model; - obj["model"] = model_entry.model; + obj["status"] = "downloaded"; auto es = model_service_->GetEstimation(model_entry.model); if (es.has_value() && !!es.value()) { obj["recommendation"] = hardware::ToJson(*(es.value())); @@ -723,4 +745,78 @@ void Models::AddRemoteModel( resp->setStatusCode(k400BadRequest); callback(resp); } +} + +void Models::AddModelSource( + const HttpRequestPtr& req, + std::function&& callback) { + if (!http_util::HasFieldInReq(req, callback, "source")) { + return; + } + + auto model_source = (*(req->getJsonObject())).get("source", "").asString(); + auto res = model_src_svc_->AddModelSource(model_source); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto const& info = res.value(); + Json::Value ret; + ret["message"] = "Model source is added successfully!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + } +} + +void Models::DeleteModelSource( + const HttpRequestPtr& req, + std::function&& callback) { + if (!http_util::HasFieldInReq(req, callback, "source")) { + return; + } + + auto model_source = (*(req->getJsonObject())).get("source", "").asString(); + auto res = model_src_svc_->RemoveModelSource(model_source); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto const& info = res.value(); + Json::Value ret; + ret["message"] = "Model source is deleted successfully!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + } +} + +void Models::GetModelSources( + const HttpRequestPtr& req, + std::function&& callback) { + auto res = model_src_svc_->GetModelSources(); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto const& info = res.value(); + Json::Value ret; + Json::Value data(Json::arrayValue); + for (auto const& i : info) { + data.append(i); + } + ret["data"] = data; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + } } \ No newline at end of file diff --git a/engine/controllers/models.h b/engine/controllers/models.h index b2b288adc..d3200f33a 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -4,6 +4,7 @@ #include #include "services/engine_service.h" #include "services/model_service.h" +#include "services/model_source_service.h" using namespace drogon; @@ -23,6 +24,9 @@ class Models : public drogon::HttpController { METHOD_ADD(Models::GetModelStatus, "/status/{1}", Get); METHOD_ADD(Models::AddRemoteModel, "/add", Options, Post); METHOD_ADD(Models::GetRemoteModels, "/remote/{1}", Get); + METHOD_ADD(Models::AddModelSource, "/sources", Post); + METHOD_ADD(Models::DeleteModelSource, "/sources", Delete); + METHOD_ADD(Models::GetModelSources, "/sources", Get); ADD_METHOD_TO(Models::PullModel, "/v1/models/pull", Options, Post); ADD_METHOD_TO(Models::AbortPullModel, "/v1/models/pull", Options, Delete); @@ -36,11 +40,17 @@ class Models : public drogon::HttpController { ADD_METHOD_TO(Models::GetModelStatus, "/v1/models/status/{1}", Get); ADD_METHOD_TO(Models::AddRemoteModel, "/v1/models/add", Options, Post); ADD_METHOD_TO(Models::GetRemoteModels, "/v1/models/remote/{1}", Get); + ADD_METHOD_TO(Models::AddModelSource, "/v1/models/sources", Post); + ADD_METHOD_TO(Models::DeleteModelSource, "/v1/models/sources", Delete); + ADD_METHOD_TO(Models::GetModelSources, "/v1/models/sources", Get); METHOD_LIST_END explicit Models(std::shared_ptr model_service, - std::shared_ptr engine_service) - : model_service_{model_service}, engine_service_{engine_service} {} + std::shared_ptr engine_service, + std::shared_ptr mss) + : model_service_{model_service}, + engine_service_{engine_service}, + model_src_svc_(mss) {} void PullModel(const HttpRequestPtr& req, std::function&& callback); @@ -84,7 +94,18 @@ class Models : public drogon::HttpController { std::function&& callback, const std::string& engine_id); + void AddModelSource(const HttpRequestPtr& req, + std::function&& callback); + + void DeleteModelSource( + const HttpRequestPtr& req, + std::function&& callback); + + void GetModelSources(const HttpRequestPtr& req, + std::function&& callback); + private: std::shared_ptr model_service_; std::shared_ptr engine_service_; + std::shared_ptr model_src_svc_; }; diff --git a/engine/database/models.cc b/engine/database/models.cc index 8c8be9eaf..67ff1a8c9 100644 --- a/engine/database/models.cc +++ b/engine/database/models.cc @@ -18,8 +18,8 @@ std::string Models::StatusToString(ModelStatus status) const { return "remote"; case ModelStatus::Downloaded: return "downloaded"; - case ModelStatus::Undownloaded: - return "undownloaded"; + case ModelStatus::Downloadable: + return "downloadable"; } return "unknown"; } @@ -31,8 +31,8 @@ ModelStatus Models::StringToStatus(const std::string& status_str) const { return ModelStatus::Remote; } else if (status_str == "downloaded" || status_str.empty()) { return ModelStatus::Downloaded; - } else if (status_str == "undownloaded") { - return ModelStatus::Undownloaded; + } else if (status_str == "downloadable") { + return ModelStatus::Downloadable; } throw std::invalid_argument("Invalid status string"); } @@ -50,23 +50,21 @@ cpp::result, std::string> Models::LoadModelList() } bool Models::IsUnique(const std::vector& entries, - const std::string& model_id, - const std::string& model_alias) const { + const std::string& model_id) const { return std::none_of( - entries.begin(), entries.end(), [&](const ModelEntry& entry) { - return entry.model == model_id || entry.model_alias == model_id || - entry.model == model_alias || entry.model_alias == model_alias; - }); + entries.begin(), entries.end(), + [&](const ModelEntry& entry) { return entry.model == model_id; }); } cpp::result, std::string> Models::LoadModelListNoLock() const { try { std::vector entries; - SQLite::Statement query(db_, - "SELECT model_id, author_repo_id, branch_name, " - "path_to_model_yaml, model_alias, model_format, " - "model_source, status, engine FROM models"); + SQLite::Statement query( + db_, + "SELECT model_id, author_repo_id, branch_name, " + "path_to_model_yaml, model_alias, model_format, " + "model_source, status, engine, metadata FROM models"); while (query.executeStep()) { ModelEntry entry; @@ -79,6 +77,7 @@ cpp::result, std::string> Models::LoadModelListNoLock() entry.model_source = query.getColumn(6).getString(); entry.status = StringToStatus(query.getColumn(7).getString()); entry.engine = query.getColumn(8).getString(); + entry.metadata = query.getColumn(9).getString(); entries.push_back(entry); } return entries; @@ -88,77 +87,17 @@ cpp::result, std::string> Models::LoadModelListNoLock() } } -std::string Models::GenerateShortenedAlias( - const std::string& model_id, const std::vector& entries) const { - std::vector parts; - std::istringstream iss(model_id); - std::string part; - while (std::getline(iss, part, ':')) { - parts.push_back(part); - } - - if (parts.empty()) { - return model_id; // Return original if no parts - } - - // Extract the filename without extension - std::string filename = parts.back(); - size_t last_dot_pos = filename.find_last_of('.'); - if (last_dot_pos != std::string::npos) { - filename = filename.substr(0, last_dot_pos); - } - - // Convert to lowercase - std::transform(filename.begin(), filename.end(), filename.begin(), - [](unsigned char c) { return std::tolower(c); }); - - // Generate alias candidates - std::vector candidates; - candidates.push_back(filename); - - if (parts.size() >= 2) { - candidates.push_back(parts[parts.size() - 2] + ":" + filename); - } - - if (parts.size() >= 3) { - candidates.push_back(parts[parts.size() - 3] + ":" + - parts[parts.size() - 2] + ":" + filename); - } - - if (parts.size() >= 4) { - candidates.push_back(parts[0] + ":" + parts[1] + ":" + - parts[parts.size() - 2] + ":" + filename); - } - - // Find the first unique candidate - for (const auto& candidate : candidates) { - if (IsUnique(entries, model_id, candidate)) { - return candidate; - } - } - - // If all candidates are taken, append a number to the last candidate - std::string base_candidate = candidates.back(); - int suffix = 1; - std::string unique_candidate = base_candidate; - while (!IsUnique(entries, model_id, unique_candidate)) { - unique_candidate = base_candidate + "-" + std::to_string(suffix++); - } - - return unique_candidate; -} - cpp::result Models::GetModelInfo( const std::string& identifier) const { try { - SQLite::Statement query(db_, - "SELECT model_id, author_repo_id, branch_name, " - "path_to_model_yaml, model_alias, model_format, " - "model_source, status, engine FROM models " - "WHERE model_id = ? OR model_alias = ?"); + SQLite::Statement query( + db_, + "SELECT model_id, author_repo_id, branch_name, " + "path_to_model_yaml, model_alias, model_format, " + "model_source, status, engine, metadata FROM models " + "WHERE model_id = ?"); query.bind(1, identifier); - query.bind(2, identifier); if (query.executeStep()) { ModelEntry entry; entry.model = query.getColumn(0).getString(); @@ -170,6 +109,7 @@ cpp::result Models::GetModelInfo( entry.model_source = query.getColumn(6).getString(); entry.status = StringToStatus(query.getColumn(7).getString()); entry.engine = query.getColumn(8).getString(); + entry.metadata = query.getColumn(9).getString(); return entry; } else { return cpp::fail("Model not found: " + identifier); @@ -189,10 +129,10 @@ void Models::PrintModelInfo(const ModelEntry& entry) const { LOG_INFO << "Model Source: " << entry.model_source; LOG_INFO << "Status: " << StatusToString(entry.status); LOG_INFO << "Engine: " << entry.engine; + LOG_INFO << "Metadata: " << entry.metadata; } -cpp::result Models::AddModelEntry(ModelEntry new_entry, - bool use_short_alias) { +cpp::result Models::AddModelEntry(ModelEntry new_entry) { try { db_.exec("BEGIN TRANSACTION;"); cortex::utils::ScopeExit se([this] { db_.exec("COMMIT;"); }); @@ -201,17 +141,13 @@ cpp::result Models::AddModelEntry(ModelEntry new_entry, CTL_WRN(model_list.error()); return cpp::fail(model_list.error()); } - if (IsUnique(model_list.value(), new_entry.model, new_entry.model_alias)) { - if (use_short_alias) { - new_entry.model_alias = - GenerateShortenedAlias(new_entry.model, model_list.value()); - } + if (IsUnique(model_list.value(), new_entry.model)) { SQLite::Statement insert( db_, "INSERT INTO models (model_id, author_repo_id, branch_name, " "path_to_model_yaml, model_alias, model_format, model_source, " - "status, engine) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"); + "status, engine, metadata) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"); insert.bind(1, new_entry.model); insert.bind(2, new_entry.author_repo_id); insert.bind(3, new_entry.branch_name); @@ -221,6 +157,7 @@ cpp::result Models::AddModelEntry(ModelEntry new_entry, insert.bind(7, new_entry.model_source); insert.bind(8, StatusToString(new_entry.status)); insert.bind(9, new_entry.engine); + insert.bind(10, new_entry.metadata); insert.exec(); return true; @@ -242,7 +179,7 @@ cpp::result Models::UpdateModelEntry( db_, "UPDATE models SET author_repo_id = ?, branch_name = ?, " "path_to_model_yaml = ?, model_format = ?, model_source = ?, status = " - "?, engine = ? WHERE model_id = ? OR model_alias = ?"); + "?, engine = ?, metadata = ? WHERE model_id = ?"); upd.bind(1, updated_entry.author_repo_id); upd.bind(2, updated_entry.branch_name); upd.bind(3, updated_entry.path_to_model_yaml); @@ -250,7 +187,7 @@ cpp::result Models::UpdateModelEntry( upd.bind(5, updated_entry.model_source); upd.bind(6, StatusToString(updated_entry.status)); upd.bind(7, updated_entry.engine); - upd.bind(8, identifier); + upd.bind(8, updated_entry.metadata); upd.bind(9, identifier); return upd.exec() == 1; } catch (const std::exception& e) { @@ -258,36 +195,6 @@ cpp::result Models::UpdateModelEntry( } } -cpp::result Models::UpdateModelAlias( - const std::string& model_id, const std::string& new_model_alias) { - if (!HasModel(model_id)) { - return cpp::fail("Model not found: " + model_id); - } - try { - db_.exec("BEGIN TRANSACTION;"); - cortex::utils::ScopeExit se([this] { db_.exec("COMMIT;"); }); - auto model_list = LoadModelListNoLock(); - if (model_list.has_error()) { - CTL_WRN(model_list.error()); - return cpp::fail(model_list.error()); - } - // Check new_model_alias is unique - if (IsUnique(model_list.value(), new_model_alias, new_model_alias)) { - SQLite::Statement upd(db_, - "UPDATE models " - "SET model_alias = ? " - "WHERE model_id = ? OR model_alias = ?"); - upd.bind(1, new_model_alias); - upd.bind(2, model_id); - upd.bind(3, model_id); - return upd.exec() == 1; - } - return false; - } catch (const std::exception& e) { - return cpp::fail(e.what()); - } -} - cpp::result Models::DeleteModelEntry( const std::string& identifier) { try { @@ -296,10 +203,34 @@ cpp::result Models::DeleteModelEntry( return true; } - SQLite::Statement del( - db_, "DELETE from models WHERE model_id = ? OR model_alias = ?"); + SQLite::Statement del(db_, "DELETE from models WHERE model_id = ?"); del.bind(1, identifier); - del.bind(2, identifier); + return del.exec() == 1; + } catch (const std::exception& e) { + return cpp::fail(e.what()); + } +} + +cpp::result Models::DeleteModelEntryWithOrg( + const std::string& src) { + try { + SQLite::Statement del(db_, + "DELETE from models WHERE model_source LIKE ? AND " + "status = \"downloadable\""); + del.bind(1, src + "%"); + return del.exec() == 1; + } catch (const std::exception& e) { + return cpp::fail(e.what()); + } +} + +cpp::result Models::DeleteModelEntryWithRepo( + const std::string& src) { + try { + SQLite::Statement del(db_, + "DELETE from models WHERE model_source = ? AND " + "status = \"downloadable\""); + del.bind(1, src); return del.exec() == 1; } catch (const std::exception& e) { return cpp::fail(e.what()); @@ -310,8 +241,9 @@ cpp::result, std::string> Models::FindRelatedModel( const std::string& identifier) const { try { std::vector related_models; - SQLite::Statement query( - db_, "SELECT model_id FROM models WHERE model_id LIKE ?"); + SQLite::Statement query(db_, + "SELECT model_id FROM models WHERE model_id LIKE ? " + "AND status = \"downloaded\""); query.bind(1, "%" + identifier + "%"); while (query.executeStep()) { @@ -325,11 +257,9 @@ cpp::result, std::string> Models::FindRelatedModel( bool Models::HasModel(const std::string& identifier) const { try { - SQLite::Statement query( - db_, - "SELECT COUNT(*) FROM models WHERE model_id = ? OR model_alias = ?"); + SQLite::Statement query(db_, + "SELECT COUNT(*) FROM models WHERE model_id = ?"); query.bind(1, identifier); - query.bind(2, identifier); if (query.executeStep()) { return query.getColumn(0).getInt() > 0; } @@ -340,4 +270,38 @@ bool Models::HasModel(const std::string& identifier) const { } } +cpp::result, std::string> Models::GetModelSources() + const { + try { + std::vector sources; + SQLite::Statement query(db_, + "SELECT DISTINCT model_source FROM models WHERE " + "status = \"downloadable\""); + + while (query.executeStep()) { + sources.push_back(query.getColumn(0).getString()); + } + return sources; + } catch (const std::exception& e) { + return cpp::fail(e.what()); + } +} + +cpp::result, std::string> Models::GetModels( + const std::string& model_src) const { + try { + std::vector ids; + SQLite::Statement query(db_, + "SELECT model_id FROM models WHERE model_source = " + "? AND status = \"downloadable\""); + query.bind(1, model_src); + while (query.executeStep()) { + ids.push_back(query.getColumn(0).getString()); + } + return ids; + } catch (const std::exception& e) { + return cpp::fail(e.what()); + } +} + } // namespace cortex::db diff --git a/engine/database/models.h b/engine/database/models.h index 5c855cf1b..b0c4bc258 100644 --- a/engine/database/models.h +++ b/engine/database/models.h @@ -8,7 +8,8 @@ namespace cortex::db { -enum class ModelStatus { Remote, Downloaded, Undownloaded }; +enum class ModelStatus { Remote, Downloaded, Downloadable }; + struct ModelEntry { std::string model; @@ -20,6 +21,7 @@ struct ModelEntry { std::string model_source; ModelStatus status; std::string engine; + std::string metadata; }; class Models { @@ -28,8 +30,7 @@ class Models { SQLite::Database& db_; bool IsUnique(const std::vector& entries, - const std::string& model_id, - const std::string& model_alias) const; + const std::string& model_id) const; cpp::result, std::string> LoadModelListNoLock() const; @@ -41,23 +42,24 @@ class Models { Models(); Models(SQLite::Database& db); ~Models(); - std::string GenerateShortenedAlias( - const std::string& model_id, - const std::vector& entries) const; cpp::result GetModelInfo( const std::string& identifier) const; void PrintModelInfo(const ModelEntry& entry) const; - cpp::result AddModelEntry(ModelEntry new_entry, - bool use_short_alias = false); + cpp::result AddModelEntry(ModelEntry new_entry); cpp::result UpdateModelEntry( const std::string& identifier, const ModelEntry& updated_entry); cpp::result DeleteModelEntry( const std::string& identifier); - cpp::result UpdateModelAlias( - const std::string& model_id, const std::string& model_alias); + cpp::result DeleteModelEntryWithOrg( + const std::string& src); + cpp::result DeleteModelEntryWithRepo( + const std::string& src); cpp::result, std::string> FindRelatedModel( const std::string& identifier) const; bool HasModel(const std::string& identifier) const; + cpp::result, std::string> GetModelSources() const; + cpp::result, std::string> GetModels( + const std::string& model_src) const; }; } // namespace cortex::db diff --git a/engine/e2e-test/test_api_model.py b/engine/e2e-test/test_api_model.py index c2723d2ca..8f2e4b07a 100644 --- a/engine/e2e-test/test_api_model.py +++ b/engine/e2e-test/test_api_model.py @@ -129,4 +129,17 @@ async def test_models_start_stop_should_be_successful(self): # delete API print("Delete model") response = requests.delete("http://localhost:3928/v1/models/tinyllama:gguf") - assert response.status_code == 200 \ No newline at end of file + assert response.status_code == 200 + + def test_models_sources_api(self): + json_body = {"source": "https://huggingface.co/cortexso/tinyllama"} + response = requests.post( + "http://localhost:3928/v1/models/sources", json=json_body + ) + assert response.status_code == 200, f"status_code: {response.status_code}" + + json_body = {"source": "https://huggingface.co/cortexso/tinyllama"} + response = requests.delete( + "http://localhost:3928/v1/models/sources", json=json_body + ) + assert response.status_code == 200, f"status_code: {response.status_code}" \ No newline at end of file diff --git a/engine/main.cc b/engine/main.cc index 5222ac5c2..13583dc00 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -22,6 +22,7 @@ #include "services/file_watcher_service.h" #include "services/message_service.h" #include "services/model_service.h" +#include "services/model_source_service.h" #include "services/thread_service.h" #include "utils/archive_utils.h" #include "utils/cortex_utils.h" @@ -141,6 +142,7 @@ void RunServer(std::optional port, bool ignore_cout) { auto engine_service = std::make_shared(download_service); auto inference_svc = std::make_shared(engine_service); + auto model_src_svc = std::make_shared(); auto model_service = std::make_shared( download_service, inference_svc, engine_service); @@ -154,7 +156,8 @@ void RunServer(std::optional port, bool ignore_cout) { auto thread_ctl = std::make_shared(thread_srv, message_srv); auto message_ctl = std::make_shared(message_srv); auto engine_ctl = std::make_shared(engine_service); - auto model_ctl = std::make_shared(model_service, engine_service); + auto model_ctl = + std::make_shared(model_service, engine_service, model_src_svc); auto event_ctl = std::make_shared(event_queue_ptr); auto pm_ctl = std::make_shared(); auto hw_ctl = std::make_shared(engine_service, hw_service); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 7f79ddaf7..15fee15be 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -64,16 +64,30 @@ void ParseGguf(const DownloadItem& ggufDownloadItem, auto author_id = author.has_value() ? author.value() : "cortexso"; cortex::db::Models modellist_utils_obj; - cortex::db::ModelEntry model_entry{ - .model = ggufDownloadItem.id, - .author_repo_id = author_id, - .branch_name = branch, - .path_to_model_yaml = rel.string(), - .model_alias = ggufDownloadItem.id, - .status = cortex::db::ModelStatus::Downloaded}; - auto result = modellist_utils_obj.AddModelEntry(model_entry, true); - if (result.has_error()) { - CTL_WRN("Error adding model to modellist: " + result.error()); + if (!modellist_utils_obj.HasModel(ggufDownloadItem.id)) { + cortex::db::ModelEntry model_entry{ + .model = ggufDownloadItem.id, + .author_repo_id = author_id, + .branch_name = branch, + .path_to_model_yaml = rel.string(), + .model_alias = ggufDownloadItem.id, + .status = cortex::db::ModelStatus::Downloaded}; + auto result = modellist_utils_obj.AddModelEntry(model_entry); + + if (result.has_error()) { + CTL_ERR("Error adding model to modellist: " + result.error()); + } + } else { + if (auto m = modellist_utils_obj.GetModelInfo(ggufDownloadItem.id); + m.has_value()) { + auto upd_m = m.value(); + upd_m.status = cortex::db::ModelStatus::Downloaded; + if (auto r = + modellist_utils_obj.UpdateModelEntry(ggufDownloadItem.id, upd_m); + r.has_error()) { + CTL_ERR(r.error()); + } + } } } @@ -136,6 +150,9 @@ void ModelService::ForceIndexingModelList() { CTL_DBG("Database model size: " + std::to_string(list_entry.value().size())); for (const auto& model_entry : list_entry.value()) { + if (model_entry.status != cortex::db::ModelStatus::Downloaded) { + continue; + } try { yaml_handler.ModelConfigFromFile( fmu::ToAbsoluteCortexDataPath( @@ -301,7 +318,8 @@ cpp::result ModelService::HandleDownloadUrlAsync( } auto model_entry = modellist_handler.GetModelInfo(unique_model_id); - if (model_entry.has_value()) { + if (model_entry.has_value() && + model_entry->status == cortex::db::ModelStatus::Downloaded) { CLI_LOG("Model already downloaded: " << unique_model_id); return cpp::fail("Please delete the model before downloading again"); } @@ -491,7 +509,8 @@ ModelService::DownloadModelFromCortexsoAsync( } auto model_entry = modellist_handler.GetModelInfo(unique_model_id); - if (model_entry.has_value()) { + if (model_entry.has_value() && + model_entry->status == cortex::db::ModelStatus::Downloaded) { return cpp::fail("Please delete the model before downloading again"); } @@ -532,14 +551,32 @@ ModelService::DownloadModelFromCortexsoAsync( CTL_INF("path_to_model_yaml: " << rel.string()); cortex::db::Models modellist_utils_obj; - cortex::db::ModelEntry model_entry{.model = unique_model_id, - .author_repo_id = "cortexso", - .branch_name = branch, - .path_to_model_yaml = rel.string(), - .model_alias = unique_model_id}; - auto result = modellist_utils_obj.AddModelEntry(model_entry); - if (result.has_error()) { - CTL_ERR("Error adding model to modellist: " + result.error()); + if (!modellist_utils_obj.HasModel(unique_model_id)) { + cortex::db::ModelEntry model_entry{ + .model = unique_model_id, + .author_repo_id = "cortexso", + .branch_name = branch, + .path_to_model_yaml = rel.string(), + .model_alias = unique_model_id, + .status = cortex::db::ModelStatus::Downloaded}; + auto result = modellist_utils_obj.AddModelEntry(model_entry); + + if (result.has_error()) { + CTL_ERR("Error adding model to modellist: " + result.error()); + } + } else { + if (auto m = modellist_utils_obj.GetModelInfo(unique_model_id); + m.has_value()) { + auto upd_m = m.value(); + upd_m.status = cortex::db::ModelStatus::Downloaded; + if (auto r = + modellist_utils_obj.UpdateModelEntry(unique_model_id, upd_m); + r.has_error()) { + CTL_ERR(r.error()); + } + } else { + CTL_WRN("Could not get model entry with model id: " << unique_model_id); + } } }; @@ -585,14 +622,28 @@ cpp::result ModelService::DownloadModelFromCortexso( CTL_INF("path_to_model_yaml: " << rel.string()); cortex::db::Models modellist_utils_obj; - cortex::db::ModelEntry model_entry{.model = model_id, - .author_repo_id = "cortexso", - .branch_name = branch, - .path_to_model_yaml = rel.string(), - .model_alias = model_id}; - auto result = modellist_utils_obj.AddModelEntry(model_entry); - if (result.has_error()) { - CTL_ERR("Error adding model to modellist: " + result.error()); + if (!modellist_utils_obj.HasModel(model_id)) { + cortex::db::ModelEntry model_entry{ + .model = model_id, + .author_repo_id = "cortexso", + .branch_name = branch, + .path_to_model_yaml = rel.string(), + .model_alias = model_id, + .status = cortex::db::ModelStatus::Downloaded}; + auto result = modellist_utils_obj.AddModelEntry(model_entry); + + if (result.has_error()) { + CTL_ERR("Error adding model to modellist: " + result.error()); + } + } else { + if (auto m = modellist_utils_obj.GetModelInfo(model_id); m.has_value()) { + auto upd_m = m.value(); + upd_m.status = cortex::db::ModelStatus::Downloaded; + if (auto r = modellist_utils_obj.UpdateModelEntry(model_id, upd_m); + r.has_error()) { + CTL_ERR(r.error()); + } + } } }; diff --git a/engine/services/model_source_service.cc b/engine/services/model_source_service.cc new file mode 100644 index 000000000..a7d9d5e6e --- /dev/null +++ b/engine/services/model_source_service.cc @@ -0,0 +1,493 @@ +#include "model_source_service.h" +#include +#include +#include "database/models.h" +#include "json/json.h" +#include "utils/curl_utils.h" +#include "utils/huggingface_utils.h" +#include "utils/logging_utils.h" +#include "utils/string_utils.h" +#include "utils/url_parser.h" + +namespace services { +namespace hu = huggingface_utils; + +namespace { +struct ModelInfo { + std::string id; + int likes; + int trending_score; + bool is_private; + int downloads; + std::vector tags; + std::string created_at; + std::string model_id; +}; + +std::vector ParseJsonString(const std::string& json_str) { + std::vector models; + + // Parse the JSON string + Json::Value root; + Json::Reader reader; + bool parsing_successful = reader.parse(json_str, root); + + if (!parsing_successful) { + std::cerr << "Failed to parse JSON" << std::endl; + return models; + } + + // Iterate over the JSON array + for (const auto& model : root) { + ModelInfo info; + info.id = model["id"].asString(); + info.likes = model["likes"].asInt(); + info.trending_score = model["trendingScore"].asInt(); + info.is_private = model["private"].asBool(); + info.downloads = model["downloads"].asInt(); + + const Json::Value& tags = model["tags"]; + for (const auto& tag : tags) { + info.tags.push_back(tag.asString()); + } + + info.created_at = model["createdAt"].asString(); + info.model_id = model["modelId"].asString(); + models.push_back(info); + } + + return models; +} + +} // namespace + +ModelSourceService::ModelSourceService() { + sync_db_thread_ = std::thread(&ModelSourceService::SyncModelSource, this); + running_ = true; +} +ModelSourceService::~ModelSourceService() { + running_ = false; + if (sync_db_thread_.joinable()) { + sync_db_thread_.join(); + } + CTL_INF("Done cleanup thread"); +} + +cpp::result ModelSourceService::AddModelSource( + const std::string& model_source) { + auto res = url_parser::FromUrlString(model_source); + if (res.has_error()) { + return cpp::fail(res.error()); + } else { + auto& r = res.value(); + if (r.pathParams.empty() || r.pathParams.size() > 2) { + return cpp::fail("Invalid model source url: " + model_source); + } + + if (auto is_org = r.pathParams.size() == 1; is_org) { + auto& author = r.pathParams[0]; + if (author == "cortexso") { + return AddCortexsoOrg(model_source); + } else { + return AddHfOrg(model_source, author); + } + } else { // Repo + auto const& author = r.pathParams[0]; + auto const& model_name = r.pathParams[1]; + if (r.pathParams[0] == "cortexso") { + return AddCortexsoRepo(model_source, author, model_name); + } else { + return AddHfRepo(model_source, author, model_name); + } + } + } + return true; +} + +cpp::result ModelSourceService::RemoveModelSource( + const std::string& model_source) { + cortex::db::Models model_db; + auto srcs = model_db.GetModelSources(); + if (srcs.has_error()) { + return cpp::fail(srcs.error()); + } else { + auto& v = srcs.value(); + if (std::find(v.begin(), v.end(), model_source) == v.end()) { + return cpp::fail("Model source does not exist: " + model_source); + } + } + CTL_INF("Remove model source: " << model_source); + auto res = url_parser::FromUrlString(model_source); + if (res.has_error()) { + return cpp::fail(res.error()); + } else { + auto& r = res.value(); + if (r.pathParams.empty() || r.pathParams.size() > 2) { + return cpp::fail("Invalid model source url: " + model_source); + } + + if (r.pathParams.size() == 1) { + if (auto del_res = model_db.DeleteModelEntryWithOrg(model_source); + del_res.has_error()) { + CTL_INF(del_res.error()); + return cpp::fail(del_res.error()); + } + } else { + if (auto del_res = model_db.DeleteModelEntryWithRepo(model_source); + del_res.has_error()) { + CTL_INF(del_res.error()); + return cpp::fail(del_res.error()); + } + } + } + return true; +} + +cpp::result, std::string> +ModelSourceService::GetModelSources() { + cortex::db::Models model_db; + return model_db.GetModelSources(); +} + +cpp::result ModelSourceService::AddHfOrg( + const std::string& model_source, const std::string& author) { + auto res = curl_utils::SimpleGet("https://huggingface.co/api/models?author=" + + author); + if (res.has_value()) { + auto models = ParseJsonString(res.value()); + // Get models from db + cortex::db::Models model_db; + + auto model_list_before = + model_db.GetModels(model_source).value_or(std::vector{}); + std::unordered_set updated_model_list; + // Add new models + for (auto const& m : models) { + CTL_DBG(m.id); + auto author_model = string_utils::SplitBy(m.id, "/"); + if (author_model.size() == 2) { + auto const& author = author_model[0]; + auto const& model_name = author_model[1]; + auto add_res = AddRepoSiblings(model_source, author, model_name) + .value_or(std::unordered_set{}); + for (auto const& a : add_res) { + updated_model_list.insert(a); + } + } + } + + // Clean up + for (auto const& mid : model_list_before) { + if (updated_model_list.find(mid) == updated_model_list.end()) { + if (auto del_res = model_db.DeleteModelEntry(mid); + del_res.has_error()) { + CTL_INF(del_res.error()); + } + } + } + } else { + return cpp::fail(res.error()); + } + return true; +} + +cpp::result ModelSourceService::AddHfRepo( + const std::string& model_source, const std::string& author, + const std::string& model_name) { + // Get models from db + cortex::db::Models model_db; + + auto model_list_before = + model_db.GetModels(model_source).value_or(std::vector{}); + std::unordered_set updated_model_list; + auto add_res = AddRepoSiblings(model_source, author, model_name); + if (add_res.has_error()) { + return cpp::fail(add_res.error()); + } else { + updated_model_list = add_res.value(); + } + for (auto const& mid : model_list_before) { + if (updated_model_list.find(mid) == updated_model_list.end()) { + if (auto del_res = model_db.DeleteModelEntry(mid); del_res.has_error()) { + CTL_INF(del_res.error()); + } + } + } + return true; +} + +cpp::result, std::string> +ModelSourceService::AddRepoSiblings(const std::string& model_source, + const std::string& author, + const std::string& model_name) { + std::unordered_set res; + auto repo_info = hu::GetHuggingFaceModelRepoInfo(author, model_name); + if (repo_info.has_error()) { + return cpp::fail(repo_info.error()); + } + + if (!repo_info->gguf.has_value()) { + return cpp::fail( + "Not a GGUF model. Currently, only GGUF single file is " + "supported."); + } + + for (const auto& sibling : repo_info->siblings) { + if (string_utils::EndsWith(sibling.rfilename, ".gguf")) { + cortex::db::Models model_db; + std::string model_id = + author + ":" + model_name + ":" + sibling.rfilename; + cortex::db::ModelEntry e = { + .model = model_id, + .author_repo_id = author, + .branch_name = "main", + .path_to_model_yaml = "", + .model_alias = "", + .model_format = "hf-gguf", + .model_source = model_source, + .status = cortex::db::ModelStatus::Downloadable, + .engine = "llama-cpp", + .metadata = repo_info->metadata}; + if (!model_db.HasModel(model_id)) { + if (auto add_res = model_db.AddModelEntry(e); add_res.has_error()) { + CTL_INF(add_res.error()); + } + } else { + if (auto m = model_db.GetModelInfo(model_id); + m.has_value() && + m->status == cortex::db::ModelStatus::Downloadable) { + if (auto upd_res = model_db.UpdateModelEntry(model_id, e); + upd_res.has_error()) { + CTL_INF(upd_res.error()); + } + } + } + res.insert(model_id); + } + } + + return res; +} + +cpp::result ModelSourceService::AddCortexsoOrg( + const std::string& model_source) { + auto res = curl_utils::SimpleGet( + "https://huggingface.co/api/models?author=cortexso"); + if (res.has_value()) { + auto models = ParseJsonString(res.value()); + // Get models from db + cortex::db::Models model_db; + + auto model_list_before = + model_db.GetModels(model_source).value_or(std::vector{}); + std::unordered_set updated_model_list; + for (auto const& m : models) { + CTL_INF(m.id); + auto author_model = string_utils::SplitBy(m.id, "/"); + if (author_model.size() == 2) { + auto const& author = author_model[0]; + auto const& model_name = author_model[1]; + auto branches = huggingface_utils::GetModelRepositoryBranches( + "cortexso", model_name); + if (branches.has_error()) { + CTL_INF(branches.error()); + continue; + } + + auto repo_info = hu::GetHuggingFaceModelRepoInfo(author, model_name); + if (repo_info.has_error()) { + CTL_INF(repo_info.error()); + continue; + } + for (auto const& [branch, _] : branches.value()) { + CTL_INF(branch); + auto add_res = AddCortexsoRepoBranch(model_source, author, model_name, + branch, repo_info->metadata) + .value_or(std::unordered_set{}); + for (auto const& a : add_res) { + updated_model_list.insert(a); + } + } + } + } + // Clean up + for (auto const& mid : model_list_before) { + if (updated_model_list.find(mid) == updated_model_list.end()) { + if (auto del_res = model_db.DeleteModelEntry(mid); + del_res.has_error()) { + CTL_INF(del_res.error()); + } + } + } + } else { + return cpp::fail(res.error()); + } + + return true; +} + +cpp::result ModelSourceService::AddCortexsoRepo( + const std::string& model_source, const std::string& author, + const std::string& model_name) { + auto branches = + huggingface_utils::GetModelRepositoryBranches("cortexso", model_name); + if (branches.has_error()) { + return cpp::fail(branches.error()); + } + + auto repo_info = hu::GetHuggingFaceModelRepoInfo(author, model_name); + if (repo_info.has_error()) { + return cpp::fail(repo_info.error()); + } + // Get models from db + cortex::db::Models model_db; + + auto model_list_before = + model_db.GetModels(model_source).value_or(std::vector{}); + std::unordered_set updated_model_list; + + for (auto const& [branch, _] : branches.value()) { + CTL_INF(branch); + auto add_res = AddCortexsoRepoBranch(model_source, author, model_name, + branch, repo_info->metadata) + .value_or(std::unordered_set{}); + for (auto const& a : add_res) { + updated_model_list.insert(a); + } + } + + // Clean up + for (auto const& mid : model_list_before) { + if (updated_model_list.find(mid) == updated_model_list.end()) { + if (auto del_res = model_db.DeleteModelEntry(mid); del_res.has_error()) { + CTL_INF(del_res.error()); + } + } + } + return true; +} + +cpp::result, std::string> +ModelSourceService::AddCortexsoRepoBranch(const std::string& model_source, + const std::string& author, + const std::string& model_name, + const std::string& branch, + const std::string& metadata) { + std::unordered_set res; + + url_parser::Url url = { + .protocol = "https", + .host = kHuggingFaceHost, + .pathParams = {"api", "models", "cortexso", model_name, "tree", branch}, + }; + + auto result = curl_utils::SimpleGetJson(url.ToFullPath()); + if (result.has_error()) { + return cpp::fail("Model " + model_name + " not found"); + } + + bool has_gguf = false; + for (const auto& value : result.value()) { + auto path = value["path"].asString(); + if (path.find(".gguf") != std::string::npos) { + has_gguf = true; + } + } + if (!has_gguf) { + CTL_INF("Only support gguf file format! - branch: " << branch); + return {}; + } else { + cortex::db::Models model_db; + std::string model_id = model_name + ":" + branch; + cortex::db::ModelEntry e = {.model = model_id, + .author_repo_id = author, + .branch_name = branch, + .path_to_model_yaml = "", + .model_alias = "", + .model_format = "cortexso", + .model_source = model_source, + .status = cortex::db::ModelStatus::Downloadable, + .engine = "llama-cpp", + .metadata = metadata}; + if (!model_db.HasModel(model_id)) { + CTL_INF("Adding model to db: " << model_name << ":" << branch); + if (auto res = model_db.AddModelEntry(e); + res.has_error() || !res.value()) { + CTL_DBG("Cannot add model to db: " << model_id); + } + } else { + if (auto m = model_db.GetModelInfo(model_id); + m.has_value() && m->status == cortex::db::ModelStatus::Downloadable) { + if (auto upd_res = model_db.UpdateModelEntry(model_id, e); + upd_res.has_error()) { + CTL_INF(upd_res.error()); + } + } + } + res.insert(model_id); + } + return res; +} + +void ModelSourceService::SyncModelSource() { + // Do interval check for 10 minutes + constexpr const int kIntervalCheck = 10 * 60; + auto start_time = std::chrono::steady_clock::now(); + while (running_) { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + auto current_time = std::chrono::steady_clock::now(); + auto elapsed_time = std::chrono::duration_cast( + current_time - start_time) + .count(); + + if (elapsed_time > kIntervalCheck) { + CTL_DBG("Start to sync cortex.db"); + start_time = current_time; + + cortex::db::Models model_db; + auto res = model_db.GetModelSources(); + if (res.has_error()) { + CTL_INF(res.error()); + } else { + for (auto const& src : res.value()) { + CTL_DBG(src); + } + + std::unordered_set orgs; + std::vector repos; + for (auto const& src : res.value()) { + auto url_res = url_parser::FromUrlString(src); + if (url_res.has_value()) { + if (url_res->pathParams.size() == 1) { + orgs.insert(src); + } else if (url_res->pathParams.size() == 2) { + repos.push_back(src); + } + } + } + + // Get list to update + std::vector update_cand(orgs.begin(), orgs.end()); + auto get_org = [](const std::string& rp) { + return rp.substr(0, rp.find_last_of("/")); + }; + for (auto const& repo : repos) { + if (orgs.find(get_org(repo)) != orgs.end()) { + update_cand.push_back(repo); + } + } + + // Sync cortex.db with the upstream data + for (auto const& c : update_cand) { + if (auto res = AddModelSource(c); res.has_error()) { + CTL_INF(res.error();) + } + } + } + + CTL_DBG("Done sync cortex.db"); + } + } +} + +} // namespace services \ No newline at end of file diff --git a/engine/services/model_source_service.h b/engine/services/model_source_service.h new file mode 100644 index 000000000..aa0b37259 --- /dev/null +++ b/engine/services/model_source_service.h @@ -0,0 +1,53 @@ +#pragma once +#include +#include +#include +#include "utils/result.hpp" + +namespace services { +class ModelSourceService { + public: + explicit ModelSourceService(); + ~ModelSourceService(); + + cpp::result AddModelSource( + const std::string& model_source); + + cpp::result RemoveModelSource( + const std::string& model_source); + + cpp::result, std::string> GetModelSources(); + + private: + cpp::result AddHfOrg(const std::string& model_source, + const std::string& author); + + cpp::result AddHfRepo( + const std::string& model_source, const std::string& author, + const std::string& model_name); + + cpp::result, std::string> AddRepoSiblings( + const std::string& model_source, const std::string& author, + const std::string& model_name); + + cpp::result AddCortexsoOrg( + const std::string& model_source); + + cpp::result AddCortexsoRepo( + const std::string& model_source, const std::string& author, + const std::string& model_name); + + cpp::result, std::string> + AddCortexsoRepoBranch(const std::string& model_source, + const std::string& author, + const std::string& model_name, + const std::string& branch, + const std::string& metadata); + + void SyncModelSource(); + + private: + std::thread sync_db_thread_; + std::atomic running_; +}; +} // namespace services \ No newline at end of file diff --git a/engine/test/components/test_models_db.cc b/engine/test/components/test_models_db.cc index ab0ea9f70..06294aa8c 100644 --- a/engine/test/components/test_models_db.cc +++ b/engine/test/components/test_models_db.cc @@ -24,7 +24,8 @@ class ModelsTestSuite : public ::testing::Test { "model_format TEXT," "model_source TEXT," "status TEXT," - "engine TEXT" + "engine TEXT," + "metadata TEXT" ")"); } catch (const std::exception& e) {} } @@ -70,10 +71,6 @@ TEST_F(ModelsTestSuite, TestGetModelInfo) { EXPECT_TRUE(model_by_id.has_value()); EXPECT_EQ(model_by_id.value().model, kTestModel.model); - auto model_by_alias = model_list_.GetModelInfo("test_alias"); - EXPECT_TRUE(model_by_alias); - EXPECT_EQ(model_by_alias.value().model, kTestModel.model); - EXPECT_TRUE(model_list_.GetModelInfo("non_existent_model").has_error()); // Clean up @@ -104,26 +101,6 @@ TEST_F(ModelsTestSuite, TestDeleteModelEntry) { EXPECT_TRUE(model_list_.GetModelInfo(kTestModel.model).has_error()); } -TEST_F(ModelsTestSuite, TestGenerateShortenedAlias) { - EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); - auto models1 = model_list_.LoadModelList(); - auto alias = model_list_.GenerateShortenedAlias( - "huggingface.co:bartowski:llama3.1-7b-gguf:Model_ID_Xxx.gguf", - models1.value()); - EXPECT_EQ(alias, "model_id_xxx"); - EXPECT_TRUE(model_list_.UpdateModelAlias(kTestModel.model, alias).value()); - - // Test with existing entries to force longer alias - auto models2 = model_list_.LoadModelList(); - alias = model_list_.GenerateShortenedAlias( - "huggingface.co:bartowski:llama3.1-7b-gguf:Model_ID_Xxx.gguf", - models2.value()); - EXPECT_EQ(alias, "llama3.1-7b-gguf:model_id_xxx"); - - // Clean up - EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); -} - TEST_F(ModelsTestSuite, TestPersistence) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); @@ -136,53 +113,10 @@ TEST_F(ModelsTestSuite, TestPersistence) { EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); } -TEST_F(ModelsTestSuite, TestUpdateModelAlias) { - constexpr const auto kNewTestAlias = "new_test_alias"; - constexpr const auto kNonExistentModel = "non_existent_model"; - constexpr const auto kAnotherAlias = "another_alias"; - constexpr const auto kFinalTestAlias = "final_test_alias"; - constexpr const auto kAnotherModelId = "another_model_id"; - // Add the test model - ASSERT_TRUE(model_list_.AddModelEntry(kTestModel).value()); - - // Test successful update - EXPECT_TRUE( - model_list_.UpdateModelAlias(kTestModel.model, kNewTestAlias).value()); - auto updated_model = model_list_.GetModelInfo(kNewTestAlias); - EXPECT_TRUE(updated_model.has_value()); - EXPECT_EQ(updated_model.value().model_alias, kNewTestAlias); - EXPECT_EQ(updated_model.value().model, kTestModel.model); - - // Test update with non-existent model - EXPECT_TRUE(model_list_.UpdateModelAlias(kNonExistentModel, kAnotherAlias) - .has_error()); - - // Test update with non-unique alias - cortex::db::ModelEntry another_model = kTestModel; - another_model.model = kAnotherModelId; - another_model.model_alias = kAnotherAlias; - ASSERT_TRUE(model_list_.AddModelEntry(another_model).value()); - - EXPECT_FALSE( - model_list_.UpdateModelAlias(kTestModel.model, kAnotherAlias).value()); - - // Test update using model alias instead of model ID - EXPECT_TRUE(model_list_.UpdateModelAlias(kNewTestAlias, kFinalTestAlias)); - updated_model = model_list_.GetModelInfo(kFinalTestAlias); - EXPECT_TRUE(updated_model); - EXPECT_EQ(updated_model.value().model_alias, kFinalTestAlias); - EXPECT_EQ(updated_model.value().model, kTestModel.model); - - // Clean up - EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); - EXPECT_TRUE(model_list_.DeleteModelEntry(kAnotherModelId).value()); -} - TEST_F(ModelsTestSuite, TestHasModel) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel).value()); EXPECT_TRUE(model_list_.HasModel(kTestModel.model)); - EXPECT_TRUE(model_list_.HasModel("test_alias")); EXPECT_FALSE(model_list_.HasModel("non_existent_model")); // Clean up EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); diff --git a/engine/utils/huggingface_utils.h b/engine/utils/huggingface_utils.h index f2895c363..1d1040612 100644 --- a/engine/utils/huggingface_utils.h +++ b/engine/utils/huggingface_utils.h @@ -67,6 +67,7 @@ struct HuggingFaceModelRepoInfo { std::vector siblings; std::vector spaces; std::string createdAt; + std::string metadata; static cpp::result FromJson( const Json::Value& body) { @@ -104,6 +105,7 @@ struct HuggingFaceModelRepoInfo { .spaces = json_parser_utils::ParseJsonArray(body["spaces"]), .createdAt = body["createdAt"].asString(), + .metadata = body.toStyledString(), }; } diff --git a/engine/utils/json_parser_utils.h b/engine/utils/json_parser_utils.h index 3ebd2c546..b4ea1a7e1 100644 --- a/engine/utils/json_parser_utils.h +++ b/engine/utils/json_parser_utils.h @@ -10,7 +10,7 @@ template T jsonToValue(const Json::Value& value); template <> -std::string jsonToValue(const Json::Value& value) { +inline std::string jsonToValue(const Json::Value& value) { return value.asString(); } From 9f6936c246efe4b5c77e09d94e4e040430a451b9 Mon Sep 17 00:00:00 2001 From: NamH Date: Fri, 13 Dec 2024 08:15:34 +0700 Subject: [PATCH 28/44] chore: add files api docs (#1793) Signed-off-by: James --- docs/static/openapi/cortex.json | 368 ++++++++++++++++++++++++++++++-- 1 file changed, 356 insertions(+), 12 deletions(-) diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 2ff239ce2..9b96ba0a7 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -810,7 +810,7 @@ "/v1/models/sources": { "post": { "summary": "Add a model source", - "description": "User can add a Huggingface Organization or Repository", + "description": "User can add a Huggingface Organization or Repository", "requestBody": { "required": false, "content": { @@ -850,7 +850,7 @@ }, "delete": { "summary": "Remove a model source", - "description": "User can remove a Huggingface Organization or Repository", + "description": "User can remove a Huggingface Organization or Repository", "requestBody": { "required": false, "content": { @@ -860,7 +860,7 @@ "properties": { "source": { "type": "string", - "description": "The url of model source to remove", + "description": "The url of model source to remove", "example": "https://huggingface.co/cortexso/tinyllama" } } @@ -1583,7 +1583,13 @@ "required": true, "schema": { "type": "string", - "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm", "openai", "anthropic"], + "enum": [ + "llama-cpp", + "onnxruntime", + "tensorrt-llm", + "openai", + "anthropic" + ], "default": "llama-cpp" }, "description": "The type of engine" @@ -1625,9 +1631,9 @@ "type": "object", "properties": { "get_models_url": { - "type": "string", - "description": "The URL to get models", - "example": "https://api.openai.com/v1/models" + "type": "string", + "description": "The URL to get models", + "example": "https://api.openai.com/v1/models" } } } @@ -1666,7 +1672,13 @@ "required": true, "schema": { "type": "string", - "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm", "openai", "anthropic"], + "enum": [ + "llama-cpp", + "onnxruntime", + "tensorrt-llm", + "openai", + "anthropic" + ], "default": "llama-cpp" }, "description": "The type of engine" @@ -1881,7 +1893,13 @@ "required": true, "schema": { "type": "string", - "enum": ["llama-cpp", "onnxruntime", "tensorrt-llm", "openai", "anthropic"], + "enum": [ + "llama-cpp", + "onnxruntime", + "tensorrt-llm", + "openai", + "anthropic" + ], "default": "llama-cpp" }, "description": "The name of the engine to update" @@ -2058,6 +2076,319 @@ "tags": ["Hardware"] } }, + "/v1/files": { + "post": { + "summary": "Upload a File", + "description": "Uploads a file to the Cortex server.", + "requestBody": { + "required": true, + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "properties": { + "file": { + "type": "string", + "format": "binary" + }, + "purpose": { + "type": "string", + "enum": ["assistants"], + "description": "The intended purpose of the uploaded file" + } + }, + "required": ["file", "purpose"] + } + } + } + }, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "bytes": { + "type": "integer", + "example": 3211109 + }, + "created_at": { + "type": "integer", + "example": 1733942093 + }, + "filename": { + "type": "string", + "example": "Enterprise_Application_Infrastructure_v2_20140903_toCTC_v1.0.pdf" + }, + "id": { + "type": "string", + "example": "file-0001KNKPTDDAQSDVEQGRBTCTNJ" + }, + "object": { + "type": "string", + "example": "file" + }, + "purpose": { + "type": "string", + "example": "assistants" + } + } + } + } + } + } + }, + "tags": ["Files"] + }, + "get": { + "summary": "List files", + "description": "Lists all the files in the current directory.", + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "bytes": { + "type": "integer", + "example": 3211109 + }, + "created_at": { + "type": "integer", + "example": 1733942093 + }, + "filename": { + "type": "string", + "example": "Enterprise_Application_Infrastructure_v2_20140903_toCTC_v1.0.pdf" + }, + "id": { + "type": "string", + "example": "file-0001KNKPTDDAQSDVEQGRBTCTNJ" + }, + "object": { + "type": "string", + "example": "file" + }, + "purpose": { + "type": "string", + "example": "assistants" + } + } + } + }, + "object": { + "type": "string", + "example": "list" + } + } + } + } + } + } + }, + "tags": ["Files"] + } + }, + "/v1/files/{id}": { + "get": { + "summary": "Retrieve File", + "description": "Retrieves a file by its ID.", + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "description": "The ID of the file to retrieve", + "schema": { + "type": "string" + } + }, + { + "name": "thread", + "in": "query", + "required": false, + "description": "Optional thread identifier", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Successfully retrieved file", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "bytes": { + "type": "integer", + "example": 3211109 + }, + "created_at": { + "type": "integer", + "example": 1733942093 + }, + "filename": { + "type": "string", + "example": "Enterprise_Application_Infrastructure_v2_20140903_toCTC_v1.0.pdf" + }, + "id": { + "type": "string", + "example": "file-0001KNKPTDDAQSDVEQGRBTCTNJ" + }, + "object": { + "type": "string", + "example": "file" + }, + "purpose": { + "type": "string", + "example": "assistants" + } + } + } + } + } + } + }, + "tags": ["Files"] + }, + "delete": { + "summary": "Delete File", + "description": "Deletes a file by its ID.", + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "description": "The ID of the file to delete", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "File successfully deleted", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "deleted": { + "type": "boolean", + "description": "Indicates if the file was successfully deleted" + }, + "id": { + "type": "string", + "description": "The ID of the deleted file" + }, + "object": { + "type": "string", + "description": "Type of object, always 'file'" + } + }, + "required": ["deleted", "id", "object"] + }, + "example": { + "deleted": true, + "id": "file-0001KNP26FC62D620DGYNG2R8H", + "object": "file" + } + } + } + }, + "400": { + "description": "File not found or invalid request", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "Error message describing the issue" + } + }, + "required": ["message"] + }, + "example": { + "message": "File not found: file-0001KNP26FC62D620DGYNG2R8H" + } + } + } + } + }, + "tags": ["Files"] + } + }, + "/v1/files/{id}/content": { + "get": { + "summary": "Get File Content", + "description": "Retrieves the content of a file by its ID.", + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "description": "The ID of the file to retrieve content from", + "schema": { + "type": "string" + } + }, + { + "name": "thread", + "in": "query", + "required": false, + "description": "Optional thread identifier", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "File content retrieved successfully", + "content": { + "*/*": { + "schema": { + "type": "string", + "format": "binary", + "description": "The raw content of the file" + } + } + } + }, + "400": { + "description": "File not found or invalid request", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "Error message describing the issue" + } + }, + "required": ["message"] + } + } + } + } + }, + "tags": ["Files"] + } + }, "/v1/configs": { "get": { "summary": "Get Configurations", @@ -2338,6 +2669,10 @@ "name": "Engines", "description": "Endpoints for managing the available engines within Cortex." }, + { + "name": "Files", + "description": "Endpoints for managing the files within Cortex." + }, { "name": "Hardware", "description": "Endpoints for managing the available hardware within Cortex." @@ -2354,6 +2689,7 @@ "Chat", "Embeddings", "Engines", + "Files", "Hardware", "Events", "Pulling Models", @@ -2426,7 +2762,7 @@ } }, "required": ["type", "function"] - }, + } }, "metadata": { "type": "object", @@ -3829,7 +4165,15 @@ }, "AddModelRequest": { "type": "object", - "required": ["model", "engine", "version", "inference_params", "TransformReq", "TransformResp", "metadata"], + "required": [ + "model", + "engine", + "version", + "inference_params", + "TransformReq", + "TransformResp", + "metadata" + ], "properties": { "model": { "type": "string", @@ -3878,7 +4222,7 @@ }, "chat_completions": { "type": "object", - "properties": { + "properties": { "url": { "type": "string" }, From b390fa4d73932182b79151d173fe575e32652efa Mon Sep 17 00:00:00 2001 From: NamH Date: Fri, 13 Dec 2024 08:43:57 +0700 Subject: [PATCH 29/44] chore: add thread api docs (#1794) --- docs/static/openapi/cortex.json | 566 +++++++++++++++----------------- engine/controllers/threads.cc | 7 +- 2 files changed, 268 insertions(+), 305 deletions(-) diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 9b96ba0a7..ba7944b71 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -908,319 +908,174 @@ }, "/v1/threads": { "post": { - "operationId": "ThreadsController_create", - "summary": "Create thread", - "tags": ["Threads"], - "description": "Creates a new thread.", - "parameters": [], + "summary": "Create Thread", + "description": "Creates a new thread with optional metadata.", "requestBody": { - "required": true, "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/CreateThreadDto" - } - } - } - }, - "responses": { - "201": { - "description": "", - "content": { - "application/json": { - "schema": { - "type": "object" - } - } - } - } - } - }, - "get": { - "operationId": "ThreadsController_findAll", - "summary": "List threads", - "tags": ["Threads"], - "description": "Lists all the available threads along with its configurations.", - "parameters": [], - "responses": { - "200": { - "description": "", - "content": { - "application/json": { - "schema": { - "type": "array", - "items": { - "type": "object" + "type": "object", + "properties": { + "metadata": { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Title of the thread" + } + }, + "description": "Optional metadata for the thread" } } + }, + "example": { + "metadata": { + "title": "New Thread" + } } } - } - } - } - }, - "/v1/threads/{thread_id}/messages/{message_id}": { - "get": { - "operationId": "ThreadsController_retrieveMessage", - "summary": "Retrieve message", - "tags": ["Messages"], - "description": "Retrieves a message.", - "parameters": [ - { - "name": "thread_id", - "required": true, - "in": "path", - "schema": { - "type": "string" - } }, - { - "name": "message_id", - "required": true, - "in": "path", - "schema": { - "type": "string" - } - } - ], + "required": false + }, "responses": { "200": { - "description": "The message object matching the specified ID.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/GetMessageResponseDto" - } - } - } - } - } - }, - "post": { - "operationId": "ThreadsController_updateMessage", - "summary": "Modify message", - "tags": ["Messages"], - "description": "Modifies a message.", - "responses": { - "201": { - "description": "", + "description": "Thread created successfully", "content": { "application/json": { "schema": { - "type": "object" + "type": "object", + "properties": { + "created_at": { + "type": "integer", + "description": "Unix timestamp of when the thread was created" + }, + "id": { + "type": "string", + "description": "Unique identifier for the thread" + }, + "metadata": { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Title of the thread" + } + }, + "description": "Metadata associated with the thread" + }, + "object": { + "type": "string", + "description": "Type of object, always 'thread'" + } + }, + "required": ["created_at", "id", "object"] + }, + "example": { + "created_at": 1734020845, + "id": "0001KNP3QDX314435VAEGW1Z2X", + "metadata": { + "title": "New Thread" + }, + "object": "thread" } } } } }, - "parameters": [ - { - "name": "thread_id", - "required": true, - "in": "path", - "schema": { - "type": "string" - } - }, - { - "name": "message_id", - "required": true, - "in": "path", - "schema": { - "type": "string" - } - } - ], - "requestBody": { - "required": true, - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/UpdateMessageDto" - } - } - } - } + "tags": ["Threads"] }, - "delete": { - "operationId": "ThreadsController_deleteMessage", - "summary": "Delete message", - "description": "Deletes a message.", - "tags": ["Messages"], - "parameters": [ - { - "name": "thread_id", - "required": true, - "in": "path", - "schema": { - "type": "string" - } - }, - { - "name": "message_id", - "required": true, - "in": "path", - "schema": { - "type": "string" - } - } - ], - "responses": { - "200": { - "description": "Deletion status.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/DeleteMessageDto" - } - } - } - } - } - } - }, - "/v1/threads/{thread_id}/messages": { "get": { - "operationId": "ThreadsController_getMessagesOfThread", - "summary": "List messages", - "tags": ["Messages"], - "description": "Returns a list of messages for a given thread.", - "parameters": [ - { - "name": "thread_id", - "required": true, - "in": "path", - "schema": { - "type": "string" - } - }, - { - "name": "limit", - "required": true, - "in": "query", - "schema": { - "type": "number" - } - }, - { - "name": "order", - "required": true, - "in": "query", - "schema": { - "type": "string" - } - }, - { - "name": "after", - "required": true, - "in": "query", - "schema": { - "type": "string" - } - }, - { - "name": "before", - "required": true, - "in": "query", - "schema": { - "type": "string" - } - }, - { - "name": "run_id", - "required": true, - "in": "query", - "schema": { - "type": "string" - } - } - ], + "summary": "List Threads", + "description": "Returns a list of threads with their metadata.", "responses": { "200": { - "description": "A list of message objects.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ListMessagesResponseDto" - } - } - } - } - } - }, - "post": { - "operationId": "ThreadsController_createMessageInThread", - "summary": "Create message", - "tags": ["Messages"], - "description": "Create a message.", - "responses": { - "201": { - "description": "", + "description": "List of threads retrieved successfully", "content": { "application/json": { "schema": { - "type": "object" + "type": "object", + "properties": { + "object": { + "type": "string", + "description": "Type of the list response, always 'list'" + }, + "data": { + "type": "array", + "description": "Array of thread objects", + "items": { + "type": "object", + "properties": { + "created_at": { + "type": "integer", + "description": "Unix timestamp of when the thread was created" + }, + "id": { + "type": "string", + "description": "Unique identifier for the thread" + }, + "metadata": { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Title of the thread" + }, + "lastMessage": { + "type": "string", + "description": "Content of the last message in the thread" + } + }, + "description": "Metadata associated with the thread" + }, + "object": { + "type": "string", + "description": "Type of object, always 'thread'" + } + }, + "required": ["created_at", "id", "object"] + } + } + }, + "required": ["object", "data"] + }, + "example": { + "data": [ + { + "created_at": 1734020845, + "id": "0001KNP3QDX314435VAEGW1Z2X", + "metadata": { + "title": "New Thread" + }, + "object": "thread" + }, + { + "created_at": 1734020803, + "id": "0001KNP3P3DAQSDVEQGRBTCTNJ", + "metadata": { + "title": "" + }, + "object": "thread" + } + ], + "object": "list" } } } } }, - "parameters": [ - { - "name": "thread_id", - "required": true, - "in": "path", - "schema": { - "type": "string" - } - } - ], - "requestBody": { - "required": true, - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/CreateMessageDto" - } - } - } - } - } - }, - "/v1/threads/{thread_id}/clean": { - "post": { - "operationId": "ThreadsController_cleanThread", - "summary": "Clean thread", - "description": "Deletes all messages in a thread.", - "tags": ["Threads"], - "parameters": [ - { - "name": "thread_id", - "required": true, - "in": "path", - "schema": { - "type": "string" - } - } - ], - "responses": { - "201": { - "description": "" - } - } + "tags": ["Threads"] } }, - "/v1/threads/{thread_id}": { + "/v1/threads/{id}": { "get": { - "operationId": "ThreadsController_retrieveThread", - "summary": "Retrieve thread", - "tags": ["Threads"], - "description": "Retrieves a thread.", + "summary": "Retrieve Thread", + "description": "Retrieves a specific thread by its ID.", "parameters": [ { - "name": "thread_id", - "required": true, + "name": "id", "in": "path", + "required": true, + "description": "The ID of the thread to retrieve", "schema": { "type": "string" } @@ -1228,27 +1083,65 @@ ], "responses": { "200": { - "description": "Retrieves a thread.", + "description": "Thread retrieved successfully", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/GetThreadResponseDto" + "type": "object", + "properties": { + "created_at": { + "type": "integer", + "description": "Unix timestamp of when the thread was created" + }, + "id": { + "type": "string", + "description": "Unique identifier for the thread" + }, + "metadata": { + "type": "object", + "properties": { + "lastMessage": { + "type": "string", + "description": "Content of the last message in the thread" + }, + "title": { + "type": "string", + "description": "Title of the thread" + } + }, + "description": "Metadata associated with the thread" + }, + "object": { + "type": "string", + "description": "Type of object, always 'thread'" + } + }, + "required": ["created_at", "id", "object"] + }, + "example": { + "created_at": 1732370026, + "id": "jan_1732370027", + "metadata": { + "lastMessage": "Based on the context, I'm not sure how to build a unique experience quickly and easily. The text mentions that there are some concerns about Android apps providing consistent experiences for different users, which makes me skeptical about building one.\n\nSpecifically, it says:\n\n* \"Might not pass CTS\" (Computer Science Technology standards)\n* \"Might not comply with CDD\" (Consumer Development Division standards)\n\nThis suggests that building a unique experience for all users could be challenging or impossible. Therefore, I don't know how to build a unique experience quickly and easily.\n\nWould you like me to try again?", + "title": "hello" + }, + "object": "thread" } } } } - } + }, + "tags": ["Threads"] }, - "post": { - "operationId": "ThreadsController_modifyThread", - "summary": "Modify thread", - "tags": ["Threads"], - "description": "Modifies a thread.", + "patch": { + "summary": "Modify Thread", + "description": "Updates a specific thread's metadata.", "parameters": [ { - "name": "thread_id", - "required": true, + "name": "id", "in": "path", + "required": true, + "description": "The ID of the thread to modify", "schema": { "type": "string" } @@ -1259,37 +1152,84 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/UpdateThreadDto" + "type": "object", + "properties": { + "metadata": { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "New title for the thread" + } + }, + "description": "Metadata to update" + } + } + }, + "example": { + "metadata": { + "title": "my title" + } } } } }, "responses": { "200": { - "description": "The thread has been successfully updated.", + "description": "Thread modified successfully", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/UpdateThreadDto" + "type": "object", + "properties": { + "created_at": { + "type": "integer", + "description": "Unix timestamp of when the thread was created" + }, + "id": { + "type": "string", + "description": "Unique identifier for the thread" + }, + "metadata": { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Updated title of the thread" + } + }, + "description": "Updated metadata for the thread" + }, + "object": { + "type": "string", + "description": "Type of object, always 'thread'" + } + }, + "required": ["created_at", "id", "object"] + }, + "example": { + "created_at": 1733301054, + "id": "0001KN04SY7D75K0MPTXMXCH39", + "metadata": { + "title": "my title" + }, + "object": "thread" } } } - }, - "201": { - "description": "" } - } + }, + "tags": ["Threads"] }, "delete": { - "operationId": "ThreadsController_remove", - "summary": "Delete thread", - "tags": ["Threads"], - "description": "Deletes a specific thread defined by a thread `id` .", + "summary": "Delete Thread", + "description": "Deletes a specific thread by its ID.", "parameters": [ { - "name": "thread_id", - "required": true, + "name": "id", "in": "path", + "required": true, + "description": "The ID of the thread to delete", "schema": { "type": "string" } @@ -1297,16 +1237,37 @@ ], "responses": { "200": { - "description": "The thread has been successfully deleted.", + "description": "Thread deleted successfully", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/DeleteThreadResponseDto" + "type": "object", + "properties": { + "deleted": { + "type": "boolean", + "description": "Indicates if the thread was successfully deleted" + }, + "id": { + "type": "string", + "description": "ID of the deleted thread" + }, + "object": { + "type": "string", + "description": "Type of object, always 'thread.deleted'" + } + }, + "required": ["deleted", "id", "object"] + }, + "example": { + "deleted": true, + "id": "jan_1732370027", + "object": "thread.deleted" } } } } - } + }, + "tags": ["Threads"] } }, "/v1/system": { @@ -2692,6 +2653,7 @@ "Files", "Hardware", "Events", + "Threads", "Pulling Models", "Running Models", "Processes", diff --git a/engine/controllers/threads.cc b/engine/controllers/threads.cc index 81e14ce5a..4a87bc9eb 100644 --- a/engine/controllers/threads.cc +++ b/engine/controllers/threads.cc @@ -193,10 +193,11 @@ void Threads::ModifyThread( resp->setStatusCode(k400BadRequest); callback(resp); } else { - res->ToJson()->removeMember("title"); - res->ToJson()->removeMember("assistants"); + auto json_res = res->ToJson(); + json_res->removeMember("title"); + json_res->removeMember("assistants"); auto resp = - cortex_utils::CreateCortexHttpJsonResponse(res->ToJson().value()); + cortex_utils::CreateCortexHttpJsonResponse(json_res.value()); resp->setStatusCode(k200OK); callback(resp); } From 4c39bdbe7697be1bbb4decdc0baf98becac490e8 Mon Sep 17 00:00:00 2001 From: NamH Date: Fri, 13 Dec 2024 09:06:14 +0700 Subject: [PATCH 30/44] chore: add messages api docs (#1795) --- docs/static/openapi/cortex.json | 723 +++++++++++++++++++++++++++++++- 1 file changed, 722 insertions(+), 1 deletion(-) diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index ba7944b71..4792fe306 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -1270,6 +1270,726 @@ "tags": ["Threads"] } }, + "/v1/threads/{thread_id}/messages": { + "post": { + "summary": "Create Message", + "description": "Creates a new message in a thread.", + "parameters": [ + { + "name": "thread_id", + "in": "path", + "required": true, + "description": "The ID of the thread to create the message in", + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "role": { + "type": "string", + "description": "Role of the message sender", + "enum": ["user", "assistant"] + }, + "content": { + "type": "string", + "description": "The content of the message" + } + }, + "required": ["role", "content"] + }, + "example": { + "role": "user", + "content": "Hello, world!" + } + } + } + }, + "responses": { + "200": { + "description": "Message created successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for the message" + }, + "object": { + "type": "string", + "description": "Type of object, always 'thread.message'" + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp of when the message was created" + }, + "completed_at": { + "type": "integer", + "description": "Unix timestamp of when the message was completed" + }, + "thread_id": { + "type": "string", + "description": "ID of the thread this message belongs to" + }, + "role": { + "type": "string", + "description": "Role of the message sender", + "enum": ["user", "assistant"] + }, + "status": { + "type": "string", + "description": "Status of the message", + "enum": ["completed"] + }, + "content": { + "type": "array", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "Type of content", + "enum": ["text"] + }, + "text": { + "type": "object", + "properties": { + "value": { + "type": "string", + "description": "The message text" + }, + "annotations": { + "type": "array", + "description": "Array of annotations for the text" + } + } + } + } + } + }, + "metadata": { + "type": "object", + "description": "Additional metadata for the message" + } + }, + "required": [ + "id", + "object", + "created_at", + "completed_at", + "thread_id", + "role", + "status", + "content" + ] + }, + "example": { + "completed_at": 1734023130, + "content": [ + { + "text": { + "annotations": [], + "value": "Hello, world!" + }, + "type": "text" + } + ], + "created_at": 1734023130, + "id": "0001KNP5YT00GW0X476W5TVBFE", + "metadata": {}, + "object": "thread.message", + "role": "user", + "status": "completed", + "thread_id": "jan_1732370027" + } + } + } + } + }, + "tags": ["Messages"] + }, + "get": { + "summary": "List Messages", + "description": "Retrieves a list of messages in a thread with optional pagination and filtering.", + "parameters": [ + { + "name": "thread_id", + "in": "path", + "required": true, + "description": "The ID of the thread to list messages from", + "schema": { + "type": "string" + } + }, + { + "name": "limit", + "in": "query", + "required": false, + "description": "Maximum number of messages to return", + "schema": { + "type": "integer" + } + }, + { + "name": "order", + "in": "query", + "required": false, + "description": "Sort order of messages", + "schema": { + "type": "string", + "enum": ["asc", "desc"] + } + }, + { + "name": "after", + "in": "query", + "required": false, + "description": "Cursor for fetching messages after this message ID", + "schema": { + "type": "string" + } + }, + { + "name": "before", + "in": "query", + "required": false, + "description": "Cursor for fetching messages before this message ID", + "schema": { + "type": "string" + } + }, + { + "name": "run_id", + "in": "query", + "required": false, + "description": "Filter messages by run ID", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Messages retrieved successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "object": { + "type": "string", + "description": "Type of the list response, always 'list'" + }, + "data": { + "type": "array", + "description": "Array of message objects", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for the message" + }, + "object": { + "type": "string", + "description": "Type of object, always 'thread.message'" + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp of when the message was created" + }, + "thread_id": { + "type": "string", + "description": "ID of the thread this message belongs to" + }, + "role": { + "type": "string", + "description": "Role of the message sender", + "enum": ["assistant", "user"] + }, + "status": { + "type": "string", + "description": "Status of the message", + "enum": ["completed"] + }, + "content": { + "type": "array", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "Type of content", + "enum": ["text"] + }, + "text": { + "type": "object", + "properties": { + "value": { + "type": "string", + "description": "The message text" + }, + "annotations": { + "type": "array", + "description": "Array of annotations for the text" + } + } + } + } + } + }, + "metadata": { + "type": "object", + "description": "Additional metadata for the message" + }, + "attachments": { + "type": "array", + "items": { + "type": "object", + "properties": { + "file_id": { + "type": "string", + "description": "ID of the attached file" + }, + "tools": { + "type": "array", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "Type of tool used" + } + } + } + } + } + } + } + }, + "required": [ + "id", + "object", + "created_at", + "thread_id", + "role", + "content" + ] + } + } + }, + "required": ["object", "data"] + }, + "example": { + "data": [ + { + "content": [ + { + "text": { + "annotations": [], + "value": "Based on the context, I'm not sure how to build a unique experience quickly and easily..." + }, + "type": "text" + } + ], + "created_at": 1732633637, + "id": "01JDMG6CG6DD4B3RQN82QD8Q7P", + "metadata": {}, + "object": "thread.message", + "role": "assistant", + "status": "completed", + "thread_id": "jan_1732370027" + } + ], + "object": "list" + } + } + } + } + }, + "tags": ["Messages"] + } + }, + "/v1/threads/{thread_id}/messages/{message_id}": { + "get": { + "summary": "Retrieve Message", + "description": "Retrieves a specific message from a thread by its ID.", + "parameters": [ + { + "name": "thread_id", + "in": "path", + "required": true, + "description": "The ID of the thread containing the message", + "schema": { + "type": "string" + } + }, + { + "name": "message_id", + "in": "path", + "required": true, + "description": "The ID of the message to retrieve", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Message retrieved successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for the message" + }, + "object": { + "type": "string", + "description": "Type of object, always 'thread.message'" + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp of when the message was created" + }, + "thread_id": { + "type": "string", + "description": "ID of the thread this message belongs to" + }, + "role": { + "type": "string", + "description": "Role of the message sender", + "enum": ["assistant", "user"] + }, + "status": { + "type": "string", + "description": "Status of the message", + "enum": ["completed"] + }, + "content": { + "type": "array", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "Type of content", + "enum": ["text"] + }, + "text": { + "type": "object", + "properties": { + "value": { + "type": "string", + "description": "The message text" + }, + "annotations": { + "type": "array", + "description": "Array of annotations for the text" + } + } + } + } + } + }, + "metadata": { + "type": "object", + "description": "Additional metadata for the message" + }, + "attachments": { + "type": "array", + "items": { + "type": "object", + "properties": { + "file_id": { + "type": "string", + "description": "ID of the attached file" + }, + "tools": { + "type": "array", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "Type of tool used" + } + } + } + } + } + } + } + }, + "required": [ + "id", + "object", + "created_at", + "thread_id", + "role", + "content" + ] + }, + "example": { + "attachments": [ + { + "file_id": "01JDMG617BHMPW859VE18BPQ7Y", + "tools": [ + { + "type": "file_search" + } + ] + } + ], + "content": [ + { + "text": { + "annotations": [], + "value": "summary this" + }, + "type": "text" + } + ], + "created_at": 1732633625, + "id": "01JDMG617BHMPW859VE18BPQ7Y", + "metadata": {}, + "object": "thread.message", + "role": "user", + "status": "completed", + "thread_id": "jan_1732370027" + } + } + } + } + }, + "tags": ["Messages"] + }, + "patch": { + "summary": "Modify Message", + "description": "Modifies a specific message's content or metadata in a thread.", + "parameters": [ + { + "name": "thread_id", + "in": "path", + "required": true, + "description": "The ID of the thread containing the message", + "schema": { + "type": "string" + } + }, + { + "name": "message_id", + "in": "path", + "required": true, + "description": "The ID of the message to modify", + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "content": { + "type": "object", + "description": "New content for the message" + }, + "metadata": { + "type": "object", + "description": "Updated metadata for the message", + "additionalProperties": true + } + } + }, + "example": { + "content": {}, + "metadata": { + "test": 1 + } + } + } + } + }, + "responses": { + "200": { + "description": "Message modified successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for the message" + }, + "object": { + "type": "string", + "description": "Type of object, always 'thread.message'" + }, + "created_at": { + "type": "integer", + "description": "Unix timestamp of when the message was created" + }, + "completed_at": { + "type": "integer", + "description": "Unix timestamp of when the message was completed" + }, + "thread_id": { + "type": "string", + "description": "ID of the thread this message belongs to" + }, + "role": { + "type": "string", + "description": "Role of the message sender", + "enum": ["user", "assistant"] + }, + "status": { + "type": "string", + "description": "Status of the message", + "enum": ["completed"] + }, + "content": { + "type": "array", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "Type of content", + "enum": ["text"] + }, + "text": { + "type": "object", + "properties": { + "value": { + "type": "string", + "description": "The message text" + }, + "annotations": { + "type": "array", + "description": "Array of annotations for the text" + } + } + } + } + } + }, + "metadata": { + "type": "object", + "description": "Additional metadata for the message", + "additionalProperties": true + } + }, + "required": [ + "id", + "object", + "created_at", + "completed_at", + "thread_id", + "role", + "status", + "content" + ] + }, + "example": { + "completed_at": 1734023130, + "content": [ + { + "text": { + "annotations": [], + "value": "Hello, world!" + }, + "type": "text" + } + ], + "created_at": 1734023130, + "id": "0001KNP5YT00GW0X476W5TVBFE", + "metadata": { + "test": 1 + }, + "object": "thread.message", + "role": "user", + "status": "completed", + "thread_id": "jan_1732370027" + } + } + } + } + }, + "tags": ["Messages"] + }, + "delete": { + "summary": "Delete Message", + "description": "Deletes a specific message from a thread.", + "parameters": [ + { + "name": "thread_id", + "in": "path", + "required": true, + "description": "The ID of the thread containing the message", + "schema": { + "type": "string" + } + }, + { + "name": "message_id", + "in": "path", + "required": true, + "description": "The ID of the message to delete", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Message deleted successfully", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "deleted": { + "type": "boolean", + "description": "Indicates if the message was successfully deleted" + }, + "id": { + "type": "string", + "description": "ID of the deleted message" + }, + "object": { + "type": "string", + "description": "Type of object, always 'thread.message.deleted'" + } + }, + "required": ["deleted", "id", "object"] + }, + "example": { + "deleted": true, + "id": "01JDCMZPBGDP276D6Z2QN2MJMX", + "object": "thread.message.deleted" + } + } + } + } + }, + "tags": ["Messages"] + } + }, "/v1/system": { "delete": { "operationId": "SystemController_delete", @@ -2615,7 +3335,7 @@ "description": "These endpoints manage the lifecycle of Server, including heath check and shutdown." }, { - "name": "Configuration", + "name": "Configurations", "description": "These endpoints manage the configuration of the Cortex server." }, { @@ -2654,6 +3374,7 @@ "Hardware", "Events", "Threads", + "Messages", "Pulling Models", "Running Models", "Processes", From a64af0090dae29a6bf1820f70031e6f687d457b3 Mon Sep 17 00:00:00 2001 From: NamH Date: Fri, 13 Dec 2024 11:19:18 +0700 Subject: [PATCH 31/44] fix: load engine linux (#1790) * fix: load engine linux * fix linux --------- Co-authored-by: vansangpfiev --- engine/CMakeLists.txt | 1 + engine/cli/CMakeLists.txt | 1 + engine/cli/command_line_parser.cc | 65 +++++------ engine/cli/command_line_parser.h | 8 +- engine/cli/commands/engine_install_cmd.cc | 2 +- engine/cli/commands/engine_install_cmd.h | 6 +- engine/cli/commands/engine_list_cmd.cc | 8 +- engine/cli/commands/engine_list_cmd.h | 7 ++ engine/cli/commands/run_cmd.cc | 4 +- engine/cli/commands/run_cmd.h | 9 +- engine/cli/commands/server_start_cmd.cc | 5 +- engine/controllers/engines.cc | 8 +- engine/cortex-common/EngineI.h | 15 +-- engine/main.cc | 6 +- engine/services/engine_service.cc | 76 +++++++++---- engine/services/engine_service.h | 13 +-- engine/services/model_service.cc | 4 +- engine/utils/config_yaml_utils.cc | 7 +- engine/utils/dylib_path_manager.cc | 129 ++++++++++++++++++++++ engine/utils/dylib_path_manager.h | 35 ++++++ 20 files changed, 293 insertions(+), 116 deletions(-) create mode 100644 engine/utils/dylib_path_manager.cc create mode 100644 engine/utils/dylib_path_manager.h diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 7cac3421c..41ebb3dd6 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -142,6 +142,7 @@ file(APPEND "${CMAKE_CURRENT_BINARY_DIR}/cortex_openapi.h" add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/cpuid/cpu_info.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/file_logger.cc + ${CMAKE_CURRENT_SOURCE_DIR}/utils/dylib_path_manager.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/remote_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/openai_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/anthropic_engine.cc diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index 51382dc13..237596f21 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -75,6 +75,7 @@ find_package(lfreist-hwinfo CONFIG REQUIRED) add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../utils/cpuid/cpu_info.cc ${CMAKE_CURRENT_SOURCE_DIR}/../utils/file_logger.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../utils/dylib_path_manager.cc ${CMAKE_CURRENT_SOURCE_DIR}/command_line_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/config_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/download_service.cc diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index 624ccd3dd..825780895 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -48,8 +48,11 @@ constexpr const auto kSubcommands = "Subcommands"; CommandLineParser::CommandLineParser() : app_("\nCortex.cpp CLI\n"), download_service_{std::make_shared()}, - model_service_{ModelService(download_service_)}, - engine_service_{EngineService(download_service_)} {} + dylib_path_manager_{std::make_shared()}, + engine_service_{std::make_shared(download_service_, + dylib_path_manager_)} { + supported_engines_ = engine_service_->GetSupportedEngineNames().value(); +} bool CommandLineParser::SetupCommand(int argc, char** argv) { app_.usage("Usage:\n" + commands::GetCortexBinary() + @@ -60,8 +63,6 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) { SetupCommonCommands(); - SetupInferenceCommands(); - SetupModelCommands(); SetupEngineCommands(); @@ -176,17 +177,11 @@ void CommandLineParser::SetupCommonCommands() { return; commands::RunCmd rc(cml_data_.config.apiServerHost, std::stoi(cml_data_.config.apiServerPort), - cml_data_.model_id, download_service_); + cml_data_.model_id, engine_service_); rc.Exec(cml_data_.run_detach, run_settings_); }); } -void CommandLineParser::SetupInferenceCommands() { - // auto embeddings_cmd = app_.add_subcommand( - // "embeddings", "Creates an embedding vector representing the input text"); - // embeddings_cmd->group(kInferenceGroup); -} - void CommandLineParser::SetupModelCommands() { // Models group commands auto models_cmd = @@ -476,7 +471,7 @@ void CommandLineParser::SetupEngineCommands() { list_engines_cmd->callback([this]() { if (std::exchange(executed_, true)) return; - commands::EngineListCmd command; + auto command = commands::EngineListCmd(engine_service_); command.Exec(cml_data_.config.apiServerHost, std::stoi(cml_data_.config.apiServerPort)); }); @@ -493,9 +488,9 @@ void CommandLineParser::SetupEngineCommands() { CLI_LOG(install_cmd->help()); } }); - for (const auto& engine : engine_service_.kSupportEngines) { - std::string engine_name{engine}; - EngineInstall(install_cmd, engine_name, cml_data_.engine_version, + + for (const auto& engine : supported_engines_) { + EngineInstall(install_cmd, engine, cml_data_.engine_version, cml_data_.engine_src); } @@ -512,9 +507,8 @@ void CommandLineParser::SetupEngineCommands() { } }); uninstall_cmd->group(kSubcommands); - for (auto& engine : engine_service_.kSupportEngines) { - std::string engine_name{engine}; - EngineUninstall(uninstall_cmd, engine_name); + for (const auto& engine : supported_engines_) { + EngineUninstall(uninstall_cmd, engine); } auto engine_upd_cmd = engines_cmd->add_subcommand("update", "Update engine"); @@ -529,9 +523,8 @@ void CommandLineParser::SetupEngineCommands() { } }); engine_upd_cmd->group(kSubcommands); - for (auto& engine : engine_service_.kSupportEngines) { - std::string engine_name{engine}; - EngineUpdate(engine_upd_cmd, engine_name); + for (const auto& engine : supported_engines_) { + EngineUpdate(engine_upd_cmd, engine); } auto engine_use_cmd = @@ -547,9 +540,8 @@ void CommandLineParser::SetupEngineCommands() { } }); engine_use_cmd->group(kSubcommands); - for (auto& engine : engine_service_.kSupportEngines) { - std::string engine_name{engine}; - EngineUse(engine_use_cmd, engine_name); + for (const auto& engine : supported_engines_) { + EngineUse(engine_use_cmd, engine); } auto engine_load_cmd = engines_cmd->add_subcommand("load", "Load engine"); @@ -564,9 +556,8 @@ void CommandLineParser::SetupEngineCommands() { } }); engine_load_cmd->group(kSubcommands); - for (auto& engine : engine_service_.kSupportEngines) { - std::string engine_name{engine}; - EngineLoad(engine_load_cmd, engine_name); + for (const auto& engine : supported_engines_) { + EngineLoad(engine_load_cmd, engine); } auto engine_unload_cmd = @@ -582,9 +573,8 @@ void CommandLineParser::SetupEngineCommands() { } }); engine_unload_cmd->group(kSubcommands); - for (auto& engine : engine_service_.kSupportEngines) { - std::string engine_name{engine}; - EngineUnload(engine_unload_cmd, engine_name); + for (const auto& engine : supported_engines_) { + EngineUnload(engine_unload_cmd, engine); } EngineGet(engines_cmd); @@ -756,7 +746,7 @@ void CommandLineParser::EngineInstall(CLI::App* parent, return; try { commands::EngineInstallCmd( - download_service_, cml_data_.config.apiServerHost, + engine_service_, cml_data_.config.apiServerHost, std::stoi(cml_data_.config.apiServerPort), cml_data_.show_menu) .Exec(engine_name, version, src); } catch (const std::exception& e) { @@ -878,20 +868,19 @@ void CommandLineParser::EngineGet(CLI::App* parent) { } }); - for (auto& engine : engine_service_.kSupportEngines) { - std::string engine_name{engine}; - std::string desc = "Get " + engine_name + " status"; + for (const auto& engine : supported_engines_) { + std::string desc = "Get " + engine + " status"; - auto engine_get_cmd = get_cmd->add_subcommand(engine_name, desc); + auto engine_get_cmd = get_cmd->add_subcommand(engine, desc); engine_get_cmd->usage("Usage:\n" + commands::GetCortexBinary() + - " engines get " + engine_name + " [options]"); + " engines get " + engine + " [options]"); engine_get_cmd->group(kEngineGroup); - engine_get_cmd->callback([this, engine_name] { + engine_get_cmd->callback([this, engine] { if (std::exchange(executed_, true)) return; commands::EngineGetCmd().Exec(cml_data_.config.apiServerHost, std::stoi(cml_data_.config.apiServerPort), - engine_name); + engine); }); } } diff --git a/engine/cli/command_line_parser.h b/engine/cli/command_line_parser.h index 896c026d0..14e10e420 100644 --- a/engine/cli/command_line_parser.h +++ b/engine/cli/command_line_parser.h @@ -5,7 +5,6 @@ #include "CLI/CLI.hpp" #include "commands/hardware_list_cmd.h" #include "services/engine_service.h" -#include "services/model_service.h" #include "utils/config_yaml_utils.h" class CommandLineParser { @@ -16,8 +15,6 @@ class CommandLineParser { private: void SetupCommonCommands(); - void SetupInferenceCommands(); - void SetupModelCommands(); void SetupEngineCommands(); @@ -47,8 +44,9 @@ class CommandLineParser { CLI::App app_; std::shared_ptr download_service_; - EngineService engine_service_; - ModelService model_service_; + std::shared_ptr dylib_path_manager_; + std::shared_ptr engine_service_; + std::vector supported_engines_; struct CmlData { std::string model_id; diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index 491ab0937..85a5def5d 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -12,7 +12,7 @@ bool EngineInstallCmd::Exec(const std::string& engine, const std::string& src) { // Handle local install, if fails, fallback to remote install if (!src.empty()) { - auto res = engine_service_.UnzipEngine(engine, version, src); + auto res = engine_service_->UnzipEngine(engine, version, src); if (res.has_error()) { CLI_LOG(res.error()); return false; diff --git a/engine/cli/commands/engine_install_cmd.h b/engine/cli/commands/engine_install_cmd.h index d50776dc4..2f318b4d7 100644 --- a/engine/cli/commands/engine_install_cmd.h +++ b/engine/cli/commands/engine_install_cmd.h @@ -7,9 +7,9 @@ namespace commands { class EngineInstallCmd { public: - explicit EngineInstallCmd(std::shared_ptr download_service, + explicit EngineInstallCmd(std::shared_ptr engine_service, const std::string& host, int port, bool show_menu) - : engine_service_{EngineService(download_service)}, + : engine_service_{engine_service}, host_(host), port_(port), show_menu_(show_menu), @@ -21,7 +21,7 @@ class EngineInstallCmd { const std::string& src = ""); private: - EngineService engine_service_; + std::shared_ptr engine_service_; std::string host_; int port_; bool show_menu_; diff --git a/engine/cli/commands/engine_list_cmd.cc b/engine/cli/commands/engine_list_cmd.cc index 35584dcd2..0abe32b28 100644 --- a/engine/cli/commands/engine_list_cmd.cc +++ b/engine/cli/commands/engine_list_cmd.cc @@ -13,7 +13,6 @@ // clang-format on namespace commands { - bool EngineListCmd::Exec(const std::string& host, int port) { // Start server if server is not started yet if (!commands::IsServerAlive(host, port)) { @@ -38,15 +37,10 @@ bool EngineListCmd::Exec(const std::string& host, int port) { return false; } - std::vector engines = { - kLlamaEngine, - kOnnxEngine, - kTrtLlmEngine, - }; - std::unordered_map> engine_map; + auto engines = engine_service_->GetSupportedEngineNames().value(); for (const auto& engine : engines) { auto installed_variants = result.value()[engine]; for (const auto& variant : installed_variants) { diff --git a/engine/cli/commands/engine_list_cmd.h b/engine/cli/commands/engine_list_cmd.h index 96ad956b2..1a06126a4 100644 --- a/engine/cli/commands/engine_list_cmd.h +++ b/engine/cli/commands/engine_list_cmd.h @@ -1,11 +1,18 @@ #pragma once #include +#include "services/engine_service.h" namespace commands { class EngineListCmd { public: + explicit EngineListCmd(std::shared_ptr engine_service) + : engine_service_{engine_service} {} + bool Exec(const std::string& host, int port); + + private: + std::shared_ptr engine_service_; }; } // namespace commands diff --git a/engine/cli/commands/run_cmd.cc b/engine/cli/commands/run_cmd.cc index 1b71f1af7..91a813d64 100644 --- a/engine/cli/commands/run_cmd.cc +++ b/engine/cli/commands/run_cmd.cc @@ -94,7 +94,7 @@ void RunCmd::Exec(bool run_detach, // Check if engine existed. If not, download it { - auto is_engine_ready = engine_service_.IsEngineReady(mc.engine); + auto is_engine_ready = engine_service_->IsEngineReady(mc.engine); if (is_engine_ready.has_error()) { throw std::runtime_error(is_engine_ready.error()); } @@ -102,7 +102,7 @@ void RunCmd::Exec(bool run_detach, if (!is_engine_ready.value()) { CTL_INF("Engine " << mc.engine << " is not ready. Proceed to install.."); - if (!EngineInstallCmd(download_service_, host_, port_, false) + if (!EngineInstallCmd(engine_service_, host_, port_, false) .Exec(mc.engine)) { return; } else { diff --git a/engine/cli/commands/run_cmd.h b/engine/cli/commands/run_cmd.h index c0f6a4eb2..b22b064f9 100644 --- a/engine/cli/commands/run_cmd.h +++ b/engine/cli/commands/run_cmd.h @@ -12,12 +12,11 @@ std::optional SelectLocalModel(std::string host, int port, class RunCmd { public: explicit RunCmd(std::string host, int port, std::string model_handle, - std::shared_ptr download_service) + std::shared_ptr engine_service) : host_{std::move(host)}, port_{port}, model_handle_{std::move(model_handle)}, - download_service_(download_service), - engine_service_{EngineService(download_service)} {}; + engine_service_{engine_service} {}; void Exec(bool chat_flag, const std::unordered_map& options); @@ -26,8 +25,6 @@ class RunCmd { std::string host_; int port_; std::string model_handle_; - - std::shared_ptr download_service_; - EngineService engine_service_; + std::shared_ptr engine_service_; }; } // namespace commands diff --git a/engine/cli/commands/server_start_cmd.cc b/engine/cli/commands/server_start_cmd.cc index 3d52f3d25..3d6045cd5 100644 --- a/engine/cli/commands/server_start_cmd.cc +++ b/engine/cli/commands/server_start_cmd.cc @@ -112,7 +112,9 @@ bool ServerStartCmd::Exec(const std::string& host, int port, return false; } else if (pid == 0) { // Some engines requires to add lib search path before process being created - EngineService().RegisterEngineLibPath(); + auto download_srv = std::make_shared(); + auto dylib_path_mng = std::make_shared(); + EngineService(download_srv, dylib_path_mng).RegisterEngineLibPath(); std::string p = cortex_utils::GetCurrentPath() + "/" + exe; execl(p.c_str(), exe.c_str(), "--start-server", "--config_file_path", @@ -131,5 +133,4 @@ bool ServerStartCmd::Exec(const std::string& host, int port, #endif return true; } - }; // namespace commands diff --git a/engine/controllers/engines.cc b/engine/controllers/engines.cc index 1d0223d9a..a92d6805f 100644 --- a/engine/controllers/engines.cc +++ b/engine/controllers/engines.cc @@ -3,9 +3,9 @@ #include "utils/archive_utils.h" #include "utils/cortex_utils.h" #include "utils/engine_constants.h" -#include "utils/http_util.h" #include "utils/logging_utils.h" #include "utils/string_utils.h" + namespace { // Need to change this after we rename repositories std::string NormalizeEngine(const std::string& engine) { @@ -24,8 +24,8 @@ void Engines::ListEngine( const HttpRequestPtr& req, std::function&& callback) const { Json::Value ret; - auto engine_names = engine_service_->GetSupportedEngineNames().value(); - for (const auto& engine : engine_names) { + auto engines = engine_service_->GetSupportedEngineNames().value(); + for (const auto& engine : engines) { auto installed_engines = engine_service_->GetInstalledEngineVariants(engine); if (installed_engines.has_error()) { @@ -37,6 +37,7 @@ void Engines::ListEngine( } ret[engine] = variants; } + // Add remote engine auto remote_engines = engine_service_->GetEngines(); if (remote_engines.has_value()) { @@ -49,7 +50,6 @@ void Engines::ListEngine( } } } - auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); resp->setStatusCode(k200OK); callback(resp); diff --git a/engine/cortex-common/EngineI.h b/engine/cortex-common/EngineI.h index b456cb109..b796ebaed 100644 --- a/engine/cortex-common/EngineI.h +++ b/engine/cortex-common/EngineI.h @@ -8,15 +8,11 @@ #include "trantor/utils/Logger.h" class EngineI { public: - struct RegisterLibraryOption { - std::vector paths; - }; - struct EngineLoadOption { // engine std::filesystem::path engine_path; - std::filesystem::path cuda_path; - bool custom_engine_path; + std::filesystem::path deps_path; + bool is_custom_engine_path; // logging std::filesystem::path log_path; @@ -25,16 +21,11 @@ class EngineI { }; struct EngineUnloadOption { - bool unload_dll; + // place holder for now }; virtual ~EngineI() {} - /** - * Being called before starting process to register dependencies search paths. - */ - virtual void RegisterLibraryPath(RegisterLibraryOption opts) = 0; - virtual void Load(EngineLoadOption opts) = 0; virtual void Unload(EngineUnloadOption opts) = 0; diff --git a/engine/main.cc b/engine/main.cc index 13583dc00..8ca5ffd1f 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -26,6 +26,7 @@ #include "services/thread_service.h" #include "utils/archive_utils.h" #include "utils/cortex_utils.h" +#include "utils/dylib_path_manager.h" #include "utils/event_processor.h" #include "utils/file_logger.h" #include "utils/file_manager_utils.h" @@ -125,6 +126,8 @@ void RunServer(std::optional port, bool ignore_cout) { cortex::event::EventProcessor event_processor(event_queue_ptr); auto data_folder_path = file_manager_utils::GetCortexDataPath(); + // utils + auto dylib_path_manager = std::make_shared(); auto file_repo = std::make_shared(data_folder_path); auto msg_repo = std::make_shared(data_folder_path); @@ -139,7 +142,8 @@ void RunServer(std::optional port, bool ignore_cout) { auto config_service = std::make_shared(); auto download_service = std::make_shared(event_queue_ptr, config_service); - auto engine_service = std::make_shared(download_service); + auto engine_service = + std::make_shared(download_service, dylib_path_manager); auto inference_svc = std::make_shared(engine_service); auto model_src_svc = std::make_shared(); diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 4f2122f6b..035ef4a4e 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -711,23 +711,42 @@ cpp::result EngineService::LoadEngine( auto custom_engine_path = engine_dir_path_res.value().second; try { + auto cuda_path = file_manager_utils::GetCudaToolkitPath(ne); + +#if defined(_WIN32) || defined(_WIN64) + // register deps + std::vector paths{}; + paths.push_back(std::move(cuda_path)); + paths.push_back(std::move(engine_dir_path)); + + CTL_DBG("Registering dylib for " + << ne << " with " << std::to_string(paths.size()) << " paths."); + for (const auto& path : paths) { + CTL_DBG("Registering path: " << path.string()); + } + + auto reg_result = dylib_path_manager_->RegisterPath(ne, paths); + if (reg_result.has_error()) { + CTL_DBG("Failed register lib paths for: " << ne); + } else { + CTL_DBG("Registered lib paths for: " << ne); + } +#endif + auto dylib = std::make_unique(engine_dir_path.string(), "engine"); auto config = file_manager_utils::GetCortexConfig(); - - auto log_path = - std::filesystem::path(config.logFolderPath) / - std::filesystem::path( - config.logLlamaCppPath); // for now seems like we use same log path + auto log_path = std::filesystem::path(config.logFolderPath) / + std::filesystem::path(config.logLlamaCppPath); // init auto func = dylib->get_function("get_engine"); auto engine_obj = func(); auto load_opts = EngineI::EngineLoadOption{ .engine_path = engine_dir_path, - .cuda_path = file_manager_utils::GetCudaToolkitPath(ne), - .custom_engine_path = custom_engine_path, + .deps_path = cuda_path, + .is_custom_engine_path = custom_engine_path, .log_path = log_path, .max_log_lines = config.maxLogLines, .log_level = logging_utils_helper::global_log_level, @@ -753,27 +772,32 @@ void EngineService::RegisterEngineLibPath() { try { auto engine_dir_path_res = GetEngineDirPath(engine); if (engine_dir_path_res.has_error()) { - CTL_ERR( + CTL_WRN( "Could not get engine dir path: " << engine_dir_path_res.error()); continue; } auto engine_dir_path = engine_dir_path_res.value().first; auto custom_engine_path = engine_dir_path_res.value().second; - - auto dylib = std::make_unique(engine_dir_path.string(), - "engine"); - auto cuda_path = file_manager_utils::GetCudaToolkitPath(ne); - // init - auto func = dylib->get_function("get_engine"); - auto engine = func(); + + // register deps std::vector paths{}; - auto register_opts = EngineI::RegisterLibraryOption{ - .paths = paths, - }; - engine->RegisterLibraryPath(register_opts); - delete engine; - CTL_DBG("Register lib path for: " << engine); + paths.push_back(std::move(cuda_path)); + paths.push_back(std::move(engine_dir_path)); + + CTL_DBG("Registering dylib for " + << ne << " with " << std::to_string(paths.size()) << " paths."); + for (const auto& path : paths) { + CTL_DBG("Registering path: " << path.string()); + } + + auto reg_result = dylib_path_manager_->RegisterPath(ne, paths); + if (reg_result.has_error()) { + CTL_WRN("Failed register lib path for " << engine); + } else { + CTL_DBG("Registered lib path for " << engine); + } + } catch (const std::exception& e) { CTL_WRN("Failed to registering engine lib path: " << e.what()); } @@ -832,10 +856,14 @@ cpp::result EngineService::UnloadEngine( } if (std::holds_alternative(engines_[ne].engine)) { LOG_INFO << "Unloading engine " << ne; + auto unreg_result = dylib_path_manager_->Unregister(ne); + if (unreg_result.has_error()) { + CTL_DBG("Failed unregister lib paths for: " << ne); + } else { + CTL_DBG("Unregistered lib paths for: " << ne); + } auto* e = std::get(engines_[ne].engine); - auto unload_opts = EngineI::EngineUnloadOption{ - .unload_dll = true, - }; + auto unload_opts = EngineI::EngineUnloadOption{}; e->Unload(unload_opts); delete e; engines_.erase(ne); diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 8299655f2..9253eccf1 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -16,6 +16,7 @@ #include "services/download_service.h" #include "utils/cpuid/cpu_info.h" #include "utils/dylib.h" +#include "utils/dylib_path_manager.h" #include "utils/engine_constants.h" #include "utils/github_release_utils.h" #include "utils/result.hpp" @@ -56,6 +57,7 @@ class EngineService : public EngineServiceI { std::mutex engines_mutex_; std::unordered_map engines_{}; std::shared_ptr download_service_; + std::shared_ptr dylib_path_manager_; struct HardwareInfo { std::unique_ptr sys_inf; @@ -65,18 +67,15 @@ class EngineService : public EngineServiceI { HardwareInfo hw_inf_; public: - const std::vector kSupportEngines = { - kLlamaEngine, kOnnxEngine, kTrtLlmEngine}; - - explicit EngineService(std::shared_ptr download_service) + explicit EngineService( + std::shared_ptr download_service, + std::shared_ptr dylib_path_manager) : download_service_{download_service}, + dylib_path_manager_{dylib_path_manager}, hw_inf_{.sys_inf = system_info_utils::GetSystemInfo(), .cuda_driver_version = system_info_utils::GetDriverAndCudaVersion().second} {} - // just for initialize supported engines - EngineService() {}; - std::vector GetEngineInfoList() const; /** diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 15fee15be..6a45733d3 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -970,9 +970,7 @@ cpp::result ModelService::GetModelStatus( if (status == drogon::k200OK) { return true; } else { - CTL_WRN("Model failed to get model status with status code: " << status); - return cpp::fail("Model failed to get model status: " + - data["message"].asString()); + return cpp::fail(data["message"].asString()); } } catch (const std::exception& e) { return cpp::fail("Fail to get model status with ID '" + model_handle + diff --git a/engine/utils/config_yaml_utils.cc b/engine/utils/config_yaml_utils.cc index c7a696df4..8fbfe1dbe 100644 --- a/engine/utils/config_yaml_utils.cc +++ b/engine/utils/config_yaml_utils.cc @@ -84,7 +84,8 @@ CortexConfig CortexConfigMgr::FromYaml(const std::string& path, !node["proxyUsername"] || !node["proxyPassword"] || !node["verifyPeerSsl"] || !node["verifyHostSsl"] || !node["verifyProxySsl"] || !node["verifyProxyHostSsl"] || - !node["sslCertPath"] || !node["sslKeyPath"] || !node["noProxy"]); + !node["supportedEngines"] || !node["sslCertPath"] || + !node["sslKeyPath"] || !node["noProxy"]); CortexConfig config = { .logFolderPath = node["logFolderPath"] @@ -172,6 +173,10 @@ CortexConfig CortexConfigMgr::FromYaml(const std::string& path, : default_cfg.sslCertPath, .sslKeyPath = node["sslKeyPath"] ? node["sslKeyPath"].as() : default_cfg.sslKeyPath, + .supportedEngines = + node["supportedEngines"] + ? node["supportedEngines"].as>() + : default_cfg.supportedEngines, }; if (should_update_config) { l.unlock(); diff --git a/engine/utils/dylib_path_manager.cc b/engine/utils/dylib_path_manager.cc new file mode 100644 index 000000000..3d10fc8ff --- /dev/null +++ b/engine/utils/dylib_path_manager.cc @@ -0,0 +1,129 @@ +#include "dylib_path_manager.h" +#include "utils/logging_utils.h" + +namespace cortex { + +cpp::result DylibPathManager::RegisterPath( + const std::string& key, std::vector paths) { +#if defined(_WIN32) || defined(_WIN64) + std::vector dylib_paths; + for (const auto& path : paths) { + if (!std::filesystem::exists(path)) { + return cpp::fail("Path does not exist: " + path.string()); + } + + std::wstring_convert> converter; + std::wstring wide_path = converter.from_bytes(path.string()); + + auto cookie = AddDllDirectory(wide_path.c_str()); + if (cookie == nullptr) { + CTL_ERR("Failed to added DLL directory: " << path.string()); + + // Clean up any paths we've already added + for (auto& dylib_path : dylib_paths) { + CTL_DBG("Cleaning DLL path: " + dylib_path.path.string()); + RemoveDllDirectory(dylib_path.cookie); + } + return cpp::fail("Failed to add DLL directory: " + path.string()); + } else { + CTL_DBG("Added DLL directory: " << path.string()); + } + + dylib_paths.push_back({path, cookie}); + } + dylib_map_[key] = std::move(dylib_paths); + +#elif defined(__linux__) + // For Linux, we need to modify LD_LIBRARY_PATH + std::vector dylib_paths; + std::stringstream new_path; + bool first = true; + + // First verify all paths exist + for (const auto& path : paths) { + if (!std::filesystem::exists(path)) { + return cpp::fail("Path does not exist: " + path.string()); + } + } + + // Get current LD_LIBRARY_PATH + const char* current_path = getenv(kLdLibraryPath); + std::string current_paths = current_path ? current_path : ""; + CTL_DBG("Current paths: " << current_paths); + + // Add new paths + for (const auto& path : paths) { + if (!first) { + new_path << ":"; + } + new_path << path.string(); + dylib_paths.push_back({path}); + first = false; + } + + // Append existing paths if they exist + if (!current_paths.empty()) { + new_path << ":" << current_paths; + } + CTL_DBG("New paths: " << new_path.str()); + // Set the new LD_LIBRARY_PATH + if (setenv(kLdLibraryPath, new_path.str().c_str(), 1) != 0) { + CTL_ERR("Failed to set path!!!"); + return cpp::fail("Failed to set " + std::string(kLdLibraryPath)); + } + + CTL_DBG("After set path: " << getenv(kLdLibraryPath)); + + dylib_map_[key] = std::move(dylib_paths); +#endif + + return {}; +} + +cpp::result DylibPathManager::Unregister( + const std::string& key) { + auto it = dylib_map_.find(key); + if (it == dylib_map_.end()) { + return cpp::fail("Key not found: " + key); + } + +#if defined(_WIN32) || defined(_WIN64) + // For Windows, remove each DLL directory + for (auto& dylib_path : it->second) { + if (!RemoveDllDirectory(dylib_path.cookie)) { + return cpp::fail("Failed to remove DLL directory: " + + dylib_path.path.string()); + } + } + +#elif defined(__linux__) + // For Linux, we need to rebuild LD_LIBRARY_PATH without the removed paths + const char* current_path = getenv(kLdLibraryPath); + if (current_path) { + std::string paths = current_path; + for (const auto& dylib_path : it->second) { + std::string path_str = dylib_path.path.string(); + size_t pos = paths.find(path_str); + if (pos != std::string::npos) { + // Remove the path and the following colon (or preceding colon if it's at the end) + if (pos > 0 && paths[pos - 1] == ':') { + paths.erase(pos - 1, path_str.length() + 1); + } else if (pos + path_str.length() < paths.length() && + paths[pos + path_str.length()] == ':') { + paths.erase(pos, path_str.length() + 1); + } else { + paths.erase(pos, path_str.length()); + } + } + } + + if (setenv(kLdLibraryPath, paths.c_str(), 1) != 0) { + return cpp::fail("Failed to update " + std::string(kLdLibraryPath)); + } + } +#endif + + dylib_map_.erase(it); + return {}; +} +} // namespace cortex diff --git a/engine/utils/dylib_path_manager.h b/engine/utils/dylib_path_manager.h new file mode 100644 index 000000000..bfdff7c7e --- /dev/null +++ b/engine/utils/dylib_path_manager.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include +#include "utils/result.hpp" + +#if defined(_WIN32) +#include +#include +#include +#endif + +namespace cortex { +class DylibPathManager { + // for linux + constexpr static auto kLdLibraryPath{"LD_LIBRARY_PATH"}; + + struct DylibPath { + std::filesystem::path path; +#if defined(_WIN32) || defined(_WIN64) + DLL_DIRECTORY_COOKIE cookie; +#endif + }; + + public: + cpp::result RegisterPath( + const std::string& key, std::vector paths); + + cpp::result Unregister(const std::string& key); + + private: + std::unordered_map> dylib_map_; +}; +} // namespace cortex From 5e84fb5e58413717f6d5fb659f74d675bd4908c0 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Fri, 13 Dec 2024 12:44:11 +0700 Subject: [PATCH 32/44] fix: correct stop inferencing condition (#1796) Co-authored-by: vansangpfiev --- engine/controllers/server.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index a9920e8aa..4c6bcaf82 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -140,7 +140,9 @@ void server::ProcessStreamRes(std::function cb, std::size_t buf_size) -> std::size_t { if (buf == nullptr) { LOG_TRACE << "Buf is null"; - inference_svc_->StopInferencing(engine_type, model_id); + if (!(*err_or_done)) { + inference_svc_->StopInferencing(engine_type, model_id); + } return 0; } From fc5397619d01a3be6a5590ec821533011a8d1bfd Mon Sep 17 00:00:00 2001 From: NamH Date: Mon, 16 Dec 2024 11:46:07 +0700 Subject: [PATCH 33/44] fix: add support image url for jan (#1798) --- engine/common/message.h | 21 ++++++++++++ engine/common/message_content_image_url.h | 42 ++++++++++++++++------- engine/common/message_content_text.h | 3 +- engine/test/components/test_models_db.cc | 3 +- 4 files changed, 54 insertions(+), 15 deletions(-) diff --git a/engine/common/message.h b/engine/common/message.h index 3bff6f048..d31c4f0d3 100644 --- a/engine/common/message.h +++ b/engine/common/message.h @@ -137,6 +137,27 @@ struct Message : JsonSerializable { if (root["content"].isArray() && !root["content"].empty()) { if (root["content"][0]["type"].asString() == "text") { message.content = ParseContents(std::move(root["content"])).value(); + } else if (root["content"][0]["type"].asString() == "image") { + // deprecated, for supporting jan and should be removed in the future + auto text_str = root["content"][0]["text"]["value"].asString(); + auto img_url = + root["content"][0]["text"]["annotations"][0].asString(); + auto text_content = std::make_unique(); + { + auto text = OpenAi::Text(); + auto empty_annotations = + std::vector>(); + text.value = std::move(text_str); + text.annotations = std::move(empty_annotations); + text_content->text = std::move(text); + } + + auto image_url_obj = OpenAi::ImageUrl(img_url, "auto"); + auto image_url_content = std::make_unique( + "image_url", std::move(image_url_obj)); + + message.content.push_back(std::move(text_content)); + message.content.push_back(std::move(image_url_content)); } else { // deprecated, for supporting jan and should be removed in the future // check if annotations is empty diff --git a/engine/common/message_content_image_url.h b/engine/common/message_content_image_url.h index b86544e38..336cf01d3 100644 --- a/engine/common/message_content_image_url.h +++ b/engine/common/message_content_image_url.h @@ -4,14 +4,21 @@ namespace OpenAi { -struct ImageUrl { - // The external URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp. +struct ImageUrl : public JsonSerializable { + /** + * The external URL of the image, must be a supported image types: + * jpeg, jpg, png, gif, webp. + */ std::string url; - // Specifies the detail level of the image. low uses fewer tokens, you can opt in to high resolution using high. Default value is auto + /** + * Specifies the detail level of the image. low uses fewer tokens, you + * can opt in to high resolution using high. Default value is auto + */ std::string detail; - ImageUrl() = default; + ImageUrl(const std::string& url, const std::string& detail = "auto") + : url{url}, detail{detail} {} ImageUrl(ImageUrl&&) noexcept = default; @@ -20,13 +27,25 @@ struct ImageUrl { ImageUrl(const ImageUrl&) = delete; ImageUrl& operator=(const ImageUrl&) = delete; + + cpp::result ToJson() override { + try { + Json::Value root; + root["url"] = url; + root["detail"] = detail; + return root; + } catch (const std::exception& e) { + return cpp::fail(std::string("ToJson failed: ") + e.what()); + } + } }; // References an image URL in the content of a message. struct ImageUrlContent : Content { // The type of the content part. - ImageUrlContent(const std::string& type) : Content(type) {} + explicit ImageUrlContent(const std::string& type, ImageUrl&& image_url) + : Content(type), image_url{std::move(image_url)} {} ImageUrlContent(ImageUrlContent&&) noexcept = default; @@ -38,6 +57,8 @@ struct ImageUrlContent : Content { ImageUrl image_url; + ~ImageUrlContent() override = default; + static cpp::result FromJson( Json::Value&& json) { if (json.empty()) { @@ -45,11 +66,9 @@ struct ImageUrlContent : Content { } try { - ImageUrlContent content{"image_url"}; - ImageUrl image_url; - image_url.url = std::move(json["image_url"]["url"].asString()); - image_url.detail = std::move(json["image_url"]["detail"].asString()); - content.image_url = std::move(image_url); + auto image_url = ImageUrl(json["image_url"]["url"].asString(), + json["image_url"]["detail"].asString()); + ImageUrlContent content{"image_url", std::move(image_url)}; return content; } catch (const std::exception& e) { return cpp::fail(std::string("FromJson failed: ") + e.what()); @@ -60,8 +79,7 @@ struct ImageUrlContent : Content { try { Json::Value json; json["type"] = type; - json["image_url"]["url"] = image_url.url; - json["image_url"]["detail"] = image_url.detail; + json["image_url"] = image_url.ToJson().value(); return json; } catch (const std::exception& e) { return cpp::fail(std::string("ToJson failed: ") + e.what()); diff --git a/engine/common/message_content_text.h b/engine/common/message_content_text.h index ea6aab1ab..5ede2582d 100644 --- a/engine/common/message_content_text.h +++ b/engine/common/message_content_text.h @@ -122,7 +122,6 @@ struct FilePathWrapper : Annotation { struct Text : JsonSerializable { // The data that makes up the text. - Text() = default; Text(Text&&) noexcept = default; @@ -214,6 +213,8 @@ struct TextContent : Content { Text text; + ~TextContent() override = default; + static cpp::result FromJson(Json::Value&& json) { if (json.empty()) { return cpp::fail("Json string is empty"); diff --git a/engine/test/components/test_models_db.cc b/engine/test/components/test_models_db.cc index 06294aa8c..0cc9b0344 100644 --- a/engine/test/components/test_models_db.cc +++ b/engine/test/components/test_models_db.cc @@ -1,6 +1,5 @@ #include "database/models.h" #include "gtest/gtest.h" -#include "utils/file_manager_utils.h" namespace cortex::db { namespace { @@ -122,4 +121,4 @@ TEST_F(ModelsTestSuite, TestHasModel) { EXPECT_TRUE(model_list_.DeleteModelEntry(kTestModel.model).value()); } -} // namespace cortex::db \ No newline at end of file +} // namespace cortex::db From 2ef085adbf4e045fbf6ad4e132dc3bfb56dce02e Mon Sep 17 00:00:00 2001 From: NamH Date: Mon, 16 Dec 2024 16:14:24 +0700 Subject: [PATCH 34/44] fix: remove sort msg by ulid (#1799) --- engine/repositories/message_fs_repository.cc | 31 +++++++------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/engine/repositories/message_fs_repository.cc b/engine/repositories/message_fs_repository.cc index 422242e3a..db6f5dd6e 100644 --- a/engine/repositories/message_fs_repository.cc +++ b/engine/repositories/message_fs_repository.cc @@ -80,34 +80,23 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit, messages.end()); } - const bool is_descending = (order == "desc"); - std::sort( - messages.begin(), messages.end(), - [is_descending](const OpenAi::Message& a, const OpenAi::Message& b) { - return is_descending ? (a.id > b.id) : (a.id < b.id); - }); - auto start_it = messages.begin(); auto end_it = messages.end(); if (!after.empty()) { - start_it = std::lower_bound( - messages.begin(), messages.end(), after, - [is_descending](const OpenAi::Message& msg, const std::string& value) { - return is_descending ? (msg.id > value) : (msg.id < value); - }); - - if (start_it != messages.end() && start_it->id == after) { - ++start_it; - } + start_it = std::find_if( + messages.begin(), messages.end(), + [&after](const OpenAi::Message& msg) { return msg.id > after; }); } if (!before.empty()) { - end_it = std::upper_bound( - start_it, messages.end(), before, - [is_descending](const std::string& value, const OpenAi::Message& msg) { - return is_descending ? (value > msg.id) : (value < msg.id); - }); + end_it = std::find_if( + start_it, messages.end(), + [&before](const OpenAi::Message& msg) { return msg.id >= before; }); + } + + if (order == "desc") { + std::reverse(start_it, end_it); } const size_t available_messages = std::distance(start_it, end_it); From c7982aeb10821f8086dd50979c204ad48e06b264 Mon Sep 17 00:00:00 2001 From: NamH Date: Mon, 16 Dec 2024 21:40:53 +0700 Subject: [PATCH 35/44] fix: allow upload file with same name (#1801) --- engine/repositories/file_fs_repository.cc | 25 ++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/engine/repositories/file_fs_repository.cc b/engine/repositories/file_fs_repository.cc index b9ab4fec6..a209d33c3 100644 --- a/engine/repositories/file_fs_repository.cc +++ b/engine/repositories/file_fs_repository.cc @@ -18,9 +18,28 @@ cpp::result FileFsRepository::StoreFile( } cortex::db::File db; - auto file_full_path = file_container_path / file_metadata.filename; - if (std::filesystem::exists(file_full_path)) { - return cpp::fail("File already exists: " + file_full_path.string()); + auto original_filename = file_metadata.filename; + auto file_full_path = file_container_path / original_filename; + + // Handle duplicate filenames + int counter = 1; + while (std::filesystem::exists(file_full_path)) { + auto dot_pos = original_filename.find_last_of('.'); + std::string name_part; + std::string ext_part; + + if (dot_pos != std::string::npos) { + name_part = original_filename.substr(0, dot_pos); + ext_part = original_filename.substr(dot_pos); + } else { + name_part = original_filename; + ext_part = ""; + } + + auto new_filename = name_part + "_" + std::to_string(counter) + ext_part; + file_full_path = file_container_path / new_filename; + file_metadata.filename = new_filename; + counter++; } try { From bb976124de92cc3b2f68f1df5bab9ee52ed1d8c2 Mon Sep 17 00:00:00 2001 From: NamH Date: Mon, 16 Dec 2024 21:42:31 +0700 Subject: [PATCH 36/44] chore: update set default engine docs (#1800) --- docs/static/openapi/cortex.json | 41 ++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 4792fe306..4b238d2e4 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -2522,26 +2522,31 @@ "default": "llama-cpp" }, "description": "The type of engine" - }, - { - "name": "version", - "in": "query", - "required": true, - "schema": { - "type": "string" - }, - "description": "The version of the engine variant" - }, - { - "name": "variant", - "in": "query", - "required": true, - "schema": { - "type": "string" - }, - "description": "The variant of the engine" } ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["version", "variant"], + "properties": { + "version": { + "type": "string", + "description": "The version of the engine variant", + "example": "0.1.34" + }, + "variant": { + "type": "string", + "description": "The variant of the engine", + "example": "mac-arm64" + } + } + } + } + } + }, "responses": { "200": { "description": "Successful response", From 0255f3d4f7b21943d3951735427e3378230a3bbc Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 17 Dec 2024 08:51:07 +0700 Subject: [PATCH 37/44] fix: improve remote engine (#1787) * fix: improve streaming message for remote engine * feat: improve remote engine * chore: cleanup * fix: correct remote engine check * chore: add unit tests * fix: cleanup * chore: cleanup --------- Co-authored-by: vansangpfiev --- engine/CMakeLists.txt | 2 - engine/cli/CMakeLists.txt | 2 - engine/common/engine_servicei.h | 2 + engine/config/model_config.h | 43 +-- engine/config/remote_template.h | 66 +++++ engine/controllers/models.cc | 10 +- .../remote-engine/anthropic_engine.cc | 62 ----- .../remote-engine/anthropic_engine.h | 13 - .../extensions/remote-engine/openai_engine.cc | 54 ---- .../extensions/remote-engine/openai_engine.h | 14 - .../extensions/remote-engine/remote_engine.cc | 261 ++++++++++-------- .../extensions/remote-engine/remote_engine.h | 13 +- engine/services/engine_service.cc | 43 ++- engine/services/engine_service.h | 2 + engine/services/model_service.cc | 2 +- engine/test/components/CMakeLists.txt | 1 + engine/test/components/main.cc | 4 + engine/test/components/test_remote_engine.cc | 81 ++++++ engine/utils/engine_constants.h | 3 + 19 files changed, 332 insertions(+), 346 deletions(-) create mode 100644 engine/config/remote_template.h delete mode 100644 engine/extensions/remote-engine/anthropic_engine.cc delete mode 100644 engine/extensions/remote-engine/anthropic_engine.h delete mode 100644 engine/extensions/remote-engine/openai_engine.cc delete mode 100644 engine/extensions/remote-engine/openai_engine.h create mode 100644 engine/test/components/test_remote_engine.cc diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index 41ebb3dd6..25c0783b1 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -144,8 +144,6 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/file_logger.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/dylib_path_manager.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/remote_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/openai_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/anthropic_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/remote-engine/template_renderer.cc ) diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index 237596f21..df4f1a76b 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -84,8 +84,6 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/hardware_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/remote_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/openai_engine.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/anthropic_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/template_renderer.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/download_progress.cc diff --git a/engine/common/engine_servicei.h b/engine/common/engine_servicei.h index 85fa87d76..a4b0c8732 100644 --- a/engine/common/engine_servicei.h +++ b/engine/common/engine_servicei.h @@ -58,4 +58,6 @@ class EngineServiceI { GetEngineByNameAndVariant( const std::string& engine_name, const std::optional variant = std::nullopt) = 0; + + virtual bool IsRemoteEngine(const std::string& engine_name) = 0; }; diff --git a/engine/config/model_config.h b/engine/config/model_config.h index 84e175d54..a799adb27 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -8,52 +8,12 @@ #include #include #include +#include "config/remote_template.h" #include "utils/format_utils.h" #include "utils/remote_models_utils.h" namespace config { -namespace { -const std::string kOpenAITransformReqTemplate = - R"({ {% set first = true %} {% for key, value in input_request %} {% if key == \"messages\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"presence_penalty\" or key == \"metadata\" or key == \"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; -const std::string kOpenAITransformRespTemplate = - R"({ {%- set first = true -%} {%- for key, value in input_request -%} {%- if key == \"id\" or key == \"choices\" or key == \"created\" or key == \"model\" or key == \"service_tier\" or key == \"system_fingerprint\" or key == \"object\" or key == \"usage\" -%} {%- if not first -%},{%- endif -%} \"{{ key }}\": {{ tojson(value) }} {%- set first = false -%} {%- endif -%} {%- endfor -%} })"; -const std::string kAnthropicTransformReqTemplate = - R"({ {% set first = true %} {% for key, value in input_request %} {% if key == \"system\" or key == \"messages\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"presence_penalty\" or key == \"metadata\" or key == \"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; -const std::string kAnthropicTransformRespTemplate = R"({ - "id": "{{ input_request.id }}", - "created": null, - "object": "chat.completion", - "model": "{{ input_request.model }}", - "choices": [ - { - "index": 0, - "message": { - "role": "{{ input_request.role }}", - "content": "{% if input_request.content and input_request.content.0.type == "text" %} {{input_request.content.0.text}} {% endif %}", - "refusal": null - }, - "logprobs": null, - "finish_reason": "{{ input_request.stop_reason }}" - } - ], - "usage": { - "prompt_tokens": {{ input_request.usage.input_tokens }}, - "completion_tokens": {{ input_request.usage.output_tokens }}, - "total_tokens": {{ input_request.usage.input_tokens + input_request.usage.output_tokens }}, - "prompt_tokens_details": { - "cached_tokens": 0 - }, - "completion_tokens_details": { - "reasoning_tokens": 0, - "accepted_prediction_tokens": 0, - "rejected_prediction_tokens": 0 - } - }, - "system_fingerprint": "fp_6b68a8204b" - })"; -} // namespace - struct RemoteModelConfig { std::string model; std::string api_key_template; @@ -108,6 +68,7 @@ struct RemoteModelConfig { kOpenAITransformRespTemplate; } } + metadata = json.get("metadata", metadata); } diff --git a/engine/config/remote_template.h b/engine/config/remote_template.h new file mode 100644 index 000000000..8a17aaa9a --- /dev/null +++ b/engine/config/remote_template.h @@ -0,0 +1,66 @@ +#include + +namespace config { +const std::string kOpenAITransformReqTemplate = + R"({ {% set first = true %} {% for key, value in input_request %} {% if key == "messages" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} {% if not first %},{% endif %} "{{ key }}": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} })"; +const std::string kOpenAITransformRespTemplate = + R"({ {%- set first = true -%} {%- for key, value in input_request -%} {%- if key == "id" or key == "choices" or key == "created" or key == "model" or key == "service_tier" or key == "system_fingerprint" or key == "object" or key == "usage" -%} {%- if not first -%},{%- endif -%} "{{ key }}": {{ tojson(value) }} {%- set first = false -%} {%- endif -%} {%- endfor -%} })"; +const std::string kAnthropicTransformReqTemplate = + R"({ + {% for key, value in input_request %} + {% if key == "messages" %} + {% if input_request.messages.0.role == "system" %} + "system": "{{ input_request.messages.0.content }}", + "messages": [ + {% for message in input_request.messages %} + {% if not loop.is_first %} + {"role": "{{ message.role }}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {% endif %} + {% endfor %} + ] + {% else %} + "messages": [ + {% for message in input_request.messages %} + {"role": " {{ message.role}}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {% endfor %} + ] + {% endif %} + {% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} + "{{ key }}": {{ tojson(value) }} + {% endif %} + {% if not loop.is_last %},{% endif %} + {% endfor %} })"; +const std::string kAnthropicTransformRespTemplate = R"({ + "id": "{{ input_request.id }}", + "created": null, + "object": "chat.completion", + "model": "{{ input_request.model }}", + "choices": [ + { + "index": 0, + "message": { + "role": "{{ input_request.role }}", + "content": "{% if input_request.content and input_request.content.0.type == "text" %} {{input_request.content.0.text}} {% endif %}", + "refusal": null + }, + "logprobs": null, + "finish_reason": "{{ input_request.stop_reason }}" + } + ], + "usage": { + "prompt_tokens": {{ input_request.usage.input_tokens }}, + "completion_tokens": {{ input_request.usage.output_tokens }}, + "total_tokens": {{ input_request.usage.input_tokens + input_request.usage.output_tokens }}, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "system_fingerprint": "fp_6b68a8204b" + })"; + +} // namespace config \ No newline at end of file diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index affa45d52..59793b2a6 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -200,7 +200,7 @@ void Models::ListModel( .string()); auto model_config = yaml_handler.GetModelConfig(); - if (!remote_engine::IsRemoteEngine(model_config.engine)) { + if (!engine_service_->IsRemoteEngine(model_config.engine)) { Json::Value obj = model_config.ToJson(); obj["id"] = model_entry.model; obj["model"] = model_entry.model; @@ -632,7 +632,7 @@ void Models::GetRemoteModels( const HttpRequestPtr& req, std::function&& callback, const std::string& engine_id) { - if (!remote_engine::IsRemoteEngine(engine_id)) { + if (!engine_service_->IsRemoteEngine(engine_id)) { Json::Value ret; ret["message"] = "Not a remote engine: " + engine_id; auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); @@ -668,8 +668,7 @@ void Models::AddRemoteModel( auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); auto engine_name = (*(req->getJsonObject())).get("engine", "").asString(); - /* To do: uncomment when remote engine is ready - + auto engine_validate = engine_service_->IsEngineReady(engine_name); if (engine_validate.has_error()) { Json::Value ret; @@ -679,6 +678,7 @@ void Models::AddRemoteModel( callback(resp); return; } + if (!engine_validate.value()) { Json::Value ret; ret["message"] = "Engine is not ready! Please install first!"; @@ -687,7 +687,7 @@ void Models::AddRemoteModel( callback(resp); return; } - */ + config::RemoteModelConfig model_config; model_config.LoadFromJson(*(req->getJsonObject())); cortex::db::Models modellist_utils_obj; diff --git a/engine/extensions/remote-engine/anthropic_engine.cc b/engine/extensions/remote-engine/anthropic_engine.cc deleted file mode 100644 index 847cba566..000000000 --- a/engine/extensions/remote-engine/anthropic_engine.cc +++ /dev/null @@ -1,62 +0,0 @@ -#include "anthropic_engine.h" -#include -#include -#include "utils/logging_utils.h" - -namespace remote_engine { -namespace { -constexpr const std::array kAnthropicModels = { - "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", - "claude-3-opus-20240229", "claude-3-sonnet-20240229", - "claude-3-haiku-20240307"}; -} -void AnthropicEngine::GetModels( - std::shared_ptr json_body, - std::function&& callback) { - Json::Value json_resp; - Json::Value model_array(Json::arrayValue); - { - std::shared_lock l(models_mtx_); - for (const auto& [m, _] : models_) { - Json::Value val; - val["id"] = m; - val["engine"] = "anthropic"; - val["start_time"] = "_"; - val["model_size"] = "_"; - val["vram"] = "_"; - val["ram"] = "_"; - val["object"] = "model"; - model_array.append(val); - } - } - - json_resp["object"] = "list"; - json_resp["data"] = model_array; - - Json::Value status; - status["is_done"] = true; - status["has_error"] = false; - status["is_stream"] = false; - status["status_code"] = 200; - callback(std::move(status), std::move(json_resp)); - CTL_INF("Running models responded"); -} - -Json::Value AnthropicEngine::GetRemoteModels() { - Json::Value json_resp; - Json::Value model_array(Json::arrayValue); - for (const auto& m : kAnthropicModels) { - Json::Value val; - val["id"] = std::string(m); - val["engine"] = "anthropic"; - val["created"] = "_"; - val["object"] = "model"; - model_array.append(val); - } - - json_resp["object"] = "list"; - json_resp["data"] = model_array; - CTL_INF("Remote models responded"); - return json_resp; -} -} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/anthropic_engine.h b/engine/extensions/remote-engine/anthropic_engine.h deleted file mode 100644 index bcd3dfaf7..000000000 --- a/engine/extensions/remote-engine/anthropic_engine.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once -#include "remote_engine.h" - -namespace remote_engine { - class AnthropicEngine: public RemoteEngine { -public: - void GetModels( - std::shared_ptr json_body, - std::function&& callback) override; - - Json::Value GetRemoteModels() override; - }; -} \ No newline at end of file diff --git a/engine/extensions/remote-engine/openai_engine.cc b/engine/extensions/remote-engine/openai_engine.cc deleted file mode 100644 index 7c7d70385..000000000 --- a/engine/extensions/remote-engine/openai_engine.cc +++ /dev/null @@ -1,54 +0,0 @@ -#include "openai_engine.h" -#include "utils/logging_utils.h" - -namespace remote_engine { - -void OpenAiEngine::GetModels( - std::shared_ptr json_body, - std::function&& callback) { - Json::Value json_resp; - Json::Value model_array(Json::arrayValue); - { - std::shared_lock l(models_mtx_); - for (const auto& [m, _] : models_) { - Json::Value val; - val["id"] = m; - val["engine"] = "openai"; - val["start_time"] = "_"; - val["model_size"] = "_"; - val["vram"] = "_"; - val["ram"] = "_"; - val["object"] = "model"; - model_array.append(val); - } - } - - json_resp["object"] = "list"; - json_resp["data"] = model_array; - - Json::Value status; - status["is_done"] = true; - status["has_error"] = false; - status["is_stream"] = false; - status["status_code"] = 200; - callback(std::move(status), std::move(json_resp)); - CTL_INF("Running models responded"); -} - -Json::Value OpenAiEngine::GetRemoteModels() { - auto response = MakeGetModelsRequest(); - if (response.error) { - Json::Value error; - error["error"] = response.error_message; - return error; - } - Json::Value response_json; - Json::Reader reader; - if (!reader.parse(response.body, response_json)) { - Json::Value error; - error["error"] = "Failed to parse response"; - return error; - } - return response_json; -} -} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/openai_engine.h b/engine/extensions/remote-engine/openai_engine.h deleted file mode 100644 index 61dc68f0c..000000000 --- a/engine/extensions/remote-engine/openai_engine.h +++ /dev/null @@ -1,14 +0,0 @@ -#pragma once - -#include "remote_engine.h" - -namespace remote_engine { -class OpenAiEngine : public RemoteEngine { - public: - void GetModels( - std::shared_ptr json_body, - std::function&& callback) override; - - Json::Value GetRemoteModels() override; -}; -} // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 04effb457..6361077dd 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -16,67 +16,14 @@ bool is_anthropic(const std::string& model) { return model.find("claude") != std::string::npos; } -struct AnthropicChunk { - std::string type; - std::string id; - int index; - std::string msg; - std::string model; - std::string stop_reason; - bool should_ignore = false; - - AnthropicChunk(const std::string& str) { - if (str.size() > 6) { - std::string s = str.substr(6); - try { - auto root = json_helper::ParseJsonString(s); - type = root["type"].asString(); - if (type == "message_start") { - id = root["message"]["id"].asString(); - model = root["message"]["model"].asString(); - } else if (type == "content_block_delta") { - index = root["index"].asInt(); - if (root["delta"]["type"].asString() == "text_delta") { - msg = root["delta"]["text"].asString(); - } - } else if (type == "message_delta") { - stop_reason = root["delta"]["stop_reason"].asString(); - } else { - // ignore other messages - should_ignore = true; - } - } catch (const std::exception& e) { - should_ignore = true; - CTL_WRN("JSON parse error: " << e.what()); - } - } else { - should_ignore = true; - } - } +bool is_openai(const std::string& model) { + return model.find("gpt") != std::string::npos; +} - std::string ToOpenAiFormatString() { - Json::Value root; - root["id"] = id; - root["object"] = "chat.completion.chunk"; - root["created"] = Json::Value(); - root["model"] = model; - root["system_fingerprint"] = "fp_e76890f0c3"; - Json::Value choices(Json::arrayValue); - Json::Value choice; - Json::Value content; - choice["index"] = 0; - content["content"] = msg; - if (type == "message_start") { - content["role"] = "assistant"; - content["refusal"] = Json::Value(); - } - choice["delta"] = content; - choice["finish_reason"] = stop_reason.empty() ? Json::Value() : stop_reason; - choices.append(choice); - root["choices"] = choices; - return "data: " + json_helper::DumpJsonString(root); - } -}; +constexpr const std::array kAnthropicModels = { + "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", + "claude-3-opus-20240229", "claude-3-sonnet-20240229", + "claude-3-haiku-20240307"}; } // namespace @@ -92,21 +39,13 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, while ((pos = context->buffer.find('\n')) != std::string::npos) { std::string line = context->buffer.substr(0, pos); context->buffer = context->buffer.substr(pos + 1); - CTL_TRC(line); // Skip empty lines if (line.empty() || line == "\r" || line.find("event:") != std::string::npos) continue; - // Remove "data: " prefix if present - // if (line.substr(0, 6) == "data: ") - // { - // line = line.substr(6); - // } - - // Skip [DONE] message - // std::cout << line << std::endl; + CTL_DBG(line); if (line == "data: [DONE]" || line.find("message_stop") != std::string::npos) { Json::Value status; @@ -120,17 +59,20 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb, // Parse the JSON Json::Value chunk_json; - if (is_anthropic(context->model)) { - AnthropicChunk ac(line); - if (ac.should_ignore) + if (!is_openai(context->model)) { + std::string s = line.substr(6); + try { + auto root = json_helper::ParseJsonString(s); + root["model"] = context->model; + root["id"] = context->id; + root["stream"] = true; + auto result = context->renderer.Render(context->stream_template, root); + CTL_DBG(result); + chunk_json["data"] = "data: " + result + "\n\n"; + } catch (const std::exception& e) { + CTL_WRN("JSON parse error: " << e.what()); continue; - ac.model = context->model; - if (ac.type == "message_start") { - context->id = ac.id; - } else { - ac.id = context->id; } - chunk_json["data"] = ac.ToOpenAiFormatString() + "\n\n"; } else { chunk_json["data"] = line + "\n\n"; } @@ -178,10 +120,16 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( headers = curl_slist_append(headers, "Cache-Control: no-cache"); headers = curl_slist_append(headers, "Connection: keep-alive"); + std::string stream_template = chat_res_template_; + StreamContext context{ std::make_shared>( callback), - "", "", config.model}; + "", + "", + config.model, + renderer_, + stream_template}; curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); @@ -232,7 +180,8 @@ static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, return size * nmemb; } -RemoteEngine::RemoteEngine() { +RemoteEngine::RemoteEngine(const std::string& engine_name) + : engine_name_(engine_name) { curl_global_init(CURL_GLOBAL_ALL); } @@ -395,7 +344,33 @@ bool RemoteEngine::LoadModelConfig(const std::string& model, void RemoteEngine::GetModels( std::shared_ptr json_body, std::function&& callback) { - CTL_WRN("Not implemented yet!"); + Json::Value json_resp; + Json::Value model_array(Json::arrayValue); + { + std::shared_lock l(models_mtx_); + for (const auto& [m, _] : models_) { + Json::Value val; + val["id"] = m; + val["engine"] = "openai"; + val["start_time"] = "_"; + val["model_size"] = "_"; + val["vram"] = "_"; + val["ram"] = "_"; + val["object"] = "model"; + model_array.append(val); + } + } + + json_resp["object"] = "list"; + json_resp["data"] = model_array; + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = 200; + callback(std::move(status), std::move(json_resp)); + CTL_INF("Running models responded"); } void RemoteEngine::LoadModel( @@ -431,6 +406,21 @@ void RemoteEngine::LoadModel( } if (json_body->isMember("metadata")) { metadata_ = (*json_body)["metadata"]; + if (!metadata_["TransformReq"].isNull() && + !metadata_["TransformReq"]["chat_completions"].isNull() && + !metadata_["TransformReq"]["chat_completions"]["template"].isNull()) { + chat_req_template_ = + metadata_["TransformReq"]["chat_completions"]["template"].asString(); + CTL_INF(chat_req_template_); + } + + if (!metadata_["TransformResp"].isNull() && + !metadata_["TransformResp"]["chat_completions"].isNull() && + !metadata_["TransformResp"]["chat_completions"]["template"].isNull()) { + chat_res_template_ = + metadata_["TransformResp"]["chat_completions"]["template"].asString(); + CTL_INF(chat_res_template_); + } } Json::Value response; @@ -535,23 +525,6 @@ void RemoteEngine::HandleChatCompletion( std::string(e.what())); } - // Parse system for anthropic - if (is_anthropic(model)) { - bool has_system = false; - Json::Value msgs(Json::arrayValue); - for (auto& kv : (*json_body)["messages"]) { - if (kv["role"].asString() == "system") { - (*json_body)["system"] = kv["content"].asString(); - has_system = true; - } else { - msgs.append(kv); - } - } - if (has_system) { - (*json_body)["messages"] = msgs; - } - } - // Render with error handling try { result = renderer_.Render(template_str, *json_body); @@ -601,33 +574,42 @@ void RemoteEngine::HandleChatCompletion( // Transform Response std::string response_str; try { - // Check if required YAML nodes exist - if (!model_config->transform_resp["chat_completions"]) { - throw std::runtime_error( - "Missing 'chat_completions' node in transform_resp"); - } - if (!model_config->transform_resp["chat_completions"]["template"]) { - throw std::runtime_error("Missing 'template' node in chat_completions"); - } + std::string template_str; + if (!chat_res_template_.empty()) { + CTL_DBG( + "Use engine transform response template: " << chat_res_template_); + template_str = chat_res_template_; + } else { + // Check if required YAML nodes exist + if (!model_config->transform_resp["chat_completions"]) { + throw std::runtime_error( + "Missing 'chat_completions' node in transform_resp"); + } + if (!model_config->transform_resp["chat_completions"]["template"]) { + throw std::runtime_error( + "Missing 'template' node in chat_completions"); + } - // Validate JSON body - if (!response_json || response_json.isNull()) { - throw std::runtime_error("Invalid or null JSON body"); - } + // Validate JSON body + if (!response_json || response_json.isNull()) { + throw std::runtime_error("Invalid or null JSON body"); + } - // Get template string with error check - std::string template_str; - try { - template_str = - model_config->transform_resp["chat_completions"]["template"] - .as(); - } catch (const YAML::BadConversion& e) { - throw std::runtime_error("Failed to convert template node to string: " + - std::string(e.what())); + // Get template string with error check + + try { + template_str = + model_config->transform_resp["chat_completions"]["template"] + .as(); + } catch (const YAML::BadConversion& e) { + throw std::runtime_error( + "Failed to convert template node to string: " + + std::string(e.what())); + } } - // Render with error handling try { + response_json["stream"] = false; response_str = renderer_.Render(template_str, response_json); } catch (const std::exception& e) { throw std::runtime_error("Template rendering error: " + @@ -705,8 +687,43 @@ void RemoteEngine::HandleEmbedding( } Json::Value RemoteEngine::GetRemoteModels() { - CTL_WRN("Not implemented yet!"); - return {}; + if (metadata_["get_models_url"].isNull() || + metadata_["get_models_url"].asString().empty()) { + if (engine_name_ == kAnthropicEngine) { + Json::Value json_resp; + Json::Value model_array(Json::arrayValue); + for (const auto& m : kAnthropicModels) { + Json::Value val; + val["id"] = std::string(m); + val["engine"] = "anthropic"; + val["created"] = "_"; + val["object"] = "model"; + model_array.append(val); + } + + json_resp["object"] = "list"; + json_resp["data"] = model_array; + CTL_INF("Remote models responded"); + return json_resp; + } else { + return Json::Value(); + } + } else { + auto response = MakeGetModelsRequest(); + if (response.error) { + Json::Value error; + error["error"] = response.error_message; + return error; + } + Json::Value response_json; + Json::Reader reader; + if (!reader.parse(response.body, response_json)) { + Json::Value error; + error["error"] = "Failed to parse response"; + return error; + } + return response_json; + } } } // namespace remote_engine \ No newline at end of file diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 8ce6fa652..d8dfbad61 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -14,9 +14,6 @@ // Helper for CURL response namespace remote_engine { -inline bool IsRemoteEngine(std::string_view e) { - return e == kAnthropicEngine || e == kOpenAiEngine; -} struct StreamContext { std::shared_ptr> callback; @@ -24,6 +21,8 @@ struct StreamContext { // Cache value for Anthropic std::string id; std::string model; + TemplateRenderer& renderer; + std::string stream_template; }; struct CurlResponse { std::string body; @@ -49,8 +48,10 @@ class RemoteEngine : public RemoteEngineI { std::unordered_map models_; TemplateRenderer renderer_; Json::Value metadata_; + std::string chat_req_template_; + std::string chat_res_template_; std::string api_key_template_; - std::unique_ptr async_file_logger_; + std::string engine_name_; // Helper functions CurlResponse MakeChatCompletionRequest(const ModelConfig& config, @@ -67,7 +68,7 @@ class RemoteEngine : public RemoteEngineI { ModelConfig* GetModelConfig(const std::string& model); public: - RemoteEngine(); + explicit RemoteEngine(const std::string& engine_name); virtual ~RemoteEngine(); // Main interface implementations @@ -95,7 +96,7 @@ class RemoteEngine : public RemoteEngineI { void HandleEmbedding( std::shared_ptr json_body, std::function&& callback) override; - + Json::Value GetRemoteModels() override; }; diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 035ef4a4e..1f3e4d81c 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -6,8 +6,7 @@ #include #include "algorithm" #include "database/engines.h" -#include "extensions/remote-engine/anthropic_engine.h" -#include "extensions/remote-engine/openai_engine.h" +#include "extensions/remote-engine/remote_engine.h" #include "utils/archive_utils.h" #include "utils/engine_constants.h" #include "utils/engine_matcher_utils.h" @@ -187,7 +186,7 @@ cpp::result EngineService::UninstallEngineVariant( // TODO: handle uninstall remote engine // only delete a remote engine if no model are using it auto exist_engine = GetEngineByNameAndVariant(engine); - if (exist_engine.has_value() && exist_engine.value().type == "remote") { + if (exist_engine.has_value() && exist_engine.value().type == kRemote) { auto result = DeleteEngine(exist_engine.value().id); if (!result.empty()) { // This mean no error when delete model CTL_ERR("Failed to delete engine: " << result); @@ -333,15 +332,9 @@ cpp::result EngineService::DownloadEngine( } else { CTL_INF("Set default engine variant: " << res.value().variant); } - auto create_res = - EngineService::UpsertEngine(engine, // engine_name - "local", // todo - luke - "", // todo - luke - "", // todo - luke - normalize_version, variant.value(), - "Default", // todo - luke - "" // todo - luke - ); + auto create_res = EngineService::UpsertEngine( + engine, // engine_name + kLocal, "", "", normalize_version, variant.value(), "Default", ""); if (create_res.has_value()) { CTL_ERR("Failed to create engine entry: " << create_res->engine_name); @@ -683,17 +676,13 @@ cpp::result EngineService::LoadEngine( } // Check for remote engine - if (remote_engine::IsRemoteEngine(engine_name)) { + if (IsRemoteEngine(engine_name)) { auto exist_engine = GetEngineByNameAndVariant(engine_name); if (exist_engine.has_error()) { return cpp::fail("Remote engine '" + engine_name + "' is not installed"); } - if (engine_name == kOpenAiEngine) { - engines_[engine_name].engine = new remote_engine::OpenAiEngine(); - } else { - engines_[engine_name].engine = new remote_engine::AnthropicEngine(); - } + engines_[engine_name].engine = new remote_engine::RemoteEngine(engine_name); CTL_INF("Loaded engine: " << engine_name); return {}; @@ -899,7 +888,7 @@ cpp::result EngineService::IsEngineReady( auto ne = NormalizeEngine(engine); // Check for remote engine - if (remote_engine::IsRemoteEngine(engine)) { + if (IsRemoteEngine(engine)) { auto exist_engine = GetEngineByNameAndVariant(engine); if (exist_engine.has_error()) { return cpp::fail("Remote engine '" + engine + "' is not installed"); @@ -1075,11 +1064,7 @@ cpp::result EngineService::GetRemoteModels( if (exist_engine.has_error()) { return cpp::fail("Remote engine '" + engine_name + "' is not installed"); } - if (engine_name == kOpenAiEngine) { - engines_[engine_name].engine = new remote_engine::OpenAiEngine(); - } else { - engines_[engine_name].engine = new remote_engine::AnthropicEngine(); - } + engines_[engine_name].engine = new remote_engine::RemoteEngine(engine_name); CTL_INF("Loaded engine: " << engine_name); } @@ -1092,6 +1077,16 @@ cpp::result EngineService::GetRemoteModels( } } +bool EngineService::IsRemoteEngine(const std::string& engine_name) { + auto ne = Repo2Engine(engine_name); + auto local_engines = file_manager_utils::GetCortexConfig().supportedEngines; + for (auto const& le : local_engines) { + if (le == ne) + return false; + } + return true; +} + cpp::result, std::string> EngineService::GetSupportedEngineNames() { return file_manager_utils::GetCortexConfig().supportedEngines; diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 9253eccf1..527123cb5 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -153,6 +153,8 @@ class EngineService : public EngineServiceI { void RegisterEngineLibPath(); + bool IsRemoteEngine(const std::string& engine_name) override; + private: bool IsEngineLoaded(const std::string& engine); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 6a45733d3..ce83152c4 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -773,7 +773,7 @@ cpp::result ModelService::StartModel( auto mc = yaml_handler.GetModelConfig(); // Running remote model - if (remote_engine::IsRemoteEngine(mc.engine)) { + if (engine_svc_->IsRemoteEngine(mc.engine)) { config::RemoteModelConfig remote_mc; remote_mc.LoadFromYamlFile( diff --git a/engine/test/components/CMakeLists.txt b/engine/test/components/CMakeLists.txt index 58c5d83d6..0df46cfc2 100644 --- a/engine/test/components/CMakeLists.txt +++ b/engine/test/components/CMakeLists.txt @@ -16,6 +16,7 @@ add_executable(${PROJECT_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/file_manager_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/curl_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/system_info_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../extensions/remote-engine/template_renderer.cc ) find_package(Drogon CONFIG REQUIRED) diff --git a/engine/test/components/main.cc b/engine/test/components/main.cc index 08080680e..ba24a3e01 100644 --- a/engine/test/components/main.cc +++ b/engine/test/components/main.cc @@ -4,11 +4,15 @@ int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); +#if defined(NDEBUG) ::testing::GTEST_FLAG(filter) = "-FileManagerConfigTest.*"; int ret = RUN_ALL_TESTS(); if (ret != 0) return ret; ::testing::GTEST_FLAG(filter) = "FileManagerConfigTest.*"; ret = RUN_ALL_TESTS(); +#else + int ret = RUN_ALL_TESTS(); +#endif return ret; } diff --git a/engine/test/components/test_remote_engine.cc b/engine/test/components/test_remote_engine.cc new file mode 100644 index 000000000..bfac76f49 --- /dev/null +++ b/engine/test/components/test_remote_engine.cc @@ -0,0 +1,81 @@ +#include "extensions/remote-engine/template_renderer.h" +#include "gtest/gtest.h" +#include "utils/json_helper.h" + +class RemoteEngineTest : public ::testing::Test {}; + +TEST_F(RemoteEngineTest, OpenAiToAnthropicRequest) { + std::string tpl = + R"({ + {% for key, value in input_request %} + {% if key == "messages" %} + {% if input_request.messages.0.role == "system" %} + "system": "{{ input_request.messages.0.content }}", + "messages": [ + {% for message in input_request.messages %} + {% if not loop.is_first %} + {"role": "{{ message.role }}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {% endif %} + {% endfor %} + ] + {% else %} + "messages": [ + {% for message in input_request.messages %} + {"role": " {{ message.role}}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %} + {% endfor %} + ] + {% endif %} + {% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %} + "{{ key }}": {{ tojson(value) }} + {% endif %} + {% if not loop.is_last %},{% endif %} + {% endfor %} })"; + { + std::string message_with_system = R"({ + "messages": [ + {"role": "system", "content": "You are a seasoned data scientist at a Fortune 500 company."}, + {"role": "user", "content": "Hello, world"} + ], + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, +})"; + + auto data = json_helper::ParseJsonString(message_with_system); + + remote_engine::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(data["model"].asString(), res_json["model"].asString()); + EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt()); + for (auto const& msg : data["messages"]) { + if (msg["role"].asString() == "system") { + EXPECT_EQ(msg["content"].asString(), res_json["system"].asString()); + } else if (msg["role"].asString() == "user") { + EXPECT_EQ(msg["content"].asString(), + res_json["messages"][0]["content"].asString()); + } + } + } + + { + std::string message_without_system = R"({ + "messages": [ + {"role": "user", "content": "Hello, world"} + ], + "model": "claude-3-5-sonnet-20241022", + "max_tokens": 1024, +})"; + + auto data = json_helper::ParseJsonString(message_without_system); + + remote_engine::TemplateRenderer rdr; + auto res = rdr.Render(tpl, data); + + auto res_json = json_helper::ParseJsonString(res); + EXPECT_EQ(data["model"].asString(), res_json["model"].asString()); + EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt()); + EXPECT_EQ(data["messages"][0]["content"].asString(), + res_json["messages"][0]["content"].asString()); + } +} \ No newline at end of file diff --git a/engine/utils/engine_constants.h b/engine/utils/engine_constants.h index 020109fd8..dcdf6a443 100644 --- a/engine/utils/engine_constants.h +++ b/engine/utils/engine_constants.h @@ -6,6 +6,9 @@ constexpr const auto kTrtLlmEngine = "tensorrt-llm"; constexpr const auto kOpenAiEngine = "openai"; constexpr const auto kAnthropicEngine = "anthropic"; +constexpr const auto kRemote = "remote"; +constexpr const auto kLocal = "local"; + constexpr const auto kOnnxRepo = "cortex.onnx"; constexpr const auto kLlamaRepo = "cortex.llamacpp"; constexpr const auto kTrtLlmRepo = "cortex.tensorrt-llm"; From 52acbfae73e4702cbaf4ebfb6fedb4b5056aa180 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 17 Dec 2024 12:23:46 +0700 Subject: [PATCH 38/44] fix: validate GPU (#1802) --- engine/services/hardware_service.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/engine/services/hardware_service.cc b/engine/services/hardware_service.cc index 97ddacb97..ca1ea4cc6 100644 --- a/engine/services/hardware_service.cc +++ b/engine/services/hardware_service.cc @@ -357,11 +357,11 @@ bool HardwareService::IsValidConfig( auto res = hw_db.LoadHardwareList(); if (res.has_value()) { for (auto const& e : res.value()) { - if (!is_valid(e.software_id)) { - return false; + if (is_valid(e.software_id)) { + return true; } } } - return true; + return false; } } // namespace services From c9f15a203d862fc248630bf55d80b9e23046cc92 Mon Sep 17 00:00:00 2001 From: NamH Date: Tue, 17 Dec 2024 13:49:36 +0700 Subject: [PATCH 39/44] fix: swagger getting configuration from config file (#1803) --- docs/static/openapi/cortex.json | 8 ++++---- engine/controllers/swagger.cc | 31 +++++++++-------------------- engine/controllers/swagger.h | 35 +++++++++++++++++++++++++++++---- engine/main.cc | 4 ++++ 4 files changed, 48 insertions(+), 30 deletions(-) diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 4b238d2e4..9134e89e6 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -3092,7 +3092,7 @@ "items": { "type": "string" }, - "example": ["http://localhost:39281", "https://cortex.so"] + "example": ["http://127.0.0.1:39281", "https://cortex.so"] }, "cors": { "type": "boolean", @@ -3139,7 +3139,7 @@ }, "example": { "allowed_origins": [ - "http://localhost:39281", + "http://127.0.0.1:39281", "https://cortex.so" ], "cors": false, @@ -3180,7 +3180,7 @@ "type": "string" }, "description": "List of allowed origins.", - "example": ["http://localhost:39281", "https://cortex.so"] + "example": ["http://127.0.0.1:39281", "https://cortex.so"] }, "proxy_username": { "type": "string", @@ -3249,7 +3249,7 @@ "type": "string" }, "example": [ - "http://localhost:39281", + "http://127.0.0.1:39281", "https://cortex.so" ] }, diff --git a/engine/controllers/swagger.cc b/engine/controllers/swagger.cc index 96a6c3837..abb80b94e 100644 --- a/engine/controllers/swagger.cc +++ b/engine/controllers/swagger.cc @@ -2,30 +2,17 @@ #include "cortex_openapi.h" #include "utils/cortex_utils.h" -constexpr auto ScalarUi = R"( - - - - Cortex API Reference - - - - - - - - - -)"; - -Json::Value SwaggerController::generateOpenAPISpec() { +Json::Value SwaggerController::GenerateOpenApiSpec() const { Json::Value root; Json::Reader reader; reader.parse(CortexOpenApi::GetOpenApiJson(), root); + + Json::Value server_url; + server_url["url"] = "http://" + host_ + ":" + port_; + Json::Value resp_data(Json::arrayValue); + resp_data.append(server_url); + + root["servers"] = resp_data; return root; } @@ -41,7 +28,7 @@ void SwaggerController::serveSwaggerUI( void SwaggerController::serveOpenAPISpec( const drogon::HttpRequestPtr& req, std::function&& callback) const { - Json::Value spec = generateOpenAPISpec(); + auto spec = GenerateOpenApiSpec(); auto resp = cortex_utils::CreateCortexHttpJsonResponse(spec); callback(resp); } diff --git a/engine/controllers/swagger.h b/engine/controllers/swagger.h index 4099bc447..61db1cc6e 100644 --- a/engine/controllers/swagger.h +++ b/engine/controllers/swagger.h @@ -5,13 +5,38 @@ using namespace drogon; -class SwaggerController : public drogon::HttpController { +class SwaggerController + : public drogon::HttpController { + + constexpr static auto ScalarUi = R"( + + + + Cortex API Reference + + + + + + + + + +)"; + public: METHOD_LIST_BEGIN ADD_METHOD_TO(SwaggerController::serveSwaggerUI, "/", Get); ADD_METHOD_TO(SwaggerController::serveOpenAPISpec, "/openapi.json", Get); METHOD_LIST_END + explicit SwaggerController(const std::string& host, const std::string& port) + : host_{host}, port_{port} {}; + void serveSwaggerUI( const drogon::HttpRequestPtr& req, std::function&& callback) const; @@ -21,6 +46,8 @@ class SwaggerController : public drogon::HttpController { std::function&& callback) const; private: - static const std::string swaggerUIHTML; - static Json::Value generateOpenAPISpec(); -}; \ No newline at end of file + std::string host_; + std::string port_; + + Json::Value GenerateOpenApiSpec() const; +}; diff --git a/engine/main.cc b/engine/main.cc index 8ca5ffd1f..b79859ef3 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -11,6 +11,7 @@ #include "controllers/models.h" #include "controllers/process_manager.h" #include "controllers/server.h" +#include "controllers/swagger.h" #include "controllers/threads.h" #include "database/database.h" #include "migrations/migration_manager.h" @@ -155,6 +156,8 @@ void RunServer(std::optional port, bool ignore_cout) { file_watcher_srv->start(); // initialize custom controllers + auto swagger_ctl = std::make_shared(config.apiServerHost, + config.apiServerPort); auto file_ctl = std::make_shared(file_srv, message_srv); auto assistant_ctl = std::make_shared(assistant_srv); auto thread_ctl = std::make_shared(thread_srv, message_srv); @@ -169,6 +172,7 @@ void RunServer(std::optional port, bool ignore_cout) { std::make_shared(inference_svc, engine_service); auto config_ctl = std::make_shared(config_service); + drogon::app().registerController(swagger_ctl); drogon::app().registerController(file_ctl); drogon::app().registerController(assistant_ctl); drogon::app().registerController(thread_ctl); From 841a8df9074aa5fe51c7adacf7c080eb86717648 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Wed, 18 Dec 2024 05:19:43 +0700 Subject: [PATCH 40/44] feat: support host parameter for server (#1805) --- engine/main.cc | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/engine/main.cc b/engine/main.cc index b79859ef3..5cc6c740e 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -51,7 +51,8 @@ #error "Unsupported platform!" #endif -void RunServer(std::optional port, bool ignore_cout) { +void RunServer(std::optional host, std::optional port, + bool ignore_cout) { #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) signal(SIGINT, SIG_IGN); #elif defined(_WIN32) @@ -62,9 +63,16 @@ void RunServer(std::optional port, bool ignore_cout) { reinterpret_cast(console_ctrl_handler), true); #endif auto config = file_manager_utils::GetCortexConfig(); - if (port.has_value() && *port != std::stoi(config.apiServerPort)) { + if (host.has_value() || port.has_value()) { + if (host.has_value() && *host != config.apiServerHost) { + config.apiServerHost = *host; + } + + if (port.has_value() && *port != std::stoi(config.apiServerPort)) { + config.apiServerPort = std::to_string(*port); + } + auto config_path = file_manager_utils::GetConfigurationPath(); - config.apiServerPort = std::to_string(*port); auto result = config_yaml_utils::CortexConfigMgr::GetInstance().DumpYamlConfig( config, config_path.string()); @@ -72,6 +80,7 @@ void RunServer(std::optional port, bool ignore_cout) { CTL_ERR("Error update " << config_path.string() << result.error()); } } + if (!ignore_cout) { std::cout << "Host: " << config.apiServerHost << " Port: " << config.apiServerPort << "\n"; @@ -283,6 +292,7 @@ int main(int argc, char* argv[]) { // avoid printing logs to terminal is_server = true; + std::optional server_host; std::optional server_port; bool ignore_cout_log = false; #if defined(_WIN32) @@ -296,6 +306,8 @@ int main(int argc, char* argv[]) { std::wstring v = argv[i + 1]; file_manager_utils::cortex_data_folder_path = cortex::wc::WstringToUtf8(v); + } else if (command == L"--host") { + server_host = cortex::wc::WstringToUtf8(argv[i + 1]); } else if (command == L"--port") { server_port = std::stoi(argv[i + 1]); } else if (command == L"--ignore_cout") { @@ -312,6 +324,8 @@ int main(int argc, char* argv[]) { file_manager_utils::cortex_config_file_path = argv[i + 1]; } else if (strcmp(argv[i], "--data_folder_path") == 0) { file_manager_utils::cortex_data_folder_path = argv[i + 1]; + } else if (strcmp(argv[i], "--host") == 0) { + server_host = argv[i + 1]; } else if (strcmp(argv[i], "--port") == 0) { server_port = std::stoi(argv[i + 1]); } else if (strcmp(argv[i], "--ignore_cout") == 0) { @@ -367,6 +381,6 @@ int main(int argc, char* argv[]) { } } - RunServer(server_port, ignore_cout_log); + RunServer(server_host, server_port, ignore_cout_log); return 0; } From 33be8d839af7200545f799ab0787136030729731 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Wed, 18 Dec 2024 09:37:50 +0700 Subject: [PATCH 41/44] fix: check cpu info size (#1804) * fix: check cpu info size * fix: sort gpus --- engine/services/hardware_service.cc | 4 ++-- engine/utils/hardware/cpu_info.h | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/engine/services/hardware_service.cc b/engine/services/hardware_service.cc index ca1ea4cc6..ca2bd8ed9 100644 --- a/engine/services/hardware_service.cc +++ b/engine/services/hardware_service.cc @@ -235,8 +235,8 @@ bool HardwareService::SetActivateHardwareConfig( activated_ids.push_back(std::pair(e.software_id, e.priority)); } } - std::sort(activated_ids.begin(), activated_ids.end()); - std::sort(ahc_gpus.begin(), ahc_gpus.end()); + std::sort(activated_ids.begin(), activated_ids.end(), + [](auto& p1, auto& p2) { return p1.second < p2.second; }); if (ahc_gpus.size() != activated_ids.size()) { need_update = true; } else { diff --git a/engine/utils/hardware/cpu_info.h b/engine/utils/hardware/cpu_info.h index 4c2cb3027..4395cc8dd 100644 --- a/engine/utils/hardware/cpu_info.h +++ b/engine/utils/hardware/cpu_info.h @@ -10,7 +10,10 @@ namespace cortex::hw { inline CPU GetCPUInfo() { - auto cpu = hwinfo::getAllCPUs()[0]; + auto res = hwinfo::getAllCPUs(); + if (res.empty()) + return CPU{}; + auto cpu = res[0]; cortex::cpuid::CpuInfo inst; return CPU{.cores = cpu.numPhysicalCores(), .arch = std::string(GetArch()), From 0ae0146e4f66a0e0558c554e3b8ef14ac7c34f00 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Wed, 18 Dec 2024 16:50:41 +0700 Subject: [PATCH 42/44] fix: only use dll search path if ENGINE_PATH is not set (#1808) * fix: only use dll search path if ENGINE_PATH is not set * chore: remove unused --- engine/services/engine_service.cc | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 1f3e4d81c..bdd080f50 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -704,21 +704,23 @@ cpp::result EngineService::LoadEngine( #if defined(_WIN32) || defined(_WIN64) // register deps - std::vector paths{}; - paths.push_back(std::move(cuda_path)); - paths.push_back(std::move(engine_dir_path)); - - CTL_DBG("Registering dylib for " - << ne << " with " << std::to_string(paths.size()) << " paths."); - for (const auto& path : paths) { - CTL_DBG("Registering path: " << path.string()); - } + if (!(getenv("ENGINE_PATH"))) { + std::vector paths{}; + paths.push_back(std::move(cuda_path)); + paths.push_back(std::move(engine_dir_path)); - auto reg_result = dylib_path_manager_->RegisterPath(ne, paths); - if (reg_result.has_error()) { - CTL_DBG("Failed register lib paths for: " << ne); - } else { - CTL_DBG("Registered lib paths for: " << ne); + CTL_DBG("Registering dylib for " + << ne << " with " << std::to_string(paths.size()) << " paths."); + for (const auto& path : paths) { + CTL_DBG("Registering path: " << path.string()); + } + + auto reg_result = dylib_path_manager_->RegisterPath(ne, paths); + if (reg_result.has_error()) { + CTL_DBG("Failed register lib paths for: " << ne); + } else { + CTL_DBG("Registered lib paths for: " << ne); + } } #endif From b0b8bec212de99f255f980c4c6c9ad8c262425db Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Wed, 18 Dec 2024 16:51:44 +0700 Subject: [PATCH 43/44] chore: correct storage dto link (#1809) --- docs/static/openapi/cortex.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 9134e89e6..a05f8b24e 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -6282,7 +6282,7 @@ }, "required": ["available", "total", "type"] }, - "Storage": { + "StorageDto": { "type": "object", "properties": { "available": { From 5414e02a7ed1fc253dabbc890a9a7225f207f161 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 19 Dec 2024 16:45:40 +0700 Subject: [PATCH 44/44] chore: add log for cpu instructions (#1807) * fix: check cpu info size * fix: sort gpus * chore: add log for cpu instructions * fix: add guard --- engine/cli/commands/model_status_cmd.cc | 2 +- engine/services/engine_service.cc | 4 ++++ engine/utils/github_release_utils.h | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/engine/cli/commands/model_status_cmd.cc b/engine/cli/commands/model_status_cmd.cc index cd9f3034d..e467e4353 100644 --- a/engine/cli/commands/model_status_cmd.cc +++ b/engine/cli/commands/model_status_cmd.cc @@ -25,7 +25,7 @@ bool ModelStatusCmd::IsLoaded(const std::string& host, int port, auto res = curl_utils::SimpleGetJson(url.ToFullPath()); if (res.has_error()) { auto root = json_helper::ParseJsonString(res.error()); - CLI_LOG(root["message"].asString()); + CTL_WRN(root["message"].asString()); return false; } diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index bdd080f50..c8f4c180c 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -8,6 +8,7 @@ #include "database/engines.h" #include "extensions/remote-engine/remote_engine.h" #include "utils/archive_utils.h" +#include "utils/cpuid/cpu_info.h" #include "utils/engine_constants.h" #include "utils/engine_matcher_utils.h" #include "utils/file_manager_utils.h" @@ -691,6 +692,9 @@ cpp::result EngineService::LoadEngine( // End hard code CTL_INF("Loading engine: " << ne); +#if defined(_WIN32) || defined(_WIN64) || defined(__linux__) + CTL_INF("CPU Info: " << cortex::cpuid::CpuInfo().to_string()); +#endif auto engine_dir_path_res = GetEngineDirPath(ne); if (engine_dir_path_res.has_error()) { diff --git a/engine/utils/github_release_utils.h b/engine/utils/github_release_utils.h index be97cb37c..72d7687f6 100644 --- a/engine/utils/github_release_utils.h +++ b/engine/utils/github_release_utils.h @@ -194,7 +194,7 @@ inline cpp::result GetReleaseByVersion( .pathParams = path_params, }; - CTL_DBG("GetReleaseByVersion: " << url.ToFullPath()); + // CTL_DBG("GetReleaseByVersion: " << url.ToFullPath()); auto result = curl_utils::SimpleGetJson(url_parser::FromUrl(url), kCurlGetTimeout);