diff --git a/dags/inference/configs/maxtext_inference_microbenchmark_gce_config.py b/dags/inference/configs/maxtext_inference_microbenchmark_gce_config.py index ffaac119..3a12e868 100644 --- a/dags/inference/configs/maxtext_inference_microbenchmark_gce_config.py +++ b/dags/inference/configs/maxtext_inference_microbenchmark_gce_config.py @@ -42,6 +42,7 @@ def config( num_slices: int = 1, model_configs: Dict = {}, maxtext_branch: str = "", + xla_flags: str = "", ): job_gcp_config = gcp_config.GCPConfig( project_name=project_name, @@ -86,7 +87,8 @@ def config( "cat MaxText/metadata.json", ### Benchmark # Configure flags - "export XLA_FLAGS='--xla_disable_hlo_passes=rematerialization'", + "export XLA_FLAGS='--xla_disable_hlo_passes=rematerialization'" + + xla_flags, f"""python MaxText/inference_microbenchmark_sweep.py \ MaxText/configs/base.yml \ model_name={model_configs['model_name']} \ diff --git a/dags/inference/maxtext_inference_microbenchmark.py b/dags/inference/maxtext_inference_microbenchmark.py index 7cb67cdf..4976e4d2 100644 --- a/dags/inference/maxtext_inference_microbenchmark.py +++ b/dags/inference/maxtext_inference_microbenchmark.py @@ -19,9 +19,10 @@ import itertools import numpy from airflow import models -from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion +from dags.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK from dags.inference.configs import maxtext_inference_microbenchmark_gce_config from dags.multipod.configs.common import SetupMode +from sklearn.model_selection import ParameterGrid USER_PREFIX = "" MAXTEXT_BRANCH = "" @@ -100,6 +101,61 @@ def get_concatenated_list_of_params(sweep_vm_count=1): ) +def generate_possible_values(min_val="", max_val="", type="", interval="0.1"): + possible_vals = ["None"] + if type == "boolean": + possible_vals = ["True", "False"] + if type == "int": + for i in range(min_val, max_val): + possible_vals.append(i) + if type == "flaot": + for i in range(min_val, max_val, interval): + possible_vals.append(i) + return possible_vals + + +boolean_flags = { + "xla_tpu_enable_async_collective_fusion", + "xla_tpu_enable_async_collective_fusion_fuse_all_gather", + "xla_tpu_overlap_compute_collective_tc", + "xla_tpu_rwb_fusion", +} + +int_flags = { + "xla_tpu_rematerialization_min_size_in_bytes": [], + "xla_tpu_relayout_group_size_threshold_for_reduce_scatter": [], + "xla_jf_spmd_threshold_for_windowed_einsum_mib": [], + "xla_tpu_nd_short_transfer_max_chunks": [], +} +float_flags = {} + + +def create_flags_possible_vals_dict(): + flags_possible_vals = {} + list_flags_combinations = [] + for flag in boolean_flags.items(): + flags_possible_vals[flag] = generate_possible_values(type="boolean") + for flag, val_range in int_flags.items(): + flags_possible_vals[flag] = generate_possible_values( + type="int", min_val=val_range[0], max_val=val_range[1] + ) + for flag, val_range in float_flags.items(): + flags_possible_vals[flag] = generate_possible_values( + type="float", min_val=val_range[0], max_val=val_range[1], interval=0.1 + ) + + flags_grid = ParameterGrid(flags_possible_vals) + for dict_ in flags_grid: + dict_pruned = {} + str_flags = "" + for k, v in dict_.items(): + if v != "None": + dict_pruned[k] = v + str_flags += f" --{k}={v} " + list_flags_combinations.append(str_flags) + return list_flags_combinations + + def generate_model_configs( test_name_prefix, model_config_name, @@ -109,6 +165,7 @@ def generate_model_configs( vm_number, tpu_version, tpu_cores, + xla_flags, ): model_configs = {} model_configs["model_config_name"] = model_config_name @@ -184,6 +241,12 @@ def generate_model_configs( network = V5_NETWORKS subnetwork = V5E_SUBNETWORKS runtime_version = RuntimeVersion.V2_ALPHA_TPUV5_LITE.value + if tpu_version == TpuVersion.TRILLIUM: + project_name = Project.CLOUD_ML_AUTO_SOLUTIONS.value + zone = Zone.EUROPE_WEST4_A.value + network = V6E_GCE_NETWORK + subnetwork = V6E_GCE_SUBNETWORK + runtime_version = RuntimeVersion.V2_ALPHA_TPUV6.value maxtext_kv_cache_layout_optimization = ( maxtext_inference_microbenchmark_gce_config.config( @@ -200,6 +263,7 @@ def generate_model_configs( is_tpu_reserved=True, model_configs=model_configs, maxtext_branch=model_configs["maxtext_branch"], + xla_flags=xla_flags, ) ) @@ -241,7 +305,7 @@ def generate_model_configs( if not MAXTEXT_BRANCH else f"-b {MAXTEXT_BRANCH}", "sleep_time": 60, - "tpu_version_cores": [(TpuVersion.V5E, 8)], + "tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)], "model_name": LLAMA2_7B, "tokenizer": "tokenizer.llama2", "weight_dtype": "bfloat16", @@ -300,14 +364,16 @@ def generate_model_configs( for tpu_version, tpu_cores in sweep_model_configs["tpu_version_cores"]: for compute_axis_order in sweep_model_configs["compute_axis_order"]: for ici_parallelism in sweep_model_configs["ici_parallelisms"]: - for vm_number in range(sweep_vm_count): - maxtext_kv_cache_layout_optimization = generate_model_configs( - test_name_prefix=test_name_prefix, - model_config_name=model_config_name, - sweep_model_configs=sweep_model_configs, - compute_axis_order=compute_axis_order, - ici_parallelism=ici_parallelism, - vm_number=vm_number, - tpu_version=tpu_version, - tpu_cores=tpu_cores, - ) + for flag_combination in sweep_model_configs["xla_flags"]: + for vm_number in range(sweep_vm_count): + maxtext_kv_cache_layout_optimization = generate_model_configs( + test_name_prefix=test_name_prefix, + model_config_name=model_config_name, + sweep_model_configs=sweep_model_configs, + compute_axis_order=compute_axis_order, + ici_parallelism=ici_parallelism, + vm_number=vm_number, + tpu_version=tpu_version, + tpu_cores=tpu_cores, + xla_flags=flag_combination, + ) diff --git a/dags/vm_resource.py b/dags/vm_resource.py index 6b03e39d..93682240 100644 --- a/dags/vm_resource.py +++ b/dags/vm_resource.py @@ -26,6 +26,9 @@ V6E_SUBNETWORKS = ( f"{V5_NETWORKS_PREFIX}/regions/us-central2/subnetworks/mas-test" ) +# TODO: Figure V6E_GCE_NETWORK and V6E_GCE_SUBNETWORK +V6E_GCE_NETWORK = "default" +V6E_GCE_SUBNETWORK = "default" BM_NETWORKS_PREFIX_BENCHMARKING = "projects/cloud-ml-benchmarking" BM_NETWORKS = f"{BM_NETWORKS_PREFIX_BENCHMARKING}/global/networks/mas-test" @@ -100,6 +103,8 @@ class Zone(enum.Enum): US_WEST1_C = "us-west1-c" # reserved a3+ cluster in supercomputer-testing AUSTRALIA_SOUTHEAST1_C = "australia-southeast1-c" + # reserved TRILLIUM capacity + EUROPE_WEST4_A = "europe-west4-a" class MachineVersion(enum.Enum): @@ -159,6 +164,7 @@ class RuntimeVersion(enum.Enum): TPU_VM_V4_BASE = "tpu-vm-v4-base" V2_ALPHA_TPUV5_LITE = "v2-alpha-tpuv5-lite" V2_ALPHA_TPUV5 = "v2-alpha-tpuv5" + V2_ALPHA_TPUV6 = "v2-alpha-tpuv6e" class XpkClusters: