viscy-data¶
Data loading and Lightning
DataModules for AI × imaging tasks.
What's here
Lightning DataModules for OME-Zarr microscopy — HCS plates and wells,
triplet sampling for contrastive learning, and memory-mapped caching for
terabyte-scale datasets.
Optional extras
viscy-data[triplet] adds tensorstore-backed triplet sampling ·
viscy-data[livecell] adds LiveCell dataset support ·
viscy-data[mmap] adds memory-mapped caching · viscy-data[all] enables all.
API reference¶
BatchedConcatDataModule
¶
Bases: ConcatDataModule
Concatenated data module with batched micro-batch GPU transforms.
Under DDP, attaches ShardedDistributedSampler so each rank
iterates a disjoint shard while preserving the existing
micro-batch-to-single-batch contract.
on_after_batch_transfer(batch, dataloader_idx)
¶
Apply GPU transforms from constituent data modules to micro-batches.
setup(stage)
¶
Mark each child as a BatchedConcat child before parent setup.
train_dataloader here uses batch_size as-is (loads N
indices, each yielding num_samples patches via the child's
RandWeightedCropd), so the divisibility constraint enforced
by HCSDataModule._train_transform for standalone use does
not apply. Setting the flag on each child before calling
super().setup (which iterates children's setup) lets
the check skip itself.
train_dataloader()
¶
Return batched concatenated training data loader with optional DDP sampling.
val_dataloader()
¶
Return batched concatenated validation data loader with optional DDP sampling.
BatchedConcatDataset
¶
CTMCv1DataModule
¶
Bases: GPUTransformDataModule
Autoregression data module for the CTMCv1 dataset.
Training and validation datasets are stored in separate HCS OME-Zarr stores.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
train_data_path
|
str or Path
|
Path to the training dataset. |
required |
val_data_path
|
str or Path
|
Path to the validation dataset. |
required |
train_cpu_transforms
|
list of MapTransform
|
List of CPU transforms for training. |
required |
val_cpu_transforms
|
list of MapTransform
|
List of CPU transforms for validation. |
required |
train_gpu_transforms
|
list of MapTransform
|
List of GPU transforms for training. |
required |
val_gpu_transforms
|
list of MapTransform
|
List of GPU transforms for validation. |
required |
batch_size
|
int
|
Batch size, by default 16. |
16
|
num_workers
|
int
|
Number of dataloading workers, by default 8. |
8
|
val_subsample_ratio
|
int
|
Skip every N frames for validation to reduce redundancy in video, by default 30. |
30
|
channel_name
|
str
|
Name of the DIC channel, by default "DIC". |
'DIC'
|
pin_memory
|
bool
|
Pin memory for dataloaders, by default True. |
True
|
train_cpu_transforms
property
¶
Return training CPU transforms.
train_gpu_transforms
property
¶
Return training GPU transforms.
val_cpu_transforms
property
¶
Return validation CPU transforms.
val_gpu_transforms
property
¶
Return validation GPU transforms.
setup(stage)
¶
Set up datasets for the given stage.
CachedConcatDataModule
¶
Bases: LightningDataModule
Concatenated data module with distributed sampling support.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_modules
|
Sequence[LightningDataModule]
|
Data modules to concatenate. |
required |
Notes
Trainer propagation to children happens in both prepare_data and
setup for the same reason as ConcatDataModule — see that
class's docstring.
prepare_data()
¶
Prepare data for all constituent data modules.
setup(stage)
¶
Set up constituent data modules and create concatenated datasets.
train_dataloader()
¶
Return concatenated training data loader with optional DDP sampling.
val_dataloader()
¶
Return concatenated validation data loader with optional DDP sampling.
CachedOmeZarrDataModule
¶
Bases: GPUTransformDataModule, SelectWell
Data module for cached OME-Zarr arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_path
|
Path
|
Path to the HCS OME-Zarr dataset. |
required |
channels
|
str | list[str]
|
Channel names to load. |
required |
batch_size
|
int
|
Batch size for training and validation. |
required |
num_workers
|
int
|
Number of workers for data-loaders. |
required |
split_ratio
|
float
|
Fraction of the FOVs used for the training split. The rest will be used for validation. |
required |
train_cpu_transforms
|
list[DictTransform]
|
Transforms to be applied on the CPU during training. |
required |
val_cpu_transforms
|
list[DictTransform]
|
Transforms to be applied on the CPU during validation. |
required |
train_gpu_transforms
|
list[DictTransform]
|
Transforms to be applied on the GPU during training. |
required |
val_gpu_transforms
|
list[DictTransform]
|
Transforms to be applied on the GPU during validation. |
required |
pin_memory
|
bool
|
Use page-locked memory in data-loaders, by default True. |
True
|
skip_cache
|
bool
|
Skip caching for this dataset, by default False. |
False
|
include_wells
|
list[str]
|
List of well names to include in the dataset, by default None (all). |
None
|
exclude_fovs
|
list[str]
|
List of fovs names to exclude from the dataset, by default None (none). |
None
|
prefetch_factor
|
int | None
|
Number of batches loaded in advance by each worker. |
None
|
train_cpu_transforms
property
¶
Return training CPU transforms.
train_gpu_transforms
property
¶
Return training GPU transforms.
val_cpu_transforms
property
¶
Return validation CPU transforms.
val_gpu_transforms
property
¶
Return validation GPU transforms.
setup(stage)
¶
Set up datasets for fit or validate stage.
CachedOmeZarrDataset
¶
Bases: Dataset
Dataset for cached OME-Zarr arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
positions
|
list[Position]
|
List of FOVs to load images from. |
required |
channel_names
|
list[str]
|
List of channel names to load. |
required |
cache_map
|
DictProxy
|
Shared dictionary for caching loaded volumes. |
required |
transform
|
Compose | None
|
Composed transforms to be applied on the CPU, by default None. |
None
|
array_key
|
str
|
The image array key name (multi-scale level), by default "0". |
'0'
|
load_normalization_metadata
|
bool
|
Load normalization metadata in the sample dictionary, by default True. |
True
|
skip_cache
|
bool
|
Skip caching to save RAM, by default False. |
False
|
CellDivisionTripletDataModule
¶
Bases: HCSDataModule
Lightning data module for cell division triplet sampling.
__init__(data_path, source_channel, final_yx_patch_size=(64, 64), split_ratio=0.8, batch_size=16, num_workers=8, normalizations=None, augmentations=None, augment_validation=True, time_interval='any', return_negative=True, output_2d=False, persistent_workers=False, prefetch_factor=None, pin_memory=False)
¶
Lightning data module for cell division triplet sampling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_path
|
str
|
Path to directory containing npy files |
required |
source_channel
|
str | Sequence[str]
|
List of input channel names |
required |
final_yx_patch_size
|
tuple[int, int]
|
Output patch size, by default (64, 64) |
(64, 64)
|
split_ratio
|
float
|
Ratio of training samples, by default 0.8 |
0.8
|
batch_size
|
int
|
Batch size, by default 16 |
16
|
num_workers
|
int
|
Number of data-loading workers, by default 8 |
8
|
normalizations
|
list[MapTransform] or None
|
Normalization transforms, by default None |
None
|
augmentations
|
list[MapTransform] or None
|
Augmentation transforms, by default None |
None
|
augment_validation
|
bool
|
Apply augmentations to validation data, by default True |
True
|
time_interval
|
Literal['any'] | int
|
Future time interval to sample positive and anchor from, by default "any" |
'any'
|
return_negative
|
bool
|
Whether to return the negative sample during the fit stage, by default True |
True
|
output_2d
|
bool
|
Whether to return 2D tensors (C,Y,X) instead of 3D (C,1,Y,X), by default False |
False
|
persistent_workers
|
bool
|
Whether to keep worker processes alive between iterations, by default False |
False
|
prefetch_factor
|
int | None
|
Number of batches loaded in advance by each worker, by default None |
None
|
pin_memory
|
bool
|
Whether to pin memory in CPU for faster GPU transfer, by default False |
False
|
CellDivisionTripletDataset
¶
Bases: Dataset
Dataset for triplet sampling of cell division data from npy files.
For the dataset from the paper: https://arxiv.org/html/2502.02182v1
__getitem__(index)
¶
Return a triplet sample for the given index.
__init__(data_paths, channel_names, anchor_transform=None, positive_transform=None, negative_transform=None, fit=True, time_interval='any', return_negative=True, output_2d=False)
¶
Dataset for triplet sampling of cell division data from npy files.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_paths
|
list[Path]
|
List of paths to npy files containing cell division tracks (T,C,Y,X format) |
required |
channel_names
|
list[str]
|
Input channel names |
required |
anchor_transform
|
DictTransform | None
|
Transforms applied to the anchor sample, by default None |
None
|
positive_transform
|
DictTransform | None
|
Transforms applied to the positive sample, by default None |
None
|
negative_transform
|
DictTransform | None
|
Transforms applied to the negative sample, by default None |
None
|
fit
|
bool
|
Fitting mode in which the full triplet will be sampled, only sample anchor if False, by default True |
True
|
time_interval
|
Literal['any'] | int
|
Future time interval to sample positive and anchor from, by default "any" |
'any'
|
return_negative
|
bool
|
Whether to return the negative sample during the fit stage, by default True |
True
|
output_2d
|
bool
|
Whether to return 2D tensors (C,Y,X) instead of 3D (C,1,Y,X), by default False |
False
|
__len__()
¶
Return the number of valid anchor samples.
CellIndex
¶
Bases: TypedDict
Ultrack tracking index carried in predict-mode batches.
All fields optional — presence depends on the source CSV columns. (fov_name, track_id, t) together uniquely identify a cell observation and are the join key back to valid_anchors.
ChannelDropout
¶
Bases: Module
Randomly zero out entire channels during training.
Designed for (B, C, Z, Y, X) tensors in the GPU augmentation pipeline. Applied after the scatter/gather augmentation chain in on_after_batch_transfer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
channels
|
list[int]
|
Channel indices to potentially drop. |
required |
p
|
float
|
Probability of dropping each specified channel per sample. Default: 0.5. |
0.5
|
forward(x)
¶
Drop selected channels per-sample with probability self.p.
ChannelMap
¶
Bases: TypedDict
Source channel names.
ChannelNormStats
¶
Bases: TypedDict
Per-channel normalization statistics.
ClassificationDataModule
¶
Bases: LightningDataModule
Lightning data module for cell classification tasks.
__init__(image_path, annotation_path, val_fovs, channel_name, z_range, train_exclude_timepoints, train_transforms, val_transforms, initial_yx_patch_size, batch_size, num_workers, label_column='infection_state')
¶
Lightning data module for cell classification tasks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image_path
|
Path
|
Path to the OME-Zarr image store |
required |
annotation_path
|
Path
|
Path to the annotation CSV file |
required |
val_fovs
|
list[str]
|
FOV names for validation |
required |
channel_name
|
str
|
Input channel name |
required |
z_range
|
tuple[int, int]
|
Range of Z-slices |
required |
train_exclude_timepoints
|
list[int]
|
Timepoints to exclude from training |
required |
train_transforms
|
list[Callable] | None
|
Training transforms |
required |
val_transforms
|
list[Callable] | None
|
Validation transforms |
required |
initial_yx_patch_size
|
tuple[int, int]
|
YX size of the initially sampled image patch |
required |
batch_size
|
int
|
Batch size |
required |
num_workers
|
int
|
Number of data-loading workers |
required |
label_column
|
str
|
Column name for the label, by default "infection_state" |
'infection_state'
|
predict_dataloader()
¶
Return predict data loader.
setup(stage=None)
¶
Set up datasets for the given stage.
train_dataloader()
¶
Return training data loader.
val_dataloader()
¶
Return validation data loader.
ClassificationDataset
¶
Bases: Dataset
Dataset for cell classification from annotated image data.
__getitem__(idx)
¶
Return a sample for the given index.
__init__(plate, annotation, channel_name, z_range, transform, initial_yx_patch_size, return_indices=False, label_column='infection_state')
¶
Dataset for cell classification from annotated image data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
plate
|
Plate
|
OME-Zarr plate store |
required |
annotation
|
DataFrame
|
Annotation dataframe with cell locations and labels |
required |
channel_name
|
str
|
Input channel name |
required |
z_range
|
tuple[int, int]
|
Range of Z-slices |
required |
transform
|
Callable | None
|
Transform to apply to image patches |
required |
initial_yx_patch_size
|
tuple[int, int]
|
YX size of the initially sampled image patch |
required |
return_indices
|
bool
|
Whether to return index information, by default False |
False
|
label_column
|
AnnotationColumns
|
Column name for the label, by default "infection_state" |
'infection_state'
|
__len__()
¶
Return the number of annotated samples.
CombineMode
¶
Bases: Enum
Mode for combining multiple data modules.
CombinedDataModule
¶
Bases: LightningDataModule
Wrapper for combining multiple data modules.
For supported modes, see lightning.pytorch.utilities.combined_loader.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_modules
|
Sequence[LightningDataModule]
|
data modules to combine |
required |
train_mode
|
CombineMode
|
mode in training stage, by default CombineMode.MAX_SIZE_CYCLE |
MAX_SIZE_CYCLE
|
val_mode
|
CombineMode
|
mode in validation stage, by default CombineMode.SEQUENTIAL |
SEQUENTIAL
|
test_mode
|
CombineMode
|
mode in testing stage, by default CombineMode.SEQUENTIAL |
SEQUENTIAL
|
predict_mode
|
CombineMode
|
mode in prediction stage, by default CombineMode.SEQUENTIAL |
SEQUENTIAL
|
on_after_batch_transfer(batch, dataloader_idx)
¶
Dispatch GPU transforms to child data modules.
CombinedLoader yields different batch formats:
- Training (
max_size_cycle): list of sub-batches, one per child data module. - Validation (
sequential): single batch from one child at a time, identified bydataloader_idx.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
list | dict | Tensor
|
Batch from |
required |
dataloader_idx
|
int
|
Index of the active dataloader (meaningful in sequential mode). |
required |
Returns:
| Type | Description |
|---|---|
list | dict | Tensor
|
Transformed batch(es). |
predict_dataloader()
¶
Return combined predict data loader.
prepare_data()
¶
Prepare data for all constituent data modules.
setup(stage)
¶
Set up all constituent data modules.
test_dataloader()
¶
Return combined test data loader.
train_dataloader()
¶
Return combined training data loader.
val_dataloader()
¶
Return combined validation data loader.
ConcatDataModule
¶
Bases: LightningDataModule
Concatenate multiple data modules.
The concatenated data module will have the same batch size and number of workers as the first data module. Each element will be sampled uniformly regardless of their original data module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_modules
|
Sequence[LightningDataModule]
|
Data modules to concatenate. |
required |
Notes
Trainer propagation to children happens in both prepare_data and
setup because prepare_data_per_node=True causes Lightning to
invoke prepare_data only on rank 0 of each node. Without the
setup propagation, non-rank-0 children keep self.trainer = None
and silently skip trainer-gated paths such as
HCSDataModule.on_after_batch_transfer's
if self.trainer and self.trainer.training guard, producing
rank-asymmetric failures where non-rank-0 ranks receive un-cropped
batches because gpu_augmentations did not run.
FlexibleBatchSampler
¶
Bases: Sampler[list[int]]
Composable batch sampler with batch grouping and stratification.
Each batch is constructed by a cascade:
- Group selection (
batch_group_by): pick a single group to draw from, or draw from all samples. - Leaky mixing (
leaky): optionally inject a fraction of cross-group samples into group-restricted batches. - Stratified sampling (
stratify_by): within the selected pool, balance representation across groups defined by one or more DataFrame columns. - Temporal enrichment (
temporal_enrichment): concentrate batch indices around a randomly chosen focal HPI, with a configurable global fraction drawn from all timepoints.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
valid_anchors
|
DataFrame
|
DataFrame with a clean integer index (0..N-1).
Must contain any columns referenced by |
required |
batch_size
|
int
|
Number of indices per batch. |
128
|
batch_group_by
|
str | list[str] | None
|
Column(s) in valid_anchors that define batch-level groups.
Each batch draws from a single group.
|
None
|
leaky
|
float
|
Fraction of the batch drawn from other groups when
|
0.0
|
group_weights
|
dict[str, float] | None
|
Per-group sampling weight (keyed by group string key). Defaults to proportional to group size. |
None
|
stratify_by
|
str | list[str] | None
|
Column name(s) in valid_anchors to stratify batches by.
Groups are balanced equally within each batch.
|
'perturbation'
|
temporal_enrichment
|
bool
|
If |
False
|
temporal_window_hours
|
float
|
Half-width of the focal window around the chosen HPI. |
2.0
|
temporal_global_fraction
|
float
|
Fraction of the batch drawn from all timepoints (global). |
0.3
|
num_replicas
|
int
|
Number of DDP processes (1 for single-process). |
1
|
rank
|
int
|
Rank of the current process (0 for single-process). |
0
|
seed
|
int
|
Base RNG seed for deterministic sampling. |
0
|
drop_last
|
bool
|
If |
True
|
__iter__()
¶
Yield batch-sized lists of integer indices.
Builds batches lazily so the first batch is ready in milliseconds
instead of blocking on a full-epoch materialization. Every rank
still calls _build_one_batch on every index so the RNG draws
stay identical to the list-based implementation — only the
yield is rank-filtered, not the sampling. DDP correctness is
therefore bit-identical to the previous implementation; the only
change is that the main thread sees batch 0 after one
_build_one_batch call instead of total_batches calls.
limit_train_batches interacts with this: Lightning stops
pulling from the generator after its cap, so we never pay for
the unused suffix of the epoch.
The epoch counter auto-advances at the start of each iteration
so that the next __iter__ call reseeds the RNG with a fresh
seed + epoch and yields a different batch sequence. Advancing
at the start (not the end) is robust against early generator
termination from limit_train_batches: Lightning stops pulling
after its cap and garbage-collects the generator, which would
skip any end-of-iter bookkeeping.
PyTorch Lightning does not call set_epoch on custom
batch_sampler instances (use_distributed_sampler: false
with a batch sampler means Lightning's auto-wrap skips us), so
we self-advance. set_epoch still works if a caller wants
deterministic resume from a specific epoch — call it before the
iteration and the advance will take the resumed epoch as its
starting point.
__len__()
¶
Return number of batches this rank will yield.
set_epoch(epoch)
¶
Set epoch for deterministic shuffling across DDP ranks.
GPUTransformDataModule
¶
Bases: ABC, LightningDataModule
Abstract data module with GPU transforms.
train_cpu_transforms
abstractmethod
property
¶
Return training CPU transforms.
train_gpu_transforms
abstractmethod
property
¶
Return training GPU transforms.
val_cpu_transforms
abstractmethod
property
¶
Return validation CPU transforms.
val_gpu_transforms
abstractmethod
property
¶
Return validation GPU transforms.
on_after_batch_transfer(batch, dataloader_idx)
¶
Apply GPU transforms after batch transfer to device.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict
|
Batch dict with channel-name keys mapped to |
required |
dataloader_idx
|
int
|
Dataloader index (unused). |
required |
Returns:
| Type | Description |
|---|---|
dict
|
Transformed batch (e.g., with |
train_dataloader()
¶
Return training data loader.
val_dataloader()
¶
Return validation data loader.
HCSDataModule
¶
Bases: LightningDataModule
Lightning data module for a preprocessed HCS NGFF Store.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_path
|
str
|
Path to the data store. |
required |
source_channel
|
str or Sequence[str]
|
Name(s) of the source channel, e.g. 'Phase'. |
required |
target_channel
|
str or Sequence[str]
|
Name(s) of the target channel, e.g. ['Nuclei', 'Membrane']. |
required |
z_window_size
|
int
|
Z window size of the 2.5D U-Net, 1 for 2D. |
required |
split_ratio
|
float
|
Split ratio of the training subset in the fit stage, e.g. 0.8 means an 80/20 split between training/validation, by default 0.8. |
0.8
|
batch_size
|
int
|
Batch size, defaults to 16. |
16
|
num_workers
|
int
|
Number of data-loading workers, defaults to 8. |
8
|
target_2d
|
bool
|
Whether the target is 2D (e.g. in a 2.5D model), defaults to False. |
False
|
yx_patch_size
|
tuple[int, int]
|
Patch size in (Y, X), defaults to (256, 256). |
(256, 256)
|
normalizations
|
list of MapTransform or None
|
MONAI dictionary transforms applied to selected channels, defaults to None (no normalization). |
None
|
augmentations
|
list of MapTransform or None
|
MONAI dictionary transforms applied to the training set, defaults to None (no augmentation). |
None
|
mmap_preload
|
bool
|
If |
False
|
scratch_dir
|
Path or None
|
Directory for mmap cache files. Defaults to |
None
|
ground_truth_masks
|
Path or None
|
Path to the ground truth masks, used in the test stage to compute segmentation metrics, defaults to None. |
None
|
persistent_workers
|
bool
|
Whether to keep the workers alive between fitting epochs, defaults to False. |
False
|
prefetch_factor
|
int or None
|
Number of samples loaded in advance by each worker during fitting, defaults to None (2 per PyTorch default). |
None
|
array_key
|
str
|
Name of the image arrays (multiscales level), by default "0". |
'0'
|
min_nonzero_fraction
|
float
|
Minimum fraction of voxels above |
0.0
|
nonzero_threshold
|
float
|
Intensity threshold for the nonzero fraction check, by default 0.0. |
0.0
|
nonzero_channel
|
str or None
|
Channel to check. |
None
|
max_nonzero_retries
|
int
|
Maximum retries when a patch fails the nonzero check, by default 100. |
100
|
fg_mask_key
|
str or None
|
Zarr array key for precomputed foreground masks, by default None. |
None
|
val_gpu_augmentations
|
list[MapTransform] or None
|
GPU transforms applied to validation batches in
|
None
|
include_fov_names
|
Iterable[str] or None
|
If given, only positions whose plate-relative name (e.g.
|
None
|
exclude_fov_names
|
Iterable[str] or None
|
If given, positions whose plate-relative name is in this
collection are skipped. Useful to hold out test FOVs from a plate
that also contains training FOVs, or to resume a predict run
without redoing already-written positions. Applied after
|
None
|
on_after_batch_transfer(batch, dataloader_idx)
¶
Apply GPU augmentations and validate output spatial shape.
Training: applies gpu_augmentations if configured. When no
gpu_augmentations are set, validates that source spatial
dimensions match (z_window_size, *yx_patch_size).
Validation: applies val_gpu_augmentations if configured.
Test/predict: pass through unchanged.
When target_2d is set, the target center Z slice is extracted
after augmentations to save VRAM.
predict_dataloader()
¶
Return predict data loader.
prepare_data()
¶
Stage FOVs to a memory-mapped tensor buffer on local scratch.
Runs only when the current stage is fit or validate. Predict
and test stages skip mmap staging entirely (they read from
zarr directly and ignore FOV filters). Manual
dm.prepare_data() calls without a trainer attached (or
with a trainer whose state.fn is None) also run —
preserving existing standalone usage.
setup(stage)
¶
Set up datasets for the given stage.
test_dataloader()
¶
Return test data loader.
train_dataloader()
¶
Return training data loader.
val_dataloader()
¶
Return validation data loader.
HCSStackIndex
¶
Bases: NamedTuple
HCS stack index.
LevelNormStats
¶
Bases: TypedDict
Per-level normalization statistics.
Not all fields are present for every level. The normalize transforms
access fields dynamically based on subtrahend and divisor
config (e.g. mean/std or median/iqr).
LiveCellDataModule
¶
Bases: GPUTransformDataModule
Data module for LiveCell training and evaluation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
train_val_images
|
Path | None
|
Path to the training/validation images directory. |
None
|
test_images
|
Path | None
|
Path to the test images directory. |
None
|
train_annotations
|
Path | None
|
Path to the training COCO annotations file. |
None
|
val_annotations
|
Path | None
|
Path to the validation COCO annotations file. |
None
|
test_annotations
|
Path | None
|
Path to the test COCO annotations file. |
None
|
train_cpu_transforms
|
list[MapTransform]
|
CPU transforms for training. |
None
|
val_cpu_transforms
|
list[MapTransform]
|
CPU transforms for validation. |
None
|
train_gpu_transforms
|
list[MapTransform]
|
GPU transforms for training. |
None
|
val_gpu_transforms
|
list[MapTransform]
|
GPU transforms for validation. |
None
|
test_transforms
|
list[MapTransform]
|
Transforms for test stage. |
None
|
batch_size
|
int
|
Batch size, by default 16. |
16
|
num_workers
|
int
|
Number of dataloading workers, by default 8. |
8
|
pin_memory
|
bool
|
Pin memory for dataloaders, by default True. |
True
|
train_cpu_transforms
property
¶
Return training CPU transforms.
train_gpu_transforms
property
¶
Return training GPU transforms.
val_cpu_transforms
property
¶
Return validation CPU transforms.
val_gpu_transforms
property
¶
Return validation GPU transforms.
setup(stage)
¶
Set up datasets for the given stage.
test_dataloader()
¶
Return test data loader.
LiveCellDataset
¶
Bases: Dataset
LiveCell dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
images
|
list of Path
|
List of paths to single-page, single-channel TIFF files. |
required |
transform
|
Transform or Compose
|
Transform to apply to the dataset. |
required |
cache_map
|
DictProxy
|
Shared dictionary for caching images. |
required |
LiveCellTestDataset
¶
Bases: Dataset
LiveCell test dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
image_dir
|
Path
|
Directory containing the images. |
required |
transform
|
MapTransform | Compose
|
Transform to apply to the dataset. |
required |
annotations
|
Path
|
Path to the COCO annotations file. |
required |
load_target
|
bool
|
Whether to load the target images (default is False). |
False
|
load_labels
|
bool
|
Whether to load the labels (default is False). |
False
|
MaskTestDataset
¶
Bases: SlidingWindowDataset
Test dataset with optional ground truth masks.
Each element is a window of (C, Z, Y, X) where C=2 (source and target)
and Z is z_window_size.
This a testing stage version of
class:
viscy_data.sliding_window.SlidingWindowDataset,
and can only be used with batch size 1 for efficiency (no padding for collation),
since the mask is not available for each stack.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
positions
|
list[Position]
|
FOVs to include in dataset. |
required |
channels
|
ChannelMap
|
Source and target channel names,
e.g. |
required |
z_window_size
|
int
|
Z window size of the 2.5D U-Net, 1 for 2D. |
required |
transform
|
DictTransform
|
A callable that transforms data, defaults to None. |
None
|
ground_truth_masks
|
str | None
|
Path to the ground truth masks. |
None
|
array_key
|
str
|
Name of the image arrays (multiscales level), by default "0". |
'0'
|
__getitem__(index)
¶
Return a sample with optional ground truth mask.
MmappedDataModule
¶
Bases: GPUTransformDataModule, SelectWell
Data module for cached OME-Zarr arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_path
|
Path
|
Path to the HCS OME-Zarr dataset. |
required |
channels
|
str | list[str]
|
Channel names to load. |
required |
batch_size
|
int
|
Batch size for training and validation. |
required |
num_workers
|
int
|
Number of workers for data-loaders. |
required |
split_ratio
|
float
|
Fraction of the FOVs used for the training split. The rest will be used for validation. |
required |
train_cpu_transforms
|
list[DictTransform]
|
Transforms to be applied on the CPU during training. |
required |
val_cpu_transforms
|
list[DictTransform]
|
Transforms to be applied on the CPU during validation. |
required |
train_gpu_transforms
|
list[DictTransform]
|
Transforms to be applied on the GPU during training. |
required |
val_gpu_transforms
|
list[DictTransform]
|
Transforms to be applied on the GPU during validation. |
required |
pin_memory
|
bool
|
Use page-locked memory in data-loaders, by default True |
True
|
prefetch_factor
|
int | None
|
Prefetching ratio for the torch dataloader, by default None |
None
|
array_key
|
str
|
Name of the image arrays (multiscales level), by default "0" |
'0'
|
scratch_dir
|
Path | None
|
Path to the scratch directory, by default None (use OS temporary data directory) |
None
|
include_wells
|
list[str] | None
|
Include only a subset of wells, by default None (include all wells) |
None
|
exclude_fovs
|
list[str] | None
|
Exclude FOVs, by default None (do not exclude any FOVs) |
None
|
cache_dir
property
¶
Return the cache directory path for memory-mapped tensors.
preprocessing_transforms
property
¶
Return preprocessing transforms.
train_cpu_transforms
property
¶
Return training CPU transforms.
train_gpu_transforms
property
¶
Return training GPU transforms.
val_cpu_transforms
property
¶
Return validation CPU transforms.
val_gpu_transforms
property
¶
Return validation GPU transforms.
setup(stage)
¶
Set up datasets for fit or validate stage.
MmappedDataset
¶
Bases: Dataset
Dataset backed by memory-mapped tensors for efficient caching.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
positions
|
list[Position]
|
List of FOVs to load images from. |
required |
channel_names
|
list[str]
|
Channel names to load. |
required |
cache_map
|
DictProxy
|
Shared dictionary for caching loaded volumes. |
required |
buffer
|
MemoryMappedTensor
|
Memory-mapped tensor buffer for cached volumes. |
required |
preprocess_transforms
|
Compose | None
|
Preprocessing transforms, by default None. |
None
|
cpu_transform
|
Compose | None
|
CPU transforms, by default None. |
None
|
array_key
|
str
|
The image array key name (multi-scale level), by default "0". |
'0'
|
load_normalization_metadata
|
bool
|
Load normalization metadata in the sample dictionary, by default True. |
True
|
Sample
¶
Bases: TypedDict
Image sample type for mini-batches.
All fields are optional.
SampleMeta
¶
Bases: TypedDict
Biological metadata carried in train-mode batches for sampler debugging.
Joinable against valid_anchors on (global_track_id, t).
Core fields are defined here. Domain-specific fields should be added by
subclassing SampleMeta (e.g. OpsSampleMeta). The labels field
is an open-ended dict of integer labels that auxiliary heads can consume
via batch_key without requiring a subclass.
SegmentationDataModule
¶
SegmentationDataset
¶
SegmentationSample
¶
Bases: TypedDict
Segmentation sample type for mini-batches.
SelectWell
¶
Mixin class for filtering wells and FOVs from HCS plates.
ShardedDistributedSampler
¶
Bases: DistributedSampler
Distributed sampler with sharded random permutations for DDP training.
__iter__()
¶
Shard data across distributed ranks.
SlidingWindowDataset
¶
Bases: Dataset
Sliding window dataset over HCS NGFF positions.
Each element is a window of (C, Z, Y, X) where C=2 (source and target)
and Z is z_window_size.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
positions
|
list[Position]
|
FOVs to include in dataset. |
required |
channels
|
ChannelMap
|
Source and target channel names,
e.g. |
required |
z_window_size
|
int
|
Z window size of the 2.5D U-Net, 1 for 2D. |
required |
array_key
|
str
|
Name of the image arrays (multiscales level), by default "0". |
'0'
|
transform
|
DictTransform | None
|
A callable that transforms data, defaults to None. |
None
|
load_normalization_metadata
|
bool
|
Whether to load normalization metadata, defaults to True. |
True
|
min_nonzero_fraction
|
float
|
Minimum fraction of voxels above |
0.0
|
nonzero_threshold
|
float
|
Intensity threshold for the nonzero fraction check. Default 0.0 means any nonzero voxel counts. |
0.0
|
nonzero_channel
|
str or None
|
Channel name to check for nonzero fraction. |
None
|
max_nonzero_retries
|
int
|
Maximum number of random re-samples when a patch fails the nonzero fraction check. Default 100. |
100
|
fg_mask_key
|
str or None
|
Zarr array key for precomputed foreground masks. When set,
masks are loaded alongside images and included in the sample
as |
None
|
preloaded_fovs
|
list[Tensor] or None
|
Pre-loaded FOV data, one tensor per position with shape
|
None
|
TripletDataModule
¶
Bases: HCSDataModule
Lightning data module for triplet sampling of patches.
__init__(data_path, tracks_path, source_channel, z_range, initial_yx_patch_size=(512, 512), final_yx_patch_size=(224, 224), split_ratio=0.8, batch_size=16, num_workers=1, normalizations=None, augmentations=None, augment_validation=True, fit_include_wells=None, fit_exclude_fovs=None, predict_cells=False, include_fov_names=None, include_track_ids=None, time_interval='any', return_negative=True, persistent_workers=False, prefetch_factor=None, pin_memory=False, z_window_size=None, cache_pool_bytes=0)
¶
Lightning data module for triplet sampling of patches.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_path
|
str
|
Image dataset path |
required |
tracks_path
|
str
|
Tracks labels dataset path |
required |
source_channel
|
str | Sequence[str]
|
List of input channel names |
required |
z_range
|
tuple[int, int]
|
Range of valid z-slices |
required |
initial_yx_patch_size
|
tuple[int, int]
|
XY size of the initially sampled image patch, by default (512, 512) |
(512, 512)
|
final_yx_patch_size
|
tuple[int, int]
|
Output patch size, by default (224, 224) |
(224, 224)
|
split_ratio
|
float
|
Ratio of training samples, by default 0.8 |
0.8
|
batch_size
|
int
|
Batch size, by default 16 |
16
|
num_workers
|
int
|
Number of thread workers. Set to 0 to disable threading. Using more than 1 is not recommended. by default 1 |
1
|
normalizations
|
list[MapTransform] or None
|
Normalization transforms, by default None |
None
|
augmentations
|
list[MapTransform] or None
|
Augmentation transforms, by default None |
None
|
augment_validation
|
bool
|
Apply augmentations to validation data, by default True. Set to False for VAE training where clean validation is needed. |
True
|
fit_include_wells
|
list[str]
|
Only include these wells for fitting, by default None |
None
|
fit_exclude_fovs
|
list[str]
|
Exclude these FOVs for fitting, by default None |
None
|
predict_cells
|
bool
|
Only predict for selected cells, by default False |
False
|
include_fov_names
|
list[str] | None
|
Only predict for selected FOVs, by default None |
None
|
include_track_ids
|
list[int] | None
|
Only predict for selected tracks, by default None |
None
|
time_interval
|
Literal['any'] | int
|
Future time interval to sample positive and anchor from, "any" means sampling negative from another track any time point and using the augmented anchor patch as positive), by default "any" |
'any'
|
return_negative
|
bool
|
Whether to return the negative sample during the fit stage (can be set to False when using a loss function like NT-Xent), by default True |
True
|
persistent_workers
|
bool
|
Whether to keep worker processes alive between iterations, by default False |
False
|
prefetch_factor
|
int | None
|
Number of batches loaded in advance by each worker, by default None |
None
|
pin_memory
|
bool
|
Whether to pin memory in CPU for faster GPU transfer, by default False |
False
|
z_window_size
|
int
|
Size of the final Z window, by default None (inferred from z_range) |
None
|
cache_pool_bytes
|
int
|
Size of the tensorstore cache pool in bytes, attached to the
plate at |
0
|
on_after_batch_transfer(batch, dataloader_idx)
¶
Apply transforms after transferring to device.
predict_dataloader()
¶
Return predict data loader.
train_dataloader()
¶
Return training data loader.
val_dataloader()
¶
Return validation data loader.
TripletDataset
¶
Bases: Dataset
Dataset for triplet sampling of cells based on tracking.
__getitems__(indices)
¶
Return a batch of triplet samples for the given indices.
__init__(positions, tracks_tables, channel_names, initial_yx_patch_size, z_range, fit=True, predict_cells=False, include_fov_names=None, include_track_ids=None, time_interval='any', return_negative=True)
¶
Dataset for triplet sampling of cells based on tracking.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
positions
|
list[Position]
|
OME-Zarr images with consistent channel order |
required |
tracks_tables
|
list[DataFrame]
|
Data frames containing ultrack results |
required |
channel_names
|
list[str]
|
Input channel names |
required |
initial_yx_patch_size
|
tuple[int, int]
|
YX size of the initially sampled image patch |
required |
z_range
|
slice
|
Range of Z-slices |
required |
fit
|
bool
|
Fitting mode in which the full triplet will be sampled,
only sample anchor if |
True
|
predict_cells
|
bool
|
Only predict on selected cells, by default False |
False
|
include_fov_names
|
list[str] | None
|
Only predict on selected FOVs, by default None |
None
|
include_track_ids
|
list[int] | None
|
Only predict on selected track IDs, by default None |
None
|
time_interval
|
Literal['any'] | int
|
Future time interval to sample positive and anchor from, by default "any" (sample negative from another track any time point and use the augmented anchor patch as positive) |
'any'
|
return_negative
|
bool
|
Whether to return the negative sample during the fit stage (can be set to False when using a loss function like NT-Xent), by default True |
True
|
__len__()
¶
Return the number of valid anchor samples.
TripletSample
¶
Bases: TypedDict
Triplet sample type for mini-batches.
read_cell_index(path)
¶
Read a cell index parquet into a pandas DataFrame.
String columns are materialized as NumPy object arrays instead of
ArrowStringArray. ArrowStringArray-backed columns route every
boolean mask slice through pyarrow.compute.take, which allocates
a fresh buffer per string column and can spike peak RSS by 50+ GiB
on 80M-row indices during train/val FOV partitioning. NumPy object
columns make df[mask] a cheap gather.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
Path to parquet file. |
required |
Returns:
| Type | Description |
|---|---|
DataFrame
|
Cell index with correct dtypes. |
validate_cell_index(df, *, strict=False)
¶
Validate a cell index DataFrame against the canonical schema.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df
|
DataFrame
|
Cell index to validate. |
required |
strict
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
list[str]
|
Warnings (e.g. nullable columns that are entirely null). |
Raises:
| Type | Description |
|---|---|
ValueError
|
If required columns are missing or |
write_cell_index(df, path, *, validate=True)
¶
Write a cell index DataFrame to parquet with the canonical schema.
Missing nullable columns are added as None before writing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
df
|
DataFrame
|
Cell index to write. |
required |
path
|
str | Path
|
Output parquet path. |
required |
validate
|
bool
|
Run :func: |
True
|