From 00af9798390f8cab6966ac4a922ff46d84956b53 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 29 Oct 2024 10:51:43 +0700 Subject: [PATCH] feat: cortex pull and cortex engines install CLI uses API server (#1550) * fix: add ws and indicators * fix: more * fix: pull models info from server * fix: model_source * fix: rename * fix: download cortexso * fix: pull models * fix: remove comments * fix: rename * fix: change download UI * fix: comment out * fix: e2e tests * fix: run, start * fix: start server * fix: e2e * fix: remove * fix: abort model * fix: build * fix: clean code * fix: clean more * fix: normalize engine id * fix: use auto * fix: use vcpkg for indicators * fix: download progress --------- Co-authored-by: vansangpfiev --- engine/cli/CMakeLists.txt | 5 + engine/cli/command_line_parser.cc | 8 +- engine/cli/commands/engine_install_cmd.cc | 66 +- engine/cli/commands/engine_install_cmd.h | 8 +- engine/cli/commands/model_pull_cmd.cc | 178 +++++- engine/cli/commands/model_pull_cmd.h | 9 +- engine/cli/commands/model_start_cmd.cc | 2 +- engine/cli/commands/run_cmd.cc | 17 +- engine/cli/commands/run_cmd.h | 5 +- engine/cli/commands/server_start_cmd.cc | 4 + engine/cli/utils/download_progress.cc | 109 ++++ engine/cli/utils/download_progress.h | 28 + engine/cli/utils/easywsclient.cc | 594 ++++++++++++++++++ engine/cli/utils/easywsclient.hpp | 85 +++ engine/common/download_task.h | 66 ++ engine/common/event.h | 28 + engine/controllers/models.cc | 39 ++ engine/controllers/models.h | 3 + engine/e2e-test/test_cli_engine_install.py | 23 +- engine/e2e-test/test_cli_engine_uninstall.py | 2 +- ..._cli_model_pull_cortexso_with_selection.py | 13 + .../test_cli_model_pull_direct_url.py | 18 +- .../test_cli_model_pull_from_cortexso.py | 11 + ..._cli_model_pull_hugging_face_repository.py | 11 + engine/e2e-test/test_create_log_folder.py | 1 + engine/e2e-test/test_runner.py | 4 +- engine/services/engine_service.cc | 56 +- engine/services/engine_service.h | 12 +- engine/services/model_service.cc | 136 +++- engine/services/model_service.h | 13 + engine/test/components/test_event.cc | 50 ++ engine/vcpkg.json | 3 +- 32 files changed, 1526 insertions(+), 81 deletions(-) create mode 100644 engine/cli/utils/download_progress.cc create mode 100644 engine/cli/utils/download_progress.h create mode 100644 engine/cli/utils/easywsclient.cc create mode 100644 engine/cli/utils/easywsclient.hpp create mode 100644 engine/test/components/test_event.cc diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index 11e2c384b..19f206a40 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -71,6 +71,8 @@ find_package(tabulate CONFIG REQUIRED) find_package(CURL REQUIRED) find_package(SQLiteCpp REQUIRED) find_package(Trantor CONFIG REQUIRED) +find_package(indicators CONFIG REQUIRED) + add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../utils/cpuid/cpu_info.cc @@ -80,6 +82,8 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/engine_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc + ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc + ${CMAKE_CURRENT_SOURCE_DIR}/utils/download_progress.cc ) target_link_libraries(${TARGET_NAME} PRIVATE httplib::httplib) @@ -93,6 +97,7 @@ target_link_libraries(${TARGET_NAME} PRIVATE JsonCpp::JsonCpp OpenSSL::SSL OpenS ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET_NAME} PRIVATE SQLiteCpp) target_link_libraries(${TARGET_NAME} PRIVATE Trantor::Trantor) +target_link_libraries(${TARGET_NAME} PRIVATE indicators::indicators) # ############################################################################## diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index c40c90a9e..c4c612aa6 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -130,7 +130,9 @@ void CommandLineParser::SetupCommonCommands() { return; } try { - commands::ModelPullCmd(download_service_).Exec(cml_data_.model_id); + commands::ModelPullCmd(download_service_) + .Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id); } catch (const std::exception& e) { CLI_LOG(e.what()); } @@ -462,7 +464,9 @@ void CommandLineParser::EngineInstall(CLI::App* parent, if (std::exchange(executed_, true)) return; try { - commands::EngineInstallCmd(download_service_) + commands::EngineInstallCmd(download_service_, + cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort)) .Exec(engine_name, version, src); } catch (const std::exception& e) { CTL_ERR(e.what()); diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index 8cf7c1cc7..4cb9c0277 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -1,16 +1,68 @@ #include "engine_install_cmd.h" +#include "server_start_cmd.h" +#include "utils/download_progress.h" +#include "utils/engine_constants.h" +#include "utils/json_helper.h" #include "utils/logging_utils.h" namespace commands { - -void EngineInstallCmd::Exec(const std::string& engine, +bool EngineInstallCmd::Exec(const std::string& engine, const std::string& version, const std::string& src) { - auto result = engine_service_.InstallEngine(engine, version, src); - if (result.has_error()) { - CLI_LOG(result.error()); - } else if(result && result.value()){ - CLI_LOG("Engine " << engine << " installed successfully!"); + // Handle local install, if fails, fallback to remote install + if (!src.empty()) { + auto res = engine_service_.UnzipEngine(engine, version, src); + if (res.has_error()) { + CLI_LOG(res.error()); + return false; + } + if (res.value()) { + CLI_LOG("Engine " << engine << " installed successfully!"); + return true; + } + } + + // Start server if server is not started yet + if (!commands::IsServerAlive(host_, port_)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host_, port_)) { + return false; + } + } + + httplib::Client cli(host_ + ":" + std::to_string(port_)); + Json::Value json_data; + auto data_str = json_data.toStyledString(); + cli.set_read_timeout(std::chrono::seconds(60)); + auto res = cli.Post("/v1/engines/install/" + engine, httplib::Headers(), + data_str.data(), data_str.size(), "application/json"); + + if (res) { + if (res->status != httplib::StatusCode::OK_200) { + auto root = json_helper::ParseJsonString(res->body); + CLI_LOG(root["message"].asString()); + return false; + } + } else { + auto err = res.error(); + CTL_ERR("HTTP error: " << httplib::to_string(err)); + return false; } + + CLI_LOG("Start downloading ...") + DownloadProgress dp; + dp.Connect(host_, port_); + if (!dp.Handle(engine)) + return false; + + bool check_cuda_download = !system_info_utils::GetCudaVersion().empty(); + if (check_cuda_download) { + if (!dp.Handle("cuda")) + return false; + } + + CLI_LOG("Engine " << engine << " downloaded successfully!") + return true; } }; // namespace commands diff --git a/engine/cli/commands/engine_install_cmd.h b/engine/cli/commands/engine_install_cmd.h index 199d4d319..4a22d03f7 100644 --- a/engine/cli/commands/engine_install_cmd.h +++ b/engine/cli/commands/engine_install_cmd.h @@ -7,13 +7,15 @@ namespace commands { class EngineInstallCmd { public: - explicit EngineInstallCmd(std::shared_ptr download_service) - : engine_service_{EngineService(download_service)} {}; + explicit EngineInstallCmd(std::shared_ptr download_service, const std::string& host, int port) + : engine_service_{EngineService(download_service)}, host_(host), port_(port) {}; - void Exec(const std::string& engine, const std::string& version = "latest", + bool Exec(const std::string& engine, const std::string& version = "latest", const std::string& src = ""); private: EngineService engine_service_; + std::string host_; + int port_; }; } // namespace commands diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 4ec5344bb..3a8f202d3 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -1,11 +1,181 @@ #include "model_pull_cmd.h" +#include +#include "common/event.h" +#include "database/models.h" +#include "server_start_cmd.h" +#include "utils/cli_selection_utils.h" +#include "utils/download_progress.h" +#include "utils/format_utils.h" +#include "utils/huggingface_utils.h" +#include "utils/json_helper.h" #include "utils/logging_utils.h" +#include "utils/scope_exit.h" +#include "utils/string_utils.h" +#if defined(_WIN32) +#include +#endif namespace commands { -void ModelPullCmd::Exec(const std::string& input) { - auto result = model_service_.DownloadModel(input); - if (result.has_error()) { - CLI_LOG(result.error()); +std::function shutdown_handler; +inline void signal_handler(int signal) { + if (shutdown_handler) { + shutdown_handler(signal); + } +} +std::optional ModelPullCmd::Exec(const std::string& host, int port, + const std::string& input) { + + // model_id: use to check the download progress + // model: use as a parameter for pull API + auto model_id = input; + auto model = input; + + // Start server if server is not started yet + if (!commands::IsServerAlive(host, port)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host, port)) { + return std::nullopt; + } + } + + // Get model info from Server + httplib::Client cli(host + ":" + std::to_string(port)); + cli.set_read_timeout(std::chrono::seconds(60)); + Json::Value j_data; + j_data["model"] = input; + auto d_str = j_data.toStyledString(); + auto res = cli.Post("/models/pull/info", httplib::Headers(), d_str.data(), + d_str.size(), "application/json"); + + if (res) { + if (res->status == httplib::StatusCode::OK_200) { + // CLI_LOG(res->body); + auto root = json_helper::ParseJsonString(res->body); + auto id = root["id"].asString(); + bool is_cortexso = root["modelSource"].asString() == "cortexso"; + auto default_branch = root["defaultBranch"].asString(); + std::vector downloaded; + for (auto const& v : root["downloadedModels"]) { + downloaded.push_back(v.asString()); + } + std::vector avails; + for (auto const& v : root["availableModels"]) { + avails.push_back(v.asString()); + } + auto download_url = root["downloadUrl"].asString(); + + if (downloaded.empty() && avails.empty()) { + model_id = id; + model = download_url; + } else { + if (is_cortexso) { + auto selection = cli_selection_utils::PrintModelSelection( + downloaded, avails, + default_branch.empty() + ? std::nullopt + : std::optional(default_branch)); + + if (!selection.has_value()) { + CLI_LOG("Invalid selection"); + return std::nullopt; + } + model_id = selection.value(); + model = model_id; + } else { + auto selection = cli_selection_utils::PrintSelection(avails); + CLI_LOG("Selected: " << selection.value()); + model_id = id + ":" + selection.value(); + model = download_url + selection.value(); + } + } + } else { + auto root = json_helper::ParseJsonString(res->body); + CLI_LOG(root["message"].asString()); + return std::nullopt; + } + } else { + auto err = res.error(); + CTL_ERR("HTTP error: " << httplib::to_string(err)); + return std::nullopt; + } + + // Send request download model to server + Json::Value json_data; + json_data["model"] = model; + auto data_str = json_data.toStyledString(); + cli.set_read_timeout(std::chrono::seconds(60)); + res = cli.Post("/v1/models/pull", httplib::Headers(), data_str.data(), + data_str.size(), "application/json"); + + if (res) { + if (res->status != httplib::StatusCode::OK_200) { + auto root = json_helper::ParseJsonString(res->body); + CLI_LOG(root["message"].asString()); + return std::nullopt; + } + } else { + auto err = res.error(); + CTL_ERR("HTTP error: " << httplib::to_string(err)); + return std::nullopt; + } + + CLI_LOG("Start downloading ...") + DownloadProgress dp; + bool force_stop = false; + + shutdown_handler = [this, &dp, &host, &port, &model_id, &force_stop](int) { + force_stop = true; + AbortModelPull(host, port, model_id); + dp.ForceStop(); + }; + + utils::ScopeExit se([]() { shutdown_handler = {}; }); +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = signal_handler; + sigemptyset(&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); + sigaction(SIGTERM, &sigint_action, NULL); +#elif defined(_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler( + reinterpret_cast(console_ctrl_handler), true); +#endif + dp.Connect(host, port); + if (!dp.Handle(model_id)) + return std::nullopt; + if (force_stop) + return std::nullopt; + CLI_LOG("Model " << model_id << " downloaded successfully!") + return model_id; +} + +bool ModelPullCmd::AbortModelPull(const std::string& host, int port, + const std::string& task_id) { + Json::Value json_data; + json_data["taskId"] = task_id; + auto data_str = json_data.toStyledString(); + httplib::Client cli(host + ":" + std::to_string(port)); + cli.set_read_timeout(std::chrono::seconds(60)); + auto res = cli.Delete("/v1/models/pull", httplib::Headers(), data_str.data(), + data_str.size(), "application/json"); + if (res) { + if (res->status == httplib::StatusCode::OK_200) { + CTL_INF("Abort model pull successfully: " << task_id); + return true; + } else { + auto root = json_helper::ParseJsonString(res->body); + CLI_LOG(root["message"].asString()); + return false; + } + } else { + auto err = res.error(); + CTL_ERR("HTTP error: " << httplib::to_string(err)); + return false; } } }; // namespace commands diff --git a/engine/cli/commands/model_pull_cmd.h b/engine/cli/commands/model_pull_cmd.h index 3586b3cd4..d05759dbc 100644 --- a/engine/cli/commands/model_pull_cmd.h +++ b/engine/cli/commands/model_pull_cmd.h @@ -8,7 +8,14 @@ class ModelPullCmd { public: explicit ModelPullCmd(std::shared_ptr download_service) : model_service_{ModelService(download_service)} {}; - void Exec(const std::string& input); + explicit ModelPullCmd(const ModelService& model_service) + : model_service_{model_service} {}; + std::optional Exec(const std::string& host, int port, + const std::string& input); + + private: + bool AbortModelPull(const std::string& host, int port, + const std::string& task_id); private: ModelService model_service_; diff --git a/engine/cli/commands/model_start_cmd.cc b/engine/cli/commands/model_start_cmd.cc index 9041e7e07..1055805f5 100644 --- a/engine/cli/commands/model_start_cmd.cc +++ b/engine/cli/commands/model_start_cmd.cc @@ -14,7 +14,7 @@ bool ModelStartCmd::Exec(const std::string& host, int port, const std::string& model_handle, bool print_success_log) { std::optional model_id = - SelectLocalModel(model_service_, model_handle); + SelectLocalModel(host, port, model_service_, model_handle); if (!model_id.has_value()) { return false; diff --git a/engine/cli/commands/run_cmd.cc b/engine/cli/commands/run_cmd.cc index c80f12de1..d09298cd5 100644 --- a/engine/cli/commands/run_cmd.cc +++ b/engine/cli/commands/run_cmd.cc @@ -3,6 +3,8 @@ #include "config/yaml_config.h" #include "cortex_upd_cmd.h" #include "database/models.h" +#include "engine_install_cmd.h" +#include "model_pull_cmd.h" #include "model_start_cmd.h" #include "model_status_cmd.h" #include "server_start_cmd.h" @@ -11,7 +13,8 @@ namespace commands { -std::optional SelectLocalModel(ModelService& model_service, +std::optional SelectLocalModel(std::string host, int port, + ModelService& model_service, const std::string& model_handle) { std::optional model_id = model_handle; cortex::db::Models modellist_handler; @@ -42,8 +45,8 @@ std::optional SelectLocalModel(ModelService& model_service, } else { auto related_models_ids = modellist_handler.FindRelatedModel(model_handle); if (related_models_ids.has_error() || related_models_ids.value().empty()) { - auto result = model_service.DownloadModel(model_handle); - if (result.has_error()) { + auto result = ModelPullCmd(model_service).Exec(host, port, model_handle); + if (!result) { CLI_LOG("Model " << model_handle << " not found!"); return std::nullopt; } @@ -79,7 +82,7 @@ std::string Repo2Engine(const std::string& r) { void RunCmd::Exec(bool run_detach) { std::optional model_id = - SelectLocalModel(model_service_, model_handle_); + SelectLocalModel(host_, port_, model_service_, model_handle_); if (!model_id.has_value()) { return; } @@ -114,9 +117,9 @@ void RunCmd::Exec(bool run_detach) { throw std::runtime_error("Engine " + mc.engine + " is incompatible"); } if (required_engine.value().status == EngineService::kNotInstalled) { - auto install_engine_result = engine_service_.InstallEngine(mc.engine); - if (install_engine_result.has_error()) { - throw std::runtime_error(install_engine_result.error()); + if (!EngineInstallCmd(download_service_, host_, port_) + .Exec(mc.engine)) { + return; } } } diff --git a/engine/cli/commands/run_cmd.h b/engine/cli/commands/run_cmd.h index 4a0d68078..46a687fce 100644 --- a/engine/cli/commands/run_cmd.h +++ b/engine/cli/commands/run_cmd.h @@ -6,7 +6,8 @@ namespace commands { -std::optional SelectLocalModel(ModelService& model_service, +std::optional SelectLocalModel(std::string host, int port, + ModelService& model_service, const std::string& model_handle); class RunCmd { @@ -16,6 +17,7 @@ class RunCmd { : host_{std::move(host)}, port_{port}, model_handle_{std::move(model_handle)}, + download_service_(download_service), engine_service_{EngineService(download_service)}, model_service_{ModelService(download_service)} {}; @@ -26,6 +28,7 @@ class RunCmd { int port_; std::string model_handle_; + std::shared_ptr download_service_; ModelService model_service_; EngineService engine_service_; }; diff --git a/engine/cli/commands/server_start_cmd.cc b/engine/cli/commands/server_start_cmd.cc index cd06a3ba3..ca5363fa6 100644 --- a/engine/cli/commands/server_start_cmd.cc +++ b/engine/cli/commands/server_start_cmd.cc @@ -100,6 +100,10 @@ bool ServerStartCmd::Exec(const std::string& host, int port) { auto data_path = file_manager_utils::GetEnginesContainerPath(); auto llamacpp_path = data_path / "cortex.llamacpp/"; auto trt_path = data_path / "cortex.tensorrt-llm/"; + if (!std::filesystem::exists(llamacpp_path)) { + std::filesystem::create_directory(llamacpp_path); + } + auto new_v = trt_path.string() + ":" + llamacpp_path.string() + ":" + v; setenv(name, new_v.c_str(), true); CTL_INF("LD_LIBRARY_PATH: " << getenv(name)); diff --git a/engine/cli/utils/download_progress.cc b/engine/cli/utils/download_progress.cc new file mode 100644 index 000000000..2613fe413 --- /dev/null +++ b/engine/cli/utils/download_progress.cc @@ -0,0 +1,109 @@ +#include "download_progress.h" +#include +#include "common/event.h" +#include "indicators/dynamic_progress.hpp" +#include "indicators/progress_bar.hpp" +#include "utils/format_utils.h" +#include "utils/json_helper.h" +#include "utils/logging_utils.h" + +bool DownloadProgress::Connect(const std::string& host, int port) { + if (ws_) { + CTL_INF("Already connected!"); + return true; + } + ws_.reset(easywsclient::WebSocket::from_url( + "ws://" + host + ":" + std::to_string(port) + "/events")); + if (!!ws_) + return false; + + return true; +} + +bool DownloadProgress::Handle(const std::string& id) { + assert(!!ws_); + status_ = DownloadStatus::DownloadStarted; + std::unique_ptr> bars; + + std::vector> items; + indicators::show_console_cursor(false); + auto handle_message = [this, &bars, &items, id](const std::string& message) { + CTL_INF(message); + + auto pad_string = [](const std::string& str, + size_t max_length = 20) -> std::string { + // Check the length of the input string + if (str.length() >= max_length) { + return str.substr( + 0, max_length); // Return truncated string if it's too long + } + + // Calculate the number of spaces needed + size_t padding_size = max_length - str.length(); + + // Create a new string with the original string followed by spaces + return str + std::string(padding_size, ' '); + }; + + auto ev = cortex::event::GetDownloadEventFromJson( + json_helper::ParseJsonString(message)); + // Ignore other task ids + if (ev.download_task_.id != id) { + return; + } + + if (!bars) { + bars = std::make_unique< + indicators::DynamicProgress>(); + for (auto& i : ev.download_task_.items) { + items.emplace_back(std::make_unique( + indicators::option::BarWidth{50}, indicators::option::Start{"["}, + indicators::option::Fill{"="}, indicators::option::Lead{">"}, + indicators::option::End{"]"}, + indicators::option::PrefixText{pad_string(i.id)}, + indicators::option::ForegroundColor{indicators::Color::white}, + indicators::option::ShowRemainingTime{true})); + bars->push_back(*(items.back())); + } + } else { + for (int i = 0; i < ev.download_task_.items.size(); i++) { + auto& it = ev.download_task_.items[i]; + uint64_t downloaded = it.downloadedBytes.value_or(0); + uint64_t total = it.bytes.value_or(9999); + if (ev.type_ == DownloadStatus::DownloadUpdated) { + (*bars)[i].set_option(indicators::option::PrefixText{ + pad_string(it.id) + + std::to_string( + int(static_cast(downloaded) / total * 100)) + + '%'}); + (*bars)[i].set_progress( + int(static_cast(downloaded) / total * 100)); + (*bars)[i].set_option(indicators::option::PostfixText{ + format_utils::BytesToHumanReadable(downloaded) + "/" + + format_utils::BytesToHumanReadable(total)}); + } else if (ev.type_ == DownloadStatus::DownloadSuccess) { + (*bars)[i].set_progress(100); + auto total_str = format_utils::BytesToHumanReadable(total); + (*bars)[i].set_option( + indicators::option::PostfixText{total_str + "/" + total_str}); + (*bars)[i].set_option( + indicators::option::PrefixText{pad_string(it.id) + "100%"}); + (*bars)[i].set_progress(100); + + CTL_INF("Download success"); + } + } + } + status_ = ev.type_; + }; + + while (ws_->getReadyState() != easywsclient::WebSocket::CLOSED && + !should_stop()) { + ws_->poll(); + ws_->dispatch(handle_message); + } + indicators::show_console_cursor(true); + if (status_ == DownloadStatus::DownloadError) + return false; + return true; +} \ No newline at end of file diff --git a/engine/cli/utils/download_progress.h b/engine/cli/utils/download_progress.h new file mode 100644 index 000000000..4f71e6d84 --- /dev/null +++ b/engine/cli/utils/download_progress.h @@ -0,0 +1,28 @@ +#pragma once +#include +#include +#include +#include "common/event.h" +#include "easywsclient.hpp" + +using DownloadStatus = cortex::event::DownloadEventType; +class DownloadProgress { + public: + bool Connect(const std::string& host, int port); + + bool Handle(const std::string& id); + + void ForceStop() { force_stop_ = true; } + + private: + bool should_stop() const { + return (status_ != DownloadStatus::DownloadStarted && + status_ != DownloadStatus::DownloadUpdated) || + force_stop_; + } + + private: + std::unique_ptr ws_; + std::atomic status_ = DownloadStatus::DownloadStarted; + std::atomic force_stop_ = false; +}; \ No newline at end of file diff --git a/engine/cli/utils/easywsclient.cc b/engine/cli/utils/easywsclient.cc new file mode 100644 index 000000000..5c6ed38e8 --- /dev/null +++ b/engine/cli/utils/easywsclient.cc @@ -0,0 +1,594 @@ + +#ifdef _WIN32 +#if defined(_MSC_VER) && !defined(_CRT_SECURE_NO_WARNINGS) +#define _CRT_SECURE_NO_WARNINGS // _CRT_SECURE_NO_WARNINGS for sscanf errors in MSVC2013 Express +#endif +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#include +#include +#include +#pragma comment(lib, "ws2_32") +#include +#include +#include +#include +#include +#ifndef _SSIZE_T_DEFINED +typedef int ssize_t; +#define _SSIZE_T_DEFINED +#endif +#ifndef _SOCKET_T_DEFINED +typedef SOCKET socket_t; +#define _SOCKET_T_DEFINED +#endif +#ifndef snprintf +#define snprintf _snprintf_s +#endif +#if _MSC_VER >= 1600 +// vs2010 or later +#include +#else +typedef __int8 int8_t; +typedef unsigned __int8 uint8_t; +typedef __int32 int32_t; +typedef unsigned __int32 uint32_t; +typedef __int64 int64_t; +typedef unsigned __int64 uint64_t; +#endif +#define socketerrno WSAGetLastError() +#define SOCKET_EAGAIN_EINPROGRESS WSAEINPROGRESS +#define SOCKET_EWOULDBLOCK WSAEWOULDBLOCK +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifndef _SOCKET_T_DEFINED +typedef int socket_t; +#define _SOCKET_T_DEFINED +#endif +#ifndef INVALID_SOCKET +#define INVALID_SOCKET (-1) +#endif +#ifndef SOCKET_ERROR +#define SOCKET_ERROR (-1) +#endif +#define closesocket(s) ::close(s) +#include +#define socketerrno errno +#define SOCKET_EAGAIN_EINPROGRESS EAGAIN +#define SOCKET_EWOULDBLOCK EWOULDBLOCK +#endif + +#include +#include + +#include "easywsclient.hpp" + +using easywsclient::BytesCallback_Imp; +using easywsclient::Callback_Imp; + +namespace { // private module-only namespace + +socket_t hostname_connect(const std::string& hostname, int port) { + struct addrinfo hints; + struct addrinfo* result; + struct addrinfo* p; + int ret; + socket_t sockfd = INVALID_SOCKET; + char sport[16]; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + snprintf(sport, 16, "%d", port); + if ((ret = getaddrinfo(hostname.c_str(), sport, &hints, &result)) != 0) { + fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(ret)); + return 1; + } + for (p = result; p != NULL; p = p->ai_next) { + sockfd = socket(p->ai_family, p->ai_socktype, p->ai_protocol); + if (sockfd == INVALID_SOCKET) { + continue; + } + if (connect(sockfd, p->ai_addr, p->ai_addrlen) != SOCKET_ERROR) { + break; + } + closesocket(sockfd); + sockfd = INVALID_SOCKET; + } + freeaddrinfo(result); + return sockfd; +} + +class _DummyWebSocket : public easywsclient::WebSocket { + public: + void poll(int timeout) {} + void send(const std::string& message) {} + void sendBinary(const std::string& message) {} + void sendBinary(const std::vector& message) {} + void sendPing() {} + void close() {} + readyStateValues getReadyState() const { return CLOSED; } + void _dispatch(Callback_Imp& callable) {} + void _dispatchBinary(BytesCallback_Imp& callable) {} +}; + +class _RealWebSocket : public easywsclient::WebSocket { + public: + // http://tools.ietf.org/html/rfc6455#section-5.2 Base Framing Protocol + // + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-------+-+-------------+-------------------------------+ + // |F|R|R|R| opcode|M| Payload len | Extended payload length | + // |I|S|S|S| (4) |A| (7) | (16/64) | + // |N|V|V|V| |S| | (if payload len==126/127) | + // | |1|2|3| |K| | | + // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + // | Extended payload length continued, if payload len == 127 | + // + - - - - - - - - - - - - - - - +-------------------------------+ + // | |Masking-key, if MASK set to 1 | + // +-------------------------------+-------------------------------+ + // | Masking-key (continued) | Payload Data | + // +-------------------------------- - - - - - - - - - - - - - - - + + // : Payload Data continued ... : + // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + // | Payload Data continued ... | + // +---------------------------------------------------------------+ + struct wsheader_type { + unsigned header_size; + bool fin; + bool mask; + enum opcode_type { + CONTINUATION = 0x0, + TEXT_FRAME = 0x1, + BINARY_FRAME = 0x2, + CLOSE = 8, + PING = 9, + PONG = 0xa, + } opcode; + int N0; + uint64_t N; + uint8_t masking_key[4]; + }; + + std::vector rxbuf; + std::vector txbuf; + std::vector receivedData; + + socket_t sockfd; + readyStateValues readyState; + bool useMask; + bool isRxBad; + + _RealWebSocket(socket_t sockfd, bool useMask) + : sockfd(sockfd), readyState(OPEN), useMask(useMask), isRxBad(false) {} + + readyStateValues getReadyState() const { return readyState; } + + void poll(int timeout) { // timeout in milliseconds + if (readyState == CLOSED) { + if (timeout > 0) { + timeval tv = {timeout / 1000, (timeout % 1000) * 1000}; + select(0, NULL, NULL, NULL, &tv); + } + return; + } + if (timeout != 0) { + fd_set rfds; + fd_set wfds; + timeval tv = {timeout / 1000, (timeout % 1000) * 1000}; + FD_ZERO(&rfds); + FD_ZERO(&wfds); + FD_SET(sockfd, &rfds); + if (txbuf.size()) { + FD_SET(sockfd, &wfds); + } + select(sockfd + 1, &rfds, &wfds, 0, timeout > 0 ? &tv : 0); + } + while (true) { + // FD_ISSET(0, &rfds) will be true + int N = rxbuf.size(); + ssize_t ret; + rxbuf.resize(N + 1500); + ret = recv(sockfd, (char*)&rxbuf[0] + N, 1500, 0); + if (false) { + } else if (ret < 0 && (socketerrno == SOCKET_EWOULDBLOCK || + socketerrno == SOCKET_EAGAIN_EINPROGRESS)) { + rxbuf.resize(N); + break; + } else if (ret <= 0) { + rxbuf.resize(N); + closesocket(sockfd); + readyState = CLOSED; + fputs(ret < 0 ? "Connection error!\n" : "Connection closed!\n", stderr); + break; + } else { + rxbuf.resize(N + ret); + } + } + while (txbuf.size()) { + int ret = ::send(sockfd, (char*)&txbuf[0], txbuf.size(), 0); + if (false) { + } // ?? + else if (ret < 0 && (socketerrno == SOCKET_EWOULDBLOCK || + socketerrno == SOCKET_EAGAIN_EINPROGRESS)) { + break; + } else if (ret <= 0) { + closesocket(sockfd); + readyState = CLOSED; + fputs(ret < 0 ? "Connection error!\n" : "Connection closed!\n", stderr); + break; + } else { + txbuf.erase(txbuf.begin(), txbuf.begin() + ret); + } + } + if (!txbuf.size() && readyState == CLOSING) { + closesocket(sockfd); + readyState = CLOSED; + } + } + + // Callable must have signature: void(const std::string & message). + // Should work with C functions, C++ functors, and C++11 std::function and + // lambda: + //template + //void dispatch(Callable callable) + virtual void _dispatch(Callback_Imp& callable) { + struct CallbackAdapter : public BytesCallback_Imp + // Adapt void(const std::string&) to void(const std::string&) + { + Callback_Imp& callable; + CallbackAdapter(Callback_Imp& callable) : callable(callable) {} + void operator()(const std::vector& message) { + std::string stringMessage(message.begin(), message.end()); + callable(stringMessage); + } + }; + CallbackAdapter bytesCallback(callable); + _dispatchBinary(bytesCallback); + } + + virtual void _dispatchBinary(BytesCallback_Imp& callable) { + // TODO: consider acquiring a lock on rxbuf... + if (isRxBad) { + return; + } + while (true) { + wsheader_type ws; + if (rxbuf.size() < 2) { + return; /* Need at least 2 */ + } + const uint8_t* data = (uint8_t*)&rxbuf[0]; // peek, but don't consume + ws.fin = (data[0] & 0x80) == 0x80; + ws.opcode = (wsheader_type::opcode_type)(data[0] & 0x0f); + ws.mask = (data[1] & 0x80) == 0x80; + ws.N0 = (data[1] & 0x7f); + ws.header_size = 2 + (ws.N0 == 126 ? 2 : 0) + (ws.N0 == 127 ? 8 : 0) + + (ws.mask ? 4 : 0); + if (rxbuf.size() < ws.header_size) { + return; /* Need: ws.header_size - rxbuf.size() */ + } + int i = 0; + if (ws.N0 < 126) { + ws.N = ws.N0; + i = 2; + } else if (ws.N0 == 126) { + ws.N = 0; + ws.N |= ((uint64_t)data[2]) << 8; + ws.N |= ((uint64_t)data[3]) << 0; + i = 4; + } else if (ws.N0 == 127) { + ws.N = 0; + ws.N |= ((uint64_t)data[2]) << 56; + ws.N |= ((uint64_t)data[3]) << 48; + ws.N |= ((uint64_t)data[4]) << 40; + ws.N |= ((uint64_t)data[5]) << 32; + ws.N |= ((uint64_t)data[6]) << 24; + ws.N |= ((uint64_t)data[7]) << 16; + ws.N |= ((uint64_t)data[8]) << 8; + ws.N |= ((uint64_t)data[9]) << 0; + i = 10; + if (ws.N & 0x8000000000000000ull) { + // https://tools.ietf.org/html/rfc6455 writes the "the most + // significant bit MUST be 0." + // + // We can't drop the frame, because (1) we don't we don't + // know how much data to skip over to find the next header, + // and (2) this would be an impractically long length, even + // if it were valid. So just close() and return immediately + // for now. + isRxBad = true; + fprintf(stderr, "ERROR: Frame has invalid frame length. Closing.\n"); + close(); + return; + } + } + if (ws.mask) { + ws.masking_key[0] = ((uint8_t)data[i + 0]) << 0; + ws.masking_key[1] = ((uint8_t)data[i + 1]) << 0; + ws.masking_key[2] = ((uint8_t)data[i + 2]) << 0; + ws.masking_key[3] = ((uint8_t)data[i + 3]) << 0; + } else { + ws.masking_key[0] = 0; + ws.masking_key[1] = 0; + ws.masking_key[2] = 0; + ws.masking_key[3] = 0; + } + + // Note: The checks above should hopefully ensure this addition + // cannot overflow: + if (rxbuf.size() < ws.header_size + ws.N) { + return; /* Need: ws.header_size+ws.N - rxbuf.size() */ + } + + // We got a whole message, now do something with it: + if (false) { + } else if (ws.opcode == wsheader_type::TEXT_FRAME || + ws.opcode == wsheader_type::BINARY_FRAME || + ws.opcode == wsheader_type::CONTINUATION) { + if (ws.mask) { + for (size_t i = 0; i != ws.N; ++i) { + rxbuf[i + ws.header_size] ^= ws.masking_key[i & 0x3]; + } + } + receivedData.insert( + receivedData.end(), rxbuf.begin() + ws.header_size, + rxbuf.begin() + ws.header_size + (size_t)ws.N); // just feed + if (ws.fin) { + callable((const std::vector)receivedData); + receivedData.erase(receivedData.begin(), receivedData.end()); + std::vector().swap(receivedData); // free memory + } + } else if (ws.opcode == wsheader_type::PING) { + if (ws.mask) { + for (size_t i = 0; i != ws.N; ++i) { + rxbuf[i + ws.header_size] ^= ws.masking_key[i & 0x3]; + } + } + std::string data(rxbuf.begin() + ws.header_size, + rxbuf.begin() + ws.header_size + (size_t)ws.N); + sendData(wsheader_type::PONG, data.size(), data.begin(), data.end()); + } else if (ws.opcode == wsheader_type::PONG) { + } else if (ws.opcode == wsheader_type::CLOSE) { + close(); + } else { + fprintf(stderr, "ERROR: Got unexpected WebSocket message.\n"); + close(); + } + + rxbuf.erase(rxbuf.begin(), rxbuf.begin() + ws.header_size + (size_t)ws.N); + } + } + + void sendPing() { + std::string empty; + sendData(wsheader_type::PING, empty.size(), empty.begin(), empty.end()); + } + + void send(const std::string& message) { + sendData(wsheader_type::TEXT_FRAME, message.size(), message.begin(), + message.end()); + } + + void sendBinary(const std::string& message) { + sendData(wsheader_type::BINARY_FRAME, message.size(), message.begin(), + message.end()); + } + + void sendBinary(const std::vector& message) { + sendData(wsheader_type::BINARY_FRAME, message.size(), message.begin(), + message.end()); + } + + template + void sendData(wsheader_type::opcode_type type, uint64_t message_size, + Iterator message_begin, Iterator message_end) { + // TODO: + // Masking key should (must) be derived from a high quality random + // number generator, to mitigate attacks on non-WebSocket friendly + // middleware: + const uint8_t masking_key[4] = {0x12, 0x34, 0x56, 0x78}; + // TODO: consider acquiring a lock on txbuf... + if (readyState == CLOSING || readyState == CLOSED) { + return; + } + std::vector header; + header.assign(2 + (message_size >= 126 ? 2 : 0) + + (message_size >= 65536 ? 6 : 0) + (useMask ? 4 : 0), + 0); + header[0] = 0x80 | type; + if (false) { + } else if (message_size < 126) { + header[1] = (message_size & 0xff) | (useMask ? 0x80 : 0); + if (useMask) { + header[2] = masking_key[0]; + header[3] = masking_key[1]; + header[4] = masking_key[2]; + header[5] = masking_key[3]; + } + } else if (message_size < 65536) { + header[1] = 126 | (useMask ? 0x80 : 0); + header[2] = (message_size >> 8) & 0xff; + header[3] = (message_size >> 0) & 0xff; + if (useMask) { + header[4] = masking_key[0]; + header[5] = masking_key[1]; + header[6] = masking_key[2]; + header[7] = masking_key[3]; + } + } else { // TODO: run coverage testing here + header[1] = 127 | (useMask ? 0x80 : 0); + header[2] = (message_size >> 56) & 0xff; + header[3] = (message_size >> 48) & 0xff; + header[4] = (message_size >> 40) & 0xff; + header[5] = (message_size >> 32) & 0xff; + header[6] = (message_size >> 24) & 0xff; + header[7] = (message_size >> 16) & 0xff; + header[8] = (message_size >> 8) & 0xff; + header[9] = (message_size >> 0) & 0xff; + if (useMask) { + header[10] = masking_key[0]; + header[11] = masking_key[1]; + header[12] = masking_key[2]; + header[13] = masking_key[3]; + } + } + // N.B. - txbuf will keep growing until it can be transmitted over the socket: + txbuf.insert(txbuf.end(), header.begin(), header.end()); + txbuf.insert(txbuf.end(), message_begin, message_end); + if (useMask) { + size_t message_offset = txbuf.size() - message_size; + for (size_t i = 0; i != message_size; ++i) { + txbuf[message_offset + i] ^= masking_key[i & 0x3]; + } + } + } + + void close() { + if (readyState == CLOSING || readyState == CLOSED) { + return; + } + readyState = CLOSING; + uint8_t closeFrame[6] = {0x88, 0x80, 0x00, 0x00, + 0x00, 0x00}; // last 4 bytes are a masking key + std::vector header(closeFrame, closeFrame + 6); + txbuf.insert(txbuf.end(), header.begin(), header.end()); + } +}; + +easywsclient::WebSocket::pointer from_url(const std::string& url, bool useMask, + const std::string& origin) { + char host[512]; + int port; + char path[512]; + if (url.size() >= 512) { + fprintf(stderr, "ERROR: url size limit exceeded: %s\n", url.c_str()); + return NULL; + } + if (origin.size() >= 200) { + fprintf(stderr, "ERROR: origin size limit exceeded: %s\n", origin.c_str()); + return NULL; + } + if (false) { + } else if (sscanf(url.c_str(), "ws://%[^:/]:%d/%s", host, &port, path) == 3) { + } else if (sscanf(url.c_str(), "ws://%[^:/]/%s", host, path) == 2) { + port = 80; + } else if (sscanf(url.c_str(), "ws://%[^:/]:%d", host, &port) == 2) { + path[0] = '\0'; + } else if (sscanf(url.c_str(), "ws://%[^:/]", host) == 1) { + port = 80; + path[0] = '\0'; + } else { + fprintf(stderr, "ERROR: Could not parse WebSocket url: %s\n", url.c_str()); + return NULL; + } + //fprintf(stderr, "easywsclient: connecting: host=%s port=%d path=/%s\n", host, port, path); + socket_t sockfd = hostname_connect(host, port); + if (sockfd == INVALID_SOCKET) { + fprintf(stderr, "Unable to connect to %s:%d\n", host, port); + return NULL; + } + { + // XXX: this should be done non-blocking, + char line[1024]; + int status; + int i; + snprintf(line, 1024, "GET /%s HTTP/1.1\r\n", path); + ::send(sockfd, line, strlen(line), 0); + if (port == 80) { + snprintf(line, 1024, "Host: %s\r\n", host); + ::send(sockfd, line, strlen(line), 0); + } else { + snprintf(line, 1024, "Host: %s:%d\r\n", host, port); + ::send(sockfd, line, strlen(line), 0); + } + snprintf(line, 1024, "Upgrade: websocket\r\n"); + ::send(sockfd, line, strlen(line), 0); + snprintf(line, 1024, "Connection: Upgrade\r\n"); + ::send(sockfd, line, strlen(line), 0); + if (!origin.empty()) { + snprintf(line, 1024, "Origin: %s\r\n", origin.c_str()); + ::send(sockfd, line, strlen(line), 0); + } + snprintf(line, 1024, "Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"); + ::send(sockfd, line, strlen(line), 0); + snprintf(line, 1024, "Sec-WebSocket-Version: 13\r\n"); + ::send(sockfd, line, strlen(line), 0); + snprintf(line, 1024, "\r\n"); + ::send(sockfd, line, strlen(line), 0); + for (i = 0; + i < 2 || (i < 1023 && line[i - 2] != '\r' && line[i - 1] != '\n'); + ++i) { + if (recv(sockfd, line + i, 1, 0) == 0) { + return NULL; + } + } + line[i] = 0; + if (i == 1023) { + fprintf(stderr, "ERROR: Got invalid status line connecting to: %s\n", + url.c_str()); + return NULL; + } + if (sscanf(line, "HTTP/1.1 %d", &status) != 1 || status != 101) { + fprintf(stderr, "ERROR: Got bad status connecting to %s: %s", url.c_str(), + line); + return NULL; + } + // TODO: verify response headers, + while (true) { + for (i = 0; + i < 2 || (i < 1023 && line[i - 2] != '\r' && line[i - 1] != '\n'); + ++i) { + if (recv(sockfd, line + i, 1, 0) == 0) { + return NULL; + } + } + if (line[0] == '\r' && line[1] == '\n') { + break; + } + } + } + int flag = 1; + setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, + sizeof(flag)); // Disable Nagle's algorithm +#ifdef _WIN32 + u_long on = 1; + ioctlsocket(sockfd, FIONBIO, &on); +#else + fcntl(sockfd, F_SETFL, O_NONBLOCK); +#endif + //fprintf(stderr, "Connected to: %s\n", url.c_str()); + return easywsclient::WebSocket::pointer(new _RealWebSocket(sockfd, useMask)); +} + +} // namespace + +namespace easywsclient { + +WebSocket::pointer WebSocket::create_dummy() { + static pointer dummy = pointer(new _DummyWebSocket); + return dummy; +} + +WebSocket::pointer WebSocket::from_url(const std::string& url, + const std::string& origin) { + return ::from_url(url, true, origin); +} + +WebSocket::pointer WebSocket::from_url_no_mask(const std::string& url, + const std::string& origin) { + return ::from_url(url, false, origin); +} + +} // namespace easywsclient \ No newline at end of file diff --git a/engine/cli/utils/easywsclient.hpp b/engine/cli/utils/easywsclient.hpp new file mode 100644 index 000000000..1f0149d2c --- /dev/null +++ b/engine/cli/utils/easywsclient.hpp @@ -0,0 +1,85 @@ +#ifndef EASYWSCLIENT_HPP_20120819_MIOFVASDTNUASZDQPLFD +#define EASYWSCLIENT_HPP_20120819_MIOFVASDTNUASZDQPLFD + +// This code comes from: +// https://github.com/dhbaird/easywsclient +// +// To get the latest version: +// wget https://raw.github.com/dhbaird/easywsclient/master/easywsclient.hpp +// wget https://raw.github.com/dhbaird/easywsclient/master/easywsclient.cpp + +#include +#include + +namespace easywsclient { + +struct Callback_Imp { + virtual void operator()(const std::string& message) = 0; +}; +struct BytesCallback_Imp { + virtual void operator()(const std::vector& message) = 0; +}; + +class WebSocket { + public: + typedef WebSocket* pointer; + typedef enum readyStateValues { + CLOSING, + CLOSED, + CONNECTING, + OPEN + } readyStateValues; + + // Factories: + static pointer create_dummy(); + static pointer from_url(const std::string& url, + const std::string& origin = std::string()); + static pointer from_url_no_mask(const std::string& url, + const std::string& origin = std::string()); + + // Interfaces: + virtual ~WebSocket() {} + virtual void poll(int timeout = 0) = 0; // timeout in milliseconds + virtual void send(const std::string& message) = 0; + virtual void sendBinary(const std::string& message) = 0; + virtual void sendBinary(const std::vector& message) = 0; + virtual void sendPing() = 0; + virtual void close() = 0; + virtual readyStateValues getReadyState() const = 0; + + template + void dispatch(Callable callable) + // For callbacks that accept a string argument. + { // N.B. this is compatible with both C++11 lambdas, functors and C function pointers + struct _Callback : public Callback_Imp { + Callable& callable; + _Callback(Callable& callable) : callable(callable) {} + void operator()(const std::string& message) { callable(message); } + }; + _Callback callback(callable); + _dispatch(callback); + } + + template + void dispatchBinary(Callable callable) + // For callbacks that accept a std::vector argument. + { // N.B. this is compatible with both C++11 lambdas, functors and C function pointers + struct _Callback : public BytesCallback_Imp { + Callable& callable; + _Callback(Callable& callable) : callable(callable) {} + void operator()(const std::vector& message) { + callable(message); + } + }; + _Callback callback(callable); + _dispatchBinary(callback); + } + + protected: + virtual void _dispatch(Callback_Imp& callable) = 0; + virtual void _dispatchBinary(BytesCallback_Imp& callable) = 0; +}; + +} // namespace easywsclient + +#endif /* EASYWSCLIENT_HPP_20120819_MIOFVASDTNUASZDQPLFD */ \ No newline at end of file diff --git a/engine/common/download_task.h b/engine/common/download_task.h index 5994cdaed..39bf03a99 100644 --- a/engine/common/download_task.h +++ b/engine/common/download_task.h @@ -5,6 +5,7 @@ #include #include #include +#include enum class DownloadType { Model, Engine, Miscellaneous, CudaToolkit, Cortex }; @@ -55,6 +56,22 @@ inline std::string DownloadTypeToString(DownloadType type) { } } +inline DownloadType DownloadTypeFromString(const std::string& str) { + if (str == "Model") { + return DownloadType::Model; + } else if (str == "Engine") { + return DownloadType::Engine; + } else if (str == "Miscellaneous") { + return DownloadType::Miscellaneous; + } else if (str == "CudaToolkit") { + return DownloadType::CudaToolkit; + } else if (str == "Cortex") { + return DownloadType::Cortex; + } else { + return DownloadType::Miscellaneous; + } +} + struct DownloadTask { enum class Status { Pending, InProgress, Completed, Cancelled, Error }; @@ -116,3 +133,52 @@ struct DownloadTask { {"id", id}, {"type", DownloadTypeToString(type)}, {"items", dl_items}}; } }; + +namespace common { +inline DownloadItem GetDownloadItemFromJson(const Json::Value item_json) { + DownloadItem item; + if (!item_json["id"].isNull()) { + item.id = item_json["id"].asString(); + } + if (!item_json["downloadUrl"].isNull()) { + item.downloadUrl = item_json["downloadUrl"].asString(); + } + + if (!item_json["localPath"].isNull()) { + item.localPath = std::filesystem::path(item_json["localPath"].asString()); + } + + if (!item_json["checksum"].isNull()) { + item.checksum = item_json["checksum"].asString(); + } + + if (!item_json["bytes"].isNull()) { + item.bytes = item_json["bytes"].asUInt64(); + } + + if (!item_json["downloadedBytes"].isNull()) { + item.downloadedBytes = item_json["downloadedBytes"].asUInt64(); + } + + return item; +} + +inline DownloadTask GetDownloadTaskFromJson(const Json::Value item_json) { + DownloadTask task; + + if (!item_json["id"].isNull()) { + task.id = item_json["id"].asString(); + } + + if (!item_json["type"].isNull()) { + task.type = DownloadTypeFromString(item_json["type"].asString()); + } + + if (!item_json["items"].isNull() && item_json["items"].isArray()) { + for (auto const& i_json : item_json["items"]) { + task.items.emplace_back(GetDownloadItemFromJson(i_json)); + } + } + return task; +} +} // namespace common \ No newline at end of file diff --git a/engine/common/event.h b/engine/common/event.h index fe68bd04e..c23ebea5f 100644 --- a/engine/common/event.h +++ b/engine/common/event.h @@ -45,6 +45,22 @@ std::string DownloadEventTypeToString(DownloadEventType type) { return "Unknown"; } } + +inline DownloadEventType DownloadEventTypeFromString(const std::string& str) { + if (str == "DownloadStarted") { + return DownloadEventType::DownloadStarted; + } else if (str == "DownloadStopped") { + return DownloadEventType::DownloadStopped; + } else if (str == "DownloadUpdated") { + return DownloadEventType::DownloadUpdated; + } else if (str == "DownloadSuccess") { + return DownloadEventType::DownloadSuccess; + } else if (str == "DownloadError") { + return DownloadEventType::DownloadError; + } else { + return DownloadEventType::DownloadError; + } +} } // namespace struct DownloadEvent : public cortex::event::Event { @@ -57,6 +73,18 @@ struct DownloadEvent : public cortex::event::Event { DownloadEventType type_; DownloadTask download_task_; }; + +inline DownloadEvent GetDownloadEventFromJson(const Json::Value& item_json) { + DownloadEvent ev; + if (!item_json["type"].isNull()) { + ev.type_ = DownloadEventTypeFromString(item_json["type"].asString()); + } + + if (!item_json["task"].isNull()) { + ev.download_task_ = common::GetDownloadTaskFromJson(item_json["task"]); + } + return ev; +} } // namespace cortex::event constexpr std::size_t eventMaxSize = diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index f812e896d..602c81ab6 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -65,6 +65,45 @@ void Models::PullModel(const HttpRequestPtr& req, } } +void Models::GetModelPullInfo( + const HttpRequestPtr& req, + std::function&& callback) const { + if (!http_util::HasFieldInReq(req, callback, "model")) { + return; + } + + auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); + auto res = model_service_->GetModelPullInfo(model_handle); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto const& info = res.value(); + Json::Value ret; + Json::Value downloaded(Json::arrayValue); + for (auto const& s : info.downloaded_models) { + downloaded.append(s); + } + Json::Value avails(Json::arrayValue); + for (auto const& s : info.available_models) { + avails.append(s); + } + ret["id"] = info.id; + ret["modelSource"] = info.model_source; + ret["defaultBranch"] = info.default_branch; + ret["message"] = "Get model pull information successfully"; + ret["downloadedModels"] = downloaded; + ret["availableModels"] = avails; + ret["downloadUrl"] = info.download_url; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + } +} + void Models::AbortPullModel( const HttpRequestPtr& req, std::function&& callback) { diff --git a/engine/controllers/models.h b/engine/controllers/models.h index cacec2e48..b48a0d1aa 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -11,6 +11,7 @@ class Models : public drogon::HttpController { public: METHOD_LIST_BEGIN METHOD_ADD(Models::PullModel, "/pull", Post); + METHOD_ADD(Models::GetModelPullInfo, "/pull/info", Post); METHOD_ADD(Models::AbortPullModel, "/pull", Delete); METHOD_ADD(Models::ListModel, "", Get); METHOD_ADD(Models::GetModel, "/{1}", Get); @@ -39,6 +40,8 @@ class Models : public drogon::HttpController { void PullModel(const HttpRequestPtr& req, std::function&& callback); + void GetModelPullInfo(const HttpRequestPtr& req, + std::function&& callback) const; void AbortPullModel(const HttpRequestPtr& req, std::function&& callback); void ListModel(const HttpRequestPtr& req, diff --git a/engine/e2e-test/test_cli_engine_install.py b/engine/e2e-test/test_cli_engine_install.py index b4c27f3ef..572e62ed9 100644 --- a/engine/e2e-test/test_cli_engine_install.py +++ b/engine/e2e-test/test_cli_engine_install.py @@ -1,17 +1,29 @@ import platform import tempfile - +import os +from pathlib import Path import pytest from test_runner import run class TestCliEngineInstall: + def setup_and_teardown(self): + # Setup + success = start_server() + if not success: + raise Exception("Failed to start server") + + yield + + # Teardown + stop_server() def test_engines_install_llamacpp_should_be_successfully(self): exit_code, output, error = run( - "Install Engine", ["engines", "install", "llama-cpp"], timeout=None + "Install Engine", ["engines", "install", "llama-cpp"], timeout=None, capture = False ) - assert "Start downloading" in output, "Should display downloading message" + root = Path.home() + assert os.path.exists(root / "cortexcpp" / "engines" / "cortex.llamacpp" / "version.txt") assert exit_code == 0, f"Install engine failed with error: {error}" @pytest.mark.skipif(platform.system() != "Darwin", reason="macOS-specific test") @@ -32,9 +44,10 @@ def test_engines_install_onnx_on_tensorrt_should_be_failed(self): def test_engines_install_pre_release_llamacpp(self): exit_code, output, error = run( - "Install Engine", ["engines", "install", "llama-cpp", "-v", "v0.1.29"], timeout=600 + "Install Engine", ["engines", "install", "llama-cpp", "-v", "v0.1.29"], timeout=None, capture = False ) - assert "Start downloading" in output, "Should display downloading message" + root = Path.home() + assert os.path.exists(root / "cortexcpp" / "engines" / "cortex.llamacpp" / "version.txt") assert exit_code == 0, f"Install engine failed with error: {error}" def test_engines_should_fallback_to_download_llamacpp_engine_if_not_exists(self): diff --git a/engine/e2e-test/test_cli_engine_uninstall.py b/engine/e2e-test/test_cli_engine_uninstall.py index 5190cee7a..23b621b0e 100644 --- a/engine/e2e-test/test_cli_engine_uninstall.py +++ b/engine/e2e-test/test_cli_engine_uninstall.py @@ -12,7 +12,7 @@ def setup_and_teardown(self): raise Exception("Failed to start server") # Preinstall llamacpp engine - run("Install Engine", ["engines", "install", "llama-cpp"],timeout = None) + run("Install Engine", ["engines", "install", "llama-cpp"],timeout = None, capture = False) yield diff --git a/engine/e2e-test/test_cli_model_pull_cortexso_with_selection.py b/engine/e2e-test/test_cli_model_pull_cortexso_with_selection.py index 619833e16..8c3de8d98 100644 --- a/engine/e2e-test/test_cli_model_pull_cortexso_with_selection.py +++ b/engine/e2e-test/test_cli_model_pull_cortexso_with_selection.py @@ -1,8 +1,21 @@ from test_runner import popen +import os +from pathlib import Path class TestCliModelPullCortexsoWithSelection: + def setup_and_teardown(self): + # Setup + success = start_server() + if not success: + raise Exception("Failed to start server") + + yield + + # Teardown + stop_server() + def test_pull_model_from_cortexso_should_display_list_and_allow_user_to_choose( self, ): diff --git a/engine/e2e-test/test_cli_model_pull_direct_url.py b/engine/e2e-test/test_cli_model_pull_direct_url.py index 4907ced1f..b10d1593d 100644 --- a/engine/e2e-test/test_cli_model_pull_direct_url.py +++ b/engine/e2e-test/test_cli_model_pull_direct_url.py @@ -1,8 +1,20 @@ from test_runner import run - +import os +from pathlib import Path class TestCliModelPullDirectUrl: + def setup_and_teardown(self): + # Setup + success = start_server() + if not success: + raise Exception("Failed to start server") + + yield + + # Teardown + stop_server() + def test_model_pull_with_direct_url_should_be_success(self): exit_code, output, error = run( "Pull model", @@ -10,8 +22,10 @@ def test_model_pull_with_direct_url_should_be_success(self): "pull", "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/blob/main/tinyllama-1.1b-chat-v0.3.Q2_K.gguf", ], - timeout=None, + timeout=None, capture=False ) + root = Path.home() + assert os.path.exists(root / "cortexcpp" / "models" / "huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf") assert exit_code == 0, f"Model pull failed with error: {error}" # TODO: verify that the model has been pull successfully # TODO: skip this test. since download model is taking too long diff --git a/engine/e2e-test/test_cli_model_pull_from_cortexso.py b/engine/e2e-test/test_cli_model_pull_from_cortexso.py index c9c3f4c40..1791e39a6 100644 --- a/engine/e2e-test/test_cli_model_pull_from_cortexso.py +++ b/engine/e2e-test/test_cli_model_pull_from_cortexso.py @@ -4,6 +4,17 @@ class TestCliModelPullCortexso: + def setup_and_teardown(self): + # Setup + success = start_server() + if not success: + raise Exception("Failed to start server") + + yield + + # Teardown + stop_server() + def test_model_pull_with_direct_url_should_be_success(self): exit_code, output, error = run( "Pull model", diff --git a/engine/e2e-test/test_cli_model_pull_hugging_face_repository.py b/engine/e2e-test/test_cli_model_pull_hugging_face_repository.py index 50b7e832b..996ac086c 100644 --- a/engine/e2e-test/test_cli_model_pull_hugging_face_repository.py +++ b/engine/e2e-test/test_cli_model_pull_hugging_face_repository.py @@ -4,6 +4,17 @@ class TestCliModelPullHuggingFaceRepository: + def setup_and_teardown(self): + # Setup + success = start_server() + if not success: + raise Exception("Failed to start server") + + yield + + # Teardown + stop_server() + def test_model_pull_hugging_face_repository(self): """ Test pull model pervll/bge-reranker-v2-gemma-Q4_K_M-GGUF from issue #1017 diff --git a/engine/e2e-test/test_create_log_folder.py b/engine/e2e-test/test_create_log_folder.py index 8b667141b..5dbbd521c 100644 --- a/engine/e2e-test/test_create_log_folder.py +++ b/engine/e2e-test/test_create_log_folder.py @@ -10,6 +10,7 @@ class TestCreateLogFolder: @pytest.fixture(autouse=True) def setup_and_teardown(self): # Setup + stop_server() root = Path.home() if os.path.exists(root / "cortexcpp" / "logs"): shutil.rmtree(root / "cortexcpp" / "logs") diff --git a/engine/e2e-test/test_runner.py b/engine/e2e-test/test_runner.py index 320b8e332..20a8490a4 100644 --- a/engine/e2e-test/test_runner.py +++ b/engine/e2e-test/test_runner.py @@ -24,14 +24,14 @@ def getExecutablePath() -> str: # Execute a command -def run(test_name: str, arguments: List[str], timeout=timeout) -> (int, str, str): +def run(test_name: str, arguments: List[str], timeout=timeout, capture = True) -> (int, str, str): executable_path = getExecutablePath() print("Running:", test_name) print("Command:", [executable_path] + arguments) result = subprocess.run( [executable_path] + arguments, - capture_output=True, + capture_output=capture, text=True, timeout=timeout, ) diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 4dfe7fefb..9d2ef42c0 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -44,6 +44,17 @@ std::string NormalizeEngine(const std::string& engine) { } return engine; }; + +std::string Repo2Engine(const std::string& r) { + if (r == kLlamaRepo) { + return kLlamaEngine; + } else if (r == kOnnxRepo) { + return kOnnxEngine; + } else if (r == kTrtLlmRepo) { + return kTrtLlmEngine; + } + return r; +}; } // namespace cpp::result EngineService::GetEngineInfo( @@ -121,36 +132,23 @@ std::vector EngineService::GetEngineInfoList() const { return engines; } -cpp::result EngineService::InstallEngine( - const std::string& engine, const std::string& version, - const std::string& src) { - auto ne = NormalizeEngine(engine); - if (!src.empty()) { - return UnzipEngine(ne, version, src); - } else { - auto result = DownloadEngine(ne, version); - if (result.has_error()) { - return result; - } - return DownloadCuda(ne); - } -} - cpp::result EngineService::InstallEngineAsync( const std::string& engine, const std::string& version, const std::string& src) { // Although this function is called async, only download tasks are performed async - // TODO(sang) better handler for unzip and download scenarios auto ne = NormalizeEngine(engine); if (!src.empty()) { - return UnzipEngine(ne, version, src); - } else { - auto result = DownloadEngine(ne, version, true /*async*/); - if (result.has_error()) { - return result; + auto res = UnzipEngine(ne, version, src); + // If has error or engine is installed successfully + if (res.has_error() || res.value()) { + return res; } - return DownloadCuda(ne, true /*async*/); } + auto result = DownloadEngine(ne, version, true /*async*/); + if (result.has_error()) { + return result; + } + return DownloadCuda(ne, true /*async*/); } cpp::result EngineService::UnzipEngine( @@ -198,23 +196,21 @@ cpp::result EngineService::UnzipEngine( auto matched_variant = GetMatchedVariant(engine, variants); CTL_INF("Matched variant: " << matched_variant); + if (!found_cuda || matched_variant.empty()) { + return false; + } + if (matched_variant.empty()) { CTL_INF("No variant found for " << hw_inf_.sys_inf->os << "-" << hw_inf_.sys_inf->arch << ", will get engine from remote"); // Go with the remote flow - return DownloadEngine(engine, version); } else { auto engine_path = file_manager_utils::GetEnginesContainerPath(); archive_utils::ExtractArchive(path + "/" + matched_variant, engine_path.string()); } - // Not match any cuda binary, download from remote - if (!found_cuda) { - return DownloadCuda(engine); - } - return true; } @@ -329,10 +325,10 @@ cpp::result EngineService::DownloadEngine( CTL_INF("Engine folder path: " << engine_folder_path.string() << "\n"); auto local_path = engine_folder_path / file_name; - auto downloadTask{DownloadTask{.id = engine, + auto downloadTask{DownloadTask{.id = Repo2Engine(engine), .type = DownloadType::Engine, .items = {DownloadItem{ - .id = engine, + .id = Repo2Engine(engine), .downloadUrl = download_url, .localPath = local_path, }}}}; diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 521771325..0f491edc7 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -43,25 +43,23 @@ class EngineService { std::vector GetEngineInfoList() const; - cpp::result InstallEngine( - const std::string& engine, const std::string& version = "latest", - const std::string& src = ""); - cpp::result InstallEngineAsync( const std::string& engine, const std::string& version = "latest", const std::string& src = ""); cpp::result UninstallEngine(const std::string& engine); - private: cpp::result UnzipEngine(const std::string& engine, const std::string& version, const std::string& path); + private: cpp::result DownloadEngine( - const std::string& engine, const std::string& version = "latest", bool async = false); + const std::string& engine, const std::string& version = "latest", + bool async = false); - cpp::result DownloadCuda(const std::string& engine, bool async = false); + cpp::result DownloadCuda(const std::string& engine, + bool async = false); std::string GetMatchedVariant(const std::string& engine, const std::vector& variants); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 7966fd890..b49df3420 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -236,7 +236,7 @@ cpp::result ModelService::HandleDownloadUrlAsync( auto file_name{url_obj.pathParams.back()}; if (author == "cortexso") { - return DownloadModelFromCortexsoAsync(model_id); + return DownloadModelFromCortexsoAsync(model_id, url_obj.pathParams[3]); } if (url_obj.pathParams.size() < 5) { @@ -279,7 +279,7 @@ cpp::result ModelService::HandleDownloadUrlAsync( .localPath = local_path, }}}}; - auto on_finished = [&](const DownloadTask& finishedTask) { + auto on_finished = [author](const DownloadTask& finishedTask) { auto gguf_download_item = finishedTask.items[0]; ParseGguf(gguf_download_item, author); }; @@ -344,7 +344,7 @@ cpp::result ModelService::HandleUrl( .localPath = local_path, }}}}; - auto on_finished = [&](const DownloadTask& finishedTask) { + auto on_finished = [author](const DownloadTask& finishedTask) { auto gguf_download_item = finishedTask.items[0]; ParseGguf(gguf_download_item, author); }; @@ -381,8 +381,9 @@ ModelService::DownloadModelFromCortexsoAsync( if (model_entry.has_value()) { return cpp::fail("Please delete the model before downloading again"); } - auto on_finished = [branch, - unique_model_id](const DownloadTask& finishedTask) { + + auto on_finished = [unique_model_id, + branch](const DownloadTask& finishedTask) { const DownloadItem* model_yml_item = nullptr; auto need_parse_gguf = true; @@ -436,7 +437,7 @@ cpp::result ModelService::DownloadModelFromCortexso( } std::string model_id{name + ":" + branch}; - auto on_finished = [&, model_id](const DownloadTask& finishedTask) { + auto on_finished = [branch, model_id](const DownloadTask& finishedTask) { const DownloadItem* model_yml_item = nullptr; auto need_parse_gguf = true; @@ -628,7 +629,7 @@ cpp::result ModelService::StartModel( ASSIGN_IF_PRESENT(json_data, params_override, cache_type); ASSIGN_IF_PRESENT(json_data, params_override, mmproj); ASSIGN_IF_PRESENT(json_data, params_override, model_path); -#undef ASSIGN_IF_PRESENT; +#undef ASSIGN_IF_PRESENT CTL_INF(json_data.toStyledString()); assert(!!inference_svc_); @@ -750,6 +751,127 @@ cpp::result ModelService::GetModelStatus( } } +cpp::result ModelService::GetModelPullInfo( + const std::string& input) { + if (input.empty()) { + return cpp::fail( + "Input must be Cortex Model Hub handle or HuggingFace url!"); + } + auto model_name = input; + + if (string_utils::StartsWith(input, "https://")) { + auto url_obj = url_parser::FromUrlString(input); + + if (url_obj.host == kHuggingFaceHost) { + if (url_obj.pathParams[2] == "blob") { + url_obj.pathParams[2] = "resolve"; + } + } + auto author{url_obj.pathParams[0]}; + auto model_id{url_obj.pathParams[1]}; + auto file_name{url_obj.pathParams.back()}; + if (author == "cortexso") { + return ModelPullInfo{.id = model_id + ":" + url_obj.pathParams[3], + .downloaded_models = {}, + .available_models = {}, + .download_url = url_parser::FromUrl(url_obj)}; + } + return ModelPullInfo{.id = author + ":" + model_id + ":" + file_name, + .downloaded_models = {}, + .available_models = {}, + .download_url = url_parser::FromUrl(url_obj)}; + } + + if (input.find(":") != std::string::npos) { + auto parsed = string_utils::SplitBy(input, ":"); + if (parsed.size() != 2) { + return cpp::fail("Invalid model handle: " + input); + } + return ModelPullInfo{ + .id = input, .downloaded_models = {}, .available_models = {}, .download_url = input}; + } + + if (input.find("/") != std::string::npos) { + auto parsed = string_utils::SplitBy(input, "/"); + if (parsed.size() != 2) { + return cpp::fail("Invalid model handle: " + input); + } + + auto author = parsed[0]; + model_name = parsed[1]; + if (author != "cortexso") { + auto repo_info = + huggingface_utils::GetHuggingFaceModelRepoInfo(author, model_name); + + if (!repo_info.has_value()) { + return cpp::fail("Model not found"); + } + + if (!repo_info->gguf.has_value()) { + return cpp::fail( + "Not a GGUF model. Currently, only GGUF single file is " + "supported."); + } + + std::vector options{}; + for (const auto& sibling : repo_info->siblings) { + if (string_utils::EndsWith(sibling.rfilename, ".gguf")) { + options.push_back(sibling.rfilename); + } + } + + return ModelPullInfo{ + .id = author + ":" + model_name, + .downloaded_models = {}, + .available_models = options, + .download_url = + huggingface_utils::GetDownloadableUrl(author, model_name, "")}; + } + } + auto branches = + huggingface_utils::GetModelRepositoryBranches("cortexso", model_name); + if (branches.has_error()) { + return cpp::fail(branches.error()); + } + + auto default_model_branch = huggingface_utils::GetDefaultBranch(model_name); + + cortex::db::Models modellist_handler; + auto downloaded_model_ids = modellist_handler.FindRelatedModel(model_name) + .value_or(std::vector{}); + + std::vector avai_download_opts{}; + for (const auto& branch : branches.value()) { + if (branch.second.name == "main") { // main branch only have metadata. skip + continue; + } + auto model_id = model_name + ":" + branch.second.name; + if (std::find(downloaded_model_ids.begin(), downloaded_model_ids.end(), + model_id) != + downloaded_model_ids.end()) { // if downloaded, we skip it + continue; + } + avai_download_opts.emplace_back(model_id); + } + + if (avai_download_opts.empty()) { + // TODO: only with pull, we return + return cpp::fail("No variant available"); + } + std::optional normalized_def_branch = std::nullopt; + if (default_model_branch.has_value()) { + normalized_def_branch = model_name + ":" + default_model_branch.value(); + } + string_utils::SortStrings(downloaded_model_ids); + string_utils::SortStrings(avai_download_opts); + + return ModelPullInfo{.id = model_name, + .default_branch = normalized_def_branch.value_or(""), + .downloaded_models = downloaded_model_ids, + .available_models = avai_download_opts, + .model_source = "cortexso"}; +} + cpp::result ModelService::AbortDownloadModel( const std::string& task_id) { return download_service_->StopTask(task_id); diff --git a/engine/services/model_service.h b/engine/services/model_service.h index cdae6c6f1..495685982 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -7,6 +7,15 @@ #include "services/download_service.h" #include "services/inference_service.h" +struct ModelPullInfo { + std::string id; + std::string default_branch; + std::vector downloaded_models; + std::vector available_models; + std::string model_source; + std::string download_url; +}; + struct StartParameterOverride { std::optional cache_enabled; std::optional ngl; @@ -18,6 +27,7 @@ struct StartParameterOverride { std::optional model_path; bool bypass_model_check() const { return mmproj.has_value(); } }; + class ModelService { public: constexpr auto static kHuggingFaceHost = "huggingface.co"; @@ -65,6 +75,9 @@ class ModelService { cpp::result GetModelStatus( const std::string& host, int port, const std::string& model_handle); + cpp::result GetModelPullInfo( + const std::string& model_handle); + cpp::result HandleUrl(const std::string& url); cpp::result HandleDownloadUrlAsync( diff --git a/engine/test/components/test_event.cc b/engine/test/components/test_event.cc new file mode 100644 index 000000000..d10933f52 --- /dev/null +++ b/engine/test/components/test_event.cc @@ -0,0 +1,50 @@ + +#include "common/event.h" +#include "gtest/gtest.h" +#include "utils/json_helper.h" + +class EventTest : public ::testing::Test {}; + +TEST_F(EventTest, EventFromString) { + // clang-format off + std::string ev_str = R"({ + "task": { + "id": "tinyllama:gguf", + "items": [ + { + "bytes": 668788096, + "checksum": "N/A", + "downloadUrl": "https://huggingface.co/cortexso/tinyllama/resolve/gguf/model.gguf", + "downloadedBytes": 0, + "id": "model.gguf", + "localPath": + "/home/jan/cortexcpp/models/cortex.so/tinyllama/gguf/model.gguf" + }, + { + "bytes": 545, + "checksum": "N/A", + "downloadUrl": "https://huggingface.co/cortexso/tinyllama/resolve/gguf/model.yml", + "downloadedBytes": 0, + "id": "model.yml", + "localPath": + "/home/jan/cortexcpp/models/cortex.so/tinyllama/gguf/model.yml" + } + ], + "type": "Model" + }, + "type": "DownloadStarted" + })"; + // clang-format on + auto root = json_helper::ParseJsonString(ev_str); + std::cout << root.toStyledString() << std::endl; + + auto download_item = common::GetDownloadItemFromJson(root["task"]["items"][0]); + EXPECT_EQ(download_item.downloadUrl, root["task"]["items"][0]["downloadUrl"].asString()); + std::cout << download_item.ToString() << std::endl; + + auto download_task = common::GetDownloadTaskFromJson(root["task"]); + std::cout << download_task.ToString() << std::endl; + + auto ev = cortex::event::GetDownloadEventFromJson(root); + EXPECT_EQ(ev.type_, cortex::event::DownloadEventType::DownloadStarted); +} \ No newline at end of file diff --git a/engine/vcpkg.json b/engine/vcpkg.json index cfab8d2f5..1f8d31bcc 100644 --- a/engine/vcpkg.json +++ b/engine/vcpkg.json @@ -16,6 +16,7 @@ "tabulate", "eventpp", "sqlitecpp", - "trantor" + "trantor", + "indicators" ] }