8000 ENH: cogent3 Tree classes now record their source by GavinHuttley · Pull Request #2351 · cogent3/cogent3 · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

ENH: cogent3 Tree classes now record their source #2351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 12, 2025
39 changes: 25 additions & 14 deletions src/cogent3/app/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from functools import singledispatch
from io import TextIOWrapper
from pathlib import Path
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING

from scitrack import get_text_hexdigest

from cogent3.core import alignment as old_alignment
from cogent3.core import new_alignment
from cogent3.core import tree as c3tree
from cogent3.core.table import Table
from cogent3.util.deserialise import deserialise_object
from cogent3.util.io import get_format_suffixes, open_
Expand All @@ -37,7 +38,7 @@
# used for log files, not-completed results
_special_suffixes = re.compile(r"\.(log|json)$")

StrOrBytes = Union[str, bytes]
StrOrBytes = str | bytes
NoneType = type(None)


Expand Down Expand Up @@ -748,42 +749,52 @@ def get_data_source(data: object) -> str | None:


@get_data_source.register
def _(data: old_alignment.SequenceCollection):
def _(data: old_alignment.SequenceCollection) -> str | None:
return get_data_source(data.info)


@get_data_source.register
def _(data: old_alignment.ArrayAlignment):
def _(data: old_alignment.ArrayAlignment) -> str | None:
return get_data_source(data.info)


@get_data_source.register
def _(data: old_alignment.Alignment):
def _(data: old_alignment.Alignment) -> str | None:
return get_data_source(data.info)


@get_data_source.register
def _(data: new_alignment.Alignment):
return get_data_source(data.source)
def _(data: new_alignment.Alignment) -> str | None:
return data.source


@get_data_source.register
def _(data: new_alignment.SequenceCollection):
return get_data_source(data.source)
def _(data: new_alignment.SequenceCollection) -> str | None:
return data.source


@get_data_source.register
def _(data: str):
def _(data: c3tree.TreeNode) -> str | None:
return data.source


@get_data_source.register
def _(data: c3tree.PhyloNode) -> str | None:
return data.source


@get_data_source.register
def _(data: str) -> str | None:
return get_data_source(Path(data))


@get_data_source.register
def _(data: Path):
return str(data.name)
def _(data: Path) -> str | None:
return data.name


@get_data_source.register
def _(data: dict):
def _(data: dict) -> str | None:
try:
source = data.get("info", {})["source"]
except KeyError:
Expand All @@ -792,7 +803,7 @@ def _(data: dict):


@get_data_source.register
def _(data: DataMemberABC):
def _(data: DataMemberABC) -> str | None:
return str(data.unique_id)


Expand Down
6 changes: 3 additions & 3 deletions src/cogent3/app/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from cogent3.util.misc import get_true_spans

