Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion api/app/routes/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status

from app.schemas import LayoutElementOut, OCRBlock, OCRIn, OCRLine, OCROut, OCRPage, OCRWord
from app.schemas import LayoutElementOut, OCRBlock, OCRIn, OCRLine, OCROut, OCRPage, OCRWord, TableCellOut, TableOut
from app.utils import get_documents, resolve_geometry
from app.vision import init_predictor

Expand Down Expand Up @@ -39,6 +39,27 @@ async def perform_ocr(request: OCRIn = Depends(), files: list[UploadFile] = [Fil
)
for region in page.layout
],
tables=[
TableOut(
num_rows=table.num_rows,
num_cols=table.num_cols,
geometry=resolve_geometry(table.geometry),
confidence=round(table.confidence, 2),
cells=[
TableCellOut(
value=cell.value,
geometry=resolve_geometry(cell.geometry),
confidence=round(cell.confidence, 2),
row_start=cell.row_start,
row_end=cell.row_end,
col_start=cell.col_start,
col_end=cell.col_end,
)
for cell in table.cells
],
)
for table in page.tables
],
items=[
OCRPage(
blocks=[
Expand Down
94 changes: 85 additions & 9 deletions api/app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class OCRIn(KIEIn, BaseModel):
resolve_lines: bool = Field(default=True, examples=[True])
resolve_blocks: bool = Field(default=False, examples=[False])
paragraph_break: float = Field(default=0.0035, examples=[0.0035])
detect_tables: bool = Field(default=False, examples=[False])


class RecognitionIn(BaseModel):
Expand Down Expand Up @@ -139,6 +140,37 @@ class LayoutElementOut(BaseModel):
confidence: float = Field(..., examples=[0.99])


class TableCellOut(BaseModel):
value: str = Field(..., examples=["example"])
geometry: list[float] = Field(..., examples=[[0.0, 0.0, 0.0, 0.0]])
confidence: float = Field(..., examples=[0.99])
row_start: int = Field(..., examples=[0])
row_end: int = Field(..., examples=[0])
col_start: int = Field(..., examples=[0])
col_end: int = Field(..., examples=[0])


class TableOut(BaseModel):
num_rows: int = Field(..., examples=[2])
num_cols: int = Field(..., examples=[2])
geometry: list[float] = Field(..., examples=[[0.0, 0.0, 0.0, 0.0]])
confidence: float = Field(..., examples=[0.99])
cells: list[TableCellOut] = Field(
...,
examples=[
{
"value": "example",
"geometry": [0.0, 0.0, 0.0, 0.0],
"confidence": 0.99,
"row_start": 0,
"row_end": 0,
"col_start": 0,
"col_end": 0,
}
],
)


class OCROut(BaseModel):
name: str = Field(..., examples=["example.jpg"])
orientation: dict[str, float | None] = Field(..., examples=[{"value": 0.0, "confidence": 0.99}])
Expand All @@ -148,27 +180,55 @@ class OCROut(BaseModel):
default=[],
examples=[[{"type": "Title", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}]],
)
tables: list[TableOut] = Field(
default=[],
examples=[
[
{
"num_rows": 2,
"num_cols": 2,
"geometry": [0.0, 0.0, 0.0, 0.0],
"confidence": 0.99,
"cells": [
{
"value": "example",
"geometry": [0.0, 0.0, 0.0, 0.0],
"confidence": 0.99,
"row_start": 0,
"row_end": 0,
"col_start": 0,
"col_end": 0,
}
],
}
]
],
)
items: list[OCRPage] = Field(
...,
examples=[
{
"geometry": [0.0, 0.0, 0.0, 0.0],
"objectness_score": 0.99,
"lines": [
"blocks": [
{
"geometry": [0.0, 0.0, 0.0, 0.0],
"objectness_score": 0.99,
"words": [
"lines": [
{
"value": "example",
"geometry": [0.0, 0.0, 0.0, 0.0],
"objectness_score": 0.99,
"confidence": 0.99,
"crop_orientation": {"value": 0, "confidence": None},
"words": [
{
"value": "example",
"geometry": [0.0, 0.0, 0.0, 0.0],
"objectness_score": 0.99,
"confidence": 0.99,
"crop_orientation": {"value": 0, "confidence": None},
}
],
}
],
}
],
]
}
],
)
Expand Down Expand Up @@ -199,4 +259,20 @@ class KIEOut(BaseModel):
default=[],
examples=[[{"type": "Title", "geometry": [0.0, 0.0, 0.0, 0.0], "confidence": 0.99}]],
)
predictions: list[KIEElement]
predictions: list[KIEElement] = Field(
...,
examples=[
{
"class_name": "example",
"items": [
{
"value": "example",
"geometry": [0.0, 0.0, 0.0, 0.0],
"objectness_score": 0.99,
"confidence": 0.99,
"crop_orientation": {"value": 0, "confidence": None},
}
],
}
],
)
2 changes: 2 additions & 0 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def mock_ocr_response():
"language": {"value": None, "confidence": None},
"dimensions": [2339, 1654],
"layout": [],
"tables": [],
"items": [
{
"blocks": [
Expand Down Expand Up @@ -207,6 +208,7 @@ def mock_ocr_response():
"language": {"value": None, "confidence": None},
"dimensions": [2339, 1654],
"layout": [],
"tables": [],
"items": [
{
"blocks": [
Expand Down
29 changes: 29 additions & 0 deletions api/tests/routes/test_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,35 @@
assert len(region["geometry"]) in (4, 8)


@pytest.mark.asyncio
async def test_ocr_tables(test_app_asyncio, mock_detection_image):
headers = {
"accept": "application/json",
}
params = {"det_arch": "db_resnet50", "reco_arch": "crnn_vgg16_bn", "detect_tables": True}
files = [
("files", ("test.jpg", mock_detection_image, "image/jpeg")),
]
response = await test_app_asyncio.post("/ocr", params=params, files=files, headers=headers)
assert response.status_code == 200
json_response = response.json()

assert isinstance(json_response, list) and len(json_response) == 1
# detect_tables enables the layout model as well
assert "layout" in json_response[0] and isinstance(json_response[0]["layout"], list)
assert "tables" in json_response[0] and isinstance(json_response[0]["tables"], list)

Check warning on line 112 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L112

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
for table in json_response[0]["tables"]:
assert isinstance(table["num_rows"], int) and isinstance(table["num_cols"], int)

Check warning on line 114 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L114

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert isinstance(table["confidence"], (int, float))

Check warning on line 115 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L115

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert len(table["geometry"]) in (4, 8)

Check warning on line 116 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L116

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert isinstance(table["cells"], list)

Check warning on line 117 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L117

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
for cell in table["cells"]:
assert isinstance(cell["value"], str)

Check warning on line 119 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L119

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert isinstance(cell["confidence"], (int, float))

Check warning on line 120 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L120

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert len(cell["geometry"]) in (4, 8)

Check warning on line 121 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L121

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
assert all(isinstance(cell[k], int) for k in ("row_start", "row_end", "col_start", "col_end"))

Check warning on line 122 in api/tests/routes/test_ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/tests/routes/test_ocr.py#L122

Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.


@pytest.mark.asyncio
async def test_ocr_invalid_file(test_app_asyncio, mock_txt_file):
headers = {
Expand Down
11 changes: 11 additions & 0 deletions demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st
import torch
from backend.pytorch import DET_ARCHS, LAYOUT_ARCHS, RECO_ARCHS, forward_image, load_predictor
Expand Down Expand Up @@ -70,6 +71,8 @@ def main(det_archs, reco_archs, layout_archs):
# Layout detection
detect_layout = st.sidebar.checkbox("Detect layout", value=False)
layout_arch = st.sidebar.selectbox("Layout detection model", layout_archs, disabled=not detect_layout)
# Table detection (relies on the layout model to locate tables)
detect_tables = st.sidebar.checkbox("Detect tables", value=False)
st.sidebar.write("\n")
# Binarization threshold
bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
Expand Down Expand Up @@ -97,6 +100,7 @@ def main(det_archs, reco_archs, layout_archs):
device=forward_device,
detect_layout=detect_layout,
layout_arch=layout_arch,
detect_tables=detect_tables,
)

with st.spinner("Analyzing..."):
Expand All @@ -122,6 +126,13 @@ def main(det_archs, reco_archs, layout_archs):
img = out.pages[0].synthesize()
cols[3].image(img, clamp=True)

# Display extracted tables (if any)
if out.pages[0].tables:
st.markdown("\nExtracted tables:")
for idx, table in enumerate(out.pages[0].tables):
st.markdown(f"**Table {idx + 1}** ({table.num_rows} x {table.num_cols})")
st.dataframe(pd.DataFrame(table.to_grid()))

# Display JSON
st.markdown("\nHere are your analysis results in JSON format:")
st.json(page_export, expanded=False)
Expand Down
3 changes: 3 additions & 0 deletions demo/backend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def load_predictor(
device: torch.device,
detect_layout: bool,
layout_arch: str,
detect_tables: bool,
) -> OCRPredictor:
"""Load a predictor from doctr.models

Expand All @@ -66,6 +67,7 @@ def load_predictor(
device: torch.device, the device to load the predictor on
detect_layout: whether to run a layout detection model and attach the regions to each page
layout_arch: layout architecture to use when detect_layout is True
detect_tables: whether to detect tables (via the layout model), structure them and attach them to each page

Returns:
instance of OCRPredictor
Expand All @@ -82,6 +84,7 @@ def load_predictor(
disable_crop_orientation=disable_crop_orientation,
detect_layout=detect_layout,
layout_arch=layout_arch,
detect_tables=detect_tables,
).to(device)
predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh
predictor.det_predictor.model.postprocessor.box_thresh = box_thresh
Expand Down
1 change: 1 addition & 0 deletions demo/pt-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
-e "python-doctr[viz] @ git+https://github.com/mindee/doctr.git"
streamlit>=1.0.0
pandas>=2.0.0
33 changes: 33 additions & 0 deletions docs/source/using_doctr/custom_models_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,36 @@ Loading your custom trained orientation classification model
# Overwrite the default orientation models
predictor.crop_orientation_predictor = crop_orientation_predictor(custom_crop_orientation_model)
predictor.page_orientation_predictor = page_orientation_predictor(custom_page_orientation_model)

Custom table structure recognition models
-----------------------------------------

If you work with documents containing tables and make use of the table structure recognition feature by passing the following arguments:

* `detect_tables=True`

You can train your own table structure recognition model using the docTR library. For details on the training process and the necessary data and data format, refer to the following link:

- `table structure recognition <https://github.com/mindee/doctr/blob/main/references/table/README.md#usage>`_

Loading your custom trained table structure recognition model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. code:: python3

import torch
from doctr.io import DocumentFile
from doctr.models import ocr_predictor, tablecenternet
from doctr.models.table_structure.zoo import table_predictor

custom_table_structure_model = tablecenternet(pretrained=False)
custom_table_structure_model.from_pretrained('<path_to_pt>')

predictor = ocr_predictor(
pretrained=True,
detect_layout=True,
detect_tables=True,
)

# Overwrite the default table structure model
predictor.table_predictor = table_predictor(custom_table_structure_model)
32 changes: 32 additions & 0 deletions docs/source/using_doctr/using_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,38 @@ and columns.
# out[0] -> {"cells": [{"geometry": ..., "score": ..., "row_start": 0, "row_end": 0,
# "col_start": 0, "col_end": 0}, ...], "num_rows": ..., "num_cols": ...}

Tables in the OCR pipeline
^^^^^^^^^^^^^^^^^^^^^^^^^^^

Passing ``detect_tables=True`` to :py:meth:`ocr_predictor <doctr.models.ocr_predictor>` runs table structure
recognition. It relies on the layout model: each region the layout model labels as
a table is cropped and passed to the table model, so a page yields one structured table per detected table region
(``detect_tables=True`` therefore also enables the layout model, whose regions are attached to the page). The words
whose center falls inside a detected cell are regrouped into a structured table, attached to the page as
``page.tables``, and **removed from the regular** ``blocks`` **output** so the same text is not returned twice. Each
table exposes a dense row/column grid that loads directly into pandas.

.. code:: python3

from doctr.io import DocumentFile
from doctr.models import ocr_predictor

model = ocr_predictor(pretrained=True, detect_tables=True)
doc = DocumentFile.from_images("invoice_with_table.png")
result = model(doc)

page = result.pages[0]
# Structured tables (one or more per page), kept out of the regular text blocks
for i, table in enumerate(page.tables):
df = table.to_grid() # for a plain list of lists
print(f"Table {i} ({table.num_rows}x{table.num_cols}):")
print(df)

# The remaining (non-table) text is still available as usual
print(page.render())

Tables are included in :meth:`Page.export` under the ``tables`` key, so they are preserved in the JSON export as well.


End-to-End OCR
--------------
Expand Down
2 changes: 1 addition & 1 deletion doctr/datasets/table_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class TableStructureDataset(AbstractDataset):
img_folder: folder with all the dataset images
label_path: path to the JSON labels
use_polygons: whether to keep cell polygons instead of converting them to straight boxes
**kwargs: keyword arguments from `AbstractDataset` (e.g. ``img_transforms``, ``sample_transforms``)
**kwargs: keyword arguments from `AbstractDataset` (e.g. `img_transforms`, `sample_transforms`)
"""

def __init__(
Expand Down
Loading
Loading