Skip to content

Commit

Permalink
Gracefully exit the daemon thread that is periodically sending user s…
Browse files Browse the repository at this point in the history
…ignal

PiperOrigin-RevId: 580229225
  • Loading branch information
SurbhiJainUSC authored and copybara-github committed Nov 7, 2023
1 parent 414de25 commit 12f648b
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 26 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ infrastructure based on configuration files. This repository will help the
customers to deploy various google cloud resources via script, without any
manual effort.

[cloud-tpu-diagnostics PyPI package]((https://pypi.org/project/cloud-tpu-diagnostics)) contains all the logic to monitor, debug and profile the jobs running on Cloud TPU.
[cloud-tpu-diagnostics PyPI package](https://pypi.org/project/cloud-tpu-diagnostics) contains all the logic to monitor, debug and profile the jobs running on Cloud TPU.

## Getting Started with Terraform

Expand Down
5 changes: 5 additions & 0 deletions pip_package/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
-->

## [0.1.4] - 2023-11-07
* Gracefully exiting daemon threads
* Fixed the URL for PyPI package in README

## [0.1.3] - 2023-11-01
* Fixing issue with using signals and threads together in a program

Expand All @@ -49,6 +53,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
* Initial release of cloud-tpu-diagnostics PyPI package
* FEATURE: Contains debug module to collect stack traces on faults

[0.1.4]: https://github.com/google/cloud-tpu-monitoring-debugging/compare/v0.1.3...v0.1.4
[0.1.3]: https://github.com/google/cloud-tpu-monitoring-debugging/compare/v0.1.2...v0.1.3
[0.1.2]: https://github.com/google/cloud-tpu-monitoring-debugging/compare/v0.1.1...v0.1.2
[0.1.1]: https://github.com/google/cloud-tpu-monitoring-debugging/compare/v0.1.0...v0.1.1
Expand Down
17 changes: 14 additions & 3 deletions pip_package/cloud_tpu_diagnostics/src/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,25 @@
from cloud_tpu_diagnostics.src.stack_trace import disable_stack_trace_dumping
from cloud_tpu_diagnostics.src.stack_trace import enable_stack_trace_dumping

# flag to signal daemon thread to exit gracefully
_exit_flag = threading.Event()
_exit_flag.clear()
_daemon_thread = None

def start_debugging(debug_config):
"""Context manager to debug and identify errors."""
global _daemon_thread
_exit_flag.clear()
if (
debug_config.stack_trace_config is not None
and debug_config.stack_trace_config.collect_stack_trace
):
thread = threading.Thread(
_daemon_thread = threading.Thread(
target=send_user_signal,
daemon=True,
args=(debug_config.stack_trace_config.stack_trace_interval_seconds,),
)
thread.start() # start a daemon thread
_daemon_thread.start() # start a daemon thread
enable_stack_trace_dumping(debug_config.stack_trace_config)


Expand All @@ -41,11 +47,16 @@ def stop_debugging(debug_config):
debug_config.stack_trace_config is not None
and debug_config.stack_trace_config.collect_stack_trace
):
_exit_flag.set()
# wait for daemon thread to complete
if _daemon_thread is not None:
_daemon_thread.join()
disable_stack_trace_dumping(debug_config.stack_trace_config)
_exit_flag.clear()


def send_user_signal(stack_trace_interval_seconds):
"""Send SIGUSR1 signal to main thread after every stack_trace_interval_seconds seconds."""
while True:
while not _exit_flag.is_set():
time.sleep(stack_trace_interval_seconds)
signal.pthread_kill(threading.main_thread().ident, signal.SIGUSR1)
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
stack_trace_config=stack_trace_configuration.StackTraceConfig(
collect_stack_trace=args.collect_stack_trace,
stack_trace_to_cloud=args.log_to_cloud,
stack_trace_interval_seconds=1,
),
)
diagnostic_config = diagnostic_configuration.DiagnosticConfig(
Expand Down
24 changes: 22 additions & 2 deletions pip_package/cloud_tpu_diagnostics/tests/debug_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def testDaemonThreadRunningWhenCollectStackTraceTrue(self):
debug_config = debug_configuration.DebugConfig(
stack_trace_config=stack_trace_configuration.StackTraceConfig(
collect_stack_trace=True,
stack_trace_to_cloud=False,
stack_trace_to_cloud=True,
stack_trace_interval_seconds=1,
),
)
start_debugging(debug_config)
Expand All @@ -38,12 +39,19 @@ def testDaemonThreadRunningWhenCollectStackTraceTrue(self):
filter(lambda thread: thread.daemon is True, threading.enumerate())
)
self.assertLen(daemon_thread_list, 1)
stop_debugging(debug_config)
self.assertEqual(threading.active_count(), 1)
daemon_thread_list = list(
filter(lambda thread: thread.daemon is True, threading.enumerate())
)
self.assertLen(daemon_thread_list, 0)

def testDaemonThreadNotRunningWhenCollectStackTraceFalse(self):
debug_config = debug_configuration.DebugConfig(
stack_trace_config=stack_trace_configuration.StackTraceConfig(
collect_stack_trace=False,
stack_trace_to_cloud=False,
stack_trace_to_cloud=True,
stack_trace_interval_seconds=1,
),
)
start_debugging(debug_config)
Expand All @@ -52,6 +60,12 @@ def testDaemonThreadNotRunningWhenCollectStackTraceFalse(self):
filter(lambda thread: thread.daemon is True, threading.enumerate())
)
self.assertLen(daemon_thread_list, 0)
stop_debugging(debug_config)
self.assertEqual(threading.active_count(), 1)
daemon_thread_list = list(
filter(lambda thread: thread.daemon is True, threading.enumerate())
)
self.assertLen(daemon_thread_list, 0)

@mock.patch(
'google3.third_party.cloud_tpu_monitoring_debugging.pip_package.cloud_tpu_diagnostics.src.debug.disable_stack_trace_dumping'
Expand All @@ -63,10 +77,16 @@ def testStopDebuggingDisableStackTraceDumpingCalled(
stack_trace_config=stack_trace_configuration.StackTraceConfig(
collect_stack_trace=True,
stack_trace_to_cloud=True,
stack_trace_interval_seconds=1,
),
)
stop_debugging(debug_config)
disable_stack_trace_dumping_mock.assert_called_once()
self.assertEqual(threading.active_count(), 1)
daemon_thread_list = list(
filter(lambda thread: thread.daemon is True, threading.enumerate())
)
self.assertLen(daemon_thread_list, 0)

def testSendUserSignalSIGUSR1SignalReceived(self):
signal.signal(signal.SIGUSR1, user_signal_handler)
Expand Down
27 changes: 8 additions & 19 deletions pip_package/cloud_tpu_diagnostics/tests/stack_trace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,36 +49,31 @@ def tearDown(self):
@unittest.skipIf(not hasattr(signal, 'SIGSEGV'), 'Missing signal.SIGSEGV')
def testSigsegvCollectStackTraceTrueTraceCollectedOnCloud(self):
error = 'Fatal Python error: Segmentation fault'
self.check_fatal_error(51, error, 'SIGSEGV', True)
self.check_fatal_error(52, error, 'SIGSEGV', True)

@unittest.skipIf(not hasattr(signal, 'SIGABRT'), 'Missing signal.SIGABRT')
def testSigabrtCollectStackTraceTrueTraceCollectedOnCloud(self):
error = 'Fatal Python error: Aborted'
self.check_fatal_error(54, error, 'SIGABRT', True)
self.check_fatal_error(55, error, 'SIGABRT', True)

@unittest.skipIf(not hasattr(signal, 'SIGFPE'), 'Missing signal.SIGFPE')
def testSigfpeCollectStackTraceTrueTraceCollectedOnCloud(self):
error = 'Fatal Python error: Floating point exception'
self.check_fatal_error(57, error, 'SIGFPE', True)
self.check_fatal_error(58, error, 'SIGFPE', True)

@unittest.skipIf(not hasattr(signal, 'SIGILL'), 'Missing signal.SIGILL')
def testSigillCollectStackTraceTrueTraceCollectedOnCloud(self):
error = 'Fatal Python error: Illegal instruction'
self.check_fatal_error(60, error, 'SIGILL', True)
self.check_fatal_error(61, error, 'SIGILL', True)

@unittest.skipIf(not hasattr(signal, 'SIGBUS'), 'Missing signal.SIGBUS')
def testSigbusCollectStackTraceTrueTraceCollectedOnCloud(self):
error = 'Fatal Python error: Bus error'
self.check_fatal_error(63, error, 'SIGBUS', True)
self.check_fatal_error(64, error, 'SIGBUS', True)

@unittest.skipIf(not hasattr(signal, 'SIGUSR1'), 'Missing signal.SIGUSR1')
def testSigusrCollectStackTraceTrueTraceCollectedOnCloud(self):
self.check_fatal_error(66, '', 'SIGUSR1', True)

def testNoFaultCollectStackTraceTrueNoTraceCollectedOnCloud(self):
output, stderr = self.get_output('', True, True)
self.assertEqual(output, '')
self.assertEqual(stderr, '')
self.check_fatal_error(67, '', 'SIGUSR1', True)

def testCollectStackTraceFalseNoTraceDirCreated(self):
process = self.run_python_code('', False, True)
Expand All @@ -88,18 +83,12 @@ def testCollectStackTraceFalseNoTraceDirCreated(self):

@unittest.skipIf(not hasattr(signal, 'SIGUSR1'), 'Missing signal.SIGUSR1')
def testCollectStackTraceToConsole(self):
self.check_fatal_error(66, '', 'SIGUSR1', False)

def testNoFaultCollectStackTraceTrueNoTraceCollectedOnConsole(self):
output, stderr = self.get_output('', True, False)
self.assertEqual(output, '')
self.assertEqual(stderr, '')
self.check_fatal_error(67, '', 'SIGUSR1', False)

def testCollectStackTraceFalseNoTraceCollectedOnConsole(self):
process = self.run_python_code('', False, False)
_, stderr = process.communicate()
self.assertEmpty(stderr)
self.assertFalse(os.path.exists(default.STACK_TRACE_DIR_DEFAULT))

def testEnableStackTraceDumpingFaulthandlerEnabled(self):
stack_trace_config = stack_trace_configuration.StackTraceConfig(
Expand Down Expand Up @@ -157,7 +146,7 @@ def check_fatal_error(self, line_number, error, signal_name, log_to_cloud):
else:
header = (
r'INFO: Not a crash. cloud\-tpu\-diagnostics emits a stack trace'
r' snapshot every 600 seconds.\n'
r' snapshot every 1 seconds.\n'
r'Stack \(most recent call first\)'
)
regex = """
Expand Down
2 changes: 1 addition & 1 deletion pip_package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

[project]
name = "cloud-tpu-diagnostics"
version = "0.1.3"
version = "0.1.4"
authors = [
{ name="Cloud TPU Team", email="[email protected]" },
]
Expand Down

0 comments on commit 12f648b

Please sign in to comment.