Skip to content

Commit

Permalink
restructure pyconfig and add omegaconfig
Browse files Browse the repository at this point in the history
  • Loading branch information
A9isha committed Feb 20, 2025
1 parent bea1cef commit 6815d67
Show file tree
Hide file tree
Showing 32 changed files with 135 additions and 127 deletions.
3 changes: 1 addition & 2 deletions MaxText/convert_gpt3_ckpt_from_paxml.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
"checkpoint_period=1",
"async_checkpointing=false",
]
pyconfig.initialize(base_args)
cfg = pyconfig.config
cfg = pyconfig.initialize(base_args)
init_rng, _ = random.split(random.PRNGKey(cfg.init_weights_seed), 2)
devices_array = max_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)
Expand Down
3 changes: 1 addition & 2 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def main(argv: Sequence[str]) -> None:
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

pyconfig.initialize(argv)
config = pyconfig.config
config = pyconfig.initialize(argv)
validate_config(config)
max_utils.print_system_information()

Expand Down
4 changes: 2 additions & 2 deletions MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def generate_decode_checkpoint(config):

def main(argv: Sequence[str]) -> None:
print(argv)
pyconfig.initialize(argv)
generate_decode_checkpoint(pyconfig.config)
config = pyconfig.initialize(argv)
generate_decode_checkpoint(config)


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,7 @@ def run_benchmarks(config):

