diff --git a/core/src/main/java/org/apache/datafusion/ColumnarValue.java b/core/src/main/java/org/apache/datafusion/ColumnarValue.java new file mode 100644 index 0000000..cbe0703 --- /dev/null +++ b/core/src/main/java/org/apache/datafusion/ColumnarValue.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import java.util.Objects; + +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.types.pojo.ArrowType; + +/** + * The value of a scalar UDF argument or result: either a per-row {@link Array} of length {@code + * rowCount}, or a {@link Scalar} (length-1 vector) that the framework broadcasts. + * + *

Mirrors DataFusion's {@code datafusion::logical_expr::ColumnarValue} enum. Use {@link + * #array(FieldVector)} and {@link #scalar(FieldVector)} factories rather than constructing the + * records directly so the length invariants are enforced consistently. + */ +public sealed interface ColumnarValue permits ColumnarValue.Array, ColumnarValue.Scalar { + + /** The underlying Arrow vector. For {@link Scalar} this vector has {@code valueCount == 1}. */ + FieldVector vector(); + + /** Convenience: the vector's declared Arrow type. */ + default ArrowType dataType() { + return vector().getField().getType(); + } + + /** Wrap an arbitrary-length vector as an {@link Array}. */ + static ColumnarValue array(FieldVector vector) { + return new Array(Objects.requireNonNull(vector, "vector")); + } + + /** + * Wrap a length-1 vector as a {@link Scalar}. + * + * @throws IllegalArgumentException if {@code vector.getValueCount() != 1} + */ + static ColumnarValue scalar(FieldVector vector) { + Objects.requireNonNull(vector, "vector"); + if (vector.getValueCount() != 1) { + throw new IllegalArgumentException( + "Scalar vector must have valueCount == 1, got " + vector.getValueCount()); + } + return new Scalar(vector); + } + + /** Per-row Arrow vector of length equal to the batch row count. */ + record Array(FieldVector vector) implements ColumnarValue { + public Array { + Objects.requireNonNull(vector, "vector"); + } + } + + /** Length-1 Arrow vector representing a single value broadcast across all rows. */ + record Scalar(FieldVector vector) implements ColumnarValue { + public Scalar { + Objects.requireNonNull(vector, "vector"); + if (vector.getValueCount() != 1) { + throw new IllegalArgumentException( + "Scalar vector must have valueCount == 1, got " + vector.getValueCount()); + } + } + } +} diff --git a/core/src/main/java/org/apache/datafusion/ScalarFunction.java b/core/src/main/java/org/apache/datafusion/ScalarFunction.java index 676154e..b83c636 100644 --- a/core/src/main/java/org/apache/datafusion/ScalarFunction.java +++ b/core/src/main/java/org/apache/datafusion/ScalarFunction.java @@ -22,7 +22,6 @@ import java.util.List; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.types.pojo.ArrowType; /** @@ -46,7 +45,9 @@ public interface ScalarFunction { */ List argTypes(); - /** Declared return type. The returned {@link FieldVector} must have this exact type. */ + /** + * Declared return type. The returned {@link ColumnarValue}'s vector must have this exact type. + */ ArrowType returnType(); /** @@ -59,14 +60,16 @@ public interface ScalarFunction { /** * Compute the function result for one input batch. * - * @param allocator the {@link BufferAllocator} that MUST be used for any new {@link FieldVector} + * @param allocator the {@link BufferAllocator} that MUST be used for any new Arrow vector * allocation, including the result. Buffers allocated from other allocators will not survive * the JNI handoff. - * @param args one {@link FieldVector} per declared argument, all of the same length. These are - * read-only views; the implementation must NOT close them. - * @return a {@link FieldVector} of the declared return type and the same length as the inputs. - * Ownership transfers to the framework on return; the implementation must NOT close the - * returned vector. + * @param args the per-arg {@link ColumnarValue}s and the batch row count. Each {@link + * ColumnarValue} is a read-only view; the implementation must NOT close its underlying + * vector. + * @return a {@link ColumnarValue} of the declared return type. If {@link ColumnarValue.Array}, + * the underlying vector must have length {@code args.rowCount()}; if {@link + * ColumnarValue.Scalar}, length 1. Ownership of the returned vector transfers to the + * framework; the implementation must NOT close it. */ - FieldVector evaluate(BufferAllocator allocator, List args); + ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args); } diff --git a/core/src/main/java/org/apache/datafusion/ScalarFunctionArgs.java b/core/src/main/java/org/apache/datafusion/ScalarFunctionArgs.java new file mode 100644 index 0000000..927fcb1 --- /dev/null +++ b/core/src/main/java/org/apache/datafusion/ScalarFunctionArgs.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import java.util.List; +import java.util.Objects; + +/** + * Bundle of inputs passed to {@link ScalarFunction#evaluate}: the per-arg {@link ColumnarValue}s + * (in declared order) and the batch row count DataFusion is driving. + * + *

Mirrors DataFusion's {@code datafusion::logical_expr::ScalarFunctionArgs}. {@code rowCount} is + * the only channel by which an array-returning UDF without array-typed inputs (all-scalar args, or + * nullary) can size its output. Nullary UDFs that prefer to broadcast a single value should return + * {@link ColumnarValue#scalar(org.apache.arrow.vector.FieldVector) ColumnarValue.scalar(...)} + * instead, which removes the need to consult {@code rowCount}. + */ +public record ScalarFunctionArgs(List args, int rowCount) { + public ScalarFunctionArgs { + args = List.copyOf(Objects.requireNonNull(args, "args")); + if (rowCount < 0) { + throw new IllegalArgumentException("rowCount must be >= 0, got " + rowCount); + } + } +} diff --git a/core/src/main/java/org/apache/datafusion/internal/JniBridge.java b/core/src/main/java/org/apache/datafusion/internal/JniBridge.java index 210e996..8248357 100644 --- a/core/src/main/java/org/apache/datafusion/internal/JniBridge.java +++ b/core/src/main/java/org/apache/datafusion/internal/JniBridge.java @@ -19,6 +19,7 @@ package org.apache.datafusion.internal; +import java.util.ArrayList; import java.util.List; import org.apache.arrow.c.ArrowArray; @@ -27,7 +28,9 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.datafusion.ColumnarValue; import org.apache.datafusion.ScalarFunction; +import org.apache.datafusion.ScalarFunctionArgs; /** Internal trampoline invoked from native code on every UDF call. Not part of the public API. */ public final class JniBridge { @@ -40,54 +43,100 @@ public final class JniBridge { private JniBridge() {} + /** argKind byte signalling a {@link ColumnarValue.Array} arg. */ + private static final byte KIND_ARRAY = 0; + + /** argKind byte signalling a {@link ColumnarValue.Scalar} arg. */ + private static final byte KIND_SCALAR = 1; + /** * Invoke a scalar UDF for one batch. Called from native code; not for application use. * - * @param impl the registered {@link ScalarFunction} implementation - * @param argsArrayAddr address of a populated {@code FFI_ArrowArray} struct holding the input - * batch as a struct array (one field per UDF argument) - * @param argsSchemaAddr address of the matching {@code FFI_ArrowSchema} - * @param resultArrayAddr address of an empty {@code FFI_ArrowArray} the bridge writes into - * @param resultSchemaAddr address of an empty {@code FFI_ArrowSchema} the bridge writes into - * @param expectedRowCount the row count the result vector must have + *

Args arrive split into two struct arrays: {@code arrayArgs*} of length {@code rowCount} + * holding the {@link ColumnarValue.Array} arguments in their relative order, and {@code + * scalarArgs*} of length 1 holding the {@link ColumnarValue.Scalar} arguments. {@code argKinds} + * records the original positional order so the bridge can interleave them back into a single + * {@code List} for the user. + * + * @return {@link #KIND_ARRAY} if the UDF returned an Array, {@link #KIND_SCALAR} if it returned a + * Scalar. The native caller uses this to reconstruct the right {@code ColumnarValue} variant. */ - public static void invokeScalarUdf( + public static byte invokeScalarUdf( ScalarFunction impl, - long argsArrayAddr, - long argsSchemaAddr, + long arrayArgsArrayAddr, + long arrayArgsSchemaAddr, + long scalarArgsArrayAddr, + long scalarArgsSchemaAddr, + byte[] argKinds, long resultArrayAddr, long resultSchemaAddr, - int expectedRowCount) { - ArrowArray argsArr = ArrowArray.wrap(argsArrayAddr); - ArrowSchema argsSch = ArrowSchema.wrap(argsSchemaAddr); + int rowCount) { + ArrowArray arrayArr = ArrowArray.wrap(arrayArgsArrayAddr); + ArrowSchema arraySch = ArrowSchema.wrap(arrayArgsSchemaAddr); + ArrowArray scalarArr = ArrowArray.wrap(scalarArgsArrayAddr); + ArrowSchema scalarSch = ArrowSchema.wrap(scalarArgsSchemaAddr); ArrowArray resultArr = ArrowArray.wrap(resultArrayAddr); ArrowSchema resultSch = ArrowSchema.wrap(resultSchemaAddr); - try (VectorSchemaRoot root = Data.importVectorSchemaRoot(ALLOCATOR, argsArr, argsSch, null)) { - List argVectors = root.getFieldVectors(); + try (VectorSchemaRoot arrayRoot = + Data.importVectorSchemaRoot(ALLOCATOR, arrayArr, arraySch, null); + VectorSchemaRoot scalarRoot = + Data.importVectorSchemaRoot(ALLOCATOR, scalarArr, scalarSch, null)) { - FieldVector result = impl.evaluate(ALLOCATOR, argVectors); + List arrayFields = arrayRoot.getFieldVectors(); + List scalarFields = scalarRoot.getFieldVectors(); + + List args = new ArrayList<>(argKinds.length); + int arrayIdx = 0; + int scalarIdx = 0; + for (byte kind : argKinds) { + if (kind == KIND_ARRAY) { + args.add(ColumnarValue.array(arrayFields.get(arrayIdx++))); + } else if (kind == KIND_SCALAR) { + args.add(ColumnarValue.scalar(scalarFields.get(scalarIdx++))); + } else { + throw new IllegalStateException("Unknown argKind byte: " + kind); + } + } + + ColumnarValue result = impl.evaluate(ALLOCATOR, new ScalarFunctionArgs(args, rowCount)); if (result == null) { throw new IllegalStateException("ScalarFunction.evaluate returned null"); } - if (result.getValueCount() != expectedRowCount) { + + FieldVector resultVec = result.vector(); + byte resultKind; + int expectedLen; + if (result instanceof ColumnarValue.Array) { + resultKind = KIND_ARRAY; + expectedLen = rowCount; + } else { + resultKind = KIND_SCALAR; + expectedLen = 1; + } + + if (resultVec.getValueCount() != expectedLen) { try { throw new IllegalStateException( - "ScalarFunction.evaluate returned vector with " - + result.getValueCount() + "ScalarFunction.evaluate returned " + + (resultKind == KIND_ARRAY ? "Array" : "Scalar") + + " vector with " + + resultVec.getValueCount() + " rows; expected " - + expectedRowCount); + + expectedLen); } finally { - result.close(); + resultVec.close(); } } try { - Data.exportVector(ALLOCATOR, result, null, resultArr, resultSch); + Data.exportVector(ALLOCATOR, resultVec, null, resultArr, resultSch); } finally { - result.close(); + resultVec.close(); } + + return resultKind; } } } diff --git a/core/src/test/java/org/apache/datafusion/ColumnarValueTest.java b/core/src/test/java/org/apache/datafusion/ColumnarValueTest.java new file mode 100644 index 0000000..42879e0 --- /dev/null +++ b/core/src/test/java/org/apache/datafusion/ColumnarValueTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.junit.jupiter.api.Test; + +class ColumnarValueTest { + + private static final ArrowType INT32 = new ArrowType.Int(32, true); + + @Test + void array_factory_returnsArrayVariant() { + try (BufferAllocator allocator = new RootAllocator(); + IntVector v = new IntVector("v", allocator)) { + v.allocateNew(3); + v.setValueCount(3); + ColumnarValue cv = ColumnarValue.array(v); + assertSame(v, cv.vector()); + assertEquals(INT32, cv.dataType()); + } + } + + @Test + void scalar_factory_returnsScalarVariant() { + try (BufferAllocator allocator = new RootAllocator(); + IntVector v = new IntVector("v", allocator)) { + v.allocateNew(1); + v.set(0, 42); + v.setValueCount(1); + ColumnarValue cv = ColumnarValue.scalar(v); + assertSame(v, cv.vector()); + assertEquals(INT32, cv.dataType()); + } + } + + @Test + void scalar_factory_rejectsNonOneLength() { + try (BufferAllocator allocator = new RootAllocator(); + IntVector v = new IntVector("v", allocator)) { + v.allocateNew(2); + v.setValueCount(2); + IllegalArgumentException ex = + assertThrows(IllegalArgumentException.class, () -> ColumnarValue.scalar(v)); + assertEquals("Scalar vector must have valueCount == 1, got 2", ex.getMessage()); + } + } + + @Test + void array_factory_rejectsNull() { + assertThrows(NullPointerException.class, () -> ColumnarValue.array(null)); + } + + @Test + void scalar_factory_rejectsNull() { + assertThrows(NullPointerException.class, () -> ColumnarValue.scalar(null)); + } +} diff --git a/core/src/test/java/org/apache/datafusion/ScalarFunctionArgsTest.java b/core/src/test/java/org/apache/datafusion/ScalarFunctionArgsTest.java new file mode 100644 index 0000000..df9f480 --- /dev/null +++ b/core/src/test/java/org/apache/datafusion/ScalarFunctionArgsTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.junit.jupiter.api.Test; + +class ScalarFunctionArgsTest { + + @Test + void construct_emptyArgs_zeroRows_ok() { + ScalarFunctionArgs a = new ScalarFunctionArgs(List.of(), 0); + assertEquals(List.of(), a.args()); + assertEquals(0, a.rowCount()); + } + + @Test + void construct_rejectsNullArgs() { + assertThrows(NullPointerException.class, () -> new ScalarFunctionArgs(null, 3)); + } + + @Test + void construct_rejectsNegativeRowCount() { + assertThrows(IllegalArgumentException.class, () -> new ScalarFunctionArgs(List.of(), -1)); + } + + @Test + void construct_copiesArgsDefensively() { + try (BufferAllocator allocator = new RootAllocator(); + IntVector v = new IntVector("v", allocator)) { + v.allocateNew(1); + v.setValueCount(1); + List source = new ArrayList<>(); + source.add(ColumnarValue.scalar(v)); + ScalarFunctionArgs a = new ScalarFunctionArgs(source, 1); + source.clear(); + assertEquals(1, a.args().size()); + assertThrows(UnsupportedOperationException.class, () -> a.args().clear()); + } + } + + @Test + void args_singletonCase_preservesValue() { + try (BufferAllocator allocator = new RootAllocator(); + IntVector v = new IntVector("v", allocator)) { + v.allocateNew(1); + v.set(0, 7); + v.setValueCount(1); + ScalarFunctionArgs a = + new ScalarFunctionArgs(Collections.singletonList(ColumnarValue.scalar(v)), 5); + assertEquals(5, a.rowCount()); + assertTrue(a.args().get(0) instanceof ColumnarValue.Scalar); + } + } +} diff --git a/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java b/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java index 97f5f52..3e14580 100644 --- a/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java +++ b/core/src/test/java/org/apache/datafusion/ScalarUdfTest.java @@ -86,8 +86,8 @@ static final class AddOne extends AbstractScalarFunction { } @Override - public FieldVector evaluate(BufferAllocator allocator, List args) { - IntVector in = (IntVector) args.get(0); + public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) { + IntVector in = (IntVector) args.args().get(0).vector(); IntVector out = new IntVector("add_one_out", allocator); int n = in.getValueCount(); out.allocateNew(n); @@ -99,7 +99,7 @@ public FieldVector evaluate(BufferAllocator allocator, List args) { } } out.setValueCount(n); - return out; + return ColumnarValue.array(out); } } @@ -133,11 +133,11 @@ static final class Concat extends AbstractScalarFunction { } @Override - public FieldVector evaluate(BufferAllocator allocator, List args) { + public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) { org.apache.arrow.vector.VarCharVector left = - (org.apache.arrow.vector.VarCharVector) args.get(0); + (org.apache.arrow.vector.VarCharVector) args.args().get(0).vector(); org.apache.arrow.vector.VarCharVector right = - (org.apache.arrow.vector.VarCharVector) args.get(1); + (org.apache.arrow.vector.VarCharVector) args.args().get(1).vector(); org.apache.arrow.vector.VarCharVector out = new org.apache.arrow.vector.VarCharVector("concat_out", allocator); int n = left.getValueCount(); @@ -155,7 +155,7 @@ public FieldVector evaluate(BufferAllocator allocator, List args) { } } out.setValueCount(n); - return out; + return ColumnarValue.array(out); } } @@ -188,8 +188,9 @@ static final class Square extends AbstractScalarFunction { } @Override - public FieldVector evaluate(BufferAllocator allocator, List args) { - org.apache.arrow.vector.Float8Vector in = (org.apache.arrow.vector.Float8Vector) args.get(0); + public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) { + org.apache.arrow.vector.Float8Vector in = + (org.apache.arrow.vector.Float8Vector) args.args().get(0).vector(); org.apache.arrow.vector.Float8Vector out = new org.apache.arrow.vector.Float8Vector("square_out", allocator); int n = in.getValueCount(); @@ -203,7 +204,7 @@ public FieldVector evaluate(BufferAllocator allocator, List args) { } } out.setValueCount(n); - return out; + return ColumnarValue.array(out); } } @@ -252,7 +253,7 @@ static final class ReturnsNull extends AbstractScalarFunction { } @Override - public FieldVector evaluate(BufferAllocator allocator, List args) { + public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) { return null; } } @@ -283,13 +284,13 @@ static final class WrongRowCount extends AbstractScalarFunction { } @Override - public FieldVector evaluate(BufferAllocator allocator, List args) { - IntVector in = (IntVector) args.get(0); + public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) { + IntVector in = (IntVector) args.args().get(0).vector(); IntVector out = new IntVector("out", allocator); out.allocateNew(in.getValueCount() + 1); // off by one for (int i = 0; i < in.getValueCount() + 1; i++) out.set(i, 0); out.setValueCount(in.getValueCount() + 1); - return out; + return ColumnarValue.array(out); } } @@ -319,14 +320,15 @@ static final class WrongType extends AbstractScalarFunction { } @Override - public FieldVector evaluate(BufferAllocator allocator, List args) { + public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) { // Declared return type is Int32; return Float64. org.apache.arrow.vector.Float8Vector out = new org.apache.arrow.vector.Float8Vector("out", allocator); - out.allocateNew(args.get(0).getValueCount()); - for (int i = 0; i < args.get(0).getValueCount(); i++) out.set(i, 0.0); - out.setValueCount(args.get(0).getValueCount()); - return out; + FieldVector in = args.args().get(0).vector(); + out.allocateNew(in.getValueCount()); + for (int i = 0; i < in.getValueCount(); i++) out.set(i, 0.0); + out.setValueCount(in.getValueCount()); + return ColumnarValue.array(out); } } @@ -356,7 +358,7 @@ static final class ThrowsIAE extends AbstractScalarFunction { } @Override - public FieldVector evaluate(BufferAllocator allocator, List args) { + public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) { throw new IllegalArgumentException("custom boom from UDF"); } } @@ -448,6 +450,152 @@ void udfAppliedToMultiRowQuery_processesAllRows() throws Exception { } } + /** + * Nullary UDF returning a length-1 Float8 vector. Marked VOLATILE so DataFusion's constant folder + * does not collapse the call before reaching us. Exercises the path that the abandoned PR #57 + * added a separate rowCount parameter for: a nullary UDF can now broadcast its value through + * {@link ColumnarValue#scalar(FieldVector)} and the framework handles per-row expansion. + */ + static final class JavaPi extends AbstractScalarFunction { + JavaPi() { + super("java_pi", List.of(), FLOAT64, Volatility.VOLATILE); + } + + @Override + public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) { + org.apache.arrow.vector.Float8Vector out = + new org.apache.arrow.vector.Float8Vector("pi_out", allocator); + out.allocateNew(1); + out.set(0, Math.PI); + out.setValueCount(1); + return ColumnarValue.scalar(out); + } + } + + @Test + void nullaryScalarReturnUdf_overMultiRowQuery_broadcasts() throws Exception { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf(new ScalarUdf(new JavaPi())); + + try (DataFrame df = ctx.sql("SELECT java_pi() AS p FROM (VALUES (1), (2), (3)) AS t(x)"); + ArrowReader r = df.collect(allocator)) { + assertEquals(true, r.loadNextBatch()); + VectorSchemaRoot root = r.getVectorSchemaRoot(); + org.apache.arrow.vector.Float8Vector p = + (org.apache.arrow.vector.Float8Vector) root.getVector("p"); + assertEquals(3, p.getValueCount()); + assertEquals(Math.PI, p.get(0), 0.0); + assertEquals(Math.PI, p.get(1), 0.0); + assertEquals(Math.PI, p.get(2), 0.0); + } + } + } + + /** + * UDF over (int_col, int_literal). On every invocation it asserts that arg 0 is an Array and arg + * 1 is a Scalar (length-1 vector). Proves the FFI protocol preserves scalar-ness end-to-end + * rather than materialising the literal to a length-N array on the native side. + */ + static final class AssertSecondArgIsScalar extends AbstractScalarFunction { + AssertSecondArgIsScalar() { + super("assert_scalar_arg", List.of(INT32, INT32), INT32, Volatility.IMMUTABLE); + } + + @Override + public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) { + if (!(args.args().get(0) instanceof ColumnarValue.Array)) { + throw new AssertionError( + "arg 0 expected Array, got " + args.args().get(0).getClass().getSimpleName()); + } + if (!(args.args().get(1) instanceof ColumnarValue.Scalar)) { + throw new AssertionError( + "arg 1 expected Scalar, got " + args.args().get(1).getClass().getSimpleName()); + } + IntVector left = (IntVector) args.args().get(0).vector(); + IntVector right = (IntVector) args.args().get(1).vector(); + if (right.getValueCount() != 1) { + throw new AssertionError( + "Scalar arg vector should have length 1, got " + right.getValueCount()); + } + int rightVal = right.get(0); + IntVector out = new IntVector("out", allocator); + int n = left.getValueCount(); + out.allocateNew(n); + for (int i = 0; i < n; i++) { + if (left.isNull(i)) { + out.setNull(i); + } else { + out.set(i, left.get(i) + rightVal); + } + } + out.setValueCount(n); + return ColumnarValue.array(out); + } + } + + @Test + void scalarLiteralArg_arrivesAsScalarColumnarValue() throws Exception { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf(new ScalarUdf(new AssertSecondArgIsScalar())); + + try (DataFrame df = + ctx.sql( + "SELECT assert_scalar_arg(x, CAST(100 AS INT)) AS y" + + " FROM (VALUES (CAST(1 AS INT)), (CAST(2 AS INT)), (CAST(3 AS INT)))" + + " AS t(x)"); + ArrowReader r = df.collect(allocator)) { + assertEquals(true, r.loadNextBatch()); + VectorSchemaRoot root = r.getVectorSchemaRoot(); + IntVector y = (IntVector) root.getVector("y"); + assertEquals(3, y.getValueCount()); + assertEquals(101, y.get(0)); + assertEquals(102, y.get(1)); + assertEquals(103, y.get(2)); + } + } + } + + /** UDF that ignores its input and returns a constant Scalar. */ + static final class IgnoreInputReturnFortyTwo extends AbstractScalarFunction { + IgnoreInputReturnFortyTwo() { + super("forty_two", List.of(INT32), INT32, Volatility.IMMUTABLE); + } + + @Override + public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) { + IntVector out = new IntVector("out", allocator); + out.allocateNew(1); + out.set(0, 42); + out.setValueCount(1); + return ColumnarValue.scalar(out); + } + } + + @Test + void udfReturningScalar_isBroadcastByFramework() throws Exception { + try (SessionContext ctx = new SessionContext(); + BufferAllocator allocator = new RootAllocator()) { + ctx.registerUdf(new ScalarUdf(new IgnoreInputReturnFortyTwo())); + + try (DataFrame df = + ctx.sql( + "SELECT forty_two(x) AS y" + + " FROM (VALUES (CAST(1 AS INT)), (CAST(2 AS INT))," + + " (CAST(3 AS INT)), (CAST(4 AS INT)), (CAST(5 AS INT))) AS t(x)"); + ArrowReader r = df.collect(allocator)) { + assertEquals(true, r.loadNextBatch()); + VectorSchemaRoot root = r.getVectorSchemaRoot(); + IntVector y = (IntVector) root.getVector("y"); + assertEquals(5, y.getValueCount()); + for (int i = 0; i < 5; i++) { + assertEquals(42, y.get(i)); + } + } + } + } + @Test void volatilityBytesRoundTrip_forAllThreeKinds() throws Exception { for (Volatility v : Volatility.values()) { diff --git a/docs/source/user-guide/scalar-udf.md b/docs/source/user-guide/scalar-udf.md index 8b78410..b53b32f 100644 --- a/docs/source/user-guide/scalar-udf.md +++ b/docs/source/user-guide/scalar-udf.md @@ -19,9 +19,10 @@ under the License. # Scalar UDFs -A scalar UDF is a Java-implemented SQL function that operates on one row at a -time, expressed in vectorised form: each invocation receives a batch of input -columns and returns a single output column of the same length. +A scalar UDF is a Java-implemented SQL function that operates one row at a time, +expressed in vectorised form: each invocation receives a batch of input columns +and returns either a per-row output column of the same length (`Array`) or a +single value broadcast to every row (`Scalar`). ## Implement @@ -32,10 +33,11 @@ per-batch `evaluate` body: ```java import java.util.List; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.datafusion.ColumnarValue; import org.apache.datafusion.ScalarFunction; +import org.apache.datafusion.ScalarFunctionArgs; import org.apache.datafusion.Volatility; public final class AddOne implements ScalarFunction { @@ -47,8 +49,8 @@ public final class AddOne implements ScalarFunction { @Override public Volatility volatility() { return Volatility.IMMUTABLE; } @Override - public FieldVector evaluate(BufferAllocator allocator, List args) { - IntVector in = (IntVector) args.get(0); + public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) { + IntVector in = (IntVector) args.args().get(0).vector(); IntVector out = new IntVector("add_one", allocator); out.allocateNew(in.getValueCount()); for (int i = 0; i < in.getValueCount(); i++) { @@ -56,15 +58,50 @@ public final class AddOne implements ScalarFunction { else out.set(i, in.get(i) + 1); } out.setValueCount(in.getValueCount()); - return out; + return ColumnarValue.array(out); } } ``` +Each entry in `args.args()` is a `ColumnarValue` — either `ColumnarValue.Array` +(a per-row vector of length `args.rowCount()`) or `ColumnarValue.Scalar` (a +length-1 vector representing a single literal or folded constant). Access the +underlying Arrow vector with `.vector()`. + Allocate any new vectors — including the result — from the supplied `BufferAllocator`. The input vectors are read-only views; do not close them. Ownership of the returned vector transfers to the framework on return. +## Returning a Scalar + +Functions that yield a single value (nullary constants like `pi()`, or any +function that wants the framework to broadcast a result across the batch) can +return `ColumnarValue.scalar(...)` over a length-1 vector: + +```java +public final class JavaPi implements ScalarFunction { + private static final ArrowType FLOAT64 = + new ArrowType.FloatingPoint(org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE); + + @Override public String name() { return "java_pi"; } + @Override public List argTypes() { return List.of(); } + @Override public ArrowType returnType() { return FLOAT64; } + @Override public Volatility volatility() { return Volatility.VOLATILE; } + + @Override + public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) { + org.apache.arrow.vector.Float8Vector out = + new org.apache.arrow.vector.Float8Vector("pi", allocator); + out.allocateNew(1); + out.set(0, Math.PI); + out.setValueCount(1); + return ColumnarValue.scalar(out); + } +} +``` + +The framework expands the scalar across `args.rowCount()` rows automatically. + ## Register Wrap the implementation in a `ScalarUdf` and pass it to @@ -90,9 +127,11 @@ non-deterministic functions. ## Errors If the UDF throws, the exception class and message surface in the -`RuntimeException` raised from `collect()`. If the returned vector is `null`, -has the wrong row count, or the wrong type, the runtime raises a -`RuntimeException` with a descriptive message. +`RuntimeException` raised from `collect()`. If the returned `ColumnarValue` is +`null`, an Array result's vector length does not equal `args.rowCount()`, or +the result's Arrow type differs from the declared return type, the runtime +raises a `RuntimeException` with a descriptive message. A Scalar result whose +vector is not length-1 is rejected at the `ColumnarValue.scalar` factory. ## Threading diff --git a/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java b/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java index c27bff0..d9416b1 100644 --- a/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java +++ b/examples/src/main/java/org/apache/datafusion/examples/AddOneExample.java @@ -23,13 +23,14 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.datafusion.ColumnarValue; import org.apache.datafusion.DataFrame; import org.apache.datafusion.ScalarFunction; +import org.apache.datafusion.ScalarFunctionArgs; import org.apache.datafusion.ScalarUdf; import org.apache.datafusion.SessionContext; import org.apache.datafusion.Volatility; @@ -62,8 +63,8 @@ public Volatility volatility() { } @Override - public FieldVector evaluate(BufferAllocator allocator, List args) { - IntVector in = (IntVector) args.get(0); + public ColumnarValue evaluate(BufferAllocator allocator, ScalarFunctionArgs args) { + IntVector in = (IntVector) args.args().get(0).vector(); IntVector out = new IntVector("add_one_out", allocator); int n = in.getValueCount(); out.allocateNew(n); @@ -75,7 +76,7 @@ public FieldVector evaluate(BufferAllocator allocator, List args) { } } out.setValueCount(n); - return out; + return ColumnarValue.array(out); } } diff --git a/native/src/lib.rs b/native/src/lib.rs index f6f16d3..fe46a07 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -663,7 +663,7 @@ pub extern "system" fn Java_org_apache_datafusion_SessionContext_registerScalarU let invoke_method = env.get_static_method_id( &bridge_class_local, "invokeScalarUdf", - "(Lorg/apache/datafusion/ScalarFunction;JJJJI)V", + "(Lorg/apache/datafusion/ScalarFunction;JJJJ[BJJI)B", )?; let java_udf = crate::udf::JavaScalarUdf { diff --git a/native/src/udf.rs b/native/src/udf.rs index 62d0e24..d2b18b4 100644 --- a/native/src/udf.rs +++ b/native/src/udf.rs @@ -19,18 +19,18 @@ use std::any::Any; use std::fmt; -use std::sync::Arc; use datafusion::arrow::array::{make_array, Array, ArrayRef, StructArray}; use datafusion::arrow::datatypes::{DataType, Field, Fields}; use datafusion::arrow::ffi::{from_ffi, to_ffi, FFI_ArrowArray, FFI_ArrowSchema}; +use datafusion::common::ScalarValue; use datafusion::error::DataFusionError; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use jni::objects::{GlobalRef, JStaticMethodID, JThrowable}; use jni::signature::{Primitive, ReturnType}; -use jni::sys::{jlong, jvalue}; +use jni::sys::{jbyte, jlong, jvalue}; use jni::JNIEnv; pub(crate) struct JavaScalarUdf { @@ -99,21 +99,8 @@ impl ScalarUDFImpl for JavaScalarUdf { ) -> datafusion::error::Result { let number_rows = args.number_rows; - // 1. Materialise scalars to arrays so all columns are length-N. - let arrays: Vec = args - .args - .iter() - .map(|cv| cv.clone().into_array(number_rows)) - .collect::>>()?; - - // 2. Build a single struct array carrying all arg columns. Field names/types come - // from the signature's Exact type list (matches what the Java caller declared). - let signature_fields: Vec> = match &self.signature.type_signature { - TypeSignature::Exact(types) => types - .iter() - .enumerate() - .map(|(i, ty)| Arc::new(Field::new(format!("arg{}", i), ty.clone(), true))) - .collect(), + let signature_types: &[DataType] = match &self.signature.type_signature { + TypeSignature::Exact(types) => types, _ => { return Err(DataFusionError::Internal( "JavaScalarUdf signature is not Exact; only Signature::exact is supported" @@ -122,43 +109,87 @@ impl ScalarUDFImpl for JavaScalarUdf { } }; - let fields = Fields::from( - signature_fields - .iter() - .map(|f| f.as_ref().clone()) - .collect::>(), - ); - let struct_array = StructArray::try_new_with_length(fields, arrays, None, number_rows) + if args.args.len() != signature_types.len() { + return Err(DataFusionError::Internal(format!( + "Java UDF '{}' called with {} args; signature declares {}", + self.name, + args.args.len(), + signature_types.len() + ))); + } + + // 1. Partition args by kind. ColumnarValue::Scalar stays as a length-1 array so the Java + // side observes it as a Scalar; ColumnarValue::Array passes through at full length. + let mut array_arrays: Vec = Vec::new(); + let mut array_fields: Vec = Vec::new(); + let mut scalar_arrays: Vec = Vec::new(); + let mut scalar_fields: Vec = Vec::new(); + let mut arg_kinds: Vec = Vec::with_capacity(args.args.len()); + + for (i, cv) in args.args.iter().enumerate() { + let ty = signature_types[i].clone(); + match cv { + ColumnarValue::Array(a) => { + array_fields.push(Field::new(format!("arg{}", array_arrays.len()), ty, true)); + array_arrays.push(a.clone()); + arg_kinds.push(0); + } + ColumnarValue::Scalar(s) => { + let arr = s.to_array_of_size(1)?; + scalar_fields.push(Field::new(format!("arg{}", scalar_arrays.len()), ty, true)); + scalar_arrays.push(arr); + arg_kinds.push(1); + } + } + } + + // 2. Build the two struct arrays. Empty field+array vectors with the appropriate length + // cover nullary and all-one-kind cases. + let array_struct = StructArray::try_new_with_length( + Fields::from(array_fields), + array_arrays, + None, + number_rows, + ) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; + let scalar_struct = + StructArray::try_new_with_length(Fields::from(scalar_fields), scalar_arrays, None, 1) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; + + let (array_ffi_arr, array_ffi_sch) = to_ffi(&array_struct.into_data()) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; + let (scalar_ffi_arr, scalar_ffi_sch) = to_ffi(&scalar_struct.into_data()) .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; - let args_data = struct_array.into_data(); - let (args_ffi_array, args_ffi_schema) = - to_ffi(&args_data).map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; // 3. Pre-allocate empty FFI structs for the result. - let result_ffi_array = FFI_ArrowArray::empty(); - let result_ffi_schema = FFI_ArrowSchema::empty(); + let result_ffi_arr = FFI_ArrowArray::empty(); + let result_ffi_sch = FFI_ArrowSchema::empty(); // 4. Box for stable addresses across the JNI call. - let mut args_array_box = Box::new(args_ffi_array); - let mut args_schema_box = Box::new(args_ffi_schema); - let mut result_array_box = Box::new(result_ffi_array); - let mut result_schema_box = Box::new(result_ffi_schema); - - let args_array_addr = args_array_box.as_mut() as *mut _ as jlong; - let args_schema_addr = args_schema_box.as_mut() as *mut _ as jlong; - let result_array_addr = result_array_box.as_mut() as *mut _ as jlong; - let result_schema_addr = result_schema_box.as_mut() as *mut _ as jlong; + let mut array_arr_box = Box::new(array_ffi_arr); + let mut array_sch_box = Box::new(array_ffi_sch); + let mut scalar_arr_box = Box::new(scalar_ffi_arr); + let mut scalar_sch_box = Box::new(scalar_ffi_sch); + let mut result_arr_box = Box::new(result_ffi_arr); + let mut result_sch_box = Box::new(result_ffi_sch); + + let array_arr_addr = array_arr_box.as_mut() as *mut _ as jlong; + let array_sch_addr = array_sch_box.as_mut() as *mut _ as jlong; + let scalar_arr_addr = scalar_arr_box.as_mut() as *mut _ as jlong; + let scalar_sch_addr = scalar_sch_box.as_mut() as *mut _ as jlong; + let result_arr_addr = result_arr_box.as_mut() as *mut _ as jlong; + let result_sch_addr = result_sch_box.as_mut() as *mut _ as jlong; // 5. Attach JNI to current thread. let mut env = crate::jvm() .attach_current_thread() .map_err(|e| DataFusionError::Execution(format!("JNI attach failed: {}", e)))?; - // 6. Call JniBridge.invokeScalarUdf(udf, args*, result*, expectedRowCount). - // - // Build the jvalue argument array for call_static_method_unchecked. - // SAFETY: we build the args inline and pass them immediately; the JObject - // pointed to by udf_global_ref is alive for the duration of this call. + // 6. Build the byte[] for argKinds inside the JVM heap. JNI local; freed when env drops. + let arg_kinds_array = env.byte_array_from_slice(&arg_kinds).map_err(|e| { + DataFusionError::Execution(format!("byte_array_from_slice failed: {}", e)) + })?; + let expected_rows = i32::try_from(number_rows).map_err(|_| { DataFusionError::Execution(format!( "batch row count {} exceeds i32::MAX; UDFs cannot handle batches larger than 2^31 - 1 rows", @@ -167,29 +198,20 @@ impl ScalarUDFImpl for JavaScalarUdf { })?; let udf_jobject = self.udf_global_ref.as_obj(); - // SAFETY: udf_jobject is derived from a GlobalRef alive for the duration of this - // function. The raw pointer is only read by the JNI call below, which happens - // before any code that could drop udf_global_ref. - let call_args: [jvalue; 6] = [ - // ScalarFunction instance + // SAFETY: udf_global_ref and arg_kinds_array are alive for the duration of this call. + let call_args: [jvalue; 9] = [ jvalue { l: udf_jobject.as_raw(), }, - // argsArrayAddr - jvalue { j: args_array_addr }, - // argsSchemaAddr + jvalue { j: array_arr_addr }, + jvalue { j: array_sch_addr }, + jvalue { j: scalar_arr_addr }, + jvalue { j: scalar_sch_addr }, jvalue { - j: args_schema_addr, + l: arg_kinds_array.as_raw(), }, - // resultArrayAddr - jvalue { - j: result_array_addr, - }, - // resultSchemaAddr - jvalue { - j: result_schema_addr, - }, - // expectedRowCount + jvalue { j: result_arr_addr }, + jvalue { j: result_sch_addr }, jvalue { i: expected_rows }, ]; @@ -197,12 +219,12 @@ impl ScalarUDFImpl for JavaScalarUdf { env.call_static_method_unchecked( &self.bridge_class, self.invoke_method, - ReturnType::Primitive(Primitive::Void), + ReturnType::Primitive(Primitive::Byte), &call_args, ) }; - // 7. If Java threw, translate to DataFusionError. Always check exception_check first. + // 7. Java-exception path: translate to DataFusionError. if env.exception_check().unwrap_or(false) { let throwable = env.exception_occurred().map_err(|e| { DataFusionError::Execution(format!("exception_occurred failed: {}", e)) @@ -211,19 +233,22 @@ impl ScalarUDFImpl for JavaScalarUdf { let message = jthrowable_to_string(&mut env, &throwable, &self.name); return Err(DataFusionError::Execution(message)); } - call_result.map_err(|e| DataFusionError::Execution(format!("JNI call failed: {}", e)))?; - - // 8. Import result. from_ffi consumes the FFI_ArrowArray. - let result_array = *result_array_box; - let result_schema = *result_schema_box; - // SAFETY: Java's `Data.exportVector` populated `result_array_box` and - // `result_schema_box` in place via the C Data Interface, and the - // exception check above guarantees the call succeeded without - // throwing — so the FFI structs are fully initialized. - let result_data = unsafe { from_ffi(result_array, &result_schema) } + + let result_kind: jbyte = call_result + .map_err(|e| DataFusionError::Execution(format!("JNI call failed: {}", e)))? + .b() + .map_err(|e| { + DataFusionError::Execution(format!("invokeScalarUdf return decode failed: {}", e)) + })?; + + // 8. Import the result vector. from_ffi consumes the FFI_ArrowArray. + let result_array_ffi = *result_arr_box; + let result_schema_ffi = *result_sch_box; + // SAFETY: bridge populated both structs via Arrow C Data Interface; the exception check + // above confirmed no Java exception, so the FFI structs are fully initialised. + let result_data = unsafe { from_ffi(result_array_ffi, &result_schema_ffi) } .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; - // 9. Validate type. if result_data.data_type() != &self.return_type { return Err(DataFusionError::Execution(format!( "Java UDF '{}' returned vector of type {:?}; declared return type was {:?}", @@ -234,7 +259,25 @@ impl ScalarUDFImpl for JavaScalarUdf { } let array: ArrayRef = make_array(result_data); - Ok(ColumnarValue::Array(array)) + + match result_kind { + 0 => Ok(ColumnarValue::Array(array)), + 1 => { + if array.len() != 1 { + return Err(DataFusionError::Internal(format!( + "Java UDF '{}' returned Scalar with length {} (expected 1)", + self.name, + array.len() + ))); + } + let scalar = ScalarValue::try_from_array(&array, 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } + other => Err(DataFusionError::Internal(format!( + "Java UDF '{}' returned unknown kind byte: {}", + self.name, other + ))), + } } }