Eray.
Projelere Dön

MNIST Rakam Tanıma

PyTorch ve ONNX kullanarak tarayıcıda çalışan rakam tanıma uygulaması

Kaynak Kod
Machine LearningPyTorchONNXDeep LearningJavaScript
Teknolojiler:PyTorchONNX RuntimeNext.jsCanvas APITypeScript

MNIST Rakam Tanıma - İnteraktif Demo

Bu proje, PyTorch ile eğitilmiş bir Convolutional Neural Network (CNN) modelinin ONNX formatına export edilerek doğrudan tarayıcınızda çalıştırılmasını sağlıyor. Sunucu gerektirmeden, tamamen client-side inference yapıyor.

Aşağıdaki canvas alanına 0-9 arası bir rakam çizin ve "Tahmin Et" butonuna tıklayın. Model, çizdiğiniz rakamı tanıyacak ve her rakam için olasılık değerlerini gösterecektir.

MNIST Rakam Tanıma Demosu

Aşağıdaki canvas alanına 0-9 arası bir rakam çizin. Model, çizdiğiniz rakamı tanıyacak ve olasılık değerlerini gösterecektir.

Model yükleniyor / Loading model...


MNIST ile PyTorch'tan ONNX'e: Tarayıcıda Çalışan Makine Öğrenimi

Makine öğrenimi modellerini genellikle server-side çalıştırırız. Ancak ONNX Runtime Web sayesinde, eğitilmiş modelleri doğrudan kullanıcının tarayıcısında çalıştırabiliriz. Bu yazıda, MNIST digit recognition projemi nasıl yaptığımı adım adım anlatacağım.

Proje Hedefi

Amaç: Kullanıcının çizdiği rakamı tanıyan, tamamen tarayıcıda çalışan bir web uygulaması.

Gereksinimler:

  • Sunucu maliyeti sıfır
  • Gerçek zamanlı tahmin (<100ms)
  • Kullanıcı gizliliği (veri sunucuya gitmiyor)
  • %99+ doğruluk oranı

1. Model Mimarisi ve Tasarım

MNIST için klasik bir CNN (Convolutional Neural Network) mimarisi seçtim:

class MNISTNet(nn.Module):
    def __init__(self):
        super(MNISTNet, self).__init__()
        # Conv Layer 1: 1 → 32 channels
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)  # 28x28 → 14x14
        
        # Conv Layer 2: 32 → 64 channels
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, 2)  # 14x14 → 7x7
        
        # Fully Connected Layers
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.flatten(x)
        x = self.dropout(self.relu3(self.fc1(x)))
        x = self.fc2(x)
        return x

Neden bu mimari?

  • 2 Convolutional Layer: Feature extraction için yeterli
  • MaxPooling: Spatial dimensions'ı azaltıyor, hesaplama maliyetini düşürüyor
  • Dropout (0.5): Overfitting'i önlüyor
  • Kompakt: ~100K parameters, ~500KB model size

2. Model Eğitimi

Data Preparation

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean & std
])
 
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
 
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Training Loop

10 epoch eğitim yeterli oldu:

model = MNISTNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
 
for epoch in range(1, 11):
    # Training
    model.train()
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    
    # Validation
    model.eval()
    test_accuracy = evaluate(model, test_loader)
    print(f'Epoch {epoch}: Accuracy = {test_accuracy:.2f}%')

Sonuçlar:

  • Epoch 1: 97.8%
  • Epoch 5: 99.0%
  • Epoch 10: 99.21%

3. ONNX Export

PyTorch modelini ONNX formatına export etmek oldukça basit:

# Model eval moduna al
model.eval()
 
# Dummy input (batch_size=1, channels=1, height=28, width=28)
dummy_input = torch.randn(1, 1, 28, 28, device='cpu')
 
