8000 feat/on-the-fly inference by YoniSchirris · Pull Request #87 · NKI-AI/ahcore · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.

feat/on-the-fly inference #87

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ahcore/callbacks/abstract_writer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ def _on_epoch_start(self, trainer: "pl.Trainer") -> None:
current_dataset: TiledWsiDataset
assert self._total_dataset
for current_dataset in self._total_dataset.datasets: # type: ignore
assert current_dataset.slide_image.identifier
self._dataset_sizes[current_dataset.slide_image.identifier] = len(current_dataset)
curr_filename = current_dataset._path
assert curr_filename
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion is redundant though, since a tiledwsidataset always has a path.

self._dataset_sizes[str(curr_filename)] = len(current_dataset)

self._start_callback_workers()

Expand Down
18 changes: 10 additions & 8 deletions ahcore/data/dataset.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dlup.data.dataset import Dataset, TiledWsiDataset
from torch.utils.data import DataLoader, DistributedSampler, Sampler

from ahcore.utils.data import DataDescription, basemodel_to_uuid
from ahcore.utils.data import DataDescription, OnTheFlyDataDescription, basemodel_to_uuid
from ahcore.utils.io import fullname, get_cache_dir, get_logger
from ahcore.utils.manifest import DataManager, datasets_from_data_description
from ahcore.utils.types import DlupDatasetSample, _DlupDataset
Expand Down Expand Up @@ -87,10 +87,12 @@ def __len__(self) -> int:
return self.cumulative_sizes[-1]

@overload
def __getitem__(self, index: int) -> DlupDatasetSample: ...
def __getitem__(self, index: int) -> DlupDatasetSample:
...

@overload
def __getitem__(self, index: slice) -> list[DlupDatasetSample]: ...
def __getitem__(self, index: slice) -> list[DlupDatasetSample]:
...

def __getitem__(self, index: Union[int, slice]) -> DlupDatasetSample | list[DlupDatasetSample]:
"""Returns the sample at the given index."""
Expand All @@ -109,7 +111,7 @@ class DlupDataModule(pl.LightningDataModule):

