Skip to content

Commit

Permalink
Add tvm based scheduling example
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Gerum committed Sep 5, 2024
1 parent 7d50498 commit c4e4836
Show file tree
Hide file tree
Showing 6 changed files with 1,205 additions and 7,424 deletions.
4 changes: 2 additions & 2 deletions hannah/nas/performance_prediction/nn_meter/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def __init__(self, hardware_name, predictor_version: Optional[float] = None):
hardware_name, predictor_version
)

def predict(self, model):
def predict(self, model, input=None):
self.tmp_dir = TemporaryDirectory()
tmp_dir = Path(self.tmp_dir.name)

logging.info("transfering model to onnx")
dummy_input = model.example_input_array
dummy_input = model.example_input_array if input is None else input
torch.onnx.export(model, dummy_input, tmp_dir / "model.onnx", verbose=False)
logging.info("Creating onnxrt-model")
onnx_model = onnx.load(tmp_dir / "model.onnx")
Expand Down
2 changes: 0 additions & 2 deletions hannah/nas/performance_prediction/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,6 @@ def predict(self, model, input=None):

result, std_dev = self.predictor.predict(dgl_graph)

print(result, std_dev)

metrics = {"val_error": result.item()}

logger.info("Predicted performance metrics")
Expand Down
1 change: 1 addition & 0 deletions hannah/nas/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def sample_candidates(
candidates = []
skip_ct = 0
while len(candidates) < num_total:
msglogger.info(f"Sampling candidate {len(candidates)}")
parameters = self.sample(constrain)
model = self.build_model(parameters)
estimated_metrics, satisfied_bounds = self.estimate_metrics(
Expand Down
Loading

0 comments on commit c4e4836

Please sign in to comment.