Skip to content

Adding UniSRec model implemented on lightweight class hierarchy with pytorch preprocessing #306

Open
TOPAPEC wants to merge 18 commits into
MTSWebServices:mainfrom
TOPAPEC:feat/unisrec-model
Open

Adding UniSRec model implemented on lightweight class hierarchy with pytorch preprocessing #306
TOPAPEC wants to merge 18 commits into
MTSWebServices:mainfrom
TOPAPEC:feat/unisrec-model

Conversation

@TOPAPEC
Copy link
Copy Markdown

@TOPAPEC TOPAPEC commented Apr 24, 2026

Motivation

  • Introduce a standalone rectools.fast_transformers package for sequential recommendation models (UniSRec, FlatSASRec) that work directly with PyTorch tensors, without relying on the Dataset/pandas pipeline. The main motivation is speeding up and simplifying extendability of the NN recommenders. Also aiming to simplify production use of rectools models by reducing boilerplates and making dataflow inside the model more straightforward.
  • Provide fast sequence preprocessing and embedding-alignment utilities needed for efficient training and benchmarking workflows.
  • Faster training on ml-20m fast_transformers/ vs rectools native pipeline - 900sec vs 1400sec - 10 epochs

Description

  • Added a new rectools.fast_transformers package with:
    • metrics — GPU-computed ranking metrics - much faster at scale compared to pandas ones.
    • net — FlatSASRec implementation;
    • preprocessing — vectorized build_sequences, align_embeddings, and SequenceBatchDataset - also much faster at scale.
    • unisrec — model network, Lightning module, high-level model API, ONNX export helpers, and demo docs.
  • Added unit tests under tests/fast_transformers covering:
    • metrics;
    • preprocessing and sequence building;
    • FlatSASRec and UniSRec networks;
    • Lightning wrapper behavior;
    • model fit/predict flows and ONNX export roundtrips.
  • Updated project/config artifacts to integrate the new package:
    • flake8 per-file ignores for the new modules;
    • repository metadata/docs artifacts (.gitignore, CHANGELOG.md);
    • benchmark artifacts under benchmark/....

Typing/Lint note

  • Added a targeted # type: ignore[import-untyped] for requests in benchmark/compare_sasrec_unisrec.py so mypy passes in the current lint environment without changing dependency policy.

Testing

  • New fast_transformers tests are included under tests/fast_transformers and are part of the standard test suite structure.

TOPAPEC and others added 4 commits April 22, 2026 18:28
Standalone sequential recommender package, mimics ModelBase interface
without touching existing rectools code.

FlatSASRec - plain ID-embedding SASRec encoder.
UniSRec - pretrained text embeddings + PCA/BN adaptor, 3-phase training
(ID emb -> adaptor only -> full finetune).

Uses lightweight rank_topk instead of TorchRanker, reuses
SASRecDataPreparator for the data pipeline.

30 tests, smoke scripts for both models.

Fix: NaN*0=NaN in IEEE 754 breaks attention padding masking via
multiplication, switched to masked_fill.
New config options:
- ffn_type: conv1d / linear_gelu / linear_relu + ffn_expansion
- optimizer: adam / adamw
- scheduler: cosine_warmup (with warmup_ratio, min_lr_ratio)
- loss: softmax / BCE / gBCE / sampled_softmax (with gbce_t)
- patience: early stopping via EarlyStopping callback + val split
- data_preparator: accept custom preparator instance

31 tests passing.
@TOPAPEC TOPAPEC changed the title Feat/unisrec model Adding UniSRec model implemented on lightweight class hierarchy with pytorch preprocessing Apr 24, 2026
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new rectools.fast_transformers subpackage providing GPU-native preprocessing and standalone sequential transformer recommenders (FlatSASRec + UniSRec), plus ranking utilities, scripts, and comprehensive tests.

Changes:

  • Introduces torch-native sequence building (build_sequences), embedding alignment, and lightweight dataset/dataloader helpers.
  • Adds UniSRec (pretrained text embeddings + adaptor + SASRec encoder) with Lightning training wrapper and a standalone UniSRecModel API (fit/checkpoint/ONNX export).
  • Adds rank_topk() for batched scoring with CSR filtering + whitelist, along with benchmark scripts and extensive test coverage.

