Skip to content

Commit

Permalink
Add Pathways Benchmarking Recipes for Scale Testing
Browse files Browse the repository at this point in the history
  • Loading branch information
SujeethJinesh committed Feb 12, 2025
1 parent ff59fb3 commit 68bde9d
Show file tree
Hide file tree
Showing 4 changed files with 453 additions and 15 deletions.
178 changes: 176 additions & 2 deletions benchmarks/maxtext_trillium_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
# "enable_pathways_goodput": True,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
Expand Down Expand Up @@ -552,6 +552,34 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
)
)

llama2_70b_4096_real_data_long_run = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama2-70b-4096-rd-lr",
model_type="llama2-70b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": -1,
"remat_policy": "full",
"max_target_length": 4096,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"reuse_example_batch": 0,
"profiler": "xplane",
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "tfds",
"tokenizer_path": "assets/tokenizer.llama2",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)
)

llama2_70b_4096_real_data_pw_long_run = _add_to_model_dictionary(
trillium_model_dict,
Expand All @@ -574,7 +602,6 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
"steps": 1000000,

# Additional tuning params for pathways long running test.
"enable_checkpointing": True,
Expand Down Expand Up @@ -829,6 +856,51 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
)


llama3_1_405b_8192_fsdp_dcn_pw = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama3-1-405b-8192-fsdp-dcn-pw",
model_type="llama3.1-405b",
tuning_params={
"per_device_batch_size": 1,
"ici_fsdp_parallelism": 64,
"ici_tensor_parallelism": 4,
"dcn_fsdp_parallelism": 2,
"allow_split_physical_axes": True,
"custom_mesh": "hybrid_ring_64x4",
"remat_policy": "custom",
"decoder_layer_input": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"out_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"reuse_example_batch": 1,
"enable_checkpointing": False,
"profiler": "xplane",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,

# Pathways specific tuning params.
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"enable_pathways_goodput": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.HOST_OFFLOAD_FLAGS
),
)
)

llama3_1_8b_8192 = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
Expand Down Expand Up @@ -916,6 +988,108 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
)
)

llama3_1_70b_8192_pw = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama3_1-70b-8192-pw",
model_type="llama3.1-70b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": -1,
"remat_policy": "custom",
"decoder_layer_input": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"enable_checkpointing": False,
"sa_block_q": 2048,
"sa_block_kv": 2048,
"sa_block_kv_compute": 2048,
"sa_block_q_dkv": 2048,
"sa_block_kv_dkv": 2048,
"sa_block_kv_dkv_compute": 2048,
"sa_block_q_dq": 2048,
"sa_block_kv_dq": 2048,
"sa_use_fused_bwd_kernel": True,
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,

# Pathways specific tuning params.
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"enable_pathways_goodput": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
+ xla_flags_library.DATA_PARALLEL_OVERLAP
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.HOST_OFFLOAD_FLAGS
),
)
)


llama3_1_70b_8192_pw_lr_real_data = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama3_1-70b-8192-pw-lr-rd",
model_type="llama3.1-70b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": -1,
"remat_policy": "custom",
"decoder_layer_input": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"enable_checkpointing": False,
"sa_block_q": 2048,
"sa_block_kv": 2048,
"sa_block_kv_compute": 2048,
"sa_block_q_dkv": 2048,
"sa_block_kv_dkv": 2048,
"sa_block_kv_dkv_compute": 2048,
"sa_block_q_dq": 2048,
"sa_block_kv_dq": 2048,
"sa_use_fused_bwd_kernel": True,
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,

# Pathways specific tuning params.
"enable_checkpointing": True,
"async_checkpointing": True,
"checkpoint_period": 100,
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
+ xla_flags_library.DATA_PARALLEL_OVERLAP
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.HOST_OFFLOAD_FLAGS
),
)
)

llama3_1_70b_129024 = _add_to_model_dictionary(
trillium_model_dict,
Expand Down
46 changes: 33 additions & 13 deletions benchmarks/maxtext_xpk_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class PathwaysConfig:
server_image: str
proxy_image: str
runner_image: str
remote_python_sidecar_image: str
remote_python_sidecar_image: str = None
server_flags: str = ''


# TODO(@vbarr): Split out parameters related to XPK workload and a General workload
Expand Down Expand Up @@ -284,7 +285,7 @@ def build_user_command(

install_libtpu_cmd = ''
jax_platforms = None
vertex_tensorboard = None
vertex_tensorboard = ''
# TODO() support modifying nightly / stable dependencies in pathway flow
if is_pw_enabled:
jax_platforms = 'proxy'
Expand Down Expand Up @@ -375,35 +376,54 @@ def generate_xpk_workload_cmd(
docker_image_flag = ''
# pathways-related flags
pathways_specific_flags = ''
workload_create_command = f'python3 {wl_config.xpk_path}/xpk.py workload create'
device_type = f' --device-type={cluster_config.device_type}'
if is_pathways_enabled:
pw_config = wl_config.pathways_config
pathways_specific_flags = (
'--use-pathways'
remote_python_sidecar_image_flag = (
f' --remote-python-sidecar-image={pw_config.remote_python_sidecar_image}'
if pw_config.remote_python_sidecar_image is not None
else ''
)
server_image_flag = (
f' --server-image={pw_config.server_image}'
if pw_config.server_image is not None
else ''
)
proxy_image_flag = (
f' --proxy-server-image={pw_config.proxy_image}'
f' --remote-python-sidecar-image={pw_config.remote_python_sidecar_image}'
if pw_config.remote_python_sidecar_image is not None else ''
' --termination-grace-period-seconds=300'
f' --pathways-gcs-location={wl_config.base_output_directory}'
f' --restart-on-user-code-failure'
f' --debug-dump-gcs={wl_config.base_output_directory}'
if pw_config.proxy_image is not None
else ''
)
pathways_specific_flags = (
f' {server_image_flag} '
f' {proxy_image_flag} '
f' {remote_python_sidecar_image_flag} '
f' --termination-grace-period-seconds=300 '
f' --pathways-gcs-location={wl_config.base_output_directory} '
# f' --restart-on-user-code-failure'
# f' --debug-dump-gcs={wl_config.base_output_directory} '
f' --custom-pathways-server-args="{wl_config.pathways_config.server_flags}" '
)
device_type = f' --tpu-type={wl_config.device_type}'
workload_create_command = (
f'python3 {wl_config.xpk_path}/xpk.py workload create-pathways'
)
docker_image_flag = (
f'--docker-image={pw_config.runner_image}'
)
else:
docker_image_flag = f'--base-docker-image="{wl_config.base_docker_image}"'


print(f'User command: {user_command}')
return (
(
f'python3 {wl_config.xpk_path}/xpk.py workload create'
f'{workload_create_command}'
f' {pathways_specific_flags}'
f' --cluster={cluster_config.cluster_name}'
f' --project={cluster_config.project}'
f' --zone={cluster_config.zone}'
f' --device-type={cluster_config.device_type}'
f' {device_type}'
f' --num-slices={wl_config.num_slices}'
f' --command="{user_command}"'
f' {docker_image_flag}'
Expand Down
Loading

0 comments on commit 68bde9d

Please sign in to comment.