Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tensor shapes for elementwise binary operations with broadcasting #738

Closed
wants to merge 12 commits into from
33 changes: 30 additions & 3 deletions python/flexflow/keras/layers/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,16 @@ def __init__(self, **kwargs):
def _calculate_inout_shape(self, input_tensors):
assert len(input_tensors) == 2, "check input_tensors"
self.input_shape = input_tensors[0].batch_shape
self.output_shape = input_tensors[0].batch_shape
self.output_shape = list(input_tensors[0].batch_shape)
for i, d in enumerate(input_tensors[1].batch_shape):
if self.output_shape[i] != d:
if self.output_shape[i] == 1 or d == 1:
self.output_shape[i] *= d
else:
raise AssertionError(
f"Tensor with shape {input_tensors[0].batch_shape} and "
f"{input_tensors[1].batch_shape} cannot be added")
self.output_shape = tuple(self.output_shape)
fflogger.debug("add output %s" %( str(self.output_shape)))

def subtract(input_tensors):
Expand All @@ -114,7 +123,16 @@ def __init__(self, **kwargs):
def _calculate_inout_shape(self, input_tensors):
assert len(input_tensors) == 2, "check input_tensors"
self.input_shape = input_tensors[0].batch_shape
self.output_shape = input_tensors[0].batch_shape
self.output_shape = list(input_tensors[0].batch_shape)
for i, d in enumerate(input_tensors[1].batch_shape):
if self.output_shape[i] != d:
if self.output_shape[i] == 1 or d == 1:
self.output_shape[i] *= d
else:
raise AssertionError(
f"Tensor with shape {input_tensors[0].batch_shape} and "
f"{input_tensors[1].batch_shape} cannot be subtracted")
self.output_shape = tuple(self.output_shape)
fflogger.debug("subtract output %s" %( str(self.output_shape)))

def multiply(input_tensors):
Expand All @@ -127,7 +145,16 @@ def __init__(self, **kwargs):
def _calculate_inout_shape(self, input_tensors):
assert len(input_tensors) == 2, "check input_tensors"
self.input_shape = input_tensors[0].batch_shape
self.output_shape = input_tensors[0].batch_shape
self.output_shape = list(input_tensors[0].batch_shape)
for i, d in enumerate(input_tensors[1].batch_shape):
if self.output_shape[i] != d:
if self.output_shape[i] == 1 or d == 1:
self.output_shape[i] *= d
else:
raise AssertionError(
f"Tensor with shape {input_tensors[0].batch_shape} and "
f"{input_tensors[1].batch_shape} cannot be multiplied")
self.output_shape = tuple(self.output_shape)
fflogger.debug("multiply output %s" %( str(self.output_shape)))

class Maximum(_Merge):
Expand Down
15 changes: 14 additions & 1 deletion src/ops/element_binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,21 @@ Tensor FFModel::binary(OperatorType op,
}
// Assert type match after broadcast
assert(ele->inputs[0]->data_type == ele->inputs[1]->data_type);

int numdim = in1->num_dims;
int dims[MAX_TENSOR_DIM];
for (int i = 0; i < numdim; i++) {
if (in1->dims[i] == 1) {
dims[i] = in2->dims[i];
} else if (in2->dims[i] == 1) {
dims[i] = in1->dims[i];
} else {
dims[i] = in1->dims[i];
}
}

ele->outputs[0] = create_tensor_legion_ordering(
in1->num_dims, in1->dims, ele->data_type, ele, 0, true /*create_grad*/);
in1->num_dims, dims, ele->data_type, ele, 0, true /*create_grad*/);
ele->add_int_property("inplace_a", inplace_a);
layers.push_back(ele);
return ele->outputs[0];
Expand Down
Loading