Skip to content

Deprecate aqt keep deps#4107

Open
sarunsingla11722 wants to merge 44 commits into
mainfrom
deprecate-aqt-keep-deps
Open

Deprecate aqt keep deps#4107
sarunsingla11722 wants to merge 44 commits into
mainfrom
deprecate-aqt-keep-deps

Conversation

@sarunsingla11722

@sarunsingla11722 sarunsingla11722 commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Pull Request: [DEBUG] Test keeping aqtp dependency with Qwix/FP8 code changes

1. Executive Summary & Purpose

This is an isolated debug Pull Request created to diagnose the GPU/NCCL unit and integration test failures observed on the main PR (deprecate-aqt-phase2).

To isolate the root cause, this PR:

  1. Includes ALL of the code changes from the main PR (transitioning from legacy AQT to Qwix/FP8, pruning in_serve_mode logic, and decoupling type hints). This ensures that the codebase does not import aqt, preventing ModuleNotFoundError on CPU tests.
  2. Reverts ONLY the requirements files back to main, meaning the aqtp dependency is still kept in the requirements and lockfiles.

By running the GPU tests on this branch, we will cleanly isolate whether the NCCL failures are triggered by the dependency package changes or by the model code changes.


2. Verification & Diagnostic Scenarios

Depending on the outcome of the GPU/CUDA 12 tests on this PR, we will know exactly how to proceed:

Scenario A: The GPU tests PASS on this PR

  • The Verdict: The NCCL error is caused by removing the aqtp dependency from the requirements files.
  • Why: When your new code runs in an environment where aqtp is installed, it works; when aqtp is removed, it fails with NCCL. This indicates that removing aqtp changes how pip resolves other transitive dependencies (e.g. JAX, Jaxlib, or CUDA libraries), introducing a buggy package version.
  • Next Step: Compare the package installation logs between the two environments, identify the shifting package, and explicitly pin the correct version.

Scenario B: The GPU tests STILL FAIL with the NCCL error on this PR

  • The Verdict: The NCCL error is not caused by the dependency removal (since aqtp is still installed here). It is triggered by one of the model/layer code changes.
  • Next Step: We can completely rule out dependencies and focus 100% of our debugging on the model, sharding, or collective communication changes in moe.py, linears.py, etc.

3. Checklist

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

… True

- Change default of use_qwix_quantization to True in types.py and base.yml.
- Fix typo in validator error message.
- Update test_quantization_fallbacks to use fp8_gpu and use_qwix_quantization=False to safely test fallback path.

TAG=agy
CONV=9aac9cad-26d1-453f-9b81-c70a14dd59dc
@sarunsingla11722 sarunsingla11722 force-pushed the deprecate-aqt-keep-deps branch from 8f02c78 to 7d24a3e Compare June 9, 2026 21:26
TAG=agy
CONV=39f31ff3-8d14-4d1b-b519-217aeb07b904
- Updated QwixQuantization.einsum to accept **kwargs to prevent TypeError when called from MoE layer with dtype.
- Reverted .github/workflows/run_tests_against_package.yml to main to isolate whether the NCCL failures were triggered by the new workflow changes (e.g., NCCL_SOCKET_IFNAME=lo).

TAG=agy
CONV=753fe72c-2db8-4329-9321-b25762bed269
@sarunsingla11722 sarunsingla11722 force-pushed the deprecate-aqt-keep-deps branch from 2b4fa42 to 15305a7 Compare June 11, 2026 16:07
Removed NCCL_SOCKET_IFNAME=lo and NCCL_NET_GDR_LEVEL=0 to allow NCCL to
auto-discover the optimal network interface and transport (NVLink/P2P) in the Docker container, fixing ncclInvalidArgument errors. Also changed NCCL_DEBUG to WARN to prevent log spam.

TAG=agy
CONV=bf394eb7-23a3-4437-91ef-0cead1a5b0a0
# Conflicts:
#	src/maxtext/layers/quantizations.py
@shralex shralex added the gemini-investigate investigate CI failures label Jun 13, 2026
@github-actions

Copy link
Copy Markdown

🤖 Hi @shralex, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 CI Failure Investigation Report

I have analyzed the recent test failures in the CI pipeline and identified two distinct failure modes: a GPU collective operations conflict (due to a CUDA context pre-emption/corruption issue) and a transient TPU Pathways network timeout (infrastructure/environment flake).


1. GPU Collective Operations Failure (Codebase-Related / Environment)

🔍 What Failed

  • Job/Matrix: gpu-unit and gpu-integration
  • Failing Tests (Examples):
    • tests/unit/model_creation_utils_test.py::TestSetupDecodeStateFromNnx::test_returns_linen_train_state_and_annotations
    • tests/unit/multi_token_prediction_test.py::MultiTokenPredictionLayerTest::test_multi_token_prediction_layer_output
    • tests/integration/maxengine_test.py::MaxEngineTest::test_basic_decode_nnx
  • Error: jax.errors.JaxRuntimeError: INTERNAL: NCCL operation failed: invalid argument / NCCL WARN Error: corrupted comm object detected

🪵 Error Details & Stack Trace

FAILED tests/unit/model_creation_utils_test.py::TestSetupDecodeStateFromNnx::test_returns_linen_train_state_and_annotations - jax.errors.JaxRuntimeError: INTERNAL: NCCL operation ncclAllReduce( send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, ToNcclReduction(reduction_kind), comm_, AsCudaStream(stream)) failed: invalid argument (run with NCCL_DEBUG=WARN for details). Last NCCL warning(error) log entry (may be unrelated) ''.

