Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion docs/envvars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,19 @@ These environment variables control the behavior of Transformer Engine during ex
Attention Backend Selection
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Transformer Engine attention selects a backend in two stages. First, it filters the available
backends by environment variables, GPU architecture, installed ``flash-attn`` and cuDNN versions,
data type and FP8 recipe, training or inference mode, and the provided attention configuration.
Then it applies a performance-based preference order among the remaining eligible backends.

In PyTorch, the broad preference order is ``FlashAttention > FusedAttention >
UnfusedDotProductAttention`` on supported pre-Hopper GPUs such as Ampere/Ada, and
``FusedAttention > FlashAttention > UnfusedDotProductAttention`` on Hopper and newer GPUs,
including Blackwell. In JAX, Transformer Engine uses cuDNN fused attention when
``NVTE_FUSED_ATTN=1`` and an eligible cuDNN kernel is available; otherwise it falls back to the
JAX-native implementation. See :doc:`examples/attention/attention` for a longer
backend-selection overview.

.. envvar:: NVTE_FLASH_ATTN

:Type: ``int`` (0 or 1)
Expand All @@ -144,7 +157,7 @@ Attention Backend Selection

:Type: ``int`` (1 or 2)
:Default: Auto-selected
:Description: Force a specific FusedAttention backend. ``1`` = F16_arbitrary_seqlen (cuDNN, any seq len), ``2`` = FP8 backend. If not set, the backend is automatically selected based on the input configuration.
:Description: Request a cuDNN FusedAttention backend when that request is supported by the active fused-attention path. ``1`` = F16_arbitrary_seqlen (cuDNN, any seq len), ``2`` = FP8 backend. If not set, the backend is automatically selected based on the input configuration. BF16/FP16 attention uses sub-backend ``1`` when eligible. FP8 attention uses sub-backend ``2`` when FP8 DPA is enabled and supported by the architecture, cuDNN version, and input configuration.

.. envvar:: NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT

Expand Down
39 changes: 19 additions & 20 deletions docs/examples/attention/attention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,6 @@
" <th>Additional info</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>Non-Flash</td>\n",
" <td>BF16/FP16</td>\n",
" <td> &le;512 </td>\n",
" <td> sm80, 90 </td>\n",
" <td> [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/latest/developer/graph-api.html#fused-attention-fprop)</td> \n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>Flash</td>\n",
" <td>BF16/FP16</td>\n",
Expand Down Expand Up @@ -208,34 +200,41 @@
"source": [
"## 2. Backend Selection\n",
"\n",
"Given the various attention backends, Transformer Engine has a selection logic in place to choose the most appropriate backend for a particular set of user inputs and runtime environment. The selection logic is based on both backend availability and backend performance.\n",
"Given the various attention backends, Transformer Engine first determines which backends are eligible for the provided inputs and runtime environment, then applies a preference order among the eligible backends. Eligibility is affected by user environment variables, GPU architecture, installed `flash-attn` and cuDNN versions, data type and FP8 recipe, QKV layout, training or inference mode, dropout, and other attention features.\n",
"\n",
"Backend availability is determined by factors such as model configuration, training hyper-parameters, software versions, and the GPU architecture in question. For example, some considerations are the sequence length, number of attention heads, head size, attention mask type, attention bias type, training or inference mode, self or cross attention, MHA or MQA/GQA, `flash-attn`/cuDNN library versions, and the compute capability of the GPU.\n",
"In PyTorch, the candidates are FlashAttention (`flash-attn` v2, v3, or v4), FusedAttention (cuDNN sub-backends), and UnfusedDotProductAttention. Users can disable whole backend families with `NVTE_FLASH_ATTN`, `NVTE_FUSED_ATTN`, or `NVTE_UNFUSED_ATTN`. In JAX, Transformer Engine checks whether a cuDNN fused-attention kernel is available when `NVTE_FUSED_ATTN=1`; otherwise it falls back to the JAX-native implementation.\n",
"\n",
"When there are multiple backends available, Transformer Engine makes backend selection based on performance. In general, there are a few rules being followed in our selection logic (see table below). As we monitor the performance of different backends, the selection logic may change.\n",
"At a high level, the architecture-specific PyTorch selection order is:\n",
"\n",
"<table class=\"docutils align-default\">\n",
" <tr>\n",
" <th>Framework</th>\n",
" <th>Selection Order</th>\n",
" </tr>\n",
" <tr>\n",
" <td rowspan=\"3\">PyTorch</td>\n",
" <td>sm90: cuDNN attention > flash-attention > PyTorch-native attention</td>\n",
" <td rowspan=\"4\">PyTorch</td>\n",
" <td>sm8x (Ampere/Ada): flash-attention > cuDNN attention > PyTorch-native attention</td>\n",
" </tr>\n",
" <tr>\n",
" <td> sm80: flash-attention > cuDNN attention > PyTorch-native attention</td>\n",
" <td>sm90 (Hopper): cuDNN attention > flash-attention > PyTorch-native attention</td>\n",
" </tr>\n",
" <tr>\n",
" <td>\n",
" cuDNN attention: sub-backend 1 > sub-backend 0\n",
" </td> \n",
" <td>sm100/sm120 (Blackwell): cuDNN attention > flash-attention > PyTorch-native attention</td>\n",
" </tr>\n",
" <tr>\n",
" <td>cuDNN attention: BF16/FP16 uses sub-backend 1 when eligible; FP8 uses sub-backend 2 when enabled and eligible</td>\n",
" </tr>\n",
" <tr>\n",
" <td>JAX</td>\n",
" <td>cuDNN attention > JAX-native attention</td>\n",
" </tr>\n",
"</table>"
"</table>\n",
"\n",
"Within FlashAttention, TE uses the installed implementation that is supported for the architecture and input. FlashAttention 3 is Hopper-only (`sm90`). FlashAttention 4 supports `sm80`, `sm90`, `sm100`, and `sm120`; on Hopper, TE prefers FlashAttention 3 over FlashAttention 4 when both are installed and eligible. On Blackwell, FlashAttention 4 is the Blackwell-specific flash-attention path when installed and eligible, while FlashAttention 2 can still be eligible depending on the installed version and input configuration.\n",
"\n",
"Within cuDNN FusedAttention, TE asks the fused-attention helper which sub-backend is eligible. Sub-backend 1 is the BF16/FP16 flash-based path when available; sub-backend 2 is the FP8 path when FP8 DPA is enabled and the architecture, cuDNN version, and input configuration support it. Hopper supports eligible FP8 DPA through cuDNN sub-backend 2. In the current PyTorch selector, eligible FP8 DPA on Blackwell is an `sm100` path and is disabled on `sm120`.\n",
"\n",
"When all optimized backends are disabled or ineligible, TE falls back to UnfusedDotProductAttention if it is enabled. If no backend is eligible, backend selection returns no backend and the caller raises an error. As we monitor the performance of different backends, the selection logic may change."
]
},
{
Expand Down Expand Up @@ -350,7 +349,7 @@
"**cuDNN attention sub-backends:**\n",
"This environment variable allows users to express their preference of cuDNN attention sub-backends. However, the elected sub-backend will only be used *if* it is eligible, i.e. if it has support for the provided inputs and runtime environment.\n",
"```\n",
"NVTE_FUSED_ATTN_BACKEND = 0/1/2 # user preference of cuDNN sub-backend\n",
"NVTE_FUSED_ATTN_BACKEND = 1/2 # user preference of cuDNN sub-backend\n",
"```\n",
"\n",
"**Execution paths of cuDNN sub-backend 1:**\n",
Expand All @@ -369,7 +368,7 @@
"<div class=\"alert alert-info\">\n",
"<b>Note</b>\n",
" \n",
"Environment variables <code>NVTE_FLASH_ATTN</code>, <code>NVTE_FUSED_ATTN</code>, <code>NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT</code> and <code>NVTE_ALLOW_NONDETERMINISTIC_ALGO</code> are only supported in PyTorch, and will be added to JAX in the future.\n",
"Environment variables <code>NVTE_FLASH_ATTN</code>, <code>NVTE_UNFUSED_ATTN</code>, <code>NVTE_FUSED_ATTN_BACKEND</code>, <code>NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT</code>, and <code>NVTE_FUSED_ATTN_USE_FAv2_BWD</code> are supported in PyTorch. <code>NVTE_FUSED_ATTN</code> and <code>NVTE_ALLOW_NONDETERMINISTIC_ALGO</code> are supported in both PyTorch and JAX.\n",
"</div>\n",
"\n",
"### 2.3 Example Tests\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1059,8 +1059,13 @@ def forward(

Users can use environment variables :attr:`NVTE_FLASH_ATTN`, :attr:`NVTE_FUSED_ATTN`,
and :attr:`NVTE_FUSED_ATTN_BACKEND` to control which DotProductAttention backend,
and FusedAttention backend if applicable, to use. Transformer Engine prioritizes
FlashAttention over FusedAttention and over UnfusedDotProductAttention.
and FusedAttention backend if applicable, to use. Transformer Engine first filters
backends by support for the runtime environment and input configuration, then applies
a performance-based preference order. On supported pre-Hopper GPUs, FlashAttention is
preferred over FusedAttention and UnfusedDotProductAttention when both optimized
backends are eligible. On Hopper and newer GPUs, including Blackwell, FusedAttention is
preferred over FlashAttention and UnfusedDotProductAttention when both optimized
backends are eligible.
If FusedAttention is being used, users can also choose to switch to flash-attn's
implementation for backward by setting :attr:`NVTE_FUSED_ATTN_USE_FAv2_BWD=1`
(default: 0), because of the performance differences between various versions of
Expand Down
Loading