Skip to content

Commit

Permalink
No need for a validation split, if eval_holdout_size has been specified.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 375127276
  • Loading branch information
afrozenator authored and copybara-github committed May 21, 2021
1 parent 86dc892 commit 675eae1
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 4 deletions.
6 changes: 3 additions & 3 deletions trax/data/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def Parallel( # pylint: disable=invalid-name
# Remove generators with zero counters
counters = list(counters)
fns = list(fns)
zeros = [j for j in range(len(counters)) if counters[j] != 0]
counters = [counters[j] for j in zeros]
fns = [fns[j] for j in zeros]
non_zeros = [j for j in range(len(counters)) if counters[j] != 0]
counters = [counters[j] for j in non_zeros]
fns = [fns[j] for j in non_zeros]
else:
counters = [1] * len(fns)

Expand Down
27 changes: 27 additions & 0 deletions trax/data/testdata/para_crawl/ende/1.2.0/dataset_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"citation": "@misc {paracrawl,\n title = \"ParaCrawl\",\n year = \"2018\",\n url = \"http://paracrawl.eu/download.html.\"\n}",
"configDescription": "Translation dataset from English to de.",
"configName": "ende",
"description": "Web-Scale Parallel Corpora for Official European Languages.",
"downloadSize": "1307754745",
"location": {
"urls": [
"https://paracrawl.eu/releases.html"
]
},
"name": "para_crawl",
"splits": [
{
"name": "train",
"numBytes": "3241",
"shardLengths": [
"10"
]
}
],
"supervisedKeys": {
"input": "en",
"output": "de"
},
"version": "1.2.0"
}
9 changes: 9 additions & 0 deletions trax/data/testdata/para_crawl/ende/1.2.0/features.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"type": "tensorflow_datasets.core.features.translation_feature.Translation",
"content": {
"languages": [
"de",
"en"
]
}
}
Binary file not shown.
3 changes: 2 additions & 1 deletion trax/data/tf_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def _train_and_eval_dataset(dataset_name,
if dataset_name != 'c4/multilingual' and tfds.Split.TRAIN not in splits:
raise ValueError('To train we require a train split in the dataset.')
train_split = tfds.Split.TRAIN if dataset_name != 'c4/multilingual' else 'en'
eval_split = None
train_examples = info.splits[train_split].num_examples
eval_holdout_examples = int(train_examples * eval_holdout_size)
if eval_holdout_examples > 0 or subsplit is not None:
Expand All @@ -248,7 +249,7 @@ def _train_and_eval_dataset(dataset_name,
'validation_mismatched' if use_alt_eval else 'validation_matched')
elif dataset_name == 'c4/multilingual':
eval_split = 'en-validation'
else:
elif eval_split is None:
if tfds.Split.VALIDATION not in splits and 'test' not in splits:
raise ValueError('We require a validation or test split in the dataset.')
eval_split = tfds.Split.VALIDATION
Expand Down
31 changes: 31 additions & 0 deletions trax/data/tf_inputs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,37 @@ def test_TFDS_single_host_with_eval_holdout(self):
print(f'Eval: {d}')
break

def test_TFDS_single_host_with_eval_holdout_no_valid_split(self):
train_ds_gen = tf_inputs.TFDS(
'para_crawl/ende',
data_dir=_TESTDATA,
train=True,
host_id=0,
keys=('en', 'de'),
n_hosts=1,
eval_holdout_size=0.1)

# Just ensure that this doesn't crash.
for d in train_ds_gen():
print(f'Train: {d}')
break

# para_crawl doesn't have a validation set, see that this still doesn't
# crash because of eval_holdout_set.
valid_ds_gen = tf_inputs.TFDS(
'para_crawl/ende',
data_dir=_TESTDATA,
train=False,
host_id=0,
keys=('en', 'de'),
n_hosts=1,
eval_holdout_size=0.1)

# Just ensure that this doesn't crash.
for d in valid_ds_gen():
print(f'Eval: {d}')
break

def test_TFDS_mnli_split_is_eval(self):
with mock.patch('tensorflow_datasets.load') as tfds_load:
with mock.patch('trax.data.tf_inputs.download_and_prepare',
Expand Down

0 comments on commit 675eae1

Please sign in to comment.