diff --git a/age--1.7.0--y.y.y.sql b/age--1.7.0--y.y.y.sql index 6dc8f707f..fd1f28160 100644 --- a/age--1.7.0--y.y.y.sql +++ b/age--1.7.0--y.y.y.sql @@ -1107,17 +1107,21 @@ COMMENT ON FUNCTION ag_catalog.create_subgraph(name, name, text, text) IS -- Transition function for the age_reduce aggregate. The fold body is compiled -- by transform_cypher_reduce() with the accumulator and element rewritten to -- PARAM_EXEC params 0 and 1 and serialized into the text argument; the --- transition evaluates it for each element in list order. It must be callable --- with a NULL transition state (no initcond), so it is intentionally not STRICT. -CREATE FUNCTION ag_catalog.age_reduce_transfn(agtype, agtype, text, agtype) +-- transition evaluates it for each element in list order. The trailing +-- agtype[] argument carries the loop-invariant outer values (outer-query +-- variables and cypher() parameters) referenced by the body, bound to +-- PARAM_EXEC params 2, 3, ... It must be callable with a NULL transition state +-- (no initcond), so it is intentionally not STRICT. +CREATE FUNCTION ag_catalog.age_reduce_transfn(agtype, agtype, text, agtype, agtype[]) RETURNS agtype LANGUAGE c PARALLEL UNSAFE AS 'MODULE_PATHNAME'; -- aggregate definition for reduce(); direct arguments are --- (init, serialized-body, element), with the element fed ORDER BY ordinality. -CREATE AGGREGATE ag_catalog.age_reduce(agtype, text, agtype) +-- (init, serialized-body, element, captured-outer-values), with the element +-- fed ORDER BY ordinality. +CREATE AGGREGATE ag_catalog.age_reduce(agtype, text, agtype, agtype[]) ( stype = agtype, sfunc = ag_catalog.age_reduce_transfn diff --git a/regress/expected/age_reduce.out b/regress/expected/age_reduce.out index 8a198965a..0cf4acad4 100644 --- a/regress/expected/age_reduce.out +++ b/regress/expected/age_reduce.out @@ -222,6 +222,87 @@ $$) AS (result agtype); [1, 4, 9] (1 row) +-- +-- Value types in the fold +-- +-- a float accumulator and float elements +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0.0, x IN [1.5, 2.5, 3.0] | s + x) +$$) AS (result agtype); + result +-------- + 7.0 +(1 row) + +-- negative numbers +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x IN [-1, -2, -3] | s + x) +$$) AS (result agtype); + result +-------- + -6 +(1 row) + +-- a map accumulator passed through unchanged +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = {n: 0}, x IN [1, 2, 3] | s) +$$) AS (result agtype); + result +---------- + {"n": 0} +(1 row) + +-- elements that are themselves lists, indexed in the body +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x IN [[1, 2], [3, 4], [5, 6]] | s + x[0]) +$$) AS (result agtype); + result +-------- + 9 +(1 row) + +-- +-- Function calls in the fold body +-- +-- a scalar function applied to the element +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x IN ['a', 'bb', 'ccc'] | s + size(x)) +$$) AS (result agtype); + result +-------- + 6 +(1 row) + +-- the list itself produced by a function +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x IN range(1, 5) | s + x) +$$) AS (result agtype); + result +-------- + 15 +(1 row) + +-- +-- Composing reduce() with surrounding expressions +-- +-- the reduce() result consumed by another function +SELECT * FROM cypher('reduce', $$ + RETURN size(reduce(s = [], x IN [1, 2, 3, 4] | s + [x])) +$$) AS (result agtype); + result +-------- + 4 +(1 row) + +-- the reduce() result used in a comparison +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x) = 6 +$$) AS (result agtype); + result +-------- + true +(1 row) + -- -- A conditional body (CASE) -- @@ -484,17 +565,129 @@ $$) AS (name agtype, total agtype); (3 rows) -- --- Not-yet-supported constructs raise a clean feature error +-- Outer references in the fold body -- --- an outer variable referenced in the body +-- The body may reference loop-invariant values from the enclosing query: an +-- outer variable, a property of an outer variable, or a cypher() parameter. +-- a plain outer variable in the body SELECT * FROM cypher('reduce', $$ WITH 5 AS w - RETURN reduce(s = 0, x IN [1, 2] | s + x + w) + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x + w) $$) AS (result agtype); -ERROR: a reduce() expression may only reference its accumulator and element variables -LINE 1: SELECT * FROM cypher('reduce', $$ - ^ --- a nested reduce() in the body + result +-------- + 21 +(1 row) + +-- an outer variable used as a multiplier +SELECT * FROM cypher('reduce', $$ + WITH 3 AS factor + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x * factor) +$$) AS (result agtype); + result +-------- + 18 +(1 row) + +-- two distinct outer variables in the body +SELECT * FROM cypher('reduce', $$ + WITH 2 AS a, 100 AS b + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x * a + b) +$$) AS (result agtype); + result +-------- + 312 +(1 row) + +-- a property of an outer (graph) variable in the body +SELECT * FROM cypher('reduce', $$ + MATCH (u:bag) WHERE u.name = 'mid' + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x + u.vals[0]) +$$) AS (result agtype); + result +-------- + 21 +(1 row) + +-- the same outer variable referenced more than once in the body +SELECT * FROM cypher('reduce', $$ + WITH 7 AS k + RETURN reduce(s = 0, x IN [1, 2, 3] | s + k + k) +$$) AS (result agtype); + result +-------- + 42 +(1 row) + +-- a property of an outer map referenced in the body +SELECT * FROM cypher('reduce', $$ + WITH {factor: 10} AS m + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x * m.factor) +$$) AS (result agtype); + result +-------- + 60 +(1 row) + +-- a subexpression that mixes an outer reference with the element: only the +-- loop-invariant part (the outer list) is captured, the element index is not +SELECT * FROM cypher('reduce', $$ + WITH [10, 20, 30] AS lookup + RETURN reduce(s = 0, x IN [1, 2, 3] | s + lookup[x - 1]) +$$) AS (result agtype); + result +-------- + 60 +(1 row) + +-- an outer reference inside a CASE branch of the body is captured +SELECT * FROM cypher('reduce', $$ + WITH 10 AS w + RETURN reduce(s = 0, x IN [1, 2, 3] | CASE WHEN x % 2 = 0 THEN s + w ELSE s + x END) +$$) AS (result agtype); + result +-------- + 14 +(1 row) + +-- a NULL outer value propagates through the fold +SELECT * FROM cypher('reduce', $$ + WITH null AS w + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x + w) +$$) AS (result agtype); + result +-------- + null +(1 row) + +-- multiple outer captures with a mix of NULL and non-NULL: each is bound to its +-- own slot (the non-NULL multiplier is bound and the NULL still propagates) +SELECT * FROM cypher('reduce', $$ + WITH 3 AS a, null AS b + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x * a + b) +$$) AS (result agtype); + result +-------- + null +(1 row) + +-- an outer variable that changes per row is captured per group +SELECT * FROM cypher('reduce', $$ + UNWIND [1, 2, 3] AS m + RETURN reduce(s = 0, x IN [1, 2, 3, 4] | s + x * m) AS total + ORDER BY total +$$) AS (result agtype); + result +-------- + 10 + 20 + 30 +(3 rows) + +-- +-- Not-yet-supported constructs raise a clean feature error +-- +-- a nested reduce() in the body (any subquery in the body is unsupported) SELECT * FROM cypher('reduce', $$ RETURN reduce(s = 0, x IN [1, 2] | s + reduce(t = 0, y IN [x] | t + y)) $$) AS (result agtype); @@ -509,6 +702,57 @@ ERROR: aggregate functions are not supported in a reduce() expression LINE 1: SELECT * FROM cypher('reduce', $$ ^ -- +-- Syntax errors: each required piece of the reduce() form is enforced +-- +-- missing "= init" +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s, x IN [1, 2] | s + x) +$$) AS (result agtype); +ERROR: syntax error at or near "," +LINE 2: RETURN reduce(s, x IN [1, 2] | s + x) + ^ +-- missing ", var IN list" +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0 | s) +$$) AS (result agtype); +ERROR: syntax error at or near "|" +LINE 2: RETURN reduce(s = 0 | s) + ^ +-- missing "| body" +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x IN [1, 2]) +$$) AS (result agtype); +ERROR: syntax error at or near ")" +LINE 2: RETURN reduce(s = 0, x IN [1, 2]) + ^ +-- a qualified iterator variable is not allowed +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x.y IN [1, 2] | s) +$$) AS (result agtype); +ERROR: syntax error at or near "." +LINE 2: RETURN reduce(s = 0, x.y IN [1, 2] | s) + ^ +-- +-- cypher() parameter referenced in the fold body (via a prepared statement) +-- +PREPARE reduce_param(agtype) AS + SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x + $p) + $$, $1) AS (result agtype); +EXECUTE reduce_param('{"p": 10}'); + result +-------- + 36 +(1 row) + +EXECUTE reduce_param('{"p": 100}'); + result +-------- + 306 +(1 row) + +DEALLOCATE reduce_param; +-- -- "reduce" as a property key name (safe_keywords backward compatibility): -- because reduce() introduced a reserved keyword, confirm the word is still -- usable as a map key, the same way any/none/single are. diff --git a/regress/sql/age_reduce.sql b/regress/sql/age_reduce.sql index cf1261010..b237488c9 100644 --- a/regress/sql/age_reduce.sql +++ b/regress/sql/age_reduce.sql @@ -151,6 +151,55 @@ SELECT * FROM cypher('reduce', $$ RETURN reduce(acc = [], x IN [1, 2, 3] | acc + [x * x]) $$) AS (result agtype); +-- +-- Value types in the fold +-- +-- a float accumulator and float elements +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0.0, x IN [1.5, 2.5, 3.0] | s + x) +$$) AS (result agtype); + +-- negative numbers +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x IN [-1, -2, -3] | s + x) +$$) AS (result agtype); + +-- a map accumulator passed through unchanged +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = {n: 0}, x IN [1, 2, 3] | s) +$$) AS (result agtype); + +-- elements that are themselves lists, indexed in the body +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x IN [[1, 2], [3, 4], [5, 6]] | s + x[0]) +$$) AS (result agtype); + +-- +-- Function calls in the fold body +-- +-- a scalar function applied to the element +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x IN ['a', 'bb', 'ccc'] | s + size(x)) +$$) AS (result agtype); + +-- the list itself produced by a function +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x IN range(1, 5) | s + x) +$$) AS (result agtype); + +-- +-- Composing reduce() with surrounding expressions +-- +-- the reduce() result consumed by another function +SELECT * FROM cypher('reduce', $$ + RETURN size(reduce(s = [], x IN [1, 2, 3, 4] | s + [x])) +$$) AS (result agtype); + +-- the reduce() result used in a comparison +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x) = 6 +$$) AS (result agtype); + -- -- A conditional body (CASE) -- @@ -317,15 +366,83 @@ SELECT * FROM cypher('reduce', $$ $$) AS (name agtype, total agtype); -- --- Not-yet-supported constructs raise a clean feature error +-- Outer references in the fold body -- --- an outer variable referenced in the body +-- The body may reference loop-invariant values from the enclosing query: an +-- outer variable, a property of an outer variable, or a cypher() parameter. +-- a plain outer variable in the body SELECT * FROM cypher('reduce', $$ WITH 5 AS w - RETURN reduce(s = 0, x IN [1, 2] | s + x + w) + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x + w) +$$) AS (result agtype); + +-- an outer variable used as a multiplier +SELECT * FROM cypher('reduce', $$ + WITH 3 AS factor + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x * factor) +$$) AS (result agtype); + +-- two distinct outer variables in the body +SELECT * FROM cypher('reduce', $$ + WITH 2 AS a, 100 AS b + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x * a + b) +$$) AS (result agtype); + +-- a property of an outer (graph) variable in the body +SELECT * FROM cypher('reduce', $$ + MATCH (u:bag) WHERE u.name = 'mid' + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x + u.vals[0]) +$$) AS (result agtype); + +-- the same outer variable referenced more than once in the body +SELECT * FROM cypher('reduce', $$ + WITH 7 AS k + RETURN reduce(s = 0, x IN [1, 2, 3] | s + k + k) $$) AS (result agtype); --- a nested reduce() in the body +-- a property of an outer map referenced in the body +SELECT * FROM cypher('reduce', $$ + WITH {factor: 10} AS m + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x * m.factor) +$$) AS (result agtype); + +-- a subexpression that mixes an outer reference with the element: only the +-- loop-invariant part (the outer list) is captured, the element index is not +SELECT * FROM cypher('reduce', $$ + WITH [10, 20, 30] AS lookup + RETURN reduce(s = 0, x IN [1, 2, 3] | s + lookup[x - 1]) +$$) AS (result agtype); + +-- an outer reference inside a CASE branch of the body is captured +SELECT * FROM cypher('reduce', $$ + WITH 10 AS w + RETURN reduce(s = 0, x IN [1, 2, 3] | CASE WHEN x % 2 = 0 THEN s + w ELSE s + x END) +$$) AS (result agtype); + +-- a NULL outer value propagates through the fold +SELECT * FROM cypher('reduce', $$ + WITH null AS w + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x + w) +$$) AS (result agtype); + +-- multiple outer captures with a mix of NULL and non-NULL: each is bound to its +-- own slot (the non-NULL multiplier is bound and the NULL still propagates) +SELECT * FROM cypher('reduce', $$ + WITH 3 AS a, null AS b + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x * a + b) +$$) AS (result agtype); + +-- an outer variable that changes per row is captured per group +SELECT * FROM cypher('reduce', $$ + UNWIND [1, 2, 3] AS m + RETURN reduce(s = 0, x IN [1, 2, 3, 4] | s + x * m) AS total + ORDER BY total +$$) AS (result agtype); + +-- +-- Not-yet-supported constructs raise a clean feature error +-- +-- a nested reduce() in the body (any subquery in the body is unsupported) SELECT * FROM cypher('reduce', $$ RETURN reduce(s = 0, x IN [1, 2] | s + reduce(t = 0, y IN [x] | t + y)) $$) AS (result agtype); @@ -335,6 +452,43 @@ SELECT * FROM cypher('reduce', $$ RETURN reduce(s = 0, x IN [1, 2] | s + count(x)) $$) AS (result agtype); +-- +-- Syntax errors: each required piece of the reduce() form is enforced +-- +-- missing "= init" +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s, x IN [1, 2] | s + x) +$$) AS (result agtype); + +-- missing ", var IN list" +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0 | s) +$$) AS (result agtype); + +-- missing "| body" +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x IN [1, 2]) +$$) AS (result agtype); + +-- a qualified iterator variable is not allowed +SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x.y IN [1, 2] | s) +$$) AS (result agtype); + +-- +-- cypher() parameter referenced in the fold body (via a prepared statement) +-- +PREPARE reduce_param(agtype) AS + SELECT * FROM cypher('reduce', $$ + RETURN reduce(s = 0, x IN [1, 2, 3] | s + x + $p) + $$, $1) AS (result agtype); + +EXECUTE reduce_param('{"p": 10}'); + +EXECUTE reduce_param('{"p": 100}'); + +DEALLOCATE reduce_param; + -- -- "reduce" as a property key name (safe_keywords backward compatibility): -- because reduce() introduced a reserved keyword, confirm the word is still diff --git a/sql/age_aggregate.sql b/sql/age_aggregate.sql index fb258e5c5..9ad715683 100644 --- a/sql/age_aggregate.sql +++ b/sql/age_aggregate.sql @@ -223,17 +223,21 @@ CREATE AGGREGATE ag_catalog.age_collect(variadic "any") -- Transition function for the age_reduce aggregate. The fold body is compiled -- by transform_cypher_reduce() with the accumulator and element rewritten to -- PARAM_EXEC params 0 and 1 and serialized into the text argument; the --- transition evaluates it for each element in list order. It must be callable --- with a NULL transition state (no initcond), so it is intentionally not STRICT. -CREATE FUNCTION ag_catalog.age_reduce_transfn(agtype, agtype, text, agtype) +-- transition evaluates it for each element in list order. The trailing +-- agtype[] argument carries the loop-invariant outer values (outer-query +-- variables and cypher() parameters) referenced by the body, bound to +-- PARAM_EXEC params 2, 3, ... It must be callable with a NULL transition state +-- (no initcond), so it is intentionally not STRICT. +CREATE FUNCTION ag_catalog.age_reduce_transfn(agtype, agtype, text, agtype, agtype[]) RETURNS agtype LANGUAGE c PARALLEL UNSAFE AS 'MODULE_PATHNAME'; -- aggregate definition for reduce(); direct arguments are --- (init, serialized-body, element), with the element fed ORDER BY ordinality. -CREATE AGGREGATE ag_catalog.age_reduce(agtype, text, agtype) +-- (init, serialized-body, element, captured-outer-values), with the element +-- fed ORDER BY ordinality. +CREATE AGGREGATE ag_catalog.age_reduce(agtype, text, agtype, agtype[]) ( stype = agtype, sfunc = ag_catalog.age_reduce_transfn diff --git a/src/backend/parser/cypher_clause.c b/src/backend/parser/cypher_clause.c index 6582ff8d1..a7daa0e5c 100644 --- a/src/backend/parser/cypher_clause.c +++ b/src/backend/parser/cypher_clause.c @@ -2110,15 +2110,60 @@ static Query *make_reduce_var_subquery(char *acc_name, char *elem_name) } /* - * Validate a transformed-and-mutated reduce() fold body. After - * reduce_var_to_param_mutator() has replaced the accumulator and element with - * PARAM_EXEC params 0 and 1, a valid body is a pure expression over those two - * params: it must contain no other Vars (outer-query references), no other - * params, and no aggregates or subqueries, because the body is evaluated - * standalone (ExecEvalExpr) inside age_reduce_transfn with only those two - * param slots bound. + * Walker: true if the subtree references the reduce() accumulator or element, + * i.e. it contains PARAM_EXEC param 0 or 1 (assigned by + * reduce_var_to_param_mutator). Such a subtree changes per element and cannot + * be captured as a loop-invariant outer value. */ -static bool reduce_body_check_walker(Node *node, void *context) +static bool reduce_expr_has_acc_elem(Node *node, void *context) +{ + if (node == NULL) + { + return false; + } + + if (IsA(node, Param)) + { + Param *param = (Param *) node; + + if (param->paramkind == PARAM_EXEC && + (param->paramid == 0 || param->paramid == 1)) + { + return true; + } + } + + return expression_tree_walker(node, reduce_expr_has_acc_elem, context); +} + +/* + * Walker: true if the subtree contains an aggregate, grouping, or window + * function. Such a node cannot be evaluated standalone and must not be folded + * into a captured outer value (it would become an illegal nested aggregate). + */ +static bool reduce_expr_has_aggregate(Node *node, void *context) +{ + if (node == NULL) + { + return false; + } + + if (IsA(node, Aggref) || IsA(node, GroupingFunc) || IsA(node, WindowFunc)) + { + return true; + } + + return expression_tree_walker(node, reduce_expr_has_aggregate, context); +} + +/* + * Walker: true if the subtree references anything that cannot be evaluated + * standalone -- an outer-query Var or a non-PARAM_EXEC parameter (e.g. a + * cypher() $parameter, which transforms to agtype_access_operator over a + * PARAM_EXTERN). Such a subtree must be captured and supplied to the fold via + * the extras array. A subtree of only constants does not need capturing. + */ +static bool reduce_expr_needs_capture(Node *node, void *context) { if (node == NULL) { @@ -2127,24 +2172,92 @@ static bool reduce_body_check_walker(Node *node, void *context) if (IsA(node, Var)) { - ereport(ERROR, - (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), - errmsg("a reduce() expression may only reference its accumulator and element variables"))); + return true; } if (IsA(node, Param)) { Param *param = (Param *) node; - if (param->paramkind != PARAM_EXEC || - (param->paramid != 0 && param->paramid != 1)) + if (param->paramkind != PARAM_EXEC) { - ereport(ERROR, - (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), - errmsg("a reduce() expression may not reference query parameters"))); + return true; } } + return expression_tree_walker(node, reduce_expr_needs_capture, context); +} + +/* + * Walker: true if the subtree contains a subquery (SubLink). A captured outer + * value is supplied to the aggregate as a plain expression argument, which the + * standalone fold evaluator cannot plan, so a subtree containing a subquery is + * never captured -- it falls through to the explicit rejection instead. + */ +static bool reduce_expr_has_sublink(Node *node, void *context) +{ + if (node == NULL) + { + return false; + } + + if (IsA(node, SubLink)) + { + return true; + } + + return expression_tree_walker(node, reduce_expr_has_sublink, context); +} + +/* + * Mutator context for capturing loop-invariant outer references in a reduce() + * fold body. Each captured subtree is assigned the next PARAM_EXEC id (starting + * at 2, after the accumulator and element) and collected, in id order, so the + * caller can supply the values to age_reduce_transfn through the extras array. + */ +typedef struct reduce_capture_context +{ + int next_slot; /* next PARAM_EXEC id to assign (starts at 2) */ + List *captured; /* captured outer-reference exprs, in slot order */ +} reduce_capture_context; + +/* + * Capture the loop-invariant outer references in a reduce() fold body. + * + * After reduce_var_to_param_mutator() has rewritten the accumulator and + * element to PARAM_EXEC params 0 and 1, the largest agtype-typed subtrees that + * (a) do not reference the accumulator or element, (b) contain no aggregate, + * (c) contain no subquery, and (d) reference an outer Var or a cypher() + * $parameter are replaced by new PARAM_EXEC params 2, 3, ... and collected. + * Each such subtree is loop-invariant within a fold, so the executor evaluates + * it once per row in the outer query context (as an aggregate argument) and + * binds the value to its param slot. + * + * Aggregates and subqueries (including a nested reduce()) are not captured and + * are rejected outright: an aggregate is undefined inside a per-element fold, + * and a subquery cannot be supplied as a plain aggregate argument. + */ +static Node *reduce_capture_mutator(Node *node, void *context) +{ + reduce_capture_context *ctx = (reduce_capture_context *) context; + + if (node == NULL) + { + return NULL; + } + + /* + * Container / support nodes that expression_tree_mutator hands us are not + * themselves typed expressions (calling exprType on them errors), so just + * recurse into them. For an agtype scalar fold body these are List nodes + * (argument lists) and CaseWhen nodes (CASE branches). + */ + if (IsA(node, List) || IsA(node, CaseWhen)) + { + return expression_tree_mutator(node, reduce_capture_mutator, context); + } + + /* an aggregate in the fold body is never supported */ if (IsA(node, Aggref) || IsA(node, GroupingFunc) || IsA(node, WindowFunc)) { ereport(ERROR, @@ -2152,6 +2265,37 @@ static bool reduce_body_check_walker(Node *node, void *context) errmsg("aggregate functions are not supported in a reduce() expression"))); } + /* + * Capture a maximal agtype-typed, loop-invariant subtree as a single outer + * value. It must not reference the accumulator/element, must not embed an + * aggregate, and must not contain a subquery (which could not be planned as + * a plain aggregate argument). + */ + if (exprType(node) == AGTYPEOID && + !reduce_expr_has_acc_elem(node, NULL) && + !reduce_expr_has_aggregate(node, NULL) && + !reduce_expr_has_sublink(node, NULL) && + reduce_expr_needs_capture(node, NULL)) + { + Param *param = makeNode(Param); + + param->paramkind = PARAM_EXEC; + param->paramid = ctx->next_slot++; + param->paramtype = AGTYPEOID; + param->paramtypmod = -1; + param->paramcollid = InvalidOid; + param->location = -1; + + ctx->captured = lappend(ctx->captured, copyObject(node)); + + return (Node *) param; + } + + /* + * Any subquery in the body is rejected: it is never captured (see the + * subquery exclusion in the capture test above), and it cannot be + * evaluated standalone by the fold either. + */ if (IsA(node, SubLink)) { ereport(ERROR, @@ -2159,6 +2303,43 @@ static bool reduce_body_check_walker(Node *node, void *context) errmsg("subqueries (including a nested reduce()) are not supported in a reduce() expression"))); } + return expression_tree_mutator(node, reduce_capture_mutator, context); +} + +/* + * Safety net run after reduce_capture_mutator(). A valid body now references + * only PARAM_EXEC params (0/1 for the accumulator and element, 2.. for the + * captured outer values) and constants. Any remaining Var or non-PARAM_EXEC + * parameter is an outer reference that could not be captured (for example a + * non-agtype-typed one); reject it cleanly rather than letting it reach the + * standalone evaluator. + */ +static bool reduce_body_check_walker(Node *node, void *context) +{ + if (node == NULL) + { + return false; + } + + if (IsA(node, Var)) + { + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("a reduce() expression references a value that cannot be used in the fold body"))); + } + + if (IsA(node, Param)) + { + Param *param = (Param *) node; + + if (param->paramkind != PARAM_EXEC) + { + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("a reduce() expression references a value that cannot be used in the fold body"))); + } + } + return expression_tree_walker(node, reduce_body_check_walker, context); } @@ -2170,12 +2351,17 @@ static bool reduce_body_check_walker(Node *node, void *context) * aggregate ordered by that ordinality so the fold runs in list order: * * SELECT ag_catalog.age_reduce(, ''::text, - * r.elem ORDER BY r.ord) + * r.elem, + * ORDER BY r.ord) * FROM unnest() WITH ORDINALITY AS r(elem, ord) * * The fold body is transformed separately with the accumulator and element * rewritten to PARAM_EXEC params 0 and 1, serialized into the text argument, - * and evaluated per element inside age_reduce_transfn. + * and evaluated per element inside age_reduce_transfn. Loop-invariant outer + * references in the body (outer-query variables and cypher() parameters) are + * captured as PARAM_EXEC params 2.. and passed through the trailing agtype[] + * argument so the body can use values from the enclosing query; correlated + * captures are re-evaluated per group. * * The null/empty-list guard * (CASE WHEN list IS NULL THEN NULL ELSE COALESCE(, init) END) is built @@ -2193,6 +2379,7 @@ static Query *transform_cypher_reduce(cypher_parsestate *cpstate, Node *body_node; char *body_serialized; reduce_var_param_context mutator_ctx; + reduce_capture_context capture_ctx; cypher_parsestate *child_cpstate; ParseState *child_pstate; FuncCall *unnest_fc; @@ -2210,9 +2397,11 @@ static Query *transform_cypher_reduce(cypher_parsestate *cpstate, Oid sort_eqop; bool sort_hashable; Const *body_const; + ArrayExpr *extras_arr; + List *extras_exprs = NIL; Aggref *agg; Oid agg_oid; - Oid agg_argtypes[3]; + Oid agg_argtypes[4]; TargetEntry *result_te; /* @@ -2266,6 +2455,18 @@ static Query *transform_cypher_reduce(cypher_parsestate *cpstate, mutator_ctx.varno = body_pnsi->p_rtindex; body_node = reduce_var_to_param_mutator(body_node, &mutator_ctx); + /* + * Capture loop-invariant outer references (outer-query variables and + * cypher() parameters) in the body as PARAM_EXEC params 2.. and collect + * them in slot order; their values are supplied to the fold through the + * extras array argument built below. + */ + capture_ctx.next_slot = 2; + capture_ctx.captured = NIL; + body_node = reduce_capture_mutator(body_node, &capture_ctx); + extras_exprs = capture_ctx.captured; + + /* reject anything in the body that could not be captured or evaluated */ reduce_body_check_walker(body_node, NULL); body_serialized = nodeToString(body_node); @@ -2316,7 +2517,7 @@ static Query *transform_cypher_reduce(cypher_parsestate *cpstate, get_sort_group_operators(INT8OID, true, true, false, &sort_ltop, &sort_eqop, NULL, &sort_hashable); - ord_te = makeTargetEntry((Expr *) ord_var, 4, NULL, true); + ord_te = makeTargetEntry((Expr *) ord_var, 5, NULL, true); ord_te->ressortgroupref = 1; sortcl = makeNode(SortGroupClause); @@ -2382,13 +2583,28 @@ static Query *transform_cypher_reduce(cypher_parsestate *cpstate, init_node = (Node *) init_case; } - /* look up the age_reduce(agtype, text, agtype) aggregate */ + /* + * The captured loop-invariant outer values (outer-query variables and + * cypher() parameters referenced by the body) are passed to the aggregate + * as an agtype[] argument, in the same order their PARAM_EXEC params 2, 3, + * ... were assigned. When the body references nothing outside the + * accumulator and element this is an empty array. + */ + extras_arr = makeNode(ArrayExpr); + extras_arr->array_typeid = AGTYPEARRAYOID; + extras_arr->element_typeid = AGTYPEOID; + extras_arr->elements = extras_exprs; + extras_arr->multidims = false; + extras_arr->location = -1; + + /* look up the age_reduce(agtype, text, agtype, agtype[]) aggregate */ agg_argtypes[0] = AGTYPEOID; agg_argtypes[1] = TEXTOID; agg_argtypes[2] = AGTYPEOID; + agg_argtypes[3] = AGTYPEARRAYOID; agg_oid = LookupFuncName(list_make2(makeString("ag_catalog"), makeString("age_reduce")), - 3, agg_argtypes, false); + 4, agg_argtypes, false); agg = makeNode(Aggref); agg->aggfnoid = agg_oid; @@ -2396,11 +2612,13 @@ static Query *transform_cypher_reduce(cypher_parsestate *cpstate, agg->aggcollid = InvalidOid; agg->inputcollid = InvalidOid; agg->aggtranstype = InvalidOid; /* filled by the planner */ - agg->aggargtypes = list_make3_oid(AGTYPEOID, TEXTOID, AGTYPEOID); + agg->aggargtypes = list_make4_oid(AGTYPEOID, TEXTOID, AGTYPEOID, + AGTYPEARRAYOID); agg->aggdirectargs = NIL; - agg->args = list_make4(makeTargetEntry((Expr *) init_node, 1, NULL, false), + agg->args = list_make5(makeTargetEntry((Expr *) init_node, 1, NULL, false), makeTargetEntry((Expr *) body_const, 2, NULL, false), makeTargetEntry((Expr *) elem_var, 3, NULL, false), + makeTargetEntry((Expr *) extras_arr, 4, NULL, false), ord_te); agg->aggorder = list_make1(sortcl); agg->aggdistinct = NIL; diff --git a/src/backend/utils/adt/agtype.c b/src/backend/utils/adt/agtype.c index bf69bf1fa..8684ff006 100644 --- a/src/backend/utils/adt/agtype.c +++ b/src/backend/utils/adt/agtype.c @@ -11581,13 +11581,16 @@ Datum age_float8_stddev_pop_aggfinalfn(PG_FUNCTION_ARGS) /* * Per-aggregate-group evaluation state for reduce(). Caches the compiled * fold-body expression and a standalone ExprContext whose PARAM_EXEC slots - * (0 = accumulator, 1 = current element) are rebound on every element. + * are rebound on every element. Slot 0 = accumulator, slot 1 = current + * element, and slots 2 .. nparams-1 = captured loop-invariant outer values + * (outer-query variables and cypher() parameters referenced by the body). */ typedef struct reduce_eval_ctx { ExprState *body_state; /* compiled fold-body expression */ ExprContext *econtext; /* eval context carrying the param slots */ - ParamExecData *params; /* [0] = accumulator, [1] = current element */ + ParamExecData *params; /* [0]=accumulator, [1]=element, [2..]=outer refs */ + int nparams; /* total param slots = 2 + number of captures */ } reduce_eval_ctx; /* Build an agtype 'null' Datum (a real agtype value, not a SQL NULL). */ @@ -11600,12 +11603,17 @@ static Datum reduce_agtype_null(void) } /* - * age_reduce_transfn(state agtype, init agtype, body text, element agtype) + * age_reduce_transfn(state agtype, init agtype, body text, element agtype, + * extras agtype[]) * * Transition function for the age_reduce aggregate that implements the Cypher * reduce(acc = init, var IN list | body) fold. The fold body is compiled by * transform_cypher_reduce() with the accumulator and element rewritten to * PARAM_EXEC params 0 and 1, then serialized into the `body` text argument. + * Any loop-invariant outer-query variable or cypher() parameter referenced by + * the body is captured into the `extras` agtype array and rewritten to a + * PARAM_EXEC param 2, 3, ... in body order; those slots are bound from the + * array here. * * On the first element of a group the accumulator is seeded from `init` * (the running state is NULL because the aggregate uses no initcond); on @@ -11652,6 +11660,7 @@ Datum age_reduce_transfn(PG_FUNCTION_ARGS) text *body_txt; char *body_str; Node *body_node; + int n_extras = 0; if (PG_ARGISNULL(2)) { @@ -11660,6 +11669,25 @@ Datum age_reduce_transfn(PG_FUNCTION_ARGS) errmsg("age_reduce: missing fold expression"))); } + /* + * The number of captured outer values is fixed for this aggregate + * call (the body's structure does not change between rows), so it is + * read once here to size the param array. Their values are bound per + * row below because a correlated capture changes between groups. + * + * The PG_NARGS() guard lets the function tolerate being reached + * through an older 4-argument aggregate definition (for example a + * stale catalog paired with a newer age.so): a missing extras + * argument is simply treated as zero captures. + */ + if (PG_NARGS() > 4 && !PG_ARGISNULL(4)) + { + ArrayType *extras_arr = PG_GETARG_ARRAYTYPE_P(4); + + n_extras = ArrayGetNItems(ARR_NDIM(extras_arr), + ARR_DIMS(extras_arr)); + } + oldctx = MemoryContextSwitchTo(fcinfo->flinfo->fn_mcxt); rc = (reduce_eval_ctx *) palloc0(sizeof(reduce_eval_ctx)); body_txt = PG_GETARG_TEXT_PP(2); @@ -11686,7 +11714,9 @@ Datum age_reduce_transfn(PG_FUNCTION_ARGS) rc->body_state = ExecInitExpr((Expr *) body_node, NULL); rc->econtext = CreateStandaloneExprContext(); - rc->params = (ParamExecData *) palloc0(sizeof(ParamExecData) * 2); + rc->nparams = 2 + n_extras; + rc->params = (ParamExecData *) palloc0(sizeof(ParamExecData) * + rc->nparams); rc->econtext->ecxt_param_exec_vals = rc->params; fcinfo->flinfo->fn_extra = rc; MemoryContextSwitchTo(oldctx); @@ -11710,6 +11740,9 @@ Datum age_reduce_transfn(PG_FUNCTION_ARGS) /* a NULL element is likewise normalized to agtype 'null' */ element = PG_ARGISNULL(3) ? reduce_agtype_null() : PG_GETARG_DATUM(3); + /* evaluate the fold body for this element */ + ResetExprContext(rc->econtext); + /* bind PARAM_EXEC 0 = accumulator, 1 = current element */ rc->params[0].value = acc; rc->params[0].isnull = false; @@ -11718,8 +11751,57 @@ Datum age_reduce_transfn(PG_FUNCTION_ARGS) rc->params[1].isnull = false; rc->params[1].execPlan = NULL; - /* evaluate the fold body for this element */ - ResetExprContext(rc->econtext); + /* + * Bind the captured loop-invariant outer values to params 2 .. The values + * are pulled from the extras array every row because correlated captures + * differ between groups; the per-row deconstruction is done in the + * econtext's per-tuple memory (reset above) so it does not leak. A NULL + * array element is normalized to agtype 'null' like the accumulator and + * element. + * + * Every slot 2 .. nparams-1 is rebound on every row, so a slot never + * retains a value from a previous row -- which, after the per-tuple reset + * above, would be a dangling pointer. If the extras array supplies fewer + * values than there are capture slots (only reachable through a direct SQL + * call with a varying-length array), the unsupplied slots are filled with + * agtype 'null'. The PG_NARGS() guard keeps the arg-4 access safe under an + * older 4-argument signature. + */ + if (rc->nparams > 2 && PG_NARGS() > 4 && !PG_ARGISNULL(4)) + { + ArrayType *extras_arr = PG_GETARG_ARRAYTYPE_P(4); + Oid elemtype = ARR_ELEMTYPE(extras_arr); + int16 typlen; + bool typbyval; + char typalign; + Datum *ex_vals; + bool *ex_nulls; + int ex_n; + int i; + MemoryContext per_tuple = rc->econtext->ecxt_per_tuple_memory; + MemoryContext save = MemoryContextSwitchTo(per_tuple); + + get_typlenbyvalalign(elemtype, &typlen, &typbyval, &typalign); + deconstruct_array(extras_arr, elemtype, typlen, typbyval, typalign, + &ex_vals, &ex_nulls, &ex_n); + + for (i = 0; (2 + i) < rc->nparams; i++) + { + if (i < ex_n && !ex_nulls[i]) + { + rc->params[2 + i].value = ex_vals[i]; + } + else + { + rc->params[2 + i].value = reduce_agtype_null(); + } + rc->params[2 + i].isnull = false; + rc->params[2 + i].execPlan = NULL; + } + + MemoryContextSwitchTo(save); + } + result = ExecEvalExpr(rc->body_state, rc->econtext, &result_isnull); /*