MNIST with PyTorch to ONNX: Machine Learning in the Browser
MNIST with PyTorch to ONNX: Machine Learning in the Browser
Machine learning models are typically run server-side. However, with ONNX Runtime Web, we can run trained models directly in the user's browser. In this post, I'll walk through how I built my MNIST digit recognition project step by step.
Project Goal
Objective: A web app that recognizes handwritten digits, running entirely in the browser.
Requirements:
- Zero server cost
- Real-time prediction (
<100ms) - User privacy (data never leaves browser)
- 99%+ accuracy
1. Model Architecture and Design
I chose a classic CNN (Convolutional Neural Network) architecture for MNIST:
class MNISTNet(nn.Module):
def __init__(self):
super(MNISTNet, self).__init__()
# Conv Layer 1: 1 β 32 channels
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2, 2) # 28x28 β 14x14
# Conv Layer 2: 32 β 64 channels
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2, 2) # 14x14 β 7x7
# Fully Connected Layers
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.relu3 = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = self.flatten(x)
x = self.dropout(self.relu3(self.fc1(x)))
x = self.fc2(x)
return xWhy this architecture?
- 2 Convolutional Layers: Sufficient for feature extraction
- MaxPooling: Reduces spatial dimensions, lowers computation cost
- Dropout (0.5): Prevents overfitting
- Compact: ~100K parameters, ~500KB model size
2. Model Training
Data Preparation
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # MNIST mean & std
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)Training Loop
10 epochs were sufficient:
model = MNISTNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(1, 11):
# Training
model.train()
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Validation
model.eval()
test_accuracy = evaluate(model, test_loader)
print(f'Epoch {epoch}: Accuracy = {test_accuracy:.2f}%')Results:
- Epoch 1: 97.8%
- Epoch 5: 99.0%
- Epoch 10: 99.21%
3. ONNX Export
Exporting a PyTorch model to ONNX is straightforward:
# Set model to eval mode
model.eval()
# Dummy input (batch_size=1, channels=1, height=28, width=28)
dummy_input = torch.randn(1, 1, 28, 28, device='cpu')
# Export
torch.onnx.export(
model,
dummy_input,
'model/mnist_model.onnx',
export_params=True,
opset_version=12, # Critical for browser compatibility!
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)Important Notes:
opset_version=12: For ONNX Runtime Web compatibilitydo_constant_folding=True: Model optimizationdynamic_axes: Supports different batch sizes
ONNX Validation
import onnx
onnx_model = onnx.load('model/mnist_model.onnx')
onnx.checker.check_model(onnx_model)
print('β ONNX model is valid!')4. Browser Inference
ONNX Runtime Web Setup
npm install onnxruntime-webModel Loading
import * as ort from 'onnxruntime-web';
const session = await ort.InferenceSession.create('/demos/mnist/model.onnx');Preprocessing
Converting the canvas image to the model's expected format is critical:
function preprocessCanvas(canvas: HTMLCanvasElement): Float32Array {
const ctx = canvas.getContext('2d');
// Resize to 28x28
const tempCanvas = document.createElement('canvas');
tempCanvas.width = 28;
tempCanvas.height = 28;
const tempCtx = tempCanvas.getContext('2d');
tempCtx.drawImage(canvas, 0, 0, 28, 28);
// Get ImageData
const imageData = tempCtx.getImageData(0, 0, 28, 28);
const pixels = imageData.data;
// Grayscale + Normalize
const input = new Float32Array(1 * 1 * 28 * 28);
for (let i = 0; i < 28 * 28; i++) {
const pixelIndex = i * 4;
const gray = (pixels[pixelIndex] + pixels[pixelIndex + 1] + pixels[pixelIndex + 2]) / 3;
const normalizedPixel = gray / 255.0;
// MNIST normalization
input[i] = (normalizedPixel - 0.1307) / 0.3081;
}
return input;
}Key considerations:
- Resize: Use bilinear interpolation
- Grayscale: RGB β Gray conversion
- Normalization: MNIST mean=0.1307, std=0.3081
- Inversion: MNIST is white-on-black, our canvas is opposite
Inference
async function runInference(session: ort.InferenceSession, input: Float32Array) {
// Create tensor
const tensor = new ort.Tensor('float32', input, [1, 1, 28, 28]);
// Inference
const feeds = { [session.inputNames[0]]: tensor };
const results = await session.run(feeds);
const output = results[session.outputNames[0]];
// Softmax (logits β probabilities)
const logits = Array.from(output.data as Float32Array);
const maxLogit = Math.max(...logits);
const expLogits = logits.map(x => Math.exp(x - maxLogit));
const sumExp = expLogits.reduce((a, b) => a + b, 0);
const probabilities = expLogits.map(x => x / sumExp);
return probabilities;
}5. Performance Optimization
Model Loading
First load ~1 second. Subsequent inferences use cached session.
let modelSession: ort.InferenceSession | null = null;
export async function loadModel(modelPath: string) {
if (modelSession) return modelSession;
modelSession = await ort.InferenceSession.create(modelPath);
return modelSession;
}Inference Speed
- Average: 15-30ms
- Minimum: 10ms (powerful devices)
- Maximum: 50ms (mobile devices)
Bundle Size
onnxruntime-web: ~1.5MB (gzipped: ~500KB)- Model file: ~500KB
Optimization strategies:
- Lazy loading: Load model only when needed
- Web Worker: Inference without blocking UI thread
- WebGL backend: GPU acceleration (optional)
6. Challenges and Solutions
Challenge 1: Canvas Normalization
Problem: Canvas drawing didn't match MNIST format (inverted colors).
Solution: Apply inversion during preprocessing or use data augmentation during training.
Challenge 2: ONNX Opset Compatibility
Problem: Opset 13+ not supported in some browsers.
Solution: Use opset_version=12 for maximum compatibility.
Challenge 3: Mobile Touch Events
Problem: Touch events handled differently than mouse events.
Solution: Handle both onMouseMove and onTouchMove:
const getCoordinates = (e: MouseEvent | TouchEvent) => {
const rect = canvas.getBoundingClientRect();
if ('touches' in e) {
return {
x: e.touches[0].clientX - rect.left,
y: e.touches[0].clientY - rect.top
};
}
return {
x: e.clientX - rect.left,
y: e.clientY - rect.top
};
};7. Results and Takeaways
Successes
99%+ accuracy
<50ms inference time
Zero server cost
User privacy preserved
Works on mobile
Lessons Learned
- Power and limitations of ONNX Runtime Web
- Browser ML inference optimization techniques
- Image preprocessing with Canvas API
- PyTorch β ONNX export best practices
Future Improvements
- Web Worker for async inference
- WebGL backend for GPU acceleration
- Model quantization (smaller size)
- Multi-model comparison (different architectures)
- Real-time prediction while drawing
8. Demo and Source Code
Live Demo: /projects/mnist-digit-recognition
GitHub Repository: erdoganeray/mnist-digit-recognition
9. Related Resources
Questions or feedback: Contact
Project date: February 2026