def main(argv):
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
pyconfig.initialize(argv)
run_benchmarks(pyconfig.config)
run_benchmarks(pyconfig.initialize(argv))


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions MaxText/inference_microbenchmark_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def main():
- flatten_microbenchmark_results: Whether or not to flatten results. Should
be true
"""
pyconfig.initialize(sys.argv)
config = pyconfig.config
config = pyconfig.initialize(sys.argv)
base_run_name = config.run_name

with open(config.inference_metadata_file, encoding="utf-8") as json_file:
Expand Down
1 change: 1 addition & 0 deletions MaxText/inference_mlperf/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ absl-py==1.4.0
rouge-score==0.1.2
sentencepiece==0.1.99
accelerate==0.21.0
omegaconf
4 changes: 2 additions & 2 deletions MaxText/llama_mistral_mixtral_orbax_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,11 @@ def convert_orbax_hf(hf_model_path, config):


def main(argv: Sequence[str]):
pyconfig.initialize(argv[:-1])
config = pyconfig.initialize(argv[:-1])
hf_model_path = argv[-1].split("=")[1]
print(f"Will save converted HuggingFace checkpoint to path = {hf_model_path}")

convert_orbax_hf(hf_model_path, pyconfig.config)
convert_orbax_hf(hf_model_path, config)


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@

class MaxEngineConfig:
"""Engine specific config class to allow using multiple MaxEngine instances in an inference run.
The default pyconfig.config is a global param shared across multiple instances and doesn't
allow using different config for each MaxEngine instance.
TODO: evaluate the need for this given the restructured pyconfig.py
"""

def __init__(self, keys):
Expand Down
3 changes: 1 addition & 2 deletions MaxText/maxengine_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,5 @@ def main(config):
if __name__ == "__main__":
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(sys.argv)
cfg = pyconfig.config
cfg = pyconfig.initialize(sys.argv)
main(cfg)
82 changes: 45 additions & 37 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import accelerator_to_spec_map
import max_logging
import max_utils
import yaml
from omegaconf import OmegaConf

# pylint: disable=line-too-long

Expand Down Expand Up @@ -64,7 +64,7 @@ def validate_kv_quant_axis(s: str, quantize_kvcache: bool) -> None:
if s not in valid_kv_quant_axis: # currently supported kv_quant_axis
raise ValueError("Invalid kv_quant_axis was passed. Valid options ", valid_kv_quant_axis)
if quantize_kvcache and s == "":
raise ValueError("kv_quant_axis can not be '' when quantize_kvcache is True")
raise ValueError("kv_quant_axis cannot be '' when quantize_kvcache is True")


def validate_attention_kernel(s: str) -> None:
Expand Down Expand Up @@ -92,7 +92,7 @@ def validate_periodic_profiler(profiler, profile_periodically_period, profiler_s
raise ValueError("Periodic profiler requested but no profiler was set, set it via profiler=xplane or profiler=nsys")
if profile_periodically_period < profiler_steps:
raise ValueError(
f"You must set the profile_periodically_period {profile_periodically_period} at least as long profiler_steps {profiler_steps}."
f"You must set the profile_periodically_period {profile_periodically_period} at least as long as profiler_steps {profiler_steps}."
)


Expand All @@ -108,8 +108,8 @@ def validate_prefill_and_target_lengths(max_prefill_length: int, max_target_leng
if max_target_length < max_prefill_length:
# valid max_target_length = max_prefill_length for existing logit checks
raise ValueError(
f"Invalid max_target_length {max_target_length}, this should be sum of "
f"max_prefill_predict_length ({max_prefill_length}) and max output length expected."
f"Invalid max_target_length {max_target_length}, this should be the sum of "
f"max_prefill_predict_length ({max_prefill_length}) and the expected max output length."
)


Expand Down Expand Up @@ -263,16 +263,14 @@ def validate_and_assign_remat_tensors(keys):
return keys


_config = None
config = None


def _lists_to_tuples(l: list[Any]) -> Union[tuple[Any], list[Any]]:
return tuple(_lists_to_tuples(x) for x in l) if isinstance(l, list) else l


class _HyperParameters:
# pylint: disable=missing-class-docstring
# This class is responsible for loading, merging, and overriding the configuration.

def _validate_env_variables(self, raw_data_from_yaml: dict[str, Any]):
for environment_var in os.environ:
if environment_var[: len(_MAX_PREFIX)] == _MAX_PREFIX:
Expand All @@ -282,14 +280,16 @@ def _validate_env_variables(self, raw_data_from_yaml: dict[str, Any]):
if not environment_var[len(_MAX_PREFIX) :].isupper():
raise ValueError(f"We received env `{environment_var}` but it isn't all uppercase.")

def _load_kwargs(self, argv: list[str], **kwargs):
args_dict = dict(a.split("=", 1) for a in argv[2:])
args_dict.update(kwargs)
return args_dict

def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, **kwargs) -> list[str]:
"""Update model config from environment and command line"""
raw_data_from_cmd_line = self._load_kwargs(argv, **kwargs)
"""Update model config from environment and command line using OmegaConf overrides."""
# Use OmegaConf.from_cli to capture CLI arguments.
cli_cfg = OmegaConf.from_cli(argv[2:])
# Also create a configuration from any extra keyword arguments.
kwargs_cfg = OmegaConf.create(kwargs)
# Merge command-line and keyword arguments.
cmdline_cfg = OmegaConf.merge(cli_cfg, kwargs_cfg)
raw_data_from_cmd_line = OmegaConf.to_container(cmdline_cfg, resolve=True)

updated_keys = []

for k in raw_data_from_cmd_line:
Expand All @@ -300,7 +300,7 @@ def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv,
if k in raw_data_from_cmd_line and yaml_key_to_env_key(k) in os.environ:
raise ValueError(f"You are passing overrides by both CLI and ENV for `{k}`. This isn't allowed.")

if not k in raw_data_from_cmd_line and not yaml_key_to_env_key(k) in os.environ:
if k not in raw_data_from_cmd_line and yaml_key_to_env_key(k) not in os.environ:
raw_keys[k] = raw_data_from_yaml[k]
continue

Expand Down Expand Up @@ -331,9 +331,9 @@ def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv,
return updated_keys

def _load_config(self, config_name: str) -> dict[str, Any]:
"""Loads the YAML config from a file with a given name."""
with open(config_name, "r", encoding="utf-8") as yaml_file:
raw_data_from_yaml = yaml.safe_load(yaml_file)
"""Loads the YAML config from a file using OmegaConf, and resolves inheritance."""
base_cfg = OmegaConf.load(config_name)
raw_data_from_yaml = OmegaConf.to_container(base_cfg, resolve=True)

# Load data from parent config. Note that inheritance has override
# semantics, and the path is relative to the current config.
Expand All @@ -348,6 +348,7 @@ def _load_config(self, config_name: str) -> dict[str, Any]:
loaded_parent_config_filename = parent_config_filename

base_config = self._load_config(loaded_parent_config_filename)
# Override base_config with values from raw_data_from_yaml.
for key, value in raw_data_from_yaml.items():
base_config[key] = value
return base_config
Expand Down Expand Up @@ -451,7 +452,10 @@ def user_init(raw_keys):
raw_keys["global_batch_size_to_eval_on"],
raw_keys["micro_batch_size_to_eval_on"],
) = calculate_global_batch_sizes(
raw_keys["eval_per_device_batch_size"], raw_keys["expansion_factor_real_data"], get_num_target_devices(raw_keys), 1
raw_keys["eval_per_device_batch_size"],
raw_keys["expansion_factor_real_data"],
get_num_target_devices(raw_keys),
1,
)

raw_keys["num_slices"] = max_utils.get_num_slices(raw_keys)
Expand Down Expand Up @@ -508,9 +512,10 @@ def update_model_vars(base_config_path, raw_keys, config_name: str):
if not os.path.isfile(file_path):
dir_path = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(dir_path, f"configs/models/{model_name}.yml")
with open(file_path, "r", encoding="utf-8") as file:
model_vars = yaml.safe_load(file)
updated_keys = list(model_vars.keys())
# Use OmegaConf to load the model-specific configuration.
model_vars = OmegaConf.load(file_path)
model_vars = OmegaConf.to_container(model_vars, resolve=True)
updated_keys = list(model_vars.keys())
raw_keys = validate_and_update_keys(raw_keys, model_vars, config_name)
return updated_keys

Expand Down Expand Up @@ -853,30 +858,33 @@ def using_expert_parallelism(raw_keys) -> bool:
return int(raw_keys["ici_expert_parallelism"]) > 1 or int(raw_keys["dcn_expert_parallelism"]) > 1


class HyperParameters: # pylint: disable=missing-class-docstring
class HyperParameters:
"""Wrapper class to expose the configuration in a read-only manner."""

def __init__(self):
pass
def __init__(self, config):
object.__setattr__(self, "_config", config)

def __getattr__(self, attr):
if attr not in _config.keys:
raise ValueError(f"Requested key {attr}, not in config")
return _config.keys[attr]
try:
# Attempt to perform the normal lookup
return object.__getattribute__(self, "_config").keys[attr]
except AttributeError as exc:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'") from exc

def __setattr__(self, attr, value):
raise ValueError
raise ValueError("Reinitialization of config is not allowed")

def get_keys(self):
return _config.keys
return self._config.keys


def initialize(argv, **kwargs):
global _config, config
_config = _HyperParameters(argv, **kwargs)
config = HyperParameters()
config = HyperParameters(_config)
return config


if __name__ == "__main__":
initialize(sys.argv)
print(config.steps)
r = range(config.steps)
main_config = initialize(sys.argv)
print(main_config.steps)
r = range(main_config.steps)
3 changes: 1 addition & 2 deletions MaxText/scratch_code/mixtral-numerical-verification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"import pyconfig\n",
"from transformers.models.mixtral.configuration_mixtral import MixtralConfig\n",
"\n",
"pyconfig.initialize(\n",
"config_maxtext = pyconfig.initialize(\n",
" [None, \"configs/base.yml\"],\n",
" base_emb_dim=4096,\n",
" base_num_query_heads=32,\n",
Expand Down Expand Up @@ -73,7 +73,6 @@
" capacity_factor=-1,\n",
" scan_layers=False,\n",
")\n",
"config_maxtext = pyconfig.config\n",
"\n",
"config_hf = MixtralConfig(\n",
" vocab_size=config_maxtext.vocab_size,\n",
Expand Down
3 changes: 1 addition & 2 deletions MaxText/standalone_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def add_entropy_to_checkpoint(state):
def main(argv: Sequence[str]) -> None:
jax.config.update("jax_cpu_enable_gloo_collectives", True)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(argv)
config = pyconfig.config
config = pyconfig.initialize(argv)
validate_train_config(config)
print(f"Found {jax.device_count()} devices.")
print(f"Found {jax.process_count()} processes.")
Expand Down
3 changes: 1 addition & 2 deletions MaxText/standalone_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def data_load_loop(config, state=None):
def main(argv: Sequence[str]) -> None:
jax.config.update("jax_cpu_enable_gloo_collectives", True)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(argv)
config = pyconfig.config
config = pyconfig.initialize(argv)
validate_train_config(config)
max_logging.log(f"Found {jax.device_count()} devices.")
max_logging.log(f"Found {jax.process_count()} processes.")
Expand Down
13 changes: 5 additions & 8 deletions MaxText/tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ class AttentionTest(unittest.TestCase):

def setUp(self):
super().setUp()
pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
enable_checkpointing=False,
max_target_length=128,
max_prefill_predict_length=16,
)
self.cfg = pyconfig.config
self.cfg = config
self.rng = jax.random.PRNGKey(0)

devices_array = max_utils.create_device_mesh(self.cfg)
Expand Down Expand Up @@ -336,7 +336,7 @@ def _dot_product_attention(

rtol, atol = 1e-02, 1e-02

pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
Expand All @@ -345,7 +345,6 @@ def _dot_product_attention(
max_prefill_predict_length=16,
attention="dot_product",
)
config = pyconfig.config

prefill_length = config.max_prefill_predict_length
decode_total_length = config.max_target_length
Expand Down Expand Up @@ -437,7 +436,7 @@ def _dot_product_attention_reshape_q(self, compute_axis_order):

rtol, atol = 1e-02, 1e-02

pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
Expand All @@ -446,7 +445,6 @@ def _dot_product_attention_reshape_q(self, compute_axis_order):
max_prefill_predict_length=16,
attention="dot_product",
)
config = pyconfig.config

prefill_length = config.max_prefill_predict_length
decode_total_length = config.max_target_length
Expand Down Expand Up @@ -727,7 +725,7 @@ class MLATest(parameterized.TestCase):

def init_mla(self, rope_type):
"""Helper function to initialize MLA with different model names."""
pyconfig.initialize(
cfg = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
Expand All @@ -737,7 +735,6 @@ def init_mla(self, rope_type):
attention_type=attentions.AttentionType.MLA.value,
rope_type=rope_type,
)
cfg = pyconfig.config
rng = jax.random.PRNGKey(0)

devices_array = max_utils.create_device_mesh(cfg)
Expand Down
3 changes: 1 addition & 2 deletions MaxText/tests/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,5 @@ def main(config, test_args):
for arg in to_remove_args:
model_args = [s for s in model_args if not s.startswith(arg)]

pyconfig.initialize(model_args)
cfg = pyconfig.config
cfg = pyconfig.initialize(model_args)
main(cfg, test_args)
Loading

0 comments on commit 6815d67

Please sign in to comment.