Adding UniSRec model implemented on lightweight class hierarchy with pytorch preprocessing #306
Adding UniSRec model implemented on lightweight class hierarchy with pytorch preprocessing #306TOPAPEC wants to merge 18 commits into
Conversation
Standalone sequential recommender package, mimics ModelBase interface without touching existing rectools code. FlatSASRec - plain ID-embedding SASRec encoder. UniSRec - pretrained text embeddings + PCA/BN adaptor, 3-phase training (ID emb -> adaptor only -> full finetune). Uses lightweight rank_topk instead of TorchRanker, reuses SASRecDataPreparator for the data pipeline. 30 tests, smoke scripts for both models. Fix: NaN*0=NaN in IEEE 754 breaks attention padding masking via multiplication, switched to masked_fill.
New config options: - ffn_type: conv1d / linear_gelu / linear_relu + ffn_expansion - optimizer: adam / adamw - scheduler: cosine_warmup (with warmup_ratio, min_lr_ratio) - loss: softmax / BCE / gBCE / sampled_softmax (with gbce_t) - patience: early stopping via EarlyStopping callback + val split - data_preparator: accept custom preparator instance 31 tests passing.
2e923df to
d68834f
Compare
There was a problem hiding this comment.
Pull request overview
Adds a new rectools.fast_transformers subpackage providing GPU-native preprocessing and standalone sequential transformer recommenders (FlatSASRec + UniSRec), plus ranking utilities, scripts, and comprehensive tests.
Changes:
- Introduces torch-native sequence building (
build_sequences), embedding alignment, and lightweight dataset/dataloader helpers. - Adds UniSRec (pretrained text embeddings + adaptor + SASRec encoder) with Lightning training wrapper and a standalone
UniSRecModelAPI (fit/checkpoint/ONNX export). - Adds
rank_topk()for batched scoring with CSR filtering + whitelist, along with benchmark scripts and extensive test coverage.
Reviewed changes
Copilot reviewed 17 out of 19 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| rectools/fast_transformers/init.py | Exposes the new fast_transformers public API surface. |
| rectools/fast_transformers/gpu_data.py | Implements torch-native preprocessing utilities (sequence building, embedding alignment, dataloader helpers). |
| rectools/fast_transformers/net.py | Adds FlatSASRec network implementation. |
| rectools/fast_transformers/ranking.py | Adds rank_topk() batching + filtering + whitelist ranking utility. |
| rectools/fast_transformers/unisrec_lightning.py | Adds LightningModule wrapper (loss/optimizer/scheduler dispatch) for UniSRec training phases. |
| rectools/fast_transformers/unisrec_model.py | Adds standalone UniSRecModel (3-phase training, checkpointing, ONNX export, ID mapping). |
| rectools/fast_transformers/unisrec_net.py | Adds UniSRec network (adaptor + transformer encoder + helper methods). |
| tests/fast_transformers/init.py | Test package marker for fast_transformers. |
| tests/fast_transformers/test_gpu_data.py | Tests for sequence building, embedding alignment, dataset/dataloader, and hashing. |
| tests/fast_transformers/test_net.py | Tests for FlatSASRec forward paths and encoding helpers. |
| tests/fast_transformers/test_onnx_export.py | Tests ONNX export/roundtrip for UniSRec network and UniSRecModel export. |
| tests/fast_transformers/test_ranking.py | Tests top-k ranking, filtering, whitelist behavior, and edge cases. |
| tests/fast_transformers/test_unisrec_lightning.py | Tests UniSRecLightning configuration + loss/scheduler dispatch behavior. |
| tests/fast_transformers/test_unisrec_model.py | Tests UniSRecModel fit phases, losses/optimizers/schedulers, checkpointing, and mapping. |
| tests/fast_transformers/test_unisrec_net.py | Tests UniSRec network output shapes, adaptor variants, and freeze/unfreeze helpers. |
| scripts/compare_sasrec_unisrec.py | Benchmark script to compare RecTools SASRec vs UniSRec-ID and generate a report. |
| scripts/comparison_report.md | Adds a sample benchmark report output. |
| CHANGELOG.md | Documents the new module and features under Unreleased. |
| .gitignore | Ignores new dev artifacts, model weights, and data folders. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Add hash-based ID mapping (splitmix64) as alternative to dense torch.unique mapping in build_sequences and align_embeddings. - Add UniSRecModel.export_to_onnx() for native ONNX export of encoder and item embeddings (project_all). - Add UniSRecModel.map_item_ids() for external→internal ID conversion at inference time (works for both dense and hash modes). - Remove FlatSASRecModel/FlatSASRecLightning (RecTools-coupled wrappers that duplicated UniSRecModel functionality). - Add tests: hash mapping (including string-derived IDs), ONNX export roundtrip, map_item_ids for both modes.
- Remove ranking.py (duplicates TorchRanker) - Remove hash ID mapping from build_sequences/align_embeddings - Simplify UniSRecModel to single joint training phase (adaptor + transformer) - Rename gpu_data.py -> sequence_data.py, GPUBatchDataset -> SequenceBatchDataset - Vectorize map_item_ids with torch.searchsorted - Fix device default (None -> auto-detect from input tensor) - Fix double torch.unique call - Add empty dataset validation in fit() - Add **kwargs to make_dataloader - Add dataloader_num_workers passthrough - Move benchmark script to benchmark/ folder - Add KION training demo with Qwen3-Embedding-0.6B results - Update tests for simplified API - Clean up CHANGELOG and .gitignore
- Remove item_emb, use_id, freeze/unfreeze, phase references from net/lightning - Remove GPUBatchDataset alias and make_dataloader wrapper - Reorganize into preprocessing/ and unisrec/ subpackages - Add GPU-friendly HR@K, NDCG@K, MRR@K metrics (tested against RecTools) - Update benchmark, demo, and all tests (102 passed + 28 metric tests)
d68834f to
45ed8ae
Compare
- Add negative sampling transform in fit() for BCE/gBCE/sampled_softmax losses - Add e2e tests for all non-softmax losses via UniSRecModel.fit() - Fix load_checkpoint() default device: auto-detect cuda/cpu instead of hardcoded "cuda" - Fix map_item_ids() device mismatch when input is on CUDA - Fix Python 3.9 compat: replace PEP 604 unions with Optional[] in tests - Fix CHANGELOG: remove nonexistent FlatSASRecModel and make_dataloader() - Update benchmark: auto-download ML-20M, fallback random embeddings, fix paths
…ings, n_negatives validation - Run black/isort/flake8 on all fast_transformers files — all pass now - Fix val dataloader missing negatives when patience + non-softmax loss - Extract _NegativeSampler class: device-aware, resamples positive collisions - Validate n_negatives is a positive integer for non-softmax losses - Make align_embeddings() device-aware (supports CUDA pretrained embeddings) - Remove unused imports (os in benchmark, pytest in test_sequence_data) - Add CUDA guard in benchmark main() - Add e2e tests: non-softmax losses with patience, n_negatives=0/-1/None
Keep only device-awareness (the actual review request). Preserving pretrained.dtype could cause precision issues with float16 inputs.
- Add `device` parameter to UniSRecModel.__init__ (default None = input device) - Move x/y to CPU before DataLoader to avoid CUDA+multiprocessing issues - Benchmark: pass device="cuda" explicitly to build_sequences and UniSRecModel
…pylint, bandit) - Add type annotations across benchmark, tests, and source files (mypy 30→0 errors) - Annotate frozen_emb buffer and Optional head in net.py - Add assert guards for Optional item_id_mapping usage - Type sasrec_kwargs and nested functions in benchmark - Fix tensor index type in test_metrics
- isort: fix import ordering in __init__.py files and test_metrics.py - black: auto-format all new files to project style - flake8: add per-file-ignores in setup.cfg for new modules (D102, N806, N812, D401); fix D403 capitalization in test docstring - mypy: fix arg-type for align_embeddings (add assert for Optional), fix slice index type in test_unisrec_model - pylint: rename unused vars (B -> _B, unique_items -> _unique_items, y -> _y), move math import to top-level in metrics.py, add pylint-disable for too-many-* / protected-access / not-callable / redefined-outer-name, use dict literals instead of dict() - codespell: already clean - bandit: already clean
Add comprehensive tutorial covering model training, evaluation, inference, checkpointing, ONNX export, and comparison of different configurations (loss functions, adaptor types, optimizers).
| # Resample positions where negative == positive | ||
| collisions = negs == y.unsqueeze(-1) | ||
| if collisions.any(): | ||
| negs[collisions] = torch.randint(1, self.n_items + 1, (int(collisions.sum()),), device=y.device) |
| import numpy as np | ||
| import pandas as pd | ||
| import requests | ||
| import torch |
| e_b = net.encode_last(x_b) | ||
| torch.testing.assert_close(e_a, e_b) |
| k = topk_ids.shape[1] | ||
| hits = (topk_ids == targets.unsqueeze(1)).float() # (B, K) | ||
| ranks = torch.arange(1, k + 1, device=topk_ids.device, dtype=torch.float) | ||
| discounts = 1.0 / torch.log(ranks + 1) * (1.0 / _log(log_base)) |
|
|
||
| Parameters | ||
| ---------- | ||
| ffn_type : ``"conv1d"`` | ``"linear_gelu"`` | ``"linear_relu"`` |
There was a problem hiding this comment.
Could you please add all the parameters and the return
| return results | ||
|
|
||
|
|
||
| def _log(base: int) -> float: |
There was a problem hiding this comment.
it's used only once and it's a one-liner - are you sure you need it?
| ## [Unreleased] | ||
|
|
||
| ### Added | ||
| - `rectools.fast_transformers` module — standalone transformer-based sequential recommenders that work directly with torch tensors, bypassing the `Dataset`/pandas pipeline. GPU-native sequence building via `build_sequences()` gives ~30x preprocessing speedup over `SASRecDataPreparator` on ML-20M |
There was a problem hiding this comment.
Please add links to the PR (see examples below)
| ignore = D205,D400,D105,D100,E203,W503 | ||
| per-file-ignores = | ||
| tests/*: D100,D101,D102,D103,D104 | ||
| tests/*: D100,D101,D102,D103,D104,N806 |
There was a problem hiding this comment.
why do we need to ignore N806? Can't we don't violate it?
Is it violated in many places? (If not, better to have local ignores)
If absolutely necessary, please specify it only for the fast_transformers subfolder
| rectools/fast_transformers/net.py: D102,N806 | ||
| rectools/fast_transformers/unisrec/lightning.py: D102,D401,N812 | ||
| rectools/fast_transformers/unisrec/model.py: D102 | ||
| rectools/fast_transformers/unisrec/net.py: D102,N806 |
There was a problem hiding this comment.
- It's not the best way to specify rules per file - unless really needed
- Please don't ignore D102 and D401 - all public methods must have a docstring, and in the correct style
- I'd also really prefer not to violate N812
- N806 is not so critical
feldlime
left a comment
There was a problem hiding this comment.
A few more findings that haven't been raised yet — mostly around UniSRecModel.fit/save/load/predict/export lifecycle and the embedding-alignment helper. Nothing here blocks the architecture, but several would bite users in practice.
| if self.patience is not None: | ||
| val_y_last = y[:, -1:] | ||
| val_dl = DataLoader( | ||
| SequenceBatchDataset(x, val_y_last, transform=neg_transform), |
There was a problem hiding this comment.
Early stopping is measured on the training data: val_dl wraps the same x as train_dl (just with y reduced to the last position). So val_loss decreases monotonically as the model overfits, and patience either never triggers or triggers for unrelated reasons (noise / regularization stochasticity). Either accept an explicit validation split in fit() (e.g. val_user_ids/val_item_ids/val_timestamps, or hold out the last interaction per user), or document that patience is effectively a train-loss plateau heuristic, not real early stopping.
| def load_checkpoint(self, path: tp.Union[str, Path], device: tp.Optional[str] = None) -> None: | ||
| if device is None: | ||
| device = "cuda" if torch.cuda.is_available() else "cpu" | ||
| ckpt = torch.load(path, map_location=device, weights_only=False) # nosec B614 |
There was a problem hiding this comment.
torch.load(path, weights_only=False) is an arbitrary-pickle deserialization — silencing bandit with # nosec B614 hides the risk, it doesn't address it. The checkpoint here only stores tensors and a Python int, so this should work with weights_only=True once unique_items/unique_users are saved as plain tensors (already the case). Please switch to weights_only=True (PyTorch is making this the default in 2.6+ anyway), and we avoid shipping a model that will RCE on a malicious checkpoint.
| "unique_items": self._unique_items, | ||
| "unique_users": self._unique_users, | ||
| "n_items": len(self._unique_items), | ||
| }, |
There was a problem hiding this comment.
save_checkpoint stores state_dict but no architecture hyperparameters (n_factors, n_blocks, n_heads, session_max_len, adaptor_type, use_adaptor_ffn, ffn_type, ffn_expansion, ...). load_checkpoint rebuilds UniSRec from self.* — i.e. whatever the user passed to the new UniSRecModel(...) constructor before calling load_checkpoint. If they don't pass the same arch, load_state_dict either fails loudly (shape mismatch) or, worse, silently loads partial weights. Save the relevant hparams in the checkpoint dict and restore them in load_checkpoint, or at minimum add a @classmethod from_checkpoint(cls, path) that constructs the model from saved hparams.
| item_embs = net.project_all() | ||
| scores = h @ item_embs.T | ||
| scores[:, 0] = float("-inf") | ||
| top_scores, top_ids = scores.topk(k, dim=1) |
There was a problem hiding this comment.
predict_topk returns internal item IDs (1..n_items), but the user gives external IDs everywhere else (in fit, in map_item_ids). The model already holds the mapping (self._unique_items), so making every caller post-process top_ids is friction and an easy source of subtle bugs (e.g. computing HR/NDCG against external targets while comparing to internal IDs). Either return external IDs by default and offer an internal=True flag, or change the return type to a small dataclass that documents the ID space.
| input_names=["input_ids"], | ||
| output_names=["hidden"], | ||
| opset_version=opset_version, | ||
| ) |
There was a problem hiding this comment.
The ONNX export is traced with dummy = torch.zeros(1, 5, ...) and no dynamic_axes, so the resulting graph only accepts inputs of shape (1, 5). In practice anyone exporting for serving wants dynamic batch and (often) dynamic sequence length up to session_max_len. Suggest passing dynamic_axes={"input_ids": {0: "batch", 1: "seq"}, "hidden": {0: "batch", 1: "seq"}} (and similar for the project_all wrapper — its output is (n_items+1, D), so dynamic n_items if you ever expect to re-export after refit). Also worth using session_max_len for the dummy length so positional embeddings are exercised at their full range during tracing.
| else: | ||
| aligned = torch.zeros(n_items + 1, pretrained.shape[1], pretrained.shape[2], device=device) | ||
|
|
||
| aligned[1:][valid] = pretrained[idx[valid]] |
There was a problem hiding this comment.
Two issues here:
-
valid = (idx >= 0) & (idx < pretrained.shape[0])silently zeros out unknown item embeddings — making them indistinguishable from the padding row (also zeros). Ifpretrainedwas built from a stale catalog, the model will silently treat missing items as padding rather than warning. Suggest raising (or at least logging) when(~valid).any(). -
torch.zeros(n_items + 1, ...)defaults tofloat32, so ifpretrainedisbfloat16/float16/float64the alignment silently changes dtype (and then everything downstream is fp32). Usedtype=pretrained.dtype, device=pretrained.device.
| scores[:, 0] = float("-inf") | ||
| top_scores, top_ids = scores.topk(k, dim=1) | ||
| if was_training: | ||
| net.train() |
There was a problem hiding this comment.
net.eval() ... if was_training: net.train() isn't exception-safe — if encode_last raises (CUDA OOM, shape mismatch, ...) the network is left in eval() mode for the rest of the session, which silently disables dropout on subsequent fit(). Same pattern in export_to_onnx. Wrap in try/finally or use a small context manager.
| from .metrics import compute_metrics, hitrate_at_k, mrr_at_k, ndcg_at_k | ||
| from .net import FlatSASRec, SASRecBlock | ||
| from .preprocessing import SequenceBatchDataset, align_embeddings, build_sequences | ||
| from .unisrec import UniSRec, UniSRecLightning, UniSRecModel |
There was a problem hiding this comment.
Top-level package exposes both UniSRec (the nn.Module) and UniSRecModel (the high-level trainer). Two classes whose only difference is a Model suffix is genuinely confusing on import (from rectools.fast_transformers import UniSRec vs UniSRecModel). Same goes for FlatSASRec (network) sitting next to UniSRecModel (trainer) in the same namespace. Suggest renaming the networks to UniSRecNet/FlatSASRecNet (matching lightning.py / model.py vocabulary), or only re-exporting the high-level *Model classes here and keeping nets reachable via rectools.fast_transformers.unisrec.net.
|
|
||
| def _full_softmax_loss(self, hidden: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: | ||
| all_emb = self._get_all_embs() | ||
| logits = hidden @ all_emb.T |
There was a problem hiding this comment.
_full_softmax_loss scores against project_all(), which uses variant 0 of frozen_emb (deterministic). The other losses score against _get_item_embs(...) = _adapt_score(_sample_frozen(...)), which randomly samples variants during training. For multi-variant pretrained embeddings this means the training signal is fundamentally different for softmax vs BCE/gBCE/sampled_softmax — softmax never sees variant augmentation. Either route the full-softmax path through a sampled-variant projection too, or document that variant augmentation only applies to negative-sampling losses.
Motivation
rectools.fast_transformerspackage for sequential recommendation models (UniSRec, FlatSASRec) that work directly with PyTorch tensors, without relying on theDataset/pandas pipeline. The main motivation is speeding up and simplifying extendability of the NN recommenders. Also aiming to simplify production use of rectools models by reducing boilerplates and making dataflow inside the model more straightforward.Description
rectools.fast_transformerspackage with:metrics— GPU-computed ranking metrics - much faster at scale compared to pandas ones.net— FlatSASRec implementation;preprocessing— vectorizedbuild_sequences,align_embeddings, andSequenceBatchDataset- also much faster at scale.unisrec— model network, Lightning module, high-level model API, ONNX export helpers, and demo docs.tests/fast_transformerscovering:.gitignore,CHANGELOG.md);benchmark/....Typing/Lint note
# type: ignore[import-untyped]forrequestsinbenchmark/compare_sasrec_unisrec.pysomypypasses in the current lint environment without changing dependency policy.Testing
tests/fast_transformersand are part of the standard test suite structure.