# Export
torch.onnx.export(
    model,
    dummy_input,
    'model/mnist_model.onnx',
    export_params=True,
    opset_version=12,  # Browser compatibility için kritik!
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

Önemli Notlar:

  • opset_version=12: ONNX Runtime Web ile uyumluluk için
  • do_constant_folding=True: Model optimization
  • dynamic_axes: Farklı batch size'ları destekler

ONNX Doğrulama

import onnx
onnx_model = onnx.load('model/mnist_model.onnx')
onnx.checker.check_model(onnx_model)
print('✓ ONNX model is valid!')

4. Browser'da Inference

ONNX Runtime Web Kurulumu

npm install onnxruntime-web

Model Loading

import * as ort from 'onnxruntime-web';
 
const session = await ort.InferenceSession.create('/demos/mnist/model.onnx');

Preprocessing

Canvas'tan alınan görüntüyü modelin beklediği formata çevirmek kritik:

function preprocessCanvas(canvas: HTMLCanvasElement): Float32Array {
  const ctx = canvas.getContext('2d');
  
  // 28x28'e resize
  const tempCanvas = document.createElement('canvas');
  tempCanvas.width = 28;
  tempCanvas.height = 28;
  const tempCtx = tempCanvas.getContext('2d');
  tempCtx.drawImage(canvas, 0, 0, 28, 28);
  
  // ImageData al
  const imageData = tempCtx.getImageData(0, 0, 28, 28);
  const pixels = imageData.data;
  
  // Grayscale + Normalize
  const input = new Float32Array(1 * 1 * 28 * 28);
  for (let i = 0; i < 28 * 28; i++) {
    const pixelIndex = i * 4;
    const gray = (pixels[pixelIndex] + pixels[pixelIndex + 1] + pixels[pixelIndex + 2]) / 3;
    const normalizedPixel = gray / 255.0;
    
    // MNIST normalization
    input[i] = (normalizedPixel - 0.1307) / 0.3081;
  }
  
  return input;
}

Dikkat edilmesi gerekenler:

  1. Resize: Bilinear interpolation kullan
  2. Grayscale: RGB → Gray conversion
  3. Normalization: MNIST mean=0.1307, std=0.3081
  4. Inversion: MNIST beyaz üzerine siyah, canvas'ımız tam tersi

Inference

async function runInference(session: ort.InferenceSession, input: Float32Array) {
  // Tensor oluştur
  const tensor = new ort.Tensor('float32', input, [1, 1, 28, 28]);
  
  // Inference
  const feeds = { [session.inputNames[0]]: tensor };
  const results = await session.run(feeds);
  const output = results[session.outputNames[0]];
  
  // Softmax (logits → probabilities)
  const logits = Array.from(output.data as Float32Array);
  const maxLogit = Math.max(...logits);
  const expLogits = logits.map(x => Math.exp(x - maxLogit));
  const sumExp = expLogits.reduce((a, b) => a + b, 0);
  const probabilities = expLogits.map(x => x / sumExp);
  
  return probabilities;
}

5. Performance Optimization

Model Loading

İlk yükleme ~1 saniye. Sonraki inference'lar için cache'lenmiş session kullanılıyor.

let modelSession: ort.InferenceSession | null = null;
 
export async function loadModel(modelPath: string) {
  if (modelSession) return modelSession;
  modelSession = await ort.InferenceSession.create(modelPath);
  return modelSession;
}

Inference Speed

  • Ortalama: 15-30ms
  • Minimum: 10ms (güçlü cihazlar)
  • Maximum: 50ms (mobil cihazlar)

Bundle Size

  • onnxruntime-web: ~1.5MB (gzip: ~500KB)
  • Model dosyası: ~500KB

Optimization stratejileri:

  • Lazy loading: Model sadece gerektiğinde yükle
  • Web Worker: UI thread'i bloklamadan inference
  • WebGL backend: GPU acceleration (opsiyonel)

6. Challenges ve Çözümler

Challenge 1: Canvas Normalization

Problem: Canvas'taki çizim MNIST formatına uymuyordu (ters renkler).

Çözüm: Preprocessing sırasında inversion uygula veya model eğitiminde data augmentation kullan.

Challenge 2: ONNX Opset Compatibility

Problem: Opset 13+ bazı browser'larda desteklenmiyor.

Çözüm: opset_version=12 kullan, maksimum compatibility sağlıyor.

Challenge 3: Mobile Touch Events

Problem: Touch eventi mouse event'ten farklı işleniyor.

Çözüm: Hem onMouseMove hem onTouchMove handle et:

const getCoordinates = (e: MouseEvent | TouchEvent) => {
  const rect = canvas.getBoundingClientRect();
  if ('touches' in e) {
    return {
      x: e.touches[0].clientX - rect.left,
      y: e.touches[0].clientY - rect.top
    };
  }
  return {
    x: e.clientX - rect.left,
    y: e.clientY - rect.top
  };
};

7. Sonuçlar ve İzlenimler

Başarılar

✓ %99+ doğruluk oranı
<50ms inference süresi
✓ Sıfır sunucu maliyeti
✓ Kullanıcı gizliliği korunuyor
✓ Mobile'da da çalışıyor

Öğrendiklerim

  • ONNX Runtime Web'in gücü ve limitations
  • Browser'da ML inference optimization teknikleri
  • Canvas API ile image preprocessing
  • PyTorch → ONNX export best practices

Gelecek İyileştirmeler

  • Web Worker ile async inference
  • WebGL backend ile GPU acceleration
  • Model quantization (smaller size)
  • Multi-model comparison (different architectures)
  • Real-time prediction while drawing

8. Demo ve Kaynak Kodlar

Live Demo: /projects/mnist-digit-recognition

GitHub Repository: erdoganeray/mnist-digit-recognition

9. İlgili Kaynaklar


Sorular veya geri bildirimler için: İletişim

Proje tarihi: Şubat 2026