diff --git a/README.md b/README.md index 6884171..5c57541 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ PyMongoSQL implements the DB API 2.0 interfaces to provide SQL-like access to Mo - **PartiQL-based SQL Syntax**: Built on [PartiQL](https://partiql.org/tutorial.html) (SQL for semi-structured data), enabling seamless SQL querying of nested and hierarchical MongoDB documents - **Nested Structure Support**: Query and filter deeply nested fields and arrays within MongoDB documents using standard SQL syntax - **MongoDB Aggregate Pipeline Support**: Execute native MongoDB aggregation pipelines using SQL-like syntax with `aggregate()` function +- **SQL Aggregate Functions**: `COUNT(*)`, `SUM`, `AVG`, `MIN`, `MAX` translated to MongoDB aggregation pipelines - **SQLAlchemy Integration**: Complete ORM and Core support with dedicated MongoDB dialect - **SQL Query Support**: SELECT statements with WHERE conditions, field selection, and aliases - **DML Support**: Full support for INSERT, UPDATE, and DELETE operations using PartiQL syntax @@ -87,6 +88,7 @@ pip install -e . - [WHERE Clauses](#where-clauses) - [Nested Field Support](#nested-field-support) - [Sorting and Limiting](#sorting-and-limiting) + - [SQL Aggregate Functions](#sql-aggregate-functions) - [MongoDB Aggregate Function](#mongodb-aggregate-function) - [INSERT Statements](#insert-statements) - [UPDATE Statements](#update-statements) @@ -293,6 +295,42 @@ Both functions: - **LIMIT**: `LIMIT 10` - **Combined**: `ORDER BY created_at DESC LIMIT 5` +### SQL Aggregate Functions + +PyMongoSQL supports standard SQL aggregate functions that are automatically translated into MongoDB aggregation pipelines. + +**Supported Functions**: `COUNT(*)`, `SUM(field)`, `AVG(field)`, `MIN(field)`, `MAX(field)` + +**Basic Count** + +```python +cursor.execute("SELECT COUNT(*) AS total FROM users") +row = cursor.fetchone() +print(f"Total users: {row[0]}") +``` + +**Multiple Aggregates** + +```python +cursor.execute( + "SELECT COUNT(*) AS cnt, AVG(price) AS avg_price, MIN(price) AS cheapest, MAX(price) AS priciest FROM products" +) +``` + +**Aggregate with WHERE** + +```python +cursor.execute("SELECT COUNT(*) AS total FROM users WHERE active = true AND age > 30") +``` + +**Aggregate with OR Conditions** + +```python +cursor.execute("SELECT COUNT(*) AS cnt FROM users WHERE age < 26 OR age > 40") +``` + +**Note:** Aggregate functions are translated into a MongoDB `aggregate()` pipeline with `$match` (from WHERE), `$group` (with accumulators), and `$project` stages. `COUNT(*)` maps to `{$sum: 1}`, while `SUM`, `AVG`, `MIN`, and `MAX` map to their corresponding MongoDB accumulators (`$sum`, `$avg`, `$min`, `$max`). + ### MongoDB Aggregate Function PyMongoSQL supports executing native MongoDB aggregation pipelines using SQL-like syntax with the `aggregate()` function. This allows you to leverage MongoDB's powerful aggregation framework while maintaining SQL-style query patterns. @@ -608,20 +646,24 @@ The table below shows how PyMongoSQL translates SQL operations into MongoDB comm ### SQL Operations to MongoDB Commands -| SQL Operation | MongoDB Command | Equivalent PyMongo Method | -|---|---|---| -| `SELECT ... FROM col` | `{find: col, projection: {...}}` | `db.command("find", ...)` | -| `SELECT ... FROM col WHERE ...` | `{find: col, filter: {...}}` | `db.command("find", ...)` | -| `SELECT ... ORDER BY col ASC/DESC` | `{find: ..., sort: {col: 1/-1}}` | `db.command("find", ...)` | -| `SELECT ... LIMIT n` | `{find: ..., limit: n}` | `db.command("find", ...)` | -| `SELECT ... OFFSET n` | `{find: ..., skip: n}` | `db.command("find", ...)` | -| `SELECT * FROM col.aggregate(...)` | `collection.aggregate(pipeline)` | `collection.aggregate()` | -| `INSERT INTO col ...` | `{insert: col, documents: [...]}` | `db.command("insert", ...)` | -| `UPDATE col SET ... WHERE ...` | `{update: col, updates: [{q: filter, u: {$set: {...}}, multi: true}]}` | `db.command("update", ...)` | -| `DELETE FROM col WHERE ...` | `{delete: col, deletes: [{q: filter, limit: 0}]}` | `db.command("delete", ...)` | -| `CREATE VIEW v ON col AS '[...]'` | `{create: v, viewOn: col, pipeline: [...]}` | `db.command("create", ...)` | -| `DROP VIEW v` | `{drop: v}` | `db.command("drop", ...)` | -| `EXPLAIN ` | `{explain: , verbosity: "queryPlanner"}` | ### SQL Clauses to MongoDB Query Components @@ -635,6 +677,11 @@ The table below shows how PyMongoSQL translates SQL operations into MongoDB comm | `ORDER BY col DESC` | `sort: {col: -1}` | Descending sort | | `LIMIT n` | `limit: n` | Restrict result count | | `OFFSET n` | `skip: n` | Skip first n results | +| `COUNT(*)` | `{$group: {_id: null, count: {$sum: 1}}}` | Document count | +| `SUM(field)` | `{$group: {_id: null, sum: {$sum: "$field"}}}` | Field sum | +| `AVG(field)` | `{$group: {_id: null, avg: {$avg: "$field"}}}` | Field average | +| `MIN(field)` | `{$group: {_id: null, min: {$min: "$field"}}}` | Field minimum | +| `MAX(field)` | `{$group: {_id: null, max: {$max: "$field"}}}` | Field maximum | ### WHERE Operators to MongoDB Filter Operators diff --git a/pymongosql/sql/builder.py b/pymongosql/sql/builder.py index 4c7a26b..d9dff9d 100644 --- a/pymongosql/sql/builder.py +++ b/pymongosql/sql/builder.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import json import logging from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, Optional, Union @@ -112,6 +113,11 @@ def build_from_parse_result( @staticmethod def _build_query_plan(parse_result: "QueryParseResult") -> "QueryExecutionPlan": """Build a query execution plan from SELECT parsing.""" + + # Auto-generate aggregate pipeline for SQL aggregate functions (COUNT, SUM, etc.) + if getattr(parse_result, "aggregate_functions", None): + return ExecutionPlanBuilder._build_sql_aggregate_plan(parse_result) + builder = BuilderFactory.create_query_builder().collection(parse_result.collection) builder.filter(parse_result.filter_conditions).project(parse_result.projection).column_aliases( @@ -128,6 +134,60 @@ def _build_query_plan(parse_result: "QueryParseResult") -> "QueryExecutionPlan": plan = builder.build() return plan + @staticmethod + def _build_sql_aggregate_plan(parse_result: "QueryParseResult") -> "QueryExecutionPlan": + """Build an aggregate execution plan from SQL aggregate functions like COUNT(*), SUM(), etc.""" + _FUNCTION_TO_ACCUMULATOR = { + "COUNT": "$sum", + "SUM": "$sum", + "AVG": "$avg", + "MIN": "$min", + "MAX": "$max", + } + + builder = BuilderFactory.create_query_builder().collection(parse_result.collection) + + pipeline = [] + + # Add $match stage if there are filter conditions (from WHERE clause) + if parse_result.filter_conditions: + pipeline.append({"$match": parse_result.filter_conditions}) + + # Build $group stage from aggregate functions + group_stage = {"_id": None} + for func_info in parse_result.aggregate_functions: + alias = func_info["alias"] + func_name = func_info["function"] + arg = func_info["argument"] + accumulator = _FUNCTION_TO_ACCUMULATOR[func_name] + + if func_name == "COUNT": + group_stage[alias] = {accumulator: 1} + else: + group_stage[alias] = {accumulator: f"${arg}"} + + pipeline.append({"$group": group_stage}) + + # Add $project to exclude _id + project_stage = {"_id": 0} + for func_info in parse_result.aggregate_functions: + project_stage[func_info["alias"]] = 1 + pipeline.append({"$project": project_stage}) + + # Configure the execution plan as an aggregate query + builder._execution_plan.is_aggregate_query = True + builder._execution_plan.aggregate_pipeline = json.dumps(pipeline) + builder._execution_plan.aggregate_options = json.dumps({}) + + # Set projection for ResultSet description + agg_projection = {} + for func_info in parse_result.aggregate_functions: + agg_projection[func_info["alias"]] = 1 + builder._execution_plan.projection_stage = agg_projection + + plan = builder.build() + return plan + @staticmethod def _build_insert_plan(parse_result: "InsertParseResult") -> "InsertExecutionPlan": """Build an INSERT execution plan from INSERT parsing.""" diff --git a/pymongosql/sql/query_handler.py b/pymongosql/sql/query_handler.py index 52ef408..3a09db8 100644 --- a/pymongosql/sql/query_handler.py +++ b/pymongosql/sql/query_handler.py @@ -32,6 +32,9 @@ class QueryParseResult: aggregate_pipeline: Optional[str] = None # JSON string representation of pipeline aggregate_options: Optional[str] = None # JSON string representation of options + # SQL aggregate functions detected in SELECT (COUNT, SUM, AVG, MIN, MAX) + aggregate_functions: List[Dict[str, Any]] = field(default_factory=list) + # Subquery info (for wrapped subqueries, e.g., Superset outering) subquery_plan: Optional[Any] = None subquery_alias: Optional[str] = None @@ -111,6 +114,12 @@ def handle(self, ctx: PartiQLParser.WhereClauseSelectContext) -> Dict[str, Any]: class SelectHandler(BaseHandler, ContextUtilsMixin): """Handles SELECT statement parsing""" + # Pattern to detect SQL aggregate functions: COUNT(*), SUM(field), AVG(field), etc. + _AGGREGATE_PATTERN = re.compile( + r"^(COUNT|SUM|AVG|MIN|MAX)\s*\(\s*(\*|\w+(?:\.\w+)*)\s*\)$", + re.IGNORECASE, + ) + def can_handle(self, ctx: Any) -> bool: """Check if this is a select context""" return hasattr(ctx, "projectionItems") @@ -122,6 +131,21 @@ def handle_visitor(self, ctx: PartiQLParser.SelectItemsContext, parse_result: "Q if hasattr(ctx, "projectionItems") and ctx.projectionItems(): for item in ctx.projectionItems().projectionItem(): field_name, alias = self._extract_field_and_alias(item) + + # Check if this is an aggregate function (COUNT, SUM, etc.) + agg_match = self._AGGREGATE_PATTERN.match(field_name) + if agg_match: + func_name = agg_match.group(1).upper() + func_arg = agg_match.group(2) + parse_result.aggregate_functions.append( + { + "function": func_name, + "argument": func_arg, + "alias": alias or field_name, + } + ) + continue + # Use MongoDB standard projection format: {field: 1} to include field projection[field_name] = 1 # Store alias if present diff --git a/tests/test_cursor_aggregate.py b/tests/test_cursor_aggregate.py index a74e105..41fb1dc 100644 --- a/tests/test_cursor_aggregate.py +++ b/tests/test_cursor_aggregate.py @@ -354,3 +354,165 @@ def test_aggregate_collection_name_with_hyphen(self, conn): customer_type_idx = col_names.index("customer_type") for row in rows: assert row[customer_type_idx] == "premium", "All rows should have customer_type='premium'" + + +class TestSqlGroupFunctions: + """Test SQL aggregate functions (COUNT, AVG, MIN, MAX, SUM) translated to MongoDB pipelines.""" + + def test_count_star(self, conn): + """SELECT COUNT(*) AS total FROM users → should return document count""" + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) AS total FROM users") + + rows = cursor.fetchall() + assert len(rows) == 1 + + col_names = [desc[0] for desc in cursor.description] + assert "total" in col_names + + total_idx = col_names.index("total") + assert rows[0][total_idx] == 22 # 22 users in test data + + def test_count_star_no_alias(self, conn): + """SELECT COUNT(*) FROM users → column name defaults to COUNT(*)""" + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM users") + + rows = cursor.fetchall() + assert len(rows) == 1 + + col_names = [desc[0] for desc in cursor.description] + assert "COUNT(*)" in col_names + assert rows[0][col_names.index("COUNT(*)")] == 22 + + def test_count_star_with_where(self, conn): + """SELECT COUNT(*) AS total FROM users WHERE age > 30 → filtered count""" + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) AS total FROM users WHERE age > 30") + + rows = cursor.fetchall() + assert len(rows) == 1 + + col_names = [desc[0] for desc in cursor.description] + total = rows[0][col_names.index("total")] + assert isinstance(total, (int, float)) + assert total > 0 + assert total < 22 # Must be less than total users + + def test_avg(self, conn): + """SELECT AVG(age) AS avg_age FROM users""" + cursor = conn.cursor() + cursor.execute("SELECT AVG(age) AS avg_age FROM users") + + rows = cursor.fetchall() + assert len(rows) == 1 + + col_names = [desc[0] for desc in cursor.description] + avg_age = rows[0][col_names.index("avg_age")] + assert isinstance(avg_age, (int, float)) + assert 24 <= avg_age <= 45 # Must be within the age range + + def test_min(self, conn): + """SELECT MIN(age) AS youngest FROM users""" + cursor = conn.cursor() + cursor.execute("SELECT MIN(age) AS youngest FROM users") + + rows = cursor.fetchall() + assert len(rows) == 1 + + col_names = [desc[0] for desc in cursor.description] + youngest = rows[0][col_names.index("youngest")] + assert youngest == 24 # Min age in test data + + def test_max(self, conn): + """SELECT MAX(age) AS oldest FROM users""" + cursor = conn.cursor() + cursor.execute("SELECT MAX(age) AS oldest FROM users") + + rows = cursor.fetchall() + assert len(rows) == 1 + + col_names = [desc[0] for desc in cursor.description] + oldest = rows[0][col_names.index("oldest")] + assert oldest == 45 # Max age in test data + + def test_sum(self, conn): + """SELECT SUM(price) AS total_price FROM products""" + cursor = conn.cursor() + cursor.execute("SELECT SUM(price) AS total_price FROM products") + + rows = cursor.fetchall() + assert len(rows) == 1 + + col_names = [desc[0] for desc in cursor.description] + total_price = rows[0][col_names.index("total_price")] + assert isinstance(total_price, (int, float)) + assert total_price > 0 + + def test_multiple_aggregates(self, conn): + """SELECT COUNT(*) AS cnt, MIN(price) AS cheapest, MAX(price) AS priciest, AVG(price) AS avg_price FROM products""" + cursor = conn.cursor() + cursor.execute( + "SELECT COUNT(*) AS cnt, MIN(price) AS cheapest, MAX(price) AS priciest, AVG(price) AS avg_price FROM products" + ) + + rows = cursor.fetchall() + assert len(rows) == 1 + + col_names = [desc[0] for desc in cursor.description] + row = rows[0] + + cnt = row[col_names.index("cnt")] + cheapest = row[col_names.index("cheapest")] + priciest = row[col_names.index("priciest")] + avg_price = row[col_names.index("avg_price")] + + assert cnt == 50 + assert cheapest <= avg_price <= priciest + + def test_min_max_on_products(self, conn): + """SELECT MIN(price) AS low, MAX(price) AS high FROM products""" + cursor = conn.cursor() + cursor.execute("SELECT MIN(price) AS low, MAX(price) AS high FROM products") + + rows = cursor.fetchall() + assert len(rows) == 1 + + col_names = [desc[0] for desc in cursor.description] + low = rows[0][col_names.index("low")] + high = rows[0][col_names.index("high")] + assert low < high + + def test_count_with_and_or_conditions(self, conn): + """SELECT COUNT(*) AS cnt FROM users WHERE (active = true AND age > 30) OR age < 25""" + cursor = conn.cursor() + + # AND-only: active users over 30 + cursor.execute("SELECT COUNT(*) AS cnt FROM users WHERE active = true AND age > 30") + rows = cursor.fetchall() + col_names = [desc[0] for desc in cursor.description] + and_count = rows[0][col_names.index("cnt")] + assert isinstance(and_count, (int, float)) + assert and_count > 0 + assert and_count < 22 + + # OR-only: very young or very old + cursor.execute("SELECT COUNT(*) AS cnt FROM users WHERE age < 26 OR age > 40") + rows = cursor.fetchall() + col_names = [desc[0] for desc in cursor.description] + or_count = rows[0][col_names.index("cnt")] + assert isinstance(or_count, (int, float)) + assert or_count > 0 + assert or_count < 22 + + # Three AND conditions + cursor.execute( + "SELECT COUNT(*) AS cnt, AVG(age) AS avg_age FROM users " "WHERE active = true AND age >= 25 AND age <= 40" + ) + rows = cursor.fetchall() + col_names = [desc[0] for desc in cursor.description] + cnt = rows[0][col_names.index("cnt")] + avg_age = rows[0][col_names.index("avg_age")] + assert cnt > 0 + assert cnt < 22 + assert 25 <= avg_age <= 40 diff --git a/tests/test_sql_parser_group.py b/tests/test_sql_parser_group.py new file mode 100644 index 0000000..1f69545 --- /dev/null +++ b/tests/test_sql_parser_group.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +import json + +from pymongosql.sql.parser import SQLParser + + +class TestCountStarParsing: + """Test that COUNT(*) in SQL is translated to a MongoDB aggregate pipeline.""" + + def test_count_star_basic(self): + """SELECT COUNT(*) FROM users → aggregate with $group and $sum:1""" + sql = "SELECT COUNT(*) FROM users" + parser = SQLParser(sql) + plan = parser.get_execution_plan() + + assert plan.is_aggregate_query is True + assert plan.collection == "users" + + pipeline = json.loads(plan.aggregate_pipeline) + # Should have $group and $project stages + assert len(pipeline) == 2 + assert "$group" in pipeline[0] + assert pipeline[0]["$group"]["_id"] is None + assert pipeline[0]["$group"]["COUNT(*)"] == {"$sum": 1} + + def test_count_star_with_alias(self): + """SELECT COUNT(*) AS total FROM users → alias used in $group""" + sql = "SELECT COUNT(*) AS total FROM users" + parser = SQLParser(sql) + plan = parser.get_execution_plan() + + assert plan.is_aggregate_query is True + pipeline = json.loads(plan.aggregate_pipeline) + assert pipeline[0]["$group"]["total"] == {"$sum": 1} + # $project should expose the alias + assert pipeline[1]["$project"]["total"] == 1 + assert pipeline[1]["$project"]["_id"] == 0 + + def test_count_star_with_alias_no_as(self): + """SELECT COUNT(*) total FROM users → alias without AS keyword""" + sql = "SELECT COUNT(*) total FROM users" + parser = SQLParser(sql) + plan = parser.get_execution_plan() + + assert plan.is_aggregate_query is True + pipeline = json.loads(plan.aggregate_pipeline) + assert pipeline[0]["$group"]["total"] == {"$sum": 1} + + def test_count_star_with_where(self): + """SELECT COUNT(*) AS total FROM users WHERE age > 25 → $match before $group""" + sql = "SELECT COUNT(*) AS total FROM users WHERE age > 25" + parser = SQLParser(sql) + plan = parser.get_execution_plan() + + assert plan.is_aggregate_query is True + pipeline = json.loads(plan.aggregate_pipeline) + # Should have $match, $group, $project + assert len(pipeline) == 3 + assert "$match" in pipeline[0] + assert "$group" in pipeline[1] + assert pipeline[1]["$group"]["total"] == {"$sum": 1} + + def test_count_star_projection_stage(self): + """Projection stage should reflect aggregate output fields.""" + sql = "SELECT COUNT(*) AS total FROM users" + parser = SQLParser(sql) + plan = parser.get_execution_plan() + + assert plan.projection_stage == {"total": 1} + + def test_count_star_plan_validates(self): + """Generated aggregate plan should pass validation.""" + sql = "SELECT COUNT(*) FROM users" + parser = SQLParser(sql) + plan = parser.get_execution_plan() + assert plan.validate() is True + + def test_sum(self): + """SELECT SUM(price) AS total_price FROM products""" + sql = "SELECT SUM(price) AS total_price FROM products" + parser = SQLParser(sql) + plan = parser.get_execution_plan() + + assert plan.is_aggregate_query is True + pipeline = json.loads(plan.aggregate_pipeline) + assert pipeline[0]["$group"]["total_price"] == {"$sum": "$price"} + + def test_avg(self): + """SELECT AVG(age) AS avg_age FROM users""" + sql = "SELECT AVG(age) AS avg_age FROM users" + parser = SQLParser(sql) + plan = parser.get_execution_plan() + + assert plan.is_aggregate_query is True + pipeline = json.loads(plan.aggregate_pipeline) + assert pipeline[0]["$group"]["avg_age"] == {"$avg": "$age"} + + def test_min(self): + """SELECT MIN(price) AS cheapest FROM products""" + sql = "SELECT MIN(price) AS cheapest FROM products" + parser = SQLParser(sql) + plan = parser.get_execution_plan() + + assert plan.is_aggregate_query is True + pipeline = json.loads(plan.aggregate_pipeline) + assert pipeline[0]["$group"]["cheapest"] == {"$min": "$price"} + + def test_max(self): + """SELECT MAX(price) AS most_expensive FROM products""" + sql = "SELECT MAX(price) AS most_expensive FROM products" + parser = SQLParser(sql) + plan = parser.get_execution_plan() + + assert plan.is_aggregate_query is True + pipeline = json.loads(plan.aggregate_pipeline) + assert pipeline[0]["$group"]["most_expensive"] == {"$max": "$price"} + + def test_multiple_aggregates(self): + """SELECT COUNT(*) AS cnt, AVG(price) AS avg_price, MAX(price) AS max_price FROM products""" + sql = "SELECT COUNT(*) AS cnt, AVG(price) AS avg_price, MAX(price) AS max_price FROM products" + parser = SQLParser(sql) + plan = parser.get_execution_plan() + + assert plan.is_aggregate_query is True + pipeline = json.loads(plan.aggregate_pipeline) + group = pipeline[0]["$group"] + assert group["_id"] is None + assert group["cnt"] == {"$sum": 1} + assert group["avg_price"] == {"$avg": "$price"} + assert group["max_price"] == {"$max": "$price"} + # $project exposes all three + project = pipeline[1]["$project"] + assert project == {"_id": 0, "cnt": 1, "avg_price": 1, "max_price": 1} + + def test_aggregate_with_nested_field(self): + """SELECT SUM(details.total) AS revenue FROM orders""" + sql = "SELECT SUM(details.total) AS revenue FROM orders" + parser = SQLParser(sql) + plan = parser.get_execution_plan() + + assert plan.is_aggregate_query is True + pipeline = json.loads(plan.aggregate_pipeline) + assert pipeline[0]["$group"]["revenue"] == {"$sum": "$details.total"} + + def test_aggregate_no_alias_uses_raw_text(self): + """SELECT SUM(price) FROM products → alias defaults to SUM(price)""" + sql = "SELECT SUM(price) FROM products" + parser = SQLParser(sql) + plan = parser.get_execution_plan() + + pipeline = json.loads(plan.aggregate_pipeline) + assert "SUM(price)" in pipeline[0]["$group"] + + def test_regular_select_unaffected(self): + """Regular SELECT without aggregate functions should not be affected.""" + sql = "SELECT name, email FROM users" + parser = SQLParser(sql) + plan = parser.get_execution_plan() + + assert plan.is_aggregate_query is False + assert plan.projection_stage == {"name": 1, "email": 1}