From 4ca0f18127722190d29ec843e5711e2c189a4683 Mon Sep 17 00:00:00 2001 From: goldokpa Date: Mon, 15 Jun 2026 22:00:22 +0100 Subject: [PATCH 1/4] feat(flooding): SAR flood detection model, ensemble, OSM impact, and tests Co-Authored-By: Claude Opus 4.8 --- config/train_flood.yaml | 59 ++++ notebooks/02_flood_detection.ipynb | 314 ++++++++++++++++++ .../analysis/flooding_ensemble.py | 224 +++++++++++++ src/climatevision/data/sar_preprocessing.py | 172 ++++++++++ src/climatevision/impact/__init__.py | 13 + src/climatevision/impact/osm_roads.py | 209 ++++++++++++ src/climatevision/models/flood_unet.py | 227 +++++++++++++ tests/test_api_flood.py | 49 +++ tests/test_flooding.py | 72 ++++ tests/test_sar_preprocessing.py | 54 +++ 10 files changed, 1393 insertions(+) create mode 100644 config/train_flood.yaml create mode 100644 notebooks/02_flood_detection.ipynb create mode 100644 src/climatevision/analysis/flooding_ensemble.py create mode 100644 src/climatevision/data/sar_preprocessing.py create mode 100644 src/climatevision/impact/__init__.py create mode 100644 src/climatevision/impact/osm_roads.py create mode 100644 src/climatevision/models/flood_unet.py create mode 100644 tests/test_api_flood.py create mode 100644 tests/test_flooding.py create mode 100644 tests/test_sar_preprocessing.py diff --git a/config/train_flood.yaml b/config/train_flood.yaml new file mode 100644 index 0000000..221c1cc --- /dev/null +++ b/config/train_flood.yaml @@ -0,0 +1,59 @@ +# ============================================================ +# ClimateVision — Flood Detection Training Config +# ============================================================ + +# --- Data -------------------------------------------------- +data: + dir: data/processed/flood + image_size: 256 + batch_size: 8 + num_workers: 4 + use_weighted_sampler: true + pin_memory: true + +# --- Model ------------------------------------------------- +model: + architecture: flood_unet_s2only + in_channels: 3 + num_classes: 3 + encoder: efficientnet-b7 + +# --- Loss -------------------------------------------------- +loss: + type: combined + focal_weight: 0.5 + focal_alpha: 0.25 + focal_gamma: 2.0 + use_class_weights: true + +# --- Optimiser -------------------------------------------- +optimizer: + learning_rate: 1.0e-4 + weight_decay: 1.0e-4 + min_lr: 1.0e-6 + +# --- Schedule --------------------------------------------- +schedule: + epochs: 20 + warmup_epochs: 3 + checkpoint_interval: 5 + +# --- Regularisation / Tricks ------------------------------ +training: + mixed_precision: true + grad_clip: 1.0 + use_ema: true + ema_decay: 0.99 + early_stopping_patience: 10 + +# --- Outputs ---------------------------------------------- +output: + save_dir: models + run_name: "" + +# --- Normalisation stats ---------------------------------- +normalizer_stats: "" + +# --- Analysis type ---------------------------------------- +analysis: + type: flooding diff --git a/notebooks/02_flood_detection.ipynb b/notebooks/02_flood_detection.ipynb new file mode 100644 index 0000000..8ba6f1a --- /dev/null +++ b/notebooks/02_flood_detection.ipynb @@ -0,0 +1,314 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ClimateVision Flood Detection Validation Notebook\n", + "\n", + "This notebook demonstrates end-to-end flood detection using ClimateVision's production pipeline.\n", + "\n", + "**Requirements:**\n", + "- Trained flood model: `models/unet_flood.pth` (or `models/unet_flood_sar.pth` for SAR)\n", + "- GEE credentials (for real satellite data) OR sample GeoTIFF files\n", + "\n", + "**What it covers:**\n", + "1. Load trained model\n", + "2. Run inference on sample data\n", + "3. Visualize predictions (RGB, MNDWI, predicted mask)\n", + "4. Change detection (pre vs post event)\n", + "5. OSM road impact assessment\n", + "6. Compare against GFM ensemble baseline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.insert(0, '../src')\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import rasterio\n", + "from pathlib import Path\n", + "\n", + "from climatevision.inference.pipeline import run_inference_from_file, run_bitemporal_inference\n", + "from climatevision.models.flood_unet import build_flood_model\n", + "from climatevision.analysis.flooding_ensemble import EnsembleFloodPipeline\n", + "from climatevision.impact.osm_roads import assess_flood_impact\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Load Trained Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Path to trained weights\n", + "MODEL_PATH = '../models/unet_flood.pth'\n", + "\n", + "if Path(MODEL_PATH).exists():\n", + " model = build_flood_model(use_sar=False, weights_path=MODEL_PATH)\n", + " print(f\"Loaded flood model from {MODEL_PATH}\")\n", + "else:\n", + " print(f\"WARNING: Model not found at {MODEL_PATH}. Using untrained weights.\")\n", + " model = build_flood_model(use_sar=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Load Sample Data\n", + "\n", + "Use either:\n", + "- Real GeoTIFF from `data/processed/flood/test/images/`\n", + "- GEE download for a specific region and date" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Option A: Load from local test data\n", + "sample_dir = Path('../data/processed/flood/test/images')\n", + "sample_files = sorted(sample_dir.glob('*.tif'))\n", + "\n", + "if sample_files:\n", + " sample_path = str(sample_files[0])\n", + " with rasterio.open(sample_path) as src:\n", + " image = src.read().astype(np.float32)\n", + " print(f\"Loaded sample: {sample_path}, shape={image.shape}\")\n", + "else:\n", + " print(\"No local samples found. Generate test data first:\")\n", + " print(\" python scripts/prepare_data.py --mode synthetic --analysis-type flooding --n-patches 50\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Run Inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = run_inference_from_file(\n", + " sample_path,\n", + " analysis_type='flooding'\n", + ")\n", + "\n", + "print(\"Inference Result:\")\n", + "print(f\" Mean confidence: {result['inference']['mean_confidence']:.3f}\")\n", + "print(f\" Flooded: {result['inference'].get('flooded_percentage', 0):.2f}%\")\n", + "print(f\" Water: {result['inference'].get('water_percentage', 0):.2f}%\")\n", + "print(f\" Dry: {result['inference'].get('dry_percentage', 0):.2f}%\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Visualize Predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n", + "\n", + "# RGB composite (B03=Green as pseudo-R, B08=NIR as pseudo-G, B11=SWIR as pseudo-B)\n", + "rgb = np.stack([image[0], image[1], image[2]], axis=-1)\n", + "rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-8)\n", + "axes[0].imshow(rgb)\n", + "axes[0].set_title('Input (B03/B08/B11)')\n", + "axes[0].axis('off')\n", + "\n", + "# MNDWI\n", + "green = image[0].astype(np.float64)\n", + "swir = image[2].astype(np.float64)\n", + "mndwi = (green - swir) / (green + swir + 1e-8)\n", + "axes[1].imshow(mndwi, cmap='RdYlBu', vmin=-1, vmax=1)\n", + "axes[1].set_title('MNDWI')\n", + "axes[1].axis('off')\n", + "\n", + "# Predicted mask (we need to re-run to get the mask array)\n", + "import torch\n", + "from climatevision.inference.pipeline import _load_model\n", + "model_loaded, device = _load_model('flooding')\n", + "tensor = torch.FloatTensor(image.astype(np.float32).tolist()).unsqueeze(0).to(device)\n", + "with torch.no_grad():\n", + " pred = model_loaded(tensor).argmax(dim=1).squeeze().cpu().numpy()\n", + "\n", + "axes[2].imshow(pred, cmap='tab10', vmin=0, vmax=2)\n", + "axes[2].set_title('Predicted Mask')\n", + "axes[2].axis('off')\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Change Detection (Bitemporal)\n", + "\n", + "Simulate a pre-event and post-event pair to detect newly flooded areas." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use two different samples as pre/post (or same sample with modification)\n", + "if len(sample_files) >= 2:\n", + " with rasterio.open(sample_files[0]) as src:\n", + " pre_image = src.read().astype(np.float32)\n", + " with rasterio.open(sample_files[1]) as src:\n", + " post_image = src.read().astype(np.float32)\n", + " \n", + " change_result = run_bitemporal_inference(\n", + " pre_image, post_image,\n", + " analysis_type='flooding'\n", + " )\n", + " \n", + " cd = change_result['change_detection']\n", + " print(f\"Newly flooded: {cd['newly_flooded_percentage']:.2f}% ({cd['newly_flooded_pixels']} pixels)\")\n", + " print(f\"Receded: {cd['receded_percentage']:.2f}% ({cd['receded_pixels']} pixels)\")\n", + "else:\n", + " print(\"Need at least 2 samples for change detection.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. GFM Ensemble Baseline\n", + "\n", + "Compare the deep learning result against the physics-based ensemble fallback." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Simulate SAR VH backscatter from the optical data (simplified)\n", + "vh = -20.0 + 5.0 * (image[0] / image[0].max()) # rough approximation\n", + "\n", + "ensemble = EnsembleFloodPipeline()\n", + "ensemble_result = ensemble.detect(post_vh=vh)\n", + "\n", + "fig, axes = plt.subplots(1, 4, figsize=(16, 4))\n", + "axes[0].imshow(ensemble_result['list_mask'], cmap='gray')\n", + "axes[0].set_title('LIST (Change Det)')\n", + "axes[1].imshow(ensemble_result['dlr_mask'], cmap='gray')\n", + "axes[1].set_title('DLR (Otsu)')\n", + "axes[2].imshow(ensemble_result['tuw_mask'], cmap='gray')\n", + "axes[2].set_title('TUW (Bayesian)')\n", + "axes[3].imshow(ensemble_result['ensemble_mask'], cmap='gray')\n", + "axes[3].set_title('Ensemble (Majority Vote)')\n", + "for ax in axes:\n", + " ax.axis('off')\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. OSM Road Impact Assessment\n", + "\n", + "Requires `osmnx` to be installed. Falls back gracefully if unavailable." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use Nairobi bbox as example\n", + "nairobi_bbox = [36.7, -1.4, 37.0, -1.1]\n", + "\n", + "try:\n", + " impact = assess_flood_impact(\n", + " flood_mask=pred,\n", + " bbox=nairobi_bbox,\n", + " pixel_size_m=100 # GEE download scale\n", + " )\n", + " print(f\"Affected road km: {impact['affected_road_km']:.2f}\")\n", + "except Exception as exc:\n", + " print(f\"OSM impact assessment skipped: {exc}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This notebook validated:\n", + "- [x] Model loading and inference\n", + "- [x] MNDWI computation and visualization\n", + "- [x] 3-class segmentation mask prediction\n", + "- [x] Bitemporal change detection\n", + "- [x] GFM-style ensemble baseline comparison\n", + "- [x] OSM road impact assessment\n", + "\n", + "**Next steps for production:**\n", + "1. Train on real flood datasets (Sen1Floods11, WorldFloods)\n", + "2. Fine-tune on Kenya/Nairobi-specific events\n", + "3. Deploy API with trained weights\n", + "4. Set up automated GEE monitoring pipeline" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/climatevision/analysis/flooding_ensemble.py b/src/climatevision/analysis/flooding_ensemble.py new file mode 100644 index 0000000..1342507 --- /dev/null +++ b/src/climatevision/analysis/flooding_ensemble.py @@ -0,0 +1,224 @@ +""" +GFM-style ensemble flood detection using three independent SAR algorithms. + +No deep learning required — operates purely on Sentinel-1 backscatter +using well-established physics-based and statistical methods. + +Algorithms: + 1. LIST-style: change detection (pre/post differencing) + histogram thresholding + 2. DLR-style: tile-based Otsu thresholding on VH + fuzzy slope filtering + 3. TUW-style: per-pixel Bayesian classification using backscatter distributions + +Ensemble: majority vote (≥2 of 3 must agree to classify as flooded). +""" +from __future__ import annotations + +import logging +from typing import Optional + +import numpy as np + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# LIST-style change detection +# --------------------------------------------------------------------------- + +class LISTFloodDetector: + """ + Change-detection based flood mapping. + + Detects flooding by comparing a pre-event reference image to a post-event + image. Uses histogram thresholding on the backscatter difference. + """ + + def __init__(self, diff_threshold_db: float = -3.0): + self.diff_threshold_db = diff_threshold_db + + def detect( + self, + pre_vh: np.ndarray, + post_vh: np.ndarray, + ) -> np.ndarray: + """ + Args: + pre_vh: Pre-event VH backscatter in dB, shape (H, W). + post_vh: Post-event VH backscatter in dB, shape (H, W). + + Returns: + Binary flood mask (H, W), 1=flooded. + """ + diff = post_vh - pre_vh + # Flood typically lowers VH backscatter by several dB + flooded = diff < self.diff_threshold_db + return flooded.astype(np.uint8) + + +# --------------------------------------------------------------------------- +# DLR-style fuzzy thresholding +# --------------------------------------------------------------------------- + +class DLRFloodDetector: + """ + Hierarchical tile-based thresholding with fuzzy logic refinement. + + Uses Otsu thresholding on post-event VH backscatter, then refines using + terrain slope and water body size constraints. + """ + + def __init__( + self, + min_water_size: int = 10, + slope_mask: Optional[np.ndarray] = None, + ): + self.min_water_size = min_water_size + self.slope_mask = slope_mask + + def detect(self, post_vh: np.ndarray) -> np.ndarray: + """ + Args: + post_vh: Post-event VH backscatter in dB, shape (H, W). + + Returns: + Binary flood mask (H, W), 1=flooded. + """ + try: + from skimage.filters import threshold_otsu + from skimage.morphology import remove_small_objects + except ImportError: + raise ImportError("scikit-image is required for DLR detector. Install: pip install scikit-image") + + # Water has very low VH backscatter + thresh = threshold_otsu(post_vh) + water = post_vh < thresh + + # Remove small noise pixels + water = remove_small_objects(water, min_size=self.min_water_size) + + # Apply slope mask if provided (mask out steep terrain) + if self.slope_mask is not None: + water = water & (~self.slope_mask) + + return water.astype(np.uint8) + + +# --------------------------------------------------------------------------- +# TUW-style Bayesian classification +# --------------------------------------------------------------------------- + +class TUWFloodDetector: + """ + Bayesian flood classification using backscatter distribution modeling. + + Models water and land as Gaussian distributions in VH backscatter space, + then classifies each pixel by posterior probability. + """ + + def __init__( + self, + water_mean_db: float = -24.0, + water_std_db: float = 3.0, + land_mean_db: float = -18.0, + land_std_db: float = 4.0, + prior_water: float = 0.3, + ): + self.water_mean = water_mean_db + self.water_std = water_std_db + self.land_mean = land_mean_db + self.land_std = land_std_db + self.prior_water = prior_water + self.prior_land = 1.0 - prior_water + + def detect(self, post_vh: np.ndarray) -> np.ndarray: + """ + Args: + post_vh: Post-event VH backscatter in dB, shape (H, W). + + Returns: + Binary flood mask (H, W), 1=flooded. + """ + # Gaussian likelihoods + def _gaussian_pdf(x, mu, sigma): + return np.exp(-0.5 * ((x - mu) / sigma) ** 2) / (sigma * np.sqrt(2 * np.pi)) + + p_vh_water = _gaussian_pdf(post_vh, self.water_mean, self.water_std) + p_vh_land = _gaussian_pdf(post_vh, self.land_mean, self.land_std) + + # Posterior probability of water + posterior_water = (p_vh_water * self.prior_water) / ( + p_vh_water * self.prior_water + p_vh_land * self.prior_land + 1e-10 + ) + + flooded = posterior_water > 0.5 + return flooded.astype(np.uint8) + + +# --------------------------------------------------------------------------- +# Ensemble pipeline +# --------------------------------------------------------------------------- + +class EnsembleFloodPipeline: + """ + GFM-style ensemble combining LIST, DLR, and TUW detectors. + + A pixel is classified as flooded only if at least 2 of 3 algorithms agree. + """ + + def __init__( + self, + list_detector: Optional[LISTFloodDetector] = None, + dlr_detector: Optional[DLRFloodDetector] = None, + tuw_detector: Optional[TUWFloodDetector] = None, + ): + self.list_det = list_detector or LISTFloodDetector() + self.dlr_det = dlr_detector or DLRFloodDetector() + self.tuw_det = tuw_detector or TUWFloodDetector() + + def detect( + self, + post_vh: np.ndarray, + pre_vh: Optional[np.ndarray] = None, + ) -> dict[str, np.ndarray]: + """ + Run all three detectors and return ensemble result. + + Args: + post_vh: Post-event VH backscatter in dB, shape (H, W). + pre_vh: Optional pre-event VH for change detection. + + Returns: + Dict with keys: + - list_mask: LIST detector result + - dlr_mask: DLR detector result + - tuw_mask: TUW detector result + - ensemble_mask: Majority vote result + - agreement: Number of algorithms agreeing per pixel (0-3) + """ + list_mask = ( + self.list_det.detect(pre_vh, post_vh) + if pre_vh is not None + else np.zeros_like(post_vh, dtype=np.uint8) + ) + dlr_mask = self.dlr_det.detect(post_vh) + tuw_mask = self.tuw_det.detect(post_vh) + + # Stack and sum votes + votes = list_mask.astype(np.uint8) + dlr_mask.astype(np.uint8) + tuw_mask.astype(np.uint8) + ensemble_mask = (votes >= 2).astype(np.uint8) + + logger.info( + "Ensemble vote counts: 0=%d, 1=%d, 2=%d, 3=%d", + int((votes == 0).sum()), + int((votes == 1).sum()), + int((votes == 2).sum()), + int((votes == 3).sum()), + ) + + return { + "list_mask": list_mask, + "dlr_mask": dlr_mask, + "tuw_mask": tuw_mask, + "ensemble_mask": ensemble_mask, + "agreement": votes, + } diff --git a/src/climatevision/data/sar_preprocessing.py b/src/climatevision/data/sar_preprocessing.py new file mode 100644 index 0000000..ecb6da1 --- /dev/null +++ b/src/climatevision/data/sar_preprocessing.py @@ -0,0 +1,172 @@ +""" +Sentinel-1 SAR preprocessing for flood detection. + +Handles speckle filtering, terrain flattening, and backscatter conversion +for C-band VV/VH imagery from COPERNICUS/S1_GRD. +""" +from __future__ import annotations + +import logging +from typing import Optional + +import numpy as np + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Speckle filtering +# --------------------------------------------------------------------------- + +class RefinedLeeSpeckleFilter: + """ + Refined Lee adaptive speckle filter for SAR imagery. + + Uses local statistics (mean, variance) within a sliding window to + adaptively smooth homogeneous regions while preserving edges. + + Reference: + Lee, J.-S. (1981). Speckle analysis and smoothing of synthetic + aperture radar images. Computer Graphics and Image Processing. + """ + + def __init__(self, window_size: int = 7, num_looks: float = 1.0): + assert window_size % 2 == 1, "window_size must be odd" + self.window_size = window_size + self.half = window_size // 2 + self.num_looks = num_looks + self.cu = 1.0 / np.sqrt(num_looks) # theoretical speckle std/mean + + def __call__(self, image: np.ndarray) -> np.ndarray: + """ + Apply filter to a (H, W) or (C, H, W) array. + + Args: + image: Linear intensity or amplitude image (NOT dB). + + Returns: + Filtered image with same shape. + """ + if image.ndim == 2: + return self._filter_band(image) + elif image.ndim == 3: + return np.stack([self._filter_band(image[i]) for i in range(image.shape[0])], axis=0) + else: + raise ValueError(f"image must be 2-D or 3-D, got shape {image.shape}") + + def _filter_band(self, band: np.ndarray) -> np.ndarray: + from scipy.ndimage import uniform_filter + + band = band.astype(np.float64) + h, w = band.shape + + # Local mean and mean-of-squares + mean = uniform_filter(band, size=self.window_size, mode="reflect") + mean_sq = uniform_filter(band ** 2, size=self.window_size, mode="reflect") + var = mean_sq - mean ** 2 + var = np.clip(var, 0, None) + + std = np.sqrt(var) + cv = std / (mean + 1e-8) # coefficient of variation + + # Refined Lee weights + # Three cases: homogeneous, heterogeneous, point target + cu2 = self.cu ** 2 + cmax2 = 2.0 * cu2 # upper threshold for heterogeneous region + + weight = np.zeros_like(band) + homogeneous = cv <= self.cu + heterogeneous = (cv > self.cu) & (cv < np.sqrt(cmax2)) + point_target = cv >= np.sqrt(cmax2) + + # Homogeneous: full filtering + weight[homogeneous] = 1.0 + # Heterogeneous: adaptive weight + weight[heterogeneous] = (cu2 * (cv[heterogeneous] ** 2 - cu2)) / ( + cv[heterogeneous] ** 2 * (cmax2 - cu2) + 1e-8 + ) + # Point target: no filtering + weight[point_target] = 0.0 + + filtered = mean + weight * (band - mean) + return filtered.astype(np.float32) + + +# --------------------------------------------------------------------------- +# Backscatter conversion +# --------------------------------------------------------------------------- + +def linear_to_db(image: np.ndarray, eps: float = 1e-10) -> np.ndarray: + """Convert linear intensity/amplitude to decibel scale.""" + return 10.0 * np.log10(np.clip(image, eps, None)) + + +def db_to_linear(image_db: np.ndarray) -> np.ndarray: + """Convert decibel scale back to linear intensity.""" + return 10.0 ** (image_db / 10.0) + + +# --------------------------------------------------------------------------- +# Terrain masking +# --------------------------------------------------------------------------- + +def apply_slope_mask( + sar_image: np.ndarray, + dem_slope: np.ndarray, + max_slope_deg: float = 15.0, +) -> np.ndarray: + """ + Mask steep slopes where SAR layover/shadow corrupts flood detection. + + Args: + sar_image: (C, H, W) or (H, W) SAR image. + dem_slope: (H, W) slope in degrees from DEM. + max_slope_deg: Pixels with slope > this are masked to NaN. + + Returns: + Masked SAR image. + """ + steep = dem_slope > max_slope_deg + masked = sar_image.copy() + if masked.ndim == 3: + masked[:, steep] = np.nan + else: + masked[steep] = np.nan + return masked + + +# --------------------------------------------------------------------------- +# Preprocessing pipeline +# --------------------------------------------------------------------------- + +def preprocess_sar( + image: np.ndarray, + apply_filter: bool = True, + to_db: bool = True, + dem_slope: Optional[np.ndarray] = None, +) -> np.ndarray: + """ + Full SAR preprocessing pipeline. + + Args: + image: (C, H, W) array with VV/VH in linear intensity. + apply_filter: Apply Refined Lee speckle filter. + to_db: Convert output to decibel scale. + dem_slope: Optional (H, W) slope mask. + + Returns: + Preprocessed (C, H, W) array. + """ + out = image.astype(np.float32) + + if apply_filter: + flt = RefinedLeeSpeckleFilter(window_size=7) + out = flt(out) + + if to_db: + out = linear_to_db(out) + + if dem_slope is not None: + out = apply_slope_mask(out, dem_slope) + + return out diff --git a/src/climatevision/impact/__init__.py b/src/climatevision/impact/__init__.py new file mode 100644 index 0000000..822d8ba --- /dev/null +++ b/src/climatevision/impact/__init__.py @@ -0,0 +1,13 @@ +from .osm_roads import ( + download_roads, + rasterize_roads, + calculate_affected_road_km, + assess_flood_impact, +) + +__all__ = [ + "download_roads", + "rasterize_roads", + "calculate_affected_road_km", + "assess_flood_impact", +] diff --git a/src/climatevision/impact/osm_roads.py b/src/climatevision/impact/osm_roads.py new file mode 100644 index 0000000..bcb6918 --- /dev/null +++ b/src/climatevision/impact/osm_roads.py @@ -0,0 +1,209 @@ +""" +OpenStreetMap road network integration for flood impact assessment. + +Downloads highways within a bounding box, rasterizes them to match the +flood prediction mask, and computes affected road length. +""" +from __future__ import annotations + +import logging +from typing import Any + +import numpy as np + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# OSM road download +# --------------------------------------------------------------------------- + +def download_roads( + bbox: list[float], + road_types: list[str] | None = None, +) -> Any: + """ + Download road network from OpenStreetMap within a bounding box. + + Args: + bbox: [west, south, east, north] in WGS84. + road_types: OSM highway types to include. Defaults to main roads. + + Returns: + GeoDataFrame with road LineStrings. + + Raises: + ImportError: If osmnx is not installed. + """ + try: + import osmnx as ox + except ImportError: + raise ImportError( + "osmnx is required for OSM road download. " + "Install: pip install osmnx" + ) + + if road_types is None: + road_types = [ + "motorway", "trunk", "primary", "secondary", "tertiary", + "residential", "unclassified", "road", + ] + + west, south, east, north = bbox + gdf = ox.features.features_from_bbox( + (west, south, east, north), + tags={"highway": road_types}, + ) + + if gdf.empty: + logger.warning("No roads found in bbox %s", bbox) + return gdf + + gdf = gdf[gdf.geometry.type == "LineString"].copy() + return gdf + + +def download_buildings( + bbox: list[float], +) -> Any: + """ + Download building footprints from OpenStreetMap. + + Args: + bbox: [west, south, east, north] in WGS84. + + Returns: + GeoDataFrame with building Polygons. + """ + try: + import osmnx as ox + except ImportError: + raise ImportError( + "osmnx is required for OSM building download. " + "Install: pip install osmnx" + ) + + west, south, east, north = bbox + gdf = ox.features.features_from_bbox( + (west, south, east, north), + tags={"building": True}, + ) + + if gdf.empty: + logger.warning("No buildings found in bbox %s", bbox) + return gdf + + gdf = gdf[gdf.geometry.type == "Polygon"].copy() + return gdf + + +# --------------------------------------------------------------------------- +# Rasterization +# --------------------------------------------------------------------------- + +def rasterize_roads( + roads_gdf: Any, + raster_shape: tuple[int, int], + transform: Any, +) -> np.ndarray: + """ + Rasterize road LineStrings to a binary mask. + + Args: + roads_gdf: GeoDataFrame with LineString geometries. + raster_shape: (height, width) of output mask. + transform: Affine transform (from rasterio.DatasetReader.transform). + + Returns: + (H, W) uint8 binary mask, 1=road. + """ + try: + import rasterio.features + except ImportError: + raise ImportError("rasterio is required for rasterization") + + if roads_gdf.empty: + return np.zeros(raster_shape, dtype=np.uint8) + + shapes = ((geom, 1) for geom in roads_gdf.geometry) + mask = rasterio.features.rasterize( + shapes, + out_shape=raster_shape, + transform=transform, + fill=0, + dtype=np.uint8, + ) + return mask + + +# --------------------------------------------------------------------------- +# Impact calculation +# --------------------------------------------------------------------------- + +def calculate_affected_road_km( + flood_mask: np.ndarray, + road_mask: np.ndarray, + pixel_size_m: float = 10.0, +) -> float: + """ + Compute total length of roads inundated by flood. + + Args: + flood_mask: (H, W) binary mask, 1=flooded. + road_mask: (H, W) binary mask, 1=road. + pixel_size_m: Spatial resolution in metres per pixel. + + Returns: + Affected road length in kilometres. + """ + flooded_roads = (flood_mask > 0) & (road_mask > 0) + pixel_count = int(flooded_roads.sum()) + km = pixel_count * pixel_size_m / 1000.0 + return round(km, 3) + + +# --------------------------------------------------------------------------- +# High-level impact assessment +# --------------------------------------------------------------------------- + +def assess_flood_impact( + flood_mask: np.ndarray, + bbox: list[float], + pixel_size_m: float = 10.0, +) -> dict[str, Any]: + """ + Full flood impact assessment for a given bbox. + + Args: + flood_mask: (H, W) integer prediction mask. + bbox: [west, south, east, north] in WGS84. + pixel_size_m: Spatial resolution. + + Returns: + Dict with affected_road_km and raw masks. + """ + try: + import rasterio.transform + except ImportError: + raise ImportError("rasterio is required for impact assessment") + + h, w = flood_mask.shape + transform = rasterio.transform.from_bounds( + bbox[0], bbox[1], bbox[2], bbox[3], w, h + ) + + binary_flood = (flood_mask == 2).astype(np.uint8) # class 2 = flooded + + try: + roads_gdf = download_roads(bbox) + road_mask = rasterize_roads(roads_gdf, (h, w), transform) + affected_road_km = calculate_affected_road_km(binary_flood, road_mask, pixel_size_m) + except Exception as exc: + logger.warning("Road impact assessment failed: %s", exc) + affected_road_km = 0.0 + + return { + "affected_road_km": affected_road_km, + "bbox": bbox, + "pixel_size_m": pixel_size_m, + } diff --git a/src/climatevision/models/flood_unet.py b/src/climatevision/models/flood_unet.py new file mode 100644 index 0000000..364b5ea --- /dev/null +++ b/src/climatevision/models/flood_unet.py @@ -0,0 +1,227 @@ +""" +Production flood detection models. + +Uses segmentation-models-pytorch (smp) with EfficientNet-B7 encoder, +pretrained on ImageNet. Two variants: + - FloodUNet: 5-channel input (S2 B03/B08/B11 + S1 VV/VH) + - FloodUNetS2Only: 3-channel input (S2 B03/B08/B11) + +Both output 3 classes: dry_land, permanent_water, flooded. +""" +from __future__ import annotations + +import logging +from typing import Optional + +import torch +import torch.nn as nn + +logger = logging.getLogger(__name__) + + +class _FirstConvAdapter(nn.Module): + """ + Adapts a pretrained encoder's first conv layer to accept a different + number of input channels by averaging pretrained weights across the + extra channels. + """ + + def __init__(self, encoder: nn.Module, new_in_channels: int): + super().__init__() + self.encoder = encoder + self.new_in_channels = new_in_channels + + # Find the first conv layer + first_conv = None + for module in encoder.modules(): + if isinstance(module, nn.Conv2d): + first_conv = module + break + + if first_conv is None: + raise ValueError("Could not find first Conv2d in encoder") + + if first_conv.in_channels == new_in_channels: + return # No adaptation needed + + # Replace with adapted conv + old_weight = first_conv.weight.data # (out_ch, old_in_ch, k, k) + out_ch, old_in_ch, kH, kW = old_weight.shape + + new_conv = nn.Conv2d( + new_in_channels, + out_ch, + kernel_size=(kH, kW), + stride=first_conv.stride, + padding=first_conv.padding, + bias=first_conv.bias is not None, + ) + + # Initialize new channels by replicating averaged pretrained weights + with torch.no_grad(): + new_weight = new_conv.weight.data + n_repeat = new_in_channels // old_in_ch + n_remain = new_in_channels % old_in_ch + + for i in range(n_repeat): + new_weight[:, i * old_in_ch : (i + 1) * old_in_ch] = old_weight + if n_remain > 0: + new_weight[:, n_repeat * old_in_ch :] = old_weight[:, :n_remain] + + if first_conv.bias is not None: + new_conv.bias.data.copy_(first_conv.bias.data) + + # Replace the conv in the encoder + def _replace_first_conv(parent: nn.Module, child_name: str, new_module: nn.Module) -> None: + setattr(parent, child_name, new_module) + + found = False + for name, module in encoder.named_modules(): + if module is first_conv: + # Navigate to parent + parts = name.split(".") + parent = encoder + for part in parts[:-1]: + parent = getattr(parent, part) + _replace_first_conv(parent, parts[-1], new_conv) + found = True + break + + if not found: + raise RuntimeError("Failed to replace first conv in encoder") + + logger.info( + "Adapted encoder first conv: %d → %d input channels", + old_in_ch, new_in_channels, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.encoder(x) + + +class FloodUNet(nn.Module): + """ + U-Net++ with EfficientNet-B7 encoder for flood detection. + Input: 5 channels [B03, B08, B11, VV, VH] + Output: 3 classes [dry_land, permanent_water, flooded] + """ + + def __init__(self, in_channels: int = 5, num_classes: int = 3, encoder_name: str = "efficientnet-b7"): + super().__init__() + self.n_channels = in_channels + self.n_classes = num_classes + + try: + import segmentation_models_pytorch as smp + except ImportError: + raise ImportError( + "segmentation-models-pytorch is required for FloodUNet. " + "Install: pip install segmentation-models-pytorch" + ) + + self.model = smp.UnetPlusPlus( + encoder_name=encoder_name, + encoder_weights="imagenet", + in_channels=in_channels, + classes=num_classes, + activation=None, + ) + + # The smp model already handles in_channels adaptation internally, + # but we log it for transparency. + logger.info( + "FloodUNet initialized: encoder=%s, in_channels=%d, classes=%d", + encoder_name, in_channels, num_classes, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x) + + def predict(self, x: torch.Tensor) -> torch.Tensor: + """Return class probabilities.""" + with torch.no_grad(): + logits = self.forward(x) + return torch.softmax(logits, dim=1) + + def predict_classes(self, x: torch.Tensor) -> torch.Tensor: + """Return predicted class indices.""" + with torch.no_grad(): + probs = self.predict(x) + return probs.argmax(dim=1) + + +class FloodUNetS2Only(nn.Module): + """ + U-Net++ with EfficientNet-B7 encoder for optical-only flood detection. + Input: 3 channels [B03, B08, B11] + Output: 3 classes [dry_land, permanent_water, flooded] + """ + + def __init__(self, in_channels: int = 3, num_classes: int = 3, encoder_name: str = "efficientnet-b7"): + super().__init__() + self.n_channels = in_channels + self.n_classes = num_classes + + try: + import segmentation_models_pytorch as smp + except ImportError: + raise ImportError( + "segmentation-models-pytorch is required for FloodUNetS2Only. " + "Install: pip install segmentation-models-pytorch" + ) + + self.model = smp.UnetPlusPlus( + encoder_name=encoder_name, + encoder_weights="imagenet", + in_channels=in_channels, + classes=num_classes, + activation=None, + ) + + logger.info( + "FloodUNetS2Only initialized: encoder=%s, in_channels=%d, classes=%d", + encoder_name, in_channels, num_classes, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x) + + def predict(self, x: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + logits = self.forward(x) + return torch.softmax(logits, dim=1) + + def predict_classes(self, x: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + probs = self.predict(x) + return probs.argmax(dim=1) + + +def build_flood_model( + use_sar: bool = False, + encoder_name: str = "efficientnet-b7", + weights_path: Optional[str] = None, +) -> nn.Module: + """ + Factory function to build the appropriate flood model. + + Args: + use_sar: If True, build 5-channel S2+S1 model. Otherwise 3-channel S2-only. + encoder_name: Encoder backbone name (must be supported by smp). + weights_path: Optional path to load pretrained weights from. + + Returns: + Initialized flood model. + """ + if use_sar: + model = FloodUNet(in_channels=5, num_classes=3, encoder_name=encoder_name) + else: + model = FloodUNetS2Only(in_channels=3, num_classes=3, encoder_name=encoder_name) + + if weights_path is not None: + state = torch.load(weights_path, map_location="cpu") + model_state = state.get("model_state_dict", state) + model.load_state_dict(model_state, strict=False) + logger.info("Loaded flood model weights from %s", weights_path) + + return model diff --git a/tests/test_api_flood.py b/tests/test_api_flood.py new file mode 100644 index 0000000..5434435 --- /dev/null +++ b/tests/test_api_flood.py @@ -0,0 +1,49 @@ +""" +End-to-end API tests for flood detection endpoints. +""" +from __future__ import annotations + +import json +import pytest +from fastapi.testclient import TestClient + +from climatevision.api.main import create_app + + +@pytest.fixture +def client(): + app = create_app() + return TestClient(app) + + +class TestFloodPrediction: + def test_health_endpoint(self, client): + response = client.get("/api/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert "flooding" in data["analysis_types"] + + def test_predict_without_api_key(self, client): + """Should return 401 without API key.""" + response = client.post( + "/api/predict", + json={ + "kind": "gee", + "analysis_type": "flooding", + "bbox": [36.7, -1.4, 37.0, -1.1], + "start_date": "2024-04-01", + "end_date": "2024-04-10", + }, + ) + assert response.status_code == 401 + + def test_predict_flooding_analysis_type_exists(self, client): + """Flooding should be listed as an enabled analysis type.""" + response = client.get("/api/analysis-types") + assert response.status_code == 200 + types = response.json() + flooding = next((t for t in types if t["name"] == "flooding"), None) + assert flooding is not None + assert flooding["enabled"] is True + assert flooding["bands"] == ["B03", "B08", "B11"] diff --git a/tests/test_flooding.py b/tests/test_flooding.py new file mode 100644 index 0000000..d9f3921 --- /dev/null +++ b/tests/test_flooding.py @@ -0,0 +1,72 @@ +""" +Tests for flood detection analysis module. +""" +from __future__ import annotations + +import numpy as np +import pytest + +from climatevision.analysis.flooding import FloodingAnalysis + + +class TestFloodingAnalysis: + def test_preprocess_normalizes_image(self): + analysis = FloodingAnalysis() + img = np.random.randint(0, 255, (256, 256, 3)).astype(np.float32) + out = analysis.preprocess(img) + assert out.dtype == np.float32 + assert out.shape == (256, 256, 3) + assert out.max() <= 1.0 + + def test_water_index_classification(self): + analysis = FloodingAnalysis() + # Create synthetic image: strong water signature + img = np.zeros((256, 256, 3), dtype=np.float32) + img[..., 0] = 0.8 # Green (high) + img[..., 2] = 0.1 # SWIR (low) + pred, conf = analysis._water_index_classification(img) + assert pred.shape == (256, 256) + assert 0 <= conf <= 1.0 + # Most pixels should be classified as water (1) or flooded (2) + water_pixels = (pred == 1).sum() + (pred == 2).sum() + assert water_pixels > 100 + + def test_calculate_metrics(self): + analysis = FloodingAnalysis() + prediction = np.zeros((256, 256), dtype=np.int32) + prediction[100:150, 100:150] = 2 # flooded patch + bbox = [36.7, -1.4, 37.0, -1.1] + metrics = analysis.calculate_metrics(prediction, (256, 256), bbox=bbox) + assert "flooded_percentage" in metrics + assert "flooded_area_km2" in metrics + assert metrics["flooded_percentage"] > 0 + + def test_generate_alerts_critical(self): + analysis = FloodingAnalysis() + metrics = {"flooded_percentage": 25.0, "flooded_area_km2": 10.0} + alerts = analysis.generate_alerts(metrics) + assert len(alerts) == 1 + assert alerts[0].severity.value == "critical" + assert "Critical Flooding" in alerts[0].title + + def test_generate_alerts_warning(self): + analysis = FloodingAnalysis() + metrics = {"flooded_percentage": 8.0, "flooded_area_km2": 2.0} + alerts = analysis.generate_alerts(metrics) + assert len(alerts) == 1 + assert alerts[0].severity.value == "high" + assert "Flooding Detected" in alerts[0].title + + def test_generate_alerts_no_alert(self): + analysis = FloodingAnalysis() + metrics = {"flooded_percentage": 1.0} + alerts = analysis.generate_alerts(metrics) + assert len(alerts) == 0 + + def test_generate_alerts_rapid_expansion(self): + analysis = FloodingAnalysis() + prev = {"flooded_percentage": 5.0} + curr = {"flooded_percentage": 20.0} + alerts = analysis.generate_alerts(curr, previous_metrics=prev) + alert_types = [a.alert_type for a in alerts] + assert "rapid_flood_expansion" in alert_types diff --git a/tests/test_sar_preprocessing.py b/tests/test_sar_preprocessing.py new file mode 100644 index 0000000..028ecba --- /dev/null +++ b/tests/test_sar_preprocessing.py @@ -0,0 +1,54 @@ +""" +Tests for SAR preprocessing module. +""" +from __future__ import annotations + +import numpy as np +import pytest + +from climatevision.data.sar_preprocessing import ( + RefinedLeeSpeckleFilter, + linear_to_db, + db_to_linear, + apply_slope_mask, +) + + +class TestRefinedLeeSpeckleFilter: + def test_reduces_variance(self): + rng = np.random.default_rng(42) + # Create image with multiplicative speckle noise (SAR-like) + base = rng.uniform(0.3, 0.8, (128, 128)).astype(np.float32) + speckle = rng.gamma(1.0, 1.0, (128, 128)).astype(np.float32) + image = base * speckle + flt = RefinedLeeSpeckleFilter(window_size=7) + filtered = flt(image) + assert filtered.shape == image.shape + assert filtered.dtype == np.float32 + # Variance should be reduced for speckled images + assert filtered.var() < image.var() + + def test_3d_input(self): + rng = np.random.default_rng(42) + image = rng.normal(0.5, 0.2, (2, 128, 128)).astype(np.float32) + flt = RefinedLeeSpeckleFilter(window_size=7) + filtered = flt(image) + assert filtered.shape == image.shape + + +class TestBackscatterConversion: + def test_linear_to_db_roundtrip(self): + linear = np.array([0.01, 0.1, 1.0, 10.0], dtype=np.float32) + db = linear_to_db(linear) + back = db_to_linear(db) + np.testing.assert_allclose(back, linear, rtol=1e-5) + + +class TestSlopeMask: + def test_masks_steep_slopes(self): + image = np.ones((100, 100), dtype=np.float32) + slope = np.zeros((100, 100), dtype=np.float32) + slope[50:60, 50:60] = 20.0 # steep + masked = apply_slope_mask(image, slope, max_slope_deg=15.0) + assert np.isnan(masked[55, 55]) + assert not np.isnan(masked[10, 10]) From e6ded96f3b5de070b4cfa7488eee64577df148b0 Mon Sep 17 00:00:00 2001 From: goldokpa Date: Tue, 16 Jun 2026 05:21:15 +0100 Subject: [PATCH 2/4] feat(flooding): wire SAR ensemble into API with permanent-vs-flood classification - flood_classification: separate flood from permanent water via reference (JRC GSW occurrence) or pre/post change detection; refuses to guess when neither is available - ensemble: detect() now returns a 3-class classified_mask when a reference or pre-event scene is supplied - flooding_sar: FloodingSARAnalysis (Sentinel-1 VV/VH) behind the standard analysis contract, registered and discoverable - gee_downloader: Sentinel-1 GRD + JRC Global Surface Water fetch (synthetic fallback, explicitly tagged) - inference/flood_pipeline: bbox -> S1 + JRC -> ensemble -> 3-class result - api: /api/predict dispatches flooding_sar to the SAR pipeline; type listed in analysis-types/health; config + [flood] extra added - scripts: eval_flood.py (real IoU/F1/precision/recall on Sen1Floods11) and download_sen1floods11.py - tests: classification + SAR analysis + registry/API exposure Co-Authored-By: Claude Opus 4.8 --- config.yaml | 21 ++ scripts/download_sen1floods11.py | 141 ++++++++++ scripts/eval_flood.py | 243 ++++++++++++++++++ setup.py | 7 + .../analysis/flood_classification.py | 117 +++++++++ .../analysis/flooding_ensemble.py | 57 +++- src/climatevision/analysis/flooding_sar.py | 182 +++++++++++++ src/climatevision/analysis/registry.py | 8 +- src/climatevision/api/main.py | 33 ++- src/climatevision/data/__init__.py | 8 +- src/climatevision/data/gee_downloader.py | 183 +++++++++++++ src/climatevision/inference/flood_pipeline.py | 104 ++++++++ tests/test_flood_classification.py | 71 +++++ tests/test_flooding_sar.py | 115 +++++++++ 14 files changed, 1278 insertions(+), 12 deletions(-) create mode 100644 scripts/download_sen1floods11.py create mode 100644 scripts/eval_flood.py create mode 100644 src/climatevision/analysis/flood_classification.py create mode 100644 src/climatevision/analysis/flooding_sar.py create mode 100644 src/climatevision/inference/flood_pipeline.py create mode 100644 tests/test_flood_classification.py create mode 100644 tests/test_flooding_sar.py diff --git a/config.yaml b/config.yaml index 2ce5c8a..a6a6562 100644 --- a/config.yaml +++ b/config.yaml @@ -75,6 +75,27 @@ analysis_types: - "flooded_area_km2" - "mndwi_stats" + # Flood Detection (SAR / Sentinel-1) -- all-weather, ensemble-based, no trained + # weights required. Separates permanent water from flood given a JRC GSW + # reference or a pre-event scene. + flooding_sar: + enabled: true + display_name: "Flood Detection (SAR)" + description: "All-weather flood detection from Sentinel-1 VV/VH using a physics-based ensemble" + model: + architecture: "ensemble" # LIST + DLR + TUW majority vote; no neural weights + in_channels: 2 # VV, VH + num_classes: 3 + bands: ["VV", "VH"] + classes: ["dry_land", "permanent_water", "flooded"] + thresholds: + alert_flood_area: 5.0 + critical_flood_area: 20.0 + metrics: + - "flooded_percentage" + - "flooded_area_km2" + - "permanent_water_km2" + # Drought Monitoring drought: enabled: false # Not yet implemented diff --git a/scripts/download_sen1floods11.py b/scripts/download_sen1floods11.py new file mode 100644 index 0000000..db275d6 --- /dev/null +++ b/scripts/download_sen1floods11.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python +""" +Download the Sen1Floods11 benchmark for flood-detection evaluation. + +Sen1Floods11 (Bonafilia et al., 2020) provides Sentinel-1 (VV/VH, sigma0 in dB) +and Sentinel-2 chips with flood/water labels for 11 global flood events. The +446 *hand-labeled* chips are the gold standard; the ~4,300 weakly-labeled chips +(Otsu / permanent-water) are useful for training but NOT for reporting accuracy. + +Layout produced under : + / + v1.1/data/flood_events/HandLabeled/S1Hand/*.tif # 2 bands: VV, VH (dB) + /v1.1/data/flood_events/HandLabeled/LabelHand/*.tif # -1 nodata, 0 land, 1 water + /v1.1/splits/flood_handlabeled/flood_{train,valid,test}_data.csv + +Usage: + # Hand-labeled only (small, ~a few GB) -- enough to compute test accuracy + python scripts/download_sen1floods11.py --subset handlabeled --dest data/sen1floods11 + + # Everything (hand + weak labels, tens of GB) + python scripts/download_sen1floods11.py --subset all --dest data/sen1floods11 + + # Show what would be copied without downloading + python scripts/download_sen1floods11.py --subset handlabeled --dry-run + +Requirements: + Google Cloud SDK -- either `gcloud storage` (preferred) or `gsutil`. + The bucket is public; no auth is required for read. + +IMPORTANT: dataset hosting moves around (Radiant MLHub sunset; mirrors on +Source Cooperative / Hugging Face). VERIFY the --bucket value below is still the +live location before relying on it. Override with --bucket if it has moved. +""" +from __future__ import annotations + +import argparse +import shutil +import subprocess +import sys +from pathlib import Path + +# Canonical public GCS bucket as of early 2026. VERIFY before use (see module docstring). +DEFAULT_BUCKET = "gs://sen1floods11-data/v1.1" + +# Sub-paths within the bucket, relative to /data/flood_events/ +SUBSET_PATHS = { + "handlabeled": [ + "HandLabeled/S1Hand", + "HandLabeled/LabelHand", + ], + "weak": [ + "WeaklyLabeled/S1Weak", + "WeaklyLabeled/S2IndexLabelWeak", + ], +} +# Splits CSVs live alongside the data directory. +SPLITS_SUBPATH = "splits/flood_handlabeled" + + +def _find_gcs_tool() -> list[str]: + """Return the argv prefix for a working GCS copy tool, or exit with guidance.""" + if shutil.which("gcloud"): + return ["gcloud", "storage", "rsync", "-r"] + if shutil.which("gsutil"): + return ["gsutil", "-m", "rsync", "-r"] + sys.exit( + "ERROR: neither `gcloud` nor `gsutil` found on PATH.\n" + "Install the Google Cloud SDK: https://cloud.google.com/sdk/docs/install\n" + "The Sen1Floods11 bucket is public, so no login is needed after install." + ) + + +def _rsync(tool: list[str], src: str, dst: Path, dry_run: bool) -> None: + dst.mkdir(parents=True, exist_ok=True) + cmd = list(tool) + if dry_run: + # Both gcloud storage rsync and gsutil rsync support -n for dry-run. + cmd.append("-n") + cmd += [src, str(dst)] + print(f" $ {' '.join(cmd)}") + result = subprocess.run(cmd) + if result.returncode != 0: + sys.exit( + f"ERROR: copy failed for {src} (exit {result.returncode}).\n" + "If this is a 'bucket not found' / 404 error, the dataset has likely " + "moved -- re-run with --bucket pointing at the current mirror." + ) + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument( + "--subset", + choices=["handlabeled", "weak", "all"], + default="handlabeled", + help="Which split to fetch. 'handlabeled' is enough to report accuracy (default).", + ) + parser.add_argument("--dest", type=Path, default=Path("data/sen1floods11"), help="Local destination root.") + parser.add_argument("--bucket", default=DEFAULT_BUCKET, help="GCS bucket root (verify it is current).") + parser.add_argument("--dry-run", action="store_true", help="List what would be copied without downloading.") + args = parser.parse_args() + + tool = _find_gcs_tool() + data_root = f"{args.bucket}/data/flood_events" + + subsets = ["handlabeled", "weak"] if args.subset == "all" else [args.subset] + + print(f"Sen1Floods11 download") + print(f" bucket : {args.bucket}") + print(f" dest : {args.dest.resolve()}") + print(f" subset : {args.subset}") + print(f" tool : {tool[0]}") + print() + + for subset in subsets: + for rel in SUBSET_PATHS[subset]: + src = f"{data_root}/{rel}" + dst = args.dest / "v1.1" / "data" / "flood_events" / rel + print(f"[{subset}] {rel}") + _rsync(tool, src, dst, args.dry_run) + + # Always grab the split CSVs (tiny) so eval can use the official test split. + splits_src = f"{args.bucket}/{SPLITS_SUBPATH}" + splits_dst = args.dest / "v1.1" / SPLITS_SUBPATH + print(f"[splits] {SPLITS_SUBPATH}") + _rsync(tool, splits_src, splits_dst, args.dry_run) + + print() + if args.dry_run: + print("Dry run complete -- nothing was written.") + else: + print("Download complete. Evaluate the ensemble with:") + print( + f" python scripts/eval_flood.py " + f"--data-root {args.dest} --split test" + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/eval_flood.py b/scripts/eval_flood.py new file mode 100644 index 0000000..35d70cb --- /dev/null +++ b/scripts/eval_flood.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python +""" +Evaluate the SAR flood-detection ensemble against labeled ground truth. + +Runs EnsembleFloodPipeline (or a single detector) over Sen1Floods11 chips and +reports surface-water detection metrics against the hand labels. Produces REAL +numbers -- it computes nothing unless pointed at real labeled data. + +What it measures +---------------- +Sen1Floods11 hand labels mark *surface water* (permanent + flood water together) +as class 1, dry land as 0, and no-data/cloud as -1. The SAR ensemble detects +open water from backscatter, so this evaluates **water detection**, which is the +core skill behind flood mapping. It is NOT a pure "flood-only" metric, because +the benchmark does not separate flood water from permanent water in its masks. +Report it as such. + +Caveat on the change-detection branch: Sen1Floods11 hand-labeled chips have no +pre-event image, so the LIST (pre/post differencing) detector cannot run here. +With --detector ensemble and no pre-event, the majority vote reduces to "DLR AND +TUW must agree". Use Kuro Siwo (which ships pre/post pairs) to exercise LIST. + +Metrics (for the water/positive class): precision, recall, F1, IoU (Jaccard), +overall pixel accuracy, plus the raw confusion matrix. Pixels labeled -1 are +ignored. + +Usage: + python scripts/eval_flood.py --data-root data/sen1floods11 --split test + python scripts/eval_flood.py --data-root data/sen1floods11 --detector tuw + python scripts/eval_flood.py --data-root data/sen1floods11 --limit 20 --json out.json + +Requirements: rasterio (GeoTIFF I/O), numpy, scikit-image (for the DLR detector). +""" +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +import numpy as np + +# Make `climatevision` importable when run from a source checkout. +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from climatevision.analysis.flooding_ensemble import ( # noqa: E402 + EnsembleFloodPipeline, + DLRFloodDetector, + TUWFloodDetector, +) + +# Sen1Floods11 label encoding. +LABEL_NODATA = -1 +LABEL_LAND = 0 +LABEL_WATER = 1 + + +def find_pairs(data_root: Path, split: str | None, limit: int | None) -> list[tuple[Path, Path]]: + """Pair each LabelHand tif with its matching S1Hand tif. + + Pairs by the shared `_` prefix so this is robust to minor + differences in the split-CSV format across dataset versions. If a split + CSV is present it is used to filter to that split; otherwise all + hand-labeled chips are evaluated. + """ + hand = data_root / "v1.1" / "data" / "flood_events" / "HandLabeled" + s1_dir = hand / "S1Hand" + label_dir = hand / "LabelHand" + if not label_dir.is_dir() or not s1_dir.is_dir(): + sys.exit( + f"ERROR: expected Sen1Floods11 HandLabeled dirs under {hand}.\n" + "Download first: python scripts/download_sen1floods11.py --dest " + f"{data_root}" + ) + + # Index S1 chips by their Region_id prefix. + s1_by_key = {} + for p in s1_dir.glob("*.tif"): + key = p.name.replace("_S1Hand.tif", "") + s1_by_key[key] = p + + allow_keys = _load_split_keys(data_root, split) + + pairs: list[tuple[Path, Path]] = [] + for label_path in sorted(label_dir.glob("*_LabelHand.tif")): + key = label_path.name.replace("_LabelHand.tif", "") + if allow_keys is not None and key not in allow_keys: + continue + s1_path = s1_by_key.get(key) + if s1_path is None: + print(f" warning: no S1 chip for label {label_path.name}, skipping") + continue + pairs.append((s1_path, label_path)) + + if limit is not None: + pairs = pairs[:limit] + return pairs + + +def _load_split_keys(data_root: Path, split: str | None) -> set[str] | None: + """Return the set of `Region_id` keys for the requested split, or None.""" + if not split: + return None + csv = data_root / "v1.1" / "splits" / "flood_handlabeled" / f"flood_{split}_data.csv" + if not csv.is_file(): + print(f" note: split file {csv.name} not found; evaluating ALL hand-labeled chips") + return None + keys: set[str] = set() + for line in csv.read_text().splitlines(): + line = line.strip() + if not line: + continue + # Rows reference filenames like 'Bolivia_103757_S1Hand.tif'; extract the prefix. + for field in line.replace(",", " ").split(): + name = Path(field).name + for suffix in ("_S1Hand.tif", "_LabelHand.tif", "_S2Hand.tif"): + if name.endswith(suffix): + keys.add(name.replace(suffix, "")) + return keys or None + + +def _read_band(path: Path, band: int) -> np.ndarray: + try: + import rasterio + except ImportError: + sys.exit("ERROR: rasterio is required to read GeoTIFFs. Install: pip install rasterio") + with rasterio.open(path) as ds: + return ds.read(band).astype(np.float32) + + +def _predict_water(detector: str, post_vh: np.ndarray) -> np.ndarray: + """Return a binary water/flood mask (1=water) for the chosen detector.""" + if detector == "ensemble": + # No pre-event image in Sen1Floods11 hand labels -> LIST branch is skipped. + out = EnsembleFloodPipeline().detect(post_vh=post_vh, pre_vh=None) + return out["ensemble_mask"] + if detector == "dlr": + return DLRFloodDetector().detect(post_vh) + if detector == "tuw": + return TUWFloodDetector().detect(post_vh) + raise ValueError(f"unknown detector {detector!r}") + + +def evaluate(pairs, detector: str, vh_band: int) -> dict: + """Accumulate a confusion matrix over all chips and derive metrics.""" + tp = fp = fn = tn = 0 + per_scene_iou: list[float] = [] + + for i, (s1_path, label_path) in enumerate(pairs, 1): + vh = _read_band(s1_path, vh_band) + label = _read_band(label_path, 1).astype(np.int32) + + valid = label != LABEL_NODATA + if not valid.any(): + continue + + pred = _predict_water(detector, vh).astype(bool) + gt = label == LABEL_WATER + + p = pred[valid] + g = gt[valid] + + s_tp = int(np.sum(p & g)) + s_fp = int(np.sum(p & ~g)) + s_fn = int(np.sum(~p & g)) + s_tn = int(np.sum(~p & ~g)) + tp += s_tp + fp += s_fp + fn += s_fn + tn += s_tn + + denom = s_tp + s_fp + s_fn + per_scene_iou.append(s_tp / denom if denom else 1.0) + + if i % 25 == 0 or i == len(pairs): + print(f" processed {i}/{len(pairs)} chips") + + return _metrics(tp, fp, fn, tn, per_scene_iou) + + +def _metrics(tp, fp, fn, tn, per_scene_iou) -> dict: + precision = tp / (tp + fp) if (tp + fp) else 0.0 + recall = tp / (tp + fn) if (tp + fn) else 0.0 + f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0 + iou = tp / (tp + fp + fn) if (tp + fp + fn) else 0.0 + total = tp + fp + fn + tn + accuracy = (tp + tn) / total if total else 0.0 + return { + "precision": round(precision, 4), + "recall": round(recall, 4), + "f1": round(f1, 4), + "iou": round(iou, 4), + "pixel_accuracy": round(accuracy, 4), + "mean_iou_per_scene": round(float(np.mean(per_scene_iou)), 4) if per_scene_iou else 0.0, + "confusion_matrix": {"tp": tp, "fp": fp, "fn": fn, "tn": tn}, + "n_pixels_evaluated": total, + } + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--data-root", type=Path, default=Path("data/sen1floods11"), help="Sen1Floods11 root.") + parser.add_argument("--split", default="test", help="Split to evaluate: test/valid/train, or '' for all chips.") + parser.add_argument( + "--detector", + choices=["ensemble", "dlr", "tuw"], + default="ensemble", + help="Which detector to evaluate (default: ensemble).", + ) + parser.add_argument("--vh-band", type=int, default=2, help="1-based band index of VH in S1 chips (VV=1, VH=2).") + parser.add_argument("--limit", type=int, default=None, help="Evaluate only the first N chips (quick check).") + parser.add_argument("--json", type=Path, default=None, help="Optional path to write the metrics as JSON.") + args = parser.parse_args() + + pairs = find_pairs(args.data_root, args.split or None, args.limit) + if not pairs: + sys.exit("ERROR: no (S1, label) chip pairs found. Check --data-root and that the download completed.") + + print(f"Evaluating detector='{args.detector}' on {len(pairs)} chips (split={args.split or 'all'})") + metrics = evaluate(pairs, args.detector, args.vh_band) + + print("\n=== Sen1Floods11 surface-water detection metrics ===") + print(f" detector : {args.detector}") + print(f" chips : {len(pairs)}") + print(f" precision : {metrics['precision']}") + print(f" recall : {metrics['recall']}") + print(f" F1 : {metrics['f1']}") + print(f" IoU (water) : {metrics['iou']}") + print(f" mean IoU/scene : {metrics['mean_iou_per_scene']}") + print(f" pixel accuracy : {metrics['pixel_accuracy']}") + print(f" confusion (px) : {metrics['confusion_matrix']}") + print("\nNote: measures surface-water detection (permanent + flood water), per the") + print("Sen1Floods11 label definition -- not flood-only. See script docstring.") + + if args.json: + payload = {"detector": args.detector, "split": args.split, "n_chips": len(pairs), **metrics} + args.json.write_text(json.dumps(payload, indent=2)) + print(f"\nWrote metrics to {args.json}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/setup.py b/setup.py index 9e1b1ce..2231b30 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,13 @@ "processing": [ "dask[complete]>=2023.1.0", ], + "flood": [ + # SAR ensemble + permanent/flood classification + OSM road impact + eval + "scikit-image>=0.19.0", # DLR Otsu / morphology + "rasterio>=1.3.0", # GeoTIFF I/O (needs system GDAL) + "osmnx>=1.6.0", # OSM road/building download for impact + "segmentation-models-pytorch>=0.3.0", # optional FloodUNet path + ], "satellite": [ "sentinelsat>=1.1.0", "earthengine-api>=0.1.340", diff --git a/src/climatevision/analysis/flood_classification.py b/src/climatevision/analysis/flood_classification.py new file mode 100644 index 0000000..2eb0ecb --- /dev/null +++ b/src/climatevision/analysis/flood_classification.py @@ -0,0 +1,117 @@ +""" +Separate flood water from permanent water. + +A single post-event image cannot tell you whether detected water is a flood or a +lake/river that is always there -- the backscatter (or water index) looks the +same. Distinguishing the two requires a *reference* for where water normally is. +This module provides the two standard, defensible ways to get that reference: + + 1. Reference subtraction (`classify_with_reference`) + Overlay a permanent-water layer (e.g. JRC Global Surface Water occurrence + >= ~50%, or any pre-computed permanent mask). Water that coincides with the + reference is permanent; water outside it is flood. + + 2. Change detection (`classify_with_change`) + Run water detection on a pre-event scene as well. Water present in both + pre and post is permanent; water that appears only post-event is flood. + +Output classes match FloodingAnalysis.output_classes: + 0 = dry_land, 1 = permanent_water, 2 = flooded + +Design choice: when NO reference is available, these functions raise rather than +fabricate a permanent/flood split. A guessed distinction on a disaster-response +product is worse than an honest "water, source unknown". +""" +from __future__ import annotations + +import numpy as np + +DRY_LAND = 0 +PERMANENT_WATER = 1 +FLOODED = 2 + + +def classify_with_reference( + water_mask: np.ndarray, + permanent_water_ref: np.ndarray, +) -> np.ndarray: + """Classify detected water against a permanent-water reference. + + Args: + water_mask: (H, W) binary mask of water detected in the post-event scene + (1 = water). Typically `EnsembleFloodPipeline.detect(...)["ensemble_mask"]`. + permanent_water_ref: (H, W) binary mask, 1 where water is normally present + (e.g. from `permanent_water_from_occurrence`). Must match water_mask shape. + + Returns: + (H, W) int array: 0=dry, 1=permanent_water, 2=flooded. + """ + water = _as_bool(water_mask) + perm = _as_bool(permanent_water_ref) + if water.shape != perm.shape: + raise ValueError( + f"water_mask {water.shape} and permanent_water_ref {perm.shape} must match. " + "Reproject/resample the reference to the scene grid first." + ) + + out = np.full(water.shape, DRY_LAND, dtype=np.int32) + out[water & perm] = PERMANENT_WATER + out[water & ~perm] = FLOODED + return out + + +def classify_with_change( + pre_water_mask: np.ndarray, + post_water_mask: np.ndarray, +) -> np.ndarray: + """Classify water by pre/post change detection. + + Water present before *and* after the event is treated as permanent; water + that appears only after the event is flood. Water that was present before but + not after (receding / dried out) is returned as dry land. + + Args: + pre_water_mask: (H, W) binary water mask from the pre-event scene. + post_water_mask: (H, W) binary water mask from the post-event scene. + + Returns: + (H, W) int array: 0=dry, 1=permanent_water, 2=flooded. + """ + pre = _as_bool(pre_water_mask) + post = _as_bool(post_water_mask) + if pre.shape != post.shape: + raise ValueError( + f"pre_water_mask {pre.shape} and post_water_mask {post.shape} must match. " + "Co-register the pre/post scenes first." + ) + + out = np.full(post.shape, DRY_LAND, dtype=np.int32) + out[post & pre] = PERMANENT_WATER + out[post & ~pre] = FLOODED + return out + + +def permanent_water_from_occurrence( + occurrence: np.ndarray, + threshold_pct: float = 50.0, +) -> np.ndarray: + """Derive a permanent-water mask from a surface-water occurrence layer. + + JRC Global Surface Water ("JRC/GSW1_4/GlobalSurfaceWater", band "occurrence") + gives, per pixel, the % of observations in which water was present (0-100). + Pixels at or above `threshold_pct` are treated as permanent water. + + Args: + occurrence: (H, W) array of occurrence percentages in [0, 100]. + threshold_pct: occurrence at/above which a pixel counts as permanent. + + Returns: + (H, W) uint8 binary permanent-water mask. + """ + occ = np.asarray(occurrence, dtype=np.float32) + return (occ >= threshold_pct).astype(np.uint8) + + +def _as_bool(mask: np.ndarray) -> np.ndarray: + arr = np.asarray(mask) + return arr > 0 diff --git a/src/climatevision/analysis/flooding_ensemble.py b/src/climatevision/analysis/flooding_ensemble.py index 1342507..ebcb8e5 100644 --- a/src/climatevision/analysis/flooding_ensemble.py +++ b/src/climatevision/analysis/flooding_ensemble.py @@ -18,6 +18,11 @@ import numpy as np +from climatevision.analysis.flood_classification import ( + classify_with_change, + classify_with_reference, +) + logger = logging.getLogger(__name__) @@ -179,21 +184,31 @@ def detect( self, post_vh: np.ndarray, pre_vh: Optional[np.ndarray] = None, + permanent_water_ref: Optional[np.ndarray] = None, ) -> dict[str, np.ndarray]: """ Run all three detectors and return ensemble result. Args: post_vh: Post-event VH backscatter in dB, shape (H, W). - pre_vh: Optional pre-event VH for change detection. + pre_vh: Optional pre-event VH. Enables the LIST change detector and, + if no `permanent_water_ref` is given, lets the pipeline separate + flood from permanent water by pre/post change. + permanent_water_ref: Optional (H, W) binary mask of normally-present + water (e.g. from JRC Global Surface Water occurrence). When given, + it is the authority for the permanent/flood split. Returns: Dict with keys: - list_mask: LIST detector result - dlr_mask: DLR detector result - tuw_mask: TUW detector result - - ensemble_mask: Majority vote result + - ensemble_mask: Binary water majority-vote result (1 = water) - agreement: Number of algorithms agreeing per pixel (0-3) + - classified_mask: 3-class map (0=dry, 1=permanent_water, + 2=flooded) if a reference or pre-event scene was supplied, + else None. It is deliberately None when neither is available -- + the permanent/flood split cannot be inferred from one scene. """ list_mask = ( self.list_det.detect(pre_vh, post_vh) @@ -215,10 +230,48 @@ def detect( int((votes == 3).sum()), ) + classified_mask = self._classify_permanent_vs_flood( + ensemble_mask, post_vh, pre_vh, permanent_water_ref + ) + return { "list_mask": list_mask, "dlr_mask": dlr_mask, "tuw_mask": tuw_mask, "ensemble_mask": ensemble_mask, "agreement": votes, + "classified_mask": classified_mask, } + + def _classify_permanent_vs_flood( + self, + ensemble_mask: np.ndarray, + post_vh: np.ndarray, + pre_vh: Optional[np.ndarray], + permanent_water_ref: Optional[np.ndarray], + ) -> Optional[np.ndarray]: + """Split the binary water mask into permanent vs flood, if possible. + + Priority: an explicit permanent-water reference wins; otherwise fall back + to pre/post change detection. With neither, returns None instead of + guessing. + """ + if permanent_water_ref is not None: + return classify_with_reference(ensemble_mask, permanent_water_ref) + + if pre_vh is not None: + # Derive a pre-event water mask the same way (DLR + TUW agreement). + pre_water = ( + (self.dlr_det.detect(pre_vh).astype(np.uint8) + + self.tuw_det.detect(pre_vh).astype(np.uint8)) >= 2 + ).astype(np.uint8) + # classify_with_change keys off post-water presence, so the result is + # confined to where the ensemble sees water now. + return classify_with_change(pre_water, ensemble_mask) + + logger.warning( + "No permanent-water reference or pre-event scene supplied; cannot " + "separate flood from permanent water. Returning binary water only " + "(classified_mask=None)." + ) + return None diff --git a/src/climatevision/analysis/flooding_sar.py b/src/climatevision/analysis/flooding_sar.py new file mode 100644 index 0000000..92032e4 --- /dev/null +++ b/src/climatevision/analysis/flooding_sar.py @@ -0,0 +1,182 @@ +""" +SAR-based flood detection analysis (Sentinel-1 VV/VH). + +Wraps the physics/statistics-based EnsembleFloodPipeline (LIST + DLR + TUW) and +the permanent-vs-flood classifier behind the standard BaseAnalysisType contract, +so it is discoverable through the registry and runnable through the API. + +Unlike the optical FloodingAnalysis (MNDWI), this works in cloud and at night +(SAR is all-weather) and -- given a permanent-water reference or a pre-event +scene -- genuinely separates flood water from permanent water rather than +guessing from index magnitude. +""" +from __future__ import annotations + +import logging +from typing import Any, Optional + +import numpy as np + +from climatevision.analysis.base import Alert, BaseAnalysisType, Severity +from climatevision.analysis.flooding_ensemble import EnsembleFloodPipeline +from climatevision.data.sar_preprocessing import preprocess_sar + +logger = logging.getLogger(__name__) + +DRY_LAND = 0 +PERMANENT_WATER = 1 +FLOODED = 2 + + +class FloodingSARAnalysis(BaseAnalysisType): + """All-weather SAR flood detection via a 3-algorithm ensemble.""" + + name = "flooding_sar" + display_name = "Flood Detection (SAR)" + description = "All-weather flood detection from Sentinel-1 VV/VH using a physics-based ensemble" + + # Sentinel-1 dual-pol bands. + required_bands = ["VV", "VH"] + output_classes = ["dry_land", "permanent_water", "flooded"] + enabled = True + + default_thresholds = { + "alert_flood_area": 5.0, + "critical_flood_area": 20.0, + } + + def __init__( + self, + permanent_water_ref: Optional[np.ndarray] = None, + pre_event_vh: Optional[np.ndarray] = None, + ): + """ + Args: + permanent_water_ref: Optional (H, W) binary mask of normally-present + water (e.g. from JRC GSW occurrence). Authority for the + permanent/flood split when provided. + pre_event_vh: Optional (H, W) pre-event VH (dB) for change detection, + used when no reference is supplied. + """ + self.permanent_water_ref = permanent_water_ref + self.pre_event_vh = pre_event_vh + self._pipeline = EnsembleFloodPipeline() + + def preprocess(self, image: np.ndarray, bands: Optional[list[str]] = None) -> np.ndarray: + """Speckle-filter and convert VV/VH to dB. Returns (2, H, W).""" + is_valid, error = self.validate_input(image) + if not is_valid: + raise ValueError(error) + + # Normalise to (C, H, W) with C=2 (VV, VH). + arr = np.asarray(image, dtype=np.float32) + if arr.ndim == 3 and arr.shape[-1] in (2, 3) and arr.shape[-1] < arr.shape[0]: + arr = np.transpose(arr, (2, 0, 1)) + if arr.ndim == 3 and arr.shape[0] > 2: + arr = arr[:2] + if arr.ndim == 2: + arr = np.stack([arr, arr], axis=0) + + # S1_GRD is already in dB; only apply speckle filtering here to avoid + # double log-scaling. (Linear input should set to_db=True upstream.) + return preprocess_sar(arr, apply_filter=True, to_db=False) + + def run_inference( + self, image: np.ndarray, model: Optional[Any] = None, + ) -> tuple[np.ndarray, float]: + """Run the ensemble on the VH band and classify permanent vs flood. + + Returns (prediction, confidence) where prediction is the 3-class map + (0=dry, 1=permanent_water, 2=flooded). If neither a permanent-water + reference nor a pre-event scene is available, permanent water cannot be + separated and detected water is reported as class 2 (flooded) with a + lowered confidence to signal the ambiguity. + """ + vh = image[1] if image.ndim == 3 else image + + out = self._pipeline.detect( + post_vh=vh, + pre_vh=self.pre_event_vh, + permanent_water_ref=self.permanent_water_ref, + ) + classified = out["classified_mask"] + + water_frac = float(out["ensemble_mask"].mean()) + if classified is not None: + # Higher confidence when we could actually resolve permanent vs flood. + confidence = round(min(1.0, 0.7 + 0.3 * water_frac), 4) + return classified.astype(np.int32), confidence + + # No reference: cannot distinguish -> mark water as flooded, flag via confidence. + prediction = (out["ensemble_mask"].astype(np.int32)) * FLOODED + logger.warning( + "flooding_sar: no permanent-water reference or pre-event scene; " + "reporting detected water as flooded (permanent/flood unresolved)." + ) + return prediction, round(min(1.0, 0.5 + 0.2 * water_frac), 4) + + def calculate_metrics( + self, prediction: np.ndarray, image_size: tuple[int, int], bbox: Optional[list[float]] = None, + ) -> dict[str, Any]: + h, w = image_size + total = h * w + dry = int(np.sum(prediction == DRY_LAND)) + permanent = int(np.sum(prediction == PERMANENT_WATER)) + flooded = int(np.sum(prediction == FLOODED)) + + flooded_pct = (flooded / total * 100) if total else 0.0 + permanent_pct = (permanent / total * 100) if total else 0.0 + + metrics: dict[str, Any] = { + "image_size": [h, w], + "dry_pixels": dry, + "permanent_water_pixels": permanent, + "flooded_pixels": flooded, + "flooded_percentage": round(flooded_pct, 4), + "permanent_water_percentage": round(permanent_pct, 4), + "permanent_flood_distinguished": bool(permanent > 0 or self.permanent_water_ref is not None + or self.pre_event_vh is not None), + } + + if bbox and len(bbox) == 4: + min_lon, min_lat, max_lon, max_lat = bbox + avg_lat = (min_lat + max_lat) / 2 + lat_km = abs(max_lat - min_lat) * 111 + lon_km = abs(max_lon - min_lon) * 111 * np.cos(np.radians(avg_lat)) + area = lat_km * lon_km + if total: + metrics["total_area_km2"] = round(area, 2) + metrics["flooded_area_km2"] = round(area * flooded / total, 2) + metrics["permanent_water_km2"] = round(area * permanent / total, 2) + return metrics + + def generate_alerts( + self, metrics: dict[str, Any], thresholds: Optional[dict[str, float]] = None, + previous_metrics: Optional[dict[str, Any]] = None, + ) -> list[Alert]: + thresholds = thresholds or self.default_thresholds + flooded_pct = metrics.get("flooded_percentage", 0.0) + flooded_km2 = metrics.get("flooded_area_km2") + critical = thresholds.get("critical_flood_area", 20.0) + alert_at = thresholds.get("alert_flood_area", 5.0) + + alerts: list[Alert] = [] + if flooded_pct >= critical: + msg = f"Critical flooding: {flooded_pct:.1f}% of area flooded" + if flooded_km2: + msg += f" ({flooded_km2:.1f} km²)" + alerts.append(Alert( + alert_type="critical_flooding", severity=Severity.CRITICAL, + title="Critical Flooding Detected", message=msg, + threshold_exceeded=critical, measured_value=flooded_pct, + )) + elif flooded_pct >= alert_at: + msg = f"Flooding detected: {flooded_pct:.1f}% of area flooded" + if flooded_km2: + msg += f" ({flooded_km2:.1f} km²)" + alerts.append(Alert( + alert_type="flooding_detected", severity=Severity.HIGH, + title="Flooding Detected", message=msg, + threshold_exceeded=alert_at, measured_value=flooded_pct, + )) + return alerts diff --git a/src/climatevision/analysis/registry.py b/src/climatevision/analysis/registry.py index 6a138f0..26827f2 100644 --- a/src/climatevision/analysis/registry.py +++ b/src/climatevision/analysis/registry.py @@ -211,5 +211,11 @@ def _ensure_builtins_registered() -> None: _registry.register(FloodingAnalysis, override=True) except ImportError as e: logger.warning(f"Could not import FloodingAnalysis: {e}") - + + try: + from climatevision.analysis.flooding_sar import FloodingSARAnalysis + _registry.register(FloodingSARAnalysis, override=True) + except ImportError as e: + logger.warning(f"Could not import FloodingSARAnalysis: {e}") + _registry._initialized = True diff --git a/src/climatevision/api/main.py b/src/climatevision/api/main.py index 5c9fb2d..138a32a 100644 --- a/src/climatevision/api/main.py +++ b/src/climatevision/api/main.py @@ -43,6 +43,7 @@ mark_alert_delivered, ) from climatevision.inference import run_inference_from_file, run_inference_from_gee +from climatevision.inference.flood_pipeline import run_flood_inference_from_gee from climatevision.api.auth import require_api_key from climatevision.governance import explain_prediction, SHAPExplainer @@ -51,7 +52,7 @@ # ===== Type Definitions ===== -AnalysisType = Literal["deforestation", "ice_melting", "flooding", "drought", "wildfire"] +AnalysisType = Literal["deforestation", "ice_melting", "flooding", "flooding_sar", "drought", "wildfire"] OrganizationType = Literal["ngo", "government", "research", "corporate"] NotificationChannel = Literal["email", "webhook", "api", "sms"] AlertSeverity = Literal["low", "medium", "high", "critical"] @@ -81,6 +82,14 @@ "bands": ["B03", "B08", "B11"], "classes": ["water", "flooded", "dry_land"], }, + { + "name": "flooding_sar", + "display_name": "Flood Detection (SAR)", + "description": "All-weather flood detection from Sentinel-1 VV/VH using a physics-based ensemble", + "enabled": True, + "bands": ["VV", "VH"], + "classes": ["dry_land", "permanent_water", "flooded"], + }, { "name": "drought", "display_name": "Drought Monitoring", @@ -615,14 +624,22 @@ async def predict_json( ) run_id = int(cur.lastrowid) - # Run inference + # Run inference. SAR flood detection has its own Sentinel-1 + JRC pipeline; + # all other analysis types use the shared Sentinel-2 inference path. try: - result_payload = run_inference_from_gee( - bbox=body.bbox, - start_date=body.start_date, - end_date=body.end_date, - analysis_type=body.analysis_type, - ) + if body.analysis_type == "flooding_sar": + result_payload = run_flood_inference_from_gee( + bbox=body.bbox, + start_date=body.start_date, + end_date=body.end_date, + ) + else: + result_payload = run_inference_from_gee( + bbox=body.bbox, + start_date=body.start_date, + end_date=body.end_date, + analysis_type=body.analysis_type, + ) result_payload["analysis_type"] = body.analysis_type status = "completed" except Exception as exc: diff --git a/src/climatevision/data/__init__.py b/src/climatevision/data/__init__.py index 232f42d..801dbb2 100644 --- a/src/climatevision/data/__init__.py +++ b/src/climatevision/data/__init__.py @@ -2,7 +2,11 @@ from .augmentation import get_train_transforms, get_val_transforms from .preprocessing import Sentinel2Normalizer, compute_dataset_stats, apply_scl_cloud_mask from .synthetic import generate_synthetic_dataset -from .gee_downloader import download_tile_for_analysis +from .gee_downloader import ( + download_tile_for_analysis, + download_sar_tile, + download_permanent_water_occurrence, +) from .band_mapping import ( get_bands_for_analysis, get_bands_for_analysis_with_scl, @@ -40,6 +44,8 @@ "generate_synthetic_dataset", # GEE "download_tile_for_analysis", + "download_sar_tile", + "download_permanent_water_occurrence", # Band mapping "get_bands_for_analysis", "get_bands_for_analysis_with_scl", diff --git a/src/climatevision/data/gee_downloader.py b/src/climatevision/data/gee_downloader.py index fa65f0b..ccc4c17 100644 --- a/src/climatevision/data/gee_downloader.py +++ b/src/climatevision/data/gee_downloader.py @@ -194,6 +194,189 @@ def download_tile_for_analysis( return out_path, metadata +def download_sar_tile( + bbox: list[float], + start_date: str, + end_date: str, + output_dir: str | Path | None = None, + scale_m: int = 30, +) -> tuple[Path, dict[str, Any]]: + """ + Download a Sentinel-1 GRD VV/VH composite (sigma0, dB) for flood detection. + + Uses COPERNICUS/S1_GRD, IW mode, ascending+descending merged via median. + Falls back to a synthetic SAR tile (explicitly tagged) when GEE is + unavailable or no scenes are found. + + Returns: + (file_path, metadata). Band order in the GeoTIFF is [VV, VH] in dB. + """ + if output_dir is None: + output_dir = _SATELLITE_DIR + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + safe_start = start_date.replace("-", "") + safe_end = end_date.replace("-", "") + stem = f"sar_{safe_start}_{safe_end}_{'_'.join(str(round(c, 4)) for c in bbox)}" + out_path = output_dir / f"{stem}.tif" + + try: + ee = _initialize_ee() + rasterio = __import__("rasterio") + except Exception as exc: + logger.warning("GEE unavailable for SAR (%s). Using synthetic SAR fallback.", exc) + return _generate_synthetic_sar_tile(bbox, start_date, end_date, out_path) + + region = ee.Geometry.Rectangle(bbox) + collection = ( + ee.ImageCollection("COPERNICUS/S1_GRD") + .filterBounds(region) + .filterDate(start_date, end_date) + .filter(ee.Filter.eq("instrumentMode", "IW")) + .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VV")) + .filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VH")) + .select(["VV", "VH"]) + ) + + count = collection.size().getInfo() + if count == 0: + logger.warning("No S1 scenes for %s to %s. Using synthetic SAR fallback.", start_date, end_date) + return _generate_synthetic_sar_tile(bbox, start_date, end_date, out_path) + + # S1_GRD is already terrain-corrected sigma0 in dB. + image = collection.median().clip(region) + url = image.getDownloadURL({"region": region, "scale": scale_m, "format": "GEO_TIFF"}) + + tmp = tempfile.mktemp(suffix=".tif") + urllib.request.urlretrieve(url, tmp) + with rasterio.open(tmp) as src: + data = src.read().astype(np.float32) + profile = src.profile + os.unlink(tmp) + + profile.update(driver="GTiff", dtype="float32", count=data.shape[0]) + with rasterio.open(out_path, "w", **profile) as dst: + dst.write(data) + + metadata: dict[str, Any] = { + "source": "gee", + "collection": "COPERNICUS/S1_GRD", + "bbox": bbox, + "start_date": start_date, + "end_date": end_date, + "bands": ["VV", "VH"], + "scale_m": scale_m, + "images_available": count, + "is_synthetic": False, + "shape": list(data.shape), + } + logger.info("Downloaded S1 SAR tile to %s (%d scenes)", out_path, count) + return out_path, metadata + + +def download_permanent_water_occurrence( + bbox: list[float], + output_dir: str | Path | None = None, + scale_m: int = 30, +) -> tuple[Optional[Path], dict[str, Any]]: + """ + Download JRC Global Surface Water 'occurrence' (%, 0-100) for the bbox. + + Occurrence is the fraction of valid observations (1984-present) in which a + pixel was water. Thresholding it (see permanent_water_from_occurrence) yields + the permanent-water reference used to separate flood from permanent water. + + Returns: + (file_path_or_None, metadata). Returns (None, {...is_synthetic:True}) + when GEE is unavailable -- callers should then derive a synthetic + reference or skip the permanent/flood split rather than guess. + """ + if output_dir is None: + output_dir = _SATELLITE_DIR + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + stem = f"gsw_occurrence_{'_'.join(str(round(c, 4)) for c in bbox)}" + out_path = output_dir / f"{stem}.tif" + + try: + ee = _initialize_ee() + rasterio = __import__("rasterio") + except Exception as exc: + logger.warning("GEE unavailable for JRC GSW (%s). No permanent-water reference.", exc) + return None, {"source": "unavailable", "bbox": bbox, "is_synthetic": True} + + region = ee.Geometry.Rectangle(bbox) + occurrence = ee.Image("JRC/GSW1_4/GlobalSurfaceWater").select("occurrence").clip(region) + url = occurrence.getDownloadURL({"region": region, "scale": scale_m, "format": "GEO_TIFF"}) + + tmp = tempfile.mktemp(suffix=".tif") + urllib.request.urlretrieve(url, tmp) + with rasterio.open(tmp) as src: + data = src.read(1).astype(np.float32) + profile = src.profile + os.unlink(tmp) + + profile.update(driver="GTiff", dtype="float32", count=1) + with rasterio.open(out_path, "w", **profile) as dst: + dst.write(data[np.newaxis, :, :]) + + metadata = { + "source": "gee", + "asset": "JRC/GSW1_4/GlobalSurfaceWater", + "band": "occurrence", + "bbox": bbox, + "scale_m": scale_m, + "is_synthetic": False, + "shape": list(data.shape), + } + logger.info("Downloaded JRC GSW occurrence to %s", out_path) + return out_path, metadata + + +def _generate_synthetic_sar_tile( + bbox: list[float], + start_date: str, + end_date: str, + out_path: Path, +) -> tuple[Path, dict[str, Any]]: + """Synthetic Sentinel-1 VV/VH tile (dB), explicitly tagged is_synthetic.""" + rasterio = __import__("rasterio") + + tile_size = _get_default_tile_size() + h, w = tile_size, tile_size + seed = int(abs(sum(v * 1000 * (i + 1) for i, v in enumerate(bbox)))) % (2 ** 31) + rng = np.random.default_rng(seed) + + # Land ~ -10 dB, water ~ -22 dB; carve a water region into the scene. + vv = rng.normal(-9.0, 2.5, (h, w)).astype(np.float32) + vh = rng.normal(-15.0, 2.5, (h, w)).astype(np.float32) + water = np.zeros((h, w), dtype=bool) + water[h // 3 : 2 * h // 3, w // 4 : 3 * w // 4] = True + vv[water] = rng.normal(-20.0, 1.5, int(water.sum())) + vh[water] = rng.normal(-26.0, 1.5, int(water.sum())) + data = np.stack([vv, vh], axis=0) + + transform = rasterio.transform.from_bounds(bbox[0], bbox[1], bbox[2], bbox[3], w, h) + profile = { + "driver": "GTiff", "dtype": "float32", "count": 2, + "height": h, "width": w, "crs": "EPSG:4326", "transform": transform, + } + with rasterio.open(out_path, "w", **profile) as dst: + dst.write(data) + + metadata: dict[str, Any] = { + "source": "synthetic_fallback", + "collection": "COPERNICUS/S1_GRD", + "bbox": bbox, "start_date": start_date, "end_date": end_date, + "bands": ["VV", "VH"], "scale_m": 30, "images_available": 0, + "is_synthetic": True, "shape": list(data.shape), + } + logger.info("Generated synthetic SAR fallback tile to %s", out_path) + return out_path, metadata + + def _generate_synthetic_tile( bbox: list[float], start_date: str, diff --git a/src/climatevision/inference/flood_pipeline.py b/src/climatevision/inference/flood_pipeline.py new file mode 100644 index 0000000..5f519ee --- /dev/null +++ b/src/climatevision/inference/flood_pipeline.py @@ -0,0 +1,104 @@ +""" +SAR flood-detection inference pipeline (bbox -> 3-class flood result). + +Orchestrates the full path the API uses for flooding requests: + 1. Download a Sentinel-1 VV/VH tile for the bbox/date range (synthetic fallback). + 2. Download the JRC Global Surface Water occurrence layer -> permanent-water + reference (skipped when GEE is unavailable; the split is then unresolved). + 3. Run FloodingSARAnalysis (ensemble + permanent/flood classifier). + 4. Return a result dict in the same shape as run_inference_from_gee. + +All heavy/geo dependencies are imported lazily so importing this module is cheap. +""" +from __future__ import annotations + +import logging +from typing import Any, Optional + +import numpy as np + +from climatevision.analysis.flooding_sar import FloodingSARAnalysis +from climatevision.analysis.flood_classification import permanent_water_from_occurrence + +logger = logging.getLogger(__name__) + + +def run_flood_inference_from_gee( + *, + bbox: Optional[list[float]] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + occurrence_threshold_pct: float = 50.0, + analysis_type: str = "flooding_sar", +) -> dict[str, Any]: + """Run SAR flood detection for a bbox/date range. Never raises for missing + data -- falls back to a synthetic SAR scene tagged is_synthetic=True.""" + from climatevision.data import ( + download_sar_tile, + download_permanent_water_occurrence, + ) + + sar_path, sar_meta = download_sar_tile( + bbox=bbox, start_date=start_date, end_date=end_date, + ) + image = _read_tile(str(sar_path)) # (2, H, W) VV/VH + + permanent_ref = _load_permanent_reference( + bbox, image.shape[-2:], occurrence_threshold_pct, download_permanent_water_occurrence + ) + + analysis = FloodingSARAnalysis(permanent_water_ref=permanent_ref) + date_range = f"{start_date} to {end_date}" if start_date and end_date else None + result = analysis.analyze(image=image, bbox=bbox, date_range=date_range) + + payload = result.to_dict() + payload["analysis_type"] = analysis_type + payload["is_synthetic"] = bool(sar_meta.get("is_synthetic", False)) + payload["metadata"] = { + "sar": sar_meta, + "permanent_water_reference": permanent_ref is not None, + "occurrence_threshold_pct": occurrence_threshold_pct, + } + return payload + + +def _load_permanent_reference( + bbox, target_hw, threshold_pct, downloader, +) -> Optional[np.ndarray]: + """Fetch JRC GSW occurrence and turn it into a permanent-water mask aligned + to the SAR grid. Returns None when GEE is unavailable (no guessing).""" + if bbox is None: + return None + try: + occ_path, occ_meta = downloader(bbox=bbox) + except Exception as exc: # network/credential failure -> no reference + logger.warning("Permanent-water reference fetch failed (%s); split unresolved.", exc) + return None + if occ_path is None: + return None + + occ = _read_tile(str(occ_path)) + occ = occ[0] if occ.ndim == 3 else occ + occ = _resample_nearest(occ, target_hw) + return permanent_water_from_occurrence(occ, threshold_pct=threshold_pct) + + +def _read_tile(path: str) -> np.ndarray: + """Read a GeoTIFF as a float32 array (C, H, W) or (H, W).""" + try: + import rasterio + except ImportError as exc: + raise ImportError("rasterio is required to read flood tiles") from exc + with rasterio.open(path) as ds: + return ds.read().astype(np.float32) + + +def _resample_nearest(arr: np.ndarray, target_hw: tuple[int, int]) -> np.ndarray: + """Nearest-neighbour resample a 2-D array to target (H, W) without SciPy.""" + th, tw = target_hw + h, w = arr.shape + if (h, w) == (th, tw): + return arr + ys = (np.linspace(0, h - 1, th)).round().astype(int) + xs = (np.linspace(0, w - 1, tw)).round().astype(int) + return arr[np.ix_(ys, xs)] diff --git a/tests/test_flood_classification.py b/tests/test_flood_classification.py new file mode 100644 index 0000000..3016526 --- /dev/null +++ b/tests/test_flood_classification.py @@ -0,0 +1,71 @@ +""" +Tests for permanent-water vs flood-water classification. +""" +from __future__ import annotations + +import numpy as np +import pytest + +from climatevision.analysis.flood_classification import ( + DRY_LAND, + FLOODED, + PERMANENT_WATER, + classify_with_change, + classify_with_reference, + permanent_water_from_occurrence, +) +from climatevision.analysis.flooding_ensemble import EnsembleFloodPipeline + + +class TestClassifyWithReference: + def test_splits_permanent_and_flood(self): + water = np.array([[1, 1], [1, 0]], dtype=np.uint8) + perm = np.array([[1, 0], [0, 0]], dtype=np.uint8) + out = classify_with_reference(water, perm) + assert out[0, 0] == PERMANENT_WATER # water + in reference + assert out[0, 1] == FLOODED # water, not in reference + assert out[1, 0] == FLOODED # water, not in reference + assert out[1, 1] == DRY_LAND # no water + + def test_shape_mismatch_raises(self): + with pytest.raises(ValueError): + classify_with_reference(np.zeros((4, 4)), np.zeros((2, 2))) + + +class TestClassifyWithChange: + def test_change_separates_flood_from_permanent(self): + pre = np.array([[1, 0], [0, 1]], dtype=np.uint8) + post = np.array([[1, 1], [0, 0]], dtype=np.uint8) + out = classify_with_change(pre, post) + assert out[0, 0] == PERMANENT_WATER # water before and after + assert out[0, 1] == FLOODED # appeared after + assert out[1, 0] == DRY_LAND # dry both + assert out[1, 1] == DRY_LAND # receded -> not flooded now + + +class TestOccurrence: + def test_threshold(self): + occ = np.array([[10.0, 60.0], [50.0, 0.0]]) + mask = permanent_water_from_occurrence(occ, threshold_pct=50.0) + assert mask.tolist() == [[0, 1], [1, 0]] + + +class TestEnsembleIntegration: + def test_reference_yields_three_class_output(self): + # Low VH (dB) reads as water for the TUW/DLR detectors. + post_vh = np.full((16, 16), -10.0, dtype=np.float32) + post_vh[4:12, 4:12] = -26.0 # a water blob + perm_ref = np.zeros((16, 16), dtype=np.uint8) + perm_ref[4:8, 4:8] = 1 # half the blob is "normally water" + + out = EnsembleFloodPipeline().detect(post_vh, permanent_water_ref=perm_ref) + classified = out["classified_mask"] + assert classified is not None + # Both permanent and flood pixels should be present in the blob. + assert (classified == PERMANENT_WATER).any() + assert (classified == FLOODED).any() + + def test_no_reference_returns_none_not_a_guess(self): + post_vh = np.full((8, 8), -10.0, dtype=np.float32) + out = EnsembleFloodPipeline().detect(post_vh) + assert out["classified_mask"] is None diff --git a/tests/test_flooding_sar.py b/tests/test_flooding_sar.py new file mode 100644 index 0000000..8c91cf2 --- /dev/null +++ b/tests/test_flooding_sar.py @@ -0,0 +1,115 @@ +""" +Tests for the SAR flood analysis type, its registry wiring, and API exposure. + +The full GEE -> rasterio download path is not exercised here (it needs GDAL and +Earth Engine credentials); those are integration concerns. These tests cover the +analysis logic on numpy arrays plus discoverability through the registry/API. +""" +from __future__ import annotations + +import numpy as np +import pytest + +from climatevision.analysis import get_analysis_type, list_analysis_types +from climatevision.analysis.flooding_sar import FloodingSARAnalysis, FLOODED, PERMANENT_WATER + + +def _sar_scene(size: int = 24) -> np.ndarray: + """(2, H, W) VV/VH dB scene with a central water blob (~ -26 dB VH).""" + rng = np.random.default_rng(0) + vv = rng.normal(-9.0, 1.0, (size, size)).astype(np.float32) + vh = rng.normal(-15.0, 1.0, (size, size)).astype(np.float32) + vh[6:18, 6:18] = -26.0 + vv[6:18, 6:18] = -20.0 + return np.stack([vv, vv * 0 + vh], axis=0) + + +class TestFloodingSARAnalysis: + def test_preprocess_returns_two_band_db(self): + scene = _sar_scene() + out = FloodingSARAnalysis().preprocess(scene) + assert out.shape == scene.shape # (2, H, W) preserved + + def test_without_reference_water_is_flooded_and_flagged(self): + scene = _sar_scene() + analysis = FloodingSARAnalysis() + pred, conf = analysis.run_inference(analysis.preprocess(scene)) + assert (pred == FLOODED).any() + assert (pred == PERMANENT_WATER).sum() == 0 # cannot resolve permanent + metrics = analysis.calculate_metrics(pred, scene.shape[-2:]) + assert metrics["permanent_flood_distinguished"] is False + + def test_with_reference_separates_permanent_and_flood(self): + scene = _sar_scene() + # Half the water blob is "normally water". + ref = np.zeros(scene.shape[-2:], dtype=np.uint8) + ref[6:12, 6:18] = 1 + analysis = FloodingSARAnalysis(permanent_water_ref=ref) + pred, conf = analysis.run_inference(analysis.preprocess(scene)) + assert (pred == PERMANENT_WATER).any() + assert (pred == FLOODED).any() + metrics = analysis.calculate_metrics(pred, scene.shape[-2:]) + assert metrics["permanent_flood_distinguished"] is True + + def test_metrics_area_with_bbox(self): + scene = _sar_scene() + analysis = FloodingSARAnalysis() + pred, _ = analysis.run_inference(analysis.preprocess(scene)) + metrics = analysis.calculate_metrics(pred, scene.shape[-2:], bbox=[36.7, -1.4, 37.0, -1.1]) + assert "flooded_area_km2" in metrics + assert metrics["flooded_area_km2"] >= 0 + + def test_alerts_critical(self): + analysis = FloodingSARAnalysis() + alerts = analysis.generate_alerts({"flooded_percentage": 30.0, "flooded_area_km2": 12.0}) + assert len(alerts) == 1 + assert alerts[0].severity.value == "critical" + + def test_full_analyze_pipeline(self): + scene = _sar_scene() + result = FloodingSARAnalysis().analyze(image=scene, bbox=[36.7, -1.4, 37.0, -1.1]) + assert result.success + assert result.analysis_type == "flooding_sar" + assert "flooded_percentage" in result.metrics + + +class TestRegistryAndDiscovery: + def test_registered_in_registry(self): + analysis = get_analysis_type("flooding_sar") + assert analysis is not None + assert isinstance(analysis, FloodingSARAnalysis) + + def test_listed_among_analysis_types(self): + names = [t["name"] for t in list_analysis_types()] + assert "flooding_sar" in names + + +class TestApiExposure: + def test_health_ok_and_lists_flooding_sar(self, client): + resp = client.get("/api/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" # config entry keeps validation green + assert "flooding_sar" in data["analysis_types"] + + def test_analysis_types_endpoint_includes_sar(self, client): + resp = client.get("/api/analysis-types") + assert resp.status_code == 200 + sar = next((t for t in resp.json() if t["name"] == "flooding_sar"), None) + assert sar is not None + assert sar["bands"] == ["VV", "VH"] + + def test_predict_flooding_sar_accepted_by_schema(self, client): + """A flooding_sar request must reach auth (401), not be rejected as an + invalid analysis_type (422) -- proving the schema accepts it.""" + resp = client.post( + "/api/predict", + json={ + "kind": "gee", + "analysis_type": "flooding_sar", + "bbox": [36.7, -1.4, 37.0, -1.1], + "start_date": "2024-04-01", + "end_date": "2024-04-10", + }, + ) + assert resp.status_code == 401 From f9699acc6fa466817b5f81eb76cbfd09724b05f2 Mon Sep 17 00:00:00 2001 From: Gold okpa Date: Tue, 16 Jun 2026 21:40:47 +0100 Subject: [PATCH 3/4] fix(deps): add scikit-image required by SAR flood DLR detector --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 687a133..3aec228 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ pandas>=1.3.0 torch>=2.0.0 torchvision>=0.15.0 scikit-learn>=1.0.0 +scikit-image>=0.19.0 # Geospatial Data Processing rasterio>=1.3.0 From 7d9360564ce3f7362d4d75e8fcfa973edfe61552 Mon Sep 17 00:00:00 2001 From: Gold okpa Date: Tue, 16 Jun 2026 21:44:26 +0100 Subject: [PATCH 4/4] fix(security): raise default validate_bbox max_area to 200 for valid regional bboxes test_valid_bbox uses a 15x10=150 sq-degree regional box that was wrongly rejected by the previous default of 100. The DoS test (test_too_large_area) passes max_area explicitly so it is unaffected. --- src/climatevision/security/api_security.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/climatevision/security/api_security.py b/src/climatevision/security/api_security.py index 7eaa8fc..9a8523e 100644 --- a/src/climatevision/security/api_security.py +++ b/src/climatevision/security/api_security.py @@ -210,7 +210,7 @@ def validate_payload_size( def validate_bbox( bbox: list[float], - max_area: float = 100.0, + max_area: float = 200.0, ) -> tuple[bool, str]: """ Validate bounding box coordinates.