This repo contains simple gRPC server implementation for serving MNIST digit classification requests using Rust stack. Purpose of this project is personal learning and experimentation with tonic for rust implementation of gRPC; and candle, a ml/tensor framework in Rust.
One nice feature of using Rust in ml-inference is the lightweight deployments. While Python deployments with deep learning
frameworks like Pytorch often result in container sizes of multiple GBs, with candle, the release binary size of the
server implementation is only ~6MB (without model weights).
This project supports two (trivial) neural network architectures for MNIST classification:
-
Multi-Layer Perceptron (MLP)
- 3 fully connected layers: 784 → 128 → 64 → 10
- Simple feedforward network
-
Convolutional Neural Network (ConvNet)
- 2 convolutional layers (1→32→64 channels) with ReLU and max pooling
- 2 fully connected layers: 3136 → 128 → 10
These models are defined in the mnist sub-crate.
- Rust
- Python 3.8+
- uv (Python package manager)
Model training is done in Python with Pytorch:
-
Navigate to the training directory:
cd training -
Install Python dependencies:
uv sync
-
Train the model and save weights:
uv run python train.py --output ../models/mnist_convnet.safetensors
The training script will:
- Download the MNIST dataset automatically
- Train a ConvNet for 3 epochs
- Display training progress and final test accuracy
- Save the model weights in SafeTensors format
-
Build the server:
cargo build --release
-
Start the gRPC server:
cargo run --release --bin grpc-server -- --model-architecture conv --model-weights models/mnist_convnet.safetensors
The server will start on
localhost:50051by default.
You can check other available CLI args with --help.
As the protocol expects the images to be sent as raw bytes, one can convert image to base64 and create a request in JSON format:
echo '{"data": "'$(base64 -w 0 -i ~/path_to_pic/four.png)'"}' > test_request.jsonUsing grpcurl, such requests can be sent to the server:
grpcurl -plaintext -proto ./proto/mnist.proto \
-d @ \
'[::1]:50051' mnist.Mnist.Predict \
< your_request.jsonThis should respond with something like:
{
"label": 4,
"probabilities": [0.001, 0.002, 0.003, 0.004, 0.985, 0.002, 0.001, 0.001, 0.001, 0.000]
}