Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about fine-tuning strategies #55

Closed
khalilv opened this issue Dec 6, 2024 · 4 comments
Closed

Questions about fine-tuning strategies #55

khalilv opened this issue Dec 6, 2024 · 4 comments

Comments

@khalilv
Copy link

khalilv commented Dec 6, 2024

Thank you for this great work! I have been working with Aurora to fine-tune it on some variables from ERA5. I had a few questions regarding the fine-tuning strategies you discuss in the paper, and I would greatly appreciate your time in clarifying a few aspects.

  1. In Supplementary D.3 and D.4, you mention two fine-tuning strategies: short lead-time fine-tuning and roll-out fine-tuning. It was unclear to me how both strategies were used to produce the results in Supplementary H. For example, was a single Aurora model fine-tuned first using the strategy in D.3 and then subsequently using the strategy in D.4? Or were two models trained on strategies D.3 and D.4 respectively, with the results in Supplementary H being a mixture of those two models (e.g., days 1–3 from D.3 and days 4–10 from D.4)?

If it was the former and a single model was trained using both strategies, were the optimizer states preserved between these two stages, or were just the weights loaded while the optimizer states were reset (i.e., would there be a 1K warm-up in the learning rate scheduler for both D.3 and D.4)?

  1. In Supplementary D.4, would a training epoch end after all samples from the dataset have been added to the replay buffer? The reason for this inquiry is that the model is trained on its generated data as well as data from the dataset. With a sampling rate of K, this would take roughly K times as long to see all samples in the dataset. Was this the strategy used during the rollout fine-tuning, or were steps taken to keep the absolute training time roughly constant? For example, was every Kth sample from the dataset added to the buffer instead, since the remaining samples would be generated by the model?

  2. During fine-tuning, was a validation set used to monitor the loss and perform early stopping? If so, which metric was monitored and at which auto-regressive step? Currently, I have been monitoring the single-step MAE.

  3. In Supplementary D.4, were the samples drawn from the replay buffer with or without replacement?

  4. In Supplementary D.4, how was the size of the replay buffer maintained? From my understanding, if a batch of samples is drawn from the buffer at every training step and the corresponding predictions are then added back in, every time the replay buffer is refreshed with a new sample from the dataset, the size of the replay buffer would increase. If my understanding is correct, what happens when the buffer reaches its maximum size?

Please let me know if any of these questions are unclear, and again, I appreciate you taking the time to discuss this. Your insights will be very helpful for my work. Thanks!

@wesselb
Copy link
Contributor

wesselb commented Dec 11, 2024

Hey @khalilv! Thank you for opening an issue. :) Let me attempt to answer your questions in detail.

Both fine-tuning strategies were used: the model was first fine-tuned with strategy in D.3 and then fine-tuned with the strategy in D.4. Between strategies D.3 and D.4, the optimiser state was not preserved. We did not use a warm-up for strategy D.4 because it uses LoRA and we did not really find a warm-up to make a big difference.

For strategy D.4, once all samples of the data have been added, training just continues by shuffling the entire dataset and continuing sampling from the dataset. You could count this as an "epoch", but training is not interrupted at this point and just continues as if the dataset were infinitely large. You're right that a sampling rate of K means that it takes much longer before the end of a dataset is reached. The time it takes to do roll-out fine-tuning is kept constant by limiting the number of training steps, which does mean that roll-out fine-tuning might only see a small fraction of the entire dataset.

To perform early stopping for roll-out fine-tuning, you can inspect the training losses and see if these keep decreasing or remain constant. You can also use a validation loss by doing full roll-outs on the validation data, but clearly this is expensive. Generally, because roll-out fine-tuning uses LoRA instead of full-architecture fine-tuning, the potential to overfit is greatly reduced and just running it for 1-3 days works well.

In strategy D.4, samples are drawn from the reply buffer with replacement.

Once the replay buffer reaches it maximum size, the next time a prediction or sample is added, the oldest element in the replay buffer is ejected. Once the replay buffer reaches maximum size, it therefore acts as a queue, where sampling takes a random element from the queue without taking that element out of the queue.

@khalilv
Copy link
Author

khalilv commented Dec 18, 2024

Hi @wesselb,

Thank you for your insights! They have clarified most of my questions.

Regarding the variable-weighted MAE loss function used during training, what were the typical values the model converged to? In my experiments, the values range between 50 and 100. I noticed in particular that the model struggles predicting specific humidity compared to other variables. Was this also something you had experienced? I am using ERA5 at 5.625 deg resolution for fine-tuning.

Additionally, I noticed that in the batch metadata, the rollout_step parameter is a constant rather than a tuple with length equal to the batch size. This restricts all samples within a batch to the same rollout_step, instead of allowing a mixture when using lora_mode = 'all'. Could you comment on the configuration used during fine-tuning in the paper (e.g. mode, rank, steps) and whether you think it might be beneficial to generalise rollout_step to support varying lead times within a batch?

Thanks again for your help!

@wesselb
Copy link
Contributor

wesselb commented Dec 18, 2024

Regarding the variable-weighted MAE loss function used during training, what were the typical values the model converged to? In my experiments, the values range between 50 and 100. I noticed in particular that the model struggles predicting specific humidity compared to other variables. Was this also something you had experienced? I am using ERA5 at 5.625 deg resolution for fine-tuning.

Do you compute the loss function over normalised variables? If you don't compute the loss function over normalised variables (and
of course normalised model outputs), you will be adding MAEs of variables with different scales, so one or a few variables might dominate the loss function. The aggregate total combined MAE should be on the order of O(0.1-1).

Additionally, I noticed that in the batch metadata, the rollout_step parameter is a constant rather than a tuple with length equal to the batch size.

The way it currently works is that rollout_step is the number of roll-out steps performed to produce this batch as a prediction. For data loaded from a dataset, which are data and not predictions, this value is equal to zero. For every autoregressive application of the model, rollout_step is incremented by one, which should allow lora_mode="all" to use different LoRA parameters for every roll-out step. We used lora_mode="single" because lora_mode="all" didn't seem to work well. We used LoRA rank equal to four. It could be worthwhile to experiment a little with this. Perhaps you can make lora_mode="all" work!

@khalilv
Copy link
Author

khalilv commented Jan 10, 2025

Thank you! I corrected the loss function and the results look as expected. We are now experimenting with rollout finetuning strategies to find one that is stable and can help the forecasts for longer lead times.

@khalilv khalilv closed this as completed Jan 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants