diff --git a/s3torchconnectorclient/Cargo.lock b/s3torchconnectorclient/Cargo.lock index 9b0d0d02..b1065585 100644 --- a/s3torchconnectorclient/Cargo.lock +++ b/s3torchconnectorclient/Cargo.lock @@ -424,6 +424,18 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "fastrand" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -1102,6 +1114,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + [[package]] name = "quote" version = "1.0.35" @@ -1204,6 +1222,18 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rusty-fork" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb3dcc6e454c328bb824492db107ab7c0ae8fcffe4ad210136ef014458c1bc4f" +dependencies = [ + "fnv", + "quick-error", + "tempfile", + "wait-timeout", +] + [[package]] name = "ryu" version = "1.0.16" @@ -1224,6 +1254,7 @@ dependencies = [ "nix", "pyo3", "pyo3-log", + "rusty-fork", "tracing", "tracing-subscriber", ] @@ -1338,6 +1369,18 @@ version = "0.12.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" +[[package]] +name = "tempfile" +version = "3.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +dependencies = [ + "cfg-if", + "fastrand", + "rustix", + "windows-sys 0.52.0", +] + [[package]] name = "thiserror" version = "1.0.57" @@ -1546,6 +1589,15 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "wait-timeout" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f200f5b12eb75f8c1ed65abd4b2db8a6e1b138a20de009dacee265a2498f3f6" +dependencies = [ + "libc", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/s3torchconnectorclient/Cargo.toml b/s3torchconnectorclient/Cargo.toml index 79e5f64f..e9d29c68 100644 --- a/s3torchconnectorclient/Cargo.toml +++ b/s3torchconnectorclient/Cargo.toml @@ -27,6 +27,7 @@ tracing-subscriber = "0.3.17" nix = { version = "0.27.1", features = ["process"] } env_logger = "0.11.2" chrono = "0.4.34" +rusty-fork = "0.3.0" [features] extension-module = ["pyo3/extension-module"] diff --git a/s3torchconnectorclient/python/tst/integration/test_logging.py b/s3torchconnectorclient/python/tst/integration/test_logging.py index d08072f0..4778b82b 100644 --- a/s3torchconnectorclient/python/tst/integration/test_logging.py +++ b/s3torchconnectorclient/python/tst/integration/test_logging.py @@ -1,5 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # // SPDX-License-Identifier: BSD +import sys import pytest PYTHON_TEST_CODE = """ @@ -7,7 +8,6 @@ import os import sys -os.environ["PYTHONUNBUFFERED"] = "1" os.environ["ENABLE_CRT_LOGS"] = "{0}" from s3torchconnector import S3MapDataset @@ -18,14 +18,13 @@ ) logging.getLogger().setLevel(logging.INFO) -if __name__ == "__main__": - if len(sys.argv) != 3: - exit("The script needs an S3 uri and a region") - s3_uri = sys.argv[1] - region = sys.argv[2] - map_dataset = S3MapDataset.from_prefix(s3_uri, region=region) - obj = map_dataset[0] - assert obj is not None +if len(sys.argv) != 3: + exit("The script needs an S3 uri and a region") +s3_uri = sys.argv[1] +region = sys.argv[2] +map_dataset = S3MapDataset.from_prefix(s3_uri, region=region) +obj = map_dataset[0] +assert obj is not None """ import subprocess @@ -37,11 +36,7 @@ ( "info", ["INFO s3torchconnector.s3map_dataset"], - [ - "DEBUG awscrt::AWSProfile", - "TRACE awscrt::AWSProfile", - "DEBUG awscrt::AuthCredentialsProvider", - ], + ["DEBUG", "TRACE"], ), ( "debug", @@ -50,7 +45,7 @@ "DEBUG awscrt::AWSProfile", "DEBUG awscrt::AuthCredentialsProvider", ], - ["TRACE awscrt::AWSProfile"], + ["TRACE"], ), ( "trace", @@ -60,6 +55,7 @@ "DEBUG awscrt::AuthCredentialsProvider", "TRACE awscrt::event-loop", ], + # Python log level is set to INFO in the test script ["TRACE s3torchconnector.s3map_dataset"], ), ], @@ -67,15 +63,15 @@ def test_logging_valid(log_level, should_contain, should_not_contain, image_directory): stdout, stderr = _start_subprocess(log_level, image_directory) assert stderr == "" - assert stdout is not None - assert all([s in stdout for s in should_contain]) - assert all([s not in stdout for s in should_not_contain]) + assert stdout is not "" + assert all(s in stdout for s in should_contain) + assert all(s not in stdout for s in should_not_contain) def test_logging_off(image_directory): stdout, stderr = _start_subprocess("off", image_directory) assert stderr == "" - assert stdout is not None + assert stdout is not "" assert "INFO s3torchconnector.s3map_dataset" in stdout assert "awscrt" not in stdout @@ -92,7 +88,7 @@ def test_logging_invalid(image_directory): def _start_subprocess(log_level, image_directory): process = subprocess.Popen( [ - "python", + sys.executable, "-c", PYTHON_TEST_CODE.format(log_level), image_directory.s3_uri, diff --git a/s3torchconnectorclient/rust/src/logger_setup.rs b/s3torchconnectorclient/rust/src/logger_setup.rs index 4cabaecb..5ebf8c74 100644 --- a/s3torchconnectorclient/rust/src/logger_setup.rs +++ b/s3torchconnectorclient/rust/src/logger_setup.rs @@ -22,16 +22,16 @@ pub fn setup_logging() -> PyResult<()> { match enable_crt_logs.as_str() { "OFF" => enable_default_logging(), - level_filter_str => enable_crt_logging(level_filter_str)? + level_filter_str => enable_crt_logging(level_filter_str) } - - Ok(()) } fn enable_crt_logging(level_filter_str: &str) -> PyResult<()> { let level_filter = LevelFilter::from_str(level_filter_str) .map_err(python_exception)?; - let _ = RustLogAdapter::try_init().map_err(python_exception); + + RustLogAdapter::try_init().map_err(python_exception)?; + let mut builder = Builder::new(); builder .format(|buf, record| { @@ -50,59 +50,86 @@ fn enable_crt_logging(level_filter_str: &str) -> PyResult<()> { ) .filter_level(level_filter); - let _ = builder.try_init().map_err(python_exception); - - Ok(()) + builder.try_init().map_err(python_exception) } -fn enable_default_logging() { +fn enable_default_logging() -> PyResult<()> { let logger = Logger::default() .filter_target( "mountpoint_s3_client::s3_crt_client::request".to_owned(), LevelFilter::Off, ) .filter(LevelFilter::Trace); - let _ = logger.install().map_err(python_exception); + + logger.install().map_err(python_exception)?; + + Ok(()) } #[cfg(test)] mod tests { + use rusty_fork::rusty_fork_test; use std::{env}; use pyo3::PyResult; use crate::logger_setup::{ENABLE_CRT_LOGS_ENV_VAR, setup_logging}; - #[test] - fn test_logging_setup() { - pyo3::prepare_freethreaded_python(); - // Enforce serial execution as we modify the same environment variable - check_environment_variable_unset(); - check_valid_values(); - check_invalid_values(); - } - - fn check_environment_variable_unset() { - env::remove_var(ENABLE_CRT_LOGS_ENV_VAR); - let result: PyResult<()> = setup_logging(); - assert!(result.is_ok()); - } - - fn check_valid_values() { - let valid_values = ["OFF", "ERROR", "WARN", "INFO", "DEBUG", "TRACE", "debug"]; - for value in valid_values.iter() { - env::set_var(ENABLE_CRT_LOGS_ENV_VAR, *value); + rusty_fork_test! { + #[test] + fn test_environment_variable_unset() { + pyo3::prepare_freethreaded_python(); + env::remove_var(ENABLE_CRT_LOGS_ENV_VAR); let result: PyResult<()> = setup_logging(); assert!(result.is_ok()); } - } - fn check_invalid_values() { - let invalid_values = ["invalid", "", "\n", "123", "xyz"]; - for value in invalid_values.iter() { - env::set_var(ENABLE_CRT_LOGS_ENV_VAR, *value); - let error_result: PyResult<()> = setup_logging(); - assert!(error_result.is_err()); - let pyerr = error_result.err().unwrap(); - assert_eq!(pyerr.to_string(), "S3Exception: attempted to convert a string that doesn't match an existing log level"); + #[test] + fn test_logging_off() { + check_valid_log_level("OFF"); } + + #[test] + fn test_logging_level_error() { + check_valid_log_level("ERROR"); + } + + #[test] + fn test_logging_level_warn() { + check_valid_log_level("WARN"); + } + + #[test] + fn test_logging_level_info() { + check_valid_log_level("INFO"); + } + + #[test] + fn test_logging_level_debug() { + check_valid_log_level("debug"); + } + + #[test] + fn test_logging_level_trace() { + check_valid_log_level("trace"); + } + + #[test] + fn test_invalid_values() { + pyo3::prepare_freethreaded_python(); + let invalid_values = ["invalid", "", "\n", "123", "xyz"]; + for value in invalid_values.iter() { + env::set_var(ENABLE_CRT_LOGS_ENV_VAR, *value); + let error_result: PyResult<()> = setup_logging(); + assert!(error_result.is_err()); + let pyerr = error_result.err().unwrap(); + assert_eq!(pyerr.to_string(), "S3Exception: attempted to convert a string that doesn't match an existing log level"); + } + } + } + + fn check_valid_log_level(log_level: &str) { + pyo3::prepare_freethreaded_python(); + env::set_var(ENABLE_CRT_LOGS_ENV_VAR, log_level); + let result: PyResult<()> = setup_logging(); + assert!(result.is_ok()); } }