MLFlowLogger implementation: checkpoints downloaded during mlflow.pytorch.load_model. #19394
Replies: 2 comments 1 reply
-
Seems to have been us using the Bit of a gotcha I was not foreseeing without looking deep into the code. Is there a reason for choosing |
Beta Was this translation helpful? Give feedback.
-
It would be nice if the release notes for v2.5 would mention that this behavior was changed now and there is still no configurable prefix to the checkpoint, but instead it now no longer adds |
Beta Was this translation helpful? Give feedback.
-
TL;DR: It appears to me that lightnings MLFlow logger implementation's choice to put checkpoints into mlflow's model directory might lead to excessive downloads during model loading.
Details
It is my understanding that checkpoints are generally much larger than the models themselves, probably due to extra information such as optimizer state, etc.
Lightning's MLflow logger implementation
Furthermore, it seems that the default behavior of lightnings mlflowlogger is to upload checkpoints at the end of the run according to the ModelCheckpoints config. The checkpoint path seems hard-coded to the
checkpoints
subdirectory of the MLflow model.pytorch-lightning/src/lightning/pytorch/loggers/mlflow.py
Line 356 in 8646515
MLflows model loading behavior
the MLflow model loading code seems to download the full models directory when loading from a remote target.
To me this seems to mean that whenever one attempts to load such a model with
mlflow.pytorch.load_model(f"runs:/{run_id}/model")
logged with lightnings mlflow-logger implementation, one would always also download all the checkpoints. However it is my suspicion that the weights needed during unpickling are indata/model.pth
.Moving the checkpoints out of the model folder?
Ultimately my question is: is there a reason for putting the checkpoints as a sub-directory into
model
directory, or would it be better to put them under a top-levelcheckpoints
directory which to my understanding should avoid downloading all the checkpoints during model load?Am I missing something here?
Beta Was this translation helpful? Give feedback.
All reactions