Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Universal functions #29

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions TensorLib/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ import TensorLib.Slice
import TensorLib.Tensor
import TensorLib.Index
import TensorLib.Mgrid
import TensorLib.Ufunc
187 changes: 187 additions & 0 deletions TensorLib/Ufunc.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jean-Baptiste Tristan, Paul Govereau, Sean McLaughlin
-/

import TensorLib.Common
import TensorLib.Tensor
import TensorLib.Broadcast

/-!
Universal functions: https://numpy.org/doc/stable/reference/ufuncs.html
-/

namespace TensorLib
namespace Tensor
namespace Ufunc

private def binop (a : Type) [Element a] (x y : Tensor) (op : a -> a -> Err a) : Err Tensor :=
match Broadcast.broadcast { left := x.shape, right := y.shape } with
| .none => .error s!"Can't broadcast shapes ${x.shape} with {y.shape}"
| .some shape =>
if x.dtype != y.dtype then .error s!"Casting between dtypes is not implemented yet: {repr x.dtype} <> {repr y.dtype}" else
do
let mut arr := Tensor.empty x.dtype shape
let iter := DimsIter.make shape
for idx in iter do
let v <- Element.getDimIndex x idx
let w <- Element.getDimIndex y idx
let k <- op v w
let arr' <- Element.setDimIndex arr idx k
arr := arr'
.ok arr

def add (a : Type) [Add a] [Element a] (x y : Tensor) : Err Tensor :=
binop a x y (fun x y => .ok (x + y))

def sub (a : Type) [Sub a] [Element a] (x y : Tensor) : Err Tensor :=
binop a x y (fun x y => .ok (x - y))

def mul (a : Type) [Mul a] [Element a] (x y : Tensor) : Err Tensor :=
binop a x y (fun x y => .ok (x * y))

def div (a : Type) [Div a] [Element a] (x y : Tensor) : Err Tensor :=
binop a x y (fun x y => .ok (x / y))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this check for divide by zero?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numpy just gives inf or nan in this case. E.g.

#np.full(5, 1, dtype='int8') / np.arange(5)
<ipython-input-26-4ea893b50512>:1: RuntimeWarning: divide by zero encountered in divide
  np.full(5, 1, dtype='int8') / np.arange(5)
array([       inf, 1.        , 0.5       , 0.33333333, 0.25      ])

Seems like the correct behavior to me. E.g. if we fail on nans we will have quite a different behavior than the current numpy and neuron compiler.


/-
TODO:
- np.sum. Prove that np.sum(x, axis=(2, 4, 6)) == np.sum(np.sum(np.sum(x, axis=6), axis=4), axis=2) # and other variations
-/

-- Sum with no axis. Adds all the elements.
private def sum0 (a : Type) [Add a] [Zero a] [Element a] (arr : Tensor) : Err Tensor := do
let mut acc : a := 0
let mut iter := DimsIter.make arr.shape
for index in iter do
let n : a <- Element.getDimIndex arr index
acc := Add.add acc n
return Element.arrayScalar acc

-- Sum with a single axis.
private def sum1 (a : Type) [Add a] [Zero a] [Element a] (arr : Tensor) (axis : Nat) : Err Tensor := do
if arr.ndim <= axis then .error "axis out of range" else
let oldshape := arr.shape
let (leftShape, rightShape) := oldshape.val.splitAt axis
match rightShape with
| [] => .error "Invariant failure"
| dim :: dims =>
let newshape := Shape.mk $ leftShape ++ dims
let mut res := Tensor.zeros arr.dtype newshape
for index in DimsIter.make newshape do
let mut acc : a := 0
for i in [0:dim] do
let index' := index.insertIdx axis i
let v : a <- Element.getDimIndex arr index'
acc := acc + v
res <- Element.setDimIndex res index acc
return res

