From fc0936a7ad8153b1035f7f28570fb5f8f5d4c1e6 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 26 Jun 2026 11:28:12 +0100 Subject: [PATCH] Add optional reference FASTA for non-SARS-CoV-2 genomes Make the reference genome an optional input so inference can run on genomes other than SARS-CoV-2, as a non-breaking change: all new parameters default to the built-in SARS-CoV-2 reference. - import-alignments gains a --reference option, sizing/labelling the dataset from the supplied FASTA (Dataset.new gains sequence_length and contig_id kwargs). - infer reads an optional reference_fasta config key, threaded through initial_ts; genome length and identity metadata are derived from the FASTA and the length is checked against the dataset contig length. - match_path_ts takes the sequence length from the working tree sequence rather than the hardcoded constant. - Non-SARS-CoV-2 runs use a generic time-zero epoch for the reference. --- CHANGELOG.md | 5 ++ docs/example_config.toml | 10 +++- sc2ts/cli.py | 31 +++++++++- sc2ts/core.py | 5 ++ sc2ts/data_import.py | 37 +++++++++--- sc2ts/dataset.py | 16 ++++- sc2ts/inference.py | 32 ++++++---- tests/test_cli.py | 125 +++++++++++++++++++++++++++++++++++++++ tests/test_dataset.py | 44 ++++++++++++++ tests/test_inference.py | 20 +++++++ 10 files changed, 299 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ebb60bd..07dd0e29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ In development +- Add basic support for non-SARS-CoV-2 genomes via an optional reference FASTA. + Supply `--reference` to `import-alignments` and a `reference_fasta` key in the + inference config; both default to the built-in SARS-CoV-2 reference, so + existing workflows are unchanged. + ## [1.0.2] - 2026-03-05 Maintenance release. diff --git a/docs/example_config.toml b/docs/example_config.toml index daf187b3..da41e4ea 100644 --- a/docs/example_config.toml +++ b/docs/example_config.toml @@ -1,11 +1,17 @@ # This is a path to the dataset, in VCZ format. dataset="viridian_mafft_2024-10-14_v1.vcz.zip" -# The metadata field used for dates. For the Viridian dataset, this is -# "Date_tree" (which means, "date used to partition samples when building +# The metadata field used for dates. For the Viridian dataset, this is +# "Date_tree" (which means, "date used to partition samples when building # the Viridian tree") date_field="Date_tree" +# Optional: path to a reference genome FASTA. If omitted, the built-in +# SARS-CoV-2 reference (MN908947, Wuhan/Hu-1/2019) is used. Supply this to run +# inference on another genome; it must be the same reference the dataset was +# built with (see "sc2ts import-alignments --reference"). +# reference_fasta="reference.fasta" + # The run_id is a prefix added to all output files. This is useful when # running lots of different parameter combinations. run_id="ex1" diff --git a/sc2ts/cli.py b/sc2ts/cli.py index abbb45ed..1e15ffdf 100644 --- a/sc2ts/cli.py +++ b/sc2ts/cli.py @@ -127,15 +127,29 @@ def setup_logging(verbosity, log_file=None, date=None): "If true, initialise a new dataset. WARNING! This will erase an existing store" ), ) +@click.option( + "--reference", + type=click.Path(exists=True, dir_okay=False), + default=None, + help="Reference FASTA defining the genome (default: built-in SARS-CoV-2)", +) @progress @verbose -def import_alignments(dataset, fastas, initialise, progress, verbose): +def import_alignments(dataset, fastas, initialise, reference, progress, verbose): """ Import the alignments from all FASTAS into the dataset """ setup_logging(verbose) if initialise: - sc2ts.Dataset.new(dataset) + if reference is None: + sc2ts.Dataset.new(dataset) + else: + ref_seq = data_import.get_reference_sequence(reference) + sc2ts.Dataset.new( + dataset, + sequence_length=len(ref_seq), + contig_id=data_import.get_reference_id(reference), + ) f_bar = tqdm.tqdm(sorted(fastas), desc="Files", disable=not progress, position=0) for fasta_path in f_bar: @@ -304,6 +318,7 @@ def infer(config_file, start, stop, force): ts_file_pattern = str(results_dir / f"{run_id}_{{date}}.ts") exclude_sites = config.pop("exclude_sites", []) + reference_fasta = config.pop("reference_fasta", None) if start is None: if match_db.exists() and not force: @@ -311,7 +326,7 @@ def infer(config_file, start, stop, force): f"Do you want to overwrite MatchDB at {match_db}", abort=True, ) - init_ts = si.initial_ts(exclude_sites) + init_ts = si.initial_ts(exclude_sites, reference_fasta=reference_fasta) si.MatchDb.initialise(match_db) base_ts = results_dir / f"{run_id}_init.ts" init_ts.dump(base_ts) @@ -341,6 +356,16 @@ def infer(config_file, start, stop, force): raise ValueError(f"Unknown keys in config: {list(config.keys())}") ds = sc2ts.Dataset(dataset, date_field=date_field) + if reference_fasta is not None: + reference_length = len(data_import.get_reference_sequence(reference_fasta)) + contig_length = int(ds["contig_length"][0]) + if reference_length != contig_length: + raise ValueError( + f"reference_fasta length ({reference_length}) does not match the " + f"dataset contig length ({contig_length}). The reference_fasta must " + "be the same genome the dataset was built with." + ) + for date in np.unique(ds.metadata.sample_date): if date >= stop: break diff --git a/sc2ts/core.py b/sc2ts/core.py index 4a04c94a..a527e874 100644 --- a/sc2ts/core.py +++ b/sc2ts/core.py @@ -17,6 +17,11 @@ REFERENCE_GENBANK = "MN908947" REFERENCE_SEQUENCE_LENGTH = 29904 +# Generic time-zero epoch used as the reference date when inference is run on a +# non-SARS-CoV-2 genome (i.e. a custom reference FASTA is supplied). It only +# needs to be a valid date that precedes all sample dates. +GENERIC_REFERENCE_DATE = "1900-01-01" + # We omit N here as it's mapped to -1. Make "-" the 5th allele # as this is a valid allele for us. # NOTE!! This string is also used in the jit module where it's diff --git a/sc2ts/data_import.py b/sc2ts/data_import.py index dc14e961..ca8cdf4c 100644 --- a/sc2ts/data_import.py +++ b/sc2ts/data_import.py @@ -34,16 +34,39 @@ def __len__(self): __cached_reference = None -def get_reference_sequence(as_array=False): - global __cached_reference - if __cached_reference is None: - reader = pyfaidx.Fasta(str(data_path / "reference.fasta")) - __cached_reference = reader[core.REFERENCE_GENBANK] +def get_reference_sequence(path=None, as_array=False): + """ + Return the reference sequence with an "X" prepended at position 0 so that + the genome uses 1-based coordinates. If ``path`` is None the built-in + SARS-CoV-2 reference is used; otherwise the first record in the FASTA at + ``path`` is used. + """ + if path is None: + global __cached_reference + if __cached_reference is None: + reader = pyfaidx.Fasta(str(data_path / "reference.fasta")) + __cached_reference = reader[core.REFERENCE_GENBANK] + reference = __cached_reference + else: + reader = pyfaidx.Fasta(str(path)) + reference = reader[list(reader.keys())[0]] if as_array: - h = np.array(__cached_reference).astype(str) + h = np.array(reference).astype(str) return np.append(["X"], h) else: - return "X" + str(__cached_reference) + return "X" + str(reference) + + +def get_reference_id(path=None): + """ + Return the identifier for the reference genome. If ``path`` is None this is + the built-in SARS-CoV-2 GenBank accession; otherwise it is the name of the + first record in the FASTA at ``path``. + """ + if path is None: + return core.REFERENCE_GENBANK + reader = pyfaidx.Fasta(str(path)) + return list(reader.keys())[0] __cached_genes = None diff --git a/sc2ts/dataset.py b/sc2ts/dataset.py index 26e62803..8cd6cd48 100644 --- a/sc2ts/dataset.py +++ b/sc2ts/dataset.py @@ -469,14 +469,24 @@ def reorder(self, path, additional_fields=list(), show_progress=False): self.copy(path, sample_id=sample_id[index], show_progress=show_progress) @staticmethod - def new(path, samples_chunk_size=None, variants_chunk_size=None): + def new( + path, + samples_chunk_size=None, + variants_chunk_size=None, + sequence_length=None, + contig_id=None, + ): if samples_chunk_size is None: samples_chunk_size = 10_000 if variants_chunk_size is None: variants_chunk_size = 100 + if sequence_length is None: + sequence_length = core.REFERENCE_SEQUENCE_LENGTH + if contig_id is None: + contig_id = core.REFERENCE_STRAIN logger.info(f"Creating new dataset at {path}") - L = core.REFERENCE_SEQUENCE_LENGTH - 1 + L = sequence_length - 1 N = 0 # Samples must be added store = zarr.DirectoryStore(path) root = zarr.open(store, mode="w") @@ -508,7 +518,7 @@ def new(path, samples_chunk_size=None, variants_chunk_size=None): dtype="str", compressor=DEFAULT_ZARR_COMPRESSOR, ) - z[0] = core.REFERENCE_STRAIN + z[0] = contig_id z.attrs["_ARRAY_DIMENSIONS"] = ["contigs"] z = root.empty( diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 9d89f17c..0e5d46e1 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -264,12 +264,22 @@ def mirror_ts_coordinates(ts): return tables.tree_sequence() -def initial_ts(problematic_sites=None): +def initial_ts(problematic_sites=None, reference_fasta=None): if problematic_sites is None: problematic_sites = [] - reference = data_import.get_reference_sequence() - L = core.REFERENCE_SEQUENCE_LENGTH - assert L == len(reference) + reference = data_import.get_reference_sequence(reference_fasta) + L = len(reference) + if reference_fasta is None: + genbank_id = core.REFERENCE_GENBANK + reference_strain = core.REFERENCE_STRAIN + reference_date = core.REFERENCE_DATE + else: + contig_id = data_import.get_reference_id(reference_fasta) + genbank_id = contig_id + reference_strain = contig_id + # Generic epoch for non-SARS-CoV-2 genomes: the reference defines time + # zero, and this just needs to be a valid date preceding all samples. + reference_date = core.GENERIC_REFERENCE_DATE problematic_sites = set(problematic_sites) logger.info(f"Masking out {len(problematic_sites)} sites") @@ -281,7 +291,7 @@ def initial_ts(problematic_sites=None): base_schema = tskit.MetadataSchema.permissive_json().schema tables.reference_sequence.metadata_schema = tskit.MetadataSchema(base_schema) tables.reference_sequence.metadata = { - "genbank_id": core.REFERENCE_GENBANK, + "genbank_id": genbank_id, "notes": "X prepended to alignment to map from 1-based to 0-based coordinates", } tables.reference_sequence.data = reference @@ -289,7 +299,7 @@ def initial_ts(problematic_sites=None): tables.metadata_schema = tskit.MetadataSchema(base_schema) tables.metadata = { "sc2ts": { - "date": core.REFERENCE_DATE, + "date": reference_date, "samples_strain": [], "daily_stats": {}, "cumulative_stats": { @@ -328,8 +338,8 @@ def initial_ts(problematic_sites=None): flags=core.NODE_IS_REFERENCE, time=0, metadata={ - "strain": core.REFERENCE_STRAIN, - "date": core.REFERENCE_DATE, + "strain": reference_strain, + "date": reference_date, "sc2ts": {"notes": "Reference sequence"}, }, ) @@ -806,12 +816,12 @@ def add_sample_to_tables(sample, tables, group_id=None, time=0): return tables.nodes.add_row(flags=sample.flags, metadata=metadata, time=time) -def match_path_ts(group): +def match_path_ts(group, sequence_length): """ Given the specified SampleGroup return the tree sequence rooted at zero representing the data. """ - tables = tskit.TableCollection(core.REFERENCE_SEQUENCE_LENGTH) + tables = tskit.TableCollection(sequence_length) tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json() tables.mutations.metadata_schema = tskit.MetadataSchema.permissive_json() site_id_map = {} @@ -1039,7 +1049,7 @@ def add_matching_results( f"{group.summary()}" ) continue - flat_ts = match_path_ts(group) + flat_ts = match_path_ts(group, ts.sequence_length) if flat_ts.num_mutations == 0 or flat_ts.num_samples == 1: poly_ts = flat_ts else: diff --git a/tests/test_cli.py b/tests/test_cli.py index a2a6d7eb..8f12508d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -42,6 +42,131 @@ def test_duplicate_aligments(self, tmp_path, fx_alignments_fasta): assert result.exit_code == 1 +def write_custom_genome(tmp_path, length=60): + """ + Write a small synthetic reference FASTA, a matching-length alignments FASTA + and a metadata TSV. Returns (reference_path, dataset_path, metadata_path). + """ + rng = np.random.RandomState(42) + bases = np.array(list("ACGT")) + ref = "".join(rng.choice(bases, size=length)) + ref_path = tmp_path / "ref.fasta" + ref_path.write_text(f">chr_test synthetic genome\n{ref}\n") + + aln_path = tmp_path / "aln.fasta" + rows = [] + for j in range(4): + # Each sample is the reference with a single mutation. + h = list(ref) + h[10 + j] = "A" if h[10 + j] != "A" else "C" + rows.append(f">s{j}\n{''.join(h)}\n") + aln_path.write_text("".join(rows)) + + meta_path = tmp_path / "meta.tsv" + meta_lines = ["Run\tdate"] + for j in range(4): + meta_lines.append(f"s{j}\t2020-01-0{j + 1}") + meta_path.write_text("\n".join(meta_lines) + "\n") + return ref_path, aln_path, meta_path + + +class TestCustomReferenceGenome: + def build_dataset(self, tmp_path): + ref_path, aln_path, meta_path = write_custom_genome(tmp_path) + ds_path = tmp_path / "ds.zarr" + runner = ct.CliRunner() + result = runner.invoke( + cli.cli, + f"import-alignments {ds_path} {aln_path} -i " + f"--reference {ref_path} --no-progress", + catch_exceptions=False, + ) + assert result.exit_code == 0 + result = runner.invoke( + cli.cli, + f"import-metadata {ds_path} {meta_path}", + catch_exceptions=False, + ) + assert result.exit_code == 0 + return ref_path, ds_path + + def test_import_alignments_reference(self, tmp_path): + ref_path, ds_path = self.build_dataset(tmp_path) + ds = sc2ts.Dataset(ds_path, date_field="date") + assert ds["contig_id"][0] == "chr_test" + assert int(ds["contig_length"][0]) == 61 + assert ds.num_samples == 4 + + def test_infer_custom_reference(self, tmp_path): + ref_path, ds_path = self.build_dataset(tmp_path) + config = { + "dataset": str(ds_path), + "date_field": "date", + "run_id": "test", + "results_dir": str(tmp_path / "results"), + "log_dir": str(tmp_path / "logs"), + "matches_dir": str(tmp_path / "matches"), + "reference_fasta": str(ref_path), + "extend_parameters": {"min_group_size": 1, "num_threads": 0}, + } + config_file = tmp_path / "config.toml" + with open(config_file, "w") as f: + f.write(tomli_w.dumps(config)) + + runner = ct.CliRunner() + # Run the full pipeline over the sample dates so the matching machinery + # (match_path_ts) exercises the custom genome length end-to-end. + result = runner.invoke( + cli.cli, + f"infer {config_file} --stop 2020-02-01", + catch_exceptions=False, + ) + assert result.exit_code == 0 + + results_dir = tmp_path / "results" / "test" + init_ts = tskit.load(results_dir / "test_init.ts") + expected = si.initial_ts(reference_fasta=str(ref_path)) + expected.tables.assert_equals(init_ts.tables) + # The initial ts is sized to the custom genome and labelled from it. + assert init_ts.sequence_length == 61 + assert init_ts.node(1).metadata["strain"] == "chr_test" + + # A dated tree sequence is produced for each sample date, at the custom + # genome length. + for day in range(1, 5): + dated = tskit.load(results_dir / f"test_2020-01-0{day}.ts") + assert dated.sequence_length == 61 + assert dated.num_sites == 60 + + def test_infer_reference_length_mismatch(self, tmp_path): + # Build a length-60 custom dataset, then point the config at a reference + # of a different length -> the guard rejects the inconsistency. + _, ds_path = self.build_dataset(tmp_path) + ref_path = tmp_path / "wrong_ref.fasta" + ref_path.write_text(">chr_other\n" + "ACGT" * 10 + "\n") + config = { + "dataset": str(ds_path), + "date_field": "date", + "run_id": "test", + "results_dir": str(tmp_path / "results"), + "log_dir": str(tmp_path / "logs"), + "matches_dir": str(tmp_path / "matches"), + "reference_fasta": str(ref_path), + "extend_parameters": {}, + } + config_file = tmp_path / "config.toml" + with open(config_file, "w") as f: + f.write(tomli_w.dumps(config)) + + runner = ct.CliRunner() + with pytest.raises(ValueError, match="does not match the dataset"): + runner.invoke( + cli.cli, + f"infer {config_file} --stop 2020-01-01 -f", + catch_exceptions=False, + ) + + class TestImportMetadata: def test_suite_data(self, tmp_path, fx_metadata_tsv, fx_alignments_fasta): ds_path = tmp_path / "ds.zarr" diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 10d8ad8e..9f35801a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -40,6 +40,27 @@ def test_massaged_viridian_metadata(fx_raw_viridian_metadata_df): assert np.sum(df["Genbank_N"]) > 0 +class TestReferenceSequence: + def test_builtin_default(self): + ref = data_import.get_reference_sequence() + assert ref[0] == "X" + assert data_import.get_reference_id() == "MN908947" + + def test_custom_fasta(self, tmp_path): + seq = "ACGTACGTACGT" + fasta = tmp_path / "ref.fasta" + fasta.write_text(f">chr_test some description\n{seq}\n") + assert data_import.get_reference_sequence(str(fasta)) == "X" + seq + assert data_import.get_reference_id(str(fasta)) == "chr_test" + + def test_custom_fasta_as_array(self, tmp_path): + seq = "ACGTACGTACGT" + fasta = tmp_path / "ref.fasta" + fasta.write_text(f">chr_test\n{seq}\n") + ref = data_import.get_reference_sequence(str(fasta), as_array=True) + nt.assert_array_equal(ref, ["X"] + list(seq)) + + class TestCreateDataset: def test_new(self, tmp_path): path = tmp_path / "dataset.vcz" @@ -54,6 +75,29 @@ def test_new(self, tmp_path): } # TODO check various properties of the dataset + def test_new_custom_genome(self, tmp_path): + path = tmp_path / "dataset.vcz" + seq = "ACGTACGTACGTAACCGGTT" + sc2ts.Dataset.new(path, sequence_length=len(seq) + 1, contig_id="chr_test") + sg_ds = load_dataset(path) + assert dict(sg_ds.sizes) == { + "variants": len(seq), + "samples": 0, + "ploidy": 1, + "contigs": 1, + "alleles": 16, + } + assert sg_ds["contig_id"].values[0] == "chr_test" + assert sg_ds["contig_length"].values[0] == len(seq) + 1 + + # Length-M encoded alignments round-trip through append. + h = jit.encode_alleles(np.array(list(seq))) + sc2ts.Dataset.append_alignments(path, {"s0": h}) + sg_ds = load_dataset(path) + assert sg_ds.sizes["samples"] == 1 + H = sg_ds["call_genotype"].values.squeeze(2).T + nt.assert_array_equal(h, H[0]) + @pytest.mark.parametrize( ["num_samples", "chunk_size"], [ diff --git a/tests/test_inference.py b/tests/test_inference.py index 0f429372..d365a772 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -160,6 +160,26 @@ def test_reference_node(self): } assert node.flags == sc2ts.NODE_IS_REFERENCE + def test_custom_reference_fasta(self, tmp_path): + seq = "ACGTACGTACGTAACCGGTT" + fasta = tmp_path / "ref.fasta" + fasta.write_text(f">chr_test some description\n{seq}\n") + + ts = si.initial_ts(reference_fasta=str(fasta)) + # sequence_length and number of sites are driven by the reference length + # (with the "X" prepended at position 0 for 1-based coordinates). + assert ts.sequence_length == len(seq) + 1 + assert ts.num_sites == len(seq) + assert ts.reference_sequence.data == "X" + seq + # Ancestral states match the reference (1-based coordinates). + for site in ts.sites(): + assert site.ancestral_state == seq[int(site.position) - 1] + # Identity metadata is derived from the FASTA header, not SARS-CoV-2. + assert ts.reference_sequence.metadata["genbank_id"] == "chr_test" + node = ts.node(1) + assert node.metadata["strain"] == "chr_test" + assert node.metadata["date"] == "1900-01-01" + class TestMatchTsinfer: def match_tsinfer(self, samples, ts, mirror_coordinates=False, **kwargs):