Skip to content

Commit

Permalink
restructure yconfig.py
Browse files Browse the repository at this point in the history
  • Loading branch information
A9isha committed Feb 19, 2025
1 parent 45a8423 commit 3982091
Show file tree
Hide file tree
Showing 28 changed files with 50 additions and 80 deletions.
3 changes: 1 addition & 2 deletions MaxText/convert_gpt3_ckpt_from_paxml.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
"checkpoint_period=1",
"async_checkpointing=false",
]
pyconfig.initialize(base_args)
cfg = pyconfig.config
cfg = pyconfig.initialize(base_args)
init_rng, _ = random.split(random.PRNGKey(cfg.init_weights_seed), 2)
devices_array = max_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)
Expand Down
3 changes: 1 addition & 2 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def main(argv: Sequence[str]) -> None:
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

pyconfig.initialize(argv)
config = pyconfig.config
config = pyconfig.initialize(argv)
validate_config(config)
max_utils.print_system_information()

Expand Down
3 changes: 1 addition & 2 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,7 @@ def run_benchmarks(config):

def main(argv):
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
pyconfig.initialize(argv)
run_benchmarks(pyconfig.config)
run_benchmarks(pyconfig.initialize(argv))


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions MaxText/inference_microbenchmark_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def main():
- flatten_microbenchmark_results: Whether or not to flatten results. Should
be true
"""
pyconfig.initialize(sys.argv)
config = pyconfig.config
config = pyconfig.initialize(sys.argv)
base_run_name = config.run_name

with open(config.inference_metadata_file, encoding="utf-8") as json_file:
Expand Down
4 changes: 2 additions & 2 deletions MaxText/llama_mistral_mixtral_orbax_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,11 @@ def convert_orbax_hf(hf_model_path, config):


def main(argv: Sequence[str]):
pyconfig.initialize(argv[:-1])
config = pyconfig.initialize(argv[:-1])
hf_model_path = argv[-1].split("=")[1]
print(f"Will save converted HuggingFace checkpoint to path = {hf_model_path}")

convert_orbax_hf(hf_model_path, pyconfig.config)
convert_orbax_hf(hf_model_path, config)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

class MaxEngineConfig:
"""Engine specific config class to allow using multiple MaxEngine instances in an inference run.
The default pyconfig.config is a global param shared across multiple instances and doesn't
The default pyconfig.initialize is a global param shared across multiple instances and doesn't
allow using different config for each MaxEngine instance.
"""

Expand Down
3 changes: 1 addition & 2 deletions MaxText/maxengine_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,5 @@ def main(config):
if __name__ == "__main__":
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(sys.argv)
cfg = pyconfig.config
cfg = pyconfig.initialize(sys.argv)
main(cfg)
2 changes: 1 addition & 1 deletion MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ def __setattr__(self, attr, value):
object.__setattr__(self, attr, value)

def get_keys(self):
return self.__dict__.keys()
return self.__getattribute__("_config").keys


def initialize(argv, **kwargs):
Expand Down
3 changes: 1 addition & 2 deletions MaxText/scratch_code/mixtral-numerical-verification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"import pyconfig\n",
"from transformers.models.mixtral.configuration_mixtral import MixtralConfig\n",
"\n",
"pyconfig.initialize(\n",
"config_maxtext = pyconfig.initialize(\n",
" [None, \"configs/base.yml\"],\n",
" base_emb_dim=4096,\n",
" base_num_query_heads=32,\n",
Expand Down Expand Up @@ -73,7 +73,6 @@
" capacity_factor=-1,\n",
" scan_layers=False,\n",
")\n",
"config_maxtext = pyconfig.config\n",
"\n",
"config_hf = MixtralConfig(\n",
" vocab_size=config_maxtext.vocab_size,\n",
Expand Down
3 changes: 1 addition & 2 deletions MaxText/standalone_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def add_entropy_to_checkpoint(state):
def main(argv: Sequence[str]) -> None:
jax.config.update("jax_cpu_enable_gloo_collectives", True)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(argv)
config = pyconfig.config
config = pyconfig.initialize(argv)
validate_train_config(config)
print(f"Found {jax.device_count()} devices.")
print(f"Found {jax.process_count()} processes.")
Expand Down
3 changes: 1 addition & 2 deletions MaxText/standalone_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def data_load_loop(config, state=None):
def main(argv: Sequence[str]) -> None:
jax.config.update("jax_cpu_enable_gloo_collectives", True)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(argv)
config = pyconfig.config
config = pyconfig.initialize(argv)
validate_train_config(config)
max_logging.log(f"Found {jax.device_count()} devices.")
max_logging.log(f"Found {jax.process_count()} processes.")
Expand Down
10 changes: 4 additions & 6 deletions MaxText/tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ class AttentionTest(unittest.TestCase):

def setUp(self):
super().setUp()
pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
enable_checkpointing=False,
max_target_length=128,
max_prefill_predict_length=16,
)
self.cfg = pyconfig.config
self.cfg = config
self.rng = jax.random.PRNGKey(0)

devices_array = max_utils.create_device_mesh(self.cfg)
Expand Down Expand Up @@ -335,7 +335,7 @@ def _dot_product_attention(

rtol, atol = 1e-02, 1e-02

pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
Expand All @@ -344,7 +344,6 @@ def _dot_product_attention(
max_prefill_predict_length=16,
attention="dot_product",
)
config = pyconfig.config

prefill_length = config.max_prefill_predict_length
decode_total_length = config.max_target_length
Expand Down Expand Up @@ -436,7 +435,7 @@ def _dot_product_attention_reshape_q(self, compute_axis_order):

rtol, atol = 1e-02, 1e-02

pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
Expand All @@ -445,7 +444,6 @@ def _dot_product_attention_reshape_q(self, compute_axis_order):
max_prefill_predict_length=16,
attention="dot_product",
)
config = pyconfig.config

prefill_length = config.max_prefill_predict_length
decode_total_length = config.max_target_length
Expand Down
3 changes: 1 addition & 2 deletions MaxText/tests/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,5 @@ def main(config, test_args):
for arg in to_remove_args:
model_args = [s for s in model_args if not s.startswith(arg)]

pyconfig.initialize(model_args)
cfg = pyconfig.config
cfg = pyconfig.initialize(model_args)
main(cfg, test_args)
4 changes: 1 addition & 3 deletions MaxText/tests/gpt3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,13 @@ class GPT3(unittest.TestCase):

def setUp(self):
super().setUp()
pyconfig.initialize(
self.cfg = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
run_name="test",
enable_checkpointing=False,
model_name="gpt3-52k",
dtype="float32",
)

self.cfg = pyconfig.config
self.rng = jax.random.PRNGKey(1234)

devices_array = max_utils.create_device_mesh(self.cfg)
Expand Down
3 changes: 1 addition & 2 deletions MaxText/tests/grain_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def setUpClass(cls):

def setUp(self):
super().setUp()
pyconfig.initialize(
self.config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1,
run_name="test",
Expand All @@ -53,7 +53,6 @@ def setUp(self):
tokenizer_path="../assets/tokenizer",
enable_checkpointing=False,
)
self.config = pyconfig.config
self.mesh_shape_1d = (len(jax.devices()),)
self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes)
self.process_indices = input_pipeline_interface.get_process_loading_real_data(
Expand Down
4 changes: 2 additions & 2 deletions MaxText/tests/hf_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class HfDataProcessingTest(unittest.TestCase):

def setUp(self):
super().setUp()
pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1,
run_name="test",
Expand All @@ -45,7 +45,7 @@ def setUp(self):
tokenizer_path="google-t5/t5-large",
enable_checkpointing=False,
)
self.config = pyconfig.config
self.config = config
self.mesh_shape_1d = (len(jax.devices()),)
self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes)
self.process_indices = input_pipeline_interface.get_process_loading_real_data(
Expand Down
4 changes: 2 additions & 2 deletions MaxText/tests/inference_microbenchmark_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Inference_Microbenchmark(unittest.TestCase):

@pytest.mark.tpu_only
def test(self):
pyconfig.initialize(
config = pyconfig.initialize(
[
None,
"configs/tpu_smoke_test.yml",
Expand All @@ -38,7 +38,7 @@ def test(self):
"weight_dtype=bfloat16",
]
)
run_benchmarks(pyconfig.config)
run_benchmarks(config)


if __name__ == "__main__":
Expand Down
6 changes: 2 additions & 4 deletions MaxText/tests/max_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ def __call__(self, x, y):
class MaxUtilsInitStateWithMultipleCollections(unittest.TestCase):

def setUp(self):
pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False)
self.config = pyconfig.config
self.config = pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False)
self.model = ModelWithMultipleCollections()
self.key1, self.key2, self.key3 = random.split(random.key(0), num=3)
self.input = random.normal(self.key1, (self.config.global_batch_size_to_load, self.config.max_target_length))
Expand Down Expand Up @@ -152,8 +151,7 @@ class MaxUtilsInitTransformerState(unittest.TestCase):
"""Tests initialization of transformer states in max_utils.py"""

def setUp(self):
pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False)
self.config = pyconfig.config
self.config = pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False)
devices_array = max_utils.create_device_mesh(self.config)
self.mesh = Mesh(devices_array, self.config.mesh_axes)
quant = quantizations.configure_quantization(self.config)
Expand Down
7 changes: 3 additions & 4 deletions MaxText/tests/maxengine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def setUp(self):
self.rng = jax.random.PRNGKey(0)

def init_pyconfig(self, **kwargs):
pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
Expand All @@ -57,7 +57,7 @@ def init_pyconfig(self, **kwargs):
max_prefill_predict_length=4,
**kwargs,
)
return pyconfig.config
return config

def get_data(self):
s = (self.cfg.global_batch_size_to_train_on, self.cfg.max_target_length)
Expand All @@ -71,12 +71,11 @@ def get_data(self):
return ids, decoder_segment_ids, decoder_positions

def test_stack_and_unstack_prefill_cache(self):
pyconfig.initialize(
config = pyconfig.initialize(
[None, "configs/base.yml"],
enable_checkpointing=False,
stack_prefill_result_cache=True,
)
config = pyconfig.config
engine = MaxEngine(config, jax.devices())
num_layers = engine.config.num_decoder_layers
input = {
Expand Down
4 changes: 2 additions & 2 deletions MaxText/tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setUp(self):
self.rng = jax.random.PRNGKey(0)

def init_pyconfig(self, **kwargs):
pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
Expand All @@ -56,7 +56,7 @@ def init_pyconfig(self, **kwargs):
max_prefill_predict_length=4,
**kwargs,
)
return pyconfig.config
return config

def get_data(self):
s = (self.cfg.global_batch_size_to_train_on, self.cfg.max_target_length)
Expand Down
9 changes: 3 additions & 6 deletions MaxText/tests/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class TokenDroppingTest(unittest.TestCase):

def setUp(self):
super().setUp()
pyconfig.initialize(
self.cfg = pyconfig.initialize(
[None, "configs/base.yml"],
run_name="token_dropping_test",
enable_checkpointing=False,
Expand All @@ -50,7 +50,6 @@ def setUp(self):
per_device_batch_size=1,
capacity_factor=2,
)
self.cfg = pyconfig.config
self.rng = jax.random.PRNGKey(42)
devices_array = max_utils.create_device_mesh(self.cfg)
self.model = linears.MoeBlock(
Expand Down Expand Up @@ -263,7 +262,7 @@ def get_moe_output(self, variables, hidden_states, cfg, mesh):

@pytest.mark.tpu_only
def test_megablox(self):
pyconfig.initialize(
cfg = pyconfig.initialize(
[None, "configs/base.yml"],
run_name="moe_block_megablox_test",
enable_checkpointing=False,
Expand All @@ -274,7 +273,6 @@ def test_megablox(self):
per_device_batch_size=4,
)

cfg = pyconfig.config
rng = jax.random.PRNGKey(1234)
rng_model, rng_hidden_states = jax.random.split(rng)
hidden_states = jax.random.uniform(
Expand All @@ -289,7 +287,7 @@ def test_megablox(self):

@pytest.mark.tpu_only
def test_dense(self):
pyconfig.initialize(
cfg = pyconfig.initialize(
[None, "configs/base.yml"],
run_name="moe_block_dense_test",
enable_checkpointing=False,
Expand All @@ -300,7 +298,6 @@ def test_dense(self):
per_device_batch_size=4,
)

cfg = pyconfig.config
rng = jax.random.PRNGKey(2345)
rng_model, rng_hidden_states = jax.random.split(rng)
hidden_states = jax.random.uniform(
Expand Down
3 changes: 1 addition & 2 deletions MaxText/tests/multihost_dataloading_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class MultihostDataloadingTest(unittest.TestCase):
def setUp(self):
super().setUp()
batch_size = 4
pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1,
run_name="test",
Expand All @@ -46,7 +46,6 @@ def setUp(self):
dataset_path="gs://maxtext-dataset/",
enable_checkpointing=False,
)
config = pyconfig.config
global_data_shape = PartitionSpec(batch_size, config.max_target_length)
data_sharding = ("data",)
mesh_shape_1d = (len(jax.devices()),)
Expand Down
Loading

0 comments on commit 3982091

Please sign in to comment.