Skip to content

viscy-data

Data loading and Lightning DataModules for AI × imaging tasks.

pip install viscy-data
uv add viscy-data

What's here

Lightning DataModules for OME-Zarr microscopy — HCS plates and wells, triplet sampling for contrastive learning, and memory-mapped caching for terabyte-scale datasets.

Optional extras

viscy-data[triplet] adds tensorstore-backed triplet sampling · viscy-data[livecell] adds LiveCell dataset support · viscy-data[mmap] adds memory-mapped caching · viscy-data[all] enables all.

API reference

BatchedConcatDataModule

Bases: ConcatDataModule

Concatenated data module with batched micro-batch GPU transforms.

Under DDP, attaches ShardedDistributedSampler so each rank iterates a disjoint shard while preserving the existing micro-batch-to-single-batch contract.

on_after_batch_transfer(batch, dataloader_idx)

Apply GPU transforms from constituent data modules to micro-batches.

setup(stage)

Mark each child as a BatchedConcat child before parent setup.

train_dataloader here uses batch_size as-is (loads N indices, each yielding num_samples patches via the child's RandWeightedCropd), so the divisibility constraint enforced by HCSDataModule._train_transform for standalone use does not apply. Setting the flag on each child before calling super().setup (which iterates children's setup) lets the check skip itself.

train_dataloader()

Return batched concatenated training data loader with optional DDP sampling.

val_dataloader()

Return batched concatenated validation data loader with optional DDP sampling.

BatchedConcatDataset

Bases: ConcatDataset

Concatenated dataset with batched access by constituent dataset.

__getitem__(idx)

Not implemented; use getitems for batched access.

__getitems__(indices)

Return micro-batches grouped by constituent dataset.

CTMCv1DataModule

Bases: GPUTransformDataModule

Autoregression data module for the CTMCv1 dataset.

Training and validation datasets are stored in separate HCS OME-Zarr stores.

Parameters:

Name Type Description Default
train_data_path str or Path

Path to the training dataset.

required
val_data_path str or Path

Path to the validation dataset.

required
train_cpu_transforms list of MapTransform

List of CPU transforms for training.

required
val_cpu_transforms list of MapTransform

List of CPU transforms for validation.

required
train_gpu_transforms list of MapTransform

List of GPU transforms for training.

required
val_gpu_transforms list of MapTransform

List of GPU transforms for validation.

required
batch_size int

Batch size, by default 16.

16
num_workers int

Number of dataloading workers, by default 8.

8
val_subsample_ratio int

Skip every N frames for validation to reduce redundancy in video, by default 30.

30
channel_name str

Name of the DIC channel, by default "DIC".

'DIC'
pin_memory bool

Pin memory for dataloaders, by default True.

True

train_cpu_transforms property

Return training CPU transforms.

train_gpu_transforms property

Return training GPU transforms.

val_cpu_transforms property

Return validation CPU transforms.

val_gpu_transforms property

Return validation GPU transforms.

setup(stage)

Set up datasets for the given stage.

CachedConcatDataModule

Bases: LightningDataModule

Concatenated data module with distributed sampling support.

Parameters:

Name Type Description Default
data_modules Sequence[LightningDataModule]

Data modules to concatenate.

required
Notes

Trainer propagation to children happens in both prepare_data and setup for the same reason as ConcatDataModule — see that class's docstring.

prepare_data()

Prepare data for all constituent data modules.

setup(stage)

Set up constituent data modules and create concatenated datasets.

train_dataloader()

Return concatenated training data loader with optional DDP sampling.

val_dataloader()

Return concatenated validation data loader with optional DDP sampling.

CachedOmeZarrDataModule

Bases: GPUTransformDataModule, SelectWell

Data module for cached OME-Zarr arrays.

Parameters:

Name Type Description Default
data_path Path

Path to the HCS OME-Zarr dataset.

required
channels str | list[str]

Channel names to load.

required
batch_size int

Batch size for training and validation.

required
num_workers int

Number of workers for data-loaders.

required
split_ratio float

Fraction of the FOVs used for the training split. The rest will be used for validation.

required
train_cpu_transforms list[DictTransform]

Transforms to be applied on the CPU during training.

required
val_cpu_transforms list[DictTransform]

Transforms to be applied on the CPU during validation.

required
train_gpu_transforms list[DictTransform]

Transforms to be applied on the GPU during training.

required
val_gpu_transforms list[DictTransform]

Transforms to be applied on the GPU during validation.

required
pin_memory bool

Use page-locked memory in data-loaders, by default True.

True
skip_cache bool

Skip caching for this dataset, by default False.

False
include_wells list[str]

List of well names to include in the dataset, by default None (all).

None
exclude_fovs list[str]

List of fovs names to exclude from the dataset, by default None (none).

None
prefetch_factor int | None

