Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

In development

- Add basic support for non-SARS-CoV-2 genomes via an optional reference FASTA.
Supply `--reference` to `import-alignments` and a `reference_fasta` key in the
inference config; both default to the built-in SARS-CoV-2 reference, so
existing workflows are unchanged.

## [1.0.2] - 2026-03-05

Maintenance release.
Expand Down
10 changes: 8 additions & 2 deletions docs/example_config.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@

# This is a path to the dataset, in VCZ format.
dataset="viridian_mafft_2024-10-14_v1.vcz.zip"
# The metadata field used for dates. For the Viridian dataset, this is
# "Date_tree" (which means, "date used to partition samples when building
# The metadata field used for dates. For the Viridian dataset, this is
# "Date_tree" (which means, "date used to partition samples when building
# the Viridian tree")
date_field="Date_tree"

# Optional: path to a reference genome FASTA. If omitted, the built-in
# SARS-CoV-2 reference (MN908947, Wuhan/Hu-1/2019) is used. Supply this to run
# inference on another genome; it must be the same reference the dataset was
# built with (see "sc2ts import-alignments --reference").
# reference_fasta="reference.fasta"

# The run_id is a prefix added to all output files. This is useful when
# running lots of different parameter combinations.
run_id="ex1"
Expand Down
31 changes: 28 additions & 3 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,29 @@ def setup_logging(verbosity, log_file=None, date=None):
"If true, initialise a new dataset. WARNING! This will erase an existing store"
),
)
@click.option(
"--reference",
type=click.Path(exists=True, dir_okay=False),
default=None,
help="Reference FASTA defining the genome (default: built-in SARS-CoV-2)",
)
@progress
@verbose
def import_alignments(dataset, fastas, initialise, progress, verbose):
def import_alignments(dataset, fastas, initialise, reference, progress, verbose):
"""
Import the alignments from all FASTAS into the dataset
"""
setup_logging(verbose)
if initialise:
sc2ts.Dataset.new(dataset)
if reference is None:
sc2ts.Dataset.new(dataset)
else:
ref_seq = data_import.get_reference_sequence(reference)
sc2ts.Dataset.new(
dataset,
sequence_length=len(ref_seq),
contig_id=data_import.get_reference_id(reference),
)

f_bar = tqdm.tqdm(sorted(fastas), desc="Files", disable=not progress, position=0)
for fasta_path in f_bar:
Expand Down Expand Up @@ -304,14 +318,15 @@ def infer(config_file, start, stop, force):

ts_file_pattern = str(results_dir / f"{run_id}_{{date}}.ts")
exclude_sites = config.pop("exclude_sites", [])
reference_fasta = config.pop("reference_fasta", None)

if start is None:
if match_db.exists() and not force:
click.confirm(
f"Do you want to overwrite MatchDB at {match_db}",
abort=True,
)
init_ts = si.initial_ts(exclude_sites)
init_ts = si.initial_ts(exclude_sites, reference_fasta=reference_fasta)
si.MatchDb.initialise(match_db)
base_ts = results_dir / f"{run_id}_init.ts"
init_ts.dump(base_ts)
Expand Down Expand Up @@ -341,6 +356,16 @@ def infer(config_file, start, stop, force):
raise ValueError(f"Unknown keys in config: {list(config.keys())}")
ds = sc2ts.Dataset(dataset, date_field=date_field)

if reference_fasta is not None:
reference_length = len(data_import.get_reference_sequence(reference_fasta))
contig_length = int(ds["contig_length"][0])
if reference_length != contig_length:
raise ValueError(
f"reference_fasta length ({reference_length}) does not match the "
f"dataset contig length ({contig_length}). The reference_fasta must "
"be the same genome the dataset was built with."
)

for date in np.unique(ds.metadata.sample_date):
if date >= stop:
break
Expand Down
5 changes: 5 additions & 0 deletions sc2ts/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
REFERENCE_GENBANK = "MN908947"
REFERENCE_SEQUENCE_LENGTH = 29904

# Generic time-zero epoch used as the reference date when inference is run on a
# non-SARS-CoV-2 genome (i.e. a custom reference FASTA is supplied). It only
# needs to be a valid date that precedes all sample dates.
GENERIC_REFERENCE_DATE = "1900-01-01"

# We omit N here as it's mapped to -1. Make "-" the 5th allele
# as this is a valid allele for us.
# NOTE!! This string is also used in the jit module where it's
Expand Down
37 changes: 30 additions & 7 deletions sc2ts/data_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,39 @@ def __len__(self):
__cached_reference = None


def get_reference_sequence(as_array=False):
global __cached_reference
if __cached_reference is None:
reader = pyfaidx.Fasta(str(data_path / "reference.fasta"))
__cached_reference = reader[core.REFERENCE_GENBANK]
def get_reference_sequence(path=None, as_array=False):
"""
Return the reference sequence with an "X" prepended at position 0 so that
the genome uses 1-based coordinates. If ``path`` is None the built-in
SARS-CoV-2 reference is used; otherwise the first record in the FASTA at
``path`` is used.
"""
if path is None:
global __cached_reference
if __cached_reference is None:
reader = pyfaidx.Fasta(str(data_path / "reference.fasta"))
__cached_reference = reader[core.REFERENCE_GENBANK]
reference = __cached_reference
else:
reader = pyfaidx.Fasta(str(path))
reference = reader[list(reader.keys())[0]]
if as_array:
h = np.array(__cached_reference).astype(str)
h = np.array(reference).astype(str)
return np.append(["X"], h)
else:
return "X" + str(__cached_reference)
return "X" + str(reference)


