Eray.
β€’Machine Learning, PyTorch, ONNX, Deep Learning, Web Development

MNIST with PyTorch to ONNX: Machine Learning in the Browser

MNIST with PyTorch to ONNX: Machine Learning in the Browser

Machine learning models are typically run server-side. However, with ONNX Runtime Web, we can run trained models directly in the user's browser. In this post, I'll walk through how I built my MNIST digit recognition project step by step.

Project Goal

Objective: A web app that recognizes handwritten digits, running entirely in the browser.

Requirements:

  • Zero server cost
  • Real-time prediction (<100ms)
  • User privacy (data never leaves browser)
  • 99%+ accuracy

1. Model Architecture and Design

I chose a classic CNN (Convolutional Neural Network) architecture for MNIST:

class MNISTNet(nn.Module):
    def __init__(self):
        super(MNISTNet, self).__init__()
        # Conv Layer 1: 1 β†’ 32 channels
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)  # 28x28 β†’ 14x14
        
        # Conv Layer 2: 32 β†’ 64 channels
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, 2)  # 14x14 β†’ 7x7
        
        # Fully Connected Layers
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.flatten(x)
        x = self.dropout(self.relu3(self.fc1(x)))
        x = self.fc2(x)
        return x

Why this architecture?

  • 2 Convolutional Layers: Sufficient for feature extraction
  • MaxPooling: Reduces spatial dimensions, lowers computation cost
  • Dropout (0.5): Prevents overfitting
  • Compact: ~100K parameters, ~500KB model size

2. Model Training

Data Preparation

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean & std
])
 
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
 
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Training Loop

10 epochs were sufficient:

model = MNISTNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
 
for epoch in range(1, 11):
    # Training
    model.train()
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    
    # Validation
    model.eval()
    test_accuracy = evaluate(model, test_loader)
    print(f'Epoch {epoch}: Accuracy = {test_accuracy:.2f}%')

Results:

  • Epoch 1: 97.8%
  • Epoch 5: 99.0%
  • Epoch 10: 99.21%

3. ONNX Export

Exporting a PyTorch model to ONNX is straightforward:

# Set model to eval mode
model.eval()
 
# Dummy input (batch_size=1, channels=1, height=28, width=28)
dummy_input = torch.randn(1, 1, 28, 28, device='cpu')
 
# Export
torch.onnx.export(
    model,
    dummy_input,
    'model/mnist_model.onnx',
    export_params=True,
    opset_version=12,  # Critical for browser compatibility!
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

Important Notes:

  • opset_version=12: For ONNX Runtime Web compatibility
  • do_constant_folding=True: Model optimization
  • dynamic_axes: Supports different batch sizes

ONNX Validation

import onnx
onnx_model = onnx.load('model/mnist_model.onnx')
onnx.checker.check_model(onnx_model)
print('βœ“ ONNX model is valid!')

4. Browser Inference

ONNX Runtime Web Setup

npm install onnxruntime-web

Model Loading

import * as ort from 'onnxruntime-web';
 
const session = await ort.InferenceSession.create('/demos/mnist/model.onnx');

Preprocessing

Converting the canvas image to the model's expected format is critical:

function preprocessCanvas(canvas: HTMLCanvasElement): Float32Array {
  const ctx = canvas.getContext('2d');
  
  // Resize to 28x28
  const tempCanvas = document.createElement('canvas');
  tempCanvas.width = 28;
  tempCanvas.height = 28;
  const tempCtx = tempCanvas.getContext('2d');
  tempCtx.drawImage(canvas, 0, 0, 28, 28);
  
  // Get ImageData
  const imageData = tempCtx.getImageData(0, 0, 28, 28);
  const pixels = imageData.data;
  
  // Grayscale + Normalize
  const input = new Float32Array(1 * 1 * 28 * 28);
  for (let i = 0; i < 28 * 28; i++) {
    const pixelIndex = i * 4;
    const gray = (pixels[pixelIndex] + pixels[pixelIndex + 1] + pixels[pixelIndex + 2]) / 3;
    const normalizedPixel = gray / 255.0;
    
    // MNIST normalization
    input[i] = (normalizedPixel - 0.1307) / 0.3081;
  }
  
  return input;
}

Key considerations:

  1. Resize: Use bilinear interpolation
  2. Grayscale: RGB β†’ Gray conversion
  3. Normalization: MNIST mean=0.1307, std=0.3081
  4. Inversion: MNIST is white-on-black, our canvas is opposite

Inference

async function runInference(session: ort.InferenceSession, input: Float32Array) {
  // Create tensor
  const tensor = new ort.Tensor('float32', input, [1, 1, 28, 28]);
  
  // Inference
  const feeds = { [session.inputNames[0]]: tensor };
  const results = await session.run(feeds);
  const output = results[session.outputNames[0]];
  
  // Softmax (logits β†’ probabilities)
  const logits = Array.from(output.data as Float32Array);
  const maxLogit = Math.max(...logits);
  const expLogits = logits.map(x => Math.exp(x - maxLogit));
  const sumExp = expLogits.reduce((a, b) => a + b, 0);
  const probabilities = expLogits.map(x => x / sumExp);
  
  return probabilities;
}

5. Performance Optimization

Model Loading

First load ~1 second. Subsequent inferences use cached session.

let modelSession: ort.InferenceSession | null = null;
 
export async function loadModel(modelPath: string) {
  if (modelSession) return modelSession;
  modelSession = await ort.InferenceSession.create(modelPath);
  return modelSession;
}

Inference Speed

  • Average: 15-30ms
  • Minimum: 10ms (powerful devices)
  • Maximum: 50ms (mobile devices)

Bundle Size

  • onnxruntime-web: ~1.5MB (gzipped: ~500KB)
  • Model file: ~500KB

Optimization strategies:

  • Lazy loading: Load model only when needed
  • Web Worker: Inference without blocking UI thread
  • WebGL backend: GPU acceleration (optional)

6. Challenges and Solutions

Challenge 1: Canvas Normalization

Problem: Canvas drawing didn't match MNIST format (inverted colors).

Solution: Apply inversion during preprocessing or use data augmentation during training.

Challenge 2: ONNX Opset Compatibility

Problem: Opset 13+ not supported in some browsers.

Solution: Use opset_version=12 for maximum compatibility.

Challenge 3: Mobile Touch Events

Problem: Touch events handled differently than mouse events.

Solution: Handle both onMouseMove and onTouchMove:

const getCoordinates = (e: MouseEvent | TouchEvent) => {
  const rect = canvas.getBoundingClientRect();
  if ('touches' in e) {
    return {
      x: e.touches[0].clientX - rect.left,
      y: e.touches[0].clientY - rect.top
    };
  }
  return {
    x: e.clientX - rect.left,
    y: e.clientY - rect.top
  };
};

7. Results and Takeaways

Successes

99%+ accuracy
<50ms inference time
Zero server cost
User privacy preserved
Works on mobile

Lessons Learned

  • Power and limitations of ONNX Runtime Web
  • Browser ML inference optimization techniques
  • Image preprocessing with Canvas API
  • PyTorch β†’ ONNX export best practices

Future Improvements

  • Web Worker for async inference
  • WebGL backend for GPU acceleration
  • Model quantization (smaller size)
  • Multi-model comparison (different architectures)
  • Real-time prediction while drawing

8. Demo and Source Code

Live Demo: /projects/mnist-digit-recognition

GitHub Repository: erdoganeray/mnist-digit-recognition

9. Related Resources


Questions or feedback: Contact

Project date: February 2026