-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: cortex pull and cortex engines install CLI uses API server (#1550)
* fix: add ws and indicators * fix: more * fix: pull models info from server * fix: model_source * fix: rename * fix: download cortexso * fix: pull models * fix: remove comments * fix: rename * fix: change download UI * fix: comment out * fix: e2e tests * fix: run, start * fix: start server * fix: e2e * fix: remove * fix: abort model * fix: build * fix: clean code * fix: clean more * fix: normalize engine id * fix: use auto * fix: use vcpkg for indicators * fix: download progress --------- Co-authored-by: vansangpfiev <[email protected]>
- Loading branch information
1 parent
04c5c40
commit 00af979
Showing
32 changed files
with
1,526 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,68 @@ | ||
#include "engine_install_cmd.h" | ||
#include "server_start_cmd.h" | ||
#include "utils/download_progress.h" | ||
#include "utils/engine_constants.h" | ||
#include "utils/json_helper.h" | ||
#include "utils/logging_utils.h" | ||
|
||
namespace commands { | ||
|
||
void EngineInstallCmd::Exec(const std::string& engine, | ||
bool EngineInstallCmd::Exec(const std::string& engine, | ||
const std::string& version, | ||
const std::string& src) { | ||
auto result = engine_service_.InstallEngine(engine, version, src); | ||
if (result.has_error()) { | ||
CLI_LOG(result.error()); | ||
} else if(result && result.value()){ | ||
CLI_LOG("Engine " << engine << " installed successfully!"); | ||
// Handle local install, if fails, fallback to remote install | ||
if (!src.empty()) { | ||
auto res = engine_service_.UnzipEngine(engine, version, src); | ||
if (res.has_error()) { | ||
CLI_LOG(res.error()); | ||
return false; | ||
} | ||
if (res.value()) { | ||
CLI_LOG("Engine " << engine << " installed successfully!"); | ||
return true; | ||
} | ||
} | ||
|
||
// Start server if server is not started yet | ||
if (!commands::IsServerAlive(host_, port_)) { | ||
CLI_LOG("Starting server ..."); | ||
commands::ServerStartCmd ssc; | ||
if (!ssc.Exec(host_, port_)) { | ||
return false; | ||
} | ||
} | ||
|
||
httplib::Client cli(host_ + ":" + std::to_string(port_)); | ||
Json::Value json_data; | ||
auto data_str = json_data.toStyledString(); | ||
cli.set_read_timeout(std::chrono::seconds(60)); | ||
auto res = cli.Post("/v1/engines/install/" + engine, httplib::Headers(), | ||
data_str.data(), data_str.size(), "application/json"); | ||
|
||
if (res) { | ||
if (res->status != httplib::StatusCode::OK_200) { | ||
auto root = json_helper::ParseJsonString(res->body); | ||
CLI_LOG(root["message"].asString()); | ||
return false; | ||
} | ||
} else { | ||
auto err = res.error(); | ||
CTL_ERR("HTTP error: " << httplib::to_string(err)); | ||
return false; | ||
} | ||
|
||
CLI_LOG("Start downloading ...") | ||
DownloadProgress dp; | ||
dp.Connect(host_, port_); | ||
if (!dp.Handle(engine)) | ||
return false; | ||
|
||
bool check_cuda_download = !system_info_utils::GetCudaVersion().empty(); | ||
if (check_cuda_download) { | ||
if (!dp.Handle("cuda")) | ||
return false; | ||
} | ||
|
||
CLI_LOG("Engine " << engine << " downloaded successfully!") | ||
return true; | ||
} | ||
}; // namespace commands |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,181 @@ | ||
#include "model_pull_cmd.h" | ||
#include <memory> | ||
#include "common/event.h" | ||
#include "database/models.h" | ||
#include "server_start_cmd.h" | ||
#include "utils/cli_selection_utils.h" | ||
#include "utils/download_progress.h" | ||
#include "utils/format_utils.h" | ||
#include "utils/huggingface_utils.h" | ||
#include "utils/json_helper.h" | ||
#include "utils/logging_utils.h" | ||
#include "utils/scope_exit.h" | ||
#include "utils/string_utils.h" | ||
#if defined(_WIN32) | ||
#include <signal.h> | ||
#endif | ||
|
||
namespace commands { | ||
void ModelPullCmd::Exec(const std::string& input) { | ||
auto result = model_service_.DownloadModel(input); | ||
if (result.has_error()) { | ||
CLI_LOG(result.error()); | ||
std::function<void(int)> shutdown_handler; | ||
inline void signal_handler(int signal) { | ||
if (shutdown_handler) { | ||
shutdown_handler(signal); | ||
} | ||
} | ||
std::optional<std::string> ModelPullCmd::Exec(const std::string& host, int port, | ||
const std::string& input) { | ||
|
||
// model_id: use to check the download progress | ||
// model: use as a parameter for pull API | ||
auto model_id = input; | ||
auto model = input; | ||
|
||
// Start server if server is not started yet | ||
if (!commands::IsServerAlive(host, port)) { | ||
CLI_LOG("Starting server ..."); | ||
commands::ServerStartCmd ssc; | ||
if (!ssc.Exec(host, port)) { | ||
return std::nullopt; | ||
} | ||
} | ||
|
||
// Get model info from Server | ||
httplib::Client cli(host + ":" + std::to_string(port)); | ||
cli.set_read_timeout(std::chrono::seconds(60)); | ||
Json::Value j_data; | ||
j_data["model"] = input; | ||
auto d_str = j_data.toStyledString(); | ||
auto res = cli.Post("/models/pull/info", httplib::Headers(), d_str.data(), | ||
d_str.size(), "application/json"); | ||
|
||
if (res) { | ||
if (res->status == httplib::StatusCode::OK_200) { | ||
// CLI_LOG(res->body); | ||
auto root = json_helper::ParseJsonString(res->body); | ||
auto id = root["id"].asString(); | ||
bool is_cortexso = root["modelSource"].asString() == "cortexso"; | ||
auto default_branch = root["defaultBranch"].asString(); | ||
std::vector<std::string> downloaded; | ||
for (auto const& v : root["downloadedModels"]) { | ||
downloaded.push_back(v.asString()); | ||
} | ||
std::vector<std::string> avails; | ||
for (auto const& v : root["availableModels"]) { | ||
avails.push_back(v.asString()); | ||
} | ||
auto download_url = root["downloadUrl"].asString(); | ||
|
||
if (downloaded.empty() && avails.empty()) { | ||
model_id = id; | ||
model = download_url; | ||
} else { | ||
if (is_cortexso) { | ||
auto selection = cli_selection_utils::PrintModelSelection( | ||
downloaded, avails, | ||
default_branch.empty() | ||
? std::nullopt | ||
: std::optional<std::string>(default_branch)); | ||
|
||
if (!selection.has_value()) { | ||
CLI_LOG("Invalid selection"); | ||
return std::nullopt; | ||
} | ||
model_id = selection.value(); | ||
model = model_id; | ||
} else { | ||
auto selection = cli_selection_utils::PrintSelection(avails); | ||
CLI_LOG("Selected: " << selection.value()); | ||
model_id = id + ":" + selection.value(); | ||
model = download_url + selection.value(); | ||
} | ||
} | ||
} else { | ||
auto root = json_helper::ParseJsonString(res->body); | ||
CLI_LOG(root["message"].asString()); | ||
return std::nullopt; | ||
} | ||
} else { | ||
auto err = res.error(); | ||
CTL_ERR("HTTP error: " << httplib::to_string(err)); | ||
return std::nullopt; | ||
} | ||
|
||
// Send request download model to server | ||
Json::Value json_data; | ||
json_data["model"] = model; | ||
auto data_str = json_data.toStyledString(); | ||
cli.set_read_timeout(std::chrono::seconds(60)); | ||
res = cli.Post("/v1/models/pull", httplib::Headers(), data_str.data(), | ||
data_str.size(), "application/json"); | ||
|
||
if (res) { | ||
if (res->status != httplib::StatusCode::OK_200) { | ||
auto root = json_helper::ParseJsonString(res->body); | ||
CLI_LOG(root["message"].asString()); | ||
return std::nullopt; | ||
} | ||
} else { | ||
auto err = res.error(); | ||
CTL_ERR("HTTP error: " << httplib::to_string(err)); | ||
return std::nullopt; | ||
} | ||
|
||
CLI_LOG("Start downloading ...") | ||
DownloadProgress dp; | ||
bool force_stop = false; | ||
|
||
shutdown_handler = [this, &dp, &host, &port, &model_id, &force_stop](int) { | ||
force_stop = true; | ||
AbortModelPull(host, port, model_id); | ||
dp.ForceStop(); | ||
}; | ||
|
||
utils::ScopeExit se([]() { shutdown_handler = {}; }); | ||
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) | ||
struct sigaction sigint_action; | ||
sigint_action.sa_handler = signal_handler; | ||
sigemptyset(&sigint_action.sa_mask); | ||
sigint_action.sa_flags = 0; | ||
sigaction(SIGINT, &sigint_action, NULL); | ||
sigaction(SIGTERM, &sigint_action, NULL); | ||
#elif defined(_WIN32) | ||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { | ||
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; | ||
}; | ||
SetConsoleCtrlHandler( | ||
reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true); | ||
#endif | ||
dp.Connect(host, port); | ||
if (!dp.Handle(model_id)) | ||
return std::nullopt; | ||
if (force_stop) | ||
return std::nullopt; | ||
CLI_LOG("Model " << model_id << " downloaded successfully!") | ||
return model_id; | ||
} | ||
|
||
bool ModelPullCmd::AbortModelPull(const std::string& host, int port, | ||
const std::string& task_id) { | ||
Json::Value json_data; | ||
json_data["taskId"] = task_id; | ||
auto data_str = json_data.toStyledString(); | ||
httplib::Client cli(host + ":" + std::to_string(port)); | ||
cli.set_read_timeout(std::chrono::seconds(60)); | ||
auto res = cli.Delete("/v1/models/pull", httplib::Headers(), data_str.data(), | ||
data_str.size(), "application/json"); | ||
if (res) { | ||
if (res->status == httplib::StatusCode::OK_200) { | ||
CTL_INF("Abort model pull successfully: " << task_id); | ||
return true; | ||
} else { | ||
auto root = json_helper::ParseJsonString(res->body); | ||
CLI_LOG(root["message"].asString()); | ||
return false; | ||
} | ||
} else { | ||
auto err = res.error(); | ||
CTL_ERR("HTTP error: " << httplib::to_string(err)); | ||
return false; | ||
} | ||
} | ||
}; // namespace commands |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.