Deprecate aqt keep deps#4107
Conversation
… True - Change default of use_qwix_quantization to True in types.py and base.yml. - Fix typo in validator error message. - Update test_quantization_fallbacks to use fp8_gpu and use_qwix_quantization=False to safely test fallback path. TAG=agy CONV=9aac9cad-26d1-453f-9b81-c70a14dd59dc
8f02c78 to
7d24a3e
Compare
…ix mode detection for Qwix
TAG=agy CONV=39f31ff3-8d14-4d1b-b519-217aeb07b904
- Updated QwixQuantization.einsum to accept **kwargs to prevent TypeError when called from MoE layer with dtype. - Reverted .github/workflows/run_tests_against_package.yml to main to isolate whether the NCCL failures were triggered by the new workflow changes (e.g., NCCL_SOCKET_IFNAME=lo). TAG=agy CONV=753fe72c-2db8-4329-9321-b25762bed269
aa7a9f4 to
ac32599
Compare
2b4fa42 to
15305a7
Compare
Removed NCCL_SOCKET_IFNAME=lo and NCCL_NET_GDR_LEVEL=0 to allow NCCL to auto-discover the optimal network interface and transport (NVLink/P2P) in the Docker container, fixing ncclInvalidArgument errors. Also changed NCCL_DEBUG to WARN to prevent log spam. TAG=agy CONV=bf394eb7-23a3-4437-91ef-0cead1a5b0a0
# Conflicts: # src/maxtext/layers/quantizations.py
|
🤖 Hi @shralex, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
🤖 CI Failure Investigation ReportI have analyzed the recent test failures in the CI pipeline and identified two distinct failure modes: a GPU collective operations conflict (due to a CUDA context pre-emption/corruption issue) and a transient TPU Pathways network timeout (infrastructure/environment flake). 1. GPU Collective Operations Failure (Codebase-Related / Environment)🔍 What Failed
🪵 Error Details & Stack TraceFAILED tests/unit/model_creation_utils_test.py::TestSetupDecodeStateFromNnx::test_returns_linen_train_state_and_annotations - jax.errors.JaxRuntimeError: INTERNAL: NCCL operation ncclAllReduce( send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, ToNcclReduction(reduction_kind), comm_, AsCudaStream(stream)) failed: invalid argument (run with NCCL_DEBUG=WARN for details). Last NCCL warning(error) log entry (may be unrelated) ''.
FAILED tests/unit/multi_token_prediction_test.py::MultiTokenPredictionLayerTest::test_multi_token_prediction_layer_output - jax.errors.JaxRuntimeError: INTERNAL: NCCL operation ncclAllGather( send_buffer.opaque(), recv_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype, comm_, AsCudaStream(stream)) failed: invalid argument (run with NCCL_DEBUG=WARN for details). Last NCCL warning(error) log entry (may be unrelated) ''.
linux-x86-a2-48-a100-4gpu-r2txm-runner-g6f9c-workflow:1125:1301 [1] external/nccl_archive/src/misc/argcheck.cc:39 NCCL WARN Error: corrupted comm object detected💡 Root Cause Analysis & ContextConfidence: high (confirmed cause)
🛠️ Recommended FixWe can solve this robustly and permanently by unconditionally forcing JAX to initialize early inside diff --git a/tests/conftest.py b/tests/conftest.py
index 99066cc..4fdfdfd 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -37,17 +37,9 @@ if not _absl_flags.FLAGS.is_parsed():
import jax
import os
import importlib.util
-# Force early JAX initialization on GPU to prevent CUDA context conflicts with TensorFlow/PyTorch.
-# If JAX initialization is deferred, TensorFlow/PyTorch (imported during test collection)
-# might initialize CUDA first, causing JAX's subsequent NCCL communicator creation to fail
-# with 'corrupted comm object detected'.
-def _initialize_jax_early():
- _jax_platforms = os.getenv("JAX_PLATFORMS", "").lower()
- _device_type = os.getenv("INPUTS_DEVICE_TYPE", "").lower()
- _has_gpu = (
- "cuda" in _jax_platforms
- or "gpu" in _jax_platforms
- or "cuda" in _device_type
- or "gpu" in _device_type
- or os.getenv("CUDA_VISIBLE_DEVICES") is not None
- or os.getenv("NVIDIA_VISIBLE_DEVICES") is not None
- )
- if _has_gpu:
- try:
- _ = jax.devices()
- except Exception: # pylint: disable=broad-exception-caught
- pass
+# Force early JAX initialization on all platforms to prevent CUDA context conflicts with TensorFlow/PyTorch/Qwix.
+# If JAX initialization is deferred, other libraries (imported during test collection)
+# might initialize CUDA first, causing JAX's subsequent NCCL communicator creation to fail
+# with 'corrupted comm object detected'.
+try:
+ _ = jax.devices()
+except Exception: # pylint: disable=broad-exception-caught
+ pass2. TPU Pathways Integration Failure (Infrastructure/Environment Flake)🔍 What Failed
🪵 Error Details & Stack Trace💡 Root Cause Analysis & ContextConfidence: high (confirmed cause) This failure is a pure infrastructure/environment flake and is completely unrelated to the changes introduced in this PR.
|
…to grain forking tests
…xt corruption from Qwix/TF
Pull Request: [DEBUG] Test keeping aqtp dependency with Qwix/FP8 code changes
1. Executive Summary & Purpose
This is an isolated debug Pull Request created to diagnose the GPU/NCCL unit and integration test failures observed on the main PR (deprecate-aqt-phase2).
To isolate the root cause, this PR:
in_serve_modelogic, and decoupling type hints). This ensures that the codebase does not importaqt, preventingModuleNotFoundErroron CPU tests.main, meaning theaqtpdependency is still kept in the requirements and lockfiles.By running the GPU tests on this branch, we will cleanly isolate whether the NCCL failures are triggered by the dependency package changes or by the model code changes.
2. Verification & Diagnostic Scenarios
Depending on the outcome of the GPU/CUDA 12 tests on this PR, we will know exactly how to proceed:
Scenario A: The GPU tests PASS on this PR
aqtpdependency from the requirements files.aqtpis installed, it works; whenaqtpis removed, it fails with NCCL. This indicates that removingaqtpchanges howpipresolves other transitive dependencies (e.g. JAX, Jaxlib, or CUDA libraries), introducing a buggy package version.Scenario B: The GPU tests STILL FAIL with the NCCL error on this PR
aqtpis still installed here). It is triggered by one of the model/layer code changes.moe.py,linears.py, etc.3. Checklist
gemini-reviewlabel.