Skip to content
11 changes: 11 additions & 0 deletions videodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from videodb._constants import (
VIDEO_DB_API,
IndexType,
IndexCapability,
FieldGroup,
SceneExtractionType,
MediaType,
SearchType,
Expand All @@ -25,6 +27,8 @@
RTStreamChannelType,
)
from videodb.client import Connection
from videodb.search import AskResponse, SearchResponse, SearchResult
from videodb.understanding import Understanding, UnderstandingAnalyzer
from videodb.capture_session import CaptureSession
from videodb.websocket_client import WebSocketConnection
from videodb.capture import CaptureClient, Channel, AudioChannel, VideoChannel, Channels, ChannelList
Expand Down Expand Up @@ -53,7 +57,14 @@
"AuthenticationError",
"InvalidRequestError",
"IndexType",
"IndexCapability",
"FieldGroup",
"SearchError",
"SearchResult",
"SearchResponse",
"AskResponse",
"Understanding",
"UnderstandingAnalyzer",
"play_stream",
"build_iframe_embed_code",
"MediaType",
Expand Down
25 changes: 25 additions & 0 deletions videodb/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,24 @@ class IndexType:
scene = "scene"


class IndexCapability:
"""Retrieval capabilities an index can be built for (``use_for``)."""

semantic = "semantic"
query = "query"
aggregate = "aggregate"


class FieldGroup:
"""Field groups that map artifact fields to retrieval capabilities."""

semantic = "semantic"
text = "text"
filter = "filter"
aggregate = "aggregate"
sort = "sort"


class SceneExtractionType:
shot_based = "shot"
time_based = "time"
Expand Down Expand Up @@ -80,7 +98,14 @@ class ApiPath:
upload_url = "upload_url"
transcription = "transcription"
index = "index"
indexes = "indexes"
records = "records"
understand = "understand"
search = "search"
ask = "ask"
semantic_search = "semantic-search"
query = "query"
aggregate = "aggregate"
compile = "compile"
workflow = "workflow"
timeline = "timeline"
Expand Down
163 changes: 160 additions & 3 deletions videodb/collection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from typing import Optional, Union, List, Dict, Any, Literal
from typing import Optional, Union, List, Dict, Any, Literal, Tuple
from videodb._upload import (
upload,
)
Expand All @@ -17,7 +17,7 @@
from videodb.meeting import Meeting
from videodb.capture_session import CaptureSession
from videodb.rtstream import RTStream, RTStreamSearchResult, RTStreamShot
from videodb.search import SearchFactory, SearchResult
from videodb.search import AskResponse, SearchFactory, SearchResponse, SearchResult, warn_legacy_search_once

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -462,7 +462,159 @@ def dub_video(
if dub_data:
return Video(self._connection, **dub_data)

def search(
def search(self, query: str, *args, **kwargs) -> Union[SearchResponse, SearchResult, RTStreamSearchResult]:
"""Search the collection.

New search is used by default. Calls that use legacy-shaped parameters are
routed to :meth:`legacy_search` with a warning.
"""
old_params = {
"search_type",
"index_type",
"result_threshold",
"dynamic_score_percentage",
"scene_index_id",
"index_id",
"algorithm",
"sort_docs_on",
"namespace",
}
new_params = {
"top_k",
"mode",
"return_fields",
"include_clip",
}
unsupported_params = {"index_name", "index_names"}

if args:
legacy_arg_names = [
"search_type",
"index_type",
"result_threshold",
"score_threshold",
"dynamic_score_percentage",
"filter",
"sort_docs_on",
"namespace",
"scene_index_id",
]
for name, value in zip(legacy_arg_names, args):
kwargs.setdefault(name, value)

has_old = bool(args) or any(k in kwargs and kwargs[k] is not None for k in old_params)
has_new = any(k in kwargs and kwargs[k] is not None for k in new_params)
has_unsupported = any(k in kwargs and kwargs[k] is not None for k in unsupported_params)

if has_old and (has_new or has_unsupported):
raise ValueError(
"Cannot mix legacy search params with new search params. "
"Use search(...) for new search or legacy_search(...) for legacy search."
)
if has_unsupported:
raise ValueError(
"index_name/index_names are not supported in search(). "
"Use semantic_search(), query(), or aggregate() for index-specific calls."
)

if has_old:
warn_legacy_search_once()
return self.legacy_search(query=query, **kwargs)

return self._new_search(query=query, **kwargs)

def _new_search(self, query: str, **kwargs) -> SearchResponse:
payload = {"query": query, **{k: v for k, v in kwargs.items() if v is not None}}
search_data = self._connection.post(
path=f"{ApiPath.collection}/{self.id}/{ApiPath.search}/v2",
data=payload,
show_progress=True,
)
return SearchResponse(self._connection, **search_data)

def ask(
self,
question: str,
top_k: int = 15,
mode: str = "default",
include_sources: bool = False,
) -> AskResponse:
ask_data = self._connection.post(
path=f"{ApiPath.collection}/{self.id}/{ApiPath.ask}",
data={
"question": question,
"top_k": top_k,
"mode": mode,
"include_sources": include_sources,
},
show_progress=True,
)
return AskResponse(self._connection, **ask_data)

def semantic_search(
self,
query: str,
index_names: Optional[Union[List[str], str]] = None,
top_k: int = 10,
score_threshold: Optional[float] = None,
filter: Optional[Union[List, Dict]] = None,
return_fields: Optional[Union[List, Dict, str]] = None,
) -> SearchResult:
search_data = self._connection.post(
path=f"{ApiPath.collection}/{self.id}/{ApiPath.semantic_search}",
data={
"query": query,
"index_names": index_names,
"top_k": top_k,
"score_threshold": score_threshold,
"filter": filter,
"return_fields": return_fields,
},
)
return SearchResult(self._connection, **search_data)

def query(
self,
index_name: str,
filter: Optional[Union[List, Dict]] = None,
limit: int = 100,
return_fields: Optional[Union[List, Dict, str]] = None,
sort: Optional[Union[str, List[Tuple[str, str]]]] = None,
) -> SearchResult:
query_data = self._connection.post(
path=f"{ApiPath.collection}/{self.id}/{ApiPath.query}",
data={
"index_name": index_name,
"filter": filter,
"limit": limit,
"return_fields": return_fields,
"sort": sort,
},
)
return SearchResult(self._connection, **query_data)

def aggregate(
self,
index_name: str,
filter: Optional[Union[List, Dict]] = None,
group_by: Optional[str] = None,
metric: str = "count",
limit: int = 100,
sort: Optional[Union[str, List[Tuple[str, str]]]] = None,
) -> Union[Dict, List[Dict]]:
return self._connection.post(
path=f"{ApiPath.collection}/{self.id}/{ApiPath.aggregate}",
data={
"index_name": index_name,
"filter": filter,
"group_by": group_by,
"metric": metric,
"limit": limit,
"sort": sort,
},
)

def legacy_search(
self,
query: str,
search_type: Optional[str] = SearchType.semantic,
Expand All @@ -474,6 +626,8 @@ def search(
sort_docs_on: Optional[str] = None,
namespace: Optional[str] = None,
scene_index_id: Optional[str] = None,
index_id: Optional[str] = None,
algorithm: Optional[str] = None,
) -> Union[SearchResult, RTStreamSearchResult]:
"""Search for a query in the collection.

Expand All @@ -493,6 +647,9 @@ def search(
:rtype: Union[:class:`videodb.search.SearchResult`,
:class:`videodb.rtstream.RTStreamSearchResult`]
"""
if scene_index_id is None and index_id is not None:
scene_index_id = index_id

if namespace == "rtstream":
data = {"query": query}
if scene_index_id is not None:
Expand Down
Loading
Loading