Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 61 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ PyMongoSQL implements the DB API 2.0 interfaces to provide SQL-like access to Mo
- **PartiQL-based SQL Syntax**: Built on [PartiQL](https://partiql.org/tutorial.html) (SQL for semi-structured data), enabling seamless SQL querying of nested and hierarchical MongoDB documents
- **Nested Structure Support**: Query and filter deeply nested fields and arrays within MongoDB documents using standard SQL syntax
- **MongoDB Aggregate Pipeline Support**: Execute native MongoDB aggregation pipelines using SQL-like syntax with `aggregate()` function
- **SQL Aggregate Functions**: `COUNT(*)`, `SUM`, `AVG`, `MIN`, `MAX` translated to MongoDB aggregation pipelines
- **SQLAlchemy Integration**: Complete ORM and Core support with dedicated MongoDB dialect
- **SQL Query Support**: SELECT statements with WHERE conditions, field selection, and aliases
- **DML Support**: Full support for INSERT, UPDATE, and DELETE operations using PartiQL syntax
Expand Down Expand Up @@ -87,6 +88,7 @@ pip install -e .
- [WHERE Clauses](#where-clauses)
- [Nested Field Support](#nested-field-support)
- [Sorting and Limiting](#sorting-and-limiting)
- [SQL Aggregate Functions](#sql-aggregate-functions)
- [MongoDB Aggregate Function](#mongodb-aggregate-function)
- [INSERT Statements](#insert-statements)
- [UPDATE Statements](#update-statements)
Expand Down Expand Up @@ -293,6 +295,42 @@ Both functions:
- **LIMIT**: `LIMIT 10`
- **Combined**: `ORDER BY created_at DESC LIMIT 5`

### SQL Aggregate Functions

PyMongoSQL supports standard SQL aggregate functions that are automatically translated into MongoDB aggregation pipelines.

**Supported Functions**: `COUNT(*)`, `SUM(field)`, `AVG(field)`, `MIN(field)`, `MAX(field)`

**Basic Count**

```python
cursor.execute("SELECT COUNT(*) AS total FROM users")
row = cursor.fetchone()
print(f"Total users: {row[0]}")
```

**Multiple Aggregates**

```python
cursor.execute(
"SELECT COUNT(*) AS cnt, AVG(price) AS avg_price, MIN(price) AS cheapest, MAX(price) AS priciest FROM products"
)
```

**Aggregate with WHERE**

```python
cursor.execute("SELECT COUNT(*) AS total FROM users WHERE active = true AND age > 30")
```

**Aggregate with OR Conditions**

```python
cursor.execute("SELECT COUNT(*) AS cnt FROM users WHERE age < 26 OR age > 40")
```

**Note:** Aggregate functions are translated into a MongoDB `aggregate()` pipeline with `$match` (from WHERE), `$group` (with accumulators), and `$project` stages. `COUNT(*)` maps to `{$sum: 1}`, while `SUM`, `AVG`, `MIN`, and `MAX` map to their corresponding MongoDB accumulators (`$sum`, `$avg`, `$min`, `$max`).

### MongoDB Aggregate Function

PyMongoSQL supports executing native MongoDB aggregation pipelines using SQL-like syntax with the `aggregate()` function. This allows you to leverage MongoDB's powerful aggregation framework while maintaining SQL-style query patterns.
Expand Down Expand Up @@ -608,20 +646,24 @@ The table below shows how PyMongoSQL translates SQL operations into MongoDB comm

### SQL Operations to MongoDB Commands

| SQL Operation | MongoDB Command | Equivalent PyMongo Method |
|---|---|---|
| `SELECT ... FROM col` | `{find: col, projection: {...}}` | `db.command("find", ...)` |
| `SELECT ... FROM col WHERE ...` | `{find: col, filter: {...}}` | `db.command("find", ...)` |
| `SELECT ... ORDER BY col ASC/DESC` | `{find: ..., sort: {col: 1/-1}}` | `db.command("find", ...)` |
| `SELECT ... LIMIT n` | `{find: ..., limit: n}` | `db.command("find", ...)` |
| `SELECT ... OFFSET n` | `{find: ..., skip: n}` | `db.command("find", ...)` |
| `SELECT * FROM col.aggregate(...)` | `collection.aggregate(pipeline)` | `collection.aggregate()` |
| `INSERT INTO col ...` | `{insert: col, documents: [...]}` | `db.command("insert", ...)` |
| `UPDATE col SET ... WHERE ...` | `{update: col, updates: [{q: filter, u: {$set: {...}}, multi: true}]}` | `db.command("update", ...)` |
| `DELETE FROM col WHERE ...` | `{delete: col, deletes: [{q: filter, limit: 0}]}` | `db.command("delete", ...)` |
| `CREATE VIEW v ON col AS '[...]'` | `{create: v, viewOn: col, pipeline: [...]}` | `db.command("create", ...)` |
| `DROP VIEW v` | `{drop: v}` | `db.command("drop", ...)` |
| `EXPLAIN <select>` | `{explain: <find\|aggregate cmd>, verbosity: "queryPlanner"}` | `db.command("explain", ...)` |
| SQL Operation | MongoDB Command |
|---|---|
| `SELECT ... FROM col` | `{find: col, projection: {...}}` |
| `SELECT ... FROM col WHERE ...` | `{find: col, filter: {...}}` |
| `SELECT ... ORDER BY col ASC/DESC` | `{find: ..., sort: {col: 1/-1}}` |
| `SELECT ... LIMIT n` | `{find: ..., limit: n}` |
| `SELECT ... OFFSET n` | `{find: ..., skip: n}` |
| `SELECT COUNT(*) FROM col` | `collection.aggregate([{$group: {_id: null, ...}}, {$project: ...}])` |
| `SELECT AVG(field) FROM col` | `collection.aggregate([{$group: {_id: null, ...}}, {$project: ...}])` |
| `SELECT MIN/MAX(field) FROM col` | `collection.aggregate([{$group: {_id: null, ...}}, {$project: ...}])` |
| `SELECT SUM(field) FROM col` | `collection.aggregate([{$group: {_id: null, ...}}, {$project: ...}])` |
| `SELECT * FROM col.aggregate(...)` | `collection.aggregate(pipeline)` |
| `INSERT INTO col ...` | `{insert: col, documents: [...]}` |
| `UPDATE col SET ... WHERE ...` | `{update: col, updates: [{q: filter, u: {$set: {...}}, multi: true}]}` |
| `DELETE FROM col WHERE ...` | `{delete: col, deletes: [{q: filter, limit: 0}]}` |
| `CREATE VIEW v ON col AS '[...]'` | `{create: v, viewOn: col, pipeline: [...]}` |
| `DROP VIEW v` | `{drop: v}` |
| `EXPLAIN <select>` | `{explain: <find\|aggregate cmd>, verbosity: "queryPlanner"}` |

### SQL Clauses to MongoDB Query Components

Expand All @@ -635,6 +677,11 @@ The table below shows how PyMongoSQL translates SQL operations into MongoDB comm
| `ORDER BY col DESC` | `sort: {col: -1}` | Descending sort |
| `LIMIT n` | `limit: n` | Restrict result count |
| `OFFSET n` | `skip: n` | Skip first n results |
| `COUNT(*)` | `{$group: {_id: null, count: {$sum: 1}}}` | Document count |
| `SUM(field)` | `{$group: {_id: null, sum: {$sum: "$field"}}}` | Field sum |
| `AVG(field)` | `{$group: {_id: null, avg: {$avg: "$field"}}}` | Field average |
| `MIN(field)` | `{$group: {_id: null, min: {$min: "$field"}}}` | Field minimum |
| `MAX(field)` | `{$group: {_id: null, max: {$max: "$field"}}}` | Field maximum |

### WHERE Operators to MongoDB Filter Operators

Expand Down
60 changes: 60 additions & 0 deletions pymongosql/sql/builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import json
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
Expand Down Expand Up @@ -112,6 +113,11 @@ def build_from_parse_result(
@staticmethod
def _build_query_plan(parse_result: "QueryParseResult") -> "QueryExecutionPlan":
"""Build a query execution plan from SELECT parsing."""

# Auto-generate aggregate pipeline for SQL aggregate functions (COUNT, SUM, etc.)
if getattr(parse_result, "aggregate_functions", None):
return ExecutionPlanBuilder._build_sql_aggregate_plan(parse_result)

builder = BuilderFactory.create_query_builder().collection(parse_result.collection)

builder.filter(parse_result.filter_conditions).project(parse_result.projection).column_aliases(
Expand All @@ -128,6 +134,60 @@ def _build_query_plan(parse_result: "QueryParseResult") -> "QueryExecutionPlan":
plan = builder.build()
return plan

@staticmethod
def _build_sql_aggregate_plan(parse_result: "QueryParseResult") -> "QueryExecutionPlan":
"""Build an aggregate execution plan from SQL aggregate functions like COUNT(*), SUM(), etc."""
_FUNCTION_TO_ACCUMULATOR = {
"COUNT": "$sum",
"SUM": "$sum",
"AVG": "$avg",
"MIN": "$min",
"MAX": "$max",
}

builder = BuilderFactory.create_query_builder().collection(parse_result.collection)

pipeline = []

# Add $match stage if there are filter conditions (from WHERE clause)
if parse_result.filter_conditions:
pipeline.append({"$match": parse_result.filter_conditions})

# Build $group stage from aggregate functions
group_stage = {"_id": None}
for func_info in parse_result.aggregate_functions:
alias = func_info["alias"]
func_name = func_info["function"]
arg = func_info["argument"]
accumulator = _FUNCTION_TO_ACCUMULATOR[func_name]

if func_name == "COUNT":
group_stage[alias] = {accumulator: 1}
else:
group_stage[alias] = {accumulator: f"${arg}"}

pipeline.append({"$group": group_stage})

# Add $project to exclude _id
project_stage = {"_id": 0}
for func_info in parse_result.aggregate_functions:
project_stage[func_info["alias"]] = 1
pipeline.append({"$project": project_stage})

# Configure the execution plan as an aggregate query
builder._execution_plan.is_aggregate_query = True
builder._execution_plan.aggregate_pipeline = json.dumps(pipeline)
builder._execution_plan.aggregate_options = json.dumps({})

# Set projection for ResultSet description
agg_projection = {}
for func_info in parse_result.aggregate_functions:
agg_projection[func_info["alias"]] = 1
builder._execution_plan.projection_stage = agg_projection

plan = builder.build()
return plan

@staticmethod
def _build_insert_plan(parse_result: "InsertParseResult") -> "InsertExecutionPlan":
"""Build an INSERT execution plan from INSERT parsing."""
Expand Down
24 changes: 24 additions & 0 deletions pymongosql/sql/query_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class QueryParseResult:
aggregate_pipeline: Optional[str] = None # JSON string representation of pipeline
aggregate_options: Optional[str] = None # JSON string representation of options

# SQL aggregate functions detected in SELECT (COUNT, SUM, AVG, MIN, MAX)
aggregate_functions: List[Dict[str, Any]] = field(default_factory=list)

# Subquery info (for wrapped subqueries, e.g., Superset outering)
subquery_plan: Optional[Any] = None
subquery_alias: Optional[str] = None
Expand Down Expand Up @@ -111,6 +114,12 @@ def handle(self, ctx: PartiQLParser.WhereClauseSelectContext) -> Dict[str, Any]:
class SelectHandler(BaseHandler, ContextUtilsMixin):
"""Handles SELECT statement parsing"""

# Pattern to detect SQL aggregate functions: COUNT(*), SUM(field), AVG(field), etc.
_AGGREGATE_PATTERN = re.compile(
r"^(COUNT|SUM|AVG|MIN|MAX)\s*\(\s*(\*|\w+(?:\.\w+)*)\s*\)$",
re.IGNORECASE,
)

def can_handle(self, ctx: Any) -> bool:
"""Check if this is a select context"""
return hasattr(ctx, "projectionItems")
Expand All @@ -122,6 +131,21 @@ def handle_visitor(self, ctx: PartiQLParser.SelectItemsContext, parse_result: "Q
if hasattr(ctx, "projectionItems") and ctx.projectionItems():
for item in ctx.projectionItems().projectionItem():
field_name, alias = self._extract_field_and_alias(item)

# Check if this is an aggregate function (COUNT, SUM, etc.)
agg_match = self._AGGREGATE_PATTERN.match(field_name)
if agg_match:
func_name = agg_match.group(1).upper()
func_arg = agg_match.group(2)
parse_result.aggregate_functions.append(
{
"function": func_name,
"argument": func_arg,
"alias": alias or field_name,
}
)
continue

# Use MongoDB standard projection format: {field: 1} to include field
projection[field_name] = 1
# Store alias if present
Expand Down
162 changes: 162 additions & 0 deletions tests/test_cursor_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,165 @@ def test_aggregate_collection_name_with_hyphen(self, conn):
customer_type_idx = col_names.index("customer_type")
for row in rows:
assert row[customer_type_idx] == "premium", "All rows should have customer_type='premium'"


class TestSqlGroupFunctions:
"""Test SQL aggregate functions (COUNT, AVG, MIN, MAX, SUM) translated to MongoDB pipelines."""

def test_count_star(self, conn):
"""SELECT COUNT(*) AS total FROM users → should return document count"""
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) AS total FROM users")

rows = cursor.fetchall()
assert len(rows) == 1

col_names = [desc[0] for desc in cursor.description]
assert "total" in col_names

total_idx = col_names.index("total")
assert rows[0][total_idx] == 22 # 22 users in test data

def test_count_star_no_alias(self, conn):
"""SELECT COUNT(*) FROM users → column name defaults to COUNT(*)"""
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM users")

rows = cursor.fetchall()
assert len(rows) == 1

col_names = [desc[0] for desc in cursor.description]
assert "COUNT(*)" in col_names
assert rows[0][col_names.index("COUNT(*)")] == 22

def test_count_star_with_where(self, conn):
"""SELECT COUNT(*) AS total FROM users WHERE age > 30 → filtered count"""
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) AS total FROM users WHERE age > 30")

rows = cursor.fetchall()
assert len(rows) == 1

col_names = [desc[0] for desc in cursor.description]
total = rows[0][col_names.index("total")]
assert isinstance(total, (int, float))
assert total > 0
assert total < 22 # Must be less than total users

def test_avg(self, conn):
"""SELECT AVG(age) AS avg_age FROM users"""
cursor = conn.cursor()
cursor.execute("SELECT AVG(age) AS avg_age FROM users")

rows = cursor.fetchall()
assert len(rows) == 1

col_names = [desc[0] for desc in cursor.description]
avg_age = rows[0][col_names.index("avg_age")]
assert isinstance(avg_age, (int, float))
assert 24 <= avg_age <= 45 # Must be within the age range

def test_min(self, conn):
"""SELECT MIN(age) AS youngest FROM users"""
cursor = conn.cursor()
cursor.execute("SELECT MIN(age) AS youngest FROM users")

rows = cursor.fetchall()
assert len(rows) == 1

col_names = [desc[0] for desc in cursor.description]
youngest = rows[0][col_names.index("youngest")]
assert youngest == 24 # Min age in test data

def test_max(self, conn):
"""SELECT MAX(age) AS oldest FROM users"""
cursor = conn.cursor()
cursor.execute("SELECT MAX(age) AS oldest FROM users")

rows = cursor.fetchall()
assert len(rows) == 1

col_names = [desc[0] for desc in cursor.description]
oldest = rows[0][col_names.index("oldest")]
assert oldest == 45 # Max age in test data

def test_sum(self, conn):
"""SELECT SUM(price) AS total_price FROM products"""
cursor = conn.cursor()
cursor.execute("SELECT SUM(price) AS total_price FROM products")

rows = cursor.fetchall()
assert len(rows) == 1

col_names = [desc[0] for desc in cursor.description]
total_price = rows[0][col_names.index("total_price")]
assert isinstance(total_price, (int, float))
assert total_price > 0

def test_multiple_aggregates(self, conn):
"""SELECT COUNT(*) AS cnt, MIN(price) AS cheapest, MAX(price) AS priciest, AVG(price) AS avg_price FROM products"""
cursor = conn.cursor()
cursor.execute(
"SELECT COUNT(*) AS cnt, MIN(price) AS cheapest, MAX(price) AS priciest, AVG(price) AS avg_price FROM products"
)

rows = cursor.fetchall()
assert len(rows) == 1

col_names = [desc[0] for desc in cursor.description]
row = rows[0]

cnt = row[col_names.index("cnt")]
cheapest = row[col_names.index("cheapest")]
priciest = row[col_names.index("priciest")]
avg_price = row[col_names.index("avg_price")]

assert cnt == 50
assert cheapest <= avg_price <= priciest

def test_min_max_on_products(self, conn):
"""SELECT MIN(price) AS low, MAX(price) AS high FROM products"""
cursor = conn.cursor()
cursor.execute("SELECT MIN(price) AS low, MAX(price) AS high FROM products")

rows = cursor.fetchall()
assert len(rows) == 1

col_names = [desc[0] for desc in cursor.description]
low = rows[0][col_names.index("low")]
high = rows[0][col_names.index("high")]
assert low < high

def test_count_with_and_or_conditions(self, conn):
"""SELECT COUNT(*) AS cnt FROM users WHERE (active = true AND age > 30) OR age < 25"""
cursor = conn.cursor()

# AND-only: active users over 30
cursor.execute("SELECT COUNT(*) AS cnt FROM users WHERE active = true AND age > 30")
rows = cursor.fetchall()
col_names = [desc[0] for desc in cursor.description]
and_count = rows[0][col_names.index("cnt")]
assert isinstance(and_count, (int, float))
assert and_count > 0
assert and_count < 22

# OR-only: very young or very old
cursor.execute("SELECT COUNT(*) AS cnt FROM users WHERE age < 26 OR age > 40")
rows = cursor.fetchall()
col_names = [desc[0] for desc in cursor.description]
or_count = rows[0][col_names.index("cnt")]
assert isinstance(or_count, (int, float))
assert or_count > 0
assert or_count < 22

# Three AND conditions
cursor.execute(
"SELECT COUNT(*) AS cnt, AVG(age) AS avg_age FROM users " "WHERE active = true AND age >= 25 AND age <= 40"
)
rows = cursor.fetchall()
col_names = [desc[0] for desc in cursor.description]
cnt = rows[0][col_names.index("cnt")]
avg_age = rows[0][col_names.index("avg_age")]
assert cnt > 0
assert cnt < 22
assert 25 <= avg_age <= 40
Loading
Loading