JaggedTensor: Efficient Batching for Variable-Length Data
This tutorial introduces JaggedTensor, fVDB’s data structure for efficiently handling batches of tensors with varying sizes. We’ll explore how to create, manipulate, and perform operations on jagged tensors, with special emphasis on the type-safe PyTorch function overloads.
What is a JaggedTensor?
In many machine learning applications, we need to process batches of data where each item has a different number of elements. For example:
Point clouds with different numbers of points per scene
Graphs with different numbers of nodes
Sequences with different lengths
Sparse voxel grids with different numbers of active voxels
A JaggedTensor is conceptually a list of tensors with varying first dimensions: [N_0, *], [N_1, *], ..., [N_{B-1}, *] where B is the batch size, N_i is the number of elements in the i-th batch item, and * represents additional dimensions that match across all tensors.

Internally, JaggedTensor concatenates these tensors into a single flat jdata tensor of shape [N_0 + N_1 + ... + N_{B-1}, *] for efficient GPU processing. It maintains indexing structures (jidx, joffsets) to track batch boundaries.

Creating JaggedTensors
From a List of Tensors
The most common way to create a JaggedTensor is from a list of PyTorch tensors:
import torch
import fvdb
# Create three tensors with different first dimensions
t0 = torch.randn(100, 3) # 100 points, 3D
t1 = torch.randn(150, 3) # 150 points, 3D
t2 = torch.randn(120, 3) # 120 points, 3D
# Create a JaggedTensor containing all three
jt = fvdb.JaggedTensor([t0, t1, t2])
print(f"Number of tensors: {jt.num_tensors}") # 3
print(f"Total elements: {jt.jdata.shape[0]}") # 370 (100 + 150 + 120)
print(f"Element shape: {jt.rshape}") # (3,)
Using Factory Functions
fVDB provides factory functions similar to PyTorch’s tensor creation functions:
# Create jagged tensor filled with random values
jt_rand = fvdb.JaggedTensor.from_rand(lsizes=[100, 150, 120], rsizes=[3])
# lsizes: list of sizes for each tensor in the batch
# rsizes: shape of the regular (non-jagged) dimensions
# Create jagged tensor filled with zeros
jt_zeros = fvdb.JaggedTensor.from_zeros(lsizes=[100, 150, 120], rsizes=[3], device="cuda")
# Create jagged tensor filled with ones
jt_ones = fvdb.JaggedTensor.from_ones(lsizes=[100, 150, 120], rsizes=[3])
# Create jagged tensor filled with normal random values
jt_randn = fvdb.JaggedTensor.from_randn(lsizes=[100, 150, 120], rsizes=[3])
From Flat Data and Indices
If you already have flattened data and indexing information:
# Flattened data: 370 total elements
data = torch.randn(370, 3)
# Indices indicating which batch each element belongs to
indices = torch.tensor([0]*100 + [1]*150 + [2]*120)
# Create JaggedTensor
jt = fvdb.JaggedTensor.from_data_and_indices(data, indices, num_tensors=3)
Or using offsets:
# Offsets marking boundaries: [0, 100, 250, 370]
offsets = torch.tensor([0, 100, 250, 370])
jt = fvdb.JaggedTensor.from_data_and_offsets(data, offsets)
Accessing JaggedTensor Data
The Underlying Data: jdata
The jdata property provides access to the flattened concatenated tensor:
jt = fvdb.JaggedTensor([torch.randn(100, 3), torch.randn(150, 3)])
# Access the underlying flattened data
print(jt.jdata.shape) # torch.Size([250, 3])
# Modify the data directly
jt.jdata *= 2.0
Indexing and Iteration
You can index into a JaggedTensor to extract individual tensors:
jt = fvdb.JaggedTensor([torch.randn(100, 3), torch.randn(150, 3), torch.randn(120, 3)])
# Access the first tensor (returns a JaggedTensor with one element)
first = jt[0]
print(first.jdata.shape) # torch.Size([100, 3])
print(first.num_tensors) # 1
# Slice to get a subset
subset = jt[1:3] # Get tensors 1 and 2
print(subset.num_tensors) # 2
# Iterate over tensors
for i, tensor in enumerate(jt):
print(f"Tensor {i}: {tensor.jdata.shape}")
Unbinding to Regular Tensors
To convert a JaggedTensor back to a list of regular PyTorch tensors:
jt = fvdb.JaggedTensor([torch.randn(100, 3), torch.randn(150, 3)])
# Get list of individual tensors
tensors = jt.unbind()
print(len(tensors)) # 2
print(tensors[0].shape) # torch.Size([100, 3])
print(tensors[1].shape) # torch.Size([150, 3])
PyTorch Function Overloads: torch.* vs fvdb.*
A key feature of JaggedTensor is its integration with PyTorch’s function dispatch system through __torch_function__. This allows many PyTorch functions to work directly with JaggedTensor at runtime. However, static type checkers cannot understand this dynamic dispatch mechanism.
The solution: fVDB provides type-safe wrappers in the fvdb namespace that work with both JaggedTensor and regular Tensor, with proper type inference.
The Type Safety Problem
jt = fvdb.JaggedTensor([torch.randn(100, 3), torch.randn(150, 3)])
# This works at RUNTIME due to __torch_function__
result1 = torch.relu(jt) # type: ignore
# BUT: Type checker sees result1 as Tensor, not JaggedTensor
# You need type: ignore to suppress the warning
# This is type-safe and works at runtime AND compile-time
result2 = fvdb.relu(jt)
# Type checker correctly infers result2 as JaggedTensor
Both approaches produce identical results at runtime, but fvdb.relu() provides proper type safety for static analysis tools like mypy and pyright.
Why Use fvdb.* Functions?
Type Safety: Static type checkers correctly infer return types
IDE Support: Better autocomplete and inline documentation
Code Clarity: Explicit about working with both Tensor types
Zero Overhead: Direct delegation to PyTorch functions
Elementwise Operations
Elementwise operations apply to each element independently, preserving the jagged structure:
jt = fvdb.JaggedTensor([torch.randn(100, 3), torch.randn(150, 3)])
# Activation functions
activated = fvdb.relu(jt) # ReLU activation
sigmoid_out = fvdb.sigmoid(jt) # Sigmoid activation
tanh_out = fvdb.tanh(jt) # Tanh activation
# Mathematical operations
sqrt_out = fvdb.sqrt(jt.abs()) # Square root (on absolute values)
exp_out = fvdb.exp(jt) # Exponential
log_out = fvdb.log(jt.abs()) # Natural logarithm
# Rounding operations
floor_out = fvdb.floor(jt) # Round down
ceil_out = fvdb.ceil(jt) # Round up
round_out = fvdb.round(jt) # Round to nearest
# Utility operations
clamped = fvdb.clamp(jt, min=-1.0, max=1.0) # Clamp values
safe = fvdb.nan_to_num(jt) # Replace NaN/inf with numbers
In-Place Operations
Some operations have in-place variants (ending with _):
jt = fvdb.JaggedTensor([torch.randn(100, 3), torch.randn(150, 3)])
# In-place ReLU (modifies jt directly)
fvdb.relu_(jt)
# Compare with torch.* (works but needs type: ignore)
torch.relu_(jt) # type: ignore
Binary Operations
Binary operations work between two JaggedTensors, or between a JaggedTensor and a scalar:
jt1 = fvdb.JaggedTensor([torch.randn(100, 3), torch.randn(150, 3)])
jt2 = fvdb.JaggedTensor([torch.randn(100, 3), torch.randn(150, 3)])
# Arithmetic operations with scalars
scaled = fvdb.mul(jt1, 2.0) # Multiply by scalar
shifted = fvdb.add(jt1, 1.0) # Add scalar
divided = fvdb.true_divide(jt1, 3.0) # Divide by scalar
# Arithmetic operations between JaggedTensors
sum_jt = fvdb.add(jt1, jt2) # Elementwise addition
diff_jt = fvdb.sub(jt1, jt2) # Elementwise subtraction
prod_jt = fvdb.mul(jt1, jt2) # Elementwise multiplication
quot_jt = fvdb.true_divide(jt1, jt2) # Elementwise division
# Power and other operations
powered = fvdb.pow(jt1, 2.0) # Square each element
remainder = fvdb.remainder(jt1, 2.0) # Modulo operation
# Element-wise min/max
max_jt = fvdb.maximum(jt1, jt2) # Elementwise maximum
min_jt = fvdb.minimum(jt1, jt2) # Elementwise minimum
Comparison Operations
Comparison operations return boolean JaggedTensors:
jt1 = fvdb.JaggedTensor([torch.randn(100, 3), torch.randn(150, 3)])
jt2 = fvdb.JaggedTensor([torch.randn(100, 3), torch.randn(150, 3)])
# Comparisons with scalars
gt_zero = fvdb.gt(jt1, 0.0) # Greater than
ge_zero = fvdb.ge(jt1, 0.0) # Greater than or equal
lt_zero = fvdb.lt(jt1, 0.0) # Less than
le_zero = fvdb.le(jt1, 0.0) # Less than or equal
eq_zero = fvdb.eq(jt1, 0.0) # Equal to
ne_zero = fvdb.ne(jt1, 0.0) # Not equal to
# Comparisons between JaggedTensors
gt_mask = fvdb.gt(jt1, jt2) # Elementwise comparison
eq_mask = fvdb.eq(jt1, jt2) # Elementwise equality
# Conditional selection with where
# Select from jt1 where mask is True, otherwise from jt2
selected = fvdb.where(gt_mask, jt1, jt2)
Reduction Operations
Reductions that preserve the leading (flattened) dimension work seamlessly:
# Create JaggedTensor with shape [100, 10, 3] and [150, 10, 3]
jt = fvdb.JaggedTensor([torch.randn(100, 10, 3), torch.randn(150, 10, 3)])
print(jt.jdata.shape) # torch.Size([250, 10, 3])
# Reduce over non-primary dimensions
# These preserve the jagged structure (batch boundaries)
summed = fvdb.sum(jt, dim=-1) # Sum over last dimension
print(summed.jdata.shape) # torch.Size([250, 10])
print(summed.num_tensors) # 2 (structure preserved!)
mean_val = fvdb.mean(jt, dim=1) # Mean over dimension 1
print(mean_val.jdata.shape) # torch.Size([250, 3])
max_val = fvdb.amax(jt, dim=-1) # Max over last dimension
min_val = fvdb.amin(jt, dim=-1) # Min over last dimension
# Standard deviation and variance
std_val = fvdb.std(jt, dim=1) # Standard deviation
var_val = fvdb.var(jt, dim=1) # Variance
# Norms
l2_norm = fvdb.norm(jt, p=2, dim=-1) # L2 norm over last dimension
l1_norm = fvdb.norm(jt, p=1, dim=-1) # L1 norm
# Logical reductions
all_positive = fvdb.all(fvdb.gt(jt, 0.0), dim=-1) # All elements > 0?
any_positive = fvdb.any(fvdb.gt(jt, 0.0), dim=-1) # Any element > 0?
# Argmax and argmin
max_indices = fvdb.argmax(jt, dim=-1) # Indices of max values
min_indices = fvdb.argmin(jt, dim=-1) # Indices of min values
Important: These reductions only work on non-primary dimensions. To reduce over the jagged dimension itself (collapsing different-length tensors), use the specialized j* methods described below.
Chaining Operations
All fvdb functions return the same type as their input, enabling seamless chaining:
jt = fvdb.JaggedTensor([torch.randn(100, 3), torch.randn(150, 3)])
# Chain multiple operations (all type-safe!)
result = fvdb.sigmoid(
fvdb.add(
fvdb.relu(jt),
1.0
)
)
# Or more readable step-by-step:
activated = fvdb.relu(jt) # Apply ReLU
shifted = fvdb.add(activated, 1.0) # Add 1.0
normalized = fvdb.sigmoid(shifted) # Apply sigmoid
# Type checker knows the type at each step
print(type(activated)) # JaggedTensor
print(type(shifted)) # JaggedTensor
print(type(normalized)) # JaggedTensor
Jagged-Specific Operations: j* Methods
While PyTorch functions preserve the jagged structure, sometimes you need to operate directly on the jagged dimensions. fVDB provides specialized j* methods for this:
Jagged Sum: jsum()
Sum along the jagged dimension to reduce varying-length tensors:
jt = fvdb.JaggedTensor([torch.randn(100, 3), torch.randn(150, 3), torch.randn(120, 3)])
print(jt.jdata.shape) # torch.Size([370, 3])
# Sum each tensor in the batch (reduces jagged dimension)
summed = jt.jsum(dim=0)
print(summed.jdata.shape) # torch.Size([3, 3])
print(summed.num_tensors) # 3
# Each row is the sum over one tensor in the batch
Jagged Max/Min: jmax(), jmin()
jt = fvdb.JaggedTensor([torch.randn(100, 3), torch.randn(150, 3)])
# Max along jagged dimension (returns values and indices)
max_values, max_indices = jt.jmax(dim=0)
print(max_values.jdata.shape) # torch.Size([2, 3])
print(max_indices.jdata.shape) # torch.Size([2, 3])
Jagged Reshape: jreshape()
Reshape the jagged structure:
jt = fvdb.JaggedTensor([torch.randn(100, 3), torch.randn(150, 3)])
# Reshape to different jagged sizes
# Note: Total number of elements must match
reshaped = jt.jreshape(lshape=[50, 50, 75, 75]) # Split into 4 tensors
print(reshaped.num_tensors) # 4
Jagged Flatten: jflatten()
Flatten nested jagged structures:
# Create nested jagged structure
nested = fvdb.JaggedTensor([[torch.randn(10, 3), torch.randn(20, 3)],
[torch.randn(15, 3)]])
print(nested.ldim) # 2 (two levels of jagging)
# Flatten to single level
flattened = nested.jflatten()
print(flattened.ldim) # 1
print(flattened.num_tensors) # 3
Complete Example: Processing Variable-Length Point Clouds
Here’s a complete example demonstrating JaggedTensor usage for batch processing point clouds:
# Load or generate point clouds with different numbers of points
# (simulating 3 different scenes)
points_scene1 = torch.randn(1523, 3, device="cuda") # Scene 1: 1523 points
points_scene2 = torch.randn(2847, 3, device="cuda") # Scene 2: 2847 points
points_scene3 = torch.randn(1102, 3, device="cuda") # Scene 3: 1102 points
# Create JaggedTensor for batch processing
jt_points = fvdb.JaggedTensor([points_scene1, points_scene2, points_scene3])
print(f"Batched {jt_points.num_tensors} point clouds")
print(f"Total points: {jt_points.jdata.shape[0]}")
# Generate per-point features (simulating a neural network output)
features = fvdb.JaggedTensor.from_randn(lsizes=jt_points.lshape, rsizes=[64], device="cuda")
# Apply a series of transformations
# Step 1: Apply ReLU activation
features = fvdb.relu(features)
# Step 2: L2 normalize each feature vector
feature_norms = fvdb.norm(features, p=2, dim=-1, keepdim=True)
features = fvdb.true_divide(features, fvdb.add(feature_norms, 1e-8))
# Step 3: Apply attention-like scaling
# Compute attention scores per scene (jagged sum)
attention_scores = fvdb.sigmoid(fvdb.sum(features, dim=-1, keepdim=True))
features = fvdb.mul(features, attention_scores)
# Step 4: Aggregate features per scene
# Sum features across all points in each scene
scene_features = features.jsum(dim=0)
print(f"Scene features: {scene_features.jdata.shape}") # [3, 64]
# Step 5: Apply final transformation
scene_features = fvdb.relu(scene_features)
# Convert back to list if needed
feature_list = scene_features.unbind()
print(f"Scene 1 features: {feature_list[0].shape}") # [64]
print(f"Scene 2 features: {feature_list[1].shape}") # [64]
print(f"Scene 3 features: {feature_list[2].shape}") # [64]
When to Use torch.* vs fvdb.*
Here’s a decision guide:
Use fvdb.* functions when:
Working with
JaggedTensorand you want type safetyBuilding typed APIs or libraries
Using static type checkers (mypy, pyright)
You want better IDE autocomplete and documentation
Use torch.* functions when:
Working only with regular
torch.TensorType safety is not a concern
You’re okay with adding
type: ignorecomments
Remember: Both approaches are identical at runtime for supported operations. The difference is purely in static type checking.
Common Patterns and Best Practices
Pattern 1: Batch Processing with Variable Sizes
def process_batch(data_list):
"""Process a batch of variable-size data efficiently."""
# Bundle into JaggedTensor for efficient GPU processing
jt = fvdb.JaggedTensor(data_list)
# Apply transformations (single GPU kernel launch!)
jt = fvdb.relu(jt)
jt = fvdb.mul(jt, 2.0)
# Unbind if you need individual results
return jt.unbind()
Pattern 2: Safe Operations with Type Checking
def safe_activation(jt: fvdb.JaggedTensor) -> fvdb.JaggedTensor:
"""Apply activation with proper type hints."""
# Type checker knows return type is JaggedTensor
activated = fvdb.relu(jt)
normalized = fvdb.sigmoid(activated)
return normalized # Type-safe!
Pattern 3: Mixing Regular and Jagged Dimensions
def process_with_reductions(jt: fvdb.JaggedTensor) -> fvdb.JaggedTensor:
"""Reduce over regular dimensions, preserve jagged structure."""
# Input: [N_0, D, C] and [N_1, D, C] and ...
# where N_i varies (jagged), D and C are fixed (regular)
# Mean over D dimension (preserves jagged structure)
reduced = fvdb.mean(jt, dim=1) # Now shape: [N_0, C] and [N_1, C] and ...
# Apply activation
activated = fvdb.relu(reduced)
return activated
Pattern 4: Avoiding Invalid Operations
jt = fvdb.JaggedTensor([torch.randn(100, 3), torch.randn(150, 3)])
# WRONG: Cannot reduce over primary dimension with torch/fvdb functions
# result = fvdb.sum(jt) # This would fail!
# result = fvdb.sum(jt, dim=0) # This would also fail!
# RIGHT: Use j* methods for jagged dimension operations
result = jt.jsum(dim=0) # Correctly reduces jagged dimension
Integration with GridBatch
JaggedTensor is the foundation for GridBatch operations in fVDB. When working with sparse voxel grids, you’ll frequently encounter jagged tensors as inputs and outputs:
# Create point clouds (using JaggedTensor)
points = fvdb.JaggedTensor([torch.randn(1000, 3), torch.randn(1500, 3)])
# Build grids from points
grids = fvdb.GridBatch.from_points(points, voxel_sizes=0.1)
# Grid operations return JaggedTensors
voxel_coords = grids.ijk # JaggedTensor of voxel coordinates
print(type(voxel_coords)) # JaggedTensor
# You can use fvdb.* functions on the results
# Example: Normalize coordinates
normalized = fvdb.true_divide(voxel_coords.float(), 10.0)
Summary
JaggedTensor efficiently represents batches of variable-length data
fvdb. functions* provide type-safe operations that work with both
JaggedTensorand regularTensortorch. functions* work at runtime via
__torch_function__but lack static type safetyj methods* (jsum, jmax, jreshape, etc.) operate on the jagged dimension itself
Use fvdb.* for better type checking and IDE support
Use j* methods when you need to manipulate the jagged structure
For more examples, see:
Basic Concepts - Overview of JaggedTensor in fVDB
Single Grid - Using Grid (non-batched) vs GridBatch with JaggedTensor
Basic Grid Operations - How GridBatch operations use JaggedTensor