-
Notifications
You must be signed in to change notification settings - Fork 482
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug report
MaxText.train_compile fails if compile_topology_num_slices > 1. Jax gets initialized in pyconfig.initialize by default. Since train_compile calls it before get_topology_mesh updates mock_num_gpu_processes, the update has no consequences on the topology_devices. Setting quantization_local_shard_count=1 solves it. Seems like jax devices get initialized at maxtext/src/MaxText/configs/types.py:1666 when figuring out the value of quantization_local_shard_count
Logs/Output
python3 -m MaxText.train_compile ./src/MaxText/configs/base.yml compile_topology=a3 hardware=gpu compile_topology_num_slices=8 compiled_trainstep_file=./out/compiled_train_step.pkl
Starting train_compile.py...
...
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/opt/maxtext/src/MaxText/train_compile.py", line 264, in <module>
app.run(main)
File "/usr/local/lib/python3.12/dist-packages/absl/app.py", line 316, in run
_run_main(main, args)
File "/usr/local/lib/python3.12/dist-packages/absl/app.py", line 261, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/opt/maxtext/src/MaxText/train_compile.py", line 209, in main
topology_mesh = get_topology_mesh(config)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/maxtext/src/MaxText/train_compile.py", line 81, in get_topology_mesh
topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/maxtext/src/MaxText/maxtext_utils.py", line 1054, in create_device_mesh
ici_parallelism = max_utils.fill_unspecified_mesh_axes(config.ici_parallelism.copy(), num_devices_per_slice, "ICI")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/maxtext/src/MaxText/max_utils.py", line 353, in fill_unspecified_mesh_axes
determined_val >= 1 and determined_val.is_integer
AssertionError: Unspecified value unable to be determined with the given ICI parallelism valuesEnvironment Information
No response
Additional Context
No response
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working