Skip to content

Commit

Permalink
restructure pyconfig and add omegaconfig
Browse files Browse the repository at this point in the history
  • Loading branch information
A9isha committed Feb 19, 2025
1 parent 858da97 commit 99b807e
Show file tree
Hide file tree
Showing 36 changed files with 86 additions and 115 deletions.
3 changes: 1 addition & 2 deletions MaxText/convert_gpt3_ckpt_from_paxml.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
"checkpoint_period=1",
"async_checkpointing=false",
]
pyconfig.initialize(base_args)
cfg = pyconfig.config
cfg = pyconfig.initialize(base_args)
init_rng, _ = random.split(random.PRNGKey(cfg.init_weights_seed), 2)
devices_array = max_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)
Expand Down
3 changes: 1 addition & 2 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def main(argv: Sequence[str]) -> None:
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

pyconfig.initialize(argv)
config = pyconfig.config
config = pyconfig.initialize(argv)
validate_config(config)
max_utils.print_system_information()

Expand Down
4 changes: 2 additions & 2 deletions MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def generate_decode_checkpoint(config):

def main(argv: Sequence[str]) -> None:
print(argv)
pyconfig.initialize(argv)
generate_decode_checkpoint(pyconfig.config)
config = pyconfig.initialize(argv)
generate_decode_checkpoint(config)


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,7 @@ def run_benchmarks(config):

