From 8bfdf34d53c6de9ad0191ad2f07e1f4ca871e723 Mon Sep 17 00:00:00 2001 From: lzhan94swu <92148689+lzhan94swu@users.noreply.github.com> Date: Mon, 15 Apr 2024 18:21:12 +0800 Subject: [PATCH] Fix: Modify the time_eval func in utils.py for killing timeout sub-process (#109) --- hcga/feature_class.py | 5 +---- hcga/utils.py | 38 ++++++++++++++++++++++++++++++++------ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/hcga/feature_class.py b/hcga/feature_class.py index 9bae8df6..2674a6a2 100644 --- a/hcga/feature_class.py +++ b/hcga/feature_class.py @@ -235,16 +235,13 @@ def evaluate_feature( # pylint: disable=too-many-branches try: try: - feature = timeout_eval( - feature_function, (function_args,), timeout=self.timeout, pool=self.pool - ) + feature = timeout_eval(feature_function, (function_args,), timeout=self.timeout) except NetworkXNotImplemented: if self.graph_type == "directed": feature = timeout_eval( feature_function, (to_undirected(function_args),), timeout=self.timeout, - pool=self.pool, ) else: return None diff --git a/hcga/utils.py b/hcga/utils.py index a83e80f2..3bd19e43 100644 --- a/hcga/utils.py +++ b/hcga/utils.py @@ -36,15 +36,41 @@ class NestedPool(multiprocessing.pool.Pool): # pylint: disable=abstract-method Process = NoDaemonProcess -def timeout_eval(func, args, timeout=None, pool=None): - """Evaluate a function and kill it is it takes longer than timeout. +def timeout_eval(func, args, timeout=None): + """Evaluate a function within a given timeout period. - If timeout is Nonei or == 0, a simple evaluation will take place. + Args: + func: The function to call. + args: Arguments to pass to the function. + timeout: The timeout period in seconds. + + Returns: + The function's result, or None if a timeout or an error occurs. """ if timeout is None or timeout == 0: - return func(*args) - - return pool.apply_async(func, args).get(timeout=timeout) + try: + return func(*args) + except Exception: # pylint: disable=broad-exception-caught + return None + + def target(queue, args): + try: + result = func(*args) + queue.put(result) + except Exception: # pylint: disable=broad-exception-caught + queue.put(None) + + queue = multiprocessing.Queue() + process = multiprocessing.Process(target=target, args=(queue, args)) + process.start() + process.join(timeout) + + if process.is_alive(): + process.terminate() + process.join() + return None + + return queue.get_nowait() def get_trivial_graph(n_node_features=0):