Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.11"
python-version: "3.12"
- name: Install pypa/build
run: >-
python3 -m
Expand Down
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ source venv/bin/activate
uv pip install -r requirements.txt
```

The base package only needs lightweight dependencies: indexing and searching with the
anserini/seismic backends start from already-encoded JSONL files, so they work without torch.
Optional dependency groups cover the rest:
- `encode`: torch, transformers, etc. for encoding text into sparse representations
- `anserini`: pyjnius for the Anserini search backend (a JAR is also required; see below)
- `seismic`: pyseismic-lsr for the Seismic backend
- `all`: everything above

```
# e.g., install everything needed to encode and to search with the Seismic backend
uv pip install 'bsparse[encode,seismic]'
```

```
# Request access to splade-v3: https://huggingface.co/naver/splade-v3
# Get your huggingface API token and then:
Expand Down Expand Up @@ -71,7 +84,7 @@ backends.

```
# install the Seismic Python bindings (optional dependency; only needed for this backend)
uv pip install pyseismic-lsr
uv pip install 'bsparse[seismic]' # or: uv pip install pyseismic-lsr
# for best performance, build against your CPU instead:
# RUSTFLAGS="-C target-cpu=native" uv pip install --no-binary :all: pyseismic-lsr

Expand Down
2 changes: 1 addition & 1 deletion bsparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
from .utils import batch_encode, get_torch_device, token_ids_to_binary_vec


__version__ = "0.2.0"
__version__ = "0.3.0"
9 changes: 7 additions & 2 deletions bsparse/anserini.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
import tempfile
from collections import Counter
from dataclasses import dataclass
from typing import TYPE_CHECKING

from trecrun import TRECRun

from bsparse.models import Model
from bsparse.utils import psgid_to_docid


if TYPE_CHECKING:
# only used as a type hint; importing bsparse.models would pull in torch/transformers
from bsparse.models import Model


ANSERINI_JAR = os.environ.get("ANSERINI_JAR", "anserini-1.0.0-fatjar-bsparse.jar")
JAVA_ARGS = os.environ.get("ANSERINI_JAVA_ARGS", "-Xms4g,-Xmx16g").split(",")
THREADS = os.environ.get("ANSERINI_THREADS", os.cpu_count())
Expand All @@ -23,7 +28,7 @@ class Anserini:
def __init__(self, index_path: str):
self.index_path = index_path

def query_from_raw_text(self, queries: list[str], model: Model, k: int = 1000, scale: int = 50):
def query_from_raw_text(self, queries: list[str], model: "Model", k: int = 1000, scale: int = 50):
dataset = [(str(idx), query) for idx, query in enumerate(queries)]
ids, reps = model.encode(dataset)
vectors = [{"vector": rep} for rep in reps]
Expand Down
53 changes: 33 additions & 20 deletions bsparse/cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import argparse

import bsparse.commands
import bsparse.datasets
import bsparse.models


COMMANDS = {
Expand All @@ -13,22 +11,34 @@
"memsearch": bsparse.commands.MemSearch,
}

DEFAULT_MODELS = {
"splade": bsparse.models.SpladeModel,
"spladepsg": bsparse.models.SpladePsgModel,
"multilsr": bsparse.models.MultiLSRModel,
}
DEFAULT_MODEL = "splade"

DEFAULT_DATASETS = {
"irds": bsparse.datasets.IRDSDataset,
"jsonl": bsparse.datasets.JSONLDataset,
"tsv": bsparse.datasets.TSVDataset,
"hgf": bsparse.datasets.HgfDataset,
}
DEFAULT_DATASET = "irds"


# bsparse.models and bsparse.datasets are imported lazily so that commands that
# don't need them (e.g. index/search on already-encoded jsonl) don't pull in
# big packages like torch, transformers, and datasets
def get_models():
import bsparse.models

return {
"splade": bsparse.models.SpladeModel,
"spladepsg": bsparse.models.SpladePsgModel,
"multilsr": bsparse.models.MultiLSRModel,
}


def get_datasets():
import bsparse.datasets

return {
"irds": bsparse.datasets.IRDSDataset,
"jsonl": bsparse.datasets.JSONLDataset,
"tsv": bsparse.datasets.TSVDataset,
"hgf": bsparse.datasets.HgfDataset,
}


def get_command():
parser = argparse.ArgumentParser(description="bsparse CLI")
parser.add_argument("command", choices=list(COMMANDS.keys()), help="Command")
Expand All @@ -45,32 +55,35 @@ def main():
# parse the command name, so that we can see whether it needs --dataset and --model
command_cls = COMMANDS[get_command()]

datasets = get_datasets() if command_cls.needs_dataset else None
models = get_models() if command_cls.needs_model else None

# parse the --dataset and --model, so we can add dataset-specific and model-specific args
minimal_parser = argparse.ArgumentParser(description="bsparse CLI")
minimal_parser.add_argument("command", choices=list(COMMANDS.keys()), help="Command")
if command_cls.needs_dataset:
minimal_parser.add_argument("--dataset", choices=list(DEFAULT_DATASETS.keys()), required=True, help="Dataset")
minimal_parser.add_argument("--dataset", choices=list(datasets.keys()), required=True, help="Dataset")
if command_cls.needs_model:
minimal_parser.add_argument("--model", choices=list(DEFAULT_MODELS.keys()), required=True, help="Model")
minimal_parser.add_argument("--model", choices=list(models.keys()), required=True, help="Model")

known_args, remaining_args = minimal_parser.parse_known_args()

# parse the full command, so we can run() it
full_parser = argparse.ArgumentParser(parents=[minimal_parser], add_help=False)

if command_cls.needs_dataset:
DEFAULT_DATASETS[known_args.dataset].add_arguments(full_parser)
datasets[known_args.dataset].add_arguments(full_parser)
if command_cls.needs_model:
DEFAULT_MODELS[known_args.model].add_arguments(full_parser)
models[known_args.model].add_arguments(full_parser)

command_cls.add_arguments(full_parser)
args = full_parser.parse_args()

kwargs = {}
if command_cls.needs_dataset:
kwargs["dataset"] = DEFAULT_DATASETS[args.dataset](args)
kwargs["dataset"] = datasets[args.dataset](args)
if command_cls.needs_model:
kwargs["model"] = DEFAULT_MODELS[args.model](args)
kwargs["model"] = models[args.model](args)

command = command_cls(args, **kwargs)
command.run()
Expand Down
76 changes: 65 additions & 11 deletions bsparse/commands.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
import os
from abc import ABC, abstractmethod
from functools import partial
from multiprocessing import Pool
from pathlib import Path

import ir_datasets as irds
from tqdm import tqdm
from trecrun import TRECRun

Expand All @@ -14,6 +14,50 @@
from bsparse.utils import psgid_to_docid


def load_qrels(qrels: str) -> dict[str, dict[str, int]]:
"""Load relevance judgments into a {query_id: {doc_id: relevance}} dict.

`qrels` may be a TREC-format qrels file or an ir_datasets name: an existing file is parsed as
'qid iteration docid relevance' lines, and anything else is treated as an ir_datasets name.
"""
all_qrels = {}

if os.path.isfile(qrels):
with open(qrels, "rt", encoding="utf-8") as f:
for line_no, line in enumerate(f, start=1):
if not line.strip():
continue
try:
qid, _iteration, docid, relevance = line.split()
all_qrels.setdefault(qid, {})[docid] = int(relevance)
except ValueError as e:
raise ValueError(
f"invalid TREC qrels on line {line_no} of {qrels} (expected 'qid iteration docid relevance'): {line!r}"
) from e
return all_qrels

# ir_datasets is imported lazily to keep the no-qrels paths light
import ir_datasets as irds

try:
dataset = irds.load(qrels)
except KeyError:
raise ValueError(
f"--qrels value is neither an existing file nor a known ir_datasets name: {qrels} "
"(pass a TREC-format qrels file path, or an ir_datasets name like 'beir/nfcorpus/test')"
) from None

if not dataset.has_qrels():
raise ValueError(
f"ir_datasets dataset has no qrels: {qrels} "
"(pass a subset that includes relevance judgments, e.g. 'beir/nfcorpus/test' rather than 'beir/nfcorpus')"
)

for qr in dataset.qrels_iter():
all_qrels.setdefault(qr.query_id, {})[qr.doc_id] = qr.relevance
return all_qrels


class Command(ABC):
needs_model = False
needs_dataset = False
Expand Down Expand Up @@ -100,7 +144,12 @@ def add_arguments(cls, parser):
parser.add_argument("--out", type=Path, required=True, help="Output file path")
parser.add_argument("--pool", type=int, default=20, help="Multiprocessing pool size (default: %(default)s)")
parser.add_argument("--topk", type=int, default=1000, help="Top K results to return (default: %(default)s)")
parser.add_argument("--qrels", type=str, default=None, help="Relevance judgments dataset (default: %(default)s)")
parser.add_argument(
"--qrels",
type=str,
default=None,
help="Relevance judgments: TREC qrels file or ir_datasets name (default: %(default)s)",
)
parser.add_argument(
"--aggregate",
type=str,
Expand All @@ -116,6 +165,10 @@ def run(self):
if self.cfg.out.is_dir():
raise ValueError(f"--out is a directory: {self.cfg.out}")

if self.cfg.qrels:
# load the qrels before the (slow) search so a bad --qrels fails fast
all_qrels = load_qrels(self.cfg.qrels)

with Pool(self.cfg.pool) as p:
score_f = partial(score_shard, queries_fn=self.cfg.queries, topk=self.cfg.topk, aggregate=self.cfg.aggregate)
shard_scores = list(
Expand All @@ -142,10 +195,6 @@ def run(self):

if self.cfg.qrels:
print(f"evaluating with qrels: {self.cfg.qrels}")
all_qrels = {}
for qr in irds.load(self.cfg.qrels).qrels_iter():
all_qrels.setdefault(qr.query_id, {})[qr.doc_id] = qr.relevance

metrics = run.evaluate(all_qrels)
avg = {metric: vals["mean"] for metric, vals in metrics.items()}
print(json.dumps(avg, indent=4, sort_keys=True))
Expand Down Expand Up @@ -254,7 +303,12 @@ def add_arguments(cls, parser):
help="Search backend (default: %(default)s)",
)
parser.add_argument("--topk", type=int, default=1000, help="Top K results to return (default: %(default)s)")
parser.add_argument("--qrels", type=str, default=None, help="Relevance judgments dataset (default: %(default)s)")
parser.add_argument(
"--qrels",
type=str,
default=None,
help="Relevance judgments: TREC qrels file or ir_datasets name (default: %(default)s)",
)
# Backend-specific args default to None so we can tell whether the user set them: unset args fall
# back to the backend's own default, and passing an arg for a different backend is rejected (see run()).
# anserini-specific
Expand Down Expand Up @@ -284,6 +338,10 @@ def run(self):
if self.cfg.out.is_dir():
raise ValueError(f"--out is a directory: {self.cfg.out}")

if self.cfg.qrels:
# load the qrels before the (slow) search so a bad --qrels fails fast
all_qrels = load_qrels(self.cfg.qrels)

queries = load_dict(self.cfg.queries)
queries.ids = [psgid_to_docid(qid) for qid in queries.ids]
vectors = [{"vector": rep} for rep in queries.weights]
Expand All @@ -309,10 +367,6 @@ def run(self):

if self.cfg.qrels:
print(f"evaluating with qrels: {self.cfg.qrels}")
all_qrels = {}
for qr in irds.load(self.cfg.qrels).qrels_iter():
all_qrels.setdefault(qr.query_id, {})[qr.doc_id] = qr.relevance

metrics = run.evaluate(all_qrels)
avg = {metric: vals["mean"] for metric, vals in metrics.items()}
print(json.dumps(avg, indent=4, sort_keys=True))
Expand Down
12 changes: 10 additions & 2 deletions bsparse/convert.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch

from bsparse.jsonl import SparseRepresentations, dict2jsonl, jsonl2dict


# torch is imported inside the functions that need it, so that torch-free paths
# (e.g. anserini/seismic index/search over already-encoded jsonl) work without it


def load_dict(fn) -> SparseRepresentations:
return jsonl2dict(fn)

Expand Down Expand Up @@ -49,10 +51,14 @@ def dict2vec(ds, term2id):
if isinstance(ds, dict):
return _single_dict2vec(ds, term2id)

import torch

return torch.stack([_single_dict2vec(d, term2id) for d in ds])


def _single_dict2vec(d, term2id):
import torch

terms, weights = zip(*d.items())
termids = [term2id[term] for term in terms]
vec = torch.zeros(len(term2id), dtype=torch.float32)
Expand All @@ -61,6 +67,8 @@ def _single_dict2vec(d, term2id):


def token_ids_to_binary_vec(input_ids, attention_mask, special_tokens_mask, vocab_size):
import torch

binary_ids = torch.ones_like(input_ids, dtype=torch.float) * attention_mask * (1 - special_tokens_mask)
batch_size = binary_ids.shape[0]
sparse_rep = torch.zeros((batch_size, vocab_size), device=binary_ids.device).scatter_reduce_(
Expand Down
8 changes: 6 additions & 2 deletions bsparse/seismic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
import tempfile
import warnings
from contextlib import contextmanager
from typing import TYPE_CHECKING

import numpy as np
from tqdm import tqdm

from bsparse.models import Model

if TYPE_CHECKING:
# only used as a type hint; importing bsparse.models would pull in torch/transformers
from bsparse.models import Model


THREADS = int(os.environ.get("SEISMIC_THREADS", os.cpu_count()))
Expand Down Expand Up @@ -177,7 +181,7 @@ def index(self):
def query_from_raw_text(
self,
queries: list[str],
model: Model,
model: "Model",
k: int = 1000,
query_cut: int = DEFAULT_QUERY_CUT,
heap_factor: float = DEFAULT_HEAP_FACTOR,
Expand Down
Loading