Skip to content

Commit

Permalink
Add Pathways Benchmarking Recipes for Scale Testing
Browse files Browse the repository at this point in the history
  • Loading branch information
SujeethJinesh committed Feb 20, 2025
1 parent bea1cef commit c7f68ca
Show file tree
Hide file tree
Showing 3 changed files with 500 additions and 2 deletions.
228 changes: 226 additions & 2 deletions benchmarks/maxtext_trillium_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
# "enable_pathways_goodput": True,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
Expand Down Expand Up @@ -552,6 +552,34 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
)
)

llama2_70b_4096_real_data_long_run = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama2-70b-4096-rd-lr",
model_type="llama2-70b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": -1,
"remat_policy": "full",
"max_target_length": 4096,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"reuse_example_batch": 0,
"profiler": "xplane",
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "tfds",
"tokenizer_path": "assets/tokenizer.llama2",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
),
)
)

llama2_70b_4096_real_data_pw_long_run = _add_to_model_dictionary(
trillium_model_dict,
Expand All @@ -574,7 +602,6 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,
"steps": 1000000,

# Additional tuning params for pathways long running test.
"enable_checkpointing": True,
Expand Down Expand Up @@ -829,6 +856,51 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
)


llama3_1_405b_8192_fsdp_dcn_pw = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama3-1-405b-8192-fsdp-dcn-pw",
model_type="llama3.1-405b",
tuning_params={
"per_device_batch_size": 1,
"ici_fsdp_parallelism": 64,
"ici_tensor_parallelism": 4,
"dcn_fsdp_parallelism": 2,
"allow_split_physical_axes": True,
"custom_mesh": "hybrid_ring_64x4",
"remat_policy": "custom",
"decoder_layer_input": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"out_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"gcs_metrics": True,
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"reuse_example_batch": 1,
"enable_checkpointing": False,
"profiler": "xplane",
"sa_block_q": 1024,
"sa_block_q_dkv": 2048,
"sa_block_q_dq": 2048,

# Pathways specific tuning params.
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"enable_pathways_goodput": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.HOST_OFFLOAD_FLAGS
),
)
)

llama3_1_8b_8192 = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
Expand Down Expand Up @@ -874,6 +946,56 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
)


llama3_1_8b_8192_pw = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama3_1-8b-8192-pw",
model_type="llama3.1-8b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": -1,
"remat_policy": "custom",
"decoder_layer_input": "offload",
"out_proj": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"enable_checkpointing": False,
"sa_block_q": 2048,
"sa_block_kv": 2048,
"sa_block_kv_compute": 2048,
"sa_block_q_dkv": 2048,
"sa_block_kv_dkv": 2048,
"sa_block_kv_dkv_compute": 2048,
"sa_block_q_dq": 2048,
"sa_block_kv_dq": 2048,
"sa_use_fused_bwd_kernel": True,
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,

# Pathways specific tuning params.
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"enable_pathways_goodput": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
+ xla_flags_library.DATA_PARALLEL_OVERLAP
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.ENABLE_SPARSECORE_OFFLOADING_FOR_ALL_REDUCE_PW
+ xla_flags_library.HOST_OFFLOAD_FLAGS
),
)
)

llama3_1_70b_8192 = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
Expand Down Expand Up @@ -916,6 +1038,108 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
)
)

llama3_1_70b_8192_pw = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama3_1-70b-8192-pw",
model_type="llama3.1-70b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": -1,
"remat_policy": "custom",
"decoder_layer_input": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"enable_checkpointing": False,
"sa_block_q": 2048,
"sa_block_kv": 2048,
"sa_block_kv_compute": 2048,
"sa_block_q_dkv": 2048,
"sa_block_kv_dkv": 2048,
"sa_block_kv_dkv_compute": 2048,
"sa_block_q_dq": 2048,
"sa_block_kv_dq": 2048,
"sa_use_fused_bwd_kernel": True,
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,

# Pathways specific tuning params.
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"enable_pathways_goodput": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
+ xla_flags_library.DATA_PARALLEL_OVERLAP
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.HOST_OFFLOAD_FLAGS
),
)
)


llama3_1_70b_8192_pw_lr_real_data = _add_to_model_dictionary(
trillium_model_dict,
MaxTextModel(
model_name="llama3_1-70b-8192-pw-lr-rd",
model_type="llama3.1-70b",
tuning_params={
"per_device_batch_size": 4,
"ici_fsdp_parallelism": -1,
"remat_policy": "custom",
"decoder_layer_input": "offload",
"query_proj": "offload",
"key_proj": "offload",
"value_proj": "offload",
"max_target_length": 8192,
"attention": "flash",
"use_iota_embed": True,
"dataset_path": "gs://max-datasets-rogue",
"dataset_type": "synthetic",
"enable_checkpointing": False,
"sa_block_q": 2048,
"sa_block_kv": 2048,
"sa_block_kv_compute": 2048,
"sa_block_q_dkv": 2048,
"sa_block_kv_dkv": 2048,
"sa_block_kv_dkv_compute": 2048,
"sa_block_q_dq": 2048,
"sa_block_kv_dq": 2048,
"sa_use_fused_bwd_kernel": True,
"profiler": "xplane",
"skip_first_n_steps_for_profiler": 10,
"profiler_steps": 5,

# Pathways specific tuning params.
"enable_checkpointing": True,
"async_checkpointing": True,
"checkpoint_period": 100,
"checkpoint_storage_use_ocdbt": False,
"checkpoint_storage_use_zarr3": False,
"metrics_file": "metrics.txt",
"goodput_upload_interval_seconds": 30,
"enable_pathways_goodput": True,
"enable_checkpoint_cloud_logger": True,
"enable_single_controller": True,
},
xla_flags=(
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
+ xla_flags_library.DATA_PARALLEL_OVERLAP
+ xla_flags_library.CF_FOR_ALL_GATHER
+ xla_flags_library.HOST_OFFLOAD_FLAGS
),
)
)

llama3_1_70b_129024 = _add_to_model_dictionary(
trillium_model_dict,
Expand Down
Loading

0 comments on commit c7f68ca

Please sign in to comment.