viscy-models¶
Neural network architectures, primarily for AI × imaging.
What's here
UNet variants for virtual staining and segmentation, contrastive encoders for self-supervised representation learning, and VAE components.
API reference¶
BetaVae25D
¶
Bases: Module
2.5D Beta-VAE combining VaeEncoder and VaeDecoder.
forward(x)
¶
Forward pass returning VAE outputs.
BetaVaeMonai
¶
Bases: Module
Beta-VAE with Monai architecture.
forward(x)
¶
Forward pass returning VAE encoder outputs.
ContrastiveEncoder
¶
Bases: Module
Contrastive encoder network using ConvNeXt v1 and ResNet backbones from timm.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
backbone
|
Literal['convnext_tiny', 'convnextv2_tiny', 'resnet50']
|
Name of the timm backbone architecture. |
required |
in_channels
|
int
|
Number of input channels. |
required |
in_stack_depth
|
int
|
Number of input Z slices. |
required |
stem_kernel_size
|
tuple[int, int, int]
|
Stem kernel size, by default (5, 4, 4). |
(5, 4, 4)
|
stem_stride
|
tuple[int, int, int]
|
Stem stride, by default (5, 4, 4). |
(5, 4, 4)
|
embedding_dim
|
int
|
Embedded feature dimension that matches backbone output channels, by default 768 (convnext_tiny). |
768
|
projection_dim
|
int
|
Projection dimension for computing loss, by default 128. |
128
|
drop_path_rate
|
float
|
Probability that residual connections are dropped during training, by default 0.0. |
0.0
|
pretrained
|
bool
|
Whether to load pretrained weights for the backbone, by default False. |
False
|
forward(x)
¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input image. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor]
|
The embedding tensor and the projection tensor. |
CosineClassifier
¶
Bases: Module
L2-normalised linear head with learnable temperature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_dim
|
int
|
Input feature dimensionality. |
required |
num_classes
|
int
|
Number of output classes. |
required |
init_scale
|
float
|
Initial value of the temperature scale (before log). |
20.0
|
learn_scale
|
bool
|
Whether to make the temperature a learnable parameter. |
True
|
DINOv3Model
¶
Bases: Module
Wrap a HuggingFace DINOv3 vision model for microscopy images.
The model accepts raw dataloader tensors (B, C, D, H, W) directly in
:meth:forward — preprocessing is applied inline. Alternatively, call
:meth:preprocess_2d manually before passing (B, C, H, W) input.
Z-slice selection is not handled here — configure z_range on the
dataloader so it delivers the correct focal plane (see
get_z_range() in the evaluation utilities).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_name
|
str
|
HuggingFace model identifier, e.g.
|
required |
freeze
|
bool
|
If |
True
|
projection
|
Module or None
|
Optional trainable projection head applied to backbone features.
When provided, :meth: |
None
|
forward(x)
¶
Run the DINOv3 backbone on an image batch.
Preprocessing is applied inline, so raw dataloader tensors
(B, C, D, H, W) or (B, C, H, W) can be passed directly.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
|
required |
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor]
|
|
preprocess_2d(x, normalize=False)
¶
Convert a raw dataloader tensor to a normalised RGB image.
Handles squeezing a singleton Z dim, repeating/trimming channels to 3, resizing to the model's expected spatial size, and applying ImageNet mean/std.
normalize defaults to False because the production path
feeds dataloader-normalized input (e.g. NormalizeSampled
per-FOV stats); per-image min-max on top of that reintroduces
saturation from outlier patches. In that case the input is
treated as already z-scored and clipped to ±3σ →
[0, 1] deterministically before the ImageNet step.
Set normalize=True only when feeding raw zarr values without
any upstream normalization — this reverts to per-image min-max.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
|
required |
normalize
|
bool
|
Apply per-image min-max scale to |
False
|
Returns:
| Type | Description |
|---|---|
Tensor
|
|
train(mode=True)
¶
Override train to keep backbone in eval when frozen.
FullyConvolutionalMAE
¶
Bases: Module
Fully Convolutional Masked Autoencoder (FCMAE) for self-supervised pretraining.
forward(x, mask_ratio=0.0)
¶
Forward pass through the FCMAE.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor (BCDHW). |
required |
mask_ratio
|
float
|
Ratio of the feature maps to mask, defaults to 0.0. |
0.0
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Reconstructed output tensor. |
MLP
¶
Bases: Module
Configurable MLP with optional classification head and penultimate-layer extraction.
Supports two modes:
- Projection mode (
num_classes=None, default): maps embeddings to a projection space for contrastive loss. Output norm is applied to the final layer vianorm. - Classification mode (
num_classesset): adds a classification head (linear or cosine) on top of the backbone. Use :meth:encodeto extract L2-normalised penultimate-layer representations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_dims
|
int
|
Input feature dimension. |
required |
hidden_dims
|
int | list[int]
|
Hidden layer width. A single |
required |
out_dims
|
int
|
Output dimension of the final projection layer (ignored in classification
mode — the backbone output feeds directly into |
None
|
norm
|
Literal['bn', 'ln']
|
Normalization applied after each hidden layer. |
'bn'
|
activation
|
Literal['relu', 'gelu', 'silu']
|
Hidden activation function. |
'relu'
|
dropout
|
float
|
Dropout rate applied after each hidden layer. |
0.0
|
num_classes
|
int or None
|
When set, adds a classification head and enables :meth: |
None
|
cosine_classifier
|
bool
|
Use :class: |
True
|
encode(x)
¶
Return L2-normalised penultimate-layer representations.
Only valid when num_classes was set at construction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
L2-normalised backbone output of shape |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If called on a projection-mode MLP ( |
forward(x)
¶
Forward pass through backbone and optional head.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Projected or classified output. |
NTXentHCL
¶
Bases: NTXentLoss
NT-Xent loss with hard-negative concentration and optional temperature schedule.
When beta=0.0, produces identical results to standard NTXentLoss. When beta>0, up-weights hard negatives (high cosine similarity) in the denominator, focusing learning on difficult examples.
The HCL reweighting multiplies each negative pair's contribution in the denominator by exp(beta * sim(i, k)), concentrating gradient signal on negatives that are close to the anchor in embedding space.
Call :meth:step at the start of each epoch to apply the temperature schedule.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
temperature
|
float
|
Temperature scaling for cosine similarities. Default: 0.07. |
0.07
|
beta
|
float
|
Hard-negative concentration strength. 0.0 = standard NT-Xent. Higher values concentrate more on hard negatives. Default: 0.5. |
0.5
|
temperature_schedule
|
('cosine', 'constant')
|
Inherited from :class: |
"cosine"
|
temperature_start
|
float
|
Inherited from :class: |
0.1
|
temperature_warmup_epochs
|
int
|
Inherited from :class: |
50
|
OpenPhenomModel
¶
Bases: Module
Wrap Recursion's OpenPhenom CA-MAE ViT-S/16 for microscopy images.
OpenPhenom accepts 1–11 channel uint8 input at 256×256 and normalises
internally. :meth:preprocess_2d handles Z-squeeze, resize, and
float→uint8 conversion.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_name
|
str
|
HuggingFace model identifier, e.g. |
required |
freeze
|
bool
|
If |
True
|
forward(x)
¶
Run the OpenPhenom backbone on a preprocessed image batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input of shape |
required |
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor]
|
|
preprocess_2d(x)
¶
Convert a raw dataloader tensor to uint8 input for OpenPhenom.
Handles squeezing a singleton Z dim, resizing to 256×256, and rescaling float values to [0, 255] uint8 (OpenPhenom normalises internally).
Unlike DINOv3, no channel manipulation is needed — OpenPhenom accepts 1–11 channels natively.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
|
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
|
train(mode=True)
¶
Override train to keep backbone in eval when frozen.
ResNet3dEncoder
¶
Bases: Module
3D ResNet encoder network that uses MONAI's ResNetFeatures.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
backbone
|
str
|
Name of the backbone model. |
required |
in_channels
|
int
|
Number of input channels. |
1
|
embedding_dim
|
int
|
Embedded feature dimension that matches backbone output channels, by default 512 (ResNet-18). |
512
|
projection_dim
|
int
|
Projection dimension for computing loss, by default 128. |
128
|
pretrained
|
bool
|
Whether to load pretrained weights for the backbone, by default False. |
False
|
forward(x)
¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input image. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor]
|
The embedding tensor and the projection tensor. |
UNeXt2
¶
Unet25d
¶
Bases: Module
2.5D UNet for learning 3D-to-2D compression.
Architecture takes in stack of 2D inputs given as a 3D tensor and returns a 2D interpretation. Learns 3D information based upon input stack, but speeds up training by compressing 3D information before the decoding path. Uses interruption conv layers in the UNet skip paths to compress information with z-channel convolution.
Reference: https://elifesciences.org/articles/55502
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of feature channels in (1 or more). |
1
|
out_channels
|
int
|
Number of feature channels out (1 or more). |
1
|
in_stack_depth
|
int
|
Depth of input stack in z. |
5
|
out_stack_depth
|
int
|
Depth of output stack. |
1
|
xy_kernel_size
|
int or tuple of int
|
Size of x and y dimensions of conv kernels in blocks. |
(3, 3)
|
residual
|
bool
|
Whether to use residual connections. |
False
|
dropout
|
float
|
Probability of dropout, between 0 and 0.5. |
0.2
|
num_blocks
|
int
|
Number of convolutional blocks on encoder and decoder paths. |
4
|
num_block_layers
|
int
|
Number of layer sequences repeated per block. |
2
|
num_filters
|
tuple of int
|
Filter counts at each conv block depth. |
()
|
task
|
str
|
Network task, one of 'seg' or 'reg'. |
'seg'
|
__name__()
¶
Return model name.
forward(x)
¶
Perform forward pass through the 2.5D UNet.
Call order: => num_block 3D convolutional blocks, with downsampling in between (encoder) => skip connections between corresponding blocks in encoder and decoder => num_block 2D (3D with 1 z-channel) convolutional blocks, with upsampling between them (decoder) => terminal block collapses to output dimensions
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input image tensor. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor with compressed z-dimension. |
register_modules(module_list, name)
¶
Register modules stored in a list to the model object.
Used to enable model graph creation with non-sequential model types and dynamic layer numbers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module_list
|
list of torch.nn.Module
|
List of modules to register. |
required |
name
|
str
|
Name of module type. |
required |
Unet2d
¶
Bases: Module
2D UNet with variable input/output channels and depth.
Follows 2D UNet Architecture: 1) UNet: https://arxiv.org/pdf/1505.04597.pdf 2) Residual UNet: https://arxiv.org/pdf/1711.10684.pdf
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of feature channels in. |
1
|
out_channels
|
int
|
Number of feature channels out. |
1
|
kernel_size
|
int or tuple of int
|
Size of x and y dimensions of conv kernels in blocks. |
(3, 3)
|
residual
|
bool
|
Whether to use residual connections. |
False
|
dropout
|
float
|
Probability of dropout, between 0 and 0.5. |
0.2
|
num_blocks
|
int
|
Number of convolutional blocks on encoder and decoder. |
4
|
num_block_layers
|
int
|
Number of layers per block. |
2
|
num_filters
|
tuple of int
|
Filter counts at each conv block depth. |
()
|
task
|
str
|
Network task, one of 'seg' or 'reg'. |
'seg'
|
__name__()
¶
Return model name.
forward(x, validate_input=False)
¶
Perform forward pass through the 2D UNet.
Call order: => num_block 2D convolutional blocks, with downsampling in between (encoder) => num_block 2D convolutional blocks, with upsampling between them (decoder) => skip connections between corresponding blocks on encoder and decoder => terminal block collapses to output dimensions
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input image tensor. |
required |
validate_input
|
bool
|
Deactivates assertions which are redundant if forward pass is being traced by tensorboard writer. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor with same spatial dimensions as input. |
register_modules(module_list, name)
¶
Register modules stored in a list to the model object.
Used to enable model graph creation with non-sequential model types and dynamic layer numbers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module_list
|
list of torch.nn.Module
|
List of modules to register. |
required |
name
|
str
|
Name of module type. |
required |
Unet3d
¶
Bases: UNet3DBase
3D U-Net following Ounkomol et al. 2018 (F-Net).
FNet-configured preset of the unified 3D U-Net base with BatchNorm, ReLU activations, non-residual double-conv blocks, and a convolutional bottleneck. Downsamples all three spatial dimensions.
All spatial dimensions (Z, Y, X) must be divisible by 2**depth.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels. |
1
|
out_channels
|
int
|
Number of output channels. |
1
|
depth
|
int
|
Number of downsampling levels. |
4
|
mult_chan
|
int
|
Base channel count at the first encoder level. |
32
|
in_stack_depth
|
int or None
|
Z-window size. Stored for engine compatibility
( |
None
|