Number of batches loaded in advance by each worker.

None

train_cpu_transforms property

Return training CPU transforms.

train_gpu_transforms property

Return training GPU transforms.

val_cpu_transforms property

Return validation CPU transforms.

val_gpu_transforms property

Return validation GPU transforms.

setup(stage)

Set up datasets for fit or validate stage.

CachedOmeZarrDataset

Bases: Dataset

Dataset for cached OME-Zarr arrays.

Parameters:

Name Type Description Default
positions list[Position]

List of FOVs to load images from.

required
channel_names list[str]

List of channel names to load.

required
cache_map DictProxy

Shared dictionary for caching loaded volumes.

required
transform Compose | None

Composed transforms to be applied on the CPU, by default None.

None
array_key str

The image array key name (multi-scale level), by default "0".

'0'
load_normalization_metadata bool

Load normalization metadata in the sample dictionary, by default True.

True
skip_cache bool

Skip caching to save RAM, by default False.

False

__getitem__(idx)

Return a sample for the given index, using cache when available.

__len__()

Return total number of cached samples.

CellDivisionTripletDataModule

Bases: HCSDataModule

Lightning data module for cell division triplet sampling.

__init__(data_path, source_channel, final_yx_patch_size=(64, 64), split_ratio=0.8, batch_size=16, num_workers=8, normalizations=None, augmentations=None, augment_validation=True, time_interval='any', return_negative=True, output_2d=False, persistent_workers=False, prefetch_factor=None, pin_memory=False)

Lightning data module for cell division triplet sampling.

Parameters:

Name Type Description Default
data_path str

Path to directory containing npy files

required
source_channel str | Sequence[str]

List of input channel names

required
final_yx_patch_size tuple[int, int]

Output patch size, by default (64, 64)

(64, 64)
split_ratio float

Ratio of training samples, by default 0.8

0.8
batch_size int

Batch size, by default 16

16
num_workers int

Number of data-loading workers, by default 8

8
normalizations list[MapTransform] or None

Normalization transforms, by default None

None
augmentations list[MapTransform] or None

Augmentation transforms, by default None

None
augment_validation bool

Apply augmentations to validation data, by default True

True
time_interval Literal['any'] | int

Future time interval to sample positive and anchor from, by default "any"

'any'
return_negative bool

Whether to return the negative sample during the fit stage, by default True

True
output_2d bool

Whether to return 2D tensors (C,Y,X) instead of 3D (C,1,Y,X), by default False

False
persistent_workers bool

Whether to keep worker processes alive between iterations, by default False

False
prefetch_factor int | None

Number of batches loaded in advance by each worker, by default None

None
pin_memory bool

Whether to pin memory in CPU for faster GPU transfer, by default False

False

CellDivisionTripletDataset

Bases: Dataset

Dataset for triplet sampling of cell division data from npy files.

For the dataset from the paper: https://arxiv.org/html/2502.02182v1

__getitem__(index)

Return a triplet sample for the given index.

__init__(data_paths, channel_names, anchor_transform=None, positive_transform=None, negative_transform=None, fit=True, time_interval='any', return_negative=True, output_2d=False)

Dataset for triplet sampling of cell division data from npy files.

Parameters:

Name Type Description Default
data_paths list[Path]

List of paths to npy files containing cell division tracks (T,C,Y,X format)

required
channel_names list[str]

Input channel names

required
anchor_transform DictTransform | None

Transforms applied to the anchor sample, by default None

None
positive_transform DictTransform | None

Transforms applied to the positive sample, by default None

None
negative_transform DictTransform | None

Transforms applied to the negative sample, by default None

None
fit bool

Fitting mode in which the full triplet will be sampled, only sample anchor if False, by default True

True
time_interval Literal['any'] | int

Future time interval to sample positive and anchor from, by default "any"

'any'
return_negative bool

Whether to return the negative sample during the fit stage, by default True

True
output_2d bool

Whether to return 2D tensors (C,Y,X) instead of 3D (C,1,Y,X), by default False

False

__len__()

Return the number of valid anchor samples.

CellIndex

Bases: TypedDict

Ultrack tracking index carried in predict-mode batches.

All fields optional — presence depends on the source CSV columns. (fov_name, track_id, t) together uniquely identify a cell observation and are the join key back to valid_anchors.

ChannelDropout

Bases: Module

Randomly zero out entire channels during training.

Designed for (B, C, Z, Y, X) tensors in the GPU augmentation pipeline. Applied after the scatter/gather augmentation chain in on_after_batch_transfer.

Parameters:

Name Type Description Default
channels list[int]

Channel indices to potentially drop.

required
p float

Probability of dropping each specified channel per sample. Default: 0.5.

0.5

forward(x)

Drop selected channels per-sample with probability self.p.

ChannelMap

Bases: TypedDict

Source channel names.

ChannelNormStats

