-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Universal functions #29
base: main
Are you sure you want to change the base?
Conversation
binop a x y (fun x y => .ok (x * y)) | ||
|
||
def div (a : Type) [Div a] [Element a] (x y : Tensor) : Err Tensor := | ||
binop a x y (fun x y => .ok (x / y)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this check for divide by zero?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Numpy just gives inf or nan in this case. E.g.
#np.full(5, 1, dtype='int8') / np.arange(5)
<ipython-input-26-4ea893b50512>:1: RuntimeWarning: divide by zero encountered in divide
np.full(5, 1, dtype='int8') / np.arange(5)
array([ inf, 1. , 0.5 , 0.33333333, 0.25 ])
Seems like the correct behavior to me. E.g. if we fail on nans we will have quite a different behavior than the current numpy and neuron compiler.
TensorLib/Ufunc.lean
Outdated
return Element.arrayScalar acc | ||
|
||
-- Sum with a single axis. | ||
def sum1 (a : Type) [Add a] [Zero a] [Element a] (arr : Tensor) (axis : Nat) : Err Tensor := do |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This came out pretty nice. Seems like the design is holding up well.
|
||
private def uniq [BEq a] (xs : List a) : Bool := match xs with | ||
| [] | [_] => true | ||
| x1 :: x2 :: xs => x1 != x2 && uniq (x2 :: xs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I don't understand this function.
uniq [1,2,1] == True ?
Maybe, the input is sorted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, fixed.
| .none => sum0 a arr | ||
| .some axes => | ||
if !(uniq axes) then .error "Duplicate axis elements" else | ||
let axes := (List.mergeSort axes).reverse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah here is the sort, but maybe in the wrong place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test.
NumPy uses an abstraction of functions called Ufunc. It has a lot of widgets in NumPy, but for now we just support lifting element-wise operations and broadcasting.