From 2f9392695d60ae2674ae69ae3a3a0b0443c9bacd Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 1 Jul 2026 16:03:29 +0200 Subject: [PATCH 1/5] table predictor integration --- .../using_doctr/custom_models_training.rst | 33 ++ docs/source/using_doctr/using_models.rst | 32 ++ doctr/datasets/table_structure.py | 2 +- doctr/io/elements.py | 146 ++++++++- doctr/models/builder.py | 282 +++++++++++++++--- doctr/models/modules/layers/pytorch.py | 10 +- doctr/models/predictor/pytorch.py | 91 +++++- doctr/models/utils/pytorch.py | 36 +-- doctr/models/zoo.py | 24 +- doctr/utils/visualization.py | 2 +- pyproject.toml | 1 + references/detection/train.py | 6 +- tests/common/test_io_elements.py | 92 ++++++ tests/common/test_models_builder.py | 110 ++++++- tests/pytorch/test_models_utils_pt.py | 2 +- tests/pytorch/test_models_zoo_pt.py | 54 +++- 16 files changed, 851 insertions(+), 72 deletions(-) diff --git a/docs/source/using_doctr/custom_models_training.rst b/docs/source/using_doctr/custom_models_training.rst index b1b95c6110..bcb0fc0556 100644 --- a/docs/source/using_doctr/custom_models_training.rst +++ b/docs/source/using_doctr/custom_models_training.rst @@ -186,3 +186,36 @@ Loading your custom trained orientation classification model # Overwrite the default orientation models predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model) predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model) + +Custom table structure recognition models +----------------------------------------- + +If you work with documents containing tables and make use of the table structure recognition feature by passing the following arguments: + +* `detect_tables=True` + +You can train your own table structure recognition model using the docTR library. For details on the training process and the necessary data and data format, refer to the following link: + +- `table structure recognition `_ + +Loading your custom trained table structure recognition model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code:: python3 + + import torch + from doctr.io import DocumentFile + from doctr.models import ocr_predictor, tablecenternet + from doctr.models.table_structure.zoo import table_predictor + + custom_table_structure_model = tablecenternet(pretrained=False) + custom_table_structure_model.from_pretrained('') + + predictor = ocr_predictor( + pretrained=True, + detect_layout=True, + detect_tables=True, + ) + + # Overwrite the default table structure model + predictor.table_predictor = table_predictor(custom_table_structure_model) diff --git a/docs/source/using_doctr/using_models.rst b/docs/source/using_doctr/using_models.rst index f1ec642553..a7ac2c3c3b 100644 --- a/docs/source/using_doctr/using_models.rst +++ b/docs/source/using_doctr/using_models.rst @@ -303,6 +303,38 @@ and columns. # out[0] -> {"cells": [{"geometry": ..., "score": ..., "row_start": 0, "row_end": 0, # "col_start": 0, "col_end": 0}, ...], "num_rows": ..., "num_cols": ...} +Tables in the OCR pipeline +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Passing ``detect_tables=True`` to :py:meth:`ocr_predictor ` runs table structure +recognition. It relies on the layout model: each region the layout model labels as +a table is cropped and passed to the table model, so a page yields one structured table per detected table region +(``detect_tables=True`` therefore also enables the layout model, whose regions are attached to the page). The words +whose center falls inside a detected cell are regrouped into a structured table, attached to the page as +``page.tables``, and **removed from the regular** ``blocks`` **output** so the same text is not returned twice. Each +table exposes a dense row/column grid that loads directly into pandas. + +.. code:: python3 + + from doctr.io import DocumentFile + from doctr.models import ocr_predictor + + model = ocr_predictor(pretrained=True, detect_tables=True) + doc = DocumentFile.from_images("invoice_with_table.png") + result = model(doc) + + page = result.pages[0] + # Structured tables (one or more per page), kept out of the regular text blocks + for i, table in enumerate(page.tables): + df = table.to_grid() # for a plain list of lists + print(f"Table {i} ({table.num_rows}x{table.num_cols}):") + print(df) + + # The remaining (non-table) text is still available as usual + print(page.render()) + +Tables are included in :meth:`Page.export` under the ``tables`` key, so they are preserved in the JSON export as well. + End-to-End OCR -------------- diff --git a/doctr/datasets/table_structure.py b/doctr/datasets/table_structure.py index aa3bfc38ab..a376742b33 100644 --- a/doctr/datasets/table_structure.py +++ b/doctr/datasets/table_structure.py @@ -48,7 +48,7 @@ class TableStructureDataset(AbstractDataset): img_folder: folder with all the dataset images label_path: path to the JSON labels use_polygons: whether to keep cell polygons instead of converting them to straight boxes - **kwargs: keyword arguments from `AbstractDataset` (e.g. ``img_transforms``, ``sample_transforms``) + **kwargs: keyword arguments from `AbstractDataset` (e.g. `img_transforms`, `sample_transforms`) """ def __init__( diff --git a/doctr/io/elements.py b/doctr/io/elements.py index 4e8829f22b..7d0f80b530 100644 --- a/doctr/io/elements.py +++ b/doctr/io/elements.py @@ -34,6 +34,8 @@ "KIEDocument", "Document", "LayoutElement", + "TableCell", + "Table", ] @@ -178,6 +180,136 @@ def from_dict(cls, save_dict: dict[str, Any], **kwargs): return cls(layout_type=kwargs["type"], confidence=kwargs["confidence"], geometry=kwargs["geometry"]) +class TableCell(Element): + """Implements a single cell of a recognized table + + Args: + value: the text content of the cell (words assigned to the cell, joined together) + confidence: the mean recognition confidence of the words assigned to the cell + geometry: bounding box of the cell in format ((xmin, ymin), (xmax, ymax)) or a (4, 2) polygon, + with coordinates relative to the page's size + row_start: index of the first row spanned by the cell (0-indexed) + row_end: index of the last row spanned by the cell (0-indexed, inclusive) + col_start: index of the first column spanned by the cell (0-indexed) + col_end: index of the last column spanned by the cell (0-indexed, inclusive) + """ + + _exported_keys: list[str] = [ + "geometry", + "value", + "confidence", + "row_start", + "row_end", + "col_start", + "col_end", + ] + _children_names: list[str] = [] + + def __init__( + self, + value: str, + confidence: float, + geometry: BoundingBox | np.ndarray, + row_start: int, + row_end: int, + col_start: int, + col_end: int, + ) -> None: + super().__init__() + self.value = value + self.confidence = confidence + self.geometry = geometry + self.row_start = row_start + self.row_end = row_end + self.col_start = col_start + self.col_end = col_end + + @property + def row_span(self) -> int: + """Number of rows spanned by the cell""" + return self.row_end - self.row_start + 1 + + @property + def col_span(self) -> int: + """Number of columns spanned by the cell""" + return self.col_end - self.col_start + 1 + + def render(self) -> str: + """Renders the cell text""" + return self.value + + def extra_repr(self) -> str: + return f"value='{self.value}', rows=({self.row_start}, {self.row_end}), cols=({self.col_start}, {self.col_end})" + + @classmethod + def from_dict(cls, save_dict: dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + return cls(**kwargs) + + +class Table(Element): + """Implements a table recognized on a page as a grid of cells + + The recognized text of the words falling inside the table is regrouped here and removed from the + regular `blocks` output of the page, so it is not duplicated. The structured content can be loaded + directly into pandas, e.g. `pd.DataFrame(table.to_grid())`. + + Args: + cells: list of `TableCell` objects composing the table + num_rows: number of rows of the table + num_cols: number of columns of the table + geometry: bounding box enclosing the whole table, with coordinates relative to the page's size + confidence: the confidence of the table structure prediction + """ + + _exported_keys: list[str] = ["geometry", "num_rows", "num_cols", "confidence"] + _children_names: list[str] = ["cells"] + cells: list[TableCell] = [] + + def __init__( + self, + cells: list[TableCell], + num_rows: int, + num_cols: int, + geometry: BoundingBox | np.ndarray, + confidence: float = 1.0, + ) -> None: + super().__init__(cells=cells) + self.num_rows = num_rows + self.num_cols = num_cols + self.geometry = geometry + self.confidence = confidence + + def to_grid(self) -> list[list[str]]: + """Return the table content as a dense `num_rows` x `num_cols` grid of strings. + + Cells spanning several rows/columns have their value placed at their top-left position; the + remaining positions they span are left empty. The result is directly loadable into pandas via + `pd.DataFrame(table.to_grid())`. + + Returns: + a list of `num_rows` lists, each of length `num_cols` + """ + grid = [["" for _ in range(self.num_cols)] for _ in range(self.num_rows)] + for cell in self.cells: + if 0 <= cell.row_start < self.num_rows and 0 <= cell.col_start < self.num_cols: + grid[cell.row_start][cell.col_start] = cell.value + return grid + + def render(self, row_break: str = "\n", col_break: str = "\t") -> str: + """Renders the table as plain text (tab-separated values)""" + return row_break.join(col_break.join(row) for row in self.to_grid()) + + def extra_repr(self) -> str: + return f"num_rows={self.num_rows}, num_cols={self.num_cols}, confidence={self.confidence:.2}" + + @classmethod + def from_dict(cls, save_dict: dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + kwargs["cells"] = [TableCell.from_dict(cell) for cell in save_dict["cells"]] + return cls(**kwargs) + + class Line(Element): """Implements a line element as a collection of words @@ -299,12 +431,14 @@ class Page(Element): orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction language: a dictionary with the language value and confidence of the prediction layout: optional list of layout regions detected on the page + tables: optional list of tables recognized on the page. Words assigned to a table are removed from `blocks`. """ _exported_keys: list[str] = ["page_idx", "dimensions", "orientation", "language"] - _children_names: list[str] = ["blocks", "layout"] + _children_names: list[str] = ["blocks", "layout", "tables"] blocks: list[Block] = [] layout: list[LayoutElement] = [] + tables: list[Table] = [] def __init__( self, @@ -315,8 +449,13 @@ def __init__( orientation: dict[str, Any] | None = None, language: dict[str, Any] | None = None, layout: list[LayoutElement] | None = None, + tables: list[Table] | None = None, ) -> None: - super().__init__(blocks=blocks, layout=layout if layout is not None else []) + super().__init__( + blocks=blocks, + layout=layout if layout is not None else [], + tables=tables if tables is not None else [], + ) self.page = page self.page_idx = page_idx self.dimensions = dimensions @@ -337,7 +476,7 @@ def show(self, interactive: bool = True, preserve_aspect_ratio: bool = False, ** interactive: whether the display should be interactive preserve_aspect_ratio: pass True if you passed True to the predictor **kwargs: additional keyword arguments passed to the matplotlib.pyplot.show method - (e.g. ``display_layout=False`` to hide detected layout regions) + (e.g. `display_layout=False` to hide detected layout regions) """ requires_package("matplotlib", "`.show()` requires matplotlib & mplcursors installed") requires_package("mplcursors", "`.show()` requires matplotlib & mplcursors installed") @@ -474,6 +613,7 @@ def from_dict(cls, save_dict: dict[str, Any], **kwargs): kwargs.update({ "blocks": [Block.from_dict(block_dict) for block_dict in save_dict["blocks"]], "layout": [LayoutElement.from_dict(region_dict) for region_dict in save_dict.get("layout", [])], + "tables": [Table.from_dict(table_dict) for table_dict in save_dict.get("tables", [])], }) return cls(**kwargs) diff --git a/doctr/models/builder.py b/doctr/models/builder.py index 685de2d215..7b40840ff8 100644 --- a/doctr/models/builder.py +++ b/doctr/models/builder.py @@ -9,7 +9,19 @@ import numpy as np from scipy.cluster.hierarchy import fclusterdata -from doctr.io.elements import Block, Document, KIEDocument, KIEPage, LayoutElement, Line, Page, Prediction, Word +from doctr.io.elements import ( + Block, + Document, + KIEDocument, + KIEPage, + LayoutElement, + Line, + Page, + Prediction, + Table, + TableCell, + Word, +) from doctr.utils.geometry import estimate_page_angle, resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes from doctr.utils.repr import NestedObject @@ -213,14 +225,14 @@ def _resolve_blocks(boxes: np.ndarray, lines: list[list[int]]) -> list[list[list @staticmethod def _build_layout_elements(regions: dict[str, Any] | None) -> list[LayoutElement]: - """Convert a raw layout prediction into exportable ``LayoutElement`` objects. + """Convert a raw layout prediction into exportable `LayoutElement` objects. Args: - regions: a layout prediction ``{"boxes": (R, 4) | (R, 4, 2), "class_names": [...], "scores": [...]}`` - as returned by a ``LayoutPredictor``, or None. + regions: a layout prediction `{"boxes": (R, 4) | (R, 4, 2), "class_names": [...], "scores": [...]}` + as returned by a `LayoutPredictor`, or None. Returns: - list of ``LayoutElement`` (empty if no layout was provided). + list of `LayoutElement` (empty if no layout was provided). """ if regions is None or len(regions.get("boxes", [])) == 0: return [] @@ -238,6 +250,172 @@ def _build_layout_elements(regions: dict[str, Any] | None) -> list[LayoutElement elements.append(LayoutElement(layout_type=str(cname), confidence=float(score), geometry=geometry)) return elements + @staticmethod + def _word_centers(boxes: np.ndarray) -> np.ndarray: + """Return the (x, y) center of each word box. + + Args: + boxes: word boxes of shape (N, 4) (straight: x1, y1, x2, y2) or (N, 4, 2) (rotated polygon) + + Returns: + array of shape (N, 2) with the relative center coordinates of each box + """ + if boxes.ndim == 3: # rotated polygons (N, 4, 2) + return boxes.mean(axis=1) + return np.stack([(boxes[:, 0] + boxes[:, 2]) / 2, (boxes[:, 1] + boxes[:, 3]) / 2], axis=1) + + @staticmethod + def _point_in_poly(point: np.ndarray, poly: np.ndarray) -> bool: + """Test whether a 2D point lies inside a polygon using the ray casting algorithm. + + Args: + point: array of shape (2,) with the (x, y) coordinates of the point + poly: array of shape (M, 2) with the polygon vertices + + Returns: + True if the point is inside the polygon + """ + x, y = float(point[0]), float(point[1]) + inside = False + n = len(poly) + j = n - 1 + for i in range(n): + xi, yi = float(poly[i][0]), float(poly[i][1]) + xj, yj = float(poly[j][0]), float(poly[j][1]) + if ((yi > y) != (yj > y)) and (x < (xj - xi) * (y - yi) / (yj - yi + 1e-12) + xi): + inside = not inside + j = i + return inside + + @staticmethod + def _localize_logic(cells: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], int, int]: + """Re-index a table's logical coordinates to a local 0-based grid. + + The table model returns logical (row/column) coordinates that may carry a constant offset; shifting them + so that the smallest row/column start is 0 makes the grid directly usable by :meth:`Table.to_grid`. + + Args: + cells: the cells of a single table + + Returns: + a tuple `(cells, num_rows, num_cols)` with the re-indexed cells and the table dimensions + """ + min_row = min(int(c["row_start"]) for c in cells) + min_col = min(int(c["col_start"]) for c in cells) + norm: list[dict[str, Any]] = [] + max_row = max_col = 0 + for c in cells: + nc = dict(c) + nc["row_start"] = int(c["row_start"]) - min_row + nc["row_end"] = int(c["row_end"]) - min_row + nc["col_start"] = int(c["col_start"]) - min_col + nc["col_end"] = int(c["col_end"]) - min_col + max_row, max_col = max(max_row, nc["row_end"]), max(max_col, nc["col_end"]) + norm.append(nc) + return norm, max_row + 1, max_col + 1 + + def _build_tables( + self, + boxes: np.ndarray, + word_preds: list[tuple[str, float]], + page_table: dict[str, Any] | list[dict[str, Any]] | None, + ) -> tuple[list[Table], np.ndarray]: + """Assign detected words to table cells and build the page tables. + + A page may contain several tables; each one is provided as its own grid (the OCR pipeline detects table + regions with the layout model, then runs the table model on every cropped region). Both a single grid and + a list of grids are accepted. Each word whose center falls inside a cell polygon is assigned to (at most) + one cell, across all tables, and flagged so it can be removed from the regular `blocks` output. Words + are joined per cell in reading order (top to bottom, then left to right). + + Args: + boxes: word boxes of the page, of shape (N, 4) or (N, 4, 2), in relative coordinates + word_preds: list of (text, confidence) for each of the N words + page_table: the table structure prediction(s) for the page. Either a single grid + `{"cells": [{"geometry", "score", "row_start", "row_end", "col_start", "col_end"}], "num_rows", + "num_cols"}` (cell geometries in page-relative coordinates), a list of such grids, or None + + Returns: + a tuple with the list of `Table` objects of the page (one per provided table) and a boolean mask of + shape (N,) that is True for words assigned to a table (to be removed from `blocks`) + """ + num_words = boxes.shape[0] + consumed = np.zeros(num_words, dtype=bool) + if page_table is None: + return [], consumed + + # Normalize the prediction(s) to a list of per-table grids with local 0-based logical coordinates + raw_tables = [page_table] if isinstance(page_table, dict) else list(page_table) + table_dicts: list[dict[str, Any]] = [] + for raw in raw_tables: + if not raw or len(raw.get("cells", [])) == 0: + continue + cells, n_rows, n_cols = self._localize_logic(raw["cells"]) + table_dicts.append({"cells": cells, "num_rows": n_rows, "num_cols": n_cols}) + if len(table_dicts) == 0: + return [], consumed + + centers = self._word_centers(boxes) if num_words > 0 else np.empty((0, 2)) + + tables_out: list[Table] = [] + for table_dict in table_dicts: + cells = table_dict["cells"] + cell_polys = [np.asarray(cell["geometry"], dtype=np.float32) for cell in cells] + + # Assign each (still unassigned) word to at most one cell of this table + cell_word_idcs: list[list[int]] = [[] for _ in cells] + for w_idx in range(num_words): + if consumed[w_idx]: + continue + for c_idx, poly in enumerate(cell_polys): + if self._point_in_poly(centers[w_idx], poly): + cell_word_idcs[c_idx].append(w_idx) + consumed[w_idx] = True + break + + # Build the cells + table_cells: list[TableCell] = [] + for cell, poly, w_idcs in zip(cells, cell_polys, cell_word_idcs): + if len(w_idcs) > 0: + # Reading order inside the cell: top to bottom, then left to right + ordered = sorted(w_idcs, key=lambda i: (round(float(centers[i][1]), 3), float(centers[i][0]))) + value = " ".join(word_preds[i][0] for i in ordered) + confidence = float(np.mean([word_preds[i][1] for i in ordered])) + else: + value, confidence = "", float(cell["score"]) + geometry = tuple(tuple(float(c) for c in pt) for pt in poly.tolist()) + table_cells.append( + TableCell( + value=value, + confidence=confidence, + geometry=geometry, # type: ignore[arg-type] + row_start=int(cell["row_start"]), + row_end=int(cell["row_end"]), + col_start=int(cell["col_start"]), + col_end=int(cell["col_end"]), + ) + ) + + # Enclosing geometry of the whole table (relative bbox) + all_pts = np.concatenate(cell_polys, axis=0) + table_geometry = ( + (float(all_pts[:, 0].min()), float(all_pts[:, 1].min())), + (float(all_pts[:, 0].max()), float(all_pts[:, 1].max())), + ) + table_confidence = float(np.mean([cell["score"] for cell in cells])) + + tables_out.append( + Table( + cells=table_cells, + num_rows=int(table_dict["num_rows"]), + num_cols=int(table_dict["num_cols"]), + geometry=table_geometry, + confidence=table_confidence, + ) + ) + + return tables_out, consumed + def _build_blocks( self, boxes: np.ndarray, @@ -320,6 +498,7 @@ def __call__( orientations: list[dict[str, Any]] | None = None, languages: list[dict[str, Any]] | None = None, regions: list[dict[str, Any] | None] | None = None, + tables: list[dict[str, Any] | None] | None = None, ) -> Document: """Re-arrange detected words into structured blocks @@ -337,7 +516,11 @@ def __call__( languages: optional, list of N elements, where each element is a dictionary containing the language (language + confidence) regions: optional, list of N elements, where each element is a layout prediction - ``{"boxes": (R, 4|4x2), "class_names": [...], "scores": [...]}`` attached to each page + `{"boxes": (R, 4|4x2), "class_names": [...], "scores": [...]}` attached to each page + tables: optional, list of N elements, where each element is the table structure prediction(s) of a + page: a single grid `{"cells": [...], "num_rows": int, "num_cols": int}` or a list of such grids + (one per table region detected by the layout model). Words assigned to any table are removed from + the `blocks` output of that page. Returns: document object @@ -350,40 +533,66 @@ def __call__( _orientations = orientations if isinstance(orientations, list) else [None] * len(boxes) _languages = languages if isinstance(languages, list) else [None] * len(boxes) _regions = regions if isinstance(regions, list) else [None] * len(boxes) + _tables = tables if isinstance(tables, list) else [None] * len(boxes) if self.export_as_straight_boxes and len(boxes) > 0: # If boxes are already straight OK, else fit a bounding rect if boxes[0].ndim == 3: # Iterate over pages and boxes boxes = [np.concatenate((p_boxes.min(1), p_boxes.max(1)), 1) for p_boxes in boxes] - _pages = [ - Page( - page, - self._build_blocks( - page_boxes, - loc_scores, - word_preds, - word_crop_orientations, - ), - _idx, - shape, - orientation, - language, - self._build_layout_elements(page_regions), - ) - for page, _idx, shape, page_boxes, loc_scores, word_preds, word_crop_orientations, orientation, language, page_regions in zip( # noqa: E501 - pages, - range(len(boxes)), - page_shapes, - boxes, - objectness_scores, - text_preds, - crop_orientations, - _orientations, - _languages, - _regions, + _pages = [] + for ( + page, + _idx, + shape, + page_boxes, + loc_scores, + word_preds, + word_crop_orientations, + orientation, + language, + page_regions, + page_table, + ) in zip( # noqa: E501 + pages, + range(len(boxes)), + page_shapes, + boxes, + objectness_scores, + text_preds, + crop_orientations, + _orientations, + _languages, + _regions, + _tables, + ): + # Build the page tables and flag the words that belong to a table + page_tables, consumed = self._build_tables(page_boxes, word_preds, page_table) + if consumed.any(): + # Remove the words assigned to a table from the regular blocks output + keep = ~consumed + page_boxes = page_boxes[keep] + loc_scores = loc_scores[keep] + word_preds = [wp for wp, k in zip(word_preds, keep) if k] + word_crop_orientations = [co for co, k in zip(word_crop_orientations, keep) if k] + + _pages.append( + Page( + page, + self._build_blocks( + page_boxes, + loc_scores, + word_preds, + word_crop_orientations, + ), + _idx, + shape, + orientation, + language, + self._build_layout_elements(page_regions), + page_tables, + ) ) - ] return Document(_pages) @@ -410,6 +619,7 @@ def __call__( # type: ignore[override] orientations: list[dict[str, Any]] | None = None, languages: list[dict[str, Any]] | None = None, regions: list[dict[str, Any] | None] | None = None, + tables: list[list[dict[str, Any] | None] | None] | None = None, ) -> KIEDocument: """Re-arrange detected words into structured predictions @@ -427,7 +637,11 @@ def __call__( # type: ignore[override] languages: optional, list of N elements, where each element is a dictionary containing the language (language + confidence) regions: optional, list of N elements, where each element is a layout prediction - ``{"boxes": (R, 4|4x2), "class_names": [...], "scores": [...]}`` attached to each page + `{"boxes": (R, 4|4x2), "class_names": [...], "scores": [...]}` attached to each page + tables: optional, list of N elements, where each element is the table structure prediction(s) of a + page: a single grid `{"cells": [...], "num_rows": int, "num_cols": int}` or a list of such grids + (one per table region detected by the layout model). Words assigned to any table are removed from + the `blocks` output of that page. Unused for KIE documents, as tables are not supported in KIE. Returns: document object diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py index 89d0fffa0c..07ddd0ea4a 100644 --- a/doctr/models/modules/layers/pytorch.py +++ b/doctr/models/modules/layers/pytorch.py @@ -93,11 +93,11 @@ def _deform_conv2d( ) -> torch.Tensor: """Modulated deformable convolution (DCNv2). - Numerically equivalent to ``torchvision.ops.deform_conv2d`` (same offset/mask channel layout and bilinear - convention) but built only from ``grid_sample`` + ``conv2d``, so the model is ONNX-exportable (the - ``torchvision::deform_conv2d`` operator has no ONNX symbolic). ``offset`` is laid out as torchvision - expects: for kernel position ``k = kh * Kw + kw``, ``offset[:, 2 * k]`` is the vertical offset and - ``offset[:, 2 * k + 1]`` the horizontal one; ``mask[:, k]`` is the modulation. + Numerically equivalent to `torchvision.ops.deform_conv2d` (same offset/mask channel layout and bilinear + convention) but built only from `grid_sample` + `conv2d`, so the model is ONNX-exportable (the + `torchvision::deform_conv2d` operator has no ONNX symbolic). `offset` is laid out as torchvision + expects: for kernel position `k = kh * Kw + kw`, `offset[:, 2 * k]` is the vertical offset and + `offset[:, 2 * k + 1]` the horizontal one; `mask[:, k]` is the modulation. Args: x: input feature map, shape (N, C, H, W) diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index ceab3e8414..51696f7127 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -14,6 +14,7 @@ from doctr.models.detection.predictor import DetectionPredictor from doctr.models.layout.predictor import LayoutPredictor from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.models.table_structure.predictor import TablePredictor from doctr.utils.geometry import detach_scores from .base import _OCRPredictor @@ -37,6 +38,9 @@ class OCRPredictor(nn.Module, _OCRPredictor): detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. layout_predictor: optional layout detection module + table_predictor: optional table structure recognition module. Requires `layout_predictor`: table + regions are located by the layout model, cropped, and passed to this module. Words falling inside a + detected table are regrouped into a structured table and removed from the regular text output. **kwargs: keyword args of `DocumentBuilder` """ @@ -51,6 +55,7 @@ def __init__( detect_orientation: bool = False, detect_language: bool = False, layout_predictor: LayoutPredictor | None = None, + table_predictor: TablePredictor | None = None, **kwargs: Any, ) -> None: nn.Module.__init__(self) @@ -68,6 +73,14 @@ def __init__( self.detect_orientation = detect_orientation self.detect_language = detect_language self.layout_predictor = layout_predictor.eval() if layout_predictor is not None else None + self.table_predictor = table_predictor.eval() if table_predictor is not None else None + # Layout class label whose regions are cropped and passed to the table model + self.table_class_name = "Table" + if self.table_predictor is not None and self.layout_predictor is None: + raise ValueError( + "`table_predictor` requires a `layout_predictor`: tables are located with the layout model, " + "cropped, and then passed to the table model." + ) @torch.inference_mode() def forward( @@ -81,6 +94,10 @@ def forward( origin_page_shapes = [page.shape[:2] for page in pages] + if not self.straighten_pages: + # Detect layout regions on the pages + regions = self.layout_predictor(pages, **kwargs) if self.layout_predictor is not None else None + # Localize text elements loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs) @@ -103,6 +120,9 @@ def forward( # update page shapes after straightening origin_page_shapes = [page.shape[:2] for page in pages] + # Detect layout regions on the pages + regions = self.layout_predictor(pages, **kwargs) if self.layout_predictor is not None else None + # Forward again to get predictions on straight pages loc_preds = self.det_predictor(pages, **kwargs) @@ -146,8 +166,12 @@ def forward( else: languages_dict = None - # Detect layout regions on the (possibly straightened) pages - regions = self.layout_predictor(pages, **kwargs) if self.layout_predictor is not None else None + # Recognize table structure: locate tables with the layout model, crop each one and run the table model + tables = ( + self._tables_from_regions(pages, regions, **kwargs) + if self.table_predictor is not None and regions is not None + else None + ) out = self.doc_builder( pages, @@ -159,5 +183,68 @@ def forward( orientations, languages_dict, regions, + tables, ) return out + + def _tables_from_regions( + self, + pages: list[np.ndarray], + regions: list[dict[str, Any]], + **kwargs: Any, + ) -> list[list[dict[str, Any]]]: + """Crop the table regions found by the layout model and run the table model on each crop. + + The table model is applied per cropped region, so a page naturally yields one structured table per + detected `Table` region. Cell geometries are mapped back from crop-relative to page-relative coordinates. + + Args: + pages: the (possibly straightened) page images + regions: the per-page layout predictions `{"class_names", "boxes", "scores"}` + **kwargs: keyword arguments forwarded to the table predictor + + Returns: + a per-page list of table grids `{"cells": [...], "num_rows": int, "num_cols": int}` in page-relative + coordinates + """ + crops: list[np.ndarray] = [] + crop_meta: list[tuple[int, tuple[float, float, float, float]]] = [] + for p_idx, (page, region) in enumerate(zip(pages, regions)): + if region is None: + continue + h, w = page.shape[:2] + for cls_name, box in zip(region["class_names"], region["boxes"]): + if cls_name != self.table_class_name: + continue + pts = np.asarray(box, dtype=np.float32).reshape(-1, 2) + x0, y0 = float(pts[:, 0].min()), float(pts[:, 1].min()) + x1, y1 = float(pts[:, 0].max()), float(pts[:, 1].max()) + # Relative box -> pixel crop (axis-aligned, clamped to the page) + px0, py0 = max(0, int(round(x0 * w))), max(0, int(round(y0 * h))) + px1, py1 = min(w, int(round(x1 * w))), min(h, int(round(y1 * h))) + if px1 - px0 < 2 or py1 - py0 < 2: + continue + crops.append(page[py0:py1, px0:px1]) + crop_meta.append((p_idx, (x0, y0, x1, y1))) + + tables_per_page: list[list[dict[str, Any]]] = [[] for _ in pages] + if len(crops) == 0: + return tables_per_page + + grids = self.table_predictor(crops, **kwargs) # type: ignore[misc] + for (p_idx, (x0, y0, x1, y1)), grid in zip(crop_meta, grids): + region_w, region_h = (x1 - x0), (y1 - y0) + remapped_cells: list[dict[str, Any]] = [] + for cell in grid["cells"]: + poly = np.asarray(cell["geometry"], dtype=np.float32).reshape(-1, 2) + poly[:, 0] = x0 + poly[:, 0] * region_w + poly[:, 1] = y0 + poly[:, 1] * region_h + new_cell = dict(cell) + new_cell["geometry"] = poly.tolist() + remapped_cells.append(new_cell) + tables_per_page[p_idx].append({ + "cells": remapped_cells, + "num_rows": grid["num_rows"], + "num_cols": grid["num_cols"], + }) + return tables_per_page diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py index c816f84297..9fce34f20c 100644 --- a/doctr/models/utils/pytorch.py +++ b/doctr/models/utils/pytorch.py @@ -179,10 +179,10 @@ def export_model_to_onnx( model: the PyTorch model to be exported model_name: the name for the exported model dummy_input: the dummy input to the model - input_names: optional names for the model inputs. Defaults to ``["input"]`` (or ``["input", "masks"]`` - when ``dummy_input`` is a tuple). - output_names: optional names for the model outputs. Defaults to ``["logits"]`` (or - ``["logits", "pred_boxes"]`` when ``dummy_input`` is a tuple). Pass the names of every output when + input_names: optional names for the model inputs. Defaults to `["input"]` (or `["input", "masks"]` + when `dummy_input` is a tuple). + output_names: optional names for the model outputs. Defaults to `["logits"]` (or + `["logits", "pred_boxes"]` when `dummy_input` is a tuple). Pass the names of every output when the model returns more than one tensor (e.g. a multi-head model). dynamic_axes: optional dynamic axes. Defaults to a dynamic batch dimension on every input and output. kwargs: additional arguments to be passed to torch.onnx.export @@ -285,8 +285,8 @@ def _vocab_projections(model: nn.Module, vocab_size: int) -> list[nn.Linear]: def _anyascii_nearest_map(vocab: str, allowed: set[str]) -> dict[str, str]: """Map each forbidden character to the visually closest allowed one via transliteration. - Uses ``anyascii`` to fold characters to their ASCII form (e.g. ``ä -> a``, ``ł -> l``, - Cyrillic ``а -> a``); a forbidden character is mapped to an allowed character sharing the + Uses `anyascii` to fold characters to their ASCII form (e.g. `ä -> a`, `ł -> l`, + Cyrillic `а -> a`); a forbidden character is mapped to an allowed character sharing the same ASCII form. Forbidden characters without such a match are left unmapped (they fall back to plain masking). """ @@ -361,19 +361,19 @@ def add_whitelist( """Restrict a recognition model so it can only predict a subset of its vocabulary. The whitelist is enforced at the model's final projection layer, before the decoding - ``argmax``. Because the projection is the single point every logit flows through, the + `argmax`. Because the projection is the single point every logit flows through, the constraint also applies inside the autoregressive decoding loop of SAR, MASTER and PARSeq, so a forbidden character can never be produced -- not even fed back mid-word. The sequence - terminator (CTC ``blank`` / attention ````) is always kept so decoding still + terminator (CTC `blank` / attention ``) is always kept so decoding still terminates. It works with every recognition architecture and with any predictor wrapping one (`ocr_predictor`, `kie_predictor`, `recognition_predictor`). Two strategies are available: - * ``"mask"`` (default): the logits of forbidden characters are set to ``-inf``, so decoding + * `"mask"` (default): the logits of forbidden characters are set to `-inf`, so decoding falls back to the highest-scoring allowed character. - * ``"nearest"``: the score of each forbidden character is first reassigned to the closest - allowed character (so e.g. ``ä`` folds onto ``a``), then forbidden logits are masked. + * `"nearest"`: the score of each forbidden character is first reassigned to the closest + allowed character (so e.g. `ä` folds onto `a`), then forbidden logits are masked. Forbidden characters without a mapping fall back to masking. A whitelist can only restrict a model to characters it already knows: characters that are @@ -389,14 +389,14 @@ def add_whitelist( Args: model: an `ocr_predictor`, `kie_predictor`, `recognition_predictor`, or a recognition model. - vocabs: a vocabulary string (e.g. ``VOCABS["german"]``) or an iterable of vocabulary - strings (e.g. ``[VOCABS["polish"], VOCABS["german"]]``) whose characters are allowed. - strategy: ``"mask"`` (default) to drop forbidden characters, or ``"nearest"`` to fold + vocabs: a vocabulary string (e.g. `VOCABS["german"]`) or an iterable of vocabulary + strings (e.g. `[VOCABS["polish"], VOCABS["german"]]`) whose characters are allowed. + strategy: `"mask"` (default) to drop forbidden characters, or `"nearest"` to fold them onto the closest allowed character. - mapping: only used when ``strategy="nearest"``. ``None`` or ``"anyascii"`` builds the - forbidden-to-allowed map by transliteration (the default); ``"weights"`` derives it - from the projection weights (the model's own confusions); a ``dict`` of - ``{forbidden_char: allowed_char}`` overrides specific characters on top of the + mapping: only used when `"strategy="nearest"`". `None` or `"anyascii"` builds the + forbidden-to-allowed map by transliteration (the default); `"weights"` derives it + from the projection weights (the model's own confusions); a `dict` of + `{forbidden_char: allowed_char}` overrides specific characters on top of the transliteration map. verbose: if True, log how many characters were kept, forbidden and reassigned per model. diff --git a/doctr/models/zoo.py b/doctr/models/zoo.py index d1a83b6207..aebc09a859 100644 --- a/doctr/models/zoo.py +++ b/doctr/models/zoo.py @@ -10,6 +10,7 @@ from .layout.zoo import layout_predictor from .predictor import OCRPredictor from .recognition.zoo import recognition_predictor +from .table_structure.zoo import table_predictor __all__ = ["ocr_predictor", "kie_predictor"] @@ -29,6 +30,7 @@ def _predictor( detect_language: bool = False, detect_layout: bool = False, layout_arch: Any = "lw_detr_s", + detect_tables: bool = False, **kwargs, ) -> OCRPredictor: # Detection @@ -50,7 +52,7 @@ def _predictor( batch_size=reco_bs, ) - # Layout - optional + # Layout - required for table detection, so build it whenever layout or tables are requested layout_pred = ( layout_predictor( layout_arch, @@ -60,7 +62,18 @@ def _predictor( symmetric_pad=symmetric_pad, batch_size=det_bs, ) - if detect_layout + if (detect_layout or detect_tables) + else None + ) + + # Table structure - optional, applied on the cropped table regions found by the layout model + table_pred = ( + table_predictor( + "tablecenternet", + pretrained=pretrained, + batch_size=det_bs, + ) + if detect_tables else None ) @@ -74,6 +87,7 @@ def _predictor( straighten_pages=straighten_pages, detect_language=detect_language, layout_predictor=layout_pred, + table_predictor=table_pred, **kwargs, ) @@ -92,6 +106,7 @@ def ocr_predictor( detect_language: bool = False, detect_layout: bool = False, layout_arch: Any = "lw_detr_s", + detect_tables: bool = False, **kwargs: Any, ) -> OCRPredictor: """End-to-end OCR architecture using one model for localization, and another for text recognition. @@ -128,6 +143,10 @@ def ocr_predictor( to each page. Doing so will slightly deteriorate the overall latency. layout_arch: name of the layout architecture or the model itself to use. + detect_tables: if True, table regions found by the layout model are cropped and passed to a table + structure model. Words falling inside a detected table are regrouped into a structured table + (accessible via `page.tables`) and removed from the regular text output. This enables the layout + model and slightly deteriorates the overall latency. kwargs: keyword args of `OCRPredictor` Returns: @@ -147,6 +166,7 @@ def ocr_predictor( detect_language=detect_language, detect_layout=detect_layout, layout_arch=layout_arch, + detect_tables=detect_tables, **kwargs, ) diff --git a/doctr/utils/visualization.py b/doctr/utils/visualization.py index 27a5a15ded..9940cb9bd8 100644 --- a/doctr/utils/visualization.py +++ b/doctr/utils/visualization.py @@ -403,7 +403,7 @@ def draw_boxes(boxes: np.ndarray, image: np.ndarray, color: tuple[int, int, int] """Draw an array of relative straight boxes on an image. Args: - boxes: array of relative boxes, of shape ``(*, 4)`` + boxes: array of relative boxes, of shape `(*, 4)` image: np array, float32 or uint8 color: color to use for bounding box edges **kwargs: keyword arguments from `matplotlib.pyplot.plot` diff --git a/pyproject.toml b/pyproject.toml index 14c2452ad3..b0b97fcc2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,7 @@ dev = [ "onnxruntime>=1.11.0", "requests>=2.20.0", "psutil>=5.9.5", + "pandas>=3.0.0", # Quality "ruff>=0.3.0", "mypy>=1.0", diff --git a/references/detection/train.py b/references/detection/train.py index c91681ba3a..503f9c9bde 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -39,8 +39,8 @@ def convert_to_multiclass_targets(targets: list) -> list[dict[str, np.ndarray]]: """Convert detection targets to the multi-class format expected by the models. - Built-in datasets loaded with ``detection_task=True`` yield the boxes of each sample as a - plain ``np.ndarray``, whereas the models expect a mapping from class name to boxes. Targets + Built-in datasets loaded with `detection_task=True` yield the boxes of each sample as a + plain `np.ndarray`, whereas the models expect a mapping from class name to boxes. Targets coming from a :class:`~doctr.datasets.DetectionDataset` are already dictionaries and are returned unchanged. @@ -48,7 +48,7 @@ def convert_to_multiclass_targets(targets: list) -> list[dict[str, np.ndarray]]: targets: the batch of targets to normalize Returns: - the batch of targets as a list of ``{class_name: boxes}`` dictionaries + the batch of targets as a list of `{class_name: boxes}` dictionaries """ return [target if isinstance(target, dict) else {CLASS_NAME: target} for target in targets] diff --git a/tests/common/test_io_elements.py b/tests/common/test_io_elements.py index 7da0f4a004..dd1b7feb46 100644 --- a/tests/common/test_io_elements.py +++ b/tests/common/test_io_elements.py @@ -1,6 +1,7 @@ from xml.etree.ElementTree import ElementTree import numpy as np +import pandas as pd import pytest from doctr.file_utils import CLASS_NAME @@ -271,6 +272,96 @@ def test_layout_element(): assert region.export() == state_dict +def test_table_cell(): + geom = ((0.1, 0.1), (0.3, 0.2)) + cell = elements.TableCell( + value="hello", confidence=0.9, geometry=geom, row_start=0, row_end=1, col_start=2, col_end=2 + ) + + # Attribute checks + assert cell.value == "hello" + assert cell.confidence == 0.9 + assert cell.geometry == geom + assert (cell.row_start, cell.row_end, cell.col_start, cell.col_end) == (0, 1, 2, 2) + assert cell.row_span == 2 and cell.col_span == 1 + + # Render + assert cell.render() == "hello" + + # Export + assert cell.export() == { + "geometry": geom, + "value": "hello", + "confidence": 0.9, + "row_start": 0, + "row_end": 1, + "col_start": 2, + "col_end": 2, + } + + # Class method + cell2 = elements.TableCell.from_dict(cell.export()) + assert cell2.export() == cell.export() + + +def _mock_table(): + # 2 x 2 table + cells = [ + elements.TableCell("Name", 0.9, ((0.1, 0.1), (0.3, 0.2)), 0, 0, 0, 0), + elements.TableCell("Age", 0.9, ((0.3, 0.1), (0.5, 0.2)), 0, 0, 1, 1), + elements.TableCell("Alice", 0.9, ((0.1, 0.2), (0.3, 0.3)), 1, 1, 0, 0), + elements.TableCell("30", 0.9, ((0.3, 0.2), (0.5, 0.3)), 1, 1, 1, 1), + ] + return elements.Table(cells=cells, num_rows=2, num_cols=2, geometry=((0.1, 0.1), (0.5, 0.3)), confidence=0.9) + + +def test_table(): + table = _mock_table() + + # Attribute checks + assert table.num_rows == 2 and table.num_cols == 2 + assert len(table.cells) == 4 + assert all(isinstance(c, elements.TableCell) for c in table.cells) + + # Grid + render + assert table.to_grid() == [["Name", "Age"], ["Alice", "30"]] + assert table.render() == "Name\tAge\nAlice\t30" + + # Pandas + df = pd.DataFrame(table.to_grid()) + assert df.shape == (2, 2) + assert df.values.tolist() == [["Name", "Age"], ["Alice", "30"]] + # With a header row + table_grid = table.to_grid() + df_h = pd.DataFrame(table_grid[1:], columns=table_grid[0]) + assert list(df_h.columns) == ["Name", "Age"] + assert df_h.values.tolist() == [["Alice", "30"]] + + # Spanning cell: value placed at top-left of its span, the rest left empty + spanned = elements.Table( + cells=[ + elements.TableCell("merged", 0.9, ((0.0, 0.0), (1.0, 0.5)), 0, 0, 0, 1), + elements.TableCell("a", 0.9, ((0.0, 0.5), (0.5, 1.0)), 1, 1, 0, 0), + elements.TableCell("b", 0.9, ((0.5, 0.5), (1.0, 1.0)), 1, 1, 1, 1), + ], + num_rows=2, + num_cols=2, + geometry=((0.0, 0.0), (1.0, 1.0)), + ) + assert spanned.to_grid() == [["merged", ""], ["a", "b"]] + + # Export + exported = table.export() + assert set(exported.keys()) == {"geometry", "num_rows", "num_cols", "confidence", "cells"} + assert exported["cells"] == [c.export() for c in table.cells] + + # Class method round-trip + assert elements.Table.from_dict(table.export()).export() == table.export() + + # Repr + assert table.__repr__().startswith("Table(") + + def test_prediction(): prediction_str = "hello" conf = 0.8 @@ -372,6 +463,7 @@ def test_page(): "orientation": orientation, "language": language, "layout": [r.export() for r in layout], + "tables": [], } # Export XML diff --git a/tests/common/test_models_builder.py b/tests/common/test_models_builder.py index ad695eb368..5e2587302e 100644 --- a/tests/common/test_models_builder.py +++ b/tests/common/test_models_builder.py @@ -3,7 +3,7 @@ from doctr.file_utils import CLASS_NAME from doctr.io import Document -from doctr.io.elements import KIEDocument, LayoutElement +from doctr.io.elements import KIEDocument, LayoutElement, Table from doctr.models import builder words_per_page = 10 @@ -248,6 +248,114 @@ def test_documentbuilder_layout(): assert isinstance(region.geometry, tuple) and len(region.geometry) == 4 +def _table_cell(x0, y0, x1, y1, rs, re, cs, ce, score=0.9): + return { + "geometry": [[x0, y0], [x1, y0], [x1, y1], [x0, y1]], + "score": score, + "row_start": rs, + "row_end": re, + "col_start": cs, + "col_end": ce, + } + + +def test_documentbuilder_tables(): + doc_builder = builder.DocumentBuilder(resolve_lines=True) + + # 4 words inside a top table, 2 inside a bottom table, 1 caption outside both + def wbox(cx, cy, w=0.04, h=0.02): + return [cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2] + + words = [ + ("Name", 0.17, 0.15), + ("Age", 0.36, 0.15), + ("Alice", 0.17, 0.27), + ("30", 0.36, 0.27), + ("City", 0.17, 0.63), + ("Pop", 0.36, 0.63), + ("caption", 0.30, 0.92), + ] + boxes = np.array([wbox(cx, cy) for _, cx, cy in words], dtype=np.float32) + text_preds = [[(w, 0.95) for w, _, _ in words]] + objectness_scores = np.full(len(words), 0.9, dtype=np.float32) + orientations = [[{"value": 0, "confidence": None}] * len(words)] + + # The OCR pipeline passes a list of grids (one per cropped table region), in page-relative coordinates. + # The bottom table uses offset (1-based) logical coordinates to exercise local re-indexing. + table_top = { + "cells": [ + _table_cell(0.10, 0.10, 0.25, 0.20, 0, 0, 0, 0), + _table_cell(0.28, 0.10, 0.45, 0.20, 0, 0, 1, 1), + _table_cell(0.10, 0.22, 0.25, 0.32, 1, 1, 0, 0), + _table_cell(0.28, 0.22, 0.45, 0.32, 1, 1, 1, 1), + ], + "num_rows": 2, + "num_cols": 2, + } + table_bottom = { + "cells": [ + _table_cell(0.10, 0.58, 0.25, 0.68, 1, 1, 1, 1), + _table_cell(0.28, 0.58, 0.45, 0.68, 1, 1, 2, 2), + ], + "num_rows": 99, # deliberately wrong dims -> recomputed from local coordinates + "num_cols": 99, + } + + out = doc_builder( + [np.zeros((100, 100, 3))], + [boxes], + [objectness_scores], + text_preds, + [(100, 100)], + orientations, + tables=[[table_top, table_bottom]], + ) + page = out.pages[0] + + # One Table per provided grid + assert len(page.tables) == 2 + assert all(isinstance(t, Table) for t in page.tables) + assert page.tables[0].to_grid() == [["Name", "Age"], ["Alice", "30"]] + # bottom table re-indexed from offset coordinates to a local 0-based 1 x 2 grid + assert (page.tables[1].num_rows, page.tables[1].num_cols) == (1, 2) + assert page.tables[1].to_grid() == [["City", "Pop"]] + + # Words assigned to a table are removed from the blocks; the caption remains + remaining = [w.value for b in page.blocks for line in b.lines for w in line.words] + assert remaining == ["caption"] + + # Tables are part of the page export + exported = page.export() + assert len(exported["tables"]) == 2 + assert page.tables[0].to_grid() == [["Name", "Age"], ["Alice", "30"]] + + # A single grid (dict) is also accepted -> one table + out_single = doc_builder( + [np.zeros((100, 100, 3))], + [boxes[:4]], + [objectness_scores[:4]], + [text_preds[0][:4]], + [(100, 100)], + [orientations[0][:4]], + tables=[table_top], + ) + assert len(out_single.pages[0].tables) == 1 + + # No tables -> empty page.tables and every word is kept in the blocks + out_none = doc_builder( + [np.zeros((100, 100, 3))], + [boxes], + [objectness_scores], + text_preds, + [(100, 100)], + orientations, + ) + assert out_none.pages[0].tables == [] + assert out_none.pages[0].export()["tables"] == [] + kept = sorted(w.value for b in out_none.pages[0].blocks for line in b.lines for w in line.words) + assert kept == sorted(w for w, _, _ in words) + + def test_kiedocumentbuilder_layout(): from doctr.io.elements import LayoutElement diff --git a/tests/pytorch/test_models_utils_pt.py b/tests/pytorch/test_models_utils_pt.py index c5db89b6a9..e06709aa9a 100644 --- a/tests/pytorch/test_models_utils_pt.py +++ b/tests/pytorch/test_models_utils_pt.py @@ -274,7 +274,7 @@ def __init__(self): def _force_and_decode(model, target_char, **whitelist_kwargs): - """Bias the model to prefer ``target_char``, apply the whitelist, return the decoded word.""" + """Bias the model to prefer `target_char`, apply the whitelist, return the decoded word.""" from doctr.models.utils.pytorch import _vocab_projections, add_whitelist forbidden_idx = model.vocab.index(target_char) diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py index 9b91f2196a..6590d17fe0 100644 --- a/tests/pytorch/test_models_zoo_pt.py +++ b/tests/pytorch/test_models_zoo_pt.py @@ -6,7 +6,7 @@ from doctr import models from doctr.file_utils import CLASS_NAME from doctr.io import Document, DocumentFile -from doctr.io.elements import KIEDocument, LayoutElement +from doctr.io.elements import KIEDocument, LayoutElement, Table from doctr.models import detection, layout, recognition from doctr.models.classification import mobilenet_v3_small_crop_orientation, mobilenet_v3_small_page_orientation from doctr.models.classification.zoo import crop_orientation_predictor, page_orientation_predictor @@ -19,6 +19,8 @@ from doctr.models.preprocessor import PreProcessor from doctr.models.recognition.predictor import RecognitionPredictor from doctr.models.recognition.zoo import recognition_predictor +from doctr.models.table_structure.predictor import TablePredictor +from doctr.models.table_structure.zoo import table_predictor # Create a dummy callback @@ -214,6 +216,56 @@ def test_ocrpredictor_layout(mock_pdf, mock_vocab, mock_payslip): assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr." +def test_ocrpredictor_tables(mock_pdf, mock_vocab): + det_predictor = DetectionPredictor( + PreProcessor(output_size=(512, 512), batch_size=2), + detection.db_mobilenet_v3_large(pretrained=False, pretrained_backbone=False, assume_straight_pages=True), + ) + reco_predictor = RecognitionPredictor( + PreProcessor(output_size=(32, 128), batch_size=32, preserve_aspect_ratio=True), + recognition.crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=mock_vocab), + ) + layout_pred = layout_predictor("lw_detr_s", pretrained=False) + table_pred = table_predictor("tablecenternet", pretrained=False) + + # A table predictor requires a layout predictor (tables are located with the layout model) + with pytest.raises(ValueError): + OCRPredictor(det_predictor, reco_predictor, table_predictor=table_pred) + + doc = DocumentFile.from_pdf(mock_pdf) + + # Without a table predictor -> pages carry an empty list of tables + predictor = OCRPredictor(det_predictor, reco_predictor) + assert predictor.table_predictor is None + out = predictor(doc) + assert all(page.tables == [] for page in out.pages) + assert all(page.export()["tables"] == [] for page in out.pages) + + # With layout + table predictors -> structured tables are attached and exported + predictor = OCRPredictor(det_predictor, reco_predictor, layout_predictor=layout_pred, table_predictor=table_pred) + assert isinstance(predictor.layout_predictor, LayoutPredictor) + assert isinstance(predictor.table_predictor, TablePredictor) + out = predictor(doc) + assert isinstance(out, Document) + for page in out.pages: + assert isinstance(page.tables, list) + assert all(isinstance(t, Table) for t in page.tables) + exported = page.export() + assert "tables" in exported + assert exported["tables"] == [t.export() for t in page.tables] + + +def test_ocrpredictor_tables_factory(): + # The factory exposes a single `detect_tables` flag, which also enables the layout model + predictor = models.ocr_predictor("db_mobilenet_v3_large", "crnn_vgg16_bn", pretrained=False, detect_tables=True) + assert isinstance(predictor.table_predictor, TablePredictor) + assert isinstance(predictor.layout_predictor, LayoutPredictor) + + # No tables by default + predictor = models.ocr_predictor("db_mobilenet_v3_large", "crnn_vgg16_bn", pretrained=False) + assert predictor.table_predictor is None + + def test_trained_ocr_predictor(mock_pdf, mock_vocab, mock_payslip): det_predictor = DetectionPredictor( PreProcessor(output_size=(512, 512), batch_size=2), From 62416cf2700ff1d5671e992ce4b34ae9b284da76 Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 1 Jul 2026 16:17:30 +0200 Subject: [PATCH 2/5] Update deps --- pyproject.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b0b97fcc2d..2ec1bd38d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,8 @@ testing = [ "coverage[toml]>=4.5.4", "onnxruntime>=1.11.0", "requests>=2.20.0", - "psutil>=5.9.5" + "psutil>=5.9.5", + "pandas>=2.0.0", ] quality = [ "ruff>=0.1.5", @@ -104,7 +105,7 @@ dev = [ "onnxruntime>=1.11.0", "requests>=2.20.0", "psutil>=5.9.5", - "pandas>=3.0.0", + "pandas>=2.0.0", # Quality "ruff>=0.3.0", "mypy>=1.0", From e4d6c8da2e41fec3e737d0b3583b68498d805ed5 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 2 Jul 2026 09:08:23 +0200 Subject: [PATCH 3/5] Update demo & API --- api/app/routes/ocr.py | 23 ++++++++- api/app/schemas.py | 94 ++++++++++++++++++++++++++++++++---- api/tests/conftest.py | 2 + api/tests/routes/test_ocr.py | 29 +++++++++++ demo/app.py | 11 +++++ demo/backend/pytorch.py | 3 ++ demo/pt-requirements.txt | 1 + doctr/models/builder.py | 5 +- 8 files changed, 156 insertions(+), 12 deletions(-) diff --git a/api/app/routes/ocr.py b/api/app/routes/ocr.py index 16d6c7295b..9c6fe97d6c 100644 --- a/api/app/routes/ocr.py +++ b/api/app/routes/ocr.py @@ -6,7 +6,7 @@ from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status -from app.schemas import LayoutElementOut, OCRBlock, OCRIn, OCRLine, OCROut, OCRPage, OCRWord +from app.schemas import LayoutElementOut, OCRBlock, OCRIn, OCRLine, OCROut, OCRPage, OCRWord, TableCellOut, TableOut from app.utils import get_documents, resolve_geometry from app.vision import init_predictor @@ -39,6 +39,27 @@ async def perform_ocr(request: OCRIn = Depends(), files: list[UploadFile] = [Fil ) for region in page.layout ], + tables=[ + TableOut( + num_rows=table.num_rows, + num_cols=table.num_cols, + geometry=resolve_geometry(table.geometry), + confidence=round(table.confidence, 2), + cells=[ + TableCellOut( + value=cell.value, + geometry=resolve_geometry(cell.geometry), + confidence=round(cell.confidence, 2), + row_start=cell.row_start, + row_end=cell.row_end, + col_start=cell.col_start, + col_end=cell.col_end, + ) + for cell in table.cells + ], + ) + for table in page.tables + ], items=[ OCRPage( blocks=[ diff --git a/api/app/schemas.py b/api/app/schemas.py index 4e5f779168..a7f0ee7e33 100644 --- a/api/app/schemas.py +++ b/api/app/schemas.py @@ -31,6 +31,7 @@ class OCRIn(KIEIn, BaseModel): resolve_lines: bool = Field(default=True, examples=[True]) resolve_blocks: bool = Field(default=False, examples=[False]) paragraph_break: float = Field(default=0.0035, examples=[0.0035]) + detect_tables: bool = Field(default=False, examples=[False]) class RecognitionIn(BaseModel): @@ -139,6 +140,37 @@ class LayoutElementOut(BaseModel): confidence: float = Field(..., examples=[0.99]) +class TableCellOut(BaseModel): + value: str = Field(..., examples=["example"]) + geometry: list[float] = Field(..., examples=[[0.0, 0.0, 0.0, 0.0]]) + confidence: float = Field(..., examples=[0.99]) + row_start: int = Field(..., examples=[0]) + row_end: int = Field(..., examples=[0]) + col_start: int = Field(..., examples=[0]) + col_end: int = Field(..., examples=[0]) + + +class TableOut(BaseModel): + num_rows: int = Field(..., examples=[2]) + num_cols: int = Field(..., examples=[2]) + geometry: list[float] = Field(..., examples=[[0.0, 0.0, 0.0, 0.0]]) + confidence: float = Field(..., examples=[0.99]) + cells: list[TableCellOut] = Field( + ..., + examples=[ + { + "value": "example", + "geometry": [0.0, 0.0, 0.0, 0.0], + "confidence": 0.99, + "row_start": 0, + "row_end": 0, + "col_start": 0, + "col_end": 0, + } + ], + ) + + class OCROut(BaseModel): name: str = Field(..., examples=["example.jpg"]) orientation: dict[str, float | None] = Field(..., examples=[{"value": 0.0, "confidence": 0.99}]) @@ -148,27 +180,55 @@ class OCROut(BaseModel): default=[], examples=[[{"type": "Title", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}]], ) + tables: list[TableOut] = Field( + default=[], + examples=[ + [ + { + "num_rows": 2, + "num_cols": 2, + "geometry": [0.0, 0.0, 0.0, 0.0], + "confidence": 0.99, + "cells": [ + { + "value": "example", + "geometry": [0.0, 0.0, 0.0, 0.0], + "confidence": 0.99, + "row_start": 0, + "row_end": 0, + "col_start": 0, + "col_end": 0, + } + ], + } + ] + ], + ) items: list[OCRPage] = Field( ..., examples=[ { - "geometry": [0.0, 0.0, 0.0, 0.0], - "objectness_score": 0.99, - "lines": [ + "blocks": [ { "geometry": [0.0, 0.0, 0.0, 0.0], "objectness_score": 0.99, - "words": [ + "lines": [ { - "value": "example", "geometry": [0.0, 0.0, 0.0, 0.0], "objectness_score": 0.99, - "confidence": 0.99, - "crop_orientation": {"value": 0, "confidence": None}, + "words": [ + { + "value": "example", + "geometry": [0.0, 0.0, 0.0, 0.0], + "objectness_score": 0.99, + "confidence": 0.99, + "crop_orientation": {"value": 0, "confidence": None}, + } + ], } ], } - ], + ] } ], ) @@ -199,4 +259,20 @@ class KIEOut(BaseModel): default=[], examples=[[{"type": "Title", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}]], ) - predictions: list[KIEElement] + predictions: list[KIEElement] = Field( + ..., + examples=[ + { + "class_name": "example", + "items": [ + { + "value": "example", + "geometry": [0.0, 0.0, 0.0, 0.0], + "objectness_score": 0.99, + "confidence": 0.99, + "crop_orientation": {"value": 0, "confidence": None}, + } + ], + } + ], + ) diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 3d9d65a1aa..d410913215 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -158,6 +158,7 @@ def mock_ocr_response(): "language": {"value": None, "confidence": None}, "dimensions": [2339, 1654], "layout": [], + "tables": [], "items": [ { "blocks": [ @@ -207,6 +208,7 @@ def mock_ocr_response(): "language": {"value": None, "confidence": None}, "dimensions": [2339, 1654], "layout": [], + "tables": [], "items": [ { "blocks": [ diff --git a/api/tests/routes/test_ocr.py b/api/tests/routes/test_ocr.py index bdb54d0174..2e915db271 100644 --- a/api/tests/routes/test_ocr.py +++ b/api/tests/routes/test_ocr.py @@ -93,6 +93,35 @@ async def test_ocr_layout(test_app_asyncio, mock_detection_image): assert len(region["geometry"]) in (4, 8) +@pytest.mark.asyncio +async def test_ocr_tables(test_app_asyncio, mock_detection_image): + headers = { + "accept": "application/json", + } + params = {"det_arch": "db_resnet50", "reco_arch": "crnn_vgg16_bn", "detect_tables": True} + files = [ + ("files", ("test.jpg", mock_detection_image, "image/jpeg")), + ] + response = await test_app_asyncio.post("/ocr", params=params, files=files, headers=headers) + assert response.status_code == 200 + json_response = response.json() + + assert isinstance(json_response, list) and len(json_response) == 1 + # detect_tables enables the layout model as well + assert "layout" in json_response[0] and isinstance(json_response[0]["layout"], list) + assert "tables" in json_response[0] and isinstance(json_response[0]["tables"], list) + for table in json_response[0]["tables"]: + assert isinstance(table["num_rows"], int) and isinstance(table["num_cols"], int) + assert isinstance(table["confidence"], (int, float)) + assert len(table["geometry"]) in (4, 8) + assert isinstance(table["cells"], list) + for cell in table["cells"]: + assert isinstance(cell["value"], str) + assert isinstance(cell["confidence"], (int, float)) + assert len(cell["geometry"]) in (4, 8) + assert all(isinstance(cell[k], int) for k in ("row_start", "row_end", "col_start", "col_end")) + + @pytest.mark.asyncio async def test_ocr_invalid_file(test_app_asyncio, mock_txt_file): headers = { diff --git a/demo/app.py b/demo/app.py index 85446708bb..aa502ccd9b 100644 --- a/demo/app.py +++ b/demo/app.py @@ -6,6 +6,7 @@ import cv2 import matplotlib.pyplot as plt import numpy as np +import pandas as pd import streamlit as st import torch from backend.pytorch import DET_ARCHS, LAYOUT_ARCHS, RECO_ARCHS, forward_image, load_predictor @@ -70,6 +71,8 @@ def main(det_archs, reco_archs, layout_archs): # Layout detection detect_layout = st.sidebar.checkbox("Detect layout", value=False) layout_arch = st.sidebar.selectbox("Layout detection model", layout_archs, disabled=not detect_layout) + # Table detection (relies on the layout model to locate tables) + detect_tables = st.sidebar.checkbox("Detect tables", value=False) st.sidebar.write("\n") # Binarization threshold bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1) @@ -97,6 +100,7 @@ def main(det_archs, reco_archs, layout_archs): device=forward_device, detect_layout=detect_layout, layout_arch=layout_arch, + detect_tables=detect_tables, ) with st.spinner("Analyzing..."): @@ -122,6 +126,13 @@ def main(det_archs, reco_archs, layout_archs): img = out.pages[0].synthesize() cols[3].image(img, clamp=True) + # Display extracted tables (if any) + if out.pages[0].tables: + st.markdown("\nExtracted tables:") + for idx, table in enumerate(out.pages[0].tables): + st.markdown(f"**Table {idx + 1}** ({table.num_rows} x {table.num_cols})") + st.dataframe(pd.DataFrame(table.to_grid())) + # Display JSON st.markdown("\nHere are your analysis results in JSON format:") st.json(page_export, expanded=False) diff --git a/demo/backend/pytorch.py b/demo/backend/pytorch.py index 1c4c9d941d..980b1137a0 100644 --- a/demo/backend/pytorch.py +++ b/demo/backend/pytorch.py @@ -50,6 +50,7 @@ def load_predictor( device: torch.device, detect_layout: bool, layout_arch: str, + detect_tables: bool, ) -> OCRPredictor: """Load a predictor from doctr.models @@ -66,6 +67,7 @@ def load_predictor( device: torch.device, the device to load the predictor on detect_layout: whether to run a layout detection model and attach the regions to each page layout_arch: layout architecture to use when detect_layout is True + detect_tables: whether to detect tables (via the layout model), structure them and attach them to each page Returns: instance of OCRPredictor @@ -82,6 +84,7 @@ def load_predictor( disable_crop_orientation=disable_crop_orientation, detect_layout=detect_layout, layout_arch=layout_arch, + detect_tables=detect_tables, ).to(device) predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh predictor.det_predictor.model.postprocessor.box_thresh = box_thresh diff --git a/demo/pt-requirements.txt b/demo/pt-requirements.txt index 3fdaaf120e..c2a256f7d0 100644 --- a/demo/pt-requirements.txt +++ b/demo/pt-requirements.txt @@ -1,2 +1,3 @@ -e "python-doctr[viz] @ git+https://github.com/mindee/doctr.git" streamlit>=1.0.0 +pandas>=2.0.0 diff --git a/doctr/models/builder.py b/doctr/models/builder.py index 7b40840ff8..3fa15ce709 100644 --- a/doctr/models/builder.py +++ b/doctr/models/builder.py @@ -4,6 +4,7 @@ # See LICENSE or go to for full license details. +from collections.abc import Sequence from typing import Any import numpy as np @@ -498,7 +499,7 @@ def __call__( orientations: list[dict[str, Any]] | None = None, languages: list[dict[str, Any]] | None = None, regions: list[dict[str, Any] | None] | None = None, - tables: list[dict[str, Any] | None] | None = None, + tables: Sequence[dict[str, Any] | list[dict[str, Any]] | None] | None = None, ) -> Document: """Re-arrange detected words into structured blocks @@ -619,7 +620,7 @@ def __call__( # type: ignore[override] orientations: list[dict[str, Any]] | None = None, languages: list[dict[str, Any]] | None = None, regions: list[dict[str, Any] | None] | None = None, - tables: list[list[dict[str, Any] | None] | None] | None = None, + tables: Sequence[dict[str, Any] | list[dict[str, Any]] | None] | None = None, ) -> KIEDocument: """Re-arrange detected words into structured predictions From bab903746e32e4b1286e22673b5f582361c5a9ef Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 2 Jul 2026 09:50:32 +0200 Subject: [PATCH 4/5] Update model init --- doctr/models/zoo.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doctr/models/zoo.py b/doctr/models/zoo.py index aebc09a859..8637222dce 100644 --- a/doctr/models/zoo.py +++ b/doctr/models/zoo.py @@ -71,6 +71,9 @@ def _predictor( table_predictor( "tablecenternet", pretrained=pretrained, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, batch_size=det_bs, ) if detect_tables From 25dcd29f28b1c43fd7ae382e1fc72c925b639a75 Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 2 Jul 2026 10:29:50 +0200 Subject: [PATCH 5/5] Update table cropping --- doctr/models/builder.py | 23 ++++++--- doctr/models/predictor/pytorch.py | 47 +++++++++++++------ .../table_structure/tablecenternet/base.py | 33 ++++--------- 3 files changed, 59 insertions(+), 44 deletions(-) diff --git a/doctr/models/builder.py b/doctr/models/builder.py index 3fa15ce709..d68c31c1e4 100644 --- a/doctr/models/builder.py +++ b/doctr/models/builder.py @@ -357,6 +357,9 @@ def _build_tables( return [], consumed centers = self._word_centers(boxes) if num_words > 0 else np.empty((0, 2)) + # Geometry format follows the page's word geometry: straight 2-point boxes when the word boxes are + # (N, 4), 4-point polygons when they are (N, 4, 2). + straight = boxes.ndim != 3 tables_out: list[Table] = [] for table_dict in table_dicts: @@ -384,12 +387,19 @@ def _build_tables( confidence = float(np.mean([word_preds[i][1] for i in ordered])) else: value, confidence = "", float(cell["score"]) - geometry = tuple(tuple(float(c) for c in pt) for pt in poly.tolist()) + if straight: + xs, ys = poly[:, 0], poly[:, 1] + geometry: Any = ( + (float(xs.min()), float(ys.min())), + (float(xs.max()), float(ys.max())), + ) + else: + geometry = tuple(tuple(float(c) for c in pt) for pt in poly.tolist()) table_cells.append( TableCell( value=value, confidence=confidence, - geometry=geometry, # type: ignore[arg-type] + geometry=geometry, row_start=int(cell["row_start"]), row_end=int(cell["row_end"]), col_start=int(cell["col_start"]), @@ -397,11 +407,12 @@ def _build_tables( ) ) - # Enclosing geometry of the whole table (relative bbox) + # Enclosing geometry of the whole table all_pts = np.concatenate(cell_polys, axis=0) - table_geometry = ( - (float(all_pts[:, 0].min()), float(all_pts[:, 1].min())), - (float(all_pts[:, 0].max()), float(all_pts[:, 1].max())), + xmin, ymin = float(all_pts[:, 0].min()), float(all_pts[:, 1].min()) + xmax, ymax = float(all_pts[:, 0].max()), float(all_pts[:, 1].max()) + table_geometry: Any = ( + ((xmin, ymin), (xmax, ymax)) if straight else ((xmin, ymin), (xmax, ymin), (xmax, ymax), (xmin, ymax)) ) table_confidence = float(np.mean([cell["score"] for cell in cells])) diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index 51696f7127..8f248c294c 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -5,6 +5,7 @@ from typing import Any +import cv2 import numpy as np import torch from torch import nn @@ -195,8 +196,10 @@ def _tables_from_regions( ) -> list[list[dict[str, Any]]]: """Crop the table regions found by the layout model and run the table model on each crop. - The table model is applied per cropped region, so a page naturally yields one structured table per - detected `Table` region. Cell geometries are mapped back from crop-relative to page-relative coordinates. + Each `Table` region is rectified to an upright crop with a perspective transform (so straight boxes and + rotated polygons are both handled), the table model is applied per crop, and the predicted cell polygons + are mapped back to page-relative coordinates through the inverse transform. A page therefore yields one + structured table per detected `Table` region. Args: pages: the (possibly straightened) page images @@ -208,7 +211,8 @@ def _tables_from_regions( coordinates """ crops: list[np.ndarray] = [] - crop_meta: list[tuple[int, tuple[float, float, float, float]]] = [] + # (page index, inverse transform, crop width, crop height, page width, page height) + crop_meta: list[tuple[int, np.ndarray, int, int, int, int]] = [] for p_idx, (page, region) in enumerate(zip(pages, regions)): if region is None: continue @@ -217,28 +221,41 @@ def _tables_from_regions( if cls_name != self.table_class_name: continue pts = np.asarray(box, dtype=np.float32).reshape(-1, 2) - x0, y0 = float(pts[:, 0].min()), float(pts[:, 1].min()) - x1, y1 = float(pts[:, 0].max()), float(pts[:, 1].max()) - # Relative box -> pixel crop (axis-aligned, clamped to the page) - px0, py0 = max(0, int(round(x0 * w))), max(0, int(round(y0 * h))) - px1, py1 = min(w, int(round(x1 * w))), min(h, int(round(y1 * h))) - if px1 - px0 < 2 or py1 - py0 < 2: + if pts.shape[0] == 2: # straight box (x_min, y_min, x_max, y_max) -> corners + (bx0, by0), (bx1, by1) = pts + src = np.array([[bx0, by0], [bx1, by0], [bx1, by1], [bx0, by1]], dtype=np.float32) + else: # rotated 4-point polygon, already ordered (top-left, top-right, bottom-right, bottom-left) + src = pts.copy() + # Relative -> absolute pixel corners + src[:, 0] *= w + src[:, 1] *= h + # Upright crop size from the region side lengths + crop_w = int(round(max(np.linalg.norm(src[1] - src[0]), np.linalg.norm(src[2] - src[3])))) + crop_h = int(round(max(np.linalg.norm(src[3] - src[0]), np.linalg.norm(src[2] - src[1])))) + if crop_w < 2 or crop_h < 2: continue - crops.append(page[py0:py1, px0:px1]) - crop_meta.append((p_idx, (x0, y0, x1, y1))) + # Full-extent destination corners so crop-relative [0, 1] maps exactly to [0, crop] pixels + dst = np.array([[0, 0], [crop_w, 0], [crop_w, crop_h], [0, crop_h]], dtype=np.float32) + transform = cv2.getPerspectiveTransform(src, dst) + inverse = cv2.getPerspectiveTransform(dst, src) + crops.append(cv2.warpPerspective(page, transform, (crop_w, crop_h))) + crop_meta.append((p_idx, inverse, crop_w, crop_h, w, h)) tables_per_page: list[list[dict[str, Any]]] = [[] for _ in pages] if len(crops) == 0: return tables_per_page grids = self.table_predictor(crops, **kwargs) # type: ignore[misc] - for (p_idx, (x0, y0, x1, y1)), grid in zip(crop_meta, grids): - region_w, region_h = (x1 - x0), (y1 - y0) + for (p_idx, inverse, crop_w, crop_h, w, h), grid in zip(crop_meta, grids): remapped_cells: list[dict[str, Any]] = [] for cell in grid["cells"]: + # Cell polygon is crop-relative -> crop pixels -> page pixels (inverse transform) -> page-relative poly = np.asarray(cell["geometry"], dtype=np.float32).reshape(-1, 2) - poly[:, 0] = x0 + poly[:, 0] * region_w - poly[:, 1] = y0 + poly[:, 1] * region_h + poly[:, 0] *= crop_w + poly[:, 1] *= crop_h + poly = cv2.perspectiveTransform(poly[None, :, :], inverse)[0] + poly[:, 0] /= w + poly[:, 1] /= h new_cell = dict(cell) new_cell["geometry"] = poly.tolist() remapped_cells.append(new_cell) diff --git a/doctr/models/table_structure/tablecenternet/base.py b/doctr/models/table_structure/tablecenternet/base.py index cbdf03eadc..5cf56dd7cc 100644 --- a/doctr/models/table_structure/tablecenternet/base.py +++ b/doctr/models/table_structure/tablecenternet/base.py @@ -12,6 +12,7 @@ from shapely.geometry import Point, Polygon from doctr.models.core import BaseModel +from doctr.utils import order_points __all__ = ["_TableCenterNet", "TableCenterNetPostProcessor"] @@ -77,23 +78,6 @@ def _lookup_logic(lc_map: np.ndarray, x: float, y: float) -> np.ndarray: return lc_map[:, yi, xi] -def _ensure_simple_quads(polys: np.ndarray) -> np.ndarray: - """Guarantee each predicted quad is a simple (non-self-intersecting) polygon. - - Args: - polys: predicted quads, shape (N, 4, 2) - - Returns: - the quads with every self-intersecting one reordered into a simple polygon, shape (N, 4, 2) - """ - for i in range(polys.shape[0]): - if not Polygon(polys[i]).is_valid: - centroid = polys[i].mean(axis=0) - angles = np.arctan2(polys[i, :, 1] - centroid[1], polys[i, :, 0] - centroid[0]) - polys[i] = polys[i][np.argsort(angles)] - return polys - - class TableCenterNetPostProcessor: """TableCenterNet post-processor turning the model's *decoded* key-points into table cells. @@ -202,12 +186,15 @@ def __call__(self, decoded: dict[str, np.ndarray]) -> list[dict[str, np.ndarray] cp, cs, logic = self._simple(decoded, b) if self.not_relocate else self._relocate(decoded, b) keep = cs >= self.center_thresh polys = cp[keep].reshape(-1, 4, 2) / scale # relative coordinates - polys = _ensure_simple_quads(np.clip(polys.astype(np.float32), 0, 1)) - cells = ( - np.concatenate([polys.min(axis=1), polys.max(axis=1)], axis=1).astype(np.float32) - if self.assume_straight_pages - else polys - ) + polys = np.clip(polys.astype(np.float32), 0, 1) + if self.assume_straight_pages: + cells = np.concatenate([polys.min(axis=1), polys.max(axis=1)], axis=1).astype(np.float32) + else: + cells = ( + np.stack([order_points(poly) for poly in polys]).astype(np.float32) + if polys.shape[0] + else polys.reshape(0, 4, 2).astype(np.float32) + ) results.append({ "polygons": cells, # (N, 4) boxes or (N, 4, 2) quads in relative coordinates "scores": cs[keep].astype(np.float32),