Skip to content

Commit

Permalink
Fix tensor shapes for elementwise binary operations with broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
Soumya Chatterjee committed May 26, 2023
1 parent ef3b76e commit 71922e4
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/ops/element_binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,21 @@ Tensor FFModel::binary(OperatorType op,
}
// Assert type match after broadcast
assert(ele->inputs[0]->data_type == ele->inputs[1]->data_type);

int numdim = in1->num_dims;
int dims[MAX_TENSOR_DIM];
for (int i = 0; i < numdim; i++) {
if (in1->dims[i] == 1) {
dims[i] = in2->dims[i];
} else if (in2->dims[i] == 1) {
dims[i] = in1->dims[i];
} else {
dims[i] = in1->dims[i];
}
}

ele->outputs[0] = create_tensor_legion_ordering(
in1->num_dims, in1->dims, ele->data_type, ele, 0, true /*create_grad*/);
in1->num_dims, dims, ele->data_type, ele, 0, true /*create_grad*/);
ele->add_int_property("inplace_a", inplace_a);
layers.push_back(ele);
return ele->outputs[0];
Expand Down

0 comments on commit 71922e4

Please sign in to comment.