From 1f1da14354c2c1febb1f4644b58a40e6df66b93e Mon Sep 17 00:00:00 2001 From: Paulo9631 Date: Tue, 30 Jun 2026 13:07:42 -0300 Subject: [PATCH] Remove-code-smell-too-many-branches --- doctr/datasets/coco_text.py | 70 +++--- doctr/datasets/cord.py | 57 +++-- doctr/models/_utils.py | 124 +++++----- doctr/models/detection/_utils/base.py | 54 +++-- .../differentiable_binarization/base.py | 80 ++++--- doctr/models/kie_predictor/pytorch.py | 91 +++++--- doctr/models/layout/lw_detr/pytorch.py | 82 ++++--- doctr/transforms/modules/pytorch.py | 220 ++++++++++-------- doctr/utils/metrics.py | 185 +++++++-------- doctr/utils/visualization.py | 170 ++++++++------ 10 files changed, 612 insertions(+), 521 deletions(-) diff --git a/doctr/datasets/coco_text.py b/doctr/datasets/coco_text.py index d1df3f0c5c..8edfa3154c 100644 --- a/doctr/datasets/coco_text.py +++ b/doctr/datasets/coco_text.py @@ -102,40 +102,50 @@ def __init__( for annotation in annotations: x, y, w, h = annotation["bbox"] - if use_polygons: - # (x, y) coordinates of top left, top right, bottom right, bottom left corners - box = np.array( - [ - [x, y], - [x + w, y], - [x + w, y + h], - [x, y + h], - ], - dtype=np_dtype, - ) - else: - # (xmin, ymin, xmax, ymax) coordinates - box = [x, y, x + w, y + h] + box = self._build_box(x, y, w, h, use_polygons, np_dtype) _targets.append((annotation["utf8_string"], box)) text_targets, box_targets = zip(*_targets) - - if recognition_task: - crops = crop_bboxes_from_image( - img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0) - ) - for crop, label in zip(crops, list(text_targets)): - if label and " " not in label: - self.data.append((crop, label)) - - elif detection_task: - self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0))) - else: - self.data.append(( - img_path, - dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)), - )) + self._process_task_sample(img_path, text_targets, box_targets, recognition_task, detection_task, tmp_root) self.root = tmp_root + @staticmethod + def _build_box(x: float, y: float, w: float, h: float, use_polygons: bool, np_dtype: type) -> list[float] | np.ndarray: + if use_polygons: + return np.array( + [ + [x, y], + [x + w, y], + [x + w, y + h], + [x, y + h], + ], + dtype=np_dtype, + ) + return [x, y, x + w, y + h] + + def _process_task_sample( + self, + img_path: str, + text_targets: tuple[str, ...], + box_targets: tuple[list[float] | np.ndarray, ...], + recognition_task: bool, + detection_task: bool, + tmp_root: str, + ) -> None: + if recognition_task: + crops = crop_bboxes_from_image( + img_path=os.path.join(tmp_root, img_path), geoms=np.asarray(box_targets, dtype=int).clip(min=0) + ) + for crop, label in zip(crops, list(text_targets)): + if label and " " not in label: + self.data.append((crop, label)) + elif detection_task: + self.data.append((img_path, np.asarray(box_targets, dtype=int).clip(min=0))) + else: + self.data.append(( + img_path, + dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)), + )) + def extra_repr(self) -> str: return f"train={self.train}" diff --git a/doctr/datasets/cord.py b/doctr/datasets/cord.py index d58376bd3b..147f9258cb 100644 --- a/doctr/datasets/cord.py +++ b/doctr/datasets/cord.py @@ -84,31 +84,7 @@ def __init__( if not os.path.exists(os.path.join(tmp_root, img_path)): raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}") - stem = Path(img_path).stem - _targets = [] - with open(os.path.join(self.root, "json", f"{stem}.json"), "rb") as f: - label = json.load(f) - for line in label["valid_line"]: - for word in line["words"]: - if len(word["text"]) > 0: - x = word["quad"]["x1"], word["quad"]["x2"], word["quad"]["x3"], word["quad"]["x4"] - y = word["quad"]["y1"], word["quad"]["y2"], word["quad"]["y3"], word["quad"]["y4"] - box: list[float] | np.ndarray - if use_polygons: - # (x, y) coordinates of top left, top right, bottom right, bottom left corners - box = np.array( - [ - [x[0], y[0]], - [x[1], y[1]], - [x[2], y[2]], - [x[3], y[3]], - ], - dtype=np_dtype, - ) - else: - # Reduce 8 coords to 4 -> xmin, ymin, xmax, ymax - box = [min(x), min(y), max(x), max(y)] - _targets.append((word["text"], box)) + _targets = self._process_image(img_path, tmp_root, use_polygons, np_dtype) text_targets, box_targets = zip(*_targets) @@ -129,5 +105,36 @@ def __init__( self.root = tmp_root + def _process_image( + self, + img_path: str, + tmp_root: str, + use_polygons: bool, + np_dtype: np.dtype, + ) -> list[tuple[str, list[float] | np.ndarray]]: + stem = Path(img_path).stem + _targets: list[tuple[str, list[float] | np.ndarray]] = [] + with open(os.path.join(self.root, "json", f"{stem}.json"), "rb") as f: + label = json.load(f) + for line in label["valid_line"]: + for word in line["words"]: + if len(word["text"]) > 0: + x = word["quad"]["x1"], word["quad"]["x2"], word["quad"]["x3"], word["quad"]["x4"] + y = word["quad"]["y1"], word["quad"]["y2"], word["quad"]["y3"], word["quad"]["y4"] + if use_polygons: + box = np.array( + [ + [x[0], y[0]], + [x[1], y[1]], + [x[2], y[2]], + [x[3], y[3]], + ], + dtype=np_dtype, + ) + else: + box = [min(x), min(y), max(x), max(y)] + _targets.append((word["text"], box)) + return _targets + def extra_repr(self) -> str: return f"train={self.train}" diff --git a/doctr/models/_utils.py b/doctr/models/_utils.py index 905eaf1e1f..5f36b535a8 100644 --- a/doctr/models/_utils.py +++ b/doctr/models/_utils.py @@ -31,6 +31,58 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float: return max(w / h, h / w) +def _compute_contour_angles( + contours: list[np.ndarray], + n_ct: int, + ratio_threshold_for_lines: float, +) -> list[float]: + angles = [] + for contour in contours[:n_ct]: + _, (w, h), angle = cv2.minAreaRect(contour) + if w < h: + w, h = h, w + angle -= 90 + while angle <= -90: + angle += 180 + while angle > 90: + angle -= 180 + if h > 0: + if w / h > ratio_threshold_for_lines: + angles.append(angle) + elif w / h < 1 / ratio_threshold_for_lines: + angles.append(angle - 90) + return angles + + +def _compute_median_skew_angle(angles: list[float]) -> int: + if len(angles) == 0: + return 0 + median = -median_low(angles) + skew_angle = -round(median) if abs(median) != 0 else 0 + if abs(skew_angle) == 90: + skew_angle = 0 + return skew_angle + + +def _resolve_final_angle( + base_angle: int, + skew_angle: int, + is_confident: bool, + page_orientation: int, +) -> int: + final_angle = base_angle + skew_angle + while final_angle > 180: + final_angle -= 360 + while final_angle <= -180: + final_angle += 360 + if is_confident: + if abs(skew_angle) % 90 == 0: + return page_orientation + if abs(skew_angle) == abs(page_orientation) and page_orientation != 0: + return page_orientation + return int(final_angle) + + def estimate_orientation( img: np.ndarray, general_page_orientation: tuple[int, float] | None = None, @@ -56,7 +108,6 @@ def estimate_orientation( """ assert len(img.shape) == 3 and img.shape[-1] in [1, 3], f"Image shape {img.shape} not supported" - # Convert image to grayscale if necessary if img.shape[-1] == 3: gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) gray_img = cv2.medianBlur(gray_img, 5) @@ -69,87 +120,24 @@ def estimate_orientation( base_angle = page_orientation if is_confident else 0 if is_confident: - # We rotate the image to the general orientation which improves the detection - # No expand needed bitmap is already padded thresh = rotate_image(thresh, -base_angle) - else: # That's only required if we do not work on the detection models bin map - # try to merge words in lines + else: (h, w) = img.shape[:2] k_x = max(1, (floor(w / 100))) k_y = max(1, (floor(h / 100))) kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y)) thresh = cv2.dilate(thresh, kernel, iterations=1) - # extract contours contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) - - # Filter & Sort contours contours = sorted( [contour for contour in contours if cv2.contourArea(contour) > lower_area], key=get_max_width_length_ratio, reverse=True, ) - angles = [] - for contour in contours[:n_ct]: - _, (w, h), angle = cv2.minAreaRect(contour) - - # OpenCV version-proof normalization: force 'w' to be the long side - # so the angle is consistently relative to the major axis. - # https://github.com/opencv/opencv/pull/28051/changes - if w < h: - w, h = h, w - angle -= 90 - - # Normalize angle to be within [-90, 90] - while angle <= -90: - angle += 180 - while angle > 90: - angle -= 180 - - if h > 0: - if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines - angles.append(angle) - elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree - angles.append(angle - 90) - - if len(angles) == 0: - skew_angle = 0 # in case no angles is found - else: - # median_low picks a value from the data to avoid outliers - median = -median_low(angles) - skew_angle = -round(median) if abs(median) != 0 else 0 - - # Resolve the 90-degree flip ambiguity. - # If the estimation is exactly 90/-90, it's usually a vertical detection of horizontal lines. - if abs(skew_angle) == 90: - skew_angle = 0 - - # combine with the general orientation and the estimated angle - # Apply the detected skew to our base orientation - final_angle = base_angle + skew_angle - - # Standardize result to [-179, 180] range to handle wrap-around cases (e.g., 180 + -31) - while final_angle > 180: - final_angle -= 360 - while final_angle <= -180: - final_angle += 360 - - if is_confident: - # If the estimated angle is perpendicular, treat it as 0 to avoid wrong flips - if abs(skew_angle) % 90 == 0: - return page_orientation - - # special case where the estimated angle is mostly wrong: - # case 1: - and + swapped - # case 2: estimated angle is completely wrong - # so in this case we prefer the general page orientation - if abs(skew_angle) == abs(page_orientation) and page_orientation != 0: - return page_orientation - - return int( - final_angle - ) # return the clockwise angle (negative - left side rotation, positive - right side rotation) + angles = _compute_contour_angles(contours, n_ct, ratio_threshold_for_lines) + skew_angle = _compute_median_skew_angle(angles) + return _resolve_final_angle(base_angle, skew_angle, is_confident, page_orientation) def rectify_crops( diff --git a/doctr/models/detection/_utils/base.py b/doctr/models/detection/_utils/base.py index 142b44e924..40a6a7af20 100644 --- a/doctr/models/detection/_utils/base.py +++ b/doctr/models/detection/_utils/base.py @@ -9,6 +9,35 @@ __all__ = ["_remove_padding"] +def _adjust_coords( + loc_pred: np.ndarray, + ratio: float, + symmetric_pad: bool, + assume_straight_pages: bool, + axis: int, +) -> None: + """Adjust coordinates along a given axis to remove padding + + Args: + loc_pred: localization predictions + ratio: aspect ratio multiplier + symmetric_pad: whether the padding was symmetric + assume_straight_pages: whether the pages are assumed to be straight + axis: 0 for x coordinates, 1 for y coordinates + """ + if assume_straight_pages: + cols = [axis, axis + 2] + if symmetric_pad: + loc_pred[:, cols] = (loc_pred[:, cols] - 0.5) * ratio + 0.5 + else: + loc_pred[:, cols] *= ratio + else: + if symmetric_pad: + loc_pred[:, :, axis] = (loc_pred[:, :, axis] - 0.5) * ratio + 0.5 + else: + loc_pred[:, :, axis] *= ratio + + def _remove_padding( pages: list[np.ndarray], loc_preds: list[dict[str, np.ndarray]], @@ -29,35 +58,14 @@ def _remove_padding( list of unpaded localization predictions """ if preserve_aspect_ratio: - # Rectify loc_preds to remove padding rectified_preds = [] for page, dict_loc_preds in zip(pages, loc_preds): for k, loc_pred in dict_loc_preds.items(): h, w = page.shape[0], page.shape[1] if h > w: - # y unchanged, dilate x coord - if symmetric_pad: - if assume_straight_pages: - loc_pred[:, [0, 2]] = (loc_pred[:, [0, 2]] - 0.5) * h / w + 0.5 - else: - loc_pred[:, :, 0] = (loc_pred[:, :, 0] - 0.5) * h / w + 0.5 - else: - if assume_straight_pages: - loc_pred[:, [0, 2]] *= h / w - else: - loc_pred[:, :, 0] *= h / w + _adjust_coords(loc_pred, h / w, symmetric_pad, assume_straight_pages, axis=0) elif w > h: - # x unchanged, dilate y coord - if symmetric_pad: - if assume_straight_pages: - loc_pred[:, [1, 3]] = (loc_pred[:, [1, 3]] - 0.5) * w / h + 0.5 - else: - loc_pred[:, :, 1] = (loc_pred[:, :, 1] - 0.5) * w / h + 0.5 - else: - if assume_straight_pages: - loc_pred[:, [1, 3]] *= w / h - else: - loc_pred[:, :, 1] *= w / h + _adjust_coords(loc_pred, w / h, symmetric_pad, assume_straight_pages, axis=1) rectified_preds.append({k: np.clip(loc_pred, 0, 1)}) return rectified_preds return loc_preds diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index 5f8d1e90e5..e78c074b47 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -85,6 +85,45 @@ def polygon_to_box( else order_points(cv2.boxPoints(cv2.minAreaRect(expanded_points))) ) + def _process_straight_contour( + self, + pred: np.ndarray, + contour: np.ndarray, + width: int, + height: int, + min_size_box: int, + ) -> list[float] | None: + x, y, w, h = cv2.boundingRect(contour) + points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]]) + score = self.box_score(pred, points, assume_straight_pages=True) + if score < self.box_thresh: + return None + _box = self.polygon_to_box(points) + if _box is None or _box[2] < min_size_box or _box[3] < min_size_box: + return None + xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height + return [xmin, ymin, xmax, ymax, score] + + def _process_rotated_contour( + self, + pred: np.ndarray, + contour: np.ndarray, + width: int, + height: int, + min_size_box: int, + ) -> np.ndarray | None: + score = self.box_score(pred, contour, assume_straight_pages=False) + if score < self.box_thresh: + return None + _box = self.polygon_to_box(np.squeeze(contour)) + if np.linalg.norm(_box[2, :] - _box[0, :], axis=-1) < min_size_box: + return None + if not isinstance(_box, np.ndarray) and _box.shape == (4, 2): + raise AssertionError("When assume straight pages is false a box is a (4, 2) array (polygon)") + _box[:, 0] /= width + _box[:, 1] /= height + return np.vstack([_box, np.array([0.0, score])]) + def bitmap_to_boxes( self, pred: np.ndarray, @@ -105,49 +144,16 @@ def bitmap_to_boxes( height, width = bitmap.shape[:2] min_size_box = 2 boxes: list[np.ndarray | list[float]] = [] - # get contours from connected components on the bitmap contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) for contour in contours: - # Check whether smallest enclosing bounding box is not too small if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box): continue - # Compute objectness if self.assume_straight_pages: - x, y, w, h = cv2.boundingRect(contour) - points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]]) - score = self.box_score(pred, points, assume_straight_pages=True) + box = self._process_straight_contour(pred, contour, width, height, min_size_box) else: - score = self.box_score(pred, contour, assume_straight_pages=False) - - if score < self.box_thresh: # remove polygons with a weak objectness - continue - - if self.assume_straight_pages: - _box = self.polygon_to_box(points) - else: - _box = self.polygon_to_box(np.squeeze(contour)) - - # Remove too small boxes - if self.assume_straight_pages: - if _box is None or _box[2] < min_size_box or _box[3] < min_size_box: - continue - elif np.linalg.norm(_box[2, :] - _box[0, :], axis=-1) < min_size_box: - continue - - if self.assume_straight_pages: - x, y, w, h = _box - # compute relative polygon to get rid of img shape - xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height - boxes.append([xmin, ymin, xmax, ymax, score]) - else: - # compute relative box to get rid of img shape, in that case _box is a 4pt polygon - if not isinstance(_box, np.ndarray) and _box.shape == (4, 2): - raise AssertionError("When assume straight pages is false a box is a (4, 2) array (polygon)") - _box[:, 0] /= width - _box[:, 1] /= height - # Add score to box as (0, score) - boxes.append(np.vstack([_box, np.array([0.0, score])])) - + box = self._process_rotated_contour(pred, contour, width, height, min_size_box) + if box is not None: + boxes.append(box) if not self.assume_straight_pages: return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5, 2), dtype=pred.dtype) else: diff --git a/doctr/models/kie_predictor/pytorch.py b/doctr/models/kie_predictor/pytorch.py index 153d97f2d4..90e0e2c371 100644 --- a/doctr/models/kie_predictor/pytorch.py +++ b/doctr/models/kie_predictor/pytorch.py @@ -65,6 +65,57 @@ def __init__( self.detect_orientation = detect_orientation self.detect_language = detect_language + def _handle_orientation_and_straighten( + self, + pages: list[np.ndarray], + seg_maps: list[np.ndarray], + origin_page_shapes: list[tuple[int, int]], + **kwargs: Any, + ) -> tuple[list[np.ndarray], list[tuple[int, int]], Any]: + + if self.detect_orientation: + general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) + orientations = [ + {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations + ] + else: + orientations = None + general_pages_orientations = None + origin_pages_orientations = None + + if self.straighten_pages: + pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) + origin_page_shapes = [page.shape[:2] for page in pages] + + return pages, origin_page_shapes, orientations + + def _prepare_crops_and_orientations( + self, + pages: list[np.ndarray], + dict_loc_preds: dict[str, list[np.ndarray]], + ) -> tuple[dict, dict[str, list[np.ndarray]], Any]: + + crops = {} + for class_name in dict_loc_preds.keys(): + crops[class_name], dict_loc_preds[class_name] = self._prepare_crops( + pages, + dict_loc_preds[class_name], + assume_straight_pages=self.assume_straight_pages, + assume_horizontal=self._page_orientation_disabled, + ) + + crop_orientations: Any = {} + if not self.assume_straight_pages: + for class_name in dict_loc_preds.keys(): + crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops( + crops[class_name], dict_loc_preds[class_name] + ) + crop_orientations[class_name] = [ + {"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations + ] + + return crops, dict_loc_preds, crop_orientations + @torch.inference_mode() def forward( self, @@ -87,21 +138,12 @@ def forward( ) for out_map in out_maps ] - if self.detect_orientation: - general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) - orientations = [ - {"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations - ] - else: - orientations = None - general_pages_orientations = None - origin_pages_orientations = None - if self.straighten_pages: - pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) - # update page shapes after straightening - origin_page_shapes = [page.shape[:2] for page in pages] + pages, origin_page_shapes, orientations = self._handle_orientation_and_straighten( + pages, seg_maps, origin_page_shapes, **kwargs + ) - # Forward again to get predictions on straight pages + # Forward again to get predictions on straight pages + if self.straighten_pages: loc_preds = self.det_predictor(pages, **kwargs) dict_loc_preds: dict[str, list[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment] @@ -117,25 +159,8 @@ def forward( for hook in self.hooks: dict_loc_preds = hook(dict_loc_preds) - # Crop images - crops = {} - for class_name in dict_loc_preds.keys(): - crops[class_name], dict_loc_preds[class_name] = self._prepare_crops( - pages, - dict_loc_preds[class_name], - assume_straight_pages=self.assume_straight_pages, - assume_horizontal=self._page_orientation_disabled, - ) - # Rectify crop orientation - crop_orientations: Any = {} - if not self.assume_straight_pages: - for class_name in dict_loc_preds.keys(): - crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops( - crops[class_name], dict_loc_preds[class_name] - ) - crop_orientations[class_name] = [ - {"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations - ] + # Crop images and rectify orientation + crops, dict_loc_preds, crop_orientations = self._prepare_crops_and_orientations(pages, dict_loc_preds) # Identify character sequences word_preds = { diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 521f6df7cc..1e1cedca28 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -341,8 +341,10 @@ def __init__( assume_straight_pages=self.assume_straight_pages, ) + self._init_weights() + + def _init_weights(self) -> None: for n, m in self.named_modules(): - # Don't override the initialization of the backbone if n.startswith("feat_extractor."): continue @@ -368,18 +370,15 @@ def __init__( if m.bias is not None: nn.init.constant_(m.bias, bias_value) - # Initialize the iterative refinement heads to predict zero deltas (i.e. identity refinement) - # at the start of training, to stabilize training in the early stages when the encoder proposals are still noisy with torch.no_grad(): for head in [self.bbox_embed, *self.enc_out_bbox_embed]: last = head.layers[-1] last.weight.zero_() last.bias.zero_() - last.bias[5] = 1.0 # cosθ of the rotation delta -> identity rotation + last.bias[5] = 1.0 - # The reference point embedding acts as a learned delta composed with the encoder proposals self.reference_point_embed.weight.zero_() - self.reference_point_embed.weight[:, 5] = 1.0 # cosθ + self.reference_point_embed.weight[:, 5] = 1.0 def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None: """Load pretrained parameters onto the model @@ -573,45 +572,56 @@ def _postprocess(logits, boxes): out["preds"] = _postprocess(logits.detach().cpu().numpy(), pred_boxes.detach().cpu().numpy()) if target is not None: - # Build target - processed_targets = self.build_target(target, self.class_names) + out["loss"] = self._compute_losses( + logits, pred_boxes, target, input, group_detr, + intermediate, intermediate_reference_points, + all_group_enc_logits, all_group_enc_coords, + ) - # ProbIoU is computed in pixel coordinates - box_scale = float(max(input.shape[-2], input.shape[-1])) + return out - # Main loss from final decoder layer (group DETR: each group is matched independently) - split_logits = logits.chunk(group_detr, dim=1) - split_boxes = pred_boxes.chunk(group_detr, dim=1) + def _compute_losses( + self, + logits: torch.Tensor, + pred_boxes: torch.Tensor, + target: list[dict[str, np.ndarray]], + input: torch.Tensor, + group_detr: int, + intermediate: torch.Tensor, + intermediate_reference_points: list[torch.Tensor], + all_group_enc_logits: list[torch.Tensor], + all_group_enc_coords: list[torch.Tensor], + ) -> torch.Tensor: + processed_targets = self.build_target(target, self.class_names) + box_scale = float(max(input.shape[-2], input.shape[-1])) - main_loss: float | torch.Tensor = 0.0 - for g_logits, g_boxes in zip(split_logits, split_boxes): - main_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale) - loss = main_loss / group_detr + split_logits = logits.chunk(group_detr, dim=1) + split_boxes = pred_boxes.chunk(group_detr, dim=1) - # Auxiliary losses from intermediate decoder layers - # (`intermediate_reference_points[i]` is the reference INPUT to decoder layer i) - for i in range(intermediate.shape[0] - 1): - aux_logits = self.class_embed(intermediate[i]) - aux_boxes_delta = self.bbox_embed(intermediate[i]) - aux_boxes = refine_obb_boxes(intermediate_reference_points[i], aux_boxes_delta) + main_loss: float | torch.Tensor = 0.0 + for g_logits, g_boxes in zip(split_logits, split_boxes): + main_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale) + loss = main_loss / group_detr - split_aux_logits = aux_logits.chunk(group_detr, dim=1) - split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) + for i in range(intermediate.shape[0] - 1): + aux_logits = self.class_embed(intermediate[i]) + aux_boxes_delta = self.bbox_embed(intermediate[i]) + aux_boxes = refine_obb_boxes(intermediate_reference_points[i], aux_boxes_delta) - aux_loss: float | torch.Tensor = 0.0 - for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): - aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale) - loss += aux_loss / group_detr + split_aux_logits = aux_logits.chunk(group_detr, dim=1) + split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) - # Auxiliary losses for the selected encoder proposals - enc_loss: float | torch.Tensor = 0.0 - for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): - enc_loss += self.compute_loss(group_logits, group_coords, processed_targets, box_scale=box_scale) - loss += enc_loss / group_detr + aux_loss: float | torch.Tensor = 0.0 + for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): + aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale) + loss += aux_loss / group_detr - out["loss"] = loss + enc_loss: float | torch.Tensor = 0.0 + for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): + enc_loss += self.compute_loss(group_logits, group_coords, processed_targets, box_scale=box_scale) + loss += enc_loss / group_detr - return out + return loss def compute_loss( self, diff --git a/doctr/transforms/modules/pytorch.py b/doctr/transforms/modules/pytorch.py index c6b7881b56..e0c9256277 100644 --- a/doctr/transforms/modules/pytorch.py +++ b/doctr/transforms/modules/pytorch.py @@ -92,6 +92,118 @@ def _resize_target( return np.clip(target, 0, 1) + def _resize_mask(self, mask: torch.Tensor, size: tuple[int, int]) -> torch.Tensor: + return F.resize( + mask, + size, + interpolation=F.InterpolationMode.NEAREST, + antialias=False, + ).squeeze(0) + + def _build_return_sample( + self, + sample: Sample, + img: torch.Tensor, + mask: torch.Tensor | None, + target: np.ndarray | dict | str | None, + padding_mask: torch.Tensor | None, + resize_mask: bool, + ) -> Sample: + if target is not None: + if self.return_padding_mask: + return sample.replace(image=img, target=target, mask=mask if resize_mask else padding_mask) + return sample.replace(image=img, target=target, mask=mask if resize_mask else sample.mask) + if self.return_padding_mask: + return sample.replace(image=img, mask=mask if resize_mask else padding_mask) + return sample.replace(image=img, mask=mask if resize_mask else sample.mask) + + def _resize_targets( + self, + target: np.ndarray | dict | str | None, + raw_shape: tuple[int, int], + final_shape: tuple[int, int], + half_pad: tuple[int, int] | None, + ) -> np.ndarray | dict | str | None: + if target is None: + return target + + if self.symmetric_pad: + offset = ( + half_pad[0] / final_shape[-1], + half_pad[1] / final_shape[-2], + ) + else: + offset = (0, 0) + + if isinstance(target, str) or (isinstance(target, np.ndarray) and target.shape == (1,)): + return target + elif isinstance(target, dict): + return { + cls_name: self._resize_target( + arr, + raw_shape, + final_shape, + symmetric_pad=self.symmetric_pad, + offset=offset, + ) + for cls_name, arr in target.items() + } + else: + return self._resize_target( + target, + raw_shape, + final_shape, + symmetric_pad=self.symmetric_pad, + offset=offset, + ) + + def _resize_preserve_aspect_ratio( + self, + sample: Sample, + img: torch.Tensor, + mask: torch.Tensor | None, + target: np.ndarray | dict | str | None, + resize_mask: bool, + ) -> Sample: + target_ratio = self.size[0] / self.size[1] + actual_ratio = img.shape[-2] / img.shape[-1] + + if actual_ratio > target_ratio: + tmp_size = (self.size[0], max(int(self.size[0] / actual_ratio), 1)) + else: + tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1]) + + img = F.resize(img, tmp_size, self.interpolation, antialias=True) + + if resize_mask: + mask = self._resize_mask(mask, tmp_size) + + raw_shape = img.shape[-2:] + padding_mask = None + half_pad = (0, 0) + + if isinstance(self.size, (tuple, list)): + _pad = (0, self.size[1] - img.shape[-1], 0, self.size[0] - img.shape[-2]) + + if self.symmetric_pad: + half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2)) + _pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1]) + + img = pad(img, _pad) + + if resize_mask and mask is not None: + mask = pad(mask, _pad) + + if self.return_padding_mask: + h, w = self.size + padding_mask = torch.zeros((h, w), dtype=torch.bool, device=img.device) + left, right, top, bottom = _pad + padding_mask[top : h - bottom, left : w - right] = True + + target = self._resize_targets(target, raw_shape, img.shape[-2:], half_pad) + + return self._build_return_sample(sample, img, mask, target, padding_mask, resize_mask) + def forward( self, sample: Sample, @@ -100,118 +212,28 @@ def forward( target = sample.target mask = sample.mask - # Resize mask alongside image if provided - # Masks should use nearest interpolation to preserve label integrity resize_mask = mask is not None - if resize_mask and mask is not None and mask.ndim == 2: + if resize_mask and mask.ndim == 2: mask = mask.unsqueeze(0) target_ratio = self.size[0] / self.size[1] actual_ratio = img.shape[-2] / img.shape[-1] if not self.preserve_aspect_ratio or (target_ratio == actual_ratio): - # If we don't preserve the aspect ratio or the wanted aspect ratio is the same than the original one - # We can use with the regular resize img = super().forward(img) if resize_mask: - mask = F.resize( - mask, - self.size, - interpolation=F.InterpolationMode.NEAREST, - antialias=False, - ).squeeze(0) - - if self.return_padding_mask: - padding_mask = torch.zeros(self.size, dtype=torch.bool, device=img.device) - - if target is not None: - if self.return_padding_mask: - return sample.replace(image=img, target=target, mask=mask if resize_mask else padding_mask) - return sample.replace(image=img, target=target, mask=mask if resize_mask else sample.mask) - if self.return_padding_mask: - return sample.replace(image=img, mask=mask if resize_mask else padding_mask) - return sample.replace(image=img, mask=mask if resize_mask else sample.mask) - - else: - # Resize - if actual_ratio > target_ratio: - tmp_size = (self.size[0], max(int(self.size[0] / actual_ratio), 1)) - else: - tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1]) + mask = self._resize_mask(mask, self.size) - # Scale image - img = F.resize(img, tmp_size, self.interpolation, antialias=True) + padding_mask = ( + torch.zeros(self.size, dtype=torch.bool, device=img.device) + if self.return_padding_mask + else None + ) - if resize_mask: - mask = F.resize( - mask, - tmp_size, - interpolation=F.InterpolationMode.NEAREST, - antialias=False, - ).squeeze(0) - - raw_shape = img.shape[-2:] - - if isinstance(self.size, (tuple, list)): - # Pad (inverted in pytorch) - _pad = (0, self.size[1] - img.shape[-1], 0, self.size[0] - img.shape[-2]) - - if self.symmetric_pad: - half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2)) - _pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1]) - # Pad image - img = pad(img, _pad) - - if resize_mask and mask is not None: - mask = pad(mask, _pad) - - if self.return_padding_mask: - h, w = self.size - padding_mask = torch.zeros((h, w), dtype=torch.bool, device=img.device) - left, right, top, bottom = _pad - padding_mask[top : h - bottom, left : w - right] = True - - # In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio) - if target is not None: - if self.symmetric_pad: - offset = ( - half_pad[0] / img.shape[-1], - half_pad[1] / img.shape[-2], - ) - else: - offset = (0, 0) - - if isinstance(target, str) or (isinstance(target, np.ndarray) and target.shape == (1,)): - # Special case for orientation targets and other non-box targets, which should not be resized - pass - elif isinstance(target, dict): - target = { - cls_name: self._resize_target( - arr, - raw_shape, - img.shape[-2:], - symmetric_pad=self.symmetric_pad, - offset=offset, - ) - for cls_name, arr in target.items() - } - else: - target = self._resize_target( - target, - raw_shape, - img.shape[-2:], - symmetric_pad=self.symmetric_pad, - offset=offset, - ) + return self._build_return_sample(sample, img, mask, target, padding_mask, resize_mask) - if target is not None: - if self.return_padding_mask: - return sample.replace(image=img, target=target, mask=mask if resize_mask else padding_mask) - return sample.replace(image=img, target=target, mask=mask if resize_mask else sample.mask) - if self.return_padding_mask: - return sample.replace(image=img, mask=mask if resize_mask else padding_mask) - return sample.replace(image=img, mask=mask if resize_mask else sample.mask) + return self._resize_preserve_aspect_ratio(sample, img, mask, target, resize_mask) def __repr__(self) -> str: interpolate_str = self.interpolation.value diff --git a/doctr/utils/metrics.py b/doctr/utils/metrics.py index 5d01cc8bd1..04192beb2a 100644 --- a/doctr/utils/metrics.py +++ b/doctr/utils/metrics.py @@ -656,113 +656,16 @@ def summary(self) -> dict[str, float | dict[float, float]]: if len(self._gts) == 0: raise AssertionError("No samples added") - # Determine classes - if self.num_classes is None: - labels = [] - for g in self._gts: - labels.extend(g["labels"].tolist()) - for p in self._preds: - labels.extend(p["labels"].tolist()) - classes = np.unique(labels) - else: - classes = np.arange(self.num_classes) - + classes = self._determine_classes() ap_per_iou = {} for iou_thresh in self.iou_thresholds: class_aps = [] - for c in classes: - # Collect GTs per image - gt_by_image = {} - total_gt = 0 - - for img_idx, gt in enumerate(self._gts): - mask = gt["labels"] == c - gt_boxes = gt["boxes"][mask] - - gt_by_image[img_idx] = { - "boxes": gt_boxes, - "matched": np.zeros(len(gt_boxes), dtype=bool), - } - - total_gt += len(gt_boxes) - - if total_gt == 0: - continue - - # Collect all detections globally - detections = [] - - for img_idx, pred in enumerate(self._preds): - mask = pred["labels"] == c - - pred_boxes = pred["boxes"][mask] - pred_scores = pred["scores"][mask] - - for box, score in zip(pred_boxes, pred_scores): - detections.append({ - "image_id": img_idx, - "box": box, - "score": float(score), - }) - - if len(detections) == 0: - class_aps.append(0.0) - continue - - # Global sorting COCO-style - detections.sort(key=lambda x: -x["score"]) - - tp = np.zeros(len(detections)) - fp = np.zeros(len(detections)) - - # Evaluate detections - for det_idx, det in enumerate(detections): - img_idx = det["image_id"] - pred_box = det["box"] - - gt_data = gt_by_image[img_idx] - gt_boxes = gt_data["boxes"] - - if len(gt_boxes) == 0: - fp[det_idx] = 1 - continue - - # Compute IoUs - if self.use_polygons: - iou_mat = polygon_iou( - gt_boxes, - np.expand_dims(pred_box, axis=0), - ) - else: - iou_mat = box_iou( - gt_boxes, - np.expand_dims(pred_box, axis=0), - ) - - ious = iou_mat[:, 0] - - best_gt = np.argmax(ious) - best_iou = ious[best_gt] - - if best_iou >= iou_thresh and not gt_data["matched"][best_gt]: - tp[det_idx] = 1 - gt_data["matched"][best_gt] = True - else: - fp[det_idx] = 1 - - # Precision / Recall - tp_cum = np.cumsum(tp) - fp_cum = np.cumsum(fp) - - recall = tp_cum / total_gt - precision = tp_cum / np.maximum(tp_cum + fp_cum, 1e-8) - - ap = self._compute_ap(recall, precision) - class_aps.append(ap) - - ap_per_iou[float(iou_thresh)] = float(np.mean(class_aps)) if len(class_aps) > 0 else 0.0 + ap = self._evaluate_class(c, iou_thresh) + if ap is not None: + class_aps.append(ap) + ap_per_iou[float(iou_thresh)] = float(np.mean(class_aps)) if class_aps else 0.0 map_value = float(np.mean(list(ap_per_iou.values()))) ap50 = ap_per_iou.get(0.5, 0.0) @@ -775,6 +678,84 @@ def summary(self) -> dict[str, float | dict[float, float]]: "AP_per_IoU": ap_per_iou, } + def _determine_classes(self) -> np.ndarray: + if self.num_classes is None: + labels = [] + for g in self._gts: + labels.extend(g["labels"].tolist()) + for p in self._preds: + labels.extend(p["labels"].tolist()) + return np.unique(labels) + return np.arange(self.num_classes) + + def _evaluate_class(self, class_label: int, iou_thresh: float) -> float | None: + gt_by_image = {} + total_gt = 0 + + for img_idx, gt in enumerate(self._gts): + mask = gt["labels"] == class_label + gt_boxes = gt["boxes"][mask] + gt_by_image[img_idx] = { + "boxes": gt_boxes, + "matched": np.zeros(len(gt_boxes), dtype=bool), + } + total_gt += len(gt_boxes) + + if total_gt == 0: + return None + + detections = [] + for img_idx, pred in enumerate(self._preds): + mask = pred["labels"] == class_label + pred_boxes = pred["boxes"][mask] + pred_scores = pred["scores"][mask] + for box, score in zip(pred_boxes, pred_scores): + detections.append({ + "image_id": img_idx, + "box": box, + "score": float(score), + }) + + if len(detections) == 0: + return 0.0 + + detections.sort(key=lambda x: -x["score"]) + + tp = np.zeros(len(detections)) + fp = np.zeros(len(detections)) + + for det_idx, det in enumerate(detections): + img_idx = det["image_id"] + pred_box = det["box"] + gt_data = gt_by_image[img_idx] + gt_boxes = gt_data["boxes"] + + if len(gt_boxes) == 0: + fp[det_idx] = 1 + continue + + if self.use_polygons: + iou_mat = polygon_iou(gt_boxes, np.expand_dims(pred_box, axis=0)) + else: + iou_mat = box_iou(gt_boxes, np.expand_dims(pred_box, axis=0)) + + ious = iou_mat[:, 0] + best_gt = np.argmax(ious) + best_iou = ious[best_gt] + + if best_iou >= iou_thresh and not gt_data["matched"][best_gt]: + tp[det_idx] = 1 + gt_data["matched"][best_gt] = True + else: + fp[det_idx] = 1 + + tp_cum = np.cumsum(tp) + fp_cum = np.cumsum(fp) + recall = tp_cum / total_gt + precision = tp_cum / np.maximum(tp_cum + fp_cum, 1e-8) + + return self._compute_ap(recall, precision) + def _compute_ap(self, recall: np.ndarray, precision: np.ndarray) -> float: """Computes the Average Precision using the 101-point interpolation method from COCO diff --git a/doctr/utils/visualization.py b/doctr/utils/visualization.py index bf634fe1d9..71efd893d6 100644 --- a/doctr/utils/visualization.py +++ b/doctr/utils/visualization.py @@ -152,6 +152,103 @@ def get_colors(num_colors: int) -> list[tuple[float, float, float]]: return colors +def _render_layout( + ax, + page, + interactive, + add_labels, + artists, + **kwargs, +): + region_classes = sorted({region["type"] for region in page["layout"]}) + layout_colors = {cls: color for color, cls in zip(get_colors(max(len(region_classes), 1)), region_classes)} + for region in page["layout"]: + rect = create_obj_patch( + region["geometry"], + page["dimensions"], + label=f"{region['type']} (confidence: {region['confidence']:.2%})", + color=layout_colors[region["type"]], + linewidth=2, + fill=False, + **kwargs, + ) + ax.add_patch(rect) + if interactive: + artists.append(rect) + elif add_labels and len(region["geometry"]) == 2: + ax.text( + int(page["dimensions"][1] * region["geometry"][0][0]), + int(page["dimensions"][0] * region["geometry"][0][1]), + region["type"], + size=9, + alpha=0.7, + color=layout_colors[region["type"]], + ) + + +def _render_word( + ax, + word, + page, + interactive, + add_labels, + artists, + **kwargs, +): + rect = create_obj_patch( + word["geometry"], + page["dimensions"], + label=f"{word['value']} (confidence: {word['confidence']:.2%})", + color=(0, 0, 1), + **kwargs, + ) + ax.add_patch(rect) + if interactive: + artists.append(rect) + elif add_labels: + if len(word["geometry"]) == 5: + text_loc = ( + int(page["dimensions"][1] * (word["geometry"][0] - word["geometry"][2] / 2)), + int(page["dimensions"][0] * (word["geometry"][1] - word["geometry"][3] / 2)), + ) + else: + text_loc = ( + int(page["dimensions"][1] * word["geometry"][0][0]), + int(page["dimensions"][0] * word["geometry"][0][1]), + ) + + if len(word["geometry"]) == 2: + ax.text( + *text_loc, + word["value"], + size=10, + alpha=0.5, + color=(0, 0, 1), + ) + + +def _render_artefacts( + ax, + block, + page, + interactive, + artists, + **kwargs, +): + for artefact in block["artefacts"]: + rect = create_obj_patch( + artefact["geometry"], + page["dimensions"], + label="artefact", + color=(0.5, 0.5, 0.5), + linewidth=1, + **kwargs, + ) + ax.add_patch(rect) + if interactive: + artists.append(rect) + + def visualize_page( page: dict[str, Any], image: np.ndarray, @@ -198,35 +295,13 @@ def visualize_page( # hide both axis ax.axis("off") + artists: list[patches.Patch] | None = None if interactive: - artists: list[patches.Patch] = [] # instantiate an empty list of patches (to be drawn on the page) + artists = [] # instantiate an empty list of patches (to be drawn on the page) # Draw layout regions first so text boxes are overlaid on top of them if display_layout and page.get("layout"): - region_classes = sorted({region["type"] for region in page["layout"]}) - layout_colors = {cls: color for color, cls in zip(get_colors(max(len(region_classes), 1)), region_classes)} - for region in page["layout"]: - rect = create_obj_patch( - region["geometry"], - page["dimensions"], - label=f"{region['type']} (confidence: {region['confidence']:.2%})", - color=layout_colors[region["type"]], - linewidth=2, - fill=False, - **kwargs, - ) - ax.add_patch(rect) - if interactive: - artists.append(rect) - elif add_labels and len(region["geometry"]) == 2: - ax.text( - int(page["dimensions"][1] * region["geometry"][0][0]), - int(page["dimensions"][0] * region["geometry"][0][1]), - region["type"], - size=9, - alpha=0.7, - color=layout_colors[region["type"]], - ) + _render_layout(ax, page, interactive, add_labels, artists, **kwargs) for block in page["blocks"]: if not words_only: @@ -249,51 +324,10 @@ def visualize_page( artists.append(rect) for word in line["words"]: - rect = create_obj_patch( - word["geometry"], - page["dimensions"], - label=f"{word['value']} (confidence: {word['confidence']:.2%})", - color=(0, 0, 1), - **kwargs, - ) - ax.add_patch(rect) - if interactive: - artists.append(rect) - elif add_labels: - if len(word["geometry"]) == 5: - text_loc = ( - int(page["dimensions"][1] * (word["geometry"][0] - word["geometry"][2] / 2)), - int(page["dimensions"][0] * (word["geometry"][1] - word["geometry"][3] / 2)), - ) - else: - text_loc = ( - int(page["dimensions"][1] * word["geometry"][0][0]), - int(page["dimensions"][0] * word["geometry"][0][1]), - ) - - if len(word["geometry"]) == 2: - # We draw only if boxes are in straight format - ax.text( - *text_loc, - word["value"], - size=10, - alpha=0.5, - color=(0, 0, 1), - ) + _render_word(ax, word, page, interactive, add_labels, artists, **kwargs) if display_artefacts: - for artefact in block["artefacts"]: - rect = create_obj_patch( - artefact["geometry"], - page["dimensions"], - label="artefact", - color=(0.5, 0.5, 0.5), - linewidth=1, - **kwargs, - ) - ax.add_patch(rect) - if interactive: - artists.append(rect) + _render_artefacts(ax, block, page, interactive, artists, **kwargs) if interactive: import mplcursors