Skip to content

Commit

Permalink
Return an incremental id instead of just pinging
Browse files Browse the repository at this point in the history
The response of the best model stream will return an id. The id is
increased by one, each time we have a new model. A client can identify
if an action has been performed by an outdated model based on the id. If
    the current is greater, then a new best model exists.
  • Loading branch information
thodkatz committed Jan 20, 2025
1 parent 3d93c94 commit 44bc634
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 21 deletions.
2 changes: 1 addition & 1 deletion proto/training.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ service Training {

rpc GetLogs(ModelSession) returns (GetLogsResponse) {}

rpc IsBestModel(ModelSession) returns (stream Empty) {}
rpc GetBestModelIdx(ModelSession) returns (stream GetBestModelIdxResponse) {}

rpc Save(SaveRequest) returns (Empty) {}

Expand Down
13 changes: 7 additions & 6 deletions tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,6 @@ def test_close_trainer_session_twice(self, grpc_stub):
grpc_stub.CloseTrainerSession(training_session_id)
assert "Unknown session" in excinfo.value.details()


@pytest.mark.parametrize(
"dims, shape",
[
Expand Down Expand Up @@ -671,17 +670,19 @@ def test_export_while_paused(self, grpc_stub):
grpc_stub.CloseTrainerSession(training_session_id)

def test_best_model_ping(self, grpc_stub):
training_session_id = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = grpc_stub.Init(
training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())
)

grpc_stub.Start(training_session_id)

responses = grpc_stub.IsBestModel(training_session_id)
responses = grpc_stub.GetBestModelIdx(training_session_id)
received_updates = 0
for response in responses:
assert isinstance(response, utils_pb2.Empty)
assert isinstance(response, training_pb2.GetBestModelIdxResponse)
assert response.id is not None
received_updates += 1

if received_updates >= 3:
if received_updates >= 2:
break

def test_close_session(self, grpc_stub):
Expand Down
4 changes: 2 additions & 2 deletions tiktorch/proto/training_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 10 additions & 10 deletions tiktorch/proto/training_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def __init__(self, channel):
request_serializer=utils__pb2.ModelSession.SerializeToString,
response_deserializer=training__pb2.GetLogsResponse.FromString,
)
self.IsBestModel = channel.unary_stream(
'/training.Training/IsBestModel',
self.GetBestModelIdx = channel.unary_stream(
'/training.Training/GetBestModelIdx',
request_serializer=utils__pb2.ModelSession.SerializeToString,
response_deserializer=utils__pb2.Empty.FromString,
response_deserializer=training__pb2.GetBestModelIdxResponse.FromString,
)
self.Save = channel.unary_unary(
'/training.Training/Save',
Expand Down Expand Up @@ -127,7 +127,7 @@ def GetLogs(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def IsBestModel(self, request, context):
def GetBestModelIdx(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
Expand Down Expand Up @@ -201,10 +201,10 @@ def add_TrainingServicer_to_server(servicer, server):
request_deserializer=utils__pb2.ModelSession.FromString,
response_serializer=training__pb2.GetLogsResponse.SerializeToString,
),
'IsBestModel': grpc.unary_stream_rpc_method_handler(
servicer.IsBestModel,
'GetBestModelIdx': grpc.unary_stream_rpc_method_handler(
servicer.GetBestModelIdx,
request_deserializer=utils__pb2.ModelSession.FromString,
response_serializer=utils__pb2.Empty.SerializeToString,
response_serializer=training__pb2.GetBestModelIdxResponse.SerializeToString,
),
'Save': grpc.unary_unary_rpc_method_handler(
servicer.Save,
Expand Down Expand Up @@ -361,7 +361,7 @@ def GetLogs(request,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def IsBestModel(request,
def GetBestModelIdx(request,
target,
options=(),
channel_credentials=None,
Expand All @@ -371,9 +371,9 @@ def IsBestModel(request,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(request, target, '/training.Training/IsBestModel',
return grpc.experimental.unary_stream(request, target, '/training.Training/GetBestModelIdx',
utils__pb2.ModelSession.SerializeToString,
utils__pb2.Empty.FromString,
training__pb2.GetBestModelIdxResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

Expand Down
4 changes: 2 additions & 2 deletions tiktorch/server/grpc/training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ def StreamUpdates(self, request: utils_pb2.ModelSession, context):
def GetLogs(self, request: utils_pb2.ModelSession, context):
raise NotImplementedError

def IsBestModel(self, request, context):
def GetBestModelIdx(self, request, context):
session = self._getTrainerSession(context, request)
prev_best_model_idx = None
while context.is_active():
current_best_model_idx = session.client.get_best_model_idx()
if current_best_model_idx != prev_best_model_idx:
prev_best_model_idx = current_best_model_idx
yield utils_pb2.Empty()
yield training_pb2.GetBestModelIdxResponse(id=str(current_best_model_idx))
time.sleep(1)
logger.info("Client disconnected. Stopping stream.")

Expand Down

0 comments on commit 44bc634

Please sign in to comment.