diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 123680e5275..673b5b4fd4b 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -89,6 +89,11 @@ case "${IMAGE_NAME}" in OS_VERSION=24.04 GCC_VERSION=14 ;; + executorch-ubuntu-26.04-gcc15) + LINTRUNNER="" + OS_VERSION=26.04 + GCC_VERSION=15 + ;; *) echo "Invalid image name ${IMAGE_NAME}" exit 1 diff --git a/.ci/docker/common/install_docs_reqs.sh b/.ci/docker/common/install_docs_reqs.sh index 3b6d10c5c2b..ea54d90523e 100755 --- a/.ci/docker/common/install_docs_reqs.sh +++ b/.ci/docker/common/install_docs_reqs.sh @@ -15,8 +15,8 @@ if [ -n "$BUILD_DOCS" ]; then curl --retry 3 --retry-all-errors -sL https://deb.nodesource.com/setup_16.x | sudo -E bash - sudo apt-get install -y nodejs - curl --retry 3 --retry-all-errors -sS https://dl.yarnpkg.com/debian/pubkey.gpg | sudo apt-key add - - echo "deb https://dl.yarnpkg.com/debian/ stable main" | sudo tee /etc/apt/sources.list.d/yarn.list + curl --retry 3 --retry-all-errors -sS https://dl.yarnpkg.com/debian/pubkey.gpg | sudo gpg --dearmor -o /usr/share/keyrings/yarn-archive-keyring.gpg + echo "deb [signed-by=/usr/share/keyrings/yarn-archive-keyring.gpg] https://dl.yarnpkg.com/debian/ stable main" | sudo tee /etc/apt/sources.list.d/yarn.list apt-get update apt-get install -y --no-install-recommends yarn diff --git a/.ci/scripts/test_riscv_qemu.sh b/.ci/scripts/test_riscv_qemu.sh index 2842542aa3a..472484ecd60 100755 --- a/.ci/scripts/test_riscv_qemu.sh +++ b/.ci/scripts/test_riscv_qemu.sh @@ -4,10 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# CI wrapper: install RISC-V cross-compile + qemu-user tooling, then run the -# RISC-V smoke test (export, cross-compile, qemu-user execution) via -# examples/riscv/run.sh. The bundled-IO comparison and Test_result: PASS -# check are done by run.sh. +# CI wrapper: install riscv32/64 cross-compile + qemu tooling, then drive +# examples/riscv/run.sh which does the export, cross-compile, qemu run, and +# bundled-IO PASS check. set -eu @@ -15,8 +14,11 @@ script_dir=$(realpath "$(dirname "${BASH_SOURCE[0]}")") et_root_dir=$(realpath "${script_dir}/../..") model="add" -xnnpack=false +backend="portable" quantize=false +os="linux" +arch="rv64" +qemu_cpu_ext="" verbose_xnnpack=false debug_xnnpack=false @@ -24,20 +26,26 @@ usage() { cat < Which model to export and run (default: add) - --xnnpack Enable the XNNPACK backend (AOT partitioner + runtime) - --quantize Produce an 8-bit quantized model - --verbose-xnnpack Build XNNPACK with XNN_LOG_LEVEL=4 to log microkernel dispatch - --debug-xnnpack Enable XNNPACK partitioner DEBUG logging and dump the lowered graph - -h, --help Show this help + --model= Which model to export and run (default: ${model}) + --quantize Produce an 8-bit quantized model + --backend= AOT backend (portable|xnnpack) (default: ${backend}) + --os= Target OS (linux|baremetal) (default: ${os}) + --arch= Target arch (rv32|rv64) (default: ${arch}) + --qemu-cpu-ext= QEMU -cpu extensions (no rv32/rv64 prefix, default: none) + --verbose-xnnpack Build XNNPACK with XNN_LOG_LEVEL=4 to log microkernel dispatch + --debug-xnnpack Enable XNNPACK partitioner DEBUG logging and dump the lowered graph + -h, --help Show this help EOF } for arg in "$@"; do case $arg in --model=*) model="${arg#*=}" ;; - --xnnpack) xnnpack=true ;; --quantize) quantize=true ;; + --backend=*) backend="${arg#*=}" ;; + --os=*) os="${arg#*=}" ;; + --arch=*) arch="${arg#*=}" ;; + --qemu-cpu-ext=*) qemu_cpu_ext="${arg#*=}" ;; --debug-xnnpack) debug_xnnpack=true ;; --verbose-xnnpack) verbose_xnnpack=true ;; -h|--help) usage; exit 0 ;; @@ -46,8 +54,8 @@ for arg in "$@"; do done run_extra_args=() -if ${xnnpack}; then - run_extra_args+=(--xnnpack) +if [ -n "${qemu_cpu_ext}" ]; then + run_extra_args+=(--qemu-cpu-ext="${qemu_cpu_ext}") fi if ${quantize}; then run_extra_args+=(--quantize) @@ -60,4 +68,6 @@ if ${verbose_xnnpack}; then fi bash "${et_root_dir}/examples/riscv/setup.sh" -bash "${et_root_dir}/examples/riscv/run.sh" --model="${model}" "${run_extra_args[@]}" +bash "${et_root_dir}/examples/riscv/run.sh" \ + --model="${model}" --backend="${backend}" --os="${os}" --arch="${arch}" \ + "${run_extra_args[@]}" diff --git a/.github/workflows/_test_riscv.yml b/.github/workflows/_test_riscv.yml index 223a146e3d8..1298954242d 100644 --- a/.github/workflows/_test_riscv.yml +++ b/.github/workflows/_test_riscv.yml @@ -13,35 +13,44 @@ on: type: number default: 30 model: - description: 'Which model to run. Possible values are: add, mv2 (mobilenetv2)' + description: 'Which model to run (add, mv2, mobilebert, llama2, resnet18, yolo26)' required: false type: string default: 'add' - xnnpack: - description: 'Whether to enable XNNPACK' - required: false - type: boolean - default: false quantize: description: 'Produce an 8-bit quantized model' required: false type: boolean default: false - qemu-cpu: - description: 'Configuration(s) for the CPU to emulate with QEMU, expecting a JSON array' - required: true + backend: + description: 'AOT backend to lower to (portable|xnnpack)' + required: false type: string - docker-image: - description: 'The docker image to use for this job' + default: 'portable' + os: + description: 'Target OS for the runner (linux|baremetal)' required: false type: string + default: 'linux' + arch: + description: 'Target architecture (rv32|rv64)' + required: false + type: string + default: 'rv64' + qemu-cpu-ext: + description: >- + JSON array of QEMU -cpu *extension* strings (no rv32/rv64 prefix). + The script splices each entry with `arch` to form the final -cpu + value. Use [""] for plain base-ISA runs. + required: true + type: string jobs: run: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: runner: linux.2xlarge - docker-image: ci-image:executorch-ubuntu-24.04-gcc14 + docker-image: ci-image:executorch-ubuntu-26.04-gcc15 submodules: 'recursive' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} timeout: ${{ inputs.timeout }} @@ -55,20 +64,22 @@ jobs: # Allows failure in `echo | jq | while read` pipeline to bubble up and fail the workflow set -o pipefail - echo '${{ inputs.qemu-cpu }}' | jq -r '.[]' | while IFS= read -r qemu_cpu; do - export QEMU_CPU="${qemu_cpu}" - export GCC_VERSION=14 + echo '${{ inputs.qemu-cpu-ext }}' | jq -r '.[]' | while IFS= read -r qemu_cpu_ext; do bash .ci/scripts/test_riscv_qemu.sh \ --model="${{ inputs.model }}" \ - ${{ inputs.xnnpack && '--xnnpack --verbose-xnnpack' || '' }} \ + --backend="${{ inputs.backend }}" \ + --os="${{ inputs.os }}" \ + --arch="${{ inputs.arch }}" \ + --qemu-cpu-ext="${qemu_cpu_ext}" \ + ${{ inputs.backend == 'xnnpack' && '--verbose-xnnpack' || '' }} \ ${{ inputs.quantize && '--quantize' || '' }} - # We only generate riscv_test/${{ inputs.model }}_riscv.etdump.json from `--verbose-xnnpack`. - if ${{ inputs.xnnpack }}; then - # Generate markdown table from riscv_test/${{ inputs.model }}_riscv.etdump.json, sorted by sum_ms + # We only generate riscv_test/${{ inputs.model }}${{ inputs.quantize && '_q' || '' }}_${{ inputs.backend }}_${{ inputs.os }}_${{ inputs.arch }}_riscv.etdump.json from `--verbose-xnnpack`. + if [[ "${{ inputs.backend }}" == "xnnpack" ]]; then + # Generate markdown table from riscv_test/${{ inputs.model }}${{ inputs.quantize && '_q' || '' }}_${{ inputs.backend }}_${{ inputs.os }}_${{ inputs.arch }}_riscv.etdump.json, sorted by sum_ms ( - etdump_json="riscv_test/${{ inputs.model }}_riscv.etdump.json" - echo "### Model=${{ inputs.model }} XNNPACK=${{ inputs.xnnpack }} Quantize=${{ inputs.quantize }} QEMU_CPU='${QEMU_CPU}'" + etdump_json="riscv_test/${{ inputs.model }}${{ inputs.quantize && '_q' || '' }}_${{ inputs.backend }}_${{ inputs.os }}_${{ inputs.arch }}_riscv.etdump.json" + echo "### Model=${{ inputs.model }} Quantize=${{ inputs.quantize }} Backend=${{ inputs.backend }} OS=${{ inputs.os }} Arch=${{ inputs.arch }}${qemu_cpu_ext:+,${qemu_cpu_ext}}" jq -r ' def r3: (. * 1000 | round) / 1000; ["Section","Op","Count","Sum (ms)","Avg (ms)","Max (ms)","Microkernels"], diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index b77e5497f79..d11b2e9e6d9 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -43,6 +43,7 @@ jobs: executorch-ubuntu-22.04-mediatek-sdk, executorch-ubuntu-22.04-clang12-android, executorch-ubuntu-24.04-gcc14, + executorch-ubuntu-26.04-gcc15, ] include: - docker-image-name: executorch-ubuntu-22.04-gcc11-aarch64 diff --git a/.github/workflows/riscv64.yml b/.github/workflows/riscv64.yml index 14b9ad62047..cb7f0174a9b 100644 --- a/.github/workflows/riscv64.yml +++ b/.github/workflows/riscv64.yml @@ -1,4 +1,4 @@ -name: Test RISC-V Backend +name: Test RISC-V on: push: @@ -10,8 +10,9 @@ on: pull_request: paths: - .github/workflows/riscv64.yml + - .github/workflows/_test_riscv.yml - .ci/scripts/test_riscv_qemu.sh - - tools/cmake/preset/riscv64_linux.cmake + - tools/cmake/preset/riscv64_*.cmake - examples/riscv/** workflow_dispatch: schedule: @@ -28,39 +29,50 @@ jobs: strategy: fail-fast: false matrix: - include: - - { model: add, xnnpack: false, quantize: false } - - { model: add, xnnpack: true, quantize: false } - - { model: mv2, xnnpack: false, quantize: false } - - { model: mv2, xnnpack: true, quantize: false } - - { model: mv2, xnnpack: true, quantize: true } - - { model: mobilebert, xnnpack: false, quantize: false } - - { model: mobilebert, xnnpack: true, quantize: false } - - { model: mobilebert, xnnpack: true, quantize: true } - - { model: llama2, xnnpack: false, quantize: false } - - { model: llama2, xnnpack: true, quantize: false } - - { model: llama2, xnnpack: true, quantize: true } - - { model: resnet18, xnnpack: false, quantize: false } - - { model: resnet18, xnnpack: true, quantize: false } - - { model: resnet18, xnnpack: true, quantize: true } + model: + - add + - mv2 + - mobilebert + - llama2 + - resnet18 + - yolo26 + quantize: [true, false] + backend: [portable, xnnpack] + os: [linux, baremetal] + arch: [rv64, rv32] + exclude: + # Disable quantization testing with Portable Kernels + - { backend: portable, quantize: true } + # XNNPACK needs pthreads + dynamic loading (no baremetal) and ships no rv32 microkernels. + - { backend: xnnpack, os: baremetal } + - { backend: xnnpack, arch: rv32 } + # No riscv32-linux-gnu cross is packaged on Ubuntu. + - { os: linux, arch: rv32 } + # No quantization recipe for Yolo26. + - { model: yolo26, quantize: true } permissions: id-token: write contents: read with: model: ${{ matrix.model }} - xnnpack: ${{ matrix.xnnpack }} quantize: ${{ matrix.quantize }} - # If XNNPACK, test with multiple RVV length, disabled otherwise - qemu-cpu: >- + backend: ${{ matrix.backend }} + os: ${{ matrix.os }} + arch: ${{ matrix.arch }} + # JSON array of QEMU -cpu *extension* strings (no rv32/rv64 prefix - that + # comes from `arch`). The script splices them as `,`. xnnpack + # benefits from RVV so it sweeps multiple vlen; everything else just uses + # the plain base ISA. + qemu-cpu-ext: >- ${{ case( - matrix.xnnpack, '[ - "rv64,zba=true,zbb=true,zbs=true,v=true,vlen=128,elen=64,vext_spec=v1.0", - "rv64,zba=true,zbb=true,zbs=true,v=true,vlen=256,elen=64,vext_spec=v1.0", - "rv64,zba=true,zbb=true,zbs=true,v=true,vlen=512,elen=64,vext_spec=v1.0" + matrix.backend == 'xnnpack', '[ + "zba=true,zbb=true,zbs=true,v=true,vlen=128,elen=64,vext_spec=v1.0", + "zba=true,zbb=true,zbs=true,v=true,vlen=256,elen=64,vext_spec=v1.0", + "zba=true,zbb=true,zbs=true,v=true,vlen=512,elen=64,vext_spec=v1.0" ]', '[ - "rv64,zba=true,zbb=true,zbs=true,v=false" + "zba=true,zbb=true,zbs=true,v=false" ]' ) }} diff --git a/.lintrunner.toml b/.lintrunner.toml index 3ee436f61e8..02380ce1356 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -112,6 +112,8 @@ include_patterns = [ 'backends/arm/**/*.cpp', 'backends/arm/**/*.h', 'backends/arm/**/*.hpp', + 'backends/cortex_m/**/*.cpp', + 'backends/cortex_m/**/*.h', 'examples/arm/**/*.cpp', 'examples/arm/**/*.h', 'examples/arm/**/*.hpp', diff --git a/CMakePresets.json b/CMakePresets.json index 91848565067..e451084b20a 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -327,6 +327,15 @@ "rhs": "Linux" } }, + { + "name": "riscv64-baremetal", + "displayName": "Build ExecuTorch for riscv64 baremetal (cross-compile)", + "inherits": ["common"], + "cacheVariables": { + "EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/riscv64_baremetal.cmake", + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/examples/riscv/riscv64-unknown-elf-toolchain.cmake" + } + }, { "name": "mlx", "displayName": "Build MLX delegate", diff --git a/backends/aoti/slim/core/storage.h b/backends/aoti/slim/core/storage.h index 73c4d32d955..a3d17a89903 100644 --- a/backends/aoti/slim/core/storage.h +++ b/backends/aoti/slim/core/storage.h @@ -13,6 +13,7 @@ #ifdef CUDA_AVAILABLE #include #include +#include #endif #include @@ -107,9 +108,6 @@ struct DeviceTraits { /// @param device The target CUDA device (used to get the stream). /// @return Pointer to allocated device memory. static void* allocate(size_t nbytes, const c10::Device& device) { - // Get the current stream for this device (set by CUDAStreamGuard if any) - // This follows PyTorch's pattern where the allocator assumes the caller - // has already set the correct device via CUDAStreamGuard. auto stream_result = executorch::backends::cuda::getCurrentCUDAStream(device.index()); ET_CHECK_MSG( @@ -118,31 +116,23 @@ struct DeviceTraits { static_cast(device.index())); cudaStream_t stream = stream_result.get(); - void* data = nullptr; - ET_CUDA_CHECK(cudaMallocAsync(&data, nbytes, stream)); - return data; + auto result = executorch::backends::cuda::CudaAllocator::allocate_async( + nbytes, device.index(), stream); + ET_CHECK_MSG( + result.ok(), + "CudaAllocator::allocate_async failed for %zu bytes on device %d", + nbytes, + static_cast(device.index())); + return result.get(); } - /// Frees CUDA device memory on the current stream. - /// @param ptr Pointer to device memory to free. static void free(void* ptr) { - // Get the current stream for the current device - // Currently all cuda slimtensors should be on the same device same stream, - // so we can just use the stream on current device. - // TODO(gasoonjia): add cuda stream as a member of MaybeOwningStorage to - // support multiple devices. auto stream_result = executorch::backends::cuda::getCurrentCUDAStream(-1); ET_CHECK_MSG(stream_result.ok(), "Failed to get current CUDA stream"); - ET_CUDA_LOG_WARN(cudaFreeAsync(ptr, stream_result.get())); + executorch::backends::cuda::CudaAllocator::deallocate_async( + ptr, -1, stream_result.get()); } - /// Copies memory between CPU and CUDA or CUDA and CUDA asynchronously. - /// @param dst Destination pointer. - /// @param src Source pointer. - /// @param nbytes Number of bytes to copy. - /// @param dst_device Destination device. - /// @param src_device Source device. - /// @param stream CUDA stream for async copy. static void memcpy_async( void* dst, const void* src, @@ -151,7 +141,6 @@ struct DeviceTraits { const c10::Device& src_device, cudaStream_t stream) { cudaMemcpyKind direction = cudaMemcpyDeviceToDevice; - if (src_device.is_cpu()) { direction = cudaMemcpyHostToDevice; } else if (dst_device.is_cpu()) { @@ -164,15 +153,11 @@ struct DeviceTraits { static_cast(dst_device.index())); } - ET_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, direction, stream)); + auto err = executorch::backends::cuda::CudaAllocator::memcpy_async( + dst, src, nbytes, direction, stream); + ET_CHECK_MSG(err == executorch::runtime::Error::Ok, "memcpy_async failed"); } - /// Copies memory between CPU and CUDA or CUDA and CUDA synchronously. - /// @param dst Destination pointer. - /// @param src Source pointer. - /// @param nbytes Number of bytes to copy. - /// @param dst_device Destination device. - /// @param src_device Source device. static void memcpy( void* dst, const void* src, @@ -180,7 +165,6 @@ struct DeviceTraits { const c10::Device& dst_device, const c10::Device& src_device) { cudaMemcpyKind direction = cudaMemcpyDeviceToDevice; - if (src_device.is_cpu()) { direction = cudaMemcpyHostToDevice; } else if (dst_device.is_cpu()) { diff --git a/backends/aoti/slim/core/targets.bzl b/backends/aoti/slim/core/targets.bzl index b9148305c91..42a7b79da6e 100644 --- a/backends/aoti/slim/core/targets.bzl +++ b/backends/aoti/slim/core/targets.bzl @@ -19,6 +19,7 @@ def define_common_targets(): "//executorch/runtime/platform:platform", "//executorch/backends/aoti/slim/c10/cuda:exception", "//executorch/backends/aoti/slim/cuda:guard", + "//executorch/backends/cuda/runtime:cuda_allocator", ], ) diff --git a/backends/arm/operator_support/TARGETS b/backends/arm/operator_support/TARGETS index 8f6721bd911..a2fd054d472 100644 --- a/backends/arm/operator_support/TARGETS +++ b/backends/arm/operator_support/TARGETS @@ -6,6 +6,7 @@ runtime.python_library( deps = [ "//executorch/backends/arm:constants", "//executorch/backends/arm/_passes:passes", + "//executorch/backends/arm/tosa:resize_utils", "//executorch/backends/arm/tosa:tosa", "//executorch/backends/transforms:remove_getitem_op", "//executorch/backends/xnnpack/_passes:xnnpack_passes", diff --git a/backends/arm/operator_support/index_select_support.py b/backends/arm/operator_support/index_select_support.py index a3188e739c7..285b2cfe79f 100644 --- a/backends/arm/operator_support/index_select_support.py +++ b/backends/arm/operator_support/index_select_support.py @@ -77,8 +77,16 @@ def is_node_tosa_supported( f"{node.target}: dtype {values_dtype} requires INT profile.", ) return False - # fp16/fp32: either FP profile, or INT profile (via quantization) - elif values_dtype in (torch.float16, torch.float32): + # fp16/fp32/bf16: either FP profile, or INT profile (via quantization) + elif values_dtype in (torch.float16, torch.float32, torch.bfloat16): + if values_dtype == torch.bfloat16 and not tosa_spec.support_extension( + "bf16" + ): + self.reporter.report_reject( + node, + f"{node.target}: dtype {values_dtype} requires bf16 extension.", + ) + return False if not (tosa_spec.support_float() or tosa_spec.support_integer()): self.reporter.report_reject( node, @@ -90,7 +98,7 @@ def is_node_tosa_supported( self.reporter.report_reject( node, f"{node.target}: unsupported values dtype {values_dtype}; " - "expected bool/int8/int16/int32/float16/float32.", + "expected bool/int8/int16/int32/float16/bfloat16/float32.", ) return False diff --git a/backends/arm/operator_support/unfold_copy_support.py b/backends/arm/operator_support/unfold_copy_support.py index bf6c1cad22e..ac9fc7d0ee3 100644 --- a/backends/arm/operator_support/unfold_copy_support.py +++ b/backends/arm/operator_support/unfold_copy_support.py @@ -84,8 +84,16 @@ def is_node_tosa_supported( f"{node.target}: dtype {values_dtype} requires INT profile.", ) return False - # fp16/fp32: either FP profile, or INT profile (via quantization) - elif values_dtype in (torch.float16, torch.float32): + # fp16/fp32/bf16: either FP profile, or INT profile (via quantization) + elif values_dtype in (torch.float16, torch.float32, torch.bfloat16): + if values_dtype == torch.bfloat16 and not tosa_spec.support_extension( + "bf16" + ): + self.reporter.report_reject( + node, + f"{node.target}: dtype {values_dtype} requires bf16 extension.", + ) + return False if not (tosa_spec.support_float() or tosa_spec.support_integer()): self.reporter.report_reject( node, @@ -97,7 +105,7 @@ def is_node_tosa_supported( self.reporter.report_reject( node, f"{node.target}: unsupported values dtype {values_dtype}; " - "expected bool/int8/int16/int32/float16/float32.", + "expected bool/int8/int16/int32/float16/bfloat16/float32.", ) return False diff --git a/backends/arm/operator_support/upsample_support.py b/backends/arm/operator_support/upsample_support.py index bd03a4d2b4f..42e88f08521 100644 --- a/backends/arm/operator_support/upsample_support.py +++ b/backends/arm/operator_support/upsample_support.py @@ -13,9 +13,53 @@ SupportedTOSAOperatorCheck, ) from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.resize_utils import get_tosa_resize_validation_error from executorch.exir.dialects._ops import ops as exir_ops +def _is_upsample_node_tosa_supported( + support_check: SupportedTOSAOperatorCheck, + node: fx.Node, + tosa_spec: TosaSpecification, + *, + align_corners: bool, +) -> bool: + input_node = ensure_type(fx.Node, node.args[0]) + input_size_yx = get_first_fake_tensor(input_node).shape[2:] + output_size_yx = get_first_fake_tensor(node).shape[2:] + + try: + scale_y_n, scale_y_d, offset_y, border_y = ( + RewriteUpsamplePass.get_resize_parameters_1d( + input_size_yx[0], output_size_yx[0], align_corners + ) + ) + scale_x_n, scale_x_d, offset_x, border_x = ( + RewriteUpsamplePass.get_resize_parameters_1d( + input_size_yx[1], output_size_yx[1], align_corners + ) + ) + except RuntimeError as err: + support_check.reporter.report_reject(node, str(err)) + return False + + # Validate the exact TOSA RESIZE parameters that RewriteUpsamplePass will + # emit so support checks and fake-op validation reject the same cases. + validation_error = get_tosa_resize_validation_error( + input_hw=input_size_yx, + output_hw=output_size_yx, + scale=[scale_y_n, scale_y_d, scale_x_n, scale_x_d], + offset=[offset_y, offset_x], + border=[border_y, border_x], + tosa_spec=tosa_spec, + ) + if validation_error is not None: + support_check.reporter.report_reject(node, validation_error) + return False + + return True + + @register_tosa_support_check class UpsampleNearest2dSupported(SupportedTOSAOperatorCheck): """Provide the explicit TOSA support gate for nearest upsample.""" @@ -23,9 +67,11 @@ class UpsampleNearest2dSupported(SupportedTOSAOperatorCheck): targets = [exir_ops.edge.aten.upsample_nearest2d.vec] def is_node_tosa_supported( - self, _node: fx.Node, _tosa_spec: TosaSpecification + self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: # type: ignore[override, misc] - return True + return _is_upsample_node_tosa_supported( + self, node, tosa_spec, align_corners=False + ) @register_tosa_support_check @@ -37,33 +83,9 @@ class UpsampleBilinear2dSupported(SupportedTOSAOperatorCheck): targets = [exir_ops.edge.aten.upsample_bilinear2d.vec] def is_node_tosa_supported( - self, node: fx.Node, _tosa_spec: TosaSpecification + self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: # type: ignore[override, misc] - input_node = ensure_type(fx.Node, node.args[0]) align_corners = ensure_type(bool, node.args[2]) - input_size_yx = get_first_fake_tensor(input_node).shape[2:] - output_size_yx = get_first_fake_tensor(node).shape[2:] - - try: - scale_y_n, scale_y_d, _, _ = RewriteUpsamplePass.get_resize_parameters_1d( - input_size_yx[0], output_size_yx[0], align_corners - ) - scale_x_n, scale_x_d, _, _ = RewriteUpsamplePass.get_resize_parameters_1d( - input_size_yx[1], output_size_yx[1], align_corners - ) - except RuntimeError as err: - self.reporter.report_reject(node, str(err)) - return False - - # get_resize_parameters_1d() returns the TOSA RESIZE scale fraction for - # each spatial dimension. For align_corners=False, this is the effective - # output_size / input_size ratio, so the 1/16 boundary is checked - # directly in the same representation that RESIZE lowering will use. - if scale_y_d >= 16 * scale_y_n or scale_x_d >= 16 * scale_x_n: - self.reporter.report_reject( - node, - "Bilinear RESIZE downscale must be strictly greater than 1/16", - ) - return False - - return True + return _is_upsample_node_tosa_supported( + self, node, tosa_spec, align_corners=align_corners + ) diff --git a/backends/arm/runtime/VGFSetup.cpp b/backends/arm/runtime/VGFSetup.cpp index b62a6b2ec23..307d0ab266e 100644 --- a/backends/arm/runtime/VGFSetup.cpp +++ b/backends/arm/runtime/VGFSetup.cpp @@ -793,9 +793,14 @@ bool VgfRepr::process_vgf( return false; } - vector - bind_point_requirements; - bind_point_requirements.resize(bind_point_count); + vector bind_point_requirements( + bind_point_count, + { + .sType = + VK_STRUCTURE_TYPE_DATA_GRAPH_PIPELINE_SESSION_BIND_POINT_REQUIREMENT_ARM, + .pNext = nullptr, + }); + result = vkGetDataGraphPipelineSessionBindPointRequirementsARM( vk_device, &bind_point_requirements_info, diff --git a/backends/arm/scripts/build_executorch.sh b/backends/arm/scripts/build_executorch.sh index 54d2091d1f4..5ac2674f964 100755 --- a/backends/arm/scripts/build_executorch.sh +++ b/backends/arm/scripts/build_executorch.sh @@ -7,6 +7,7 @@ # Optional parameter: # --build_type= "Release" | "Debug" | "RelWithDebInfo" | "UndefinedSanitizer" | "AddressSanitizer" # --etdump build with devtools-etdump support +# --cmake-args= Additional arguments passed to cmake configure set -eu @@ -24,6 +25,7 @@ build_type="Release" build_devtools=OFF build_with_etdump=OFF is_linux_musl=0 +extra_cmake_args=() target_cpu="" help() { @@ -33,6 +35,7 @@ help() { echo " --build_type= Build with Release, Debug, RelWithDebInfo, UndefinedSanitizer or AddressSanitizer, default is ${build_type}" echo " --devtools Build Devtools libs" echo " --etdump Adds Devtools etdump support to track timing, etdump area will be base64 encoded in the log" + echo " --cmake-args= Additional arguments passed to cmake configure" echo " --toolchain= Toolchain can be specified (arm-none-eabi-gcc, arm-zephyr-eabi-gcc, aarch64-linux-musl-gcc). Default: ${toolchain}" echo " --target_cpu= Override the toolchain's default TARGET_CPU (e.g. cortex-m4). Switching target_cpu reuses the same cmake-out dir, so clear ${et_build_root}/cmake-out first to avoid stale per-CPU artifacts. Default: unset (toolchain default)." exit 0 @@ -45,6 +48,10 @@ for arg in "$@"; do --build_type=*) build_type="${arg#*=}";; --devtools) build_devtools=ON ;; --etdump) build_with_etdump=ON ;; + --cmake-args=*) + # shellcheck disable=SC2206 + extra_cmake_args=(${arg#*=}) + ;; --toolchain=*) toolchain="${arg#*=}";; --target_cpu=*) target_cpu="${arg#*=}";; *) @@ -89,6 +96,7 @@ cmake_args=( -DEXECUTORCH_BUILD_DEVTOOLS=${build_devtools} -DEXECUTORCH_BUILD_ARM_ETDUMP=${build_with_etdump} -DEXECUTORCH_BAREMETAL_SKIP_INSTALL=OFF + "${extra_cmake_args[@]}" ) if [[ -n "${target_cpu}" ]]; then diff --git a/backends/arm/scripts/corstone_utils.cmake b/backends/arm/scripts/corstone_utils.cmake index 58ce4f9a919..34f04ba1225 100644 --- a/backends/arm/scripts/corstone_utils.cmake +++ b/backends/arm/scripts/corstone_utils.cmake @@ -50,11 +50,12 @@ function(fetch_ethos_u_content ETHOS_SDK_PATH ET_DIR_PATH) WORKING_DIRECTORY ${ET_DIR_PATH} ) # Always patch the core_platform repo since this is fast enough. TODO: - # examples/arm/ethos-u-setup/core_platform/0002-*.patch is a transient bridge - # that guards Armv8-M-only MPU init so the source compiles for non-Armv8-M - # Cortex-M cores. Once the same guard lands upstream in ethos-u/core_platform - # and ${core_platform_base_rev} is bumped past that commit, delete the 0002 - # patch. + # examples/arm/ethos-u-setup/core_platform/0002-*.patch and 0003-*.patch are + # transient bridges that guard Armv8-M-only MPU init and the Armv7-M-and-newer + # HardFault handler so the Corstone-300 target source compiles for older + # Cortex-M cores. Once the equivalent guards land upstream in + # ethos-u/core_platform and ${core_platform_base_rev} is bumped past those + # commits, delete the 0002 and 0003 patches. set(core_platform_base_rev "26.02") execute_process( COMMAND diff --git a/backends/arm/scripts/pre-push b/backends/arm/scripts/pre-push index 8e26463cd94..6aa32d07286 100755 --- a/backends/arm/scripts/pre-push +++ b/backends/arm/scripts/pre-push @@ -177,7 +177,7 @@ for COMMIT in ${COMMITS}; do for committed_file in "${license_files[@]}"; do # Skip files with certain extensions case "$committed_file" in - *.md|*.md.in|*.json|*.yml|*.yaml|*.cmake|*.patch|.gitignore|*.bzl) + *.md|*.md.in|*.json|*.yml|*.yaml|*.cmake|*.patch|.gitignore|*.bzl|BUCK|*/BUCK|TARGETS|*/TARGETS) echo -e "${INFO} Skipping license check for ${committed_file} (excluded extension)" continue ;; diff --git a/backends/arm/test/misc/tosa_dialect/test_tosa_resize.py b/backends/arm/test/misc/tosa_dialect/test_tosa_resize.py index d9d8b89feb6..0a90de5c0c0 100644 --- a/backends/arm/test/misc/tosa_dialect/test_tosa_resize.py +++ b/backends/arm/test/misc/tosa_dialect/test_tosa_resize.py @@ -33,13 +33,14 @@ def _expr(sym: torch.SymInt) -> sympy.Expr: return sympy.sympify(getattr(sym.node, "expr", sym.node._expr)) -def test_bilinear_resize_rejects_exact_one_sixteenth_downscale(): +@pytest.mark.parametrize("resize_mode", ("nearest", "bilinear")) +def test_resize_rejects_exact_one_sixteenth_downscale(resize_mode: str): with TosaLoweringContext( TosaSpecification.create_from_string("TOSA-1.0+INT") ), FakeTensorMode() as mode: with pytest.raises( TosaValueError, - match="Bilinear RESIZE downscale must be strictly greater than 1/16", + match="RESIZE downscale must be strictly greater than 1/16", ): exir_ops.backend.tosa.RESIZE.default( mode.from_tensor( @@ -48,7 +49,26 @@ def test_bilinear_resize_rejects_exact_one_sixteenth_downscale(): [2, 32, 2, 32], [15, 15], [-15, -15], - resize_mode="bilinear", + resize_mode=resize_mode, + ) + + +def test_resize_rejects_scale_numerator_over_tosa_limit(): + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT") + ), FakeTensorMode() as mode: + with pytest.raises( + TosaValueError, + match="RESIZE scale numerator must be <= 2048", + ): + exir_ops.backend.tosa.RESIZE.default( + mode.from_tensor(torch.randint(0, 10, (1, 3, 4, 2), dtype=torch.int8)), + # 2049 violates scale_n <= 1 << 11, while 2049/2 still stays + # within MAX_SCALE so this test isolates the numerator rule. + [2049, 2, 4, 2], + [0, 0], + [0, 0], + resize_mode="nearest", ) diff --git a/backends/arm/test/ops/test_index_select.py b/backends/arm/test/ops/test_index_select.py index bb5f0a92c51..4de19d30daf 100644 --- a/backends/arm/test/ops/test_index_select.py +++ b/backends/arm/test/ops/test_index_select.py @@ -61,6 +61,26 @@ def forward(self, input_: torch.Tensor, dim: int, index_: torch.Tensor): torch.tensor([3, 1], dtype=torch.int32), # [W=2] ), } +test_data_fp_bf16: dict[str, input_params] = { + # Rank-2: [K, C] -> index_select dim=0 => [W, C] + "test_bf16_rank2_dim0": ( + torch.tensor( + [[0.5, 1.25, 2.5], [3.5, 4.25, 5.75], [6.5, 7.25, 8.75]], + dtype=torch.bfloat16, + ), # [K=3, C=3] + 0, + torch.tensor([2, 0], dtype=torch.int32), # [W=2] + ), + # Rank-3: [N, K, C] -> index_select dim=-1 => [N, K, W] + "test_bf16_rank3_dim_neg1": ( + torch.tensor( + [[[0.5, 1.5], [2.5, 3.5]], [[4.5, 5.5], [6.5, 7.5]]], + dtype=torch.bfloat16, + ), # [N=2, K=2, C=2] + -1, + torch.tensor([1, 0], dtype=torch.int32), # [W=2] + ), +} # ---- INT profile: integer inputs + bool ---- test_data_int: dict[str, input_params] = { @@ -104,6 +124,18 @@ def test_index_select_tosa_FP(test_data: input_params): pipeline.run() +@common.parametrize("test_data", test_data_fp_bf16) +def test_index_select_tosa_FP_bf16(test_data: input_params): + pipeline = TosaPipelineFP[input_params]( + IndexSelect(), + test_data, + aten_op=IndexSelect.aten_op, + exir_op=IndexSelect.exir_op, + tosa_extensions=["bf16"], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_int | test_data_fp) def test_index_select_tosa_INT(test_data: input_params): # INT profile runs quantized, so we test both int inputs and float inputs here. diff --git a/backends/arm/test/ops/test_unfold_copy.py b/backends/arm/test/ops/test_unfold_copy.py index 2b502a9be10..baa4b7f64bc 100644 --- a/backends/arm/test/ops/test_unfold_copy.py +++ b/backends/arm/test/ops/test_unfold_copy.py @@ -120,6 +120,18 @@ def forward(self, input_: torch.Tensor, dim_: int, size_: int, step_: int): ), } +test_data_bf16: dict[str, input_params] = { + "test_bf16_2d_dim1": ( + torch.tensor( + [[0.1, 0.2, 0.3, 0.4, 0.5], [1.1, 1.2, 1.3, 1.4, 1.5]], + dtype=torch.bfloat16, + ), # [B=2, T=5] + 1, + 3, + 2, # U=(5-3)//2+1=2 -> [B=2, U=2, C=3] + ), +} + @common.parametrize("test_data", test_data_fp) def test_unfold_copy_tosa_FP(test_data: input_params): @@ -132,6 +144,18 @@ def test_unfold_copy_tosa_FP(test_data: input_params): pipeline.run() +@common.parametrize("test_data", test_data_bf16) +def test_unfold_copy_tosa_FP_bf16(test_data: input_params): + pipeline = TosaPipelineFP[input_params]( + UnfoldCopy(), + test_data, + aten_op=UnfoldCopy.aten_op, + exir_op=UnfoldCopy.exir_op, + tosa_extensions=["bf16"], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_int | test_data_fp) def test_unfold_copy_tosa_INT(test_data: input_params): pipeline = TosaPipelineINT[input_params]( diff --git a/backends/arm/test/ops/test_upsample_nearest2d.py b/backends/arm/test/ops/test_upsample_nearest2d.py index 5781e4ed29d..d8bf4d7dbd5 100644 --- a/backends/arm/test/ops/test_upsample_nearest2d.py +++ b/backends/arm/test/ops/test_upsample_nearest2d.py @@ -198,6 +198,17 @@ def test_upsample_nearest2d_vec_tosa_FP_interpolate(test_data: torch.Tensor): pipeline.run() +def test_upsample_nearest2d_vec_tosa_does_not_delegate_exact_one_sixteenth_downscale(): + pipeline = OpNotSupportedPipeline[input_t1]( + Interpolate(size=None, scale_factor=1.0 / 16.0), + (torch.randn(1, 3, 256, 448),), + {exir_op: 1}, + n_expected_delegates=0, + ) + + pipeline.run() + + @common.parametrize("test_data", test_data_suite) def test_upsample_nearest2d_vec_tosa_INT(test_data: torch.Tensor): test_data, size, scale_factor, compare_outputs = test_data() diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 7e7f576e35c..86a5f857e58 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -48,7 +48,7 @@ from executorch.backends.arm.vgf.compile_spec import VgfCompileSpec from executorch.backends.test.harness.stages import StageType from executorch.exir.pass_base import ExportPass -from torch._export.pass_base import PassType +from executorch.exir.pass_manager import PassType from torch.export.graph_signature import InputKind, OutputKind from torchao.quantization.pt2e.quantizer import QuantizationSpec diff --git a/backends/arm/tosa/BUCK b/backends/arm/tosa/BUCK index 46ff6648c54..81d1f62437f 100644 --- a/backends/arm/tosa/BUCK +++ b/backends/arm/tosa/BUCK @@ -41,6 +41,17 @@ fbcode_target(_kind = runtime.python_library, ], ) +fbcode_target(_kind = runtime.python_library, + name = "resize_utils", + srcs = [ + "resize_utils.py", + ], + deps = [ + "//caffe2:torch", + ":specification", + ], +) + fbcode_target(_kind = runtime.python_library, name = "tosa", srcs = [ diff --git a/backends/arm/tosa/dialect/BUCK b/backends/arm/tosa/dialect/BUCK index 4e7f5837766..5081f5d6945 100644 --- a/backends/arm/tosa/dialect/BUCK +++ b/backends/arm/tosa/dialect/BUCK @@ -22,6 +22,7 @@ fbcode_target(_kind = runtime.python_library, deps = [ ":core", "//caffe2:torch", + "//executorch/backends/arm/tosa:resize_utils", "//executorch/backends/arm/tosa:tosa", ], ) diff --git a/backends/arm/tosa/dialect/ops/resize.py b/backends/arm/tosa/dialect/ops/resize.py index c48ff508afc..8a2d4c5e60a 100644 --- a/backends/arm/tosa/dialect/ops/resize.py +++ b/backends/arm/tosa/dialect/ops/resize.py @@ -8,6 +8,10 @@ import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.resize_utils import ( + calculate_tosa_resize_output_hw, + get_tosa_resize_validation_error, +) from executorch.backends.arm.tosa.specification import ( get_context_spec, @@ -50,23 +54,17 @@ def _get_output_dtype( return output_dtype -def _validate_resize_parameters(scale, border, resize_mode): - def in_int16_range(values): - return all( - (x >= -(2**15)) and (x <= 2**15 - 1) for x in values if isinstance(x, int) - ) - - if not in_int16_range(scale): - raise TosaValueError("scale is out of the int16 range", op="RESIZE") - if not in_int16_range(border): - raise TosaValueError("border is out of the int16 range", op="RESIZE") - if resize_mode == "bilinear": - scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale - if scale_y_d >= 16 * scale_y_n or scale_x_d >= 16 * scale_x_n: - raise TosaValueError( - "Bilinear RESIZE downscale must be strictly greater than 1/16", - op="RESIZE", - ) +def _validate_resize_parameters(input_hw, output_hw, scale, offset, border, tosa_spec): + validation_error = get_tosa_resize_validation_error( + input_hw=input_hw, + output_hw=output_hw, + scale=scale, + offset=offset, + border=border, + tosa_spec=tosa_spec, + ) + if validation_error is not None: + raise TosaValueError(validation_error, op="RESIZE") @register_fake_tosa_op( @@ -88,24 +86,26 @@ def RESIZE( f"Input tensor must be 4D, but got {x.dim()}D", op="RESIZE" ) _validate_resize_mode(resize_mode) - _validate_resize_parameters(scale, border, resize_mode) output_dtype = _get_output_dtype(x.dtype, tosa_spec, resize_mode) input_shape = x.shape - scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale - offset_y, offset_x = offset - border_y, border_x = border H, W = input_shape[1], input_shape[2] - # RESIZE first upscales the input by an integer value, to "upscale space". - H_upscaled = (H - 1) * scale_y_n - # offset and border are provided in this scale, therefore adjust for these while in this space. - H_shifted = H_upscaled - offset_y + border_y - # Then, complete the RESIZE by downscaling with another integer value, approximating multplication with a fraction. - OH = (H_shifted // scale_y_d) + 1 - # Mirror the same computation horizontally for the output width. - W_upscaled = (W - 1) * scale_x_n - W_shifted = W_upscaled - offset_x + border_x - OW = (W_shifted // scale_x_d) + 1 + _validate_resize_parameters((H, W), None, scale, offset, border, tosa_spec) + output_hw = calculate_tosa_resize_output_hw((H, W), scale, offset, border) + _validate_resize_parameters((H, W), output_hw, scale, offset, border, tosa_spec) + if output_hw is None: + scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale + offset_y, offset_x = offset + border_y, border_x = border + # RESIZE first upscales the input by an integer value to "upscale + # space". Offset and border are encoded in that space, then RESIZE + # completes by downscaling with another integer value, approximating + # multiplication by a fraction. + OH = ((H - 1) * scale_y_n - offset_y + border_y) // scale_y_d + 1 + OW = ((W - 1) * scale_x_n - offset_x + border_x) // scale_x_d + 1 + else: + OH, OW = output_hw + fake_aten_tensor = torch.empty( size=(input_shape[0], OH, OW, input_shape[3]), dtype=output_dtype ) diff --git a/backends/arm/tosa/resize_utils.py b/backends/arm/tosa/resize_utils.py new file mode 100644 index 00000000000..6c716bfa59c --- /dev/null +++ b/backends/arm/tosa/resize_utils.py @@ -0,0 +1,259 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Sequence + +import torch + +from executorch.backends.arm.tosa.specification import TosaSpecification + +_MAX_RESIZE_DIMENSION = 16384 +_MAX_RESIZE_SCALE_NUMERATOR = 1 << 11 +_MAX_SCALE = 2048 +_MAX_SCALE_LEVEL_8K = 256 +_INT16_MIN = -(2**15) +_INT16_MAX = 2**15 - 1 + + +def _as_concrete_ints(values: Sequence[int | torch.SymInt]) -> list[int] | None: + if all(isinstance(value, int) for value in values): + return [int(value) for value in values] + return None + + +def _concrete_int_values(values: Sequence[int | torch.SymInt]) -> list[int]: + return [int(value) for value in values if isinstance(value, int)] + + +def _first_outside_range( + values: Sequence[int], min_value: int, max_value: int +) -> int | None: + return next( + (value for value in values if value < min_value or value > max_value), None + ) + + +def _max_scale(tosa_spec: TosaSpecification) -> int: + return _MAX_SCALE_LEVEL_8K if getattr(tosa_spec, "level_8k", False) else _MAX_SCALE + + +def _validate_dimensions( + input_hw: Sequence[int | torch.SymInt], + output_hw: Sequence[int | torch.SymInt] | None, +) -> str | None: + concrete_dimensions: list[int] = [] + input_hw_ints = _as_concrete_ints(input_hw) + output_hw_ints = _as_concrete_ints(output_hw) if output_hw is not None else None + if input_hw_ints is not None: + concrete_dimensions.extend(input_hw_ints) + if output_hw_ints is not None: + concrete_dimensions.extend(output_hw_ints) + + invalid_dimension = next( + ( + dimension + for dimension in concrete_dimensions + if dimension >= _MAX_RESIZE_DIMENSION + ), + None, + ) + if invalid_dimension is not None: + return ( + "RESIZE dimensions must be less than " + f"{_MAX_RESIZE_DIMENSION}; got {invalid_dimension}" + ) + return None + + +def _validate_scale( + scale: Sequence[int | torch.SymInt], + tosa_spec: TosaSpecification, +) -> str | None: + invalid_scale = _first_outside_range( + _concrete_int_values(scale), _INT16_MIN, _INT16_MAX + ) + if invalid_scale is not None: + return ( + "RESIZE scale must be in int16 range " + f"[{_INT16_MIN}, {_INT16_MAX}]; got {invalid_scale}" + ) + + scale_ints = _as_concrete_ints(scale) + if scale_ints is None: + return None + + scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale_ints + if min(scale_y_n, scale_y_d, scale_x_n, scale_x_d) <= 0: + return f"RESIZE scale values must be positive; got {scale_ints}" + + max_scale = _max_scale(tosa_spec) + if scale_y_n > max_scale * scale_y_d or scale_x_n > max_scale * scale_x_d: + return ( + f"RESIZE scale ratio must be <= MAX_SCALE ({max_scale}); " + f"got y={scale_y_n}/{scale_y_d}, x={scale_x_n}/{scale_x_d}" + ) + + if ( + scale_y_n > _MAX_RESIZE_SCALE_NUMERATOR + or scale_x_n > _MAX_RESIZE_SCALE_NUMERATOR + ): + return ( + "RESIZE scale numerator must be <= " + f"{_MAX_RESIZE_SCALE_NUMERATOR}; got y={scale_y_n}, x={scale_x_n}" + ) + + # The scale values are already in the doubled rational representation that + # TOSA RESIZE lowering emits, so the lower-bound downscale rule can be + # checked directly against them. + if scale_y_d >= 16 * scale_y_n or scale_x_d >= 16 * scale_x_n: + return ( + "RESIZE downscale must be strictly greater than 1/16; " + f"got y={scale_y_n}/{scale_y_d}, x={scale_x_n}/{scale_x_d}" + ) + return None + + +def _validate_offset( + offset: Sequence[int | torch.SymInt], + scale_ints: list[int], +) -> str | None: + offset_ints = _as_concrete_ints(offset) + if offset_ints is None: + return None + + scale_y_n, _, scale_x_n, _ = scale_ints + offset_y, offset_x = offset_ints + if offset_y < -scale_y_n or offset_y >= 16 * scale_y_n: + return ( + f"RESIZE offset_y must be in [{-scale_y_n}, {16 * scale_y_n}); " + f"got {offset_y}" + ) + if offset_x < -scale_x_n or offset_x >= 16 * scale_x_n: + return ( + f"RESIZE offset_x must be in [{-scale_x_n}, {16 * scale_x_n}); " + f"got {offset_x}" + ) + return None + + +def _validate_border( + border: Sequence[int | torch.SymInt], + scale_ints: list[int], +) -> str | None: + invalid_border = _first_outside_range( + _concrete_int_values(border), _INT16_MIN, _INT16_MAX + ) + if invalid_border is not None: + return ( + "RESIZE border must be in int16 range " + f"[{_INT16_MIN}, {_INT16_MAX}]; got {invalid_border}" + ) + + border_ints = _as_concrete_ints(border) + if border_ints is None: + return None + + scale_y_n, _, scale_x_n, _ = scale_ints + border_y, border_x = border_ints + if border_y < -16 * scale_y_n or border_y >= scale_y_n: + return ( + f"RESIZE border_y must be in [{-16 * scale_y_n}, {scale_y_n}); " + f"got {border_y}" + ) + if border_x < -16 * scale_x_n or border_x >= scale_x_n: + return ( + f"RESIZE border_x must be in [{-16 * scale_x_n}, {scale_x_n}); " + f"got {border_x}" + ) + return None + + +def _validate_output_shape( + input_hw: Sequence[int | torch.SymInt], + output_hw: Sequence[int | torch.SymInt] | None, + scale: Sequence[int | torch.SymInt], + offset: Sequence[int | torch.SymInt], + border: Sequence[int | torch.SymInt], +) -> str | None: + if output_hw is None: + return None + + output_hw_ints = _as_concrete_ints(output_hw) + expected_output_hw = calculate_tosa_resize_output_hw( + input_hw, scale, offset, border + ) + if ( + output_hw_ints is not None + and expected_output_hw is not None + and tuple(output_hw_ints) != expected_output_hw + ): + return ( + "RESIZE output shape is inconsistent with input and parameters; " + f"expected {expected_output_hw}, got {tuple(output_hw_ints)}" + ) + return None + + +def calculate_tosa_resize_output_hw( + input_hw: Sequence[int | torch.SymInt], + scale: Sequence[int | torch.SymInt], + offset: Sequence[int | torch.SymInt], + border: Sequence[int | torch.SymInt], +) -> tuple[int, int] | None: + input_hw_ints = _as_concrete_ints(input_hw) + scale_ints = _as_concrete_ints(scale) + offset_ints = _as_concrete_ints(offset) + border_ints = _as_concrete_ints(border) + if ( + input_hw_ints is None + or scale_ints is None + or offset_ints is None + or border_ints is None + ): + return None + + input_h, input_w = input_hw_ints + scale_y_n, scale_y_d, scale_x_n, scale_x_d = scale_ints + offset_y, offset_x = offset_ints + border_y, border_x = border_ints + + # RESIZE first upscales the input by an integer value to "upscale space". + # Offset and border are encoded in that space, then RESIZE completes by + # downscaling with another integer value, approximating multiplication by a + # fraction. + return ( + ((input_h - 1) * scale_y_n - offset_y + border_y) // scale_y_d + 1, + ((input_w - 1) * scale_x_n - offset_x + border_x) // scale_x_d + 1, + ) + + +def get_tosa_resize_validation_error( + *, + input_hw: Sequence[int | torch.SymInt], + output_hw: Sequence[int | torch.SymInt] | None, + scale: Sequence[int | torch.SymInt], + offset: Sequence[int | torch.SymInt], + border: Sequence[int | torch.SymInt], + tosa_spec: TosaSpecification, +) -> str | None: + scale_ints = _as_concrete_ints(scale) + + validation_error = _validate_dimensions(input_hw, output_hw) + if validation_error is not None: + return validation_error + validation_error = _validate_scale(scale, tosa_spec) + if validation_error is not None: + return validation_error + if scale_ints is None: + return None + + for validation_error in ( + _validate_offset(offset, scale_ints), + _validate_border(border, scale_ints), + _validate_output_shape(input_hw, output_hw, scale, offset, border), + ): + if validation_error is not None: + return validation_error + return None diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 4b60feb2121..50112a4eb66 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -162,14 +162,31 @@ def targets(self) -> list[EdgeOpOverload]: def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops + out_dtype = node.kwargs.get("out_dtype") + kwargs = {k: v for k, v in node.kwargs.items() if k != "out_dtype"} with node.graph.inserting_before(node): new_node = node.graph.call_function( ns.cadence.dequantize_per_tensor.default, args=node.args, - kwargs=node.kwargs, + kwargs=kwargs, ) - new_node.meta = node.meta - node.replace_all_uses_with(new_node) + new_node.meta = node.meta.copy() + if ( + out_dtype is not None + and out_dtype != torch.float32 + and "val" in new_node.meta + ): + new_node.meta["val"] = new_node.meta["val"].to(torch.float32) + if out_dtype is not None and out_dtype != torch.float32: + with node.graph.inserting_after(new_node): + cast_node = node.graph.call_function( + ns.aten.to.dtype, + args=(new_node, out_dtype), + ) + cast_node.meta = node.meta.copy() + node.replace_all_uses_with(cast_node) + else: + node.replace_all_uses_with(new_node) return True diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 170da6deb09..a73ef02c996 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -1250,6 +1250,7 @@ def test_replace_conv1d_with_linear(self) -> None: inputs, "ReplaceTrivialConvWithLinear", rtol=2e-5, + atol=5e-6, ) # Assert that conv1d is trivially converted to linear @@ -1294,6 +1295,7 @@ def test_replace_conv2d_with_linear(self) -> None: inputs, "ReplaceTrivialConvWithLinear", rtol=2e-5, + atol=5e-6, ) # Assert that conv2d is trivially converted to linear diff --git a/backends/cortex_m/CMakeLists.txt b/backends/cortex_m/CMakeLists.txt index 876c65982e6..627406c1935 100644 --- a/backends/cortex_m/CMakeLists.txt +++ b/backends/cortex_m/CMakeLists.txt @@ -30,6 +30,10 @@ set(CMSIS_NN_LOCAL_PATH "" CACHE PATH "Path to existing local CMSIS-NN installation" ) +option(CORTEX_M_ENABLE_RUNTIME_CHECKS + "Enable additional Cortex-M runtime assertions and validation checks" + OFF +) # Try to find existing / local CMSIS-NN installation. This is useful for # debugging and testing with local changes. This is not common, as the CMSIS-NN @@ -107,6 +111,11 @@ target_link_libraries( PRIVATE executorch PRIVATE kernels_util_all_deps ) +target_compile_definitions( + cortex_m_kernels + PRIVATE + $<$:CORTEX_M_ENABLE_RUNTIME_CHECKS> +) # Include directories for cortex_m_kernels target_include_directories( diff --git a/backends/cortex_m/ops/cmsis_scratch_buffer_context.h b/backends/cortex_m/ops/cmsis_scratch_buffer_context.h index 4672f05e777..656309abcee 100644 --- a/backends/cortex_m/ops/cmsis_scratch_buffer_context.h +++ b/backends/cortex_m/ops/cmsis_scratch_buffer_context.h @@ -1,3 +1,4 @@ +// cppcheck-suppress-file unusedFunction /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. diff --git a/backends/cortex_m/ops/cortex_m_ops_common.h b/backends/cortex_m/ops/cortex_m_ops_common.h index 4c0f83d6eb6..2e3f49dd861 100644 --- a/backends/cortex_m/ops/cortex_m_ops_common.h +++ b/backends/cortex_m/ops/cortex_m_ops_common.h @@ -113,8 +113,7 @@ inline void validate_quantization_params( const int64_t shift2, const int64_t output_zero_point, const int64_t output_multiplier, - const int64_t output_shift, - Tensor& output) { + const int64_t output_shift) { validate_single_quant_params( zero_point1, multiplier1, shift1, "Single quant Input1"); validate_single_quant_params( @@ -346,6 +345,7 @@ inline bool prepare_cmsis_pool2d_config( // https://github.com/ARM-software/CMSIS-NN/blob/main/Include/arm_nnsupportfunctions.h#L1625 // multiplier: Range {ARM_NN_Q31_MIN + 1, Q32_MAX} // shift : Range {-31, 30} +// cppcheck-suppress unusedFunction inline bool validate_per_channel_quant_params( const Int64ArrayRef multipliers, const Int64ArrayRef shifts, diff --git a/backends/cortex_m/ops/op_dequantize_per_tensor.cpp b/backends/cortex_m/ops/op_dequantize_per_tensor.cpp index ca648f74695..136bce297b0 100644 --- a/backends/cortex_m/ops/op_dequantize_per_tensor.cpp +++ b/backends/cortex_m/ops/op_dequantize_per_tensor.cpp @@ -100,6 +100,7 @@ F dequantize_val(float scale, int32_t zero_point, Q qvalue) { } // namespace Tensor& dequantize_per_tensor_out( + // cppcheck-suppress constParameterReference KernelRuntimeContext& context, const Tensor& input, double scale, diff --git a/backends/cortex_m/ops/op_maximum.cpp b/backends/cortex_m/ops/op_maximum.cpp index fc76f5c8c48..936ef273684 100644 --- a/backends/cortex_m/ops/op_maximum.cpp +++ b/backends/cortex_m/ops/op_maximum.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 Arm Limited and/or its affiliates. + * Copyright 2025-2026 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -12,6 +12,7 @@ namespace native { using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +// cppcheck-suppress unusedFunction Tensor& maximum_out( KernelRuntimeContext& context, const Tensor& input1, diff --git a/backends/cortex_m/ops/op_minimum.cpp b/backends/cortex_m/ops/op_minimum.cpp index 5a75cb8a1dc..3324a4e39d7 100644 --- a/backends/cortex_m/ops/op_minimum.cpp +++ b/backends/cortex_m/ops/op_minimum.cpp @@ -1,7 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. - * Copyright 2025 Arm Limited and/or its affiliates. + * Copyright 2025-2026 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -14,6 +14,7 @@ namespace native { using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +// cppcheck-suppress unusedFunction Tensor& minimum_out( KernelRuntimeContext& context, const Tensor& input1, diff --git a/backends/cortex_m/ops/op_pad.cpp b/backends/cortex_m/ops/op_pad.cpp index e59f986c37d..57b5257873e 100644 --- a/backends/cortex_m/ops/op_pad.cpp +++ b/backends/cortex_m/ops/op_pad.cpp @@ -19,6 +19,7 @@ constexpr size_t kMaxSupportedDims = 4; } // namespace +// cppcheck-suppress unusedFunction Tensor& pad_out( KernelRuntimeContext& context, const Tensor& input, diff --git a/backends/cortex_m/ops/op_quantize_per_tensor.cpp b/backends/cortex_m/ops/op_quantize_per_tensor.cpp index 7809db379c7..d8bb34c6eb4 100644 --- a/backends/cortex_m/ops/op_quantize_per_tensor.cpp +++ b/backends/cortex_m/ops/op_quantize_per_tensor.cpp @@ -97,6 +97,7 @@ Q quantize_val( } // namespace Tensor& quantize_per_tensor_out( + // cppcheck-suppress constParameterReference KernelRuntimeContext& context, const Tensor& input, double scale, diff --git a/backends/cortex_m/ops/op_quantized_add.cpp b/backends/cortex_m/ops/op_quantized_add.cpp index f607977aa48..f93bb6c1be9 100644 --- a/backends/cortex_m/ops/op_quantized_add.cpp +++ b/backends/cortex_m/ops/op_quantized_add.cpp @@ -13,6 +13,7 @@ namespace cortex_m { namespace native { using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +// cppcheck-suppress unusedFunction Tensor& quantized_add_out( KernelRuntimeContext& context, const Tensor& input1_int8, @@ -49,8 +50,7 @@ Tensor& quantized_add_out( input2_shift, output_zero_point, output_multiplier, - output_shift, - out); + output_shift); ET_LOG( Debug, diff --git a/backends/cortex_m/ops/op_quantized_avg_pool2d.cpp b/backends/cortex_m/ops/op_quantized_avg_pool2d.cpp index fc04edcc82b..0d22971f89b 100644 --- a/backends/cortex_m/ops/op_quantized_avg_pool2d.cpp +++ b/backends/cortex_m/ops/op_quantized_avg_pool2d.cpp @@ -12,6 +12,7 @@ namespace native { using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +// cppcheck-suppress unusedFunction Tensor& quantized_avg_pool2d_out( KernelRuntimeContext& context, const Tensor& input, diff --git a/backends/cortex_m/ops/op_quantized_batch_matmul.cpp b/backends/cortex_m/ops/op_quantized_batch_matmul.cpp index e6bc5a949ce..fd0859e8b00 100644 --- a/backends/cortex_m/ops/op_quantized_batch_matmul.cpp +++ b/backends/cortex_m/ops/op_quantized_batch_matmul.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2026 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -62,6 +63,7 @@ bool validate_batch_matmul_arguments( } // namespace +// cppcheck-suppress unusedFunction Tensor& quantized_batch_matmul_out( KernelRuntimeContext& context, const Tensor& lhs, @@ -71,6 +73,7 @@ Tensor& quantized_batch_matmul_out( int64_t output_offset, int64_t output_multiplier, int64_t output_shift, + const Tensor& scratch, Tensor& out) { if (!validate_batch_matmul_arguments(context, lhs, rhs_transposed, out)) { return out; @@ -100,25 +103,26 @@ Tensor& quantized_batch_matmul_out( quant_params.multiplier = static_cast(output_multiplier); quant_params.shift = static_cast(output_shift); - const int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&out_dims); - cmsis_nn_context ctx; ctx.buf = nullptr; - ctx.size = 0; - - if (buf_size > 0) { - auto buffer_or_error = context.allocate_temp(buf_size); - if (!buffer_or_error.ok()) { - ET_LOG( - Error, - "quantized_batch_matmul: failed to allocate scratch buffer (%d bytes)", - buf_size); - context.fail(buffer_or_error.error()); - return out; - } - ctx.buf = buffer_or_error.get(); - ctx.size = buf_size; + ctx.size = scratch.nbytes(); + if (ctx.size > 0) { + ctx.buf = scratch.mutable_data_ptr(); + } + +#ifdef CORTEX_M_ENABLE_RUNTIME_CHECKS + const int32_t runtime_buffer_bytes = + arm_fully_connected_s8_get_buffer_size(&out_dims); + if (ctx.size != static_cast(runtime_buffer_bytes)) { + ET_LOG( + Error, + "quantized_batch_matmul: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(ctx.size), + runtime_buffer_bytes); + context.fail(Error::Internal); + return out; } +#endif const arm_cmsis_nn_status status = arm_batch_matmul_s8( &ctx, diff --git a/backends/cortex_m/ops/op_quantized_conv2d.cpp b/backends/cortex_m/ops/op_quantized_conv2d.cpp index 7d4433690f6..3d4f19e10d0 100644 --- a/backends/cortex_m/ops/op_quantized_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_conv2d.cpp @@ -98,6 +98,7 @@ bool validate_conv2d_arguments( } } // namespace +// cppcheck-suppress unusedFunction Tensor& quantized_conv2d_out( KernelRuntimeContext& context, const Tensor& input, @@ -112,6 +113,7 @@ Tensor& quantized_conv2d_out( const Tensor& requantize_shifts, const int64_t activation_min, const int64_t activation_max, + const Tensor& scratch, Tensor& out) { if (!validate_conv2d_arguments( context, @@ -182,31 +184,30 @@ Tensor& quantized_conv2d_out( cmsis_nn_context cmsis_context; cmsis_context.buf = nullptr; - cmsis_context.size = 0; + cmsis_context.size = scratch.nbytes(); + if (cmsis_context.size > 0) { + cmsis_context.buf = scratch.mutable_data_ptr(); + } - const int32_t buffer_bytes = arm_convolve_wrapper_s8_get_buffer_size( +#ifdef CORTEX_M_ENABLE_RUNTIME_CHECKS + const int32_t runtime_buffer_bytes = arm_convolve_wrapper_s8_get_buffer_size( &conv_params, &input_dims, &filter_dims, &output_dims); - if (buffer_bytes < 0) { + if (runtime_buffer_bytes < 0) { ET_LOG( Error, "quantized_conv2d_out: CMSIS-NN buffer size calculation failed"); context.fail(Error::Internal); return out; } - if (buffer_bytes > 0) { - auto buffer_or_error = - context.allocate_temp(buffer_bytes, kCortexMMveAlignment); - if (!buffer_or_error.ok()) { - ET_LOG( - Error, - "quantized_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)", - static_cast(buffer_bytes), - static_cast(buffer_or_error.error())); - context.fail(buffer_or_error.error()); - return out; - } - cmsis_context.buf = buffer_or_error.get(); - cmsis_context.size = buffer_bytes; + if (scratch.nbytes() != static_cast(runtime_buffer_bytes)) { + ET_LOG( + Error, + "quantized_conv2d_out: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(scratch.nbytes()), + static_cast(runtime_buffer_bytes)); + context.fail(Error::Internal); + return out; } +#endif const arm_cmsis_nn_status status = arm_convolve_wrapper_s8( &cmsis_context, diff --git a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp index 8dec61e0af1..a8e1fc21ed7 100644 --- a/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_depthwise_conv2d.cpp @@ -135,6 +135,7 @@ bool validate_depthwise_conv2d_arguments( } } // namespace +// cppcheck-suppress unusedFunction Tensor& quantized_depthwise_conv2d_out( KernelRuntimeContext& context, const Tensor& input, @@ -150,6 +151,7 @@ Tensor& quantized_depthwise_conv2d_out( const Tensor& requantize_shifts, const int64_t activation_min, const int64_t activation_max, + const Tensor& scratch, Tensor& out) { if (!validate_depthwise_conv2d_arguments( context, @@ -220,32 +222,32 @@ Tensor& quantized_depthwise_conv2d_out( cmsis_nn_context cmsis_context; cmsis_context.buf = nullptr; - cmsis_context.size = 0; + cmsis_context.size = scratch.nbytes(); + if (cmsis_context.size > 0) { + cmsis_context.buf = scratch.mutable_data_ptr(); + } - const int32_t buffer_bytes = arm_depthwise_conv_wrapper_s8_get_buffer_size( - &dw_conv_params, &input_dims, &filter_dims, &output_dims); - if (buffer_bytes < 0) { +#ifdef CORTEX_M_ENABLE_RUNTIME_CHECKS + const int32_t runtime_buffer_bytes = + arm_depthwise_conv_wrapper_s8_get_buffer_size( + &dw_conv_params, &input_dims, &filter_dims, &output_dims); + if (runtime_buffer_bytes < 0) { ET_LOG( Error, "quantized_depthwise_conv2d_out: CMSIS-NN buffer size calculation failed"); context.fail(Error::Internal); return out; } - - auto buffer_or_error = context.allocate_temp( - static_cast(buffer_bytes), kCortexMMveAlignment); - if (!buffer_or_error.ok()) { + if (scratch.nbytes() != static_cast(runtime_buffer_bytes)) { ET_LOG( Error, - "quantized_depthwise_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)", - static_cast(buffer_bytes), - static_cast(buffer_or_error.error())); - context.fail(buffer_or_error.error()); + "quantized_depthwise_conv2d_out: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(scratch.nbytes()), + static_cast(runtime_buffer_bytes)); + context.fail(Error::Internal); return out; } - cmsis_context.buf = buffer_or_error.get(); - cmsis_context.size = buffer_bytes; - +#endif const arm_cmsis_nn_status status = arm_depthwise_conv_wrapper_s8( &cmsis_context, &dw_conv_params, diff --git a/backends/cortex_m/ops/op_quantized_linear.cpp b/backends/cortex_m/ops/op_quantized_linear.cpp index 5d018cbc0c4..7448058de8e 100644 --- a/backends/cortex_m/ops/op_quantized_linear.cpp +++ b/backends/cortex_m/ops/op_quantized_linear.cpp @@ -13,6 +13,7 @@ namespace cortex_m { namespace native { using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +// cppcheck-suppress unusedFunction Tensor& quantized_linear_out( KernelRuntimeContext& context, const Tensor& input, diff --git a/backends/cortex_m/ops/op_quantized_max_pool2d.cpp b/backends/cortex_m/ops/op_quantized_max_pool2d.cpp index 181a29c1b65..ca1b00ff340 100644 --- a/backends/cortex_m/ops/op_quantized_max_pool2d.cpp +++ b/backends/cortex_m/ops/op_quantized_max_pool2d.cpp @@ -10,6 +10,7 @@ namespace cortex_m { namespace native { +// cppcheck-suppress unusedFunction Tensor& quantized_max_pool2d_out( KernelRuntimeContext& context, const Tensor& input, diff --git a/backends/cortex_m/ops/op_quantized_mul.cpp b/backends/cortex_m/ops/op_quantized_mul.cpp index 524e74a6b9f..93ce2303d64 100644 --- a/backends/cortex_m/ops/op_quantized_mul.cpp +++ b/backends/cortex_m/ops/op_quantized_mul.cpp @@ -18,6 +18,7 @@ constexpr int32_t kInt8ActivationMax = std::numeric_limits::max(); using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +// cppcheck-suppress unusedFunction Tensor& quantized_mul_out( KernelRuntimeContext& context, const Tensor& input1_int8, @@ -50,8 +51,7 @@ Tensor& quantized_mul_out( kZeroShift, output_zero_point, output_multiplier, - output_shift, - out); + output_shift); // Extract quantization parameters int8_t* input1_ptr = input1_int8.data_ptr(); diff --git a/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp b/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp index e3f6135c7b9..e7ecbc7c7b4 100644 --- a/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp +++ b/backends/cortex_m/ops/op_quantized_transpose_conv2d.cpp @@ -1,6 +1,7 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. + * Copyright 2026 Arm Limited and/or its affiliates. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. @@ -82,6 +83,7 @@ bool validate_transpose_conv2d_arguments( } } // namespace +// cppcheck-suppress unusedFunction Tensor& quantized_transpose_conv2d_out( KernelRuntimeContext& context, const Tensor& input, @@ -97,6 +99,8 @@ Tensor& quantized_transpose_conv2d_out( const Tensor& requantize_shifts, const int64_t activation_min, const int64_t activation_max, + const Tensor& scratch, + const Tensor& output_scratch, Tensor& out) { if (!validate_transpose_conv2d_arguments( context, @@ -179,44 +183,43 @@ Tensor& quantized_transpose_conv2d_out( cmsis_nn_context cmsis_context; cmsis_context.buf = nullptr; - cmsis_context.size = 0; + cmsis_context.size = scratch.nbytes(); + if (cmsis_context.size > 0) { + cmsis_context.buf = scratch.mutable_data_ptr(); + } cmsis_nn_context output_context; output_context.buf = nullptr; - output_context.size = 0; - + output_context.size = output_scratch.nbytes(); + if (output_context.size > 0) { + output_context.buf = output_scratch.mutable_data_ptr(); + } +#ifdef CORTEX_M_ENABLE_RUNTIME_CHECKS const int32_t buffer_bytes = arm_transpose_conv_s8_get_buffer_size( &transpose_conv_params, &input_dims, &filter_dims, &output_dims); - auto buffer_or_error = context.allocate_temp( - static_cast(buffer_bytes), kCortexMMveAlignment); - if (!buffer_or_error.ok()) { + if (scratch.nbytes() != static_cast(buffer_bytes)) { ET_LOG( Error, - "quantized_transpose_conv2d_out: failed to allocate scratch buffer (%d bytes, error %d)", - buffer_bytes, - static_cast(buffer_or_error.error())); - context.fail(buffer_or_error.error()); + "quantized_transpose_conv2d_out: scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(scratch.nbytes()), + buffer_bytes); + context.fail(Error::Internal); return out; } - cmsis_context.buf = buffer_or_error.get(); - cmsis_context.size = buffer_bytes; const int32_t output_buffer_bytes = arm_transpose_conv_s8_get_reverse_conv_buffer_size( &transpose_conv_params, &input_dims, &filter_dims); - auto output_buffer_or_error = context.allocate_temp( - static_cast(output_buffer_bytes), kCortexMMveAlignment); - if (!output_buffer_or_error.ok()) { + if (output_scratch.nbytes() != static_cast(output_buffer_bytes)) { ET_LOG( Error, - "quantized_transpose_conv2d_out: failed to allocate output scratch buffer (%d bytes, error %d)", - output_buffer_bytes, - static_cast(output_buffer_or_error.error())); - context.fail(output_buffer_or_error.error()); + "quantized_transpose_conv2d_out: output scratch buffer size incorrect - actual: (%d) needed: (%d)", + static_cast(output_scratch.nbytes()), + output_buffer_bytes); + context.fail(Error::Internal); return out; } - output_context.buf = output_buffer_or_error.get(); - output_context.size = output_buffer_bytes; +#endif const arm_cmsis_nn_status status = arm_transpose_conv_wrapper_s8( &cmsis_context, diff --git a/backends/cortex_m/ops/op_softmax.cpp b/backends/cortex_m/ops/op_softmax.cpp index c07a538db84..97d78d07a05 100644 --- a/backends/cortex_m/ops/op_softmax.cpp +++ b/backends/cortex_m/ops/op_softmax.cpp @@ -36,6 +36,7 @@ inline int64_t normalize_dim(const Tensor& tensor, int64_t dim) { } // namespace +// cppcheck-suppress unusedFunction Tensor& softmax_out( KernelRuntimeContext& context, const Tensor& input, diff --git a/backends/cortex_m/ops/op_transpose.cpp b/backends/cortex_m/ops/op_transpose.cpp index 7fcbc034283..9ef144296b7 100644 --- a/backends/cortex_m/ops/op_transpose.cpp +++ b/backends/cortex_m/ops/op_transpose.cpp @@ -22,6 +22,7 @@ constexpr size_t kMaxSupportedDims = 4; } // namespace +// cppcheck-suppress unusedFunction Tensor& transpose_out( KernelRuntimeContext& context, const Tensor& input, diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py index 2c35ed8730b..d4393bc7ada 100644 --- a/backends/cortex_m/ops/operators.py +++ b/backends/cortex_m/ops/operators.py @@ -271,13 +271,15 @@ def quantized_mul_impl( "quantized_batch_matmul(" "Tensor lhs, int lhs_zero_point, " "Tensor rhs_transposed, int rhs_zero_point, " - "int output_zero_point, int output_multiplier, int output_shift) -> Tensor" + "int output_zero_point, int output_multiplier, int output_shift, " + "Tensor scratch) -> Tensor" ) lib.define( "quantized_batch_matmul.out(" "Tensor lhs, int lhs_zero_point, " "Tensor rhs_transposed, int rhs_zero_point, " "int output_zero_point, int output_multiplier, int output_shift, " + "Tensor scratch, " "*, Tensor(a!) out) -> Tensor(a!)" ) @@ -291,6 +293,7 @@ def quantized_batch_matmul_meta( output_zero_point: int, output_multiplier: int, output_shift: int, + scratch: torch.Tensor, ) -> torch.Tensor: batch, lhs_rows, inner = lhs.shape batch_rhs, rhs_cols, inner_rhs = rhs_transposed.shape @@ -307,6 +310,7 @@ def quantized_batch_matmul_impl( output_zero_point: int, output_multiplier: int, output_shift: int, + scratch: torch.Tensor, ) -> torch.Tensor: # Offsets are negated zero points (CMSIS-NN convention) lhs_fp = lhs.to(torch.float32) + float(lhs_zero_point) @@ -638,7 +642,8 @@ def pad_impl( "Tensor requantize_multipliers, " "Tensor requantize_shifts, " "int activation_min, " - "int activation_max" + "int activation_max, " + "Tensor scratch" ") -> Tensor" ) @@ -657,6 +662,7 @@ def pad_impl( "Tensor requantize_shifts, " "int activation_min, " "int activation_max, " + "Tensor scratch, " "*, Tensor(a!) out" ") -> Tensor(a!)" ) @@ -733,6 +739,7 @@ def quantized_conv2d_meta( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: stride_vals = list(stride) padding_vals = list(padding) @@ -762,6 +769,7 @@ def quantized_conv2d_impl( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: if input.dim() != 4 or weight.dim() != 4: raise RuntimeError("quantized_conv2d expects 4D input and weight tensors") @@ -830,7 +838,8 @@ def quantized_conv2d_impl( "Tensor requantize_multipliers, " "Tensor requantize_shifts, " "int activation_min, " - "int activation_max" + "int activation_max, " + "Tensor scratch" ") -> Tensor" ) @@ -850,6 +859,7 @@ def quantized_conv2d_impl( "Tensor requantize_shifts, " "int activation_min, " "int activation_max, " + "Tensor scratch, " "*, Tensor(a!) out" ") -> Tensor(a!)" ) @@ -870,6 +880,7 @@ def quantized_depthwise_conv2d_meta( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: stride_vals = list(stride) padding_vals = list(padding) @@ -900,6 +911,7 @@ def quantized_depthwise_conv2d_impl( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, ) -> torch.Tensor: if input.dim() != 4 or weight.dim() != 4: raise RuntimeError( @@ -973,7 +985,9 @@ def quantized_depthwise_conv2d_impl( "Tensor requantize_multipliers, " "Tensor requantize_shifts, " "int activation_min, " - "int activation_max" + "int activation_max, " + "Tensor scratch, " + "Tensor output_scratch" ") -> Tensor" ) @@ -992,6 +1006,8 @@ def quantized_depthwise_conv2d_impl( "Tensor requantize_shifts, " "int activation_min, " "int activation_max, " + "Tensor scratch, " + "Tensor output_scratch, " "*, Tensor(a!) out) -> Tensor(a!)" ) @@ -1057,6 +1073,8 @@ def quantized_transpose_conv2d_meta( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, + output_scratch: torch.Tensor, ) -> torch.Tensor: stride_vals = list(stride) padding_vals = list(padding) @@ -1095,6 +1113,8 @@ def quantized_transpose_conv2d_impl( requantize_shifts: torch.Tensor, activation_min: int, activation_max: int, + scratch: torch.Tensor, + output_scratch: torch.Tensor, ) -> torch.Tensor: """ Reference implementation of quantized transposed convolution. diff --git a/backends/cortex_m/ops/operators.yaml b/backends/cortex_m/ops/operators.yaml index e0ebbfab868..8db109dea43 100644 --- a/backends/cortex_m/ops/operators.yaml +++ b/backends/cortex_m/ops/operators.yaml @@ -65,19 +65,20 @@ - arg_meta: null kernel_name: cortex_m::pad_out -- func: cortex_m::quantized_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, Tensor scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null kernel_name: cortex_m::quantized_conv2d_out -- func: cortex_m::quantized_depthwise_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int depth_multiplier, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) + +- func: cortex_m::quantized_depthwise_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int depth_multiplier, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, Tensor scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null kernel_name: cortex_m::quantized_depthwise_conv2d_out -- func: cortex_m::quantized_transpose_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_transpose_conv2d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int input_offset, int output_offset, Tensor requantize_multipliers, Tensor requantize_shifts, int activation_min, int activation_max, Tensor scratch, Tensor output_scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null @@ -94,7 +95,7 @@ - arg_meta: null kernel_name: cortex_m::quantized_max_pool2d_out -- func: cortex_m::quantized_batch_matmul.out(Tensor lhs, int lhs_zero_point, Tensor rhs_transposed, int rhs_zero_point, int output_zero_point, int output_multiplier, int output_shift, *, Tensor(a!) out) -> Tensor(a!) +- func: cortex_m::quantized_batch_matmul.out(Tensor lhs, int lhs_zero_point, Tensor rhs_transposed, int rhs_zero_point, int output_zero_point, int output_multiplier, int output_shift, Tensor scratch, *, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null diff --git a/backends/cortex_m/passes/BUCK b/backends/cortex_m/passes/BUCK index 4e49c8cd319..f1b7b9a201d 100644 --- a/backends/cortex_m/passes/BUCK +++ b/backends/cortex_m/passes/BUCK @@ -36,6 +36,7 @@ fbcode_target(_kind = runtime.python_library, "decompose_hardswish_pass.py", "decompose_mean_pass.py", "quantized_clamp_activation_pass.py", + "scratch_buffer_sizes.py", ], deps=[ "//caffe2:torch", diff --git a/backends/cortex_m/passes/__init__.py b/backends/cortex_m/passes/__init__.py index 92179ec6654..c379461949f 100644 --- a/backends/cortex_m/passes/__init__.py +++ b/backends/cortex_m/passes/__init__.py @@ -33,6 +33,7 @@ def _ensure_cortex_m_dependencies() -> None: _ensure_cortex_m_dependencies() +from .cortex_m_pass import CortexMPass # noqa # usort: skip from .activation_fusion_pass import ActivationFusionPass # noqa from .clamp_hardswish_pass import ClampHardswishPass # noqa from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index 418f6cd63ff..5704645caf8 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -6,25 +6,32 @@ # LICENSE file in the root directory of this source tree. import executorch.backends.cortex_m.ops.operators # noqa +import executorch.exir as exir import torch import torch.fx from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor + +from executorch.backends.cortex_m.passes.cortex_m_pass import CortexMPass from executorch.backends.cortex_m.passes.passes_utils import quantize_multiplier_aot +from executorch.backends.cortex_m.passes.scratch_buffer_sizes import ( + required_cmsis_nn_buffer_sizes, +) from executorch.backends.transforms.utils import ( create_constant_placeholder, get_param_tensor, is_param_node, ) - -from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.passes import make_alloc_node +from torch._subclasses.fake_tensor import FakeTensorMode + from torch.export.graph_signature import InputKind from torch.fx.passes.infra.pass_manager import PassResult -class ConvertToCortexMPass(XNNPACKPass): +class ConvertToCortexMPass(CortexMPass): """ Cortex-M backend pass for replacing supported quantized kernels with Cortex-M accelerated kernels. @@ -33,6 +40,15 @@ class ConvertToCortexMPass(XNNPACKPass): by call_operator. """ + def _create_uninitialized_alloc_node(self): + """Create an unitialized alloc node to be initialize at a later point.""" + with FakeTensorMode() as mode: + return make_alloc_node( + self.exported_program.graph_module, + mode.from_tensor(torch.empty(0)), + None, + ) + def _compute_kernel_sum(self, weights, bias, input_offset, weight_offset): """ Computes the precomputed kernel sum term (bias optional) @@ -238,6 +254,9 @@ def _get_convolution_replacement(self, node): torch.tensor(quantized_shifts, dtype=torch.int32), ) + with node.graph.inserting_before(node): + scratch = self._create_uninitialized_alloc_node() + if use_depthwise_conv: # Compute depth_multiplier for depthwise convolution # For depthwise: output_channels = input_channels * depth_multiplier @@ -263,6 +282,7 @@ def _get_convolution_replacement(self, node): quantized_shift_tensor, output_qmin, output_qmax, + scratch, ) return exir_ops.edge.cortex_m.quantized_depthwise_conv2d.default, new_args else: @@ -280,9 +300,36 @@ def _get_convolution_replacement(self, node): quantized_shift_tensor, output_qmin, output_qmax, + scratch, ) return exir_ops.edge.cortex_m.quantized_conv2d.default, new_args + def _initialize_alloc_node_size(self, node: torch.fx.Node) -> None: + """For nodes with a registered buffer size function for node.target, set the buffer sizes + of the last n args, which should be exir.memory.alloc nodes. For nodes without a + registered function, do nothing. + """ + + scratch_buffer_sizes = required_cmsis_nn_buffer_sizes( + node, self.target_config.backend + ) + if scratch_buffer_sizes is None: + return + + # Assume that scratch_buffer_sizes are given from left to right in the call signature of node.target. + for i, scratch_buffer_size in enumerate(reversed(scratch_buffer_sizes)): + scratch_arg = node.args[-(i + 1)] + if ( + not isinstance(scratch_arg, torch.fx.Node) + or scratch_arg.target != exir.memory.alloc + ): + raise RuntimeError( + f"Expected scratch alloc node as final argument(s) for {node.target}, got {scratch_arg}." + ) + + # buffer size is given in bytes, always use uint8 as dtype. + scratch_arg.args = (((scratch_buffer_size,), torch.uint8),) + def _get_transpose_conv2d_replacement(self, node): """ Transform aten.convolution with transposed=True to cortex_m.quantized_transpose_conv2d @@ -363,6 +410,10 @@ def _get_transpose_conv2d_replacement(self, node): torch.tensor(quantized_shifts, dtype=torch.int32), ) + with node.graph.inserting_before(node): + scratch = self._create_uninitialized_alloc_node() + output_scratch = self._create_uninitialized_alloc_node() + new_args = ( x, weight_nhwc, @@ -377,6 +428,8 @@ def _get_transpose_conv2d_replacement(self, node): quantized_shift_tensor, output_qmin, output_qmax, + scratch, + output_scratch, ) return exir_ops.edge.cortex_m.quantized_transpose_conv2d.default, new_args @@ -415,6 +468,9 @@ def _get_bmm_replacement(self, node): args=(rhs_node, [0, 2, 1]), ) + with node.graph.inserting_before(node): + scratch = self._create_uninitialized_alloc_node() + args = ( lhs_node, -lhs_zp, @@ -423,6 +479,7 @@ def _get_bmm_replacement(self, node): output_zp, output_mult, output_shift, + scratch, ) return exir_ops.edge.cortex_m.quantized_batch_matmul.default, args @@ -459,6 +516,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: args=args, kwargs={}, ) + self._initialize_alloc_node_size(cortex_m_op) node.replace_all_uses_with(cortex_m_op) graph_module.graph.erase_node(node) diff --git a/backends/cortex_m/passes/scratch_buffer_sizes.py b/backends/cortex_m/passes/scratch_buffer_sizes.py new file mode 100644 index 00000000000..36f3f8bbc17 --- /dev/null +++ b/backends/cortex_m/passes/scratch_buffer_sizes.py @@ -0,0 +1,266 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Callable +from typing import Any, cast + +import cmsis_nn # type: ignore[import-not-found, import-untyped] +import executorch.backends.cortex_m.ops.operators # noqa + +import torch +import torch.fx + +from executorch.exir.dialects._ops import ops as exir_ops + +BufferSizeFunction = Callable[[cmsis_nn.Backend, torch.fx.Node], list[int]] + + +def _tensor_from_node(node: torch.fx.Node) -> torch.Tensor: + if "val" in node.meta: + return node.meta["val"] + elif node.op == "call_function": + args = ( + _tensor_from_node(arg) if isinstance(arg, torch.fx.Node) else arg + for arg in node.args + ) + return node.target(*args, **node.kwargs) # type: ignore[operator] + else: + raise RuntimeError("Encountered non-call_function without 'val' meta.") + + +def _shape_from_node(node: torch.fx.Node) -> torch.Size: + return _tensor_from_node(node).shape + + +def _get_common_conv_buffer_size_inputs( + conv_node: torch.fx.Node, + *, + stride_arg_idx: int = 3, + padding_arg_idx: int = 4, + dilation_arg_idx: int = 5, +) -> tuple[ + list[int], + list[int], + list[int], + list[int], + list[int], + list[int], +]: + x = cast(torch.fx.Node, conv_node.args[0]) + weight = cast(torch.fx.Node, conv_node.args[1]) + stride = cast(list[int], conv_node.args[stride_arg_idx]) + padding = cast(list[int], conv_node.args[padding_arg_idx]) + dilation = cast(list[int], conv_node.args[dilation_arg_idx]) + + # Input is NCHW (PyTorch); CMSIS-NN wants NHWC dims. + n, c_in, height, width = _shape_from_node(x) + + weight_shape = _shape_from_node(weight) + + # Output is NCHW; convert to NHWC dims. + out_n, out_c, out_h, out_w = _shape_from_node(conv_node) + + input_nhwc = [n, height, width, c_in] + output_nhwc = [out_n, out_h, out_w, out_c] + stride_hw = [int(stride[0]), int(stride[1])] + padding_hw = [int(padding[0]), int(padding[1])] + dilation_hw = [int(dilation[0]), int(dilation[1])] + + return ( + input_nhwc, + list(weight_shape), + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) + + +def cmsis_nn_conv_buffer_size( + backend: cmsis_nn.Backend, + conv_node: torch.fx.Node, +) -> list[int]: + ( + input_nhwc, + weight_shape, + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) = _get_common_conv_buffer_size_inputs(conv_node=conv_node) + input_offset = cast(int, conv_node.args[6]) + output_offset = cast(int, conv_node.args[7]) + output_qmin = cast(int, conv_node.args[10]) + output_qmax = cast(int, conv_node.args[11]) + + # Weight is in OHWI layout after conversion. + c_out, kernel_h, kernel_w, c_in = weight_shape + filter_nhwc = [c_out, kernel_h, kernel_w, c_in] + + return [ + int( + cmsis_nn.convolve_wrapper_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + output_nhwc=output_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ) + ] + + +def cmsis_nn_depthwise_conv_buffer_size( + backend: cmsis_nn.Backend, + conv_node: torch.fx.Node, +) -> list[int]: + ( + input_nhwc, + weight_shape, + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) = _get_common_conv_buffer_size_inputs(conv_node=conv_node) + depth_multiplier = cast(int, conv_node.args[6]) + input_offset = cast(int, conv_node.args[7]) + output_offset = cast(int, conv_node.args[8]) + output_qmin = cast(int, conv_node.args[11]) + output_qmax = cast(int, conv_node.args[12]) + + # Weight is in IHWO layout after conversion. + _, kernel_h, kernel_w, c_out = weight_shape + filter_nhwc = [c_out, kernel_h, kernel_w, 1] + + return [ + int( + cmsis_nn.depthwise_conv_wrapper_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + output_nhwc=output_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + ch_mult=depth_multiplier, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ) + ] + + +def cmsis_nn_batch_matmul_buffer_size( + backend: cmsis_nn.Backend, + matmul_node: torch.fx.Node, +) -> list[int]: + rhs_transposed = cast(torch.fx.Node, matmul_node.args[2]) + rhs_shape = _shape_from_node(rhs_transposed) + + _, rhs_cols, inner = rhs_shape + + return [ + int( + cmsis_nn.fully_connected_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + filter_nhwc=[inner, -1, -1, rhs_cols], # H and W values are unused. + ) + ) + ] + + +def cmsis_nn_transpose_conv_buffer_size( + backend: cmsis_nn.Backend, + conv_node: torch.fx.Node, +) -> list[int]: + ( + input_nhwc, + weight_shape, + output_nhwc, + stride_hw, + padding_hw, + dilation_hw, + ) = _get_common_conv_buffer_size_inputs( + conv_node=conv_node, + stride_arg_idx=3, + padding_arg_idx=4, + dilation_arg_idx=6, + ) + output_padding = cast(list[int], conv_node.args[5]) + input_offset = cast(int, conv_node.args[7]) + output_offset = cast(int, conv_node.args[8]) + output_qmin = cast(int, conv_node.args[11]) + output_qmax = cast(int, conv_node.args[12]) + c_out, kernel_h, kernel_w, kernel_c_in = weight_shape + filter_nhwc = [c_out, kernel_h, kernel_w, kernel_c_in] + padding_offsets_hw = [int(output_padding[0]), int(output_padding[1])] + + return [ + int( + cmsis_nn.transpose_conv_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + output_nhwc=output_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + padding_offsets_hw=padding_offsets_hw, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ), + int( + cmsis_nn.transpose_conv_reverse_conv_buffer_size( + backend, + cmsis_nn.DataType.A8W8, + input_nhwc=input_nhwc, + filter_nhwc=filter_nhwc, + padding_hw=padding_hw, + stride_hw=stride_hw, + dilation_hw=dilation_hw, + padding_offsets_hw=padding_offsets_hw, + input_offset=input_offset, + output_offset=output_offset, + activation_min=output_qmin, + activation_max=output_qmax, + ) + ), + ] + + +_target_to_buffer_sizes_registry: dict[Any, BufferSizeFunction] = { + exir_ops.edge.cortex_m.quantized_conv2d.default: cmsis_nn_conv_buffer_size, + exir_ops.edge.cortex_m.quantized_depthwise_conv2d.default: cmsis_nn_depthwise_conv_buffer_size, + exir_ops.edge.cortex_m.quantized_batch_matmul.default: cmsis_nn_batch_matmul_buffer_size, + exir_ops.edge.cortex_m.quantized_transpose_conv2d.default: cmsis_nn_transpose_conv_buffer_size, +} + + +def required_cmsis_nn_buffer_sizes( + node: torch.fx.Node, backend: cmsis_nn.Backend +) -> list[int] | None: + """Returns a sequence of scratch buffer sizes required by node, in bytes. + If no function is registered to compute this for the target of the node, return None. + """ + if node.target not in _target_to_buffer_sizes_registry: + return None + + buffer_size_function = _target_to_buffer_sizes_registry[node.target] + return buffer_size_function(backend, node) diff --git a/backends/cortex_m/test/build_test_runner.sh b/backends/cortex_m/test/build_test_runner.sh index bdca1a21e7c..a67c5a907a4 100755 --- a/backends/cortex_m/test/build_test_runner.sh +++ b/backends/cortex_m/test/build_test_runner.sh @@ -28,7 +28,7 @@ fi script_dir=$(realpath "$(dirname "${BASH_SOURCE[0]}")") et_root_dir=$(realpath "${script_dir}/../../..") build_executorch="${et_root_dir}/backends/arm/scripts/build_executorch.sh" -${build_executorch} --devtools --target_cpu="${target_cpu}" +${build_executorch} --devtools --target_cpu="${target_cpu}" --cmake-args="-DCORTEX_M_ENABLE_RUNTIME_CHECKS=ON" # Build executor runner with selected aten ops and semi hosting build_dir="${et_root_dir}/arm_test" @@ -48,4 +48,4 @@ aten::unsqueeze_copy.out,\ aten::select_copy.int_out,\ aten::amax.out" -${build_executor_runner} --pte=semihosting --bundleio --target="${target}" --output="${build_root_test_dir}" --select_ops_list="${select_ops_list}" --extra_build_flags="-DET_ATOL=5.0 -DET_RTOL=1.0" +${build_executor_runner} --pte=semihosting --bundleio --target="${target}" --output="${build_root_test_dir}" --select_ops_list="${select_ops_list}" --extra_build_flags="-DET_ATOL=5.0 -DET_RTOL=1.0 -DET_ARM_BAREMETAL_SCRATCH_TEMP_ALLOCATOR_POOL_SIZE=0" diff --git a/backends/cortex_m/test/models/test_silero_vad.py b/backends/cortex_m/test/models/test_silero_vad.py new file mode 100644 index 00000000000..27b958627bb --- /dev/null +++ b/backends/cortex_m/test/models/test_silero_vad.py @@ -0,0 +1,94 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import CortexMTester, McuTestCase +from executorch.examples.models.silero_vad.export_silero_vad import ( + CONTEXT_SIZE, + HIDDEN_DIM, + SileroVAD16k, + WINDOW_SIZE, +) + + +ops_before_transforms: dict[str, int] = { + "executorch_exir_dialects_edge__ops_aten_abs_default": 2, + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 3, + "executorch_exir_dialects_edge__ops_aten_arange_start_step": 1, + "executorch_exir_dialects_edge__ops_aten_cat_default": 1, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 6, + "executorch_exir_dialects_edge__ops_aten_index_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_linear_default": 2, + "executorch_exir_dialects_edge__ops_aten_mean_dim": 1, + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 3, + "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 2, + "executorch_exir_dialects_edge__ops_aten_relu_default": 5, + "executorch_exir_dialects_edge__ops_aten_select_copy_int": 2, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 4, + "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 2, + "executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default": 1, + "executorch_exir_dialects_edge__ops_aten_sqrt_default": 1, + "executorch_exir_dialects_edge__ops_aten_squeeze_copy_dims": 2, + "executorch_exir_dialects_edge__ops_aten_sub_Tensor": 2, + "executorch_exir_dialects_edge__ops_aten_tanh_default": 2, + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 2, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 12, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 11, +} +ops_after_transforms: dict[str, int] = { + "executorch_exir_dialects_edge__ops_aten_abs_default": 2, + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 2, + "executorch_exir_dialects_edge__ops_aten_arange_start_step": 1, + "executorch_exir_dialects_edge__ops_aten_cat_default": 1, + "executorch_exir_dialects_edge__ops_aten_convolution_default": 6, + "executorch_exir_dialects_edge__ops_aten_index_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_linear_default": 2, + "executorch_exir_dialects_edge__ops_aten_mean_dim": 1, + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 3, + "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar": 2, + "executorch_exir_dialects_edge__ops_aten_relu_default": 5, + "executorch_exir_dialects_edge__ops_aten_select_copy_int": 2, + "executorch_exir_dialects_edge__ops_aten_sigmoid_default": 4, + "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 2, + "executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default": 1, + "executorch_exir_dialects_edge__ops_aten_sqrt_default": 1, + "executorch_exir_dialects_edge__ops_aten_squeeze_copy_dims": 2, + "executorch_exir_dialects_edge__ops_aten_sub_Tensor": 2, + "executorch_exir_dialects_edge__ops_aten_tanh_default": 2, + "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 2, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 6, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 6, + "executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1, +} + + +pt_model = SileroVAD16k().eval() + +x = torch.randn( + 1, CONTEXT_SIZE + WINDOW_SIZE +) # (1, 576) — 64 context + 512 audio samples +state = torch.zeros(2, 1, HIDDEN_DIM) # (2, 1, 128) — [h, c] LSTM state + +test_cases = { + "silero_vad_16k": McuTestCase( + model=pt_model, + example_inputs=lambda: (x, state), + ), +} + + +@parametrize("test_case", test_cases) +def test_dialect_silero_vad_16k(test_case): + """This model currently does largely not lower to accelerated kernels due to missing LSTM and conv1d support, this test is to track development progress.""" + inputs = test_case.get_example_inputs() + tester = CortexMTester(test_case.model, inputs) + tester.test_dialect( + ops_before_transforms, + ops_after_transforms, + qtol=10, + ) diff --git a/backends/cortex_m/test/models/test_wav2letter.py b/backends/cortex_m/test/models/test_wav2letter.py new file mode 100644 index 00000000000..ddc5354293c --- /dev/null +++ b/backends/cortex_m/test/models/test_wav2letter.py @@ -0,0 +1,34 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.arm.test.common import parametrize +from executorch.backends.cortex_m.test.tester import CortexMTester, McuTestCase +from executorch.examples.models.wav2letter.model import Wav2LetterModel + + +ops_before_transforms: dict[str, int] = {} +ops_after_transforms: dict[str, int] = {} + +model = Wav2LetterModel() +pt_model = model.get_eager_model() + +test_cases = { + "wav2letter": McuTestCase( + model=pt_model, + example_inputs=lambda: model.get_example_inputs(), + ), +} + + +@parametrize("test_case", test_cases) +def test_dialect_wav2letter(test_case): + """This model currently does largely not lower to accelerated kernels due to missing conv1d support, this test is to track development progress.""" + inputs = test_case.get_example_inputs() + tester = CortexMTester(test_case.model, inputs) + tester.test_dialect( + ops_before_transforms, + ops_after_transforms, + qtol=10, + ) diff --git a/backends/cortex_m/test/models/test_yolo11.py b/backends/cortex_m/test/models/test_yolo11.py new file mode 100644 index 00000000000..f17c5ced331 --- /dev/null +++ b/backends/cortex_m/test/models/test_yolo11.py @@ -0,0 +1,45 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from executorch.backends.arm.test.common import parametrize + +from executorch.backends.cortex_m.test.tester import CortexMTester, McuTestCase + +YOLO = pytest.importorskip( + "ultralytics", + reason="ultralytics is optional; install it locally to run YOLO tests.", +).YOLO + + +ops_before_transforms: dict[str, int] = {} +ops_after_transforms: dict[str, int] = {} + + +WEIGHTS = "yolo11n.pt" +yolo = YOLO(WEIGHTS) +pt_model = yolo.model.eval() + +test_cases = { + "yolo11n": McuTestCase( + model=pt_model, + example_inputs=lambda: ( + torch.randn(1, 3, 640, 640).to(memory_format=torch.channels_last), + ), + ), +} + + +@parametrize("test_case", test_cases) +def test_dialect_yolo11(test_case): + """This model currently does not lower in the cortex-m backend, this test is to track development progress.""" + inputs = test_case.get_example_inputs() + tester = CortexMTester(test_case.model, inputs) + tester.test_dialect( + ops_before_transforms, + ops_after_transforms, + qtol=10, + ) diff --git a/backends/cuda/CMakeLists.txt b/backends/cuda/CMakeLists.txt index 217c893efe5..d56e994eab4 100644 --- a/backends/cuda/CMakeLists.txt +++ b/backends/cuda/CMakeLists.txt @@ -103,7 +103,7 @@ install( ) # CUDA-specific AOTI shim symbols (dynamically linked) -set(_aoti_cuda_shim_sources runtime/shims/memory.cpp +set(_aoti_cuda_shim_sources runtime/cuda_allocator.cpp runtime/shims/memory.cpp runtime/shims/cuda_guard.cpp ) @@ -180,8 +180,12 @@ install( # CUDA backend implementation set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp) +if(_cuda_is_msvc_toolchain) + # MSVC links aoti_cuda_backend into portable_lib without relying on C++ + # symbols exported from aoti_cuda_shims.dll. + list(APPEND _aoti_cuda_backend_sources runtime/cuda_allocator.cpp) +endif() -# CUDA backend implementation add_library(aoti_cuda_backend STATIC ${_aoti_cuda_backend_sources}) target_include_directories( diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index f13f41ab8b7..c8449a95718 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -74,6 +74,33 @@ runtime.cxx_library( ], ) +runtime.cxx_library( + name = "cuda_allocator", + srcs = [ + "cuda_allocator.cpp", + ], + headers = [ + "cuda_allocator.h", + ], + # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) + link_whole = True, + supports_python_dlopen = True, + visibility = ["PUBLIC"], + exported_deps = [ + "//executorch/runtime/core:device_allocator", + ], + deps = [ + "//executorch/runtime/platform:platform", + ], + nvcc_flags = get_nvcc_arch_args() + [ + "-_NVCC_HOST_COMPILER_FLAG_", + "gcc", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], +) + runtime.cxx_library( name = "cuda_backend", srcs = [ @@ -92,6 +119,8 @@ runtime.cxx_library( deps = [ ":cuda_platform", ":runtime_shims", + ":cuda_allocator", + ":cuda_platform", "//executorch/backends/aoti:aoti_common_slim", "//executorch/backends/aoti/slim/core:slimtensor", "//executorch/backends/aoti/slim/factory:empty", diff --git a/backends/cuda/runtime/cuda_allocator.cpp b/backends/cuda/runtime/cuda_allocator.cpp new file mode 100644 index 00000000000..94294b08fa0 --- /dev/null +++ b/backends/cuda/runtime/cuda_allocator.cpp @@ -0,0 +1,258 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +namespace executorch::backends::cuda { + +using executorch::runtime::Error; +using executorch::runtime::Result; +using executorch::runtime::etensor::DeviceIndex; +using executorch::runtime::etensor::DeviceType; + +Result +CudaAllocator::allocate(size_t nbytes, DeviceIndex index, size_t alignment) { + // index == -1 means "use the current CUDA device"; any value < -1 is invalid. + ET_CHECK_OR_RETURN_ERROR( + index >= -1, + InvalidArgument, + "CudaAllocator::allocate: invalid device index %d (must be >= -1)", + static_cast(index)); + + // Alignment must be a non-zero power of 2. + ET_CHECK_OR_RETURN_ERROR( + alignment != 0 && (alignment & (alignment - 1)) == 0, + InvalidArgument, + "CudaAllocator::allocate: alignment must be a power of 2, got %zu", + alignment); + + // cudaMalloc is documented to return memory aligned to at least 256 bytes, + // which trivially satisfies kDefaultAlignment (alignof(void*)). For any + // requested alignment <= 256 bytes, the returned pointer is already aligned. + // Stricter alignment would require over-allocation plus bookkeeping that + // deallocate() does not currently support, so reject that case. + constexpr size_t kCudaMallocAlignment = 256; + ET_CHECK_OR_RETURN_ERROR( + alignment <= kCudaMallocAlignment, + NotSupported, + "CudaAllocator::allocate: requested alignment %zu exceeds cudaMalloc's " + "guaranteed alignment of %zu bytes; stricter alignment is not supported", + alignment, + kCudaMallocAlignment); + + void* ptr = nullptr; + int prev_device = 0; + cudaError_t prev_device_err = cudaGetDevice(&prev_device); + + // If index == -1, fall back to the current device returned by cudaGetDevice + // and skip the set/restore round-trip. + const bool switch_device = index >= 0 && prev_device_err == cudaSuccess && + static_cast(index) != prev_device; + if (switch_device) { + cudaSetDevice(index); + } + + cudaError_t err = cudaMalloc(&ptr, nbytes); + + if (switch_device) { + cudaSetDevice(prev_device); + } + + if (err != cudaSuccess) { + ET_LOG( + Error, + "cudaMalloc failed: %s (requested %zu bytes on device %d)", + cudaGetErrorString(err), + nbytes, + static_cast(index)); + return Error::MemoryAllocationFailed; + } + + // Sanity check: the pointer returned by cudaMalloc should already meet the + // requested alignment. If a future CUDA runtime weakens this guarantee, we + // want to fail loudly rather than silently return a misaligned pointer. + if ((reinterpret_cast(ptr) & (alignment - 1)) != 0) { + ET_LOG( + Error, + "cudaMalloc returned pointer %p not aligned to %zu bytes", + ptr, + alignment); + cudaFree(ptr); + return Error::MemoryAllocationFailed; + } + + return ptr; +} + +void CudaAllocator::deallocate(void* ptr, DeviceIndex index) { + if (ptr == nullptr) { + return; + } + + int prev_device = 0; + cudaError_t prev_device_err = cudaSuccess; + + if (index >= 0) { + prev_device_err = cudaGetDevice(&prev_device); + if (prev_device_err == cudaSuccess) { + cudaSetDevice(index); + } + } + + cudaError_t err = cudaFree(ptr); + + if (index >= 0 && prev_device_err == cudaSuccess) { + cudaSetDevice(prev_device); + } + + if (err != cudaSuccess) { + ET_LOG( + Error, + "cudaFree failed: %s (ptr=%p, device %d)", + cudaGetErrorString(err), + ptr, + static_cast(index)); + } +} + +// TODO(gasoonjia): Add support for async copy +Error CudaAllocator::copy_host_to_device( + void* dst, + const void* src, + size_t nbytes, + DeviceIndex index) { + int prev_device = 0; + cudaError_t prev_device_err = cudaSuccess; + + if (index >= 0) { + prev_device_err = cudaGetDevice(&prev_device); + if (prev_device_err == cudaSuccess) { + cudaSetDevice(index); + } + } + + cudaError_t err = cudaMemcpy(dst, src, nbytes, cudaMemcpyHostToDevice); + + if (index >= 0 && prev_device_err == cudaSuccess) { + cudaSetDevice(prev_device); + } + + if (err != cudaSuccess) { + ET_LOG( + Error, + "cudaMemcpy H2D failed: %s (%zu bytes, device %d)", + cudaGetErrorString(err), + nbytes, + static_cast(index)); + return Error::Internal; + } + return Error::Ok; +} + +// TODO(gasoonjia): Add support for async copy +Error CudaAllocator::copy_device_to_host( + void* dst, + const void* src, + size_t nbytes, + DeviceIndex index) { + int prev_device = 0; + cudaError_t prev_device_err = cudaSuccess; + + if (index >= 0) { + prev_device_err = cudaGetDevice(&prev_device); + if (prev_device_err == cudaSuccess) { + cudaSetDevice(index); + } + } + + cudaError_t err = cudaMemcpy(dst, src, nbytes, cudaMemcpyDeviceToHost); + + if (index >= 0 && prev_device_err == cudaSuccess) { + cudaSetDevice(prev_device); + } + + if (err != cudaSuccess) { + ET_LOG( + Error, + "cudaMemcpy D2H failed: %s (%zu bytes, device %d)", + cudaGetErrorString(err), + nbytes, + static_cast(index)); + return Error::Internal; + } + return Error::Ok; +} + +DeviceType CudaAllocator::device_type() const { + return DeviceType::CUDA; +} + +CudaAllocator& CudaAllocator::instance() { + static CudaAllocator allocator; + return allocator; +} + +Result CudaAllocator::allocate_async( + size_t nbytes, + DeviceIndex index, + cudaStream_t stream) { + void* ptr = nullptr; + cudaError_t err = cudaMallocAsync(&ptr, nbytes, stream); + if (err != cudaSuccess) { + ET_LOG( + Error, + "cudaMallocAsync failed: %s (requested %zu bytes on device %d)", + cudaGetErrorString(err), + nbytes, + static_cast(index)); + return Error::MemoryAllocationFailed; + } + return ptr; +} + +void CudaAllocator::deallocate_async( + void* ptr, + DeviceIndex index, + cudaStream_t stream) { + if (ptr == nullptr) { + return; + } + cudaError_t err = cudaFreeAsync(ptr, stream); + if (err != cudaSuccess) { + ET_LOG( + Error, + "cudaFreeAsync failed: %s (ptr=%p, device %d)", + cudaGetErrorString(err), + ptr, + static_cast(index)); + } +} + +Error CudaAllocator::memcpy_async( + void* dst, + const void* src, + size_t nbytes, + cudaMemcpyKind direction, + cudaStream_t stream) { + cudaError_t err = cudaMemcpyAsync(dst, src, nbytes, direction, stream); + if (err != cudaSuccess) { + ET_LOG( + Error, + "cudaMemcpyAsync failed: %s (%zu bytes)", + cudaGetErrorString(err), + nbytes); + return Error::Internal; + } + return Error::Ok; +} + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/cuda_allocator.h b/backends/cuda/runtime/cuda_allocator.h new file mode 100644 index 00000000000..fcd8224305a --- /dev/null +++ b/backends/cuda/runtime/cuda_allocator.h @@ -0,0 +1,84 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace executorch::backends::cuda { + +/** + * CUDA implementation of DeviceAllocator. + * + * Uses cudaMalloc/cudaFree for allocation and cudaMemcpy for host-device + * transfers. This allocator is automatically registered as a singleton + * with the DeviceAllocatorRegistry when the CUDA backend library is linked. + * + * All CUDA memory operations in the CUDA backend should go through this + * allocator for consistent memory management. + */ +class CudaAllocator final : public executorch::runtime::DeviceAllocator { + public: + executorch::runtime::Result allocate( + size_t nbytes, + executorch::runtime::etensor::DeviceIndex index, + size_t alignment = kDefaultAlignment) override; + + void deallocate(void* ptr, executorch::runtime::etensor::DeviceIndex index) + override; + + executorch::runtime::Error copy_host_to_device( + void* dst, + const void* src, + size_t nbytes, + executorch::runtime::etensor::DeviceIndex index) override; + + executorch::runtime::Error copy_device_to_host( + void* dst, + const void* src, + size_t nbytes, + executorch::runtime::etensor::DeviceIndex index) override; + + executorch::runtime::etensor::DeviceType device_type() const override; + + /// Returns the global CudaAllocator singleton. + static CudaAllocator& instance(); + + // --- Async (stream-based) operations for SlimTensor/Storage layer --- + + /** + * Allocate device memory asynchronously on the given CUDA stream. + */ + static executorch::runtime::Result allocate_async( + size_t nbytes, + executorch::runtime::etensor::DeviceIndex index, + cudaStream_t stream); + + /** + * Deallocate device memory asynchronously on the given CUDA stream. + */ + static void deallocate_async( + void* ptr, + executorch::runtime::etensor::DeviceIndex index, + cudaStream_t stream); + + /** + * Copy memory asynchronously on the given CUDA stream. + * Supports H2D, D2H, and D2D based on src/dst device types. + */ + static executorch::runtime::Error memcpy_async( + void* dst, + const void* src, + size_t nbytes, + cudaMemcpyKind direction, + cudaStream_t stream); +}; + +} // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 1497ba1e376..d2738f7a976 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -40,6 +40,7 @@ // Include our shim layer headers #include #include +#include #include #include #include @@ -1273,5 +1274,13 @@ auto cls = cuda::CudaBackend(); executorch::runtime::Backend backend{"CudaBackend", &cls}; static executorch::runtime::Error success_with_compiler = register_backend(backend); + +// Auto-register the CudaAllocator so that DeviceMemoryBuffer::create(CUDA) +// works whenever the CUDA backend library is linked. +static bool cuda_allocator_registered = [] { + executorch::runtime::register_device_allocator( + &cuda::CudaAllocator::instance()); + return true; +}(); } // namespace } // namespace executorch::backends diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py index fd28b077b8a..673af19310f 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py @@ -3,6 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch + +from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, @@ -23,11 +26,33 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - if NodeConverter.uses_shape_broadcasting(node): - # Shape broadcasting may require the addition of `Transpose` ops during conversion. - return False + if custom_delegation_options.use_new_flow_neutron_c: + if not NodeConverter.at_least_one_input_shape_matches_the_output_shape( + node + ): + return False - return True + # If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes + # Transpose is currently not supported for new flow + if any( + input_node.meta[NXP_NODE_FORMAT].is_channels_first() + for input_node in node.all_input_nodes + ) and NodeConverter._node_inputs_ranks_not_equal(node): + return False + + supported_types = [torch.int8, torch.uint8] + if not NodeConverter.uses_quantization_type_for_io( + node, supported_types, [0, 1], [0] + ): + return False + + return True + else: + if NodeConverter.uses_shape_broadcasting(node): + # Shape broadcasting may require the addition of `Transpose` ops during conversion. + return False + + return True @staticmethod def _is_supported_in_IR( @@ -43,12 +68,13 @@ def _is_supported_in_IR( return True - # add.Tensor Node format: (Tensor self, Tensor other, *, Scalar alpha=1) def convert(self, node: Node): - """Convert 'add_tensor' operator to TFLite 'add'.""" + """Convert 'add_tensor' operator to NeutronIR 'Add'. + The ExecuTorch schema is: + add.Tensor(Tensor self, Tensor other, Scalar alpha=1) + """ self.assert_convertible(node) - t_op = self._create_tflite_op_with_io_tensors(node) - t_op.builtin_options = add_options.Add() + self.builder.append_operators([t_op]) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py index e97f4bf63c2..79dbcbcc012 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/sub_tensor_converter.py @@ -3,6 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torch + +from executorch.backends.nxp.backend.data_format import NXP_NODE_FORMAT from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, @@ -23,11 +26,33 @@ def _is_supported_on_target( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - if NodeConverter.uses_shape_broadcasting(node): - # Shape broadcasting may require the addition of `Transpose` ops during conversion. - return False + if custom_delegation_options.use_new_flow_neutron_c: + if not NodeConverter.at_least_one_input_shape_matches_the_output_shape( + node + ): + return False - return True + # If one input is in channel first and ranks of input tensors are not equal, we need to add Transposes + # Transpose is currently not supported for new flow + if any( + input_node.meta[NXP_NODE_FORMAT].is_channels_first() + for input_node in node.all_input_nodes + ) and NodeConverter._node_inputs_ranks_not_equal(node): + return False + + supported_types = [torch.int8, torch.uint8] + if not NodeConverter.uses_quantization_type_for_io( + node, supported_types, [0, 1], [0] + ): + return False + + return True + else: + if NodeConverter.uses_shape_broadcasting(node): + # Shape broadcasting may require the addition of `Transpose` ops during conversion. + return False + + return True @staticmethod def _is_supported_in_IR( @@ -45,9 +70,12 @@ def _is_supported_in_IR( return True - # sub.Tensor Node format: (Tensor self, Tensor other, *, Scalar alpha=1) def convert(self, node: Node): - """Convert 'sub_tensor' operator to NeutronIR 'Sub'.""" + """Convert 'sub_tensor' operator to NeutronIR 'Sub'. + The ExecuTorch schema is: + sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) + """ + self.assert_convertible(node) t_op = self._create_tflite_op_with_io_tensors(node) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py index 1aa58ab5d95..4a656eb9517 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_add_tensor_converter.py @@ -1,7 +1,8 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import numpy as np import pytest import torch @@ -9,17 +10,29 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator +from executorch.backends.nxp.tests.executorch_pipeline import ( + ModelInputSpec, + to_quantized_edge_program, +) from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToChannelFirstPreprocess, ToChannelLastPreprocess, ) +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier from executorch.backends.nxp.tests.models import ( AddTensorConvModule, AddTensorModule, AddTensorOneInputModule, ) +from executorch.backends.nxp.tests.nsys_testing import lower_run_compare +from executorch.backends.nxp.tests.ops_aliases import ( + AddTensor, + Convolution, + ExecutorchDelegateCall, +) from torch.export import ExportedProgram from executorch.backends.nxp.tests.use_qat import * # noqa F403 @@ -92,20 +105,26 @@ def test_add_tensor_one_input_quant_conversion(mocker, input_shape, use_qat): @pytest.mark.parametrize( - "input_shape", + "x_input_shape", [ pytest.param((1, 4, 8, 8), id="4D."), pytest.param((1, 4, 5, 5), id="4D, product of dims is not a multiple of 8."), ], ) -def test_add_tensor_w_conv_quant_conversion(mocker, input_shape, use_qat): +def test_add_tensor_w_conv_quant_conversion(mocker, x_input_shape, use_qat): model = AddTensorConvModule() converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + n, c, h, w = x_input_shape + y_input_shape = (n, 8, h, w) + # Run conversion _ = to_quantized_edge_program( - model, input_shape, use_qat=use_qat, use_neutron_for_format_conversion=False + model, + [x_input_shape, y_input_shape], + use_qat=use_qat, + use_neutron_for_format_conversion=False, ) # Capture generated model @@ -114,7 +133,13 @@ def test_add_tensor_w_conv_quant_conversion(mocker, input_shape, use_qat): # Capture converted program exported_program: ExportedProgram = converter_spy.call_args.args[1] - input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + input_data_1 = (np.random.random(x_input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data_2 = (np.random.random(y_input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + input_data = {0: input_data_1, 1: input_data_2} convert_run_compare( exported_program, @@ -149,7 +174,7 @@ def test_add_tensor_broadcasting_unsupported_quant_conversion( nodes = list(edge_program.graph.nodes) # Broadcast is not supported, node is not converted - assert nodes[6].target.__name__ == "aten.add.Tensor" # Add Tensor is not delegated. + assert nodes[6].target == AddTensor # Add Tensor is not delegated. # Capture converted program # exported_program: ExportedProgram = converter_spy.call_args.args[1] @@ -159,3 +184,227 @@ def test_add_tensor_broadcasting_unsupported_quant_conversion( # input_data = {0: x_input_data, 1: y_input_data} # # convert_run_compare(exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data) + + +class TestAddTensorNewNeutronFlow: + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((1,), id="1D."), + pytest.param((6, 5), id="2D."), + pytest.param((1, 4, 7), id="3D."), + pytest.param((2, 4, 3, 15), id="4D."), + pytest.param( + (6, 82), + id="2D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (1, 68, 7), + id="3D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (1, 4, 9, 11, 4), + id="5D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__basic_nsys_inference(self, x_input_shape, mocker): + x_input_spec = ModelInputSpec(x_input_shape) + model = AddTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={AddTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, x_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((1,), id="1D."), + pytest.param((6, 5), id="2D."), + pytest.param((1, 4, 7), id="3D."), + pytest.param((2, 4, 3, 15), id="4D."), + pytest.param( + (1, 4, 9, 11, 4), + id="5D.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__basic_nsys_inference_qat(self, x_input_shape, mocker): + x_input_spec = ModelInputSpec(x_input_shape) + model = AddTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={AddTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, x_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + use_qat=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((4, 6)), ModelInputSpec((1, 6))], id="2 inputs 2D." + ), + pytest.param( + [ModelInputSpec((5, 3, 4)), ModelInputSpec((1, 3, 1))], + id="2 inputs 3D.", + ), + pytest.param( + [ModelInputSpec((4,)), ModelInputSpec((4, 4))], id="2 inputs 1D + 2D." + ), + pytest.param( + [ModelInputSpec((69, 73)), ModelInputSpec((1, 73))], + id="2 inputs 2D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__broadcast(self, input_spec, mocker): + model = AddTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={AddTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + input_spec, + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((4, 1)), ModelInputSpec((1, 6))], id="2 inputs 2D." + ), + pytest.param( + [ModelInputSpec((1, 3, 4)), ModelInputSpec((5, 3, 1))], + id="2 inputs 3D.", + ), + pytest.param( + [ModelInputSpec((6, 4)), ModelInputSpec((6, 6, 1))], + id="2 inputs 2D + 3D.", + ), + ], + ) + def test__broadcast_unsupported(self, input_spec): + # Broadcast where at least one of the inputs is not equal to output is not supported + model = AddTensorModule() + + delegated_ep = to_quantized_edge_program( + model, input_spec, use_new_flow_neutron_c=True + ).exported_program() + + # Make sure the `add.Tensor` was NOT delegated. + assert not graph_contains_any_of_ops( + delegated_ep.graph, [ExecutorchDelegateCall] + ) + assert graph_contains_any_of_ops(delegated_ep.graph, [AddTensor]) + + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param( + (1, 4, 5, 5), id="4D, product of dims is not a multiple of 8." + ), + ], + ) + def test__w_conv(self, x_input_shape, mocker): + model = AddTensorConvModule() + + n, c, h, w = x_input_shape + y_input_spec = ModelInputSpec((n, 8, h, w)) + x_input_spec = ModelInputSpec(x_input_shape) + + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={AddTensor: 1, Convolution: 1}, + expected_non_delegated_ops={}, + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, y_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((1, 4, 5, 5)), ModelInputSpec((1, 8, 5, 1))], + id="2 inputs 4D + 4D.", + ), + pytest.param( + [ModelInputSpec((1, 4, 5, 67)), ModelInputSpec((1, 8, 5, 1))], + id="2 inputs 4D + 4D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__w_conv_broadcast(self, input_spec, mocker): + model = AddTensorConvModule() + + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={AddTensor: 1, Convolution: 1}, + expected_non_delegated_ops={}, + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + input_spec, + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((1, 4, 5, 5)), ModelInputSpec((1, 5))], + id="2 inputs 4D + 2D.", + ), + pytest.param( + [ModelInputSpec((1, 4, 4, 10)), ModelInputSpec((1, 4, 1))], + id="2 inputs 4D + 3D.", + ), + ], + ) + def test__w_conv_unsupported(self, input_spec): + model = AddTensorConvModule() + + delegated_ep = to_quantized_edge_program( + model, input_spec, use_new_flow_neutron_c=True + ).exported_program() + + # Make sure the `add.Tensor` was NOT delegated. + assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert graph_contains_any_of_ops(delegated_ep.graph, [AddTensor]) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py index 2c73ccd8092..193b7ecf9ab 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py @@ -6,6 +6,7 @@ import numpy as np import pytest import torch + from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) @@ -29,13 +30,8 @@ ToNHWCPreprocess, ) from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier -from executorch.backends.nxp.tests.model_output_comparator import ( - NumericalStatsOutputComparator, -) from executorch.backends.nxp.tests.models import AvgPool2dConvModule, AvgPool2dModule - from executorch.backends.nxp.tests.nsys_testing import lower_run_compare - from executorch.backends.nxp.tests.ops_aliases import ( AvgPool2D, ExecutorchDelegateCall, @@ -45,6 +41,7 @@ Unsqueeze, ViewCopy, ) + from torch.export import ExportedProgram from executorch.backends.nxp.tests.use_qat import * # noqa F403 @@ -320,7 +317,6 @@ def test__basic_nsys_inference(self, mocker): def test__basic_nsys_inference_qat(self, mocker): input_shape = (2, 9, 6, 15) model = AvgPool2dModule(False, 0) - comparator = NumericalStatsOutputComparator() graph_verifier = DetailedGraphVerifier( mocker, expected_delegated_ops={AvgPool2D: 1}, expected_non_delegated_ops={} ) @@ -329,7 +325,6 @@ def test__basic_nsys_inference_qat(self, mocker): model, input_shape, graph_verifier, - output_comparator=comparator, use_new_flow_neutron_c=True, use_qat=True, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py index 583dc2bfd04..9062d5efbfc 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_max_pool_2d_converter.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import numpy as np +import pytest import torch from executorch.backends.nxp.backend.edge_program_converter import ( @@ -17,9 +18,6 @@ ToChannelLastPreprocess, ) from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier -from executorch.backends.nxp.tests.model_output_comparator import ( - NumericalStatsOutputComparator, -) from executorch.backends.nxp.tests.nsys_testing import lower_run_compare from executorch.backends.nxp.tests.ops_aliases import ( ExecutorchDelegateCall, @@ -32,7 +30,6 @@ ViewCopy, ) from executorch.backends.nxp.tests.use_qat import * # noqa F403 -import pytest class MaxPool1DModule(torch.nn.Module): @@ -286,7 +283,6 @@ def test__basic_nsys_inference(self, mocker): def test__basic_nsys_inference_qat(self, mocker): input_shape = (2, 11, 7, 16) # The old flow limited the batch size to 1. model = MaxPool2dModule() - comparator = NumericalStatsOutputComparator() graph_verifier = DetailedGraphVerifier( mocker, expected_delegated_ops={MaxPool2DWithIndices: 1, GetItem: 1}, @@ -297,7 +293,6 @@ def test__basic_nsys_inference_qat(self, mocker): model, input_shape, graph_verifier, - output_comparator=comparator, use_new_flow_neutron_c=True, use_qat=True, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py index 927af47bbf5..90113f484ad 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_mul_tensor_converter.py @@ -21,9 +21,6 @@ ToChannelLastPreprocess, ) from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier -from executorch.backends.nxp.tests.model_output_comparator import ( - NumericalStatsOutputComparator, -) from executorch.backends.nxp.tests.models import ( MulTensorConvModule, MulTensorModule, @@ -256,7 +253,6 @@ def test__basic_nsys_inference(self, x_input_shape, mocker): def test__basic_nsys_inference_qat(self, x_input_shape, mocker): x_input_spec = ModelInputSpec(x_input_shape) model = MulTensorModule() - comparator = NumericalStatsOutputComparator() graph_verifier = DetailedGraphVerifier( mocker, expected_delegated_ops={MulTensor: 1}, expected_non_delegated_ops={} ) @@ -265,7 +261,6 @@ def test__basic_nsys_inference_qat(self, x_input_shape, mocker): model, [x_input_spec, x_input_spec], graph_verifier, - output_comparator=comparator, use_new_flow_neutron_c=True, use_qat=True, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py index 9ce3e93f39b..2734e89bc5d 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_sub_tensor_converter.py @@ -1,7 +1,8 @@ -# Copyright 2025 NXP +# Copyright 2025-2026 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import numpy as np import pytest import torch @@ -9,18 +10,29 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) -from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.dataset_creator import RandomDatasetCreator +from executorch.backends.nxp.tests.executorch_pipeline import ( + ModelInputSpec, + to_quantized_edge_program, +) from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToChannelFirstPreprocess, ToChannelLastPreprocess, ) +from executorch.backends.nxp.tests.graph_verifier import DetailedGraphVerifier from executorch.backends.nxp.tests.models import ( SubTensorConvModule, SubTensorModule, SubTensorOneInputModule, ) -from executorch.exir.dialects._ops import ops as exir_ops +from executorch.backends.nxp.tests.nsys_testing import lower_run_compare +from executorch.backends.nxp.tests.ops_aliases import ( + Convolution, + ExecutorchDelegateCall, + SubTensor, +) from torch.export import ExportedProgram from executorch.backends.nxp.tests.use_qat import * # noqa F403 @@ -63,7 +75,7 @@ def test_sub_tensor_quant_conversion(mocker, input_shape, use_qat): input_data = {0: input_data_1, 1: input_data_2} nodes = list(exported_program.graph.nodes) - assert nodes[4].target == exir_ops.edge.aten.sub.Tensor + assert nodes[4].target == SubTensor convert_run_compare( exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data @@ -96,7 +108,7 @@ def test_sub_tensor_one_input_quant_conversion(mocker, input_shape, use_qat): input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) nodes = list(exported_program.graph.nodes) - assert nodes[2].target == exir_ops.edge.aten.sub.Tensor + assert nodes[2].target == SubTensor convert_run_compare( exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data @@ -141,7 +153,7 @@ def test_sub_tensor_w_conv_quant_conversion(mocker, x_input_shape, use_qat): input_data = {0: input_data_1, 1: input_data_2} nodes = list(exported_program.graph.nodes) - assert nodes[15].target == exir_ops.edge.aten.sub.Tensor + assert nodes[15].target == SubTensor convert_run_compare( exported_program, @@ -176,6 +188,236 @@ def test_sub_tensor_broadcasting_unsupported_quant_conversion( nodes = list(edge_program.graph.nodes) # Broadcast is not supported, node is not converted - assert ( - nodes[6].target == exir_ops.edge.aten.sub.Tensor - ) # Sub Tensor is not delegated. + assert nodes[6].target == SubTensor # Sub Tensor is not delegated. + + +class TestSubTensorNewNeutronFlow: + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((1,), id="1D."), + pytest.param((6, 5), id="2D."), + pytest.param((1, 4, 7), id="3D."), + pytest.param( + (6, 82), + id="2D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (1, 68, 7), + id="3D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (2, 4, 3, 15), + id="4D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (1, 4, 9, 11, 4), + id="5D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__basic_nsys_inference(self, x_input_shape, mocker): + x_input_spec = ModelInputSpec(x_input_shape) + model = SubTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={SubTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, x_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param((1,), id="1D."), + pytest.param((6, 5), id="2D."), + pytest.param((2, 4, 3, 15), id="4D."), + pytest.param( + (1, 4, 7), + id="3D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + (1, 4, 9, 11, 4), + id="5D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__basic_nsys_inference_qat(self, x_input_shape, mocker): + x_input_spec = ModelInputSpec(x_input_shape) + model = SubTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={SubTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, x_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + use_qat=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((4, 6)), ModelInputSpec((1, 6))], id="2 inputs 2D." + ), + pytest.param( + [ModelInputSpec((4,)), ModelInputSpec((4, 4))], id="2 inputs 1D + 2D." + ), + pytest.param( + [ModelInputSpec((5, 3, 4)), ModelInputSpec((1, 3, 1))], + id="2 inputs 3D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + pytest.param( + [ModelInputSpec((69, 73)), ModelInputSpec((1, 73))], + id="2 inputs 2D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__broadcast(self, input_spec, mocker): + model = SubTensorModule() + graph_verifier = DetailedGraphVerifier( + mocker, expected_delegated_ops={SubTensor: 1}, expected_non_delegated_ops={} + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + input_spec, + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((4, 1)), ModelInputSpec((1, 6))], id="2 inputs 2D." + ), + pytest.param( + [ModelInputSpec((1, 3, 4)), ModelInputSpec((5, 3, 1))], + id="2 inputs 3D.", + ), + pytest.param( + [ModelInputSpec((6, 4)), ModelInputSpec((6, 6, 1))], + id="2 inputs 2D+3D.", + ), + ], + ) + def test__broadcast_unsupported(self, input_spec): + # Broadcast where at least one of the inputs is not equal to output is not supported + model = SubTensorModule() + + delegated_ep = to_quantized_edge_program( + model, input_spec, use_new_flow_neutron_c=True + ).exported_program() + + # Make sure the `sub.Tensor` was NOT delegated. + assert not graph_contains_any_of_ops( + delegated_ep.graph, [ExecutorchDelegateCall] + ) + assert graph_contains_any_of_ops(delegated_ep.graph, [SubTensor]) + + @pytest.mark.parametrize( + "x_input_shape", + [ + pytest.param( + (1, 4, 5, 5), id="4D, product of dims is not a multiple of 8." + ), + ], + ) + def test__w_conv(self, x_input_shape, mocker): + model = SubTensorConvModule() + + n, c, h, w = x_input_shape + y_input_spec = ModelInputSpec((n, 8, h, w)) + x_input_spec = ModelInputSpec(x_input_shape) + + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={SubTensor: 1, Convolution: 1}, + expected_non_delegated_ops={}, + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + [x_input_spec, y_input_spec], + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((1, 4, 7, 1)), ModelInputSpec((1, 8, 1, 1))], + id="2 inputs 4D + 4D.", + ), + pytest.param( + [ModelInputSpec((1, 4, 5, 5)), ModelInputSpec((1, 8, 5, 1))], + id="2 inputs 4D + 4D incorrect.", + marks=pytest.mark.xfail(reason="AIR-14602: incorrect results"), + ), + ], + ) + def test__w_conv_broadcast(self, input_spec, mocker): + model = SubTensorConvModule() + graph_verifier = DetailedGraphVerifier( + mocker, + expected_delegated_ops={SubTensor: 1, Convolution: 1}, + expected_non_delegated_ops={}, + ) + dataset_creator = RandomDatasetCreator(low=-1.0, high=1.0) + + lower_run_compare( + model, + input_spec, + graph_verifier, + dataset_creator, + use_new_flow_neutron_c=True, + ) + + @pytest.mark.parametrize( + "input_spec", + [ + pytest.param( + [ModelInputSpec((1, 4, 5, 5)), ModelInputSpec((1, 5))], + id="2 inputs 4D + 2D.", + ), + pytest.param( + [ModelInputSpec((1, 4, 4, 10)), ModelInputSpec((1, 4, 1))], + id="2 inputs 4D + 3D.", + ), + ], + ) + def test__w_conv_unsupported(self, input_spec): + model = SubTensorConvModule() + + delegated_ep = to_quantized_edge_program( + model, input_spec, use_new_flow_neutron_c=True + ).exported_program() + + # Make sure the `sub.Tensor` was NOT delegated. + assert graph_contains_any_of_ops(delegated_ep.graph, [ExecutorchDelegateCall]) + assert graph_contains_any_of_ops(delegated_ep.graph, [SubTensor]) diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index 045dcfaba40..1292c4cf17d 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -656,9 +656,9 @@ def __init__(self): super().__init__() self.conv = Conv2dModule(padding=1, stride=1) - def forward(self, x): + def forward(self, x, y): x = self.conv(x) - return x + x + return x + y class AddTensorOneInputModule(torch.nn.Module): diff --git a/backends/nxp/tests/ops_aliases.py b/backends/nxp/tests/ops_aliases.py index ec58072658d..7f855dd63af 100644 --- a/backends/nxp/tests/ops_aliases.py +++ b/backends/nxp/tests/ops_aliases.py @@ -13,6 +13,7 @@ Abs = exir_ops.edge.aten.abs.default AdaptiveAvgPool2D = exir_ops.edge.aten._adaptive_avg_pool2d.default +AddTensor = exir_ops.edge.aten.add.Tensor AvgPool2D = exir_ops.edge.aten.avg_pool2d.default Bmm = exir_ops.edge.aten.bmm.default ConstantPadND = exir_ops.edge.aten.constant_pad_nd.default @@ -36,6 +37,7 @@ Squeeze = exir_ops.edge.aten.squeeze.default SqueezeDim = exir_ops.edge.aten.squeeze.dim SqueezeDims = exir_ops.edge.aten.squeeze.dims +SubTensor = exir_ops.edge.aten.sub.Tensor Unsqueeze = exir_ops.edge.aten.unsqueeze.default UpsampleBilinear2D = exir_ops.edge.aten.upsample_bilinear2d.vec UpsampleNearest2D = exir_ops.edge.aten.upsample_nearest2d.vec diff --git a/backends/qualcomm/_passes/build_quant_io.py b/backends/qualcomm/_passes/build_quant_io.py index d43842e84a5..057dcc0f864 100644 --- a/backends/qualcomm/_passes/build_quant_io.py +++ b/backends/qualcomm/_passes/build_quant_io.py @@ -5,11 +5,10 @@ # LICENSE file in the root directory of this source tree. import torch from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO -from executorch.exir.delegate import executorch_call_delegate -from executorch.exir.pass_base import ExportPass, ProxyValue +from executorch.exir.delegate import executorch_call_delegate +from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.tensor import TensorSpec -from torch.utils import _pytree as pytree class BuildQuantIo(ExportPass): @@ -28,22 +27,27 @@ def _make_spec(self, x): else: return None - def placeholder(self, name: str, arg, meta): - if quantized_dtype := meta.data.get(QCOM_QUANTIZED_IO, None): - arg = arg.to(dtype=quantized_dtype) - meta["spec"] = self._make_spec(arg) - return super().placeholder(name, arg, meta) - - def call_getitem(self, value, key: int, meta): - meta["spec"] = value.node.meta["spec"][key] - return super().call_getitem(value, key, meta) - - def call_delegate(self, lowered_module, args, kwargs, meta): - args_data, _ = pytree.tree_map_only( - ProxyValue, lambda x: x.data, (args, kwargs) - ) - meta["spec"] = pytree.tree_map( - self._make_spec, - executorch_call_delegate(lowered_module, *args_data), - ) - return super().call_delegate(lowered_module, args, kwargs, meta) + def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + # Forcedly update delegate node's meta['spec'] to get correct output + # tensor size in runtime + call_delegates = [ + node + for node in graph_module.graph.nodes + if node.op == "call_function" and node.target == executorch_call_delegate + ] + for n in graph_module.graph.nodes: + if QCOM_QUANTIZED_IO in n.meta: + n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO]) + n.meta["spec"] = self._make_spec(n.meta["val"]) + + for call_delegate in call_delegates: + spec = [] + for user in list(call_delegate.users): + spec.append(self._make_spec(user.meta["val"])) + call_delegate.meta["spec"] = tuple(spec) + + def call(self, graph_module: torch.fx.GraphModule): + self._build(graph_module) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/recompose_pad_maxpool2d.py b/backends/qualcomm/_passes/recompose_pad_maxpool2d.py index 81b4836f251..6a8374cb66a 100644 --- a/backends/qualcomm/_passes/recompose_pad_maxpool2d.py +++ b/backends/qualcomm/_passes/recompose_pad_maxpool2d.py @@ -13,12 +13,8 @@ from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.passes import dead_code_elimination_pass -from torch._subclasses.fake_tensor import FakeTensorMode - - -def add_fake_tensor_to_node(padding_node, input_shape, padding_args, dtype): - fake_mode = FakeTensorMode() +def add_fake_tensor_to_node(padding_node, input_shape, padding_args, dtype, fake_mode): with fake_mode: batch, channels, height, width = input_shape pad_left, pad_right, pad_top, pad_bottom = padding_args @@ -114,6 +110,7 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa C901 input_node.meta["val"].shape, padding, input_node.meta["val"].dtype, + input_node.meta["val"].fake_mode, ) if quant_attrs: padding_node.meta["quant_attrs"] = node.meta["quant_attrs"] diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index 542fa1115a6..91a7cfdc69a 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -137,7 +137,23 @@ def copy_nn_module_stack(src, target): target.meta["nn_module_stack"] = value -def merge_decomposed_graph( +def _unify_fake_mode(node: torch.fx.Node, fake_mode) -> None: + val = node.meta.get("val") + if val is None: + return + if isinstance(val, FakeTensor) and val.fake_mode is not fake_mode: + node.meta["val"] = fake_mode.from_tensor(val) + elif isinstance(val, (list, tuple)): + unified = [] + for v in val: + if isinstance(v, FakeTensor) and v.fake_mode is not fake_mode: + unified.append(fake_mode.from_tensor(v)) + else: + unified.append(v) + node.meta["val"] = type(val)(unified) + + +def merge_decomposed_graph( # noqa: C901 remap: Dict[str, torch.fx.Node], target_node: torch.fx.Node, target_graph: torch.fx.GraphModule, @@ -148,6 +164,16 @@ def merge_decomposed_graph( [torch.fx.Node, torch.fx.Node, Dict[str, torch.fx.Node]], None ] = None, ) -> None: + target_fake_mode = None + target_val = target_node.meta.get("val") + if isinstance(target_val, FakeTensor): + target_fake_mode = target_val.fake_mode + elif isinstance(target_val, (list, tuple)): + for v in target_val: + if isinstance(v, FakeTensor): + target_fake_mode = v.fake_mode + break + def default_output_process(node): for user in node.users.copy(): # remap @@ -170,10 +196,13 @@ def default_output_process(node): # replace node map from string to graph node remap[decomposed_node] = remap.pop(decomposed_node.name) else: - remap[decomposed_node] = target_graph.node_copy( + copied = target_graph.node_copy( decomposed_node, arg_transform=lambda x, remap=remap: remap[x], ) + if target_fake_mode is not None: + _unify_fake_mode(copied, target_fake_mode) + remap[decomposed_node] = copied def is_float_tensor(node: torch.fx.Node) -> bool: diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 6d5b44d7a35..ee6678fa499 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -7730,8 +7730,11 @@ def test_llama_stories_110m(self): "--max_context_len", "128", ] + if self.use_fp16: + cmds.append("--use_fp16") self.add_default_cmds(cmds) - + print(" ".join(cmds)) + exit(0) golden_start_with = "Once upon a time," p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) with Listener((self.ip, self.port)) as listener: @@ -7750,7 +7753,10 @@ def test_llama_stories_110m(self): # x86 does not allow weight sharing, so we don't check pte size if not self.enable_x86_64: pte_size = msg["pte_size"] - self.assertLessEqual(pte_size, 135_000_000) # 135MB + if self.use_fp16: + self.assertLessEqual(pte_size, 275_000_000) # 275MB + else: + self.assertLessEqual(pte_size, 135_000_000) # 135MB if not self.compile_only and not self.enable_x86_64: self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai @@ -10087,6 +10093,13 @@ def setup_environment(): choices=["wikitext_ppl", "hellaswag_acc_norm", "sqnr"], type=str, ) + parser.add_argument( + "-F", + "--use_fp16", + help="If specified, will run in fp16 precision and discard ptq setting", + action="store_true", + default=False, + ) args, ns_args = parser.parse_known_args(namespace=unittest) TestQNN.host = args.host @@ -10114,6 +10127,7 @@ def setup_environment(): TestQNN.backend = args.backend TestQNN.static_llm_eval_method = args.static_llm_eval_method TestQNN.direct_build_folder = args.direct_build_folder + TestQNN.use_fp16 = args.use_fp16 return sys.argv[:1] + ns_args diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index d8802f74e68..c22ee8371e0 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -221,6 +221,7 @@ class TestQNN(unittest.TestCase): static_llm_eval_method = "" direct_build_folder: str = "" dsp_heap_profile_filename = "htp_heap_usage.txt" + use_fp16 = False @classmethod def setUpClass(cls): diff --git a/backends/vulkan/test/op_tests/utils/gen_computegraph.py b/backends/vulkan/test/op_tests/utils/gen_computegraph.py index a09b4d36b18..507719b8555 100644 --- a/backends/vulkan/test/op_tests/utils/gen_computegraph.py +++ b/backends/vulkan/test/op_tests/utils/gen_computegraph.py @@ -286,7 +286,7 @@ def create_aten_fn_call(self) -> str: def create_aten_method_call(self) -> str: # For functions with only Method variant, we fallback to the function # declared in MethodOperators.h - cpp_sig = gen_static_dispatch_backend_call_signature(self.f_sig, self.f) + cpp_sig = gen_static_dispatch_backend_call_signature(self.f) exprs = translate_args(self.f_sig, cpp_sig) func_call = f"at::_ops::{self.f_sig.name()}::call({exprs});" return func_call diff --git a/backends/xnnpack/runtime/XNNExecutor.cpp b/backends/xnnpack/runtime/XNNExecutor.cpp index 103a8812931..5a150f92b6b 100644 --- a/backends/xnnpack/runtime/XNNExecutor.cpp +++ b/backends/xnnpack/runtime/XNNExecutor.cpp @@ -23,6 +23,28 @@ using executorch::runtime::is_contiguous_dim_order; using executorch::runtime::kTensorDimensionLimit; using executorch::runtime::Span; +namespace { +class InUseGuard { + public: + explicit InUseGuard(std::atomic& flag) : flag_(flag) {} + ~InUseGuard() { + if (!dismissed_) { + flag_.store(false, std::memory_order_release); + } + } + void dismiss() { + dismissed_ = true; + } + + InUseGuard(const InUseGuard&) = delete; + InUseGuard& operator=(const InUseGuard&) = delete; + + private: + std::atomic& flag_; + bool dismissed_ = false; +}; +} // namespace + /** * Initializes the XNNExecutor with the runtime and given number of * inputs/outputs externals_ is resized to the total number of inputs and @@ -71,6 +93,21 @@ ET_NODISCARD Error XNNExecutor::initialize( * delegate->execute() */ ET_NODISCARD Error XNNExecutor::prepare_args(Span args) { + ET_DCHECK_MSG( + !destroyed_.load(std::memory_order_acquire), + "XNNExecutor::prepare_args called after destroy"); + + bool was_in_use = in_use_.exchange(true, std::memory_order_acquire); + if (was_in_use) { + ET_LOG(Error, "XNNExecutor::prepare_args called concurrently"); + } + ET_DCHECK_MSG(!was_in_use, "XNNExecutor::prepare_args called concurrently"); + + InUseGuard in_use_guard(in_use_); + if (was_in_use) { + in_use_guard.dismiss(); + } + ET_CHECK_OR_RETURN_ERROR( runtime_ != nullptr, Internal, @@ -142,6 +179,7 @@ ET_NODISCARD Error XNNExecutor::prepare_args(Span args) { return err; } + in_use_guard.dismiss(); return Error::Ok; } @@ -152,6 +190,8 @@ ET_NODISCARD Error XNNExecutor::prepare_args(Span args) { * After which we then execute the runtime through invoke_runtime. */ ET_NODISCARD Error XNNExecutor::forward(BackendExecutionContext& context) { + InUseGuard in_use_guard(in_use_); + ET_CHECK_OR_RETURN_ERROR( runtime_ != nullptr, Internal, @@ -160,11 +200,13 @@ ET_NODISCARD Error XNNExecutor::forward(BackendExecutionContext& context) { xnn_status status = xnn_setup_runtime_v2( runtime_.get(), externals_.size(), externals_.data()); - ET_CHECK_OR_RETURN_ERROR( - status == xnn_status_success, - Internal, - "Internal Error: Setting up the runtime failed with code: %s", - xnn_status_to_string(status)); + if (status != xnn_status_success) { + ET_LOG( + Error, + "Internal Error: Setting up the runtime failed with code: %s", + xnn_status_to_string(status)); + return Error::Internal; + } auto error = profiler_.start(context.event_tracer()); if (error != Error::Ok) { diff --git a/backends/xnnpack/runtime/XNNExecutor.h b/backends/xnnpack/runtime/XNNExecutor.h index fa7c8360be4..2d709678c1c 100644 --- a/backends/xnnpack/runtime/XNNExecutor.h +++ b/backends/xnnpack/runtime/XNNExecutor.h @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -36,11 +37,20 @@ class XNNExecutor { std::vector externals_; std::vector packed_data_names_; std::shared_ptr workspace_; + std::atomic in_use_{false}; + std::atomic destroyed_{false}; public: XNNExecutor(std::shared_ptr workspace) : workspace_(workspace) {} + ~XNNExecutor() { + ET_DCHECK_MSG( + !in_use_.load(std::memory_order_acquire), + "XNNExecutor destroyed while in use"); + destroyed_.store(true, std::memory_order_release); + } + inline size_t getNumInputs() { return input_ids_.size(); } diff --git a/backends/xnnpack/runtime/XNNPACKBackend.cpp b/backends/xnnpack/runtime/XNNPACKBackend.cpp index c20fa985f46..9eaadda86f8 100644 --- a/backends/xnnpack/runtime/XNNPACKBackend.cpp +++ b/backends/xnnpack/runtime/XNNPACKBackend.cpp @@ -100,6 +100,7 @@ class XnnpackBackend final lock_weights_cache.lock(); weights_cache_->initialize_for_runtime( context.get_runtime_allocator(), named_data_map); + workspace->set_uses_weight_cache(); } auto [workspace_lock, workspace_ptr] = workspace->acquire(); @@ -129,6 +130,7 @@ class XnnpackBackend final Error, "XNNCompiler::compileModel failed: 0x%x", (unsigned int)err); return err; } + return executor; } @@ -138,13 +140,15 @@ class XnnpackBackend final Span args) const override { auto executor = static_cast(handle); + auto workspace = executor->get_workspace(); + std::unique_lock lock_weights_cache( weights_cache_mutex_, std::defer_lock); - if (executor->uses_weight_cache()) { + if (executor->uses_weight_cache() || workspace->uses_weight_cache()) { lock_weights_cache.lock(); } - auto [raii_lock, _] = executor->get_workspace()->acquire(); + auto [raii_lock, _] = workspace->acquire(); // Prepare Inputs/Outputs and Propagate Input Shapes Error err = executor->prepare_args(args); @@ -167,14 +171,16 @@ class XnnpackBackend final void destroy(DelegateHandle* handle) const override { if (handle != nullptr) { auto executor = static_cast(handle); + auto workspace = executor->get_workspace(); + + const std::lock_guard lock_weights_cache( + weights_cache_mutex_); #ifdef ENABLE_XNNPACK_PROFILING executor->print_avg_op_timings(); #endif if (executor->uses_weight_cache()) { - const std::lock_guard lock_weights_cache( - weights_cache_mutex_); weights_cache_->delete_packed_data(executor->get_packed_data_names()); } @@ -183,7 +189,6 @@ class XnnpackBackend final // the same backend instance. Make sure to hold onto the workspace // shared_ptr, as the pointer in the executor is freed, which includes // the mutex referenced by raii_lock. - auto workspace = executor->get_workspace(); auto [raii_lock, _] = workspace->acquire(); // XNNExecutor is not trivially destructible. Since this was constructed diff --git a/backends/xnnpack/runtime/XNNWorkspace.h b/backends/xnnpack/runtime/XNNWorkspace.h index b7ef442c460..e1b452a0a8b 100644 --- a/backends/xnnpack/runtime/XNNWorkspace.h +++ b/backends/xnnpack/runtime/XNNWorkspace.h @@ -59,6 +59,14 @@ class XNNWorkspace { lock_required_ = false; } + void set_uses_weight_cache() { + uses_weight_cache_.store(true, std::memory_order_release); + } + + bool uses_weight_cache() const { + return uses_weight_cache_.load(std::memory_order_acquire); + } + static runtime::Result> create() { // Because this class can't be moved, we need to construct it in-place. xnn_workspace_t workspace = nullptr; @@ -80,6 +88,7 @@ class XNNWorkspace { std::mutex mutex_; uint64_t id_; bool lock_required_ = true; + std::atomic uses_weight_cache_{false}; WorkspacePtr workspace_; }; diff --git a/backends/xnnpack/runtime/XNNWorkspaceManager.cpp b/backends/xnnpack/runtime/XNNWorkspaceManager.cpp index d3550da5cc7..e115074a108 100644 --- a/backends/xnnpack/runtime/XNNWorkspaceManager.cpp +++ b/backends/xnnpack/runtime/XNNWorkspaceManager.cpp @@ -61,7 +61,9 @@ XNNWorkspaceManager::get_or_create_workspace( return create_result.error(); } +#ifndef XNNPACK_WORKSPACE_ALWAYS_LOCK create_result.get()->disable_locking(); +#endif return create_result.get(); } else if (mode == WorkspaceSharingMode::PerModel) { return get_or_create_model_workspace(program_id); diff --git a/backends/xnnpack/targets.bzl b/backends/xnnpack/targets.bzl index 868e68e5b8c..b3af589df10 100644 --- a/backends/xnnpack/targets.bzl +++ b/backends/xnnpack/targets.bzl @@ -14,6 +14,8 @@ def _get_preprocessor_flags(): if native.read_config("executorch", "xnnpack_weights_cache", "0") != "0": preprocessor_flags.append("-DENABLE_XNNPACK_WEIGHTS_CACHE") + preprocessor_flags.append("-DXNNPACK_WORKSPACE_ALWAYS_LOCK") + # Enable if not disabled through config return preprocessor_flags diff --git a/backends/xnnpack/test/runtime/test_workspace_manager.cpp b/backends/xnnpack/test/runtime/test_workspace_manager.cpp index a7689966635..a239d19b415 100644 --- a/backends/xnnpack/test/runtime/test_workspace_manager.cpp +++ b/backends/xnnpack/test/runtime/test_workspace_manager.cpp @@ -116,7 +116,11 @@ TEST_F(XNNWorkspaceManagerTest, DisabledModeAcquireDoesNotLock) { auto [lock, ptr] = workspace->acquire(); ASSERT_NE(ptr, nullptr); +#ifdef XNNPACK_WORKSPACE_ALWAYS_LOCK + EXPECT_TRUE(lock.owns_lock()); +#else EXPECT_FALSE(lock.owns_lock()); +#endif } TEST_F(XNNWorkspaceManagerTest, PerModelMode) { diff --git a/backends/xnnpack/test/targets.bzl b/backends/xnnpack/test/targets.bzl index 812986a12e6..d690e1c9dcd 100644 --- a/backends/xnnpack/test/targets.bzl +++ b/backends/xnnpack/test/targets.bzl @@ -96,6 +96,9 @@ def define_common_targets(): runtime.cxx_test( name = "test_workspace_manager", srcs = ["runtime/test_workspace_manager.cpp"], + preprocessor_flags = [ + "-DXNNPACK_WORKSPACE_ALWAYS_LOCK", + ], deps = [ third_party_dep("XNNPACK"), "//executorch/backends/xnnpack:xnnpack_backend", diff --git a/conftest.py b/conftest.py index 19d777a74e0..be0e6e4ea3d 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,4 @@ +import hashlib import sys import torch @@ -13,5 +14,8 @@ "backends/apple/**", ] -# Seed the run -torch.manual_seed(42) + +def pytest_runtest_setup(item): + # Set a stable seed for each test based on a hash of the test name. + seed = int(hashlib.sha256(item.nodeid.encode()).hexdigest(), 16) % (2**32) + torch.manual_seed(seed) diff --git a/docs/source/backends/nxp/nxp-kernel-selection.md b/docs/source/backends/nxp/nxp-kernel-selection.md index 3ff61323694..307f06d1d02 100644 --- a/docs/source/backends/nxp/nxp-kernel-selection.md +++ b/docs/source/backends/nxp/nxp-kernel-selection.md @@ -1,25 +1,25 @@ # NXP eIQ Neutron Kernel Selective Kernel Registration -The NXP ExecuTorch backend supports selective Neutron kernel registration for `Neutron-C` targets, which decreases the +The NXP ExecuTorch backend supports selective Neutron kernel registration for `Neutron-C` targets, which reduces the size of the Neutron Firmware. During the backend's conversion to the Neutron representation by the Neutron Converter, microcode for the Neutron accelerator is generated. The microcode consists of kernel calls executed by the Neutron Driver. The code for kernel call functions is -distributed in Neutron Firmware. +distributed in the Neutron Firmware. -The `eiq_neutron_sdk.neutron_converter` optionally generates the `*_kernel_selection.c` file, registering -only kernels that are required for a particular model or in the case of ExecuTorch, a delegated subgraph. This -`*_kernel_selection.c`, when used during the application linking, takes precedence over the default list of registered +The `eiq_neutron_sdk.neutron_converter` optionally generates a `*_kernel_selection.c` file, registering +only kernels that are required for a particular model or, in the case of ExecuTorch, a delegated subgraph. This +`*_kernel_selection.c`, when used during application linking, takes precedence over the default list of registered kernels in the Neutron Firmware, and allows the linker to include only the necessary Neutron kernels. -This software is required for deployment on an edge device (e.g. `i.MXRT700`) and is -distributed via the MCUXpresso SDK. The MCUXpresso SDK enables building of a final application that is then flashed on +The Neutron Firmware is required for deployment on an edge device (e.g. `i.MX RT700`) and is +distributed via the MCUXpresso SDK. The MCUXpresso SDK enables the building of a final application that is then flashed on the edge device. For more details about this process, see [eIQ ExecuTorch Library User Guide](https://mcuxpresso.nxp.com/mcuxsdk/latest/html/middleware/eiq/executorch/docs/nxp/ugindex.html). -By default, for Neutron-C targets like `i.MXRT700`, all kernel implementations are present in the Neutron Firmware, which +By default, for Neutron-C targets like `i.MX RT700`, all kernel implementations are present in the Neutron Firmware, which is linked to the final application. This enables an easy build process for any model, but increases the size of the -final application with unused code. In the case of limited RAM, you can link only kernels that are used in the set of -models deployed. This way you can reduce the size of the final app by linking only selected kernels, used in one or -multiple models. +final application with unused code. In memory-constrained environments, you can link only the kernels required by the +deployed models. This way you can reduce the size of the final application by linking only selected kernels, used in one +or more models. The feature works as follows: The Neutron Converter with the appropriate flag exports a kernel selection file for each converted subgraph, the kernel selection files are then merged and ready to be included in the MCUXpresso SDK to use for @@ -30,7 +30,7 @@ a selection-only build. ## Export kernel selection file -To turn on this feature on the side of NXP ExecuTorch backend, use the parameter `--dump_kernel_selection_code` in +To enable this feature in the NXP ExecuTorch backend, use the parameter `--dump_kernel_selection_code` in `aot_neutron_compile.py`. An example with the CifarNet model: ```commandline @@ -43,7 +43,7 @@ This command will create a `*_kernel_selection.c` file alongside the converted P ## Kernel Registration for Multiple Models -If you want to use or experiment with multiple models in one application while having reduced kernel set, you can +If you want to use or experiment with multiple models in one application while having a reduced kernel set, you can create one kernel selection file with the script `merge_kernel_selection_code.py`: ```commandline diff --git a/examples/apple/coreml/scripts/BUCK b/examples/apple/coreml/scripts/BUCK index 164feb8d306..42a97ea893f 100644 --- a/examples/apple/coreml/scripts/BUCK +++ b/examples/apple/coreml/scripts/BUCK @@ -16,6 +16,19 @@ fbcode_target(_kind = python_binary, ], ) +fbcode_target(_kind = python_binary, + name = "coreml_compute_plan", + srcs = [ + "coreml_compute_plan.py", + ], + main_function = "executorch.examples.apple.coreml.scripts.coreml_compute_plan.main", + deps = [ + "//executorch/backends/apple/coreml:executorchcoreml", + "//executorch/exir:schema", + "//executorch/exir/_serialize:lib", + ], +) + fbcode_target(_kind = python_binary, name = "export", srcs = [ diff --git a/examples/apple/coreml/scripts/coreml_compute_plan.py b/examples/apple/coreml/scripts/coreml_compute_plan.py new file mode 100644 index 00000000000..c0ca08db831 --- /dev/null +++ b/examples/apple/coreml/scripts/coreml_compute_plan.py @@ -0,0 +1,236 @@ +# Copyright © 2026 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + +"""Report which CoreML operations would dispatch to ANE / GPU / CPU. + +The CoreML runtime decides at compile/load time which compute device each +MIL operation will run on; that decision is exposed by ``MLComputePlan`` +in coremltools 9.0+. This script wraps that API so users can answer +"why isn't my model running on the ANE?" without writing Swift. + +Usage:: + + # Analyze a CoreML model directly (mlpackage or compiled mlmodelc). + python coreml_compute_plan.py --model_path path/to/model.mlpackage + + # Analyze every Core ML partition embedded in an ExecuTorch .pte. + python coreml_compute_plan.py --model_path path/to/program.pte + + # Show ops that fell off the ANE, grouped by op type. + python coreml_compute_plan.py --model_path model.mlpackage --show_non_ane + + # Pick which devices the runtime is allowed to consider. + python coreml_compute_plan.py --model_path model.mlpackage \\ + --compute_units cpu_and_ne +""" + +import argparse +import os +import sys +import tempfile +from collections import Counter +from typing import Iterable, List, Tuple + +import coremltools as ct +from coremltools.models.compute_device import ( + MLCPUComputeDevice, + MLGPUComputeDevice, + MLNeuralEngineComputeDevice, +) +from coremltools.models.compute_plan import MLComputePlan + +from executorch.examples.apple.coreml.scripts.extract_coreml_models import ( + extract_coreml_models, +) + + +_DEVICE_NAMES: List[Tuple[type, str]] = [ + (MLNeuralEngineComputeDevice, "ANE"), + (MLGPUComputeDevice, "GPU"), + (MLCPUComputeDevice, "CPU"), +] + +_COMPUTE_UNIT_CHOICES = { + "all": ct.ComputeUnit.ALL, + "cpu_and_ne": ct.ComputeUnit.CPU_AND_NE, + "cpu_and_gpu": ct.ComputeUnit.CPU_AND_GPU, + "cpu_only": ct.ComputeUnit.CPU_ONLY, +} + + +def _device_name(device) -> str: + if device is None: + return "unknown" + for cls, name in _DEVICE_NAMES: + if isinstance(device, cls): + return name + return type(device).__name__ + + +def _iter_operations(block) -> Iterable: + for op in block.operations: + yield op + for nested in getattr(op, "blocks", None) or []: + yield from _iter_operations(nested) + + +def _ensure_compiled(model_path: str, tmpdir: str) -> str: + """Return a `.mlmodelc` path; compile from `.mlpackage` if needed.""" + if model_path.endswith(".mlmodelc"): + return model_path + if model_path.endswith(".mlpackage"): + dest = os.path.join( + tmpdir, os.path.basename(model_path).replace(".mlpackage", ".mlmodelc") + ) + return str(ct.models.utils.compile_model(model_path, destination_path=dest)) + raise ValueError(f"Expected a .mlpackage or .mlmodelc path, got: {model_path}") + + +def analyze_one( + model_path: str, compute_units: ct.ComputeUnit +) -> List[Tuple[str, str, str]]: + """Return [(function, operator_name, device)] for every op that has a plan. + + coremltools 9.0's ``MLComputePlan.load_from_path`` only exposes usage for + the default function of a multifunction package, so a multifunction + .mlpackage is analyzed function-by-function by projecting each function + as the ``main`` of a temp single-function copy. + """ + function_names = _mlpackage_function_names(model_path) + if len(function_names) <= 1: + return _analyze_compiled(model_path, compute_units) + rows: List[Tuple[str, str, str]] = [] + with tempfile.TemporaryDirectory() as tmpdir: + for fname in function_names: + projected = _project_to_single(model_path, fname, tmpdir) + for _, op_name, device in _analyze_compiled(projected, compute_units): + rows.append((fname, op_name, device)) + return rows + + +def _analyze_compiled( + model_path: str, compute_units: ct.ComputeUnit +) -> List[Tuple[str, str, str]]: + with tempfile.TemporaryDirectory() as tmpdir: + compiled = _ensure_compiled(model_path, tmpdir) + plan = MLComputePlan.load_from_path(compiled, compute_units=compute_units) + program = plan.model_structure.program + if program is None: + raise RuntimeError( + f"{model_path} is not an MLProgram model; this tool only supports " + "the MLProgram backend (the CoreML backend executorch produces today)." + ) + + rows: List[Tuple[str, str, str]] = [] + for fname, fn in program.functions.items(): + for op in _iter_operations(fn.block): + usage = plan.get_compute_device_usage_for_mlprogram_operation(op) + if usage is None: + # Constants and similar non-dispatched ops don't have a plan. + continue + rows.append( + ( + fname, + op.operator_name, + _device_name(usage.preferred_compute_device), + ) + ) + return rows + + +def _mlpackage_function_names(model_path: str) -> List[str]: + """Names of the MLProgram functions inside an .mlpackage, or [] otherwise.""" + if not model_path.endswith(".mlpackage"): + return [] + spec = ct.models.MLModel(model_path, skip_model_load=True).get_spec() + if spec.WhichOneof("Type") != "mlProgram": + return [] + return list(spec.mlProgram.functions.keys()) + + +def _project_to_single(src_mlpackage: str, function_name: str, tmpdir: str) -> str: + """Re-save ``src_mlpackage`` with only ``function_name`` exposed as ``main``.""" + from coremltools.models.utils import MultiFunctionDescriptor, save_multifunction + + dest = os.path.join(tmpdir, f"{function_name}.mlpackage") + desc = MultiFunctionDescriptor() + desc.add_function( + src_mlpackage, + src_function_name=function_name, + target_function_name="main", + ) + desc.default_function_name = "main" + save_multifunction(desc, dest) + return dest + + +def _print_report( + label: str, rows: List[Tuple[str, str, str]], show_non_ane: bool +) -> None: + print(f"\n=== {label} ===") + if not rows: + print(" (no dispatched operations found)") + return + by_device = Counter(device for _, _, device in rows) + total = sum(by_device.values()) + for device in ("ANE", "GPU", "CPU", "unknown"): + count = by_device.get(device, 0) + if count == 0: + continue + pct = 100.0 * count / total + print(f" {device}: {count:5d} / {total} ({pct:5.1f}%)") + + if show_non_ane: + non_ane = [(fn, op_name) for fn, op_name, dev in rows if dev != "ANE"] + if non_ane: + print("\n Non-ANE op types:") + for op_name, count in Counter(op for _, op in non_ane).most_common(): + print(f" {count:5d} {op_name}") + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--model_path", + required=True, + help="Path to a .pte, .mlpackage, or .mlmodelc.", + ) + parser.add_argument( + "--compute_units", + default="cpu_and_ne", + choices=sorted(_COMPUTE_UNIT_CHOICES), + help="Which devices the runtime may use when planning dispatch.", + ) + parser.add_argument( + "--show_non_ane", + action="store_true", + help="List op types that did not get assigned to the ANE.", + ) + args = parser.parse_args() + + compute_units = _COMPUTE_UNIT_CHOICES[args.compute_units] + model_path = args.model_path + + if model_path.endswith(".pte"): + with open(model_path, "rb") as f: + pte_data = f.read() + with tempfile.TemporaryDirectory() as out_dir: + extracted = extract_coreml_models(pte_data, out_dir=out_dir) + if not extracted: + print( + f"{model_path} does not contain any CoreML delegate partitions.", + file=sys.stderr, + ) + return 1 + for path in extracted: + rows = analyze_one(str(path), compute_units) + _print_report(path.name, rows, args.show_non_ane) + else: + rows = analyze_one(model_path, compute_units) + _print_report(os.path.basename(model_path.rstrip("/")), rows, args.show_non_ane) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/apple/coreml/scripts/extract_coreml_models.py b/examples/apple/coreml/scripts/extract_coreml_models.py index 685b6b594f3..8956550eb4d 100644 --- a/examples/apple/coreml/scripts/extract_coreml_models.py +++ b/examples/apple/coreml/scripts/extract_coreml_models.py @@ -9,7 +9,7 @@ import shutil from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from executorch.backends.apple.coreml import executorchcoreml from executorch.exir._serialize._program import deserialize_pte_binary @@ -22,7 +22,12 @@ COREML_BACKEND_ID = "CoreMLBackend" -def extract_coreml_models(pte_data: bytes): +def extract_coreml_models( + pte_data: bytes, + out_dir: Optional[Union[str, Path]] = None, +) -> List[Path]: + out_root = Path(out_dir) if out_dir is not None else Path("extracted_coreml_models") + pte_file = deserialize_pte_binary(pte_data) program = pte_file.program @@ -44,6 +49,7 @@ def extract_coreml_models(pte_data: bytes): ] # Track extracted models to avoid duplicates (multifunction models share partitions) + extracted_paths: List[Path] = [] extracted_keys: set = set() model_index: int = 1 @@ -95,7 +101,7 @@ def extract_coreml_models(pte_data: bytes): if model_name is None: model_name = f"model_{model_index}" - model_path: Path = Path() / "extracted_coreml_models" / model_name + model_path: Path = out_root / model_name if model_path.exists(): shutil.rmtree(model_path.absolute()) os.makedirs(model_path.absolute()) @@ -104,11 +110,14 @@ def extract_coreml_models(pte_data: bytes): coreml_processed_bytes, str(model_path.absolute()) ): print(f"Core ML models are extracted and saved to path = {model_path}") + extracted_paths.append(model_path) model_index += 1 if len(coreml_delegates) == 0: print("The model isn't delegated to Core ML.") + return extracted_paths + def main() -> None: """ diff --git a/examples/apple/coreml/scripts/test_coreml_compute_plan.py b/examples/apple/coreml/scripts/test_coreml_compute_plan.py new file mode 100644 index 00000000000..83f06b7a2a8 --- /dev/null +++ b/examples/apple/coreml/scripts/test_coreml_compute_plan.py @@ -0,0 +1,161 @@ +# Copyright © 2026 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + +"""Tests for coreml_compute_plan.py.""" + +import os +import shutil +import tempfile +import unittest +from collections import Counter + +import coremltools as ct +import torch +from coremltools.models.utils import MultiFunctionDescriptor, save_multifunction + +from executorch.examples.apple.coreml.scripts.coreml_compute_plan import ( + _COMPUTE_UNIT_CHOICES, + _device_name, + analyze_one, +) + + +class _Op: + def __init__(self, operator_name: str, blocks=None): + self.operator_name = operator_name + self.blocks = blocks or [] + + +class _Block: + __slots__ = ("operations",) + + def __init__(self, ops): + self.operations = ops + + +def _build_small_mlpackage(out_dir: str) -> str: + class M(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.relu(x @ x.T) + x.sum() + + model = M().eval() + ep = torch.export.export(model, (torch.randn(8, 8),), strict=True) + ep = ep.run_decompositions({}) + mlmodel = ct.convert( + ep, + source="pytorch", + convert_to="mlprogram", + minimum_deployment_target=ct.target.iOS17, + skip_model_load=True, + ) + out = os.path.join(out_dir, "tiny.mlpackage") + mlmodel.save(out) + return out + + +class TestDeviceName(unittest.TestCase): + def test_none_device(self): + self.assertEqual(_device_name(None), "unknown") + + def test_known_device_classes(self): + from coremltools.models.compute_device import MLNeuralEngineComputeDevice + + # Don't construct the device classes directly (they wrap proxies that + # may be unavailable in some envs); just confirm the type-mapping path + # returns sensible names by mocking the isinstance check with a fake. + class FakeNE(MLNeuralEngineComputeDevice): + def __init__(self): + pass + + self.assertEqual(_device_name(FakeNE()), "ANE") + + +class TestComputeUnitChoices(unittest.TestCase): + def test_includes_cpu_and_ne(self): + self.assertEqual(_COMPUTE_UNIT_CHOICES["cpu_and_ne"], ct.ComputeUnit.CPU_AND_NE) + + def test_includes_all(self): + self.assertEqual(_COMPUTE_UNIT_CHOICES["all"], ct.ComputeUnit.ALL) + + +class TestAnalyzeOne(unittest.TestCase): + """End-to-end: build a tiny mlpackage and analyze it.""" + + @classmethod + def setUpClass(cls): + cls.tmpdir = tempfile.mkdtemp() + cls.mlpackage = _build_small_mlpackage(cls.tmpdir) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdir, ignore_errors=True) + + def test_returns_rows_for_dispatched_ops(self): + rows = analyze_one(self.mlpackage, ct.ComputeUnit.CPU_AND_NE) + self.assertGreater(len(rows), 0, "expected at least one dispatched op") + # Every row is (function_name, operator_name, device_name). + for fname, op_name, device in rows: + self.assertIsInstance(fname, str) + self.assertIsInstance(op_name, str) + self.assertIn(device, {"ANE", "GPU", "CPU", "unknown"}) + + def test_main_function_present(self): + rows = analyze_one(self.mlpackage, ct.ComputeUnit.CPU_ONLY) + self.assertIn("main", {fname for fname, _, _ in rows}) + + def test_op_types_for_relu_matmul_model(self): + # The toy model is `relu(x @ x.T) + x.sum()` so the lowered MIL + # should at least contain matmul, relu, add and reduce_sum. + rows = analyze_one(self.mlpackage, ct.ComputeUnit.CPU_ONLY) + op_types = Counter(op for _, op, _ in rows) + # Op names are versioned (e.g. "ios17.matmul"), so match by suffix. + suffixes = {name.split(".")[-1] for name in op_types} + for expected in ("matmul", "relu", "add", "reduce_sum"): + self.assertIn(expected, suffixes, f"missing op {expected}: {suffixes}") + + +class TestAnalyzeOneMultifunction(unittest.TestCase): + """Verify analyze_one walks every function of a multifunction .mlpackage. + + coremltools 9.0's MLComputePlan.load_from_path only exposes usage for + the default function, so analyze_one re-projects each function through + MultiFunctionDescriptor to surface plans for the rest. + """ + + @classmethod + def setUpClass(cls): + cls.tmpdir = tempfile.mkdtemp() + single = _build_small_mlpackage(cls.tmpdir) + desc = MultiFunctionDescriptor() + desc.add_function( + single, src_function_name="main", target_function_name="prefill" + ) + desc.add_function( + single, src_function_name="main", target_function_name="decode" + ) + desc.default_function_name = "prefill" + cls.multi = os.path.join(cls.tmpdir, "multi.mlpackage") + save_multifunction(desc, cls.multi) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdir, ignore_errors=True) + + def test_reports_every_function(self): + rows = analyze_one(self.multi, ct.ComputeUnit.CPU_ONLY) + fnames = {fname for fname, _, _ in rows} + self.assertEqual(fnames, {"prefill", "decode"}) + + def test_each_function_lowers_the_same_ops(self): + rows = analyze_one(self.multi, ct.ComputeUnit.CPU_ONLY) + per_fn: dict = {} + for fname, op_name, _ in rows: + per_fn.setdefault(fname, set()).add(op_name.split(".")[-1]) + for fname in ("prefill", "decode"): + self.assertIn("matmul", per_fn.get(fname, set()), f"{fname} missing matmul") + self.assertIn("relu", per_fn.get(fname, set()), f"{fname} missing relu") + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/arm/ethos-u-setup/core_platform/0003-Guard-HardFault-Handler-for-Armv6-M.patch b/examples/arm/ethos-u-setup/core_platform/0003-Guard-HardFault-Handler-for-Armv6-M.patch new file mode 100644 index 00000000000..57a27cb3dee --- /dev/null +++ b/examples/arm/ethos-u-setup/core_platform/0003-Guard-HardFault-Handler-for-Armv6-M.patch @@ -0,0 +1,49 @@ +From 380045853a133f298cee1bcf0c959b93ea94f9a2 Mon Sep 17 00:00:00 2001 +From: RJ Ascani +Date: Wed, 13 May 2026 15:42:13 -0700 +Subject: [PATCH] Guard HardFault_Handler for Armv6-M / Armv8-M Baseline + +The Corstone-300 HardFault_Handler is written for Armv7-M / Armv8-M +Mainline: it uses an `ite eq` IT-block in inline asm, and dereferences +the SCB CFSR/BFAR/MMFAR fault-status registers. Neither is available +on Armv6-M (Cortex-M0/M0+) or Armv8-M Baseline (Cortex-M23), so the +file fails to compile when the Corstone-300 target source is built +with `-mcpu=cortex-m0plus` to exercise the scalar CMSIS-NN code paths +on the Corstone-300 M55 simulator (an ISA superset). + +Wrap the Mainline-only implementation in +`__ARM_ARCH_7M__ / 7EM / 8M_MAIN / 8_1M_MAIN` and fall back to a +minimal `printf("Hard fault"); exit(1)` stub on Baseline cores. +--- + targets/corstone-300/target.cpp | 8 ++++++++ + 1 file changed, 8 insertions(+) + +diff --git a/targets/corstone-300/target.cpp b/targets/corstone-300/target.cpp +index bda2248..4aa3eea 100644 +--- a/targets/corstone-300/target.cpp ++++ b/targets/corstone-300/target.cpp +@@ -246,6 +246,11 @@ struct ExcContext { + }; + + void HardFault_Handler() { ++ // Armv6-M (M0/M0+) and Armv8-M Baseline (M23) lack the IT instruction and ++ // the SCB CFSR/BFAR/MMFAR fault-status registers, so the rich handler ++ // can't compile or run there. Fall back to a minimal stub on those cores. ++#if defined(__ARM_ARCH_7M__) || defined(__ARM_ARCH_7EM__) || defined(__ARM_ARCH_8M_MAIN__) || \ ++ defined(__ARM_ARCH_8_1M_MAIN__) + int irq; + struct ExcContext *e; + uint32_t sp; +@@ -267,6 +272,9 @@ void HardFault_Handler() { + sp); + printf( + "%11s cfsr=0x%08" PRIx32 " bfar=0x%08" PRIx32 " mmfar=0x%08" PRIx32 "\n", "", SCB->CFSR, SCB->BFAR, SCB->MMFAR); ++#else ++ printf("Hard fault\n"); ++#endif + exit(1); + } + } +-- +2.53.0 + diff --git a/examples/arm/ethos-u-setup/core_software/0002-Fix-ARMCM0plus-directory-case-and-compile-define-mis.patch b/examples/arm/ethos-u-setup/core_software/0002-Fix-ARMCM0plus-directory-case-and-compile-define-mis.patch new file mode 100644 index 00000000000..96dcdd9f29d --- /dev/null +++ b/examples/arm/ethos-u-setup/core_software/0002-Fix-ARMCM0plus-directory-case-and-compile-define-mis.patch @@ -0,0 +1,77 @@ +From 1ee9cf9c956ea6a266fc79dfa62071131f162510 Mon Sep 17 00:00:00 2001 +From: RJ Ascani +Date: Wed, 13 May 2026 15:48:07 -0700 +Subject: [PATCH] Fix ARMCM0plus directory case and compile-define mismatch +MIME-Version: 1.0 +Content-Type: text/plain; charset=UTF-8 +Content-Transfer-Encoding: 8bit + +The Cortex DFP names the Cortex-M0+ device directory and headers +`ARMCM0plus` (lowercase suffix), while the device source files +(`startup_ARMCM0plus.c`, `system_ARMCM0plus.c`) gate their +implementations on the `ARMCM0P` preprocessor macro — three different +spellings. `cmsis.cmake` previously did +`string(TOUPPER \"ARMCM\${CPU_NUMBER}\" ARM_CPU)`, producing +`ARMCM0PLUS`: the include path lookup fails and the source files hit +their `#error device not specified!` guard. + +Override `ARM_CPU` to `ARMCM0plus` and introduce a separate +`CMSIS_DEVICE_CPU_DEFINE` set to `ARMCM0P` for the cmsis_startup and +cmsis_system compile-definitions; all other cores still drive both +paths from the uppercased default. +--- + cmsis.cmake | 20 ++++++++++++++++++-- + 1 file changed, 18 insertions(+), 2 deletions(-) + +diff --git a/cmsis.cmake b/cmsis.cmake +index 7f2b93f..c49f205 100644 +--- a/cmsis.cmake ++++ b/cmsis.cmake +@@ -23,6 +23,15 @@ endif() + + string(TOUPPER "ARMCM${CPU_NUMBER}" ARM_CPU) + ++# Cortex-M0+ is special: the Cortex DFP names the device directory and headers ++# `ARMCM0plus` (lowercase suffix), while the device sources gate their ++# implementations on the `ARMCM0P` preprocessor macro. Override both so the ++# directory lookup and `#include` resolution succeed; the compile-definition ++# override is applied instead of `CMSIS_DEVICE_CPU_FEATURE` further down. ++if(CPU_NUMBER STREQUAL "0plus") ++ set(ARM_CPU "ARMCM0plus") ++endif() ++ + # Set CPU specific features + if(CMAKE_SYSTEM_PROCESSOR MATCHES "cortex-m33(\\+|$)") + set(ARM_FEATURES "_DSP_FP") +@@ -50,6 +59,13 @@ else() + cmake_path(SET CMSIS_DEVICE_CPU_FEATURE "${ARM_CPU}") + endif() + ++# Macro the device sources gate on. Matches CMSIS_DEVICE_CPU_FEATURE for most ++# cores; Cortex-M0+ keys off `ARMCM0P`, not `ARMCM0plus`. ++set(CMSIS_DEVICE_CPU_DEFINE "${CMSIS_DEVICE_CPU_FEATURE}") ++if(CPU_NUMBER STREQUAL "0plus") ++ set(CMSIS_DEVICE_CPU_DEFINE "ARMCM0P") ++endif() ++ + target_include_directories(cmsis_device INTERFACE ${CMSIS_DEVICE_PATH}/${ARM_CPU}/Include) + + target_compile_options(cmsis_device INTERFACE +@@ -66,12 +82,12 @@ target_sources(cmsis_startup INTERFACE + set_source_files_properties(${CMSIS_DEVICE_PATH}/${ARM_CPU}/Source/startup_${ARM_CPU}.c + PROPERTIES COMPILE_FLAGS -Wno-redundant-decls) + +-target_compile_definitions(cmsis_startup INTERFACE ${CMSIS_DEVICE_CPU_FEATURE}) ++target_compile_definitions(cmsis_startup INTERFACE ${CMSIS_DEVICE_CPU_DEFINE}) + target_link_libraries(cmsis_startup INTERFACE cmsis_device) + + # CMSIS system + add_library(cmsis_system INTERFACE) + target_sources(cmsis_system INTERFACE + ${CMSIS_DEVICE_PATH}/${ARM_CPU}/Source/system_${ARM_CPU}.c) +-target_compile_definitions(cmsis_system INTERFACE ${CMSIS_DEVICE_CPU_FEATURE}) ++target_compile_definitions(cmsis_system INTERFACE ${CMSIS_DEVICE_CPU_DEFINE}) + target_link_libraries(cmsis_system INTERFACE cmsis_startup) +-- +2.53.0 + diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 249e8fd14d4..b8a052004e4 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -152,13 +152,14 @@ int main(int argc, char** argv) { ET_LOG(Error, "Preprocessing failed."); return 1; } - auto mel_features = preprocess_result.get(); + auto preprocess_out = preprocess_result.get(); // --- Transcribe --- ET_LOG(Info, "Running TDT greedy decode..."); - auto result = runner.transcribe(mel_features, [](const std::string& piece) { - std::cout << piece << std::flush; - }); + auto result = runner.transcribe( + preprocess_out.features, + [](const std::string& piece) { std::cout << piece << std::flush; }, + preprocess_out.length); if (!result.ok()) { ET_LOG(Error, "Transcription failed."); diff --git a/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte b/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte index ad6bee06146..5903c5b5c32 100644 Binary files a/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte and b/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte differ diff --git a/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py b/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py index 7bebf513658..a75e67933e5 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py @@ -133,7 +133,7 @@ def _init_runner_base_cmd(self): base_cmd = " ".join( [ f"export LD_LIBRARY_PATH={self.qnn_sdk}/lib/x86_64-linux-clang/:{args.build_folder}/lib &&", - f"./{args.build_folder}/examples/qualcomm/oss_scripts/llama/{self.runner}", + f"{args.build_folder}/examples/qualcomm/oss_scripts/llama/{self.runner}", f"--decoder_model_version {DECODER_MODEL_VERSION[args.decoder_model]}", f"--tokenizer_path {self.runtime_tokenizer_path}", f"--output_path {self.device_output_response_path}", diff --git a/examples/qualcomm/oss_scripts/llama/decoder_utils.py b/examples/qualcomm/oss_scripts/llama/decoder_utils.py index 5380ff5220d..184eb857661 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_utils.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_utils.py @@ -317,13 +317,9 @@ def retrieve_info_from_pte(pte_path: str) -> dict: pte_max_context_len = pte_max_seq_len # FP has no scale/zero_point, use following values, which is equivalent to not performing dequantize. - if kv_io_bit_width == 32: + if kv_io_bit_width == 32 or (logits_scale is None or logits_zero_point is None): logits_scale = 1 logits_zero_point = 0 - elif logits_scale is None or logits_zero_point is None: - raise RuntimeError( - "Unable to find scale/offset. The .pte file might be deprecated. Please generate a new .pte file" - ) assert output_vocab_size is not None, "Couldn't find the vocab size" assert pte_max_seq_len is not None, "Couldn't find the max_seq_len from pte" meta_info = { diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index a8e28f96b71..ce0b7a80cfc 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -21,6 +21,7 @@ ) from executorch.backends.qualcomm.utils.utils import ( + generate_gpu_compiler_spec, generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, get_soc_to_chipset_map, @@ -119,9 +120,15 @@ def compile( # because the encoder is quite sensitive and quantization can make it harder for the model to distinguish # between images within the same conversation. to_skip = len(args.image_path) > 1 - backend_options = generate_htp_compiler_spec( - use_fp16=to_skip, - ) + if args.backend == "htp": + backend_options = generate_htp_compiler_spec( + use_fp16=to_skip, + ) + elif args.backend == "gpu": + backend_options = generate_gpu_compiler_spec() + else: + raise ValueError(f"Unsupported backend {args.backend}") + encoder_compile_specs = generate_qnn_executorch_compiler_spec( soc_model=get_soc_to_chipset_map()[args.soc_model], backend_options=backend_options, @@ -131,27 +138,40 @@ def compile( skip_quantize[modality] = to_skip compile_specs[modality] = encoder_compile_specs elif is_multimodal and modality == TOK_EMBEDDING: - backend_options = generate_htp_compiler_spec( - use_fp16=False, - # x86 emulator does not support weight sharing - use_weight_sharing=not args.enable_x86_64, - ) + if args.backend == "htp": + backend_options = generate_htp_compiler_spec( + use_fp16=False, + # x86 emulator does not support weight sharing + use_weight_sharing=not args.enable_x86_64, + ) + elif args.backend == "gpu": + backend_options = generate_gpu_compiler_spec() + else: + raise ValueError(f"Unsupported backend {args.backend}") + compile_specs[modality] = [ generate_qnn_executorch_compiler_spec( soc_model=get_soc_to_chipset_map()[args.soc_model], backend_options=backend_options, # x86 emulator does not support shared buffer shared_buffer=not args.enable_x86_64, + online_prepare=args.online_prepare, ) ] * len(TOK_EMBEDDING_GRAPH_NAMES) elif modality == TEXT_DECODER: # compile spec for text decoder - backend_options = generate_htp_compiler_spec( - use_fp16=False, - use_multi_contexts=decoder_model_config.num_sharding > 1, - # x86 emulator does not support weight sharing - use_weight_sharing=not args.enable_x86_64, - ) + if args.backend == "htp": + backend_options = generate_htp_compiler_spec( + use_fp16=args.use_fp16, + use_multi_contexts=decoder_model_config.num_sharding > 1, + # x86 emulator does not support weight sharing + use_weight_sharing=not args.enable_x86_64, + ) + elif args.backend == "gpu": + backend_options = generate_gpu_compiler_spec() + else: + raise ValueError(f"Unsupported backend {args.backend}") + skip_quantize[modality] = args.use_fp16 compile_specs[modality] = [ generate_qnn_executorch_compiler_spec( soc_model=get_soc_to_chipset_map()[args.soc_model], @@ -159,6 +179,7 @@ def compile( # x86 emulator does not support shared buffer shared_buffer=not args.enable_x86_64, use_mha2sha=True, + online_prepare=args.online_prepare, ) ] * len(DECODER_GRAPH_NAMES) @@ -172,7 +193,11 @@ def compile( ) # perform compilation - multi_modal_mgr.compile(compile_specs=compile_specs, pte_filenames=pte_filenames) + multi_modal_mgr.compile( + compile_specs=compile_specs, + pte_filenames=pte_filenames, + skip_quantize=skip_quantize, + ) def inference( @@ -529,6 +554,14 @@ def _build_parser(): help="Number of examples in few-shot context", ) + parser.add_argument( + "-F", + "--use_fp16", + help="If specified, will run in fp16 precision and discard ptq setting", + action="store_true", + default=False, + ) + parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument( @@ -592,6 +625,12 @@ def export_llama(args) -> None: pte_filename = "lookahead_llama_qnn" else: raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") + + if args.model_mode == "hybrid" and args.online_prepare: + raise RuntimeError( + "Currently hybrid mode is not compatible with online_prepare." + ) + if args.decoder_model == "stories260k": pte_filename = f"{args.decoder_model}_" + pte_filename pte_filenames = { @@ -740,6 +779,7 @@ def export_llama(args) -> None: def main(): parser = _build_parser() args = parser.parse_args() + args.build_folder = os.path.realpath(args.build_folder) try: export_llama(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index d8d82fece33..9b8cdd7999e 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -210,7 +210,6 @@ std::string get_formatted_prompt( return formatted_prompt; } -template void start_runner( std::unique_ptr module, std::vector& prompts, @@ -219,7 +218,7 @@ void start_runner( gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default ? false : true; // create llama runner - example::Runner runner( + example::Runner runner( std::move(module), FLAGS_decoder_model_version.c_str(), FLAGS_model_path.c_str(), @@ -298,26 +297,8 @@ int main(int argc, char** argv) { FLAGS_attention_sink_rope_path.c_str(), executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors); } - // Using 8bit as default since this meta is introduced with 16bit kv io - // support and older models only have 8bit kv io. - example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; - if (module->method_names()->count("get_kv_io_bit_width") > 0) { - kv_bitwidth = static_cast( - module->get("get_kv_io_bit_width").get().toScalar().to()); - } - - if (kv_bitwidth == example::KvBitWidth::kWidth8) { - start_runner( - std::move(module), prompts, std::move(attention_sink_rope_module)); - } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { - start_runner( - std::move(module), prompts, std::move(attention_sink_rope_module)); - } else { - ET_CHECK_MSG( - false, - "Unsupported kv bitwidth: %ld", - static_cast(kv_bitwidth)); - } + start_runner( + std::move(module), prompts, std::move(attention_sink_rope_module)); return 0; } diff --git a/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp index 29b6b9d7ddc..c9c2bd19940 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp @@ -137,7 +137,6 @@ std::vector CollectPrompts(int argc, char** argv) { return prompts; } -template void start_multimodal_runner( std::unique_ptr encoder, std::unique_ptr tok_embedding, @@ -150,7 +149,7 @@ void start_multimodal_runner( : true; // Create multimodal runner - example::QNNMultimodalRunner runner( + example::QNNMultimodalRunner runner( std::move(encoder), std::move(tok_embedding), std::move(text_decoder), @@ -289,35 +288,12 @@ int main(int argc, char** argv) { FLAGS_decoder_path.c_str(), executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors); - // Using 8bit as default since this meta is introduced with 16bit kv io - // support and older models only have 8bit kv io. - example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; - if (text_decoder->method_names()->count("get_kv_io_bit_width") > 0) { - kv_bitwidth = static_cast( - text_decoder->get("get_kv_io_bit_width") - .get() - .toScalar() - .to()); - } - // Start runner with appropriate KV bitwidth - if (kv_bitwidth == example::KvBitWidth::kWidth8) { - start_multimodal_runner( - std::move(encoder), - std::move(tok_embedding), - std::move(text_decoder), - prompts); - } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { - start_multimodal_runner( - std::move(encoder), - std::move(tok_embedding), - std::move(text_decoder), - prompts); - } else { - ET_CHECK_MSG( - false, - "Unsupported kv bitwidth: %ld", - static_cast(kv_bitwidth)); - } + // Start runner + start_multimodal_runner( + std::move(encoder), + std::move(tok_embedding), + std::move(text_decoder), + prompts); return 0; } diff --git a/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.h b/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.h index 888e9acd421..b714f737de3 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.h @@ -8,6 +8,7 @@ #pragma once +#include #include #include #include @@ -56,19 +57,36 @@ class DecoderRunner { inline int32_t logits_to_token( const executorch::aten::Tensor& logits_tensor, int64_t pos) { - auto* logits = logits_tensor.mutable_data_ptr(); + std::byte* logits = logits_tensor.mutable_data_ptr(); auto num_tokens = logits_tensor.size(1); auto vocab_size = logits_tensor.size(2); static std::vector logits_f(vocab_size); - auto* logits_last = logits; + std::byte* logits_last = logits; // offset to the meaningful logit we want for prefill model. + executorch::aten::ScalarType logits_dtype = logits_tensor.scalar_type(); + size_t logits_nbytes = getDtypeSize(logits_dtype); if (num_tokens > 1) { - logits_last += pos * vocab_size; + logits_last += pos * vocab_size * logits_nbytes; } - // Discard dequantization (converting uint16_t to float) because the + // Discard dequantization (converting std::byte to float) because the // relative order of elements remains the same without conversion for (int i = 0; i < vocab_size; i++) { - logits_f[i] = logits_last[i]; + switch (logits_dtype) { + case executorch::aten::ScalarType::UInt16: + logits_f[i] = reinterpret_cast(logits_last)[i]; + break; + case executorch::aten::ScalarType::Byte: + logits_f[i] = reinterpret_cast(logits_last)[i]; + break; + case executorch::aten::ScalarType::Float: + logits_f[i] = reinterpret_cast(logits_last)[i]; + break; + default: + ET_CHECK_MSG( + false, + "The scalar_type %s of logits is not supported", + executorch::runtime::toString(logits_dtype)); + } } return sampler_->sample(logits_f.data()); } diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp index e5c12068bab..7288ca5fbd1 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp @@ -7,24 +7,105 @@ */ #include +#include #include + +using executorch::runtime::MethodMeta; +using executorch::runtime::Result; +using executorch::runtime::TensorInfo; namespace example { -template -KVManager::KVManager(Metadata metadata) : metadata_(metadata) { + +namespace { +void fill_mask( + executorch::aten::ScalarType scalar_type, + std::byte* buf, + size_t size, + bool use_pos_value) { + if (use_pos_value) { + switch (scalar_type) { + case executorch::aten::ScalarType::UInt16: + std::fill_n(reinterpret_cast(buf), size, 65535u); + break; + case executorch::aten::ScalarType::Byte: + std::fill_n(reinterpret_cast(buf), size, 255u); + break; + case executorch::aten::ScalarType::Float: + std::fill_n(reinterpret_cast(buf), size, 0.0); + break; + default: + ET_CHECK_MSG( + false, + "Unsupported scalar type %s", + executorch::runtime::toString(scalar_type)); + break; + } + } else { + switch (scalar_type) { + case executorch::aten::ScalarType::UInt16: + std::fill_n(reinterpret_cast(buf), size, 0u); + break; + case executorch::aten::ScalarType::Byte: + std::fill_n(reinterpret_cast(buf), size, 0u); + break; + // -65535 acts as the additive "very negative" attention-mask value; + // chosen as a large finite negative so masked positions effectively + // zero out after softmax without relying on -inf. + case executorch::aten::ScalarType::Float: + std::fill_n(reinterpret_cast(buf), size, -65535.0); + break; + default: + ET_CHECK_MSG( + false, + "Unsupported scalar type %s", + executorch::runtime::toString(scalar_type)); + break; + } + } +} +} // namespace + +KVManager::KVManager(Metadata metadata, std::unique_ptr method_meta) + : metadata_(metadata) { + Result attention_mask = method_meta->input_tensor_meta(1); + attention_mask_dtype_ = attention_mask->scalar_type(); + + // inputs are [input_tokens, attention_mask, (sliding window attention_mask), + // (input_pos), kv_caches] search kv_cache in inputs + for (int i = 2; i < method_meta->num_inputs(); i++) { + Result tensor_meta = method_meta->input_tensor_meta(i); + // k_cache: [1, n_heads, head_dim, seq_len] + size_t tensor_nbytes = tensor_meta->nbytes(); + size_t expected_tensor_nbytes = metadata_.head_dim * metadata_.num_heads * + metadata_.max_cache_len * getDtypeSize(tensor_meta->scalar_type()); + if (tensor_nbytes != expected_tensor_nbytes) { + // Not a kv_cache tensor (e.g. input_pos, sliding window attention mask). + continue; + } + if (kv_cache_dtype_ == executorch::aten::ScalarType::Undefined) { + kv_cache_dtype_ = tensor_meta->scalar_type(); + } else { + ET_CHECK_MSG( + tensor_meta->scalar_type() == kv_cache_dtype_, + "Currently mixed scalar type of kv_cache is not allowed"); + } + } + ET_CHECK_MSG( + kv_cache_dtype_ != executorch::aten::ScalarType::Undefined, + "kv_cache_dtype was not detected from method inputs"); k_cache_.resize(metadata_.num_layers); v_cache_.resize(metadata_.num_layers); // Calculate cache size size_t cache_in_bytes = metadata_.num_layers * metadata_.num_heads * - metadata_.head_dim * metadata_.max_cache_len * sizeof(T); + metadata_.head_dim * metadata_.max_cache_len * + getDtypeSize(kv_cache_dtype_); size_t cache_out_bytes = metadata_.num_layers * metadata_.num_heads * - metadata_.head_dim * metadata_.max_ar_len * sizeof(T); + metadata_.head_dim * metadata_.max_ar_len * getDtypeSize(kv_cache_dtype_); total_cache_size_ = 2 * (cache_in_bytes + cache_out_bytes); }; -template -void KVManager::init_attention_mask( - uint16_t* attention_mask, +void KVManager::init_attention_mask( + std::byte* attention_mask, const std::vector& attention_map, int32_t ar_len, int32_t n_past) { @@ -33,38 +114,51 @@ void KVManager::init_attention_mask( "The size of attention_map (%zu) doesn't match with ar_len (%d)", attention_map.size(), ar_len); - uint16_t neg_val = 0; - uint16_t pos_val = 65535; // Clear the attention mask - std::fill_n(attention_mask, ar_len * metadata_.context_len, neg_val); + fill_mask( + attention_mask_dtype_, + attention_mask, + ar_len * metadata_.context_len, + /*use_pos_value=*/false); // SMART_MASK requires special handling of attention mask - uint16_t* past_ptr = attention_mask; - uint16_t* new_ptr = attention_mask + (metadata_.context_len - ar_len); + std::byte* past_ptr = attention_mask; + std::byte* new_ptr = attention_mask + + (metadata_.context_len - ar_len) * getDtypeSize(attention_mask_dtype_); // All inputs will necessarily attend to n_past and itself for (int i = 0; i < ar_len; i++) { // Iterate across ar_len if (attention_map[i] < 0) { // If negative, attend to only past tokens - std::fill_n(past_ptr, n_past, pos_val); + fill_mask( + attention_mask_dtype_, + past_ptr, + n_past, + /*use_pos_value=*/true); } else { // If positive, copy attention map from (relative to 0th input) parent // Parent token index const int32_t pidx = attention_map[i]; - uint16_t* parent_ptr = attention_mask + pidx * metadata_.context_len; + std::byte* parent_ptr = attention_mask + + pidx * metadata_.context_len * getDtypeSize(attention_mask_dtype_); std::memcpy( - past_ptr, parent_ptr, metadata_.context_len * sizeof(uint16_t)); + past_ptr, + parent_ptr, + metadata_.context_len * getDtypeSize(attention_mask_dtype_)); } // Attend to itself - new_ptr[i] = pos_val; - past_ptr += metadata_.context_len; - new_ptr += metadata_.context_len; + fill_mask( + attention_mask_dtype_, + new_ptr + i * getDtypeSize(attention_mask_dtype_), + 1, + /*use_pos_value=*/true); + past_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); + new_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); } } -template -void KVManager::init_attention_mask( - uint16_t* attention_mask, +void KVManager::init_attention_mask( + std::byte* attention_mask, const std::vector& attention_map, int32_t ar_len, int32_t n_past, @@ -75,30 +169,44 @@ void KVManager::init_attention_mask( "The size of attention_map (%zu) doesn't match with ar_len (%d)", attention_map.size(), ar_len); - uint16_t neg_val = 0; - uint16_t pos_val = 65535; // Clear the attention mask - std::fill_n(attention_mask, ar_len * metadata_.context_len, neg_val); + fill_mask( + attention_mask_dtype_, + attention_mask, + ar_len * metadata_.context_len, + /*use_pos_value=*/false); // SMART_MASK requires special handling of attention mask - uint16_t* past_ptr = attention_mask; - uint16_t* new_ptr = attention_mask + (metadata_.context_len - ar_len); + std::byte* past_ptr = attention_mask; + std::byte* new_ptr = attention_mask + + (metadata_.context_len - ar_len) * getDtypeSize(attention_mask_dtype_); // All inputs will necessarily attend to n_past and itself for (int i = 0; i < ar_len; i++) { // Iterate across ar_len if (attention_map[i] < 0) { // If negative, attend to only past tokens - std::fill_n(past_ptr, n_past, pos_val); + fill_mask( + attention_mask_dtype_, + past_ptr, + n_past, + /*use_pos_value=*/true); } else { // If positive, copy attention map from (relative to 0th input) parent // Parent token index const int32_t pidx = attention_map[i]; - uint16_t* parent_ptr = attention_mask + pidx * metadata_.context_len; + std::byte* parent_ptr = attention_mask + + pidx * metadata_.context_len * getDtypeSize(attention_mask_dtype_); std::memcpy( - past_ptr, parent_ptr, metadata_.context_len * sizeof(uint16_t)); + past_ptr, + parent_ptr, + metadata_.context_len * getDtypeSize(attention_mask_dtype_)); } // Attend to itself - new_ptr[i] = pos_val; + fill_mask( + attention_mask_dtype_, + new_ptr + i * getDtypeSize(attention_mask_dtype_), + 1, + /*use_pos_value=*/true); // mask by limitation of sliding_window int32_t available_context_len = position_offset.empty() @@ -107,87 +215,73 @@ void KVManager::init_attention_mask( // if available_context_len is less than 0, it means we need to mask some // tokens in the past to avoid exceeding the sliding window if (available_context_len < 0) { - std::fill_n(past_ptr, -available_context_len, neg_val); + fill_mask( + attention_mask_dtype_, + past_ptr, + -available_context_len, + /*use_pos_value=*/false); } - past_ptr += metadata_.context_len; - new_ptr += metadata_.context_len; + past_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); + new_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); } } -template -void KVManager::update_attention_mask( - uint16_t* attention_mask, +void KVManager::update_attention_mask( + std::byte* attention_mask, int32_t ar_len, int32_t n_past, int32_t n_update) { - uint16_t pos_val = 65535; - uint16_t* cur_ptr = attention_mask; - cur_ptr += n_past; + std::byte* cur_ptr = + attention_mask + n_past * getDtypeSize(attention_mask_dtype_); for (int i = 0; i < ar_len; i++) { - std::fill_n(cur_ptr, n_update, pos_val); - cur_ptr += metadata_.context_len; + fill_mask(attention_mask_dtype_, cur_ptr, n_update, /*use_pos_value=*/true); + cur_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); } } -template -void KVManager::update_attention_mask( - uint16_t* attention_mask, +void KVManager::update_attention_mask( + std::byte* attention_mask, int32_t ar_len, int32_t n_past, int32_t n_update, int32_t sliding_window, const std::vector& position_offset) { - uint16_t pos_val = 65535; - uint16_t neg_val = 0; - uint16_t* cur_ptr = attention_mask; - cur_ptr += n_past; + std::byte* cur_ptr = + attention_mask + n_past * getDtypeSize(attention_mask_dtype_); for (int i = 0; i < ar_len; i++) { - std::fill_n(cur_ptr, n_update, pos_val); + fill_mask(attention_mask_dtype_, cur_ptr, n_update, /*use_pos_value=*/true); int32_t available_cache_len = position_offset.empty() ? sliding_window - (i + 1) : sliding_window - (position_offset[i] + 1); if (n_past + n_update > available_cache_len) { - std::fill_n( - cur_ptr - n_past, n_past + n_update - available_cache_len, neg_val); + fill_mask( + attention_mask_dtype_, + cur_ptr - n_past * getDtypeSize(attention_mask_dtype_), + n_past + n_update, + /*use_pos_value=*/false); } - cur_ptr += metadata_.context_len; + cur_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); } } -template -void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) { +void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) { cur_ar_len_ = ar_len; - const size_t max_in_cache_block_in_bytes = - metadata_.max_cache_len * sizeof(T); - const size_t max_out_cache_block_in_bytes = metadata_.max_ar_len * sizeof(T); - - const size_t cache_in_bytes = - metadata_.num_heads * metadata_.head_dim * max_in_cache_block_in_bytes; - const size_t cache_out_bytes = - metadata_.num_heads * metadata_.head_dim * max_out_cache_block_in_bytes; + const size_t cache_in_bytes = metadata_.num_heads * metadata_.head_dim * + metadata_.max_cache_len * getDtypeSize(kv_cache_dtype_); + const size_t cache_out_bytes = metadata_.num_heads * metadata_.head_dim * + metadata_.max_ar_len * getDtypeSize(kv_cache_dtype_); for (int layer = 0; layer < metadata_.num_layers; ++layer) { - // Allocate buffer for key cache and value cache - T* single_layer_k_cache_in = - reinterpret_cast(buffer_manager->allocate(cache_in_bytes)); - T* single_layer_k_cache_out = - reinterpret_cast(buffer_manager->allocate(cache_out_bytes)); - T* single_layer_v_cache_in = - reinterpret_cast(buffer_manager->allocate(cache_in_bytes)); - T* single_layer_v_cache_out = - reinterpret_cast(buffer_manager->allocate(cache_out_bytes)); - - k_cache_[layer].buffer = single_layer_k_cache_in; - k_cache_[layer].output_buffer = single_layer_k_cache_out; - v_cache_[layer].buffer = single_layer_v_cache_in; - v_cache_[layer].output_buffer = single_layer_v_cache_out; + k_cache_[layer].buffer = buffer_manager->allocate(cache_in_bytes); + k_cache_[layer].output_buffer = buffer_manager->allocate(cache_out_bytes); + v_cache_[layer].buffer = buffer_manager->allocate(cache_in_bytes); + v_cache_[layer].output_buffer = buffer_manager->allocate(cache_out_bytes); } } -template -void KVManager::rearrange_cache(int32_t ar_len_dst) { +void KVManager::rearrange_cache(int32_t ar_len_dst) { // Don't need to rearrange if cur_ar_len_ is equal to target ar_len if (cur_ar_len_ == ar_len_dst) return; @@ -199,75 +293,73 @@ void KVManager::rearrange_cache(int32_t ar_len_dst) { cur_ar_len_ = ar_len_dst; } -template -void KVManager::rearrange_key(KVCache& k_cache, int32_t ar_len_dst) { +void KVManager::rearrange_key(KVCache& k_cache, int32_t ar_len_dst) { const int32_t src_cache_num = (cur_ar_len_ == metadata_.context_len) ? metadata_.context_len : metadata_.context_len - cur_ar_len_; const int32_t dst_cache_num = metadata_.context_len - ar_len_dst; - T* k_cache_in_read_ptr = k_cache.buffer; - T* k_cache_in_write_ptr = k_cache.buffer; - + std::byte* k_cache_in_read_ptr = k_cache.buffer; + std::byte* k_cache_in_write_ptr = k_cache.buffer; + size_t src_cache_nbytes = src_cache_num * getDtypeSize(kv_cache_dtype_); + size_t dst_cache_nbytes = dst_cache_num * getDtypeSize(kv_cache_dtype_); if (src_cache_num > dst_cache_num) { // copy from first dimension for (int i = 0; i < metadata_.head_dim * metadata_.num_heads; i++) { - std::memmove( - k_cache_in_write_ptr, k_cache_in_read_ptr, dst_cache_num * sizeof(T)); - k_cache_in_read_ptr += src_cache_num; - k_cache_in_write_ptr += dst_cache_num; + std::memmove(k_cache_in_write_ptr, k_cache_in_read_ptr, dst_cache_nbytes); + k_cache_in_read_ptr += src_cache_nbytes; + k_cache_in_write_ptr += dst_cache_nbytes; } } else { k_cache_in_read_ptr += - (metadata_.head_dim * metadata_.num_heads - 1) * src_cache_num; + (metadata_.head_dim * metadata_.num_heads - 1) * src_cache_nbytes; k_cache_in_write_ptr += - (metadata_.head_dim * metadata_.num_heads - 1) * dst_cache_num; + (metadata_.head_dim * metadata_.num_heads - 1) * dst_cache_nbytes; // copy from last dimension for (int i = 0; i < metadata_.head_dim * metadata_.num_heads; i++) { - std::memmove( - k_cache_in_write_ptr, k_cache_in_read_ptr, src_cache_num * sizeof(T)); - k_cache_in_read_ptr -= src_cache_num; - k_cache_in_write_ptr -= dst_cache_num; + std::memmove(k_cache_in_write_ptr, k_cache_in_read_ptr, src_cache_nbytes); + k_cache_in_read_ptr -= src_cache_nbytes; + k_cache_in_write_ptr -= dst_cache_nbytes; } } } -template -void KVManager::rearrange_value(KVCache& v_cache, int32_t ar_len_dst) { +void KVManager::rearrange_value(KVCache& v_cache, int32_t ar_len_dst) { const int32_t src_cache_num = (cur_ar_len_ == metadata_.context_len) ? metadata_.context_len : metadata_.context_len - cur_ar_len_; const int32_t dst_cache_num = metadata_.context_len - ar_len_dst; - T* v_cache_in_read_ptr = v_cache.buffer; - T* v_cache_in_write_ptr = v_cache.buffer; + std::byte* v_cache_in_read_ptr = v_cache.buffer; + std::byte* v_cache_in_write_ptr = v_cache.buffer; + size_t src_cache_nbytes = src_cache_num * getDtypeSize(kv_cache_dtype_); + size_t dst_cache_nbytes = dst_cache_num * getDtypeSize(kv_cache_dtype_); if (src_cache_num > dst_cache_num) { // copy from first dimension for (int i = 0; i < metadata_.num_heads; i++) { std::memmove( v_cache_in_write_ptr, v_cache_in_read_ptr, - dst_cache_num * metadata_.head_dim * sizeof(T)); - v_cache_in_read_ptr += src_cache_num * metadata_.head_dim; - v_cache_in_write_ptr += dst_cache_num * metadata_.head_dim; + dst_cache_nbytes * metadata_.head_dim); + v_cache_in_read_ptr += src_cache_nbytes * metadata_.head_dim; + v_cache_in_write_ptr += dst_cache_nbytes * metadata_.head_dim; } } else { v_cache_in_read_ptr += - metadata_.head_dim * (metadata_.num_heads - 1) * src_cache_num; + metadata_.head_dim * (metadata_.num_heads - 1) * src_cache_nbytes; v_cache_in_write_ptr += - metadata_.head_dim * (metadata_.num_heads - 1) * dst_cache_num; + metadata_.head_dim * (metadata_.num_heads - 1) * dst_cache_nbytes; // copy from last dimension for (int i = 0; i < metadata_.num_heads; i++) { std::memmove( v_cache_in_write_ptr, v_cache_in_read_ptr, - src_cache_num * metadata_.head_dim * sizeof(T)); - v_cache_in_read_ptr -= src_cache_num * metadata_.head_dim; - v_cache_in_write_ptr -= dst_cache_num * metadata_.head_dim; + src_cache_nbytes * metadata_.head_dim); + v_cache_in_read_ptr -= src_cache_nbytes * metadata_.head_dim; + v_cache_in_write_ptr -= dst_cache_nbytes * metadata_.head_dim; } } } -template -void KVManager::update_cache( +void KVManager::update_cache( int32_t ar_len, int32_t n_past, int32_t n_update, @@ -283,20 +375,19 @@ void KVManager::update_cache( } } -template -void KVManager::update_key( - KVCache& k_cache, +void KVManager::update_key( + KVCache& k_cache, int32_t n_past, int32_t n_update, const std::vector& selected) { - T* write_ptr = k_cache.buffer; - T* read_ptr = k_cache.output_buffer; - const int32_t copy_size = n_update * sizeof(T); + std::byte* write_ptr = k_cache.buffer; + std::byte* read_ptr = k_cache.output_buffer; + const int32_t copy_size = n_update * getDtypeSize(kv_cache_dtype_); const int32_t iter_size = (cur_ar_len_ == metadata_.context_len) - ? metadata_.context_len - : metadata_.context_len - cur_ar_len_; - const int32_t out_size = cur_ar_len_; - const int32_t past_size = n_past; + ? metadata_.context_len * getDtypeSize(kv_cache_dtype_) + : (metadata_.context_len - cur_ar_len_) * getDtypeSize(kv_cache_dtype_); + const int32_t out_size = cur_ar_len_ * getDtypeSize(kv_cache_dtype_); + const int32_t past_size = n_past * getDtypeSize(kv_cache_dtype_); const int32_t n_iter = metadata_.head_dim * metadata_.num_heads; write_ptr += past_size; @@ -316,7 +407,11 @@ void KVManager::update_key( for (int i = 0; i < n_iter; ++i) { auto wp = write_ptr, rp = read_ptr; for (auto ind : true_indices) { - *wp++ = rp[ind]; + std::memmove( + wp, + rp + ind * getDtypeSize(kv_cache_dtype_), + getDtypeSize(kv_cache_dtype_)); + wp += getDtypeSize(kv_cache_dtype_); } write_ptr += iter_size; read_ptr += out_size; @@ -324,21 +419,25 @@ void KVManager::update_key( } } -template -void KVManager::update_value( - KVCache& v_cache, +void KVManager::update_value( + KVCache& v_cache, int32_t n_past, int32_t n_update, const std::vector& selected) { - T* write_ptr = v_cache.buffer; - T* read_ptr = v_cache.output_buffer; - const int32_t copy_size = n_update * metadata_.head_dim * sizeof(T); - const int32_t past_size = n_past * metadata_.head_dim; + std::byte* write_ptr = v_cache.buffer; + std::byte* read_ptr = v_cache.output_buffer; + const int32_t copy_size = + n_update * metadata_.head_dim * getDtypeSize(kv_cache_dtype_); + const int32_t past_size = + n_past * metadata_.head_dim * getDtypeSize(kv_cache_dtype_); const int32_t n_iter = metadata_.num_heads; const int32_t iter_size = (cur_ar_len_ == metadata_.context_len) - ? metadata_.context_len * metadata_.head_dim - : (metadata_.context_len - cur_ar_len_) * metadata_.head_dim; - const int32_t out_size = cur_ar_len_ * metadata_.head_dim; + ? metadata_.context_len * metadata_.head_dim * + getDtypeSize(kv_cache_dtype_) + : (metadata_.context_len - cur_ar_len_) * metadata_.head_dim * + getDtypeSize(kv_cache_dtype_); + const int32_t out_size = + cur_ar_len_ * metadata_.head_dim * getDtypeSize(kv_cache_dtype_); write_ptr += past_size; @@ -354,13 +453,14 @@ void KVManager::update_value( auto wp = write_ptr, rp = read_ptr; for (auto sel : selected) { if (sel) { - std::memcpy(wp, rp, metadata_.head_dim * sizeof(T)); - wp += metadata_.head_dim; + std::memcpy( + wp, rp, metadata_.head_dim * getDtypeSize(kv_cache_dtype_)); + wp += metadata_.head_dim * getDtypeSize(kv_cache_dtype_); update_times--; if (update_times == 0) break; } - rp += metadata_.head_dim; + rp += metadata_.head_dim * getDtypeSize(kv_cache_dtype_); } write_ptr += iter_size; read_ptr += out_size; @@ -368,8 +468,4 @@ void KVManager::update_value( } } -// Explicit instantiations -template class KVManager; -template class KVManager; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h index 06fe88517a7..3b8e67dd38d 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include #include @@ -15,17 +16,15 @@ namespace example { // Structure to hold key-value cache buffers -template struct KVCache { - T* buffer; - T* output_buffer; + std::byte* buffer; + std::byte* output_buffer; }; /** * @class KVManager * @brief Class for kv cache update, rearrangement, and buffer allocatation. */ -template class KVManager { public: struct Metadata { @@ -36,7 +35,9 @@ class KVManager { int64_t num_heads; int64_t num_layers; }; - KVManager(Metadata metadata); + KVManager( + Metadata metadata, + std::unique_ptr method_meta); /** * @brief Allocate buffer for KV cache and set the cur_ar_len_. @@ -71,7 +72,7 @@ class KVManager { * @param n_past Number of past elements in the cache. */ void init_attention_mask( - uint16_t* attention_mask, + std::byte* attention_mask, const std::vector& attention_map, int32_t ar_len, int32_t n_past); @@ -98,7 +99,7 @@ class KVManager { * @param position_offset (optional) attention mask position offset of */ void init_attention_mask( - uint16_t* attention_mask, + std::byte* attention_mask, const std::vector& attention_map, int32_t ar_len, int32_t n_past, @@ -114,7 +115,7 @@ class KVManager { * @param n_update Number of elements to be updated. */ void update_attention_mask( - uint16_t* attention_mask, + std::byte* attention_mask, int32_t ar_len, int32_t n_past, int32_t n_update); @@ -132,7 +133,7 @@ class KVManager { * lookahead decoder */ void update_attention_mask( - uint16_t* attention_mask, + std::byte* attention_mask, int32_t ar_len, int32_t n_past, int32_t n_update, @@ -152,10 +153,10 @@ class KVManager { int32_t n_update, const std::vector& selected); - const std::vector>& get_k_cache_() const { + const std::vector& get_k_cache_() const { return k_cache_; } - const std::vector>& get_v_cache_() const { + const std::vector& get_v_cache_() const { return v_cache_; } @@ -169,15 +170,19 @@ class KVManager { private: // Helper functions to rearrange and update key and value caches - void rearrange_key(KVCache& k_cache, int32_t ar_len_dst); - void rearrange_value(KVCache& v_cache, int32_t ar_len_dst); + + void rearrange_key(KVCache& k_cache, int32_t ar_len_dst); + + void rearrange_value(KVCache& v_cache, int32_t ar_len_dst); + void update_key( - KVCache& k_cache, + KVCache& k_cache, int32_t n_past, int32_t n_update, const std::vector& selected); + void update_value( - KVCache& v_cache, + KVCache& v_cache, int32_t n_past, int32_t n_update, const std::vector& selected); @@ -186,10 +191,14 @@ class KVManager { Metadata metadata_; size_t total_cache_size_; int32_t cur_ar_len_; + executorch::aten::ScalarType attention_mask_dtype_ = + executorch::aten::ScalarType::Undefined; + executorch::aten::ScalarType kv_cache_dtype_ = + executorch::aten::ScalarType::Undefined; // Store start pointer of k and v cache for input and output // input: layer -> head * head_dim * max_cache_len // output: layer -> head * head_dim * max_ar_len - std::vector> k_cache_; - std::vector> v_cache_; + std::vector k_cache_; + std::vector v_cache_; }; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp index f7e44292f26..298fc1ac9ff 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp @@ -13,20 +13,19 @@ using executorch::runtime::Result; namespace example { -template -void LhdTokenGenerator::prepare_io( +void LhdTokenGenerator::prepare_io( std::vector input_tokens, std::vector input_pos) { for (int i = 0; i < metadata_.ar_len; i++) { if (i < input_tokens.size()) { // Prepare pos data - this->input_pos_.data[i] = input_pos[i]; + reinterpret_cast(this->input_pos_.data)[i] = input_pos[i]; // Support CPU 4-bit embedding, which requires int64 input. // However, for QNN embedding, only int32 input is needed. // Therefore, we need to cast to the correct type to write the data. if (metadata_.use_int64_token) { - this->input_toks_.data[i] = input_tokens[i]; + reinterpret_cast(this->input_toks_.data)[i] = input_tokens[i]; } else { int32_t* input_toks_ptr = reinterpret_cast(this->input_toks_.data); @@ -36,8 +35,7 @@ void LhdTokenGenerator::prepare_io( } } -template -void LhdTokenGenerator::init_attention_mask(int32_t n_past) { +void LhdTokenGenerator::init_attention_mask(int32_t n_past) { std::vector attention_map; attention_map.reserve(metadata_.ar_len); // Initialize attention mask with current position @@ -73,8 +71,7 @@ void LhdTokenGenerator::init_attention_mask(int32_t n_past) { } } -template -void LhdTokenGenerator::init_lookahead_branch( +void LhdTokenGenerator::init_lookahead_branch( const std::vector& tokens) { for (int i = 0; i < metadata_.ngram - 1; ++i) { for (int j = 0; j < metadata_.window; ++j) { @@ -91,8 +88,7 @@ void LhdTokenGenerator::init_lookahead_branch( is_lhd_branch_initialized_ = true; } -template -void LhdTokenGenerator::init_verification_branch(uint64_t cur_token) { +void LhdTokenGenerator::init_verification_branch(uint64_t cur_token) { const int g_cur = ngrams_pool_.cnt[cur_token]; v_branch_.resize(g_cur); @@ -116,8 +112,7 @@ void LhdTokenGenerator::init_verification_branch(uint64_t cur_token) { } } -template -void LhdTokenGenerator::update_ngrams_pool() { +void LhdTokenGenerator::update_ngrams_pool() { std::vector ngram(metadata_.ngram - 1); // n-gram pool generation for (int f = 0; f < metadata_.window; ++f) { @@ -170,8 +165,7 @@ void LhdTokenGenerator::update_ngrams_pool() { } } -template -void LhdTokenGenerator::update_lookahead_branch( +void LhdTokenGenerator::update_lookahead_branch( const executorch::aten::Tensor& logits_tensor) { for (int i = 0; i < metadata_.window; i++) { lhd_branch_prev_[i] = lhd_branch_[0][i]; @@ -189,8 +183,7 @@ void LhdTokenGenerator::update_lookahead_branch( } } -template -Result LhdTokenGenerator::generate( +Result LhdTokenGenerator::generate( std::vector tokens, int64_t start_pos, int32_t seq_len, @@ -427,8 +420,4 @@ Result LhdTokenGenerator::generate( return pos - start_pos; } -// Explicit instantiations -template class LhdTokenGenerator; -template class LhdTokenGenerator; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h index 796dde88014..8fdffb8af72 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h @@ -15,8 +15,8 @@ namespace example { * @brief Class for generating the token using decoder and key-value manager * with lookahead decoding. */ -template -class LhdTokenGenerator : public TokenGenerator { + +class LhdTokenGenerator : public TokenGenerator { public: struct Metadata { int32_t context_len; @@ -34,18 +34,19 @@ class LhdTokenGenerator : public TokenGenerator { LhdTokenGenerator( tokenizers::Tokenizer* tokenizer, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& forward_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats) - : TokenGenerator( + executorch::llm::Stats* stats, + std::unique_ptr method_meta) + : TokenGenerator( tokenizer, decoder_runner, kv_manager, forward_name, std::move(eos_ids), - typename TokenGenerator::Metadata{ + TokenGenerator::Metadata{ metadata.context_len, metadata.num_heads, metadata.num_layers, @@ -54,7 +55,8 @@ class LhdTokenGenerator : public TokenGenerator { metadata.use_int64_token, metadata.sliding_window, metadata.cache_mode}, - stats), + stats, + std::move(method_meta)), metadata_(metadata), lhd_branch_(metadata.ngram - 1, std::vector(metadata.window)), lhd_branch_prev_(metadata.window), @@ -104,7 +106,7 @@ class LhdTokenGenerator : public TokenGenerator { private: // Bring base class's virtual prepare_io into scope so the overload below // does not hide it (-Woverloaded-virtual). - using TokenGenerator::prepare_io; + using TokenGenerator::prepare_io; /** * @brief Fill in I/O buffers with prompt token and position. * @param cur_token Current token. diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.cpp index 14a93104e1a..de8d1bea0fe 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.cpp @@ -13,8 +13,7 @@ using executorch::runtime::Result; namespace example { -template -void MultimodalLhdTokenGenerator::prepare_io( +void MultimodalLhdTokenGenerator::prepare_io( std::vector input_tokens, std::vector input_pos) { for (int i = 0; i < metadata_.ar_len; i++) { @@ -51,8 +50,7 @@ void MultimodalLhdTokenGenerator::prepare_io( } } -template -void MultimodalLhdTokenGenerator::init_attention_mask(int32_t n_past) { +void MultimodalLhdTokenGenerator::init_attention_mask(int32_t n_past) { std::vector attention_map; attention_map.reserve(metadata_.ar_len); // Initialize attention mask with current position @@ -88,8 +86,7 @@ void MultimodalLhdTokenGenerator::init_attention_mask(int32_t n_past) { } } -template -void MultimodalLhdTokenGenerator::init_lookahead_branch( +void MultimodalLhdTokenGenerator::init_lookahead_branch( const std::vector& tokens) { for (int i = 0; i < metadata_.ngram - 1; ++i) { for (int j = 0; j < metadata_.window; ++j) { @@ -106,9 +103,7 @@ void MultimodalLhdTokenGenerator::init_lookahead_branch( is_lhd_branch_initialized_ = true; } -template -void MultimodalLhdTokenGenerator::init_verification_branch( - uint64_t cur_token) { +void MultimodalLhdTokenGenerator::init_verification_branch(uint64_t cur_token) { const int g_cur = ngrams_pool_.cnt[cur_token]; v_branch_.resize(g_cur); @@ -132,8 +127,7 @@ void MultimodalLhdTokenGenerator::init_verification_branch( } } -template -void MultimodalLhdTokenGenerator::update_ngrams_pool() { +void MultimodalLhdTokenGenerator::update_ngrams_pool() { std::vector ngram(metadata_.ngram - 1); // n-gram pool generation for (int f = 0; f < metadata_.window; ++f) { @@ -186,8 +180,7 @@ void MultimodalLhdTokenGenerator::update_ngrams_pool() { } } -template -void MultimodalLhdTokenGenerator::update_lookahead_branch( +void MultimodalLhdTokenGenerator::update_lookahead_branch( const executorch::aten::Tensor& logits_tensor) { for (int i = 0; i < metadata_.window; i++) { lhd_branch_prev_[i] = lhd_branch_[0][i]; @@ -205,8 +198,7 @@ void MultimodalLhdTokenGenerator::update_lookahead_branch( } } -template -Result MultimodalLhdTokenGenerator::generate( +Result MultimodalLhdTokenGenerator::generate( std::vector tokens, int64_t start_pos, int32_t seq_len, @@ -412,8 +404,4 @@ Result MultimodalLhdTokenGenerator::generate( return pos - start_pos; } -// Explicit instantiations -template class MultimodalLhdTokenGenerator; -template class MultimodalLhdTokenGenerator; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.h index 7494afec6da..6ffe285e536 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.h @@ -15,9 +15,7 @@ namespace example { * @class MultimodalLhdTokenGenerator * @brief Extended LhdTokenGenerator with multimodal embedding support */ -template -class MultimodalLhdTokenGenerator - : public example::MultimodalTokenGenerator { +class MultimodalLhdTokenGenerator : public example::MultimodalTokenGenerator { public: struct Metadata { int32_t context_len; @@ -37,19 +35,20 @@ class MultimodalLhdTokenGenerator tokenizers::Tokenizer* tokenizer, TokenEmbeddingProcessor* embedding_runner, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& forward_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats) - : MultimodalTokenGenerator( + executorch::llm::Stats* stats, + std::unique_ptr method_meta) + : MultimodalTokenGenerator( tokenizer, embedding_runner, decoder_runner, kv_manager, forward_name, std::move(eos_ids), - typename MultimodalTokenGenerator::Metadata{ + MultimodalTokenGenerator::Metadata{ metadata.context_len, metadata.num_heads, metadata.num_layers, @@ -59,7 +58,8 @@ class MultimodalLhdTokenGenerator metadata.sliding_window, metadata.cache_mode, metadata.embedding_dim}, - stats), + stats, + std::move(method_meta)), tok_embedding_runner_(embedding_runner), metadata_(metadata), lhd_branch_(metadata.ngram - 1, std::vector(metadata.window)), @@ -110,7 +110,7 @@ class MultimodalLhdTokenGenerator private: // Bring base class's virtual prepare_io into scope so the overload below // does not hide it (-Woverloaded-virtual). - using TokenGenerator::prepare_io; + using TokenGenerator::prepare_io; /** * @brief Fill in I/O buffers with prompt token and position. * @param cur_token Current token. diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.cpp b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.cpp index 2859e16a42a..f63a431791b 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.cpp @@ -16,13 +16,13 @@ using executorch::runtime::TensorInfo; namespace example { -template -MultimodalPromptProcessor::MultimodalPromptProcessor( +MultimodalPromptProcessor::MultimodalPromptProcessor( DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, - Metadata metadata) - : PromptProcessor( + Metadata metadata, + std::unique_ptr method_meta) + : PromptProcessor( decoder_runner, kv_manager, method_name, @@ -33,7 +33,8 @@ MultimodalPromptProcessor::MultimodalPromptProcessor( metadata.vocab_size, metadata.use_int64_token, metadata.sliding_window, - metadata.cache_mode}), + metadata.cache_mode}, + std::move(method_meta)), metadata_(metadata) { // Set input_toks_.size to 0 since we use embeddings instead input_toks_.size = 0; @@ -41,8 +42,7 @@ MultimodalPromptProcessor::MultimodalPromptProcessor( metadata_.ar_len * metadata_.embedding_dim * sizeof(float); }; -template -void MultimodalPromptProcessor::init_io( +void MultimodalPromptProcessor::init_io( IMemAlloc* buffer_manager, Result method_meta) { size_t idx = 0; @@ -66,8 +66,7 @@ void MultimodalPromptProcessor::init_io( // [I]: attention_mask Result attention_mask = method_meta->input_tensor_meta(idx++); - attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(attention_mask_.size)); + attention_mask_.data = buffer_manager->allocate(attention_mask_.size); attention_mask_.tensor = std::make_unique( attention_mask->scalar_type(), attention_mask->sizes().size(), @@ -83,8 +82,8 @@ void MultimodalPromptProcessor::init_io( if (metadata_.cache_mode == CacheMode::HybridCache) { Result window_attention_mask = method_meta->input_tensor_meta(idx++); - window_attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(window_attention_mask_.size)); + window_attention_mask_.data = + buffer_manager->allocate(window_attention_mask_.size); window_attention_mask_.tensor = std::make_unique( window_attention_mask->scalar_type(), window_attention_mask->sizes().size(), @@ -120,32 +119,29 @@ void MultimodalPromptProcessor::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->input_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].buffer; - cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].buffer, const_cast( kv_cache->dim_order().data())); input_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].buffer, cache[layer]->nbytes(), kv_cache.get()); } } } // [O]: logits Result logits = method_meta->output_tensor_meta(0); - logits_.data = - reinterpret_cast(buffer_manager->allocate(logits_.size)); + logits_.data = buffer_manager->allocate(logits_.size); logits_.tensor = std::make_unique( logits->scalar_type(), logits->sizes().size(), @@ -160,21 +156,22 @@ void MultimodalPromptProcessor::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->output_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].output_buffer; cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].output_buffer, const_cast(kv_cache->dim_order().data())); output_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].output_buffer, + cache[layer]->nbytes(), + kv_cache.get()); } } @@ -186,8 +183,7 @@ void MultimodalPromptProcessor::init_io( } // prepare embedding -template -void MultimodalPromptProcessor::prepare_io( +void MultimodalPromptProcessor::prepare_io( const TensorStruct& prompt_embedding, int32_t num_prompt_tokens, int64_t prompt_pos, @@ -208,8 +204,7 @@ void MultimodalPromptProcessor::prepare_io( } } -template -Result MultimodalPromptProcessor::prefill( +Result MultimodalPromptProcessor::prefill( const TensorStruct& prompt_embedding, int64_t start_pos, bool dump_logits, @@ -301,8 +296,4 @@ Result MultimodalPromptProcessor::prefill( return cur_token; } -// Explicit instantiations -template class MultimodalPromptProcessor; -template class MultimodalPromptProcessor; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.h b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.h index fcfc07c9590..c2769ed9f50 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.h +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.h @@ -16,8 +16,7 @@ namespace example { * @class MultimodalPromptProcessor * @brief Extended PromptProcessor with multimodal embedding support */ -template -class MultimodalPromptProcessor : public example::PromptProcessor { +class MultimodalPromptProcessor : public example::PromptProcessor { public: struct Metadata { int32_t context_len; @@ -33,9 +32,10 @@ class MultimodalPromptProcessor : public example::PromptProcessor { MultimodalPromptProcessor( DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, - Metadata metadata); + Metadata metadata, + std::unique_ptr method_meta); int64_t get_num_heads() const { return metadata_.num_heads; @@ -74,34 +74,29 @@ class MultimodalPromptProcessor : public example::PromptProcessor { * @return Total I/O size in bytes. */ inline const size_t total_prompt_processor_io_size_in_bytes() const { - if (metadata_.cache_mode == CacheMode::HybridCache) { - return input_toks_.size + input_pos_.size + attention_mask_.size + - window_attention_mask_.size + logits_.size + input_embedding_.size; - } else { - return input_toks_.size + input_pos_.size + attention_mask_.size + - logits_.size + input_embedding_.size; - } + return input_toks_.size + input_pos_.size + attention_mask_.size + + window_attention_mask_.size + logits_.size + input_embedding_.size; } private: // Reuse members from token_generator - using PromptProcessor::decoder_runner_; - using PromptProcessor::kv_manager_; - using PromptProcessor::method_name_; - using PromptProcessor::k_cache_in_; - using PromptProcessor::v_cache_in_; - using PromptProcessor::k_cache_out_; - using PromptProcessor::v_cache_out_; - using PromptProcessor::input_toks_; - using PromptProcessor::input_pos_; - using PromptProcessor::attention_mask_; - using PromptProcessor::window_attention_mask_; - using PromptProcessor::logits_; - using PromptProcessor::inputs_; - using PromptProcessor::input_tensors_; - using PromptProcessor::output_tensors_; - using PromptProcessor::prompt_all_logits_; - using PromptProcessor::is_bert; + using PromptProcessor::attention_mask_; + using PromptProcessor::decoder_runner_; + using PromptProcessor::input_pos_; + using PromptProcessor::input_tensors_; + using PromptProcessor::input_toks_; + using PromptProcessor::inputs_; + using PromptProcessor::is_bert; + using PromptProcessor::k_cache_in_; + using PromptProcessor::k_cache_out_; + using PromptProcessor::kv_manager_; + using PromptProcessor::logits_; + using PromptProcessor::method_name_; + using PromptProcessor::output_tensors_; + using PromptProcessor::prompt_all_logits_; + using PromptProcessor::v_cache_in_; + using PromptProcessor::v_cache_out_; + using PromptProcessor::window_attention_mask_; /** * @brief Fill in I/O buffers with embedding data and position. diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp index 32e3baf27a9..32575994222 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp @@ -74,17 +74,17 @@ void print_performance_report( void save_logits( const std::string& dump_logits_path, - const std::vector& prefill_logits, - const std::vector& decode_logits) { + const std::vector& prefill_logits, + const std::vector& decode_logits) { std::ofstream outFile(dump_logits_path.c_str(), std::ios::binary); if (outFile.is_open()) { outFile.write( reinterpret_cast(prefill_logits.data()), - prefill_logits.size() * sizeof(uint16_t)); + prefill_logits.size()); outFile.write( reinterpret_cast(decode_logits.data()), - decode_logits.size() * sizeof(uint16_t)); + decode_logits.size()); outFile.close(); } else { ET_CHECK_MSG(false, "Error saving the dump logits file"); @@ -93,8 +93,7 @@ void save_logits( } // namespace -template -QNNMultimodalRunner::QNNMultimodalRunner( +QNNMultimodalRunner::QNNMultimodalRunner( std::unique_ptr encoder, std::unique_ptr tok_embedding, std::unique_ptr text_decoder, @@ -148,16 +147,14 @@ QNNMultimodalRunner::QNNMultimodalRunner( ET_LOG(Info, "eval mode=%d", eval_mode_); } -template -bool QNNMultimodalRunner::is_loaded() const { +bool QNNMultimodalRunner::is_loaded() const { return encoder_->is_loaded() && tok_embedding_->is_loaded() && text_decoder_->is_loaded() && embedding_merger_ && tokenizer_ && decoder_runner_ && prompt_processor_ && token_generator_ && kv_manager_ && buffer_manager_; } -template -Error QNNMultimodalRunner::load() { +Error QNNMultimodalRunner::load() { if (is_loaded()) { return Error::Ok; } @@ -298,19 +295,22 @@ Error QNNMultimodalRunner::load() { sliding_window = ET_UNWRAP(text_decoder_->get("get_sliding_window")).toInt(); } - kv_manager_ = std::make_unique>(typename KVManager::Metadata{ - context_len_, - head_dim, - max_ar_len, - max_cache_len, - num_heads, - num_layers}); - - prompt_processor_ = std::make_unique>( + kv_manager_ = std::make_unique( + KVManager::Metadata{ + context_len_, + head_dim, + max_ar_len, + max_cache_len, + num_heads, + num_layers}, + std::make_unique(std::move( + text_decoder_->method_meta(token_generator_method_name).get()))); + + prompt_processor_ = std::make_unique( decoder_runner_.get(), kv_manager_.get(), prompt_processor_method_name, - typename MultimodalPromptProcessor::Metadata{ + MultimodalPromptProcessor::Metadata{ context_len_, num_heads, num_layers, @@ -319,7 +319,9 @@ Error QNNMultimodalRunner::load() { use_int64_token, sliding_window, cache_mode_, - static_cast(dim)}); + static_cast(dim)}, + std::make_unique(std::move( + text_decoder_->method_meta(prompt_processor_method_name).get()))); // Initialize EmbeddingGenerator tok_embedding_generator_ = std::make_unique( @@ -333,14 +335,14 @@ Error QNNMultimodalRunner::load() { static_cast(dim)}); if (eval_mode_ == EvalMode::kLookaheadDecoding) { // Initialize TokenGenerator - token_generator_ = std::make_unique>( + token_generator_ = std::make_unique( tokenizer_.get(), tok_embedding_generator_.get(), decoder_runner_.get(), kv_manager_.get(), token_generator_method_name, std::move(eos_ids), - typename MultimodalLhdTokenGenerator::Metadata{ + MultimodalLhdTokenGenerator::Metadata{ context_len_, num_heads, num_layers, @@ -353,16 +355,18 @@ Error QNNMultimodalRunner::load() { sliding_window, cache_mode_, static_cast(dim)}, - &stats_); + &stats_, + std::make_unique(std::move( + text_decoder_->method_meta(token_generator_method_name).get()))); } else { - token_generator_ = std::make_unique>( + token_generator_ = std::make_unique( tokenizer_.get(), tok_embedding_generator_.get(), decoder_runner_.get(), kv_manager_.get(), token_generator_method_name, std::move(eos_ids), - typename MultimodalTokenGenerator::Metadata{ + MultimodalTokenGenerator::Metadata{ context_len_, num_heads, num_layers, @@ -372,7 +376,9 @@ Error QNNMultimodalRunner::load() { sliding_window, cache_mode_, static_cast(dim)}, - &stats_); + &stats_, + std::make_unique(std::move( + text_decoder_->method_meta(token_generator_method_name).get()))); } buffer_manager_ = std::make_unique(); @@ -409,8 +415,7 @@ Error QNNMultimodalRunner::load() { return Error::Ok; } -template -executorch::runtime::Error QNNMultimodalRunner::generate( +executorch::runtime::Error QNNMultimodalRunner::generate( const std::vector& inputs, const llm::GenerationConfig& config, std::function token_callback, @@ -561,8 +566,7 @@ executorch::runtime::Error QNNMultimodalRunner::generate( return Error::Ok; } -template -Result QNNMultimodalRunner::get_model_version() { +Result QNNMultimodalRunner::get_model_version() { if (!is_loaded()) { stats_.model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); @@ -571,16 +575,11 @@ Result QNNMultimodalRunner::get_model_version() { return model_version_; } -template -Result QNNMultimodalRunner::get_encoder_method_meta() { +Result QNNMultimodalRunner::get_encoder_method_meta() { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } return encoder_->method_meta(kEncoderForwardName); } -// Explicit instantiations -template class QNNMultimodalRunner; -template class QNNMultimodalRunner; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.h b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.h index 5407d5712b7..363ded0f055 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.h @@ -66,12 +66,6 @@ inline Modality modality_of(const ModelVersion& model_version) { [](const auto& model) { return modality_of(model); }, model_version); } -enum KvBitWidth { - kWidth8 = 8, - kWidth16 = 16, -}; - -template class QNNMultimodalRunner : public executorch::extension::llm::MultimodalRunner { public: @@ -139,11 +133,11 @@ class QNNMultimodalRunner ModelVersion model_version_; std::unique_ptr buffer_manager_; - std::unique_ptr> kv_manager_; + std::unique_ptr kv_manager_; std::unique_ptr tokenizer_; std::unique_ptr decoder_runner_; - std::unique_ptr> prompt_processor_; - std::unique_ptr> token_generator_; + std::unique_ptr prompt_processor_; + std::unique_ptr token_generator_; std::unique_ptr encoder_runner_; std::unique_ptr tok_embedding_runner_; std::unique_ptr tok_embedding_processor_; diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.cpp index 2ed8ae51f1d..e3f6f8e214e 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.cpp @@ -15,17 +15,17 @@ using executorch::runtime::TensorInfo; namespace example { // Constructor with embedding runner support -template -MultimodalTokenGenerator::MultimodalTokenGenerator( +MultimodalTokenGenerator::MultimodalTokenGenerator( tokenizers::Tokenizer* tokenizer, TokenEmbeddingProcessor* tok_embedding_runner, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats) - : TokenGenerator( + executorch::llm::Stats* stats, + std::unique_ptr method_meta) + : TokenGenerator( tokenizer, decoder_runner, kv_manager, @@ -39,7 +39,8 @@ MultimodalTokenGenerator::MultimodalTokenGenerator( metadata.use_int64_token, metadata.sliding_window, metadata.cache_mode}, - stats), + stats, + std::move(method_meta)), tok_embedding_runner_(tok_embedding_runner), metadata_(metadata) { // Set input_toks_.size to 0 since we use embeddings instead @@ -48,8 +49,7 @@ MultimodalTokenGenerator::MultimodalTokenGenerator( metadata_.ar_len * metadata_.embedding_dim * sizeof(float); } -template -void MultimodalTokenGenerator::init_io( +void MultimodalTokenGenerator::init_io( IMemAlloc* buffer_manager, Result method_meta) { size_t idx = 0; @@ -73,8 +73,7 @@ void MultimodalTokenGenerator::init_io( // [I]: attention_mask Result attention_mask = method_meta->input_tensor_meta(idx++); - attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(attention_mask_.size)); + attention_mask_.data = buffer_manager->allocate(attention_mask_.size); attention_mask_.tensor = std::make_unique( attention_mask->scalar_type(), attention_mask->sizes().size(), @@ -90,8 +89,8 @@ void MultimodalTokenGenerator::init_io( if (metadata_.cache_mode == CacheMode::HybridCache) { Result window_attention_mask = method_meta->input_tensor_meta(idx++); - window_attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(window_attention_mask_.size)); + window_attention_mask_.data = + buffer_manager->allocate(window_attention_mask_.size); window_attention_mask_.tensor = std::make_unique( window_attention_mask->scalar_type(), window_attention_mask->sizes().size(), @@ -126,30 +125,27 @@ void MultimodalTokenGenerator::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->input_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].buffer; - cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].buffer, const_cast(kv_cache->dim_order().data())); input_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].buffer, cache[layer]->nbytes(), kv_cache.get()); } } // [O]: logits Result logits = method_meta->output_tensor_meta(0); - logits_.data = - reinterpret_cast(buffer_manager->allocate(logits_.size)); + logits_.data = buffer_manager->allocate(logits_.size); logits_.tensor = std::make_unique( logits->scalar_type(), logits->sizes().size(), @@ -164,21 +160,22 @@ void MultimodalTokenGenerator::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->output_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].output_buffer; cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].output_buffer, const_cast(kv_cache->dim_order().data())); output_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].output_buffer, + cache[layer]->nbytes(), + kv_cache.get()); } } @@ -190,8 +187,7 @@ void MultimodalTokenGenerator::init_io( } // This function only considers the case where token_generator_ar_len equals 1. -template -void MultimodalTokenGenerator::prepare_io( +void MultimodalTokenGenerator::prepare_io( uint64_t cur_token, int64_t start_pos) { // Generate embedding for current token using embedding runner @@ -209,8 +205,4 @@ void MultimodalTokenGenerator::prepare_io( *input_pos_.data = static_cast(start_pos); } -// Explicit instantiations -template class MultimodalTokenGenerator; -template class MultimodalTokenGenerator; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.h index 9eb9c79aaa4..2d0bf9385b4 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.h @@ -16,8 +16,7 @@ namespace example { * @class MultimodalTokenGenerator * @brief Extended TokenGenerator with multimodal embedding support */ -template -class MultimodalTokenGenerator : public example::TokenGenerator { +class MultimodalTokenGenerator : public example::TokenGenerator { public: struct Metadata { int32_t context_len; @@ -36,11 +35,12 @@ class MultimodalTokenGenerator : public example::TokenGenerator { tokenizers::Tokenizer* tokenizer, TokenEmbeddingProcessor* tok_embedding_runner, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats); + executorch::llm::Stats* stats, + std::unique_ptr method_meta); virtual ~MultimodalTokenGenerator() = default; @@ -54,36 +54,31 @@ class MultimodalTokenGenerator : public example::TokenGenerator { override; inline const size_t total_token_generator_io_size_in_bytes() const { - if (metadata_.cache_mode == CacheMode::HybridCache) { - return input_toks_.size + input_pos_.size + attention_mask_.size + - window_attention_mask_.size + logits_.size + input_embedding_.size; - } else { - return input_toks_.size + input_pos_.size + attention_mask_.size + - logits_.size + input_embedding_.size; - } + return input_toks_.size + input_pos_.size + attention_mask_.size + + window_attention_mask_.size + logits_.size + input_embedding_.size; } protected: // Reuse members from token_generator - using TokenGenerator::kv_manager_; - using TokenGenerator::input_pos_; - using TokenGenerator::attention_mask_; - using TokenGenerator::window_attention_mask_; - using TokenGenerator::inputs_; - using TokenGenerator::input_tensors_; - using TokenGenerator::output_tensors_; + using TokenGenerator::attention_mask_; + using TokenGenerator::input_pos_; + using TokenGenerator::input_tensors_; + using TokenGenerator::inputs_; + using TokenGenerator::kv_manager_; + using TokenGenerator::output_tensors_; + using TokenGenerator::window_attention_mask_; // Additional members specific to multimodal TensorStruct input_embedding_; private: // Reuse members from token_generator - using TokenGenerator::input_toks_; - using TokenGenerator::logits_; - using TokenGenerator::k_cache_in_; - using TokenGenerator::v_cache_in_; - using TokenGenerator::k_cache_out_; - using TokenGenerator::v_cache_out_; + using TokenGenerator::input_toks_; + using TokenGenerator::k_cache_in_; + using TokenGenerator::k_cache_out_; + using TokenGenerator::logits_; + using TokenGenerator::v_cache_in_; + using TokenGenerator::v_cache_out_; // Additional members specific to multimodal TokenEmbeddingProcessor* tok_embedding_runner_; diff --git a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp index 59744d488bd..0cb52246a39 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp @@ -17,12 +17,12 @@ using executorch::runtime::Span; using executorch::runtime::TensorInfo; namespace example { -template -PromptProcessor::PromptProcessor( +PromptProcessor::PromptProcessor( DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, - Metadata metadata) + Metadata metadata, + std::unique_ptr method_meta) : decoder_runner_(decoder_runner), kv_manager_(kv_manager), method_name_(method_name), @@ -32,33 +32,41 @@ PromptProcessor::PromptProcessor( k_cache_out_.resize(metadata_.num_layers); v_cache_out_.resize(metadata_.num_layers); // Calculate I/O size + Result attention_mask = method_meta->input_tensor_meta(1); + Result logits = method_meta->output_tensor_meta(0); input_toks_.size = metadata_.ar_len * sizeof(int64_t); - if (is_bert()) + if (is_bert()) { input_pos_.size = 0; - else + } else { input_pos_.size = metadata_.ar_len * sizeof(int32_t); + } + attention_mask_.dtype = attention_mask->scalar_type(); + attention_mask_.size = metadata_.ar_len * metadata_.context_len * + attention_mask_.getElementSize(); switch (metadata_.cache_mode) { case CacheMode::StaticCahce: - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); window_attention_mask_.size = 0; break; - case CacheMode::HybridCache: - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); - window_attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + case CacheMode::HybridCache: { + Result window_attention_mask = + method_meta->input_tensor_meta(2); + window_attention_mask_.dtype = window_attention_mask->scalar_type(); + window_attention_mask_.size = metadata_.ar_len * metadata_.context_len * + window_attention_mask_.getElementSize(); break; + } default: ET_CHECK_MSG(false, "Unsupported llama cache mode"); break; } - logits_.size = metadata_.ar_len * metadata_.vocab_size * sizeof(uint16_t); + logits_.dtype = logits->scalar_type(); + logits_.size = + metadata_.ar_len * metadata_.vocab_size * logits_.getElementSize(); }; -template -void PromptProcessor::init_io( + +void PromptProcessor::init_io( IMemAlloc* buffer_manager, Result method_meta) { size_t idx = 0; @@ -80,8 +88,7 @@ void PromptProcessor::init_io( // [I]: attention_mask Result attention_mask = method_meta->input_tensor_meta(idx++); - attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(attention_mask_.size)); + attention_mask_.data = buffer_manager->allocate(attention_mask_.size); attention_mask_.tensor = std::make_unique( attention_mask->scalar_type(), attention_mask->sizes().size(), @@ -97,8 +104,8 @@ void PromptProcessor::init_io( if (metadata_.cache_mode == CacheMode::HybridCache) { Result window_attention_mask = method_meta->input_tensor_meta(idx++); - window_attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(window_attention_mask_.size)); + window_attention_mask_.data = + buffer_manager->allocate(window_attention_mask_.size); window_attention_mask_.tensor = std::make_unique( window_attention_mask->scalar_type(), window_attention_mask->sizes().size(), @@ -136,33 +143,30 @@ void PromptProcessor::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->input_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].buffer; - cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].buffer, const_cast( kv_cache->dim_order().data())); input_tensors_.emplace_back(cache[layer].get()); cache_inputs_.emplace_back(input_tensors_.back()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].buffer, cache[layer]->nbytes(), kv_cache.get()); } } } // [O]: logits Result logits = method_meta->output_tensor_meta(0); - logits_.data = - reinterpret_cast(buffer_manager->allocate(logits_.size)); + logits_.data = buffer_manager->allocate(logits_.size); logits_.tensor = std::make_unique( logits->scalar_type(), logits->sizes().size(), @@ -177,21 +181,22 @@ void PromptProcessor::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->output_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].output_buffer; cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].output_buffer, const_cast(kv_cache->dim_order().data())); output_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].output_buffer, + cache[layer]->nbytes(), + kv_cache.get()); } } // Prepare the vector of EValue to run inference @@ -201,13 +206,11 @@ void PromptProcessor::init_io( } } -template -const std::vector& PromptProcessor::get_all_logits() { +const std::vector& PromptProcessor::get_all_logits() { return prompt_all_logits_; } -template -void PromptProcessor::prepare_io( +void PromptProcessor::prepare_io( const std::vector& prompt_tokens, int64_t prompt_pos, int64_t start_pos) { @@ -232,8 +235,7 @@ void PromptProcessor::prepare_io( } } -template -Result PromptProcessor::prefill( +Result PromptProcessor::prefill( std::vector prompt_tokens, int64_t start_pos, bool dump_logits, @@ -339,7 +341,9 @@ Result PromptProcessor::prefill( prompt_all_logits_.insert( prompt_all_logits_.end(), logits_.data, - logits_.data + metadata_.ar_len * metadata_.vocab_size); + logits_.data + + metadata_.ar_len * metadata_.vocab_size * + logits_.getElementSize()); } // In the last run, offset to the meaningful logits. if (i == num_iters - 1) { @@ -369,8 +373,4 @@ Result PromptProcessor::prefill( return cur_token; } -// Explicit instantiations -template class PromptProcessor; -template class PromptProcessor; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h index 599f7050d83..5317a8a77e1 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h +++ b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h @@ -21,7 +21,7 @@ namespace example { * @class PromptProcessor * @brief Class for processing prompts using decoder and key-value manager. */ -template + class PromptProcessor { public: struct Metadata { @@ -36,9 +36,10 @@ class PromptProcessor { }; PromptProcessor( DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, - Metadata metadata); + Metadata metadata, + std::unique_ptr method_meta); virtual ~PromptProcessor() = default; @@ -55,9 +56,9 @@ class PromptProcessor { /** * @brief Get the all logits generated * - * @return std::vector& all the logits generated + * @return std::vector& all the logits generated */ - virtual const std::vector& get_all_logits(); + virtual const std::vector& get_all_logits(); /** * Prefill an LLM Module with the given text input. @@ -79,13 +80,8 @@ class PromptProcessor { * @return Total I/O size in bytes. */ inline const size_t total_prompt_processor_io_size_in_bytes() const { - if (metadata_.cache_mode == CacheMode::HybridCache) { - return input_toks_.size + input_pos_.size + attention_mask_.size + - window_attention_mask_.size + logits_.size; - } else { - return input_toks_.size + input_pos_.size + attention_mask_.size + - logits_.size; - } + return input_toks_.size + input_pos_.size + attention_mask_.size + + window_attention_mask_.size + logits_.size; } protected: @@ -105,7 +101,7 @@ class PromptProcessor { int64_t prompt_pos, int64_t start_pos); DecoderRunner* decoder_runner_; - KVManager* kv_manager_; + KVManager* kv_manager_; std::string method_name_; // metadata @@ -114,9 +110,9 @@ class PromptProcessor { // inputs and outputs TensorStruct input_toks_; TensorStruct input_pos_; - TensorStruct attention_mask_; - TensorStruct window_attention_mask_; - TensorStruct logits_; + TensorStructRaw attention_mask_; + TensorStructRaw window_attention_mask_; + TensorStructRaw logits_; // layer -> TensorImpl std::vector> k_cache_in_; @@ -131,6 +127,6 @@ class PromptProcessor { std::vector cache_inputs_; // Unused by default, only used when dump_logits_path is provided. - std::vector prompt_all_logits_; + std::vector prompt_all_logits_; }; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 0a4a8b9abb5..7257e869dcc 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -66,17 +66,17 @@ void print_performance_report( void save_logits( const std::string& dump_logits_path, - const std::vector& prefill_logits, - const std::vector& decode_logits) { + const std::vector& prefill_logits, + const std::vector& decode_logits) { std::ofstream outFile(dump_logits_path.c_str(), std::ios::binary); if (outFile.is_open()) { outFile.write( reinterpret_cast(prefill_logits.data()), - prefill_logits.size() * sizeof(uint16_t)); + prefill_logits.size()); outFile.write( reinterpret_cast(decode_logits.data()), - decode_logits.size() * sizeof(uint16_t)); + decode_logits.size()); outFile.close(); } else { ET_CHECK_MSG(false, "Error saving the dump logits file"); @@ -85,8 +85,7 @@ void save_logits( } // namespace -template -Runner::Runner( +Runner::Runner( std::unique_ptr module, const std::string& decoder_model_version, const std::string& model_path, @@ -152,14 +151,12 @@ Runner::Runner( ET_LOG(Info, "eval mode=%d", eval_mode_); } -template -bool Runner::is_loaded() const { +bool Runner::is_loaded() const { return module_->is_loaded() && tokenizer_ && decoder_runner_ && prompt_processor_ && token_generator_ && kv_manager_ && buffer_manager_; } -template -Error Runner::load() { +Error Runner::load() { if (is_loaded()) { return Error::Ok; } @@ -275,13 +272,16 @@ Error Runner::load() { if (module_->method_names()->count("get_sliding_window") > 0) { sliding_window = ET_UNWRAP(module_->get("get_sliding_window")).toInt(); } - kv_manager_ = std::make_unique>(typename KVManager::Metadata{ - context_len_, - head_dim, - max_ar_len, - max_cache_len, - num_heads, - num_layers}); + kv_manager_ = std::make_unique( + KVManager::Metadata{ + context_len_, + head_dim, + max_ar_len, + max_cache_len, + num_heads, + num_layers}, + std::make_unique( + std::move(module_->method_meta(token_generator_method_name).get()))); if (attention_sink_rope_module_ != nullptr) { attention_sink_rope_runner_ = std::make_unique( @@ -290,11 +290,11 @@ Error Runner::load() { attention_sink_rope_runner_->load(method_names)); } - prompt_processor_ = std::make_unique>( + prompt_processor_ = std::make_unique( decoder_runner_.get(), kv_manager_.get(), prompt_processor_method_name, - typename PromptProcessor::Metadata{ + PromptProcessor::Metadata{ context_len_, num_heads, num_layers, @@ -302,15 +302,17 @@ Error Runner::load() { vocab_size, use_int64_token, sliding_window, - cache_mode_}); + cache_mode_}, + std::make_unique( + std::move(module_->method_meta(prompt_processor_method_name).get()))); if (eval_mode_ == EvalMode::kLookaheadDecoding) { - token_generator_ = std::make_unique>( + token_generator_ = std::make_unique( tokenizer_.get(), decoder_runner_.get(), kv_manager_.get(), token_generator_method_name, std::move(eos_ids), - typename LhdTokenGenerator::Metadata{ + LhdTokenGenerator::Metadata{ context_len_, num_heads, num_layers, @@ -322,15 +324,17 @@ Error Runner::load() { gcap_, sliding_window, cache_mode_}, - &stats_); + &stats_, + std::make_unique(std::move( + module_->method_meta(token_generator_method_name).get()))); } else { - token_generator_ = std::make_unique>( + token_generator_ = std::make_unique( tokenizer_.get(), decoder_runner_.get(), kv_manager_.get(), token_generator_method_name, std::move(eos_ids), - typename TokenGenerator::Metadata{ + TokenGenerator::Metadata{ context_len_, num_heads, num_layers, @@ -339,7 +343,9 @@ Error Runner::load() { use_int64_token, sliding_window, cache_mode_}, - &stats_); + &stats_, + std::make_unique(std::move( + module_->method_meta(token_generator_method_name).get()))); } buffer_manager_ = std::make_unique(); @@ -360,8 +366,7 @@ Error Runner::load() { return Error::Ok; } -template -Error Runner::generate( +Error Runner::generate( const std::string& prompt, const llm::GenerationConfig& config, std::function token_callback, @@ -370,8 +375,7 @@ Error Runner::generate( prompt, false, config, token_callback, stats_callback); } -template -Error Runner::generate_from_prompt_or_file( +Error Runner::generate_from_prompt_or_file( const std::string& prompt, bool tokenized_prompt, const llm::GenerationConfig& config, @@ -500,8 +504,7 @@ Error Runner::generate_from_prompt_or_file( return Error::Ok; } -template -Result Runner::get_decoder_model_version() { +Result Runner::get_decoder_model_version() { if (!is_loaded()) { stats_.model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); @@ -510,8 +513,4 @@ Result Runner::get_decoder_model_version() { return decoder_model_version_; } -// Explicit instantiations -template class Runner; -template class Runner; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 39ce62c2d9f..5d03a12f61a 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -46,12 +46,6 @@ enum DecoderModelVersion { kGemma2, }; -enum KvBitWidth { - kWidth8 = 8, - kWidth16 = 16, -}; - -template class Runner : public executorch::extension::llm::IRunner { public: explicit Runner( @@ -121,14 +115,15 @@ class Runner : public executorch::extension::llm::IRunner { DecoderModelVersion decoder_model_version_; std::unique_ptr buffer_manager_; - std::unique_ptr> kv_manager_; + std::unique_ptr kv_manager_; std::unique_ptr tokenizer_; std::unique_ptr decoder_runner_; std::unique_ptr attention_sink_rope_runner_; - std::unique_ptr> prompt_processor_; - std::unique_ptr> token_generator_; + std::unique_ptr prompt_processor_; + std::unique_ptr token_generator_; // stats executorch::llm::Stats stats_; }; + } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp index 8ab82d932e1..098fcf9efa6 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp @@ -17,15 +17,15 @@ using executorch::runtime::Span; using executorch::runtime::TensorInfo; namespace example { -template -TokenGenerator::TokenGenerator( +TokenGenerator::TokenGenerator( tokenizers::Tokenizer* tokenizer, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats) + executorch::llm::Stats* stats, + std::unique_ptr method_meta) : tokenizer_(tokenizer), decoder_runner_(decoder_runner), kv_manager_(kv_manager), @@ -39,32 +39,37 @@ TokenGenerator::TokenGenerator( v_cache_out_.resize(metadata_.num_layers); // Calculate I/O size + Result attention_mask = method_meta->input_tensor_meta(1); + Result logits = method_meta->output_tensor_meta(0); + input_toks_.size = metadata_.ar_len * sizeof(int64_t); input_pos_.size = metadata_.ar_len * sizeof(int32_t); - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + attention_mask_.dtype = attention_mask->scalar_type(); + attention_mask_.size = metadata_.ar_len * metadata_.context_len * + attention_mask_.getElementSize(); switch (metadata_.cache_mode) { case CacheMode::StaticCahce: - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); window_attention_mask_.size = 0; break; - case CacheMode::HybridCache: - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); - window_attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + case CacheMode::HybridCache: { + Result window_attention_mask = + method_meta->input_tensor_meta(2); + window_attention_mask_.dtype = window_attention_mask->scalar_type(); + window_attention_mask_.size = metadata_.ar_len * metadata_.context_len * + window_attention_mask_.getElementSize(); break; + } default: ET_CHECK_MSG(false, "Unsupported llama cache mode"); break; } - logits_.size = metadata_.ar_len * metadata_.vocab_size * sizeof(uint16_t); + logits_.dtype = logits->scalar_type(); + logits_.size = + metadata_.ar_len * metadata_.vocab_size * logits_.getElementSize(); } -template -void TokenGenerator::init_io( +void TokenGenerator::init_io( IMemAlloc* buffer_manager, Result method_meta) { size_t idx = 0; @@ -86,8 +91,7 @@ void TokenGenerator::init_io( // [I]: attention_mask Result attention_mask = method_meta->input_tensor_meta(idx++); - attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(attention_mask_.size)); + attention_mask_.data = buffer_manager->allocate(attention_mask_.size); attention_mask_.tensor = std::make_unique( attention_mask->scalar_type(), attention_mask->sizes().size(), @@ -103,8 +107,8 @@ void TokenGenerator::init_io( if (metadata_.cache_mode == CacheMode::HybridCache) { Result window_attention_mask = method_meta->input_tensor_meta(idx++); - window_attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(window_attention_mask_.size)); + window_attention_mask_.data = + buffer_manager->allocate(window_attention_mask_.size); window_attention_mask_.tensor = std::make_unique( window_attention_mask->scalar_type(), window_attention_mask->sizes().size(), @@ -141,31 +145,28 @@ void TokenGenerator::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->input_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].buffer; - cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].buffer, const_cast(kv_cache->dim_order().data())); input_tensors_.emplace_back(cache[layer].get()); cache_inputs_.emplace_back(input_tensors_.back()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].buffer, cache[layer]->nbytes(), kv_cache.get()); } } // [O]: logits Result logits = method_meta->output_tensor_meta(0); - logits_.data = - reinterpret_cast(buffer_manager->allocate(logits_.size)); + logits_.data = buffer_manager->allocate(logits_.size); logits_.tensor = std::make_unique( logits->scalar_type(), logits->sizes().size(), @@ -180,21 +181,22 @@ void TokenGenerator::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->output_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].output_buffer; cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].output_buffer, const_cast(kv_cache->dim_order().data())); output_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].output_buffer, + cache[layer]->nbytes(), + kv_cache.get()); } } // Prepare the vector of EValue to run inference @@ -204,14 +206,12 @@ void TokenGenerator::init_io( } } -template -const std::vector& TokenGenerator::get_all_logits() { +const std::vector& TokenGenerator::get_all_logits() { return token_all_logits_; } // This function only considers the case where token_generator_ar_len equals 1. -template -void TokenGenerator::prepare_io(uint64_t cur_token, int64_t start_pos) { +void TokenGenerator::prepare_io(uint64_t cur_token, int64_t start_pos) { // update input_tok *input_toks_.data = metadata_.use_int64_token ? cur_token : static_cast(cur_token); @@ -219,8 +219,7 @@ void TokenGenerator::prepare_io(uint64_t cur_token, int64_t start_pos) { *input_pos_.data = static_cast(start_pos); } -template -Result TokenGenerator::generate( +Result TokenGenerator::generate( std::vector tokens, int64_t start_pos, int32_t seq_len, @@ -306,7 +305,9 @@ Result TokenGenerator::generate( token_all_logits_.insert( token_all_logits_.end(), logits_.data, - logits_.data + metadata_.ar_len * metadata_.vocab_size); + logits_.data + + metadata_.ar_len * metadata_.vocab_size * + logits_.getElementSize()); } ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); executorch::aten::Tensor& logits_tensor = logits_res.get(); @@ -374,8 +375,5 @@ Result TokenGenerator::generate( return pos - start_pos; } -// Explicit instantiations -template class TokenGenerator; -template class TokenGenerator; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/token_generator.h index 7f9264b1102..6945d907a76 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.h @@ -22,7 +22,7 @@ namespace example { * @class TokenGenerator * @brief Class for generating the token using decoder and key-value manager. */ -template + class TokenGenerator { public: struct Metadata { @@ -38,11 +38,12 @@ class TokenGenerator { TokenGenerator( tokenizers::Tokenizer* tokenizer, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats); + executorch::llm::Stats* stats, + std::unique_ptr method_meta); virtual ~TokenGenerator() = default; /** @@ -58,9 +59,9 @@ class TokenGenerator { /** * @brief Get the all logits generated * - * @return std::vector& all the logits generated + * @return std::vector& all the logits generated */ - virtual const std::vector& get_all_logits(); + virtual const std::vector& get_all_logits(); /**    * @brief Generate tokens. @@ -78,28 +79,23 @@ class TokenGenerator { bool dump_logits, AttentionSinkRopeRunner* attention_sink_rope_runner); inline const size_t total_token_generator_io_size_in_bytes() const { - if (metadata_.cache_mode == CacheMode::HybridCache) { - return input_toks_.size + input_pos_.size + attention_mask_.size + - window_attention_mask_.size + logits_.size; - } else { - return input_toks_.size + input_pos_.size + attention_mask_.size + - logits_.size; - } + return input_toks_.size + input_pos_.size + attention_mask_.size + + window_attention_mask_.size + logits_.size; } protected: tokenizers::Tokenizer* tokenizer_; DecoderRunner* decoder_runner_; - KVManager* kv_manager_; + KVManager* kv_manager_; std::string method_name_; std::unique_ptr> eos_ids_; // inputs and outputs TensorStruct input_toks_; TensorStruct input_pos_; - TensorStruct attention_mask_; - TensorStruct window_attention_mask_; - TensorStruct logits_; + TensorStructRaw attention_mask_; + TensorStructRaw window_attention_mask_; + TensorStructRaw logits_; // layer -> TensorImpl std::vector> k_cache_in_; @@ -128,6 +124,6 @@ class TokenGenerator { Metadata metadata_; // Unused by default, only used when dump_logits_path is provided. - std::vector token_all_logits_; + std::vector token_all_logits_; }; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/utils.h b/examples/qualcomm/oss_scripts/llama/runner/utils.h index bef6b1a2017..df6dddfdc6e 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/utils.h +++ b/examples/qualcomm/oss_scripts/llama/runner/utils.h @@ -8,10 +8,16 @@ #pragma once #include +#include #include #include // Template struct to hold tensor data and tensor + +// TODO: Refactor these struct to use TensorPtr +// see https://docs.pytorch.org/executorch/stable/extension-tensor.html + +// TensorStruct whose dtype known in compile time template struct TensorStruct { std::unique_ptr tensor; @@ -20,3 +26,38 @@ struct TensorStruct { // data size in bytes size_t size; }; + +inline size_t getDtypeSize(executorch::aten::ScalarType dtype) { + switch (dtype) { + case executorch::aten::ScalarType::Float: + return sizeof(float); + case executorch::aten::ScalarType::Double: + return sizeof(double); + case executorch::aten::ScalarType::Int: + return sizeof(int32_t); + case executorch::aten::ScalarType::Long: + return sizeof(int64_t); + case executorch::aten::ScalarType::Byte: + return sizeof(uint8_t); + case executorch::aten::ScalarType::UInt16: + return sizeof(uint16_t); + default: + ET_CHECK_MSG( + false, + "Unsupported scalar type %s", + executorch::runtime::toString(dtype)); + break; + } +} + +// TensorStruct whose dtype known in runtime, and raw file is used +struct TensorStructRaw { + std::unique_ptr tensor; + std::byte* data; + // data size in bytes + size_t size; + executorch::aten::ScalarType dtype; + size_t getElementSize() const { + return getDtypeSize(dtype); + } +}; diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py index 48386f181d8..de857dfc17c 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py @@ -13,6 +13,7 @@ import torch from executorch.backends.qualcomm._passes import TagQuantIO +from executorch.backends.qualcomm._passes.build_quant_io import BuildQuantIo from executorch.backends.qualcomm._passes.qnn_pass_manager import ( get_capture_program_passes, ) @@ -460,6 +461,7 @@ def compile(self, attention_sink_evictor_pte_path: str): alloc_graph_input=False, alloc_graph_output=False, ), + passes=[BuildQuantIo()], extract_delegate_segments=True, ) exec_prog_mgr = edge_prog_mgr.to_executorch(executorch_config) diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py index ef72e0765fd..0d5052c89bd 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py @@ -19,6 +19,7 @@ import torch from executorch.backends.qualcomm._passes import FoldQDQ, I64toI32, TagQuantIO +from executorch.backends.qualcomm._passes.build_quant_io import BuildQuantIo from executorch.backends.qualcomm._passes.qnn_pass_manager import ( get_capture_program_passes, ) @@ -607,23 +608,28 @@ def quantize(self, request: Request): # noqa: C901 ): return + data = request.method_data[TEXT_DECODER] # check bit width graph io fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32} - if self.quant_recipe.get_kv_io_bit_width() == 8: - fixed_point_type["kv_type"] = torch.uint8 - elif self.quant_recipe.get_kv_io_bit_width() == 16: - fixed_point_type["kv_type"] = torch.uint16 + if data.skip_quantize: + # already init as float32 + return else: - raise RuntimeError( - f"unknown kv io bit width {self.quant_recipe.get_kv_io_bit_width()}" - ) + if self.quant_recipe.get_kv_io_bit_width() == 8: + fixed_point_type["kv_type"] = torch.uint8 + elif self.quant_recipe.get_kv_io_bit_width() == 16: + fixed_point_type["kv_type"] = torch.uint16 + else: + raise RuntimeError( + f"unknown kv io bit width {self.quant_recipe.get_kv_io_bit_width()}" + ) - if self.quant_recipe.get_logits_output_bit_width() == 16: - fixed_point_type["io_type"] = torch.uint16 - else: - raise RuntimeError( - f"unknown logits io bit width {self.quant_recipe.get_logits_output_bit_width()}" - ) + if self.quant_recipe.get_logits_output_bit_width() == 16: + fixed_point_type["io_type"] = torch.uint16 + else: + raise RuntimeError( + f"unknown logits io bit width {self.quant_recipe.get_logits_output_bit_width()}" + ) data = request.method_data[TEXT_DECODER] audio_turns = request.method_data[ @@ -906,7 +912,11 @@ def compile(self, request: Request): # noqa: C901 # here we use a mechanism to make sure the encoding align correctly and # save AoT quantization time as well. # --- - if self.prefill.decoder is not None and self.prefill.model_args.use_kv_cache: + if ( + self.prefill.decoder is not None + and self.prefill.model_args.use_kv_cache + and not request.method_data[TEXT_DECODER].skip_quantize + ): self._encoding_override( decode_model=self.decode.decoder, prefill_model=self.prefill.decoder, @@ -973,6 +983,7 @@ def compile(self, request: Request): # noqa: C901 alloc_graph_input=False, alloc_graph_output=False, ), + passes=[BuildQuantIo()], ) tok_embedding_exec_prog_mgr = tok_embedding_edge_prog_mgr.to_executorch( executorch_config @@ -1009,6 +1020,7 @@ def compile(self, request: Request): # noqa: C901 alloc_graph_input=False, alloc_graph_output=False, ), + passes=[BuildQuantIo()], ) exec_prog_mgr = edge_prog_mgr.to_executorch(executorch_config) data = request.method_data[TEXT_DECODER] @@ -1127,7 +1139,9 @@ def compile(self, request: Request): if self.control_args.verbose: print_delegation_info(edge_prog_mgr.exported_program().graph_module) - exec_prog_mgr = edge_prog_mgr.to_executorch(ExecutorchBackendConfig()) + exec_prog_mgr = edge_prog_mgr.to_executorch( + ExecutorchBackendConfig(passes=[BuildQuantIo()]) + ) data = request.method_data[self.modality] with open( f"{self.control_args.artifact}/{data.pte_filename}.pte", "wb" @@ -1223,6 +1237,7 @@ def compile( self, compile_specs: Dict[str, List[CompileSpec]], pte_filenames: Dict[str, str], + skip_quantize: Dict[str, bool], ): compile_request = Request( inspect.currentframe().f_code.co_name, @@ -1230,6 +1245,7 @@ def compile( m: Request.Data( compile_spec=compile_specs[m], pte_filename=pte_filenames[m], + skip_quantize=skip_quantize[m] if m in skip_quantize else False, ) for m in self._modalities }, diff --git a/examples/riscv/README.md b/examples/riscv/README.md index 563ff4913fd..88dd08f7715 100644 --- a/examples/riscv/README.md +++ b/examples/riscv/README.md @@ -1,41 +1,35 @@ # RISC-V -Cross-compile `executor_runner` for `riscv64-linux-gnu` and run it under -`qemu-user-static` against a small bundled program. The end-to-end check -mirrors the Arm Cortex-M e2e flow: a `Test_result: PASS` line in stdout from -the bundled-IO comparison path is the pass criterion. +End-to-end smoke tests that cross-compile ExecuTorch for RISC-V and run a bundled program under QEMU. A `Test_result: PASS` line emitted by the bundled-IO comparison path is the pass criterion. -This is the Phase 1 deliverable for the RISC-V Support RFC at -[pytorch/executorch#18991][rfc]. The cross-compile and runner artifacts -(toolchain file, preset, AOT script) are designed to carry over unchanged -to a hardware-runner job once one becomes available; only the invocation -step (qemu-user vs. native) would change. - -[rfc]: https://github.com/pytorch/executorch/issues/18991 +Part of the RISC-V Support RFC, [pytorch/executorch#18991](https://github.com/pytorch/executorch/issues/18991). ## Quick start (Ubuntu / Debian) ```bash -examples/riscv/setup.sh # apt: gcc-riscv64-linux-gnu, qemu-user-static -examples/riscv/run.sh # export, cross-compile, run under qemu-user +examples/riscv/setup.sh # apt: gcc cross + qemu-user/qemu-system + picolibc +examples/riscv/run.sh # export, cross-compile, run under qemu ``` -The driver does three steps: +`run.sh` accepts: + +| Flag | Values | Default | Notes | +|---|---|---|---| +| `--model=` | `add`, `mv2`, `mobilebert`, `llama2`, `resnet18`, `yolo26` | `add` | which model to export | +| `--quantize` | flag | off | XNNPACK quantizer (requires `--backend=xnnpack`) | +| `--backend=` | `portable`, `xnnpack` | `portable` | xnnpack is linux-only | +| `--os=` | `linux`, `baremetal` | `linux` | qemu-user vs qemu-system + semihosting | +| `--arch=` | `rv64` | `rv64` | (rv32 follow-up; no `riscv32-linux-gnu` cross is packaged on Ubuntu) | +| `--qemu-cpu-ext=` | e.g. `v=true,vlen=128` | empty | extensions appended after the arch base | + +## Pipelines + +**linux**: `aot_riscv.py` → `cmake --preset riscv64-linux` → `executor_runner` under `qemu-riscv64`. Portable kernels + (optional) XNNPACK delegate. + +**baremetal**: `aot_riscv.py` → `cmake -S examples/riscv/baremetal` (standalone project; pulls executorch in via `add_subdirectory`) → `executor_runner_baremetal.elf` under `qemu-system-riscv64 -machine virt -bios none -semihosting-config target=native`. -1. `python examples/riscv/aot_riscv.py` exports a `torch.add` module to - `riscv_test/add_riscv.bpte` (a BundledProgram with reference outputs - embedded for two test cases). -2. `cmake --preset riscv64-linux` configures the cross-build using - `examples/riscv/riscv64-linux-gnu-toolchain.cmake` and - `tools/cmake/preset/riscv64_linux.cmake`. `executor_runner` is built - against portable kernels with `ET_BUNDLE_IO_ENABLED` defined. -3. `qemu-riscv64-static` invokes the runner with `--model_path` pointing at - the `.bpte`. The runner detects the bundle, runs every embedded test case, - and emits `Test_result: PASS` (or `FAIL`) per case. +The baremetal runner embeds the `.bpte` directly in `.rodata` via the same `examples/arm/executor_runner/pte_to_header.py` Cortex-M uses; semihosting SYS_WRITE0 / SYS_EXIT carry log output and exit status to the host. ## CI -`.github/workflows/_test_riscv_qemu.yml` is a reusable `workflow_call` -job (mirroring `_test_cortex_m_e2e.yml`) invoked from `pull.yml` to run on -every PR. It runs on the standard `linux.2xlarge` x86_64 runner using the -`executorch-ubuntu-22.04-gcc11` docker image. +`.github/workflows/riscv64.yml` is the entry point; it fans out into `_test_riscv.yml` over a `(model, backend, os, arch, quantize)` matrix and sweeps `qemu-cpu-ext` per backend. Runs on the `executorch-ubuntu-26.04-gcc15` docker image (needed for the `riscv64-unknown-elf` picolibc + libstdc++ packages - see [setup.sh](setup.sh)). diff --git a/examples/riscv/aot_riscv.py b/examples/riscv/aot_riscv.py index 529e2b1e767..e01fe6f954e 100644 --- a/examples/riscv/aot_riscv.py +++ b/examples/riscv/aot_riscv.py @@ -3,11 +3,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""AOT export for the RISC-V smoke test. +"""AOT export for the RISC-V smoke tests. -Exports a small model to a BundledProgram (.bpte) that the portable -executor_runner can load on a riscv64 target and verify against the embedded -reference output, emitting ``Test_result: PASS`` on success. +Exports the model selected by ``--model`` to a BundledProgram (.bpte) that +either ``executor_runner`` (linux) or ``executor_runner_baremetal`` (qemu +virt + semihosting) consumes. The bundled-IO comparison path inside the +runner emits ``Test_result: PASS`` per testset, which is what run.sh greps. """ import argparse @@ -114,12 +115,45 @@ def build_resnet18(): return model, example_inputs, test_inputs, False +def build_yolo26(): + # Mirrors examples/models/yolo26/export_and_validate.py: predict() once + # to materialise the predictor state Ultralytics expects pre-export. + import numpy as np + from ultralytics import YOLO + + input_h, input_w = 320, 320 + yolo = YOLO("yolo26n") + yolo.predict( + np.ones((input_h, input_w, 3)), + imgsz=(input_h, input_w), + device="cpu", + ) + + class Wrapper(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = yolo.model.to(torch.device("cpu")).eval() + + def forward(self, x): + # yolo.model emits (predictions, feature_maps) in eval; keep the + # predictions tensor so BundledIO sees a single tensor output. + out = self.model(x) + return out[0] if isinstance(out, (tuple, list)) else out + + model = Wrapper().eval() + torch.manual_seed(0) + example_inputs = (torch.randn(1, 3, input_h, input_w),) + test_inputs = [example_inputs] + return model, example_inputs, test_inputs, False + + MODELS = { "add": build_add, "mv2": build_mv2, "mobilebert": build_mobilebert, "llama2": build_llama2, "resnet18": build_resnet18, + "yolo26": build_yolo26, } @@ -138,9 +172,19 @@ def main() -> None: help="Output .bpte path (default: _riscv.bpte)", ) parser.add_argument( - "--xnnpack", - action="store_true", - help="Lower through the XNNPACK partitioner", + "--backend", + choices=("portable", "xnnpack"), + default="portable", + help="AOT backend: 'portable' runs everything on the portable kernels, " + "'xnnpack' adds the XNNPACK partitioner (default: portable)", + ) + parser.add_argument( + "--os", + choices=("linux", "baremetal"), + default="linux", + help="Target OS for the runner that will consume this .bpte. The .bpte " + "itself is OS-independent; the flag is logged so callers can verify " + "the AOT/runtime sides agree (default: linux)", ) parser.add_argument( "--quantize", @@ -154,6 +198,13 @@ def main() -> None: ) args = parser.parse_args() + if args.debug_xnnpack and args.backend != "xnnpack": + parser.error("--debug-xnnpack requires --backend=xnnpack") + + # xnnpack pulls in pthreads + dynamic loading; baremetal runner doesn't have those. + if args.os == "baremetal" and args.backend == "xnnpack": + parser.error("--backend=xnnpack is not supported on --os=baremetal") + if args.debug_xnnpack: logging.basicConfig(level=logging.DEBUG) @@ -176,7 +227,7 @@ def main() -> None: exported = export(model, example_inputs, strict=strict) partitioners = [] - if args.xnnpack: + if args.backend == "xnnpack": from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( XnnpackPartitioner, ) @@ -190,7 +241,9 @@ def main() -> None: compile_config = EdgeCompileConfig(_check_ir_validity=False) edge = to_edge_transform_and_lower( - exported, partitioner=partitioners, compile_config=compile_config + exported, + partitioner=partitioners, + compile_config=compile_config, ) delegated = sum( 1 @@ -198,7 +251,7 @@ def main() -> None: if n.op == "call_function" and "call_delegate" in str(n.target) ) print( - f"[aot_riscv] model={args.model} xnnpack={args.xnnpack} " + f"[aot_riscv] model={args.model} backend={args.backend} os={args.os} " f"quantize={args.quantize} delegated_nodes={delegated}" ) diff --git a/examples/riscv/baremetal/CMakeLists.txt b/examples/riscv/baremetal/CMakeLists.txt new file mode 100644 index 00000000000..b7765c4e3a1 --- /dev/null +++ b/examples/riscv/baremetal/CMakeLists.txt @@ -0,0 +1,117 @@ +# Copyright 2026 The ExecuTorch Authors. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Standalone runner project, invoked from examples/riscv/run.sh as: +# ~~~ +# cmake -S examples/riscv/baremetal -B \ +# -DEXECUTORCH_ROOT= \ +# -DRISCV_BAREMETAL_PTE=.bpte \ +# -DCMAKE_TOOLCHAIN_FILE=.../riscv{32,64}-unknown-elf-toolchain.cmake +# ~~~ +# Mirrors examples/arm/executor_runner/standalone/CMakeLists.txt so the +# top-level executorch CMake has no reference to examples/riscv/. + +cmake_minimum_required(VERSION 3.20) +project(riscv_executor_runner_baremetal LANGUAGES C CXX ASM) + +get_filename_component( + _default_executorch_root "${CMAKE_CURRENT_LIST_DIR}/../../.." ABSOLUTE +) +if(NOT DEFINED EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT + "${_default_executorch_root}" + CACHE PATH "Path to the ExecuTorch checkout" + ) +endif() +if(NOT EXISTS "${EXECUTORCH_ROOT}/CMakeLists.txt") + message( + FATAL_ERROR + "EXECUTORCH_ROOT (${EXECUTORCH_ROOT}) does not contain an ExecuTorch CMake project." + ) +endif() + +set(RISCV_BAREMETAL_PTE + "" + CACHE FILEPATH "Path to the .bpte to embed in the baremetal runner" +) +if(NOT RISCV_BAREMETAL_PTE) + message( + FATAL_ERROR + "RISCV_BAREMETAL_PTE not set; pass -DRISCV_BAREMETAL_PTE= from run.sh" + ) +endif() + +include("${EXECUTORCH_ROOT}/tools/cmake/common/preset.cmake") +if(NOT DEFINED EXECUTORCH_BUILD_PRESET_FILE) + set(EXECUTORCH_BUILD_PRESET_FILE + "${EXECUTORCH_ROOT}/tools/cmake/preset/riscv64_baremetal.cmake" + CACHE PATH "Preset used when configuring the standalone baremetal runner" + ) +endif() +load_build_preset() +include("${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake") + +add_subdirectory( + "${EXECUTORCH_ROOT}" "${CMAKE_BINARY_DIR}/executorch" EXCLUDE_FROM_ALL +) + +find_package(Python3 REQUIRED COMPONENTS Interpreter) + +set(_pte_header "${CMAKE_CURRENT_BINARY_DIR}/model_pte.h") +add_custom_command( + OUTPUT "${_pte_header}" + COMMAND + "${Python3_EXECUTABLE}" + "${EXECUTORCH_ROOT}/examples/arm/executor_runner/pte_to_header.py" --pte + "${RISCV_BAREMETAL_PTE}" --outdir "${CMAKE_CURRENT_BINARY_DIR}" --outfile + "model_pte.h" --section ".rodata.model_pte" + DEPENDS "${RISCV_BAREMETAL_PTE}" + COMMENT "Embedding ${RISCV_BAREMETAL_PTE} into model_pte.h" + VERBATIM +) + +# pte_to_header.py emits the byte array but not its length; the glue TU +# materialises the matching `model_pte_len` and is the only place the header is +# included (avoids a double-definition at link time). +file( + WRITE "${CMAKE_CURRENT_BINARY_DIR}/model_pte_glue.cpp" + "#include \n#include \"model_pte.h\"\nextern \"C\" const size_t model_pte_len = sizeof(model_pte);\n" +) + +add_executable( + executor_runner_baremetal + start.S executor_runner_baremetal.cpp + "${CMAKE_CURRENT_BINARY_DIR}/model_pte_glue.cpp" "${_pte_header}" +) +set_target_properties( + executor_runner_baremetal PROPERTIES SUFFIX ".elf" LINKER_LANGUAGE CXX +) +target_include_directories( + executor_runner_baremetal PRIVATE "${CMAKE_CURRENT_BINARY_DIR}" +) +target_compile_options( + executor_runner_baremetal PRIVATE -fno-exceptions -fno-rtti -fdata-sections + -ffunction-sections +) +# --specs=picolibc.specs / -nostartfiles / -march / -mabi all come from the +# toolchain file; only the linker script (QEMU virt memory map) is target- +# specific here. +target_link_options( + executor_runner_baremetal PRIVATE + "-T${CMAKE_CURRENT_SOURCE_DIR}/riscv_virt.ld" +) + +# gen_operators_lib / executorch_target_link_options_shared_lib attach INTERFACE +# --whole-archive options to portable_ops_lib (so the static-init +# kernel-registration TU survives DCE) and to executorch itself. Listing the +# libs once each is enough; an extra --whole-archive wrapper around them would +# include the same archive twice and double-register every op. +target_link_libraries(executor_runner_baremetal PRIVATE bundled_program) +if(TARGET portable_ops_lib) + target_link_libraries(executor_runner_baremetal PRIVATE portable_ops_lib) +endif() +if(TARGET portable_kernels) + target_link_libraries(executor_runner_baremetal PRIVATE portable_kernels) +endif() diff --git a/examples/riscv/baremetal/executor_runner_baremetal.cpp b/examples/riscv/baremetal/executor_runner_baremetal.cpp new file mode 100644 index 00000000000..d0bb128bd98 --- /dev/null +++ b/examples/riscv/baremetal/executor_runner_baremetal.cpp @@ -0,0 +1,286 @@ +/* + * Copyright 2026 The ExecuTorch Authors. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Baremetal runner for qemu-system-riscv64 -machine virt + semihosting. Loads +// a .bpte embedded into the ELF and emits "TEST: BundleIO index[N] +// Test_result: PASS|FAIL" via ET_LOG so examples/riscv/run.sh's grep can +// detect success without a host filesystem. + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "semihosting.h" + +extern "C" const uint8_t model_pte[]; +extern "C" const size_t model_pte_len; + +using executorch::extension::BufferDataLoader; +using executorch::runtime::Error; +using executorch::runtime::HierarchicalAllocator; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::MemoryManager; +using executorch::runtime::Method; +using executorch::runtime::MethodMeta; +using executorch::runtime::Program; +using executorch::runtime::Result; +using executorch::runtime::Span; + +namespace { + +// Pools are sized for the largest model we currently test (llama2 / yolo26) +// rather than per-model; the .bss grows but freestanding picolibc never +// allocates from it so the cost is just a bigger ELF. Bumping these requires +// matching headroom in riscv_virt.ld's RAM region and qemu's -m flag. +alignas(16) uint8_t method_allocator_pool[1u << 23]; // 8 MiB +alignas(16) uint8_t temp_allocator_pool[1u << 22]; // 4 MiB +alignas(16) uint8_t planned_memory_pool[1u << 26]; // 64 MiB + +constexpr size_t kMaxPlannedBuffers = 8; +constexpr double kRtol = 0.01; +constexpr double kAtol = 0.01; + +} // namespace + +extern "C" [[noreturn]] void baremetal_exit(int status) { + executorch::riscv::baremetal::semihost_exit(status); +} + +// picolibc's abort()/raise() resolve _exit; with our own start.S we don't +// link its crt0, so reroute it to the semihosting trap. +extern "C" [[noreturn]] void _exit(int status) { + executorch::riscv::baremetal::semihost_exit(status); +} + +// libstdc++'s drags std::random_device → getentropy/read. The portable +// rand kernels are never invoked at runtime for our bundled-IO tests, so a +// failing stub is enough to satisfy the link. +extern "C" int getentropy(void*, size_t) { + return -1; +} +extern "C" long read(int, void*, size_t) { + return -1; +} + +// Virtual destructors emit deleting variants that reference operator delete +// even when we never new/delete. Stubs satisfy the linker; never called. +void operator delete(void*) noexcept {} +void operator delete(void*, size_t) noexcept {} +void operator delete[](void*) noexcept {} +void operator delete[](void*, size_t) noexcept {} + +// op_rand / op_native_dropout / op_randn from portable_kernels reference +// std::random_device::_M_{init,getval,fini}, whose only definitions live in +// libstdc++.a's medlow-built random.o (won't relocate at 0x80000000). The +// bundled-IO smoke tests never invoke those ops, so satisfy the linker with +// no-op trampolines under the Itanium-mangled names. +asm(R"( + .globl _ZNSt13random_device7_M_initERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE + .type _ZNSt13random_device7_M_initERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE, @function +_ZNSt13random_device7_M_initERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE: + ret + + .globl _ZNSt13random_device9_M_getvalEv + .type _ZNSt13random_device9_M_getvalEv, @function +_ZNSt13random_device9_M_getvalEv: + li a0, 0 + ret + + .globl _ZNSt13random_device7_M_finiEv + .type _ZNSt13random_device7_M_finiEv, @function +_ZNSt13random_device7_M_finiEv: + ret +)"); + +// Route ET_LOG through semihosting. Messages aren't null-terminated; copy and +// append \n\0 before forwarding to SYS_WRITE0. +extern "C" void et_pal_emit_log_message( + et_timestamp_t, + et_pal_log_level_t, + const char*, + const char*, + size_t, + const char* message, + size_t length) { + // The bundle doesn't expose a testset count, so we probe past the end and + // rely on InvalidArgument to terminate the loop. The accompanying ET_LOG + // ("testset_idx N is out of range ...") is benign noise — suppress it so + // run.sh's PASS/FAIL grep stays clean. + static const char kOorPrefix[] = "testset_idx "; + if (length >= sizeof(kOorPrefix) - 1 && + std::memcmp(message, kOorPrefix, sizeof(kOorPrefix) - 1) == 0) { + return; + } + char buf[512]; + size_t n = length < sizeof(buf) - 2 ? length : sizeof(buf) - 2; + std::memcpy(buf, message, n); + buf[n] = '\n'; + buf[n + 1] = '\0'; + executorch::riscv::baremetal::semihost_write0(buf); +} + +extern "C" void et_pal_init(void) {} +extern "C" [[noreturn]] void et_pal_abort(void) { + executorch::riscv::baremetal::semihost_exit(1); +} +extern "C" et_timestamp_t et_pal_current_ticks(void) { + return 0; +} +extern "C" et_tick_ratio_t et_pal_ticks_to_ns_multiplier(void) { + return {1, 1}; +} +extern "C" void* et_pal_allocate(size_t) { + return nullptr; +} +extern "C" void et_pal_free(void*) {} + +int main() { + executorch::runtime::runtime_init(); + + const void* program_data = nullptr; + size_t program_size = 0; + Error status = executorch::bundled_program::get_program_data( + const_cast(model_pte), + model_pte_len, + &program_data, + &program_size); + if (status != Error::Ok) { + ET_LOG( + Error, "get_program_data failed: 0x%x", static_cast(status)); + return 1; + } + + BufferDataLoader loader(program_data, program_size); + Result program = Program::load(&loader); + if (!program.ok()) { + ET_LOG( + Error, + "Program::load failed: 0x%x", + static_cast(program.error())); + return 1; + } + + // The harness always exports a single "forward" method. Skipping the + // Result deref of program->get_method_name(0) sidesteps a + // codegen wedge we hit under -mcmodel=medany + picolibc. + const char* method_name = "forward"; + ET_LOG(Info, "Using method %s", method_name); + + Result method_meta = program->method_meta(method_name); + if (!method_meta.ok()) { + ET_LOG( + Error, + "method_meta failed: 0x%x", + static_cast(method_meta.error())); + return 1; + } + + MemoryAllocator method_allocator( + sizeof(method_allocator_pool), method_allocator_pool); + MemoryAllocator temp_allocator( + sizeof(temp_allocator_pool), temp_allocator_pool); + + // One span per planned buffer, bumped through a single .bss arena so we + // don't need a heap. kMaxPlannedBuffers / pool size both grow with bigger + // models; failures here are loud rather than silent. + Span planned_spans[kMaxPlannedBuffers]; + size_t num_planned = method_meta->num_memory_planned_buffers(); + if (num_planned > kMaxPlannedBuffers) { + ET_LOG( + Error, + "num_planned=%zu exceeds kMaxPlannedBuffers=%zu", + num_planned, + kMaxPlannedBuffers); + return 1; + } + size_t offset = 0; + for (size_t id = 0; id < num_planned; ++id) { + size_t sz = + static_cast(method_meta->memory_planned_buffer_size(id).get()); + sz = (sz + 15u) & ~15u; + if (offset + sz > sizeof(planned_memory_pool)) { + ET_LOG( + Error, + "planned buffer %zu (size %zu) overflows pool (%zu/%zu)", + id, + sz, + offset, + sizeof(planned_memory_pool)); + return 1; + } + planned_spans[id] = Span(planned_memory_pool + offset, sz); + offset += sz; + } + HierarchicalAllocator planned_memory( + Span>(planned_spans, num_planned)); + MemoryManager memory_manager( + &method_allocator, &planned_memory, &temp_allocator); + + Result method = program->load_method(method_name, &memory_manager); + if (!method.ok()) { + ET_LOG( + Error, + "load_method failed: 0x%x", + static_cast(method.error())); + return 1; + } + + // load_bundled_input returns InvalidArgument past the last testset; that's + // how we detect the loop terminator (the bundle has no public count API). + int rc = 0; + for (size_t testset_idx = 0;; ++testset_idx) { + Error load = executorch::bundled_program::load_bundled_input( + *method, const_cast(model_pte), testset_idx); + if (load != Error::Ok) { + if (testset_idx == 0) { + ET_LOG( + Error, + "load_bundled_input failed for testset 0: 0x%x", + static_cast(load)); + rc = 1; + } + break; + } + Error exec = method->execute(); + if (exec != Error::Ok) { + ET_LOG( + Error, + "execute failed for testset %zu: 0x%x", + testset_idx, + static_cast(exec)); + ET_LOG(Error, "TEST: BundleIO index[%zu] Test_result: FAIL", testset_idx); + rc = 1; + continue; + } + Error verify = executorch::bundled_program::verify_method_outputs( + *method, const_cast(model_pte), testset_idx, kRtol, kAtol); + if (verify == Error::Ok) { + ET_LOG(Info, "TEST: BundleIO index[%zu] Test_result: PASS", testset_idx); + } else { + ET_LOG( + Error, + "verify_method_outputs failed for testset %zu: 0x%x", + testset_idx, + static_cast(verify)); + ET_LOG(Error, "TEST: BundleIO index[%zu] Test_result: FAIL", testset_idx); + rc = 1; + } + } + + return rc; +} diff --git a/examples/riscv/baremetal/riscv_virt.ld b/examples/riscv/baremetal/riscv_virt.ld new file mode 100644 index 00000000000..34980116b1d --- /dev/null +++ b/examples/riscv/baremetal/riscv_virt.ld @@ -0,0 +1,85 @@ +/* + * Copyright 2026 The ExecuTorch Authors. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* qemu-system-riscv{32,64} -machine virt -bios none -kernel: the virt board's + * reset stub at 0x1000 jumps to DRAM base 0x80000000, so _start has to live + * there. RAM size matches the qemu `-m 512M` we pass from run.sh — the + * embedded .bpte in .rodata can be tens of MB for mv2 / llama2 / yolo26. */ + +OUTPUT_ARCH(riscv) +ENTRY(_start) + +MEMORY +{ + RAM (rwx) : ORIGIN = 0x80000000, LENGTH = 512M +} + +SECTIONS +{ + .text 0x80000000 : + { + KEEP(*(.text.boot)) + *(.text .text.*) + } > RAM + + .rodata : ALIGN(8) + { + *(.rodata .rodata.*) + *(.srodata .srodata.*) + } > RAM + + /* C++ global ctors. start.S calls picolibc's __libc_init_array, which + * walks symbols __bothinit_array_start..__bothinit_array_end (preinit + + * init combined). The stock newlib names (__init_array_start/end) are + * defined too for portability, but it's the "both" pair picolibc reads. */ + .bothinit_array : ALIGN(8) + { + PROVIDE_HIDDEN(__bothinit_array_start = .); + PROVIDE_HIDDEN(__preinit_array_start = .); + KEEP(*(.preinit_array)) + PROVIDE_HIDDEN(__preinit_array_end = .); + PROVIDE_HIDDEN(__init_array_start = .); + KEEP(*(SORT_BY_INIT_PRIORITY(.init_array.*) SORT_BY_INIT_PRIORITY(.ctors.*))) + KEEP(*(.init_array EXCLUDE_FILE(*crtbegin.o *crtbegin?.o *crtend.o *crtend?.o) .ctors)) + PROVIDE_HIDDEN(__init_array_end = .); + PROVIDE_HIDDEN(__bothinit_array_end = .); + } > RAM + .fini_array : ALIGN(8) + { + PROVIDE_HIDDEN(__fini_array_start = .); + KEEP(*(SORT_BY_INIT_PRIORITY(.fini_array.*) SORT_BY_INIT_PRIORITY(.dtors.*))) + KEEP(*(.fini_array EXCLUDE_FILE(*crtbegin.o *crtbegin?.o *crtend.o *crtend?.o) .dtors)) + PROVIDE_HIDDEN(__fini_array_end = .); + } > RAM + + .data : ALIGN(8) + { + *(.data .data.*) + *(.sdata .sdata.*) + } > RAM + + .bss : ALIGN(8) + { + _bss_start = .; + *(.bss .bss.*) + *(.sbss .sbss.*) + *(COMMON) + . = ALIGN(8); + _bss_end = .; + } > RAM + + /* 2 MiB stack at the high end of RAM; grows downward. picolibc's sbrk + * looks up __heap_start / __heap_end (double-underscore). */ + . = ALIGN(16); + PROVIDE(__heap_start = .); + . = ORIGIN(RAM) + LENGTH(RAM) - 2M; + PROVIDE(__heap_end = .); + . = . + 2M; + _stack_top = .; + + /DISCARD/ : { *(.note.* .comment .eh_frame .riscv.attributes) } +} diff --git a/examples/riscv/baremetal/semihosting.h b/examples/riscv/baremetal/semihosting.h new file mode 100644 index 00000000000..7af63048d29 --- /dev/null +++ b/examples/riscv/baremetal/semihosting.h @@ -0,0 +1,51 @@ +/* + * Copyright 2026 The ExecuTorch Authors. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch { +namespace riscv { +namespace baremetal { + +// The RISC-V semihosting trigger is a fixed three-insn sequence (slli/ebreak/ +// srai of x0) so qemu can distinguish it from a normal ecall. Op number in +// a0, arg pointer in a1, return value back in a0. +inline long semihost_call(long op, const void* arg) { + register long a0 asm("a0") = op; + register long a1 asm("a1") = (long)arg; + asm volatile( + ".option push\n\t" + ".option norvc\n\t" + "slli x0, x0, 0x1f\n\t" + "ebreak\n\t" + "srai x0, x0, 0x7\n\t" + ".option pop" + : "+r"(a0) + : "r"(a1) + : "memory"); + return a0; +} + +constexpr long SYS_WRITE0 = 0x04; +constexpr long SYS_EXIT_EXTENDED = 0x20; + +inline void semihost_write0(const char* s) { + semihost_call(SYS_WRITE0, s); +} + +[[noreturn]] inline void semihost_exit(int status) { + // ADP_Stopped_ApplicationExit (0x20026) + status, per the semihosting spec. + long block[2] = {0x20026, (long)status}; + semihost_call(SYS_EXIT_EXTENDED, block); + __builtin_trap(); +} + +} // namespace baremetal +} // namespace riscv +} // namespace executorch diff --git a/examples/riscv/baremetal/start.S b/examples/riscv/baremetal/start.S new file mode 100644 index 00000000000..092eeffa4a6 --- /dev/null +++ b/examples/riscv/baremetal/start.S @@ -0,0 +1,49 @@ +/* + * Copyright 2026 The ExecuTorch Authors. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Boot stub for the qemu virt RISC-V baremetal runner: set sp, enable FPU, +// zero .bss, run C++ static ctors via __libc_init_array, jump to main. On +// return, call baremetal_exit so qemu terminates deterministically. + +#if __riscv_xlen == 64 +#define SX sd +#define XLEN_BYTES 8 +#else +#define SX sw +#define XLEN_BYTES 4 +#endif + + .section .text.boot, "ax" + .globl _start + .type _start, @function +_start: + la sp, _stack_top + + // mstatus.FS resets to Off in M-mode, so any FP insn (libstdc++ template + // code emits fsd/fld) traps. We have no trap vector, so the CPU would + // loop on the fault. FS=Dirty (0b11 in bits 13-14) keeps the FPU live. + li t0, 0x6000 + csrs mstatus, t0 + + la a0, _bss_start + la a1, _bss_end +1: + bgeu a0, a1, 2f + SX zero, 0(a0) + addi a0, a0, XLEN_BYTES + j 1b +2: + call __libc_init_array + li a0, 0 + li a1, 0 + call main + call baremetal_exit +3: + wfi + j 3b + + .size _start, .-_start diff --git a/examples/riscv/requirements.txt b/examples/riscv/requirements.txt index 273e7156a1d..649696ae65c 100644 --- a/examples/riscv/requirements.txt +++ b/examples/riscv/requirements.txt @@ -1,2 +1,3 @@ torchvision transformers +ultralytics diff --git a/examples/riscv/riscv32-unknown-elf-toolchain.cmake b/examples/riscv/riscv32-unknown-elf-toolchain.cmake new file mode 100644 index 00000000000..ae968ea6fe2 --- /dev/null +++ b/examples/riscv/riscv32-unknown-elf-toolchain.cmake @@ -0,0 +1,74 @@ +# Copyright 2026 The ExecuTorch Authors. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# rv32 baremetal cross-toolchain. Uses the multilib-aware riscv64-unknown-elf +# gcc (one package, both XLENs); `-march=rv32...` + `-mabi=ilp32d` selects the +# 32-bit picolibc + libstdc++ variant. ELF runs under qemu-system-riscv32 +# -machine virt with semihosting. + +set(CMAKE_SYSTEM_NAME Generic) +set(CMAKE_SYSTEM_PROCESSOR riscv32) + +set(CMAKE_C_COMPILER + "riscv64-unknown-elf-gcc" + CACHE FILEPATH "" +) +set(CMAKE_CXX_COMPILER + "riscv64-unknown-elf-g++" + CACHE FILEPATH "" +) +set(CMAKE_ASM_COMPILER + "riscv64-unknown-elf-gcc" + CACHE FILEPATH "" +) +set(CMAKE_AR + "riscv64-unknown-elf-ar" + CACHE FILEPATH "" +) +set(CMAKE_RANLIB + "riscv64-unknown-elf-ranlib" + CACHE FILEPATH "" +) +set(CMAKE_STRIP + "riscv64-unknown-elf-strip" + CACHE FILEPATH "" +) + +set(CMAKE_EXECUTABLE_SUFFIX ".elf") +# try_compile() can't link without crt0/specs; archive-only sidesteps that. +set(CMAKE_TRY_COMPILE_TARGET_TYPE STATIC_LIBRARY) +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) + +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) + +# Baseline rv32imafdc / ilp32d — the rv32gc-equivalent multilib Ubuntu's +# picolibc + libstdc++ ship. (Unlike rv64, the full rv32gc multilib *is* +# packaged, so we don't have to drop M / C here.) -mcmodel=medany because medlow +# can't reach our 0x80000000 base. picolibc.specs must be on the compile line +# too so libstdc++ headers find picolibc's C headers via the spec's sysroot. +add_compile_options( + --specs=picolibc.specs + -march=rv32imafdc + -mabi=ilp32d + -mcmodel=medany + -fdata-sections + -ffunction-sections + "$<$:-fno-rtti;-fno-exceptions;-fno-unwind-tables>" +) +# -nostdlib++ drops g++'s implicit libstdc++.a (medlow-built, won't relocate). +# -nostartfiles drops picolibc's crt0 in favour of our start.S. +add_link_options( + --specs=picolibc.specs + -march=rv32imafdc + -mabi=ilp32d + -mcmodel=medany + -nostdlib++ + -nostartfiles + "LINKER:--gc-sections" +) diff --git a/examples/riscv/riscv64-unknown-elf-toolchain.cmake b/examples/riscv/riscv64-unknown-elf-toolchain.cmake new file mode 100644 index 00000000000..a4533675f89 --- /dev/null +++ b/examples/riscv/riscv64-unknown-elf-toolchain.cmake @@ -0,0 +1,77 @@ +# Copyright 2026 The ExecuTorch Authors. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# rv64 baremetal cross-toolchain (Ubuntu 26.04+ packages: +# gcc-riscv64-unknown-elf, picolibc-riscv64-unknown-elf, +# libstdc++-riscv64-unknown-elf-picolibc). The resulting ELF runs under +# qemu-system-riscv64 -machine virt with semihosting. + +set(CMAKE_SYSTEM_NAME Generic) +set(CMAKE_SYSTEM_PROCESSOR riscv64) + +set(CMAKE_C_COMPILER + "riscv64-unknown-elf-gcc" + CACHE FILEPATH "" +) +set(CMAKE_CXX_COMPILER + "riscv64-unknown-elf-g++" + CACHE FILEPATH "" +) +set(CMAKE_ASM_COMPILER + "riscv64-unknown-elf-gcc" + CACHE FILEPATH "" +) +set(CMAKE_AR + "riscv64-unknown-elf-ar" + CACHE FILEPATH "" +) +set(CMAKE_RANLIB + "riscv64-unknown-elf-ranlib" + CACHE FILEPATH "" +) +set(CMAKE_STRIP + "riscv64-unknown-elf-strip" + CACHE FILEPATH "" +) + +set(CMAKE_EXECUTABLE_SUFFIX ".elf") +# try_compile() can't link without crt0/specs; archive-only sidesteps that. +set(CMAKE_TRY_COMPILE_TARGET_TYPE STATIC_LIBRARY) +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) + +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) + +# Picked baseline: rv64iafd / lp64d. Ubuntu's picolibc + libstdc++ packages +# don't ship the rv64gc (= rv64imafdc) multilib, so this drops M (integer mul) +# and C (compressed) but keeps double-float. -mcmodel=medany because medlow's +# signed-32-bit-around-0 reach can't address our 0x80000000 base. +# --specs=picolibc.specs has to appear at *compile* time too: libstdc++'s +# // need picolibc's C headers via the spec's +# sysroot. +add_compile_options( + --specs=picolibc.specs + -march=rv64iafd + -mabi=lp64d + -mcmodel=medany + -fdata-sections + -ffunction-sections + "$<$:-fno-rtti;-fno-exceptions;-fno-unwind-tables>" +) +# -nostdlib++ drops g++'s implicit libstdc++.a (medlow-built, won't relocate at +# 0x80000000); we only use its templates, no runtime calls. -nostartfiles drops +# picolibc's crt0 in favour of our start.S. +add_link_options( + --specs=picolibc.specs + -march=rv64iafd + -mabi=lp64d + -mcmodel=medany + -nostdlib++ + -nostartfiles + "LINKER:--gc-sections" +) diff --git a/examples/riscv/run.sh b/examples/riscv/run.sh index 2c207816bfc..b1294a2a2f6 100755 --- a/examples/riscv/run.sh +++ b/examples/riscv/run.sh @@ -4,11 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# RISC-V Phase 1 smoke test driver (pytorch/executorch#18991): -# 1. Export a tiny model to a BundledProgram (.bpte) on the x86_64 host. -# 2. Cross-compile executor_runner for riscv64 Linux glibc. -# 3. Invoke the runner under qemu-user-static and grep its stdout for the -# Test_result: PASS marker emitted by the bundled-IO comparison path. +# RISC-V smoke test driver: +# 1. Export a small model to a BundledProgram (.bpte) on the host. +# 2. Cross-compile a riscv32/64 runner (linux glibc or baremetal). +# 3. Invoke under qemu and grep stdout for the Test_result: PASS marker. set -eu @@ -16,30 +15,43 @@ script_dir=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) et_root_dir=$(realpath "${script_dir}/../..") build_only=false -build_dir="${et_root_dir}/cmake-out-riscv" +build_dir= output_dir="${et_root_dir}/riscv_test" -qemu="qemu-riscv64-static" -qemu_timeout="600" +qemu_timeout="1800" model="add" -xnnpack=false +backend="portable" +os="linux" +arch="rv64" +qemu_cpu_ext="" quantize=false debug_xnnpack=false verbose_xnnpack=false +qemu_override="" usage() { cat < Which model to export and run (default: ${model}) - --xnnpack Enable the XNNPACK backend (AOT partitioner + runtime) --quantize Produce an 8-bit quantized model - --verbose-xnnpack Build XNNPACK with XNN_LOG_LEVEL=4 to log microkernel dispatch at runtime + --backend= AOT backend (default: ${backend}): + - 'portable': portable kernels only + - 'xnnpack': XNNPACK delegate (linux only) + --os= Target OS (default: ${os}): + - 'linux': glibc, qemu-user + - 'baremetal': no OS, qemu-system + semihosting + --arch= Target arch (default: ${arch}): + - 'rv64': riscv64 + - 'rv32': riscv32 + --qemu-cpu-ext= QEMU -cpu extensions appended after the arch base + (e.g. 'v=true,vlen=128'); no rv32/rv64 prefix. + --verbose-xnnpack Build XNNPACK with XNN_LOG_LEVEL=4 to log microkernel dispatch --debug-xnnpack Enable XNNPACK partitioner DEBUG logging and dump the lowered graph --build_only Only export and cross-compile; do not invoke QEMU - --build_dir= CMake build directory (default: ${build_dir}) + --build_dir= CMake build directory (default: cmake-out-riscv/-) --output_dir= Directory for the exported .bpte (default: ${output_dir}) - --qemu= qemu-user binary (default: ${qemu}) - --timeout= Maximum QEMU runtime; matches run_fvp.sh --timelimit (default: ${qemu_timeout}) + --qemu= Override qemu binary + --timeout= Maximum QEMU runtime (default: ${qemu_timeout}) -h, --help Show this help EOF } @@ -47,51 +59,128 @@ EOF for arg in "$@"; do case $arg in --model=*) model="${arg#*=}" ;; - --xnnpack) xnnpack=true ;; --quantize) quantize=true ;; + --backend=*) backend="${arg#*=}" ;; + --os=*) os="${arg#*=}" ;; + --arch=*) arch="${arg#*=}" ;; + --qemu-cpu-ext=*) qemu_cpu_ext="${arg#*=}" ;; --debug-xnnpack) debug_xnnpack=true ;; --verbose-xnnpack) verbose_xnnpack=true ;; --build_only) build_only=true ;; --build_dir=*) build_dir="${arg#*=}" ;; --output_dir=*) output_dir="${arg#*=}" ;; - --qemu=*) qemu="${arg#*=}" ;; + --qemu=*) qemu_override="${arg#*=}" ;; --timeout=*) qemu_timeout="${arg#*=}" ;; -h|--help) usage; exit 0 ;; *) echo "Unknown option: $arg" >&2; usage; exit 1 ;; esac done +case "${backend}" in + portable|xnnpack) ;; + *) echo "Unknown backend: ${backend}" >&2; usage; exit 1 ;; +esac +case "${os}" in + linux|baremetal) ;; + *) echo "Unknown os: ${os}" >&2; usage; exit 1 ;; +esac +case "${arch}" in + rv32|rv64) ;; + *) echo "Unknown arch: ${arch}" >&2; usage; exit 1 ;; +esac + +# xnnpack needs pthreads + dynamic loading: baremetal has neither, and the +# Ubuntu xnnpack microkernels don't ship an rv32 build. +if [[ "${backend}" == "xnnpack" && "${os}" == "baremetal" ]]; then + echo "[run.sh] --backend=xnnpack requires --os=linux" >&2 + exit 1 +fi +if [[ "${backend}" == "xnnpack" && "${arch}" == "rv32" ]]; then + echo "[run.sh] --backend=xnnpack requires --arch=rv64" >&2 + exit 1 +fi +# Ubuntu doesn't package a riscv32-linux-gnu cross (riscv64-linux-gnu has no +# rv32 multilib either), so rv32 linux is blocked on a custom toolchain build. +if [[ "${arch}" == "rv32" && "${os}" == "linux" ]]; then + echo "[run.sh] --arch=rv32 --os=linux not supported: no riscv32-linux-gnu toolchain on Ubuntu" >&2 + exit 1 +fi + +if ${debug_xnnpack} && [[ "${backend}" != "xnnpack" ]]; then + echo "[run.sh] --debug-xnnpack requires --backend=xnnpack" >&2 + exit 1 +fi +if ${verbose_xnnpack} && [[ "${backend}" != "xnnpack" ]]; then + echo "[run.sh] --verbose-xnnpack requires --backend=xnnpack" >&2 + exit 1 +fi + +build_dir="${build_dir:-${et_root_dir}/cmake-out-riscv/${os}-${arch}}" +mkdir -p "${build_dir}" + mkdir -p "${output_dir}" -bpte_path="${output_dir}/${model}_riscv.bpte" +bpte_path="${output_dir}/${model}$(test "${quantize}" = "true" && echo "_q" || echo "")_${backend}_${os}_${arch}_riscv.bpte" -echo "[run.sh] Step 1/3: AOT export on host" +echo "[run.sh] Step 1/3: AOT export on host (backend=${backend} os=${os} arch=${arch})" aot_extra_args=() -if ${xnnpack}; then - aot_extra_args+=(--xnnpack) -fi if ${quantize}; then aot_extra_args+=(--quantize) fi if ${debug_xnnpack}; then aot_extra_args+=(--debug-xnnpack) fi -python "${script_dir}/aot_riscv.py" --model "${model}" "${aot_extra_args[@]}" --output "${bpte_path}" +python "${script_dir}/aot_riscv.py" --model "${model}" --backend "${backend}" --os "${os}" "${aot_extra_args[@]}" --output "${bpte_path}" -echo "[run.sh] Step 2/3: cross-compile executor_runner for riscv64-linux" +echo "[run.sh] Step 2/3: cross-compile executor_runner for ${arch}-${os}" cmake_extra_args=() -if ${xnnpack}; then +if [[ "${backend}" == "xnnpack" ]]; then cmake_extra_args+=(-DEXECUTORCH_BUILD_XNNPACK=ON) fi if ${verbose_xnnpack}; then cmake_extra_args+=(-DEXECUTORCH_XNNPACK_LOG_LEVEL=4 -DEXECUTORCH_BUILD_RISCV_ETDUMP=ON) fi -cmake -S "${et_root_dir}" -B "${build_dir}" \ - --preset riscv64-linux \ - "${cmake_extra_args[@]}" \ - -DCMAKE_BUILD_TYPE=Release -cmake --build "${build_dir}" -j"$(nproc)" --target executor_runner -runner="${build_dir}/executor_runner" +# Map our short arch (rv32/rv64) to the canonical riscv32/riscv64 prefix used +# by the cross toolchain and qemu binary names. +case "${arch}" in + rv32) arch_long="riscv32" ;; + rv64) arch_long="riscv64" ;; +esac + +if [[ "${os}" == "linux" ]]; then + build_target="executor_runner" + qemu_default="qemu-${arch_long}" + # --fresh re-runs configure with a clean cache (object files stay cached) + # so preset edits always take effect across iterations. + cmake -S "${et_root_dir}" -B "${build_dir}" --fresh \ + --preset "${arch_long}-linux" \ + "${cmake_extra_args[@]}" \ + -DCMAKE_BUILD_TYPE=Release + cmake --build "${build_dir}" -j"$(nproc)" --target "${build_target}" + runner="${build_dir}/${build_target}" + +elif [[ "${os}" == "baremetal" ]]; then + build_target="executor_runner_baremetal" + qemu_default="qemu-system-${arch_long}" + # Standalone build (mirrors examples/arm/executor_runner/standalone): the + # runner pulls executorch in via add_subdirectory, so the main + # CMakeLists.txt doesn't reference examples/riscv/. + cmake -S "${et_root_dir}/examples/riscv/baremetal" -B "${build_dir}" --fresh \ + -DEXECUTORCH_ROOT="${et_root_dir}" \ + -DCMAKE_TOOLCHAIN_FILE="${et_root_dir}/examples/riscv/${arch_long}-unknown-elf-toolchain.cmake" \ + -DRISCV_BAREMETAL_PTE="${bpte_path}" \ + "${cmake_extra_args[@]}" \ + -DCMAKE_BUILD_TYPE=Release + cmake --build "${build_dir}" -j"$(nproc)" --target "${build_target}" + runner="${build_dir}/${build_target}.elf" + +else + echo "Unknown os: ${os}" >&2 + usage + exit 1 +fi + +qemu="${qemu_override:-${qemu_default}}" [[ -x "${runner}" ]] || { echo "[run.sh] runner not found at ${runner}" >&2; exit 1; } if file "${runner}" | grep -q "RISC-V"; then @@ -113,45 +202,75 @@ hash "${qemu}" 2>/dev/null || { exit 1 } -# QEMU_LD_PREFIX points qemu-user at the riscv64 sysroot so the dynamic -# linker (ld-linux-riscv64-lp64d.so.1) referenced in the ELF resolves. -export QEMU_LD_PREFIX="${QEMU_LD_PREFIX:-/usr/riscv64-linux-gnu}" +log_file="${output_dir}/${model}$(test "${quantize}" = "true" && echo "_q" || echo "")_${backend}_${os}_${arch}_riscv.log" +rm -f "${log_file}" -if [[ -n "${QEMU_CPU+x}" ]]; then - echo "[run.sh] QEMU_CPU=${QEMU_CPU}" +# Compose the QEMU -cpu value once: ${arch} alone, or ${arch},${ext} when an +# extension list was supplied. qemu-user reads $QEMU_CPU; qemu-system takes +# -cpu on the command line. +qemu_cpu="${arch}" +if [[ -n "${qemu_cpu_ext}" ]]; then + qemu_cpu="${arch},${qemu_cpu_ext}" fi +echo "[run.sh] qemu -cpu = ${qemu_cpu}" -runner_extra_args=() -if ${quantize}; then - runner_extra_args+=(--bundleio_rtol=0.1 --bundleio_atol=0.25) -fi -etdump_path="" -if ${verbose_xnnpack}; then - etdump_path="${output_dir}/${model}_riscv.etdump" - rm -f "${etdump_path}" - runner_extra_args+=(--etdump_path="${etdump_path}") -fi +if [[ "${os}" == "linux" ]]; then + # QEMU_LD_PREFIX points qemu-user at the cross sysroot so the dynamic + # linker (ld-linux-riscv*) referenced in the ELF resolves. + if [[ "${arch}" == "rv64" ]]; then + export QEMU_LD_PREFIX="${QEMU_LD_PREFIX:-/usr/riscv64-linux-gnu}" + else + export QEMU_LD_PREFIX="${QEMU_LD_PREFIX:-/usr/riscv32-linux-gnu}" + fi + export QEMU_CPU="${qemu_cpu}" -# etdump_summary.py reads the XNN_LOG_LEVEL=4 registrations. -log_file="${output_dir}/${model}_riscv.run.log" -rm -f "${log_file}" + runner_extra_args=() + if ${quantize}; then + runner_extra_args+=(--bundleio_rtol=0.1 --bundleio_atol=0.25) + fi + etdump_path="" + if ${verbose_xnnpack}; then + etdump_path="${output_dir}/${model}$(test "${quantize}" = "true" && echo "_q" || echo "")_${backend}_${os}_${arch}_riscv.etdump" + rm -f "${etdump_path}" + runner_extra_args+=(--etdump_path="${etdump_path}") + fi -set +e -timeout --signal=KILL "${qemu_timeout}" "${qemu}" "${runner}" \ - --model_path="${bpte_path}" \ - "${runner_extra_args[@]}" \ - 2>&1 | tee "${log_file}" -qemu_status=${PIPESTATUS[0]} -set -e + set +e + timeout --signal=KILL "${qemu_timeout}" "${qemu}" "${runner}" \ + --model_path="${bpte_path}" \ + "${runner_extra_args[@]}" \ + |& tee "${log_file}" + qemu_status=${PIPESTATUS[0]} + set -e -echo "[run.sh] qemu exit status: ${qemu_status}" + if [[ -n "${etdump_path}" && -f "${etdump_path}" ]]; then + python "${script_dir}/etdump_summary.py" "${etdump_path}" \ + --run-log "${log_file}" \ + --json "${etdump_path}.json" || true + fi + +elif [[ "${os}" == "baremetal" ]]; then + # qemu-system -machine virt boots at 0x80000000; -bios none skips OpenSBI; + # semihosting target=native routes SYS_WRITE0/SYS_EXIT to host stdio. + # For deeper debugging, add: -accel tcg,one-insn-per-tb=on -d in_asm,nochain + # -D + set +e + timeout --signal=KILL "${qemu_timeout}" "${qemu}" \ + -machine virt -cpu "${qemu_cpu}" -m 512M -nographic -bios none \ + -semihosting-config enable=on,target=native \ + -kernel "${runner}" \ + |& tee "${log_file}" + qemu_status=${PIPESTATUS[0]} + set -e -if [[ -n "${etdump_path}" && -f "${etdump_path}" ]]; then - python "${script_dir}/etdump_summary.py" "${etdump_path}" \ - --run-log "${log_file}" \ - --json "${etdump_path}.json" || true +else + echo "Unknown os: ${os}" >&2 + usage + exit 1 fi +echo "[run.sh] qemu exit status: ${qemu_status}" + if grep -q "Test_result: PASS" "${log_file}"; then echo "[run.sh] Bundled I/O check PASSED" exit 0 diff --git a/examples/riscv/setup.sh b/examples/riscv/setup.sh index 955c8ca3386..b6e7b097b86 100755 --- a/examples/riscv/setup.sh +++ b/examples/riscv/setup.sh @@ -4,9 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Install host tooling needed for the RISC-V Phase 1.0 smoke test: -# - gcc/g++/binutils for riscv64-linux-gnu (cross-compiler + sysroot) -# - qemu-user-static (qemu-riscv64 user-mode emulator) +# Host tooling for the RISC-V smoke tests. Targets Ubuntu 26.04: that's where +# libstdc++-riscv64-unknown-elf-picolibc was first packaged, and the baremetal +# build chain needs C++ stdlib headers paired with picolibc. set -eu @@ -25,23 +25,24 @@ fi ${SUDO} apt-get update ${SUDO} apt-get install -y --no-install-recommends \ build-essential \ - gcc${GCC_VERSION:+-${GCC_VERSION}}-riscv64-linux-gnu \ - g++${GCC_VERSION:+-${GCC_VERSION}}-riscv64-linux-gnu \ + gcc-riscv64-linux-gnu \ + g++-riscv64-linux-gnu \ binutils-riscv64-linux-gnu \ libc6-riscv64-cross \ libc6-dev-riscv64-cross \ + gcc-riscv64-unknown-elf \ + picolibc-riscv64-unknown-elf \ + libstdc++-riscv64-unknown-elf-picolibc \ cmake \ file \ ca-certificates \ - qemu-user-static - -if [[ -n "${GCC_VERSION+x}" ]]; then - ${SUDO} update-alternatives --install /usr/bin/riscv64-linux-gnu-gcc riscv64-linux-gnu-gcc /usr/bin/riscv64-linux-gnu-gcc${GCC_VERSION:+-${GCC_VERSION}} 100 - ${SUDO} update-alternatives --install /usr/bin/riscv64-linux-gnu-g++ riscv64-linux-gnu-g++ /usr/bin/riscv64-linux-gnu-g++${GCC_VERSION:+-${GCC_VERSION}} 100 -fi + qemu-user \ + qemu-system-riscv \ + libglib2.0-0t64 \ + libxcb1 \ + libgl1 riscv64-linux-gnu-gcc --version | head -n1 -qemu-riscv64-static --version | head -n1 +qemu-riscv64 --version | head -n1 -# Some python packages also need to be installed pip install -r "${script_dir}/requirements.txt" diff --git a/exir/BUCK b/exir/BUCK index f00b3f1c787..d70900c02ae 100644 --- a/exir/BUCK +++ b/exir/BUCK @@ -259,6 +259,16 @@ fbcode_target(_kind = runtime.python_library, ], ) +fbcode_target(_kind = runtime.python_library, + name = "_program_utils", + srcs = [ + "_program_utils.py", + ], + deps = [ + "//caffe2:torch", + ], +) + fbcode_target(_kind = runtime.python_library, name = "pass_manager", srcs = [ @@ -266,7 +276,9 @@ fbcode_target(_kind = runtime.python_library, ], deps = [ "fbsource//third-party/pypi/typing-extensions:typing-extensions", + ":_program_utils", ":error", + ":pass_base", "//caffe2:torch", ], ) diff --git a/exir/_program_utils.py b/exir/_program_utils.py new file mode 100644 index 00000000000..d0d2039d93a --- /dev/null +++ b/exir/_program_utils.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import torch +from torch.export.exported_program import ( + ConstantArgument, + ExportGraphSignature, + InputSpec, + OutputSpec, +) + + +def _get_updated_range_constraints(gm): + def get_shape_env(gm): + vals = [ + node.meta["val"] + for node in gm.graph.nodes + if node.meta.get("val", None) is not None + ] + from torch._guards import detect_fake_mode # type: ignore[21] + + fake_mode = detect_fake_mode(vals) + if fake_mode is not None: + return fake_mode.shape_env + for v in vals: + if isinstance(v, torch.SymInt): + return v.node.shape_env + + shape_env = get_shape_env(gm) + if shape_env is None: + return {} + range_constraints = { + shape_env.replacements.get(k, k): v for k, v in shape_env.var_to_range.items() + } + # Only when we have an unbacked symint, and it's used as constructor inputs, + # runtime_var_to_range will make a difference compated to var_to_range. + # e.g. [2, oo) -> [0, oo) + for k, v in shape_env.var_to_range.items(): + if k not in shape_env.replacements: + range_constraints[k] = v + return range_constraints + + +def _get_updated_graph_signature( + old_signature: ExportGraphSignature, + new_gm: torch.fx.GraphModule, +) -> ExportGraphSignature: + """ + Update the graph signature's user_input/user_outputs. + """ + new_input_specs = [] + i = 0 + for node in new_gm.graph.nodes: + if node.op != "placeholder": + continue + + assert i < len( + old_signature.input_specs + ), "Number of inputs changed after transformation" + old_input_spec = old_signature.input_specs[i] + arg = ( + old_input_spec.arg + if isinstance(old_input_spec.arg, ConstantArgument) + # pyre-fixme[20]: Argument `class_fqn` expected. + else type(old_input_spec.arg)(node.name) + ) + new_input_specs.append( + InputSpec( + old_input_spec.kind, + arg, + old_input_spec.target, + persistent=old_input_spec.persistent, + ) + ) + i += 1 + + output_node = new_gm.graph.output_node() + assert output_node.op == "output" + + new_output_specs = [] + for i, node in enumerate(output_node.args[0]): + assert i < len( + old_signature.output_specs + ), "Number of outputs changed after transformation" + old_output_spec = old_signature.output_specs[i] + arg = ( + old_output_spec.arg + if isinstance(old_output_spec.arg, ConstantArgument) + # pyre-fixme[20]: Argument `class_fqn` expected. + else type(old_output_spec.arg)(node.name) + ) + new_output_specs.append( + OutputSpec(old_output_spec.kind, arg, old_output_spec.target) + ) + + new_signature = ExportGraphSignature( + input_specs=new_input_specs, output_specs=new_output_specs + ) + return new_signature diff --git a/exir/pass_base.py b/exir/pass_base.py index 8ab0c675240..f93dd75d156 100644 --- a/exir/pass_base.py +++ b/exir/pass_base.py @@ -6,10 +6,11 @@ # LICENSE file in the root directory of this source tree. # pyre-strict - import operator import traceback +from abc import ABC, abstractmethod from contextlib import nullcontext +from dataclasses import dataclass from typing import ( Any, Callable, @@ -27,9 +28,7 @@ import torch from executorch.exir import memory - from executorch.exir.delegate import executorch_call_delegate, is_lowered_module - from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.error import ExportError, ExportErrorType from torch import fx @@ -37,6 +36,7 @@ from torch._subclasses import FakeTensorMode, UnsupportedFakeTensorException from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode +from torch.export import ExportedProgram from torch.fx import traceback as fx_traceback from torch.fx.experimental.proxy_tensor import PythonKeyTracer from torch.fx.graph import CodeGen @@ -182,6 +182,58 @@ class ExportPassBaseError(RuntimeError): pass +@dataclass(frozen=True) +class ExportedProgramPassResult: + exported_program: ExportedProgram + modified: bool + + +class ExportedProgramPassBase(ABC): + """ + Base interface for implementing passes that operate on ExportedProgram. + """ + + def __call__(self, exported_program: ExportedProgram) -> ExportedProgramPassResult: + """ + Runs the precondition check, the pass itself, and the postcondition check. + """ + + self.requires(exported_program) + res = self.call(exported_program) + self.ensures(exported_program) + return res + + @abstractmethod + def call(self, exported_program: ExportedProgram) -> ExportedProgramPassResult: + """ + The pass that is run through the given exported program. To implement a + pass, it is required to implement this function. + + Args: + exported_program: The exported program we will run a pass on + """ + + def requires(self, exported_program: ExportedProgram) -> None: # noqa: B027 + """ + This function will be called before the pass is run and will check that + the given exported program contains the preconditions needed to run the + pass. It is not required to implement this function. + + Args: + exported_program: The exported program we will run checks on + """ + + def ensures(self, exported_program: ExportedProgram) -> None: # noqa: B027 + """ + This function will be called after the pass is run and will check that + the given exported program contains the postconditions needed to run the + pass. It is not required to implement this function. + + Args: + exported_program: The exported program we will run checks on + """ + + class _ExportPassBase(PassBase): """ Interpreter-based pass class to help users maintain the IR spec while writing diff --git a/exir/pass_manager.py b/exir/pass_manager.py index b812ccea7b8..351e98651dd 100644 --- a/exir/pass_manager.py +++ b/exir/pass_manager.py @@ -5,28 +5,46 @@ # LICENSE file in the root directory of this source tree. # pyre-strict - -from typing import Callable, List, Optional, Union +import copy +import inspect +import logging +from typing import Callable, List, Optional, Type, TypeAlias, Union import torch import torch.fx.passes.infra.pass_manager as fx import torch.utils._pytree as pytree +from executorch.exir._program_utils import ( + _get_updated_graph_signature, + _get_updated_range_constraints, +) from executorch.exir.error import ExportError, ExportErrorType +from executorch.exir.pass_base import ExportedProgramPassBase, ExportedProgramPassResult +from torch._export.verifier import Verifier +from torch.export import ExportedProgram from torch.fx.passes.infra.pass_base import PassResult -from typing_extensions import TypeAlias +from torch.fx.passes.infra.pass_manager import pass_result_wrapper + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +PassType: TypeAlias = Union[ + ExportedProgramPassBase, Callable[[torch.fx.GraphModule], Optional[PassResult]] +] + -PassType: TypeAlias = Callable[[torch.fx.GraphModule], Optional[PassResult]] +def _get_pass_name(fn: PassType) -> str: + """Returns a human-readable name for a pass.""" + return fn.__name__ if inspect.isfunction(fn) else type(fn).__name__ class PassManager(fx.PassManager): """ - Class to run multiple passes on a given graph module. The PassManager is - callable so to run it, we can just call the PassManager instance. + Runs multiple passes on a GraphModule. - Private Attributes: - * **passes**: A list of callable passes - * **params**: An instance of PassManagerParams containing the result of the - flags set in the constructor. + This is the legacy PassManager that extends torch.fx.passes.infra.pass_manager.PassManager. + Use this when you need to run passes on a GraphModule directly. + + For running passes on ExportedProgram, use ExportedProgramPassManager instead. """ def __init__( @@ -34,14 +52,11 @@ def __init__( passes: Optional[Union[List[PassType], List[List[PassType]]]] = None, run_checks_after_each_pass: bool = False, suppress_check_failures: bool = False, + steps: int = 1, ) -> None: - r""" - Args: - passes: A list of passes - enable_debug_pass: set to true to enable the debug passes - run_checks_after_each_pass: whether to run checks and linting after each pass - """ - + logger.warning( + "PassManager is deprecated. Please use ExportedProgramPassManager instead." + ) # Flatten the passes to a list of callables passes = passes if passes else [] flattened_passes = [ @@ -52,6 +67,7 @@ def __init__( flattened_passes, run_checks_after_each_pass=run_checks_after_each_pass, suppress_check_failures=suppress_check_failures, + steps=steps, ) def check(self, module: torch.nn.Module) -> None: @@ -65,10 +81,9 @@ def check(self, module: torch.nn.Module) -> None: node's spec field is a tuple) - Ensure that the graph module has type torch.fx.GraphModule """ - assert isinstance(module, fx.GraphModule) + assert isinstance(module, torch.fx.GraphModule) module.recompile() module.graph.lint() - # TODO(qihan): use verifier.check_is_exir for node in module.graph.nodes: if node.op == "call_method": @@ -76,3 +91,151 @@ def check(self, module: torch.nn.Module) -> None: ExportErrorType.NOT_SUPPORTED, f"call_method `{node}` is not supported except for backend delegate.", ) + + +class ExportedProgramPassManager(fx.PassManager): + """ + Runs multiple passes on an ExportedProgram. + + This PassManager is specifically designed for ExportedProgram and supports + both GraphModule-only passes and ExportedProgram-aware passes. + + For running passes on GraphModule directly, use PassManager instead. + """ + + def __init__( + self, + passes: Optional[Union[List[PassType], List[List[PassType]]]] = None, + constraints: Optional[List[Callable[[Callable, Callable], bool]]] = None, + run_checks_after_each_pass: bool = False, + suppress_check_failures: bool = False, + steps: int = 1, + ) -> None: + wrapped_passes = ( + [ + ( + fn + if isinstance(fn, ExportedProgramPassBase) + else pass_result_wrapper(fn) + ) + for fn in pytree.tree_flatten(passes)[0] + ] + if passes + else [] + ) + + super().__init__( + wrapped_passes, + constraints=constraints, + run_checks_after_each_pass=run_checks_after_each_pass, + suppress_check_failures=suppress_check_failures, + steps=steps, + ) + + def check(self, exported_program: ExportedProgram) -> None: + """Validates graph module invariants.""" + graph_module = exported_program.graph_module + graph_module.recompile() + graph_module.graph.lint() + + for node in graph_module.graph.nodes: + if node.op == "call_method": + raise ExportError( + ExportErrorType.NOT_SUPPORTED, + f"call_method `{node}` is not supported except for backend delegate.", + ) + + exported_program.validate() + + # pyre-ignore[14]: Intentionally overriding with different signature for ExportedProgram + def __call__( # noqa: C901 + self, + exported_program: ExportedProgram, + override_verifiers: Optional[list[Type[Verifier]]] = None, + ) -> ExportedProgramPassResult: + """ + Runs passes on an ExportedProgram. + + Handles both GraphModule-only passes and ExportedProgram-aware passes. Will create a shallow copy of the exported program before running passes. + + Args: + exported_program: The exported program to transform. + + Returns: + ExportedProgramPassResult containing the transformed program. + """ + if not self._validated: + self.solve_constraints() + + exported_program = copy.copy(exported_program) + + if override_verifiers: + exported_program._verifiers = override_verifiers + + self.check(exported_program) + + overall_modified = False + + for _ in range(self.steps): + step_modified = False + + for i, fn in enumerate(self.passes): + pass_modified = False + try: + if not isinstance(fn, ExportedProgramPassBase): + res = fn(exported_program.graph_module) + if res is None: + raise TypeError( + f"The result of pass {_get_pass_name(fn)} should be type PassResult. " + "Please wrap it with pass_result_wrapper()" + ) + + if res.modified: + # Not running _update_exported_program_graph_module here because it is + # possible that the verifier will fail upon new ExportedProgram construction, + # and we should only run verification after each pass if + # run_checks_after_each_pass is True. + res.graph_module.recompile() + exported_program._graph_module = res.graph_module + exported_program._graph_signature = ( + _get_updated_graph_signature( + exported_program.graph_signature, + res.graph_module, + ) + ) + exported_program._range_constraints = ( + _get_updated_range_constraints(res.graph_module) + ) + pass_modified = True + + else: + assert isinstance(fn, ExportedProgramPassBase) + ep_res = fn(exported_program) + exported_program = ep_res.exported_program + + if ep_res.modified: + pass_modified = True + exported_program.graph_module.recompile() + + if self.run_checks_after_each_pass: + self.check(exported_program) + + if pass_modified: + step_modified = True + logger.debug( + "Graph after pass '%s': %s", + _get_pass_name(fn), + exported_program.graph_module.graph, + ) + + except Exception as e: + prev_names = [_get_pass_name(p) for p in self.passes[:i]] + msg = f"An error occurred when running the '{_get_pass_name(fn)}' pass after the following passes: {prev_names}" + raise Exception(msg) from e # noqa: TRY002 + + overall_modified = overall_modified or step_modified + if not step_modified: + break + + self.check(exported_program) + return ExportedProgramPassResult(exported_program, overall_modified) diff --git a/exir/passes/BUCK b/exir/passes/BUCK index 954f1cfdb4f..4647388b388 100644 --- a/exir/passes/BUCK +++ b/exir/passes/BUCK @@ -381,6 +381,14 @@ fbcode_target(_kind = runtime.python_library, ], ) +fbcode_target(_kind = runtime.python_library, + name = "device_copy_ops_registry", + srcs = ["_device_copy_ops_registry.py"], + deps = [ + "//caffe2:torch", + ], +) + fbcode_target(_kind = runtime.python_library, name = "memory_format_ops_pass", srcs = [ diff --git a/exir/passes/_device_copy_ops_registry.py b/exir/passes/_device_copy_ops_registry.py new file mode 100644 index 00000000000..a62b88d4234 --- /dev/null +++ b/exir/passes/_device_copy_ops_registry.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Registry for device copy ops used to insert explicit H2D (host-to-device) +and D2H (device-to-host) data transfer operations at delegate boundaries. + +These ops are inserted by PropagateDevicePass when enable_non_cpu_memory_planning +is True, making the graph functional by explicitly transferring data between +CPU and device memory. + +Follows the same registration pattern as dim_order_ops_registry.py. +""" + +import torch +from torch.library import impl, Library + +lib = Library("et_copy", "DEF") + +# _h2d_copy: copies a CPU tensor to device memory. +# At tracing time, this is a clone (both on CPU). At runtime, the out tensor +# is memory-planned on device, and the kernel calls +# DeviceAllocator::copy_host_to_device. +lib.define("_h2d_copy(Tensor self) -> Tensor") +lib.define("_h2d_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)") + +# _d2h_copy: copies a device tensor to CPU memory. +# At tracing time, this is a clone (both on CPU). At runtime, the self tensor +# has device memory, and the kernel calls DeviceAllocator::copy_device_to_host. +lib.define("_d2h_copy(Tensor self) -> Tensor") +lib.define("_d2h_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)") + + +@impl(lib, "_h2d_copy", "CompositeImplicitAutograd") +def _h2d_copy_impl(self: torch.Tensor) -> torch.Tensor: + # During tracing, both tensors are on CPU. Just clone to represent the transfer. + return self.clone() + + +@impl(lib, "_h2d_copy.out", "CompositeImplicitAutograd") +def _h2d_copy_out_impl(self: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: + out.copy_(self) + return out + + +@impl(lib, "_d2h_copy", "CompositeImplicitAutograd") +def _d2h_copy_impl(self: torch.Tensor) -> torch.Tensor: + # During tracing, both tensors are on CPU. Just clone to represent the transfer. + return self.clone() + + +@impl(lib, "_d2h_copy.out", "CompositeImplicitAutograd") +def _d2h_copy_out_impl(self: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: + out.copy_(self) + return out diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index 9adbf65dd90..73f943e55e0 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -11,6 +11,7 @@ import torch from executorch.exir.delegate import executorch_call_delegate +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, ProxyValue from executorch.exir.tensor import TensorSpec from torch.export.exported_program import ExportGraphSignature @@ -18,6 +19,14 @@ from torch.fx.passes.infra.pass_base import PassResult from torch.utils import _pytree as pytree +# register llama.fallback (optional — only needed for QNN/llama sharding paths) +try: + import executorch.extension.llm.custom_ops.op_fallback # noqa: F401 + + _llama_fallback_default = exir_ops.edge.llama.fallback.default +except (ImportError, AttributeError): + _llama_fallback_default = None + # pyre-ignore def make_spec(x): @@ -75,9 +84,9 @@ def get_spec(x): elif node.op == "call_function" and node.target == operator.getitem: value_spec = pytree.tree_map(get_spec, node.args[0]) node.meta["spec"] = value_spec[node.args[1]] - elif ( - node.op == "call_function" - and node.target == executorch_call_delegate + elif node.op == "call_function" and node.target in ( + executorch_call_delegate, + _llama_fallback_default, ): # Note: We currently rely on delegate node specs not being regenerated, # as the spec is set somewhat manually when adding the call delegate node. diff --git a/exir/program/BUCK b/exir/program/BUCK index 7d9642efdb7..11f62edd99e 100644 --- a/exir/program/BUCK +++ b/exir/program/BUCK @@ -22,6 +22,7 @@ fbcode_target(_kind = runtime.python_library, ], deps = [ "//caffe2:torch", + "//executorch/exir:_program_utils", "//executorch/exir:error", "//executorch/exir:graph_module", "//executorch/exir:pass_base", diff --git a/exir/program/_program.py b/exir/program/_program.py index b3d94c8ffd7..485d72bbe45 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -5,8 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe - +# pyre-strict import copy import io import logging @@ -38,7 +37,8 @@ from executorch.exir.operator.convert import _pybind_schema_to_native_schema from executorch.exir.operator.util import _QUANT_PRIMITIVES from executorch.exir.pass_base import PassBase -from executorch.exir.pass_manager import PassType +from executorch.exir.pass_manager import ExportedProgramPassManager, PassType + from executorch.exir.passes import ( base_post_op_replace_passes, base_pre_op_replace_passes, @@ -88,17 +88,11 @@ from torch.export._remove_auto_functionalized_pass import ( unsafe_remove_auto_functionalized_pass, ) -from torch.export.exported_program import ( - ConstantArgument, - ExportGraphSignature, - InputKind, - InputSpec, - OutputSpec, - TensorArgument, -) +from torch.export.exported_program import InputKind, InputSpec, TensorArgument from torch.fx import _pytree as fx_pytree from torch.fx._compatibility import compatibility -from torch.fx.passes.infra.pass_manager import PassManager +from torch.fx.passes.infra.pass_manager import PassManager as GraphModulePassManager + from torch.utils import _pytree as pytree Val = Any @@ -131,93 +125,10 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: transform_op_to_aten_op = {} -def _get_updated_range_constraints(gm): - def get_shape_env(gm): - vals = [ - node.meta["val"] - for node in gm.graph.nodes - if node.meta.get("val", None) is not None - ] - from torch._guards import detect_fake_mode # type: ignore[21] - - fake_mode = detect_fake_mode(vals) - if fake_mode is not None: - return fake_mode.shape_env - for v in vals: - if isinstance(v, torch.SymInt): - return v.node.shape_env - - shape_env = get_shape_env(gm) - if shape_env is None: - return {} - range_constraints = { - shape_env.replacements.get(k, k): v for k, v in shape_env.var_to_range.items() - } - # Only when we have an unbacked symint, and it's used as constructor inputs, - # runtime_var_to_range will make a difference compated to var_to_range. - # e.g. [2, oo) -> [0, oo) - for k, v in shape_env.var_to_range.items(): - if k not in shape_env.replacements: - range_constraints[k] = v - return range_constraints - - -def _get_updated_graph_signature( - old_signature: ExportGraphSignature, - new_gm: torch.fx.GraphModule, -) -> ExportGraphSignature: - """ - Update the graph signature's user_input/user_outputs. - """ - new_input_specs = [] - i = 0 - for node in new_gm.graph.nodes: - if node.op != "placeholder": - continue - - assert i < len( - old_signature.input_specs - ), "Number of inputs changed after transformation" - old_input_spec = old_signature.input_specs[i] - arg = ( - old_input_spec.arg - if isinstance(old_input_spec.arg, ConstantArgument) - # pyre-fixme[20]: Argument `class_fqn` expected. - else type(old_input_spec.arg)(node.name) - ) - new_input_specs.append( - InputSpec( - old_input_spec.kind, - arg, - old_input_spec.target, - persistent=old_input_spec.persistent, - ) - ) - i += 1 - - output_node = new_gm.graph.output_node() - assert output_node.op == "output" - - new_output_specs = [] - for i, node in enumerate(output_node.args[0]): - assert i < len( - old_signature.output_specs - ), "Number of outputs changed after transformation" - old_output_spec = old_signature.output_specs[i] - arg = ( - old_output_spec.arg - if isinstance(old_output_spec.arg, ConstantArgument) - # pyre-fixme[20]: Argument `class_fqn` expected. - else type(old_output_spec.arg)(node.name) - ) - new_output_specs.append( - OutputSpec(old_output_spec.kind, arg, old_output_spec.target) - ) - - new_signature = ExportGraphSignature( - input_specs=new_input_specs, output_specs=new_output_specs - ) - return new_signature +from executorch.exir._program_utils import ( # noqa: E402 + _get_updated_graph_signature, + _get_updated_range_constraints, +) def _transform( @@ -243,13 +154,13 @@ def _transform( ), f"Expected all passes to be of PassType, not list or Verifier. Use override_verifiers kwarg instead. Got: {list(passes)}" return _transform_with_pass_manager( - self, PassManager(list(passes)), override_verifiers + self, ExportedProgramPassManager(list(passes)), override_verifiers ) def _transform_with_pass_manager( - self, - pass_manager: PassManager, + self: ExportedProgram, + pass_manager: Union[ExportedProgramPassManager, GraphModulePassManager], override_verifiers: None | list[Type[Verifier]] = None, ) -> "ExportedProgram": """ @@ -258,22 +169,26 @@ def _transform_with_pass_manager( Args: self: The ExportedProgram instance to transform pass_manager: An instance of PassManager to apply transformations. + - ExportedProgramPassManager: operates on the full ExportedProgram + - GraphModulePassManager: operates on the GraphModule only override_verifiers: Optional list of verifier classes to use instead of the default verifiers. This is needed if the transforms yields illegal graph that the default verifier cannot handle. Returns: ExportedProgram: A new ExportedProgram with the transformations applied, or self if no changes were made """ - res = pass_manager(self.graph_module) - transformed_gm = res.graph_module if res is not None else self.graph_module - assert transformed_gm is not None - - if transformed_gm is self.graph_module and not res.modified: - return self - - return _update_exported_program_graph_module( - self, transformed_gm, override_verifiers - ) + if isinstance(pass_manager, ExportedProgramPassManager): + res = pass_manager(self, override_verifiers) + if not res.modified: + return self + return res.exported_program + else: + res = pass_manager(self.graph_module) + if not res.modified: + return self + return _update_exported_program_graph_module( + self, res.graph_module, override_verifiers + ) def _update_exported_program_graph_module( @@ -1324,7 +1239,12 @@ def collect_named_data_store_outputs( def to_edge_transform_and_lower( # noqa: C901 programs: Union[ExportedProgram, Dict[str, ExportedProgram]], transform_passes: Optional[ - Union[Sequence[PassType], Dict[str, Sequence[PassType]], PassManager] + Union[ + Sequence[PassType], + Dict[str, Sequence[PassType]], + GraphModulePassManager, + ExportedProgramPassManager, + ] ] = None, partitioner: Optional[ Union[List[Partitioner], Dict[str, List[Partitioner]]] @@ -1359,7 +1279,7 @@ def to_edge_transform_and_lower( # noqa: C901 2) a dictionary - only method names specified in the dictionary will be transformed with their corresponding passes - 3) an instance of a PassManager - + 3) an instance of a PassManager (either a GraphModulePassManager or an ExportedProgramPassManager) - all methods in the given EdgeProgramManager will be transformed with the given PassManager instance. @@ -1604,7 +1524,12 @@ def exported_program(self, method_name: str = "forward") -> ExportedProgram: @et_logger("transform") def transform( self, - passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]], PassManager], + passes: Union[ + Sequence[PassType], + Dict[str, Sequence[PassType]], + ExportedProgramPassManager, + GraphModulePassManager, + ], compile_config: Optional[EdgeCompileConfig] = None, ) -> "EdgeProgramManager": """ @@ -1618,7 +1543,7 @@ def transform( 2) a dictionary mapping method names to lists of passes - only method names specified in the dictionary will be transformed with their corresponding passes. - 3) a PassManager instance - + 3) a PassManager (either ExportedProgramPassManager or GraphModulePassManager) instance - all methods in the given EdgeProgramManager will be transformed with the given PassManager instance. compile_config: Compile config to use for veriy the correctness of model @@ -1637,13 +1562,15 @@ def transform( # Cast passes parameter upfront. passes_seq: Optional[Sequence[PassType]] = None passes_dict: Optional[Dict[str, Sequence[PassType]]] = None - pass_manager: Optional[PassManager] = None + pass_manager: Optional[ + Union[ExportedProgramPassManager, GraphModulePassManager] + ] = None if isinstance(passes, Sequence): passes_seq = passes if isinstance(passes, dict): passes_dict = passes - if isinstance(passes, PassManager): + if isinstance(passes, (ExportedProgramPassManager, GraphModulePassManager)): pass_manager = passes for name, program in self._edge_programs.items(): diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 322f72c870a..21493a69644 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -504,3 +504,14 @@ python_unittest( "//executorch/exir/passes:propagate_device_pass", ], ) + +python_unittest( + name = "device_copy_ops", + srcs = [ + "test_device_copy_ops.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir/passes:device_copy_ops_registry", + ], +) diff --git a/exir/tests/test_device_copy_ops.py b/exir/tests/test_device_copy_ops.py new file mode 100644 index 00000000000..805159d9d81 --- /dev/null +++ b/exir/tests/test_device_copy_ops.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +# Import the registry to register the ops +import executorch.exir.passes._device_copy_ops_registry # noqa: F401 + +import torch + + +class DeviceCopyOpsRegistryTest(unittest.TestCase): + """Tests that et_copy._h2d_copy and et_copy._d2h_copy ops are correctly + registered and produce expected outputs during tracing (CPU-only).""" + + def test_h2d_copy_functional(self): + """_h2d_copy should return a clone of the input tensor.""" + x = torch.randn(2, 3) + result = torch.ops.et_copy._h2d_copy(x) + self.assertEqual(result.shape, x.shape) + self.assertEqual(result.dtype, x.dtype) + self.assertTrue(torch.equal(result, x)) + # Should be a new tensor, not the same object + self.assertFalse(result.data_ptr() == x.data_ptr()) + + def test_d2h_copy_functional(self): + """_d2h_copy should return a clone of the input tensor.""" + x = torch.randn(4, 5) + result = torch.ops.et_copy._d2h_copy(x) + self.assertEqual(result.shape, x.shape) + self.assertEqual(result.dtype, x.dtype) + self.assertTrue(torch.equal(result, x)) + self.assertFalse(result.data_ptr() == x.data_ptr()) + + def test_h2d_copy_out_variant(self): + """_h2d_copy.out should copy data into the provided out tensor.""" + x = torch.randn(3, 3) + out = torch.empty(3, 3) + result = torch.ops.et_copy._h2d_copy.out(x, out=out) + self.assertTrue(result is out) + self.assertTrue(torch.equal(out, x)) + + def test_d2h_copy_out_variant(self): + """_d2h_copy.out should copy data into the provided out tensor.""" + x = torch.randn(2, 4) + out = torch.empty(2, 4) + result = torch.ops.et_copy._d2h_copy.out(x, out=out) + self.assertTrue(result is out) + self.assertTrue(torch.equal(out, x)) + + def test_h2d_copy_preserves_dtype(self): + """_h2d_copy should work with various dtypes.""" + for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]: + x = torch.ones(2, 2, dtype=dtype) + result = torch.ops.et_copy._h2d_copy(x) + self.assertEqual(result.dtype, dtype) + self.assertTrue(torch.equal(result, x)) + + def test_h2d_copy_scalar_tensor(self): + """_h2d_copy should handle 0-dim tensors.""" + x = torch.tensor(3.14) + result = torch.ops.et_copy._h2d_copy(x) + self.assertEqual(result.shape, torch.Size([])) + self.assertTrue(torch.equal(result, x)) + + def test_d2h_copy_empty_tensor(self): + """_d2h_copy should handle empty tensors.""" + x = torch.empty(0, 3) + result = torch.ops.et_copy._d2h_copy(x) + self.assertEqual(result.shape, torch.Size([0, 3])) diff --git a/exir/tests/test_pass_infra.py b/exir/tests/test_pass_infra.py index ded3c0e849d..7df6b76b93a 100644 --- a/exir/tests/test_pass_infra.py +++ b/exir/tests/test_pass_infra.py @@ -9,14 +9,22 @@ import unittest +import executorch.exir as exir import torch -from executorch.exir import to_edge -from executorch.exir.pass_base import ExportPassBaseError, ProxyValue -from executorch.exir.pass_manager import PassManager +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ( + ExportedProgramPassBase, + ExportedProgramPassResult, + ExportPassBaseError, + ProxyValue, +) +from executorch.exir.pass_manager import ExportedProgramPassManager, PassManager from executorch.exir.passes import ScalarToTensorPass from executorch.exir.passes.pass_registry import PassRegistry -from torch.export import Dim, export -from torch.fx.passes.infra.pass_base import PassBase +from executorch.exir.program import to_edge +from torch.export import Dim, export, ExportedProgram +from torch.export.graph_signature import InputKind, InputSpec, TensorArgument +from torch.fx.passes.infra.pass_base import PassBase, PassResult class TestPassInfra(unittest.TestCase): @@ -216,3 +224,228 @@ def test_rejects_implicit_symbolic_scalar_coercions(self) -> None: with self.assertRaisesRegex(ExportPassBaseError, "converted to float"): float(ProxyValue(sym_float, torch.fx.Graph().placeholder("x"))) + + +class TestExportedProgramPassManager(unittest.TestCase): + def test_runs_graph_module_passes_on_exported_program(self) -> None: + """ + Tests that ExportedProgramPassManager runs GraphModule passes + on an ExportedProgram and the graph is correctly modified. + """ + + def replace_add_with_mul(gm: torch.fx.GraphModule) -> PassResult: + modified = False + for node in gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.add.Tensor + ): + node.target = exir_ops.edge.aten.mul.Tensor + modified = True + return PassResult(gm, modified) + + def f(x: torch.Tensor) -> torch.Tensor: + y = torch.add(x, x) + z = torch.add(y, x) + return z + + exported_program = ( + exir.capture(f, (torch.randn(10),), exir.CaptureConfig()) + .to_edge() + .exported_program + ) + + pm = ExportedProgramPassManager(passes=[replace_add_with_mul]) + result = pm(exported_program) + + # Verify return type + self.assertIsInstance(result, ExportedProgramPassResult) + self.assertTrue(result.modified) + + # Check that all add ops were replaced with mul + self.assertEqual( + len( + result.exported_program.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.add.Tensor + ) + ), + 0, + ) + + def test_updates_constants_on_exported_program(self) -> None: + """ + Tests that ExportedProgramPassManager can update constants + in the ExportedProgram using an ExportedProgram-aware pass. + """ + + class DoubleConstantsPass(ExportedProgramPassBase): + """Pass that doubles all constant tensor values in the ExportedProgram.""" + + def call(self, ep: ExportedProgram) -> ExportedProgramPassResult: + modified = False + for key, const in ep.constants.items(): + if isinstance(const, torch.Tensor): + ep.constants[key] = const * 2 + modified = True + return ExportedProgramPassResult(ep, modified) + + class ModuleWithConstant(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.ones(3) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.weight + + module = ModuleWithConstant() + exported_program = to_edge( + torch.export.export(module, (torch.randn(3),)) + ).exported_program() + + # Verify there are constants in the ExportedProgram + self.assertGreater( + len(exported_program.constants), 0, "Expected constants in ExportedProgram" + ) + + # Store original constant values + original_values = { + key: const.clone() + for key, const in exported_program.constants.items() + if isinstance(const, torch.Tensor) + } + + pm = ExportedProgramPassManager(passes=[DoubleConstantsPass()]) + result = pm(exported_program) + + self.assertIsInstance(result, ExportedProgramPassResult) + self.assertTrue(result.modified) + + # Verify constants were doubled + for key, original_const in original_values.items(): + new_const = result.exported_program.constants[key] + self.assertTrue( + torch.allclose(new_const, original_const * 2), + f"Constant {key} was not doubled correctly", + ) + + def test_adds_constant_to_exported_program(self) -> None: + """ + Tests that ExportedProgramPassManager can add a new constant + to the ExportedProgram, including updating the graph and input specs. + """ + + class AddConstantPass(ExportedProgramPassBase): + """Pass that adds a new constant tensor to the ExportedProgram.""" + + def call(self, ep: ExportedProgram) -> ExportedProgramPassResult: + graph = ep.graph_module.graph + sig = ep.graph_signature + + # Find the first user input to insert before it + placeholders = graph.find_nodes(op="placeholder") + assert len(placeholders) == 1 + user_input_node = placeholders[0] + + # Create a new constant tensor + new_constant_name = "_test_added_constant" + new_constant_tensor = torch.tensor([1.0, 2.0, 3.0]) + + # Add placeholder node for the new constant + with graph.inserting_before(user_input_node): + new_placeholder = graph.placeholder(new_constant_name) + # Set up meta for the new placeholder + new_placeholder.meta["val"] = new_constant_tensor + + # Add the constant to the constants dict + ep.constants[new_constant_name] = new_constant_tensor + + # Update input specs to include the new constant + new_input_spec = InputSpec( + kind=InputKind.CONSTANT_TENSOR, + arg=TensorArgument(name=new_placeholder.name), + target=new_constant_name, + persistent=False, + ) + sig.input_specs = (new_input_spec, sig.input_specs[0]) + + return ExportedProgramPassResult(ep, modified=True) + + class IdentityModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + exported_program = to_edge( + torch.export.export(IdentityModule(), (torch.randn(3),)) + ).exported_program() + assert len(exported_program.constants) == 0 + assert len(exported_program.graph_signature.input_specs) == 1 + + pm = ExportedProgramPassManager(passes=[AddConstantPass()]) + result = pm(exported_program) + + self.assertIsInstance(result, ExportedProgramPassResult) + self.assertTrue(result.modified) + + # Verify the new constant was added to constants dict + self.assertEqual(len(result.exported_program.constants), 1) + self.assertIn("_test_added_constant", result.exported_program.constants) + self.assertTrue( + torch.allclose( + result.exported_program.constants["_test_added_constant"], + torch.tensor([1.0, 2.0, 3.0]), + ) + ) + + # Verify input_specs was updated + self.assertEqual( + len(result.exported_program.graph_signature.input_specs), + 2, + ) + + # Verify the new placeholder exists in the graph + placeholder_names = [ + node.target + for node in result.exported_program.graph_module.graph.find_nodes( + op="placeholder" + ) + ] + self.assertTrue(len(placeholder_names) == 2) + + # Verify the new input spec has the correct kind + new_spec = None + for spec in result.exported_program.graph_signature.input_specs: + if spec.target == "_test_added_constant": + new_spec = spec + break + self.assertIsNotNone(new_spec) + self.assertEqual(new_spec.kind, InputKind.CONSTANT_TENSOR) + + def test_invalid_pass_creates_call_method(self) -> None: + """ + Tests that ExportedProgramPassManager detects invalid passes + that introduce call_method nodes. + """ + + def introduce_call_method(gm: torch.fx.GraphModule) -> PassResult: + node = list(gm.graph.nodes)[-2] + with gm.graph.inserting_after(node): + gm.graph.call_method("torch.ops.relu", (torch.randn(2),)) + return PassResult(gm, True) + + def f(x: torch.Tensor) -> torch.Tensor: + y = torch.add(x, x) + return y + + exported_program = ( + exir.capture(f, (torch.randn(10),), exir.CaptureConfig()) + .to_edge() + .exported_program + ) + + pm = ExportedProgramPassManager( + passes=[introduce_call_method], run_checks_after_each_pass=True + ) + + with self.assertRaisesRegex(Exception, "call_method"): + pm(exported_program) diff --git a/extension/android/BUCK b/extension/android/BUCK index c7e275805e2..bae5579b2a8 100644 --- a/extension/android/BUCK +++ b/extension/android/BUCK @@ -8,17 +8,19 @@ non_fbcode_target(_kind = fb_android_library, warnings_as_errors = False, required_for_source_only_abi = True, srcs = [ - "executorch_android/src/main/java/org/pytorch/executorch/DType.java", + "executorch_android/src/main/java/org/pytorch/executorch/DType.kt", "executorch_android/src/main/java/org/pytorch/executorch/EValue.java", "executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java", "executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java", - "executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java", + "executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.kt", "executorch_android/src/main/java/org/pytorch/executorch/Module.java", "executorch_android/src/main/java/org/pytorch/executorch/Tensor.java", - "executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java", + "executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.kt", ], autoglob = False, - language = "JAVA", + language = "KOTLIN", + pure_kotlin = False, + extra_kotlinc_arguments = ["-Xjvm-default=all"], deps = [ "//fbandroid/java/com/facebook/jni:jni", "//fbandroid/libraries/soloader/java/com/facebook/soloader/nativeloader:nativeloader", @@ -47,13 +49,14 @@ non_fbcode_target(_kind = fb_android_library, name = "executorch_llama", warnings_as_errors = False, srcs = [ - "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java", - "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java", - "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java", - "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java", + "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt", + "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt", + "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt", + "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt", ], autoglob = False, - language = "JAVA", + language = "KOTLIN", + extra_kotlinc_arguments = ["-Xjvm-default=all"], deps = [ ":executorch", "//fbandroid/java/com/facebook/jni:jni", diff --git a/extension/android/executorch_android/build.gradle b/extension/android/executorch_android/build.gradle index 3ee5b5877b3..2dbe0e1fb5f 100644 --- a/extension/android/executorch_android/build.gradle +++ b/extension/android/executorch_android/build.gradle @@ -51,6 +51,7 @@ android { } kotlinOptions { jvmTarget = "11" + freeCompilerArgs += ["-Xjvm-default=all"] } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.kt similarity index 77% rename from extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.java rename to extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.kt index 3aca4871d64..a58baa34b60 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.kt @@ -6,17 +6,17 @@ * LICENSE file in the root directory of this source tree. */ -package org.pytorch.executorch; +package org.pytorch.executorch -import org.pytorch.executorch.annotations.Experimental; +import org.pytorch.executorch.annotations.Experimental /** * Codes representing tensor data types. * - *

Warning: These APIs are experimental and subject to change without notice + * Warning: These APIs are experimental and subject to change without notice */ @Experimental -public enum DType { +enum class DType(@JvmField val jniCode: Int) { // NOTE: "jniCode" must be kept in sync with scalar_type.h. // NOTE: Never serialize "jniCode", because it can change between releases. @@ -68,18 +68,10 @@ public enum DType { BITS16(22), ; - final int jniCode; - - DType(int jniCode) { - this.jniCode = jniCode; - } - - public static DType fromJniCode(int jniCode) { - for (DType dtype : values()) { - if (dtype.jniCode == jniCode) { - return dtype; - } - } - throw new IllegalArgumentException("No DType found for jniCode " + jniCode); + companion object { + @JvmStatic + fun fromJniCode(jniCode: Int): DType = + entries.find { it.jniCode == jniCode } + ?: throw IllegalArgumentException("No DType found for jniCode $jniCode") } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java deleted file mode 100644 index a46b27ab39e..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch; - -/** Immutable metadata for a method in a Module. */ -public class MethodMetadata { - private final String mName; - private final String[] mBackends; - - MethodMetadata(String name, String[] backends) { - mName = name; - mBackends = backends; - } - - /** - * @return Method name - */ - public String getName() { - return mName; - } - - /** - * @return Backends used for this method - */ - public String[] getBackends() { - return mBackends; - } -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.kt new file mode 100644 index 00000000000..2f25f32c92f --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.kt @@ -0,0 +1,12 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch + +/** Immutable metadata for a method in a Module. */ +class MethodMetadata internal constructor(val name: String, val backends: Array) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.kt similarity index 68% rename from extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java rename to extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.kt index f5f36fc56da..42a5980d6ba 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.kt @@ -6,13 +6,13 @@ * LICENSE file in the root directory of this source tree. */ -package org.pytorch.executorch.annotations; +package org.pytorch.executorch.annotations /** * This annotation indicates that an API is experimental and may change or be removed at any time. * It does not provide any guarantees for API stability or backward-compatibility. * - *

This status is not permanent, and APIs marked with this annotation will need to be either made + * This status is not permanent, and APIs marked with this annotation will need to be either made * more robust or removed in the future. */ -public @interface Experimental {} +@Retention(AnnotationRetention.BINARY) annotation class Experimental diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/package-info.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/package-info.java deleted file mode 100644 index 2173a04c69d..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/annotations/package-info.java +++ /dev/null @@ -1,2 +0,0 @@ -/** Annotations used by ExecuTorch Android Java/JNI package. */ -package org.pytorch.executorch.annotations; diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt similarity index 53% rename from extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java rename to extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt index 4e834d06721..3b56986bf14 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt @@ -6,45 +6,42 @@ * LICENSE file in the root directory of this source tree. */ -package org.pytorch.executorch.extension.llm; +package org.pytorch.executorch.extension.llm -import com.facebook.jni.annotations.DoNotStrip; -import org.pytorch.executorch.annotations.Experimental; +import com.facebook.jni.annotations.DoNotStrip +import org.pytorch.executorch.annotations.Experimental /** - * Callback interface for Llama model. Users can implement this interface to receive the generated + * Callback interface for Llm model. Users can implement this interface to receive the generated * tokens and statistics. * - *

Warning: These APIs are experimental and subject to change without notice + * Warning: These APIs are experimental and subject to change without notice */ @Experimental -public interface LlmCallback { +interface LlmCallback { /** * Called when a new result is available from JNI. Users will keep getting onResult() invocations * until generate() finishes. * * @param result Last generated token */ - @DoNotStrip - public void onResult(String result); + @DoNotStrip fun onResult(result: String) /** * Called when the statistics for the generate() is available. * - *

The result will be a JSON string. See extension/llm/stats.h for the field definitions. + * The result will be a JSON string. See extension/llm/stats.h for the field definitions. * * @param stats JSON string containing the statistics for the generate() */ - @DoNotStrip - default void onStats(String stats) {} + @DoNotStrip fun onStats(stats: String) {} /** * Called when an error occurs during generate(). * - * @param errorCode Error code from the ExecuTorch runtime (see {@link - * org.pytorch.executorch.ExecutorchRuntimeException}) + * @param errorCode Error code from the ExecuTorch runtime (see + * [org.pytorch.executorch.ExecutorchRuntimeException]) * @param message Human-readable error description */ - @DoNotStrip - default void onError(int errorCode, String message) {} + @DoNotStrip fun onError(errorCode: Int, message: String) {} } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java deleted file mode 100644 index db7941aadad..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch.extension.llm; - -/** - * Configuration class for controlling text generation parameters in LLM operations. - * - *

This class provides settings for text generation behavior including output formatting, - * generation limits, and sampling parameters. Instances should be created using the {@link - * #create()} method and the fluent builder pattern. - */ -public class LlmGenerationConfig { - private final boolean echo; - private final int maxNewTokens; - private final boolean warming; - private final int seqLen; - private final float temperature; - private final int numBos; - private final int numEos; - - private LlmGenerationConfig(Builder builder) { - this.echo = builder.echo; - this.maxNewTokens = builder.maxNewTokens; - this.warming = builder.warming; - this.seqLen = builder.seqLen; - this.temperature = builder.temperature; - this.numBos = builder.numBos; - this.numEos = builder.numEos; - } - - /** - * Creates a new Builder instance for constructing generation configurations. - * - * @return a new Builder with default configuration values - */ - public static Builder create() { - return new Builder(); - } - - /** - * @return true if input prompt should be included in the output - */ - public boolean isEcho() { - return echo; - } - - /** - * @return maximum number of tokens to generate (-1 for unlimited) - */ - public int getMaxNewTokens() { - return maxNewTokens; - } - - /** - * @return true if model warming is enabled - */ - public boolean isWarming() { - return warming; - } - - /** - * @return maximum sequence length for generation (-1 for default) - */ - public int getSeqLen() { - return seqLen; - } - - /** - * @return temperature value for sampling (higher = more random) - */ - public float getTemperature() { - return temperature; - } - - /** - * @return number of BOS tokens to prepend - */ - public int getNumBos() { - return numBos; - } - - /** - * @return number of EOS tokens to append - */ - public int getNumEos() { - return numEos; - } - - /** - * Builder class for constructing LlmGenerationConfig instances. - * - *

Provides a fluent interface for configuring generation parameters with sensible defaults. - * All methods return the builder instance to enable method chaining. - */ - public static class Builder { - private boolean echo = true; - private int maxNewTokens = -1; - private boolean warming = false; - private int seqLen = -1; - private float temperature = 0.8f; - private int numBos = 0; - private int numEos = 0; - - Builder() {} - - /** - * Sets whether to include the input prompt in the generated output. - * - * @param echo true to include input prompt, false to return only new tokens - * @return this builder instance - */ - public Builder echo(boolean echo) { - this.echo = echo; - return this; - } - - /** - * Sets the maximum number of new tokens to generate. - * - * @param maxNewTokens the token limit (-1 for unlimited generation) - * @return this builder instance - */ - public Builder maxNewTokens(int maxNewTokens) { - this.maxNewTokens = maxNewTokens; - return this; - } - - /** - * Enables or disables model warming. - * - * @param warming true to generate initial tokens for model warmup - * @return this builder instance - */ - public Builder warming(boolean warming) { - this.warming = warming; - return this; - } - - /** - * Sets the maximum sequence length for generation. - * - * @param seqLen maximum sequence length (-1 for default behavior) - * @return this builder instance - */ - public Builder seqLen(int seqLen) { - this.seqLen = seqLen; - return this; - } - - /** - * Sets the temperature for random sampling. - * - * @param temperature sampling temperature (typical range 0.0-1.0) - * @return this builder instance - */ - public Builder temperature(float temperature) { - this.temperature = temperature; - return this; - } - - /** - * Sets the number of BOS tokens to prepend. - * - * @param numBos number of BOS tokens - * @return this builder instance - */ - public Builder numBos(int numBos) { - this.numBos = numBos; - return this; - } - - /** - * Sets the number of EOS tokens to append. - * - * @param numEos number of EOS tokens - * @return this builder instance - */ - public Builder numEos(int numEos) { - this.numEos = numEos; - return this; - } - - /** - * Constructs the LlmGenerationConfig instance with the configured parameters. - * - * @return new LlmGenerationConfig instance with current builder settings - */ - public LlmGenerationConfig build() { - return new LlmGenerationConfig(this); - } - } -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt new file mode 100644 index 00000000000..c0f8956fb7f --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt @@ -0,0 +1,78 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.extension.llm + +/** + * Configuration class for controlling text generation parameters in LLM operations. + * + * This class provides settings for text generation behavior including output formatting, generation + * limits, and sampling parameters. Instances should be created using the [create] method and the + * fluent builder pattern. + */ +class LlmGenerationConfig +private constructor( + @get:JvmName("isEcho") val echo: Boolean, + val maxNewTokens: Int, + @get:JvmName("isWarming") val warming: Boolean, + val seqLen: Int, + val temperature: Float, + val numBos: Int, + val numEos: Int, +) { + + companion object { + /** + * Creates a new Builder instance for constructing generation configurations. + * + * @return a new Builder with default configuration values + */ + @JvmStatic fun create(): Builder = Builder() + } + + /** + * Builder class for constructing LlmGenerationConfig instances. + * + * Provides a fluent interface for configuring generation parameters with sensible defaults. All + * methods return the builder instance to enable method chaining. + */ + class Builder internal constructor() { + private var echo: Boolean = true + private var maxNewTokens: Int = -1 + private var warming: Boolean = false + private var seqLen: Int = -1 + private var temperature: Float = 0.8f + private var numBos: Int = 0 + private var numEos: Int = 0 + + /** Sets whether to include the input prompt in the generated output. */ + fun echo(echo: Boolean): Builder = apply { this.echo = echo } + + /** Sets the maximum number of new tokens to generate. */ + fun maxNewTokens(maxNewTokens: Int): Builder = apply { this.maxNewTokens = maxNewTokens } + + /** Enables or disables model warming. */ + fun warming(warming: Boolean): Builder = apply { this.warming = warming } + + /** Sets the maximum sequence length for generation. */ + fun seqLen(seqLen: Int): Builder = apply { this.seqLen = seqLen } + + /** Sets the temperature for random sampling. */ + fun temperature(temperature: Float): Builder = apply { this.temperature = temperature } + + /** Sets the number of BOS tokens to prepend. */ + fun numBos(numBos: Int): Builder = apply { this.numBos = numBos } + + /** Sets the number of EOS tokens to append. */ + fun numEos(numEos: Int): Builder = apply { this.numEos = numEos } + + /** Constructs the LlmGenerationConfig instance with the configured parameters. */ + fun build(): LlmGenerationConfig = + LlmGenerationConfig(echo, maxNewTokens, warming, seqLen, temperature, numBos, numEos) + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java deleted file mode 100644 index 0c467b13f44..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ /dev/null @@ -1,823 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch.extension.llm; - -import com.facebook.jni.HybridData; -import com.facebook.jni.annotations.DoNotStrip; -import java.io.Closeable; -import java.nio.ByteBuffer; -import java.util.List; -import java.util.concurrent.locks.ReentrantLock; -import org.pytorch.executorch.ExecuTorchRuntime; -import org.pytorch.executorch.ExecutorchRuntimeException; -import org.pytorch.executorch.annotations.Experimental; - -/** - * LlmModule is a wrapper around the Executorch LLM. It provides a simple interface to generate text - * from the model. - * - *

Warning: These APIs are experimental and subject to change without notice - */ -@Experimental -public class LlmModule implements Closeable { - - public static final int MODEL_TYPE_TEXT = 1; - public static final int MODEL_TYPE_TEXT_VISION = 2; - public static final int MODEL_TYPE_MULTIMODAL = 2; - - private final HybridData mHybridData; - private final ReentrantLock mLock = new ReentrantLock(); - private volatile boolean mDestroyed = false; - private static final int DEFAULT_SEQ_LEN = 128; - private static final boolean DEFAULT_ECHO = true; - private static final float DEFAULT_TEMPERATURE = -1.0f; - private static final int DEFAULT_BOS = 0; - private static final int DEFAULT_EOS = 0; - private static final int DEFAULT_LOAD_MODE = LlmModuleConfig.LOAD_MODE_MMAP; - - @DoNotStrip - private static native HybridData initHybrid( - int modelType, - String modulePath, - String tokenizerPath, - float temperature, - List dataFiles, - int numBos, - int numEos, - int loadMode); - - private LlmModule( - int modelType, - String modulePath, - String tokenizerPath, - float temperature, - List dataFiles, - int numBos, - int numEos, - int loadMode) { - ExecuTorchRuntime.getRuntime(); - ExecuTorchRuntime.validateFilePath(modulePath, "model path"); - ExecuTorchRuntime.validateFilePath(tokenizerPath, "tokenizer path"); - - mHybridData = - initHybrid( - modelType, modulePath, tokenizerPath, temperature, dataFiles, numBos, numEos, loadMode); - } - - /** - * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and - * dataFiles. - */ - public LlmModule( - int modelType, - String modulePath, - String tokenizerPath, - float temperature, - List dataFiles, - int numBos, - int numEos) { - this( - modelType, - modulePath, - tokenizerPath, - temperature, - dataFiles, - numBos, - numEos, - DEFAULT_LOAD_MODE); - } - - /** - * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and - * dataFiles. - */ - public LlmModule( - int modelType, - String modulePath, - String tokenizerPath, - float temperature, - List dataFiles) { - this( - modelType, - modulePath, - tokenizerPath, - temperature, - dataFiles, - DEFAULT_BOS, - DEFAULT_EOS, - DEFAULT_LOAD_MODE); - } - - /** - * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and - * data path. - */ - public LlmModule( - int modelType, - String modulePath, - String tokenizerPath, - float temperature, - String dataPath, - int numBos, - int numEos) { - this( - modelType, - modulePath, - tokenizerPath, - temperature, - dataPath != null ? List.of(dataPath) : List.of(), - numBos, - numEos); - } - - /** - * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and - * data path. - */ - public LlmModule( - int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath) { - this(modelType, modulePath, tokenizerPath, temperature, dataPath, DEFAULT_BOS, DEFAULT_EOS); - } - - /** Constructs a LLM Module for a model with given model path, tokenizer, temperature. */ - public LlmModule(String modulePath, String tokenizerPath, float temperature) { - this( - MODEL_TYPE_TEXT, - modulePath, - tokenizerPath, - temperature, - List.of(), - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Constructs a LLM Module for a model with given model path, tokenizer, temperature and data - * path. - */ - public LlmModule(String modulePath, String tokenizerPath, float temperature, String dataPath) { - this( - MODEL_TYPE_TEXT, - modulePath, - tokenizerPath, - temperature, - List.of(dataPath), - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */ - public LlmModule(int modelType, String modulePath, String tokenizerPath, float temperature) { - this(modelType, modulePath, tokenizerPath, temperature, List.of(), DEFAULT_BOS, DEFAULT_EOS); - } - - /** Constructs a LLM Module for a model with the given LlmModuleConfig */ - public LlmModule(LlmModuleConfig config) { - this( - config.getModelType(), - config.getModulePath(), - config.getTokenizerPath(), - config.getTemperature(), - config.getDataPath() != null ? List.of(config.getDataPath()) : List.of(), - config.getNumBos(), - config.getNumEos(), - config.getLoadMode()); - } - - private void checkNotDestroyed() { - if (mDestroyed) throw new IllegalStateException("LlmModule has been destroyed"); - } - - private void checkNotReentrant() { - if (mLock.getHoldCount() > 1) { - throw new IllegalStateException("Cannot call LlmModule methods from within a callback"); - } - } - - /** - * Releases native resources. Callers must ensure no other methods are in-flight. Call {@link - * #stop()} and wait for {@link #generate(String, LlmCallback)} to return before calling this - * method. - */ - @Override - public void close() { - if (mLock.tryLock()) { - try { - if (mLock.getHoldCount() > 1) { - throw new IllegalStateException( - "Cannot close module from within a callback during execution"); - } - if (!mDestroyed) { - mDestroyed = true; - mHybridData.resetNative(); - } - } finally { - mLock.unlock(); - } - } else { - throw new IllegalStateException("Cannot close module while method is executing"); - } - } - - /** - * @deprecated Use {@link #close()} instead. - */ - @Deprecated - public void resetNative() { - close(); - } - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param llmCallback callback object to receive results. - */ - public void generate(String prompt, LlmCallback llmCallback) { - generate( - prompt, - DEFAULT_SEQ_LEN, - llmCallback, - DEFAULT_ECHO, - DEFAULT_TEMPERATURE, - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results. - */ - public void generate(String prompt, int seqLen, LlmCallback llmCallback) { - generate( - null, - 0, - 0, - 0, - prompt, - seqLen, - llmCallback, - DEFAULT_ECHO, - DEFAULT_TEMPERATURE, - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param llmCallback callback object to receive results - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - */ - public void generate(String prompt, LlmCallback llmCallback, boolean echo) { - generate( - null, - 0, - 0, - 0, - prompt, - DEFAULT_SEQ_LEN, - llmCallback, - echo, - DEFAULT_TEMPERATURE, - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - */ - public void generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { - generate(prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - * @param temperature temperature for sampling (use negative value to use module default) - * @param numBos number of BOS tokens to prepend - * @param numEos number of EOS tokens to append - */ - public void generate( - String prompt, - int seqLen, - LlmCallback llmCallback, - boolean echo, - float temperature, - int numBos, - int numEos) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int err = generateNative(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); - if (err != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to generate"); - } - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native int generateNative( - String prompt, - int seqLen, - LlmCallback llmCallback, - boolean echo, - float temperature, - int numBos, - int numEos); - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param config the config for generation - * @param llmCallback callback object to receive results - */ - public void generate(String prompt, LlmGenerationConfig config, LlmCallback llmCallback) { - int seqLen = config.getSeqLen(); - boolean echo = config.isEcho(); - float temperature = config.getTemperature(); - int numBos = config.getNumBos(); - int numEos = config.getNumEos(); - generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); - } - - /** - * Start generating tokens from the module. - * - * @param image Input image as a byte array - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results. - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - */ - public void generate( - int[] image, - int width, - int height, - int channels, - String prompt, - int seqLen, - LlmCallback llmCallback, - boolean echo) { - generate( - image, - width, - height, - channels, - prompt, - seqLen, - llmCallback, - echo, - DEFAULT_TEMPERATURE, - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param image Input image as a byte array - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results. - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - * @param temperature temperature for sampling (use negative value to use module default) - */ - public void generate( - int[] image, - int width, - int height, - int channels, - String prompt, - int seqLen, - LlmCallback llmCallback, - boolean echo, - float temperature) { - generate( - image, - width, - height, - channels, - prompt, - seqLen, - llmCallback, - echo, - temperature, - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param image Input image as a byte array - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results. - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - * @param temperature temperature for sampling (use negative value to use module default) - * @param numBos number of BOS tokens to prepend - * @param numEos number of EOS tokens to append - */ - public void generate( - int[] image, - int width, - int height, - int channels, - String prompt, - int seqLen, - LlmCallback llmCallback, - boolean echo, - float temperature, - int numBos, - int numEos) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - if (image != null) { - int nativeResult = prefillImagesInput(image, width, height, channels); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } - int err = generateNative(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); - if (err != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to generate"); - } - } finally { - mLock.unlock(); - } - } - - /** - * Prefill the KV cache with the given image input. - * - * @param image Input image as a byte array - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillImages(int[] image, int width, int height, int channels) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int nativeResult = prefillImagesInput(image, width, height, channels); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - /** - * Prefill a multimodal Module with the given image input via a direct ByteBuffer. The buffer data - * is accessed directly without JNI array copies, unlike {@link #prefillImages(int[], int, int, - * int)}. The ByteBuffer must contain raw uint8 pixel data in CHW format with at least channels * - * height * width bytes remaining. Only the first channels * height * width bytes from the - * buffer's current position are read; the position of the original ByteBuffer is not modified. - * - * @param image Input image as a direct ByteBuffer containing uint8 pixel data - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @throws IllegalArgumentException if the ByteBuffer is not direct or has insufficient remaining - * bytes - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillImages(ByteBuffer image, int width, int height, int channels) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - if (!image.isDirect()) { - throw new IllegalArgumentException("Input ByteBuffer must be direct."); - } - long expectedBytes; - try { - long pixels = Math.multiplyExact((long) width, (long) height); - expectedBytes = Math.multiplyExact(pixels, (long) channels); - } catch (ArithmeticException ex) { - throw new IllegalArgumentException( - "width*height*channels is too large and overflows the allowed range.", ex); - } - if (width <= 0 - || height <= 0 - || channels <= 0 - || expectedBytes > Integer.MAX_VALUE - || image.remaining() < expectedBytes) { - throw new IllegalArgumentException( - "ByteBuffer remaining (" - + image.remaining() - + ") must be at least width*height*channels (" - + expectedBytes - + ")."); - } - // slice() so that getDirectBufferAddress on the native side returns a pointer - // starting at the current position, not the base address. - int nativeResult = prefillImagesInputBuffer(image.slice(), width, height, channels); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - /** - * Prefill a multimodal Module with the given normalized image input via a direct ByteBuffer. The - * buffer data is accessed directly without JNI array copies, unlike {@link - * #prefillImages(float[], int, int, int)}. The ByteBuffer must contain normalized float pixel - * data in CHW format with at least channels * height * width * 4 bytes remaining. Only the first - * channels * height * width floats from the buffer's current position are consumed. The buffer - * must use the platform's native byte order (set via {@code - * buffer.order(ByteOrder.nativeOrder())}). - * - * @param image Input normalized image as a direct ByteBuffer containing float pixel data in - * native byte order - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @throws IllegalArgumentException if the ByteBuffer is not direct, has insufficient remaining - * bytes, is not float-aligned, or does not use native byte order - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillNormalizedImage(ByteBuffer image, int width, int height, int channels) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - if (!image.isDirect()) { - throw new IllegalArgumentException("Input ByteBuffer must be direct."); - } - if (image.order() != java.nio.ByteOrder.nativeOrder()) { - throw new IllegalArgumentException( - "Input ByteBuffer must use native byte order (ByteOrder.nativeOrder())."); - } - if (image.position() % Float.BYTES != 0) { - throw new IllegalArgumentException( - "Input ByteBuffer position (" + image.position() + ") must be 4-byte aligned."); - } - final long expectedBytes; - try { - int wh = Math.multiplyExact(width, height); - long whc = Math.multiplyExact((long) wh, (long) channels); - long totalBytes = Math.multiplyExact(whc, (long) Float.BYTES); - if (totalBytes > Integer.MAX_VALUE) { - throw new IllegalArgumentException( - "ByteBuffer size (width*height*channels*4) exceeds Integer.MAX_VALUE bytes: " - + totalBytes); - } - expectedBytes = totalBytes; - } catch (ArithmeticException e) { - throw new IllegalArgumentException( - "Overflow while computing width*height*channels*4 for ByteBuffer size.", e); - } - if (width <= 0 || height <= 0 || channels <= 0 || image.remaining() < expectedBytes) { - throw new IllegalArgumentException( - "ByteBuffer remaining (" - + image.remaining() - + ") must be at least width*height*channels*4 (" - + expectedBytes - + ")."); - } - if (image.remaining() % Float.BYTES != 0) { - throw new IllegalArgumentException( - "ByteBuffer remaining (" - + image.remaining() - + ") must be a multiple of 4 (float size)."); - } - // slice() so that getDirectBufferAddress on the native side returns a pointer - // starting at the current position, not the base address. - int nativeResult = prefillNormalizedImagesInputBuffer(image.slice(), width, height, channels); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - private native int prefillImagesInput(int[] image, int width, int height, int channels); - - private native int prefillImagesInputBuffer( - ByteBuffer image, int width, int height, int channels); - - private native int prefillNormalizedImagesInputBuffer( - ByteBuffer image, int width, int height, int channels); - - /** - * Prefill the KV cache with the given normalized image input. - * - * @param image Input normalized image as a float array - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillImages(float[] image, int width, int height, int channels) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int nativeResult = prefillNormalizedImagesInput(image, width, height, channels); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - private native int prefillNormalizedImagesInput( - float[] image, int width, int height, int channels); - - /** - * Prefill the KV cache with the given preprocessed audio input. - * - * @param audio Input preprocessed audio as a byte array - * @param batch_size Input batch size - * @param n_bins Input number of bins - * @param n_frames Input number of frames - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int nativeResult = prefillAudioInput(audio, batch_size, n_bins, n_frames); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - private native int prefillAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames); - - /** - * Prefill the KV cache with the given preprocessed audio input. - * - * @param audio Input preprocessed audio as a float array - * @param batch_size Input batch size - * @param n_bins Input number of bins - * @param n_frames Input number of frames - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int nativeResult = prefillAudioInputFloat(audio, batch_size, n_bins, n_frames); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - private native int prefillAudioInputFloat( - float[] audio, int batch_size, int n_bins, int n_frames); - - /** - * Prefill the KV cache with the given raw audio input. - * - * @param audio Input raw audio as a byte array - * @param batch_size Input batch size - * @param n_channels Input number of channels - * @param n_samples Input number of samples - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int nativeResult = prefillRawAudioInput(audio, batch_size, n_channels, n_samples); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - private native int prefillRawAudioInput( - byte[] audio, int batch_size, int n_channels, int n_samples); - - /** - * Prefill the KV cache with the given text prompt. - * - * @param prompt The text prompt to prefill. - * @throws ExecutorchRuntimeException if the prefill failed - */ - @Experimental - public void prefillPrompt(String prompt) { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int nativeResult = prefillTextInput(prompt); - if (nativeResult != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); - } - } finally { - mLock.unlock(); - } - } - - // returns status - private native int prefillTextInput(String prompt); - - /** - * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. - * - *

The startPos will be reset to 0. - */ - public void resetContext() { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - resetContextNative(); - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native void resetContextNative(); - - /** Stop current generate() before it finishes. */ - public void stop() { - if (mDestroyed) return; - stopNative(); - } - - @DoNotStrip - private native void stopNative(); - - /** Force loading the module. Otherwise the model is loaded during first generate(). */ - public void load() { - mLock.lock(); - try { - checkNotReentrant(); - checkNotDestroyed(); - int err = loadNative(); - if (err != 0) { - throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to load model"); - } - } finally { - mLock.unlock(); - } - } - - @DoNotStrip - private native int loadNative(); -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt new file mode 100644 index 00000000000..f95e796b83b --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt @@ -0,0 +1,898 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.extension.llm + +import com.facebook.jni.HybridData +import com.facebook.jni.annotations.DoNotStrip +import java.io.Closeable +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.util.concurrent.locks.ReentrantLock +import org.pytorch.executorch.ExecuTorchRuntime +import org.pytorch.executorch.ExecutorchRuntimeException +import org.pytorch.executorch.annotations.Experimental + +/** + * LlmModule is a wrapper around the Executorch LLM. It provides a simple interface to generate text + * from the model. + * + * Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +class LlmModule +private constructor( + modelType: Int, + modulePath: String, + tokenizerPath: String, + temperature: Float, + dataFiles: List, + numBos: Int, + numEos: Int, + loadMode: Int, +) : Closeable { + + private val mHybridData: HybridData + private val mLock = ReentrantLock() + @Volatile private var mDestroyed = false + + init { + ExecuTorchRuntime.getRuntime() + ExecuTorchRuntime.validateFilePath(modulePath, "model path") + ExecuTorchRuntime.validateFilePath(tokenizerPath, "tokenizer path") + mHybridData = + initHybrid( + modelType, + modulePath, + tokenizerPath, + temperature, + dataFiles, + numBos, + numEos, + loadMode, + ) + } + + /** + * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and + * dataFiles. + */ + constructor( + modelType: Int, + modulePath: String, + tokenizerPath: String, + temperature: Float, + dataFiles: List, + numBos: Int, + numEos: Int, + ) : this( + modelType, + modulePath, + tokenizerPath, + temperature, + dataFiles, + numBos, + numEos, + DEFAULT_LOAD_MODE, + ) + + /** + * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and + * dataFiles. + */ + constructor( + modelType: Int, + modulePath: String, + tokenizerPath: String, + temperature: Float, + dataFiles: List, + ) : this( + modelType, + modulePath, + tokenizerPath, + temperature, + dataFiles, + DEFAULT_BOS, + DEFAULT_EOS, + DEFAULT_LOAD_MODE, + ) + + /** + * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and + * data path. + */ + constructor( + modelType: Int, + modulePath: String, + tokenizerPath: String, + temperature: Float, + dataPath: String?, + numBos: Int, + numEos: Int, + ) : this( + modelType, + modulePath, + tokenizerPath, + temperature, + listOfNotNull(dataPath), + numBos, + numEos, + ) + + /** + * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and + * data path. + */ + constructor( + modelType: Int, + modulePath: String, + tokenizerPath: String, + temperature: Float, + dataPath: String?, + ) : this( + modelType, + modulePath, + tokenizerPath, + temperature, + dataPath, + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** Constructs a LLM Module for a model with given model path, tokenizer, temperature. */ + constructor( + modulePath: String, + tokenizerPath: String, + temperature: Float, + ) : this( + MODEL_TYPE_TEXT, + modulePath, + tokenizerPath, + temperature, + emptyList(), + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** + * Constructs a LLM Module for a model with given model path, tokenizer, temperature and data + * path. + */ + constructor( + modulePath: String, + tokenizerPath: String, + temperature: Float, + dataPath: String, + ) : this( + MODEL_TYPE_TEXT, + modulePath, + tokenizerPath, + temperature, + listOf(dataPath), + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */ + constructor( + modelType: Int, + modulePath: String, + tokenizerPath: String, + temperature: Float, + ) : this( + modelType, + modulePath, + tokenizerPath, + temperature, + emptyList(), + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** Constructs a LLM Module for a model with the given LlmModuleConfig */ + constructor( + config: LlmModuleConfig + ) : this( + config.modelType, + config.modulePath, + config.tokenizerPath, + config.temperature, + listOfNotNull(config.dataPath), + config.numBos, + config.numEos, + config.loadMode, + ) + + private fun checkNotDestroyed() { + if (mDestroyed) throw IllegalStateException("LlmModule has been destroyed") + } + + private fun checkNotReentrant() { + if (mLock.holdCount > 1) { + throw IllegalStateException("Cannot call LlmModule methods from within a callback") + } + } + + /** + * Releases native resources. Callers must ensure no other methods are in-flight. Call [stop] and + * wait for [generate] to return before calling this method. + */ + override fun close() { + if (mLock.tryLock()) { + try { + if (mLock.holdCount > 1) { + throw IllegalStateException("Cannot close module from within a callback during execution") + } + if (!mDestroyed) { + mDestroyed = true + mHybridData.resetNative() + } + } finally { + mLock.unlock() + } + } else { + throw IllegalStateException("Cannot close module while method is executing") + } + } + + /** @deprecated Use [close] instead. */ + @Deprecated("Use close() instead", replaceWith = ReplaceWith("close()")) + fun resetNative() { + close() + } + + // --- generate overloads --- + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param llmCallback callback object to receive results. + */ + fun generate(prompt: String, llmCallback: LlmCallback) { + generate( + prompt, + DEFAULT_SEQ_LEN, + llmCallback, + DEFAULT_ECHO, + DEFAULT_TEMPERATURE, + DEFAULT_BOS, + DEFAULT_EOS, + ) + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results. + */ + fun generate(prompt: String, seqLen: Int, llmCallback: LlmCallback) { + generate( + null, + 0, + 0, + 0, + prompt, + seqLen, + llmCallback, + DEFAULT_ECHO, + DEFAULT_TEMPERATURE, + DEFAULT_BOS, + DEFAULT_EOS, + ) + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param llmCallback callback object to receive results + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + */ + fun generate(prompt: String, llmCallback: LlmCallback, echo: Boolean) { + generate( + null, + 0, + 0, + 0, + prompt, + DEFAULT_SEQ_LEN, + llmCallback, + echo, + DEFAULT_TEMPERATURE, + DEFAULT_BOS, + DEFAULT_EOS, + ) + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + */ + fun generate(prompt: String, seqLen: Int, llmCallback: LlmCallback, echo: Boolean) { + generate(prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS) + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + * @param temperature temperature for sampling (use negative value to use module default) + * @param numBos number of BOS tokens to prepend + * @param numEos number of EOS tokens to append + */ + fun generate( + prompt: String, + seqLen: Int, + llmCallback: LlmCallback, + echo: Boolean, + temperature: Float, + numBos: Int, + numEos: Int, + ) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val err = generateNative(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos) + if (err != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to generate") + } + } finally { + mLock.unlock() + } + } + + @DoNotStrip + private external fun generateNative( + prompt: String, + seqLen: Int, + llmCallback: LlmCallback, + echo: Boolean, + temperature: Float, + numBos: Int, + numEos: Int, + ): Int + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param config the config for generation + * @param llmCallback callback object to receive results + */ + fun generate(prompt: String, config: LlmGenerationConfig, llmCallback: LlmCallback) { + generate( + null, + 0, + 0, + 0, + prompt, + config.seqLen, + llmCallback, + config.echo, + config.temperature, + config.numBos, + config.numEos, + ) + } + + /** + * Start generating tokens from the module. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + */ + fun generate( + image: IntArray?, + width: Int, + height: Int, + channels: Int, + prompt: String, + seqLen: Int, + llmCallback: LlmCallback, + echo: Boolean, + ) { + generate( + image, + width, + height, + channels, + prompt, + seqLen, + llmCallback, + echo, + DEFAULT_TEMPERATURE, + DEFAULT_BOS, + DEFAULT_EOS, + ) + } + + /** + * Start generating tokens from the module. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + * @param temperature temperature for sampling (use negative value to use module default) + */ + fun generate( + image: IntArray?, + width: Int, + height: Int, + channels: Int, + prompt: String, + seqLen: Int, + llmCallback: LlmCallback, + echo: Boolean, + temperature: Float, + ) { + generate( + image, + width, + height, + channels, + prompt, + seqLen, + llmCallback, + echo, + temperature, + DEFAULT_BOS, + DEFAULT_EOS, + ) + } + + /** + * Start generating tokens from the module. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + * @param temperature temperature for sampling (use negative value to use module default) + * @param numBos number of BOS tokens to prepend + * @param numEos number of EOS tokens to append + */ + fun generate( + image: IntArray?, + width: Int, + height: Int, + channels: Int, + prompt: String, + seqLen: Int, + llmCallback: LlmCallback, + echo: Boolean, + temperature: Float, + numBos: Int, + numEos: Int, + ) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + if (image != null) { + val nativeResult = prefillImagesInput(image, width, height, channels) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } + val err = generateNative(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos) + if (err != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to generate") + } + } finally { + mLock.unlock() + } + } + + // --- prefill methods --- + + /** + * Prefill the KV cache with the given image input. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillImages(image: IntArray, width: Int, height: Int, channels: Int) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val nativeResult = prefillImagesInput(image, width, height, channels) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + /** + * Prefill a multimodal Module with the given image input via a direct ByteBuffer. The buffer data + * is accessed directly without JNI array copies, unlike [prefillImages]. The ByteBuffer must + * contain raw uint8 pixel data in CHW format with at least channels * height * width bytes + * remaining. Only the first channels * height * width bytes from the buffer's current position + * are read; the position of the original ByteBuffer is not modified. + * + * @param image Input image as a direct ByteBuffer containing uint8 pixel data + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @throws IllegalArgumentException if the ByteBuffer is not direct or has insufficient remaining + * bytes + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillImages(image: ByteBuffer, width: Int, height: Int, channels: Int) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + require(image.isDirect) { "Input ByteBuffer must be direct." } + val expectedBytes: Long + try { + val pixels = Math.multiplyExact(width.toLong(), height.toLong()) + expectedBytes = Math.multiplyExact(pixels, channels.toLong()) + } catch (ex: ArithmeticException) { + throw IllegalArgumentException( + "width*height*channels is too large and overflows the allowed range.", + ex, + ) + } + require( + width > 0 && + height > 0 && + channels > 0 && + expectedBytes <= Int.MAX_VALUE.toLong() && + image.remaining().toLong() >= expectedBytes + ) { + "ByteBuffer remaining (${image.remaining()}) must be at least width*height*channels ($expectedBytes)." + } + // slice() so that getDirectBufferAddress on the native side returns a pointer + // starting at the current position, not the base address. + val nativeResult = prefillImagesInputBuffer(image.slice(), width, height, channels) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + /** + * Prefill a multimodal Module with the given normalized image input via a direct ByteBuffer. The + * buffer data is accessed directly without JNI array copies, unlike [prefillImages]. The + * ByteBuffer must contain normalized float pixel data in CHW format with at least channels * + * height * width * 4 bytes remaining. Only the first channels * height * width floats from the + * buffer's current position are consumed. The buffer must use the platform's native byte order + * (set via `buffer.order(ByteOrder.nativeOrder())`). + * + * @param image Input normalized image as a direct ByteBuffer containing float pixel data in + * native byte order + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @throws IllegalArgumentException if the ByteBuffer is not direct, has insufficient remaining + * bytes, is not float-aligned, or does not use native byte order + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillNormalizedImage(image: ByteBuffer, width: Int, height: Int, channels: Int) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + require(image.isDirect) { "Input ByteBuffer must be direct." } + require(image.order() == ByteOrder.nativeOrder()) { + "Input ByteBuffer must use native byte order (ByteOrder.nativeOrder())." + } + require(image.position() % Float.SIZE_BYTES == 0) { + "Input ByteBuffer position (${image.position()}) must be 4-byte aligned." + } + val expectedBytes: Long + try { + val wh = Math.multiplyExact(width, height) + val whc = Math.multiplyExact(wh.toLong(), channels.toLong()) + val totalBytes = Math.multiplyExact(whc, Float.SIZE_BYTES.toLong()) + if (totalBytes > Int.MAX_VALUE.toLong()) { + throw IllegalArgumentException( + "ByteBuffer size (width*height*channels*4) exceeds Integer.MAX_VALUE bytes: $totalBytes", + ) + } + expectedBytes = totalBytes + } catch (e: ArithmeticException) { + throw IllegalArgumentException( + "Overflow while computing width*height*channels*4 for ByteBuffer size.", + e, + ) + } + require( + width > 0 && height > 0 && channels > 0 && image.remaining().toLong() >= expectedBytes + ) { + "ByteBuffer remaining (${image.remaining()}) must be at least width*height*channels*4 ($expectedBytes)." + } + require(image.remaining() % Float.SIZE_BYTES == 0) { + "ByteBuffer remaining (${image.remaining()}) must be a multiple of 4 (float size)." + } + // slice() so that getDirectBufferAddress on the native side returns a pointer + // starting at the current position, not the base address. + val nativeResult = prefillNormalizedImagesInputBuffer(image.slice(), width, height, channels) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + private external fun prefillImagesInput( + image: IntArray, + width: Int, + height: Int, + channels: Int, + ): Int + + private external fun prefillImagesInputBuffer( + image: ByteBuffer, + width: Int, + height: Int, + channels: Int, + ): Int + + private external fun prefillNormalizedImagesInputBuffer( + image: ByteBuffer, + width: Int, + height: Int, + channels: Int, + ): Int + + /** + * Prefill the KV cache with the given normalized image input. + * + * @param image Input normalized image as a float array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillImages(image: FloatArray, width: Int, height: Int, channels: Int) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val nativeResult = prefillNormalizedImagesInput(image, width, height, channels) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + private external fun prefillNormalizedImagesInput( + image: FloatArray, + width: Int, + height: Int, + channels: Int, + ): Int + + /** + * Prefill the KV cache with the given preprocessed audio input. + * + * @param audio Input preprocessed audio as a byte array + * @param batchSize Input batch size + * @param nBins Input number of bins + * @param nFrames Input number of frames + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillAudio(audio: ByteArray, batchSize: Int, nBins: Int, nFrames: Int) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val nativeResult = prefillAudioInput(audio, batchSize, nBins, nFrames) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + private external fun prefillAudioInput( + audio: ByteArray, + batchSize: Int, + nBins: Int, + nFrames: Int, + ): Int + + /** + * Prefill the KV cache with the given preprocessed audio input. + * + * @param audio Input preprocessed audio as a float array + * @param batchSize Input batch size + * @param nBins Input number of bins + * @param nFrames Input number of frames + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillAudio(audio: FloatArray, batchSize: Int, nBins: Int, nFrames: Int) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val nativeResult = prefillAudioInputFloat(audio, batchSize, nBins, nFrames) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + private external fun prefillAudioInputFloat( + audio: FloatArray, + batchSize: Int, + nBins: Int, + nFrames: Int, + ): Int + + /** + * Prefill the KV cache with the given raw audio input. + * + * @param audio Input raw audio as a byte array + * @param batchSize Input batch size + * @param nChannels Input number of channels + * @param nSamples Input number of samples + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillRawAudio(audio: ByteArray, batchSize: Int, nChannels: Int, nSamples: Int) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val nativeResult = prefillRawAudioInput(audio, batchSize, nChannels, nSamples) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + private external fun prefillRawAudioInput( + audio: ByteArray, + batchSize: Int, + nChannels: Int, + nSamples: Int, + ): Int + + /** + * Prefill the KV cache with the given text prompt. + * + * @param prompt The text prompt to prefill. + * @throws ExecutorchRuntimeException if the prefill failed + */ + @Experimental + fun prefillPrompt(prompt: String) { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val nativeResult = prefillTextInput(prompt) + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed") + } + } finally { + mLock.unlock() + } + } + + // returns status + private external fun prefillTextInput(prompt: String): Int + + /** + * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. + * + * The startPos will be reset to 0. + */ + fun resetContext() { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + resetContextNative() + } finally { + mLock.unlock() + } + } + + @DoNotStrip private external fun resetContextNative() + + /** Stop current generate() before it finishes. */ + fun stop() { + if (mDestroyed) return + stopNative() + } + + @DoNotStrip private external fun stopNative() + + /** Force loading the module. Otherwise the model is loaded during first generate(). */ + fun load() { + mLock.lock() + try { + checkNotReentrant() + checkNotDestroyed() + val err = loadNative() + if (err != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to load model") + } + } finally { + mLock.unlock() + } + } + + @DoNotStrip private external fun loadNative(): Int + + companion object { + const val MODEL_TYPE_TEXT = 1 + const val MODEL_TYPE_TEXT_VISION = 2 + const val MODEL_TYPE_MULTIMODAL = 2 + + private const val DEFAULT_SEQ_LEN = 128 + private const val DEFAULT_ECHO = true + private const val DEFAULT_TEMPERATURE = -1.0f + private const val DEFAULT_BOS = 0 + private const val DEFAULT_EOS = 0 + private const val DEFAULT_LOAD_MODE = LlmModuleConfig.LOAD_MODE_MMAP + + @DoNotStrip + @JvmStatic + private external fun initHybrid( + modelType: Int, + modulePath: String, + tokenizerPath: String, + temperature: Float, + dataFiles: List, + numBos: Int, + numEos: Int, + loadMode: Int, + ): HybridData + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java deleted file mode 100644 index feb52a2b34b..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java +++ /dev/null @@ -1,252 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch.extension.llm; - -/** - * Configuration class for initializing a LlmModule. - * - *

{@link #create()} method and the fluent builder pattern. - */ -public class LlmModuleConfig { - private final String modulePath; - private final String tokenizerPath; - private final float temperature; - private final String dataPath; - private final int modelType; - private final int numBos; - private final int numEos; - private final int loadMode; - - /** Load entire model file into a buffer (no mmap). */ - public static final int LOAD_MODE_FILE = 0; - - /** Load model via mmap without mlock (default). Pages faulted in on demand. */ - public static final int LOAD_MODE_MMAP = 1; - - /** Load model via mmap and pin all pages with mlock. */ - public static final int LOAD_MODE_MMAP_USE_MLOCK = 2; - - /** Load model via mmap and attempt mlock, ignoring mlock failures. */ - public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3; - - private LlmModuleConfig(Builder builder) { - this.modulePath = builder.modulePath; - this.tokenizerPath = builder.tokenizerPath; - this.temperature = builder.temperature; - this.dataPath = builder.dataPath; - this.modelType = builder.modelType; - this.numBos = builder.numBos; - this.numEos = builder.numEos; - this.loadMode = builder.loadMode; - } - - /** Model type constant for text-only models. */ - public static final int MODEL_TYPE_TEXT = 1; - - /** Model type constant for text-and-vision multimodal models. */ - public static final int MODEL_TYPE_TEXT_VISION = 2; - - /** Model type constant for generic multimodal models. */ - public static final int MODEL_TYPE_MULTIMODAL = 2; - - /** - * Creates a new Builder instance for constructing LlmModuleConfig objects. - * - * @return a new Builder instance with default configuration values - */ - public static Builder create() { - return new Builder(); - } - - // Getters with documentation - /** - * @return Path to the compiled model module (.pte file) - */ - public String getModulePath() { - return modulePath; - } - - /** - * @return Path to the tokenizer file or directory - */ - public String getTokenizerPath() { - return tokenizerPath; - } - - /** - * @return Temperature value for sampling (higher = more random) - */ - public float getTemperature() { - return temperature; - } - - /** - * @return Optional path to additional data files - */ - public String getDataPath() { - return dataPath; - } - - /** - * @return Type of model (text-only or text-vision) - */ - public int getModelType() { - return modelType; - } - - /** - * @return Number of BOS tokens to prepend - */ - public int getNumBos() { - return numBos; - } - - /** - * @return Number of EOS tokens to append - */ - public int getNumEos() { - return numEos; - } - - /** - * @return Load mode for the model file (one of LOAD_MODE_* constants) - */ - public int getLoadMode() { - return loadMode; - } - - /** - * Builder class for constructing LlmModuleConfig instances with optional parameters. - * - *

The builder provides a fluent interface for configuring model parameters and validates - * required fields before construction. - */ - public static class Builder { - private String modulePath; - private String tokenizerPath; - private float temperature = 0.8f; - private String dataPath = ""; - private int modelType = MODEL_TYPE_TEXT; - private int numBos = 0; - private int numEos = 0; - private int loadMode = LOAD_MODE_MMAP; - - Builder() {} - - /** - * Sets the path to the module. - * - * @param modulePath Path to module - * @return This builder instance for method chaining - */ - public Builder modulePath(String modulePath) { - this.modulePath = modulePath; - return this; - } - - /** - * Sets the path to the tokenizer. - * - * @param tokenizerPath Path to tokenizer - * @return This builder instance for method chaining - */ - public Builder tokenizerPath(String tokenizerPath) { - this.tokenizerPath = tokenizerPath; - return this; - } - - /** - * Sets the temperature for sampling generation. - * - * @param temperature Temperature value (typical range 0.0-1.0) - * @return This builder instance for method chaining - */ - public Builder temperature(float temperature) { - this.temperature = temperature; - return this; - } - - /** - * Sets the path to optional additional data files. - * - * @param dataPath Path to supplementary data resources - * @return This builder instance for method chaining - */ - public Builder dataPath(String dataPath) { - this.dataPath = dataPath; - return this; - } - - /** - * Sets the model type (text-only or multimodal). - * - * @param modelType One of MODEL_TYPE_TEXT, MODEL_TYPE_TEXT_VISION, MODEL_TYPE_MULTIMODAL - * @return This builder instance for method chaining - */ - public Builder modelType(int modelType) { - this.modelType = modelType; - return this; - } - - /** - * Sets the number of BOS tokens to prepend. - * - * @param numBos number of BOS tokens - * @return This builder instance for method chaining - */ - public Builder numBos(int numBos) { - this.numBos = numBos; - return this; - } - - /** - * Sets the number of EOS tokens to append. - * - * @param numEos number of EOS tokens - * @return This builder instance for method chaining - */ - public Builder numEos(int numEos) { - this.numEos = numEos; - return this; - } - - /** - * Sets the load mode for the model file. Defaults to {@link #LOAD_MODE_MMAP} (mmap without - * mlock), which avoids pinning model pages in RAM. - * - * @param loadMode One of LOAD_MODE_FILE, LOAD_MODE_MMAP, LOAD_MODE_MMAP_USE_MLOCK, - * LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS - * @return This builder instance for method chaining - * @throws IllegalArgumentException if {@code loadMode} is not one of the supported constants - */ - public Builder loadMode(int loadMode) { - if (loadMode != LOAD_MODE_FILE - && loadMode != LOAD_MODE_MMAP - && loadMode != LOAD_MODE_MMAP_USE_MLOCK - && loadMode != LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS) { - throw new IllegalArgumentException("Unknown load mode: " + loadMode); - } - this.loadMode = loadMode; - return this; - } - - /** - * Constructs the LlmModuleConfig instance with validated parameters. - * - * @return New LlmModuleConfig instance with configured values - * @throws IllegalArgumentException if required fields are missing - */ - public LlmModuleConfig build() { - if (modulePath == null || tokenizerPath == null) { - throw new IllegalArgumentException("Module path and tokenizer path are required"); - } - return new LlmModuleConfig(this); - } - } -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt new file mode 100644 index 00000000000..2d65633bb9f --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt @@ -0,0 +1,134 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.extension.llm + +/** + * Configuration class for initializing a LlmModule. + * + * Use [create] method and the fluent builder pattern. + */ +class LlmModuleConfig +private constructor( + val modulePath: String, + val tokenizerPath: String, + val temperature: Float, + val dataPath: String?, + val modelType: Int, + val numBos: Int, + val numEos: Int, + val loadMode: Int, +) { + + companion object { + /** Load entire model file into a buffer (no mmap). */ + const val LOAD_MODE_FILE = 0 + + /** Load model via mmap without mlock (default). Pages faulted in on demand. */ + const val LOAD_MODE_MMAP = 1 + + /** Load model via mmap and pin all pages with mlock. */ + const val LOAD_MODE_MMAP_USE_MLOCK = 2 + + /** Load model via mmap and attempt mlock, ignoring mlock failures. */ + const val LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3 + + /** Model type constant for text-only models. */ + const val MODEL_TYPE_TEXT = 1 + + /** Model type constant for text-and-vision multimodal models. */ + const val MODEL_TYPE_TEXT_VISION = 2 + + /** Model type constant for generic multimodal models. */ + const val MODEL_TYPE_MULTIMODAL = 2 + + /** + * Creates a new Builder instance for constructing LlmModuleConfig objects. + * + * @return a new Builder instance with default configuration values + */ + @JvmStatic fun create(): Builder = Builder() + } + + /** + * Builder class for constructing LlmModuleConfig instances with optional parameters. + * + * The builder provides a fluent interface for configuring model parameters and validates required + * fields before construction. + */ + class Builder internal constructor() { + private var modulePath: String? = null + private var tokenizerPath: String? = null + private var temperature: Float = 0.8f + private var dataPath: String? = "" + private var modelType: Int = MODEL_TYPE_TEXT + private var numBos: Int = 0 + private var numEos: Int = 0 + private var loadMode: Int = LOAD_MODE_MMAP + + /** Sets the path to the module. */ + fun modulePath(modulePath: String): Builder = apply { this.modulePath = modulePath } + + /** Sets the path to the tokenizer. */ + fun tokenizerPath(tokenizerPath: String): Builder = apply { this.tokenizerPath = tokenizerPath } + + /** Sets the temperature for sampling generation. */ + fun temperature(temperature: Float): Builder = apply { this.temperature = temperature } + + /** Sets the path to optional additional data files. */ + fun dataPath(dataPath: String?): Builder = apply { this.dataPath = dataPath } + + /** Sets the model type (text-only or multimodal). */ + fun modelType(modelType: Int): Builder = apply { this.modelType = modelType } + + /** Sets the number of BOS tokens to prepend. */ + fun numBos(numBos: Int): Builder = apply { this.numBos = numBos } + + /** Sets the number of EOS tokens to append. */ + fun numEos(numEos: Int): Builder = apply { this.numEos = numEos } + + /** + * Sets the load mode for the model file. Defaults to [LOAD_MODE_MMAP] (mmap without mlock), + * which avoids pinning model pages in RAM. + * + * @throws IllegalArgumentException if loadMode is not one of the supported constants + */ + fun loadMode(loadMode: Int): Builder { + require( + loadMode == LOAD_MODE_FILE || + loadMode == LOAD_MODE_MMAP || + loadMode == LOAD_MODE_MMAP_USE_MLOCK || + loadMode == LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS + ) { + "Unknown load mode: $loadMode" + } + return apply { this.loadMode = loadMode } + } + + /** + * Constructs the LlmModuleConfig instance with validated parameters. + * + * @throws IllegalArgumentException if required fields are missing + */ + fun build(): LlmModuleConfig { + require(modulePath != null && tokenizerPath != null) { + "Module path and tokenizer path are required" + } + return LlmModuleConfig( + modulePath!!, + tokenizerPath!!, + temperature, + dataPath, + modelType, + numBos, + numEos, + loadMode, + ) + } + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/package-info.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/package-info.java deleted file mode 100644 index 86e19d09133..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/package-info.java +++ /dev/null @@ -1,51 +0,0 @@ -/** - * ExecuTorch LLM extension for Android. - * - *

This package provides Java bindings for running large language models (LLMs) on Android using - * ExecuTorch. It supports text generation, tokenization, and streaming token callbacks. - * - *

Quick Start

- * - *
{@code
- * import org.pytorch.executorch.extension.llm.LlmModule;
- *
- * // Load a Llama model
- * LlmModule llm = new LlmModule(
- *     "/data/local/tmp/llama.pte",
- *     "/data/local/tmp/tokenizer.bin",
- *     0.8f
- * );
- * llm.load();
- *
- * // Generate text token by token
- * llm.generate("Hello, my name is", 200, new LlmCallback() {
- *     public void onResult(String token) {
- *         System.out.print(token);
- *     }
- *     public void onStats(String stats) {
- *         System.out.println("\nStats: " + stats);
- *     }
- * });
- * }
- * - *

Key Classes

- * - *
    - *
  • {@link org.pytorch.executorch.extension.llm.LlmModule} — load and run an LLM - *
  • {@link org.pytorch.executorch.extension.llm.LlmModuleConfig} — configure model paths and - * settings - *
  • {@link org.pytorch.executorch.extension.llm.LlmGenerationConfig} — control generation - * (temperature, seq length) - *
- * - *

More Resources

- * - * - */ -package org.pytorch.executorch.extension.llm; diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/package-info.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/package-info.java deleted file mode 100644 index 7a5ed0bb5a5..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/package-info.java +++ /dev/null @@ -1,57 +0,0 @@ -/** - * ExecuTorch Android Java API. - * - *

This package provides Java bindings for running ExecuTorch models on Android. Use these - * classes to load a {@code .pte} model file and run inference directly from your Java or Kotlin - * Android app — no C++ required. - * - *

Quick Start

- * - *

Step 1. Add the dependency to your {@code app/build.gradle.kts}: - * - *

{@code
- * dependencies {
- *     implementation("org.pytorch:executorch-android:${executorch_version}")
- * }
- * }
- * - *

Step 2. Load your model and run inference: - * - *

{@code
- * import org.pytorch.executorch.EValue;
- * import org.pytorch.executorch.Module;
- * import org.pytorch.executorch.Tensor;
- *
- * // Load your exported .pte model file
- * Module module = Module.load("/data/local/tmp/model.pte");
- *
- * // Build an input tensor  e.g. a 1x3x224x224 image
- * float[] inputData = new float[1 * 3 * 224 * 224];
- * Tensor inputTensor = Tensor.fromBlob(inputData, new long[]{1, 3, 224, 224});
- *
- * // Run inference
- * EValue[] output = module.forward(EValue.from(inputTensor));
- *
- * // Read the result
- * float[] scores = output[0].toTensor().getDataAsFloatArray();
- * }
- * - *

Key Classes

- * - *
    - *
  • {@link org.pytorch.executorch.Module} — load and run a {@code .pte} model - *
  • {@link org.pytorch.executorch.Tensor} — create input tensors and read outputs - *
  • {@link org.pytorch.executorch.EValue} — wrap inputs and unwrap outputs - *
  • {@link org.pytorch.executorch.DType} — supported data types (FLOAT, INT32, etc.) - *
- * - *

More Resources

- * - * - */ -package org.pytorch.executorch; diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index e072694f913..b9215f978bc 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -206,41 +206,14 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { data_files_vector, cpp_load_mode); std::string decoder_model = "llama3"; // use llama3 for now - // Using 8bit as default since this meta is introduced with 16bit kv io - // support and older models only have 8bit kv io. - example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; - if (module->method_names()->count("get_kv_io_bit_width") > 0) { - kv_bitwidth = static_cast( - module->get("get_kv_io_bit_width") - .get() - .toScalar() - .to()); - } - - if (kv_bitwidth == example::KvBitWidth::kWidth8) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else { - ET_CHECK_MSG( - false, - "Unsupported kv bitwidth: %ld", - static_cast(kv_bitwidth)); - } + runner_ = std::make_unique( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif #if defined(EXECUTORCH_BUILD_MEDIATEK) diff --git a/extension/asr/runner/CMakeLists.txt b/extension/asr/runner/CMakeLists.txt index 66974aa2a24..b47cddaf48c 100644 --- a/extension/asr/runner/CMakeLists.txt +++ b/extension/asr/runner/CMakeLists.txt @@ -22,7 +22,7 @@ endif() include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) set(runner_deps executorch_core extension_module extension_tensor - tokenizers::tokenizers + extension_llm_runner tokenizers::tokenizers ) # Define runner library diff --git a/extension/asr/runner/transducer_runner.cpp b/extension/asr/runner/transducer_runner.cpp index 3461cb09cc1..7b9298845a9 100644 --- a/extension/asr/runner/transducer_runner.cpp +++ b/extension/asr/runner/transducer_runner.cpp @@ -200,7 +200,7 @@ Error TransducerRunner::load() { return Error::Ok; } -Result<::executorch::extension::TensorPtr> TransducerRunner::preprocess( +Result TransducerRunner::preprocess( ::executorch::extension::TensorPtr raw_audio) { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); @@ -229,12 +229,18 @@ Result<::executorch::extension::TensorPtr> TransducerRunner::preprocess( "Preprocessor returned unexpected output."); auto mel = outputs[0].toTensor(); - return std::make_shared<::executorch::aten::Tensor>(std::move(mel)); + int64_t mel_len = mel.sizes()[1]; // default to tensor dim + if (outputs.size() >= 2 && outputs[1].isTensor()) { + mel_len = outputs[1].toTensor().const_data_ptr()[0]; + } + return PreprocessResult{ + std::make_shared<::executorch::aten::Tensor>(std::move(mel)), mel_len}; } Result> TransducerRunner::transcribe( ::executorch::extension::TensorPtr preprocessed_features, - std::function token_callback) { + std::function token_callback, + int64_t features_length) { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } @@ -242,7 +248,9 @@ Result> TransducerRunner::transcribe( stats_.inference_start_ms = ::executorch::extension::llm::time_in_ms(); // --- Encode --- - int64_t mel_len_value = preprocessed_features->size(1); + // Use provided length, or fall back to tensor dimension + int64_t mel_len_value = + features_length > 0 ? features_length : preprocessed_features->size(1); std::vector mel_len_data = {mel_len_value}; auto mel_len = ::executorch::extension::from_blob( mel_len_data.data(), {1}, ::executorch::aten::ScalarType::Long); diff --git a/extension/asr/runner/transducer_runner.h b/extension/asr/runner/transducer_runner.h index ee819590141..aed0ad84cd6 100644 --- a/extension/asr/runner/transducer_runner.h +++ b/extension/asr/runner/transducer_runner.h @@ -29,6 +29,14 @@ using ::executorch::extension::llm::Stats; using ::executorch::runtime::Error; using ::executorch::runtime::Result; +/** + * Preprocessed audio features with actual (unpadded) length. + */ +struct PreprocessResult { + ::executorch::extension::TensorPtr features; + int64_t length; // Actual number of valid frames (excluding padding) +}; + /** * A decoded token with frame-level timing information. */ @@ -97,7 +105,7 @@ class ET_EXPERIMENTAL TransducerRunner { * @returns Preprocessed features tensor (e.g., mel spectrogram), * ready to pass to transcribe(). */ - Result<::executorch::extension::TensorPtr> preprocess( + Result preprocess( ::executorch::extension::TensorPtr raw_audio); /** @@ -112,7 +120,8 @@ class ET_EXPERIMENTAL TransducerRunner { */ Result> transcribe( ::executorch::extension::TensorPtr preprocessed_features, - std::function token_callback = {}); + std::function token_callback = {}, + int64_t features_length = -1); /** * Returns a reference to the loaded tokenizer, or nullptr if not loaded. diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java deleted file mode 100644 index 5e1dd48926b..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.app.Activity; -import android.content.Intent; -import android.os.Bundle; -import android.os.Handler; -import android.os.HandlerThread; -import android.os.Looper; -import android.system.ErrnoException; -import android.system.Os; -import com.google.gson.Gson; -import java.io.File; -import java.io.FileWriter; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -public class BenchmarkActivity extends Activity { - - File mModel; - int mNumIter; - int mNumWarmupIter; - String mTokenizerPath; - float mTemperature; - String mPrompt; - - HandlerThread mHandlerThread; - BenchmarkHandler mHandler; - - List mResult; - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - - try { - Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); - } catch (ErrnoException e) { - finish(); - } - - Intent intent = getIntent(); - File modelDir = new File(intent.getStringExtra("model_dir")); - File model = - Arrays.stream(modelDir.listFiles()) - .filter(file -> file.getName().endsWith(".pte")) - .findFirst() - .get(); - - int numIter = intent.getIntExtra("num_iter", 50); - int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10); - String tokenizerPath = intent.getStringExtra("tokenizer_path"); - float temperature = intent.getFloatExtra("temperature", 0.8f); - String prompt = intent.getStringExtra("prompt"); - - mModel = model; - mNumIter = numIter; - mNumWarmupIter = numWarmupIter; - mTokenizerPath = tokenizerPath; - mTemperature = temperature; - mPrompt = prompt; - if (mPrompt == null) { - mPrompt = "The ultimate answer"; - } - mResult = new ArrayList<>(); - - mHandlerThread = new HandlerThread("ModelRunner"); - mHandlerThread.start(); - mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this); - - mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK); - } - - void writeResult() { - try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { - Gson gson = new Gson(); - writer.write(gson.toJson(mResult)); - } catch (IOException e) { - e.printStackTrace(); - } finally { - finish(); - } - } -} - -class BenchmarkHandler extends Handler { - public static int MESSAGE_RUN_BENCHMARK = 1; - public static int MESSAGE_LLM_RUN_BENCHMARK = 2; - - ModelRunner mModelRunner; - BenchmarkActivity mBenchmarkActivity; - - LlmModelRunner mLlmModelRunner; - LlmBenchmark mLlmBenchmark; - - public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) { - super(looper); - mModelRunner = new ModelRunner(); - mBenchmarkActivity = benchmarkActivity; - } - - @Override - public void handleMessage(android.os.Message msg) { - if (msg.what == MESSAGE_RUN_BENCHMARK) { - mModelRunner.runBenchmark( - mBenchmarkActivity.mModel, - mBenchmarkActivity.mNumWarmupIter, - mBenchmarkActivity.mNumIter, - mBenchmarkActivity.mResult); - - if (mBenchmarkActivity.mTokenizerPath == null) { - mBenchmarkActivity.writeResult(); - } else { - this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK); - } - } else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) { - mLlmBenchmark = - new LlmBenchmark( - mBenchmarkActivity, - mBenchmarkActivity.mModel.getPath(), - mBenchmarkActivity.mTokenizerPath, - mBenchmarkActivity.mPrompt, - mBenchmarkActivity.mTemperature, - mBenchmarkActivity.mResult); - } - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.kt new file mode 100644 index 00000000000..b1d69c5f24f --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.kt @@ -0,0 +1,116 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench + +import android.app.Activity +import android.os.Bundle +import android.os.Handler +import android.os.HandlerThread +import android.os.Looper +import android.os.Message +import android.system.Os +import com.google.gson.Gson +import java.io.File +import java.io.FileWriter +import java.io.IOException + +class BenchmarkActivity : Activity() { + + lateinit var model: File + var numIter: Int = 0 + var numWarmupIter: Int = 0 + var tokenizerPath: String? = null + var temperature: Float = 0.8f + var prompt: String = "The ultimate answer" + + private lateinit var handlerThread: HandlerThread + private lateinit var handler: BenchmarkHandler + + val results: MutableList = mutableListOf() + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + + try { + Os.setenv("ADSP_LIBRARY_PATH", applicationInfo.nativeLibraryDir, true) + } catch (e: android.system.ErrnoException) { + finish() + return + } + + val intent = intent + val modelDir = File(intent.getStringExtra("model_dir")!!) + model = modelDir.listFiles()!!.first { it.name.endsWith(".pte") } + + numIter = intent.getIntExtra("num_iter", 50) + numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10) + tokenizerPath = intent.getStringExtra("tokenizer_path") + temperature = intent.getFloatExtra("temperature", 0.8f) + prompt = intent.getStringExtra("prompt") ?: "The ultimate answer" + + handlerThread = HandlerThread("ModelRunner") + handlerThread.start() + handler = BenchmarkHandler(handlerThread.looper, this) + + handler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK) + } + + fun writeResult() { + try { + FileWriter("${filesDir}/benchmark_results.json").use { writer -> + writer.write(Gson().toJson(results)) + } + } catch (e: IOException) { + e.printStackTrace() + } finally { + finish() + } + } +} + +private class BenchmarkHandler( + looper: Looper, + private val activity: BenchmarkActivity, +) : Handler(looper) { + + private val modelRunner = ModelRunner() + + override fun handleMessage(msg: Message) { + when (msg.what) { + MESSAGE_RUN_BENCHMARK -> { + modelRunner.runBenchmark( + activity.model, + activity.numWarmupIter, + activity.numIter, + activity.results, + ) + if (activity.tokenizerPath == null) { + activity.writeResult() + } else { + sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK) + } + } + MESSAGE_LLM_RUN_BENCHMARK -> { + LlmBenchmark( + activity, + activity.model.path, + activity.tokenizerPath!!, + activity.prompt, + activity.temperature, + activity.results, + ) + } + } + } + + companion object { + const val MESSAGE_RUN_BENCHMARK = 1 + const val MESSAGE_LLM_RUN_BENCHMARK = 2 + } +} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java deleted file mode 100644 index 66ab50550a4..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.app.ActivityManager; -import android.os.Build; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -class BenchmarkMetric { - public static class BenchmarkModel { - // The model name, i.e. stories110M - String name; - String backend; - String quantization; - - public BenchmarkModel(final String name, final String backend, final String quantization) { - this.name = name; - this.backend = backend; - this.quantization = quantization; - } - } - - BenchmarkModel benchmarkModel; - - // The metric name, i.e. TPS - String metric; - - // The actual value and the option target value - double actualValue; - double targetValue; - - public static class DeviceInfo { - // Let's see which information we want to include here - final String device = Build.BRAND; - // The phone model and Android release version - final String arch = Build.MODEL; - final String os = "Android " + Build.VERSION.RELEASE; - final long totalMem = new ActivityManager.MemoryInfo().totalMem; - final long availMem = new ActivityManager.MemoryInfo().availMem; - } - - DeviceInfo deviceInfo = new DeviceInfo(); - - public BenchmarkMetric( - final BenchmarkModel benchmarkModel, - final String metric, - final double actualValue, - final double targetValue) { - this.benchmarkModel = benchmarkModel; - this.metric = metric; - this.actualValue = actualValue; - this.targetValue = targetValue; - } - - // TODO (huydhn): Figure out a way to extract the backend and quantization information from - // the .pte model itself instead of parsing its name - public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { - final Matcher m = - Pattern.compile("(?\\w+)_(?[\\w\\+]+)_(?\\w+)").matcher(model); - if (m.matches()) { - return new BenchmarkMetric.BenchmarkModel( - m.group("name"), m.group("backend"), m.group("quantization")); - } else { - return new BenchmarkMetric.BenchmarkModel(model, "", ""); - } - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.kt new file mode 100644 index 00000000000..7bed1ab05c0 --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.kt @@ -0,0 +1,54 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench + +import android.app.ActivityManager +import android.os.Build + +class BenchmarkMetric( + val benchmarkModel: BenchmarkModel, + val metric: String, + val actualValue: Double, + val targetValue: Double, +) { + data class BenchmarkModel( + val name: String, + val backend: String, + val quantization: String, + ) + + class DeviceInfo { + val device: String = Build.BRAND + val arch: String = Build.MODEL + val os: String = "Android ${Build.VERSION.RELEASE}" + val totalMem: Long = ActivityManager.MemoryInfo().totalMem + val availMem: Long = ActivityManager.MemoryInfo().availMem + } + + val deviceInfo: DeviceInfo = DeviceInfo() + + companion object { + // TODO (huydhn): Figure out a way to extract the backend and quantization information from + // the .pte model itself instead of parsing its name + @JvmStatic + fun extractBackendAndQuantization(model: String): BenchmarkModel { + val pattern = Regex("(?\\w+)_(?[\\w+]+)_(?\\w+)") + val match = pattern.matchEntire(model) + return if (match != null) { + BenchmarkModel( + match.groups["name"]!!.value, + match.groups["backend"]!!.value, + match.groups["quantization"]!!.value, + ) + } else { + BenchmarkModel(model, "", "") + } + } + } +} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java deleted file mode 100644 index 0c0436d2676..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.util.Log; -import java.util.List; -import org.json.JSONException; -import org.json.JSONObject; - -public class LlmBenchmark implements LlmModelRunnerCallback { - LlmModelRunner mLlmModelRunner; - - String mPrompt; - StatsInfo mStatsInfo; - - List mResults; - BenchmarkActivity mActivity; - - LlmBenchmark( - BenchmarkActivity activity, - String modelFile, - String tokenizerPath, - String prompt, - float temperature, - List results) { - mResults = results; - mActivity = activity; - mStatsInfo = new StatsInfo(); - mStatsInfo.modelName = modelFile.substring(modelFile.lastIndexOf('/') + 1).replace(".pte", ""); - mPrompt = prompt; - mLlmModelRunner = new LlmModelRunner(modelFile, tokenizerPath, temperature, this); - mStatsInfo.loadStart = System.nanoTime(); - } - - @Override - public void onModelLoaded(int status) { - mStatsInfo.loadEnd = System.nanoTime(); - mStatsInfo.loadStatus = status; - if (status != 0) { - Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); - onGenerationStopped(); - return; - } - mStatsInfo.generateStart = System.nanoTime(); - mLlmModelRunner.generate(mPrompt); - } - - @Override - public void onTokenGenerated(String token) {} - - @Override - public void onStats(String stats) { - float tps = 0; - try { - JSONObject jsonObject = new JSONObject(stats); - int numGeneratedTokens = jsonObject.getInt("generated_tokens"); - int inferenceEndMs = jsonObject.getInt("inference_end_ms"); - int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); - tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; - mStatsInfo.tps = tps; - } catch (JSONException e) { - Log.e("LLM", "Error parsing JSON: " + e.getMessage()); - } - } - - @Override - public void onGenerationStopped() { - mStatsInfo.generateEnd = System.nanoTime(); - - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(mStatsInfo.modelName); - // The list of metrics we have atm includes: - // Load status - mResults.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsInfo.loadStatus, 0)); - // Model load time - mResults.add( - new BenchmarkMetric( - benchmarkModel, - "llm_model_load_time(ms)", - (mStatsInfo.loadEnd - mStatsInfo.loadStart) * 1e-6, - 0.0f)); - // LLM generate time - mResults.add( - new BenchmarkMetric( - benchmarkModel, - "generate_time(ms)", - (mStatsInfo.generateEnd - mStatsInfo.generateStart) * 1e-6, - 0.0f)); - // Token per second - mResults.add(new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f)); - mActivity.writeResult(); - } -} - -class StatsInfo { - int loadStatus; - long loadStart; - long loadEnd; - long generateStart; - long generateEnd; - float tps; - String modelName; - - @Override - public String toString() { - return "loadStart: " - + loadStart - + "\nloadEnd: " - + loadEnd - + "\ngenerateStart: " - + generateStart - + "\ngenerateEnd: " - + generateEnd - + "\n" - + tps; - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.kt new file mode 100644 index 00000000000..5c75519f870 --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.kt @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench + +import android.util.Log +import org.json.JSONException +import org.json.JSONObject + +class LlmBenchmark( + private val activity: BenchmarkActivity, + modelFile: String, + tokenizerPath: String, + private val prompt: String, + temperature: Float, + private val results: MutableList, +) : LlmModelRunnerCallback { + + private val runner: LlmModelRunner + private val statsInfo = StatsInfo() + + init { + statsInfo.modelName = modelFile.substringAfterLast('/').removeSuffix(".pte") + runner = LlmModelRunner(modelFile, tokenizerPath, temperature, this) + statsInfo.loadStart = System.nanoTime() + } + + override fun onModelLoaded(status: Int) { + statsInfo.loadEnd = System.nanoTime() + statsInfo.loadStatus = status + if (status != 0) { + Log.e("LlmBenchmarkRunner", "Loaded failed: $status") + onGenerationStopped() + return + } + statsInfo.generateStart = System.nanoTime() + runner.generate(prompt) + } + + override fun onTokenGenerated(token: String) {} + + override fun onStats(stats: String) { + try { + val json = JSONObject(stats) + val numGeneratedTokens = json.getInt("generated_tokens") + val inferenceEndMs = json.getInt("inference_end_ms") + val promptEvalEndMs = json.getInt("prompt_eval_end_ms") + statsInfo.tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000 + } catch (e: JSONException) { + Log.e("LLM", "Error parsing JSON: ${e.message}") + } + } + + override fun onGenerationStopped() { + statsInfo.generateEnd = System.nanoTime() + + val benchmarkModel = BenchmarkMetric.extractBackendAndQuantization(statsInfo.modelName) + results.add(BenchmarkMetric(benchmarkModel, "load_status", statsInfo.loadStatus.toDouble(), 0.0)) + results.add( + BenchmarkMetric( + benchmarkModel, + "llm_model_load_time(ms)", + (statsInfo.loadEnd - statsInfo.loadStart) * 1e-6, + 0.0, + )) + results.add( + BenchmarkMetric( + benchmarkModel, + "generate_time(ms)", + (statsInfo.generateEnd - statsInfo.generateStart) * 1e-6, + 0.0, + )) + results.add(BenchmarkMetric(benchmarkModel, "token_per_sec", statsInfo.tps.toDouble(), 0.0)) + activity.writeResult() + } +} + +private class StatsInfo { + var loadStatus: Int = 0 + var loadStart: Long = 0 + var loadEnd: Long = 0 + var generateStart: Long = 0 + var generateEnd: Long = 0 + var tps: Float = 0f + var modelName: String = "" +} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java deleted file mode 100644 index 3a345d3465b..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.os.Handler; -import android.os.HandlerThread; -import android.os.Looper; -import android.os.Message; -import android.util.Log; -import org.pytorch.executorch.extension.llm.LlmCallback; -import org.pytorch.executorch.extension.llm.LlmModule; - -/** A helper class to handle all model running logic within this class. */ -public class LlmModelRunner implements LlmCallback { - LlmModule mModule = null; - - String mModelFilePath = ""; - String mTokenizerFilePath = ""; - - LlmModelRunnerCallback mCallback = null; - - HandlerThread mHandlerThread = null; - Handler mHandler = null; - - /** - * ] Helper class to separate between UI logic and model runner logic. Automatically handle - * generate() request on worker thread. - * - * @param modelFilePath - * @param tokenizerFilePath - * @param callback - */ - LlmModelRunner( - String modelFilePath, - String tokenizerFilePath, - float temperature, - LlmModelRunnerCallback callback) { - mModelFilePath = modelFilePath; - mTokenizerFilePath = tokenizerFilePath; - mCallback = callback; - - mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f); - mHandlerThread = new HandlerThread("LlmModelRunner"); - mHandlerThread.start(); - mHandler = new LlmModelRunnerHandler(mHandlerThread.getLooper(), this); - - mHandler.sendEmptyMessage(LlmModelRunnerHandler.MESSAGE_LOAD_MODEL); - } - - int generate(String prompt) { - Message msg = Message.obtain(mHandler, LlmModelRunnerHandler.MESSAGE_GENERATE, prompt); - msg.sendToTarget(); - return 0; - } - - void stop() { - mModule.stop(); - } - - @Override - public void onResult(String result) { - mCallback.onTokenGenerated(result); - } - - @Override - public void onStats(String result) { - mCallback.onStats(result); - } -} - -class LlmModelRunnerHandler extends Handler { - public static int MESSAGE_LOAD_MODEL = 1; - public static int MESSAGE_GENERATE = 2; - - private final LlmModelRunner mLlmModelRunner; - - public LlmModelRunnerHandler(Looper looper, LlmModelRunner llmModelRunner) { - super(looper); - mLlmModelRunner = llmModelRunner; - } - - @Override - public void handleMessage(android.os.Message msg) { - if (msg.what == MESSAGE_LOAD_MODEL) { - int status = 0; - try { - mLlmModelRunner.mModule.load(); - } catch (Exception e) { - status = - (e instanceof org.pytorch.executorch.ExecutorchRuntimeException) - ? ((org.pytorch.executorch.ExecutorchRuntimeException) e).getErrorCode() - : -1; - } - mLlmModelRunner.mCallback.onModelLoaded(status); - } else if (msg.what == MESSAGE_GENERATE) { - try { - mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner); - } catch (Exception e) { - Log.e("LlmModelRunner", "generate() failed", e); - } - mLlmModelRunner.mCallback.onGenerationStopped(); - } - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.kt new file mode 100644 index 00000000000..29b9b177fb6 --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.kt @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench + +import android.os.Handler +import android.os.HandlerThread +import android.os.Looper +import android.os.Message +import android.util.Log +import org.pytorch.executorch.ExecutorchRuntimeException +import org.pytorch.executorch.extension.llm.LlmCallback +import org.pytorch.executorch.extension.llm.LlmModule + +/** A helper class to handle all model running logic within this class. */ +class LlmModelRunner( + modelFilePath: String, + tokenizerFilePath: String, + temperature: Float, + val callback: LlmModelRunnerCallback, +) : LlmCallback { + + val module: LlmModule = LlmModule(modelFilePath, tokenizerFilePath, temperature) + private val handlerThread: HandlerThread = HandlerThread("LlmModelRunner") + private val handler: Handler + + init { + handlerThread.start() + handler = LlmModelRunnerHandler(handlerThread.looper, this) + handler.sendEmptyMessage(LlmModelRunnerHandler.MESSAGE_LOAD_MODEL) + } + + fun generate(prompt: String): Int { + val msg = Message.obtain(handler, LlmModelRunnerHandler.MESSAGE_GENERATE, prompt) + msg.sendToTarget() + return 0 + } + + fun stop() { + module.stop() + } + + override fun onResult(result: String) { + callback.onTokenGenerated(result) + } + + override fun onStats(stats: String) { + callback.onStats(stats) + } +} + +private class LlmModelRunnerHandler( + looper: Looper, + private val runner: LlmModelRunner, +) : Handler(looper) { + + override fun handleMessage(msg: Message) { + when (msg.what) { + MESSAGE_LOAD_MODEL -> { + val status = + try { + runner.module.load() + 0 + } catch (e: ExecutorchRuntimeException) { + e.errorCode + } catch (e: Exception) { + -1 + } + runner.callback.onModelLoaded(status) + } + MESSAGE_GENERATE -> { + try { + runner.module.generate(msg.obj as String, runner) + } catch (e: Exception) { + Log.e("LlmModelRunner", "generate() failed", e) + } + runner.callback.onGenerationStopped() + } + } + } + + companion object { + const val MESSAGE_LOAD_MODEL = 1 + const val MESSAGE_GENERATE = 2 + } +} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java deleted file mode 100644 index 915496a25af..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.os.Debug; -import java.io.File; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import org.pytorch.executorch.Module; - -public class ModelRunner { - /** - * @return list of #BenchmarkMetric - */ - public void runBenchmark( - File model, int numWarmupIter, int numIter, List results) { - long pssIdle = Debug.getPss(); - - List latency = new ArrayList<>(); - - long loadStart = System.nanoTime(); - Module module = Module.load(model.getPath()); - int errorCode = 0; - try { - module.loadMethod("forward"); - } catch (Exception e) { - errorCode = - (e instanceof org.pytorch.executorch.ExecutorchRuntimeException) - ? ((org.pytorch.executorch.ExecutorchRuntimeException) e).getErrorCode() - : -1; - } - long loadEnd = System.nanoTime(); - - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); - - if (errorCode != 0) { - results.add( - new BenchmarkMetric( - benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f)); - results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); - module.destroy(); - return; - } - - try { - for (int i = 0; i < numWarmupIter; i++) { - module.forward(); - } - - for (int i = 0; i < numIter; i++) { - long start = System.nanoTime(); - module.forward(); - double forwardMs = (System.nanoTime() - start) * 1e-6; - latency.add(forwardMs); - } - - module.etdump(); - - // Currently the result has large variance from outliers, so only use - // 80% samples in the middle (trimmean 0.2) - Collections.sort(latency); - int resultSize = latency.size(); - List usedLatencyResults = latency.subList(resultSize / 10, resultSize * 9 / 10); - - results.add( - new BenchmarkMetric( - benchmarkModel, - "avg_inference_latency(ms)", - latency.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - results.add( - new BenchmarkMetric( - benchmarkModel, - "trimmean_inference_latency(ms)", - usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - // Model load time - results.add( - new BenchmarkMetric( - benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f)); - // Load status - results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); - // RAM PSS usage - results.add( - new BenchmarkMetric( - benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0)); - } finally { - module.destroy(); - } - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.kt new file mode 100644 index 00000000000..0f292b0d900 --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.kt @@ -0,0 +1,90 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench + +import android.os.Debug +import java.io.File +import org.pytorch.executorch.ExecutorchRuntimeException +import org.pytorch.executorch.Module + +class ModelRunner { + + fun runBenchmark( + model: File, + numWarmupIter: Int, + numIter: Int, + results: MutableList, + ) { + val pssIdle = Debug.getPss() + val latency = mutableListOf() + + val loadStart = System.nanoTime() + val module = Module.load(model.path) + var errorCode = 0 + try { + module.loadMethod("forward") + } catch (e: ExecutorchRuntimeException) { + errorCode = e.errorCode + } catch (e: Exception) { + errorCode = -1 + } + val loadEnd = System.nanoTime() + + val benchmarkModel = + BenchmarkMetric.extractBackendAndQuantization(model.name.removeSuffix(".pte")) + + if (errorCode != 0) { + results.add( + BenchmarkMetric(benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0)) + results.add(BenchmarkMetric(benchmarkModel, "load_status", errorCode.toDouble(), 0.0)) + module.destroy() + return + } + + try { + repeat(numWarmupIter) { module.forward() } + + repeat(numIter) { + val start = System.nanoTime() + module.forward() + latency.add((System.nanoTime() - start) * 1e-6) + } + + module.etdump() + + // Currently the result has large variance from outliers, so only use + // 80% samples in the middle (trimmean 0.2) + latency.sort() + val trimmed = latency.subList(latency.size / 10, latency.size * 9 / 10) + + results.add( + BenchmarkMetric( + benchmarkModel, + "avg_inference_latency(ms)", + latency.average(), + 0.0, + )) + results.add( + BenchmarkMetric( + benchmarkModel, + "trimmean_inference_latency(ms)", + trimmed.average(), + 0.0, + )) + results.add( + BenchmarkMetric(benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0)) + results.add(BenchmarkMetric(benchmarkModel, "load_status", errorCode.toDouble(), 0.0)) + results.add( + BenchmarkMetric( + benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024.0, 0.0)) + } finally { + module.destroy() + } + } +} diff --git a/extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.java b/extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.kt similarity index 55% rename from extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.java rename to extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.kt index c6a6a76a4d8..b98a49e4bf9 100644 --- a/extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.java +++ b/extension/benchmark/android/benchmark/app/src/test/java/org/pytorch/minibench/ExampleUnitTest.kt @@ -6,20 +6,19 @@ * LICENSE file in the root directory of this source tree. */ -package org.pytorch.minibench; +package org.pytorch.minibench -import static org.junit.Assert.*; - -import org.junit.Test; +import org.junit.Assert.assertEquals +import org.junit.Test /** * Example local unit test, which will execute on the development machine (host). * - * @see Testing documentation + * @see [Testing documentation](http://d.android.com/tools/testing) */ -public class ExampleUnitTest { +class ExampleUnitTest { @Test - public void addition_isCorrect() { - assertEquals(4, 2 + 2); + fun addition_isCorrect() { + assertEquals(4, 2 + 2) } } diff --git a/extension/llm/custom_ops/model_sharding.py b/extension/llm/custom_ops/model_sharding.py index 6838b0958a2..916b13a90b8 100644 --- a/extension/llm/custom_ops/model_sharding.py +++ b/extension/llm/custom_ops/model_sharding.py @@ -7,8 +7,9 @@ import re from typing import List -import torch +import executorch.extension.llm.custom_ops.op_fallback # noqa: F401 +import torch from executorch.backends.qualcomm.utils.constants import ( QCOM_PASS_ACTIVATE_KEY, QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, @@ -17,27 +18,6 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.export.exported_program import ExportedProgram -from torch.library import impl, Library - - -fallback_op_lib = Library("llama", "DEF") -# registering an operator. -fallback_op_lib.define("fallback(Tensor input) -> Tensor") - - -@impl(fallback_op_lib, "fallback") -def fallback_impl(a: torch.Tensor) -> torch.Tensor: - return a - - -# registering the out variant. -fallback_op_lib.define("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)") - - -@impl(fallback_op_lib, "fallback.out") -def fallback_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: - out.copy_(a) - return out class SplitGraph(ExportPass): diff --git a/extension/llm/custom_ops/op_fallback.py b/extension/llm/custom_ops/op_fallback.py new file mode 100644 index 00000000000..e94c81db51a --- /dev/null +++ b/extension/llm/custom_ops/op_fallback.py @@ -0,0 +1,29 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# pyre-ignore-all-errors + +import torch + +from torch.library import impl, Library + +fallback_op_lib = Library("llama", "DEF") +# registering an operator. +fallback_op_lib.define("fallback(Tensor input) -> Tensor") + + +@impl(fallback_op_lib, "fallback") +def fallback_impl(a: torch.Tensor) -> torch.Tensor: + return a + + +# registering the out variant. +fallback_op_lib.define("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)") + + +@impl(fallback_op_lib, "fallback.out") +def fallback_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: + out.copy_(a) + return out diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 5422fb15b71..11fea031603 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include namespace executorch { @@ -367,6 +368,51 @@ Module::make_planned_memory_with_shared_arenas( return planned; } +std::unique_ptr Module::make_planned_memory_with_devices( + const ET_RUNTIME_NAMESPACE::MethodMeta& method_meta) { + auto planned = std::make_unique(); + const size_t num_buffers = method_meta.num_memory_planned_buffers(); + planned->planned_buffers.reserve(num_buffers); + planned->planned_spans.reserve(num_buffers); + planned->device_buffers.reserve(num_buffers); + planned->planned_devices.reserve(num_buffers); + + for (size_t i = 0; i < num_buffers; ++i) { + auto size = method_meta.memory_planned_buffer_size(i); + ET_CHECK_MSG(size.ok(), "Failed to get buffer size for index %zu", i); + auto device = method_meta.memory_planned_buffer_device(i); + ET_CHECK_MSG(device.ok(), "Failed to get buffer device for index %zu", i); + planned->planned_devices.push_back(device.get()); + + if (device->is_cpu()) { + planned->planned_buffers.emplace_back(size.get()); + planned->planned_spans.emplace_back( + planned->planned_buffers.back().data(), size.get()); + } else { + // Allocate device memory via DeviceAllocator and store the RAII buffer. + planned->planned_buffers.emplace_back(); // empty CPU placeholder + auto dmb = runtime::DeviceMemoryBuffer::create( + size.get(), device->type(), device->index()); + ET_CHECK_MSG( + dmb.ok(), + "Failed to allocate device memory for buffer %zu (device_type=%d)", + i, + static_cast(device->type())); + planned->planned_spans.emplace_back(dmb->as_span()); + planned->device_buffers.push_back(std::move(dmb.get())); + } + } + + // HierarchicalAllocator owns the per-buffer Device metadata so the + // MemoryManager can later expose it via planned_buffer_devices(). + planned->planned_memory = std::make_unique( + runtime::Span>( + planned->planned_spans.data(), planned->planned_spans.size()), + runtime::Span( + planned->planned_devices.data(), planned->planned_devices.size())); + return planned; +} + runtime::Result> Module::get_mem_planned_buffer_sizes( const std::string& method_name) { auto meta_res = program_->method_meta(method_name.c_str()); @@ -422,10 +468,38 @@ runtime::Error Module::load_method( MethodHolder method_holder; if (!planned_memory) { - if (!share_memory_arenas_) { + // Check if any buffers need device memory allocation. + auto meta_res = program_->method_meta(method_name.c_str()); + ET_CHECK_OK_OR_RETURN_ERROR(meta_res.error()); + auto& meta = meta_res.get(); + + bool has_device_buffers = false; + for (size_t i = 0; i < meta.num_memory_planned_buffers(); ++i) { + auto dev = meta.memory_planned_buffer_device(i); + if (dev.ok() && !dev->is_cpu()) { + has_device_buffers = true; + break; + } + } + + if (has_device_buffers) { + // Device memory with shared arenas is not yet supported. + ET_CHECK_OR_RETURN_ERROR( + !share_memory_arenas_, + NotSupported, + "Device memory buffers are not yet compatible with " + "share_memory_arenas. Please disable share_memory_arenas " + "when using models with device-planned memory."); + + // Device-aware path: allocate CPU and device buffers. The device + // span is owned by the HierarchicalAllocator inside PlannedMemory. + method_holder.planned_memory = make_planned_memory_with_devices(meta); + planned_memory = method_holder.planned_memory->planned_memory.get(); + } else if (!share_memory_arenas_) { auto sizes_res = get_mem_planned_buffer_sizes(method_name); ET_CHECK_OK_OR_RETURN_ERROR(sizes_res.error()); method_holder.planned_memory = make_planned_memory(sizes_res.get()); + planned_memory = method_holder.planned_memory->planned_memory.get(); } else { auto sizes_res = get_mem_planned_buffer_sizes(method_name); ET_CHECK_OK_OR_RETURN_ERROR(sizes_res.error()); @@ -442,8 +516,8 @@ runtime::Error Module::load_method( } method_holder.planned_memory = make_planned_memory_with_shared_arenas(sizes, shared_arenas_); + planned_memory = method_holder.planned_memory->planned_memory.get(); } - planned_memory = method_holder.planned_memory->planned_memory.get(); } method_holder.memory_manager = std::make_unique( diff --git a/extension/module/module.h b/extension/module/module.h index 47ead23032e..91c7feaad9b 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -18,6 +18,8 @@ #include #include +#include + #ifdef USE_ATEN_LIB #define ET_MODULE_NAMESPACE module::aten #else // !USE_ATEN_LIB @@ -716,6 +718,11 @@ class Module { struct PlannedMemory { std::vector> planned_buffers; std::vector> planned_spans; + std::vector device_buffers; + /// Per-buffer Device (type + index) metadata used by + /// HierarchicalAllocator. Owns the storage backing the device span the + /// allocator references, so it must outlive `planned_memory`. + std::vector planned_devices; std::unique_ptr planned_memory; }; std::unique_ptr make_planned_memory( @@ -723,6 +730,8 @@ class Module { std::unique_ptr make_planned_memory_with_shared_arenas( const std::vector& buffer_sizes, std::vector>& shared_arenas); + std::unique_ptr make_planned_memory_with_devices( + const ET_RUNTIME_NAMESPACE::MethodMeta& method_meta); runtime::Result> get_mem_planned_buffer_sizes( const std::string& method_name); runtime::Result> get_max_mem_planned_buffer_sizes(); diff --git a/extension/module/targets.bzl b/extension/module/targets.bzl index fa80203831a..e622b138ff6 100644 --- a/extension/module/targets.bzl +++ b/extension/module/targets.bzl @@ -30,6 +30,7 @@ def define_common_targets(): "//executorch/runtime/backend:backend_options", "//executorch/runtime/backend:backend_options_map", "//executorch/runtime/executor:program_no_prim_ops" + aten_suffix, + "//executorch/runtime/core:device_memory_buffer", ], ) diff --git a/extension/module/test/module_device_memory_test.cpp b/extension/module/test/module_device_memory_test.cpp new file mode 100644 index 00000000000..5031273ac2b --- /dev/null +++ b/extension/module/test/module_device_memory_test.cpp @@ -0,0 +1,218 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Tests that Module's device-aware memory allocation path works correctly. + * + * Uses ModuleAddWithDevice.pte which has: + * non_const_buffer_sizes: [0, 48] (1 buffer, index 0 reserved) + * non_const_buffer_device: [{buffer_idx=1, device_type=CUDA, device_index=0}] + * + * Since we don't have a real CUDA backend, we test that: + * 1. CPU-only models load through Module without invoking device allocator + * 2. Device-annotated models trigger DeviceMemoryBuffer::create via a mock + */ + +#include + +#include + +#include +#include +#include + +using executorch::extension::Module; +using executorch::runtime::DeviceAllocator; +using executorch::runtime::DeviceMemoryBuffer; +using executorch::runtime::Error; +using executorch::runtime::register_device_allocator; +using executorch::runtime::Result; +using executorch::runtime::etensor::DeviceIndex; +using executorch::runtime::etensor::DeviceType; + +namespace { + +class MockCudaAllocator : public DeviceAllocator { + public: + Result allocate( + size_t nbytes, + DeviceIndex index, + size_t alignment = kDefaultAlignment) override { + (void)alignment; + allocate_count_++; + last_allocate_size_ = nbytes; + last_allocate_index_ = index; + buffer_ = std::make_unique(nbytes); + return static_cast(buffer_.get()); + } + + void deallocate(void* ptr, DeviceIndex index) override { + deallocate_count_++; + buffer_.reset(); + } + + Error copy_host_to_device(void*, const void*, size_t, DeviceIndex) override { + return Error::Ok; + } + + Error copy_device_to_host(void*, const void*, size_t, DeviceIndex) override { + return Error::Ok; + } + + DeviceType device_type() const override { + return DeviceType::CUDA; + } + + int allocate_count_ = 0; + int deallocate_count_ = 0; + size_t last_allocate_size_ = 0; + DeviceIndex last_allocate_index_ = -1; + + private: + std::unique_ptr buffer_; +}; + +} // namespace + +static MockCudaAllocator g_mock_cuda; + +class ModuleDeviceMemoryTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + executorch::runtime::runtime_init(); + register_device_allocator(&g_mock_cuda); + } + + void SetUp() override { + g_mock_cuda.allocate_count_ = 0; + g_mock_cuda.deallocate_count_ = 0; + g_mock_cuda.last_allocate_size_ = 0; + g_mock_cuda.last_allocate_index_ = -1; + } +}; + +TEST_F(ModuleDeviceMemoryTest, CpuOnlyModelDoesNotAllocateDeviceMemory) { + const char* path = std::getenv("ET_MODULE_ADD_PATH"); + ASSERT_NE(path, nullptr) << "ET_MODULE_ADD_PATH not set"; + + Module module(path); + auto err = module.load_method("forward"); + ASSERT_EQ(err, Error::Ok); + + EXPECT_EQ(g_mock_cuda.allocate_count_, 0) + << "CPU-only model should not allocate device memory"; +} + +TEST_F(ModuleDeviceMemoryTest, DeviceMemoryBufferCreateCallsAllocator) { + // Directly test DeviceMemoryBuffer::create with the registered mock. + // This verifies the RAII allocation/deallocation path that Module uses. + { + auto result = DeviceMemoryBuffer::create(48, DeviceType::CUDA, 0); + ASSERT_TRUE(result.ok()); + auto buf = std::move(result.get()); + + EXPECT_EQ(g_mock_cuda.allocate_count_, 1); + EXPECT_EQ(g_mock_cuda.last_allocate_size_, 48); + EXPECT_EQ(g_mock_cuda.last_allocate_index_, 0); + EXPECT_NE(buf.data(), nullptr); + EXPECT_EQ(buf.size(), 48); + + // as_span() wraps the device pointer for HierarchicalAllocator. + auto span = buf.as_span(); + EXPECT_EQ(span.data(), static_cast(buf.data())); + EXPECT_EQ(span.size(), 48); + + EXPECT_EQ(g_mock_cuda.deallocate_count_, 0); + } + // RAII deallocation on scope exit. + EXPECT_EQ(g_mock_cuda.deallocate_count_, 1); +} + +TEST_F(ModuleDeviceMemoryTest, DeviceModelMethodMetaReportsCudaBuffer) { + // Verify MethodMeta reports the correct device for buffers in the + // device-annotated model, without needing to load the full method. + const char* path = std::getenv("ET_MODULE_ADD_WITH_DEVICE_PATH"); + ASSERT_NE(path, nullptr) << "ET_MODULE_ADD_WITH_DEVICE_PATH not set"; + + Module module(path); + auto err = module.load(); + ASSERT_EQ(err, Error::Ok); + + auto meta = module.method_meta("forward"); + ASSERT_TRUE(meta.ok()); + + // ModuleAddWithDevice has 1 planned buffer (48 bytes) on CUDA. + ASSERT_EQ(meta->num_memory_planned_buffers(), 1); + + auto size = meta->memory_planned_buffer_size(0); + ASSERT_TRUE(size.ok()); + EXPECT_EQ(size.get(), 48); + + auto device = meta->memory_planned_buffer_device(0); + ASSERT_TRUE(device.ok()); + EXPECT_EQ(device->type(), DeviceType::CUDA); + EXPECT_EQ(device->index(), 0); +} + +TEST_F(ModuleDeviceMemoryTest, DeviceModelWithSharedArenasReturnsNotSupported) { + const char* path = std::getenv("ET_MODULE_ADD_WITH_DEVICE_PATH"); + ASSERT_NE(path, nullptr) << "ET_MODULE_ADD_WITH_DEVICE_PATH not set"; + + // share_memory_arenas = true with a device-annotated model should fail. + Module module( + path, + Module::LoadMode::File, + /*event_tracer=*/nullptr, + /*memory_allocator=*/nullptr, + /*temp_allocator=*/nullptr, + /*share_memory_arenas=*/true); + + auto err = module.load_method("forward"); + EXPECT_EQ(err, Error::NotSupported); +} + +TEST_F( + ModuleDeviceMemoryTest, + LoadMethodAllocatesDeviceMemoryAndDeallocatesOnDestroy) { + const char* path = std::getenv("ET_MODULE_ADD_WITH_DEVICE_PATH"); + ASSERT_NE(path, nullptr) << "ET_MODULE_ADD_WITH_DEVICE_PATH not set"; + + { + Module module(path); + auto err = module.load_method("forward"); + + // Regardless of whether load_method succeeds or fails (e.g. due to + // backend init issues), the device-aware memory allocation path + // (make_planned_memory_with_devices) runs BEFORE backend init. + EXPECT_EQ(g_mock_cuda.allocate_count_, 1) + << "Expected 1 device allocation for the CUDA buffer" + << " (actual: " << g_mock_cuda.allocate_count_ << ")" + << ", deallocate_count=" << g_mock_cuda.deallocate_count_ + << ", load_method returned error=" << static_cast(err); + EXPECT_EQ(g_mock_cuda.last_allocate_size_, 48) + << "Expected 48 bytes allocated (3 CUDA tensors sharing one buffer)"; + EXPECT_EQ(g_mock_cuda.last_allocate_index_, 0) + << "Expected device_index=0 (cuda:0)"; + + if (err == Error::Ok) { + // Success path: MethodHolder moved into methods_ map. + // DeviceMemoryBuffer is alive as long as Module is alive. + EXPECT_EQ(g_mock_cuda.deallocate_count_, 0) + << "No deallocation while method is loaded"; + } else { + // Error path: local MethodHolder destroyed on return from load_method. + // RAII deallocation already happened. + EXPECT_EQ(g_mock_cuda.deallocate_count_, 1) + << "RAII deallocation on error path"; + } + } + + // After Module destroyed, all device memory must be freed. + EXPECT_EQ(g_mock_cuda.deallocate_count_, 1) + << "Expected deallocation after Module destroyed"; +} diff --git a/extension/module/test/targets.bzl b/extension/module/test/targets.bzl index f0d7e449efd..4dc3fb537f3 100644 --- a/extension/module/test/targets.bzl +++ b/extension/module/test/targets.bzl @@ -28,7 +28,7 @@ def define_common_targets(is_fbcode=False): aten_suffix = ("_aten" if aten_mode else "") runtime.cxx_test( - name = "test" + aten_suffix, + name = "module_test" + aten_suffix, srcs = [ "module_test.cpp", ], @@ -68,6 +68,26 @@ def define_common_targets(is_fbcode=False): ], ) + runtime.cxx_test( + name = "module_device_memory_test" + aten_suffix, + srcs = [ + "module_device_memory_test.cpp", + ], + deps = [ + "//executorch/kernels/portable:generated_lib" + aten_suffix, + "//executorch/extension/module:module" + aten_suffix, + "//executorch/runtime/core:device_allocator", + "//executorch/runtime/core:device_memory_buffer", + ], + env = { + "ET_MODULE_ADD_WITH_DEVICE_PATH": "$(location fbcode//executorch/test/models:exported_program_with_device_info[ModuleAddWithDevice.pte])", + "ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])", + }, + compiler_flags = [ + "-Wno-error=deprecated-declarations", + ], + ) + runtime.filegroup( name = "resources", srcs = native.glob([ diff --git a/shim_et/xplat/executorch/build/build_variables.bzl b/shim_et/xplat/executorch/build/build_variables.bzl index b0545b8ce18..659a128994f 100644 --- a/shim_et/xplat/executorch/build/build_variables.bzl +++ b/shim_et/xplat/executorch/build/build_variables.bzl @@ -50,6 +50,8 @@ PLATFORM_SRCS = [ EXECUTORCH_CORE_SRCS = sorted([ "runtime/backend/interface.cpp", + "runtime/core/device_allocator.cpp", + "runtime/core/device_memory_buffer.cpp", "runtime/core/evalue.cpp", "runtime/core/exec_aten/util/tensor_shape_to_c_string.cpp", "runtime/core/exec_aten/util/tensor_util_portable.cpp", diff --git a/test/models/targets.bzl b/test/models/targets.bzl index c9fb67b7d31..a80244b1383 100644 --- a/test/models/targets.bzl +++ b/test/models/targets.bzl @@ -226,6 +226,7 @@ def define_common_targets(): default_outs = ["."], visibility = [ "//executorch/runtime/executor/test/...", + "//executorch/extension/module/test/...", ], ) diff --git a/tools/cmake/preset/riscv64_baremetal.cmake b/tools/cmake/preset/riscv64_baremetal.cmake new file mode 100644 index 00000000000..e70fc57ba57 --- /dev/null +++ b/tools/cmake/preset/riscv64_baremetal.cmake @@ -0,0 +1,50 @@ +# Copyright 2026 The ExecuTorch Authors. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Baremetal builds consume the build tree directly; mirror arm_baremetal so +# install rules stay invokable but write back into the build dir. +define_overridable_option( + EXECUTORCH_BAREMETAL_SKIP_INSTALL + "Skip emitting install/export rules when building bare-metal artifacts" BOOL + ON +) + +if(EXECUTORCH_BAREMETAL_SKIP_INSTALL) + set(CMAKE_INSTALL_PREFIX "${CMAKE_BINARY_DIR}") + unset(CMAKE_SKIP_INSTALL_RULES CACHE) + set(CMAKE_SKIP_INSTALL_RULES + OFF + CACHE + BOOL + "Retain install() rules so docs/scripts can keep calling --target install" + FORCE + ) +endif() + +set_overridable_option(EXECUTORCH_BUILD_EXECUTOR_RUNNER OFF) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER OFF) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR OFF) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_EVALUE_UTIL ON) +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL ON) +set_overridable_option(EXECUTORCH_BUILD_KERNELS_QUANTIZED ON) +# BUNDLE_IO requires DEVTOOLS to provide the bundled_program lib. +set_overridable_option(EXECUTORCH_BUILD_DEVTOOLS ON) +set_overridable_option(EXECUTORCH_ENABLE_BUNDLE_IO ON) +set_overridable_option(EXECUTORCH_ENABLE_LOGGING ON) +# Freestanding target: no pthreadpool, no cpuinfo, no shared lib. +set_overridable_option(EXECUTORCH_BUILD_PTHREADPOOL OFF) +set_overridable_option(EXECUTORCH_BUILD_CPUINFO OFF) + +define_overridable_option( + EXECUTORCH_BUILD_RISCV_ETDUMP "Build etdump support for RISC-V" BOOL OFF +) + +if("${EXECUTORCH_BUILD_RISCV_ETDUMP}") + set(EXECUTORCH_BUILD_DEVTOOLS ON) + set(EXECUTORCH_ENABLE_EVENT_TRACER ON) + set(FLATCC_ALLOW_WERROR OFF) +else() + set(EXECUTORCH_ENABLE_EVENT_TRACER OFF) +endif()