Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation loss #1864

Open
wants to merge 77 commits into
base: sd3
Choose a base branch
from
Open

Conversation

rockerBOO
Copy link
Contributor

@rockerBOO rockerBOO commented Jan 3, 2025

Related #1856 #1858 #1165 #914

Original implementation by @rockerBOO
Timestep validation implementation by @gesen2egee
Updated implementation for sd3/flux by @hinablue

I went through and tried to merge the different PR's together. I probably messed up some things in the process.

One thing I wanted to note is that process_batch was made to limit duplication of the code for validation and training to keep them consistent. I implemented the timestep processing so it could work for both. Noted that it was using only debiased_estimation in other PR's but i didn't know why it was like that.

train_db.py I did not update appropriately to my goal of a unified process_batch, as I do not have a good way to test them. I will try to get them in an acceptable state and we can refine it.

I'm posting this a little early so others can view and give me feedback. I am still working on some issues with the code so let me know before you dive in to fix anything. Open to commits to this PR, can post them to this branch on my fork.

Testing

  • Test training code is actually training
  • Test validation epoch (Test validation every epoch)
  • Test validate per n steps (After n steps it will run a validation run)
  • Test validate per n epochs (After n epochs will run validation epochs)
  • Test max validation steps
  • Test validation split (The validation split should be split accordingly, 0.2 should produce 20% dataset of the primary dataset)
  • Test validation split from train_network.py arguments (--validation_split) as well as dataset_config.toml (validation_split=0.1)
  • Test validation seed (Seed is used for dataset shuffling only right now)
  • Test image latent caching (validation and training datasets)
  • Test tokenizing strategy (SD, SDXL, SD3, Flux)
  • Test text encoding strategy (SD, SDXL, SD3, Flux)
  • Test --network_train_text_encoder_only
  • Test --network_train_unet_only
  • Test training some text encoders (I think this is a feature?)
  • Test on SD1.5, SDXL, SD3, Flux LoRAs

Parameters

Validation dataset is for dreambooth datasets (text/image pairs) and will split the dataset into 2 parts, train_dataset and validation_dataset depending on the split.

  • --validation_seed Validation seed for shuffling validation dataset, training --seed used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング --seed を使用する
  • --validation_split Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合
  • --validate_every_n_steps Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます
  • --validate_every_n_epochs Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます
  • --max_validation_steps Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します

validation_seed and validation_split can be set inside the dataset_config.toml

I'm open to feedback about this approach and if anything needs to be fixed in the code to be accurate.

@rockerBOO
Copy link
Contributor Author

To allow this to be completed and support SD3 and Flux I decided to drop fixed_timesteps for validation. It adds a bunch of refactoring to support appropriately, but can be updated to support this after. This tries to limit the updated/refactored code so we can get this released, and we can iterate on it after.

I also reverted train_db.py mostly to limit the changes in this PR. If the process_batch change is accepted, then each training script will need to be updated, and we can iterate on that.

@rockerBOO
Copy link
Contributor Author

Added
--validate_every_n_epochs
Changed --validation_every_n_step to --validate_every_n_steps

@rockerBOO
Copy link
Contributor Author

Last bit is the accelerator.log() factor but otherwise should be in a good state now

@kohya-ss
Copy link
Owner

kohya-ss commented Jan 7, 2025

Thanks for the great job! I will review this sooner, perhaps tomorrow.

@rockerBOO
Copy link
Contributor Author

rockerBOO commented Jan 8, 2025

I tried to make the accelerate.log() work approrpiately for wandb and tensorboard. Since wandb will drop logging if the step value is not the same or increasing value, it feel this compromise is OK.

Added the following wandb metrics for epoch and val_step. This allows us to check the values of those in correlation and update the charts to reflect this.

            wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)

            # Define specific metrics to handle validation and epochs "steps"
            wandb_tracker.define_metric("epoch", hidden=True)
            wandb_tracker.define_metric("val_step", hidden=True)

See how this shows "epoch" on the bottom vs the normal "Step".

Screenshot 2025-01-07 at 22-32-57 astral-bee-160 landscape-kohya-lora – Weights   Biases

Which allows us to be able to set epochs to the X axis, which aligns it as expected.

Screen Shot 2025-01-07 at 22 37 01

Only issue right now is for "val_step" incrementing because we do multiple runs on the same "global_step" for this validation current metric it is only recording the last one here. Probably something to be fixed but this value has a lot of noise anyways, and it's the only one that effected right now.

Screenshot 2025-01-07 at 22-39-06 astral-bee-160 landscape-kohya-lora – Weights   Biases

