Skip to content

Commit

Permalink
Add compile time instruction count metric (pytorch#133834)
Browse files Browse the repository at this point in the history
 PYTHONPATH=$(pwd) python benchmarks/update_hint_benchmark.py out
as of this diff, compile_time_instruction_count counts the number of instruction from within
convert_frame.compile_inner
```
update_hint_regression,compile_time_instruction_count,10522459165
```
 will add result from CI once populated.

Pull Request resolved: pytorch#133834
Approved by: https://github.com/aorenste
  • Loading branch information
laithsakka authored and pytorchmergebot committed Aug 27, 2024
1 parent ef0f591 commit d6091c8
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 20 deletions.
78 changes: 64 additions & 14 deletions benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from fbscribelogger import make_scribe_logger

import torch._C._instruction_counter as i_counter
import torch._dynamo.config as config
from torch._dynamo.utils import CompileTimeInstructionCounter


scribe_log_torch_benchmark_compile_time = make_scribe_logger(
Expand Down Expand Up @@ -51,10 +53,19 @@


class BenchmarkBase(ABC):
_instruction_count = False
# measure total number of instruction spent in _work.
_enable_instruction_count = False

# measure total number of instruction spent in convert_frame.compile_inner
# TODO is there other parts we need to add ?
_enable_compile_time_instruction_count = False

def enable_instruction_count(self):
self._instruction_count = True
self._enable_instruction_count = True
return self

def enable_compile_time_instruction_count(self):
self._enable_compile_time_instruction_count = True
return self

def name(self):
Expand All @@ -64,29 +75,44 @@ def description(self):
return ""

@abstractmethod
def prepare(self):
def _prepare(self):
pass

@abstractmethod
def work(self):
def _work(self):
pass

def prepare_once(self): # noqa: B027
def _prepare_once(self): # noqa: B027
pass

def count_instructions(self):
def _count_instructions(self):
print(f"collecting instruction count for {self.name()}")
self.prepare_once()

results = []
for i in range(10):
self.prepare()
self._prepare()
id = i_counter.start()
self.work()
self._work()
count = i_counter.end(id)
print(f"instruction count for iteration {i} is {count}")
if i != 0:
results.append(count)
results.append(count)
return min(results)

def _count_compile_time_instructions(self):
print(f"collecting compile time instruction count for {self.name()}")
config.record_compile_time_instruction_count = True

results = []
for i in range(10):
self._prepare()
# CompileTimeInstructionCounter.record is only called on convert_frame._compile_inner
# hence this will only count instruction count spent in compile_inner.
CompileTimeInstructionCounter.clear()
self._work()
count = CompileTimeInstructionCounter.value()
print(f"compile time instruction count for iteration {i} is {count}")
results.append(count)

config.record_compile_time_instruction_count = False
return min(results)

def append_results(self, path):
Expand All @@ -102,12 +128,36 @@ def print(self):
print(f"{entry[0]},{entry[1]},{entry[2]}")

def collect_all(self):
self._prepare_once()
self.results = []
if self._instruction_count:
r = self.count_instructions()
if (
self._enable_instruction_count
and self._enable_compile_time_instruction_count
):
raise RuntimeError(
"not supported until we update the logger, both logs to the same field now"
)

if self._enable_instruction_count:
r = self._count_instructions()
self.results.append((self.name(), "instruction_count", r))
scribe_log_torch_benchmark_compile_time(
name=self.name(),
instruction_count=r,
)
if self._enable_compile_time_instruction_count:
r = self._count_compile_time_instructions()

self.results.append(
(
self.name(),
"compile_time_instruction_count",
r,
)
)
# TODO add a new field compile_time_instruction_count to the logger.
scribe_log_torch_benchmark_compile_time(
name=self.name(),
instruction_count=r,
)
return self
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@ def name(self):
def description(self):
return "information at https://github.com/pytorch/pytorch/pull/129893"

def prepare_once(self):
def _prepare_once(self):
torch._dynamo.config.capture_scalar_outputs = True
random.seed(42)
self.splits = torch.randint(10, (self.N,))
sz = self.splits.sum().item()
self.input = torch.randn(sz)

def prepare(self):
def _prepare(self):
torch._dynamo.reset()

def work(self):
def _work(self):
@torch.compile(fullgraph=True)
def f(a, b):
xs = b.tolist()
Expand All @@ -34,12 +34,15 @@ def f(a, b):
torch._check(x <= self.N)
return a.split(xs)

f(self.input, self.splits)
for i in range(1000):
f(self.input, self.splits)


def main():
result_path = sys.argv[1]
Benchmark().enable_instruction_count().collect_all().append_results(result_path)
Benchmark().enable_compile_time_instruction_count().collect_all().append_results(
result_path
)


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions torch/_C/_instruction_counter.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Defined in torch/csrc/instruction_counter/Module.cpp

def start() -> int: ...
def end(id: int) -> int: ...
4 changes: 4 additions & 0 deletions torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ def _get_optimize_ddp_mode():
# Inline inbuilt nn modules
inline_inbuilt_nn_modules = not is_fbcode()

# When set, total compile time instruction count is recorded using
# torch._dynamo.utilsCompileTimeInstructionCounter.
record_compile_time_instruction_count = False


def default_debug_dir_root():
# [@compile_ignored: debug]
Expand Down
4 changes: 3 additions & 1 deletion torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch._logging
from torch._C._dynamo.guards import GlobalStateGuard
from torch._dynamo.distributed import get_compile_pg
from torch._dynamo.utils import CompileTimeInstructionCounter
from torch._guards import compile_context, CompileContext, CompileId, tracing
from torch._logging import structured
from torch._utils_internal import (
Expand Down Expand Up @@ -652,7 +653,8 @@ def compile_inner(
transform: Callable[[List[Instruction], Dict[str, Any]], Any],
) -> Optional[GuardedCode]:
with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"):
return _compile_inner(code, one_graph, hooks, transform)
with CompileTimeInstructionCounter.record():
return _compile_inner(code, one_graph, hooks, transform)

@compile_time_strobelight_meta(phase_name="compile_inner")
@maybe_cprofile
Expand Down
39 changes: 39 additions & 0 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from torch import fx
from torch._C import (
_get_function_stack_at,
_instruction_counter,
_len_torch_function_stack,
_pop_torch_function_stack,
_push_on_torch_function_stack,
Expand Down Expand Up @@ -3203,3 +3204,41 @@ def get_user_object_from_id(obj_id):
def store_user_object_weakref(obj):
obj_id = id(obj)
user_obj_id_to_weakref[obj_id] = weakref.ref(obj)


class CompileTimeInstructionCounter:
_counter: int = 0
_id: int = -1
_depth = 0

@classmethod
def start(cls) -> None:
cls._depth = cls._depth + 1
if cls._depth == 1:
cls._id = _instruction_counter.start()

@classmethod
def end(cls) -> None:
cls._depth = cls._depth - 1
if cls._depth == 0:
cls._counter += _instruction_counter.end(cls._id)
cls._id = -1

@classmethod
def clear(cls) -> None:
cls._counter = 0

@classmethod
def value(cls) -> int:
return cls._counter

@classmethod
@contextmanager
def record(cls):
try:
if config.record_compile_time_instruction_count:
cls.start()
yield
finally:
if config.record_compile_time_instruction_count:
cls.end()

0 comments on commit d6091c8

Please sign in to comment.