Reviewed changes

Copilot reviewed 17 out of 19 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
rectools/fast_transformers/init.py Exposes the new fast_transformers public API surface.
rectools/fast_transformers/gpu_data.py Implements torch-native preprocessing utilities (sequence building, embedding alignment, dataloader helpers).
rectools/fast_transformers/net.py Adds FlatSASRec network implementation.
rectools/fast_transformers/ranking.py Adds rank_topk() batching + filtering + whitelist ranking utility.
rectools/fast_transformers/unisrec_lightning.py Adds LightningModule wrapper (loss/optimizer/scheduler dispatch) for UniSRec training phases.
rectools/fast_transformers/unisrec_model.py Adds standalone UniSRecModel (3-phase training, checkpointing, ONNX export, ID mapping).
rectools/fast_transformers/unisrec_net.py Adds UniSRec network (adaptor + transformer encoder + helper methods).
tests/fast_transformers/init.py Test package marker for fast_transformers.
tests/fast_transformers/test_gpu_data.py Tests for sequence building, embedding alignment, dataset/dataloader, and hashing.
tests/fast_transformers/test_net.py Tests for FlatSASRec forward paths and encoding helpers.
tests/fast_transformers/test_onnx_export.py Tests ONNX export/roundtrip for UniSRec network and UniSRecModel export.
tests/fast_transformers/test_ranking.py Tests top-k ranking, filtering, whitelist behavior, and edge cases.
tests/fast_transformers/test_unisrec_lightning.py Tests UniSRecLightning configuration + loss/scheduler dispatch behavior.
tests/fast_transformers/test_unisrec_model.py Tests UniSRecModel fit phases, losses/optimizers/schedulers, checkpointing, and mapping.
tests/fast_transformers/test_unisrec_net.py Tests UniSRec network output shapes, adaptor variants, and freeze/unfreeze helpers.
scripts/compare_sasrec_unisrec.py Benchmark script to compare RecTools SASRec vs UniSRec-ID and generate a report.
scripts/comparison_report.md Adds a sample benchmark report output.
CHANGELOG.md Documents the new module and features under Unreleased.
.gitignore Ignores new dev artifacts, model weights, and data folders.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread rectools/fast_transformers/preprocessing/sequence_data.py Outdated
Comment thread rectools/fast_transformers/gpu_data.py Outdated
Comment thread rectools/fast_transformers/unisrec/model.py
Comment thread rectools/fast_transformers/unisrec_model.py Outdated
Comment thread rectools/fast_transformers/ranking.py Outdated
Comment thread tests/fast_transformers/test_net.py Outdated
Comment thread tests/fast_transformers/test_unisrec_net.py
Comment thread rectools/fast_transformers/unisrec_model.py Outdated
Comment thread rectools/fast_transformers/unisrec/model.py
Comment thread rectools/fast_transformers/gpu_data.py Outdated
Comment thread rectools/fast_transformers/preprocessing/sequence_data.py
Comment thread .gitignore Outdated
Comment thread CHANGELOG.md Outdated
Comment thread benchmark/compare_sasrec_unisrec.py
Comment thread rectools/fast_transformers/gpu_data.py Outdated
Comment thread rectools/fast_transformers/gpu_data.py Outdated
Comment thread rectools/fast_transformers/ranking.py Outdated
TOPAPEC added 3 commits May 14, 2026 20:41
- Add hash-based ID mapping (splitmix64) as alternative to dense
  torch.unique mapping in build_sequences and align_embeddings.
- Add UniSRecModel.export_to_onnx() for native ONNX export of
  encoder and item embeddings (project_all).
- Add UniSRecModel.map_item_ids() for external→internal ID conversion
  at inference time (works for both dense and hash modes).
- Remove FlatSASRecModel/FlatSASRecLightning (RecTools-coupled wrappers
  that duplicated UniSRecModel functionality).
- Add tests: hash mapping (including string-derived IDs),
  ONNX export roundtrip, map_item_ids for both modes.
- Remove ranking.py (duplicates TorchRanker)
- Remove hash ID mapping from build_sequences/align_embeddings
- Simplify UniSRecModel to single joint training phase (adaptor + transformer)
- Rename gpu_data.py -> sequence_data.py, GPUBatchDataset -> SequenceBatchDataset
- Vectorize map_item_ids with torch.searchsorted
- Fix device default (None -> auto-detect from input tensor)
- Fix double torch.unique call
- Add empty dataset validation in fit()
- Add **kwargs to make_dataloader
- Add dataloader_num_workers passthrough
- Move benchmark script to benchmark/ folder
- Add KION training demo with Qwen3-Embedding-0.6B results
- Update tests for simplified API
- Clean up CHANGELOG and .gitignore
- Remove item_emb, use_id, freeze/unfreeze, phase references from net/lightning
- Remove GPUBatchDataset alias and make_dataloader wrapper
- Reorganize into preprocessing/ and unisrec/ subpackages
- Add GPU-friendly HR@K, NDCG@K, MRR@K metrics (tested against RecTools)
- Update benchmark, demo, and all tests (102 passed + 28 metric tests)
@TOPAPEC TOPAPEC force-pushed the feat/unisrec-model branch from d68834f to 45ed8ae Compare May 14, 2026 21:13
TOPAPEC and others added 9 commits May 15, 2026 13:54
- Add negative sampling transform in fit() for BCE/gBCE/sampled_softmax losses
- Add e2e tests for all non-softmax losses via UniSRecModel.fit()
- Fix load_checkpoint() default device: auto-detect cuda/cpu instead of hardcoded "cuda"
- Fix map_item_ids() device mismatch when input is on CUDA
- Fix Python 3.9 compat: replace PEP 604 unions with Optional[] in tests
- Fix CHANGELOG: remove nonexistent FlatSASRecModel and make_dataloader()
- Update benchmark: auto-download ML-20M, fallback random embeddings, fix paths
…ings, n_negatives validation

- Run black/isort/flake8 on all fast_transformers files — all pass now
- Fix val dataloader missing negatives when patience + non-softmax loss
- Extract _NegativeSampler class: device-aware, resamples positive collisions
- Validate n_negatives is a positive integer for non-softmax losses
- Make align_embeddings() device-aware (supports CUDA pretrained embeddings)
- Remove unused imports (os in benchmark, pytest in test_sequence_data)
- Add CUDA guard in benchmark main()
- Add e2e tests: non-softmax losses with patience, n_negatives=0/-1/None
Keep only device-awareness (the actual review request). Preserving
pretrained.dtype could cause precision issues with float16 inputs.
- Add `device` parameter to UniSRecModel.__init__ (default None = input device)
- Move x/y to CPU before DataLoader to avoid CUDA+multiprocessing issues
- Benchmark: pass device="cuda" explicitly to build_sequences and UniSRecModel
…pylint, bandit)

- Add type annotations across benchmark, tests, and source files (mypy 30→0 errors)
- Annotate frozen_emb buffer and Optional head in net.py
- Add assert guards for Optional item_id_mapping usage
- Type sasrec_kwargs and nested functions in benchmark
- Fix tensor index type in test_metrics
- isort: fix import ordering in __init__.py files and test_metrics.py
- black: auto-format all new files to project style
- flake8: add per-file-ignores in setup.cfg for new modules (D102, N806,
  N812, D401); fix D403 capitalization in test docstring
- mypy: fix arg-type for align_embeddings (add assert for Optional),
  fix slice index type in test_unisrec_model
- pylint: rename unused vars (B -> _B, unique_items -> _unique_items,
  y -> _y), move math import to top-level in metrics.py, add
  pylint-disable for too-many-* / protected-access / not-callable /
  redefined-outer-name, use dict literals instead of dict()
- codespell: already clean
- bandit: already clean
Add comprehensive tutorial covering model training, evaluation,
inference, checkpointing, ONNX export, and comparison of different
configurations (loss functions, adaptor types, optimizers).
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 20 out of 23 changed files in this pull request and generated 5 comments.

Comment thread rectools/fast_transformers/net.py
Comment on lines +26 to +29
# Resample positions where negative == positive
collisions = negs == y.unsqueeze(-1)
if collisions.any():
negs[collisions] = torch.randint(1, self.n_items + 1, (int(collisions.sum()),), device=y.device)
Comment on lines +25 to +28
import numpy as np
import pandas as pd
import requests
import torch
Comment on lines +94 to +95
e_b = net.encode_last(x_b)
torch.testing.assert_close(e_a, e_b)
k = topk_ids.shape[1]
hits = (topk_ids == targets.unsqueeze(1)).float() # (B, K)
ranks = torch.arange(1, k + 1, device=topk_ids.device, dtype=torch.float)
discounts = 1.0 / torch.log(ranks + 1) * (1.0 / _log(log_base))

Parameters
----------
ffn_type : ``"conv1d"`` | ``"linear_gelu"`` | ``"linear_relu"``
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add all the parameters and the return

return results


def _log(base: int) -> float:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's used only once and it's a one-liner - are you sure you need it?

Comment thread CHANGELOG.md
## [Unreleased]

### Added
- `rectools.fast_transformers` module — standalone transformer-based sequential recommenders that work directly with torch tensors, bypassing the `Dataset`/pandas pipeline. GPU-native sequence building via `build_sequences()` gives ~30x preprocessing speedup over `SASRecDataPreparator` on ML-20M
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add links to the PR (see examples below)

