From caad77d98c8506c0535b76fb2ba5224df7d67ec2 Mon Sep 17 00:00:00 2001 From: Philip Monk Date: Fri, 12 Jun 2026 17:31:15 -0700 Subject: [PATCH 1/2] Add multi_tensor_raw_moments Signed-off-by: Philip Monk --- tests/pytorch/test_multi_tensor.py | 43 ++++ transformer_engine/common/CMakeLists.txt | 2 + .../include/transformer_engine/multi_tensor.h | 23 +++ .../common/multi_tensor/raw_moments.cu | 195 ++++++++++++++++++ transformer_engine/pytorch/csrc/extensions.h | 3 + .../extensions/multi_tensor/raw_moments.cpp | 46 +++++ .../pytorch/csrc/extensions/pybind.cpp | 3 + .../pytorch/optimizers/__init__.py | 1 + 8 files changed, 316 insertions(+) create mode 100644 transformer_engine/common/multi_tensor/raw_moments.cu create mode 100644 transformer_engine/pytorch/csrc/extensions/multi_tensor/raw_moments.cpp diff --git a/tests/pytorch/test_multi_tensor.py b/tests/pytorch/test_multi_tensor.py index 6f1b6948ab..65d8a40137 100644 --- a/tests/pytorch/test_multi_tensor.py +++ b/tests/pytorch/test_multi_tensor.py @@ -284,6 +284,49 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens assert overflow_buf.item() == 0 +raw_moment_size_pairs = [ + (777, 555), + (2048 * 32 + 1, 555), +] + + +def _raw_moment_reference(tensor): + values = tensor.float() + values_2 = values * values + return torch.stack( + [ + torch.tensor(float(values.numel()), dtype=torch.float32, device=tensor.device), + values.sum(), + values_2.sum(), + (values_2 * values).sum(), + (values_2 * values_2).sum(), + ] + ) + + +@pytest.mark.parametrize("input_size_pair", raw_moment_size_pairs) +@pytest.mark.parametrize("applier", appliers) +@pytest.mark.parametrize("repeat", [1, 55]) +@pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16]) +def test_multi_tensor_raw_moments(input_size_pair, applier, repeat, in_type): + sizea, sizeb = input_size_pair + device = torch.device("cuda") + overflow_buf = torch.zeros(1, dtype=torch.int32, device=device) + + a = (torch.arange(sizea, dtype=torch.float32, device=device) % 17) - 8 + b = (torch.arange(sizeb, dtype=torch.float32, device=device) % 11) - 5 + + in_list = [] + for _ in range(repeat): + in_list += [a.clone().to(in_type), b.clone().to(in_type)] + + moments = applier(tex.multi_tensor_raw_moments, overflow_buf, [in_list]) + references = torch.stack([_raw_moment_reference(tensor) for tensor in in_list]) + + torch.testing.assert_close(moments, references, rtol=1e-5, atol=1e-2) + assert overflow_buf.item() == 0 + + @pytest.mark.parametrize("input_size_pair", input_size_pairs) @pytest.mark.parametrize("applier", appliers) @pytest.mark.parametrize("repeat", [1, 55]) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index edb8c5e109..c47502ed45 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -204,6 +204,7 @@ list(APPEND transformer_engine_cuda_sources common.cu multi_tensor/adam.cu multi_tensor/l2norm.cu + multi_tensor/raw_moments.cu multi_tensor/scale.cu multi_tensor/sgd.cu transpose/cast_transpose.cu @@ -569,6 +570,7 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu multi_tensor/adam.cu multi_tensor/compute_scale.cu multi_tensor/l2norm.cu + multi_tensor/raw_moments.cu multi_tensor/scale.cu multi_tensor/sgd.cu fused_attn/flash_attn.cu diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index 09ab260f15..5289d374bd 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -68,6 +68,29 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, int per_tensor, int max_chunks_per_tensor, cudaStream_t stream); +/*! \brief Computes raw moments for a list of tensors. + * + * The returned rows contain count and raw sums of powers 1-4 for each tensor. + * + * \warning This API is **experimental** and subject to change. + * + * \param[in] chunk_size Number of tensor elements processed by a CUDA block. + * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. + * \param[in] tensor_lists 2D array of input tensors. + * \param[in] num_tensor_lists Size (dim0) of tensor_lists. + * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. + * \param[in] output_per_tensor Fixed size auxilliary scratch space. + * \param[out] ret Raw-moment rows for each tensor. + * \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_multi_tensor_raw_moments_cuda(int chunk_size, NVTETensor noop_flag, + NVTETensor **tensor_lists, + const size_t num_tensor_lists, + const size_t num_tensors_per_list, + NVTETensor output_per_tensor, NVTETensor ret, + int max_chunks_per_tensor, cudaStream_t stream); + /*! \brief Compute and apply gradient update to parameters for Adam optimizer. * * \warning This API is **experimental** and subject to change. diff --git a/transformer_engine/common/multi_tensor/raw_moments.cu b/transformer_engine/common/multi_tensor/raw_moments.cu new file mode 100644 index 0000000000..6c39eaaa51 --- /dev/null +++ b/transformer_engine/common/multi_tensor/raw_moments.cu @@ -0,0 +1,195 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include "../utils.cuh" +#include "multi_tensor_apply.cuh" + +namespace transformer_engine { +namespace multi_tensor_raw_moments { + +#define BLOCK_SIZE 512 +#define ILP 4 +#define RAW_MOMENT_FIELDS 5 + +template +__device__ __forceinline__ bool is_aligned(T *p) { + return ((uint64_t)p) % (ILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int src_offset) { + typedef typename std::aligned_storage::type LT; + ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; // NOLINT(*) +} + +__device__ __forceinline__ float reduce_block_sum(float *x, float val) { + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = blockDim.x * blockDim.y; + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = x[tid] + x[tid + i]; + __syncthreads(); + } + + float final = 0.f; + if (tid < 32) { + if (blockSize >= 64) { + final = x[tid] + x[tid + 32]; + } else { + final = val; + } + +#pragma unroll + for (int i = 16; i >= 1; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i); + } + + __syncthreads(); + return final; +} + +template +struct RawMomentsFunctor { + __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, + TensorListMetadata<1> &tl, // NOLINT(*) + float *output_per_tensor, int max_chunks_per_tensor) { + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int tensor_idx = tl.start_tensor_this_launch + tensor_loc; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + x_t *x = reinterpret_cast(tl.addresses[0][tensor_loc]); + x += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + int elements_this_chunk = n < chunk_size ? n : chunk_size; + + __shared__ float s_vals[RAW_MOMENT_FIELDS - 1][BLOCK_SIZE]; + + float sum_1 = 0.f; + float sum_2 = 0.f; + float sum_3 = 0.f; + float sum_4 = 0.f; + + x_t r_x[ILP]; + for (int i = 0; i < ILP; i++) r_x[i] = 0; + + if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) { + for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; + i_start += blockDim.x) { + load_store(r_x, x, 0, i_start); +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + float val = static_cast(r_x[ii]); + float val_2 = val * val; + sum_1 += val; + sum_2 += val_2; + sum_3 += val_2 * val; + sum_4 += val_2 * val_2; + } + } + } else { + for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + float val = static_cast(x[i]); + float val_2 = val * val; + sum_1 += val; + sum_2 += val_2; + sum_3 += val_2 * val; + sum_4 += val_2 * val_2; + } + } + } + } + + float final_sum_1 = reduce_block_sum(s_vals[0], sum_1); + float final_sum_2 = reduce_block_sum(s_vals[1], sum_2); + float final_sum_3 = reduce_block_sum(s_vals[2], sum_3); + float final_sum_4 = reduce_block_sum(s_vals[3], sum_4); + + if (threadIdx.x == 0) { + if (!isfinite(final_sum_1) || !isfinite(final_sum_2) || !isfinite(final_sum_3) || + !isfinite(final_sum_4)) { + *noop_gmem = 1; + } + float *row = output_per_tensor + + (tensor_idx * max_chunks_per_tensor + chunk_idx) * RAW_MOMENT_FIELDS; + row[0] = static_cast(elements_this_chunk); + row[1] = final_sum_1; + row[2] = final_sum_2; + row[3] = final_sum_3; + row[4] = final_sum_4; + } + } +}; + +__global__ void cleanup(float *output_per_tensor, float *ret, int max_chunks_per_tensor) { + int tensor_idx = blockIdx.x; + int field_idx = blockIdx.y; + __shared__ float vals[BLOCK_SIZE]; + + float *chunks = + output_per_tensor + tensor_idx * max_chunks_per_tensor * RAW_MOMENT_FIELDS + field_idx; + + float val = 0.f; + for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) { + val += chunks[i * RAW_MOMENT_FIELDS]; + } + + float final = reduce_block_sum(vals, val); + if (threadIdx.x == 0) ret[tensor_idx * RAW_MOMENT_FIELDS + field_idx] = final; +} + +void multi_tensor_raw_moments_cuda(int chunk_size, Tensor noop_flag, + std::vector> tensor_lists, + Tensor output_per_tensor, Tensor ret, + int max_chunks_per_tensor, cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + tensor_lists[0][0]->dtype(), dtype, + multi_tensor_apply<1>( + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, RawMomentsFunctor(), stream, + reinterpret_cast(output_per_tensor.data.dptr), max_chunks_per_tensor);) + + NVTE_CHECK_CUDA(cudaGetLastError()); + + dim3 grid(tensor_lists[0].size(), RAW_MOMENT_FIELDS); + cleanup<<>>( + reinterpret_cast(output_per_tensor.data.dptr), + reinterpret_cast(ret.data.dptr), max_chunks_per_tensor); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace multi_tensor_raw_moments +} // namespace transformer_engine + +void nvte_multi_tensor_raw_moments_cuda(int chunk_size, NVTETensor noop_flag, + NVTETensor **tensor_lists, + const size_t num_tensor_lists, + const size_t num_tensors_per_list, + NVTETensor output_per_tensor, NVTETensor ret, + int max_chunks_per_tensor, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_raw_moments_cuda); + using namespace transformer_engine; + + multi_tensor_raw_moments::multi_tensor_raw_moments_cuda( + chunk_size, *convertNVTETensorCheck(noop_flag), + convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), + *convertNVTETensorCheck(output_per_tensor), *convertNVTETensorCheck(ret), + max_chunks_per_tensor, stream); +} diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e66698dc2c..d4d81d38bc 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -580,6 +580,9 @@ std::tuple multi_tensor_unscale_l2norm_cuda( int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor inv_scale, at::optional per_tensor_python); +at::Tensor multi_tensor_raw_moments_cuda( + int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists); + void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/raw_moments.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/raw_moments.cpp new file mode 100644 index 0000000000..a9c6a7aaeb --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/raw_moments.cpp @@ -0,0 +1,46 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../../extensions.h" + +namespace transformer_engine::pytorch { + +at::Tensor multi_tensor_raw_moments_cuda( + int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists) { + auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); + + int ntensors = tensor_lists[0].size(); + int max_chunks_per_tensor = 0; + for (int t = 0; t < ntensors; t++) { + int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + if (max_chunks_this_tensor > max_chunks_per_tensor) { + max_chunks_per_tensor = max_chunks_this_tensor; + } + } + + auto ret = at::empty({ntensors, 5}, float_options); + if (max_chunks_per_tensor == 0) { + ret.zero_(); + return ret; + } + + auto output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor * 5}, float_options); + + auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); + auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = + makeTransformerEngineTensorList(tensor_lists); + auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor); + auto ret_cu = makeTransformerEngineTensor(ret); + + nvte_multi_tensor_raw_moments_cuda( + chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, + output_per_tensor_cu.data(), ret_cu.data(), max_chunks_per_tensor, + at::cuda::getCurrentCUDAStream()); + + return ret; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index ef4f5d17d8..bd0b8ccf7f 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -596,6 +596,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only " "performed for L2 norm computation, and tensors are not updated)", py::call_guard()); + m.def("multi_tensor_raw_moments", &transformer_engine::pytorch::multi_tensor_raw_moments_cuda, + "Computes count and raw sums of powers 1-4 for a list of contiguous tensors", + py::call_guard()); m.def("multi_tensor_adam", &transformer_engine::pytorch::multi_tensor_adam_cuda, "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); diff --git a/transformer_engine/pytorch/optimizers/__init__.py b/transformer_engine/pytorch/optimizers/__init__.py index 7220f1924a..ed30e79310 100644 --- a/transformer_engine/pytorch/optimizers/__init__.py +++ b/transformer_engine/pytorch/optimizers/__init__.py @@ -8,6 +8,7 @@ multi_tensor_scale_tensor, multi_tensor_l2norm, multi_tensor_unscale_l2norm, + multi_tensor_raw_moments, multi_tensor_adam, multi_tensor_adam_fp8, multi_tensor_adam_capturable, From 33c255ae93c706368c57a37aaa22d75d882bf34d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Jun 2026 01:19:17 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../include/transformer_engine/multi_tensor.h | 3 +-- .../common/multi_tensor/raw_moments.cu | 17 ++++++++--------- transformer_engine/pytorch/csrc/extensions.h | 4 ++-- .../extensions/multi_tensor/raw_moments.cpp | 12 ++++++------ 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index 5289d374bd..2b8417cdcd 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -85,8 +85,7 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_raw_moments_cuda(int chunk_size, NVTETensor noop_flag, - NVTETensor **tensor_lists, - const size_t num_tensor_lists, + NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor output_per_tensor, NVTETensor ret, int max_chunks_per_tensor, cudaStream_t stream); diff --git a/transformer_engine/common/multi_tensor/raw_moments.cu b/transformer_engine/common/multi_tensor/raw_moments.cu index 6c39eaaa51..7e7666d869 100644 --- a/transformer_engine/common/multi_tensor/raw_moments.cu +++ b/transformer_engine/common/multi_tensor/raw_moments.cu @@ -128,8 +128,8 @@ struct RawMomentsFunctor { !isfinite(final_sum_4)) { *noop_gmem = 1; } - float *row = output_per_tensor + - (tensor_idx * max_chunks_per_tensor + chunk_idx) * RAW_MOMENT_FIELDS; + float *row = + output_per_tensor + (tensor_idx * max_chunks_per_tensor + chunk_idx) * RAW_MOMENT_FIELDS; row[0] = static_cast(elements_this_chunk); row[1] = final_sum_1; row[2] = final_sum_2; @@ -158,8 +158,8 @@ __global__ void cleanup(float *output_per_tensor, float *ret, int max_chunks_per void multi_tensor_raw_moments_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, - Tensor output_per_tensor, Tensor ret, - int max_chunks_per_tensor, cudaStream_t stream) { + Tensor output_per_tensor, Tensor ret, int max_chunks_per_tensor, + cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, multi_tensor_apply<1>( @@ -169,9 +169,9 @@ void multi_tensor_raw_moments_cuda(int chunk_size, Tensor noop_flag, NVTE_CHECK_CUDA(cudaGetLastError()); dim3 grid(tensor_lists[0].size(), RAW_MOMENT_FIELDS); - cleanup<<>>( - reinterpret_cast(output_per_tensor.data.dptr), - reinterpret_cast(ret.data.dptr), max_chunks_per_tensor); + cleanup<<>>(reinterpret_cast(output_per_tensor.data.dptr), + reinterpret_cast(ret.data.dptr), + max_chunks_per_tensor); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -179,8 +179,7 @@ void multi_tensor_raw_moments_cuda(int chunk_size, Tensor noop_flag, } // namespace transformer_engine void nvte_multi_tensor_raw_moments_cuda(int chunk_size, NVTETensor noop_flag, - NVTETensor **tensor_lists, - const size_t num_tensor_lists, + NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor output_per_tensor, NVTETensor ret, int max_chunks_per_tensor, cudaStream_t stream) { diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d4d81d38bc..b7010ad23b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -580,8 +580,8 @@ std::tuple multi_tensor_unscale_l2norm_cuda( int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor inv_scale, at::optional per_tensor_python); -at::Tensor multi_tensor_raw_moments_cuda( - int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists); +at::Tensor multi_tensor_raw_moments_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists); void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, const float lr, diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/raw_moments.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/raw_moments.cpp index a9c6a7aaeb..458334eae8 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/raw_moments.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/raw_moments.cpp @@ -8,8 +8,8 @@ namespace transformer_engine::pytorch { -at::Tensor multi_tensor_raw_moments_cuda( - int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists) { +at::Tensor multi_tensor_raw_moments_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists) { auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); int ntensors = tensor_lists[0].size(); @@ -35,10 +35,10 @@ at::Tensor multi_tensor_raw_moments_cuda( auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor); auto ret_cu = makeTransformerEngineTensor(ret); - nvte_multi_tensor_raw_moments_cuda( - chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, - output_per_tensor_cu.data(), ret_cu.data(), max_chunks_per_tensor, - at::cuda::getCurrentCUDAStream()); + nvte_multi_tensor_raw_moments_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), + num_lists, num_tensors, output_per_tensor_cu.data(), + ret_cu.data(), max_chunks_per_tensor, + at::cuda::getCurrentCUDAStream()); return ret; }