Skip to content

Commit

Permalink
Fix: Modify the time_eval func in utils.py for killing timeout sub-pr…
Browse files Browse the repository at this point in the history
…ocess (#109)
  • Loading branch information
lzhan94swu authored Apr 15, 2024
1 parent 0480c93 commit 8bfdf34
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
5 changes: 1 addition & 4 deletions hcga/feature_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,13 @@ def evaluate_feature( # pylint: disable=too-many-branches

try:
try:
feature = timeout_eval(
feature_function, (function_args,), timeout=self.timeout, pool=self.pool
)
feature = timeout_eval(feature_function, (function_args,), timeout=self.timeout)
except NetworkXNotImplemented:
if self.graph_type == "directed":
feature = timeout_eval(
feature_function,
(to_undirected(function_args),),
timeout=self.timeout,
pool=self.pool,
)
else:
return None
Expand Down
38 changes: 32 additions & 6 deletions hcga/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,41 @@ class NestedPool(multiprocessing.pool.Pool): # pylint: disable=abstract-method
Process = NoDaemonProcess


def timeout_eval(func, args, timeout=None, pool=None):
"""Evaluate a function and kill it is it takes longer than timeout.
def timeout_eval(func, args, timeout=None):
"""Evaluate a function within a given timeout period.
If timeout is Nonei or == 0, a simple evaluation will take place.
Args:
func: The function to call.
args: Arguments to pass to the function.
timeout: The timeout period in seconds.
Returns:
The function's result, or None if a timeout or an error occurs.
"""
if timeout is None or timeout == 0:
return func(*args)

return pool.apply_async(func, args).get(timeout=timeout)
try:
return func(*args)
except Exception: # pylint: disable=broad-exception-caught
return None

def target(queue, args):
try:
result = func(*args)
queue.put(result)
except Exception: # pylint: disable=broad-exception-caught
queue.put(None)

queue = multiprocessing.Queue()
process = multiprocessing.Process(target=target, args=(queue, args))
process.start()
process.join(timeout)

if process.is_alive():
process.terminate()
process.join()
return None

return queue.get_nowait()


def get_trivial_graph(n_node_features=0):
Expand Down

0 comments on commit 8bfdf34

Please sign in to comment.