from .composable import NotCompleted, define_app
from .data_store import get_data_source
from .typing import (
AlignedSeqsType,
PairwiseDistanceType,
Expand Down Expand Up @@ -48,7 +49,6 @@ class fast_slow_dist:

Uses fast (but less numerically robust) approach where possible, slow (robust)
approach when not.

"""

def __init__(
Expand Down Expand Up @@ -170,7 +170,7 @@ def main(
else:
empty = dict.fromkeys(itertools.product(aln.names, aln.names), 0)
dists = DistanceMatrix(empty)
dists.source = aln.info.source
dists.source = get_data_source(aln)
if self._sm:
for a in dists.template.names[0]:
for b in dists.template.names[1]:
Expand Down Expand Up @@ -237,7 +237,7 @@ def jaccard_dist(seq_coll: UnalignedSeqsType, k: int = 10) -> PairwiseDistanceTy
f"could not compute distances between {names}",
source=seq_coll,
)
return DistanceMatrix(dists)
return DistanceMatrix(dists, source=get_data_source(seq_coll))


@define_app
Expand Down
15 changes: 12 additions & 3 deletions src/cogent3/app/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from cogent3.util.misc import is_url

from .composable import define_app
from .data_store import get_data_source
from .typing import PairwiseDistanceType, SerialisableType, TreeType

NoneType = type(None)
Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(
self._min_length = min_length

def main(self, tree: TreeType) -> SerialisableType | TreeType:
source = tree.source
scalar = self._scalar
min_length = self._min_length
tree = tree.deepcopy()
Expand All @@ -61,14 +63,17 @@ def main(self, tree: TreeType) -> SerialisableType | TreeType:
length = length / scalar if length else min_length
edge.length = abs(length) # force to be positive

tree.source = source
return tree


@define_app
class uniformize_tree:
"""Standardises the orientation of unrooted trees."""

def __init__(self, root_at="midpoint", ordered_names=None) -> None:
def __init__(
self, root_at: str = "midpoint", ordered_names: list[str] | None = None
) -> None:
"""
Parameters
----------
Expand All @@ -82,14 +87,18 @@ def __init__(self, root_at="midpoint", ordered_names=None) -> None:
self._ordered_names = ordered_names

def main(self, tree: TreeType) -> SerialisableType | TreeType:
source = tree.source
if self._root_at == "midpoint":
new = tree.root_at_midpoint()
else:
new = tree.rooted_with_tip(self._root_at)

if self._ordered_names is None:
self._ordered_names = tree.get_tip_names()
return new.sorted(self._ordered_names)

result = new.sorted(self._ordered_names)
result.source = source
return result


@define_app
Expand Down Expand Up @@ -126,7 +135,7 @@ def main(self, dists: PairwiseDistanceType) -> SerialisableType | TreeType:
(result,) = gnj(dists.to_dict(), keep=1, show_progress=False)
_, tree = result

tree.params["source"] = dists.source
tree.source = get_data_source(dists)
return tree


Expand Down
3 changes: 2 additions & 1 deletion src/cogent3/core/new_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2637,7 +2637,8 @@ def prep_for_seqs_data(
for name, seq in data.items():
name = seq_namer(seq=seq, name=name) # noqa: PLW2901
seq_data = coerce_to_raw_seq_data(seq, moltype, name=name)
offsets[seq_data.parent_name or name] = seq_data.offset
if seq_data.offset:
offsets[seq_data.parent_name or name] = seq_data.offset
seqs[seq_data.parent_name or seq_data.name] = seq_data.seq
if seq_data.is_reversed:
rvd.add(name)
Expand Down
49 changes: 39 additions & 10 deletions src/cogent3/core/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,27 +79,28 @@ class TreeNode:
Parameters:
name: label for the node, assumed to be unique.
children: list of the node's children.
parent: parent to this node
params: dict containing arbitrary parameters for the node.
name_loaded: ?
"""

_exclude_from_copy = dict.fromkeys(["_parent", "children"])
_exclude_from_copy = frozenset(["_parent", "children"])

def __init__(
self,
name=None,
children=None,
parent=None,
params=None,
name_loaded=True,
name: str | None = None,
children: list[TreeNode] | None = None,
parent: TreeNode | None = None,
params: dict[str, object] | None = None,
name_loaded: bool = True,
**kwargs,
) -> None:
"""Returns new TreeNode object."""
self.name = name
self.name_loaded = name_loaded
self.params = params or {}
self.children = []
if children is not None:
if children:
self.extend(children)
self._parent = parent
if parent is not None and self not in parent.children:
Expand Down Expand Up @@ -142,6 +143,18 @@ def __gt__(self, other):

return self_name > other_name

@property
def source(self) -> str | None:
return self.params.get("source")

@source.setter
def source(self, value: str | None) -> None:
"""Sets the source of the node."""
if value:
self.params["source"] = value
else:
self.params.pop("source", None)

def compare_name(self, other):
"""Compares TreeNode by name"""
return True if self is other else self.name == other.name
Expand Down Expand Up @@ -2320,6 +2333,7 @@ def make_tree(
tip_names: list[str] | None = None,
format: str | None = None,
underscore_unmunge: bool = False,
source: str | None = None,
) -> PhyloNode | TreeNode:
"""Initialises a tree.

Expand All @@ -2335,6 +2349,8 @@ def make_tree(
underscore_unmunge : bool
replace underscores with spaces in all names read, i.e. "sp_name"
becomes "sp name"
source
path to file tree came from, string value assigned to tree.source

Notes
-----
Expand All @@ -2346,10 +2362,13 @@ def make_tree(
PhyloNode
"""
assert treestring or tip_names, "must provide either treestring or tip_names"
source = str(source) if source else None
if tip_names:
tree_builder = TreeBuilder().create_edge
tips = [tree_builder([], str(tip_name), {}) for tip_name in tip_names]
return tree_builder(tips, "root", {})
result = tree_builder(tips, "root", {})
result.source = source
return result

if format is None and treestring.startswith("<"):
format = "xml"
Expand All @@ -2363,6 +2382,7 @@ def make_tree(
if not tree.name_loaded:
tree.name = "root"

tree.source = source
return tree


Expand Down Expand Up @@ -2390,16 +2410,25 @@ def load_tree(
of the Newick format. Only the cogent3 json and xml tree formats are
supported.

filename is assigned to root node tree.source attribute.

Returns
-------
PhyloNode
"""
file_format, _ = get_format_suffixes(filename)
format = format or file_format
if format == "json":
return load_from_json(filename, (TreeNode, PhyloNode))
tree = load_from_json(filename, (TreeNode, PhyloNode))
tree.source = str(filename)
return tree

with open_(filename) as tfile:
treestring = tfile.read()

return make_tree(treestring, format=format, underscore_unmunge=underscore_unmunge)
return make_tree(
treestring,
format=format,
underscore_unmunge=underscore_unmunge,
source=filename,
)
2 changes: 2 additions & 0 deletions tests/test_app/test_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from cogent3.core.location import gap_coords_to_map

DNA = get_moltype("dna")


_NEW_TYPE = "COGENT3_NEW_TYPE" in os.environ

if _NEW_TYPE:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_app/test_evo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from cogent3 import (
get_app,
get_dataset,
load_aligned_seqs,
make_aligned_seqs,
make_tree,
Expand Down Expand Up @@ -1002,3 +1003,13 @@ def test_model_bounds_kappa():
rules = result.lf.get_param_rules()
kappa_bounds = {(r["lower"], r["upper"]) for r in rules if r["par_name"] == "kappa"}
assert kappa_bounds == {(lower_kappa, upper_kappa)}


def test_source_propagated():
aln = get_dataset("brca1")
three = get_app("take_n_seqs", 4)
dcalc = get_app("fast_slow_dist", fast_calc="pdist", moltype="dna")
qtree = get_app("quick_tree")
app = three + dcalc + qtree
result = app(aln)
assert result.source == pathlib.Path(aln.source).name
11 changes: 8 additions & 3 deletions tests/test_app/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

DNA = get_moltype("dna")

NEW_TYPE = "COGENT3_NEW_TYPE" in os.environ


@pytest.fixture
def tmp_dir(tmp_path_factory):
Expand Down Expand Up @@ -588,12 +590,15 @@ def uniqid(source):
assert m.unique_id == expect


src_attr = "source" if NEW_TYPE else "info"


@pytest.mark.parametrize(
("writer", "data", "attr", "dstore"),
[
("write_seqs", seqs(), "info", dir_dstore),
("write_db", seqs(), "info", db_dstore),
("write_json", seqs(), "info", dir_dstore),
("write_seqs", seqs(), src_attr, dir_dstore),
("write_db", seqs(), src_attr, db_dstore),
("write_json", seqs(), src_attr, dir_dstore),
("write_tabular", table(), "source", dir_dstore),
],
)
Expand Down
Loading
0