Skip to content

Commit

Permalink
Add Pathways Benchmarking Recipes for Scale Testing
Browse files Browse the repository at this point in the history
  • Loading branch information
SujeethJinesh committed Feb 19, 2025
1 parent 10b24ab commit 7888e57
Show file tree
Hide file tree
Showing 7 changed files with 549 additions and 25 deletions.
4 changes: 2 additions & 2 deletions benchmarks/Getting_Started_Benchmarking.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export RUNNER=us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_sta
export PROXY_IMAGE=us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/proxy_server:latest
export SERVER_IMAGE=us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/server:latest

python3 benchmarks/benchmark_runner.py xpk --project $PROJECT --zone $ZONE --cluster_name $CLUSTER --device_type v6e-256 --base_output_directory gs://maxtext-experiments-tpem/ --num_steps=5 --pathways_server_image="${SERVER_IMAGE}" --pathways_proxy_image="${PROXY_IMAGE}" --pathways_runner_image="${RUNNER}"
python3 benchmarks/benchmark_runner.py xpk --project $PROJECT --zone $ZONE --cluster_name $CLUSTER --device_type v6e-256 --base_output_directory gs://maxtext-experiments-tpem/ --num_steps=5 --pathways_server_image="${SERVER_IMAGE}" --pathways_proxy_server_image="${PROXY_IMAGE}" --pathways_runner_image="${RUNNER}"
```

```shell
Expand Down Expand Up @@ -87,4 +87,4 @@ for model in list_of_models:
if return_code != 0:
print('Unable to run xpk workload: {xpk_workload_name}')

```
```
4 changes: 2 additions & 2 deletions benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def add_xpk_runner_arguments(custom_parser: argparse.ArgumentParser):
help='version of pathways server image to be benchmarked command.',
)
custom_parser.add_argument(
'--pathways_proxy_image',
'--pathways_proxy_server_image',
type=str,
default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/proxy_server:latest',
help='version of pathways proxy image to be benchmarked command.',
Expand Down Expand Up @@ -251,7 +251,7 @@ def main() -> None:
if options.use_pathways:
pw_config = PathwaysConfig(
server_image=options.pathways_server_image,
proxy_image=options.pathways_proxy_image,
proxy_image=options.pathways_proxy_server_image,
runner_image=options.pathways_runner_image,
remote_python_sidecar_image=options.remote_python_sidecar_image,
)
Expand Down
228 changes: 226 additions & 2 deletions benchmarks/maxtext_trillium_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
# "enable_pathways_goodput": True,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
Expand Down Expand Up @@ -552,6 +552,34 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
)
)

llama2_70b_4096_real_data_long_run = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama2-70b-4096-rd-lr",
model_type="llama2-70b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": -1,
"remat_policy": "full",
"max_target_length": 4096,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"reuse_example_batch": 0,
"profiler": "xplane",
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "tfds",
"tokenizer_path": "assets/tokenizer.llama2",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)
)

llama2_70b_4096_real_data_pw_long_run = _add_to_model_dictionary(
trillium_model_dict,
Expand All @@ -574,7 +602,6 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
"steps": 1000000,

# Additional tuning params for pathways long running test.
"enable_checkpointing": True,
Expand Down Expand Up @@ -829,6 +856,51 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
)


llama3_1_405b_8192_fsdp_dcn_pw = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama3-1-405b-8192-fsdp-dcn-pw",
model_type="llama3.1-405b",
tuning_params={
"per_device_batch_size": 1,
"ici_fsdp_parallelism": 64,
"ici_tensor_parallelism": 4,
"dcn_fsdp_parallelism": 2,
"allow_split_physical_axes": True,
"custom_mesh": "hybrid_ring_64x4",
"remat_policy": "custom",
"decoder_layer_input": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"out_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"reuse_example_batch": 1,
"enable_checkpointing": False,
"profiler": "xplane",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,

# Pathways specific tuning params.
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"enable_pathways_goodput": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.HOST_OFFLOAD_FLAGS
),
)
)

llama3_1_8b_8192 = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
Expand Down Expand Up @@ -874,6 +946,56 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
)


llama3_1_8b_8192_pw = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama3_1-8b-8192-pw",
model_type="llama3.1-8b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": -1,
"remat_policy": "custom",
"decoder_layer_input": "offload",
"out_proj": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"enable_checkpointing": False,
"sa_block_q": 2048,
"sa_block_kv": 2048,
"sa_block_kv_compute": 2048,
"sa_block_q_dkv": 2048,
"sa_block_kv_dkv": 2048,
"sa_block_kv_dkv_compute": 2048,
"sa_block_q_dq": 2048,
"sa_block_kv_dq": 2048,
"sa_use_fused_bwd_kernel": True,
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,

# Pathways specific tuning params.
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"enable_pathways_goodput": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
+ xla_flags_library.DATA_PARALLEL_OVERLAP
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE_PW
+ xla_flags_library.HOST_OFFLOAD_FLAGS
),
)
)

llama3_1_70b_8192 = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
Expand Down Expand Up @@ -916,6 +1038,108 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
)
)

llama3_1_70b_8192_pw = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama3_1-70b-8192-pw",
model_type="llama3.1-70b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": -1,
"remat_policy": "custom",
"decoder_layer_input": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"enable_checkpointing": False,
"sa_block_q": 2048,
"sa_block_kv": 2048,
"sa_block_kv_compute": 2048,
"sa_block_q_dkv": 2048,
"sa_block_kv_dkv": 2048,
"sa_block_kv_dkv_compute": 2048,
"sa_block_q_dq": 2048,
"sa_block_kv_dq": 2048,
"sa_use_fused_bwd_kernel": True,
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,

# Pathways specific tuning params.
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"enable_pathways_goodput": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
+ xla_flags_library.DATA_PARALLEL_OVERLAP
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.HOST_OFFLOAD_FLAGS
),
)
)


llama3_1_70b_8192_pw_lr_real_data = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama3_1-70b-8192-pw-lr-rd",
model_type="llama3.1-70b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": -1,
"remat_policy": "custom",
"decoder_layer_input": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"enable_checkpointing": False,
"sa_block_q": 2048,
"sa_block_kv": 2048,
"sa_block_kv_compute": 2048,
"sa_block_q_dkv": 2048,
"sa_block_kv_dkv": 2048,
"sa_block_kv_dkv_compute": 2048,
"sa_block_q_dq": 2048,
"sa_block_kv_dq": 2048,
"sa_use_fused_bwd_kernel": True,
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,

# Pathways specific tuning params.
"enable_checkpointing": True,
"async_checkpointing": True,
"checkpoint_period": 100,
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
+ xla_flags_library.DATA_PARALLEL_OVERLAP
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.HOST_OFFLOAD_FLAGS
),
)
)

llama3_1_70b_129024 = _add_to_model_dictionary(
trillium_model_dict,
Expand Down
Loading

0 comments on commit 7888e57

Please sign in to comment.