def main(argv):
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
pyconfig.initialize(argv)
run_benchmarks(pyconfig.config)
run_benchmarks(pyconfig.initialize(argv))


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions MaxText/inference_microbenchmark_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def main():
- flatten_microbenchmark_results: Whether or not to flatten results. Should
be true
"""
pyconfig.initialize(sys.argv)
config = pyconfig.config
config = pyconfig.initialize(sys.argv)
base_run_name = config.run_name

with open(config.inference_metadata_file, encoding="utf-8") as json_file:
Expand Down
1 change: 1 addition & 0 deletions MaxText/inference_mlperf/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ absl-py==1.4.0
rouge-score==0.1.2
sentencepiece==0.1.99
accelerate==0.21.0
omegaconf
4 changes: 2 additions & 2 deletions MaxText/llama_mistral_mixtral_orbax_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,11 @@ def convert_orbax_hf(hf_model_path, config):


def main(argv: Sequence[str]):
pyconfig.initialize(argv[:-1])
config = pyconfig.initialize(argv[:-1])
hf_model_path = argv[-1].split("=")[1]
print(f"Will save converted HuggingFace checkpoint to path = {hf_model_path}")

convert_orbax_hf(hf_model_path, pyconfig.config)
convert_orbax_hf(hf_model_path, config)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

class MaxEngineConfig:
"""Engine specific config class to allow using multiple MaxEngine instances in an inference run.
The default pyconfig.config is a global param shared across multiple instances and doesn't
The default pyconfig.initialize is a global param shared across multiple instances and doesn't
allow using different config for each MaxEngine instance.
"""

Expand Down
3 changes: 1 addition & 2 deletions MaxText/maxengine_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,5 @@ def main(config):
if __name__ == "__main__":
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(sys.argv)
cfg = pyconfig.config
cfg = pyconfig.initialize(sys.argv)
main(cfg)
45 changes: 24 additions & 21 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import accelerator_to_spec_map
import max_logging
import max_utils
import yaml
from omegaconf import OmegaConf

# pylint: disable=line-too-long

Expand Down Expand Up @@ -263,10 +263,6 @@ def validate_and_assign_remat_tensors(keys):
return keys


_config = None
config = None


def _lists_to_tuples(l: list[Any]) -> Union[tuple[Any], list[Any]]:
return tuple(_lists_to_tuples(x) for x in l) if isinstance(l, list) else l

Expand Down Expand Up @@ -333,7 +329,8 @@ def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv,
def _load_config(self, config_name: str) -> dict[str, Any]:
"""Loads the YAML config from a file with a given name."""
with open(config_name, "r", encoding="utf-8") as yaml_file:
raw_data_from_yaml = yaml.safe_load(yaml_file)
raw_data_from_yaml = OmegaConf.load(yaml_file)
raw_data_from_yaml = OmegaConf.to_container(raw_data_from_yaml, resolve=True)

# Load data from parent config. Note that inheritance has override
# semantics, and the path is relative to the current config.
Expand Down Expand Up @@ -507,9 +504,10 @@ def update_model_vars(base_config_path, raw_keys, config_name: str):
if not os.path.isfile(file_path):
dir_path = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(dir_path, f"configs/models/{model_name}.yml")
with open(file_path, "r", encoding="utf-8") as file:
model_vars = yaml.safe_load(file)
updated_keys = list(model_vars.keys())
# Use OmegaConf to load the model-specific configuration.
model_vars = OmegaConf.load(file_path)
model_vars = OmegaConf.to_container(model_vars, resolve=True)
updated_keys = list(model_vars.keys())
raw_keys = validate_and_update_keys(raw_keys, model_vars, config_name)
return updated_keys

Expand Down Expand Up @@ -854,28 +852,33 @@ def using_expert_parallelism(raw_keys) -> bool:

class HyperParameters: # pylint: disable=missing-class-docstring

def __init__(self):
pass
def __init__(self, config):
object.__setattr__(self, "_config", config)

def __getattr__(self, attr):
if attr not in _config.keys:
raise ValueError(f"Requested key {attr}, not in config")
return _config.keys[attr]
try:
# Attempt to perform the normal lookup
return object.__getattribute__(self, "_config").keys[attr]
except AttributeError as exc:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'") from exc

def __setattr__(self, attr, value):
raise ValueError
if attr != "_config" or attr in self.__dict__:
raise ValueError("Reinitialization of config is not allowed")
else: # we allow initilizing once
object.__setattr__(self, attr, value)

def get_keys(self):
return _config.keys
return self._config.keys


def initialize(argv, **kwargs):
global _config, config
_config = _HyperParameters(argv, **kwargs)
config = HyperParameters()
config = HyperParameters(_config)
return config


if __name__ == "__main__":
initialize(sys.argv)
print(config.steps)
r = range(config.steps)
main_config = initialize(sys.argv)
print(main_config.steps)
r = range(main_config.steps)
3 changes: 1 addition & 2 deletions MaxText/scratch_code/mixtral-numerical-verification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"import pyconfig\n",
"from transformers.models.mixtral.configuration_mixtral import MixtralConfig\n",
"\n",
"pyconfig.initialize(\n",
"config_maxtext = pyconfig.initialize(\n",
" [None, \"configs/base.yml\"],\n",
" base_emb_dim=4096,\n",
" base_num_query_heads=32,\n",
Expand Down Expand Up @@ -73,7 +73,6 @@
" capacity_factor=-1,\n",
" scan_layers=False,\n",
")\n",
"config_maxtext = pyconfig.config\n",
"\n",
"config_hf = MixtralConfig(\n",
" vocab_size=config_maxtext.vocab_size,\n",
Expand Down
3 changes: 1 addition & 2 deletions MaxText/standalone_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def add_entropy_to_checkpoint(state):
def main(argv: Sequence[str]) -> None:
jax.config.update("jax_cpu_enable_gloo_collectives", True)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(argv)
config = pyconfig.config
config = pyconfig.initialize(argv)
validate_train_config(config)
print(f"Found {jax.device_count()} devices.")
print(f"Found {jax.process_count()} processes.")
Expand Down
3 changes: 1 addition & 2 deletions MaxText/standalone_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def data_load_loop(config, state=None):
def main(argv: Sequence[str]) -> None:
jax.config.update("jax_cpu_enable_gloo_collectives", True)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(argv)
config = pyconfig.config
config = pyconfig.initialize(argv)
validate_train_config(config)
max_logging.log(f"Found {jax.device_count()} devices.")
max_logging.log(f"Found {jax.process_count()} processes.")
Expand Down
13 changes: 5 additions & 8 deletions MaxText/tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ class AttentionTest(unittest.TestCase):

def setUp(self):
super().setUp()
pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
enable_checkpointing=False,
max_target_length=128,
max_prefill_predict_length=16,
)
self.cfg = pyconfig.config
self.cfg = config
self.rng = jax.random.PRNGKey(0)

devices_array = max_utils.create_device_mesh(self.cfg)
Expand Down Expand Up @@ -336,7 +336,7 @@ def _dot_product_attention(

rtol, atol = 1e-02, 1e-02

pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
Expand All @@ -345,7 +345,6 @@ def _dot_product_attention(
max_prefill_predict_length=16,
attention="dot_product",
)
config = pyconfig.config

prefill_length = config.max_prefill_predict_length
decode_total_length = config.max_target_length
Expand Down Expand Up @@ -437,7 +436,7 @@ def _dot_product_attention_reshape_q(self, compute_axis_order):

rtol, atol = 1e-02, 1e-02

pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
Expand All @@ -446,7 +445,6 @@ def _dot_product_attention_reshape_q(self, compute_axis_order):
max_prefill_predict_length=16,
attention="dot_product",
)
config = pyconfig.config

prefill_length = config.max_prefill_predict_length
decode_total_length = config.max_target_length
Expand Down Expand Up @@ -727,7 +725,7 @@ class MLATest(parameterized.TestCase):

def init_mla(self, rope_type):
"""Helper function to initialize MLA with different model names."""
pyconfig.initialize(
cfg = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
Expand All @@ -737,7 +735,6 @@ def init_mla(self, rope_type):
attention_type=attentions.AttentionType.MLA.value,
rope_type=rope_type,
)
cfg = pyconfig.config
rng = jax.random.PRNGKey(0)

devices_array = max_utils.create_device_mesh(cfg)
Expand Down
3 changes: 1 addition & 2 deletions MaxText/tests/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,5 @@ def main(config, test_args):
for arg in to_remove_args:
model_args = [s for s in model_args if not s.startswith(arg)]

pyconfig.initialize(model_args)
cfg = pyconfig.config
cfg = pyconfig.initialize(model_args)
main(cfg, test_args)
4 changes: 1 addition & 3 deletions MaxText/tests/gpt3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,13 @@ class GPT3(unittest.TestCase):

def setUp(self):
super().setUp()
pyconfig.initialize(
self.cfg = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
run_name="test",
enable_checkpointing=False,
model_name="gpt3-52k",
dtype="float32",
)

self.cfg = pyconfig.config
self.rng = jax.random.PRNGKey(1234)

devices_array = max_utils.create_device_mesh(self.cfg)
Expand Down
3 changes: 1 addition & 2 deletions MaxText/tests/grain_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def setUpClass(cls):

def setUp(self):
super().setUp()
pyconfig.initialize(
self.config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1,
run_name="test",
Expand All @@ -53,7 +53,6 @@ def setUp(self):
tokenizer_path="../assets/tokenizer",
enable_checkpointing=False,
)
self.config = pyconfig.config
self.mesh_shape_1d = (len(jax.devices()),)
self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes)
self.process_indices = input_pipeline_interface.get_process_loading_real_data(
Expand Down
4 changes: 2 additions & 2 deletions MaxText/tests/hf_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class HfDataProcessingTest(unittest.TestCase):

def setUp(self):
super().setUp()
pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1,
run_name="test",
Expand All @@ -45,7 +45,7 @@ def setUp(self):
tokenizer_path="google-t5/t5-large",
enable_checkpointing=False,
)
self.config = pyconfig.config
self.config = config
self.mesh_shape_1d = (len(jax.devices()),)
self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes)
self.process_indices = input_pipeline_interface.get_process_loading_real_data(
Expand Down
4 changes: 2 additions & 2 deletions MaxText/tests/inference_microbenchmark_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Inference_Microbenchmark(unittest.TestCase):

@pytest.mark.tpu_only
def test(self):
pyconfig.initialize(
config = pyconfig.initialize(
[
None,
"configs/tpu_smoke_test.yml",
Expand All @@ -38,7 +38,7 @@ def test(self):
"weight_dtype=bfloat16",
]
)
run_benchmarks(pyconfig.config)
run_benchmarks(config)


if __name__ == "__main__":
Expand Down
6 changes: 2 additions & 4 deletions MaxText/tests/max_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ def __call__(self, x, y):
class MaxUtilsInitStateWithMultipleCollections(unittest.TestCase):

def setUp(self):
pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False)
self.config = pyconfig.config
self.config = pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False)
self.model = ModelWithMultipleCollections()
self.key1, self.key2, self.key3 = random.split(random.key(0), num=3)
self.input = random.normal(self.key1, (self.config.global_batch_size_to_load, self.config.max_target_length))
Expand Down Expand Up @@ -152,8 +151,7 @@ class MaxUtilsInitTransformerState(unittest.TestCase):
"""Tests initialization of transformer states in max_utils.py"""

def setUp(self):
pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False)
self.config = pyconfig.config
self.config = pyconfig.initialize([None, "configs/base.yml"], enable_checkpointing=False)
devices_array = max_utils.create_device_mesh(self.config)
self.mesh = Mesh(devices_array, self.config.mesh_axes)
quant = quantizations.configure_quantization(self.config)
Expand Down
Loading

0 comments on commit 99b807e

Please sign in to comment.