Skip to content

Commit

Permalink
🔄 synced local 'tools/submission/power/power_checker.py' with remote …
Browse files Browse the repository at this point in the history
…'compliance/check.py'
  • Loading branch information
mlcommons-bot committed Dec 17, 2024
1 parent aebc018 commit ebdd220
Showing 1 changed file with 22 additions and 53 deletions.
75 changes: 22 additions & 53 deletions tools/submission/power/power_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ class CheckerWarning(Exception):
]
COMMON_ERROR_TESTING = ["USB."]
WARNING_NEEDS_TO_BE_ERROR_TESTING_RE = [
re.compile(
r"Uncertainty \d+.\d+%, which is above 1.00% limit for the last sample!")
re.compile(r"Uncertainty \d+.\d+%, which is above 1.00% limit for the last sample!")
]

TIME_DELTA_TOLERANCE = 800 # in milliseconds
Expand Down Expand Up @@ -127,10 +126,8 @@ def get_time_from_line(
) -> float:
log_time_str = re.search(data_regexp, line)
if log_time_str and log_time_str.group(0):
log_datetime = datetime.strptime(
log_time_str.group(0), "%m-%d-%Y %H:%M:%S.%f")
return log_datetime.replace(
tzinfo=timezone.utc).timestamp() + timezone_offset
log_datetime = datetime.strptime(log_time_str.group(0), "%m-%d-%Y %H:%M:%S.%f")
return log_datetime.replace(tzinfo=timezone.utc).timestamp() + timezone_offset
raise LineWithoutTimeStamp(f"{line.strip()!r} in {file}.")


Expand Down Expand Up @@ -159,10 +156,8 @@ def required_fields_check(self) -> None:
), f"Required fields {', '.join(absent_keys)!r} does not exist in {self.path!r}"


