ExTorch
Elixir bindings for libtorch -- production ML model serving on the BEAM.
ExTorch lets you load TorchScript models, run inference with OTP fault tolerance, define neural network architectures with an Elixir DSL, and monitor serving performance through telemetry and LiveDashboard.
Features
- JIT Model Serving -- Load
.ptmodels, run inference with full IValue support (tensors, tuples, dicts, scalars), and serve behind GenServer with process isolation. - Neural Network DSL -- Define PyTorch-compatible layers declaratively in Elixir with
deflayer, backed by libtorch's C++ nn modules. - JIT IR Introspection -- Extract model architecture, parameters, and computation graphs from any TorchScript model. Generate Elixir DSL source code from
.ptfiles. - Zero-Copy Tensor Exchange -- Share tensor memory with Nx/Torchx via raw pointer exchange (
data_ptr/from_blob), no copies. - Telemetry & Observability --
:telemetryevents for load/inference, ETS-backed metrics (latency, throughput, errors), optional LiveDashboard page. - Tensor Operations -- 200+ wrapped libtorch ops for tensor creation, manipulation, pointwise math, comparison, reduction, and indexing.
Requirements
- Elixir >= 1.14
- Rust (stable toolchain)
- libtorch (automatically downloaded, or use a local PyTorch installation)
Installation
Add extorch to your dependencies in mix.exs:
def deps do
[
{:extorch, "~> 0.2.0"}
]
end
ExTorch will download libtorch automatically on first compile. To use a local installation, configure in config/config.exs:
config :extorch, libtorch: [
version: :local,
folder: :python # or an absolute path to libtorch
]Quick Start
Loading and Serving a TorchScript Model
# Load a model exported from Python with torch.jit.script() or torch.jit.trace()
model = ExTorch.JIT.load("model.pt")
ExTorch.JIT.eval(model)
# Run inference
input = ExTorch.randn({1, 3, 224, 224})
output = ExTorch.JIT.forward(model, [input])
# Models returning tuples/dicts work naturally
{logits, features} = ExTorch.JIT.forward(multi_output_model, [input])GenServer-based Model Server
# Start a supervised model server
{:ok, _pid} = ExTorch.JIT.Server.start_link(
path: "model.pt",
device: :cpu,
name: MyModel
)
# Run inference (thread-safe, serialized through GenServer)
output = ExTorch.JIT.Server.predict(MyModel, [input])
# Check server stats
ExTorch.JIT.Server.info(MyModel)
# => %{path: "model.pt", device: :cpu, inference_count: 42, error_count: 0, uptime_ms: 15000}Neural Network DSL
defmodule MyMLP do
use ExTorch.NN.Module
deflayer :fc1, ExTorch.NN.Linear, in_features: 784, out_features: 128
deflayer :relu, ExTorch.NN.ReLU
deflayer :dropout, ExTorch.NN.Dropout, p: 0.5
deflayer :fc2, ExTorch.NN.Linear, in_features: 128, out_features: 10
def forward(model, x) do
x
|> layer(model, :fc1)
|> layer(model, :relu)
|> layer(model, :dropout)
|> layer(model, :fc2)
end
end
model = MyMLP.new()
input = ExTorch.randn({32, 784})
output = MyMLP.forward(model, input)
# => %ExTorch.Tensor{size: {32, 10}, ...}
# Inspect parameters
MyMLP.parameters(model)
# => [{"fc1.weight", #Tensor<[128, 784]>}, {"fc1.bias", #Tensor<[128]>}, ...]
Available layers: Linear, Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, MaxPool1d, MaxPool2d, AvgPool1d, AvgPool2d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm, InstanceNorm1d, InstanceNorm2d, Dropout, Embedding, LSTM, GRU, MultiheadAttention, Flatten, Unflatten, ReLU, LeakyReLU, GELU, ELU, SiLU, Mish, PReLU, Sigmoid, Tanh, Softmax, LogSoftmax.
Loading Pre-trained Weights
There are two ways to use a trained model from Python with a DSL-defined module:
Option A: from_jit/1 -- Use the JIT model's forward directly (simplest):
# The JIT model's forward() runs the computation with pre-trained weights.
# The DSL definition is validated against the .pt file's submodules.
model = MyMLP.from_jit("trained_model.pt")
output = MyMLP.predict(model, [input])Option B: load_weights/1 -- Copy weights into Elixir-defined layers:
# Creates DSL layers, then copies matching parameter tensors from the .pt file.
# The result is a regular DSL model that runs through your forward/2 function.
model = MyMLP.load_weights("trained_model.pt")
output = MyMLP.forward(model, input)
Both produce identical outputs. Use from_jit when you want the exact Python forward logic. Use load_weights when your Elixir forward/2 differs (e.g., different dropout, custom post-processing) but you want the same trained parameters.
JIT Model Introspection
model = ExTorch.JIT.load("resnet18.pt")
# Extract structured schema
schema = ExTorch.NN.Introspect.schema(model)
schema.submodules
# => [%{name: "conv1", type_name: "...Conv2d", parameters: [%{name: "weight", shape: [64, 3, 7, 7], ...}]}, ...]
# View the computation graph
IO.puts(ExTorch.NN.Introspect.graph(model))
# Generate Elixir DSL source code from any .pt model
IO.puts(ExTorch.NN.Introspect.to_elixir(model, "ResNet18"))Zero-Copy Tensor Exchange with Nx
# ExTorch tensor -> raw pointer (for passing to Torchx/Nx)
blob = ExTorch.Tensor.Blob.to_blob(tensor)
# => %Blob{ptr: 140234567890, shape: {3, 224, 224}, strides: [...], dtype: :float, ...}
# Foreign pointer -> ExTorch tensor (zero-copy, no data movement)
view = ExTorch.Tensor.Blob.from_blob(
%{ptr: foreign_ptr, shape: {3, 224, 224}, dtype: :float32},
owner: source_tensor # prevents GC of source memory
)
view.tensor # => %ExTorch.Tensor{...}Telemetry & Metrics
# Enable metrics collection (call in your Application.start)
ExTorch.Metrics.setup()
# Metrics are automatically collected from JIT.Server telemetry events
ExTorch.Metrics.get("model.pt")
# => %{inference_count: 1500, error_count: 2, total_duration_ms: 4523.1,
# min_duration_ms: 1.2, max_duration_ms: 89.2, load_duration_ms: 340.5, ...}
# Attach your own handlers to telemetry events
:telemetry.attach("my-handler", [:extorch, :jit, :forward, :stop], &handle_event/4, nil)LiveDashboard (optional): Add {:phoenix_live_dashboard, "~> 0.8"} to your deps and configure:
live_dashboard "/dashboard",
additional_pages: [
extorch: {ExTorch.Observer.Dashboard, []}
]CUDA Support
ExTorch.Native.cuda_is_available() # => true/false
ExTorch.Native.cuda_device_count() # => 2
# Load model on GPU
model = ExTorch.JIT.load("model.pt", device: {:cuda, 0})
# Monitor GPU memory
ExTorch.Native.cuda_memory_allocated(0) # bytes currently allocated
ExTorch.Native.cuda_memory_reserved(0) # bytes reserved by caching allocatorArchitecture
ExTorch uses a three-layer architecture:
- C++ -- Wraps libtorch APIs (
torch::jit,torch::nn, tensor ops) behind shared pointer types - Rust -- Bridges C++ to Erlang NIFs via cxx and Rustler, with type-safe encoding/decoding
- Elixir -- Macro-generated API with
defbinding/nif_impl!for tensor ops, hand-written modules for JIT/NN/telemetry
License
MIT