A Simple Convolutional U-Net
In this tutorial, you will be guided on how to build a simple sparse convolutional neural network using fVDB.
A popular framework for implementing Sparse U-Nets is the MinkowskiEngine. The appendix below includes a step-by-step guide to porting MinkowskiEngine code to fVDB.
In our simplistic U-Net case, we want to build a Res-UNet with four layers, and each layer contains several blocks.
First, we import basic fvdb libraries:
import fvdb
import fvdb.nn as fvnn
from fvdb import ConvolutionPlan, GridBatch, JaggedTensor
import torch
Here fvdb.nn is a namespace similar to torch.nn, containing a broad definition of different neural layers.
Every fvdb.nn layer takes explicit (data: JaggedTensor, plan_or_grid) arguments — topology and features are always passed separately. A ConvolutionPlan precomputes the necessary acceleration structures for a given grid and kernel configuration, and is passed alongside the data to convolution layers.
We could then build a basic block as follows:
class Downsample1x1(torch.nn.Module):
"""1x1 conv + BN for channel projection in residual connections."""
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = fvnn.SparseConv3d(in_ch, out_ch, kernel_size=1, stride=1)
self.bn = fvnn.BatchNorm(out_ch)
def forward(self, data: JaggedTensor, grid: GridBatch) -> JaggedTensor:
plan = ConvolutionPlan.from_grid_batch(1, 1, source_grid=grid, target_grid=grid)
return self.bn(self.conv(data, plan), grid)
class BasicBlock(torch.nn.Module):
expansion = 1
def __init__(self, in_channels: int, out_channels: int, downsample=None, bn_momentum: float = 0.1):
super().__init__()
self.conv1 = fvnn.SparseConv3d(in_channels, out_channels, kernel_size=3, stride=1)
self.norm1 = fvnn.BatchNorm(out_channels, momentum=bn_momentum)
self.conv2 = fvnn.SparseConv3d(out_channels, out_channels, kernel_size=3, stride=1)
self.norm2 = fvnn.BatchNorm(out_channels, momentum=bn_momentum)
self.relu = torch.nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, data: JaggedTensor, plan: ConvolutionPlan) -> JaggedTensor:
grid = plan.target_grid_batch
residual = data
out = self.relu(self.norm1(self.conv1(data, plan), grid))
out = self.norm2(self.conv2(out, plan), grid)
if self.downsample is not None:
residual = self.downsample(data, grid)
out = fvdb.relu(out + residual)
return out
All the network layers are fully compatible with torch.nn. The key difference is that fvdb.nn layers take explicit (JaggedTensor, ConvolutionPlan) or (JaggedTensor, GridBatch) arguments instead of wrapping them in a carrier object.
A full network definition could then be built as:
class FVDBUNetBase(torch.nn.Module):
LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
CHANNELS = (32, 64, 128, 256, 256, 128, 96, 96)
INIT_DIM = 32
def __init__(self, in_channels, out_channels):
super().__init__()
self.inplanes = self.INIT_DIM
self.conv0p1s1 = fvnn.SparseConv3d(in_channels, self.inplanes, kernel_size=5, stride=1, bias=False)
self.bn0 = fvnn.BatchNorm(self.inplanes)
self.conv1p1s2 = fvnn.SparseConv3d(self.inplanes, self.inplanes, kernel_size=2, stride=2, bias=False)
self.bn1 = fvnn.BatchNorm(self.inplanes)
self.block1 = self._make_layer(BasicBlock, self.CHANNELS[0], self.LAYERS[0])
self.conv2p2s2 = fvnn.SparseConv3d(self.inplanes, self.inplanes, kernel_size=2, stride=2, bias=False)
self.bn2 = fvnn.BatchNorm(self.inplanes)
self.block2 = self._make_layer(BasicBlock, self.CHANNELS[1], self.LAYERS[1])
self.conv3p4s2 = fvnn.SparseConv3d(self.inplanes, self.inplanes, kernel_size=2, stride=2, bias=False)
self.bn3 = fvnn.BatchNorm(self.inplanes)
self.block3 = self._make_layer(BasicBlock, self.CHANNELS[2], self.LAYERS[2])
self.conv4p8s2 = fvnn.SparseConv3d(self.inplanes, self.inplanes, kernel_size=2, stride=2, bias=False)
self.bn4 = fvnn.BatchNorm(self.inplanes)
self.block4 = self._make_layer(BasicBlock, self.CHANNELS[3], self.LAYERS[3])
# Decoder uses SparseConvTranspose3d (separate class)
self.convtr4p16s2 = fvnn.SparseConvTranspose3d(
self.inplanes, self.CHANNELS[4], kernel_size=2, stride=2, bias=False)
self.bntr4 = fvnn.BatchNorm(self.CHANNELS[4])
self.inplanes = self.CHANNELS[4] + self.CHANNELS[2]
self.block5 = self._make_layer(BasicBlock, self.CHANNELS[4], self.LAYERS[4])
self.convtr5p8s2 = fvnn.SparseConvTranspose3d(
self.inplanes, self.CHANNELS[5], kernel_size=2, stride=2, bias=False)
self.bntr5 = fvnn.BatchNorm(self.CHANNELS[5])
self.inplanes = self.CHANNELS[5] + self.CHANNELS[1]
self.block6 = self._make_layer(BasicBlock, self.CHANNELS[5], self.LAYERS[5])
self.convtr6p4s2 = fvnn.SparseConvTranspose3d(
self.inplanes, self.CHANNELS[6], kernel_size=2, stride=2, bias=False)
self.bntr6 = fvnn.BatchNorm(self.CHANNELS[6])
self.inplanes = self.CHANNELS[6] + self.CHANNELS[0]
self.block7 = self._make_layer(BasicBlock, self.CHANNELS[6], self.LAYERS[6])
self.convtr7p2s2 = fvnn.SparseConvTranspose3d(
self.inplanes, self.CHANNELS[7], kernel_size=2, stride=2, bias=False)
self.bntr7 = fvnn.BatchNorm(self.CHANNELS[7])
self.inplanes = self.CHANNELS[7] + self.INIT_DIM
self.block8 = self._make_layer(BasicBlock, self.CHANNELS[7], self.LAYERS[7])
self.final = fvnn.SparseConv3d(self.CHANNELS[7], out_channels, kernel_size=1)
self.relu = torch.nn.ReLU(inplace=True)
def _make_layer(self, block, planes, blocks):
downsample = None
if self.inplanes != planes * block.expansion:
downsample = Downsample1x1(self.inplanes, planes * block.expansion)
layers = []
layers.append(BasicBlock(self.inplanes, planes, downsample=downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(BasicBlock(self.inplanes, planes))
return torch.nn.ModuleList(layers)
def _run_block(self, block, data, plan):
for layer in block:
data = layer(data, plan)
return data
def forward(self, features: JaggedTensor, grid: GridBatch) -> JaggedTensor:
# --- Encoder ---
# stride=1: initial convolution
plan0 = ConvolutionPlan.from_grid_batch(5, 1, source_grid=grid, target_grid=grid)
out = self.relu(self.bn0(self.conv0p1s1(features, plan0), grid))
out_p1 = out
grid1 = grid
# stride=2 downsample
plan_d1 = ConvolutionPlan.from_grid_batch(2, 2, source_grid=grid1)
out = self.relu(self.bn1(self.conv1p1s2(out, plan_d1), plan_d1.target_grid_batch))
grid2 = plan_d1.target_grid_batch
plan_s1 = ConvolutionPlan.from_grid_batch(3, 1, source_grid=grid2, target_grid=grid2)
out_b1p2 = self._run_block(self.block1, out, plan_s1)
plan_d2 = ConvolutionPlan.from_grid_batch(2, 2, source_grid=grid2)
out = self.relu(self.bn2(self.conv2p2s2(out_b1p2, plan_d2), plan_d2.target_grid_batch))
grid4 = plan_d2.target_grid_batch
plan_s2 = ConvolutionPlan.from_grid_batch(3, 1, source_grid=grid4, target_grid=grid4)
out_b2p4 = self._run_block(self.block2, out, plan_s2)
plan_d3 = ConvolutionPlan.from_grid_batch(2, 2, source_grid=grid4)
out = self.relu(self.bn3(self.conv3p4s2(out_b2p4, plan_d3), plan_d3.target_grid_batch))
grid8 = plan_d3.target_grid_batch
plan_s3 = ConvolutionPlan.from_grid_batch(3, 1, source_grid=grid8, target_grid=grid8)
out_b3p8 = self._run_block(self.block3, out, plan_s3)
plan_d4 = ConvolutionPlan.from_grid_batch(2, 2, source_grid=grid8)
out = self.relu(self.bn4(self.conv4p8s2(out_b3p8, plan_d4), plan_d4.target_grid_batch))
grid16 = plan_d4.target_grid_batch
plan_s4 = ConvolutionPlan.from_grid_batch(3, 1, source_grid=grid16, target_grid=grid16)
out = self._run_block(self.block4, out, plan_s4)
# --- Decoder ---
# Transposed convolutions use from_grid_batch_transposed with target grids from encoder
plan_u4 = ConvolutionPlan.from_grid_batch_transposed(2, 2, source_grid=grid16, target_grid=grid8)
out = self.relu(self.bntr4(self.convtr4p16s2(out, plan_u4), grid8))
out = fvdb.jcat([out, out_b3p8], dim=1)
plan_s5 = ConvolutionPlan.from_grid_batch(3, 1, source_grid=grid8, target_grid=grid8)
out = self._run_block(self.block5, out, plan_s5)
plan_u5 = ConvolutionPlan.from_grid_batch_transposed(2, 2, source_grid=grid8, target_grid=grid4)
out = self.relu(self.bntr5(self.convtr5p8s2(out, plan_u5), grid4))
out = fvdb.jcat([out, out_b2p4], dim=1)
plan_s6 = ConvolutionPlan.from_grid_batch(3, 1, source_grid=grid4, target_grid=grid4)
out = self._run_block(self.block6, out, plan_s6)
plan_u6 = ConvolutionPlan.from_grid_batch_transposed(2, 2, source_grid=grid4, target_grid=grid2)
out = self.relu(self.bntr6(self.convtr6p4s2(out, plan_u6), grid2))
out = fvdb.jcat([out, out_b1p2], dim=1)
plan_s7 = ConvolutionPlan.from_grid_batch(3, 1, source_grid=grid2, target_grid=grid2)
out = self._run_block(self.block7, out, plan_s7)
plan_u7 = ConvolutionPlan.from_grid_batch_transposed(2, 2, source_grid=grid2, target_grid=grid1)
out = self.relu(self.bntr7(self.convtr7p2s2(out, plan_u7), grid1))
out = fvdb.jcat([out, out_p1], dim=1)
plan_s8 = ConvolutionPlan.from_grid_batch(3, 1, source_grid=grid1, target_grid=grid1)
out = self._run_block(self.block8, out, plan_s8)
plan_final = ConvolutionPlan.from_grid_batch(1, 1, source_grid=grid1, target_grid=grid1)
return self.final(out, plan_final)
Please note that here, when we apply strided transposed convolution layers, we build a ConvolutionPlan using from_grid_batch_transposed with the encoder-side grid as the target_grid.
This is needed to guide the output domain of the network, because for perception networks, the output grid topology should align with the input topology.
ConvolutionPlans should be built once and reused across forward passes for efficiency.
Note: fVDB also provides a built-in
fvdb.nn.SimpleUNetthat implements this pattern with plan caching and other optimizations. Seefvdb/nn/simple_unet.pyfor the reference implementation.
To perform inference with the network, pass the features and grid explicitly:
coords = fvdb.JaggedTensor([
(torch.randn(10_000, 3, device='cuda')),
(torch.randn(11_000, 3, device='cuda')),
])
grid = fvdb.GridBatch.from_points(coords)
features = grid.jagged_like(torch.randn(grid.total_voxels, 32, device='cuda'))
model = FVDBUNetBase(32, 1).to('cuda')
output = model(features, grid)
The output will carry gradients during training, and you could train the sparse network accordingly.
You can find an interactive example for training a sparse convolutional UNet in the Convolution Lesson jupyter notebook available in the fvdb-core repository.
Appendix: Porting from MinkowskiEngine
If you were previously using MinkowskiEngine to tackle sparse 3D data, this section provides a step-by-step comparison to help you smoothly transfer your code to fVDB and enjoy speed-ups and memory savings.
BasicBlock comparison
The fVDB BasicBlock defined above corresponds to the following MinkowskiEngine equivalent:
import torch
import MinkowskiEngine as ME
class BasicBlock(torch.nn.Module):
expansion = 1
def __init__(self, in_channels: int, out_channels: int, downsample=None, bn_momentum: float = 0.1):
super().__init__()
self.conv1 = ME.MinkowskiConvolution(
in_channels, out_channels, kernel_size=3, stride=1, dilation=1, dimension=3)
self.norm1 = ME.MinkowskiBatchNorm(out_channels, momentum=bn_momentum)
self.conv2 = ME.MinkowskiConvolution(
out_channels, out_channels, kernel_size=3, stride=1, dilation=1, dimension=3)
self.norm2 = ME.MinkowskiBatchNorm(out_channels, momentum=bn_momentum)
self.relu = ME.MinkowskiReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
The key difference is that fVDB layers take explicit (JaggedTensor, ConvolutionPlan) or (JaggedTensor, GridBatch) arguments instead of wrapping them in a carrier object, while all the network layers remain fully compatible with torch.nn.