Skip to content

Commit

Permalink
Add Pathways Benchmarking Recipes for Scale Testing
Browse files Browse the repository at this point in the history
  • Loading branch information
SujeethJinesh committed Jan 31, 2025
1 parent d0270c7 commit 59799d7
Show file tree
Hide file tree
Showing 4 changed files with 354 additions and 4 deletions.
102 changes: 102 additions & 0 deletions benchmarks/maxtext_trillium_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,108 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
)
)

llama3_1_70b_8192_pw = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama3_1-70b-8192-pw",
model_type="llama3.1-70b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": -1,
"remat_policy": "custom",
"decoder_layer_input": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"enable_checkpointing": False,
"sa_block_q": 2048,
"sa_block_kv": 2048,
"sa_block_kv_compute": 2048,
"sa_block_q_dkv": 2048,
"sa_block_kv_dkv": 2048,
"sa_block_kv_dkv_compute": 2048,
"sa_block_q_dq": 2048,
"sa_block_kv_dq": 2048,
"sa_use_fused_bwd_kernel": True,
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,

# Pathways specific tuning params.
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"enable_pathways_goodput": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
+ xla_flags_library.DATA_PARALLEL_OVERLAP
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.HOST_OFFLOAD_FLAGS
),
)
)


llama3_1_70b_8192_pw_lr_real_data = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama3_1-70b-8192-pw-lr-rd",
model_type="llama3.1-70b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": -1,
"remat_policy": "custom",
"decoder_layer_input": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"enable_checkpointing": False,
"sa_block_q": 2048,
"sa_block_kv": 2048,
"sa_block_kv_compute": 2048,
"sa_block_q_dkv": 2048,
"sa_block_kv_dkv": 2048,
"sa_block_kv_dkv_compute": 2048,
"sa_block_q_dq": 2048,
"sa_block_kv_dq": 2048,
"sa_use_fused_bwd_kernel": True,
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,

# Pathways specific tuning params.
"enable_checkpointing": True,
"async_checkpointing": True,
"checkpoint_period": 100,
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
+ xla_flags_library.DATA_PARALLEL_OVERLAP
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.HOST_OFFLOAD_FLAGS
),
)
)

llama3_1_70b_129024 = _add_to_model_dictionary(
trillium_model_dict,
Expand Down
12 changes: 8 additions & 4 deletions benchmarks/maxtext_xpk_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PathwaysConfig:
server_image: str
proxy_image: str
runner_image: str
remote_python_sidecar_image: str
remote_python_sidecar_image: str = None


# TODO(@vbarr): Split out parameters related to XPK workload and a General workload
Expand Down Expand Up @@ -284,7 +284,7 @@ def build_user_command(

install_libtpu_cmd = ''
jax_platforms = None
vertex_tensorboard = None
vertex_tensorboard = ''
# TODO() support modifying nightly / stable dependencies in pathway flow
if is_pw_enabled:
jax_platforms = 'proxy'
Expand Down Expand Up @@ -372,12 +372,16 @@ def generate_xpk_workload_cmd(
pathways_specific_flags = ''
if is_pathways_enabled:
pw_config = wl_config.pathways_config
remote_python_sidecar_image_flag = (
f' --remote-python-sidecar-image={pw_config.remote_python_sidecar_image}'
if pw_config.remote_python_sidecar_image is not None
else ''
)
pathways_specific_flags = (
'--use-pathways'
f' --server-image={pw_config.server_image}'
f' --proxy-server-image={pw_config.proxy_image}'
f' --remote-python-sidecar-image={pw_config.remote_python_sidecar_image}'
if pw_config.remote_python_sidecar_image is not None else ''
f'{remote_python_sidecar_image_flag}'
' --termination-grace-period-seconds=300'
f' --pathways-gcs-location={wl_config.base_output_directory}'
f' --restart-on-user-code-failure'
Expand Down
117 changes: 117 additions & 0 deletions benchmarks/recipes/pw_long_running_recipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import sys
import os
import args_helper as helper

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)

import maxtext_trillium_model_configs as model_configs
import maxtext_xpk_runner as mxr

HEAD_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/proxy_server:latest"
HEAD_SERVER_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/server:latest"
HEAD_RUNNER = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest"

CLUSTER = "v6e-256-cluster"
PROJECT = "tpu-prod-env-cluster"
ZONE = "us-east5-b"
DEVICE_TYPE = "v6e-256"
BENCHMARK_STEPS = 100000


def main() -> int:
# V6e cluster config
cluster_config = mxr.XpkClusterConfig(
cluster_name=CLUSTER,
project=PROJECT,
zone=ZONE,
device_type=DEVICE_TYPE,
)
xpk_path = "xpk"

# Handle command line arguments using args_helper
should_continue = helper.handle_cmd_args(
cluster_config, helper.DELETE, xpk_path=xpk_path
)

if not should_continue:
return 0

user = os.environ["USER"]
region = "-".join(cluster_config.zone.split("-")[:-1])
base_output_directory = f"gs://{user}-{region}/{user}"

list_of_models = [
model_configs.llama3_1_70b_8192_pw_lr_real_data,
]
pathways_config = mxr.PathwaysConfig(
server_image=HEAD_SERVER_IMAGE,
proxy_image=HEAD_PROXY_IMAGE,
runner_image=HEAD_RUNNER,
)
num_slices_list = [
2
]

xpk_workload_cmds = []
xpk_workload_names = []

for model in list_of_models:
# Run workloads on the below clusters
for cluster_config in [
cluster_config,
]:
# Run workloads in the following slice configurations
for num_slices in num_slices_list:
wl_config = mxr.WorkloadConfig(
model=model,
num_slices=num_slices,
device_type=cluster_config.device_type,
base_output_directory=base_output_directory,
max_restarts=100,
libtpu_type=None,
libtpu_nightly_version="",
base_docker_image="",
pathways_config=pathways_config,
xpk_path=xpk_path,
num_steps=BENCHMARK_STEPS,
)
command, name = mxr.generate_xpk_workload_cmd(
cluster_config=cluster_config, wl_config=wl_config
)

print(f"Name of the workload is: {name} \n")
xpk_workload_names.append(name)

print(f"XPK command to be used is: {command} \n")
xpk_workload_cmds.append(command)

for xpk_workload_name, xpk_workload_cmd in zip(
xpk_workload_names, xpk_workload_cmds
):
print(f"Running workload: {xpk_workload_name} with command: {xpk_workload_cmd}")
return_code = mxr.run_command_with_updates(
xpk_workload_cmd, xpk_workload_name
)
if return_code != 0:
print(f"Unable to run xpk workload: {xpk_workload_name}")


if __name__ == "__main__":
main()
127 changes: 127 additions & 0 deletions benchmarks/recipes/pw_mcjax_benchmark_recipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import sys
import os
import args_helper as helper

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)

import maxtext_trillium_model_configs as model_configs
import maxtext_xpk_runner as mxr

HEAD_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/proxy_server:latest"
HEAD_SERVER_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/server:latest"
HEAD_RUNNER = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest"

CLUSTER = "v6e-256-cluster"
PROJECT = "tpu-prod-env-cluster"
ZONE = "us-east5-b"
DEVICE_TYPE = "v6e-256"
BENCHMARK_STEPS = 20


def main() -> int:
# V6e cluster config
cluster_config = mxr.XpkClusterConfig(
cluster_name=CLUSTER,
project=PROJECT,
zone=ZONE,
device_type=DEVICE_TYPE,
)
xpk_path = "xpk"

# Handle command line arguments using args_helper
should_continue = helper.handle_cmd_args(
cluster_config, helper.DELETE, xpk_path=xpk_path
)

if not should_continue:
return 0

user = os.environ["USER"]
region = "-".join(cluster_config.zone.split("-")[:-1])
base_output_directory = f"gs://{user}-{region}/{user}"

# The user MUST build their own maxtext image:
# https://github.com/AI-Hypercomputer/maxtext/blob/main/docker_build_dependency_image.sh
local_maxtext_image = f"gcr.io/{PROJECT}/{user}_latest:latest"

models = {
"mcjax": [
model_configs.llama3_1_70b_8192,
],
"pathways": [
model_configs.llama3_1_70b_8192_pw,
]
}
pathways_config = mxr.PathwaysConfig(
server_image=HEAD_SERVER_IMAGE,
proxy_image=HEAD_PROXY_IMAGE,
runner_image=HEAD_RUNNER,
)
num_slices_list = [
2
]

xpk_workload_cmds = []
xpk_workload_names = []

for infra, model_list in models.items():
for model in model_list:
# Run workloads on the below clusters
for cluster_config in [
cluster_config,
]:
# Run workloads in the following slice configurations
for num_slices in num_slices_list:
wl_config = mxr.WorkloadConfig(
model=model,
num_slices=num_slices,
device_type=cluster_config.device_type,
base_output_directory=base_output_directory,
max_restarts=0,
libtpu_type=None,
libtpu_nightly_version="",
base_docker_image=local_maxtext_image if infra == "mcjax" else "",
pathways_config=pathways_config if infra == "pathways" else None,
xpk_path=xpk_path,
num_steps=BENCHMARK_STEPS,
)
command, name = mxr.generate_xpk_workload_cmd(
cluster_config=cluster_config, wl_config=wl_config
)

print(f"Name of the workload is: {name} \n")
xpk_workload_names.append(name)

print(f"XPK command to be used is: {command} \n")
xpk_workload_cmds.append(command)

for xpk_workload_name, xpk_workload_cmd in zip(
xpk_workload_names, xpk_workload_cmds
):
print(f"Running workload: {xpk_workload_name} with command: {xpk_workload_cmd}")
return_code = mxr.run_command_with_updates(
xpk_workload_cmd, xpk_workload_name
)
if return_code != 0:
print(f"Unable to run xpk workload: {xpk_workload_name}")


if __name__ == "__main__":
main()

0 comments on commit 59799d7

Please sign in to comment.