MNIST ile PyTorch'tan ONNX'e: Tarayıcıda Çalışan Makine Öğrenimi
MNIST ile PyTorch'tan ONNX'e: Tarayıcıda Çalışan Makine Öğrenimi
Makine öğrenimi modellerini genellikle server-side çalıştırırız. Ancak ONNX Runtime Web sayesinde, eğitilmiş modelleri doğrudan kullanıcının tarayıcısında çalıştırabiliriz. Bu yazıda, MNIST digit recognition projemi nasıl yaptığımı adım adım anlatacağım.
Proje Hedefi
Amaç: Kullanıcının çizdiği rakamı tanıyan, tamamen tarayıcıda çalışan bir web uygulaması.
Gereksinimler:
- Sunucu maliyeti sıfır
- Gerçek zamanlı tahmin (
<100ms) - Kullanıcı gizliliği (veri sunucuya gitmiyor)
- %99+ doğruluk oranı
1. Model Mimarisi ve Tasarım
MNIST için klasik bir CNN (Convolutional Neural Network) mimarisi seçtim:
class MNISTNet(nn.Module):
def __init__(self):
super(MNISTNet, self).__init__()
# Conv Layer 1: 1 → 32 channels
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2, 2) # 28x28 → 14x14
# Conv Layer 2: 32 → 64 channels
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2, 2) # 14x14 → 7x7
# Fully Connected Layers
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.relu3 = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = self.flatten(x)
x = self.dropout(self.relu3(self.fc1(x)))
x = self.fc2(x)
return xNeden bu mimari?
- 2 Convolutional Layer: Feature extraction için yeterli
- MaxPooling: Spatial dimensions'ı azaltıyor, hesaplama maliyetini düşürüyor
- Dropout (0.5): Overfitting'i önlüyor
- Kompakt: ~100K parameters, ~500KB model size
2. Model Eğitimi
Data Preparation
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # MNIST mean & std
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)Training Loop
10 epoch eğitim yeterli oldu:
model = MNISTNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(1, 11):
# Training
model.train()
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Validation
model.eval()
test_accuracy = evaluate(model, test_loader)
print(f'Epoch {epoch}: Accuracy = {test_accuracy:.2f}%')Sonuçlar:
- Epoch 1: 97.8%
- Epoch 5: 99.0%
- Epoch 10: 99.21%
3. ONNX Export
PyTorch modelini ONNX formatına export etmek oldukça basit:
# Model eval moduna al
model.eval()
# Dummy input (batch_size=1, channels=1, height=28, width=28)
dummy_input = torch.randn(1, 1, 28, 28, device='cpu')
# Export
torch.onnx.export(
model,
dummy_input,
'model/mnist_model.onnx',
export_params=True,
opset_version=12, # Browser compatibility için kritik!
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)Önemli Notlar:
opset_version=12: ONNX Runtime Web ile uyumluluk içindo_constant_folding=True: Model optimizationdynamic_axes: Farklı batch size'ları destekler
ONNX Doğrulama
import onnx
onnx_model = onnx.load('model/mnist_model.onnx')
onnx.checker.check_model(onnx_model)
print('✓ ONNX model is valid!')4. Browser'da Inference
ONNX Runtime Web Kurulumu
npm install onnxruntime-webModel Loading
import * as ort from 'onnxruntime-web';
const session = await ort.InferenceSession.create('/demos/mnist/model.onnx');Preprocessing
Canvas'tan alınan görüntüyü modelin beklediği formata çevirmek kritik:
function preprocessCanvas(canvas: HTMLCanvasElement): Float32Array {
const ctx = canvas.getContext('2d');
// 28x28'e resize
const tempCanvas = document.createElement('canvas');
tempCanvas.width = 28;
tempCanvas.height = 28;
const tempCtx = tempCanvas.getContext('2d');
tempCtx.drawImage(canvas, 0, 0, 28, 28);
// ImageData al
const imageData = tempCtx.getImageData(0, 0, 28, 28);
const pixels = imageData.data;
// Grayscale + Normalize
const input = new Float32Array(1 * 1 * 28 * 28);
for (let i = 0; i < 28 * 28; i++) {
const pixelIndex = i * 4;
const gray = (pixels[pixelIndex] + pixels[pixelIndex + 1] + pixels[pixelIndex + 2]) / 3;
const normalizedPixel = gray / 255.0;
// MNIST normalization
input[i] = (normalizedPixel - 0.1307) / 0.3081;
}
return input;
}Dikkat edilmesi gerekenler:
- Resize: Bilinear interpolation kullan
- Grayscale: RGB → Gray conversion
- Normalization: MNIST mean=0.1307, std=0.3081
- Inversion: MNIST beyaz üzerine siyah, canvas'ımız tam tersi
Inference
async function runInference(session: ort.InferenceSession, input: Float32Array) {
// Tensor oluştur
const tensor = new ort.Tensor('float32', input, [1, 1, 28, 28]);
// Inference
const feeds = { [session.inputNames[0]]: tensor };
const results = await session.run(feeds);
const output = results[session.outputNames[0]];
// Softmax (logits → probabilities)
const logits = Array.from(output.data as Float32Array);
const maxLogit = Math.max(...logits);
const expLogits = logits.map(x => Math.exp(x - maxLogit));
const sumExp = expLogits.reduce((a, b) => a + b, 0);
const probabilities = expLogits.map(x => x / sumExp);
return probabilities;
}5. Performance Optimization
Model Loading
İlk yükleme ~1 saniye. Sonraki inference'lar için cache'lenmiş session kullanılıyor.
let modelSession: ort.InferenceSession | null = null;
export async function loadModel(modelPath: string) {
if (modelSession) return modelSession;
modelSession = await ort.InferenceSession.create(modelPath);
return modelSession;
}Inference Speed
- Ortalama: 15-30ms
- Minimum: 10ms (güçlü cihazlar)
- Maximum: 50ms (mobil cihazlar)
Bundle Size
onnxruntime-web: ~1.5MB (gzip: ~500KB)- Model dosyası: ~500KB
Optimization stratejileri:
- Lazy loading: Model sadece gerektiğinde yükle
- Web Worker: UI thread'i bloklamadan inference
- WebGL backend: GPU acceleration (opsiyonel)
6. Challenges ve Çözümler
Challenge 1: Canvas Normalization
Problem: Canvas'taki çizim MNIST formatına uymuyordu (ters renkler).
Çözüm: Preprocessing sırasında inversion uygula veya model eğitiminde data augmentation kullan.
Challenge 2: ONNX Opset Compatibility
Problem: Opset 13+ bazı browser'larda desteklenmiyor.
Çözüm: opset_version=12 kullan, maksimum compatibility sağlıyor.
Challenge 3: Mobile Touch Events
Problem: Touch eventi mouse event'ten farklı işleniyor.
Çözüm: Hem onMouseMove hem onTouchMove handle et:
const getCoordinates = (e: MouseEvent | TouchEvent) => {
const rect = canvas.getBoundingClientRect();
if ('touches' in e) {
return {
x: e.touches[0].clientX - rect.left,
y: e.touches[0].clientY - rect.top
};
}
return {
x: e.clientX - rect.left,
y: e.clientY - rect.top
};
};7. Sonuçlar ve İzlenimler
Başarılar
%99+ doğruluk oranı
<50ms inference süresi
Sıfır sunucu maliyeti
Kullanıcı gizliliği korunuyor
Mobile'da da çalışıyor
Öğrendiklerim
- ONNX Runtime Web'in gücü ve limitations
- Browser'da ML inference optimization teknikleri
- Canvas API ile image preprocessing
- PyTorch → ONNX export best practices
Gelecek İyileştirmeler
- Web Worker ile async inference
- WebGL backend ile GPU acceleration
- Model quantization (smaller size)
- Multi-model comparison (different architectures)
- Real-time prediction while drawing
8. Demo ve Kaynak Kodlar
Live Demo: /projects/mnist-digit-recognition
GitHub Repository: erdoganeray/mnist-digit-recognition
9. İlgili Kaynaklar
Sorular veya geri bildirimler için: İletişim
Proje tarihi: Şubat 2026