Skip to content

Vitgracer/ViT-from-scratch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

34 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Visits GitHub last commit GitHub repo size GitHub stars GitHub forks Python

Minimal Visual Transformer implementation in PyTorch

Hey Friends!

Welcome to this tiny experiment where we compare a classic Convolutional Neural Network (CNN) against the modern Vision Transformer (ViT). The task: old and gold handwritten digits recognition.

Why digits? Because my PC will explode with anything more serious 😂 Because the goal here is not to ACHIEVE, but to see how ViTs actually work under the hood.


⚙️ Installation

Create a virtual environment and install dependencies:

python -m venv vit-venv
source vit-venv/Scripts/activate
pip3 install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
pip3 install einops torchsummary

How tu run:

python train.py

🥊 CNN vs ViT

We train two models with roughly ~2k trainable parameters each.

🔵 CNN! the test accuracy looks like this:

CNN

🔴 ViT! And here’s the ViT’s performance:

VIT

🤔 Observations

  • Attention layers involve matrix multiplications of full sequences (O(N²) complexity), so ViT is SLOWER. Not like a turtle.. but the turtle loaded with bags from supermarket 😂
  • The ViT also gets lower accuracy, because Transformers work good when they can model long-range dependencies and are fed with lots of data. On MNIST, the images are tiny (28×28) and the dataset is small. CNNs are simply better at extracting local patterns like edges, strokes, and curves.

In other words: asking a ViT to classify MNIST is like hiring a theoretical physicist to count apples at a grocery store. So choose your model wisely! 🧐

⚖️ Pros & Cons

CNN

  • ✅ Fast to train
  • ✅ Great at local pattern recognition (edges, textures, shapes)
  • ✅ Works very well with small datasets
  • ❌ Limited ability to capture global context

ViT

  • ✅ Elegant, unified architecture (no handcrafted kernels)
  • ✅ Scales with data (huge datasets)
  • ✅ Attention maps are interpretable (attention weights show where the model is “looking”)
  • ❌ Training is slower
  • ❌ Needs more data to reach CNN-level performance