ExTorch

Elixir bindings for libtorch -- production ML model serving on the BEAM.

ExTorch lets you load TorchScript models, run inference with OTP fault tolerance, define neural network architectures with an Elixir DSL, and monitor serving performance through telemetry and LiveDashboard.

Features

Requirements

Installation

Add extorch to your dependencies in mix.exs:

def deps do
  [
    {:extorch, "~> 0.3.0"}
  ]
end

ExTorch will download libtorch automatically on first compile. To use a local installation, configure in config/config.exs:

config :extorch, libtorch: [
  version: :local,
  folder: :python  # or an absolute path to libtorch
]

Quick Start

Loading and Serving a TorchScript Model

# Load a model exported from Python with torch.jit.script() or torch.jit.trace()
model = ExTorch.JIT.load("model.pt")
ExTorch.JIT.eval(model)

# Run inference
input = ExTorch.randn({1, 3, 224, 224})
output = ExTorch.JIT.forward(model, [input])

# Models returning tuples/dicts work naturally
{logits, features} = ExTorch.JIT.forward(multi_output_model, [input])

GenServer-based Model Server

# Start a supervised model server
{:ok, _pid} = ExTorch.JIT.Server.start_link(
  path: "model.pt",
  device: :cpu,
  name: MyModel
)

# Run inference (thread-safe, serialized through GenServer)
output = ExTorch.JIT.Server.predict(MyModel, [input])

# Check server stats
ExTorch.JIT.Server.info(MyModel)
# => %{path: "model.pt", device: :cpu, inference_count: 42, error_count: 0, uptime_ms: 15000}

Neural Network DSL

defmodule MyMLP do
  use ExTorch.NN.Module

  deflayer :fc1, ExTorch.NN.Linear, in_features: 784, out_features: 128
  deflayer :relu, ExTorch.NN.ReLU
  deflayer :dropout, ExTorch.NN.Dropout, p: 0.5
  deflayer :fc2, ExTorch.NN.Linear, in_features: 128, out_features: 10

  def forward(model, x) do
    x
    |> layer(model, :fc1)
    |> layer(model, :relu)
    |> layer(model, :dropout)
    |> layer(model, :fc2)
  end
end

model = MyMLP.new()
input = ExTorch.randn({32, 784})
output = MyMLP.forward(model, input)
# => %ExTorch.Tensor{size: {32, 10}, ...}

# Inspect parameters
MyMLP.parameters(model)
# => [{"fc1.weight", #Tensor<[128, 784]>}, {"fc1.bias", #Tensor<[128]>}, ...]

Available layers: Linear, Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, MaxPool1d, MaxPool2d, AvgPool1d, AvgPool2d, AdaptiveAvgPool1d, AdaptiveAvgPool2d, BatchNorm1d, BatchNorm2d, LayerNorm, GroupNorm, InstanceNorm1d, InstanceNorm2d, Dropout, Embedding, LSTM, GRU, MultiheadAttention, Flatten, Unflatten, ReLU, LeakyReLU, GELU, ELU, SiLU, Mish, PReLU, Sigmoid, Tanh, Softmax, LogSoftmax.

Loading Pre-trained Weights

There are two ways to use a trained model from Python with a DSL-defined module:

Option A: from_jit/1 -- Use the JIT model's forward directly (simplest):

# The JIT model's forward() runs the computation with pre-trained weights.
# The DSL definition is validated against the .pt file's submodules.
model = MyMLP.from_jit("trained_model.pt")
output = MyMLP.predict(model, [input])

Option B: load_weights/1 -- Copy weights into Elixir-defined layers:

# Creates DSL layers, then copies matching parameter tensors from the .pt file.
# The result is a regular DSL model that runs through your forward/2 function.
model = MyMLP.load_weights("trained_model.pt")
output = MyMLP.forward(model, input)

Both produce identical outputs. Use from_jit when you want the exact Python forward logic. Use load_weights when your Elixir forward/2 differs (e.g., different dropout, custom post-processing) but you want the same trained parameters.

JIT Model Introspection

model = ExTorch.JIT.load("resnet18.pt")

# Extract structured schema
schema = ExTorch.NN.Introspect.schema(model)
schema.submodules
# => [%{name: "conv1", type_name: "...Conv2d", parameters: [%{name: "weight", shape: [64, 3, 7, 7], ...}]}, ...]