On tensorboard it might require more work to make it work as well as wandb but I think it's ok for now.

Screenshot 2025-01-07 at 22-41-01 TensorBoard

Before we would set the "epoch" as the step value accelerator.log(log, step=epoch + 1) but this would break wandb. Right now we use the global_step value as the step accelerator.log(log, step=global_step) and in log = {"epoch": epoch + 1} but this isn't causing that value to be reflected in the similar spot as it was previously.

Other factor is it adds metrics for epoch and val_step in Tensorboard, but can be minimized to allow other trackers like wandb the flexibility. Maybe later it would be appropriate to have this be more nuanced than through accelerate.log directly.

I have tested everything in the checklist but haven't tested SDXL, SD3, Flux model network training.

At this point everything should be completed and ready for a full review.

Copy link
Owner

@kohya-ss kohya-ss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, the code is very well written.

I did some reviewing before actually running it and testing it. Please check it out.

library/config_util.py Outdated Show resolved Hide resolved
library/train_util.py Outdated Show resolved Hide resolved
library/train_util.py Outdated Show resolved Hide resolved
library/train_util.py Outdated Show resolved Hide resolved
library/train_util.py Outdated Show resolved Hide resolved
train_network.py Outdated Show resolved Hide resolved
@kohya-ss
Copy link
Owner

kohya-ss commented Jan 8, 2025

It seems that SD1.5, SDXL, SD3, FLUX.1, and ControlNet etc. training no longer work with this PR. Are there any plans to support them?

If it's difficult, I will either make them compatible with validation loss, or make them work as before, after merging.

@rockerBOO
Copy link
Contributor Author

rockerBOO commented Jan 8, 2025

It seems that SD1.5, SDXL, SD3, FLUX.1, and ControlNet etc. training no longer work with this PR. Are there any plans to support them?

I'm not sure what you mean in that the training no longer works? I have been testing with SD1.5 network training but would expect all these to work. The dataset issue you mentioned above could be causing an issue with getting all the datasets. But I would expect that they all would work at least. If you can hint at what is not working, I can try to resolve it.

I have added your suggestions in the review.

I set train_dataset_group.verify_bucket_reso_steps(32) for the base train_network.py to support previous SD1.5 behavior. If you feel 64 is more appropriate I can update my older datasets that had it at *32.

Thanks for the review, apologies for the mistakes.

@kohya-ss
Copy link
Owner

kohya-ss commented Jan 8, 2025

Not at all - your code is excellent and very well structured. Thank you for taking the time to submit this PR.

In config_util.py, the following will return two DatasetGroups (or one DatasetGroup and None):

    return (
        DatasetGroup(datasets),
        DatasetGroup(val_datasets) if val_datasets else None
    )

However flux_train.py etc. seem to expect one DatasetGroup.

        train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)

I haven't run the code yet, so I apologize if it actually works.

@rockerBOO
Copy link
Contributor Author

I have updated all calls to config_util.generate_dataset_group_by_blueprint to handle extracting from the Tuple and added the return type to that function to help with typechecking. Thanks for pointing it out.

Copy link
Owner

@kohya-ss kohya-ss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update. I'm sorry the code is complicated.

I'm currently checking the training script to see if it still works as before. There was one problem (in two places), so I would appreciate it if you could check it.

Comment on lines +2108 to +2109
validation_seed: int,
validation_split: float,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The is_training_dataset argument is not added to the constructor of FineTuningDataset, which seems to cause the error TypeError: FineTuningDataset.__init__() got an unexpected keyword argument 'is_training_dataset'. The same problem may occur with ControlNetDataset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a variable for extra_dataset_params and separated out DreamBoothDataset for the is_training_dataset to be applied. This should leave the other datasets unaffected.

Divergence is the difference between training and validation to
allow a clear value to indicate the difference between the two
in the logs.
Adding metadata recording for validation arguments
Add comments about the validation split for clarity of intention
@rockerBOO
Copy link
Contributor Author

Screenshot 2025-01-12 at 20-11-19 cyberpunk-boo-kohya-lora Workspace – Weights   Biases

Added divergence value for step and epoch, indicating the difference between training and validation. Will make it easier to see the difference and not have to rely on overlapping. Maybe a different term would be better as divergence might indicate how much it's moving apart and away from convergence. Also might be better to invert the current to go to the other way to match the loss values.

Fixed a bunch of things with regularization images datasets and repeats. Fix some issues with validate every n steps (which is important when using repeats and regularization images)

@rockerBOO
Copy link
Contributor Author

Bug: If text encoders are cached and validation is enabled, and validation is running the process_batch

input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]

Errors that the batch['input_ids_list'] is None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants