Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Universal functions #29

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Universal functions #29

wants to merge 1 commit into from

Conversation

seanmcl
Copy link
Collaborator

@seanmcl seanmcl commented Jan 21, 2025

NumPy uses an abstraction of functions called Ufunc. It has a lot of widgets in NumPy, but for now we just support lifting element-wise operations and broadcasting.

binop a x y (fun x y => .ok (x * y))

def div (a : Type) [Div a] [Element a] (x y : Tensor) : Err Tensor :=
binop a x y (fun x y => .ok (x / y))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this check for divide by zero?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numpy just gives inf or nan in this case. E.g.

#np.full(5, 1, dtype='int8') / np.arange(5)
<ipython-input-26-4ea893b50512>:1: RuntimeWarning: divide by zero encountered in divide
  np.full(5, 1, dtype='int8') / np.arange(5)
array([       inf, 1.        , 0.5       , 0.33333333, 0.25      ])

Seems like the correct behavior to me. E.g. if we fail on nans we will have quite a different behavior than the current numpy and neuron compiler.

return Element.arrayScalar acc

-- Sum with a single axis.
def sum1 (a : Type) [Add a] [Zero a] [Element a] (arr : Tensor) (axis : Nat) : Err Tensor := do
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This came out pretty nice. Seems like the design is holding up well.


private def uniq [BEq a] (xs : List a) : Bool := match xs with
| [] | [_] => true
| x1 :: x2 :: xs => x1 != x2 && uniq (x2 :: xs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I don't understand this function.

uniq [1,2,1] == True ?

Maybe, the input is sorted?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, fixed.

| .none => sum0 a arr
| .some axes =>
if !(uniq axes) then .error "Duplicate axis elements" else
let axes := (List.mergeSort axes).reverse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah here is the sort, but maybe in the wrong place?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, thanks

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants