Skip to content

Commit

Permalink
CelebA dataset
Browse files Browse the repository at this point in the history
- Implements the dataset generation for CelebA to work with the new active sampling framework.

PiperOrigin-RevId: 476085835
  • Loading branch information
Uncertainty Baselines Team authored and copybara-github committed Sep 23, 2022
1 parent 2f9f4e9 commit 332eb78
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 19 deletions.
3 changes: 3 additions & 0 deletions experimental/shoshin/configs/celeb_a_resnet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def get_config() -> ml_collections.ConfigDict:
"""Get mlp config."""
config = base_config.get_config()

config.data.subgroup_ids = ('Blond_Hair',) # ('Blond_Hair')
config.data.subgroup_proportions = (0.01,) # (0.04, 0.012)

data = config.data
data.name = 'celeb_a'
data.num_classes = 2
Expand Down
260 changes: 241 additions & 19 deletions experimental/shoshin/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import os
from typing import Any, Dict, Iterator, Optional, Tuple, List, Union

import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds

Expand Down Expand Up @@ -470,7 +471,7 @@ def get_waterbirds_dataset(
to their respective combined datasets.
"""
split_size_in_pct = int(100 * initial_sample_proportion / num_splits)
reduced_datset_sz = int(100 * initial_sample_proportion)
reduced_dataset_sz = int(100 * initial_sample_proportion)
builder_kwargs = {
'subgroup_ids': subgroup_ids,
'subgroup_proportions': subgroup_proportions
Expand All @@ -479,7 +480,7 @@ def get_waterbirds_dataset(
'waterbirds_dataset',
split=[
f'validation[{k}%:{k+split_size_in_pct}%]'
for k in range(0, reduced_datset_sz, split_size_in_pct)
for k in range(0, reduced_dataset_sz, split_size_in_pct)
],
data_dir=DATA_DIR,
builder_kwargs=builder_kwargs,
Expand All @@ -490,7 +491,7 @@ def get_waterbirds_dataset(
'waterbirds_dataset',
split=[
f'train[{k}%:{k+split_size_in_pct}%]'
for k in range(0, reduced_datset_sz, split_size_in_pct)
for k in range(0, reduced_dataset_sz, split_size_in_pct)
],
data_dir=DATA_DIR,
builder_kwargs=builder_kwargs,
Expand Down Expand Up @@ -528,8 +529,233 @@ def get_waterbirds_dataset(
train_sample_ds=train_sample,
eval_ds=eval_datasets)

IMG_ALIGNED_DATA = ('https://drive.google.com/uc?export=download&'
'id=0B7EVK8r0v71pZjFTYXZWM3FlRnM')
EVAL_LIST = ('https://drive.google.com/uc?export=download&'
'id=0B7EVK8r0v71pY0NSMzRuSXJEVkk')
# Landmark coordinates: left_eye, right_eye etc.
LANDMARKS_DATA = ('https://drive.google.com/uc?export=download&'
'id=0B7EVK8r0v71pd0FJY3Blby1HUTQ')

# Attributes in the image (Eyeglasses, Mustache etc).
ATTR_DATA = ('https://drive.google.com/uc?export=download&'
'id=0B7EVK8r0v71pblRyaVFSWGxPY0U')

LANDMARK_HEADINGS = ('lefteye_x lefteye_y righteye_x righteye_y '
'nose_x nose_y leftmouth_x leftmouth_y rightmouth_x '
'rightmouth_y').split()
ATTR_HEADINGS = (
'5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs '
'Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair '
'Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair '
'Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache '
'Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline '
'Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings '
'Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young'
).split()

_CITATION = """\
@inproceedings{conf/iccv/LiuLWT15,
added-at = {2018-10-09T00:00:00.000+0200},
author = {Liu, Ziwei and Luo, Ping and Wang, Xiaogang and Tang, Xiaoou},
biburl = {https://www.bibsonomy.org/bibtex/250e4959be61db325d2f02c1d8cd7bfbb/dblp},
booktitle = {ICCV},
crossref = {conf/iccv/2015},
ee = {http://doi.ieeecomputersociety.org/10.1109/ICCV.2015.425},
interhash = {3f735aaa11957e73914bbe2ca9d5e702},
intrahash = {50e4959be61db325d2f02c1d8cd7bfbb},
isbn = {978-1-4673-8391-2},
keywords = {dblp},
pages = {3730-3738},
publisher = {IEEE Computer Society},
timestamp = {2018-10-11T11:43:28.000+0200},
title = {Deep Learning Face Attributes in the Wild.},
url = {http://dblp.uni-trier.de/db/conf/iccv/iccv2015.html#LiuLWT15},
year = 2015
}
"""

_DESCRIPTION = """\
CelebFaces Attributes Dataset (CelebA) is a large-scale face attributes dataset\
with more than 200K celebrity images, each with 40 attribute annotations. The \
images in this dataset cover large pose variations and background clutter. \
CelebA has large diversities, large quantities, and rich annotations, including\
- 10,177 number of identities,
- 202,599 number of face images, and
- 5 landmark locations, 40 binary attributes annotations per image.
The dataset can be employed as the training and test sets for the following \
computer vision tasks: face attribute recognition, face detection, and landmark\
(or facial part) localization.
Note: CelebA dataset may contain potential bias. The fairness indicators
[example](https://www.tensorflow.org/responsible_ai/fairness_indicators/tutorials/Fairness_Indicators_TFCO_CelebA_Case_Study)
goes into detail about several considerations to keep in mind while using the
CelebA dataset.
"""


class LocalCelebADataset(tfds.core.GeneratorBasedBuilder):
"""CelebA dataset. Aligned and cropped. With metadata."""

VERSION = tfds.core.Version('2.0.1')
SUPPORTED_VERSIONS = [
tfds.core.Version('2.0.0'),
]
RELEASE_NOTES = {
'2.0.1': 'New split API (https://tensorflow.org/datasets/splits)',
}

def __init__(self,
subgroup_ids: List[str],
subgroup_proportions: Optional[List[float]] = None,
label_attr: Optional[str] = 'Male',
**kwargs):
super(LocalCelebADataset, self).__init__(**kwargs)
self.subgroup_ids = subgroup_ids
self.label_attr = label_attr
if subgroup_proportions:
self.subgroup_proportions = subgroup_proportions
else:
self.subgroup_proportions = [1.] * len(subgroup_ids)

def _info(self):
return tfds.core.DatasetInfo(
builder=self,
features=tfds.features.FeaturesDict({
'example_id':
tfds.features.Text(),
'subgroup_id':
tfds.features.Text(),
'subgroup_label':
tfds.features.ClassLabel(num_classes=2),
'feature':
tfds.features.Image(
shape=(218, 178, 3), encoding_format='jpeg'),
'label':
tfds.features.ClassLabel(num_classes=2),
'image_filename':
tfds.features.Text(),
}),
supervised_keys=('feature', 'label', 'example_id'),
)

def _split_generators(self, dl_manager):
downloaded_dirs = dl_manager.download({
'img_align_celeba': IMG_ALIGNED_DATA,
'list_eval_partition': EVAL_LIST,
'list_attr_celeba': ATTR_DATA,
'landmarks_celeba': LANDMARKS_DATA,
})

# Load all images in memory (~1 GiB)
# Use split to convert: `img_align_celeba/000005.jpg` -> `000005.jpg`
all_images = {
os.path.split(k)[-1]: img for k, img in dl_manager.iter_archive(
downloaded_dirs['img_align_celeba'])
}
return [
tfds.core.SplitGenerator(
name=tfds.Split.TRAIN,
gen_kwargs={
'file_id': 0,
'downloaded_dirs': downloaded_dirs,
'downloaded_images': all_images,
'is_training': True,
}),
tfds.core.SplitGenerator(
name=tfds.Split.VALIDATION,
gen_kwargs={
'file_id': 1,
'downloaded_dirs': downloaded_dirs,
'downloaded_images': all_images,
'is_training': False,
}),
tfds.core.SplitGenerator(
name=tfds.Split.TEST,
gen_kwargs={
'file_id': 2,
'downloaded_dirs': downloaded_dirs,
'downloaded_images': all_images,
'is_training': False,
})
]

def _process_celeba_config_file(self, file_path):
"""Unpack the celeba config file.
The file starts with the number of lines, and a header.
Afterwards, there is a configuration for each file: one per line.
Args:
file_path: Path to the file with the configuration.
Returns:
keys: names of the attributes
values: map from the file name to the list of attribute values for
this file.
"""

with tf.io.gfile.GFile(file_path) as f:
data_raw = f.read()
lines = data_raw.split('\n')

keys = lines[1].strip().split()
values = {}
# Go over each line (skip the last one, as it is empty).
for line in lines[2:-1]:
row_values = line.strip().split()
# Each row start with the 'file_name' and then space-separated values.
values[row_values[0]] = [int(v) for v in row_values[1:]]
return keys, values

def _generate_examples(self, file_id, downloaded_dirs, downloaded_images,
is_training):
"""Yields examples."""

attr_path = downloaded_dirs['list_attr_celeba']

@register_dataset('celeb_a')
attributes = self._process_celeba_config_file(attr_path)
dataset = pd.DataFrame.from_dict(
attributes[1], orient='index', columns=attributes[0])

if is_training:
dataset_size = 300000
sampled_datasets = []
remaining_proportion = 1.
remaining_dataset = dataset.copy()
for idx, subgroup_id in enumerate(self.subgroup_ids):

subgroup_dataset = dataset[dataset[subgroup_id] == 1]
subgroup_sample_size = int(dataset_size *
self.subgroup_proportions[idx])
subgroup_dataset = subgroup_dataset.sample(min(len(subgroup_dataset),
subgroup_sample_size))
sampled_datasets.append(subgroup_dataset)
remaining_proportion -= self.subgroup_proportions[idx]
remaining_dataset = remaining_dataset[remaining_dataset[subgroup_id] ==
-1]

remaining_sample_size = int(dataset_size * remaining_proportion)
remaining_dataset = remaining_dataset.sample(min(len(remaining_dataset),
remaining_sample_size))
sampled_datasets.append(remaining_dataset)

dataset = pd.concat(sampled_datasets)
dataset = dataset.sample(min(len(dataset), dataset_size))
for file_name in dataset.index:
subgroup_id = self.subgroup_ids[0] if dataset.loc[file_name][
self.subgroup_ids[0]] == 1 else 'Not_' + self.subgroup_ids[0]
subgroup_label = 1 if subgroup_id in self.subgroup_ids else 0
label = 1 if dataset.loc[file_name][self.label_attr] == 1 else 0
record = {
'example_id': file_name,
'subgroup_id': subgroup_id,
'subgroup_label': subgroup_label,
'feature': downloaded_images[file_name],
'label': label,
'image_filename': file_name
}
yield file_name, record


@register_dataset('local_celeb_a')
def get_celeba_dataset(
num_splits: int, initial_sample_proportion: float,
subgroup_ids: List[str], subgroup_proportions: List[float],
Expand All @@ -549,47 +775,44 @@ def get_celeba_dataset(
combined training dataset, and a dictionary mapping evaluation dataset names
to their respective combined datasets.
"""
del subgroup_proportions, subgroup_ids
read_config = tfds.ReadConfig()
read_config.add_tfds_id = True # Set `True` to return the 'tfds_id' key

split_size_in_pct = int(100 * initial_sample_proportion / num_splits)
reduced_dataset_sz = int(100 * initial_sample_proportion)
builder_kwargs = {
'subgroup_ids': subgroup_ids,
'subgroup_proportions': subgroup_proportions
}
train_splits = tfds.load(
'celeb_a',
'local_celeb_a_dataset',
read_config=read_config,
split=[
f'train[:{k}%]+train[{k+split_size_in_pct}%:]'
for k in range(0, reduced_dataset_sz, split_size_in_pct)
],
builder_kwargs=builder_kwargs,
data_dir=DATA_DIR,
try_gcs=False,
as_supervised=True
)
val_splits = tfds.load(
'celeb_a',
'local_celeb_a_dataset',
read_config=read_config,
split=[
f'validation[{k}%:{k+split_size_in_pct}%]'
for k in range(0, reduced_dataset_sz, split_size_in_pct)
],
builder_kwargs=builder_kwargs,
data_dir=DATA_DIR,
try_gcs=False,
as_supervised=True
)
train_sample = tfds.load(
'celeb_a',
split='train_sample',
data_dir=DATA_DIR,
try_gcs=False,
as_supervised=True,
with_info=False)

test_ds = tfds.load(
'celeb_a',
'local_celeb_a_dataset',
split='test',
builder_kwargs=builder_kwargs,
data_dir=DATA_DIR,
try_gcs=False,
as_supervised=True,
with_info=False)

train_ds = gather_data_splits(list(range(num_splits)), train_splits)
Expand All @@ -602,5 +825,4 @@ def get_celeba_dataset(
train_splits,
val_splits,
train_ds,
train_sample_ds=train_sample,
eval_ds=eval_datasets)

0 comments on commit 332eb78

Please sign in to comment.