diff --git a/trax/data/__init__.py b/trax/data/__init__.py index 0289325bc..0130dc385 100644 --- a/trax/data/__init__.py +++ b/trax/data/__init__.py @@ -52,6 +52,7 @@ from trax.data.inputs import generate_sequential_chunks from trax.data.inputs import Log from trax.data.inputs import MLM +from trax.data.inputs import Pad from trax.data.inputs import PadToLength from trax.data.inputs import Parallel from trax.data.inputs import Prefetch diff --git a/trax/data/inputs.py b/trax/data/inputs.py index 1f92b272b..8204395fe 100644 --- a/trax/data/inputs.py +++ b/trax/data/inputs.py @@ -1589,3 +1589,54 @@ def _pad_to_multiple_of(x, y, axis): pad_widths[axis] = (0, int(pad_len - x.shape[axis])) return np.pad(x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) + + +@gin.configurable(module='trax.data') +def Pad(len_map = None, padding = 'pre', value=0): # pylint: disable=invalid-name + """Pads the values to lengths given in `pad_len'. + + Args: + len_map: integer. Length of all sequences. + padding: string, 'pre' or 'post'. Defaults is 'pre': pad either before or after each sequence. + value: number. Default is zero. The pad value of the return array. + Returns: ndarray. Padded array of rank equal to array with shape increased according + to pad_width. + """ + @debug_data_pipeline.debug_pipeline + def _pad(generator): + for example in generator: + if not isinstance(len_map, int): + raise ValueError(f'pad_len should be of type integer.') + if padding != 'pre' and padding != 'post': + raise ValueError(f'padding parameter should be equal to \'pre\' or \'post\'.') + if len_map <= 0: + raise ValueError(f'len_map should be greather than zero.') + if isinstance(example, tuple): + example = list(example) + pb = len_map if padding == 'post' else 0 + pa = len_map if padding == 'pre' else 0 + if isinstance(example, list): + for i, e in enumerate(example): + if example[i].ndim > 1: + raise ValueError(f'example isn\'t a collection (list or tuple) of ndarray with' + f' dimension equal to one, but should be: {example}') + _pb = len(example[i]) if padding == 'post' else 0 + _pa = len_map if padding == 'pre' else len(example[i]) + len_map + example[i] = jnp.pad(e, (pb, pa), mode='constant', + constant_values=value)[_pb : _pa] + output = example + elif isinstance(example, np.ndarray) and example.ndim == 1: + _pb = example.size if padding == 'post' else 0 + _pa = len_map if padding == 'pre' else len(example) + len_map + output = np.pad(example, + pad_width=(pb, pa), mode='constant', + constant_values=value)[_pb : _pa] + else: + raise ValueError(f'example isn\'t a collection (list or tuple) of ndarray ' + f'or a single ndarray, with dimension equal to one, ' + f'but should be: {example}') + yield output + + if len_map is None: + raise ValueError('len_map parameter should be provided.') + return _pad