def get_reference_id(path=None):
"""
Return the identifier for the reference genome. If ``path`` is None this is
the built-in SARS-CoV-2 GenBank accession; otherwise it is the name of the
first record in the FASTA at ``path``.
"""
if path is None:
return core.REFERENCE_GENBANK
reader = pyfaidx.Fasta(str(path))
return list(reader.keys())[0]


__cached_genes = None
Expand Down
16 changes: 13 additions & 3 deletions sc2ts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,14 +469,24 @@ def reorder(self, path, additional_fields=list(), show_progress=False):
self.copy(path, sample_id=sample_id[index], show_progress=show_progress)

@staticmethod
def new(path, samples_chunk_size=None, variants_chunk_size=None):
def new(
path,
samples_chunk_size=None,
variants_chunk_size=None,
sequence_length=None,
contig_id=None,
):
if samples_chunk_size is None:
samples_chunk_size = 10_000
if variants_chunk_size is None:
variants_chunk_size = 100
if sequence_length is None:
sequence_length = core.REFERENCE_SEQUENCE_LENGTH
if contig_id is None:
contig_id = core.REFERENCE_STRAIN

logger.info(f"Creating new dataset at {path}")
L = core.REFERENCE_SEQUENCE_LENGTH - 1
L = sequence_length - 1
N = 0 # Samples must be added
store = zarr.DirectoryStore(path)
root = zarr.open(store, mode="w")
Expand Down Expand Up @@ -508,7 +518,7 @@ def new(path, samples_chunk_size=None, variants_chunk_size=None):
dtype="str",
compressor=DEFAULT_ZARR_COMPRESSOR,
)
z[0] = core.REFERENCE_STRAIN
z[0] = contig_id
z.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]

z = root.empty(
Expand Down
32 changes: 21 additions & 11 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,22 @@ def mirror_ts_coordinates(ts):
return tables.tree_sequence()


def initial_ts(problematic_sites=None):
def initial_ts(problematic_sites=None, reference_fasta=None):
if problematic_sites is None:
problematic_sites = []
reference = data_import.get_reference_sequence()
L = core.REFERENCE_SEQUENCE_LENGTH
assert L == len(reference)
reference = data_import.get_reference_sequence(reference_fasta)
L = len(reference)
if reference_fasta is None:
genbank_id = core.REFERENCE_GENBANK
reference_strain = core.REFERENCE_STRAIN
reference_date = core.REFERENCE_DATE
else:
contig_id = data_import.get_reference_id(reference_fasta)
genbank_id = contig_id
reference_strain = contig_id
# Generic epoch for non-SARS-CoV-2 genomes: the reference defines time
# zero, and this just needs to be a valid date preceding all samples.
reference_date = core.GENERIC_REFERENCE_DATE
problematic_sites = set(problematic_sites)

logger.info(f"Masking out {len(problematic_sites)} sites")
Expand All @@ -281,15 +291,15 @@ def initial_ts(problematic_sites=None):
base_schema = tskit.MetadataSchema.permissive_json().schema
tables.reference_sequence.metadata_schema = tskit.MetadataSchema(base_schema)
tables.reference_sequence.metadata = {
"genbank_id": core.REFERENCE_GENBANK,
"genbank_id": genbank_id,
"notes": "X prepended to alignment to map from 1-based to 0-based coordinates",
}
tables.reference_sequence.data = reference

tables.metadata_schema = tskit.MetadataSchema(base_schema)
tables.metadata = {
"sc2ts": {
"date": core.REFERENCE_DATE,
"date": reference_date,
"samples_strain": [],
"daily_stats": {},
"cumulative_stats": {
Expand Down Expand Up @@ -328,8 +338,8 @@ def initial_ts(problematic_sites=None):
flags=core.NODE_IS_REFERENCE,
time=0,
metadata={
"strain": core.REFERENCE_STRAIN,
"date": core.REFERENCE_DATE,
"strain": reference_strain,
"date": reference_date,
"sc2ts": {"notes": "Reference sequence"},
},
)
Expand Down Expand Up @@ -806,12 +816,12 @@ def add_sample_to_tables(sample, tables, group_id=None, time=0):
return tables.nodes.add_row(flags=sample.flags, metadata=metadata, time=time)


def match_path_ts(group):
def match_path_ts(group, sequence_length):
"""
Given the specified SampleGroup return the tree sequence rooted at
zero representing the data.
"""
tables = tskit.TableCollection(core.REFERENCE_SEQUENCE_LENGTH)
tables = tskit.TableCollection(sequence_length)
tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json()
tables.mutations.metadata_schema = tskit.MetadataSchema.permissive_json()
site_id_map = {}
Expand Down Expand Up @@ -1039,7 +1049,7 @@ def add_matching_results(
f"{group.summary()}"
)
continue
flat_ts = match_path_ts(group)
flat_ts = match_path_ts(group, ts.sequence_length)
if flat_ts.num_mutations == 0 or flat_ts.num_samples == 1:
poly_ts = flat_ts
else:
Expand Down
Loading
Loading