def __init__(
self,
data_description: DataDescription,
data_description: DataDescription | OnTheFlyDataDescription,
pre_transform: Callable[[bool], Callable[[DlupDatasetSample], DlupDatasetSample]],
batch_size: int = 32, # noqa,pylint: disable=unused-argument
validate_batch_size: int | None = None, # noqa,pylint: disable=unused-argument
Expand All @@ -122,8 +124,8 @@ def __init__(

Parameters
----------
data_description : DataDescription
See `ahcore.utils.data.DataDescription` for more information.
data_description : DataDescription | OnTheFlyDataDescription
See `ahcore.utils.data.DataDescription` and `ahcore.utils.data.DataDescription` for more information.
pre_transform : Callable
A pre-transform is a callable which is directly applied to the output of the dataset before collation in
the dataloader. The transforms typically convert the image in the output to a tensor, convert the
Expand Down Expand Up @@ -157,9 +159,9 @@ def __init__(
) # save all relevant hyperparams

# Data settings
self.data_description: DataDescription = data_description
self.data_description = data_description

self._data_manager = DataManager(database_uri=data_description.manifest_database_uri)
self._data_manager = DataManager(data_description)

self._batch_size = self.hparams.batch_size # type: ignore
self._validate_batch_size = self.hparams.validate_batch_size # type: ignore
Expand Down
63 changes: 63 additions & 0 deletions ahcore/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from typing import Dict, Optional, Tuple

from pydantic import BaseModel
from sqlalchemy import create_engine, exists
from sqlalchemy.engine import Engine
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Session, sessionmaker

from ahcore.utils.database_models import Base, Manifest, OnTheFlyBase
from ahcore.utils.types import NonNegativeInt, PositiveFloat, PositiveInt


Expand Down Expand Up @@ -42,6 +47,43 @@ def basemodel_to_uuid(base_model: BaseModel) -> uuid.UUID:
return unique_id


def open_db_from_engine(engine: Engine) -> Session:
SessionLocal = sessionmaker(bind=engine)
return SessionLocal()


def open_db_from_uri(
uri: str,
ensure_exists: bool = True,
) -> Session:
"""Open a database connection from a uri"""

# Set up the engine if no engine is given and uri is given.
engine = create_engine(uri)

if not ensure_exists:
# Create tables if they don't exist
create_tables(engine, base=Base)
else:
# Check if the "manifest" table exists
inspector = inspect(engine)
if "manifest" not in inspector.get_table_names():
raise RuntimeError("Manifest table does not exist. Likely you have set the wrong URI.")

# Check if the "manifest" table is not empty
with engine.connect() as connection:
result = connection.execute(exists().where(Manifest.id.isnot(None)).select())
if not result.scalar():
raise RuntimeError("Manifest table is empty. Likely you have set the wrong URI.")

return open_db_from_engine(engine)


def create_tables(engine: Engine, base: type[Base] | type[OnTheFlyBase]) -> None:
"""Create the database tables."""
base.metadata.create_all(bind=engine)


class GridDescription(BaseModel):
mpp: Optional[PositiveFloat]
tile_size: Tuple[PositiveInt, PositiveInt]
Expand All @@ -67,3 +109,24 @@ class DataDescription(BaseModel):
convert_mask_to_rois: bool = True
use_roi: bool = True
apply_color_profile: bool = False


class OnTheFlyDataDescription(BaseModel):
# Required
data_dir: Path
glob_pattern: str
num_classes: NonNegativeInt
inference_grid: GridDescription

# Preset?
convert_mask_to_rois: bool = True
use_roi: bool = True
apply_color_profile: bool = False

# Explicitly optional
annotations_dir: Optional[Path] = None # May be used to provde a mask.
mask_label: Optional[str] = None
mask_threshold: Optional[float] = None # This is only used for training
roi_name: Optional[str] = None
index_map: Optional[Dict[str, int]]
remap_labels: Optional[Dict[str, str]] = None
25 changes: 25 additions & 0 deletions ahcore/utils/database_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,28 @@ class Split(Base):
split_definition: Mapped["SplitDefinitions"] = relationship("SplitDefinitions", back_populates="splits")

__table_args__ = (UniqueConstraint("split_definition_id", "patient_id", name="uq_patient_split_key"),)


class OnTheFlyBase(DeclarativeBase):
"""
Base for creating an in-memory DB on-the-fly for, e.g., segmentation inference on a directory of WSIs.
"""

pass


class MinimalImage(OnTheFlyBase):
"""Minimal image table for an in-memory db for instant inference"""

# TODO Link to annotations or masks
__tablename__ = "image"
id = Column(Integer, primary_key=True)
# pylint: disable=E1102
created = Column(DateTime(timezone=True), default=func.now())
last_updated = Column(DateTime(timezone=True), default=func.now(), >
filename = Column(String, unique=True, nullable=False)
relative_filename = Column(String, unique=True, nullable=False)
reader = Column(String)
height = Column(Integer)
width = Column(Integer)
mpp = Column(Float)
10 changes: 9 additions & 1 deletion ahcore/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def print_config(
config: DictConfig,
fields: Sequence[str] = (
"trainer",
"data_description",
"model",
"experiment",
"transforms",
Expand Down Expand Up @@ -241,7 +242,14 @@ def load_weights(model: LightningModule, config: DictConfig) -> LightningModule:
return model
else:
# Load checkpoint weights
lit_ckpt = torch.load(config.ckpt_path)
accelerator = config.trainer.accelerator
if accelerator == "cpu":
map_location = "cpu"
elif accelerator == "gpu":
map_location = "cuda"
else:
raise ValueError(f"Accelerator must be either cpu or gpu, but config.trainer.accelerator={accelerator}")
lit_ckpt = torch.load(config.ckpt_path, map_location=map_location)
model.load_state_dict(lit_ckpt["state_dict"], strict=True)
return model

Expand Down
Loading
0