Bases: TypedDict

Per-channel normalization statistics.

ClassificationDataModule

Bases: LightningDataModule

Lightning data module for cell classification tasks.

__init__(image_path, annotation_path, val_fovs, channel_name, z_range, train_exclude_timepoints, train_transforms, val_transforms, initial_yx_patch_size, batch_size, num_workers, label_column='infection_state')

Lightning data module for cell classification tasks.

Parameters:

Name Type Description Default
image_path Path

Path to the OME-Zarr image store

required
annotation_path Path

Path to the annotation CSV file

required
val_fovs list[str]

FOV names for validation

required
channel_name str

Input channel name

required
z_range tuple[int, int]

Range of Z-slices

required
train_exclude_timepoints list[int]

Timepoints to exclude from training

required
train_transforms list[Callable] | None

Training transforms

required
val_transforms list[Callable] | None

Validation transforms

required
initial_yx_patch_size tuple[int, int]

YX size of the initially sampled image patch

required
batch_size int

Batch size

required
num_workers int

Number of data-loading workers

required
label_column str

Column name for the label, by default "infection_state"

'infection_state'

predict_dataloader()

Return predict data loader.

setup(stage=None)

Set up datasets for the given stage.

train_dataloader()

Return training data loader.

val_dataloader()

Return validation data loader.

ClassificationDataset

Bases: Dataset

Dataset for cell classification from annotated image data.

__getitem__(idx)

Return a sample for the given index.

__init__(plate, annotation, channel_name, z_range, transform, initial_yx_patch_size, return_indices=False, label_column='infection_state')

Dataset for cell classification from annotated image data.

Parameters:

Name Type Description Default
plate Plate

OME-Zarr plate store

required
annotation DataFrame

Annotation dataframe with cell locations and labels

required
channel_name str

Input channel name

required
z_range tuple[int, int]

Range of Z-slices

required
transform Callable | None

Transform to apply to image patches

required
initial_yx_patch_size tuple[int, int]

YX size of the initially sampled image patch

required
return_indices bool

Whether to return index information, by default False

False
label_column AnnotationColumns

Column name for the label, by default "infection_state"

'infection_state'

__len__()

Return the number of annotated samples.

CombineMode

Bases: Enum

Mode for combining multiple data modules.

CombinedDataModule

Bases: LightningDataModule

Wrapper for combining multiple data modules.

For supported modes, see lightning.pytorch.utilities.combined_loader.

Parameters:

Name Type Description Default
data_modules Sequence[LightningDataModule]

data modules to combine

required
train_mode CombineMode

mode in training stage, by default CombineMode.MAX_SIZE_CYCLE

MAX_SIZE_CYCLE
val_mode CombineMode

mode in validation stage, by default CombineMode.SEQUENTIAL

SEQUENTIAL
test_mode CombineMode

mode in testing stage, by default CombineMode.SEQUENTIAL

SEQUENTIAL
predict_mode CombineMode

mode in prediction stage, by default CombineMode.SEQUENTIAL

SEQUENTIAL

on_after_batch_transfer(batch, dataloader_idx)

Dispatch GPU transforms to child data modules.

CombinedLoader yields different batch formats:

  • Training (max_size_cycle): list of sub-batches, one per child data module.
  • Validation (sequential): single batch from one child at a time, identified by dataloader_idx.

Parameters:

Name Type Description Default
batch list | dict | Tensor

Batch from CombinedLoader.

required
dataloader_idx int

Index of the active dataloader (meaningful in sequential mode).

required

Returns:

Type Description
list | dict | Tensor

Transformed batch(es).

predict_dataloader()

Return combined predict data loader.

prepare_data()

Prepare data for all constituent data modules.

setup(stage)

Set up all constituent data modules.

test_dataloader()

Return combined test data loader.

train_dataloader()

Return combined training data loader.

val_dataloader()

Return combined validation data loader.

ConcatDataModule

Bases: LightningDataModule

Concatenate multiple data modules.

The concatenated data module will have the same batch size and number of workers as the first data module. Each element will be sampled uniformly regardless of their original data module.

Parameters:

Name Type Description Default
data_modules Sequence[LightningDataModule]

Data modules to concatenate.

required
Notes

Trainer propagation to children happens in both prepare_data and setup because prepare_data_per_node=True causes Lightning to invoke prepare_data only on rank 0 of each node. Without the setup propagation, non-rank-0 children keep self.trainer = None and silently skip trainer-gated paths such as HCSDataModule.on_after_batch_transfer's if self.trainer and self.trainer.training guard, producing rank-asymmetric failures where non-rank-0 ranks receive un-cropped batches because gpu_augmentations did not run.

prepare_data()

Prepare data for all constituent data modules.

setup(stage)

Set up constituent data modules and create concatenated datasets.

train_dataloader()

Return concatenated training data loader.

val_dataloader()

Return concatenated validation data loader.

FlexibleBatchSampler

Bases: Sampler[list[int]]

Composable batch sampler with batch grouping and stratification.

Each batch is constructed by a cascade:

  1. Group selection (batch_group_by): pick a single group to draw from, or draw from all samples.
  2. Leaky mixing (leaky): optionally inject a fraction of cross-group samples into group-restricted batches.
  3. Stratified sampling (stratify_by): within the selected pool, balance representation across groups defined by one or more DataFrame columns.
  4. Temporal enrichment (temporal_enrichment): concentrate batch indices around a randomly chosen focal HPI, with a configurable global fraction drawn from all timepoints.

Parameters:

Name Type Description Default
valid_anchors DataFrame

DataFrame with a clean integer index (0..N-1). Must contain any columns referenced by batch_group_by, stratify_by, or temporal_enrichment.

required
batch_size int

Number of indices per batch.

128
batch_group_by str | list[str] | None

Column(s) in valid_anchors that define batch-level groups. Each batch draws from a single group. "experiment" — one experiment per batch (old experiment_aware=True). "marker" — one marker per batch. ["marker", "source_channel"] — one (marker, channel) per batch. None — no grouping, draw from all samples.

None
leaky float

Fraction of the batch drawn from other groups when batch_group_by is not None. Ignored otherwise.

0.0
group_weights dict[str, float] | None

Per-group sampling weight (keyed by group string key). Defaults to proportional to group size.

None
stratify_by str | list[str] | None

Column name(s) in valid_anchors to stratify batches by. Groups are balanced equally within each batch. None disables stratification.

'perturbation'
temporal_enrichment bool

If True, concentrate batch indices around a randomly chosen focal hours-post-infection (HPI) value. Requires "hours_post_perturbation" column in valid_anchors.

False
temporal_window_hours float

Half-width of the focal window around the chosen HPI.

2.0
temporal_global_fraction float

Fraction of the batch drawn from all timepoints (global).

0.3
num_replicas int

Number of DDP processes (1 for single-process).

1
rank int

Rank of the current process (0 for single-process).

0
seed int

Base RNG seed for deterministic sampling.

0
drop_last bool

If True, drop the last incomplete batch.

True

__iter__()

Yield batch-sized lists of integer indices.

Builds batches lazily so the first batch is ready in milliseconds instead of blocking on a full-epoch materialization. Every rank still calls _build_one_batch on every index so the RNG draws stay identical to the list-based implementation — only the yield is rank-filtered, not the sampling. DDP correctness is therefore bit-identical to the previous implementation; the only change is that the main thread sees batch 0 after one _build_one_batch call instead of total_batches calls.

limit_train_batches interacts with this: Lightning stops pulling from the generator after its cap, so we never pay for the unused suffix of the epoch.

The epoch counter auto-advances at the start of each iteration so that the next __iter__ call reseeds the RNG with a fresh seed + epoch and yields a different batch sequence. Advancing at the start (not the end) is robust against early generator termination from limit_train_batches: Lightning stops pulling after its cap and garbage-collects the generator, which would skip any end-of-iter bookkeeping.

PyTorch Lightning does not call set_epoch on custom batch_sampler instances (use_distributed_sampler: false with a batch sampler means Lightning's auto-wrap skips us), so we self-advance. set_epoch still works if a caller wants deterministic resume from a specific epoch — call it before the iteration and the advance will take the resumed epoch as its starting point.

__len__()

Return number of batches this rank will yield.

set_epoch(epoch)

Set epoch for deterministic shuffling across DDP ranks.

GPUTransformDataModule

Bases: ABC, LightningDataModule

Abstract data module with GPU transforms.

train_cpu_transforms abstractmethod property

Return training CPU transforms.

train_gpu_transforms abstractmethod property

Return training GPU transforms.

val_cpu_transforms abstractmethod property

Return validation CPU transforms.

val_gpu_transforms abstractmethod property

Return validation GPU transforms.

on_after_batch_transfer(batch, dataloader_idx)

Apply GPU transforms after batch transfer to device.

Parameters:

Name Type Description Default
batch dict

Batch dict with channel-name keys mapped to (B, 1, Z, Y, X) tensors from list_data_collate.

required
dataloader_idx int

Dataloader index (unused).

required

Returns:

Type Description
dict

Transformed batch (e.g., with source and target keys after BatchedStackChannelsd).

train_dataloader()

Return training data loader.

val_dataloader()

Return validation data loader.

HCSDataModule

Bases: LightningDataModule

Lightning data module for a preprocessed HCS NGFF Store.

Parameters:

Name Type Description Default
data_path str

Path to the data store.

required
source_channel str or Sequence[str]

Name(s) of the source channel, e.g. 'Phase'.

required
target_channel str or Sequence[str]

Name(s) of the target channel, e.g. ['Nuclei', 'Membrane'].

required
z_window_size int

Z window size of the 2.5D U-Net, 1 for 2D.

required
split_ratio float

Split ratio of the training subset in the fit stage, e.g. 0.8 means an 80/20 split between training/validation, by default 0.8.

0.8
batch_size int

Batch size, defaults to 16.

16
num_workers int

Number of data-loading workers, defaults to 8.

8
target_2d bool

Whether the target is 2D (e.g. in a 2.5D model), defaults to False.

False
yx_patch_size tuple[int, int]

Patch size in (Y, X), defaults to (256, 256).

(256, 256)
normalizations list of MapTransform or None

MONAI dictionary transforms applied to selected channels, defaults to None (no normalization).

None
augmentations list of MapTransform or None

MONAI dictionary transforms applied to the training set, defaults to None (no augmentation).

None
mmap_preload bool

If True, stage the entire dataset to a :class:~tensordict.MemoryMappedTensor buffer under scratch_dir during prepare_data() and serve training samples via mmap views. Eliminates zarr reads during the training loop. Point scratch_dir at tmpfs (e.g. /dev/shm) for RAM-backed I/O. Requires viscy-data[mmap] (tensordict). Default False. Only effective during fit/validate runs; predict and test stages silently ignore this flag and read from zarr directly.

False
scratch_dir Path or None

Directory for mmap cache files. Defaults to /tmp. On SLURM, uses /tmp/$SLURM_JOB_ID/.

None
ground_truth_masks Path or None

Path to the ground truth masks, used in the test stage to compute segmentation metrics, defaults to None.

None
persistent_workers bool

Whether to keep the workers alive between fitting epochs, defaults to False.

False
prefetch_factor int or None

Number of samples loaded in advance by each worker during fitting, defaults to None (2 per PyTorch default).

None
array_key str

Name of the image arrays (multiscales level), by default "0".

'0'
min_nonzero_fraction float

Minimum fraction of voxels above nonzero_threshold for training. Default 0.0 disables filtering.

0.0
nonzero_threshold float

Intensity threshold for the nonzero fraction check, by default 0.0.

0.0
nonzero_channel str or None

Channel to check. None defaults to the first target channel.

None
max_nonzero_retries int

Maximum retries when a patch fails the nonzero check, by default 100.

100
fg_mask_key str or None

Zarr array key for precomputed foreground masks, by default None.

None
val_gpu_augmentations list[MapTransform] or None

GPU transforms applied to validation batches in on_after_batch_transfer. Use for validation-time spatial crops (e.g. BatchedDivisibleCropd) when the FOV is not compatible with the model's downsampling factor.

None
include_fov_names Iterable[str] or None

If given, only positions whose plate-relative name (e.g. "B/2/000000") is in this collection are used. Applied before the train/val split. Stored as a set for O(1) lookup. Honored during fit/validate (including mmap staging) and predict. Test stage uses all plate positions.

None
exclude_fov_names Iterable[str] or None

If given, positions whose plate-relative name is in this collection are skipped. Useful to hold out test FOVs from a plate that also contains training FOVs, or to resume a predict run without redoing already-written positions. Applied after include_fov_names. Stored as a set for O(1) lookup. Honored during fit/validate (including mmap staging) and predict. Test stage uses all plate positions.

None

on_after_batch_transfer(batch, dataloader_idx)

Apply GPU augmentations and validate output spatial shape.

Training: applies gpu_augmentations if configured. When no gpu_augmentations are set, validates that source spatial dimensions match (z_window_size, *yx_patch_size). Validation: applies val_gpu_augmentations if configured. Test/predict: pass through unchanged.

When target_2d is set, the target center Z slice is extracted after augmentations to save VRAM.

predict_dataloader()

Return predict data loader.

prepare_data()

Stage FOVs to a memory-mapped tensor buffer on local scratch.

Runs only when the current stage is fit or validate. Predict and test stages skip mmap staging entirely (they read from zarr directly and ignore FOV filters). Manual dm.prepare_data() calls without a trainer attached (or with a trainer whose state.fn is None) also run — preserving existing standalone usage.

setup(stage)

Set up datasets for the given stage.

test_dataloader()

Return test data loader.

train_dataloader()

Return training data loader.

val_dataloader()

Return validation data loader.

HCSStackIndex

Bases: NamedTuple

HCS stack index.

LevelNormStats

Bases: TypedDict

Per-level normalization statistics.

Not all fields are present for every level. The normalize transforms access fields dynamically based on subtrahend and divisor config (e.g. mean/std or median/iqr).

LiveCellDataModule

Bases: GPUTransformDataModule

Data module for LiveCell training and evaluation.

Parameters:

Name Type Description Default
train_val_images Path | None

Path to the training/validation images directory.

None
test_images Path | None

Path to the test images directory.

None
train_annotations Path | None

Path to the training COCO annotations file.

None
val_annotations Path | None

Path to the validation COCO annotations file.

None
test_annotations Path | None

Path to the test COCO annotations file.

None
train_cpu_transforms list[MapTransform]

CPU transforms for training.

None
val_cpu_transforms list[MapTransform]

CPU transforms for validation.

None
train_gpu_transforms list[MapTransform]

GPU transforms for training.

None
val_gpu_transforms list[MapTransform]

GPU transforms for validation.

None
test_transforms list[MapTransform]

Transforms for test stage.

None
batch_size int

Batch size, by default 16.

16
num_workers int

Number of dataloading workers, by default 8.

8
pin_memory bool

Pin memory for dataloaders, by default True.

True

train_cpu_transforms property

Return training CPU transforms.

train_gpu_transforms property

Return training GPU transforms.

val_cpu_transforms property

Return validation CPU transforms.

val_gpu_transforms property

Return validation GPU transforms.

setup(stage)

Set up datasets for the given stage.

test_dataloader()

Return test data loader.

LiveCellDataset

Bases: Dataset

LiveCell dataset.

Parameters:

Name Type Description Default
images list of Path

List of paths to single-page, single-channel TIFF files.

required
transform Transform or Compose

Transform to apply to the dataset.

required
cache_map DictProxy

Shared dictionary for caching images.

required

__getitem__(idx)

Return a sample for the given index, using cache when available.

__len__()

Return total number of images.

LiveCellTestDataset

Bases: Dataset

LiveCell test dataset.

Parameters:

Name Type Description Default
image_dir Path

Directory containing the images.

required
transform MapTransform | Compose

Transform to apply to the dataset.

required
annotations Path

Path to the COCO annotations file.

required
load_target bool

Whether to load the target images (default is False).

False
load_labels bool

Whether to load the labels (default is False).

False

__getitem__(idx)

Return a sample for the given index.

__len__()

Return total number of test images.

MaskTestDataset

Bases: SlidingWindowDataset

Test dataset with optional ground truth masks.

Each element is a window of (C, Z, Y, X) where C=2 (source and target) and Z is z_window_size.

This a testing stage version of 🇵🇾class:viscy_data.sliding_window.SlidingWindowDataset, and can only be used with batch size 1 for efficiency (no padding for collation), since the mask is not available for each stack.

Parameters:

Name Type Description Default
positions list[Position]

FOVs to include in dataset.

required
channels ChannelMap

Source and target channel names, e.g. {'source': 'Phase', 'target': ['Nuclei', 'Membrane']}.

required
z_window_size int

Z window size of the 2.5D U-Net, 1 for 2D.

required
transform DictTransform

A callable that transforms data, defaults to None.

None
ground_truth_masks str | None

Path to the ground truth masks.

None
array_key str

Name of the image arrays (multiscales level), by default "0".

'0'

__getitem__(index)

Return a sample with optional ground truth mask.

MmappedDataModule

Bases: GPUTransformDataModule, SelectWell

Data module for cached OME-Zarr arrays.

Parameters:

Name Type Description Default
data_path Path

Path to the HCS OME-Zarr dataset.

required
channels str | list[str]

Channel names to load.

required
batch_size int

Batch size for training and validation.

required
num_workers int

Number of workers for data-loaders.

required
split_ratio float

Fraction of the FOVs used for the training split. The rest will be used for validation.

required
train_cpu_transforms list[DictTransform]

Transforms to be applied on the CPU during training.

required
val_cpu_transforms list[DictTransform]

Transforms to be applied on the CPU during validation.

required
train_gpu_transforms list[DictTransform]

Transforms to be applied on the GPU during training.

required
val_gpu_transforms list[DictTransform]

Transforms to be applied on the GPU during validation.

required
pin_memory bool

Use page-locked memory in data-loaders, by default True

True
prefetch_factor int | None

Prefetching ratio for the torch dataloader, by default None

None
array_key str

Name of the image arrays (multiscales level), by default "0"

'0'
scratch_dir Path | None

Path to the scratch directory, by default None (use OS temporary data directory)

None
include_wells list[str] | None

Include only a subset of wells, by default None (include all wells)

None
exclude_fovs list[str] | None

Exclude FOVs, by default None (do not exclude any FOVs)

None

cache_dir property

Return the cache directory path for memory-mapped tensors.

preprocessing_transforms property

Return preprocessing transforms.

train_cpu_transforms property

Return training CPU transforms.

train_gpu_transforms property

Return training GPU transforms.

val_cpu_transforms property

Return validation CPU transforms.

val_gpu_transforms property

Return validation GPU transforms.

setup(stage)

Set up datasets for fit or validate stage.

MmappedDataset

Bases: Dataset

Dataset backed by memory-mapped tensors for efficient caching.

Parameters:

Name Type Description Default
positions list[Position]

List of FOVs to load images from.

required
channel_names list[str]

Channel names to load.

required
cache_map DictProxy

Shared dictionary for caching loaded volumes.

required
buffer MemoryMappedTensor

Memory-mapped tensor buffer for cached volumes.

required
preprocess_transforms Compose | None

Preprocessing transforms, by default None.

None
cpu_transform Compose | None

CPU transforms, by default None.

None
array_key str

The image array key name (multi-scale level), by default "0".

'0'
load_normalization_metadata bool

Load normalization metadata in the sample dictionary, by default True.

True

__getitem__(idx)

Return a sample for the given index, using mmap cache.

__len__()

Return total number of cached samples.

Sample

Bases: TypedDict

Image sample type for mini-batches.

All fields are optional.

SampleMeta

Bases: TypedDict

Biological metadata carried in train-mode batches for sampler debugging.

Joinable against valid_anchors on (global_track_id, t).

Core fields are defined here. Domain-specific fields should be added by subclassing SampleMeta (e.g. OpsSampleMeta). The labels field is an open-ended dict of integer labels that auxiliary heads can consume via batch_key without requiring a subclass.

SegmentationDataModule

Bases: LightningDataModule

Lightning data module for evaluating segmentation predictions.

setup(stage)

Set up the test dataset.

test_dataloader()

Return test data loader.

SegmentationDataset

Bases: Dataset

Dataset for evaluating segmentation predictions against targets.

__getitem__(idx)

Return prediction and target tensors for a given index.

__len__()

Return number of test samples.

SegmentationSample

Bases: TypedDict

Segmentation sample type for mini-batches.

SelectWell

Mixin class for filtering wells and FOVs from HCS plates.

ShardedDistributedSampler

Bases: DistributedSampler

Distributed sampler with sharded random permutations for DDP training.

__iter__()

Shard data across distributed ranks.

SlidingWindowDataset

Bases: Dataset

Sliding window dataset over HCS NGFF positions.

Each element is a window of (C, Z, Y, X) where C=2 (source and target) and Z is z_window_size.

Parameters:

Name Type Description Default
positions list[Position]

FOVs to include in dataset.

required
channels ChannelMap

Source and target channel names, e.g. {'source': 'Phase', 'target': ['Nuclei', 'Membrane']}.

required
z_window_size int

Z window size of the 2.5D U-Net, 1 for 2D.

required
array_key str

Name of the image arrays (multiscales level), by default "0".

'0'
transform DictTransform | None

A callable that transforms data, defaults to None.

None
load_normalization_metadata bool

Whether to load normalization metadata, defaults to True.

True
min_nonzero_fraction float

Minimum fraction of voxels above nonzero_threshold for a patch to be used. Patches below this fraction are retried up to max_nonzero_retries times. Default 0.0 disables filtering.

0.0
nonzero_threshold float

Intensity threshold for the nonzero fraction check. Default 0.0 means any nonzero voxel counts.

0.0
nonzero_channel str or None

Channel name to check for nonzero fraction. None defaults to the first target channel.

None
max_nonzero_retries int

Maximum number of random re-samples when a patch fails the nonzero fraction check. Default 100.

100
fg_mask_key str or None

Zarr array key for precomputed foreground masks. When set, masks are loaded alongside images and included in the sample as "fg_mask". Default None disables mask loading.

None
preloaded_fovs list[Tensor] or None

Pre-loaded FOV data, one tensor per position with shape (T, C, Z, Y, X). Channels are source + target in order. When set, bypasses zarr reads via .clone() copies from the preloaded data. Default None reads from zarr.

None

__getitem__(index)

Return a sample for the given index.

__len__()

Return total number of windows.

TripletDataModule

Bases: HCSDataModule

Lightning data module for triplet sampling of patches.

__init__(data_path, tracks_path, source_channel, z_range, initial_yx_patch_size=(512, 512), final_yx_patch_size=(224, 224), split_ratio=0.8, batch_size=16, num_workers=1, normalizations=None, augmentations=None, augment_validation=True, fit_include_wells=None, fit_exclude_fovs=None, predict_cells=False, include_fov_names=None, include_track_ids=None, time_interval='any', return_negative=True, persistent_workers=False, prefetch_factor=None, pin_memory=False, z_window_size=None, cache_pool_bytes=0)

Lightning data module for triplet sampling of patches.

Parameters:

Name Type Description Default
data_path str

Image dataset path

required
tracks_path str

Tracks labels dataset path

required
source_channel str | Sequence[str]

List of input channel names

required
z_range tuple[int, int]

Range of valid z-slices

required
initial_yx_patch_size tuple[int, int]

XY size of the initially sampled image patch, by default (512, 512)

(512, 512)
final_yx_patch_size tuple[int, int]

Output patch size, by default (224, 224)

(224, 224)
split_ratio float

Ratio of training samples, by default 0.8

0.8
batch_size int

Batch size, by default 16

16
num_workers int

Number of thread workers. Set to 0 to disable threading. Using more than 1 is not recommended. by default 1

1
normalizations list[MapTransform] or None

Normalization transforms, by default None

None
augmentations list[MapTransform] or None

Augmentation transforms, by default None

None
augment_validation bool

Apply augmentations to validation data, by default True. Set to False for VAE training where clean validation is needed.

True
fit_include_wells list[str]

Only include these wells for fitting, by default None

None
fit_exclude_fovs list[str]

Exclude these FOVs for fitting, by default None

None
predict_cells bool

Only predict for selected cells, by default False

False
include_fov_names list[str] | None

Only predict for selected FOVs, by default None

None
include_track_ids list[int] | None

Only predict for selected tracks, by default None

None
time_interval Literal['any'] | int

Future time interval to sample positive and anchor from, "any" means sampling negative from another track any time point and using the augmented anchor patch as positive), by default "any"

