diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py b/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py index e9da5b3a45c519..225f51c5b314d7 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmark_base.py @@ -4,6 +4,8 @@ from fbscribelogger import make_scribe_logger import torch._C._instruction_counter as i_counter +import torch._dynamo.config as config +from torch._dynamo.utils import CompileTimeInstructionCounter scribe_log_torch_benchmark_compile_time = make_scribe_logger( @@ -51,10 +53,19 @@ class BenchmarkBase(ABC): - _instruction_count = False + # measure total number of instruction spent in _work. + _enable_instruction_count = False + + # measure total number of instruction spent in convert_frame.compile_inner + # TODO is there other parts we need to add ? + _enable_compile_time_instruction_count = False def enable_instruction_count(self): - self._instruction_count = True + self._enable_instruction_count = True + return self + + def enable_compile_time_instruction_count(self): + self._enable_compile_time_instruction_count = True return self def name(self): @@ -64,29 +75,44 @@ def description(self): return "" @abstractmethod - def prepare(self): + def _prepare(self): pass @abstractmethod - def work(self): + def _work(self): pass - def prepare_once(self): # noqa: B027 + def _prepare_once(self): # noqa: B027 pass - def count_instructions(self): + def _count_instructions(self): print(f"collecting instruction count for {self.name()}") - self.prepare_once() - results = [] for i in range(10): - self.prepare() + self._prepare() id = i_counter.start() - self.work() + self._work() count = i_counter.end(id) print(f"instruction count for iteration {i} is {count}") - if i != 0: - results.append(count) + results.append(count) + return min(results) + + def _count_compile_time_instructions(self): + print(f"collecting compile time instruction count for {self.name()}") + config.record_compile_time_instruction_count = True + + results = [] + for i in range(10): + self._prepare() + # CompileTimeInstructionCounter.record is only called on convert_frame._compile_inner + # hence this will only count instruction count spent in compile_inner. + CompileTimeInstructionCounter.clear() + self._work() + count = CompileTimeInstructionCounter.value() + print(f"compile time instruction count for iteration {i} is {count}") + results.append(count) + + config.record_compile_time_instruction_count = False return min(results) def append_results(self, path): @@ -102,12 +128,36 @@ def print(self): print(f"{entry[0]},{entry[1]},{entry[2]}") def collect_all(self): + self._prepare_once() self.results = [] - if self._instruction_count: - r = self.count_instructions() + if ( + self._enable_instruction_count + and self._enable_compile_time_instruction_count + ): + raise RuntimeError( + "not supported until we update the logger, both logs to the same field now" + ) + + if self._enable_instruction_count: + r = self._count_instructions() self.results.append((self.name(), "instruction_count", r)) scribe_log_torch_benchmark_compile_time( name=self.name(), instruction_count=r, ) + if self._enable_compile_time_instruction_count: + r = self._count_compile_time_instructions() + + self.results.append( + ( + self.name(), + "compile_time_instruction_count", + r, + ) + ) + # TODO add a new field compile_time_instruction_count to the logger. + scribe_log_torch_benchmark_compile_time( + name=self.name(), + instruction_count=r, + ) return self diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py index 92a83c609a1f2d..694bcea5985771 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/update_hint_benchmark.py @@ -15,17 +15,17 @@ def name(self): def description(self): return "information at https://github.com/pytorch/pytorch/pull/129893" - def prepare_once(self): + def _prepare_once(self): torch._dynamo.config.capture_scalar_outputs = True random.seed(42) self.splits = torch.randint(10, (self.N,)) sz = self.splits.sum().item() self.input = torch.randn(sz) - def prepare(self): + def _prepare(self): torch._dynamo.reset() - def work(self): + def _work(self): @torch.compile(fullgraph=True) def f(a, b): xs = b.tolist() @@ -34,12 +34,15 @@ def f(a, b): torch._check(x <= self.N) return a.split(xs) - f(self.input, self.splits) + for i in range(1000): + f(self.input, self.splits) def main(): result_path = sys.argv[1] - Benchmark().enable_instruction_count().collect_all().append_results(result_path) + Benchmark().enable_compile_time_instruction_count().collect_all().append_results( + result_path + ) if __name__ == "__main__": diff --git a/torch/_C/_instruction_counter.pyi b/torch/_C/_instruction_counter.pyi new file mode 100644 index 00000000000000..4e3c27567eb228 --- /dev/null +++ b/torch/_C/_instruction_counter.pyi @@ -0,0 +1,4 @@ +# Defined in torch/csrc/instruction_counter/Module.cpp + +def start() -> int: ... +def end(id: int) -> int: ... diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 88081a87290070..8885ca33cc0765 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -374,6 +374,10 @@ def _get_optimize_ddp_mode(): # Inline inbuilt nn modules inline_inbuilt_nn_modules = not is_fbcode() +# When set, total compile time instruction count is recorded using +# torch._dynamo.utilsCompileTimeInstructionCounter. +record_compile_time_instruction_count = False + def default_debug_dir_root(): # [@compile_ignored: debug] diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 2f0da2be866db1..c1fadedd614405 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -27,6 +27,7 @@ import torch._logging from torch._C._dynamo.guards import GlobalStateGuard from torch._dynamo.distributed import get_compile_pg +from torch._dynamo.utils import CompileTimeInstructionCounter from torch._guards import compile_context, CompileContext, CompileId, tracing from torch._logging import structured from torch._utils_internal import ( @@ -652,7 +653,8 @@ def compile_inner( transform: Callable[[List[Instruction], Dict[str, Any]], Any], ) -> Optional[GuardedCode]: with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"): - return _compile_inner(code, one_graph, hooks, transform) + with CompileTimeInstructionCounter.record(): + return _compile_inner(code, one_graph, hooks, transform) @compile_time_strobelight_meta(phase_name="compile_inner") @maybe_cprofile diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index c8b2ac7a953af5..ae78f22374665a 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -64,6 +64,7 @@ from torch import fx from torch._C import ( _get_function_stack_at, + _instruction_counter, _len_torch_function_stack, _pop_torch_function_stack, _push_on_torch_function_stack, @@ -3203,3 +3204,41 @@ def get_user_object_from_id(obj_id): def store_user_object_weakref(obj): obj_id = id(obj) user_obj_id_to_weakref[obj_id] = weakref.ref(obj) + + +class CompileTimeInstructionCounter: + _counter: int = 0 + _id: int = -1 + _depth = 0 + + @classmethod + def start(cls) -> None: + cls._depth = cls._depth + 1 + if cls._depth == 1: + cls._id = _instruction_counter.start() + + @classmethod + def end(cls) -> None: + cls._depth = cls._depth - 1 + if cls._depth == 0: + cls._counter += _instruction_counter.end(cls._id) + cls._id = -1 + + @classmethod + def clear(cls) -> None: + cls._counter = 0 + + @classmethod + def value(cls) -> int: + return cls._counter + + @classmethod + @contextmanager + def record(cls): + try: + if config.record_compile_time_instruction_count: + cls.start() + yield + finally: + if config.record_compile_time_instruction_count: + cls.end()