From 2f9f4e9d1147be4c33f55c2584954e260f734bec Mon Sep 17 00:00:00 2001 From: Uncertainty Baselines Team Date: Wed, 21 Sep 2022 16:41:17 -0700 Subject: [PATCH] Fix waterbirds dataset to work with updated xm pipeline PiperOrigin-RevId: 475955177 --- .../configs/waterbirds_resnet_config.py | 4 +- experimental/shoshin/data.py | 45 ++++++++----------- experimental/shoshin/generate_bias_table.py | 12 +++-- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/experimental/shoshin/configs/waterbirds_resnet_config.py b/experimental/shoshin/configs/waterbirds_resnet_config.py index c4907b204..27a8710c1 100644 --- a/experimental/shoshin/configs/waterbirds_resnet_config.py +++ b/experimental/shoshin/configs/waterbirds_resnet_config.py @@ -24,8 +24,8 @@ def get_config() -> ml_collections.ConfigDict: config = base_config.get_config() # Consider landbirds on water and waterbirds on land as subgroups. - config.subgroup_ids = ('0_1', '1_0') - config.subgroup_proportions = (0.04, 0.012) + config.data.subgroup_ids = () # ('0_1', '1_0') + config.data.subgroup_proportions = () # (0.04, 0.012) data = config.data data.name = 'waterbirds' diff --git a/experimental/shoshin/data.py b/experimental/shoshin/data.py index 949cb89c2..67903f61e 100644 --- a/experimental/shoshin/data.py +++ b/experimental/shoshin/data.py @@ -260,13 +260,10 @@ class WaterbirdsDataset(tfds.core.GeneratorBasedBuilder): def __init__(self, subgroup_ids: List[str], - initial_sample_proportion: Optional[float] = None, subgroup_proportions: Optional[List[float]] = None, **kwargs): super(WaterbirdsDataset, self).__init__(**kwargs) self.subgroup_ids = subgroup_ids - if initial_sample_proportion: - self.initial_sample_proportion = initial_sample_proportion if subgroup_proportions: self.subgroup_proportions = subgroup_proportions else: @@ -411,8 +408,7 @@ def filter_fn_subgroup(image, label, place, image_filename, subgroup_dataset = dataset.filter(filter_fn_subgroup) subgroup_sample_size = int(dataset_size * - self.subgroup_proportions[idx] * - self.initial_sample_proportion) + self.subgroup_proportions[idx]) subgroup_dataset = subgroup_dataset.take(subgroup_sample_size) sampled_datasets.append(subgroup_dataset) remaining_proportion -= self.subgroup_proportions[idx] @@ -428,8 +424,7 @@ def filter_fn_remaining(image, label, place, image_filename, separator='_'), self.subgroup_ids)) remaining_dataset = dataset.filter(filter_fn_remaining) - remaining_sample_size = int(dataset_size * remaining_proportion * - self.initial_sample_proportion) + remaining_sample_size = int(dataset_size * remaining_proportion) remaining_dataset = remaining_dataset.take(remaining_sample_size) sampled_datasets.append(remaining_dataset) @@ -456,14 +451,13 @@ def filter_fn_remaining(image, label, place, image_filename, @register_dataset('waterbirds') def get_waterbirds_dataset( - num_splits: int, batch_size: int, initial_sample_proportion: float, + num_splits: int, initial_sample_proportion: float, subgroup_ids: List[str], subgroup_proportions: List[float] ) -> Dataloader: """Returns datasets for training, validation, and possibly test sets. Args: num_splits: Integer for number of slices of the dataset. - batch_size: Integer for number of examples per batch. initial_sample_proportion: Float for proportion of entire training dataset to sample initially before active sampling begins. subgroup_ids: List of strings of IDs indicating subgroups. @@ -475,9 +469,9 @@ def get_waterbirds_dataset( combined training dataset, and a dictionary mapping evaluation dataset names to their respective combined datasets. """ - split_size_in_pct = int(100 / num_splits) + split_size_in_pct = int(100 * initial_sample_proportion / num_splits) + reduced_datset_sz = int(100 * initial_sample_proportion) builder_kwargs = { - 'initial_sample_proportion': initial_sample_proportion, 'subgroup_ids': subgroup_ids, 'subgroup_proportions': subgroup_proportions } @@ -485,10 +479,9 @@ def get_waterbirds_dataset( 'waterbirds_dataset', split=[ f'validation[{k}%:{k+split_size_in_pct}%]' - for k in range(0, 100, split_size_in_pct) + for k in range(0, reduced_datset_sz, split_size_in_pct) ], data_dir=DATA_DIR, - batch_size=batch_size, builder_kwargs=builder_kwargs, try_gcs=False, as_supervised=True) @@ -497,10 +490,9 @@ def get_waterbirds_dataset( 'waterbirds_dataset', split=[ f'train[{k}%:{k+split_size_in_pct}%]' - for k in range(0, 100, split_size_in_pct) + for k in range(0, reduced_datset_sz, split_size_in_pct) ], data_dir=DATA_DIR, - batch_size=batch_size, builder_kwargs=builder_kwargs, try_gcs=False, as_supervised=True) @@ -509,7 +501,6 @@ def get_waterbirds_dataset( 'waterbirds_dataset', split='train_sample', data_dir=DATA_DIR, - batch_size=batch_size, builder_kwargs=builder_kwargs, try_gcs=False, as_supervised=True, @@ -519,7 +510,6 @@ def get_waterbirds_dataset( 'waterbirds_dataset', split='test', data_dir=DATA_DIR, - batch_size=batch_size, builder_kwargs=builder_kwargs, try_gcs=False, as_supervised=True, @@ -541,31 +531,37 @@ def get_waterbirds_dataset( @register_dataset('celeb_a') def get_celeba_dataset( - num_splits: int, batch_size: int + num_splits: int, initial_sample_proportion: float, + subgroup_ids: List[str], subgroup_proportions: List[float], ) -> Dataloader: """Returns datasets for training, validation, and possibly test sets. Args: num_splits: Integer for number of slices of the dataset. - batch_size: Integer for number of examples per batch. + initial_sample_proportion: Float for proportion of entire training + dataset to sample initially before active sampling begins. + subgroup_ids: List of strings of IDs indicating subgroups. + subgroup_proportions: List of floats indicating proportion that each + subgroup should take in initial training dataset. Returns: A tuple containing the split training data, split validation data, the combined training dataset, and a dictionary mapping evaluation dataset names to their respective combined datasets. """ + del subgroup_proportions, subgroup_ids read_config = tfds.ReadConfig() read_config.add_tfds_id = True # Set `True` to return the 'tfds_id' key - split_size_in_pct = int(100 / num_splits) + split_size_in_pct = int(100 * initial_sample_proportion / num_splits) + reduced_dataset_sz = int(100 * initial_sample_proportion) train_splits = tfds.load( 'celeb_a', read_config=read_config, split=[ f'train[:{k}%]+train[{k+split_size_in_pct}%:]' - for k in range(0, 100, split_size_in_pct) + for k in range(0, reduced_dataset_sz, split_size_in_pct) ], data_dir=DATA_DIR, - batch_size=batch_size, try_gcs=False, as_supervised=True ) @@ -574,10 +570,9 @@ def get_celeba_dataset( read_config=read_config, split=[ f'validation[{k}%:{k+split_size_in_pct}%]' - for k in range(0, 100, split_size_in_pct) + for k in range(0, reduced_dataset_sz, split_size_in_pct) ], data_dir=DATA_DIR, - batch_size=batch_size, try_gcs=False, as_supervised=True ) @@ -585,7 +580,6 @@ def get_celeba_dataset( 'celeb_a', split='train_sample', data_dir=DATA_DIR, - batch_size=batch_size, try_gcs=False, as_supervised=True, with_info=False) @@ -594,7 +588,6 @@ def get_celeba_dataset( 'celeb_a', split='test', data_dir=DATA_DIR, - batch_size=batch_size, try_gcs=False, as_supervised=True, with_info=False) diff --git a/experimental/shoshin/generate_bias_table.py b/experimental/shoshin/generate_bias_table.py index d2f2f494f..ad84edd48 100644 --- a/experimental/shoshin/generate_bias_table.py +++ b/experimental/shoshin/generate_bias_table.py @@ -76,9 +76,13 @@ def main(_) -> None: if config.generate_bias_table: if config.round_idx == 0: dataloader = dataset_builder(config.data.num_splits, - config.data.initial_sample_proportion) + config.data.initial_sample_proportion, + config.data.subgroup_ids, + config.data.subgroup_proportions,) else: - dataloader = dataset_builder(config.data.num_splits, 1) + dataloader = dataset_builder(config.data.num_splits, 1, + config.data.subgroup_ids, + config.data.subgroup_proportions,) # Filter each split to only have examples from example_ids_table dataloader.train_splits = [ dataloader.train_ds.filter( @@ -95,7 +99,9 @@ def main(_) -> None: save_dir=config.output_dir, save_table=True) else: - dataloader = dataset_builder(config.data.num_splits, 1) + dataloader = dataset_builder( + config.data.num_splits, 1, config.data.subgroup_ids, + config.data.subgroup_proportions) dataloader = data.apply_batch(dataloader, config.data.batch_size) _ = generate_bias_table_lib.get_example_id_to_predictions_table( dataloader=dataloader,