From 5b924d66ed2ad0690a2771bc026ccfa3787b2dc1 Mon Sep 17 00:00:00 2001 From: Sergio Souza Costa Date: Fri, 12 Jun 2026 16:59:56 -0300 Subject: [PATCH] =?UTF-8?q?chore(release):=20vers=C3=A3o=200.6.1=20?= =?UTF-8?q?=E2=80=94=20fix=20de=20truncamento=20de=20dtype=20no=20GeoTIFF?= =?UTF-8?q?=20e=20su=C3=ADtes=20de=20teste=20io/visualization?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Fable 5 --- .gitignore | 3 + CHANGELOG.md | 19 +++ dissmodel/io/raster.py | 16 ++ pyproject.toml | 2 +- tests/executor/test_utils_config.py | 47 ++++++ tests/io/test_convert.py | 157 ++++++++++++++++++ tests/io/test_dispatch.py | 166 +++++++++++++++++++ tests/io/test_io_utils.py | 236 ++++++++++++++++++++++++++++ tests/io/test_raster_io.py | 202 ++++++++++++++++++++++++ tests/io/test_vector_io.py | 70 +++++++++ tests/io/test_xarray_io.py | 81 ++++++++++ tests/visualization/test_chart.py | 198 +++++++++++++++++++++++ tests/visualization/test_map.py | 173 ++++++++++++++++++++ 13 files changed, 1369 insertions(+), 1 deletion(-) create mode 100644 tests/executor/test_utils_config.py create mode 100644 tests/io/test_convert.py create mode 100644 tests/io/test_dispatch.py create mode 100644 tests/io/test_io_utils.py create mode 100644 tests/io/test_raster_io.py create mode 100644 tests/io/test_vector_io.py create mode 100644 tests/io/test_xarray_io.py create mode 100644 tests/visualization/test_chart.py create mode 100644 tests/visualization/test_map.py diff --git a/.gitignore b/.gitignore index d19d0de..79d9b17 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ raster_map_frames testes +chart_frames/ +map_frames/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/CHANGELOG.md b/CHANGELOG.md index dd1bf08..91f54ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 --- +## [0.6.1] — 2026-06-12 + +### Fixed +- GeoTIFF write with mixed-dtype `band_spec` (e.g. int32 categorical + + float32 continuous) silently truncated float bands to the first band's + dtype. Bands are now promoted to the common NumPy result type + (`np.result_type`) before writing. + +### Added +- Test suites for `dissmodel.io` (utils, dispatch, storage, raster, vector, + convert, xarray) and `dissmodel.visualization` (chart, map, env detection). + Coverage: 55% → 79% (319 → 441 tests). + +### Notes +- GeoTIFFs containing mixed-dtype bands saved with v0.6.0 may have truncated + float bands. Re-exporting those files is recommended. + +--- + ## [0.6.0] — 2026-06-11 ### Breaking Changes diff --git a/dissmodel/io/raster.py b/dissmodel/io/raster.py index 755bd29..07754d0 100644 --- a/dissmodel/io/raster.py +++ b/dissmodel/io/raster.py @@ -139,6 +139,13 @@ def save_geotiff( Affine geotransform. Overrides meta["transform"]. compress : str Compression algorithm. Default: "deflate". + + Notes + ----- + GeoTIFF stores a single dtype per file. When band dtypes differ + (e.g. an int32 categorical band alongside a float32 continuous band), + all bands are promoted to their common NumPy result type + (``np.result_type``) before writing. """ if not HAS_RASTERIO: raise ImportError("rasterio is required — pip install rasterio") @@ -212,6 +219,15 @@ def _write_geotiff( if transform is None: transform = rasterio.transform.from_bounds(0, 0, cols, rows, cols, rows) + # GeoTIFF requires a single dtype for all bands. When bands have mixed + # dtypes (e.g. int32 categorical + float32 continuous), promote every + # band to the common NumPy result type instead of silently truncating + # to the first band's dtype. + dtypes = {arr.dtype for arr in arrays} + if len(dtypes) > 1: + common = np.result_type(*dtypes) + arrays = [arr.astype(common) for arr in arrays] + with rasterio.open( path, "w", driver = "GTiff", diff --git a/pyproject.toml b/pyproject.toml index 481d8e2..772f290 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dissmodel" -version = "0.6.0" +version = "0.6.1" description = "Discrete Spatial Modeling framework for raster and vector simulations" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.10" diff --git a/tests/executor/test_utils_config.py b/tests/executor/test_utils_config.py new file mode 100644 index 0000000..78673b0 --- /dev/null +++ b/tests/executor/test_utils_config.py @@ -0,0 +1,47 @@ +""" +tests/executor/test_utils_config.py +==================================== +Tests for dissmodel.executor.utils — default_output_uri — and +dissmodel.executor.config — Settings. +""" +from __future__ import annotations + +from dissmodel.executor.utils import default_output_uri +from dissmodel.executor.config import Settings, settings +from dissmodel.io import _storage + + +class TestDefaultOutputUri: + + def teardown_method(self): + _storage.set_default_client(None) + + def test_s3_uri_when_minio_reachable(self): + _storage.set_default_client(object()) # any non-None client + uri = default_output_uri("exp-123", "tif") + assert uri == "s3://dissmodel-outputs/experiments/exp-123/output.tif" + + def test_local_fallback_when_minio_unreachable(self, monkeypatch): + _storage.set_default_client(None) + + def boom(): + raise RuntimeError("no MinIO") + + monkeypatch.setattr( + "dissmodel.io._storage.get_default_client", boom + ) + uri = default_output_uri("exp-123", "gpkg") + assert uri == "./outputs/exp-123/output.gpkg" + + +class TestSettings: + + def test_default_output_base(self): + assert Settings().default_output_base == "./outputs" + + def test_module_level_singleton_exists(self): + assert isinstance(settings, Settings) + + def test_env_var_overrides_default(self, monkeypatch): + monkeypatch.setenv("DEFAULT_OUTPUT_BASE", "/data/runs") + assert Settings().default_output_base == "/data/runs" diff --git a/tests/io/test_convert.py b/tests/io/test_convert.py new file mode 100644 index 0000000..ec0b5fa --- /dev/null +++ b/tests/io/test_convert.py @@ -0,0 +1,157 @@ +""" +tests/io/test_convert.py +========================= +Tests for dissmodel.io.convert — vector_to_raster_backend. + +Covers: GeoDataFrame and file-path sources, attrs as list and dict, +mask band, nodata_value sentinel, CRS validation/reprojection, +error paths, and the deprecated shapefile_to_raster_backend alias. +""" +from __future__ import annotations + +import numpy as np +import pytest + +gpd = pytest.importorskip("geopandas") +pytest.importorskip("rasterio") + +from shapely.geometry import Polygon # noqa: E402 + +from dissmodel.io.convert import ( # noqa: E402 + vector_to_raster_backend, + shapefile_to_raster_backend, +) + + +@pytest.fixture +def gdf(): + """Two unit squares side by side covering x∈[0,2], y∈[0,1].""" + geoms = [ + Polygon([(0, 0), (1, 0), (1, 1), (0, 1)]), + Polygon([(1, 0), (2, 0), (2, 1), (1, 1)]), + ] + return gpd.GeoDataFrame( + {"uso": [1, 2], "alt": [0.25, 0.75]}, + geometry=geoms, + crs="EPSG:31984", + ) + + +class TestVectorToRasterBackend: + + def test_grid_dimensions_follow_bounds_and_resolution(self, gdf): + b = vector_to_raster_backend(gdf, resolution=0.5, attrs=["uso"]) + # bounds 2x1, resolution 0.5 → 4 cols × 2 rows + assert b.shape == (2, 4) + + def test_integer_column_rasterized_as_int32(self, gdf): + b = vector_to_raster_backend(gdf, resolution=0.5, attrs=["uso"]) + arr = b.arrays["uso"] + assert arr.dtype == np.int32 + # Left half burned with 1, right half with 2 + assert np.all(arr[:, :2] == 1) + assert np.all(arr[:, 2:] == 2) + + def test_float_column_rasterized_as_float32(self, gdf): + b = vector_to_raster_backend(gdf, resolution=0.5, attrs=["alt"]) + arr = b.arrays["alt"] + assert arr.dtype == np.float32 + np.testing.assert_allclose(arr[:, :2], 0.25) + np.testing.assert_allclose(arr[:, 2:], 0.75) + + def test_mask_band_added_by_default(self, gdf): + b = vector_to_raster_backend(gdf, resolution=0.5, attrs=["uso"]) + assert "mask" in b.arrays + assert np.all(b.arrays["mask"] == 1.0) # fully covered extent + + def test_add_mask_false_omits_band(self, gdf): + b = vector_to_raster_backend( + gdf, resolution=0.5, attrs=["uso"], add_mask=False + ) + assert "mask" not in b.arrays + + def test_attrs_dict_with_per_column_defaults(self, gdf): + b = vector_to_raster_backend( + gdf, resolution=0.5, attrs={"uso": -1, "alt": -9999.0} + ) + assert set(b.arrays) >= {"uso", "alt"} + + def test_nodata_value_sentinel_outside_coverage(self): + # Two squares at opposite corners of a 2x2 extent — the other two + # corner cells of the grid fall outside any geometry. + gdf2 = gpd.GeoDataFrame( + {"uso": [5, 7]}, + geometry=[ + Polygon([(0, 0), (1, 0), (1, 1), (0, 1)]), + Polygon([(1.9, 1.9), (2, 1.9), (2, 2), (1.9, 2)]), + ], + crs="EPSG:31984", + ) + b = vector_to_raster_backend( + gdf2, resolution=1.0, attrs=["uso"], nodata_value=-1 + ) + arr = b.arrays["uso"] + mask = b.arrays["mask"].astype(bool) + assert np.all(arr[~mask] == -1) # sentinel outside coverage + assert b.nodata_value == -1 + + def test_source_gdf_is_not_mutated(self, gdf): + original_crs = gdf.crs + vector_to_raster_backend(gdf, resolution=0.5, attrs=["uso"], crs="EPSG:4326") + assert gdf.crs == original_crs + + def test_reprojection_applied_when_crs_given(self, gdf): + b = vector_to_raster_backend( + gdf, resolution=0.00001, attrs=["uso"], crs="EPSG:4326" + ) + assert b.crs is not None + assert "4326" in str(b.crs) + + def test_file_path_source(self, gdf, tmp_path): + path = tmp_path / "grid.gpkg" + gdf.to_file(str(path), driver="GPKG") + b = vector_to_raster_backend(str(path), resolution=0.5, attrs=["uso"]) + assert b.shape == (2, 4) + + # ── error paths ─────────────────────────────────────────────────────────── + + def test_missing_file_raises(self, tmp_path): + with pytest.raises(FileNotFoundError): + vector_to_raster_backend( + str(tmp_path / "missing.shp"), resolution=1, attrs=["uso"] + ) + + def test_empty_attrs_raises(self, gdf): + with pytest.raises(ValueError, match="must not be empty"): + vector_to_raster_backend(gdf, resolution=0.5, attrs=[]) + + def test_missing_column_raises(self, gdf): + with pytest.raises(ValueError, match="not found"): + vector_to_raster_backend(gdf, resolution=0.5, attrs=["inexistente"]) + + def test_no_crs_anywhere_raises(self, gdf): + # Build a GeoDataFrame without CRS from scratch — overriding + # .crs on an existing one is deprecated by pandas/geopandas. + naked = gpd.GeoDataFrame( + gdf.drop(columns="geometry"), + geometry=list(gdf.geometry), # raw shapely geoms — no CRS + ) + assert naked.crs is None + with pytest.raises(ValueError, match="no CRS"): + vector_to_raster_backend(naked, resolution=0.5, attrs=["uso"]) + + def test_gdf_without_crs_but_explicit_crs_ok(self, gdf): + naked = gdf.copy() + naked = naked.set_crs("EPSG:31984", allow_override=True) + b = vector_to_raster_backend( + naked, resolution=0.5, attrs=["uso"], crs="EPSG:31984" + ) + assert b.shape == (2, 4) + + +class TestDeprecatedAlias: + + def test_shapefile_to_raster_backend_warns(self, gdf): + with pytest.warns(FutureWarning, match="deprecated"): + b = shapefile_to_raster_backend(gdf, resolution=0.5, attrs=["uso"]) + assert b.shape == (2, 4) \ No newline at end of file diff --git a/tests/io/test_dispatch.py b/tests/io/test_dispatch.py new file mode 100644 index 0000000..383649f --- /dev/null +++ b/tests/io/test_dispatch.py @@ -0,0 +1,166 @@ +""" +tests/io/test_dispatch.py +========================== +Tests for dissmodel.io._dispatch — format routing of load_dataset and +save_dataset — and for dissmodel.io._storage — default MinIO client. + +Routing is verified with monkeypatched loaders/savers so each test +asserts only the dispatch decision, not the underlying I/O. +""" +from __future__ import annotations + +import sys + +import pytest + +from dissmodel.io import load_dataset, save_dataset +from dissmodel.io import _storage + + +# ══════════════════════════════════════════════════════════════════════════════ +# load_dataset routing +# ══════════════════════════════════════════════════════════════════════════════ + +class TestLoadDispatch: + + def test_routes_vector(self, monkeypatch): + import dissmodel.io.vector as vector_mod + monkeypatch.setattr( + vector_mod, "load_gdf", + lambda uri, minio_client=None, **kw: ("GDF", "sum"), + ) + assert load_dataset("data/grid.gpkg") == ("GDF", "sum") + + def test_routes_raster(self, monkeypatch): + import dissmodel.io.raster as raster_mod + monkeypatch.setattr( + raster_mod, "load_geotiff", + lambda uri, minio_client=None, **kw: (("BACKEND", {}), "sum"), + ) + assert load_dataset("data/scene.tif") == (("BACKEND", {}), "sum") + + def test_routes_xarray(self, monkeypatch): + import dissmodel.io._xarray as xr_mod + monkeypatch.setattr( + xr_mod, "load_xarray", + lambda uri, minio_client=None, **kw: ("DS", "sum"), + ) + assert load_dataset("data/cube.nc") == ("DS", "sum") + + def test_explicit_fmt_overrides_extension(self, monkeypatch): + import dissmodel.io.raster as raster_mod + monkeypatch.setattr( + raster_mod, "load_geotiff", + lambda uri, minio_client=None, **kw: ("RASTER", "sum"), + ) + # .dat is not a recognized extension — fmt= must win + assert load_dataset("data/scene.dat", fmt="raster") == ("RASTER", "sum") + + def test_unsupported_fmt_raises(self): + with pytest.raises(ValueError, match="Unsupported format"): + load_dataset("data/scene.tif", fmt="hologram") + + def test_unknown_extension_raises(self): + with pytest.raises(ValueError): + load_dataset("data/file.xyz") + + +# ══════════════════════════════════════════════════════════════════════════════ +# save_dataset routing +# ══════════════════════════════════════════════════════════════════════════════ + +class TestSaveDispatch: + + def test_routes_vector(self, monkeypatch): + import dissmodel.io.vector as vector_mod + monkeypatch.setattr( + vector_mod, "save_gdf", + lambda data, uri, minio_client=None, **kw: "checksum-v", + ) + assert save_dataset("GDF", "out/grid.gpkg") == "checksum-v" + + def test_routes_raster(self, monkeypatch): + import dissmodel.io.raster as raster_mod + monkeypatch.setattr( + raster_mod, "save_geotiff", + lambda data, uri, minio_client=None, **kw: "checksum-r", + ) + assert save_dataset(("BACKEND", {}), "out/scene.tif") == "checksum-r" + + def test_routes_xarray(self, monkeypatch): + import dissmodel.io._xarray as xr_mod + monkeypatch.setattr( + xr_mod, "save_xarray", + lambda data, uri, minio_client=None, **kw: "checksum-x", + ) + assert save_dataset("DS", "out/cube.zarr") == "checksum-x" + + def test_unsupported_fmt_raises(self): + with pytest.raises(ValueError, match="Unsupported format"): + save_dataset("DATA", "out/scene.tif", fmt="hologram") + + +# ══════════════════════════════════════════════════════════════════════════════ +# _storage — default MinIO client +# ══════════════════════════════════════════════════════════════════════════════ + +class TestStorage: + + def setup_method(self): + _storage.set_default_client(None) + + def teardown_method(self): + _storage.set_default_client(None) + + def test_set_default_client_overrides(self): + sentinel = object() + _storage.set_default_client(sentinel) + assert _storage.get_default_client() is sentinel + + def test_get_default_client_builds_from_env(self, monkeypatch): + created = {} + + class FakeMinio: + def __init__(self, endpoint, access_key, secret_key, secure): + created.update( + endpoint=endpoint, + access_key=access_key, + secret_key=secret_key, + secure=secure, + ) + + fake_module = type(sys)("minio") + fake_module.Minio = FakeMinio + monkeypatch.setitem(sys.modules, "minio", fake_module) + monkeypatch.setenv("MINIO_ENDPOINT", "host:9000") + monkeypatch.setenv("MINIO_ACCESS_KEY", "ak") + monkeypatch.setenv("MINIO_SECRET_KEY", "sk") + monkeypatch.setenv("MINIO_SECURE", "1") + + client = _storage.get_default_client() + assert isinstance(client, FakeMinio) + assert created == { + "endpoint": "host:9000", + "access_key": "ak", + "secret_key": "sk", + "secure": True, + } + + def test_get_default_client_is_cached(self, monkeypatch): + class FakeMinio: + def __init__(self, *a, **kw): + pass + + fake_module = type(sys)("minio") + fake_module.Minio = FakeMinio + monkeypatch.setitem(sys.modules, "minio", fake_module) + + first = _storage.get_default_client() + second = _storage.get_default_client() + assert first is second + + def test_missing_minio_package_raises_importerror(self, monkeypatch): + # sys.modules[name] = None makes ``import name`` raise ImportError + monkeypatch.setitem(sys.modules, "minio", None) + with pytest.raises(ImportError, match="minio"): + _storage.get_default_client() diff --git a/tests/io/test_io_utils.py b/tests/io/test_io_utils.py new file mode 100644 index 0000000..a6978b3 --- /dev/null +++ b/tests/io/test_io_utils.py @@ -0,0 +1,236 @@ +""" +tests/io/test_io_utils.py +========================== +Tests for dissmodel.io._utils — format detection, checksums, and +generic URI read/write helpers. + +s3:// paths are exercised with an in-memory fake MinIO client so no +network or server is required. +""" +from __future__ import annotations + +import hashlib +import io + +import pytest + +from dissmodel.io._utils import ( + detect_format, + sha256_bytes, + sha256_file, + resolve_uri, + read_bytes, + read_text, + write_bytes, + write_text, +) +from dissmodel.io import _storage + + +# ── fakes ───────────────────────────────────────────────────────────────────── + +class _FakeObject: + def __init__(self, data: bytes): + self._data = data + + def read(self) -> bytes: + return self._data + + +class FakeMinioClient: + """Minimal in-memory stand-in for the MinIO client.""" + + def __init__(self): + self.store: dict[tuple[str, str], bytes] = {} + + def get_object(self, bucket: str, key: str) -> _FakeObject: + return _FakeObject(self.store[(bucket, key)]) + + def put_object(self, bucket_name, object_name, data, length, content_type): + self.store[(bucket_name, object_name)] = data.read() + + +@pytest.fixture +def fake_client(): + client = FakeMinioClient() + _storage.set_default_client(client) + yield client + _storage.set_default_client(None) + + +# ══════════════════════════════════════════════════════════════════════════════ +# detect_format +# ══════════════════════════════════════════════════════════════════════════════ + +class TestDetectFormat: + + @pytest.mark.parametrize("uri", [ + "data/grid.shp", + "data/grid.gpkg", + "data/grid.geojson", + "data/grid.json", + "data/grid.zip", + "s3://bucket/key.GPKG", # case-insensitive + ]) + def test_vector_extensions(self, uri): + assert detect_format(uri) == "vector" + + @pytest.mark.parametrize("uri", [ + "data/scene.tif", + "data/scene.tiff", + "https://example.org/scene.TIF", + ]) + def test_raster_extensions(self, uri): + assert detect_format(uri) == "raster" + + @pytest.mark.parametrize("uri", [ + "data/cube.zarr", + "data/cube.nc", + "data/cube.nc4", + ]) + def test_xarray_extensions(self, uri): + assert detect_format(uri) == "xarray" + + def test_query_string_is_stripped(self): + assert detect_format("https://host/scene.tif?token=abc") == "raster" + + def test_unknown_extension_raises(self): + with pytest.raises(ValueError, match="Cannot detect format"): + detect_format("data/file.xyz") + + def test_no_extension_raises(self): + with pytest.raises(ValueError): + detect_format("data/file") + + +# ══════════════════════════════════════════════════════════════════════════════ +# checksums +# ══════════════════════════════════════════════════════════════════════════════ + +class TestChecksums: + + def test_sha256_bytes_matches_hashlib(self): + payload = b"dissmodel" + assert sha256_bytes(payload) == hashlib.sha256(payload).hexdigest() + + def test_sha256_bytes_empty(self): + assert sha256_bytes(b"") == hashlib.sha256(b"").hexdigest() + + def test_sha256_file_matches_bytes_digest(self, tmp_path): + payload = b"x" * 200_000 # larger than one 65536-byte chunk + path = tmp_path / "blob.bin" + path.write_bytes(payload) + assert sha256_file(str(path)) == sha256_bytes(payload) + + +# ══════════════════════════════════════════════════════════════════════════════ +# resolve_uri +# ══════════════════════════════════════════════════════════════════════════════ + +class TestResolveUri: + + def test_local_path(self, tmp_path): + path = tmp_path / "data.txt" + path.write_bytes(b"hello") + content, checksum = resolve_uri(str(path)) + assert content == b"hello" + assert checksum == sha256_bytes(b"hello") + + def test_local_path_missing_raises(self, tmp_path): + with pytest.raises(FileNotFoundError): + resolve_uri(str(tmp_path / "missing.txt")) + + def test_s3_with_explicit_client(self): + client = FakeMinioClient() + client.store[("bucket", "dir/data.bin")] = b"payload" + content, checksum = resolve_uri("s3://bucket/dir/data.bin", client) + assert content == b"payload" + assert checksum == sha256_bytes(b"payload") + + def test_s3_falls_back_to_default_client(self, fake_client): + fake_client.store[("bucket", "key")] = b"default" + content, _ = resolve_uri("s3://bucket/key") + assert content == b"default" + + def test_http_uses_urllib(self, monkeypatch): + class _FakeResponse: + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + def read(self): + return b"web-content" + + import urllib.request + monkeypatch.setattr( + urllib.request, "urlopen", lambda uri: _FakeResponse() + ) + content, checksum = resolve_uri("https://example.org/data.bin") + assert content == b"web-content" + assert checksum == sha256_bytes(b"web-content") + + +# ══════════════════════════════════════════════════════════════════════════════ +# read_bytes / read_text +# ══════════════════════════════════════════════════════════════════════════════ + +class TestRead: + + def test_read_bytes_local(self, tmp_path): + path = tmp_path / "data.bin" + path.write_bytes(b"\x00\x01") + assert read_bytes(str(path)) == b"\x00\x01" + + def test_read_text_local(self, tmp_path): + path = tmp_path / "config.toml" + path.write_text("key = 'value'", encoding="utf-8") + assert read_text(str(path)) == "key = 'value'" + + def test_read_bytes_s3(self, fake_client): + fake_client.store[("bucket", "data.bin")] = b"remote" + assert read_bytes("s3://bucket/data.bin") == b"remote" + + +# ══════════════════════════════════════════════════════════════════════════════ +# write_bytes / write_text +# ══════════════════════════════════════════════════════════════════════════════ + +class TestWrite: + + def test_write_bytes_local(self, tmp_path): + target = tmp_path / "out" / "data.bin" # parent does not exist yet + checksum = write_bytes(b"payload", str(target)) + assert target.read_bytes() == b"payload" + assert checksum == sha256_bytes(b"payload") + + def test_write_bytes_accepts_bytesio(self, tmp_path): + buf = io.BytesIO(b"stream-data") + target = tmp_path / "data.bin" + checksum = write_bytes(buf, str(target)) + assert target.read_bytes() == b"stream-data" + assert checksum == sha256_bytes(b"stream-data") + + def test_write_bytes_rejects_text_stream(self, tmp_path): + with pytest.raises(TypeError, match="bytes or a binary stream"): + write_bytes(io.StringIO("text"), str(tmp_path / "x.bin")) + + def test_write_bytes_rejects_str(self, tmp_path): + with pytest.raises(TypeError): + write_bytes("not-bytes", str(tmp_path / "x.bin")) + + def test_write_text_local(self, tmp_path): + target = tmp_path / "report.md" + checksum = write_text("# Report", str(target)) + assert target.read_text(encoding="utf-8") == "# Report" + assert checksum == sha256_bytes(b"# Report") + + def test_write_bytes_s3(self, fake_client): + checksum = write_bytes(b"upload", "s3://bucket/dir/file.bin") + assert fake_client.store[("bucket", "dir/file.bin")] == b"upload" + assert checksum == sha256_bytes(b"upload") + + def test_write_text_s3(self, fake_client): + write_text("csv,data", "s3://bucket/results.csv", content_type="text/csv") + assert fake_client.store[("bucket", "results.csv")] == b"csv,data" diff --git a/tests/io/test_raster_io.py b/tests/io/test_raster_io.py new file mode 100644 index 0000000..cf4b295 --- /dev/null +++ b/tests/io/test_raster_io.py @@ -0,0 +1,202 @@ +""" +tests/io/test_raster_io.py +=========================== +Tests for dissmodel.io.raster — GeoTIFF load/save round-trips. + +Covers: local round-trip, band_spec selection, zip archives, nodata +band skipping, s3:// upload via fake client, and metadata recovery +(CRS + transform). +""" +from __future__ import annotations + +import zipfile + +import numpy as np +import pytest + +pytest.importorskip("rasterio") + +import rasterio # noqa: E402 +import rasterio.transform # noqa: E402 + +from dissmodel.geo.raster.backend import RasterBackend # noqa: E402 +from dissmodel.io.raster import load_geotiff, save_geotiff # noqa: E402 +from dissmodel.io._utils import sha256_file # noqa: E402 + + +# ── fixtures ────────────────────────────────────────────────────────────────── + +@pytest.fixture +def backend(): + b = RasterBackend(shape=(4, 5)) + b.arrays["uso"] = np.arange(20, dtype=np.int32).reshape(4, 5) + b.arrays["alt"] = np.linspace(0, 1, 20, dtype=np.float32).reshape(4, 5) + return b + + +@pytest.fixture +def saved_tif(backend, tmp_path): + """A GeoTIFF written by save_geotiff, with known CRS and transform.""" + path = tmp_path / "scene.tif" + transform = rasterio.transform.from_bounds(0, 0, 500, 400, 5, 4) + band_spec = [("uso", "int32", -1), ("alt", "float32", -9999.0)] + checksum = save_geotiff( + (backend, {}), + str(path), + band_spec=band_spec, + crs="EPSG:31984", + transform=transform, + ) + return path, band_spec, checksum + + +class FakeMinioClient: + def __init__(self): + self.store: dict[tuple[str, str], bytes] = {} + + def put_object(self, bucket_name, object_name, data, length, content_type): + self.store[(bucket_name, object_name)] = data.read() + + def get_object(self, bucket, key): + payload = self.store[(bucket, key)] + + class _Obj: + def read(self_inner): + return payload + + return _Obj() + + +# ══════════════════════════════════════════════════════════════════════════════ +# save_geotiff +# ══════════════════════════════════════════════════════════════════════════════ + +class TestSaveGeotiff: + + def test_returns_file_checksum(self, saved_tif): + path, _, checksum = saved_tif + assert checksum == sha256_file(str(path)) + + def test_creates_parent_directories(self, backend, tmp_path): + target = tmp_path / "nested" / "dir" / "scene.tif" + save_geotiff((backend, {}), str(target)) + assert target.exists() + + def test_without_band_spec_writes_all_arrays(self, backend, tmp_path): + path = tmp_path / "all.tif" + save_geotiff((backend, {}), str(path)) + with rasterio.open(str(path)) as ds: + assert ds.count == 2 + + def test_meta_crs_and_transform_are_used(self, backend, tmp_path): + path = tmp_path / "meta.tif" + transform = rasterio.transform.from_bounds(10, 20, 510, 420, 5, 4) + meta = {"crs": "EPSG:4326", "transform": transform} + save_geotiff((backend, meta), str(path)) + with rasterio.open(str(path)) as ds: + assert ds.crs.to_string() == "EPSG:4326" + assert ds.transform == transform + + def test_explicit_crs_overrides_meta(self, backend, tmp_path): + path = tmp_path / "override.tif" + save_geotiff((backend, {"crs": "EPSG:4326"}), str(path), crs="EPSG:31984") + with rasterio.open(str(path)) as ds: + assert "31984" in ds.crs.to_string() + + def test_band_spec_fills_missing_band_with_nodata(self, backend, tmp_path): + path = tmp_path / "fill.tif" + spec = [("uso", "int32", -1), ("inexistente", "int32", -1)] + save_geotiff((backend, {}), str(path), band_spec=spec) + with rasterio.open(str(path)) as ds: + assert ds.count == 2 + assert np.all(ds.read(2) == -1) + + def test_s3_upload(self, backend): + client = FakeMinioClient() + checksum = save_geotiff( + (backend, {}), "s3://bucket/exp/scene.tif", minio_client=client + ) + payload = client.store[("bucket", "exp/scene.tif")] + assert len(payload) > 0 + assert isinstance(checksum, str) and len(checksum) == 64 + + def test_mixed_dtype_bands_are_not_truncated(self, backend, tmp_path): + """Regression: int32 + float32 bands must promote to a common + dtype instead of truncating floats to the first band's dtype.""" + path = tmp_path / "mixed.tif" + spec = [("uso", "int32", -1), ("alt", "float32", -9999.0)] + save_geotiff((backend, {}), str(path), band_spec=spec) + + with rasterio.open(str(path)) as ds: + alt = ds.read(2) + np.testing.assert_allclose(alt, backend.arrays["alt"], rtol=1e-6) + + +# ══════════════════════════════════════════════════════════════════════════════ +# load_geotiff +# ══════════════════════════════════════════════════════════════════════════════ + +class TestLoadGeotiff: + + def test_roundtrip_with_band_spec(self, backend, saved_tif): + path, band_spec, _ = saved_tif + (loaded, meta), checksum = load_geotiff(str(path), band_spec=band_spec) + + assert loaded.shape == backend.shape + np.testing.assert_array_equal(loaded.arrays["uso"], backend.arrays["uso"]) + np.testing.assert_allclose(loaded.arrays["alt"], backend.arrays["alt"]) + assert checksum == sha256_file(str(path)) + + def test_roundtrip_recovers_crs_and_transform(self, saved_tif): + path, band_spec, _ = saved_tif + (_, meta), _ = load_geotiff(str(path), band_spec=band_spec) + assert "31984" in meta["crs"].to_string() + assert meta["transform"] is not None + + def test_without_band_spec_recovers_tag_names(self, saved_tif): + path, _, _ = saved_tif + (loaded, _), _ = load_geotiff(str(path)) + # save_geotiff writes a 'name' tag per band — loader recovers it + assert set(loaded.arrays) == {"uso", "alt"} + + def test_band_spec_skips_all_nodata_band(self, tmp_path): + b = RasterBackend(shape=(2, 2)) + b.arrays["uso"] = np.ones((2, 2), dtype=np.int32) + b.arrays["vazio"] = np.full((2, 2), -1, dtype=np.int32) + path = tmp_path / "skip.tif" + spec = [("uso", "int32", -1), ("vazio", "int32", -1)] + save_geotiff((b, {}), str(path), band_spec=spec) + + (loaded, _), _ = load_geotiff(str(path), band_spec=spec) + assert "uso" in loaded.arrays + assert "vazio" not in loaded.arrays # uninitialised band skipped + + def test_band_spec_longer_than_band_count_is_truncated(self, saved_tif): + path, _, _ = saved_tif + spec = [ + ("uso", "int32", -1), + ("alt", "float32", -9999.0), + ("extra", "int32", -1), # no third band in the file + ] + (loaded, _), _ = load_geotiff(str(path), band_spec=spec) + assert "extra" not in loaded.arrays + + def test_load_from_zip_archive(self, saved_tif, tmp_path): + path, band_spec, _ = saved_tif + zip_path = tmp_path / "scene.zip" + with zipfile.ZipFile(zip_path, "w") as z: + z.write(path, arcname="scene.tif") + + (loaded, _), checksum = load_geotiff(str(zip_path), band_spec=band_spec) + assert loaded.shape == (4, 5) + assert checksum == sha256_file(str(zip_path)) + + def test_load_from_s3(self, saved_tif): + path, band_spec, _ = saved_tif + client = FakeMinioClient() + client.store[("bucket", "scene.tif")] = path.read_bytes() + + (loaded, _), _ = load_geotiff( + "s3://bucket/scene.tif", minio_client=client, band_spec=band_spec + ) + assert loaded.shape == (4, 5) diff --git a/tests/io/test_vector_io.py b/tests/io/test_vector_io.py new file mode 100644 index 0000000..bbbe7d7 --- /dev/null +++ b/tests/io/test_vector_io.py @@ -0,0 +1,70 @@ +""" +tests/io/test_vector_io.py +=========================== +Tests for dissmodel.io.vector — GeoDataFrame load/save round-trips +(GeoPackage), plus s3:// upload via fake client. +""" +from __future__ import annotations + +import pytest + +gpd = pytest.importorskip("geopandas") + +from shapely.geometry import Polygon # noqa: E402 + +from dissmodel.io.vector import load_gdf, save_gdf # noqa: E402 +from dissmodel.io._utils import sha256_file # noqa: E402 + + +@pytest.fixture +def gdf(): + geoms = [ + Polygon([(0, 0), (1, 0), (1, 1), (0, 1)]), + Polygon([(1, 0), (2, 0), (2, 1), (1, 1)]), + ] + return gpd.GeoDataFrame( + {"uso": [1, 2], "alt": [0.5, 1.5]}, + geometry=geoms, + crs="EPSG:31984", + ) + + +class FakeMinioClient: + def __init__(self): + self.store: dict[tuple[str, str], bytes] = {} + + def put_object(self, bucket_name, object_name, data, length, content_type): + self.store[(bucket_name, object_name)] = data.read() + + +class TestVectorIO: + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_roundtrip_gpkg(self, gdf, tmp_path): + path = tmp_path / "grid.gpkg" + checksum = save_gdf(gdf, str(path)) + + loaded, load_checksum = load_gdf(str(path)) + assert list(loaded["uso"]) == [1, 2] + assert loaded.crs is not None + assert checksum == load_checksum == sha256_file(str(path)) + + def test_save_custom_layer(self, gdf, tmp_path): + path = tmp_path / "layers.gpkg" + save_gdf(gdf, str(path), layer="simulation") + loaded = gpd.read_file(str(path), layer="simulation") + assert len(loaded) == 2 + + def test_save_to_s3(self, gdf): + client = FakeMinioClient() + checksum = save_gdf(gdf, "s3://bucket/out/grid.gpkg", minio_client=client) + payload = client.store[("bucket", "out/grid.gpkg")] + assert len(payload) > 0 + assert isinstance(checksum, str) and len(checksum) == 64 + + def test_load_geojson(self, gdf, tmp_path): + path = tmp_path / "grid.geojson" + gdf.to_file(str(path), driver="GeoJSON") + loaded, checksum = load_gdf(str(path)) + assert len(loaded) == 2 + assert checksum == sha256_file(str(path)) diff --git a/tests/io/test_xarray_io.py b/tests/io/test_xarray_io.py new file mode 100644 index 0000000..ee6158a --- /dev/null +++ b/tests/io/test_xarray_io.py @@ -0,0 +1,81 @@ +""" +tests/io/test_xarray_io.py +=========================== +Tests for dissmodel.io._xarray — NetCDF and Zarr round-trips through +RasterBackend.to_xarray() / from_xarray(). +""" +from __future__ import annotations + +import numpy as np +import pytest + +xr = pytest.importorskip("xarray") + +from dissmodel.geo.raster.backend import RasterBackend # noqa: E402 +from dissmodel.io._xarray import load_xarray, save_xarray, _file_checksum # noqa: E402 +from dissmodel.io._utils import sha256_file # noqa: E402 + + +@pytest.fixture +def backend(): + b = RasterBackend(shape=(3, 4)) + b.arrays["uso"] = np.arange(12, dtype=np.int32).reshape(3, 4) + b.arrays["alt"] = np.linspace(0, 1, 12, dtype=np.float32).reshape(3, 4) + return b + + +class TestNetCDF: + + @pytest.mark.filterwarnings("ignore::UserWarning") + def test_roundtrip(self, backend, tmp_path): + path = tmp_path / "snapshot.nc" + checksum = save_xarray(backend, str(path)) + + loaded, load_checksum = load_xarray(str(path)) + assert loaded.shape == backend.shape + np.testing.assert_array_equal(loaded.arrays["uso"], backend.arrays["uso"]) + np.testing.assert_allclose(loaded.arrays["alt"], backend.arrays["alt"]) + assert checksum == load_checksum == sha256_file(str(path)) + + def test_save_with_step_attaches_time(self, backend, tmp_path): + path = tmp_path / "step42.nc" + save_xarray(backend, str(path), step=42) + ds = xr.open_dataset(str(path)) + assert "time" in ds.coords + assert int(ds["time"].values) == 42 + ds.close() + + def test_save_accepts_raw_dataset(self, backend, tmp_path): + ds = backend.to_xarray() + path = tmp_path / "raw.nc" + checksum = save_xarray(ds, str(path)) + assert path.exists() + assert checksum == sha256_file(str(path)) + + def test_load_missing_file_raises(self, tmp_path): + with pytest.raises(FileNotFoundError): + load_xarray(str(tmp_path / "missing.nc")) + + +class TestZarr: + + def test_roundtrip(self, backend, tmp_path): + pytest.importorskip("zarr") + path = tmp_path / "cube.zarr" + checksum = save_xarray(backend, str(path)) + assert checksum == "" # remote/zarr checksum deferred (documented) + + loaded, load_checksum = load_xarray(str(path)) + assert load_checksum == "" + np.testing.assert_array_equal(loaded.arrays["uso"], backend.arrays["uso"]) + + +class TestFileChecksum: + + def test_existing_file(self, tmp_path): + path = tmp_path / "f.bin" + path.write_bytes(b"abc") + assert _file_checksum(str(path)) == sha256_file(str(path)) + + def test_missing_file_returns_empty(self, tmp_path): + assert _file_checksum(str(tmp_path / "missing")) == "" diff --git a/tests/visualization/test_chart.py b/tests/visualization/test_chart.py new file mode 100644 index 0000000..a78d658 --- /dev/null +++ b/tests/visualization/test_chart.py @@ -0,0 +1,198 @@ +""" +tests/visualization/test_chart.py +================================== +Tests for dissmodel.visualization.chart — track_plot decorator and the +Chart component. + +All rendering paths are exercised headlessly (Agg backend): +- headless save_frames (PNG per step) +- Streamlit path (fake plot_area) +- notebook path with and without an anchored Output widget +- select= filtering and styling options (legend, grid, title) +""" +from __future__ import annotations + +import matplotlib +matplotlib.use("Agg") + +import matplotlib.pyplot as plt +import pytest + +from dissmodel.core import Environment, Model +from dissmodel.visualization.chart import Chart, track_plot + + +# ── helper model with tracked attributes ────────────────────────────────────── + +@track_plot(label="Infected", color="red") +@track_plot(label="Recovered", color="green") +class SIR(Model): + def setup(self): + self.infected = 100 + self.recovered = 0 + + def execute(self): + self.infected -= 10 + self.recovered += 10 + + +@pytest.fixture(autouse=True) +def close_figures(): + yield + plt.close("all") + + +@pytest.fixture(autouse=True) +def fresh_plot_info(): + """track_plot stores data buffers on the class — reset between tests.""" + for info in SIR._plot_info.values(): + info["data"] = [] + yield + + +# ══════════════════════════════════════════════════════════════════════════════ +# track_plot decorator +# ══════════════════════════════════════════════════════════════════════════════ + +class TestTrackPlot: + + def test_registers_plot_info_per_label(self): + assert set(SIR._plot_info) == {"infected", "recovered"} + assert SIR._plot_info["infected"]["color"] == "red" + assert SIR._plot_info["infected"]["plot_type"] == "line" + + def test_attribute_assignment_appends_to_buffer(self): + Environment(start_time=0, end_time=1) + model = SIR() + model.infected = 90 + # setup() assigned 100, then 90 + assert SIR._plot_info["infected"]["data"][-2:] == [100, 90] + + def test_tracked_labels_registered_in_environment(self): + env = Environment(start_time=0, end_time=1) + SIR() + assert "Infected" in env._plot_metadata + assert "Recovered" in env._plot_metadata + + +# ══════════════════════════════════════════════════════════════════════════════ +# Chart — setup and rendering +# ══════════════════════════════════════════════════════════════════════════════ + +class TestChartSetup: + + def test_defaults(self): + Environment(start_time=0, end_time=1) + chart = Chart() + assert chart.select is None + assert chart.show_legend is True + assert chart.show_grid is False + assert chart.title == "Variable History" + assert chart.fig is not None # created outside notebooks + + def test_custom_options(self): + Environment(start_time=0, end_time=1) + chart = Chart( + select=["Infected"], + show_legend=False, + show_grid=True, + title="SIR", + pause=False, + ) + assert chart.select == ["Infected"] + assert chart.title == "SIR" + + +class TestChartRender: + + def _env_with_data(self, **chart_kwargs): + env = Environment(start_time=0, end_time=2) + SIR() + chart = Chart(pause=False, **chart_kwargs) + return env, chart + + def test_render_plots_all_tracked_variables(self): + env, chart = self._env_with_data(show_grid=True) + env.run() + # one line per tracked label + labels = [line.get_label() for line in chart.ax.get_lines()] + assert "Infected" in labels and "Recovered" in labels + + def test_select_filters_variables(self): + env, chart = self._env_with_data(select=["Infected"]) + env.run() + labels = [line.get_label() for line in chart.ax.get_lines()] + assert "Infected" in labels + assert "Recovered" not in labels + + def test_time_points_accumulate(self): + env, chart = self._env_with_data() + env.run() + assert len(chart.time_points) > 0 + + +# ══════════════════════════════════════════════════════════════════════════════ +# Chart — execution targets +# ══════════════════════════════════════════════════════════════════════════════ + +class TestChartExecuteTargets: + + def test_headless_save_frames_writes_png(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + env = Environment(start_time=0, end_time=2) + SIR() + Chart(save_frames=True, pause=False) + env.run() + + frames = sorted((tmp_path / "chart_frames").glob("chart_step_*.png")) + assert len(frames) >= 1 + + def test_headless_without_save_frames_still_saves_on_agg( + self, tmp_path, monkeypatch + ): + # Agg is not an interactive backend → falls into _save_frame branch + monkeypatch.chdir(tmp_path) + env = Environment(start_time=0, end_time=1) + SIR() + Chart(save_frames=False, pause=False) + env.run() + assert (tmp_path / "chart_frames").exists() + + def test_streamlit_plot_area_receives_figure(self): + class FakePlotArea: + def __init__(self): + self.figures = [] + + def pyplot(self, fig): + self.figures.append(fig) + + area = FakePlotArea() + env = Environment(start_time=0, end_time=2) + SIR() + Chart(plot_area=area, pause=False) + env.run() + assert len(area.figures) >= 1 + + def test_notebook_path_with_anchored_widget(self, monkeypatch): + monkeypatch.setattr( + "dissmodel.visualization.chart.is_notebook", lambda: True + ) + env = Environment(start_time=0, end_time=1) + SIR() + chart = Chart(pause=False) + + if chart._out is None: + pytest.skip("ipywidgets not installed — fallback covered elsewhere") + env.run() # must not raise + + def test_notebook_path_without_widget_uses_image_fallback(self, monkeypatch): + import sys + monkeypatch.setattr( + "dissmodel.visualization.chart.is_notebook", lambda: True + ) + monkeypatch.setitem(sys.modules, "ipywidgets", None) + env = Environment(start_time=0, end_time=1) + SIR() + chart = Chart(pause=False) + assert chart._out is None + env.run() # exercises clear_output + Image fallback without raising diff --git a/tests/visualization/test_map.py b/tests/visualization/test_map.py new file mode 100644 index 0000000..9eeed66 --- /dev/null +++ b/tests/visualization/test_map.py @@ -0,0 +1,173 @@ +""" +tests/visualization/test_map.py +================================ +Tests for dissmodel.visualization.map — the Map (choropleth) component — +and dissmodel.visualization._utils — environment/backend detection. + +All paths run headlessly (Agg backend). +""" +from __future__ import annotations + +import matplotlib +matplotlib.use("Agg") + +import matplotlib.pyplot as plt +import pytest + +from dissmodel.core import Environment +from dissmodel.geo import vector_grid +from dissmodel.visualization.map import Map +from dissmodel.visualization import _utils + + +@pytest.fixture(autouse=True) +def close_figures(): + yield + plt.close("all") + + +@pytest.fixture +def gdf(): + grid = vector_grid(dimension=(3, 3), resolution=1.0) + grid["state"] = range(len(grid)) + return grid + + +# ══════════════════════════════════════════════════════════════════════════════ +# Map — setup +# ══════════════════════════════════════════════════════════════════════════════ + +class TestMapSetup: + + def test_defaults(self, gdf): + Environment(start_time=0, end_time=1) + m = Map(gdf=gdf, plot_params={"column": "state"}) + assert m.figsize == (10, 6) + assert m.pause is True + assert m.fig is not None and m.ax is not None + + def test_custom_figsize(self, gdf): + Environment(start_time=0, end_time=1) + m = Map(gdf=gdf, plot_params={}, figsize=(4, 3)) + assert tuple(m.fig.get_size_inches()) == (4.0, 3.0) + + +# ══════════════════════════════════════════════════════════════════════════════ +# Map — rendering and execution targets +# ══════════════════════════════════════════════════════════════════════════════ + +class TestMapExecute: + + def test_render_sets_step_title(self, gdf): + Environment(start_time=0, end_time=1) + m = Map(gdf=gdf, plot_params={"column": "state"}, pause=False) + fig = m._render(step=3) + assert "Step 3" in fig.axes[0].get_title() + + def test_headless_save_frames_writes_png(self, gdf, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + env = Environment(start_time=0, end_time=2) + Map(gdf=gdf, plot_params={"column": "state"}, save_frames=True, pause=False) + env.run() + + frames = sorted((tmp_path / "map_frames").glob("state_step_*.png")) + assert len(frames) >= 1 + + def test_save_frame_uses_default_name_without_column( + self, gdf, tmp_path, monkeypatch + ): + monkeypatch.chdir(tmp_path) + env = Environment(start_time=0, end_time=1) + Map(gdf=gdf, plot_params={}, save_frames=True, pause=False) + env.run() + frames = list((tmp_path / "map_frames").glob("map_step_*.png")) + assert len(frames) >= 1 + + def test_streamlit_plot_area_receives_figure(self, gdf): + class FakePlotArea: + def __init__(self): + self.figures = [] + + def pyplot(self, fig): + self.figures.append(fig) + + area = FakePlotArea() + env = Environment(start_time=0, end_time=2) + Map(gdf=gdf, plot_params={"column": "state"}, plot_area=area, pause=False) + env.run() + assert len(area.figures) >= 1 + + def test_notebook_path_without_widget(self, gdf, monkeypatch): + import sys + monkeypatch.setattr( + "dissmodel.visualization.map.is_notebook", lambda: True + ) + monkeypatch.setitem(sys.modules, "ipywidgets", None) + env = Environment(start_time=0, end_time=1) + m = Map(gdf=gdf, plot_params={"column": "state"}, pause=False) + assert m._out is None + env.run() # exercises notebook fallback branch without raising + + def test_notebook_path_with_anchored_widget(self, gdf, monkeypatch): + monkeypatch.setattr( + "dissmodel.visualization.map.is_notebook", lambda: True + ) + env = Environment(start_time=0, end_time=1) + m = Map(gdf=gdf, plot_params={"column": "state"}, pause=False) + if m._out is None: + pytest.skip("ipywidgets not installed — fallback covered elsewhere") + env.run() + + +# ══════════════════════════════════════════════════════════════════════════════ +# _utils — backend / environment detection +# ══════════════════════════════════════════════════════════════════════════════ + +class TestVizUtils: + + def test_agg_is_not_interactive(self, monkeypatch): + monkeypatch.setattr(matplotlib, "get_backend", lambda: "Agg") + assert _utils.is_interactive_backend() is False + + @pytest.mark.parametrize("backend", ["TkAgg", "QtAgg", "MacOSX"]) + def test_interactive_backends_detected(self, monkeypatch, backend): + monkeypatch.setattr(matplotlib, "get_backend", lambda: backend) + assert _utils.is_interactive_backend() is True + + def test_is_notebook_reflects_cached_env(self, monkeypatch): + monkeypatch.setattr(_utils, "_ENV", "jupyter") + assert _utils.is_notebook() is True + monkeypatch.setattr(_utils, "_ENV", "headless") + assert _utils.is_notebook() is False + monkeypatch.setattr(_utils, "_ENV", "colab") + assert _utils.is_notebook() is True + + def test_detect_environment_headless_in_pytest(self): + # pytest runs without an IPython kernel — both ImportError and + # get_ipython() is None resolve to 'headless' + assert _utils._detect_environment() in ("headless", "ipython") + + def test_detect_environment_jupyter_shell(self, monkeypatch): + class ZMQInteractiveShell: + pass + + import IPython + monkeypatch.setattr( + IPython, "get_ipython", lambda: ZMQInteractiveShell() + ) + assert _utils._detect_environment() == "jupyter" + + def test_detect_environment_terminal_shell(self, monkeypatch): + class TerminalInteractiveShell: + pass + + import IPython + monkeypatch.setattr( + IPython, "get_ipython", lambda: TerminalInteractiveShell() + ) + assert _utils._detect_environment() == "ipython" + + def test_detect_environment_colab(self, monkeypatch): + import sys + monkeypatch.setitem(sys.modules, "google.colab", object()) + assert _utils._detect_environment() == "colab"