Skip to content

viscy-models

Neural network architectures, primarily for AI × imaging.

pip install viscy-models
uv add viscy-models

What's here

UNet variants for virtual staining and segmentation, contrastive encoders for self-supervised representation learning, and VAE components.

API reference

BetaVae25D

Bases: Module

2.5D Beta-VAE combining VaeEncoder and VaeDecoder.

forward(x)

Forward pass returning VAE outputs.

BetaVaeMonai

Bases: Module

Beta-VAE with Monai architecture.

forward(x)

Forward pass returning VAE encoder outputs.

ContrastiveEncoder

Bases: Module

Contrastive encoder network using ConvNeXt v1 and ResNet backbones from timm.

Parameters:

Name Type Description Default
backbone Literal['convnext_tiny', 'convnextv2_tiny', 'resnet50']

Name of the timm backbone architecture.

required
in_channels int

Number of input channels.

required
in_stack_depth int

Number of input Z slices.

required
stem_kernel_size tuple[int, int, int]

Stem kernel size, by default (5, 4, 4).

(5, 4, 4)
stem_stride tuple[int, int, int]

Stem stride, by default (5, 4, 4).

(5, 4, 4)
embedding_dim int

Embedded feature dimension that matches backbone output channels, by default 768 (convnext_tiny).

768
projection_dim int

Projection dimension for computing loss, by default 128.

128
drop_path_rate float

Probability that residual connections are dropped during training, by default 0.0.

0.0
pretrained bool

Whether to load pretrained weights for the backbone, by default False.

False

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

Input image.

required

Returns:

Type Description
tuple[Tensor, Tensor]

The embedding tensor and the projection tensor.

CosineClassifier

Bases: Module

L2-normalised linear head with learnable temperature.

Parameters:

Name Type Description Default
in_dim int

Input feature dimensionality.

required
num_classes int

Number of output classes.

required
init_scale float

Initial value of the temperature scale (before log).

20.0
learn_scale bool

Whether to make the temperature a learnable parameter.

True

DINOv3Model

Bases: Module

Wrap a HuggingFace DINOv3 vision model for microscopy images.

The model accepts raw dataloader tensors (B, C, D, H, W) directly in :meth:forward — preprocessing is applied inline. Alternatively, call :meth:preprocess_2d manually before passing (B, C, H, W) input.

Z-slice selection is not handled here — configure z_range on the dataloader so it delivers the correct focal plane (see get_z_range() in the evaluation utilities).

Parameters:

Name Type Description Default
model_name str

HuggingFace model identifier, e.g. "facebook/dinov3-small-imagenet1k-1-layer".

required
freeze bool

If True (default), all backbone parameters are frozen and the model is kept in eval mode.

True
projection Module or None

Optional trainable projection head applied to backbone features. When provided, :meth:forward returns (features, projection(features)). When None (default), returns (features, features). The projection is always trainable regardless of freeze.

None

forward(x)

Run the DINOv3 backbone on an image batch.

Preprocessing is applied inline, so raw dataloader tensors (B, C, D, H, W) or (B, C, H, W) can be passed directly.

Parameters:

Name Type Description Default
x Tensor

(B, C, D, H, W) or (B, C, H, W) — raw or preprocessed input.

required

Returns:

Type Description
tuple[Tensor, Tensor]

(features, projections) where features are the backbone pooler output (B, hidden_dim). If projection was provided at init, projections are self.projection(features); otherwise both elements are the same features tensor.

preprocess_2d(x, normalize=False)

Convert a raw dataloader tensor to a normalised RGB image.

Handles squeezing a singleton Z dim, repeating/trimming channels to 3, resizing to the model's expected spatial size, and applying ImageNet mean/std.

normalize defaults to False because the production path feeds dataloader-normalized input (e.g. NormalizeSampled per-FOV stats); per-image min-max on top of that reintroduces saturation from outlier patches. In that case the input is treated as already z-scored and clipped to ±3σ[0, 1] deterministically before the ImageNet step.

Set normalize=True only when feeding raw zarr values without any upstream normalization — this reverts to per-image min-max.

Parameters:

Name Type Description Default
x Tensor

(B, C, D, H, W) or (B, C, H, W).

required
normalize bool

Apply per-image min-max scale to [0, 1] before the ImageNet step. Default False (dataloader is expected to z-score upstream).

False

Returns:

Type Description
Tensor

(B, 3, H_target, W_target) ready for :meth:forward.

train(mode=True)

Override train to keep backbone in eval when frozen.

FullyConvolutionalMAE

Bases: Module

Fully Convolutional Masked Autoencoder (FCMAE) for self-supervised pretraining.

forward(x, mask_ratio=0.0)

Forward pass through the FCMAE.

Parameters:

Name Type Description Default
x Tensor

Input tensor (BCDHW).

required
mask_ratio float

Ratio of the feature maps to mask, defaults to 0.0.

0.0

Returns:

