-
Notifications
You must be signed in to change notification settings - Fork 910
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Validation loss #1864
base: sd3
Are you sure you want to change the base?
Validation loss #1864
Conversation
We only want to be enabling grad if we are training.
… to calculate validation loss. Balances the influence of different time steps on training performance (without affecting actual training results)
To allow this to be completed and support SD3 and Flux I decided to drop fixed_timesteps for validation. It adds a bunch of refactoring to support appropriately, but can be updated to support this after. This tries to limit the updated/refactored code so we can get this released, and we can iterate on it after. I also reverted train_db.py mostly to limit the changes in this PR. If the process_batch change is accepted, then each training script will need to be updated, and we can iterate on that. |
Added |
Last bit is the |
Thanks for the great job! I will review this sooner, perhaps tomorrow. |
…_trackers library code
I tried to make the accelerate.log() work approrpiately for wandb and tensorboard. Since wandb will drop logging if the step value is not the same or increasing value, it feel this compromise is OK. Added the following wandb metrics for epoch and val_step. This allows us to check the values of those in correlation and update the charts to reflect this. wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
# Define specific metrics to handle validation and epochs "steps"
wandb_tracker.define_metric("epoch", hidden=True)
wandb_tracker.define_metric("val_step", hidden=True) See how this shows "epoch" on the bottom vs the normal "Step". Which allows us to be able to set epochs to the X axis, which aligns it as expected. Only issue right now is for "val_step" incrementing because we do multiple runs on the same "global_step" for this validation current metric it is only recording the last one here. Probably something to be fixed but this value has a lot of noise anyways, and it's the only one that effected right now. On tensorboard it might require more work to make it work as well as wandb but I think it's ok for now. Before we would set the "epoch" as the step value Other factor is it adds metrics for I have tested everything in the checklist but haven't tested SDXL, SD3, Flux model network training. At this point everything should be completed and ready for a full review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, the code is very well written.
I did some reviewing before actually running it and testing it. Please check it out.
It seems that SD1.5, SDXL, SD3, FLUX.1, and ControlNet etc. training no longer work with this PR. Are there any plans to support them? If it's difficult, I will either make them compatible with validation loss, or make them work as before, after merging. |
… fix multiple datasets
I'm not sure what you mean in that the training no longer works? I have been testing with SD1.5 network training but would expect all these to work. The dataset issue you mentioned above could be causing an issue with getting all the datasets. But I would expect that they all would work at least. If you can hint at what is not working, I can try to resolve it. I have added your suggestions in the review. I set Thanks for the review, apologies for the mistakes. |
Not at all - your code is excellent and very well structured. Thank you for taking the time to submit this PR. In
However
I haven't run the code yet, so I apologize if it actually works. |
I have updated all calls to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the update. I'm sorry the code is complicated.
I'm currently checking the training script to see if it still works as before. There was one problem (in two places), so I would appreciate it if you could check it.
validation_seed: int, | ||
validation_split: float, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The is_training_dataset
argument is not added to the constructor of FineTuningDataset, which seems to cause the error TypeError: FineTuningDataset.__init__() got an unexpected keyword argument 'is_training_dataset'
. The same problem may occur with ControlNetDataset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a variable for extra_dataset_params
and separated out DreamBoothDataset for the is_training_dataset
to be applied. This should leave the other datasets unaffected.
…plit check and warning
Divergence is the difference between training and validation to allow a clear value to indicate the difference between the two in the logs.
Adding metadata recording for validation arguments Add comments about the validation split for clarity of intention
Added divergence value for step and epoch, indicating the difference between training and validation. Will make it easier to see the difference and not have to rely on overlapping. Maybe a different term would be better as divergence might indicate how much it's moving apart and away from convergence. Also might be better to invert the current to go to the other way to match the loss values. Fixed a bunch of things with regularization images datasets and repeats. Fix some issues with validate every n steps (which is important when using repeats and regularization images) |
Bug: If text encoders are cached and validation is enabled, and validation is running the process_batch
Errors that the batch['input_ids_list'] is None |
Related #1856 #1858 #1165 #914
Original implementation by @rockerBOO
Timestep validation implementation by @gesen2egee
Updated implementation for sd3/flux by @hinablue
I went through and tried to merge the different PR's together. I probably messed up some things in the process.
One thing I wanted to note is that
process_batch
was made to limit duplication of the code for validation and training to keep them consistent. I implemented the timestep processing so it could work for both. Noted that it was using only debiased_estimation in other PR's but i didn't know why it was like that.train_db.py
I did not update appropriately to my goal of a unifiedprocess_batch
, as I do not have a good way to test them. I will try to get them in an acceptable state and we can refine it.I'm posting this a little early so others can view and give me feedback. I am still working on some issues with the code so let me know before you dive in to fix anything. Open to commits to this PR, can post them to this branch on my fork.
Testing
--network_train_text_encoder_only
--network_train_unet_only
Parameters
Validation dataset is for dreambooth datasets (text/image pairs) and will split the dataset into 2 parts, train_dataset and validation_dataset depending on the split.
--validation_seed
Validation seed for shuffling validation dataset, training--seed
used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング--seed
を使用する--validation_split
Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合--validate_every_n_steps
Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます--validate_every_n_epochs
Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます--max_validation_steps
Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行しますvalidation_seed
andvalidation_split
can be set inside the dataset_config.tomlI'm open to feedback about this approach and if anything needs to be fixed in the code to be accurate.