# View the computation graph
IO.puts(ExTorch.NN.Introspect.graph(model))

# Generate Elixir DSL source code from any .pt model
IO.puts(ExTorch.NN.Introspect.to_elixir(model, "ResNet18"))

torch.export Inference (JIT-free)

ExTorch reads .pt2 files from torch.export.save and runs inference through a built-in ATen graph interpreter -- no Python, no JIT, no C++ ExportedProgram support needed:

# Python: export and save
exported = torch.export.export(model, (example_input,))
torch.export.save(exported, "model.pt2")
# Elixir: load and run inference directly
model = ExTorch.Export.load("model.pt2")
output = ExTorch.Export.forward(model, [input])

# Or read weights and generate DSL
weights = ExTorch.Export.read_weights("model.pt2")
IO.puts(ExTorch.Export.to_elixir("model.pt2", "MyModel"))

# Or load weights into a DSL module
model = MyModel.load_weights_from_export("model.pt2")
output = MyModel.forward(model, input)

Tested with AlexNet, ResNet18, MobileNetV2, VGG11, SqueezeNet, transformers, and autoencoders. The interpreter supports 60+ ATen operations. The DSL generator performs data flow analysis to correctly emit skip connections and branching architectures (e.g., ResNet residual blocks).

AOTI Compiled Models

For maximum inference throughput, load AOTInductor-compiled .pt2 packages:

# Python: compile and package
from torch._inductor import aoti_compile_and_package
exported = torch.export.export(model, (example_input,))
aoti_compile_and_package(exported, package_path="model.pt2")
# Elixir: load and run
model = ExTorch.AOTI.load("model.pt2")
[output] = ExTorch.AOTI.forward(model, [input])

# Inspect metadata
ExTorch.AOTI.metadata(model)
# => %{"AOTI_DEVICE_KEY" => "cpu", "AOTI_CPU_ISA" => "AVX2", ...}

Zero-Copy Tensor Exchange with Nx

# ExTorch tensor -> raw pointer (for passing to Torchx/Nx)
blob = ExTorch.Tensor.Blob.to_blob(tensor)
# => %Blob{ptr: 140234567890, shape: {3, 224, 224}, strides: [...], dtype: :float, ...}

# Foreign pointer -> ExTorch tensor (zero-copy, no data movement)
view = ExTorch.Tensor.Blob.from_blob(
  %{ptr: foreign_ptr, shape: {3, 224, 224}, dtype: :float32},
  owner: source_tensor  # prevents GC of source memory
)
view.tensor  # => %ExTorch.Tensor{...}

Telemetry & Metrics

# Enable metrics collection (call in your Application.start)
ExTorch.Metrics.setup()

# Metrics are automatically collected from JIT.Server telemetry events
ExTorch.Metrics.get("model.pt")
# => %{inference_count: 1500, error_count: 2, total_duration_ms: 4523.1,
#      min_duration_ms: 1.2, max_duration_ms: 89.2, load_duration_ms: 340.5, ...}

# Attach your own handlers to telemetry events
:telemetry.attach("my-handler", [:extorch, :jit, :forward, :stop], &handle_event/4, nil)

LiveDashboard (optional): Add {:phoenix_live_dashboard, "~> 0.8"} to your deps and configure:

live_dashboard "/dashboard",
  additional_pages: [
    extorch: {ExTorch.Observer.Dashboard, []}
  ]

CUDA Support

ExTorch.Native.cuda_is_available()   # => true/false
ExTorch.Native.cuda_device_count()   # => 2

# Load model on GPU
model = ExTorch.JIT.load("model.pt", device: {:cuda, 0})

# Monitor GPU memory
ExTorch.Native.cuda_memory_allocated(0)   # bytes currently allocated
ExTorch.Native.cuda_memory_reserved(0)    # bytes reserved by caching allocator

Architecture

ExTorch uses a three-layer architecture:

  1. C++ -- Wraps libtorch APIs (torch::jit, torch::nn, tensor ops) behind shared pointer types
  2. Rust -- Bridges C++ to Erlang NIFs via cxx and Rustler, with type-safe encoding/decoding
  3. Elixir -- Macro-generated API with defbinding/nif_impl! for tensor ops, hand-written modules for JIT/NN/telemetry

License

MIT