Type Description
Tensor

Reconstructed output tensor.

MLP

Bases: Module

Configurable MLP with optional classification head and penultimate-layer extraction.

Supports two modes:

  • Projection mode (num_classes=None, default): maps embeddings to a projection space for contrastive loss. Output norm is applied to the final layer via norm.
  • Classification mode (num_classes set): adds a classification head (linear or cosine) on top of the backbone. Use :meth:encode to extract L2-normalised penultimate-layer representations.

Parameters:

Name Type Description Default
in_dims int

Input feature dimension.

required
hidden_dims int | list[int]

Hidden layer width. A single int gives one hidden layer (matching the legacy two-layer behaviour); a list gives one layer per element.

required
out_dims int

Output dimension of the final projection layer (ignored in classification mode — the backbone output feeds directly into head).

None
norm Literal['bn', 'ln']

Normalization applied after each hidden layer. "bn" = BatchNorm1d, "ln" = LayerNorm.

'bn'
activation Literal['relu', 'gelu', 'silu']

Hidden activation function.

'relu'
dropout float

Dropout rate applied after each hidden layer. 0.0 disables dropout.

0.0
num_classes int or None

When set, adds a classification head and enables :meth:encode. When None (default), the MLP acts as a projection head.

None
cosine_classifier bool

Use :class:CosineClassifier instead of a plain linear head. Only used when num_classes is set.

True

encode(x)

Return L2-normalised penultimate-layer representations.

Only valid when num_classes was set at construction.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, in_dims).

required

Returns:

Type Description
Tensor

L2-normalised backbone output of shape (B, hidden_dims[-1]).

Raises:

Type Description
RuntimeError

If called on a projection-mode MLP (num_classes=None).

forward(x)

Forward pass through backbone and optional head.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, in_dims).

required

Returns:

Type Description
Tensor

Projected or classified output.

NTXentHCL

Bases: NTXentLoss

NT-Xent loss with hard-negative concentration and optional temperature schedule.

When beta=0.0, produces identical results to standard NTXentLoss. When beta>0, up-weights hard negatives (high cosine similarity) in the denominator, focusing learning on difficult examples.

The HCL reweighting multiplies each negative pair's contribution in the denominator by exp(beta * sim(i, k)), concentrating gradient signal on negatives that are close to the anchor in embedding space.

Call :meth:step at the start of each epoch to apply the temperature schedule.

Parameters:

Name Type Description Default
temperature float

Temperature scaling for cosine similarities. Default: 0.07.

0.07
beta float

Hard-negative concentration strength. 0.0 = standard NT-Xent. Higher values concentrate more on hard negatives. Default: 0.5.

0.5
temperature_schedule ('cosine', 'constant')

Inherited from :class:NTXentLoss. Default: "constant".

"cosine"
temperature_start float

Inherited from :class:NTXentLoss. Default: 0.1.

0.1
temperature_warmup_epochs int

Inherited from :class:NTXentLoss. Default: 50.

50

OpenPhenomModel

Bases: Module

Wrap Recursion's OpenPhenom CA-MAE ViT-S/16 for microscopy images.

OpenPhenom accepts 1–11 channel uint8 input at 256×256 and normalises internally. :meth:preprocess_2d handles Z-squeeze, resize, and float→uint8 conversion.

Parameters:

Name Type Description Default
model_name str

HuggingFace model identifier, e.g. "recursionpharma/OpenPhenom".

required
freeze bool

If True (default), all backbone parameters are frozen and the model is kept in eval mode.

True

forward(x)

Run the OpenPhenom backbone on a preprocessed image batch.

Parameters:

Name Type Description Default
x Tensor

Input of shape (B, C, 256, 256) uint8.

required

Returns:

Type Description
tuple[Tensor, Tensor]

(features, features) — both are the embedding of shape (B, 384). No separate projection head is used.

preprocess_2d(x)

Convert a raw dataloader tensor to uint8 input for OpenPhenom.

Handles squeezing a singleton Z dim, resizing to 256×256, and rescaling float values to [0, 255] uint8 (OpenPhenom normalises internally).

Unlike DINOv3, no channel manipulation is needed — OpenPhenom accepts 1–11 channels natively.

Parameters:

Name Type Description Default
x Tensor

(B, C, D, H, W) or (B, C, H, W).

required

Returns:

Type Description
Tensor

(B, C, 256, 256) uint8 tensor ready for :meth:forward.

train(mode=True)

Override train to keep backbone in eval when frozen.

ResNet3dEncoder

Bases: Module

3D ResNet encoder network that uses MONAI's ResNetFeatures.

Parameters:

Name Type Description Default
backbone str

Name of the backbone model.

required
in_channels int

Number of input channels.

1
embedding_dim int

Embedded feature dimension that matches backbone output channels, by default 512 (ResNet-18).

512
projection_dim int

Projection dimension for computing loss, by default 128.

128
pretrained bool

