-
Notifications
You must be signed in to change notification settings - Fork 321
/
Copy pathbase.yml
593 lines (515 loc) · 29.9 KB
/
base.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This sentinel is a reminder to choose a real run name.
# If there is already a checkpoint under this run, that checkpoint will auto-resume.
run_name: ""
model_name: "default" # override config settings to match a specific model. other than the override, nothing should use this!
normalization_layer_epsilon: 1.e-05
################################## CHECKPOINTING ##################################
# Checkpointing makes the following choices in the following order, starting with (1):
# (1) If there is already a checkpoint for this run_name, we load the latest entire checkpoint.
# This ensures if we're resuming a run after preemption or hardware failure we lose minimum state.
# (2) Same priority and mutually exclusive -- you can't set both!
# * If load_parameters_path is set, we load a parameter only checkpoint from that path.
# * If load_full_state_path is set, we load a full state checkpoint from that path.
# (3) We don't load a checkpoint and initialize state instead!
# Loads a just parameters from a specific directory
# e.g. gs://my-base-output-directory/my-previous-run-name/checkpoints/items/NUMBER or NUMBER/items
load_parameters_path: ""
# Loads a full checkpoint including optimizer state and step count from a specific directory
# e.g. gs://my-base-output-directory/my-previous-run-name/checkpoints/items/NUMBER or NUMBER/items
load_full_state_path: ""
# If enable_checkpointing is true, an asynchronous checkpointer will be used if
# async_checkpointing is true, else a synchronous one is used. If you have
# problems with the checkpointer we recommend trying the synchronous one.
enable_checkpointing: True
async_checkpointing: True
checkpoint_period: 10_000
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False
force_unroll: False # during generate_param_only_checkpoint should we unroll the loop?
# checkpointing using orbax has two important parameters: array driver
# and its underlying storage - the kvstore (preferably ocdbt)
# orbax supports setting a target file size, chunking a single
# large arrays into small physical files (<2GB) can speed up distributed and over
# the network loading enormously
checkpoint_storage_target_data_file_size_bytes: 2147483648
checkpoint_storage_use_ocdbt: True
checkpoint_storage_use_zarr3: True
############################### END CHECKPOINTING ##################################
reuse_example_batch: 0 # for testing TPU performance, this options repeated uses the same batch.
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
gcs_metrics: False
# If true save config to GCS in {base_output_directory}/{run_name}/
save_config_to_gcs: False
# Activation dtypes.
dtype: "bfloat16"
# Used to configure quantization in the transformer layers, defaults to null implying bf16.
# Possible alternative settings are as follows:
# 'int8' for dynamic range quantization using 8-bits
# 'intmp' for mixed precision quantization for inference as described here: MaxText/configs/quantization/README.md
# 'fp8' for 8-bit floating-point GeMMs on NVIDIA GPUs.
quantization: ""
# Choose one of default, high, and highest.
# https://kolonist26-jax-kr.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
matmul_precision: "default"
activations_in_float32: False # Sets activations to float32 before nonlinearity it true, else dtype
# Used to replicate the quantization scale to avoid the inefficient XLA fusion for 2d sharding.
replicate_quant_scale: False
# Path to file with quantization config for intmp.
quant_cfg_path: ""
quantize_kvcache: False # Set to True to quantize KV Cache values, defaults to False
# Valid kv_quant_axis values:
# - "" is valid only when quantize_kvcache is False
# - "dkv" indicates quantize kv cache over the cache_kv, i.e. kv dimension axis
# - "heads_and_dkv" indicates quantize kv cache over cache_heads and cache_kv axes
# Default to "heads_and_dkv" for faster compution, kv_quant_axis is not used when quantize_kvcache is False
# - "dkv" is expected with better accuracy but degraded computation
kv_quant_axis: "heads_and_dkv"
kv_quant_dtype: "int8"
checkpoint_is_quantized: False # Set to True if reading from a saved aqt quantized checkpoint
# Saves params quantized on fly at following path
save_quantized_params_path: ""
#Used to configure the mode in which model is called
# when left as is, corresponds to training
# accepted values are "inference"
model_call_mode: ""
# Shard the range finding operation for quantization. By default this is set to number of slices.
quantization_local_shard_count: -1
decoder_block: "llama2" # which style of DecoderBlock to use.
# Global parameter scale needs to be a power of 2. If you want finer grained control of the model sizes
# then you should explicitly set base_embed_dim, base_num_query_heads, base_num_kv_heads,
# base_mlp_dim, base_num_decoder_layers and/or head_dim.
weight_dtype: float32
global_parameter_scale: 1
base_emb_dim: 2048
base_num_query_heads: 16
base_num_kv_heads: 16
base_mlp_dim: 7168
base_num_decoder_layers: 16
head_dim: 128
mlp_activations: ["silu", "linear"]
dropout_rate: 0.0
logits_via_embedding: False
normalize_embedding_logits: True # whether to normlize pre-softmax logits if logits_via_embedding is true
logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embedding dot product for stability
cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly.
float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product
float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax
# mixture of experts (moe)
num_experts: 1
num_experts_per_tok: 1
megablox: True
sparse_matmul: True
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
load_balance_loss_weight: 0.01 # weight for the load balance loss
# deepseek moe
base_moe_mlp_dim: 7168 # intermediate dimension at MoE layer (use base_mlp_dim if not DeepSeek style)
first_num_dense_layers: 0 # number of initial dense layers in the model
shared_experts: 1
routed_scaling_factor: 1.0 # scaling factor for routing scores
routed_score_func: "" # scoring function for routing
routed_bias: False # a flag if a bias term is added for routing
# pipeline parallelism
# The number of decoder layers is equal to the product of num_stages, num_layers_per_pipeline_stage and num_pipeline_repeats.
# There is a tradeoff between the num_layers_per_pipeline_stage and num_pipeline_repeats: The more layers per stage the easier
# it is to hide the pipeline communication behind the compute since there is more compute per stage, however there will be a larger bubble
# since there are fewer repeats. Similarly there is tradeoff for num_pipeline_microbatches - more microbatches leads to a smaller bubble,
# but a smaller size per microbatch which may hurt per-stage performance. Additionally note when microbatches > num_stages we have the opportunity to
# perform the circular transfer (last stage to first) asynchronously.
# The bubble fraction is (num_stages - 1) / (num_pipeline_repeats * num_pipeline_microbatches + num_stages - 1)
num_layers_per_pipeline_stage: 1
# The number of repeats will be set to num_decoder_layers / (num_pipeline_stages * num_layers_per_pipeline_stage)
num_pipeline_repeats: -1
# num_pipeline_microbatches must be a multiple of the number of pipeline stages. By default it is set to the number of stages.
# Note the microbatch_size is given by global_batch_size / num_pipeline_microbatches, where global_batch_size = per_device_batch_size * num_devices
num_pipeline_microbatches: -1
pipeline_delay_activation_forwarding: False # This delays the activation forwarding one loop iteration simplifying XLA's task of overlapping since
# the communication and compute in each iteration are now independent. However this comes at the cost of doubling the pipeline bubble,
# and you must set the number of microbatches to at least 2 * num_stages (the minimum 2 * num_stages is set by default with this delay).
pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights over FSDP before the first pipeline iteration.
# This is a memory/time tradeoff - we now have to store the FSDP gathered weights and gradients (typically in bf16), as opposed
# to only one stage's worth, however we only execute one all-gather and reduce across per repeat, as opposed
# to every microbatch. This is similar to zero-1 sharding, since we also don't need to all gather the FSDP weights in the backward pass.
# An alternative to setting this to true may be to replace any FSDP with DP and use optimizer offloading if necessary.
# A more optimal behavior is to all-gather at the start of each repeat, which would ideally get the best of both worlds -
# a small amount of memory and time, however this has proven hard to implement in SPMD, see b/364386697 for more.
# There are two loops for PP:
# 1) Outer loop over microbatches (pipeline iterations)
# 2) Inner loop over layers (layers per stage)
# We have observed extra remat when a remat policy and scanning is performed on both, and recommend the default
# settings below of scanning and setting a remat policy only over the pipeline iterations.
# It may be useful to do the reverse when the layers_per_stage is very large.
# The below settings only have effect when using pipeline parallelism.
scan_pipeline_iterations: True
# The layers per stage scanning option is set by scan_layers, we recommend setting scan_layers=False
set_remat_policy_on_pipeline_iterations: True
set_remat_policy_on_layers_per_stage: False
# Choose 'remat_policy' between 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp',
# 'save_qkv_proj', 'qkv_proj_offloaded', 'custom', 'minimal_offloaded', 'save_out_proj' and 'full'.
# These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest)
remat_policy: 'full'
# If "custom" remat_policy is chosen, you can select tensors from the following list to offload on host memory, rematerialize or save on device memory.
# Pick one of these options for following tensors: ['remat','device','offload']
decoder_layer_input: 'device' # this tensor cannot be rematerialized - it serves as periodic checkpoints that act as the remat start points
context: 'remat' # From https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/attention.py#L581-L583
mlpwi: 'remat'
mlpwi_0: 'remat'
mlpwi_1: 'remat'
mlpwo: 'remat'
query_proj: 'remat'
key_proj: 'remat'
value_proj: 'remat'
qkv_proj: 'remat'
out_proj: 'remat'
optimizer_memory_host_offload: False
scan_layers: True # We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations.
param_scan_axis: 1
# The attention parameter dictates the specific algorithm/methodology used to compute the attention scores
# The attention_type parameter determines the variants of attention, e.g. global or local_sliding
attention: 'autoselected' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te
attention_type: 'global' # Supported attention_type: global, local_sliding, mla
sliding_window_size: 0
attn_logits_soft_cap: 0.0
final_logits_soft_cap: 0.0
use_post_attn_norm: False
use_post_ffw_norm: False
# MLA parameters
q_lora_rank: 0
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
# Combine matmuls for QKV and MLP
fused_qkv: False
fused_mlp: False
record_internal_nn_metrics: 0
# Output directory
# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/"
base_output_directory: ""
# Whether or not to enable emergency checkpoint. If True, `local_checkpoint_directory` and a non-zero `local_checkpoint_period` must also be specified.
# Emergency checkpoint is an experimental Orbax feature that: periodically saves to persistent storage and, with a larger invertal, saves to a local directory.
# During restore, if a local copy is available in any slice, it will be broadcast to other slices without having to fetch from persistent storage.
# See more details on https://github.com/google/orbax/tree/main/checkpoint/orbax/checkpoint/experimental/emergency.
enable_emergency_checkpoint: False
# It should be specified when and only when `enable_emergency_checkpoint` is True.
local_checkpoint_directory: ""
# It should be a positive number when and only when `enable_emergency_checkpoint` is True.
local_checkpoint_period: 0
# Whether or not to use emergency checkpoint with the replicator service.
use_replicator_service: False
# The interval to backup local checkpoints to the persistent storage.
replicator_backup_interval_minutes: 0
# Jax cache directory
jax_cache_dir: "~/jax_cache"
# Hardware
hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu'
# Parallelism
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
['activation_kv_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence']],
['activation_length', ['sequence']],
['activation_norm_length', ['tensor_sequence', 'sequence']],
['activation_embed', ['tensor', 'tensor_transpose']],
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_kv_head_dim', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['activation_vocab', ['tensor', 'tensor_transpose', 'sequence', 'tensor_sequence']],
['activation_vocab', ['tensor', 'tensor_transpose']],
['activation_vocab', 'tensor_sequence'],
['activation_vocab', 'sequence'],
['activation_stage', 'stage'],
['activation_exp', 'expert'],
['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']],
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']],
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
['embed', ['fsdp', 'sequence', 'expert']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose']],
['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence']],
['embed_no_exp', ['fsdp', 'sequence']],
['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']],
['layers', 'stage'],
['kv', []],
['kv_head_dim', []],
['cache_batch_prefill', []],
['cache_batch', []],
['cache_heads', ['autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence']],
['cache_heads', ['autoregressive', 'tensor', 'tensor_sequence']],
['cache_kv', []],
['cache_sequence', []],
['exp', 'expert'],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
# sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters.
sharding_tolerance: 0.02
# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_fsdp_transpose_parallelism: 1
dcn_sequence_parallelism: 1 # never recommended
dcn_tensor_parallelism: 1 # never recommended
dcn_tensor_transpose_parallelism: 1
dcn_tensor_sequence_parallelism: 1 # never recommended
dcn_pipeline_parallelism: 1
dcn_expert_parallelism: 1
dcn_autoregressive_parallelism: 1 # never recommended
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_fsdp_transpose_parallelism: 1
ici_sequence_parallelism: 1
ici_tensor_parallelism: 1
ici_tensor_transpose_parallelism: 1
ici_tensor_sequence_parallelism: 1
ici_autoregressive_parallelism: 1
ici_pipeline_parallelism: 1
ici_expert_parallelism: 1
# The number of TPU slices is automatically determined, you should not set this explicitly. For ahead of time compilation,
# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1.
num_slices: -1
# Tokenizer
vocab_size: 32_000 # powers of 2 for sharding
tokenizer_path: "assets/tokenizer.llama2"
tokenizer_type: "sentencepiece"
tokenize_train_data: True # False if the dataset is pre-tokenized
tokenize_eval_data: True # False if the dataset is pre-tokenized
add_bos: True
add_eos: True
# Dataset
per_device_batch_size: 12.0
expansion_factor_real_data: -1 # if -1 then all hosts will load real data, else total_hosts//expansion_factor_real_data will pull data from GCS.
eval_per_device_batch_size: 0.0
max_corpus_chars: 10_000_000
train_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected"
eval_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected"
# direct preference optimization (DPO)
use_dpo: False
dpo_label_smoothing: 0.0
dpo_beta: 0.1
# dataset_type must be synthetic, hf, grain, tfds
# details in: https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md
dataset_type: tfds
# for TFDS input pipeline (dataset_type=tfds)
dataset_path: "" # your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/"
dataset_name: 'c4/en:3.0.1'
eval_dataset_name: 'c4/en:3.0.1'
eval_split: 'validation'
# for HuggingFace input pipeline (dataset_type=hf)
hf_path: ''
hf_data_dir: ''
hf_train_files: ''
hf_eval_split: ''
hf_eval_files: ''
hf_access_token: ''
# for Grain input pipeline (dataset_type=grain)
grain_train_files: ''
grain_eval_files: ''
grain_worker_count: 1
# Training loop
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
log_period: 100 # Flushes Tensorboard
jax_distributed_initialization_timeout: 300 # This is the default timeout in https://github.com/jax-ml/jax/blob/main/jax/_src/distributed.py
# Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers
# only to the jax coordination service.
jax_debug_log_modules: "" # Set this to "jax" to enable jax verbose logging such as for the jax coordination service initialization.
skip_jax_distributed_system: False # If True we will not initialize the jax distributed system.
# Currently the jax distributed is needed on cloud TPUs for async checkpointing.
# However when run on google internal TPUs the coordination service is started automatically
# and we should set this to True so we won't try to initialize a second time manually.
# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
# Learning rate schedule has either two or three parts:
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
# 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps
# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps.
# The zero learning rate section can be used to more accurately measure the fully trained model's performance.
learning_rate: 3.e-5
cosine_learning_rate_final_fraction: 0.1
warmup_steps_fraction: 0.1
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.
max_target_length: 2048 # Maximum sequence length
max_prefill_predict_length: 64 # Maximum length for the prefill when doing autoregression
prompt: "I love to" # Prompt for language model sampling.
load_from_prefill_dir: False # If true, decode.py doesn't "prefill" but just reads from directory
prefill_cache_dir: "" # If set and load_from_prefill_dir, decode.py reads from directory. If set, decode.py writes to directory
autoregressive_decode_assert: ""
# For nsys profiler, pass the training command to nsys command
# e.g. nsys profile -s none --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop {training command}
profiler: "" # Supported profiler: '', xplane, nsys
# If set to true, upload all profiler results from all hosts. Otherwise, only upload the profiler result from the first host.
upload_all_profiler_results: False
# Skip first n steps for profiling, to omit things like compilation and to give
# the iteration time a chance to stabilize.
skip_first_n_steps_for_profiler: 1
# Profile for a small number of steps to avoid a large profile file size.
profiler_steps: 5
profile_cleanly: True # If set to true, adds a block_until_ready on train state which aligns the profile for each step.
profile_periodically_period: -1 # If set to a positive integer, profile every profile_periodically_period steps.
# This is useful to debug scenarios where performance is changing.
# Dump HLO options
dump_hlo: False
dump_hlo_local_dir: "/tmp/xla_dump/"
dump_hlo_delete_local_after: True # Cleans local directory after its uploaded
dump_hlo_gcs_dir: "" # Defaults to {base_output_directory}/{run_name}/xla_dump
dump_hlo_module_name: "jit_train_step" # Filter uploading modules by this string. Set to empty string to remove any filter.
dump_hlo_xla_flags: "" # Defaults to "--xla_dump_to={dump_hlo_local_dir} --xla_dump_hlo_module_re={dump_hlo_module_name} --xla_dump_large_constants"
dump_hlo_upload_all: False # If true all hosts dump HLO, false only jax.process_index()==0
# All hosts should have identical HLO for SPMD programs, however we have encountered some bugs
# where this is not the case and it is helpful to compare HLO across hosts.
# When dropout is false the model is a deterministic function of the
# data_shuffle_seed and init_weights_seed (i.e. reproducible losses)
enable_dropout: True
enable_data_shuffling: True
data_shuffle_seed: 0
init_weights_seed: 0
# You may disable clipping by setting gradient_clipping_threshold to zero.
gradient_clipping_threshold: 1.0
# Instead of updating the weights every step, you may effectively use a larger
# batch by accumulating the gradient over a set of steps.
gradient_accumulation_steps: 1
# AdamW optimizer parameters
# We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
opt_type: "adamw" # one of "adam_pax" or "adamw"
adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients.
adam_b2: 0.95 # Exponential decay rate to track the second moment of past gradients.
adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root.
adam_eps_root: 0. # A small constant applied to denominator inside the square root.
adam_weight_decay: 0.1 # AdamW Weight decay
# Stack trace parameters
collect_stack_trace: False
stack_trace_to_cloud: False # Uploads to cloud logging if True, else to the console if False.
stack_trace_interval_seconds: 600 # Stack trace collection frequency in seconds.
# Use iota operator in Embed
use_iota_embed: False
# use positional embedding
use_untrainable_positional_embedding: False
trainable_position_size: -1 # enable gpt3 position embedding with a positive trainable_position_size
# RoPE parameters
rope_type: "default" # one of "default", "llama3.1" or "yarn"
rope_min_timescale: 1
rope_max_timescale: 10_000
# yarn RoPE parameters
original_seq_len: 4096
rope_theta: 10000.0
rope_factor: 40
beta_fast: 32
beta_slow: 1
mscale: 1.0
# Ahead of time Compilation (aka AOT)
# Only set these arguments if you are running train_compile or loading a compiled train step.
compiled_trainstep_file: "" # Name of saved serialized compiled train_step, e.g. compiled_train_v5e-256.pickle
compile_topology: '' # Target hardware version, e.g. 'v5e-256'
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
decode_sampling_strategy: "greedy" # decode_sampling_strategy should be one of greedy, weighted, nucleus, or topk
decode_sampling_nucleus_p: -1 # set if you're doing nucleus / top-p
decode_sampling_top_k: 0 # set if you're doing top-k
decode_sampling_temperature: 1.
eval_interval: -1 # the specific number of train step between eval_step
eval_steps: -1 # only run this number of batches for eval, for debugging use
target_eval_loss: 0. # early stop once reaching target eval_loss
# Goodput parameters
enable_goodput_recording: True
monitor_goodput: True
goodput_upload_interval_seconds: 30
enable_pathways_goodput: False
monitor_step_time_deviation: True
step_deviation_interval_seconds: 30
# GCP workload monitoring
report_heartbeat_metric_for_gcp_monitoring: False
heartbeat_reporting_interval_in_seconds: 5
report_performance_metric_for_gcp_monitoring: False
enable_tensorboard: True
# Vertex AI Tensorboard Configurations - https://github.com/google/maxtext/tree/main/getting_started/Use_Vertex_AI_Tensorboard.md
# Set to True for GCE, False if running via XPK
use_vertex_tensorboard: False
# Project to create Vertex AI Tensorboard in for GCE, blank if project is set using 'gcloud config set project'
# Set this to blank if running via XPK
vertex_tensorboard_project: ""
# Region to create Vertex AI Tensorboard in for GCE, blank if running via XPK
# Vertex AI supported regions: https://cloud.google.com/vertex-ai/docs/general/locations#available-regions
vertex_tensorboard_region: ""
# If set to True, MaxText will perform extra checks using jax.checkify. Note that this will effect performance.
max_checkify: False
# Inference
inference_microbenchmark_prefill_lengths: "64,128,256,512,1024"
inference_microbenchmark_stages: "prefill,generate"
inference_microbenchmark_loop_iters: 10
inference_microbenchmark_log_file_path: ""
inference_metadata_file: "" # path to a json file
inference_server: "MaxtextInterleavedServer" # inference server to start
inference_benchmark_test: False
enable_model_warmup: False
hf_model_path: "" # inference checkpoint correctness verification
# Stack prefill cache across the layer to reduce the
# Python layer latency.
stack_prefill_result_cache: False
# KV Cache layout control
# Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV
# Default layout: 1,2,0,3 ; CACHE_SEQUENCE, CACHE_HEADS, CACHE_BATCH, CACHE_KV
prefill_cache_axis_order: "1,2,0,3"
ar_cache_axis_order: "1,2,0,3"
# Compute layout control
# Default layout: 0,1,2,3 ; BATCH, LENGTH, HEAD, D_KV
# Currently only support compute layout: 0,1,2,3 and 0,2,1,3
compute_axis_order: "0,1,2,3"
reshape_q: False
# Maxengine Metrics
prometheus_port: 0
# Maxengine server
enable_jax_profiler: False
jax_profiler_port: 9999
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
# Checkpoint Structured logging
enable_checkpoint_cloud_logger: False
# Single-controller
enable_single_controller: False
custom_mesh: "" # Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8']
# Split physical axes for https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.mesh_utils.create_device_mesh.html
allow_split_physical_axes: False
# Apply transformations to the mesh to optimize for TPU v6e
optimize_mesh_for_tpu_v6e: False
use_ragged_attention: False
ragged_block_size: 256
### Splash attention block sizes
# These can be tuned for specific hardware generations, and can be set up to
# the model's sequence length.
sa_block_q: 512
sa_block_kv: 512
sa_block_kv_compute: 512
sa_block_q_dkv: 512
sa_block_kv_dkv: 512
sa_block_kv_dkv_compute: 512
sa_block_q_dq: 512
sa_block_kv_dq: 512
sa_use_fused_bwd_kernel: False
sa_q_layout: "HEAD_DIM_MINOR"
sa_k_layout: "HEAD_DIM_MINOR"
sa_v_layout: "HEAD_DIM_MINOR"