-- Remove duplicate elements in a sorted list
private def uniq [BEq a] (xs : List a) : Bool := match xs with
| [] | [_] => true
| x1 :: x2 :: xs => x1 != x2 && uniq (x2 :: xs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I don't understand this function.

uniq [1,2,1] == True ?

Maybe, the input is sorted?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, fixed.


def sum (a : Type) [Add a] [Zero a] [Element a] (arr : Tensor) (axes : Option (List Nat)) : Err Tensor :=
match axes with
| .none => sum0 a arr
| .some axes =>
let axes := (List.mergeSort axes).reverse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah here is the sort, but maybe in the wrong place?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, thanks

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test.

if !(uniq axes) then .error "Duplicate axis elements" else
match axes with
| [] => sum0 a arr
| axis :: axes => do
let mut res <- sum1 a arr axis
let rec loop (axes : List Nat) (acc : Tensor) : Err Tensor := match axes with
| [] => .ok acc
| axis :: axes => do
let acc <- sum1 a acc axis
let axes := axes.map fun n => n-1 -- When we remove an axis, all later axes point to one dimension less
loop axes acc
termination_by axes.length
loop axes res

private def hasTree0 (a : Type) [BEq a] [Element a] (arr : Tensor) (n : a) : Bool :=
arr.shape.val == [] && match Element.getPosition arr 0 with
| .error _ => false
| .ok (v : a) => v == n

private def hasTree1 (a : Type) [Repr a] [BEq a] [Element a] (arr : Tensor) (xs : List a) : Bool :=
arr.shape.val == [xs.length] && match arr.toTree a with
| .error _ => false
| .ok v => v == .root xs

#guard
let typ := BV8
let arr := get! $ (Element.arange typ 10).reshape (Shape.mk [2, 5])
!(sum typ arr (.some [0, 1, 0])).isOk &&
!(sum typ arr (.some [0, 0, 1])).isOk &&
!(sum typ arr (.some [7])).isOk

-- [[0, 1, 2, 3, 4],
-- [5, 6, 7, 8, 9]]
#guard
let typ := BV8
let arr := get! $ (Element.arange typ 10).reshape (Shape.mk [2, 5])
let x0 := get! $ sum typ arr .none
let x1 := get! $ sum typ arr (.some [])
let x2 := get! $ sum typ arr (.some [0])
let x3 := get! $ sum typ arr (.some [1])
let x4 := get! $ sum typ arr (.some [1, 0])
let x5 := get! $ sum typ arr (.some [0, 1])
let res :=
hasTree0 typ x0 45 &&
hasTree0 typ x1 45 &&
hasTree1 typ x2 [5, 7, 9, 11, 13] &&
hasTree1 typ x3 [10, 35] &&
hasTree0 typ x4 45 &&
hasTree0 typ x5 45
res

#guard
let typ := BV8
let x := Element.arange typ 10
let arr := get! $ add typ x x
hasTree1 typ arr [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]

#guard
let typ := BV8
let x := Element.arange typ 10
let y := Element.arrayScalar (7 : typ)
let arr := get! $ add typ x y
hasTree1 typ arr [7, 8, 9, 10, 11, 12, 13, 14, 15, 16]

/-! WIP example NKI kernel
"""
NKI kernel to compute element-wise addition of two input tensors

This kernel assumes strict input/output tile-sizes, of up-to [128,512]

Args:
a_input: a first input tensor, of shape [128,512]
b_input: a second input tensor, of shape [128,512]
c_output: an output tensor, of shape [128,512]
"""
private def nki_tensor_add_kernel_ (program_id0 program_id1 : Nat) (a_input b_input c_input : NumpyRepr) : Err Unit := do
let tp := BV64

-- Calculate tile offsets based on current 'program'
let offset_i_x : tp := program_id0 * 128
let offset_i_y : tp := program_id1 * 512
-- Generate tensor indices to index tensors a and b
let rx0 := Element.arange tp 128
let rx <- rx0.reshape [128, 1]
let ox := Element.arrayScalar offset_i_x
let ix <- Ufunc.add tp ox rx
let ry0 := Element.arange tp 128
let ry <- ry0.reshape [1, 512]
let oy := Element.arrayScalar offset_i_y
let iy <- Ufunc.add tp oy ry
let a_tile <- sorry -- load from a_input
let b_tile <- sorry -- load from b_input
let c_tile <- Ufunc.add tp a_tile b_tile
let () <- sorry -- store to c_input
.ok ()
-/

end Ufunc
Loading