Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure Pyconfig #1285

Merged
merged 1 commit into from
Feb 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions MaxText/convert_gpt3_ckpt_from_paxml.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
"checkpoint_period=1",
"async_checkpointing=false",
]
pyconfig.initialize(base_args)
cfg = pyconfig.config
cfg = pyconfig.initialize(base_args)
init_rng, _ = random.split(random.PRNGKey(cfg.init_weights_seed), 2)
devices_array = max_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)
Expand Down
3 changes: 1 addition & 2 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def main(argv: Sequence[str]) -> None:
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

pyconfig.initialize(argv)
config = pyconfig.config
config = pyconfig.initialize(argv)
validate_config(config)
max_utils.print_system_information()

Expand Down
4 changes: 2 additions & 2 deletions MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def generate_decode_checkpoint(config):

def main(argv: Sequence[str]) -> None:
print(argv)
pyconfig.initialize(argv)
generate_decode_checkpoint(pyconfig.config)
config = pyconfig.initialize(argv)
generate_decode_checkpoint(config)


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,7 @@ def run_benchmarks(config):

def main(argv):
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
pyconfig.initialize(argv)
run_benchmarks(pyconfig.config)
run_benchmarks(pyconfig.initialize(argv))


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions MaxText/inference_microbenchmark_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def main():
- flatten_microbenchmark_results: Whether or not to flatten results. Should
be true
"""
pyconfig.initialize(sys.argv)
config = pyconfig.config
config = pyconfig.initialize(sys.argv)
base_run_name = config.run_name

with open(config.inference_metadata_file, encoding="utf-8") as json_file:
Expand Down
1 change: 1 addition & 0 deletions MaxText/inference_mlperf/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ absl-py==1.4.0
rouge-score==0.1.2
sentencepiece==0.1.99
accelerate==0.21.0
omegaconf
4 changes: 2 additions & 2 deletions MaxText/llama_mistral_mixtral_orbax_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,11 @@ def convert_orbax_hf(hf_model_path, config):


def main(argv: Sequence[str]):
pyconfig.initialize(argv[:-1])
config = pyconfig.initialize(argv[:-1])
hf_model_path = argv[-1].split("=")[1]
print(f"Will save converted HuggingFace checkpoint to path = {hf_model_path}")

convert_orbax_hf(hf_model_path, pyconfig.config)
convert_orbax_hf(hf_model_path, config)


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@

class MaxEngineConfig:
"""Engine specific config class to allow using multiple MaxEngine instances in an inference run.
The default pyconfig.config is a global param shared across multiple instances and doesn't
allow using different config for each MaxEngine instance.
TODO: evaluate the need for this given the restructured pyconfig.py
"""

def __init__(self, keys):
Expand Down
3 changes: 1 addition & 2 deletions MaxText/maxengine_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,5 @@ def main(config):
if __name__ == "__main__":
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(sys.argv)
cfg = pyconfig.config
cfg = pyconfig.initialize(sys.argv)
main(cfg)
82 changes: 45 additions & 37 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import accelerator_to_spec_map
import max_logging
import max_utils
import yaml
from omegaconf import OmegaConf

# pylint: disable=line-too-long

Expand Down Expand Up @@ -64,7 +64,7 @@ def validate_kv_quant_axis(s: str, quantize_kvcache: bool) -> None:
if s not in valid_kv_quant_axis: # currently supported kv_quant_axis
raise ValueError("Invalid kv_quant_axis was passed. Valid options ", valid_kv_quant_axis)
if quantize_kvcache and s == "":
raise ValueError("kv_quant_axis can not be '' when quantize_kvcache is True")
raise ValueError("kv_quant_axis cannot be '' when quantize_kvcache is True")


def validate_attention_kernel(s: str) -> None:
Expand Down Expand Up @@ -92,7 +92,7 @@ def validate_periodic_profiler(profiler, profile_periodically_period, profiler_s
raise ValueError("Periodic profiler requested but no profiler was set, set it via profiler=xplane or profiler=nsys")
if profile_periodically_period < profiler_steps:
raise ValueError(
f"You must set the profile_periodically_period {profile_periodically_period} at least as long profiler_steps {profiler_steps}."
f"You must set the profile_periodically_period {profile_periodically_period} at least as long as profiler_steps {profiler_steps}."
)


Expand All @@ -108,8 +108,8 @@ def validate_prefill_and_target_lengths(max_prefill_length: int, max_target_leng
if max_target_length < max_prefill_length:
# valid max_target_length = max_prefill_length for existing logit checks
raise ValueError(
f"Invalid max_target_length {max_target_length}, this should be sum of "
f"max_prefill_predict_length ({max_prefill_length}) and max output length expected."
f"Invalid max_target_length {max_target_length}, this should be the sum of "
f"max_prefill_predict_length ({max_prefill_length}) and the expected max output length."
)


Expand Down Expand Up @@ -263,16 +263,14 @@ def validate_and_assign_remat_tensors(keys):
return keys


_config = None
config = None


def _lists_to_tuples(l: list[Any]) -> Union[tuple[Any], list[Any]]:
return tuple(_lists_to_tuples(x) for x in l) if isinstance(l, list) else l


class _HyperParameters:
# pylint: disable=missing-class-docstring
# This class is responsible for loading, merging, and overriding the configuration.

def _validate_env_variables(self, raw_data_from_yaml: dict[str, Any]):
for environment_var in os.environ:
if environment_var[: len(_MAX_PREFIX)] == _MAX_PREFIX:
Expand All @@ -282,14 +280,16 @@ def _validate_env_variables(self, raw_data_from_yaml: dict[str, Any]):
if not environment_var[len(_MAX_PREFIX) :].isupper():
raise ValueError(f"We received env `{environment_var}` but it isn't all uppercase.")

def _load_kwargs(self, argv: list[str], **kwargs):
args_dict = dict(a.split("=", 1) for a in argv[2:])
args_dict.update(kwargs)
return args_dict

def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv, **kwargs) -> list[str]:
"""Update model config from environment and command line"""
raw_data_from_cmd_line = self._load_kwargs(argv, **kwargs)
"""Update model config from environment and command line using OmegaConf overrides."""
# Use OmegaConf.from_cli to capture CLI arguments.
cli_cfg = OmegaConf.from_cli(argv[2:])
# Also create a configuration from any extra keyword arguments.
kwargs_cfg = OmegaConf.create(kwargs)
# Merge command-line and keyword arguments.
cmdline_cfg = OmegaConf.merge(cli_cfg, kwargs_cfg)
raw_data_from_cmd_line = OmegaConf.to_container(cmdline_cfg, resolve=True)

updated_keys = []

for k in raw_data_from_cmd_line:
Expand All @@ -300,7 +300,7 @@ def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv,
if k in raw_data_from_cmd_line and yaml_key_to_env_key(k) in os.environ:
raise ValueError(f"You are passing overrides by both CLI and ENV for `{k}`. This isn't allowed.")

if not k in raw_data_from_cmd_line and not yaml_key_to_env_key(k) in os.environ:
if k not in raw_data_from_cmd_line and yaml_key_to_env_key(k) not in os.environ:
raw_keys[k] = raw_data_from_yaml[k]
continue

Expand Down Expand Up @@ -331,9 +331,9 @@ def _update_from_env_and_command_line(self, raw_keys, raw_data_from_yaml, argv,
return updated_keys

def _load_config(self, config_name: str) -> dict[str, Any]:
"""Loads the YAML config from a file with a given name."""
with open(config_name, "r", encoding="utf-8") as yaml_file:
raw_data_from_yaml = yaml.safe_load(yaml_file)
"""Loads the YAML config from a file using OmegaConf, and resolves inheritance."""
base_cfg = OmegaConf.load(config_name)
raw_data_from_yaml = OmegaConf.to_container(base_cfg, resolve=True)

# Load data from parent config. Note that inheritance has override
# semantics, and the path is relative to the current config.
Expand All @@ -348,6 +348,7 @@ def _load_config(self, config_name: str) -> dict[str, Any]:
loaded_parent_config_filename = parent_config_filename

base_config = self._load_config(loaded_parent_config_filename)
# Override base_config with values from raw_data_from_yaml.
for key, value in raw_data_from_yaml.items():
base_config[key] = value
return base_config
Expand Down Expand Up @@ -451,7 +452,10 @@ def user_init(raw_keys):
raw_keys["global_batch_size_to_eval_on"],
raw_keys["micro_batch_size_to_eval_on"],
) = calculate_global_batch_sizes(
raw_keys["eval_per_device_batch_size"], raw_keys["expansion_factor_real_data"], get_num_target_devices(raw_keys), 1
raw_keys["eval_per_device_batch_size"],
raw_keys["expansion_factor_real_data"],
get_num_target_devices(raw_keys),
1,
)

raw_keys["num_slices"] = max_utils.get_num_slices(raw_keys)
Expand Down Expand Up @@ -508,9 +512,10 @@ def update_model_vars(base_config_path, raw_keys, config_name: str):
if not os.path.isfile(file_path):
dir_path = os.path.dirname(os.path.realpath(__file__))
file_path = os.path.join(dir_path, f"configs/models/{model_name}.yml")
with open(file_path, "r", encoding="utf-8") as file:
model_vars = yaml.safe_load(file)
updated_keys = list(model_vars.keys())
# Use OmegaConf to load the model-specific configuration.
model_vars = OmegaConf.load(file_path)
model_vars = OmegaConf.to_container(model_vars, resolve=True)
updated_keys = list(model_vars.keys())
raw_keys = validate_and_update_keys(raw_keys, model_vars, config_name)
return updated_keys

Expand Down Expand Up @@ -853,30 +858,33 @@ def using_expert_parallelism(raw_keys) -> bool:
return int(raw_keys["ici_expert_parallelism"]) > 1 or int(raw_keys["dcn_expert_parallelism"]) > 1


class HyperParameters: # pylint: disable=missing-class-docstring
class HyperParameters:
"""Wrapper class to expose the configuration in a read-only manner."""

def __init__(self):
pass
def __init__(self, config):
object.__setattr__(self, "_config", config)

def __getattr__(self, attr):
if attr not in _config.keys:
raise ValueError(f"Requested key {attr}, not in config")
return _config.keys[attr]
try:
# Attempt to perform the normal lookup
return object.__getattribute__(self, "_config").keys[attr]
except AttributeError as exc:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'") from exc

def __setattr__(self, attr, value):
raise ValueError
raise ValueError("Reinitialization of config is not allowed")

def get_keys(self):
return _config.keys
return self._config.keys


def initialize(argv, **kwargs):
global _config, config
_config = _HyperParameters(argv, **kwargs)
config = HyperParameters()
config = HyperParameters(_config)
return config


if __name__ == "__main__":
initialize(sys.argv)
print(config.steps)
r = range(config.steps)
main_config = initialize(sys.argv)
print(main_config.steps)
r = range(main_config.steps)
3 changes: 1 addition & 2 deletions MaxText/scratch_code/mixtral-numerical-verification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"import pyconfig\n",
"from transformers.models.mixtral.configuration_mixtral import MixtralConfig\n",
"\n",
"pyconfig.initialize(\n",
"config_maxtext = pyconfig.initialize(\n",
" [None, \"configs/base.yml\"],\n",
" base_emb_dim=4096,\n",
" base_num_query_heads=32,\n",
Expand Down Expand Up @@ -73,7 +73,6 @@
" capacity_factor=-1,\n",
" scan_layers=False,\n",
")\n",
"config_maxtext = pyconfig.config\n",
"\n",
"config_hf = MixtralConfig(\n",
" vocab_size=config_maxtext.vocab_size,\n",
Expand Down
3 changes: 1 addition & 2 deletions MaxText/standalone_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def add_entropy_to_checkpoint(state):
def main(argv: Sequence[str]) -> None:
jax.config.update("jax_cpu_enable_gloo_collectives", True)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(argv)
config = pyconfig.config
config = pyconfig.initialize(argv)
validate_train_config(config)
print(f"Found {jax.device_count()} devices.")
print(f"Found {jax.process_count()} processes.")
Expand Down
3 changes: 1 addition & 2 deletions MaxText/standalone_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def data_load_loop(config, state=None):
def main(argv: Sequence[str]) -> None:
jax.config.update("jax_cpu_enable_gloo_collectives", True)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(argv)
config = pyconfig.config
config = pyconfig.initialize(argv)
validate_train_config(config)
max_logging.log(f"Found {jax.device_count()} devices.")
max_logging.log(f"Found {jax.process_count()} processes.")
Expand Down
13 changes: 5 additions & 8 deletions MaxText/tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ class AttentionTest(unittest.TestCase):

def setUp(self):
super().setUp()
pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
enable_checkpointing=False,
max_target_length=128,
max_prefill_predict_length=16,
)
self.cfg = pyconfig.config
self.cfg = config
self.rng = jax.random.PRNGKey(0)

devices_array = max_utils.create_device_mesh(self.cfg)
Expand Down Expand Up @@ -336,7 +336,7 @@ def _dot_product_attention(

rtol, atol = 1e-02, 1e-02

pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
Expand All @@ -345,7 +345,6 @@ def _dot_product_attention(
max_prefill_predict_length=16,
attention="dot_product",
)
config = pyconfig.config

prefill_length = config.max_prefill_predict_length
decode_total_length = config.max_target_length
Expand Down Expand Up @@ -437,7 +436,7 @@ def _dot_product_attention_reshape_q(self, compute_axis_order):

rtol, atol = 1e-02, 1e-02

pyconfig.initialize(
config = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
Expand All @@ -446,7 +445,6 @@ def _dot_product_attention_reshape_q(self, compute_axis_order):
max_prefill_predict_length=16,
attention="dot_product",
)
config = pyconfig.config

prefill_length = config.max_prefill_predict_length
decode_total_length = config.max_target_length
Expand Down Expand Up @@ -727,7 +725,7 @@ class MLATest(parameterized.TestCase):

def init_mla(self, rope_type):
"""Helper function to initialize MLA with different model names."""
pyconfig.initialize(
cfg = pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
Expand All @@ -737,7 +735,6 @@ def init_mla(self, rope_type):
attention_type=attentions.AttentionType.MLA.value,
rope_type=rope_type,
)
cfg = pyconfig.config
rng = jax.random.PRNGKey(0)

devices_array = max_utils.create_device_mesh(cfg)
Expand Down
3 changes: 1 addition & 2 deletions MaxText/tests/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,5 @@ def main(config, test_args):
for arg in to_remove_args:
model_args = [s for s in model_args if not s.startswith(arg)]

pyconfig.initialize(model_args)
cfg = pyconfig.config
cfg = pyconfig.initialize(model_args)
main(cfg, test_args)
Loading
Loading