-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added new framework APPy, and its implementation for one benchmark go…
…_fast
- Loading branch information
1 parent
0bc108b
commit 3f11fd0
Showing
4 changed files
with
76 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
{ | ||
"framework": { | ||
"simple_name": "appy", | ||
"full_name": "APPy", | ||
"prefix": "ap", | ||
"postfix": "appy", | ||
"class": "APPyFramework", | ||
"arch": "gpu" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# https://numba.readthedocs.io/en/stable/user/5minguide.html | ||
|
||
import torch | ||
import appy | ||
|
||
@appy.jit | ||
def go_fast(a): | ||
trace = torch.zeros(1, dtype=a.dtype) | ||
#pragma parallel for | ||
for i in range(a.shape[0]): | ||
#pragma atomic | ||
trace[0] += torch.tanh(a[i, i]) | ||
return a + trace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright 2021 ETH Zurich and the NPBench authors. All rights reserved. | ||
import pkg_resources | ||
|
||
from npbench.infrastructure import Benchmark, Framework | ||
from typing import Any, Callable, Dict | ||
|
||
|
||
class APPyFramework(Framework): | ||
""" A class for reading and processing framework information. """ | ||
|
||
def __init__(self, fname: str): | ||
""" Reads framework information. | ||
:param fname: The framework name. | ||
""" | ||
|
||
super().__init__(fname) | ||
|
||
def version(self) -> str: | ||
""" Return the framework version. """ | ||
return 0.1 | ||
|
||
# def copy_func(self) -> Callable: | ||
# """ Returns the copy-method that should be used | ||
# for copying the benchmark arguments. """ | ||
# import cupy | ||
# return cupy.asarray | ||
|
||
def copy_func(self) -> Callable: | ||
import torch | ||
torch.set_default_device('cuda') | ||
def inner(arr): | ||
copy = torch.from_numpy(arr).to('cuda') | ||
return copy | ||
return inner | ||
|
||
def imports(self) -> Dict[str, Any]: | ||
import torch | ||
import appy | ||
return {'torch': torch} | ||
|
||
def exec_str(self, bench: Benchmark, impl: Callable = None): | ||
""" Generates the execution-string that should be used to call | ||
the benchmark implementation. | ||
:param bench: A benchmark. | ||
:param impl: A benchmark implementation. | ||
""" | ||
|
||
arg_str = self.arg_str(bench, impl) | ||
# param_str = self.param_str(bench, impl) | ||
main_exec_str = "__npb_result = __npb_impl({a})".format(a=arg_str) | ||
sync_str = "torch.cuda.synchronize()" | ||
return main_exec_str + "; " + sync_str |