'any'
return_negative bool

Whether to return the negative sample during the fit stage (can be set to False when using a loss function like NT-Xent), by default True

True
persistent_workers bool

Whether to keep worker processes alive between iterations, by default False

False
prefetch_factor int | None

Number of batches loaded in advance by each worker, by default None

None
pin_memory bool

Whether to pin memory in CPU for faster GPU transfer, by default False

False
z_window_size int

Size of the final Z window, by default None (inferred from z_range)

None
cache_pool_bytes int

Size of the tensorstore cache pool in bytes, attached to the plate at open_ome_zarr time via :class:iohub.core.config.TensorStoreConfig, by default 0.

0

on_after_batch_transfer(batch, dataloader_idx)

Apply transforms after transferring to device.

predict_dataloader()

Return predict data loader.

train_dataloader()

Return training data loader.

val_dataloader()

Return validation data loader.

TripletDataset

Bases: Dataset

Dataset for triplet sampling of cells based on tracking.

__getitems__(indices)

Return a batch of triplet samples for the given indices.

__init__(positions, tracks_tables, channel_names, initial_yx_patch_size, z_range, fit=True, predict_cells=False, include_fov_names=None, include_track_ids=None, time_interval='any', return_negative=True)

Dataset for triplet sampling of cells based on tracking.