FAILED tests/unit/multi_token_prediction_test.py::MultiTokenPredictionLayerTest::test_multi_token_prediction_layer_output - jax.errors.JaxRuntimeError: INTERNAL: NCCL operation ncclAllGather( send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, comm_, AsCudaStream(stream)) failed: invalid argument (run with NCCL_DEBUG=WARN for details). Last NCCL warning(error) log entry (may be unrelated) ''.

linux-x86-a2-48-a100-4gpu-r2txm-runner-g6f9c-workflow:1125:1301 [1] external/nccl_archive/src/misc/argcheck.cc:39 NCCL WARN Error: corrupted comm object detected

💡 Root Cause Analysis & Context

Confidence: high (confirmed cause)

  • The Conflict: In multi-GPU JAX environments, if TensorFlow or libraries importing TensorFlow (like Google's JAX-native quantization library Qwix or TFDS) are imported before JAX has explicitly initialized its devices, TensorFlow pre-emptively initializes the CUDA/NCCL context on the GPUs. When JAX later attempts to initialize its own NCCL communicators for multi-GPU collective operations (e.g., ncclAllReduce or ncclAllGather), a CUDA/NCCL state conflict occurs, causing NCCL to detect a "corrupted comm object" and abort with an invalid argument error.
  • Why it triggered in this PR: This PR (deprecate-aqt-keep-deps) deprecates/removes the old AQT library and sets use_qwix_quantization to True by default in the base configuration (src/maxtext/configs/base.yml). Because of this, qwix is now imported early during test collection in several core modules (e.g., src/maxtext/layers/quantizations.py, src/maxtext/layers/moe.py, etc.). Since qwix includes LiteRT/TensorFlow serialization modules, its import pre-emptively loads the CUDA context.
  • Why the existing workaround failed: tests/conftest.py has an early-initialization block that attempts to call jax.devices() early, but it restricts this check using strict environment variables:
    _has_gpu = (
        "cuda" in _jax_platforms
        or "gpu" in _jax_platforms
        ...
    )
    In the self-hosted GitHub Actions GPU runners, these environment variables are not set or are empty, so the early-initialization was bypassed, allowing Qwix/TensorFlow imports to initialize CUDA first and corrupt subsequent JAX collective runs.

🛠️ Recommended Fix

We can solve this robustly and permanently by unconditionally forcing JAX to initialize early inside tests/conftest.py with a simple try-except block. This ensures JAX always establishes its CUDA/NCCL context first, avoiding context pre-emption/corruption from any subsequent library imports (including TensorFlow, TFDS, or Qwix).

diff --git a/tests/conftest.py b/tests/conftest.py
index 99066cc..4fdfdfd 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -37,17 +37,9 @@ if not _absl_flags.FLAGS.is_parsed():
 import jax
 import os
 import importlib.util
 
-# Force early JAX initialization on GPU to prevent CUDA context conflicts with TensorFlow/PyTorch.
-# If JAX initialization is deferred, TensorFlow/PyTorch (imported during test collection)
-# might initialize CUDA first, causing JAX's subsequent NCCL communicator creation to fail
-# with 'corrupted comm object detected'.
-def _initialize_jax_early():
-  _jax_platforms = os.getenv("JAX_PLATFORMS", "").lower()
-  _device_type = os.getenv("INPUTS_DEVICE_TYPE", "").lower()
-  _has_gpu = (
-      "cuda" in _jax_platforms
-      or "gpu" in _jax_platforms
-      or "cuda" in _device_type
-      or "gpu" in _device_type
-      or os.getenv("CUDA_VISIBLE_DEVICES") is not None
-      or os.getenv("NVIDIA_VISIBLE_DEVICES") is not None
-  )
-  if _has_gpu:
-    try:
-      _ = jax.devices()
-    except Exception:  # pylint: disable=broad-exception-caught
-      pass
+# Force early JAX initialization on all platforms to prevent CUDA context conflicts with TensorFlow/PyTorch/Qwix.
+# If JAX initialization is deferred, other libraries (imported during test collection)
+# might initialize CUDA first, causing JAX's subsequent NCCL communicator creation to fail
+# with 'corrupted comm object detected'.
+try:
+  _ = jax.devices()
+except Exception:  # pylint: disable=broad-exception-caught
+  pass

2. TPU Pathways Integration Failure (Infrastructure/Environment Flake)

🔍 What Failed

  • Job/Matrix: maxtext_tpu_pathways_integration_tests
  • Failing Component: Runner initialization / checkout stage
  • Error: Error: 14 UNAVAILABLE: No connection established. Last error: Error: connect ECONNREFUSED 34.118.229.212:50051 followed by Error: backoff timeout

🪵 Error Details & Stack Trace

maxtext_tpu_pathways_integration_tests / run	UNKNOWN STEP	2026-06-13T06:13:49.7098315Z ##[warning]Retrying execution for ECONNREFUSED or UNAVAILABLE "Error: Error execing #!/bin/sh -l\n\ncd /__w/maxtext/maxtext && exec env ...
: Error: 14 UNAVAILABLE: No connection established. Last error: Error: connect ECONNREFUSED 34.118.229.212:50051".
maxtext_tpu_pathways_integration_tests / run	UNKNOWN STEP	2026-06-13T06:14:09.6992250Z ##[error]Error: backoff timeout

💡 Root Cause Analysis & Context

Confidence: high (confirmed cause)

This failure is a pure infrastructure/environment flake and is completely unrelated to the changes introduced in this PR.

  • Details: The self-hosted runner failed to establish a network connection to the GKE/Pathways master endpoint 34.118.229.212:50051 during the runner's execution loop. This caused a connection refused error (ECONNREFUSED) which eventually hit the backoff timeout and aborted the job.
  • Resolution: Re-triggering the TPU Pathways integration job when the cluster endpoints are healthy and reachable should resolve this failure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gemini-investigate investigate CI failures

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants