Skip to content

Commit

Permalink
Update Python tests, use sys.executable and leverage
Browse files Browse the repository at this point in the history
rusty-fork crate for separate process testing, and
return error on logger initialization failure.
  • Loading branch information
dnnanuti committed Feb 27, 2024
1 parent 74a7abb commit 85075e2
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 57 deletions.
52 changes: 52 additions & 0 deletions s3torchconnectorclient/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions s3torchconnectorclient/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ tracing-subscriber = "0.3.17"
nix = { version = "0.27.1", features = ["process"] }
env_logger = "0.11.2"
chrono = "0.4.34"
rusty-fork = "0.3.0"

[features]
extension-module = ["pyo3/extension-module"]
Expand Down
36 changes: 16 additions & 20 deletions s3torchconnectorclient/python/tst/integration/test_logging.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD
import sys
import pytest

PYTHON_TEST_CODE = """
import logging
import os
import sys
os.environ["PYTHONUNBUFFERED"] = "1"
os.environ["ENABLE_CRT_LOGS"] = "{0}"
from s3torchconnector import S3MapDataset
Expand All @@ -18,14 +18,13 @@
)
logging.getLogger().setLevel(logging.INFO)
if __name__ == "__main__":
if len(sys.argv) != 3:
exit("The script needs an S3 uri and a region")
s3_uri = sys.argv[1]
region = sys.argv[2]
map_dataset = S3MapDataset.from_prefix(s3_uri, region=region)
obj = map_dataset[0]
assert obj is not None
if len(sys.argv) != 3:
exit("The script needs an S3 uri and a region")
s3_uri = sys.argv[1]
region = sys.argv[2]
map_dataset = S3MapDataset.from_prefix(s3_uri, region=region)
obj = map_dataset[0]
assert obj is not None
"""

import subprocess
Expand All @@ -37,11 +36,7 @@
(
"info",
["INFO s3torchconnector.s3map_dataset"],
[
"DEBUG awscrt::AWSProfile",
"TRACE awscrt::AWSProfile",
"DEBUG awscrt::AuthCredentialsProvider",
],
["DEBUG", "TRACE"],
),
(
"debug",
Expand All @@ -50,7 +45,7 @@
"DEBUG awscrt::AWSProfile",
"DEBUG awscrt::AuthCredentialsProvider",
],
["TRACE awscrt::AWSProfile"],
["TRACE"],
),
(
"trace",
Expand All @@ -60,22 +55,23 @@
"DEBUG awscrt::AuthCredentialsProvider",
"TRACE awscrt::event-loop",
],
# Python log level is set to INFO in the test script
["TRACE s3torchconnector.s3map_dataset"],
),
],
)
def test_logging_valid(log_level, should_contain, should_not_contain, image_directory):
stdout, stderr = _start_subprocess(log_level, image_directory)
assert stderr == ""
assert stdout is not None
assert all([s in stdout for s in should_contain])
assert all([s not in stdout for s in should_not_contain])
assert stdout is not ""
assert all(s in stdout for s in should_contain)
assert all(s not in stdout for s in should_not_contain)


def test_logging_off(image_directory):
stdout, stderr = _start_subprocess("off", image_directory)
assert stderr == ""
assert stdout is not None
assert stdout is not ""
assert "INFO s3torchconnector.s3map_dataset" in stdout
assert "awscrt" not in stdout

Expand All @@ -92,7 +88,7 @@ def test_logging_invalid(image_directory):
def _start_subprocess(log_level, image_directory):
process = subprocess.Popen(
[
"python",
sys.executable,
"-c",
PYTHON_TEST_CODE.format(log_level),
image_directory.s3_uri,
Expand Down
101 changes: 64 additions & 37 deletions s3torchconnectorclient/rust/src/logger_setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ pub fn setup_logging() -> PyResult<()> {

match enable_crt_logs.as_str() {
"OFF" => enable_default_logging(),
level_filter_str => enable_crt_logging(level_filter_str)?
level_filter_str => enable_crt_logging(level_filter_str)
}

Ok(())
}

fn enable_crt_logging(level_filter_str: &str) -> PyResult<()> {
let level_filter = LevelFilter::from_str(level_filter_str)
.map_err(python_exception)?;
let _ = RustLogAdapter::try_init().map_err(python_exception);

RustLogAdapter::try_init().map_err(python_exception)?;

let mut builder = Builder::new();
builder
.format(|buf, record| {
Expand All @@ -50,59 +50,86 @@ fn enable_crt_logging(level_filter_str: &str) -> PyResult<()> {
)
.filter_level(level_filter);

let _ = builder.try_init().map_err(python_exception);

Ok(())
builder.try_init().map_err(python_exception)
}

fn enable_default_logging() {
fn enable_default_logging() -> PyResult<()> {
let logger = Logger::default()
.filter_target(
"mountpoint_s3_client::s3_crt_client::request".to_owned(),
LevelFilter::Off,
)
.filter(LevelFilter::Trace);
let _ = logger.install().map_err(python_exception);

logger.install().map_err(python_exception)?;

Ok(())
}

#[cfg(test)]
mod tests {
use rusty_fork::rusty_fork_test;
use std::{env};
use pyo3::PyResult;
use crate::logger_setup::{ENABLE_CRT_LOGS_ENV_VAR, setup_logging};

#[test]
fn test_logging_setup() {
pyo3::prepare_freethreaded_python();
// Enforce serial execution as we modify the same environment variable
check_environment_variable_unset();
check_valid_values();
check_invalid_values();
}

fn check_environment_variable_unset() {
env::remove_var(ENABLE_CRT_LOGS_ENV_VAR);
let result: PyResult<()> = setup_logging();
assert!(result.is_ok());
}

fn check_valid_values() {
let valid_values = ["OFF", "ERROR", "WARN", "INFO", "DEBUG", "TRACE", "debug"];
for value in valid_values.iter() {
env::set_var(ENABLE_CRT_LOGS_ENV_VAR, *value);
rusty_fork_test! {
#[test]
fn test_environment_variable_unset() {
pyo3::prepare_freethreaded_python();
env::remove_var(ENABLE_CRT_LOGS_ENV_VAR);
let result: PyResult<()> = setup_logging();
assert!(result.is_ok());
}
}

fn check_invalid_values() {
let invalid_values = ["invalid", "", "\n", "123", "xyz"];
for value in invalid_values.iter() {
env::set_var(ENABLE_CRT_LOGS_ENV_VAR, *value);
let error_result: PyResult<()> = setup_logging();
assert!(error_result.is_err());
let pyerr = error_result.err().unwrap();
assert_eq!(pyerr.to_string(), "S3Exception: attempted to convert a string that doesn't match an existing log level");
#[test]
fn test_logging_off() {
check_valid_log_level("OFF");
}

#[test]
fn test_logging_level_error() {
check_valid_log_level("ERROR");
}

#[test]
fn test_logging_level_warn() {
check_valid_log_level("WARN");
}

#[test]
fn test_logging_level_info() {
check_valid_log_level("INFO");
}

#[test]
fn test_logging_level_debug() {
check_valid_log_level("debug");
}

#[test]
fn test_logging_level_trace() {
check_valid_log_level("trace");
}

#[test]
fn test_invalid_values() {
pyo3::prepare_freethreaded_python();
let invalid_values = ["invalid", "", "\n", "123", "xyz"];
for value in invalid_values.iter() {
env::set_var(ENABLE_CRT_LOGS_ENV_VAR, *value);
let error_result: PyResult<()> = setup_logging();
assert!(error_result.is_err());
let pyerr = error_result.err().unwrap();
assert_eq!(pyerr.to_string(), "S3Exception: attempted to convert a string that doesn't match an existing log level");
}
}
}

fn check_valid_log_level(log_level: &str) {
pyo3::prepare_freethreaded_python();
env::set_var(ENABLE_CRT_LOGS_ENV_VAR, log_level);
let result: PyResult<()> = setup_logging();
assert!(result.is_ok());
}
}

0 comments on commit 85075e2

Please sign in to comment.