diff --git a/.github/workflows/publish-release.yml b/.github/workflows/publish-release.yml index e98fc23..fd7ce5e 100644 --- a/.github/workflows/publish-release.yml +++ b/.github/workflows/publish-release.yml @@ -14,7 +14,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 with: - python-version: "3.11" + python-version: "3.12" - name: Install pypa/build run: >- python3 -m diff --git a/README.md b/README.md index ea59c5f..3e24e6c 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,19 @@ source venv/bin/activate uv pip install -r requirements.txt ``` +The base package only needs lightweight dependencies: indexing and searching with the +anserini/seismic backends start from already-encoded JSONL files, so they work without torch. +Optional dependency groups cover the rest: +- `encode`: torch, transformers, etc. for encoding text into sparse representations +- `anserini`: pyjnius for the Anserini search backend (a JAR is also required; see below) +- `seismic`: pyseismic-lsr for the Seismic backend +- `all`: everything above + +``` +# e.g., install everything needed to encode and to search with the Seismic backend +uv pip install 'bsparse[encode,seismic]' +``` + ``` # Request access to splade-v3: https://huggingface.co/naver/splade-v3 # Get your huggingface API token and then: @@ -71,7 +84,7 @@ backends. ``` # install the Seismic Python bindings (optional dependency; only needed for this backend) -uv pip install pyseismic-lsr +uv pip install 'bsparse[seismic]' # or: uv pip install pyseismic-lsr # for best performance, build against your CPU instead: # RUSTFLAGS="-C target-cpu=native" uv pip install --no-binary :all: pyseismic-lsr diff --git a/bsparse/__init__.py b/bsparse/__init__.py index 7b35b87..6b82190 100644 --- a/bsparse/__init__.py +++ b/bsparse/__init__.py @@ -10,4 +10,4 @@ from .utils import batch_encode, get_torch_device, token_ids_to_binary_vec -__version__ = "0.2.0" +__version__ = "0.3.0" diff --git a/bsparse/anserini.py b/bsparse/anserini.py index 8c0ddf4..af648e2 100644 --- a/bsparse/anserini.py +++ b/bsparse/anserini.py @@ -3,13 +3,18 @@ import tempfile from collections import Counter from dataclasses import dataclass +from typing import TYPE_CHECKING from trecrun import TRECRun -from bsparse.models import Model from bsparse.utils import psgid_to_docid +if TYPE_CHECKING: + # only used as a type hint; importing bsparse.models would pull in torch/transformers + from bsparse.models import Model + + ANSERINI_JAR = os.environ.get("ANSERINI_JAR", "anserini-1.0.0-fatjar-bsparse.jar") JAVA_ARGS = os.environ.get("ANSERINI_JAVA_ARGS", "-Xms4g,-Xmx16g").split(",") THREADS = os.environ.get("ANSERINI_THREADS", os.cpu_count()) @@ -23,7 +28,7 @@ class Anserini: def __init__(self, index_path: str): self.index_path = index_path - def query_from_raw_text(self, queries: list[str], model: Model, k: int = 1000, scale: int = 50): + def query_from_raw_text(self, queries: list[str], model: "Model", k: int = 1000, scale: int = 50): dataset = [(str(idx), query) for idx, query in enumerate(queries)] ids, reps = model.encode(dataset) vectors = [{"vector": rep} for rep in reps] diff --git a/bsparse/cli.py b/bsparse/cli.py index aad14a3..1058c77 100644 --- a/bsparse/cli.py +++ b/bsparse/cli.py @@ -1,8 +1,6 @@ import argparse import bsparse.commands -import bsparse.datasets -import bsparse.models COMMANDS = { @@ -13,22 +11,34 @@ "memsearch": bsparse.commands.MemSearch, } -DEFAULT_MODELS = { - "splade": bsparse.models.SpladeModel, - "spladepsg": bsparse.models.SpladePsgModel, - "multilsr": bsparse.models.MultiLSRModel, -} DEFAULT_MODEL = "splade" - -DEFAULT_DATASETS = { - "irds": bsparse.datasets.IRDSDataset, - "jsonl": bsparse.datasets.JSONLDataset, - "tsv": bsparse.datasets.TSVDataset, - "hgf": bsparse.datasets.HgfDataset, -} DEFAULT_DATASET = "irds" +# bsparse.models and bsparse.datasets are imported lazily so that commands that +# don't need them (e.g. index/search on already-encoded jsonl) don't pull in +# big packages like torch, transformers, and datasets +def get_models(): + import bsparse.models + + return { + "splade": bsparse.models.SpladeModel, + "spladepsg": bsparse.models.SpladePsgModel, + "multilsr": bsparse.models.MultiLSRModel, + } + + +def get_datasets(): + import bsparse.datasets + + return { + "irds": bsparse.datasets.IRDSDataset, + "jsonl": bsparse.datasets.JSONLDataset, + "tsv": bsparse.datasets.TSVDataset, + "hgf": bsparse.datasets.HgfDataset, + } + + def get_command(): parser = argparse.ArgumentParser(description="bsparse CLI") parser.add_argument("command", choices=list(COMMANDS.keys()), help="Command") @@ -45,13 +55,16 @@ def main(): # parse the command name, so that we can see whether it needs --dataset and --model command_cls = COMMANDS[get_command()] + datasets = get_datasets() if command_cls.needs_dataset else None + models = get_models() if command_cls.needs_model else None + # parse the --dataset and --model, so we can add dataset-specific and model-specific args minimal_parser = argparse.ArgumentParser(description="bsparse CLI") minimal_parser.add_argument("command", choices=list(COMMANDS.keys()), help="Command") if command_cls.needs_dataset: - minimal_parser.add_argument("--dataset", choices=list(DEFAULT_DATASETS.keys()), required=True, help="Dataset") + minimal_parser.add_argument("--dataset", choices=list(datasets.keys()), required=True, help="Dataset") if command_cls.needs_model: - minimal_parser.add_argument("--model", choices=list(DEFAULT_MODELS.keys()), required=True, help="Model") + minimal_parser.add_argument("--model", choices=list(models.keys()), required=True, help="Model") known_args, remaining_args = minimal_parser.parse_known_args() @@ -59,18 +72,18 @@ def main(): full_parser = argparse.ArgumentParser(parents=[minimal_parser], add_help=False) if command_cls.needs_dataset: - DEFAULT_DATASETS[known_args.dataset].add_arguments(full_parser) + datasets[known_args.dataset].add_arguments(full_parser) if command_cls.needs_model: - DEFAULT_MODELS[known_args.model].add_arguments(full_parser) + models[known_args.model].add_arguments(full_parser) command_cls.add_arguments(full_parser) args = full_parser.parse_args() kwargs = {} if command_cls.needs_dataset: - kwargs["dataset"] = DEFAULT_DATASETS[args.dataset](args) + kwargs["dataset"] = datasets[args.dataset](args) if command_cls.needs_model: - kwargs["model"] = DEFAULT_MODELS[args.model](args) + kwargs["model"] = models[args.model](args) command = command_cls(args, **kwargs) command.run() diff --git a/bsparse/commands.py b/bsparse/commands.py index 02ab128..7a1ae63 100644 --- a/bsparse/commands.py +++ b/bsparse/commands.py @@ -1,10 +1,10 @@ import json +import os from abc import ABC, abstractmethod from functools import partial from multiprocessing import Pool from pathlib import Path -import ir_datasets as irds from tqdm import tqdm from trecrun import TRECRun @@ -14,6 +14,50 @@ from bsparse.utils import psgid_to_docid +def load_qrels(qrels: str) -> dict[str, dict[str, int]]: + """Load relevance judgments into a {query_id: {doc_id: relevance}} dict. + + `qrels` may be a TREC-format qrels file or an ir_datasets name: an existing file is parsed as + 'qid iteration docid relevance' lines, and anything else is treated as an ir_datasets name. + """ + all_qrels = {} + + if os.path.isfile(qrels): + with open(qrels, "rt", encoding="utf-8") as f: + for line_no, line in enumerate(f, start=1): + if not line.strip(): + continue + try: + qid, _iteration, docid, relevance = line.split() + all_qrels.setdefault(qid, {})[docid] = int(relevance) + except ValueError as e: + raise ValueError( + f"invalid TREC qrels on line {line_no} of {qrels} (expected 'qid iteration docid relevance'): {line!r}" + ) from e + return all_qrels + + # ir_datasets is imported lazily to keep the no-qrels paths light + import ir_datasets as irds + + try: + dataset = irds.load(qrels) + except KeyError: + raise ValueError( + f"--qrels value is neither an existing file nor a known ir_datasets name: {qrels} " + "(pass a TREC-format qrels file path, or an ir_datasets name like 'beir/nfcorpus/test')" + ) from None + + if not dataset.has_qrels(): + raise ValueError( + f"ir_datasets dataset has no qrels: {qrels} " + "(pass a subset that includes relevance judgments, e.g. 'beir/nfcorpus/test' rather than 'beir/nfcorpus')" + ) + + for qr in dataset.qrels_iter(): + all_qrels.setdefault(qr.query_id, {})[qr.doc_id] = qr.relevance + return all_qrels + + class Command(ABC): needs_model = False needs_dataset = False @@ -100,7 +144,12 @@ def add_arguments(cls, parser): parser.add_argument("--out", type=Path, required=True, help="Output file path") parser.add_argument("--pool", type=int, default=20, help="Multiprocessing pool size (default: %(default)s)") parser.add_argument("--topk", type=int, default=1000, help="Top K results to return (default: %(default)s)") - parser.add_argument("--qrels", type=str, default=None, help="Relevance judgments dataset (default: %(default)s)") + parser.add_argument( + "--qrels", + type=str, + default=None, + help="Relevance judgments: TREC qrels file or ir_datasets name (default: %(default)s)", + ) parser.add_argument( "--aggregate", type=str, @@ -116,6 +165,10 @@ def run(self): if self.cfg.out.is_dir(): raise ValueError(f"--out is a directory: {self.cfg.out}") + if self.cfg.qrels: + # load the qrels before the (slow) search so a bad --qrels fails fast + all_qrels = load_qrels(self.cfg.qrels) + with Pool(self.cfg.pool) as p: score_f = partial(score_shard, queries_fn=self.cfg.queries, topk=self.cfg.topk, aggregate=self.cfg.aggregate) shard_scores = list( @@ -142,10 +195,6 @@ def run(self): if self.cfg.qrels: print(f"evaluating with qrels: {self.cfg.qrels}") - all_qrels = {} - for qr in irds.load(self.cfg.qrels).qrels_iter(): - all_qrels.setdefault(qr.query_id, {})[qr.doc_id] = qr.relevance - metrics = run.evaluate(all_qrels) avg = {metric: vals["mean"] for metric, vals in metrics.items()} print(json.dumps(avg, indent=4, sort_keys=True)) @@ -254,7 +303,12 @@ def add_arguments(cls, parser): help="Search backend (default: %(default)s)", ) parser.add_argument("--topk", type=int, default=1000, help="Top K results to return (default: %(default)s)") - parser.add_argument("--qrels", type=str, default=None, help="Relevance judgments dataset (default: %(default)s)") + parser.add_argument( + "--qrels", + type=str, + default=None, + help="Relevance judgments: TREC qrels file or ir_datasets name (default: %(default)s)", + ) # Backend-specific args default to None so we can tell whether the user set them: unset args fall # back to the backend's own default, and passing an arg for a different backend is rejected (see run()). # anserini-specific @@ -284,6 +338,10 @@ def run(self): if self.cfg.out.is_dir(): raise ValueError(f"--out is a directory: {self.cfg.out}") + if self.cfg.qrels: + # load the qrels before the (slow) search so a bad --qrels fails fast + all_qrels = load_qrels(self.cfg.qrels) + queries = load_dict(self.cfg.queries) queries.ids = [psgid_to_docid(qid) for qid in queries.ids] vectors = [{"vector": rep} for rep in queries.weights] @@ -309,10 +367,6 @@ def run(self): if self.cfg.qrels: print(f"evaluating with qrels: {self.cfg.qrels}") - all_qrels = {} - for qr in irds.load(self.cfg.qrels).qrels_iter(): - all_qrels.setdefault(qr.query_id, {})[qr.doc_id] = qr.relevance - metrics = run.evaluate(all_qrels) avg = {metric: vals["mean"] for metric, vals in metrics.items()} print(json.dumps(avg, indent=4, sort_keys=True)) diff --git a/bsparse/convert.py b/bsparse/convert.py index d7a02ad..643771d 100644 --- a/bsparse/convert.py +++ b/bsparse/convert.py @@ -1,8 +1,10 @@ -import torch - from bsparse.jsonl import SparseRepresentations, dict2jsonl, jsonl2dict +# torch is imported inside the functions that need it, so that torch-free paths +# (e.g. anserini/seismic index/search over already-encoded jsonl) work without it + + def load_dict(fn) -> SparseRepresentations: return jsonl2dict(fn) @@ -49,10 +51,14 @@ def dict2vec(ds, term2id): if isinstance(ds, dict): return _single_dict2vec(ds, term2id) + import torch + return torch.stack([_single_dict2vec(d, term2id) for d in ds]) def _single_dict2vec(d, term2id): + import torch + terms, weights = zip(*d.items()) termids = [term2id[term] for term in terms] vec = torch.zeros(len(term2id), dtype=torch.float32) @@ -61,6 +67,8 @@ def _single_dict2vec(d, term2id): def token_ids_to_binary_vec(input_ids, attention_mask, special_tokens_mask, vocab_size): + import torch + binary_ids = torch.ones_like(input_ids, dtype=torch.float) * attention_mask * (1 - special_tokens_mask) batch_size = binary_ids.shape[0] sparse_rep = torch.zeros((batch_size, vocab_size), device=binary_ids.device).scatter_reduce_( diff --git a/bsparse/seismic.py b/bsparse/seismic.py index 5e9be6c..79894c1 100644 --- a/bsparse/seismic.py +++ b/bsparse/seismic.py @@ -5,11 +5,15 @@ import tempfile import warnings from contextlib import contextmanager +from typing import TYPE_CHECKING import numpy as np from tqdm import tqdm -from bsparse.models import Model + +if TYPE_CHECKING: + # only used as a type hint; importing bsparse.models would pull in torch/transformers + from bsparse.models import Model THREADS = int(os.environ.get("SEISMIC_THREADS", os.cpu_count())) @@ -177,7 +181,7 @@ def index(self): def query_from_raw_text( self, queries: list[str], - model: Model, + model: "Model", k: int = 1000, query_cut: int = DEFAULT_QUERY_CUT, heap_factor: float = DEFAULT_HEAP_FACTOR, diff --git a/bsparse/utils.py b/bsparse/utils.py index 5041c5f..2a100df 100644 --- a/bsparse/utils.py +++ b/bsparse/utils.py @@ -1,10 +1,13 @@ import os from functools import cache -import torch from tqdm import tqdm +# torch is imported inside the functions that need it, so that torch-free paths +# (e.g. anserini/seismic index/search over already-encoded jsonl) work without it + + def psgid_to_docid(psgid): """Convert passage IDs (`doc123::psg2`) to their docID (`doc123`)""" return psgid.rsplit("::psg", 1)[0] @@ -12,6 +15,8 @@ def psgid_to_docid(psgid): @cache def get_torch_device(): + import torch + if torch.cuda.is_available(): if os.environ.get("ASSERT_GPU_SPECIFIED", "false").lower() == "true" and not os.environ.get("CUDA_VISIBLE_DEVICES", ""): raise OSError("ASSERT_GPU_SPECIFIED=true but CUDA_VISIBLE_DEVICES is empty") @@ -25,6 +30,8 @@ def get_torch_device(): def batch_encode(tokenized, model_encodef, device, batch_size: int = 128): + import torch + encoded = [] with torch.no_grad(): for i in tqdm( @@ -39,6 +46,8 @@ def batch_encode(tokenized, model_encodef, device, batch_size: int = 128): def batch_encode_untok(data, model_encodef, device, batch_size: int = 128): + import torch + encoded = [] with torch.no_grad(): for i in tqdm( @@ -53,6 +62,8 @@ def batch_encode_untok(data, model_encodef, device, batch_size: int = 128): def token_ids_to_binary_vec(input_ids, attention_mask, special_tokens_mask, vocab_size): + import torch + binary_ids = torch.ones_like(input_ids, dtype=torch.float) * attention_mask * (1 - special_tokens_mask) batch_size = binary_ids.shape[0] sparse_rep = torch.zeros((batch_size, vocab_size), device=binary_ids.device).scatter_reduce_( diff --git a/requirements.txt b/requirements.txt index 664cd5b..8ddaeb0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,10 @@ +# full dev environment: the base requirements plus the 'encode' extra (see setup.py). +# the 'anserini' and 'seismic' extras are not needed for tests, which stub the native packages. ir_datasets numpy -torch tqdm -transformers trecrun~=0.3.0 datasets +safetensors +torch +transformers diff --git a/setup.py b/setup.py index c85e94a..57aae40 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,17 @@ long_description = fh.read() +# the base install covers the index/search paths, which operate on already-encoded +# jsonl files and don't need torch; encoding requires the 'encode' extra, and each +# search backend has its own extra ('anserini' also needs a JAR; see the README) +extras_require = { + "encode": ["datasets", "safetensors", "torch", "transformers"], + "anserini": ["pyjnius"], + "seismic": ["pyseismic-lsr"], +} +extras_require["all"] = sorted({dep for deps in extras_require.values() for dep in deps}) + + # from https://packaging.python.org/guides/single-sourcing-package-version/ def read(rel_path): here = os.path.abspath(os.path.dirname(__file__)) @@ -33,12 +44,13 @@ def get_version(rel_path): long_description_content_type="text/markdown", url="https://github.com/hltcoe/bsparse", packages=setuptools.find_packages(), - install_requires=["ir_datasets", "numpy", "torch", "tqdm", "transformers"], + install_requires=["ir_datasets", "numpy", "tqdm", "trecrun~=0.3.0"], + extras_require=extras_require, classifiers=[ "Programming Language :: Python :: 3", "Operating System :: OS Independent", ], - python_requires=">=3.9", + python_requires=">=3.10", include_package_data=True, entry_points={ "console_scripts": [ diff --git a/tests/test_commands.py b/tests/test_commands.py index 792e3f5..524ffdf 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -4,9 +4,8 @@ def _make_search(**kwargs): - # importing bsparse.commands requires trecrun/ir_datasets, which are optional in the test env + # importing bsparse.commands requires trecrun, which is optional in the test env pytest.importorskip("trecrun") - pytest.importorskip("ir_datasets") from bsparse.commands import Search defaults = {"backend": "anserini", "scale": None, "query_cut": None, "heap_factor": None} @@ -30,3 +29,51 @@ def test_search_rejects_args_for_other_backend(): with pytest.raises(ValueError, match="anserini backend"): _make_search(backend="seismic", scale=80)._check_backend_args() + + +def test_load_qrels_from_trec_file(tmp_path): + pytest.importorskip("trecrun") + from bsparse.commands import load_qrels + + qrels_fn = tmp_path / "qrels.txt" + qrels_fn.write_text("q1 0 d1 1\nq1 0 d2 0\n\nq2 Q0 d1 2\n") + assert load_qrels(str(qrels_fn)) == {"q1": {"d1": 1, "d2": 0}, "q2": {"d1": 2}} + + +def test_load_qrels_rejects_invalid_file(tmp_path): + pytest.importorskip("trecrun") + from bsparse.commands import load_qrels + + qrels_fn = tmp_path / "qrels.txt" + qrels_fn.write_text("q1 0 d1 1\nq1 d2 0\n") + with pytest.raises(ValueError, match="line 2"): + load_qrels(str(qrels_fn)) + + qrels_fn.write_text("q1 0 d1 relevant\n") + with pytest.raises(ValueError, match="line 1"): + load_qrels(str(qrels_fn)) + + +def test_load_qrels_rejects_dataset_without_qrels(monkeypatch): + pytest.importorskip("trecrun") + import sys + import types + + from bsparse.commands import load_qrels + + # stub ir_datasets with a dataset that exists but has no relevance judgments + fake_irds = types.SimpleNamespace(load=lambda name: types.SimpleNamespace(has_qrels=lambda: False)) + monkeypatch.setitem(sys.modules, "ir_datasets", fake_irds) + + with pytest.raises(ValueError, match="has no qrels"): + load_qrels("beir/nfcorpus") + + +def test_load_qrels_rejects_unknown_name(tmp_path): + pytest.importorskip("trecrun") + pytest.importorskip("ir_datasets") + from bsparse.commands import load_qrels + + # not an existing file, and not a known ir_datasets name + with pytest.raises(ValueError, match="neither an existing file nor a known ir_datasets name"): + load_qrels(str(tmp_path / "no-such-qrels.txt")) diff --git a/tests/test_convert.py b/tests/test_convert.py index d480b44..68c75c6 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -1,7 +1,10 @@ import pytest -import torch -from bsparse.convert import ( + +# these tests exercise the torch-based conversion functions; torch is optional in the test env +torch = pytest.importorskip("torch") + +from bsparse.convert import ( # noqa: E402 dict2vec, vec2dict, )