Skip to content

# Description#4138

Open
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_930062734
Open

# Description#4138
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_930062734

Conversation

@copybara-service

@copybara-service copybara-service Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Description

This PR fixes a ValueError: can only convert an array of size 1 to a Python scalar that occurs in RemoteIteratorWrapper during state save/restore on multi-device topologies (size > 1). It also adds validation to ensure colocated Python data input is only used with Pathways (single controller) enabled, and replaces incorrect usages of jax.local_devices() with global_mesh.devices.

Root Cause

  1. ValueError in save/restore: RemoteIteratorWrapper.save_state and restore_state were attempting to shape the step value array using self.dummy_array.shape and shard it across devices. On topologies with more than 1 device, this resulted in a partitioned array. When this partitioned array was passed to the local iterator, attempting to unpack it to a Python scalar (e.g. via .item() or direct conversion) failed because JAX does not allow converting partitioned arrays of size > 1 to Python scalars.
  2. Incorrect Device Resolution: RemoteIteratorWrapper was using jax.local_devices() to determine CPU/TPU devices. Under Pathways (single-controller), all devices in the cluster are virtualized as local to the JAX client, meaning jax.local_devices() returns all devices (including inactive ones during elastic scale-down), which is incorrect for sharding and shape calculations.
  3. Missing Validation: colocated_python_data_input relies on Pathways single-controller mode, but there was no validation enforcing this constraint, which could lead to cryptic failures if misconfigured.

Solution

  1. Replicated Scalar for Step: Modified RemoteIteratorWrapper.save_state and restore_state in multihost_dataloading.py to pass the training step as a replicated 0D JAX scalar array (global shape ()) with replicated sharding (NamedSharding with PartitionSpec()). This ensures the array has size 1 on all devices and can be safely converted to a Python scalar by the local iterator.
  2. Use Global Mesh Devices: Replaced jax.local_devices() with global_mesh.devices (via tuple(global_mesh.devices.flat)) in RemoteIteratorWrapper.__init__ to ensure it only uses the active devices defined by the global mesh, handling elastic scaling correctly.
  3. Config Validation: Added a check in types.py to raise a ValueError if colocated_python_data_input is enabled but enable_single_controller is false.

Tests

Added new unit tests in third_party/py/maxtext/tests/unit/multihost_dataloading_test.py to verify the fixes:

  1. test_remote_iterator_wrapper_save_state: Parameterized over different mesh shapes (1, 2, and 4 devices). Instantiates RemoteIteratorWrapper and verifies that calling save_state successfully writes the state to a JSON file without raising ValueError.
  2. test_remote_iterator_wrapper_restore_state: Parameterized over different mesh shapes. Verifies that restore_state successfully restores the state from a JSON file and resumes iteration correctly.

These tests are configured to run with XLA_FLAGS="--xla_force_host_platform_device_count=4" via the BUILD target to simulate multi-device environments.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@codecov

codecov Bot commented Jun 10, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 0% with 7 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...rc/maxtext/input_pipeline/multihost_dataloading.py 0.00% 7 Missing ⚠️

📢 Thoughts on this report? Let us know!

@copybara-service copybara-service Bot changed the title Fix ValueError in RemoteIterator.save_state on scaled-down topologies. Fix ValueError in RemoteIterator.save_state on topologies with size > 1. Jun 12, 2026
@copybara-service copybara-service Bot force-pushed the test_930062734 branch 2 times, most recently from 6cce657 to 1e578e3 Compare June 12, 2026 21:35
@copybara-service copybara-service Bot changed the title Fix ValueError in RemoteIterator.save_state on topologies with size > 1. # Description Jun 12, 2026
This PR fixes a `ValueError: can only convert an array of size 1 to a Python scalar` that occurs in `RemoteIteratorWrapper` during state save/restore on multi-device topologies (size > 1). It also adds validation to ensure colocated Python data input is only used with Pathways (single controller) enabled, and replaces incorrect usages of `jax.local_devices()` with `global_mesh.devices`.

# Root Cause
1.  **ValueError in save/restore**: `RemoteIteratorWrapper.save_state` and `restore_state` were attempting to shape the step value array using `self.dummy_array.shape` and shard it across devices. On topologies with more than 1 device, this resulted in a partitioned array. When this partitioned array was passed to the local iterator, attempting to unpack it to a Python scalar (e.g. via `.item()` or direct conversion) failed because JAX does not allow converting partitioned arrays of size > 1 to Python scalars.
2.  **Incorrect Device Resolution**: `RemoteIteratorWrapper` was using `jax.local_devices()` to determine CPU/TPU devices. Under Pathways (single-controller), all devices in the cluster are virtualized as local to the JAX client, meaning `jax.local_devices()` returns all devices (including inactive ones during elastic scale-down), which is incorrect for sharding and shape calculations.
3.  **Missing Validation**: `colocated_python_data_input` relies on Pathways single-controller mode, but there was no validation enforcing this constraint, which could lead to cryptic failures if misconfigured.

# Solution
1.  **Replicated Scalar for Step**: Modified `RemoteIteratorWrapper.save_state` and `restore_state` in `multihost_dataloading.py` to pass the training step as a replicated 0D JAX scalar array (global shape `()`) with replicated sharding (`NamedSharding` with `PartitionSpec()`). This ensures the array has size 1 on all devices and can be safely converted to a Python scalar by the local iterator.
2.  **Use Global Mesh Devices**: Replaced `jax.local_devices()` with `global_mesh.devices` (via `tuple(global_mesh.devices.flat)`) in `RemoteIteratorWrapper.__init__` to ensure it only uses the active devices defined by the global mesh, handling elastic scaling correctly.
3.  **Config Validation**: Added a check in `types.py` to raise a `ValueError` if `colocated_python_data_input` is enabled but `enable_single_controller` is false.

# Tests
Added new unit tests in `third_party/py/maxtext/tests/unit/multihost_dataloading_test.py` to verify the fixes:
1.  `test_remote_iterator_wrapper_save_state`: Parameterized over different mesh shapes (1, 2, and 4 devices). Instantiates `RemoteIteratorWrapper` and verifies that calling `save_state` successfully writes the state to a JSON file without raising `ValueError`.
2.  `test_remote_iterator_wrapper_restore_state`: Parameterized over different mesh shapes. Verifies that `restore_state` successfully restores the state from a JSON file and resumes iteration correctly.

These tests are configured to run with `XLA_FLAGS="--xla_force_host_platform_device_count=4"` via the `BUILD` target to simulate multi-device environments.

# Checklist

Before submitting this PR, please make sure (put X in square brackets):
- [X] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label.
- [X] I have necessary comments in my code, particularly in hard-to-understand areas.
- [X] I have run end-to-end tests tests and provided workload links above if applicable.
- [X] I have made or will make corresponding changes to the doc if needed.

PiperOrigin-RevId: 930062734
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant