Skip to content

Commit

Permalink
Add training service
Browse files Browse the repository at this point in the history
- Supported operations: start, resume, pause, shutdown
- pytorch-3dunet package is used as the framework to create the models
  • Loading branch information
thodkatz committed Dec 12, 2024
1 parent 5ea5d3a commit 486ad1c
Show file tree
Hide file tree
Showing 17 changed files with 1,866 additions and 294 deletions.
106 changes: 106 additions & 0 deletions proto/training.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
syntax = "proto3";

package training;


message Empty {}


service Training {
rpc Init(TrainingConfig) returns (TrainingSessionId) {}

rpc Start(TrainingSessionId) returns (Empty) {}

rpc Resume(TrainingSessionId) returns (Empty) {}

rpc Pause(TrainingSessionId) returns (Empty) {}

rpc StreamUpdates(TrainingSessionId) returns (stream StreamUpdateResponse) {}

rpc GetLogs(TrainingSessionId) returns (GetLogsResponse) {}

rpc Save(TrainingSessionId) returns (Empty) {}

rpc Export(TrainingSessionId) returns (Empty) {}

rpc Predict(PredictRequest) returns (PredictResponse) {}

rpc GetStatus(TrainingSessionId) returns (GetStatusResponse) {}

rpc CloseTrainerSession(TrainingSessionId) returns (Empty) {}
}

message TrainingSessionId {
string id = 1;
}

message Logs {
enum ModelPhase {
Train = 0;
Eval = 1;
}
ModelPhase mode = 1;
double eval_score = 2;
double loss = 3;
uint32 iteration = 4;
}


message StreamUpdateResponse {
uint32 best_model_idx = 1;
Logs logs = 2;
}


message GetLogsResponse {
repeated Logs logs = 1;
}


message NamedInt {
uint32 size = 1;
string name = 2;
}


message Tensor {
bytes buffer = 1;
string dtype = 2;
repeated NamedInt shape = 4;
}


message PredictRequest {
repeated Tensor tensors = 1;
TrainingSessionId id = 2;
}


message PredictResponse {
uint32 best_model_idx = 1;
repeated Tensor tensors = 2;
}

message ValidationResponse {
double validation_score_average = 1;
}

message GetStatusResponse {
enum State {
Idle = 0;
Running = 1;
Paused = 2;
Failed = 3;
Finished = 4;
}
State state = 1;
}


message GetCurrentBestModelIdxResponse {
uint32 id = 1;
}

message TrainingConfig {
string yaml_content = 1;
}
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ max-line-length = 120

[flake8]
max-line-length = 120
ignore=E203
ignore=E203,W503
exclude = tiktorch/proto/*,vendor
Loading

0 comments on commit 486ad1c

Please sign in to comment.