diff --git a/ADBench/plot_graphs.py b/ADBench/plot_graphs.py index 2fef158f3..7db6d7c17 100644 --- a/ADBench/plot_graphs.py +++ b/ADBench/plot_graphs.py @@ -76,7 +76,7 @@ def graph_data(build_type, objective, maybe_test_size, function_type): test_size = ", ".join([utils.cap_str(s) for s in maybe_test_size[0].split("_")]) if len(maybe_test_size) == 1 else None has_ts = test_size is not None - graph_name = (f"{objective.upper()}" + + graph_name = (f"{objective_display_name(objective)}" + (f" ({test_size})" if has_ts else "") + f" [{function_type.capitalize()}] - {build_type}") graph_save_location = os.path.join(build_type, function_type, f"{graph_name} Graph") @@ -85,6 +85,18 @@ def graph_data(build_type, objective, maybe_test_size, function_type): return (graph_name, graph_save_location) +# What we call LSTM is not quite a full LSTM. Rather it's an LSTM +# with diagonal weight matrices. We don't want to be misleading so we +# rename the graph. Eventually we will implement a full LSTM and we +# will remove this special case. See +# +# https://github.com/awf/ADBench/issues/143 +def objective_display_name(objective): + if objective.upper() == "LSTM": + return "D-LSTM" + else: + return objective.upper() + has_manual = lambda tool: tool.lower() in ["manual", "manual_eigen"] def tool_names(graph_files):