Whether to load pretrained weights for the backbone, by default False.

False

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

Input image.

required

Returns:

Type Description
tuple[Tensor, Tensor]

The embedding tensor and the projection tensor.

UNeXt2

Bases: Module

UNeXt2 model composing a timm encoder with custom stem, decoder, and head.

num_blocks property

2-times downscaling factor of the smallest feature map.

forward(x)

Forward pass through the UNeXt2 model.

Unet25d

Bases: Module

2.5D UNet for learning 3D-to-2D compression.

Architecture takes in stack of 2D inputs given as a 3D tensor and returns a 2D interpretation. Learns 3D information based upon input stack, but speeds up training by compressing 3D information before the decoding path. Uses interruption conv layers in the UNet skip paths to compress information with z-channel convolution.

Reference: https://elifesciences.org/articles/55502

Parameters:

Name Type Description Default
in_channels int

Number of feature channels in (1 or more).

1
out_channels int

Number of feature channels out (1 or more).

1
in_stack_depth int

Depth of input stack in z.

5
out_stack_depth int

Depth of output stack.

1
xy_kernel_size int or tuple of int

Size of x and y dimensions of conv kernels in blocks.

(3, 3)
residual bool

Whether to use residual connections.

False
dropout float

Probability of dropout, between 0 and 0.5.

0.2
num_blocks int

Number of convolutional blocks on encoder and decoder paths.

4
num_block_layers int

Number of layer sequences repeated per block.

2
num_filters tuple of int

Filter counts at each conv block depth.

()
task str

Network task, one of 'seg' or 'reg'.

'seg'

__name__()

Return model name.

forward(x)

Perform forward pass through the 2.5D UNet.

Call order: => num_block 3D convolutional blocks, with downsampling in between (encoder) => skip connections between corresponding blocks in encoder and decoder => num_block 2D (3D with 1 z-channel) convolutional blocks, with upsampling between them (decoder) => terminal block collapses to output dimensions

Parameters:

Name Type Description Default
x Tensor

Input image tensor.

required

Returns:

Type Description
Tensor

Output tensor with compressed z-dimension.

register_modules(module_list, name)

Register modules stored in a list to the model object.

Used to enable model graph creation with non-sequential model types and dynamic layer numbers.

Parameters:

Name Type Description Default
module_list list of torch.nn.Module

List of modules to register.

required
name str

Name of module type.

required

Unet2d

Bases: Module

2D UNet with variable input/output channels and depth.

Follows 2D UNet Architecture: 1) UNet: https://arxiv.org/pdf/1505.04597.pdf 2) Residual UNet: https://arxiv.org/pdf/1711.10684.pdf

Parameters:

Name Type Description Default
in_channels int

Number of feature channels in.

1
out_channels int

Number of feature channels out.

1
kernel_size int or tuple of int

Size of x and y dimensions of conv kernels in blocks.

(3, 3)
residual bool

Whether to use residual connections.

False
dropout float

Probability of dropout, between 0 and 0.5.

0.2
num_blocks int

Number of convolutional blocks on encoder and decoder.

4
num_block_layers int

Number of layers per block.

2
num_filters tuple of int

Filter counts at each conv block depth.

()
task str

Network task, one of 'seg' or 'reg'.

'seg'

__name__()

Return model name.

forward(x, validate_input=False)

Perform forward pass through the 2D UNet.

Call order: => num_block 2D convolutional blocks, with downsampling in between (encoder) => num_block 2D convolutional blocks, with upsampling between them (decoder) => skip connections between corresponding blocks on encoder and decoder => terminal block collapses to output dimensions

Parameters:

Name Type Description Default
x Tensor

Input image tensor.

required
validate_input bool

Deactivates assertions which are redundant if forward pass is being traced by tensorboard writer.

False

Returns:

Type Description
Tensor

Output tensor with same spatial dimensions as input.

register_modules(module_list, name)

Register modules stored in a list to the model object.

Used to enable model graph creation with non-sequential model types and dynamic layer numbers.

Parameters:

Name Type Description Default
module_list list of torch.nn.Module

List of modules to register.

required
name str

Name of module type.

required

Unet3d

Bases: UNet3DBase

3D U-Net following Ounkomol et al. 2018 (F-Net).

FNet-configured preset of the unified 3D U-Net base with BatchNorm, ReLU activations, non-residual double-conv blocks, and a convolutional bottleneck. Downsamples all three spatial dimensions.

All spatial dimensions (Z, Y, X) must be divisible by 2**depth.

Parameters:

Name Type Description Default
in_channels int

Number of input channels.

1
out_channels int

Number of output channels.

1
depth int

Number of downsampling levels.

4
mult_chan int

Base channel count at the first encoder level.

32
in_stack_depth int or None

Z-window size. Stored for engine compatibility (example_input_array, DivisiblePad, sliding window prediction). The model itself handles arbitrary Z as long as it is divisible by 2**depth.

None