Skip to content

Commit

Permalink
align remove_columns in the formatted case
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Jan 6, 2025
1 parent 6457be6 commit a13ccf2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
11 changes: 4 additions & 7 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3336,13 +3336,11 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example
if with_rank:
additional_args += (rank,)
processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
returned_same_object = processed_inputs is inputs
if isinstance(processed_inputs, LazyDict):
processed_inputs = {
k: v for k, v in processed_inputs.data.items() if k not in processed_inputs.keys_to_format
}
returned_lazy_dict = True
else:
returned_lazy_dict = False
if update_data is None:
# Check if the function returns updated examples
updatable_types = (Mapping, pa.Table, pd.DataFrame)
Expand All @@ -3366,10 +3364,9 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example
if remove_columns is not None:
for column in remove_columns:
# `function` can modify input in-place causing column to be already removed.
if column in inputs_to_merge:
inputs_to_merge.pop(column)
if returned_lazy_dict and column in processed_inputs:
processed_inputs.pop(column)
inputs_to_merge.pop(column, None)
if returned_same_object:
processed_inputs.pop(column, None)
if check_same_num_examples:
input_num_examples = len(pa_inputs)
processed_inputs_num_examples = len(processed_inputs[next(iter(processed_inputs.keys()))])
Expand Down
3 changes: 1 addition & 2 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4356,13 +4356,12 @@ def f(x):
outputs = ds[:]
assert outputs == {"b": [-1, -1, 2, 3]}

# The formatted dataset version removes the lazy column from a different dictionary, hence it should be preserved in the output
ds = Dataset.from_dict({"a": [0, 1, 2, 3]})
ds = ds.with_format("numpy")
ds = ds.map(f, remove_columns=["a"])
ds = ds.with_format(None)
outputs = ds[:]
assert outputs == {"a": [0, 1, 2, 3], "b": [-1, -1, 2, 3]}
assert outputs == {"b": [-1, -1, 2, 3]}

def f(x):
"""May return a mix of LazyDict and regular Dict, but we replace a lazy column"""
Expand Down

0 comments on commit a13ccf2

Please sign in to comment.