Parameters:

Name Type Description Default
positions list[Position]

OME-Zarr images with consistent channel order

required
tracks_tables list[DataFrame]

Data frames containing ultrack results

required
channel_names list[str]

Input channel names

required
initial_yx_patch_size tuple[int, int]

YX size of the initially sampled image patch

required
z_range slice

Range of Z-slices

required
fit bool

Fitting mode in which the full triplet will be sampled, only sample anchor if False, by default True

True
predict_cells bool

Only predict on selected cells, by default False

False
include_fov_names list[str] | None

Only predict on selected FOVs, by default None

None
include_track_ids list[int] | None

Only predict on selected track IDs, by default None

None
time_interval Literal['any'] | int

Future time interval to sample positive and anchor from, by default "any" (sample negative from another track any time point and use the augmented anchor patch as positive)

'any'
return_negative bool

Whether to return the negative sample during the fit stage (can be set to False when using a loss function like NT-Xent), by default True

True

__len__()

Return the number of valid anchor samples.

TripletSample

Bases: TypedDict

Triplet sample type for mini-batches.

read_cell_index(path)

Read a cell index parquet into a pandas DataFrame.

String columns are materialized as NumPy object arrays instead of ArrowStringArray. ArrowStringArray-backed columns route every boolean mask slice through pyarrow.compute.take, which allocates a fresh buffer per string column and can spike peak RSS by 50+ GiB on 80M-row indices during train/val FOV partitioning. NumPy object columns make df[mask] a cheap gather.

Parameters:

Name Type Description Default
path str | Path

Path to parquet file.

required

Returns:

Type Description
DataFrame

Cell index with correct dtypes.

validate_cell_index(df, *, strict=False)

Validate a cell index DataFrame against the canonical schema.

Parameters:

Name Type Description Default
df DataFrame

Cell index to validate.

required
strict bool

If True, require all schema columns (not just core + grouping).

False

Returns:

Type Description
list[str]

Warnings (e.g. nullable columns that are entirely null).

Raises:

Type Description
ValueError

If required columns are missing or (cell_id, channel_name) is not unique.

write_cell_index(df, path, *, validate=True)

Write a cell index DataFrame to parquet with the canonical schema.

Missing nullable columns are added as None before writing.

Parameters:

Name Type Description Default
df DataFrame

Cell index to write.

required
path str | Path

Output parquet path.

required
validate bool

Run :func:validate_cell_index before writing.

True