From f50637c1169098dd326b09b4be002fbb9288a6cc Mon Sep 17 00:00:00 2001 From: felix Date: Thu, 25 Jun 2026 16:09:00 +0200 Subject: [PATCH 01/14] first darft --- docs/source/modules/models.rst | 18 +- docs/source/using_doctr/using_models.rst | 38 +- doctr/models/__init__.py | 1 + doctr/models/factory/hub.py | 7 +- doctr/models/table_structure/__init__.py | 2 + .../table_structure/predictor/__init__.py | 1 + .../table_structure/predictor/pytorch.py | 79 +++ .../tablecenternet/__init__.py | 2 + .../table_structure/tablecenternet/base.py | 428 ++++++++++++++ .../table_structure/tablecenternet/pytorch.py | 515 +++++++++++++++++ doctr/models/table_structure/zoo.py | 82 +++ doctr/models/utils/pytorch.py | 35 +- references/table/README.md | 1 + references/table/evaluate.py | 143 +++++ references/table/latency.py | 54 ++ references/table/train.py | 538 ++++++++++++++++++ references/table/utils.py | 115 ++++ tests/common/test_models_table_structure.py | 26 + .../pytorch/test_models_table_structure_pt.py | 165 ++++++ 19 files changed, 2223 insertions(+), 27 deletions(-) create mode 100644 doctr/models/table_structure/__init__.py create mode 100644 doctr/models/table_structure/predictor/__init__.py create mode 100644 doctr/models/table_structure/predictor/pytorch.py create mode 100644 doctr/models/table_structure/tablecenternet/__init__.py create mode 100644 doctr/models/table_structure/tablecenternet/base.py create mode 100644 doctr/models/table_structure/tablecenternet/pytorch.py create mode 100644 doctr/models/table_structure/zoo.py create mode 100644 references/table/README.md create mode 100644 references/table/evaluate.py create mode 100644 references/table/latency.py create mode 100644 references/table/train.py create mode 100644 references/table/utils.py create mode 100644 tests/common/test_models_table_structure.py create mode 100644 tests/pytorch/test_models_table_structure_pt.py diff --git a/docs/source/modules/models.rst b/docs/source/modules/models.rst index 3515c77ecc..60985aa6c2 100644 --- a/docs/source/modules/models.rst +++ b/docs/source/modules/models.rst @@ -85,7 +85,13 @@ doctr.models.layout .. autofunction:: doctr.models.layout.lw_detr_m -.. autofunction:: doctr.models.layout.layout_predictor + +doctr.models.table_structure +---------------------------- + +.. autofunction:: doctr.models.table_structure.tablecenternet + +.. autofunction:: doctr.models.table_structure.table_predictor doctr.models.recognition @@ -128,13 +134,3 @@ doctr.models.factory .. autofunction:: doctr.models.factory.from_hub .. autofunction:: doctr.models.factory.push_to_hf_hub - - -doctr.models.utils ------------------- - -.. currentmodule:: doctr.models.utils - -.. autofunction:: export_model_to_onnx - -.. autofunction:: add_whitelist diff --git a/docs/source/using_doctr/using_models.rst b/docs/source/using_doctr/using_models.rst index 3c589b2799..21d7a9b78c 100644 --- a/docs/source/using_doctr/using_models.rst +++ b/docs/source/using_doctr/using_models.rst @@ -29,6 +29,8 @@ Which predictor should I use? - :py:meth:`detection_predictor ` * - Transcribe pre-cropped word images to strings - :py:meth:`recognition_predictor ` + * - Detect the structure of a table (cell bounding-boxes and logical coordinates) + - :py:meth:`table_predictor ` For :doc:`custom model loading ` or sharing models, see the dedicated pages. @@ -121,8 +123,8 @@ Text Recognition The task consists of transcribing the character sequence in a given image. -Available recognition architectures -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Available architectures +^^^^^^^^^^^^^^^^^^^^^^^ The following architectures are currently supported: @@ -256,6 +258,37 @@ For instance, this snippet instantiates a layout predictor able to detect text o predictor = layout_predictor('lw_detr_s', pretrained=True, assume_straight_pages=False, preserve_aspect_ratio=True) +Table Structure Recognition +--------------------------- + +The task consists of parsing the structure of a table into a machine-understandable representation: localizing every +cell (its spatial structure) and recovering the row and column it spans (its logical structure). + +Available table architectures +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The following architectures are currently supported: + +* :py:meth:`tablecenternet ` + +Table structure predictors +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +:py:meth:`table_predictor ` wraps your table model so it can be used directly on +document images. For each page it returns the list of detected cells, each with its geometry, its confidence score and its logical coordinates, together with the inferred number of rows +and columns. + +.. code:: python3 + + import numpy as np + from doctr.models import table_predictor + model = table_predictor('tablecenternet', pretrained=True) + table_crop = (255 * np.random.rand(800, 600, 3)).astype(np.uint8) + out = model([table_crop]) + # out[0] -> {"cells": [{"geometry": ..., "score": ..., "row_start": 0, "row_end": 0, + # "col_start": 0, "col_end": 0}, ...], "num_rows": ..., "num_cols": ...} + + End-to-End OCR -------------- @@ -673,4 +706,3 @@ learned confusions, or a ``{forbidden_char: allowed_char}`` dict to override spe handle = add_whitelist(predictor, VOCABS["latin"], strategy="nearest") out = predictor(input_page) handle.remove() - diff --git a/doctr/models/__init__.py b/doctr/models/__init__.py index 8bdcccd1dd..893998d536 100644 --- a/doctr/models/__init__.py +++ b/doctr/models/__init__.py @@ -2,5 +2,6 @@ from .detection import * from .recognition import * from .layout import * +from .table_structure import * from .zoo import * from .factory import * diff --git a/doctr/models/factory/hub.py b/doctr/models/factory/hub.py index 57839f5646..5d75535472 100644 --- a/doctr/models/factory/hub.py +++ b/doctr/models/factory/hub.py @@ -30,6 +30,7 @@ "detection": models.detection.zoo.ARCHS, "recognition": models.recognition.zoo.ARCHS, "layout": models.layout.zoo.ARCHS, + "table_structure": models.table_structure.zoo.ARCHS, } @@ -96,8 +97,8 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # if run_config is None and arch is None: raise ValueError("run_config or arch must be specified") - if task not in ["classification", "detection", "recognition", "layout"]: - raise ValueError("task must be one of classification, detection, recognition, layout") + if task not in ["classification", "detection", "recognition", "layout", "table_structure"]: + raise ValueError("task must be one of classification, detection, recognition, layout, table_structure") # default readme readme = f"""--- @@ -218,6 +219,8 @@ def from_hub(repo_id: str, **kwargs: Any): model = models.recognition.__dict__[arch](pretrained=False, input_shape=cfg["input_shape"], vocab=cfg["vocab"]) elif task == "layout": model = models.layout.__dict__[arch](pretrained=False, class_names=cfg["class_names"]) + elif task == "table_structure": + model = models.table_structure.__dict__[arch](pretrained=False) # update model cfg model.cfg = cfg diff --git a/doctr/models/table_structure/__init__.py b/doctr/models/table_structure/__init__.py new file mode 100644 index 0000000000..4cf653000a --- /dev/null +++ b/doctr/models/table_structure/__init__.py @@ -0,0 +1,2 @@ +from .zoo import * +from .tablecenternet import * diff --git a/doctr/models/table_structure/predictor/__init__.py b/doctr/models/table_structure/predictor/__init__.py new file mode 100644 index 0000000000..e3c861310c --- /dev/null +++ b/doctr/models/table_structure/predictor/__init__.py @@ -0,0 +1 @@ +from .pytorch import * diff --git a/doctr/models/table_structure/predictor/pytorch.py b/doctr/models/table_structure/predictor/pytorch.py new file mode 100644 index 0000000000..d3c93051fa --- /dev/null +++ b/doctr/models/table_structure/predictor/pytorch.py @@ -0,0 +1,79 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any + +import numpy as np +import torch +from torch import nn + +from doctr.models.detection._utils import _remove_padding +from doctr.models.preprocessor import PreProcessor +from doctr.models.utils import set_device_and_dtype + +__all__ = ["TablePredictor"] + + +class TablePredictor(nn.Module): + """Implements an object able to recognize the cell structure of tables in a document. + + Args: + pre_processor: transform inputs for easier batched model inference + model: core table-structure-recognition architecture + """ + + def __init__( + self, + pre_processor: PreProcessor, + model: nn.Module, + ) -> None: + super().__init__() + self.pre_processor = pre_processor + self.model = model.eval() + + @torch.inference_mode() + def forward(self, pages: list[np.ndarray], **kwargs: Any) -> list[dict[str, Any]]: + # Extract parameters from the preprocessor + preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio + symmetric_pad = self.pre_processor.resize.symmetric_pad + assume_straight_pages = self.model.assume_straight_pages + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + processed_batches = self.pre_processor(pages) + _params = next(self.model.parameters()) + self.model, processed_batches = set_device_and_dtype( # type: ignore[assignment] + self.model, processed_batches, _params.device, _params.dtype + ) + predicted_batches = [self.model(batch, return_preds=True, **kwargs) for batch in processed_batches] + preds = [pred for batch in predicted_batches for pred in batch["preds"]] + + rectified = _remove_padding( + pages, + [{"polygons": pred["polygons"]} for pred in preds], + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + assume_straight_pages=assume_straight_pages, # type: ignore[arg-type] + ) + + results: list[dict[str, Any]] = [] + for pred, rect in zip(preds, rectified): + polygons = rect["polygons"] # * np.array([w, h], dtype=np.float32) # relative -> absolute pixels + scores, logical = pred["scores"], pred["logical"] + cells, max_row, max_col = [], 0, 0 + for poly, score, lc in zip(polygons, scores, logical): + start_col, end_col, start_row, end_row = (int(v) for v in lc) + max_row, max_col = max(max_row, end_row), max(max_col, end_col) + cells.append({ + "geometry": poly.tolist(), # 4 points (TL, TR, BR, BL) in relative coordinates + "score": float(score), + "row_start": start_row, + "row_end": end_row, + "col_start": start_col, + "col_end": end_col, + }) + results.append({"cells": cells, "num_rows": max_row, "num_cols": max_col}) + return results diff --git a/doctr/models/table_structure/tablecenternet/__init__.py b/doctr/models/table_structure/tablecenternet/__init__.py new file mode 100644 index 0000000000..b2a33fe1a3 --- /dev/null +++ b/doctr/models/table_structure/tablecenternet/__init__.py @@ -0,0 +1,2 @@ +from .base import * +from .pytorch import * diff --git a/doctr/models/table_structure/tablecenternet/base.py b/doctr/models/table_structure/tablecenternet/base.py new file mode 100644 index 0000000000..5ff4f6ffe1 --- /dev/null +++ b/doctr/models/table_structure/tablecenternet/base.py @@ -0,0 +1,428 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Credits: decode logic ported from https://github.com/dreamy-xay/TableCenterNet + +import math + +import numpy as np +from scipy.interpolate import griddata +from shapely.geometry import Point, Polygon + +__all__ = ["TableCenterNetPostProcessor"] + + +# TODO: It should be organized like in LinkNet for example or LWDETR _LWDETR (which builds target) +# TODO: and a LWDETRPostProcessor (which decodes the model's output). + + +def _get_logic_coords(lc_logic: np.ndarray, col_span: int, row_span: int) -> tuple[int, int, int, int]: + """Resolve a cell's logical coordinates (start/end column and row) from the per-corner logical + predictions (``lc_logic`` is a (4, 2) array of [col, row] for corners TL, TR, BR, BL) and the cell span. + Pure numpy port of the reference ``get_logic_coords``.""" + col_span = max(1, col_span) + row_span = max(1, row_span) + col_lc = [max(1, int(round(float(p)))) for p in lc_logic[:, 0]] + row_lc = [max(1, int(round(float(p)))) for p in lc_logic[:, 1]] + cols, rows = lc_logic[:, 0], lc_logic[:, 1] + + if col_lc[0] == col_lc[3]: + start_col = col_lc[0] + end_col = start_col + col_span - 1 + elif col_lc[1] == col_lc[2]: + end_col = max(col_span + 1, col_lc[1]) - 1 + start_col = end_col + 1 - col_span + elif abs(cols[0] - cols[3]) <= abs(cols[1] - cols[2]): + start_col = max(1, int(round((cols[0] + cols[3]) / 2.0))) + end_col = start_col + col_span - 1 + else: + end_col = max(col_span + 1, int(round((cols[1] + cols[2]) / 2.0))) - 1 + start_col = end_col + 1 - col_span + + if row_lc[0] == row_lc[1]: + start_row = row_lc[0] + end_row = start_row + row_span - 1 + elif row_lc[2] == row_lc[3]: + end_row = max(row_span + 1, row_lc[2]) - 1 + start_row = end_row + 1 - row_span + elif abs(rows[0] - rows[1]) <= abs(rows[2] - rows[3]): + start_row = max(1, int(round((rows[0] + rows[1]) / 2.0))) + end_row = start_row + row_span - 1 + else: + end_row = max(row_span + 1, int(round((rows[2] + rows[3]) / 2.0))) - 1 + start_row = end_row + 1 - row_span + + return start_col, end_col, start_row, end_row + + +def _bbox_overlap_query(center_polys: np.ndarray, corner_polys: np.ndarray) -> list[np.ndarray]: + """For each center polygon, the indices of corner polygons whose axis-aligned bounding boxes overlap + (equivalent to the reference ``BoxesFinder``).""" + c_xmin, c_xmax = center_polys[:, 0::2].min(1), center_polys[:, 0::2].max(1) + c_ymin, c_ymax = center_polys[:, 1::2].min(1), center_polys[:, 1::2].max(1) + k_xmin, k_xmax = corner_polys[:, 0::2].min(1), corner_polys[:, 0::2].max(1) + k_ymin, k_ymax = corner_polys[:, 1::2].min(1), corner_polys[:, 1::2].max(1) + out = [] + for i in range(center_polys.shape[0]): + x_ok = (k_xmin <= c_xmax[i]) & (k_xmax >= c_xmin[i]) + y_ok = (k_ymin <= c_ymax[i]) & (k_ymax >= c_ymin[i]) + out.append(np.nonzero(x_ok & y_ok)[0]) + return out + + +def _lookup_logic(lc_map: np.ndarray, x: float, y: float) -> np.ndarray: + """Sample the (2, H, W) logical-coordinate map at a clamped pixel location.""" + h, w = lc_map.shape[1:] + xi = 0 if x < 0 else (w - 1 if x >= w else int(x)) + yi = 0 if y < 0 else (h - 1 if y >= h else int(y)) + return lc_map[:, yi, xi] + + +class TableCenterNetPostProcessor: + """Torch-free post-processor turning the model's *decoded* key-points into table cells. + + All tensor-heavy operations (heat-map NMS, top-k, gather) are performed inside the model's decoder + (which requires torch and is skipped during ONNX export). This object only consumes numpy arrays, so + it never blocks an export and can be tested without torch. + + The cell geometry is returned in **relative** coordinates ([0, 1] w.r.t. the model input), so the + predictor can undo the pre-processor's padding/resize like the other docTR predictors. + + Args: + center_thresh: minimum score for a cell center to be kept + corner_thresh: minimum score for a corner to be used during relocation + not_relocate: if True, skip the corner-relocation step (faster, less accurate) + """ + + def __init__( + self, + center_thresh: float = 0.3, + corner_thresh: float = 0.3, + not_relocate: bool = False, + ) -> None: + self.center_thresh = center_thresh + self.corner_thresh = corner_thresh + self.not_relocate = not_relocate + # Cell score decay (reference defaults): cells optimised on <= 2 corners get their score scaled. + self.cell_min_optimize_count = 2 + self.cell_decay_thresh = 0.4 + + def _relocate(self, decoded: dict[str, np.ndarray], b: int): + cp = decoded["center_polygons"][b].copy() # (Kc, 8) + cs = decoded["center_scores"][b].copy() # (Kc,) + spans = decoded["center_spans"][b] # (Kc, 2) + corner_polys = decoded["corner_polygons"][b] # (Kn, 8) + corner_scores = decoded["corner_scores"][b] # (Kn,) + corner_pts = decoded["corner_points"][b] # (Kn, 2) + corner_logics = decoded["corner_logics"][b] # (Kn, 2) + lc_map = decoded["lc"][b] # (2, H, W) + + valid_c = np.nonzero(cs >= self.center_thresh)[0] + valid_k = np.nonzero(corner_scores >= self.corner_thresh)[0] + queries = ( + _bbox_overlap_query(cp[valid_c], corner_polys[valid_k]) + if valid_k.size + else [np.array([], int)] * valid_c.size + ) + + logic = np.zeros((cp.shape[0], 4), dtype=np.int32) + corner_count = np.zeros(cp.shape[0], dtype=np.int32) + for qi, i in enumerate(valid_c): + center_poly = Polygon(cp[i].reshape(4, 2)) + cell = cp[i].reshape(4, 2) + origin = decoded["center_polygons"][b][i].reshape(4, 2) + lc_logic: list[np.ndarray | None] = [None, None, None, None] + n_used = n_repeat = 0 + for j in valid_k[queries[qi]]: + cx, cy = corner_pts[j] + if not any(Point(p).within(center_poly) for p in corner_polys[j].reshape(4, 2)): + continue + # nearest corner index is computed on the ORIGINAL polygon (matches find_near_corner_index) + idx = int(np.argmin(((origin - [cx, cy]) ** 2).sum(1))) + ox, oy = origin[idx] + px, py = cell[idx] + if px == ox and py == oy: + n_used += 1 + cell[idx] = [cx, cy] + lc_logic[idx] = corner_logics[j] + elif (ox - px) ** 2 + (oy - py) ** 2 >= (ox - cx) ** 2 + (oy - cy) ** 2: + n_repeat += 1 + cell[idx] = [cx, cy] + lc_logic[idx] = corner_logics[j] + corner_count[i] = n_used + n_repeat + for k in range(4): + if lc_logic[k] is None: + lc_logic[k] = _lookup_logic(lc_map, cell[k][0], cell[k][1]) + col_span, row_span = int(round(float(spans[i][0]))), int(round(float(spans[i][1]))) + logic[i] = _get_logic_coords(np.stack(lc_logic), col_span, row_span) # type: ignore[arg-type] + cp[i] = cell.reshape(8) + + # Score decay for under-optimised cells, then re-sort + keep_high = cs >= self.center_thresh + decay = keep_high & (corner_count <= self.cell_min_optimize_count) + cs[decay] *= self.cell_decay_thresh + order = np.argsort(-cs) + return cp[order], cs[order], logic[order] + + def _simple(self, decoded: dict[str, np.ndarray], b: int): + cp = decoded["center_polygons"][b] + cs = decoded["center_scores"][b] + spans = decoded["center_spans"][b] + lc_map = decoded["lc"][b] + logic = np.zeros((cp.shape[0], 4), dtype=np.int32) + for i in np.nonzero(cs >= self.center_thresh)[0]: + cell = cp[i].reshape(4, 2) + lc_logic = np.stack([_lookup_logic(lc_map, cell[k][0], cell[k][1]) for k in range(4)]) + col_span, row_span = int(round(float(spans[i][0]))), int(round(float(spans[i][1]))) + logic[i] = _get_logic_coords(lc_logic, col_span, row_span) + return cp, cs, logic + + def __call__(self, decoded: dict[str, np.ndarray]) -> list[dict[str, np.ndarray]]: + feat_h, feat_w = decoded["feat_size"] + scale = np.array([feat_w, feat_h], dtype=np.float32) + results: list[dict[str, np.ndarray]] = [] + for b in range(decoded["center_polygons"].shape[0]): + cp, cs, logic = self._simple(decoded, b) if self.not_relocate else self._relocate(decoded, b) + keep = cs >= self.center_thresh + polys = cp[keep].reshape(-1, 4, 2) / scale # relative coordinates + results.append({ + "polygons": np.clip(polys.astype(np.float32), 0, 1), # (N, 4, 2) TL, TR, BR, BL + "scores": cs[keep].astype(np.float32), + "logical": logic[keep].astype(np.int32), # start_col, end_col, start_row, end_row + }) + return results + + +# --------------------------------------------------------------------------------------------------------- +# Dense-target rendering (numpy/scipy, ported from the reference dataset/target builder). +# Used by ``TableCenterNet.build_target`` to render the maps consumed by ``compute_loss``. +# --------------------------------------------------------------------------------------------------------- +def _gaussian_radius(det_size: tuple[float, float], min_overlap: float = 0.7) -> float: + height, width = det_size + a1, b1, c1 = 1, height + width, width * height * (1 - min_overlap) / (1 + min_overlap) + r1 = (b1 + math.sqrt(max(b1**2 - 4 * a1 * c1, 0))) / 2 + a2, b2, c2 = 4, 2 * (height + width), (1 - min_overlap) * width * height + r2 = (b2 + math.sqrt(max(b2**2 - 4 * a2 * c2, 0))) / 2 + a3, b3, c3 = 4 * min_overlap, -2 * min_overlap * (height + width), (min_overlap - 1) * width * height + r3 = (b3 + math.sqrt(max(b3**2 - 4 * a3 * c3, 0))) / 2 + return min(r1, r2, r3) + + +def _gaussian_2d(shape: tuple[int, int], sigma: float = 1.0) -> np.ndarray: + m, n = ((s - 1.0) / 2.0 for s in shape) + y, x = np.ogrid[-m : m + 1, -n : n + 1] + h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) + h[h < np.finfo(h.dtype).eps * h.max()] = 0 + return h + + +def _draw_umich_gaussian(heatmap: np.ndarray, center: np.ndarray, radius: int, k: float = 1.0) -> None: + diameter = 2 * radius + 1 + gaussian = _gaussian_2d((diameter, diameter), sigma=diameter / 6) + x, y = int(center[0]), int(center[1]) + height, width = heatmap.shape[:2] + left, right = min(x, radius), min(width - x, radius + 1) + top, bottom = min(y, radius), min(height - y, radius + 1) + masked_heatmap = heatmap[y - top : y + bottom, x - left : x + right] + masked_gaussian = gaussian[radius - top : radius + bottom, radius - left : radius + right] + if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: + np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) + + +def _polygon_area(points: list[tuple[float, float]]) -> float: + n = len(points) + area = 0.0 + for i in range(n): + x1, y1 = points[i] + x2, y2 = points[(i + 1) % n] + area += x1 * y2 - x2 * y1 + return abs(area) / 2.0 + + +def _interpolate_polygons(polygons: list[list[tuple]], img_size: tuple[int, int]) -> tuple[np.ndarray, np.ndarray]: + """Fill each polygon's interior with the linear interpolation of its per-corner value (the ``"sort"`` + variant of the reference ``interpolate_polygons``).""" + final_image = np.zeros(img_size, dtype=np.float32) + mask = np.zeros(img_size, dtype=bool) + areas = [_polygon_area([(x, y) for x, y, _ in poly]) for poly in polygons] + for _, poly in sorted(zip(areas, polygons), key=lambda t: t[0]): + xs = [p[0] for p in poly] + ys = [p[1] for p in poly] + x_min, y_min = math.floor(min(xs)), math.floor(min(ys)) + x_max, y_max = math.ceil(max(xs)), math.ceil(max(ys)) + bw, bh = x_max - x_min + 1, y_max - y_min + 1 + if bw <= 0 or bh <= 0: + continue + points = np.array([(x - x_min, y - y_min) for x, y, _ in poly], dtype=np.float32) + values = np.array([v for _, _, v in poly], dtype=np.float32) + gx, gy = np.meshgrid(np.arange(bw), np.arange(bh)) + grid_points = np.vstack((gx.ravel(), gy.ravel())).T + try: + interp = griddata(points, values, grid_points, method="linear", fill_value=-1) + except Exception: + continue + for (gi, gj), value in zip(grid_points, interp): + i, j = gi + x_min, gj + y_min + if value >= 0 and 0 <= j < img_size[0] and 0 <= i < img_size[1] and not mask[j, i]: + final_image[j, i] = value + mask[j, i] = True + return final_image, mask + + +def _interpolate_logical_map(cells_logic_coords: list, output_size: tuple[int, int]): + """Build the dense (2, H, W) logical-coordinate map (column map, row map) and its mask.""" + if not cells_logic_coords: + return np.zeros((2, *output_size), np.float32), np.zeros((2, *output_size), np.float32) + cols = [[(x, y, col) for (x, y, col, _row) in cell] for cell in cells_logic_coords] + rows = [[(x, y, row) for (x, y, _col, row) in cell] for cell in cells_logic_coords] + col_img, col_mask = _interpolate_polygons(cols, output_size) + row_img, row_mask = _interpolate_polygons(rows, output_size) + lc = np.stack([col_img, row_img], axis=0) + lc_mask = np.stack([col_mask, row_mask], axis=0).astype(np.float32) + return lc, lc_mask + + +def build_table_target( + cells: np.ndarray, + logic: np.ndarray, + output_size: tuple[int, int], + max_objects: int = 300, + max_corners: int = 1200, +) -> dict[str, np.ndarray]: + """Render the dense TableCenterNet targets consumed by ``TableCenterNet.compute_loss``. + + Args: + cells: (N, 4, 2) cell quadrilaterals (corner order TL, TR, BR, BL) in **output-grid** coordinates + logic: (N, 4) integer logical coordinates ``[start_col, end_col, start_row, end_row]`` (0-indexed) + output_size: (H, W) of the model output grid (input size // down_ratio) + max_objects: maximum number of cells + max_corners: maximum number of distinct corners + + Returns: + the dense target dictionary (numpy arrays) matching the reference schema + """ + out_h, out_w = output_size + hm = np.zeros((2, out_h, out_w), np.float32) + reg = np.zeros((max_objects * 5, 2), np.float32) + ct2cn = np.zeros((max_objects, 8), np.float32) + cn2ct = np.zeros((max_corners, 8), np.float32) + reg_ind = np.zeros((max_objects * 5,), np.int64) + reg_mask = np.zeros((max_objects * 5,), np.float32) + ct_ind = np.zeros((max_objects,), np.int64) + ct_mask = np.zeros((max_objects,), np.float32) + cn_ind = np.zeros((max_corners,), np.int64) + cn_mask = np.zeros((max_corners,), np.float32) + ct_cn_ind = np.zeros((max_objects * 4,), np.int64) + lc_ind = np.zeros((max_objects, 4), np.int64) + lc_span = np.zeros((max_objects, 2), np.float32) + + corner_dict: dict[str, int] = {} + cells_logic_coords: list = [] + + for i in range(min(len(cells), max_objects)): + corners = cells[i].reshape(8).astype(np.float32).copy() + corners[0::2] = np.clip(corners[0::2], 0, out_w - 1) + corners[1::2] = np.clip(corners[1::2], 0, out_h - 1) + if len(set(corners[0::2].tolist())) < 2 or len(set(corners[1::2].tolist())) < 2: + continue # not an effective quad + xs, ys = corners[0::2], corners[1::2] + max_x, min_x, max_y, min_y = xs.max(), xs.min(), ys.max(), ys.min() + bbox_h, bbox_w = max_y - min_y, max_x - min_x + if bbox_h <= 0 or bbox_w <= 0: + continue + + radius = max(0, int(_gaussian_radius((math.ceil(bbox_h), math.ceil(bbox_w))))) + center = np.array([(max_x + min_x) / 2.0, (max_y + min_y) / 2.0], np.float32) + ci = center.astype(np.int32) + flat = ci[1] * out_w + ci[0] + reg[i] = center - ci + reg_ind[i], reg_mask[i] = flat, 1 + ct_ind[i], ct_mask[i] = flat, 1 + _draw_umich_gaussian(hm[0], ci, radius) + ct2cn[i] = center[[0, 1, 0, 1, 0, 1, 0, 1]] - corners + + start_col, end_col, start_row, end_row = (int(v) + 1 for v in logic[i]) + x1, y1, x2, y2, x3, y3, x4, y4 = corners.tolist() + clc = [ + (x1, y1, start_col, start_row), + (x2, y2, end_col + 1, start_row), + (x3, y3, end_col + 1, end_row + 1), + (x4, y4, start_col, end_row + 1), + ] + cells_logic_coords.append(clc) + for j, (x, y, _c, _r) in enumerate(clc): + lc_ind[i, j] = int(y) * out_w + int(x) + lc_span[i] = (end_col - start_col + 1, end_row - start_row + 1) + + for j in range(4): + si = j * 2 + corner = corners[si : si + 2] + cint = corner.astype(np.int32) + key = f"{cint[0]}_{cint[1]}" + if key not in corner_dict: + nc = len(corner_dict) + if nc >= max_corners: + break + corner_dict[key] = nc + reg[max_objects + nc] = np.abs(corner - cint) + reg_ind[max_objects + nc] = cint[1] * out_w + cint[0] + reg_mask[max_objects + nc] = 1 + cn_ind[nc] = cint[1] * out_w + cint[0] + cn_mask[nc] = 1 + _draw_umich_gaussian(hm[1], cint, 2) + cn2ct[nc][si : si + 2] = corner - center + ct_cn_ind[4 * i + j] = nc * 4 + j + else: + idx = corner_dict[key] + cn2ct[idx][si : si + 2] = corner - center + ct_cn_ind[4 * i + j] = idx * 4 + j + + lc, lc_mask = _interpolate_logical_map(cells_logic_coords, (out_h, out_w)) + return { + "hm": hm, + "reg": reg, + "reg_ind": reg_ind, + "reg_mask": reg_mask, + "ct_ind": ct_ind, + "ct_mask": ct_mask, + "cn_ind": cn_ind, + "cn_mask": cn_mask, + "ct2cn": ct2cn, + "cn2ct": cn2ct, + "ct_cn_ind": ct_cn_ind, + "lc": lc, + "lc_mask": lc_mask, + "lc_ind": lc_ind, + "lc_span": lc_span, + } + + +def build_target( + target: list[dict[str, np.ndarray]], output_shape: tuple[int, int], max_objects: int, max_corners: int +) -> dict[str, np.ndarray]: + """Render the dense training targets for a batch from per-image cell annotations. + + Args: + target: one ``{"cells": (N, 4, 2) relative polygons, "logic": (N, 4)}`` dict per image + output_shape: (H, W) of the model output grid + max_objects: maximum number of cells + max_corners: maximum number of distinct corners + + Returns: + the batched dense target dictionary (numpy arrays) + """ + out_h, out_w = output_shape + scale = np.array([out_w, out_h], dtype=np.float32) + per_image = [ + build_table_target( + np.asarray(t["cells"], dtype=np.float32).reshape(-1, 4, 2) * scale, + np.asarray(t["logic"], dtype=np.int64).reshape(-1, 4), + (out_h, out_w), + max_objects, + max_corners, + ) + for t in target + ] + return {k: np.stack([img[k] for img in per_image], axis=0) for k in per_image[0]} diff --git a/doctr/models/table_structure/tablecenternet/pytorch.py b/doctr/models/table_structure/tablecenternet/pytorch.py new file mode 100644 index 0000000000..1b949fafe2 --- /dev/null +++ b/doctr/models/table_structure/tablecenternet/pytorch.py @@ -0,0 +1,515 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Credits: architecture and loss ported from https://github.com/dreamy-xay/TableCenterNet + +from collections.abc import Callable +from typing import Any + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from torchvision.models._utils import IntermediateLayerGetter + +from ....models.classification import starnet_s3 +from ...modules.layers.pytorch import DCNv2 +from ...utils import load_pretrained_params +from .base import TableCenterNetPostProcessor, build_target + +__all__ = ["TableCenterNet", "tablecenternet"] + +default_cfgs: dict[str, dict[str, Any]] = { + "tablecenternet": { + "input_shape": (3, 1024, 1024), + "mean": (0.798, 0.785, 0.772), + "std": (0.264, 0.2749, 0.287), + "url": None, + }, +} + +# Helper functions + + +def _gather_feat(feat: torch.Tensor, ind: torch.Tensor) -> torch.Tensor: + """Gather features at specific indices.""" + dim = feat.size(2) + ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) + return feat.gather(1, ind) + + +def _transpose_and_gather_feat(feat: torch.Tensor, ind: torch.Tensor) -> torch.Tensor: + """Transpose and gather features at specific indices.""" + feat = feat.permute(0, 2, 3, 1).contiguous() + feat = feat.view(feat.size(0), -1, feat.size(3)) + return _gather_feat(feat, ind) + + +# Layers + + +class DeformConv(nn.Module): + """A deformable convolution layer, as described in ``_. + + Args: + chi: number of input channels + cho: number of output channels + """ + + def __init__(self, chi: int, cho: int): + super().__init__() + self.actf = nn.Sequential(nn.BatchNorm2d(cho, momentum=0.1), nn.ReLU(inplace=True)) + self.conv = DCNv2(chi, cho, kernel_size=3, stride=1, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.actf(self.conv(x)) + + +class IDAUp(nn.Module): + """Iterative Deep Aggregation for Upsampling, as described in ``_. + + Args: + o: number of output channels + channels: list of number of channels for each input feature map + up_f: list of upsampling factors for each input feature map + """ + + def __init__(self, o: int, channels: list[int], up_f: list[int]): + super().__init__() + for i in range(1, len(channels)): + c, f = channels[i], int(up_f[i]) + setattr(self, "proj_" + str(i), DeformConv(c, o)) + setattr(self, "node_" + str(i), DeformConv(o, o)) + up = nn.ConvTranspose2d(o, o, f * 2, stride=f, padding=f // 2, output_padding=0, groups=o, bias=False) + setattr(self, "up_" + str(i), up) + + def forward(self, layers: list[torch.Tensor], startp: int, endp: int) -> None: + for i in range(startp + 1, endp): + upsample = getattr(self, "up_" + str(i - startp)) + project = getattr(self, "proj_" + str(i - startp)) + layers[i] = upsample(project(layers[i])) + node = getattr(self, "node_" + str(i - startp)) + layers[i] = node(layers[i - 1] + layers[i]) + + +class DLAUp(nn.Module): + """Deep Layer Aggregation for Upsampling, as described in ``_. + + Args: + startp: index of the first backbone map fed to the decoder + channels: list of number of channels for each input feature map + scales: list of upsampling factors for each input feature map + in_channels: list of number of input channels for each input feature map (optional) + """ + + def __init__(self, startp: int, channels: list[int], scales: list[int], in_channels: list[int] | None = None): + super().__init__() + self.startp = startp + if in_channels is None: + in_channels = channels + self.channels = channels + channels = list(channels) + np_scales = np.array(scales, dtype=int) + for i in range(len(channels) - 1): + j = -i - 2 + setattr(self, f"ida_{i}", IDAUp(channels[j], in_channels[j:], (np_scales[j:] // np_scales[j]).tolist())) + np_scales[j + 1 :] = np_scales[j] + in_channels[j + 1 :] = [channels[j] for _ in channels[j + 1 :]] + + def forward(self, layers: list[torch.Tensor]) -> list[torch.Tensor]: + out = [layers[-1]] + for i in range(len(layers) - self.startp - 1): + ida = getattr(self, f"ida_{i}") + ida(layers, len(layers) - i - 2, len(layers)) + out.insert(0, layers[-1]) + return out + + +# Model + + +class TableCenterNet(nn.Module): + """TableCenterNet for table-structure recognition, as described in the official implementation + ``_. + + A StarNet backbone feeds a deformable-convolution DLA decoder, followed by six dense heads + (``hm``, ``reg``, ``ct2cn``, ``cn2ct``, ``lc``, ``sp``) describing cell centers, corners and their + logical coordinates. + + Args: + feat_extractor: the StarNet backbone serving as feature extractor (returns the stem + 4 stage maps) + heads: mapping from head name to its number of output channels + head_conv: number of channels in the hidden layer of each head + first_level: index of the first backbone map fed to the decoder + last_level: index (exclusive) of the last backbone map fed to the decoder + center_thresh: minimum score for a cell center to be kept + corner_thresh: minimum score for a corner to be used during relocation + center_k: maximum number of cell centers + corner_k: maximum number of corners + not_relocate: if True, skip the corner-relocation step + assume_straight_pages: if True, the predictor will fit straight boxes to the cells + exportable: onnx exportable returns only the raw head maps + cfg: the configuration dict of the model + """ + + def __init__( + self, + feat_extractor: IntermediateLayerGetter, + heads: dict[str, int] = {"hm": 2, "reg": 2, "ct2cn": 8, "cn2ct": 8, "lc": 2, "sp": 2}, + head_conv: int = 256, + first_level: int = 1, + last_level: int = 4, + center_thresh: float = 0.3, + corner_thresh: float = 0.3, + center_k: int = 3000, + corner_k: int = 5000, + not_relocate: bool = False, + max_objects: int = 300, + max_corners: int = 1200, + assume_straight_pages: bool = False, + exportable: bool = False, + cfg: dict[str, Any] | None = None, + ) -> None: + super().__init__() + self.heads = heads + self.exportable = exportable + self.assume_straight_pages = assume_straight_pages + self.cfg = cfg + self.first_level, self.last_level = first_level, last_level + self.center_k, self.corner_k = center_k, corner_k + self.max_objects, self.max_corners = max_objects, max_corners + + self.feat_extractor = feat_extractor + # Identify the number of channels for the decoder initialization + _is_training = self.feat_extractor.training + self.feat_extractor = self.feat_extractor.eval() + with torch.no_grad(): + out = self.feat_extractor(torch.zeros((1, 3, 256, 256))) + channels = [v.shape[1] for v in out.values()] + if _is_training: + self.feat_extractor = self.feat_extractor.train() + + scales = [2**i for i in range(len(channels[first_level:]))] + self.dla_up = DLAUp(first_level, channels[first_level:], scales) + out_channel = channels[first_level] + self.ida_up = IDAUp( + out_channel, channels[first_level:last_level], [2**i for i in range(last_level - first_level)] + ) + for head, out_ch in self.heads.items(): + fc = nn.Sequential( + nn.Conv2d(channels[first_level], head_conv, 3, padding=1, bias=True), + nn.ReLU(inplace=True), + nn.Conv2d(head_conv, out_ch, 1, stride=1, padding=0, bias=True), + ) + # Reference head initialisation: detection-style bias for heatmaps, zeroed bias otherwise. + final = fc[2] + if isinstance(final, nn.Conv2d) and final.bias is not None: + nn.init.constant_(final.bias, -2.19 if "hm" in head else 0.0) + self.__setattr__(head, fc) + + self.postprocessor = TableCenterNetPostProcessor( + center_thresh=center_thresh, corner_thresh=corner_thresh, not_relocate=not_relocate + ) + + def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: + """Load pretrained parameters onto the model + + Args: + path_or_url: the path or URL to the model parameters (checkpoint) + **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params` + """ + load_pretrained_params(self, path_or_url, **kwargs) + + def _polygons_decode(self, heatmap: torch.Tensor, vec: torch.Tensor, reg: torch.Tensor, k: int): + """Decode key-points (cell centers or corners) into the four points of a quadrilateral.""" + batch = heatmap.size(0) + k = min(k, heatmap.size(2) * heatmap.size(3)) # never request more points than there are locations + # NMS on heatmaps + pad = (3 - 1) // 2 + hmax = F.max_pool2d(heatmap, (3, 3), stride=1, padding=pad) + heatmap = heatmap * (hmax == heatmap).float() + # Top-K key-points + batch, cat, height, width = heatmap.size() + k = min(k, height * width) # never request more points than there are locations + topk_scores, topk_inds = torch.topk(heatmap.view(batch, cat, -1), k) + topk_inds = topk_inds % (height * width) + topk_ys = (topk_inds // width).int().float() + topk_xs = (topk_inds % width).int().float() + scores, topk_ind = torch.topk(topk_scores.view(batch, -1), k) + indexes = _gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, k) + ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, k) + xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, k) + + scores = scores.view(batch, k, 1) + offset = _transpose_and_gather_feat(reg, indexes) + xs = xs.view(batch, k, 1) + offset[:, :, 0:1] + ys = ys.view(batch, k, 1) + offset[:, :, 1:2] + v = _transpose_and_gather_feat(vec, indexes) + polygons = torch.cat( + [ + xs - v[..., 0:1], + ys - v[..., 1:2], + xs - v[..., 2:3], + ys - v[..., 3:4], + xs - v[..., 4:5], + ys - v[..., 5:6], + xs - v[..., 6:7], + ys - v[..., 7:8], + ], + dim=2, + ) + return scores, indexes, xs, ys, polygons + + def _forward_heads(self, x: torch.Tensor) -> dict[str, torch.Tensor]: + """Run the model and return the raw head maps.""" + feats = self.feat_extractor(x) + feats = [feats[str(idx)] for idx in range(len(feats))] + layers = self.dla_up(feats) + y = [layers[i].clone() for i in range(self.last_level - self.first_level)] + self.ida_up(y, 0, len(y)) + return {head: self.__getattr__(head)(y[-1]) for head in self.heads} # type: ignore[operator] + + @torch.compiler.disable + def _decode(self, heads: dict[str, torch.Tensor]) -> dict[str, Any]: + """Decode the raw head maps into cell polygons, scores, logical coordinates and corner points.""" + hm = heads["hm"].sigmoid() + reg = heads["reg"] + c_scores, c_ind, _, _, c_poly = self._polygons_decode(hm[:, 0:1], heads["ct2cn"], reg, self.center_k) + k_scores, k_ind, k_xs, k_ys, k_poly = self._polygons_decode(hm[:, 1:2], heads["cn2ct"], reg, self.corner_k) + spans = _transpose_and_gather_feat(heads["sp"], c_ind) + corner_logics = _transpose_and_gather_feat(heads["lc"], k_ind) + feat_h, feat_w = hm.shape[2], hm.shape[3] + + def _np(t: torch.Tensor) -> np.ndarray: + # Cast to float32 first: numpy has no bfloat16 (relevant under autocast/AMP) + return t.detach().float().cpu().numpy() + + return { + "center_polygons": _np(c_poly), + "center_scores": _np(c_scores.squeeze(-1)), + "center_spans": _np(spans), + "corner_polygons": _np(k_poly), + "corner_scores": _np(k_scores.squeeze(-1)), + "corner_points": _np(torch.cat([k_xs, k_ys], dim=2)), + "corner_logics": _np(corner_logics), + "lc": _np(heads["lc"]), + "feat_size": (feat_h, feat_w), + } + + def forward( + self, + x: torch.Tensor, + target: list[dict[str, np.ndarray]] | dict[str, torch.Tensor] | None = None, + return_model_output: bool = False, + return_preds: bool = False, + ) -> dict[str, Any]: + heads_out = self._forward_heads(x) + + out: dict[str, Any] = {} + + if self.exportable: + return heads_out + + if return_model_output: + out["heads_out"] = heads_out + + if target is None or return_preds: + # Disable for torch.compile compatibility + @torch.compiler.disable + def _postprocess(heads_out): + return self.postprocessor(self._decode(heads_out)) + + out["preds"] = _postprocess(heads_out) + + if target is not None: + # Build target + @torch.compiler.disable + def _compute_loss(heads_out, target): + processed_targets = self.build_target(target, self.class_names) + return self.compute_loss(heads_out, processed_targets) + + out["loss"] = _compute_loss(heads_out, target) + + return out + + def compute_loss( + self, + output: dict[str, torch.Tensor], + target: list[dict[str, np.ndarray]], + ) -> torch.Tensor: + """Compute the multi-task TableCenterNet loss. + + Args: + output: the raw head maps returned by the model + target: one ``{"cells": (N, 4, 2) relative polygons, "logic": (N, 4)}`` dict per image + + Returns: + the scalar training loss + """ + device = output["hm"].device + out_h, out_w = int(output["hm"].shape[-2]), int(output["hm"].shape[-1]) + dense_np = build_target(target, (out_h, out_w), self.max_objects, self.max_corners) + dense = {k: torch.from_numpy(v).to(device) for k, v in dense_np.items()} + return self._loss_from_dense(output, dense) + + def _loss_from_dense(self, output: dict[str, torch.Tensor], target: dict[str, torch.Tensor]) -> torch.Tensor: + """Compute the multi-task TableCenterNet loss. + + Args: + output: the raw head maps returned by the model + target: dense targets matching the reference schema ( + keys: + hm, reg, reg_ind, reg_mask, ct_ind, ct_mask, + ct2cn, cn_ind, cn_mask, cn2ct, ct_cn_ind, + lc, lc_mask, lc_ind, lc_span + ) + + Returns: + the scalar training loss + """ + eps = 1e-4 + hm = torch.clamp(output["hm"].sigmoid(), min=eps, max=1 - eps) + + # Focal loss on the center/corner heat-maps + gt = target["hm"] + pos_inds, neg_inds = gt.eq(1).float(), gt.lt(1).float() + neg_weights = torch.pow(1 - gt, 4) + pos_loss = (torch.log(hm) * torch.pow(1 - hm, 2) * pos_inds).sum() + neg_loss = (torch.log(1 - hm) * torch.pow(hm, 2) * neg_weights * neg_inds).sum() + num_pos = pos_inds.sum() + hm_loss = -neg_loss if num_pos == 0 else -(pos_loss + neg_loss) / num_pos + + # L1 on the sub-pixel offsets + reg_pred = self._transpose_and_gather_feat(output["reg"], target["reg_ind"]) + reg_mask = target["reg_mask"].unsqueeze(2).expand_as(reg_pred).float() + reg_loss = F.l1_loss(reg_pred * reg_mask, target["reg"] * reg_mask, reduction="sum") / (reg_mask.sum() + eps) + + # Vector-pair loss (center<->corner offsets) + ct2cn_loss, cn2ct_loss, invalid_loss = self._vec_pair_loss(output, target, eps) + + # Logical-coordinate + span loss + lc_coord_loss, span_diff_loss, span_loss = self._logic_coord_loss(output, target, eps) + + return ( + hm_loss + reg_loss + ct2cn_loss + (cn2ct_loss + invalid_loss) + (lc_coord_loss + span_diff_loss + span_loss) + ) + + @staticmethod + def _vec_pair_loss(output, target, eps): + ct2cn_pred = _transpose_and_gather_feat(output["ct2cn"], target["ct_ind"]) + cn2ct_pred = _transpose_and_gather_feat(output["cn2ct"], target["cn_ind"]) + cn2ct_pred_temp, cn2ct_gt_temp = cn2ct_pred, target["cn2ct"] + b, m, n = ct2cn_pred.size(0), ct2cn_pred.size(1), cn2ct_pred.size(1) + + ct_cn_ind = target["ct_cn_ind"].unsqueeze(2).expand(b, 4 * m, 2) + cn2ct_pred = cn2ct_pred.view(b, 4 * n, 2).gather(1, ct_cn_ind).view(b, m, 8) + cn2ct_gt = target["cn2ct"].view(b, 4 * n, 2).gather(1, ct_cn_ind).view(b, m, 8) + + ct_mask = target["ct_mask"].unsqueeze(2).expand_as(ct2cn_pred).float() + num_ct = ct_mask.sum() + eps + cn_mask = target["cn_mask"].unsqueeze(2).expand_as(cn2ct_pred_temp) + + delta = (torch.abs(ct2cn_pred - target["ct2cn"]) + torch.abs(cn2ct_pred - cn2ct_gt)) / ( + torch.abs(target["ct2cn"]) + eps + ) + weight = torch.sin(1.570796 * torch.min(delta, torch.tensor(1.0, device=delta.device))) + ct2cn_loss = ( + F.l1_loss(ct2cn_pred * ct_mask * weight, target["ct2cn"] * ct_mask * weight, reduction="sum") / num_ct + ) + cn2ct_loss = F.l1_loss(cn2ct_pred * ct_mask * weight, cn2ct_gt * ct_mask * weight, reduction="sum") / num_ct + + invalid_vec_mask = cn2ct_gt_temp == 0 + invalid_vec_cn_mask = (invalid_vec_mask == cn_mask).float() + invalid_loss = F.l1_loss( + cn2ct_pred_temp * invalid_vec_cn_mask, cn2ct_gt_temp * invalid_vec_cn_mask, reduction="sum" + ) / (invalid_vec_cn_mask.sum() + eps) + return ct2cn_loss, 0.5 * cn2ct_loss, 0.2 * invalid_loss + + @staticmethod + def _logic_coord_loss(output, target, eps): + coord, span = output["lc"], output["sp"] + b, num = target["lc_span"].size(0), target["lc_span"].size(1) + + coords_pred = _transpose_and_gather_feat(coord, target["lc_ind"].view(b, num * 4)).view(b, num, 4, 2) + cols_pred, rows_pred = coords_pred[..., 0], coords_pred[..., 1] + span_pred = _transpose_and_gather_feat(span, target["ct_ind"]) + span_mask = target["ct_mask"].unsqueeze(2).expand(b, num, 2).float() + num_span_mask = span_mask.sum() + eps + + coord_gt, coord_mask = target["lc"], target["lc_mask"] + coord_weight = torch.square(1.0 - torch.abs(coord_gt - torch.round(coord_gt))) + coord_loss = F.l1_loss( + coord * coord_mask * coord_weight, coord_gt * coord_mask * coord_weight, reduction="sum" + ) / (coord_mask.sum() + eps) + + col_span_diff_pred = cols_pred[..., [1, 2]] - cols_pred[..., [0, 3]] + row_span_diff_pred = rows_pred[..., [3, 2]] - rows_pred[..., [0, 1]] + col_span_pred = span_pred[..., 0].unsqueeze(2).expand(b, num, 2) + row_span_pred = span_pred[..., 1].unsqueeze(2).expand(b, num, 2) + col_span_gt = target["lc_span"][..., 0].unsqueeze(2).expand(b, num, 2) + row_span_gt = target["lc_span"][..., 1].unsqueeze(2).expand(b, num, 2) + + def span_weight(out1, out2, tgt): + scaled = (torch.abs(out1 - tgt) + torch.abs(out2 - tgt)) * 5.0 + delta = torch.min(scaled, torch.tensor(1.0, device=tgt.device)) + return torch.sin(1.570796 * delta) + + col_w = span_weight(col_span_pred, col_span_diff_pred, col_span_gt) + row_w = span_weight(row_span_pred, row_span_diff_pred, row_span_gt) + sp_weight = torch.stack([(col_w[..., 0] + col_w[..., 1]) / 2.0, (row_w[..., 0] + row_w[..., 1]) / 2.0], dim=-1) + + col_span_diff_loss = ( + F.l1_loss(col_span_diff_pred * span_mask * col_w, col_span_gt * span_mask * col_w, reduction="sum") + / num_span_mask + ) + row_span_diff_loss = ( + F.l1_loss(row_span_diff_pred * span_mask * row_w, row_span_gt * span_mask * row_w, reduction="sum") + / num_span_mask + ) + span_diff_loss = col_span_diff_loss + row_span_diff_loss + + span_loss = ( + F.l1_loss(span_pred * span_mask * sp_weight, target["lc_span"] * span_mask * sp_weight, reduction="sum") + / num_span_mask + ) + return coord_loss, span_diff_loss, span_loss + + +def _tablecenternet( + arch: str, + pretrained: bool, + backbone_fn: Callable[..., nn.Module], + pretrained_backbone: bool = True, + **kwargs: Any, +) -> TableCenterNet: + pretrained_backbone = pretrained_backbone and not pretrained + backbone = backbone_fn(pretrained_backbone) + # 4 stages - all starnet variants + feat_extractor = IntermediateLayerGetter(backbone, {str(idx): str(idx) for idx in range(5)}) + + model = TableCenterNet(feat_extractor, cfg=default_cfgs[arch], **kwargs) + # Load pretrained parameters + if pretrained: + model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=None) + + return model + + +def tablecenternet(pretrained: bool = False, **kwargs: Any) -> TableCenterNet: + """TableCenterNet with a StarNet-S3 backbone, matching the official checkpoint. + + >>> import torch + >>> from doctr.models import tablecenternet + >>> model = tablecenternet(pretrained=False) + >>> out = model(torch.rand((1, 3, 1024, 1024), dtype=torch.float32), return_preds=True) + + Args: + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the TableCenterNet architecture + + Returns: + A TableCenterNet model with a StarNet-S3 backbone + """ + return _tablecenternet("tablecenternet", pretrained, starnet_s3, **kwargs) diff --git a/doctr/models/table_structure/zoo.py b/doctr/models/table_structure/zoo.py new file mode 100644 index 0000000000..12aa56f5f3 --- /dev/null +++ b/doctr/models/table_structure/zoo.py @@ -0,0 +1,82 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any + +from doctr.models.utils import _CompiledModule + +from .. import table_structure +from ..preprocessor import PreProcessor +from .predictor import TablePredictor + +__all__ = ["table_predictor"] + +ARCHS: list[str] = ["tablecenternet"] + + +def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = False, **kwargs: Any) -> TablePredictor: + if isinstance(arch, str): + if arch not in ARCHS: + raise ValueError(f"unknown architecture '{arch}'") + _model = table_structure.__dict__[arch](pretrained=pretrained, assume_straight_pages=assume_straight_pages) + else: + allowed_archs = [table_structure.TableCenterNet, _CompiledModule] + if not isinstance(arch, tuple(allowed_archs)): + raise ValueError(f"unknown architecture: {type(arch)}") + _model = arch + _model.assume_straight_pages = assume_straight_pages # type: ignore[attr-defined] + + kwargs.pop("pretrained_backbone", None) + kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) + kwargs["std"] = kwargs.get("std", _model.cfg["std"]) + kwargs["batch_size"] = kwargs.get("batch_size", 2) + kwargs.setdefault("preserve_aspect_ratio", True) + kwargs.setdefault("symmetric_pad", True) + predictor = TablePredictor( + PreProcessor(_model.cfg["input_shape"][1:], **kwargs), + _model, + ) + return predictor + + +def table_predictor( + arch: Any = "tablecenternet", + pretrained: bool = False, + assume_straight_pages: bool = False, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + batch_size: int = 2, + **kwargs: Any, +) -> TablePredictor: + """Table structure recognition architecture. + + >>> import numpy as np + >>> from doctr.models import table_predictor + >>> model = table_predictor(arch='tablecenternet', pretrained=True) + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + arch: name of the architecture or model itself to use (e.g. 'tablecenternet') + pretrained: If True, returns a model pre-trained on a table structure recognition dataset + assume_straight_pages: if True, fit straight boxes to the detected cells + preserve_aspect_ratio: if True, pad the input document image to preserve the aspect ratio before + running the model on it + symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right + batch_size: number of samples the model processes in parallel + **kwargs: optional keyword arguments passed to the architecture + + Returns: + Table structure recognition predictor + """ + return _predictor( + arch=arch, + pretrained=pretrained, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + batch_size=batch_size, + **kwargs, + ) diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py index 50dd9205ef..c816f84297 100644 --- a/doctr/models/utils/pytorch.py +++ b/doctr/models/utils/pytorch.py @@ -159,7 +159,13 @@ def set_device_and_dtype( def export_model_to_onnx( - model: nn.Module, model_name: str, dummy_input: torch.Tensor | tuple[torch.Tensor, torch.Tensor], **kwargs: Any + model: nn.Module, + model_name: str, + dummy_input: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + input_names: list[str] | None = None, + output_names: list[str] | None = None, + dynamic_axes: dict[str, dict[int, str]] | None = None, + **kwargs: Any, ) -> str: """Export model to ONNX format. @@ -173,25 +179,32 @@ def export_model_to_onnx( model: the PyTorch model to be exported model_name: the name for the exported model dummy_input: the dummy input to the model + input_names: optional names for the model inputs. Defaults to ``["input"]`` (or ``["input", "masks"]`` + when ``dummy_input`` is a tuple). + output_names: optional names for the model outputs. Defaults to ``["logits"]`` (or + ``["logits", "pred_boxes"]`` when ``dummy_input`` is a tuple). Pass the names of every output when + the model returns more than one tensor (e.g. a multi-head model). + dynamic_axes: optional dynamic axes. Defaults to a dynamic batch dimension on every input and output. kwargs: additional arguments to be passed to torch.onnx.export Returns: the path to the exported model """ + is_tuple = isinstance(dummy_input, tuple) + if input_names is None: + input_names = ["input", "masks"] if is_tuple else ["input"] + if output_names is None: + output_names = ["logits", "pred_boxes"] if is_tuple else ["logits"] + if dynamic_axes is None: + dynamic_axes = {name: {0: "batch_size"} for name in [*input_names, *output_names]} + torch.onnx.export( model, dummy_input, # type: ignore[arg-type] f"{model_name}.onnx", - input_names=["input", "masks"] if isinstance(dummy_input, tuple) else ["input"], - output_names=["logits", "pred_boxes"] if isinstance(dummy_input, tuple) else ["logits"], - dynamic_axes={ - "input": {0: "batch_size"}, - "masks": {0: "batch_size"}, - "logits": {0: "batch_size"}, - "pred_boxes": {0: "batch_size"}, - } - if isinstance(dummy_input, tuple) - else {"input": {0: "batch_size"}, "logits": {0: "batch_size"}}, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, export_params=True, dynamo=False, verbose=False, diff --git a/references/table/README.md b/references/table/README.md new file mode 100644 index 0000000000..4a807f4f1b --- /dev/null +++ b/references/table/README.md @@ -0,0 +1 @@ +# TODO: Write the readme like in references/detection | references/layout diff --git a/references/table/evaluate.py b/references/table/evaluate.py new file mode 100644 index 0000000000..1228ceff26 --- /dev/null +++ b/references/table/evaluate.py @@ -0,0 +1,143 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import multiprocessing as mp +import os +import time + +import numpy as np +import torch +from torch.utils.data import DataLoader, SequentialSampler +from torchvision.transforms import Normalize + +if os.getenv("TQDM_SLACK_TOKEN") and os.getenv("TQDM_SLACK_CHANNEL"): + from tqdm.contrib.slack import tqdm +else: + from tqdm.auto import tqdm + +from doctr import transforms as T +from doctr.datasets import TableStructureDataset +from doctr.models import table_structure +from doctr.utils.metrics import TableCellMetric + + +@torch.inference_mode() +def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): + model.eval() + val_metric.reset() + val_loss, batch_cnt = 0, 0 + for images, targets in tqdm(val_loader): + if torch.cuda.is_available(): + images = images.cuda() + images = batch_transforms(images) + if amp: + with torch.amp.autocast("cuda"): + out = model(images, target=targets, return_preds=True) + else: + out = model(images, target=targets, return_preds=True) + + for target, pred in zip(targets, out["preds"]): + val_metric.update( + np.asarray(target["cells"], dtype=np.float32).reshape(-1, 4, 2), + np.asarray(target["logic"], dtype=np.int64).reshape(-1, 4), + pred["polygons"], + pred["logical"], + ) + + val_loss += out["loss"].item() + batch_cnt += 1 + + val_loss /= batch_cnt + metrics = val_metric.summary() + return val_loss, metrics["recall"], metrics["precision"], metrics["f1"], metrics["structure_acc"] + + +def main(args): + slack_token = os.getenv("TQDM_SLACK_TOKEN") + slack_channel = os.getenv("TQDM_SLACK_CHANNEL") + pbar = tqdm(disable=False if slack_token and slack_channel else True) + if slack_token and slack_channel: + pbar.write = lambda msg: pbar.sio.client.chat_postMessage(channel=slack_channel, text=msg) + pbar.write(str(args)) + + if not isinstance(args.workers, int): + args.workers = min(16, mp.cpu_count()) + + torch.backends.cudnn.benchmark = True + + tmp_model = table_structure.__dict__[args.arch](pretrained=False) + input_shape = (args.size, args.size) if isinstance(args.size, int) else tmp_model.cfg["input_shape"][-2:] + mean, std = tmp_model.cfg["mean"], tmp_model.cfg["std"] + + st = time.time() + ds = TableStructureDataset( + img_folder=os.path.join(args.dataset_path, "images"), + label_path=os.path.join(args.dataset_path, "labels.json"), + sample_transforms=T.Resize( + input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad + ), + ) + test_loader = DataLoader( + ds, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + sampler=SequentialSampler(ds), + pin_memory=torch.cuda.is_available(), + collate_fn=ds.collate_fn, + ) + pbar.write(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in {len(test_loader)} batches)") + + model = table_structure.__dict__[args.arch](pretrained=not isinstance(args.resume, str)).eval() + batch_transforms = Normalize(mean=mean, std=std) + if isinstance(args.resume, str): + pbar.write(f"Resuming {args.resume}") + model.from_pretrained(args.resume) + + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + elif torch.cuda.is_available(): + args.device = 0 + else: + pbar.write("No accessible GPU, target device set to CPU.") + if torch.cuda.is_available(): + torch.cuda.set_device(args.device) + model = model.cuda() + + metric = TableCellMetric(iou_thresh=args.iou_thresh) + pbar.write("Running evaluation") + val_loss, recall, precision, f1, struct = evaluate(model, test_loader, batch_transforms, metric, amp=args.amp) + pbar.write( + f"Validation loss: {val_loss:.6f} | Recall: {(recall or 0):.2%} | Precision: {(precision or 0):.2%} " + f"| F1: {(f1 or 0):.2%} | Structure acc: {(struct or 0):.2%}" + ) + + +def parse_args(): + import argparse + + parser = argparse.ArgumentParser( + description="docTR evaluation script for table structure recognition (PyTorch)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("arch", type=str, help="table model to evaluate") + parser.add_argument("dataset_path", type=str, help="path to the dataset folder (images/ + labels.json)") + parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for evaluation") + parser.add_argument("--device", default=None, type=int, help="device") + parser.add_argument("--size", type=int, default=None, help="model input size, H = W") + parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image") + parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically") + parser.add_argument("--iou_thresh", type=float, default=0.5, help="IoU threshold for cell matching") + parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") + parser.add_argument("--resume", type=str, default=None, help="Checkpoint to resume") + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + main(parse_args()) diff --git a/references/table/latency.py b/references/table/latency.py new file mode 100644 index 0000000000..aa25f7e805 --- /dev/null +++ b/references/table/latency.py @@ -0,0 +1,54 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +"""Table structure recognition latency benchmark""" + +import argparse +import time + +import numpy as np +import torch + +from doctr.models import table_structure + + +@torch.inference_mode() +def main(args): + device = torch.device("cuda:0" if args.gpu else "cpu") + + model = table_structure.__dict__[args.arch](pretrained=args.pretrained).eval().to(device=device) + + img_tensor = torch.rand((1, 3, args.size, args.size)).to(device=device) + + # Warmup + for _ in range(5): + _ = model(img_tensor, return_preds=True) + + timings = [] + for _ in range(args.it): + start_ts = time.perf_counter() + _ = model(img_tensor, return_preds=True) + timings.append(time.perf_counter() - start_ts) + + _timings = np.array(timings) + print(f"{args.arch} ({args.it} runs on ({args.size}, {args.size}) inputs)") + print(f"mean {1000 * _timings.mean():.2f}ms, std {1000 * _timings.std():.2f}ms") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="docTR latency benchmark for table structure recognition (PyTorch)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("arch", type=str, help="Architecture to use") + parser.add_argument("--size", type=int, default=1024, help="The image input size") + parser.add_argument("--gpu", dest="gpu", help="Should the benchmark be performed on GPU", action="store_true") + parser.add_argument("--it", type=int, default=100, help="Number of iterations to run") + parser.add_argument( + "--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", action="store_true" + ) + args = parser.parse_args() + + main(args) diff --git a/references/table/train.py b/references/table/train.py new file mode 100644 index 0000000000..485fdfc958 --- /dev/null +++ b/references/table/train.py @@ -0,0 +1,538 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +import datetime +import hashlib +import logging +import multiprocessing +import os +import time +from pathlib import Path + +import numpy as np +import torch + +# The following import is required for DDP +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim.lr_scheduler import ( + CosineAnnealingLR, + LinearLR, + MultiplicativeLR, + OneCycleLR, + PolynomialLR, + SequentialLR, +) +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler +from torchvision.transforms.v2 import Compose, Normalize, RandomGrayscale, RandomPhotometricDistort + +if os.getenv("TQDM_SLACK_TOKEN") and os.getenv("TQDM_SLACK_CHANNEL"): + from tqdm.contrib.slack import tqdm +else: + from tqdm.auto import tqdm + +from doctr import transforms as T +from doctr.datasets import TableStructureDataset +from doctr.models import table_structure +from doctr.utils.metrics import TableCellMetric +from utils import EarlyStopper, build_param_groups, plot_recorder, plot_samples + + +def record_lr( + model: torch.nn.Module, + train_loader: DataLoader, + batch_transforms, + optimizer, + start_lr: float = 1e-7, + end_lr: float = 1, + num_it: int = 100, + amp: bool = False, +): + """Gridsearch the optimal learning rate for the training. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + """ + if num_it > len(train_loader): + raise ValueError("the value of `num_it` needs to be lower than the number of available batches") + + model = model.train() + optimizer.defaults["lr"] = start_lr + for pgroup in optimizer.param_groups: + pgroup["lr"] = start_lr + + gamma = (end_lr / start_lr) ** (1 / (num_it - 1)) + scheduler = MultiplicativeLR(optimizer, lambda step: gamma) + + lr_recorder = [start_lr * gamma**idx for idx in range(num_it)] + loss_recorder = [] + + if amp: + scaler = torch.amp.GradScaler("cuda") + + for batch_idx, (images, targets) in enumerate(train_loader): + if torch.cuda.is_available(): + images = images.cuda() + images = batch_transforms(images) + + optimizer.zero_grad() + if amp: + with torch.amp.autocast("cuda"): + train_loss = model(images, target=targets)["loss"] + scaler.scale(train_loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5) + scaler.step(optimizer) + scaler.update() + else: + train_loss = model(images, target=targets)["loss"] + train_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5) + optimizer.step() + scheduler.step() + + if not torch.isfinite(train_loss): + if batch_idx == 0: + raise ValueError("loss value is NaN or inf.") + break + loss_recorder.append(train_loss.item()) + if batch_idx + 1 == num_it: + break + + return lr_recorder[: len(loss_recorder)], loss_recorder + + +def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): + if amp: + scaler = torch.amp.GradScaler("cuda") + + model.train() + epoch_train_loss, batch_cnt = 0, 0 + pbar = tqdm(train_loader, dynamic_ncols=True, disable=(rank != 0)) + for images, targets in pbar: + if torch.cuda.is_available(): + images = images.cuda() + images = batch_transforms(images) + + optimizer.zero_grad() + if amp: + with torch.amp.autocast("cuda"): + train_loss = model(images, target=targets)["loss"] + scaler.scale(train_loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5) + scaler.step(optimizer) + scaler.update() + else: + train_loss = model(images, target=targets)["loss"] + train_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5) + optimizer.step() + + scheduler.step() + last_lr = scheduler.get_last_lr()[0] + + pbar.set_description(f"Training loss: {train_loss.item():.6f} | LR: {last_lr:.6f}") + if log: + log(train_loss=train_loss.item(), lr=last_lr) + + epoch_train_loss += train_loss.item() + batch_cnt += 1 + + epoch_train_loss /= batch_cnt + return epoch_train_loss, last_lr + + +@torch.no_grad() +def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=None): + model.eval() + val_metric.reset() + val_loss, batch_cnt = 0, 0 + pbar = tqdm(val_loader, dynamic_ncols=True) + for images, targets in pbar: + if torch.cuda.is_available(): + images = images.cuda() + images = batch_transforms(images) + if amp: + with torch.amp.autocast("cuda"): + out = model(images, target=targets, return_preds=True) + else: + out = model(images, target=targets, return_preds=True) + + # Cells & logical coords are compared in the (relative) model-input space + for target, pred in zip(targets, out["preds"]): + val_metric.update( + np.asarray(target["cells"], dtype=np.float32).reshape(-1, 4, 2), + np.asarray(target["logic"], dtype=np.int64).reshape(-1, 4), + pred["polygons"], + pred["logical"], + ) + + pbar.set_description(f"Validation loss: {out['loss'].item():.6f}") + if log: + log(val_loss=out["loss"].item()) + val_loss += out["loss"].item() + batch_cnt += 1 + + val_loss /= batch_cnt + metrics = val_metric.summary() + return val_loss, metrics["recall"], metrics["precision"], metrics["f1"], metrics["structure_acc"] + + +def main(args): + world_size = int(os.environ.get("WORLD_SIZE", 1)) + distributed = world_size > 1 + + if distributed: + rank = int(os.environ.get("LOCAL_RANK", 0)) + dist.init_process_group(backend=args.backend) + device = torch.device("cuda", rank) + torch.cuda.set_device(device) + else: + rank = 0 + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + device = torch.device("cuda", args.device) + elif torch.cuda.is_available(): + device = torch.device("cuda", 0) + else: + logging.warning("No accessible GPU, target device set to CPU.") + device = torch.device("cpu") + + slack_token = os.getenv("TQDM_SLACK_TOKEN") + slack_channel = os.getenv("TQDM_SLACK_CHANNEL") + pbar = tqdm(disable=False if (slack_token and slack_channel) and (rank == 0) else True) + if slack_token and slack_channel: + pbar.write = lambda msg: pbar.sio.client.chat_postMessage(channel=slack_channel, text=msg) + pbar.write(str(args)) + + if not isinstance(args.workers, int): + args.workers = min(16, multiprocessing.cpu_count()) + + if rank == 0 and args.output_dir: + os.makedirs(args.output_dir, exist_ok=True) + + torch.backends.cudnn.benchmark = True + + # Temporary model to recover the configuration (mean/std) + tmp_model = table_structure.__dict__[args.arch](pretrained=False) + mean, std = tmp_model.cfg["mean"], tmp_model.cfg["std"] + + # Validation data + val_hash = None + if rank == 0: + st = time.time() + val_set = TableStructureDataset( + img_folder=os.path.join(args.val_path, "images"), + label_path=os.path.join(args.val_path, "labels.json"), + sample_transforms=T.SampleCompose([ + T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), + ]), + ) + val_loader = DataLoader( + val_set, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + sampler=SequentialSampler(val_set), + pin_memory=torch.cuda.is_available(), + collate_fn=val_set.collate_fn, + ) + pbar.write( + f"Validation set loaded in {time.time() - st:.4f}s ({len(val_set)} samples in {len(val_loader)} batches)" + ) + with open(os.path.join(args.val_path, "labels.json"), "rb") as f: + val_hash = hashlib.sha256(f.read()).hexdigest() + + batch_transforms = Normalize(mean=mean, std=std) + + model = table_structure.__dict__[args.arch](pretrained=args.pretrained) + if isinstance(args.resume, str): + pbar.write(f"Resuming {args.resume}") + model.from_pretrained(args.resume) + + if rank == 0: + val_metric = TableCellMetric(iou_thresh=args.iou_thresh) + + if rank == 0 and args.test_only: + pbar.write("Running evaluation") + model = model.to(device) + val_loss, recall, precision, f1, struct = evaluate( + model, val_loader, batch_transforms, val_metric, amp=args.amp + ) + pbar.write( + f"Validation loss: {val_loss:.6f} | Recall: {(recall or 0):.2%} | Precision: {(precision or 0):.2%} " + f"| F1: {(f1 or 0):.2%} | Structure acc: {(struct or 0):.2%}" + ) + return + + st = time.time() + # Image-only augmentations + img_transforms = T.OneOf([ + Compose([ + T.RandomApply(T.ColorInversion(), 0.3), + T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.2), + ]), + T.ImageTorchvisionTransform(RandomPhotometricDistort(p=0.3)), + T.ImageTorchvisionTransform(RandomGrayscale(p=0.1)), + lambda x: x, # identity + ]) + # Image + geometry augmentations (letterbox to a square; the model renders the dense targets) + sample_transforms = T.SampleCompose([ + T.RandomHorizontalFlip(0.15), + T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), + ]) + + train_set = TableStructureDataset( + img_folder=os.path.join(args.train_path, "images"), + label_path=os.path.join(args.train_path, "labels.json"), + img_transforms=img_transforms, + sample_transforms=sample_transforms, + ) + sampler = ( + DistributedSampler(train_set, rank=rank, shuffle=True, drop_last=True) + if distributed + else RandomSampler(train_set) + ) + train_loader = DataLoader( + train_set, + batch_size=args.batch_size, + drop_last=True, + num_workers=args.workers, + sampler=sampler, + pin_memory=torch.cuda.is_available(), + collate_fn=train_set.collate_fn, + ) + if rank == 0: + pbar.write( + f"Train set loaded in {time.time() - st:.4f}s ({len(train_set)} samples in {len(train_loader)} batches)" + ) + with open(os.path.join(args.train_path, "labels.json"), "rb") as f: + train_hash = hashlib.sha256(f.read()).hexdigest() + + if rank == 0 and args.show_samples: + images, targets = next(iter(train_loader)) + plot_samples(images, targets) + return + + if args.freeze_backbone: + for p in model.feat_extractor.parameters(): + p.requires_grad = False + + if torch.cuda.is_available(): + torch.cuda.set_device(device) + model = model.to(device) + if distributed: + model = DDP(model, device_ids=[rank]) + + backbone_lr = args.lr * 0.1 if args.pretrained or args.resume is not None else args.lr + param_groups = build_param_groups( + model, lr=args.lr, backbone_lr=backbone_lr, weight_decay=args.weight_decay or 1e-4 + ) + optimizer = ( + torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-8) + if args.optim == "adamw" + else torch.optim.Adam(param_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-8) + ) + + if rank == 0 and args.find_lr: + lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp) + plot_recorder(lrs, losses) + return + + total_steps = args.epochs * len(train_loader) + warmup_steps = max(1, min(2000, int(0.05 * total_steps))) + if args.sched == "cosine": + warmup = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_steps) + cosine = CosineAnnealingLR(optimizer, T_max=max(1, total_steps - warmup_steps), eta_min=args.lr * 0.01) + scheduler = SequentialLR(optimizer, schedulers=[warmup, cosine], milestones=[warmup_steps]) + elif args.sched == "onecycle": + scheduler = OneCycleLR( + optimizer, + max_lr=[g["lr"] for g in optimizer.param_groups], + total_steps=total_steps, + pct_start=warmup_steps / total_steps, + div_factor=100, + final_div_factor=100, + anneal_strategy="cos", + ) + else: # poly + warmup = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_steps) + poly = PolynomialLR(optimizer, total_iters=total_steps - warmup_steps, power=1.0) + scheduler = SequentialLR(optimizer, schedulers=[warmup, poly], milestones=[warmup_steps]) + + current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name + if rank == 0: + config = { + "learning_rate": args.lr, + "backbone_learning_rate": backbone_lr, + "epochs": args.epochs, + "weight_decay": args.weight_decay, + "batch_size": args.batch_size, + "architecture": args.arch, + "input_size": args.input_size, + "optimizer": args.optim, + "framework": "pytorch", + "scheduler": args.sched, + "train_hash": train_hash, + "val_hash": val_hash, + "pretrained": args.pretrained, + "amp": args.amp, + } + + global global_step + global_step = 0 + + if args.wb: + import wandb + + run = wandb.init(name=exp_name, project="table-structure-recognition", config=config) + + def wandb_log_at_step(train_loss=None, val_loss=None, lr=None): + wandb.log({ + **({"train_loss_step": train_loss} if train_loss is not None else {}), + **({"val_loss_step": val_loss} if val_loss is not None else {}), + **({"step_lr": lr} if lr is not None else {}), + }) + + if args.clearml: + from clearml import Logger, Task + + task = Task.init(project_name="docTR/table-structure-recognition", task_name=exp_name, reuse_last_task_id=False) + task.upload_artifact("config", config) + + def clearml_log_at_step(train_loss=None, val_loss=None, lr=None): + logger = Logger.current_logger() + if train_loss is not None: + logger.report_scalar( + title="Training Step Loss", series="train_loss_step", iteration=global_step, value=train_loss + ) + if val_loss is not None: + logger.report_scalar( + title="Validation Step Loss", series="val_loss_step", iteration=global_step, value=val_loss + ) + if lr is not None: + logger.report_scalar(title="Step Learning Rate", series="step_lr", iteration=global_step, value=lr) + + def log_at_step(train_loss=None, val_loss=None, lr=None): + global global_step + if args.wb: + wandb_log_at_step(train_loss, val_loss, lr) + if args.clearml: + clearml_log_at_step(train_loss, val_loss, lr) + global_step += 1 + + min_loss = np.inf + if args.early_stop: + early_stopper = EarlyStopper(patience=args.early_stop_epochs, min_delta=args.early_stop_delta) + + for epoch in range(args.epochs): + if distributed: + sampler.set_epoch(epoch) + train_loss, actual_lr = fit_one_epoch( + model, train_loader, batch_transforms, optimizer, scheduler, amp=args.amp, log=log_at_step, rank=rank + ) + + if rank == 0: + pbar.write(f"Epoch {epoch + 1}/{args.epochs} - Training loss: {train_loss:.6f} | LR: {actual_lr:.6f}") + val_loss, recall, precision, f1, struct = evaluate( + model, val_loader, batch_transforms, val_metric, amp=args.amp, log=log_at_step + ) + params = model.module if hasattr(model, "module") else model + if val_loss < min_loss: + pbar.write(f"Validation loss decreased {min_loss:.6f} --> {val_loss:.6f}: saving state...") + torch.save(params.state_dict(), Path(args.output_dir) / f"{exp_name}.pt") + min_loss = val_loss + if args.save_interval_epoch: + torch.save(params.state_dict(), Path(args.output_dir) / f"{exp_name}_epoch{epoch + 1}.pt") + log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6f} " + if any(v is None for v in (recall, precision, f1)): + log_msg += "(Undefined metric value, caused by empty GTs or predictions)" + else: + log_msg += ( + f"| Recall: {recall:.2%} | Precision: {precision:.2%} " + f"| F1: {f1:.2%} | Structure acc: {(struct or 0):.2%}" + ) + pbar.write(log_msg) + if args.wb: + wandb.log({ + "train_loss": train_loss, + "val_loss": val_loss, + "learning_rate": actual_lr, + "recall": recall, + "precision": precision, + "f1": f1, + "structure_acc": struct, + }) + if args.early_stop and early_stopper.early_stop(val_loss): + pbar.write("Training halted early due to reaching patience limit.") + break + + if rank == 0 and args.wb: + run.finish() + if distributed: + dist.destroy_process_group() + + +def parse_args(): + import argparse + + parser = argparse.ArgumentParser( + description="docTR training script for table structure recognition (PyTorch)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--backend", default="nccl", type=str, help="Backend to use for torch.distributed") + parser.add_argument( + "--device", default=None, type=int, help="GPU index for single-GPU training (ignored under DDP)" + ) + parser.add_argument("arch", type=str, help="table model to train") + parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") + parser.add_argument( + "--train_path", type=str, required=True, help="path to the training data folder (images/ + labels.json)" + ) + parser.add_argument( + "--val_path", type=str, required=True, help="path to the validation data folder (images/ + labels.json)" + ) + parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") + parser.add_argument("--epochs", type=int, default=200, help="number of epochs to train the model on") + parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training") + parser.add_argument( + "--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch" + ) + parser.add_argument("--input_size", type=int, default=1024, help="model input size, H = W") + parser.add_argument("--lr", type=float, default=1e-4, help="learning rate for the optimizer (Adam or AdamW)") + parser.add_argument("--wd", "--weight-decay", default=1e-4, type=float, help="weight decay", dest="weight_decay") + parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") + parser.add_argument("--resume", type=str, default=None, help="Path to your checkpoint") + parser.add_argument("--iou_thresh", type=float, default=0.5, help="IoU threshold for cell matching in the metric") + parser.add_argument("--test-only", dest="test_only", action="store_true", help="Run the validation loop") + parser.add_argument( + "--freeze-backbone", dest="freeze_backbone", action="store_true", help="freeze model backbone for fine-tuning" + ) + parser.add_argument( + "--show-samples", dest="show_samples", action="store_true", help="Display unnormalized training samples" + ) + parser.add_argument("--wb", dest="wb", action="store_true", help="Log to Weights & Biases") + parser.add_argument("--clearml", dest="clearml", action="store_true", help="Log to ClearML") + parser.add_argument( + "--pretrained", dest="pretrained", action="store_true", help="Load pretrained parameters before training" + ) + parser.add_argument("--optim", type=str, default="adamw", choices=["adam", "adamw"], help="optimizer to use") + parser.add_argument( + "--sched", type=str, default="cosine", choices=["cosine", "onecycle", "poly"], help="scheduler to use" + ) + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR") + parser.add_argument("--early-stop", action="store_true", help="Enable early stopping") + parser.add_argument("--early-stop-epochs", type=int, default=5, help="Patience for early stopping") + parser.add_argument("--early-stop-delta", type=float, default=0.01, help="Minimum Delta for early stopping") + return parser.parse_args() + + +if __name__ == "__main__": + main(parse_args()) diff --git a/references/table/utils.py b/references/table/utils.py new file mode 100644 index 0000000000..873a42eb65 --- /dev/null +++ b/references/table/utils.py @@ -0,0 +1,115 @@ +# Copyright (C) 2021-2026, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from typing import Any + +import cv2 +import matplotlib.pyplot as plt +import numpy as np + + +def plot_samples(images: list[Any], targets: list[dict[str, np.ndarray]], max_samples: int = 4) -> None: + """Display a few training samples with their ground-truth cells overlaid.""" + nb_samples = min(len(images), max_samples) + _, axes = plt.subplots(2, nb_samples, figsize=(20, 6)) + if nb_samples == 1: + axes = np.expand_dims(axes, axis=1) + + for idx in range(nb_samples): + img = (255 * images[idx].detach().cpu().numpy()).round().clip(0, 255).astype(np.uint8) + if img.shape[0] == 3 and img.shape[2] != 3: + img = img.transpose(1, 2, 0) + + axes[0][idx].imshow(img) + axes[0][idx].set_title("Image") + + overlay = img.copy() + cells = targets[idx]["cells"].copy() + cells[..., 0] *= img.shape[1] + cells[..., 1] *= img.shape[0] + for quad in cells.round().astype(np.intp): + cv2.polylines(overlay, [quad], True, (255, 0, 0), 1) + axes[1][idx].imshow(overlay) + axes[1][idx].set_title("GT cells") + + for ax in axes.ravel(): + ax.axis("off") + plt.tight_layout() + plt.show() + + +def build_param_groups(model: Any, lr: float, backbone_lr: float, weight_decay: float): + """Build optimizer parameter groups, separating backbone from head parameters and applying weight decay + only to non-bias / non-norm tensors.""" + no_decay_keys = ("bias", "norm", ".bn", "embed") + + def is_backbone(name: str) -> bool: + return name.removeprefix("module.").startswith("feat_extractor.") + + groups: dict[tuple[bool, bool], list[Any]] = { + (False, True): [], + (False, False): [], + (True, True): [], + (True, False): [], + } + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + decay = not (p.ndim <= 1 or any(k in n.lower() for k in no_decay_keys)) + groups[(is_backbone(n), decay)].append(p) + + return [ + {"params": groups[(False, True)], "lr": lr, "weight_decay": weight_decay}, + {"params": groups[(False, False)], "lr": lr, "weight_decay": 0.0}, + {"params": groups[(True, True)], "lr": backbone_lr, "weight_decay": weight_decay}, + {"params": groups[(True, False)], "lr": backbone_lr, "weight_decay": 0.0}, + ] + + +def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> None: + """Display the results of the LR grid search. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + """ + if len(lr_recorder) != len(loss_recorder) or len(lr_recorder) == 0: + raise AssertionError("Both `lr_recorder` and `loss_recorder` should have the same length") + + smoothed_losses = [] + avg_loss = 0.0 + for idx, loss in enumerate(loss_recorder): + avg_loss = beta * avg_loss + (1 - beta) * loss + smoothed_losses.append(avg_loss / (1 - beta ** (idx + 1))) + + data_slice = slice(min(len(loss_recorder) // 10, 10), len(loss_recorder)) + vals = np.array(smoothed_losses[data_slice]) + min_idx = vals.argmin() + max_val = vals.max() if min_idx is None else vals[: min_idx + 1].max() # type: ignore[misc] + delta = max_val - vals[min_idx] + + plt.plot(lr_recorder[data_slice], smoothed_losses[data_slice]) + plt.xscale("log") + plt.xlabel("Learning Rate") + plt.ylabel("Training loss") + plt.ylim(vals[min_idx] - 0.1 * delta, max_val + 0.2 * delta) + plt.grid(True, linestyle="--", axis="x") + plt.show(**kwargs) + + +class EarlyStopper: + def __init__(self, patience: int = 5, min_delta: float = 0.01): + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.min_validation_loss = float("inf") + + def early_stop(self, validation_loss: float) -> bool: + if validation_loss < self.min_validation_loss: + self.min_validation_loss = validation_loss + self.counter = 0 + elif validation_loss > (self.min_validation_loss + self.min_delta): + self.counter += 1 + if self.counter >= self.patience: + return True + return False diff --git a/tests/common/test_models_table_structure.py b/tests/common/test_models_table_structure.py new file mode 100644 index 0000000000..02f09a6993 --- /dev/null +++ b/tests/common/test_models_table_structure.py @@ -0,0 +1,26 @@ +import numpy as np + +from doctr.models.table_structure.tablecenternet import TableCenterNetPostProcessor + + +def test_tablecenternet_postprocessor(): + postprocessor = TableCenterNetPostProcessor(center_thresh=0.0) + kc, kn, feat = 12, 16, 64 + decoded = { + "center_polygons": (np.random.rand(1, kc, 8) * feat).astype(np.float32), + "center_scores": np.random.rand(1, kc).astype(np.float32), + "center_spans": np.random.randint(1, 3, (1, kc, 2)).astype(np.float32), + "corner_polygons": (np.random.rand(1, kn, 8) * feat).astype(np.float32), + "corner_scores": np.random.rand(1, kn).astype(np.float32), + "corner_points": (np.random.rand(1, kn, 2) * feat).astype(np.float32), + "corner_logics": np.random.rand(1, kn, 2).astype(np.float32), + "lc": (np.random.rand(1, 2, feat, feat) * 5).astype(np.float32), + "feat_size": (feat, feat), + } + res = postprocessor(decoded) + assert len(res) == 1 and res[0]["polygons"].shape[1:] == (4, 2) + assert res[0]["logical"].shape[1] == 4 + if res[0]["polygons"].size: + assert res[0]["polygons"].max() <= 1.0 # relative coordinates + # not_relocate path + assert len(TableCenterNetPostProcessor(center_thresh=0.0, not_relocate=True)(decoded)) == 1 diff --git a/tests/pytorch/test_models_table_structure_pt.py b/tests/pytorch/test_models_table_structure_pt.py new file mode 100644 index 0000000000..337158f568 --- /dev/null +++ b/tests/pytorch/test_models_table_structure_pt.py @@ -0,0 +1,165 @@ +import gc +import os +import tempfile + +import numpy as np +import onnxruntime +import pytest +import torch + +from doctr.models import table_structure +from doctr.models.table_structure import TableCenterNet +from doctr.models.table_structure.predictor import TablePredictor +from doctr.models.utils import _CompiledModule, export_model_to_onnx + +_HEADS = {"hm": 2, "reg": 2, "ct2cn": 8, "cn2ct": 8, "lc": 2, "sp": 2} + + +def _grid_target(rows=2, cols=3): + """A relative-coordinate {"cells", "logic"} target for a rows x cols grid (the dataset's output).""" + xs, ys = np.linspace(0.1, 0.9, cols + 1), np.linspace(0.1, 0.9, rows + 1) + cells, logic = [], [] + for r in range(rows): + for c in range(cols): + cells.append([[xs[c], ys[r]], [xs[c + 1], ys[r]], [xs[c + 1], ys[r + 1]], [xs[c], ys[r + 1]]]) + logic.append([c, c, r, r]) + return {"cells": np.array(cells, np.float32), "logic": np.array(logic, np.int64)} + + +@pytest.mark.parametrize("train_mode", [True, False]) +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["tablecenternet", (3, 1024, 1024)], + ], +) +def test_table_models(arch_name, input_shape, train_mode): + batch_size = 2 + model = table_structure.__dict__[arch_name](pretrained=True) + model = model.train() if train_mode else model.eval() + assert isinstance(model, TableCenterNet) + input_tensor = torch.rand((batch_size, *input_shape)) + target = [_grid_target(), _grid_target()] + + if torch.cuda.is_available(): + model.cuda() + input_tensor = input_tensor.cuda() + out = model(input_tensor, target, return_model_output=True, return_preds=not train_mode) + assert isinstance(out, dict) + assert len(out) == 3 if not train_mode else len(out) == 2 + # Check head maps + assert "out_map" in out + for name, channels in _HEADS.items(): + assert out["out_map"][name].shape[:2] == (batch_size, channels) + assert out["out_map"][name].dtype == torch.float32 + + # Check Preds + if not train_mode: + assert len(out["preds"]) == batch_size + for pred in out["preds"]: + assert set(pred) == {"polygons", "scores", "logical"} + # Check logical coordinates have 4 entries per cell (start/end col, start/end row) + assert pred["logical"].shape[1] == 4 + # Check that the number of cells, scores and logical coordinates are the same + assert len(pred["polygons"]) == len(pred["scores"]) == len(pred["logical"]) + if pred["polygons"].size: + assert pred["polygons"].shape[1:] == (4, 2) + # Check that cells are in the range [0, 1] + assert np.all(pred["polygons"] >= 0) and np.all(pred["polygons"] <= 1) + # Check that scores are between 0 and 1 + assert np.all(pred["scores"] >= 0) and np.all(pred["scores"] <= 1) + # Check loss + assert isinstance(out["loss"], torch.Tensor) + assert hasattr(model, "from_pretrained") + gc.collect() + + +@pytest.mark.parametrize( + "arch_name", + [ + "tablecenternet", + ], +) +def test_table_structure_zoo(arch_name): + # Model + predictor = table_structure.zoo.table_predictor(arch_name, pretrained=False) + predictor.model = predictor.model.eval() + # object check + assert isinstance(predictor, TablePredictor) + input_tensor = np.random.rand(2, 1024, 1024, 3).astype(np.float32) + if torch.cuda.is_available(): + predictor.model.cuda() + + with torch.no_grad(): + out = predictor(input_tensor) + assert isinstance(out, list) and len(out) == 2 + assert all(isinstance(page, dict) for page in out) + assert all({"cells", "num_rows", "num_cols"} <= set(page) for page in out) + for page in out: + assert all(np.asarray(cell["geometry"]).shape == (4, 2) for cell in page["cells"]) + assert all({"score", "row_start", "row_end", "col_start", "col_end"} <= set(cell) for cell in page["cells"]) + assert all(0 <= cell["score"] <= 1 for cell in page["cells"]) + gc.collect() + + +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["tablecenternet", (3, 1024, 1024)], + ], +) +def test_models_onnx_export(arch_name, input_shape): + # Model + batch_size = 2 + model = table_structure.__dict__[arch_name](pretrained=False, exportable=True).eval() + dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) + head_names = list(model.heads.keys()) + pt = model(dummy_input) + pt_out = {name: pt[name].detach().cpu().numpy() for name in head_names} + with tempfile.TemporaryDirectory() as tmpdir: + # Export (the multi-head model relies on the generalized export helper to name every output) + model_path = export_model_to_onnx( + model, model_name=os.path.join(tmpdir, "model"), dummy_input=dummy_input, output_names=head_names + ) + assert os.path.exists(model_path) + # Inference + ort_session = onnxruntime.InferenceSession( + os.path.join(tmpdir, "model.onnx"), providers=["CPUExecutionProvider"] + ) + ort_outs = ort_session.run(head_names, {"input": dummy_input.numpy()}) + + assert isinstance(ort_outs, list) and len(ort_outs) == len(head_names) + # Check head map shapes + for name, ort_o in zip(head_names, ort_outs): + assert ort_o.shape == pt_out[name].shape + # Check that the output is close to the PyTorch output - only warn if not close + try: + for name, ort_o in zip(head_names, ort_outs): + assert np.allclose(pt_out[name], ort_o, atol=1e-4) + except AssertionError: + max_diff = max(np.max(np.abs(pt_out[name] - ort_o)) for name, ort_o in zip(head_names, ort_outs)) + pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {max_diff}") + + +@pytest.mark.parametrize( + "arch_name", + [ + "tablecenternet", + ], +) +def test_torch_compiled_models(arch_name): + page = (255 * np.random.rand(1024, 1024, 3)).astype(np.uint8) + predictor = table_structure.zoo.table_predictor(arch_name, pretrained=False) + assert isinstance(predictor, TablePredictor) + out = predictor([page]) + + # Compile the model + base = table_structure.__dict__[arch_name](pretrained=True).eval() + compiled_model = torch.compile(base) + assert isinstance(compiled_model, _CompiledModule) + compiled_predictor = table_structure.zoo.table_predictor(compiled_model) + compiled_out = compiled_predictor([page]) + + # Compare that outputs are close + assert len(out) == len(compiled_out) == 1 + assert {"cells", "num_rows", "num_cols"} <= set(compiled_out[0]) From eda5b6dd910361bbe56efb7d6d5531a2868b3284 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 26 Jun 2026 09:14:07 +0200 Subject: [PATCH 02/14] Add pre-model --- doctr/models/table_structure/tablecenternet/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doctr/models/table_structure/tablecenternet/pytorch.py b/doctr/models/table_structure/tablecenternet/pytorch.py index 1b949fafe2..234a079e93 100644 --- a/doctr/models/table_structure/tablecenternet/pytorch.py +++ b/doctr/models/table_structure/tablecenternet/pytorch.py @@ -26,7 +26,7 @@ "input_shape": (3, 1024, 1024), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": None, + "url": "https://github.com/mindee/doctr/releases/download/v1.0.1/tablecenternet-27736590.pt", }, } From 1d8268a3523793399424884bfa6af550528594fe Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 26 Jun 2026 12:28:47 +0200 Subject: [PATCH 03/14] update model --- doctr/models/layout/lw_detr/pytorch.py | 5 +- doctr/models/modules/layers/pytorch.py | 15 +++ .../table_structure/predictor/pytorch.py | 5 +- .../table_structure/tablecenternet/base.py | 98 +++++++++++-------- .../table_structure/tablecenternet/pytorch.py | 29 +++--- tests/common/test_models_table_structure.py | 79 ++++++++++++++- .../pytorch/test_models_table_structure_pt.py | 3 - 7 files changed, 173 insertions(+), 61 deletions(-) diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 42a140f824..8fb5a2e609 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -16,7 +16,7 @@ from doctr.models.classification import vit_det_m, vit_det_s -from ...utils import load_pretrained_params +from ...utils import _bf16_to_float32, load_pretrained_params from .base import _LWDETR, LWDETRPostProcessor from .layers import ( LWDETRDecoder, @@ -556,6 +556,9 @@ def forward( out: dict[str, Any] = {} + logits = _bf16_to_float32(logits) + pred_boxes = _bf16_to_float32(pred_boxes) + if self.exportable: out["logits"] = logits out["pred_boxes"] = pred_boxes diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py index 29a1a9d8a7..89d0fffa0c 100644 --- a/doctr/models/modules/layers/pytorch.py +++ b/doctr/models/modules/layers/pytorch.py @@ -171,6 +171,21 @@ def __init__( self.bias = nn.Parameter(torch.empty(out_channels)) channels_ = deformable_groups * 3 * kernel_size[0] * kernel_size[1] self.conv_offset_mask = nn.Conv2d(in_channels, channels_, kernel_size, stride, padding, bias=True) + self.reset_parameters() + + def reset_parameters(self) -> None: + # Standard DCN initialization: the regular conv weight is initialized like a vanilla conv, while + # the offset/mask predictor is zero-initialized so the layer starts as a plain convolution + # (offsets = 0, modulation = 0.5). Without this, weight/bias keep their uninitialized + # torch.empty values, which makes the deformable conv explode and the loss diverge to NaN. + n = self.weight.shape[1] + for k in self.weight.shape[2:]: + n *= k + stdv = 1.0 / (n**0.5) + nn.init.uniform_(self.weight, -stdv, stdv) + nn.init.zeros_(self.bias) + nn.init.zeros_(self.conv_offset_mask.weight) + nn.init.zeros_(self.conv_offset_mask.bias) # type: ignore[arg-type] def forward(self, x: torch.Tensor) -> torch.Tensor: out = self.conv_offset_mask(x) diff --git a/doctr/models/table_structure/predictor/pytorch.py b/doctr/models/table_structure/predictor/pytorch.py index d3c93051fa..d55129741b 100644 --- a/doctr/models/table_structure/predictor/pytorch.py +++ b/doctr/models/table_structure/predictor/pytorch.py @@ -63,7 +63,7 @@ def forward(self, pages: list[np.ndarray], **kwargs: Any) -> list[dict[str, Any] for pred, rect in zip(preds, rectified): polygons = rect["polygons"] # * np.array([w, h], dtype=np.float32) # relative -> absolute pixels scores, logical = pred["scores"], pred["logical"] - cells, max_row, max_col = [], 0, 0 + cells, max_row, max_col = [], -1, -1 for poly, score, lc in zip(polygons, scores, logical): start_col, end_col, start_row, end_row = (int(v) for v in lc) max_row, max_col = max(max_row, end_row), max(max_col, end_col) @@ -75,5 +75,6 @@ def forward(self, pages: list[np.ndarray], **kwargs: Any) -> list[dict[str, Any] "col_start": start_col, "col_end": end_col, }) - results.append({"cells": cells, "num_rows": max_row, "num_cols": max_col}) + # logical coordinates are 0-indexed, so the table size is the largest index + 1 (0 if no cells) + results.append({"cells": cells, "num_rows": max_row + 1, "num_cols": max_col + 1}) return results diff --git a/doctr/models/table_structure/tablecenternet/base.py b/doctr/models/table_structure/tablecenternet/base.py index 5ff4f6ffe1..39863cf7ae 100644 --- a/doctr/models/table_structure/tablecenternet/base.py +++ b/doctr/models/table_structure/tablecenternet/base.py @@ -11,11 +11,9 @@ from scipy.interpolate import griddata from shapely.geometry import Point, Polygon -__all__ = ["TableCenterNetPostProcessor"] +from doctr.models.core import BaseModel - -# TODO: It should be organized like in LinkNet for example or LWDETR _LWDETR (which builds target) -# TODO: and a LWDETRPostProcessor (which decodes the model's output). +__all__ = ["_TableCenterNet", "TableCenterNetPostProcessor"] def _get_logic_coords(lc_logic: np.ndarray, col_span: int, row_span: int) -> tuple[int, int, int, int]: @@ -58,8 +56,7 @@ def _get_logic_coords(lc_logic: np.ndarray, col_span: int, row_span: int) -> tup def _bbox_overlap_query(center_polys: np.ndarray, corner_polys: np.ndarray) -> list[np.ndarray]: - """For each center polygon, the indices of corner polygons whose axis-aligned bounding boxes overlap - (equivalent to the reference ``BoxesFinder``).""" + """For each center polygon, the indices of corner polygons whose axis-aligned bounding boxes overlap.""" c_xmin, c_xmax = center_polys[:, 0::2].min(1), center_polys[:, 0::2].max(1) c_ymin, c_ymax = center_polys[:, 1::2].min(1), center_polys[:, 1::2].max(1) k_xmin, k_xmax = corner_polys[:, 0::2].min(1), corner_polys[:, 0::2].max(1) @@ -105,7 +102,7 @@ def __init__( self.center_thresh = center_thresh self.corner_thresh = corner_thresh self.not_relocate = not_relocate - # Cell score decay (reference defaults): cells optimised on <= 2 corners get their score scaled. + # Cell score decay: cells optimised on <= 2 corners get their score scaled. self.cell_min_optimize_count = 2 self.cell_decay_thresh = 0.4 @@ -139,7 +136,7 @@ def _relocate(self, decoded: dict[str, np.ndarray], b: int): cx, cy = corner_pts[j] if not any(Point(p).within(center_poly) for p in corner_polys[j].reshape(4, 2)): continue - # nearest corner index is computed on the ORIGINAL polygon (matches find_near_corner_index) + # nearest corner index is computed on the ORIGINAL polygon idx = int(np.argmin(((origin - [cx, cy]) ** 2).sum(1))) ox, oy = origin[idx] px, py = cell[idx] @@ -187,18 +184,17 @@ def __call__(self, decoded: dict[str, np.ndarray]) -> list[dict[str, np.ndarray] cp, cs, logic = self._simple(decoded, b) if self.not_relocate else self._relocate(decoded, b) keep = cs >= self.center_thresh polys = cp[keep].reshape(-1, 4, 2) / scale # relative coordinates + # _get_logic_coords reconstructs 1-indexed logical coordinates (column/row lines start at 1, + # mirroring the +1 offset applied when rendering the target). Shift back to the 0-indexed + # convention used by the dataset and TableCellMetric so predictions and GT are comparable. results.append({ "polygons": np.clip(polys.astype(np.float32), 0, 1), # (N, 4, 2) TL, TR, BR, BL "scores": cs[keep].astype(np.float32), - "logical": logic[keep].astype(np.int32), # start_col, end_col, start_row, end_row + "logical": (logic[keep] - 1).astype(np.int32), # start_col, end_col, start_row, end_row (0-indexed) }) return results -# --------------------------------------------------------------------------------------------------------- -# Dense-target rendering (numpy/scipy, ported from the reference dataset/target builder). -# Used by ``TableCenterNet.build_target`` to render the maps consumed by ``compute_loss``. -# --------------------------------------------------------------------------------------------------------- def _gaussian_radius(det_size: tuple[float, float], min_overlap: float = 0.7) -> float: height, width = det_size a1, b1, c1 = 1, height + width, width * height * (1 - min_overlap) / (1 + min_overlap) @@ -211,8 +207,8 @@ def _gaussian_radius(det_size: tuple[float, float], min_overlap: float = 0.7) -> def _gaussian_2d(shape: tuple[int, int], sigma: float = 1.0) -> np.ndarray: - m, n = ((s - 1.0) / 2.0 for s in shape) - y, x = np.ogrid[-m : m + 1, -n : n + 1] + m, n = ((s - 1) / 2 for s in shape) + y, x = np.ogrid[-m : m + 1, -n : n + 1] # type: ignore[misc] h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) h[h < np.finfo(h.dtype).eps * h.max()] = 0 return h @@ -284,14 +280,14 @@ def _interpolate_logical_map(cells_logic_coords: list, output_size: tuple[int, i return lc, lc_mask -def build_table_target( +def _build_table_target( cells: np.ndarray, logic: np.ndarray, output_size: tuple[int, int], max_objects: int = 300, max_corners: int = 1200, ) -> dict[str, np.ndarray]: - """Render the dense TableCenterNet targets consumed by ``TableCenterNet.compute_loss``. + """Render the dense TableCenterNet targets (for a single image) consumed by ``TableCenterNet.compute_loss``. Args: cells: (N, 4, 2) cell quadrilaterals (corner order TL, TR, BR, BL) in **output-grid** coordinates @@ -399,30 +395,48 @@ def build_table_target( } -def build_target( - target: list[dict[str, np.ndarray]], output_shape: tuple[int, int], max_objects: int, max_corners: int -) -> dict[str, np.ndarray]: - """Render the dense training targets for a batch from per-image cell annotations. - - Args: - target: one ``{"cells": (N, 4, 2) relative polygons, "logic": (N, 4)}`` dict per image - output_shape: (H, W) of the model output grid - max_objects: maximum number of cells - max_corners: maximum number of distinct corners +class _TableCenterNet(BaseModel): + """TableCenterNet for table-structure recognition, as described in the official implementation + ``_. - Returns: - the batched dense target dictionary (numpy arrays) + This base class holds the framework-agnostic target rendering (``build_target``), mirroring the + organization of the detection (``_LinkNet``) and layout (``_LWDETR``) models: the dense maps consumed + by ``compute_loss`` are produced here, while ``TableCenterNetPostProcessor`` decodes the model output. """ - out_h, out_w = output_shape - scale = np.array([out_w, out_h], dtype=np.float32) - per_image = [ - build_table_target( - np.asarray(t["cells"], dtype=np.float32).reshape(-1, 4, 2) * scale, - np.asarray(t["logic"], dtype=np.int64).reshape(-1, 4), - (out_h, out_w), - max_objects, - max_corners, - ) - for t in target - ] - return {k: np.stack([img[k] for img in per_image], axis=0) for k in per_image[0]} + + max_objects: int = 300 + max_corners: int = 1200 + assume_straight_pages: bool = False + + def build_target( + self, + target: list[dict[str, np.ndarray]], + output_shape: tuple[int, int], + ) -> dict[str, np.ndarray]: + """Render the dense training targets for a batch from per-image cell annotations. + + Args: + target: one ``{"cells": (N, 4, 2) relative polygons, "logic": (N, 4)}`` dict per image + output_shape: (H, W) of the model output grid (input size // down_ratio) + + Returns: + the batched dense target dictionary (numpy arrays) matching the reference schema + """ + if any(np.asarray(t["cells"], dtype=np.float32).dtype != np.float32 for t in target): + raise AssertionError("the expected dtype of target 'cells' entry is 'np.float32'.") + if any(np.any((np.asarray(t["cells"]) > 1) | (np.asarray(t["cells"]) < 0)) for t in target): + raise ValueError("the 'cells' entry of the target is expected to take values between 0 & 1.") + + out_h, out_w = output_shape + scale = np.array([out_w, out_h], dtype=np.float32) + per_image = [ + _build_table_target( + np.asarray(t["cells"], dtype=np.float32).reshape(-1, 4, 2) * scale, + np.asarray(t["logic"], dtype=np.int64).reshape(-1, 4), + (out_h, out_w), + self.max_objects, + self.max_corners, + ) + for t in target + ] + return {k: np.stack([img[k] for img in per_image], axis=0) for k in per_image[0]} diff --git a/doctr/models/table_structure/tablecenternet/pytorch.py b/doctr/models/table_structure/tablecenternet/pytorch.py index 234a079e93..451d2ab74f 100644 --- a/doctr/models/table_structure/tablecenternet/pytorch.py +++ b/doctr/models/table_structure/tablecenternet/pytorch.py @@ -16,8 +16,8 @@ from ....models.classification import starnet_s3 from ...modules.layers.pytorch import DCNv2 -from ...utils import load_pretrained_params -from .base import TableCenterNetPostProcessor, build_target +from ...utils import _bf16_to_float32, load_pretrained_params +from .base import TableCenterNetPostProcessor, _TableCenterNet __all__ = ["TableCenterNet", "tablecenternet"] @@ -130,12 +130,12 @@ def forward(self, layers: list[torch.Tensor]) -> list[torch.Tensor]: # Model -class TableCenterNet(nn.Module): +class TableCenterNet(nn.Module, _TableCenterNet): """TableCenterNet for table-structure recognition, as described in the official implementation ``_. A StarNet backbone feeds a deformable-convolution DLA decoder, followed by six dense heads - (``hm``, ``reg``, ``ct2cn``, ``cn2ct``, ``lc``, ``sp``) describing cell centers, corners and their + (`hm`, `reg`, `ct2cn`, `cn2ct`, `lc`, `sp`) describing cell centers, corners and their logical coordinates. Args: @@ -301,19 +301,22 @@ def _np(t: torch.Tensor) -> np.ndarray: def forward( self, x: torch.Tensor, - target: list[dict[str, np.ndarray]] | dict[str, torch.Tensor] | None = None, + target: list[dict[str, np.ndarray]] | None = None, return_model_output: bool = False, return_preds: bool = False, ) -> dict[str, Any]: heads_out = self._forward_heads(x) + heads_out = {head: _bf16_to_float32(heads_out[head]) for head in self.heads} # cast to float32 (AMP safe-guard) + out: dict[str, Any] = {} if self.exportable: return heads_out if return_model_output: - out["heads_out"] = heads_out + # Cast to float32 (the heads can be bfloat16/float16 under autocast) + out["out_map"] = heads_out if target is None or return_preds: # Disable for torch.compile compatibility @@ -324,11 +327,10 @@ def _postprocess(heads_out): out["preds"] = _postprocess(heads_out) if target is not None: - # Build target + # Disable for torch.compile compatibility (the target rendering relies on numpy/scipy) @torch.compiler.disable def _compute_loss(heads_out, target): - processed_targets = self.build_target(target, self.class_names) - return self.compute_loss(heads_out, processed_targets) + return self.compute_loss(heads_out, target) out["loss"] = _compute_loss(heads_out, target) @@ -348,10 +350,13 @@ def compute_loss( Returns: the scalar training loss """ - device = output["hm"].device out_h, out_w = int(output["hm"].shape[-2]), int(output["hm"].shape[-1]) - dense_np = build_target(target, (out_h, out_w), self.max_objects, self.max_corners) + # Render the dense targets (numpy/scipy) from the relative cell annotations + dense_np = self.build_target(target, (out_h, out_w)) + device = output["hm"].device dense = {k: torch.from_numpy(v).to(device) for k, v in dense_np.items()} + # AMP safe-guard: compute the loss in float32 + output = {k: v.float() for k, v in output.items()} return self._loss_from_dense(output, dense) def _loss_from_dense(self, output: dict[str, torch.Tensor], target: dict[str, torch.Tensor]) -> torch.Tensor: @@ -382,7 +387,7 @@ def _loss_from_dense(self, output: dict[str, torch.Tensor], target: dict[str, to hm_loss = -neg_loss if num_pos == 0 else -(pos_loss + neg_loss) / num_pos # L1 on the sub-pixel offsets - reg_pred = self._transpose_and_gather_feat(output["reg"], target["reg_ind"]) + reg_pred = _transpose_and_gather_feat(output["reg"], target["reg_ind"]) reg_mask = target["reg_mask"].unsqueeze(2).expand_as(reg_pred).float() reg_loss = F.l1_loss(reg_pred * reg_mask, target["reg"] * reg_mask, reduction="sum") / (reg_mask.sum() + eps) diff --git a/tests/common/test_models_table_structure.py b/tests/common/test_models_table_structure.py index 02f09a6993..05a6a4966a 100644 --- a/tests/common/test_models_table_structure.py +++ b/tests/common/test_models_table_structure.py @@ -1,6 +1,18 @@ import numpy as np +import pytest -from doctr.models.table_structure.tablecenternet import TableCenterNetPostProcessor +from doctr.models.table_structure.tablecenternet import TableCenterNetPostProcessor, _TableCenterNet + + +def _grid_target(rows: int, cols: int) -> dict[str, np.ndarray]: + """A relative-coordinate ``{"cells", "logic"}`` target for a ``rows x cols`` grid (the dataset's output).""" + xs, ys = np.linspace(0.1, 0.9, cols + 1), np.linspace(0.1, 0.9, rows + 1) + cells, logic = [], [] + for r in range(rows): + for c in range(cols): + cells.append([[xs[c], ys[r]], [xs[c + 1], ys[r]], [xs[c + 1], ys[r + 1]], [xs[c], ys[r + 1]]]) + logic.append([c, c, r, r]) + return {"cells": np.array(cells, np.float32), "logic": np.array(logic, np.int64)} def test_tablecenternet_postprocessor(): @@ -24,3 +36,68 @@ def test_tablecenternet_postprocessor(): assert res[0]["polygons"].max() <= 1.0 # relative coordinates # not_relocate path assert len(TableCenterNetPostProcessor(center_thresh=0.0, not_relocate=True)(decoded)) == 1 + + +def test_tablecenternet_build_target(): + model = _TableCenterNet() + out_h, out_w = 64, 64 + # Two images of different sizes + one empty image + target = [ + _grid_target(2, 3), + _grid_target(1, 2), + { + "cells": np.zeros((0, 4, 2), np.float32), + "logic": np.zeros((0, 4), np.int64), + }, + ] + + dense = model.build_target(target, (out_h, out_w)) + + # The dense schema matches what compute_loss consumes + expected_keys = { + "hm", + "reg", + "reg_ind", + "reg_mask", + "ct_ind", + "ct_mask", + "cn_ind", + "cn_mask", + "ct2cn", + "cn2ct", + "ct_cn_ind", + "lc", + "lc_mask", + "lc_ind", + "lc_span", + } + assert set(dense) == expected_keys + # Every entry is batched over the images + assert all(v.shape[0] == len(target) for v in dense.values()) + # Heat-maps and logical-coordinate maps cover the 2 (center, corner) channels at output resolution + assert dense["hm"].shape == (len(target), 2, out_h, out_w) + assert dense["lc"].shape == (len(target), 2, out_h, out_w) + assert dense["lc_mask"].shape == (len(target), 2, out_h, out_w) + # Vector-pair / span widths + assert dense["ct2cn"].shape[-1] == 8 + assert dense["cn2ct"].shape[-1] == 8 + assert dense["lc_span"].shape[-1] == 2 + # The heat-map is a valid Gaussian field in [0, 1] + assert dense["hm"].min() >= 0.0 and dense["hm"].max() <= 1.0 + # One positive cell-center per ground-truth cell, none for the empty image + assert dense["ct_mask"][0].sum() == 6 # 2 x 3 grid + assert dense["ct_mask"][1].sum() == 2 # 1 x 2 grid + assert dense["ct_mask"][2].sum() == 0 # empty image + # Index tensors stay int64 (used for gather), map tensors stay float32 + assert dense["ct_ind"].dtype == np.int64 + assert dense["hm"].dtype == np.float32 + + # Cells outside the [0, 1] relative range are rejected + bad = [ + { + "cells": np.array([[[0.0, 0.0], [2.0, 0.0], [2.0, 2.0], [0.0, 2.0]]], np.float32), + "logic": np.array([[0, 0, 0, 0]], np.int64), + } + ] + with pytest.raises(ValueError): + model.build_target(bad, (out_h, out_w)) diff --git a/tests/pytorch/test_models_table_structure_pt.py b/tests/pytorch/test_models_table_structure_pt.py index 337158f568..c559c99c8c 100644 --- a/tests/pytorch/test_models_table_structure_pt.py +++ b/tests/pytorch/test_models_table_structure_pt.py @@ -1,4 +1,3 @@ -import gc import os import tempfile @@ -71,7 +70,6 @@ def test_table_models(arch_name, input_shape, train_mode): # Check loss assert isinstance(out["loss"], torch.Tensor) assert hasattr(model, "from_pretrained") - gc.collect() @pytest.mark.parametrize( @@ -99,7 +97,6 @@ def test_table_structure_zoo(arch_name): assert all(np.asarray(cell["geometry"]).shape == (4, 2) for cell in page["cells"]) assert all({"score", "row_start", "row_end", "col_start", "col_end"} <= set(cell) for cell in page["cells"]) assert all(0 <= cell["score"] <= 1 for cell in page["cells"]) - gc.collect() @pytest.mark.parametrize( From 7e631b1f201eddd86232fbc46a7a9c7e0adc1eea Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 26 Jun 2026 14:04:40 +0200 Subject: [PATCH 04/14] Update train --- references/table/train.py | 1 - references/table/utils.py | 44 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/references/table/train.py b/references/table/train.py index 485fdfc958..cd117e4cda 100644 --- a/references/table/train.py +++ b/references/table/train.py @@ -283,7 +283,6 @@ def main(args): ]) # Image + geometry augmentations (letterbox to a square; the model renders the dense targets) sample_transforms = T.SampleCompose([ - T.RandomHorizontalFlip(0.15), T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), ]) diff --git a/references/table/utils.py b/references/table/utils.py index 873a42eb65..97a03806a3 100644 --- a/references/table/utils.py +++ b/references/table/utils.py @@ -11,7 +11,7 @@ import numpy as np -def plot_samples(images: list[Any], targets: list[dict[str, np.ndarray]], max_samples: int = 4) -> None: +def plot_samples(images: list[Any], targets: list[dict[str, np.ndarray]], max_samples: int = 2) -> None: """Display a few training samples with their ground-truth cells overlaid.""" nb_samples = min(len(images), max_samples) _, axes = plt.subplots(2, nb_samples, figsize=(20, 6)) @@ -28,12 +28,50 @@ def plot_samples(images: list[Any], targets: list[dict[str, np.ndarray]], max_sa overlay = img.copy() cells = targets[idx]["cells"].copy() + logic = targets[idx]["logic"] + cells[..., 0] *= img.shape[1] cells[..., 1] *= img.shape[0] - for quad in cells.round().astype(np.intp): + + for quad, (start_col, end_col, start_row, end_row) in zip( + cells.round().astype(np.intp), + logic, + ): cv2.polylines(overlay, [quad], True, (255, 0, 0), 1) + + center = quad.mean(axis=0) + + # Corner order: 0=TL, 1=TR, 2=BR, 3=BL + for corner_idx, corner in enumerate(quad): + # Move the label from the corner toward the polygon center. + label_position = corner + 0.18 * (center - corner) + x, y = label_position.astype(np.intp) + + cv2.putText( + overlay, + str(corner_idx), + (x, y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 0, 255), # Pink + 1, + cv2.LINE_AA, + ) + + center_x, center_y = center.astype(np.intp) + cv2.putText( + overlay, + f"C:{start_col}-{end_col} R:{start_row}-{end_row}", + (center_x, center_y), + cv2.FONT_HERSHEY_SIMPLEX, + 0.35, + (0, 100, 0), # Dark green + 1, + cv2.LINE_AA, + ) + axes[1][idx].imshow(overlay) - axes[1][idx].set_title("GT cells") + axes[1][idx].set_title("GT cells | corners: TL, TR, BR, BL") for ax in axes.ravel(): ax.axis("off") From dc7b991ba8a42c7cd59feb3c0040497267fbc647 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 29 Jun 2026 12:25:11 +0200 Subject: [PATCH 05/14] Update scripts --- doctr/datasets/table_structure.py | 15 +++- .../table_structure/tablecenternet/base.py | 88 +++++++++++++++++-- .../table_structure/tablecenternet/pytorch.py | 5 +- doctr/utils/metrics.py | 30 ++++--- references/table/evaluate.py | 10 ++- references/table/train.py | 73 +++++++++++---- tests/common/test_models_table_structure.py | 76 +++++++++++++--- tests/common/test_utils_metrics.py | 33 ++++--- tests/pytorch/test_datasets_pt.py | 16 +++- .../pytorch/test_models_table_structure_pt.py | 39 +++++--- 10 files changed, 306 insertions(+), 79 deletions(-) diff --git a/doctr/datasets/table_structure.py b/doctr/datasets/table_structure.py index b97318ae94..0b6bd175e4 100644 --- a/doctr/datasets/table_structure.py +++ b/doctr/datasets/table_structure.py @@ -32,7 +32,9 @@ class TableStructureDataset(AbstractDataset): ... } - Each sample yields the image and a ``{"cells": (N, 4, 2) relative polygons, "logic": (N, 4)}`` target. + Each sample yields the image and a target containing relative cells and their logical coordinates. Cells have + shape ``(N, 4)`` by default, or ``(N, 4, 2)`` when ``use_polygons=True``. Logical coordinates have shape + ``(N, 4)``. >>> from doctr.datasets import TableStructureDataset >>> from doctr.transforms import Resize @@ -45,10 +47,17 @@ class TableStructureDataset(AbstractDataset): Args: img_folder: folder with all the dataset images label_path: path to the JSON labels + use_polygons: whether to keep cell polygons instead of converting them to straight boxes **kwargs: keyword arguments from `AbstractDataset` (e.g. ``img_transforms``, ``sample_transforms``) """ - def __init__(self, img_folder: str, label_path: str, **kwargs: Any) -> None: + def __init__( + self, + img_folder: str, + label_path: str, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: super().__init__(img_folder, **kwargs) if not os.path.exists(label_path): @@ -69,6 +78,8 @@ def __init__(self, img_folder: str, label_path: str, **kwargs: Any) -> None: raise ValueError(f"cells are expected to have shape (N, 4, 2), got {cells.shape}") if logic.shape[0] != cells.shape[0] or logic.shape[1] != 4: # pragma: no cover raise ValueError(f"logic is expected to have shape (N, 4), got {logic.shape}") + if not use_polygons: + cells = np.concatenate((cells.min(axis=1), cells.max(axis=1)), axis=1) self.data.append((img_name, {"cells": cells, "logic": logic})) # NOTE: Override basic dataset method __getitem__ to handle table-specific targets diff --git a/doctr/models/table_structure/tablecenternet/base.py b/doctr/models/table_structure/tablecenternet/base.py index 39863cf7ae..aec8216275 100644 --- a/doctr/models/table_structure/tablecenternet/base.py +++ b/doctr/models/table_structure/tablecenternet/base.py @@ -77,6 +77,31 @@ def _lookup_logic(lc_map: np.ndarray, x: float, y: float) -> np.ndarray: return lc_map[:, yi, xi] +def _ensure_simple_quads(polys: np.ndarray) -> np.ndarray: + """Guarantee each predicted quad is a simple (non-self-intersecting) polygon. + + The center decode (cell built from the ``ct2cn`` offset vectors) and the corner-relocation step can + occasionally yield a self-intersecting "bow-tie" quad - e.g. a mis-predicted cell whose ``TL``/``TR`` + (or any two) corners cross over. Such polygons are invalid for shapely and make + :func:`doctr.utils.metrics.polygon_iou` raise a ``TopologyException`` (side location conflict) during + evaluation. Reordering the four points by their angle around the centroid produces the simple polygon + spanned by the *same four corners* (identical cell region), which is the natural recovery; quads that + are already valid keep their original corner order untouched. + + Args: + polys: predicted quads, shape (N, 4, 2) + + Returns: + the quads with every self-intersecting one reordered into a simple polygon, shape (N, 4, 2) + """ + for i in range(polys.shape[0]): + if not Polygon(polys[i]).is_valid: + centroid = polys[i].mean(axis=0) + angles = np.arctan2(polys[i, :, 1] - centroid[1], polys[i, :, 0] - centroid[0]) + polys[i] = polys[i][np.argsort(angles)] + return polys + + class TableCenterNetPostProcessor: """Torch-free post-processor turning the model's *decoded* key-points into table cells. @@ -85,12 +110,15 @@ class TableCenterNetPostProcessor: it never blocks an export and can be tested without torch. The cell geometry is returned in **relative** coordinates ([0, 1] w.r.t. the model input), so the - predictor can undo the pre-processor's padding/resize like the other docTR predictors. + predictor can undo the pre-processor's padding/resize like the other docTR predictors. When + ``assume_straight_pages=True``, geometries are axis-aligned boxes of shape ``(N, 4)``; otherwise they + are quadrilaterals of shape ``(N, 4, 2)``. Args: center_thresh: minimum score for a cell center to be kept corner_thresh: minimum score for a corner to be used during relocation not_relocate: if True, skip the corner-relocation step (faster, less accurate) + assume_straight_pages: whether the pages are assumed to be straight (i.e., no rotation) """ def __init__( @@ -98,10 +126,12 @@ def __init__( center_thresh: float = 0.3, corner_thresh: float = 0.3, not_relocate: bool = False, + assume_straight_pages: bool = True, ) -> None: self.center_thresh = center_thresh self.corner_thresh = corner_thresh self.not_relocate = not_relocate + self.assume_straight_pages = assume_straight_pages # Cell score decay: cells optimised on <= 2 corners get their score scaled. self.cell_min_optimize_count = 2 self.cell_decay_thresh = 0.4 @@ -184,11 +214,19 @@ def __call__(self, decoded: dict[str, np.ndarray]) -> list[dict[str, np.ndarray] cp, cs, logic = self._simple(decoded, b) if self.not_relocate else self._relocate(decoded, b) keep = cs >= self.center_thresh polys = cp[keep].reshape(-1, 4, 2) / scale # relative coordinates + # Guarantee simple (non-self-intersecting) quads so shapely-based IoU (TableCellMetric) never + # sees an invalid geometry. Applied after the relative rescale; logical coords are unaffected. + polys = _ensure_simple_quads(np.clip(polys.astype(np.float32), 0, 1)) + cells = ( + np.concatenate([polys.min(axis=1), polys.max(axis=1)], axis=1).astype(np.float32) + if self.assume_straight_pages + else polys + ) # _get_logic_coords reconstructs 1-indexed logical coordinates (column/row lines start at 1, # mirroring the +1 offset applied when rendering the target). Shift back to the 0-indexed # convention used by the dataset and TableCellMetric so predictions and GT are comparable. results.append({ - "polygons": np.clip(polys.astype(np.float32), 0, 1), # (N, 4, 2) TL, TR, BR, BL + "polygons": cells, # (N, 4) boxes or (N, 4, 2) quads in relative coordinates "scores": cs[keep].astype(np.float32), "logical": (logic[keep] - 1).astype(np.int32), # start_col, end_col, start_row, end_row (0-indexed) }) @@ -395,6 +433,32 @@ def _build_table_target( } +def _cells_to_polygons(cells: np.ndarray) -> np.ndarray: + """Convert table cells to quadrilaterals. + + Args: + cells: relative axis-aligned boxes of shape ``(N, 4)`` in ``(xmin, ymin, xmax, ymax)`` format, + or quadrilaterals of shape ``(N, 4, 2)``. + + Returns: + Relative quadrilaterals of shape ``(N, 4, 2)`` in TL, TR, BR, BL order. + """ + if cells.ndim == 3 and cells.shape[1:] == (4, 2): + return cells + if cells.ndim == 2 and cells.shape[1:] == (4,): + xmin, ymin, xmax, ymax = cells.T + return np.stack( + [ + np.stack([xmin, ymin], axis=-1), + np.stack([xmax, ymin], axis=-1), + np.stack([xmax, ymax], axis=-1), + np.stack([xmin, ymax], axis=-1), + ], + axis=1, + ).astype(np.float32, copy=False) + raise ValueError(f"cells are expected to have shape (N, 4) or (N, 4, 2), got {cells.shape}") + + class _TableCenterNet(BaseModel): """TableCenterNet for table-structure recognition, as described in the official implementation ``_. @@ -416,27 +480,33 @@ def build_target( """Render the dense training targets for a batch from per-image cell annotations. Args: - target: one ``{"cells": (N, 4, 2) relative polygons, "logic": (N, 4)}`` dict per image + target: one ``{"cells": (N, 4) relative boxes or (N, 4, 2) relative polygons, + "logic": (N, 4)}`` dict per image output_shape: (H, W) of the model output grid (input size // down_ratio) Returns: the batched dense target dictionary (numpy arrays) matching the reference schema """ - if any(np.asarray(t["cells"], dtype=np.float32).dtype != np.float32 for t in target): - raise AssertionError("the expected dtype of target 'cells' entry is 'np.float32'.") - if any(np.any((np.asarray(t["cells"]) > 1) | (np.asarray(t["cells"]) < 0)) for t in target): - raise ValueError("the 'cells' entry of the target is expected to take values between 0 & 1.") + cells_per_image: list[np.ndarray] = [] + for t in target: + cells = np.asarray(t["cells"]) + if cells.dtype != np.float32: + raise AssertionError("the expected dtype of target 'cells' entry is 'np.float32'.") + cells = _cells_to_polygons(cells) + if np.any((cells > 1) | (cells < 0)): + raise ValueError("the 'cells' entry of the target is expected to take values between 0 & 1.") + cells_per_image.append(cells) out_h, out_w = output_shape scale = np.array([out_w, out_h], dtype=np.float32) per_image = [ _build_table_target( - np.asarray(t["cells"], dtype=np.float32).reshape(-1, 4, 2) * scale, + cells * scale, np.asarray(t["logic"], dtype=np.int64).reshape(-1, 4), (out_h, out_w), self.max_objects, self.max_corners, ) - for t in target + for t, cells in zip(target, cells_per_image) ] return {k: np.stack([img[k] for img in per_image], axis=0) for k in per_image[0]} diff --git a/doctr/models/table_structure/tablecenternet/pytorch.py b/doctr/models/table_structure/tablecenternet/pytorch.py index 451d2ab74f..406e98fe8e 100644 --- a/doctr/models/table_structure/tablecenternet/pytorch.py +++ b/doctr/models/table_structure/tablecenternet/pytorch.py @@ -210,7 +210,10 @@ def __init__( self.__setattr__(head, fc) self.postprocessor = TableCenterNetPostProcessor( - center_thresh=center_thresh, corner_thresh=corner_thresh, not_relocate=not_relocate + center_thresh=center_thresh, + corner_thresh=corner_thresh, + not_relocate=not_relocate, + assume_straight_pages=self.assume_straight_pages, ) def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: diff --git a/doctr/utils/metrics.py b/doctr/utils/metrics.py index d0bca4ea39..4364d0b0a0 100644 --- a/doctr/utils/metrics.py +++ b/doctr/utils/metrics.py @@ -318,25 +318,30 @@ def reset(self) -> None: class TableCellMetric: r"""Implements a table-structure-recognition metric. - Predicted cells are matched to ground-truth cells by maximising the total polygon IoU (Hungarian - assignment); a pair counts as a match when its IoU is at least ``iou_thresh``. From the matches it - reports cell-detection recall / precision / F1 and the **structure accuracy** (the fraction of matched - cells whose logical coordinates ``[start_col, end_col, start_row, end_row]`` exactly equal the - ground-truth ones). + Predicted cells are matched to ground-truth cells by maximising the total IoU (Hungarian assignment); a pair + counts as a match when its IoU is at least ``iou_thresh``. From the matches it reports cell-detection recall, + precision, F1 and the **structure accuracy** (the fraction of matched cells whose logical coordinates + ``[start_col, end_col, start_row, end_row]`` exactly equal the ground-truth ones). >>> import numpy as np >>> from doctr.utils import TableCellMetric >>> metric = TableCellMetric(iou_thresh=0.5) - >>> gt = np.array([[[0, 0], [1, 0], [1, 1], [0, 1]]], dtype=np.float32) + >>> gt = np.array([[0, 0, 1, 1]], dtype=np.float32) >>> metric.update(gt, np.array([[0, 0, 0, 0]]), gt, np.array([[0, 0, 0, 0]])) >>> metric.summary() Args: - iou_thresh: minimum polygon IoU for a predicted/ground-truth cell pair to be considered a match + iou_thresh: minimum IoU for a predicted/ground-truth cell pair to be considered a match + use_polygons: if set to True, predictions and targets will be expected to have polygon format """ - def __init__(self, iou_thresh: float = 0.5) -> None: + def __init__( + self, + iou_thresh: float = 0.5, + use_polygons: bool = False, + ) -> None: self.iou_thresh = iou_thresh + self.use_polygons = use_polygons self.reset() def update( @@ -349,9 +354,9 @@ def update( """Update the metric with one sample. Args: - gt_cells: ground-truth cell polygons, shape (N, 4, 2) + gt_cells: ground-truth cells, shape (N, 4) or (N, 4, 2) when ``use_polygons=True`` gt_logic: ground-truth logical coordinates, shape (N, 4) - pred_cells: predicted cell polygons, shape (M, 4, 2) + pred_cells: predicted cells, shape (M, 4) or (M, 4, 2) when ``use_polygons=True`` pred_logic: predicted logical coordinates, shape (M, 4) """ self.num_gts += gt_cells.shape[0] @@ -359,7 +364,10 @@ def update( if gt_cells.shape[0] == 0 or pred_cells.shape[0] == 0: return - iou_mat = polygon_iou(gt_cells, pred_cells) # (N, M) + if self.use_polygons: + iou_mat = polygon_iou(gt_cells, pred_cells) + else: + iou_mat = box_iou(gt_cells, pred_cells) gt_idx, pred_idx = linear_sum_assignment(-iou_mat) for gi, pi in zip(gt_idx, pred_idx): if iou_mat[gi, pi] >= self.iou_thresh: diff --git a/references/table/evaluate.py b/references/table/evaluate.py index 1228ceff26..f2aa3add83 100644 --- a/references/table/evaluate.py +++ b/references/table/evaluate.py @@ -67,7 +67,7 @@ def main(args): torch.backends.cudnn.benchmark = True - tmp_model = table_structure.__dict__[args.arch](pretrained=False) + tmp_model = table_structure.__dict__[args.arch](pretrained=False, assume_straight_pages=not args.rotation) input_shape = (args.size, args.size) if isinstance(args.size, int) else tmp_model.cfg["input_shape"][-2:] mean, std = tmp_model.cfg["mean"], tmp_model.cfg["std"] @@ -75,6 +75,7 @@ def main(args): ds = TableStructureDataset( img_folder=os.path.join(args.dataset_path, "images"), label_path=os.path.join(args.dataset_path, "labels.json"), + use_polygons=args.rotation, sample_transforms=T.Resize( input_shape, preserve_aspect_ratio=args.keep_ratio, symmetric_pad=args.symmetric_pad ), @@ -90,7 +91,9 @@ def main(args): ) pbar.write(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in {len(test_loader)} batches)") - model = table_structure.__dict__[args.arch](pretrained=not isinstance(args.resume, str)).eval() + model = table_structure.__dict__[args.arch]( + pretrained=not isinstance(args.resume, str), assume_straight_pages=not args.rotation + ).eval() batch_transforms = Normalize(mean=mean, std=std) if isinstance(args.resume, str): pbar.write(f"Resuming {args.resume}") @@ -109,7 +112,7 @@ def main(args): torch.cuda.set_device(args.device) model = model.cuda() - metric = TableCellMetric(iou_thresh=args.iou_thresh) + metric = TableCellMetric(iou_thresh=args.iou_thresh, use_polygons=args.rotation) pbar.write("Running evaluation") val_loss, recall, precision, f1, struct = evaluate(model, test_loader, batch_transforms, metric, amp=args.amp) pbar.write( @@ -133,6 +136,7 @@ def parse_args(): parser.add_argument("--keep_ratio", action="store_true", help="keep the aspect ratio of the input image") parser.add_argument("--symmetric_pad", action="store_true", help="pad the image symmetrically") parser.add_argument("--iou_thresh", type=float, default=0.5, help="IoU threshold for cell matching") + parser.add_argument("--rotation", action="store_true", help="use rotation augmentation") parser.add_argument("-j", "--workers", type=int, default=None, help="number of workers used for dataloading") parser.add_argument("--resume", type=str, default=None, help="Checkpoint to resume") parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") diff --git a/references/table/train.py b/references/table/train.py index cd117e4cda..95241984f5 100644 --- a/references/table/train.py +++ b/references/table/train.py @@ -163,7 +163,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False, log=Non # Cells & logical coords are compared in the (relative) model-input space for target, pred in zip(targets, out["preds"]): val_metric.update( - np.asarray(target["cells"], dtype=np.float32).reshape(-1, 4, 2), + np.asarray(target["cells"], dtype=np.float32), np.asarray(target["logic"], dtype=np.int64).reshape(-1, 4), pred["polygons"], pred["logical"], @@ -219,19 +219,35 @@ def main(args): torch.backends.cudnn.benchmark = True # Temporary model to recover the configuration (mean/std) - tmp_model = table_structure.__dict__[args.arch](pretrained=False) + tmp_model = table_structure.__dict__[args.arch](pretrained=False, assume_straight_pages=not args.rotation) mean, std = tmp_model.cfg["mean"], tmp_model.cfg["std"] # Validation data val_hash = None if rank == 0: st = time.time() + # Validation sample transforms (shared by both data sources) + val_sample_transforms = T.SampleCompose( + ( + [T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True)] + if not args.rotation or args.eval_straight + else [] + ) + + ( + [ + T.Resize(args.input_size, preserve_aspect_ratio=True), # This does not pad + T.RandomApply(T.RandomRotate(90, expand=True), 0.5), + T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), + ] + if args.rotation and not args.eval_straight + else [] + ) + ) val_set = TableStructureDataset( img_folder=os.path.join(args.val_path, "images"), label_path=os.path.join(args.val_path, "labels.json"), - sample_transforms=T.SampleCompose([ - T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), - ]), + sample_transforms=val_sample_transforms, + use_polygons=args.rotation and not args.eval_straight, ) val_loader = DataLoader( val_set, @@ -250,13 +266,13 @@ def main(args): batch_transforms = Normalize(mean=mean, std=std) - model = table_structure.__dict__[args.arch](pretrained=args.pretrained) + model = table_structure.__dict__[args.arch](pretrained=args.pretrained, assume_straight_pages=not args.rotation) if isinstance(args.resume, str): pbar.write(f"Resuming {args.resume}") model.from_pretrained(args.resume) if rank == 0: - val_metric = TableCellMetric(iou_thresh=args.iou_thresh) + val_metric = TableCellMetric(iou_thresh=args.iou_thresh, use_polygons=args.rotation and not args.eval_straight) if rank == 0 and args.test_only: pbar.write("Running evaluation") @@ -271,26 +287,46 @@ def main(args): return st = time.time() - # Image-only augmentations + # Augmentations + # Image augmentations img_transforms = T.OneOf([ Compose([ T.RandomApply(T.ColorInversion(), 0.3), T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.2), ]), + Compose([ + T.RandomApply(T.RandomShadow(), 0.3), + T.RandomApply(T.GaussianNoise(), 0.1), + T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3), + T.ImageTorchvisionTransform(RandomGrayscale(p=0.15)), + ]), T.ImageTorchvisionTransform(RandomPhotometricDistort(p=0.3)), - T.ImageTorchvisionTransform(RandomGrayscale(p=0.1)), - lambda x: x, # identity - ]) - # Image + geometry augmentations (letterbox to a square; the model renders the dense targets) - sample_transforms = T.SampleCompose([ - T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), + lambda x: x, # Identity no transformation ]) + # Image + target augmentations + sample_transforms = T.SampleCompose( + ( + [ + T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), + T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), + ] + if not args.rotation + else [ + T.RandomResize(scale_range=(0.4, 0.9), preserve_aspect_ratio=0.5, symmetric_pad=0.5, p=0.25), + # Rotation augmentation + T.Resize(args.input_size, preserve_aspect_ratio=True), + T.RandomApply(T.RandomRotate(90, expand=True), 0.5), + T.Resize((args.input_size, args.input_size), preserve_aspect_ratio=True, symmetric_pad=True), + ] + ) + ) train_set = TableStructureDataset( img_folder=os.path.join(args.train_path, "images"), label_path=os.path.join(args.train_path, "labels.json"), img_transforms=img_transforms, sample_transforms=sample_transforms, + use_polygons=args.rotation and not args.eval_straight, ) sampler = ( DistributedSampler(train_set, rank=rank, shuffle=True, drop_last=True) @@ -381,6 +417,7 @@ def main(args): "train_hash": train_hash, "val_hash": val_hash, "pretrained": args.pretrained, + "rotation": args.rotation, "amp": args.amp, } @@ -498,7 +535,7 @@ def parse_args(): "--val_path", type=str, required=True, help="path to the validation data folder (images/ + labels.json)" ) parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") - parser.add_argument("--epochs", type=int, default=200, help="number of epochs to train the model on") + parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") parser.add_argument("-b", "--batch_size", type=int, default=2, help="batch size for training") parser.add_argument( "--save-interval-epoch", dest="save_interval_epoch", action="store_true", help="Save model every epoch" @@ -521,6 +558,12 @@ def parse_args(): parser.add_argument( "--pretrained", dest="pretrained", action="store_true", help="Load pretrained parameters before training" ) + parser.add_argument("--rotation", dest="rotation", action="store_true", help="train with rotated documents") + parser.add_argument( + "--eval-straight", + action="store_true", + help="metrics evaluation with straight boxes instead of polygons to save time + memory", + ) parser.add_argument("--optim", type=str, default="adamw", choices=["adam", "adamw"], help="optimizer to use") parser.add_argument( "--sched", type=str, default="cosine", choices=["cosine", "onecycle", "poly"], help="scheduler to use" diff --git a/tests/common/test_models_table_structure.py b/tests/common/test_models_table_structure.py index 05a6a4966a..68811dca1a 100644 --- a/tests/common/test_models_table_structure.py +++ b/tests/common/test_models_table_structure.py @@ -4,19 +4,27 @@ from doctr.models.table_structure.tablecenternet import TableCenterNetPostProcessor, _TableCenterNet -def _grid_target(rows: int, cols: int) -> dict[str, np.ndarray]: - """A relative-coordinate ``{"cells", "logic"}`` target for a ``rows x cols`` grid (the dataset's output).""" +def _grid_target(rows: int, cols: int, use_polygons: bool) -> dict[str, np.ndarray]: + """A relative-coordinate ``{"cells", "logic"}`` target for a ``rows x cols`` grid.""" xs, ys = np.linspace(0.1, 0.9, cols + 1), np.linspace(0.1, 0.9, rows + 1) cells, logic = [], [] for r in range(rows): for c in range(cols): cells.append([[xs[c], ys[r]], [xs[c + 1], ys[r]], [xs[c + 1], ys[r + 1]], [xs[c], ys[r + 1]]]) logic.append([c, c, r, r]) - return {"cells": np.array(cells, np.float32), "logic": np.array(logic, np.int64)} + cell_array = np.asarray(cells, dtype=np.float32).reshape(-1, 4, 2) + if not use_polygons: + cell_array = np.concatenate([cell_array.min(axis=1), cell_array.max(axis=1)], axis=1) + return {"cells": cell_array, "logic": np.asarray(logic, dtype=np.int64).reshape(-1, 4)} -def test_tablecenternet_postprocessor(): - postprocessor = TableCenterNetPostProcessor(center_thresh=0.0) + +@pytest.mark.parametrize("assume_straight_pages", [True, False]) +def test_tablecenternet_postprocessor(assume_straight_pages): + postprocessor = TableCenterNetPostProcessor( + center_thresh=0.0, + assume_straight_pages=assume_straight_pages, + ) kc, kn, feat = 12, 16, 64 decoded = { "center_polygons": (np.random.rand(1, kc, 8) * feat).astype(np.float32), @@ -30,23 +38,33 @@ def test_tablecenternet_postprocessor(): "feat_size": (feat, feat), } res = postprocessor(decoded) - assert len(res) == 1 and res[0]["polygons"].shape[1:] == (4, 2) + assert len(res) == 1 + assert res[0]["polygons"].shape[1:] == ((4,) if assume_straight_pages else (4, 2)) assert res[0]["logical"].shape[1] == 4 if res[0]["polygons"].size: - assert res[0]["polygons"].max() <= 1.0 # relative coordinates - # not_relocate path - assert len(TableCenterNetPostProcessor(center_thresh=0.0, not_relocate=True)(decoded)) == 1 + assert res[0]["polygons"].max() <= 1.0 + assert res[0]["polygons"].min() >= 0.0 + + # not_relocate path follows the same geometry contract + simple_res = TableCenterNetPostProcessor( + center_thresh=0.0, + not_relocate=True, + assume_straight_pages=assume_straight_pages, + )(decoded) + assert len(simple_res) == 1 + assert simple_res[0]["polygons"].shape[1:] == ((4,) if assume_straight_pages else (4, 2)) -def test_tablecenternet_build_target(): +@pytest.mark.parametrize("use_polygons", [False, True]) +def test_tablecenternet_build_target(use_polygons): model = _TableCenterNet() out_h, out_w = 64, 64 # Two images of different sizes + one empty image target = [ - _grid_target(2, 3), - _grid_target(1, 2), + _grid_target(2, 3, use_polygons), + _grid_target(1, 2, use_polygons), { - "cells": np.zeros((0, 4, 2), np.float32), + "cells": np.zeros((0, 4, 2) if use_polygons else (0, 4), np.float32), "logic": np.zeros((0, 4), np.int64), }, ] @@ -95,9 +113,39 @@ def test_tablecenternet_build_target(): # Cells outside the [0, 1] relative range are rejected bad = [ { - "cells": np.array([[[0.0, 0.0], [2.0, 0.0], [2.0, 2.0], [0.0, 2.0]]], np.float32), + "cells": ( + np.array([[0.0, 0.0, 2.0, 2.0]], np.float32) + if not use_polygons + else np.array([[[0.0, 0.0], [2.0, 0.0], [2.0, 2.0], [0.0, 2.0]]], np.float32) + ), "logic": np.array([[0, 0, 0, 0]], np.int64), } ] with pytest.raises(ValueError): model.build_target(bad, (out_h, out_w)) + + +def test_tablecenternet_build_target_box_polygon_equivalence(): + model = _TableCenterNet() + box_dense = model.build_target([_grid_target(2, 3, use_polygons=False)], (64, 64)) + polygon_dense = model.build_target([_grid_target(2, 3, use_polygons=True)], (64, 64)) + + assert box_dense.keys() == polygon_dense.keys() + for key in box_dense: + np.testing.assert_array_equal(box_dense[key], polygon_dense[key]) + + +@pytest.mark.parametrize( + "cells", + [ + np.zeros((1, 5), dtype=np.float32), + np.zeros((1, 3, 2), dtype=np.float32), + np.zeros((4,), dtype=np.float32), + ], +) +def test_tablecenternet_build_target_rejects_invalid_cell_shape(cells): + with pytest.raises(ValueError, match="cells are expected to have shape"): + _TableCenterNet().build_target( + [{"cells": cells, "logic": np.zeros((len(cells), 4), dtype=np.int64)}], + (64, 64), + ) diff --git a/tests/common/test_utils_metrics.py b/tests/common/test_utils_metrics.py index 83c41efda1..57bf48e750 100644 --- a/tests/common/test_utils_metrics.py +++ b/tests/common/test_utils_metrics.py @@ -472,32 +472,45 @@ def _square(x, y): return [[x, y], [x + 1, y], [x + 1, y + 1], [x, y + 1]] -def test_table_cell_metric(): - gt_cells = np.asarray([_square(0, 0), _square(2, 0), _square(0, 2)], dtype=np.float32) +def _table_cells(use_polygons): + polygons = np.asarray([_square(0, 0), _square(2, 0), _square(0, 2)], dtype=np.float32) + if use_polygons: + return polygons + return np.concatenate((polygons.min(axis=1), polygons.max(axis=1)), axis=1) + + +@pytest.mark.parametrize("use_polygons", [False, True]) +def test_table_cell_metric(use_polygons): + gt_cells = _table_cells(use_polygons) gt_logic = np.asarray([[0, 0, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1]], dtype=np.int64) # Perfect match -> everything is 1 - metric = metrics.TableCellMetric(iou_thresh=0.5) + metric = metrics.TableCellMetric(iou_thresh=0.5, use_polygons=use_polygons) metric.update(gt_cells, gt_logic, gt_cells.copy(), gt_logic.copy()) res = metric.summary() - assert res["recall"] == 1.0 and res["precision"] == 1.0 and res["f1"] == 1.0 and res["structure_acc"] == 1.0 + assert res["recall"] == 1.0 and res["precision"] == 1.0 and res["f1"] == 1.0 + assert res["structure_acc"] == 1.0 # One wrong logical coordinate -> geometry perfect, structure accuracy 2/3 bad_logic = gt_logic.copy() bad_logic[1] = [5, 5, 5, 5] - metric = metrics.TableCellMetric(iou_thresh=0.5) + metric = metrics.TableCellMetric(iou_thresh=0.5, use_polygons=use_polygons) metric.update(gt_cells, gt_logic, gt_cells.copy(), bad_logic) res = metric.summary() - assert res["recall"] == 1.0 and abs(res["structure_acc"] - 2 / 3) < 1e-6 + assert res["recall"] == 1.0 and res["structure_acc"] == pytest.approx(2 / 3) # A missing prediction -> recall 2/3, precision 1 - metric = metrics.TableCellMetric(iou_thresh=0.5) + metric = metrics.TableCellMetric(iou_thresh=0.5, use_polygons=use_polygons) metric.update(gt_cells, gt_logic, gt_cells[:2], gt_logic[:2]) res = metric.summary() - assert abs(res["recall"] - 2 / 3) < 1e-6 and res["precision"] == 1.0 + assert res["recall"] == pytest.approx(2 / 3) and res["precision"] == 1.0 # Empty edge cases - metric = metrics.TableCellMetric() - metric.update(gt_cells, gt_logic, np.zeros((0, 4, 2), np.float32), np.zeros((0, 4), np.int64)) + empty_cells = np.zeros((0, 4, 2) if use_polygons else (0, 4), dtype=np.float32) + metric = metrics.TableCellMetric(use_polygons=use_polygons) + metric.update(gt_cells, gt_logic, empty_cells, np.zeros((0, 4), dtype=np.int64)) res = metric.summary() assert res["recall"] == 0.0 and res["precision"] is None and res["structure_acc"] is None + + metric.reset() + assert metric.num_gts == metric.num_preds == metric.matches == metric.struct_matches == 0 diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py index c369e43591..573c7debcc 100644 --- a/tests/pytorch/test_datasets_pt.py +++ b/tests/pytorch/test_datasets_pt.py @@ -304,13 +304,15 @@ def test_layout_dataset(mock_image_folder, mock_layout_label, use_polygons): json.dump(original_labels, f) -def test_table_dataset(mock_image_folder, mock_table_label): +@pytest.mark.parametrize("use_polygons", [False, True]) +def test_table_dataset(mock_image_folder, mock_table_label, use_polygons): input_size = (1024, 1024) ds = datasets.TableStructureDataset( img_folder=mock_image_folder, label_path=mock_table_label, sample_transforms=SampleCompose([Resize(input_size, preserve_aspect_ratio=True, symmetric_pad=True)]), + use_polygons=use_polygons, ) assert len(ds) == 5 @@ -318,17 +320,25 @@ def test_table_dataset(mock_image_folder, mock_table_label): img, target = sample.image, sample.target assert isinstance(img, torch.Tensor) and img.dtype == torch.float32 assert img.shape[-2:] == input_size - # Target carries relative cell polygons and integer logical coordinates assert isinstance(target, dict) and set(target) == {"cells", "logic"} assert isinstance(target["cells"], np.ndarray) and target["cells"].dtype == np.float32 - assert target["cells"].ndim == 3 and target["cells"].shape[1:] == (4, 2) + if use_polygons: + assert target["cells"].ndim == 3 and target["cells"].shape[1:] == (4, 2) + else: + assert target["cells"].ndim == 2 and target["cells"].shape[1:] == (4,) assert np.all(np.logical_and(target["cells"] >= 0, target["cells"] <= 1)) + assert target["logic"].dtype == np.int64 assert target["logic"].shape == (target["cells"].shape[0], 4) loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) images, targets = next(iter(loader)) assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size) assert isinstance(targets, list) and all(set(t) == {"cells", "logic"} for t in targets) + for batch_target in targets: + if use_polygons: + assert batch_target["cells"].shape[1:] == (4, 2) + else: + assert batch_target["cells"].shape[1:] == (4,) # File existence check img_name, _ = ds.data[0] diff --git a/tests/pytorch/test_models_table_structure_pt.py b/tests/pytorch/test_models_table_structure_pt.py index c559c99c8c..7f57f15c63 100644 --- a/tests/pytorch/test_models_table_structure_pt.py +++ b/tests/pytorch/test_models_table_structure_pt.py @@ -14,31 +14,42 @@ _HEADS = {"hm": 2, "reg": 2, "ct2cn": 8, "cn2ct": 8, "lc": 2, "sp": 2} -def _grid_target(rows=2, cols=3): - """A relative-coordinate {"cells", "logic"} target for a rows x cols grid (the dataset's output).""" +def _grid_target(rows=2, cols=3, use_polygons=True): + """A relative-coordinate {"cells", "logic"} target for a rows x cols grid.""" xs, ys = np.linspace(0.1, 0.9, cols + 1), np.linspace(0.1, 0.9, rows + 1) cells, logic = [], [] for r in range(rows): for c in range(cols): cells.append([[xs[c], ys[r]], [xs[c + 1], ys[r]], [xs[c + 1], ys[r + 1]], [xs[c], ys[r + 1]]]) logic.append([c, c, r, r]) - return {"cells": np.array(cells, np.float32), "logic": np.array(logic, np.int64)} + + cell_array = np.asarray(cells, dtype=np.float32).reshape(-1, 4, 2) + if not use_polygons: + cell_array = np.concatenate([cell_array.min(axis=1), cell_array.max(axis=1)], axis=1) + return {"cells": cell_array, "logic": np.asarray(logic, dtype=np.int64).reshape(-1, 4)} @pytest.mark.parametrize("train_mode", [True, False]) +@pytest.mark.parametrize("assume_straight_pages", [True, False]) @pytest.mark.parametrize( "arch_name, input_shape", [ ["tablecenternet", (3, 1024, 1024)], ], ) -def test_table_models(arch_name, input_shape, train_mode): +def test_table_models(arch_name, input_shape, train_mode, assume_straight_pages): batch_size = 2 - model = table_structure.__dict__[arch_name](pretrained=True) + model = table_structure.__dict__[arch_name]( + pretrained=True, + assume_straight_pages=assume_straight_pages, + ) model = model.train() if train_mode else model.eval() assert isinstance(model, TableCenterNet) input_tensor = torch.rand((batch_size, *input_shape)) - target = [_grid_target(), _grid_target()] + target = [ + _grid_target(use_polygons=not assume_straight_pages), + _grid_target(use_polygons=not assume_straight_pages), + ] if torch.cuda.is_available(): model.cuda() @@ -55,14 +66,15 @@ def test_table_models(arch_name, input_shape, train_mode): # Check Preds if not train_mode: assert len(out["preds"]) == batch_size + expected_shape = (4,) if assume_straight_pages else (4, 2) for pred in out["preds"]: assert set(pred) == {"polygons", "scores", "logical"} # Check logical coordinates have 4 entries per cell (start/end col, start/end row) assert pred["logical"].shape[1] == 4 # Check that the number of cells, scores and logical coordinates are the same assert len(pred["polygons"]) == len(pred["scores"]) == len(pred["logical"]) + assert pred["polygons"].shape[1:] == expected_shape if pred["polygons"].size: - assert pred["polygons"].shape[1:] == (4, 2) # Check that cells are in the range [0, 1] assert np.all(pred["polygons"] >= 0) and np.all(pred["polygons"] <= 1) # Check that scores are between 0 and 1 @@ -72,15 +84,19 @@ def test_table_models(arch_name, input_shape, train_mode): assert hasattr(model, "from_pretrained") +@pytest.mark.parametrize("assume_straight_pages", [True, False]) @pytest.mark.parametrize( "arch_name", [ "tablecenternet", ], ) -def test_table_structure_zoo(arch_name): - # Model - predictor = table_structure.zoo.table_predictor(arch_name, pretrained=False) +def test_table_structure_zoo(arch_name, assume_straight_pages): + predictor = table_structure.zoo.table_predictor( + arch_name, + pretrained=False, + assume_straight_pages=assume_straight_pages, + ) predictor.model = predictor.model.eval() # object check assert isinstance(predictor, TablePredictor) @@ -93,8 +109,9 @@ def test_table_structure_zoo(arch_name): assert isinstance(out, list) and len(out) == 2 assert all(isinstance(page, dict) for page in out) assert all({"cells", "num_rows", "num_cols"} <= set(page) for page in out) + expected_shape = (4,) if assume_straight_pages else (4, 2) for page in out: - assert all(np.asarray(cell["geometry"]).shape == (4, 2) for cell in page["cells"]) + assert all(np.asarray(cell["geometry"]).shape == expected_shape for cell in page["cells"]) assert all({"score", "row_start", "row_end", "col_start", "col_end"} <= set(cell) for cell in page["cells"]) assert all(0 <= cell["score"] <= 1 for cell in page["cells"]) From e190ab1252d547009678c82f273141413ef77965 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 29 Jun 2026 14:24:36 +0200 Subject: [PATCH 06/14] Add reference jobs --- .github/workflows/references.yml | 92 +++++++++++++++++++ doctr/datasets/table_structure.py | 4 +- .../table_structure/tablecenternet/base.py | 49 +++------- .../table_structure/tablecenternet/pytorch.py | 7 +- doctr/utils/metrics.py | 10 +- tests/common/test_models_table_structure.py | 2 +- 6 files changed, 118 insertions(+), 46 deletions(-) diff --git a/.github/workflows/references.yml b/.github/workflows/references.yml index 2ef8c4d0e1..74e39ccf52 100644 --- a/.github/workflows/references.yml +++ b/.github/workflows/references.yml @@ -344,3 +344,95 @@ jobs: pip install -e .[viz,html] --upgrade - name: Benchmark latency run: python references/layout/latency.py lw_detr_s --it 5 --size 512 + + train-table-structure-recognition: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.10"] + steps: + - uses: actions/checkout@v7 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v5 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('references/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}- + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[viz,html] --upgrade + pip install -r references/requirements.txt + - name: Download and extract toy set + run: | + wget https://github.com/mindee/doctr/releases/download/v1.0.1/toy_table_set-ea091e15.zip + sudo apt-get update && sudo apt-get install unzip -y + unzip toy_table_set-ea091e15.zip -d table_set + - name: Train for a short epoch + run: python references/table/train.py tablecenternet --train_path ./table_set --val_path ./table_set -b 2 --epochs 1 + + evaluate-table-structure-recognition: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.10"] + steps: + - uses: actions/checkout@v7 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v5 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[viz,html] --upgrade + pip install -r references/requirements.txt + - name: Download and extract toy set + run: | + wget https://github.com/mindee/doctr/releases/download/v1.0.1/toy_table_set-ea091e15.zip + sudo apt-get update && sudo apt-get install unzip -y + unzip toy_table_set-ea091e15.zip -d table_set + - name: Evaluate table structure recognition + run: python references/table/evaluate.py tablecenternet ./table_set + + latency-table-structure-recognition: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: ["3.10"] + steps: + - uses: actions/checkout@v7 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v5 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[viz,html] --upgrade + - name: Benchmark latency + run: python references/table/latency.py tablecenternet --it 5 --size 512 diff --git a/doctr/datasets/table_structure.py b/doctr/datasets/table_structure.py index 0b6bd175e4..aa3bfc38ab 100644 --- a/doctr/datasets/table_structure.py +++ b/doctr/datasets/table_structure.py @@ -33,8 +33,8 @@ class TableStructureDataset(AbstractDataset): } Each sample yields the image and a target containing relative cells and their logical coordinates. Cells have - shape ``(N, 4)`` by default, or ``(N, 4, 2)`` when ``use_polygons=True``. Logical coordinates have shape - ``(N, 4)``. + shape `(N, 4)` by default, or `(N, 4, 2)` when `use_polygons=True`. Logical coordinates have shape + `(N, 4)`. >>> from doctr.datasets import TableStructureDataset >>> from doctr.transforms import Resize diff --git a/doctr/models/table_structure/tablecenternet/base.py b/doctr/models/table_structure/tablecenternet/base.py index aec8216275..057a15294a 100644 --- a/doctr/models/table_structure/tablecenternet/base.py +++ b/doctr/models/table_structure/tablecenternet/base.py @@ -18,8 +18,8 @@ def _get_logic_coords(lc_logic: np.ndarray, col_span: int, row_span: int) -> tuple[int, int, int, int]: """Resolve a cell's logical coordinates (start/end column and row) from the per-corner logical - predictions (``lc_logic`` is a (4, 2) array of [col, row] for corners TL, TR, BR, BL) and the cell span. - Pure numpy port of the reference ``get_logic_coords``.""" + predictions (`lc_logic` is a (4, 2) array of [col, row] for corners TL, TR, BR, BL) and the cell span. + Pure numpy port of the reference `get_logic_coords`.""" col_span = max(1, col_span) row_span = max(1, row_span) col_lc = [max(1, int(round(float(p)))) for p in lc_logic[:, 0]] @@ -80,14 +80,6 @@ def _lookup_logic(lc_map: np.ndarray, x: float, y: float) -> np.ndarray: def _ensure_simple_quads(polys: np.ndarray) -> np.ndarray: """Guarantee each predicted quad is a simple (non-self-intersecting) polygon. - The center decode (cell built from the ``ct2cn`` offset vectors) and the corner-relocation step can - occasionally yield a self-intersecting "bow-tie" quad - e.g. a mis-predicted cell whose ``TL``/``TR`` - (or any two) corners cross over. Such polygons are invalid for shapely and make - :func:`doctr.utils.metrics.polygon_iou` raise a ``TopologyException`` (side location conflict) during - evaluation. Reordering the four points by their angle around the centroid produces the simple polygon - spanned by the *same four corners* (identical cell region), which is the natural recovery; quads that - are already valid keep their original corner order untouched. - Args: polys: predicted quads, shape (N, 4, 2) @@ -103,16 +95,12 @@ def _ensure_simple_quads(polys: np.ndarray) -> np.ndarray: class TableCenterNetPostProcessor: - """Torch-free post-processor turning the model's *decoded* key-points into table cells. - - All tensor-heavy operations (heat-map NMS, top-k, gather) are performed inside the model's decoder - (which requires torch and is skipped during ONNX export). This object only consumes numpy arrays, so - it never blocks an export and can be tested without torch. + """TableCenterNet post-processor turning the model's *decoded* key-points into table cells. The cell geometry is returned in **relative** coordinates ([0, 1] w.r.t. the model input), so the predictor can undo the pre-processor's padding/resize like the other docTR predictors. When - ``assume_straight_pages=True``, geometries are axis-aligned boxes of shape ``(N, 4)``; otherwise they - are quadrilaterals of shape ``(N, 4, 2)``. + `assume_straight_pages=True`, geometries are axis-aligned boxes of shape `(N, 4)`; otherwise they + are quadrilaterals of shape `(N, 4, 2)`. Args: center_thresh: minimum score for a cell center to be kept @@ -214,17 +202,12 @@ def __call__(self, decoded: dict[str, np.ndarray]) -> list[dict[str, np.ndarray] cp, cs, logic = self._simple(decoded, b) if self.not_relocate else self._relocate(decoded, b) keep = cs >= self.center_thresh polys = cp[keep].reshape(-1, 4, 2) / scale # relative coordinates - # Guarantee simple (non-self-intersecting) quads so shapely-based IoU (TableCellMetric) never - # sees an invalid geometry. Applied after the relative rescale; logical coords are unaffected. polys = _ensure_simple_quads(np.clip(polys.astype(np.float32), 0, 1)) cells = ( np.concatenate([polys.min(axis=1), polys.max(axis=1)], axis=1).astype(np.float32) if self.assume_straight_pages else polys ) - # _get_logic_coords reconstructs 1-indexed logical coordinates (column/row lines start at 1, - # mirroring the +1 offset applied when rendering the target). Shift back to the 0-indexed - # convention used by the dataset and TableCellMetric so predictions and GT are comparable. results.append({ "polygons": cells, # (N, 4) boxes or (N, 4, 2) quads in relative coordinates "scores": cs[keep].astype(np.float32), @@ -276,8 +259,7 @@ def _polygon_area(points: list[tuple[float, float]]) -> float: def _interpolate_polygons(polygons: list[list[tuple]], img_size: tuple[int, int]) -> tuple[np.ndarray, np.ndarray]: - """Fill each polygon's interior with the linear interpolation of its per-corner value (the ``"sort"`` - variant of the reference ``interpolate_polygons``).""" + """Fill each polygon's interior with the linear interpolation of its per-corner value.""" final_image = np.zeros(img_size, dtype=np.float32) mask = np.zeros(img_size, dtype=bool) areas = [_polygon_area([(x, y) for x, y, _ in poly]) for poly in polygons] @@ -325,11 +307,11 @@ def _build_table_target( max_objects: int = 300, max_corners: int = 1200, ) -> dict[str, np.ndarray]: - """Render the dense TableCenterNet targets (for a single image) consumed by ``TableCenterNet.compute_loss``. + """Render the dense TableCenterNet targets (for a single image) consumed by `TableCenterNet.compute_loss`. Args: cells: (N, 4, 2) cell quadrilaterals (corner order TL, TR, BR, BL) in **output-grid** coordinates - logic: (N, 4) integer logical coordinates ``[start_col, end_col, start_row, end_row]`` (0-indexed) + logic: (N, 4) integer logical coordinates `[start_col, end_col, start_row, end_row]` (0-indexed) output_size: (H, W) of the model output grid (input size // down_ratio) max_objects: maximum number of cells max_corners: maximum number of distinct corners @@ -437,11 +419,11 @@ def _cells_to_polygons(cells: np.ndarray) -> np.ndarray: """Convert table cells to quadrilaterals. Args: - cells: relative axis-aligned boxes of shape ``(N, 4)`` in ``(xmin, ymin, xmax, ymax)`` format, - or quadrilaterals of shape ``(N, 4, 2)``. + cells: relative axis-aligned boxes of shape `(N, 4)` in `(xmin, ymin, xmax, ymax)` format, + or quadrilaterals of shape `(N, 4, 2)`. Returns: - Relative quadrilaterals of shape ``(N, 4, 2)`` in TL, TR, BR, BL order. + Relative quadrilaterals of shape `(N, 4, 2)` in TL, TR, BR, BL order. """ if cells.ndim == 3 and cells.shape[1:] == (4, 2): return cells @@ -463,9 +445,8 @@ class _TableCenterNet(BaseModel): """TableCenterNet for table-structure recognition, as described in the official implementation ``_. - This base class holds the framework-agnostic target rendering (``build_target``), mirroring the - organization of the detection (``_LinkNet``) and layout (``_LWDETR``) models: the dense maps consumed - by ``compute_loss`` are produced here, while ``TableCenterNetPostProcessor`` decodes the model output. + This base class holds the framework-agnostic target rendering (`build_target`): the dense maps consumed + by `compute_loss` are produced here, while `TableCenterNetPostProcessor` decodes the model output. """ max_objects: int = 300 @@ -480,8 +461,8 @@ def build_target( """Render the dense training targets for a batch from per-image cell annotations. Args: - target: one ``{"cells": (N, 4) relative boxes or (N, 4, 2) relative polygons, - "logic": (N, 4)}`` dict per image + target: one `{"cells": (N, 4) relative boxes or (N, 4, 2) relative polygons, + "logic": (N, 4)}` dict per image output_shape: (H, W) of the model output grid (input size // down_ratio) Returns: diff --git a/doctr/models/table_structure/tablecenternet/pytorch.py b/doctr/models/table_structure/tablecenternet/pytorch.py index 406e98fe8e..e76bfcc28c 100644 --- a/doctr/models/table_structure/tablecenternet/pytorch.py +++ b/doctr/models/table_structure/tablecenternet/pytorch.py @@ -203,7 +203,7 @@ def __init__( nn.ReLU(inplace=True), nn.Conv2d(head_conv, out_ch, 1, stride=1, padding=0, bias=True), ) - # Reference head initialisation: detection-style bias for heatmaps, zeroed bias otherwise. + # Reference head initialisation: detection-style bias for heatmaps, zeroed bias otherwise final = fc[2] if isinstance(final, nn.Conv2d) and final.bias is not None: nn.init.constant_(final.bias, -2.19 if "hm" in head else 0.0) @@ -228,7 +228,6 @@ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: def _polygons_decode(self, heatmap: torch.Tensor, vec: torch.Tensor, reg: torch.Tensor, k: int): """Decode key-points (cell centers or corners) into the four points of a quadrilateral.""" batch = heatmap.size(0) - k = min(k, heatmap.size(2) * heatmap.size(3)) # never request more points than there are locations # NMS on heatmaps pad = (3 - 1) // 2 hmax = F.max_pool2d(heatmap, (3, 3), stride=1, padding=pad) @@ -286,7 +285,7 @@ def _decode(self, heads: dict[str, torch.Tensor]) -> dict[str, Any]: feat_h, feat_w = hm.shape[2], hm.shape[3] def _np(t: torch.Tensor) -> np.ndarray: - # Cast to float32 first: numpy has no bfloat16 (relevant under autocast/AMP) + # Cast to float32 first: relevant under autocast/AMP return t.detach().float().cpu().numpy() return { @@ -348,7 +347,7 @@ def compute_loss( Args: output: the raw head maps returned by the model - target: one ``{"cells": (N, 4, 2) relative polygons, "logic": (N, 4)}`` dict per image + target: one `{"cells": (N, 4, 2) relative polygons, "logic": (N, 4)}` dict per image Returns: the scalar training loss diff --git a/doctr/utils/metrics.py b/doctr/utils/metrics.py index 4364d0b0a0..342913bfa3 100644 --- a/doctr/utils/metrics.py +++ b/doctr/utils/metrics.py @@ -319,9 +319,9 @@ class TableCellMetric: r"""Implements a table-structure-recognition metric. Predicted cells are matched to ground-truth cells by maximising the total IoU (Hungarian assignment); a pair - counts as a match when its IoU is at least ``iou_thresh``. From the matches it reports cell-detection recall, + counts as a match when its IoU is at least `iou_thresh`. From the matches it reports cell-detection recall, precision, F1 and the **structure accuracy** (the fraction of matched cells whose logical coordinates - ``[start_col, end_col, start_row, end_row]`` exactly equal the ground-truth ones). + `[start_col, end_col, start_row, end_row]` exactly equal the ground-truth ones). >>> import numpy as np >>> from doctr.utils import TableCellMetric @@ -354,9 +354,9 @@ def update( """Update the metric with one sample. Args: - gt_cells: ground-truth cells, shape (N, 4) or (N, 4, 2) when ``use_polygons=True`` + gt_cells: ground-truth cells, shape (N, 4) or (N, 4, 2) when `use_polygons=True` gt_logic: ground-truth logical coordinates, shape (N, 4) - pred_cells: predicted cells, shape (M, 4) or (M, 4, 2) when ``use_polygons=True`` + pred_cells: predicted cells, shape (M, 4) or (M, 4, 2) when `use_polygons=True` pred_logic: predicted logical coordinates, shape (M, 4) """ self.num_gts += gt_cells.shape[0] @@ -379,7 +379,7 @@ def summary(self) -> dict[str, float | None]: """Compute the aggregated metrics. Returns: - a dict with ``recall``, ``precision``, ``f1`` (cell detection) and ``structure_acc`` + a dict with `recall`, `precision`, `f1` (cell detection) and `structure_acc` """ recall = self.matches / self.num_gts if self.num_gts > 0 else None precision = self.matches / self.num_preds if self.num_preds > 0 else None diff --git a/tests/common/test_models_table_structure.py b/tests/common/test_models_table_structure.py index 68811dca1a..e835ba38c5 100644 --- a/tests/common/test_models_table_structure.py +++ b/tests/common/test_models_table_structure.py @@ -5,7 +5,7 @@ def _grid_target(rows: int, cols: int, use_polygons: bool) -> dict[str, np.ndarray]: - """A relative-coordinate ``{"cells", "logic"}`` target for a ``rows x cols`` grid.""" + """A relative-coordinate `{"cells", "logic"}` target for a `rows x cols` grid.""" xs, ys = np.linspace(0.1, 0.9, cols + 1), np.linspace(0.1, 0.9, rows + 1) cells, logic = [], [] for r in range(rows): From b34d0b6aca3c072ca229a77da3d85bb53b885261 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 29 Jun 2026 14:32:03 +0200 Subject: [PATCH 07/14] Add reference jobs --- doctr/models/table_structure/tablecenternet/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doctr/models/table_structure/tablecenternet/pytorch.py b/doctr/models/table_structure/tablecenternet/pytorch.py index e76bfcc28c..a4378b8068 100644 --- a/doctr/models/table_structure/tablecenternet/pytorch.py +++ b/doctr/models/table_structure/tablecenternet/pytorch.py @@ -168,7 +168,7 @@ def __init__( not_relocate: bool = False, max_objects: int = 300, max_corners: int = 1200, - assume_straight_pages: bool = False, + assume_straight_pages: bool = True, exportable: bool = False, cfg: dict[str, Any] | None = None, ) -> None: From 7cc6b1637da448e1327d0a16a5ed7b85b0d287c4 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 29 Jun 2026 14:42:25 +0200 Subject: [PATCH 08/14] mypy & eval table script --- doctr/models/table_structure/tablecenternet/base.py | 2 +- references/table/evaluate.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doctr/models/table_structure/tablecenternet/base.py b/doctr/models/table_structure/tablecenternet/base.py index 057a15294a..cbdf03eadc 100644 --- a/doctr/models/table_structure/tablecenternet/base.py +++ b/doctr/models/table_structure/tablecenternet/base.py @@ -229,7 +229,7 @@ def _gaussian_radius(det_size: tuple[float, float], min_overlap: float = 0.7) -> def _gaussian_2d(shape: tuple[int, int], sigma: float = 1.0) -> np.ndarray: m, n = ((s - 1) / 2 for s in shape) - y, x = np.ogrid[-m : m + 1, -n : n + 1] # type: ignore[misc] + y, x = np.ogrid[-m : m + 1, -n : n + 1] h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) h[h < np.finfo(h.dtype).eps * h.max()] = 0 return h diff --git a/references/table/evaluate.py b/references/table/evaluate.py index f2aa3add83..125904ff27 100644 --- a/references/table/evaluate.py +++ b/references/table/evaluate.py @@ -40,7 +40,7 @@ def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): for target, pred in zip(targets, out["preds"]): val_metric.update( - np.asarray(target["cells"], dtype=np.float32).reshape(-1, 4, 2), + np.asarray(target["cells"], dtype=np.float32), np.asarray(target["logic"], dtype=np.int64).reshape(-1, 4), pred["polygons"], pred["logical"], From fe9412747f1bdc53e277cd1f168d194ecd855ce4 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 29 Jun 2026 14:48:55 +0200 Subject: [PATCH 09/14] Add references Readme --- references/table/README.md | 133 ++++++++++++++++++++++++++++++++++++- 1 file changed, 132 insertions(+), 1 deletion(-) diff --git a/references/table/README.md b/references/table/README.md index 4a807f4f1b..7a5b6d40d1 100644 --- a/references/table/README.md +++ b/references/table/README.md @@ -1 +1,132 @@ -# TODO: Write the readme like in references/detection | references/layout +# Table structure recognition + +The sample scripts in this folder let you train, evaluate and benchmark table structure recognition models with docTR. +A table structure model localizes every cell of a table (its spatial structure) and recovers the rows and columns each cell spans (its logical structure). + +## Setup + +First, you need to install `doctr` (with pip, for instance) + +```shell +pip install -e . --upgrade +pip install -r references/requirements.txt +``` + +## Usage + +You can start your training in PyTorch: + +```shell +python references/table/train.py tablecenternet --train_path path/to/your/train_set --val_path path/to/your/val_set --epochs 5 +``` + +To try the pipeline end-to-end on a small toy dataset: + +```shell +wget https://github.com/mindee/doctr/releases/download/v1.0.1/toy_table_set-ea091e15.zip +unzip toy_table_set-ea091e15.zip -d table_set +python references/table/train.py tablecenternet --train_path ./table_set --val_path ./table_set -b 2 --epochs 1 +``` + +### Multi-GPU support + +We now use the built-in [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) launcher to spawn your DDP workers. `torchrun` will set all the necessary environment variables (`LOCAL_RANK`, `RANK`, etc.) for you. Arguments are the same than the ones from single GPU, except: + +- `--backend`: you can specify another `backend` for `DistributedDataParallel` if the default one is not available on +your operating system. Fastest one is `nccl` according to [PyTorch Documentation](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). + +#### Key `torchrun` parameters + +- `--nproc_per_node=` + Spawn `` processes on the local machine (typically equal to the number of GPUs you want to use). +- `--nnodes=` + (Optional) Total number of nodes in your job. Default is 1. +- `--rdzv_backend`, `--rdzv_endpoint`, `--rdzv_id` + (Optional) Rendezvous settings for multi-node jobs. See the [torchrun docs](https://pytorch.org/docs/stable/elastic/run.html) for details. + +#### GPU selection + +By default all visible GPUs will be used. To limit which GPUs participate, set the `CUDA_VISIBLE_DEVICES` environment variable **before** running `torchrun`. For example, to use only CUDA devices 0 and 2: + +```shell +CUDA_VISIBLE_DEVICES=0,2 \ +torchrun --nproc_per_node=2 references/table/train.py \ + tablecenternet \ + --train_path path/to/train \ + --val_path path/to/val \ + --epochs 5 \ + --backend nccl + ``` + +## Evaluation + +You can evaluate a model (the pretrained one by default, or your own checkpoint with `--resume`) on a dataset: + +```shell +python references/table/evaluate.py tablecenternet path/to/your/dataset +python references/table/evaluate.py tablecenternet path/to/your/dataset --resume path/to/your/checkpoint.pt +``` + +The script reports the cell-detection recall, precision and F1, along with the structure accuracy (the share of cells whose logical coordinates are correctly predicted). Cells are matched to the ground truth above the IoU threshold set with `--iou_thresh` (default `0.5`). + +## Latency benchmark + +You can measure the inference latency of an architecture: + +```shell +python references/table/latency.py tablecenternet --it 100 --size 1024 --gpu +``` + +## Data format + +You need to provide both `train_path` and `val_path` arguments to start training (`evaluate.py` takes a single `dataset_path`). +Each path must lead to a folder with 1 subfolder and 1 file: + +```shell +├── images +│ ├── sample_img_01.png +│ ├── sample_img_02.png +│ ├── sample_img_03.png +│ └── ... +└── labels.json +``` + +`labels.json` is a dictionary mapping each image file name to its annotation. Each annotation has 2 entries: + +- `cells`: the list of cell polygons. Each polygon is a quadrilateral given as 4 `(x, y)` **absolute** coordinates ordered top-left, top-right, bottom-right, bottom-left. +- `logic`: the list of logical coordinates, one per cell, given as `[start_col, end_col, start_row, end_row]`. Indices are **0-indexed** and ends are inclusive, so a cell that spans a single row and a single column has equal start and end indices. + +Both lists must have the same length (one entry per cell): `cells` has shape `(N, 4, 2)` and `logic` has shape `(N, 4)`. + +labels.json + +```shell +{ + "sample_img_01.png": { + "cells": [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], ...], + "logic": [[start_col, end_col, start_row, end_row], ...] + }, + "sample_img_02.png": { + "cells": [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], ...], + "logic": [[start_col, end_col, start_row, end_row], ...] + } + ... +} +``` + +## Slack Logging with tqdm + +To enable Slack logging using `tqdm`, you need to set the following environment variables: + +- `TQDM_SLACK_TOKEN`: the Slack Bot Token +- `TQDM_SLACK_CHANNEL`: you can retrieve it using `Right Click on Channel > Copy > Copy link`. You should get something like `https://xxxxxx.slack.com/archives/yyyyyyyy`. Keep only the `yyyyyyyy` part. + +You can follow this page on [how to create a Slack App](https://api.slack.com/quickstart). + +## Advanced options + +Feel free to inspect the multiple script option to customize your training to your own needs! + +```python +python references/table/train.py --help +``` From 34f7e9f419b9f49835beb34442fb9d9bea984f72 Mon Sep 17 00:00:00 2001 From: felix Date: Mon, 29 Jun 2026 15:16:17 +0200 Subject: [PATCH 10/14] Add docs --- docs/source/modules/models.rst | 12 ++++++++++++ .../source/using_doctr/custom_models_training.rst | 1 + docs/source/using_doctr/using_models.rst | 15 +++++++++++++++ 3 files changed, 28 insertions(+) diff --git a/docs/source/modules/models.rst b/docs/source/modules/models.rst index 60985aa6c2..661cd0b14a 100644 --- a/docs/source/modules/models.rst +++ b/docs/source/modules/models.rst @@ -85,6 +85,8 @@ doctr.models.layout .. autofunction:: doctr.models.layout.lw_detr_m +.. autofunction:: doctr.models.layout.layout_predictor + doctr.models.table_structure ---------------------------- @@ -134,3 +136,13 @@ doctr.models.factory .. autofunction:: doctr.models.factory.from_hub .. autofunction:: doctr.models.factory.push_to_hf_hub + + +doctr.models.utils +------------------ + +.. currentmodule:: doctr.models.utils + +.. autofunction:: export_model_to_onnx + +.. autofunction:: add_whitelist diff --git a/docs/source/using_doctr/custom_models_training.rst b/docs/source/using_doctr/custom_models_training.rst index 912b9af06f..b1b95c6110 100644 --- a/docs/source/using_doctr/custom_models_training.rst +++ b/docs/source/using_doctr/custom_models_training.rst @@ -9,6 +9,7 @@ For details on the training process and the necessary data and data format, refe - `detection `_ - `recognition `_ - `layout `_ +- `table structure `_ If you’re looking for a lightweight yet efficient tool to annotate small amounts of data, especially tailored for docTR, check out the `docTR Labeling Tool `_. diff --git a/docs/source/using_doctr/using_models.rst b/docs/source/using_doctr/using_models.rst index 21d7a9b78c..731d00e36a 100644 --- a/docs/source/using_doctr/using_models.rst +++ b/docs/source/using_doctr/using_models.rst @@ -271,6 +271,21 @@ The following architectures are currently supported: * :py:meth:`tablecenternet ` +For a comprehensive comparison, we have compiled a detailed benchmark on a publicly available dataset: + ++--------------------------------------------------+-----------------+---------------+--------------+---------------+------------+-------------------+--------------------+ +| **Architecture** | **Input shape** | **# params** | **Recall** | **Precision** | **F1** | **Structure acc** | **sec/it (B: 1)** | ++==================================================+=================+===============+==============+===============+============+===================+====================+ +| tablecenternet | (1024, 1024, 3) | 7.1 M | | | | | 0.7 | ++--------------------------------------------------+-----------------+---------------+--------------+---------------+------------+-------------------+--------------------+ + +.. note:: + + The reported metrics are produced by ``references/table/evaluate.py`` using the + :py:class:`TableCellMetric `: cell-detection **Recall**, **Precision** and + **F1** (cells matched above an IoU threshold of 0.5), and **Structure acc**, the share of matched cells whose + logical (row/column) coordinates are correctly predicted. + Table structure predictors ^^^^^^^^^^^^^^^^^^^^^^^^^^^ From d41e328e2b2a1e21c8dbf183043a012bc3a990ee Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 1 Jul 2026 10:34:41 +0200 Subject: [PATCH 11/14] Update table model checkpoint --- doctr/models/table_structure/tablecenternet/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doctr/models/table_structure/tablecenternet/pytorch.py b/doctr/models/table_structure/tablecenternet/pytorch.py index a4378b8068..222c133f9c 100644 --- a/doctr/models/table_structure/tablecenternet/pytorch.py +++ b/doctr/models/table_structure/tablecenternet/pytorch.py @@ -26,7 +26,7 @@ "input_shape": (3, 1024, 1024), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), - "url": "https://github.com/mindee/doctr/releases/download/v1.0.1/tablecenternet-27736590.pt", + "url": "https://github.com/mindee/doctr/releases/download/v1.0.1/tablecenternet-ea5b30a3.pt", }, } From f8b4a2a676be479229354c77066d53773da5f877 Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 1 Jul 2026 10:44:14 +0200 Subject: [PATCH 12/14] Update docs --- docs/source/using_doctr/using_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/using_doctr/using_models.rst b/docs/source/using_doctr/using_models.rst index 731d00e36a..f1ec642553 100644 --- a/docs/source/using_doctr/using_models.rst +++ b/docs/source/using_doctr/using_models.rst @@ -276,7 +276,7 @@ For a comprehensive comparison, we have compiled a detailed benchmark on a publi +--------------------------------------------------+-----------------+---------------+--------------+---------------+------------+-------------------+--------------------+ | **Architecture** | **Input shape** | **# params** | **Recall** | **Precision** | **F1** | **Structure acc** | **sec/it (B: 1)** | +==================================================+=================+===============+==============+===============+============+===================+====================+ -| tablecenternet | (1024, 1024, 3) | 7.1 M | | | | | 0.7 | +| tablecenternet | (1024, 1024, 3) | 7.1 M | 82.31 | 96.01 | 88.64 | 77.53 | 0.7 | +--------------------------------------------------+-----------------+---------------+--------------+---------------+------------+-------------------+--------------------+ .. note:: From 587da2310ad1a51c03e5b6d5e1a8162e15a144ba Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 1 Jul 2026 11:06:00 +0200 Subject: [PATCH 13/14] Upload model to hf --- docs/source/using_doctr/sharing_models.rst | 12 ++++++++++++ tests/pytorch/test_models_factory.py | 6 ++++++ 2 files changed, 18 insertions(+) diff --git a/docs/source/using_doctr/sharing_models.rst b/docs/source/using_doctr/sharing_models.rst index d7206040eb..37f73ef12b 100644 --- a/docs/source/using_doctr/sharing_models.rst +++ b/docs/source/using_doctr/sharing_models.rst @@ -70,6 +70,8 @@ We suggest using the following naming conventions for your models: **Layout:** ``doctr-`` +**Table structure:** ``doctr-`` + Classification -------------- @@ -113,3 +115,13 @@ Layout +=================================+===================================================+========================+ | lw_detr_s (dummy) | Felix92/doctr-dummy-torch-lw-detr-s | PyTorch | +---------------------------------+---------------------------------------------------+------------------------+ + + +Table structure +--------------- + ++---------------------------------+---------------------------------------------------+------------------------+ +| **Architecture** | **Repo_ID** | **Framework** | ++=================================+===================================================+========================+ +| tablecenternet (dummy) | Felix92/doctr-dummy-torch-tablecenternet | PyTorch | ++---------------------------------+---------------------------------------------------+------------------------+ diff --git a/tests/pytorch/test_models_factory.py b/tests/pytorch/test_models_factory.py index db9aed1cdf..65035b68a2 100644 --- a/tests/pytorch/test_models_factory.py +++ b/tests/pytorch/test_models_factory.py @@ -24,6 +24,7 @@ def test_push_to_hf_hub(): @pytest.mark.parametrize( "arch_name, task_name, dummy_model_id", [ + # Classification ["vgg16_bn_r", "classification", "Felix92/doctr-dummy-torch-vgg16-bn-r"], ["resnet18", "classification", "Felix92/doctr-dummy-torch-resnet18"], ["resnet31", "classification", "Felix92/doctr-dummy-torch-resnet31"], @@ -36,12 +37,14 @@ def test_push_to_hf_hub(): ["vit_s", "classification", "Felix92/doctr-dummy-torch-vit-s"], ["vit_det_s", "classification", "Felix92/doctr-dummy-torch-vit-det-s"], ["textnet_tiny", "classification", "Felix92/doctr-dummy-torch-textnet-tiny"], + # Detection ["db_resnet34", "detection", "Felix92/doctr-dummy-torch-db-resnet34"], ["db_resnet50", "detection", "Felix92/doctr-dummy-torch-db-resnet50"], ["db_mobilenet_v3_large", "detection", "Felix92/doctr-dummy-torch-db-mobilenet-v3-large"], ["linknet_resnet18", "detection", "Felix92/doctr-dummy-torch-linknet-resnet18"], ["linknet_resnet34", "detection", "Felix92/doctr-dummy-torch-linknet-resnet34"], ["linknet_resnet50", "detection", "Felix92/doctr-dummy-torch-linknet-resnet50"], + # Recognition ["crnn_vgg16_bn", "recognition", "Felix92/doctr-dummy-torch-crnn-vgg16-bn"], ["crnn_mobilenet_v3_small", "recognition", "Felix92/doctr-dummy-torch-crnn-mobilenet-v3-small"], ["crnn_mobilenet_v3_large", "recognition", "Felix92/doctr-dummy-torch-crnn-mobilenet-v3-large"], @@ -50,7 +53,10 @@ def test_push_to_hf_hub(): ["vitstr_small", "recognition", "Felix92/doctr-dummy-torch-vitstr-small"], ["parseq", "recognition", "Felix92/doctr-dummy-torch-parseq"], ["viptr_tiny", "recognition", "Felix92/doctr-dummy-torch-viptr-tiny"], + # Layout ["lw_detr_s", "layout", "Felix92/doctr-dummy-torch-lw-detr-s"], + # Table structure + ["tablecenternet", "table_structure", "Felix92/doctr-dummy-torch-tablecenternet"], ], ) def test_models_huggingface_hub(arch_name, task_name, dummy_model_id, tmpdir): From 4250fdaa20069a38ef4f801281c2b112fb3a526c Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 1 Jul 2026 11:56:08 +0200 Subject: [PATCH 14/14] Update tests --- tests/pytorch/test_models_table_structure_pt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/test_models_table_structure_pt.py b/tests/pytorch/test_models_table_structure_pt.py index 7f57f15c63..e742f7519e 100644 --- a/tests/pytorch/test_models_table_structure_pt.py +++ b/tests/pytorch/test_models_table_structure_pt.py @@ -94,7 +94,7 @@ def test_table_models(arch_name, input_shape, train_mode, assume_straight_pages) def test_table_structure_zoo(arch_name, assume_straight_pages): predictor = table_structure.zoo.table_predictor( arch_name, - pretrained=False, + pretrained=True, assume_straight_pages=assume_straight_pages, ) predictor.model = predictor.model.eval() @@ -125,7 +125,7 @@ def test_table_structure_zoo(arch_name, assume_straight_pages): def test_models_onnx_export(arch_name, input_shape): # Model batch_size = 2 - model = table_structure.__dict__[arch_name](pretrained=False, exportable=True).eval() + model = table_structure.__dict__[arch_name](pretrained=True, exportable=True).eval() dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) head_names = list(model.heads.keys()) pt = model(dummy_input) @@ -163,7 +163,7 @@ def test_models_onnx_export(arch_name, input_shape): ) def test_torch_compiled_models(arch_name): page = (255 * np.random.rand(1024, 1024, 3)).astype(np.uint8) - predictor = table_structure.zoo.table_predictor(arch_name, pretrained=False) + predictor = table_structure.zoo.table_predictor(arch_name, pretrained=True) assert isinstance(predictor, TablePredictor) out = predictor([page])