ExTorch
Elixir bindings for libtorch -- production ML model serving on the BEAM.
ExTorch lets you load TorchScript models, run inference with OTP fault tolerance, define neural network architectures with an Elixir DSL, and monitor serving performance through telemetry and LiveDashboard.
Features
- JIT Model Serving -- Load
.ptmodels, run inference with full IValue support (tensors, tuples, dicts, scalars), and serve behind GenServer with process isolation. - AOTI Compiled Models -- Load AOTInductor
.pt2packages for optimized inference with fused kernels. - torch.export Reader & Interpreter -- Pure Elixir reader for
torch.export.save.pt2archives with a built-in ATen graph interpreter (60+ ops). Load and run inference on exported models directly -- tested with AlexNet, ResNet18, MobileNetV2, VGG11, SqueezeNet, and transformers. - Neural Network DSL -- Define PyTorch-compatible layers declaratively in Elixir with
deflayer, backed by libtorch's C++ nn modules. - JIT IR Introspection -- Extract model architecture, parameters, and computation graphs from any TorchScript model. Generate Elixir DSL source code from
.ptfiles. - Zero-Copy Tensor Exchange -- Share tensor memory with Nx/Torchx via raw pointer exchange (
data_ptr/from_blob), no copies. - Telemetry & Observability --
:telemetryevents for load/inference, ETS-backed metrics (latency, throughput, errors), optional LiveDashboard page. - Tensor Operations -- 200+ wrapped libtorch ops for tensor creation, manipulation, pointwise math, comparison, reduction, and indexing.
Requirements
- Elixir >= 1.14
- Rust (stable toolchain)
- libtorch (automatically downloaded, or use a local PyTorch installation)
Installation
Add extorch to your dependencies in mix.exs:
def deps do
[
{:extorch, "~> 0.3.0"}
]
end
ExTorch will download libtorch automatically on first compile. To use a local installation, configure in config/config.exs:
config :extorch, libtorch: [
version: :local,
folder: :python # or an absolute path to libtorch
]Quick Start
Loading and Serving a TorchScript Model
# Load a model exported from Python with torch.jit.script() or torch.jit.trace()
model = ExTorch.JIT.load("model.pt")
ExTorch.JIT.eval(model)
# Run inference
input = ExTorch.randn({1, 3, 224, 224})
output = ExTorch.JIT.forward(model, [input])
# Models returning tuples/dicts work naturally
{logits, features} = ExTorch.JIT.forward(multi_output_model, [input])GenServer-based Model Server
# Start a supervised model server
{:ok, _pid} = ExTorch.JIT.Server.start_link(
path: "model.pt",
device: :cpu,
name: MyModel
)
# Run inference (thread-safe, serialized through GenServer)
output = ExTorch.JIT.Server.predict(MyModel, [input])
# Check server stats
ExTorch.JIT.Server.info(MyModel)
# => %{path: "model.pt", device: :cpu, inference_count: 42, error_count: 0, uptime_ms: 15000}Neural Network DSL
defmodule MyMLP do
use ExTorch.NN.Module
deflayer :fc1, ExTorch.NN.Linear, in_features: 784, out_features: 128
deflayer :relu, ExTorch.NN.ReLU
deflayer :dropout, ExTorch.NN.Dropout, p: 0.5
deflayer :fc2, ExTorch.NN.Linear, in_features: 128, out_features: 10
def forward(model, x) do
x
|> layer(model, :fc1)
|> layer(model, :relu)
|> layer(model, :dropout)
|> layer(model, :fc2)
end
end
model = MyMLP.new()
input = ExTorch.randn({32, 784})
output = MyMLP.forward(model, input)
# => %ExTorch.Tensor{size: {32, 10}, ...}
# Inspect parameters
MyMLP.parameters(model)
# => [{"fc1.weight", #Tensor<[128, 784]>}, {"fc1.bias", #Tensor<[128]>}, ...]
Available layers: Linear, Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, MaxPool1d, MaxPool2d, AvgPool1d, AvgPool2d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm, InstanceNorm1d, InstanceNorm2d, Dropout, Embedding, LSTM, GRU, MultiheadAttention, Flatten, Unflatten, ReLU, LeakyReLU, GELU, ELU, SiLU, Mish, PReLU, Sigmoid, Tanh, Softmax, LogSoftmax.
Loading Pre-trained Weights
There are two ways to use a trained model from Python with a DSL-defined module:
Option A: from_jit/1 -- Use the JIT model's forward directly (simplest):
# The JIT model's forward() runs the computation with pre-trained weights.
# The DSL definition is validated against the .pt file's submodules.
model = MyMLP.from_jit("trained_model.pt")
output = MyMLP.predict(model, [input])Option B: load_weights/1 -- Copy weights into Elixir-defined layers:
# Creates DSL layers, then copies matching parameter tensors from the .pt file.
# The result is a regular DSL model that runs through your forward/2 function.
model = MyMLP.load_weights("trained_model.pt")
output = MyMLP.forward(model, input)
Both produce identical outputs. Use from_jit when you want the exact Python forward logic. Use load_weights when your Elixir forward/2 differs (e.g., different dropout, custom post-processing) but you want the same trained parameters.
JIT Model Introspection
model = ExTorch.JIT.load("resnet18.pt")
# Extract structured schema
schema = ExTorch.NN.Introspect.schema(model)
schema.submodules
# => [%{name: "conv1", type_name: "...Conv2d", parameters: [%{name: "weight", shape: [64, 3, 7, 7], ...}]}, ...]
# View the computation graph
IO.puts(ExTorch.NN.Introspect.graph(model))
# Generate Elixir DSL source code from any .pt model
IO.puts(ExTorch.NN.Introspect.to_elixir(model, "ResNet18"))torch.export Inference (JIT-free)
ExTorch reads .pt2 files from torch.export.save and runs inference
through a built-in ATen graph interpreter -- no Python, no JIT, no C++
ExportedProgram support needed:
# Python: export and save
exported = torch.export.export(model, (example_input,))
torch.export.save(exported, "model.pt2")# Elixir: load and run inference directly
model = ExTorch.Export.load("model.pt2")
output = ExTorch.Export.forward(model, [input])
# Or read weights and generate DSL
weights = ExTorch.Export.read_weights("model.pt2")
IO.puts(ExTorch.Export.to_elixir("model.pt2", "MyModel"))
# Or load weights into a DSL module
model = MyModel.load_weights_from_export("model.pt2")
output = MyModel.forward(model, input)Tested with AlexNet, ResNet18, MobileNetV2, VGG11, SqueezeNet, transformers, and autoencoders. The interpreter supports 60+ ATen operations. The DSL generator performs data flow analysis to correctly emit skip connections and branching architectures (e.g., ResNet residual blocks).
AOTI Compiled Models
For maximum inference throughput, load AOTInductor-compiled .pt2 packages:
# Python: compile and package
from torch._inductor import aoti_compile_and_package
exported = torch.export.export(model, (example_input,))
aoti_compile_and_package(exported, package_path="model.pt2")# Elixir: load and run
model = ExTorch.AOTI.load("model.pt2")
[output] = ExTorch.AOTI.forward(model, [input])
# Inspect metadata
ExTorch.AOTI.metadata(model)
# => %{"AOTI_DEVICE_KEY" => "cpu", "AOTI_CPU_ISA" => "AVX2", ...}Zero-Copy Tensor Exchange with Nx
# ExTorch tensor -> raw pointer (for passing to Torchx/Nx)
blob = ExTorch.Tensor.Blob.to_blob(tensor)
# => %Blob{ptr: 140234567890, shape: {3, 224, 224}, strides: [...], dtype: :float, ...}
# Foreign pointer -> ExTorch tensor (zero-copy, no data movement)
view = ExTorch.Tensor.Blob.from_blob(
%{ptr: foreign_ptr, shape: {3, 224, 224}, dtype: :float32},
owner: source_tensor # prevents GC of source memory
)
view.tensor # => %ExTorch.Tensor{...}Telemetry & Metrics
# Enable metrics collection (call in your Application.start)
ExTorch.Metrics.setup()
# Metrics are automatically collected from JIT.Server telemetry events
ExTorch.Metrics.get("model.pt")
# => %{inference_count: 1500, error_count: 2, total_duration_ms: 4523.1,
# min_duration_ms: 1.2, max_duration_ms: 89.2, load_duration_ms: 340.5, ...}
# Attach your own handlers to telemetry events
:telemetry.attach("my-handler", [:extorch, :jit, :forward, :stop], &handle_event/4, nil)LiveDashboard (optional): Add {:phoenix_live_dashboard, "~> 0.8"} to your deps and configure:
live_dashboard "/dashboard",
additional_pages: [
extorch: {ExTorch.Observer.Dashboard, []}
]CUDA Support
ExTorch.Native.cuda_is_available() # => true/false
ExTorch.Native.cuda_device_count() # => 2
# Load model on GPU
model = ExTorch.JIT.load("model.pt", device: {:cuda, 0})
# Monitor GPU memory
ExTorch.Native.cuda_memory_allocated(0) # bytes currently allocated
ExTorch.Native.cuda_memory_reserved(0) # bytes reserved by caching allocatorArchitecture
ExTorch uses a three-layer architecture:
- C++ -- Wraps libtorch APIs (
torch::jit,torch::nn, tensor ops) behind shared pointer types - Rust -- Bridges C++ to Erlang NIFs via cxx and Rustler, with type-safe encoding/decoding
- Elixir -- Macro-generated API with
defbinding/nif_impl!for tensor ops, hand-written modules for JIT/NN/telemetry
License
MIT