def compare_dicts_values(
d1: Dict[str, str], d2: Dict[str, str], comment: str) -> None:
files_with_diff_check_sum = {k: d1[k]
for k in d1 if k in d2 and d1[k] != d2[k]}
def compare_dicts_values(d1: Dict[str, str], d2: Dict[str, str], comment: str) -> None:
files_with_diff_check_sum = {k: d1[k] for k in d1 if k in d2 and d1[k] != d2[k]}
assert len(files_with_diff_check_sum) == 0, f"{comment}" + "".join(
[
f"Expected {d1[i]}, but got {d2[i]} for {i}\n"
Expand All @@ -171,8 +166,7 @@ def compare_dicts_values(
)


def compare_dicts(s1: Dict[str, str],
s2: Dict[str, str], comment: str) -> None:
def compare_dicts(s1: Dict[str, str], s2: Dict[str, str], comment: str) -> None:
assert (
not s1.keys() - s2.keys()
), f"{comment} Missing {', '.join(sorted(s1.keys() - s2.keys()))!r}"
Expand Down Expand Up @@ -230,8 +224,7 @@ def check_reply(cmd: str, reply: str) -> None:
for msg in msgs:
if msg["cmd"].startswith(cmd):
if msg["cmd"] == "Stop":
# In normal flow the third answer to stop command is
# `Error: no measurement to stop`
# In normal flow the third answer to stop command is `Error: no measurement to stop`
if stop_counter == 2:
reply = "Error: no measurement to stop"
stop_counter += 1
Expand All @@ -250,15 +243,13 @@ def check_reply(cmd: str, reply: str) -> None:
def get_initial_range(param_num: int, reply: str) -> str:
reply_list = reply.split(",")
try:
if reply_list[param_num] == "0" and float(
reply_list[param_num + 1]) > 0:
if reply_list[param_num] == "0" and float(reply_list[param_num + 1]) > 0:
return reply_list[param_num + 1]
except (ValueError, IndexError):
assert False, f"Can not get power meters initial values from {reply!r}"
return "Auto"

def get_command_by_value_and_number(
cmd: str, number: int) -> Optional[str]:
def get_command_by_value_and_number(cmd: str, number: int) -> Optional[str]:
command_counter = 0
for msg in msgs:
if msg["cmd"].startswith(cmd):
Expand All @@ -282,8 +273,7 @@ def get_command_by_value_and_number(
), f"Do not set Volts range as initial. Expected 'SR,V,{initial_volts}', got {initial_volts_command!r}."


def uuid_check(client_sd: SessionDescriptor,
server_sd: SessionDescriptor) -> None:
def uuid_check(client_sd: SessionDescriptor, server_sd: SessionDescriptor) -> None:
"""Compare UUIDs from client.json and server.json. They should be the same."""
uuid_c = client_sd.json_object["uuid"]
uuid_s = server_sd.json_object["uuid"]
Expand Down Expand Up @@ -372,8 +362,7 @@ def compare_duration(range_duration: float, test_duration: float) -> None:
def compare_time_boundaries(
begin: float, end: float, phases: List[Any], mode: str
) -> None:
# TODO: temporary workaround, remove when proper DST handling is
# implemented!
# TODO: temporary workaround, remove when proper DST handling is implemented!
assert (
phases[1][0] < begin < phases[2][0]
or phases[1][0] < begin - 3600 < phases[2][0]
Expand All @@ -391,16 +380,8 @@ def compare_time_boundaries(
os.path.join(path, "run_1"), client_sd
)

compare_time_boundaries(
system_begin_r,
system_end_r,
phases_ranging_c,
"ranging")
compare_time_boundaries(
system_begin_t,
system_end_t,
phases_testing_c,
"testing")
compare_time_boundaries(system_begin_r, system_end_r, phases_ranging_c, "ranging")
compare_time_boundaries(system_begin_t, system_end_t, phases_testing_c, "testing")

ranging_duration_d = system_end_r - system_begin_r
testing_duration_d = system_end_t - system_begin_t
Expand Down Expand Up @@ -483,8 +464,7 @@ def session_name_check(
), f"Session name is not equal. Client session name is {session_name_c!r}. Server session name is {session_name_s!r}"


def messages_check(client_sd: SessionDescriptor,
server_sd: SessionDescriptor) -> None:
def messages_check(client_sd: SessionDescriptor, server_sd: SessionDescriptor) -> None:
"""Compare client and server messages list length.
Compare messages values and replies from client.json and server.json.
Compare client and server version.
Expand All @@ -508,19 +488,14 @@ def messages_check(client_sd: SessionDescriptor,
)

# Check client and server version from server.json.
# Server.json contains all client.json messages and replies. Checked
# earlier.
# Server.json contains all client.json messages and replies. Checked earlier.
def get_version(regexp: str, line: str) -> str:
version_o = re.search(regexp, line)
assert version_o is not None, f"Server version is not defined in:'{line}'"
return version_o.group(1)

client_version = get_version(
r"mlcommons\/power client v(\d+)$",
ms[0]["cmd"])
server_version = get_version(
r"mlcommons\/power server v(\d+)$",
ms[0]["reply"])
client_version = get_version(r"mlcommons\/power client v(\d+)$", ms[0]["cmd"])
server_version = get_version(r"mlcommons\/power server v(\d+)$", ms[0]["reply"])

assert (
client_version == server_version
Expand Down Expand Up @@ -575,8 +550,7 @@ def remove_optional_path(res: Dict[str, str]) -> None:
f"{client_sd.path} and {server_sd.path} results checksum comparison",
)

# Check if the hashes of the files in results directory match the ones
# recorded in server.json/client.json.
# Check if the hashes of the files in results directory match the ones recorded in server.json/client.json.
result_c_s = {**results_c, **results_s}

compare_dicts(
Expand Down Expand Up @@ -642,8 +616,7 @@ def find_error_or_warning(reg_exp: str, line: str, error: bool) -> None:

# Treat uncommon errors in ranging phase as warnings
if all(
not problem_line.group(0).strip().startswith(
common_ranging_error)
not problem_line.group(0).strip().startswith(common_ranging_error)
for common_ranging_error in COMMON_ERROR_RANGING
):
raise CheckerWarning(
Expand Down Expand Up @@ -701,8 +674,7 @@ def get_msg_without_time(line: str) -> Optional[str]:
is_uncertainty_check_activated = False

for line in ptd_log_lines:
msg_o = re.search(
r"Uncertainty checking for Yokogawa\S+ is activated", line)
msg_o = re.search(r"Uncertainty checking for Yokogawa\S+ is activated", line)
if msg_o is not None:
try:
log_time = None
Expand Down Expand Up @@ -770,8 +742,7 @@ def debug_check(server_sd: SessionDescriptor) -> None:
), "Server was running in debug mode"


def check_with_logging(
check_name: str, check: Callable[[], None]) -> Tuple[bool, bool]:
def check_with_logging(check_name: str, check: Callable[[], None]) -> Tuple[bool, bool]:
try:
check()
except AssertionError as e:
Expand Down Expand Up @@ -840,9 +811,7 @@ def check(path: str) -> int:
parser = argparse.ArgumentParser(
description="Check PTD client-server session results"
)
parser.add_argument(
"session_directory",
help="directory with session results data")
parser.add_argument("session_directory", help="directory with session results data")

args = parser.parse_args()

Expand Down

0 comments on commit ebdd220

Please sign in to comment.