
We implement an advanced, end-to-end Kornia tutorial and demonstrate how modern, differentiable computer vision can be built entirely in PyTorch. We start by constructing GPU-accelerated, synchronized augmentation pipelines for images, masks, and keypoints, then move into differentiable geometry by optimizing a homography directly through gradient descent. We also show how learned feature matching with LoFTR integrates with Kornia’s RANSAC to estimate robust homographies and produce a simple stitched output, even under constrained or offline-safe conditions. Finally, we ground these ideas in practice by training a lightweight CNN on CIFAR-10 using Kornia’s GPU augmentations, highlighting how research-grade vision pipelines translate naturally into learning systems. Check out the FULL CODES here.
from dataclasses import dataclass
from typing import Tuple
import sys, subprocess
def pip_install(pkgs):
subprocess.check_call([sys.executable, “-m”, “pip”, “install”, “-q”] + pkgs)
pip_install([
“kornia==0.8.2”,
“torch”,
“torchvision”,
“matplotlib”,
“numpy”,
“opencv-python-headless”
])
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
import cv2
import kornia
import kornia.augmentation as K
import kornia.geometry.transform as KG
from kornia.geometry.ransac import RANSAC
from kornia.feature import LoFTR
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
print(“Torch:”, torch.__version__)
print(“Kornia:”, kornia.__version__)
print(“Device:”, device)
We begin by setting up a fully reproducible environment, installing Kornia and its core dependencies to ensure GPU-accelerated, differentiable computer vision runs smoothly in Google Colab. We then import and organize PyTorch, Kornia, and supporting libraries, establishing a clean foundation for geometry, augmentation, and feature-matching workflows. We set the random seed and select the available compute device so that all subsequent experiments remain deterministic, debuggable, and performance-aware. Check out the FULL CODES here.
img_rgb = cv2.cvtColor(img_bgr_uint8, cv2.COLOR_BGR2RGB)
t = torch.from_numpy(img_rgb).permute(2, 0, 1).float() / 255.0
return t.unsqueeze(0)
def show(img_t: torch.Tensor, title: str = “”, max_size: int = 900):
x = img_t.detach().float().cpu().clamp(0, 1)
if x.shape[1] == 1:
x = x.repeat(1, 3, 1, 1)
x = x[0].permute(1, 2, 0).numpy()
h, w = x.shape[:2]
scale = min(1.0, max_size / max(h, w))
if scale < 1.0:
x = cv2.resize(x, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA)
plt.figure(figsize=(7, 5))
plt.imshow(x)
plt.axis(“off”)
plt.title(title)
plt.show()
def show_mask(mask_t: torch.Tensor, title: str = “”):
x = mask_t.detach().float().cpu().clamp(0, 1)[0, 0].numpy()
plt.figure(figsize=(6, 4))
plt.imshow(x)
plt.axis(“off”)
plt.title(title)
plt.show()
def download(url: str, path: str):
os.makedirs(os.path.dirname(path), exist_ok=True)
if not os.path.exists(path):
urllib.request.urlretrieve(url, path)
def safe_download(url: str, path: str) -> bool:
try:
os.makedirs(os.path.dirname(path), exist_ok=True)
if not os.path.exists(path):
urllib.request.urlretrieve(url, path)
return True
except Exception as e:
print(“Download failed:”, e)
return False
def make_grid_mask(h: int, w: int, cell: int = 32) -> torch.Tensor:
yy, xx = torch.meshgrid(torch.arange(h), torch.arange(w), indexing=”ij”)
m = (((yy // cell) % 2) ^ ((xx // cell) % 2)).float()
return m.unsqueeze(0).unsqueeze(0)
def draw_matches(img0_rgb: np.ndarray, img1_rgb: np.ndarray, pts0: np.ndarray, pts1: np.ndarray, max_draw: int = 200) -> np.ndarray:
h0, w0 = img0_rgb.shape[:2]
h1, w1 = img1_rgb.shape[:2]
out = np.zeros((max(h0, h1), w0 + w1, 3), dtype=np.uint8)
out[:h0, :w0] = img0_rgb
out[:h1, w0:w0+w1] = img1_rgb
n = min(len(pts0), len(pts1), max_draw)
if n == 0:
return out
idx = np.random.choice(len(pts0), size=n, replace=False) if len(pts0) > n else np.arange(n)
for i in idx:
x0, y0 = pts0[i]
x1, y1 = pts1[i]
x1_shift = x1 + w0
p0 = (int(round(x0)), int(round(y0)))
p1 = (int(round(x1_shift)), int(round(y1)))
cv2.circle(out, p0, 2, (255, 255, 255), -1, lineType=cv2.LINE_AA)
cv2.circle(out, p1, 2, (255, 255, 255), -1, lineType=cv2.LINE_AA)
cv2.line(out, p0, p1, (255, 255, 255), 1, lineType=cv2.LINE_AA)
return out
def normalize_img_for_loftr(img_rgb01: torch.Tensor) -> torch.Tensor:
if img_rgb01.shape[1] == 3:
return kornia.color.rgb_to_grayscale(img_rgb01)
return img_rgb01
We define a set of reusable helper utilities for image conversion, visualization, safe data downloading, and synthetic mask generation, keeping the vision pipeline clean and modular. We also implement robust visualization and matching helpers that allow us to inspect augmented images, masks, and LoFTR correspondences directly during experimentation. We normalize image inputs to the exact tensor formats expected by Kornia and LoFTR, ensuring that all downstream geometry and feature-matching components operate consistently and correctly. Check out the FULL CODES here.
B, C, H, W = 1, 3, 256, 384
img = torch.rand(B, C, H, W, device=device)
mask = make_grid_mask(H, W, cell=24).to(device)
kps = torch.tensor([[
[40.0, 40.0],
[W – 50.0, 50.0],
[W * 0.6, H * 0.8],
[W * 0.25, H * 0.65],
]], device=device)
aug = K.AugmentationSequential(
K.RandomResizedCrop((224, 224), scale=(0.6, 1.0), ratio=(0.8, 1.25), p=1.0),
K.RandomHorizontalFlip(p=0.5),
K.RandomRotation(degrees=18.0, p=0.7),
K.ColorJiggle(0.2, 0.2, 0.2, 0.1, p=0.8),
data_keys=[“input”, “mask”, “keypoints”],
same_on_batch=True
).to(device)
img_aug, mask_aug, kps_aug = aug(img, mask, kps)
print(“image:”, tuple(img.shape), “->”, tuple(img_aug.shape))
print(“mask :”, tuple(mask.shape), “->”, tuple(mask_aug.shape))
print(“kps :”, tuple(kps.shape), “->”, tuple(kps_aug.shape))
print(“Example keypoints (before -> after):”)
print(torch.cat([kps[0], kps_aug[0]], dim=1))
show(img, “Original (synthetic)”)
show_mask(mask, “Original mask (synthetic)”)
show(img_aug, “Augmented (synced)”)
show_mask(mask_aug, “Augmented mask (synced)”)
We construct a synchronized, fully differentiable augmentation pipeline that applies the same geometric transformations to images, masks, and keypoints on the GPU. We generate synthetic data to clearly demonstrate how spatial consistency is preserved across modalities while still introducing realistic variability through cropping, rotation, flipping, and color jitter. We visualize the before-and-after results to verify that the augmented images, segmentation masks, and keypoints remain perfectly aligned after transformation. Check out the FULL CODES here.
base = torch.rand(1, 1, 240, 320, device=device)
show(base, “Base image (grayscale)”)
true_H_px = torch.eye(3, device=device).unsqueeze(0)
true_H_px[:, 0, 2] = 18.0
true_H_px[:, 1, 2] = -12.0
true_H_px[:, 0, 1] = 0.03
true_H_px[:, 1, 0] = -0.02
true_H_px[:, 2, 0] = 1e-4
true_H_px[:, 2, 1] = -8e-5
target = KG.warp_perspective(base, true_H_px, dsize=(base.shape[-2], base.shape[-1]), align_corners=True)
show(target, “Target (base warped by true homography)”)
p = torch.zeros(1, 8, device=device, requires_grad=True)
def params_to_H(p8: torch.Tensor) -> torch.Tensor:
Bp = p8.shape[0]
Hm = torch.eye(3, device=p8.device).unsqueeze(0).repeat(Bp, 1, 1)
Hm[:, 0, 0] = 1.0 + p8[:, 0]
Hm[:, 0, 1] = p8[:, 1]
Hm[:, 0, 2] = p8[:, 2]
Hm[:, 1, 0] = p8[:, 3]
Hm[:, 1, 1] = 1.0 + p8[:, 4]
Hm[:, 1, 2] = p8[:, 5]
Hm[:, 2, 0] = p8[:, 6]
Hm[:, 2, 1] = p8[:, 7]
return Hm
opt = torch.optim.Adam([p], lr=0.08)
losses = []
for step in range(120):
opt.zero_grad(set_to_none=True)
H_est = params_to_H(p)
pred = KG.warp_perspective(base, H_est, dsize=(base.shape[-2], base.shape[-1]), align_corners=True)
loss_photo = (pred – target).abs().mean()
loss_reg = 1e-3 * (p ** 2).mean()
loss = loss_photo + loss_reg
loss.backward()
opt.step()
losses.append(loss.item())
print(“Final loss:”, losses[-1])
plt.figure(figsize=(6,4))
plt.plot(losses)
plt.title(“Homography optimization loss”)
plt.xlabel(“step”)
plt.ylabel(“loss”)
plt.show()
H_est_final = params_to_H(p.detach())
pred_final = KG.warp_perspective(base, H_est_final, dsize=(base.shape[-2], base.shape[-1]), align_corners=True)
show(pred_final, “Recovered warp (optimized)”)
show((pred_final – target).abs(), “Abs error (recovered vs target)”)
print(“True H (pixel):\n”, true_H_px.squeeze(0).detach().cpu().numpy())
print(“Est H:\n”, H_est_final.squeeze(0).detach().cpu().numpy())
We demonstrate that geometric alignment can be treated as a differentiable optimization problem by directly recovering a homography via gradient descent. We first generate a target image by warping a base image with a known homography and then learn the transformation parameters by minimizing a photometric reconstruction loss with regularization. Also, we visualize the optimized warp and error map to confirm that the estimated homography closely matches the ground-truth transformation. Check out the FULL CODES here.
data_dir = “/content/kornia_demo”
os.makedirs(data_dir, exist_ok=True)
img0_path = os.path.join(data_dir, “img0.png”)
img1_path = os.path.join(data_dir, “img1.png”)
ok0 = safe_download(
“https://raw.githubusercontent.com/opencv/opencv/master/samples/data/graf1.png”,
img0_path
)
ok1 = safe_download(
“https://raw.githubusercontent.com/opencv/opencv/master/samples/data/graf3.png”,
img1_path
)
if not (ok0 and ok1):
print(“⚠️ Using synthetic fallback images (no network / blocked downloads)”)
base_rgb = torch.rand(1, 3, 480, 640, device=device)
H_syn = torch.tensor([[
[1.0, 0.05, 40.0],
[-0.03, 1.0, 25.0],
[1e-4, -8e-5, 1.0]
]], device=device)
t0 = base_rgb
t1 = KG.warp_perspective(base_rgb, H_syn, dsize=(480, 640), align_corners=True)
img0_rgb = (t0[0].permute(1,2,0).detach().cpu().numpy() * 255).astype(np.uint8)
img1_rgb = (t1[0].permute(1,2,0).detach().cpu().numpy() * 255).astype(np.uint8)
else:
img0_bgr = cv2.imread(img0_path, cv2.IMREAD_COLOR)
img1_bgr = cv2.imread(img1_path, cv2.IMREAD_COLOR)
if img0_bgr is None or img1_bgr is None:
raise RuntimeError(“Failed to load downloaded images.”)
img0_rgb = cv2.cvtColor(img0_bgr, cv2.COLOR_BGR2RGB)
img1_rgb = cv2.cvtColor(img1_bgr, cv2.COLOR_BGR2RGB)
t0 = to_tensor_img_uint8(img0_bgr).to(device)
t1 = to_tensor_img_uint8(img1_bgr).to(device)
show(t0, “Image 0”)
show(t1, “Image 1″)
g0 = normalize_img_for_loftr(t0)
g1 = normalize_img_for_loftr(t1)
loftr = LoFTR(pretrained=”outdoor”).to(device).eval()
with torch.inference_mode():
correspondences = loftr({“image0”: g0, “image1”: g1})
mkpts0 = correspondences[“keypoints0”]
mkpts1 = correspondences[“keypoints1”]
mconf = correspondences.get(“confidence”, None)
print(“Raw matches:”, mkpts0.shape[0])
if mkpts0.shape[0] < 8:
raise RuntimeError(“Too few matches to estimate homography.”)
if mconf is not None:
mconf = mconf.detach()
topk = min(2000, mkpts0.shape[0])
idx = torch.topk(mconf, k=topk, largest=True).indices
mkpts0 = mkpts0[idx]
mkpts1 = mkpts1[idx]
print(“Kept top matches:”, mkpts0.shape[0])
ransac = RANSAC(
model_type=”homography”,
inl_th=3.0,
batch_size=4096,
max_iter=10,
confidence=0.999,
max_lo_iters=5
).to(device)
with torch.inference_mode():
H01, inliers = ransac(mkpts0, mkpts1)
print(“Estimated H shape:”, tuple(H01.shape))
print(“Inliers:”, int(inliers.sum().item()), “/”, int(inliers.numel()))
vis = draw_matches(
img0_rgb,
img1_rgb,
mkpts0.detach().cpu().numpy(),
mkpts1.detach().cpu().numpy(),
max_draw=250
)
plt.figure(figsize=(10,5))
plt.imshow(vis)
plt.axis(“off”)
plt.title(“LoFTR matches (subset)”)
plt.show()
H01 = H01.unsqueeze(0) if H01.ndim == 2 else H01
warped0 = KG.warp_perspective(t0, H01, dsize=(t1.shape[-2], t1.shape[-1]), align_corners=True)
stitched = torch.max(warped0, t1)
show(warped0, “Image0 warped into Image1 frame (via RANSAC homography)”)
show(stitched, “Simple stitched blend (max)”)
We perform learned feature matching using LoFTR to establish dense correspondences between two images, while ensuring robustness through a network-safe fallback mechanism. We then apply Kornia’s RANSAC to estimate a stable homography from these matches and warp one image into the coordinate frame of the other. We visualize the correspondences and produce a simple stitched result to validate the geometric alignment end-to-end. Check out the FULL CODES here.
cifar = torchvision.datasets.CIFAR10(root=”/content/data”, train=True, download=True)
num_samples = 4096
indices = np.random.permutation(len(cifar))[:num_samples]
subset = torch.utils.data.Subset(cifar, indices.tolist())
def collate(batch):
imgs = []
labels = []
for im, y in batch:
imgs.append(TF.to_tensor(im))
labels.append(y)
return torch.stack(imgs, 0), torch.tensor(labels)
loader = torch.utils.data.DataLoader(
subset, batch_size=256, shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate
)
aug_train = K.ImageSequential(
K.RandomHorizontalFlip(p=0.5),
K.RandomAffine(degrees=12.0, translate=(0.08, 0.08), scale=(0.9, 1.1), p=0.7),
K.ColorJiggle(0.2, 0.2, 0.2, 0.1, p=0.8),
K.RandomGaussianBlur((3, 3), (0.1, 1.5), p=0.3),
).to(device)
class TinyCifarNet(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 48, 3, padding=1)
self.conv2 = nn.Conv2d(48, 96, 3, padding=1)
self.conv3 = nn.Conv2d(96, 128, 3, padding=1)
self.head = nn.Linear(128, num_classes)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv3(x))
x = x.mean(dim=(-2, -1))
return self.head(x)
model = TinyCifarNet().to(device)
opt = torch.optim.AdamW(model.parameters(), lr=2e-3, weight_decay=1e-4)
model.train()
t_start = time.time()
running = []
for it, (xb, yb) in enumerate(loader):
xb = xb.to(device, non_blocking=True)
yb = yb.to(device, non_blocking=True)
xb = aug_train(xb)
logits = model(xb)
loss = F.cross_entropy(logits, yb)
opt.zero_grad(set_to_none=True)
loss.backward()
opt.step()
running.append(loss.item())
if (it + 1) % 10 == 0:
print(f”iter {it+1:03d}/{len(loader)} | loss {np.mean(running[-10:]):.4f}”)
if it >= 39:
break
print(“Done in”, round(time.time() – t_start, 2), “sec”)
plt.figure(figsize=(6,4))
plt.plot(running)
plt.title(“Training loss (quick demo)”)
plt.xlabel(“iteration”)
plt.ylabel(“loss”)
plt.show()
xb0, yb0 = next(iter(loader))
xb0 = xb0[:8].to(device)
xbA = aug_train(xb0)
def tile8(x):
x = x.detach().cpu().clamp(0,1)
grid = torchvision.utils.make_grid(x, nrow=4)
return grid.permute(1,2,0).numpy()
plt.figure(figsize=(10,5))
plt.imshow(tile8(xb0))
plt.axis(“off”)
plt.title(“CIFAR batch (original)”)
plt.show()
plt.figure(figsize=(10,5))
plt.imshow(tile8(xbA))
plt.axis(“off”)
plt.title(“CIFAR batch (Kornia-augmented on GPU)”)
plt.show()
print(“\n✅ Tutorial complete.”)
print(“Next ideas:”)
print(“- Feathered stitching (soft masks) instead of max-blend.”)
print(“- Compare LoFTR vs DISK/LightGlue using kornia.feature.”)
print(“- Multi-scale homography optimization + SSIM/Charbonnier losses.”)
We demonstrate how Kornia’s GPU-based augmentations integrate directly into a standard training loop by applying them on the fly to a subset of the CIFAR-10 dataset. We train a lightweight convolutional network end-to-end, demonstrating that differentiable augmentations incur minimal overhead while improving data diversity. At last, we visualize original versus augmented batches to confirm that the transformations are applied consistently and efficiently during learning.
In conclusion, we demonstrated that Kornia enables a unified vision workflow where data augmentation, geometric reasoning, feature matching, and learning remain differentiable and GPU-friendly within a single framework. By combining LoFTR matching, RANSAC-based homography estimation, and optimization-driven alignment with a practical training loop, we showed how classical vision and deep learning complement each other rather than compete. It serves as a foundation for extending toward production-grade stitching, robust pose estimation, or large-scale training pipelines, and we emphasize that the same patterns we used here scale naturally to more complex, real-world vision systems.
Check out the FULL CODES here. Also, feel free to follow us on Twitter and don’t forget to join our 100k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.