Comment thread setup.cfg
ignore = D205,D400,D105,D100,E203,W503
per-file-ignores =
tests/*: D100,D101,D102,D103,D104
tests/*: D100,D101,D102,D103,D104,N806
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to ignore N806? Can't we don't violate it?

Is it violated in many places? (If not, better to have local ignores)

If absolutely necessary, please specify it only for the fast_transformers subfolder

Comment thread setup.cfg
Comment on lines +51 to +54
rectools/fast_transformers/net.py: D102,N806
rectools/fast_transformers/unisrec/lightning.py: D102,D401,N812
rectools/fast_transformers/unisrec/model.py: D102
rectools/fast_transformers/unisrec/net.py: D102,N806
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. It's not the best way to specify rules per file - unless really needed
  2. Please don't ignore D102 and D401 - all public methods must have a docstring, and in the correct style
  3. I'd also really prefer not to violate N812
  4. N806 is not so critical

Copy link
Copy Markdown
Collaborator

@feldlime feldlime left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more findings that haven't been raised yet — mostly around UniSRecModel.fit/save/load/predict/export lifecycle and the embedding-alignment helper. Nothing here blocks the architecture, but several would bite users in practice.

if self.patience is not None:
val_y_last = y[:, -1:]
val_dl = DataLoader(
SequenceBatchDataset(x, val_y_last, transform=neg_transform),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Early stopping is measured on the training data: val_dl wraps the same x as train_dl (just with y reduced to the last position). So val_loss decreases monotonically as the model overfits, and patience either never triggers or triggers for unrelated reasons (noise / regularization stochasticity). Either accept an explicit validation split in fit() (e.g. val_user_ids/val_item_ids/val_timestamps, or hold out the last interaction per user), or document that patience is effectively a train-loss plateau heuristic, not real early stopping.

def load_checkpoint(self, path: tp.Union[str, Path], device: tp.Optional[str] = None) -> None:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt = torch.load(path, map_location=device, weights_only=False) # nosec B614
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.load(path, weights_only=False) is an arbitrary-pickle deserialization — silencing bandit with # nosec B614 hides the risk, it doesn't address it. The checkpoint here only stores tensors and a Python int, so this should work with weights_only=True once unique_items/unique_users are saved as plain tensors (already the case). Please switch to weights_only=True (PyTorch is making this the default in 2.6+ anyway), and we avoid shipping a model that will RCE on a malicious checkpoint.

"unique_items": self._unique_items,
"unique_users": self._unique_users,
"n_items": len(self._unique_items),
},
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_checkpoint stores state_dict but no architecture hyperparameters (n_factors, n_blocks, n_heads, session_max_len, adaptor_type, use_adaptor_ffn, ffn_type, ffn_expansion, ...). load_checkpoint rebuilds UniSRec from self.* — i.e. whatever the user passed to the new UniSRecModel(...) constructor before calling load_checkpoint. If they don't pass the same arch, load_state_dict either fails loudly (shape mismatch) or, worse, silently loads partial weights. Save the relevant hparams in the checkpoint dict and restore them in load_checkpoint, or at minimum add a @classmethod from_checkpoint(cls, path) that constructs the model from saved hparams.

item_embs = net.project_all()
scores = h @ item_embs.T
scores[:, 0] = float("-inf")
top_scores, top_ids = scores.topk(k, dim=1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predict_topk returns internal item IDs (1..n_items), but the user gives external IDs everywhere else (in fit, in map_item_ids). The model already holds the mapping (self._unique_items), so making every caller post-process top_ids is friction and an easy source of subtle bugs (e.g. computing HR/NDCG against external targets while comparing to internal IDs). Either return external IDs by default and offer an internal=True flag, or change the return type to a small dataclass that documents the ID space.

input_names=["input_ids"],
output_names=["hidden"],
opset_version=opset_version,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ONNX export is traced with dummy = torch.zeros(1, 5, ...) and no dynamic_axes, so the resulting graph only accepts inputs of shape (1, 5). In practice anyone exporting for serving wants dynamic batch and (often) dynamic sequence length up to session_max_len. Suggest passing dynamic_axes={"input_ids": {0: "batch", 1: "seq"}, "hidden": {0: "batch", 1: "seq"}} (and similar for the project_all wrapper — its output is (n_items+1, D), so dynamic n_items if you ever expect to re-export after refit). Also worth using session_max_len for the dummy length so positional embeddings are exercised at their full range during tracing.

else:
aligned = torch.zeros(n_items + 1, pretrained.shape[1], pretrained.shape[2], device=device)

aligned[1:][valid] = pretrained[idx[valid]]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two issues here:

  1. valid = (idx >= 0) & (idx < pretrained.shape[0]) silently zeros out unknown item embeddings — making them indistinguishable from the padding row (also zeros). If pretrained was built from a stale catalog, the model will silently treat missing items as padding rather than warning. Suggest raising (or at least logging) when (~valid).any().

  2. torch.zeros(n_items + 1, ...) defaults to float32, so if pretrained is bfloat16/float16/float64 the alignment silently changes dtype (and then everything downstream is fp32). Use dtype=pretrained.dtype, device=pretrained.device.

scores[:, 0] = float("-inf")
top_scores, top_ids = scores.topk(k, dim=1)
if was_training:
net.train()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

net.eval() ... if was_training: net.train() isn't exception-safe — if encode_last raises (CUDA OOM, shape mismatch, ...) the network is left in eval() mode for the rest of the session, which silently disables dropout on subsequent fit(). Same pattern in export_to_onnx. Wrap in try/finally or use a small context manager.

from .metrics import compute_metrics, hitrate_at_k, mrr_at_k, ndcg_at_k
from .net import FlatSASRec, SASRecBlock
from .preprocessing import SequenceBatchDataset, align_embeddings, build_sequences
from .unisrec import UniSRec, UniSRecLightning, UniSRecModel
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Top-level package exposes both UniSRec (the nn.Module) and UniSRecModel (the high-level trainer). Two classes whose only difference is a Model suffix is genuinely confusing on import (from rectools.fast_transformers import UniSRec vs UniSRecModel). Same goes for FlatSASRec (network) sitting next to UniSRecModel (trainer) in the same namespace. Suggest renaming the networks to UniSRecNet/FlatSASRecNet (matching lightning.py / model.py vocabulary), or only re-exporting the high-level *Model classes here and keeping nets reachable via rectools.fast_transformers.unisrec.net.


def _full_softmax_loss(self, hidden: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
all_emb = self._get_all_embs()
logits = hidden @ all_emb.T
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_full_softmax_loss scores against project_all(), which uses variant 0 of frozen_emb (deterministic). The other losses score against _get_item_embs(...) = _adapt_score(_sample_frozen(...)), which randomly samples variants during training. For multi-variant pretrained embeddings this means the training signal is fundamentally different for softmax vs BCE/gBCE/sampled_softmax — softmax never sees variant augmentation. Either route the full-softmax path through a sampled-variant projection too, or document that variant augmentation only applies to negative-sampling losses.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants