Skip to content

Commit

Permalink
Fix waterbirds dataset to work with updated xm pipeline
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 475955177
  • Loading branch information
Uncertainty Baselines Team authored and copybara-github committed Sep 21, 2022
1 parent f667871 commit 2f9f4e9
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 31 deletions.
4 changes: 2 additions & 2 deletions experimental/shoshin/configs/waterbirds_resnet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def get_config() -> ml_collections.ConfigDict:
config = base_config.get_config()

# Consider landbirds on water and waterbirds on land as subgroups.
config.subgroup_ids = ('0_1', '1_0')
config.subgroup_proportions = (0.04, 0.012)
config.data.subgroup_ids = () # ('0_1', '1_0')
config.data.subgroup_proportions = () # (0.04, 0.012)

data = config.data
data.name = 'waterbirds'
Expand Down
45 changes: 19 additions & 26 deletions experimental/shoshin/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,10 @@ class WaterbirdsDataset(tfds.core.GeneratorBasedBuilder):

def __init__(self,
subgroup_ids: List[str],
initial_sample_proportion: Optional[float] = None,
subgroup_proportions: Optional[List[float]] = None,
**kwargs):
super(WaterbirdsDataset, self).__init__(**kwargs)
self.subgroup_ids = subgroup_ids
if initial_sample_proportion:
self.initial_sample_proportion = initial_sample_proportion
if subgroup_proportions:
self.subgroup_proportions = subgroup_proportions
else:
Expand Down Expand Up @@ -411,8 +408,7 @@ def filter_fn_subgroup(image, label, place, image_filename,

subgroup_dataset = dataset.filter(filter_fn_subgroup)
subgroup_sample_size = int(dataset_size *
self.subgroup_proportions[idx] *
self.initial_sample_proportion)
self.subgroup_proportions[idx])
subgroup_dataset = subgroup_dataset.take(subgroup_sample_size)
sampled_datasets.append(subgroup_dataset)
remaining_proportion -= self.subgroup_proportions[idx]
Expand All @@ -428,8 +424,7 @@ def filter_fn_remaining(image, label, place, image_filename,
separator='_'), self.subgroup_ids))

remaining_dataset = dataset.filter(filter_fn_remaining)
remaining_sample_size = int(dataset_size * remaining_proportion *
self.initial_sample_proportion)
remaining_sample_size = int(dataset_size * remaining_proportion)
remaining_dataset = remaining_dataset.take(remaining_sample_size)
sampled_datasets.append(remaining_dataset)

Expand All @@ -456,14 +451,13 @@ def filter_fn_remaining(image, label, place, image_filename,

@register_dataset('waterbirds')
def get_waterbirds_dataset(
num_splits: int, batch_size: int, initial_sample_proportion: float,
num_splits: int, initial_sample_proportion: float,
subgroup_ids: List[str], subgroup_proportions: List[float]
) -> Dataloader:
"""Returns datasets for training, validation, and possibly test sets.
Args:
num_splits: Integer for number of slices of the dataset.
batch_size: Integer for number of examples per batch.
initial_sample_proportion: Float for proportion of entire training
dataset to sample initially before active sampling begins.
subgroup_ids: List of strings of IDs indicating subgroups.
Expand All @@ -475,20 +469,19 @@ def get_waterbirds_dataset(
combined training dataset, and a dictionary mapping evaluation dataset names
to their respective combined datasets.
"""
split_size_in_pct = int(100 / num_splits)
split_size_in_pct = int(100 * initial_sample_proportion / num_splits)
reduced_datset_sz = int(100 * initial_sample_proportion)
builder_kwargs = {
'initial_sample_proportion': initial_sample_proportion,
'subgroup_ids': subgroup_ids,
'subgroup_proportions': subgroup_proportions
}
val_splits = tfds.load(
'waterbirds_dataset',
split=[
f'validation[{k}%:{k+split_size_in_pct}%]'
for k in range(0, 100, split_size_in_pct)
for k in range(0, reduced_datset_sz, split_size_in_pct)
],
data_dir=DATA_DIR,
batch_size=batch_size,
builder_kwargs=builder_kwargs,
try_gcs=False,
as_supervised=True)
Expand All @@ -497,10 +490,9 @@ def get_waterbirds_dataset(
'waterbirds_dataset',
split=[
f'train[{k}%:{k+split_size_in_pct}%]'
for k in range(0, 100, split_size_in_pct)
for k in range(0, reduced_datset_sz, split_size_in_pct)
],
data_dir=DATA_DIR,
batch_size=batch_size,
builder_kwargs=builder_kwargs,
try_gcs=False,
as_supervised=True)
Expand All @@ -509,7 +501,6 @@ def get_waterbirds_dataset(
'waterbirds_dataset',
split='train_sample',
data_dir=DATA_DIR,
batch_size=batch_size,
builder_kwargs=builder_kwargs,
try_gcs=False,
as_supervised=True,
Expand All @@ -519,7 +510,6 @@ def get_waterbirds_dataset(
'waterbirds_dataset',
split='test',
data_dir=DATA_DIR,
batch_size=batch_size,
builder_kwargs=builder_kwargs,
try_gcs=False,
as_supervised=True,
Expand All @@ -541,31 +531,37 @@ def get_waterbirds_dataset(

@register_dataset('celeb_a')
def get_celeba_dataset(
num_splits: int, batch_size: int
num_splits: int, initial_sample_proportion: float,
subgroup_ids: List[str], subgroup_proportions: List[float],
) -> Dataloader:
"""Returns datasets for training, validation, and possibly test sets.
Args:
num_splits: Integer for number of slices of the dataset.
batch_size: Integer for number of examples per batch.
initial_sample_proportion: Float for proportion of entire training
dataset to sample initially before active sampling begins.
subgroup_ids: List of strings of IDs indicating subgroups.
subgroup_proportions: List of floats indicating proportion that each
subgroup should take in initial training dataset.
Returns:
A tuple containing the split training data, split validation data, the
combined training dataset, and a dictionary mapping evaluation dataset names
to their respective combined datasets.
"""
del subgroup_proportions, subgroup_ids
read_config = tfds.ReadConfig()
read_config.add_tfds_id = True # Set `True` to return the 'tfds_id' key
split_size_in_pct = int(100 / num_splits)
split_size_in_pct = int(100 * initial_sample_proportion / num_splits)
reduced_dataset_sz = int(100 * initial_sample_proportion)
train_splits = tfds.load(
'celeb_a',
read_config=read_config,
split=[
f'train[:{k}%]+train[{k+split_size_in_pct}%:]'
for k in range(0, 100, split_size_in_pct)
for k in range(0, reduced_dataset_sz, split_size_in_pct)
],
data_dir=DATA_DIR,
batch_size=batch_size,
try_gcs=False,
as_supervised=True
)
Expand All @@ -574,18 +570,16 @@ def get_celeba_dataset(
read_config=read_config,
split=[
f'validation[{k}%:{k+split_size_in_pct}%]'
for k in range(0, 100, split_size_in_pct)
for k in range(0, reduced_dataset_sz, split_size_in_pct)
],
data_dir=DATA_DIR,
batch_size=batch_size,
try_gcs=False,
as_supervised=True
)
train_sample = tfds.load(
'celeb_a',
split='train_sample',
data_dir=DATA_DIR,
batch_size=batch_size,
try_gcs=False,
as_supervised=True,
with_info=False)
Expand All @@ -594,7 +588,6 @@ def get_celeba_dataset(
'celeb_a',
split='test',
data_dir=DATA_DIR,
batch_size=batch_size,
try_gcs=False,
as_supervised=True,
with_info=False)
Expand Down
12 changes: 9 additions & 3 deletions experimental/shoshin/generate_bias_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,13 @@ def main(_) -> None:
if config.generate_bias_table:
if config.round_idx == 0:
dataloader = dataset_builder(config.data.num_splits,
config.data.initial_sample_proportion)
config.data.initial_sample_proportion,
config.data.subgroup_ids,
config.data.subgroup_proportions,)
else:
dataloader = dataset_builder(config.data.num_splits, 1)
dataloader = dataset_builder(config.data.num_splits, 1,
config.data.subgroup_ids,
config.data.subgroup_proportions,)
# Filter each split to only have examples from example_ids_table
dataloader.train_splits = [
dataloader.train_ds.filter(
Expand All @@ -95,7 +99,9 @@ def main(_) -> None:
save_dir=config.output_dir,
save_table=True)
else:
dataloader = dataset_builder(config.data.num_splits, 1)
dataloader = dataset_builder(
config.data.num_splits, 1, config.data.subgroup_ids,
config.data.subgroup_proportions)
dataloader = data.apply_batch(dataloader, config.data.batch_size)
_ = generate_bias_table_lib.get_example_id_to_predictions_table(
dataloader=dataloader,
Expand Down

0 comments on commit 2f9f4e9

Please sign in to comment.