Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 61 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,42 @@

# TwisteRL

A minimalistic, high-performance Reinforcement Learning framework implemented in Rust.
A minimalistic, high-performance Reinforcement Learning framework implemented in **Rust** and **Mojo**.

The current version is a *Proof of Concept*, stay tuned for future releases!

## Implementations

| Language | Location | Status |
|----------|----------|--------|
| Rust | `rust/` | Core implementation with Python bindings |
| Mojo | `mojo/` | Optimized port with competitive performance |

## Install

### Python (Rust backend)
```shell
pip install .
```

### Mojo
```shell
cd mojo
# Requires Mojo 25.5+ from https://docs.modular.com/mojo/manual/get-started
```

## Use

### Training
### Training (Python/Rust)
```shell
python -m twisterl.train --config examples/ppo_puzzle8_v1.json
```

### Training (Mojo)
```shell
cd mojo
mojo run train_puzzle.mojo
```
This example trains a model to play the popular "8 puzzle":

```
Expand Down Expand Up @@ -121,29 +141,56 @@ convert_pt_to_safetensors("model.pt") # Creates model.safetensors

- [Permutation twists in environments](docs/twists.md)

## 🚀 Key Features
- **High-Performance Core**: RL episode loop implemented in Rust for faster training and inference
## 🚀 Key Features
- **High-Performance Core**: RL episode loop implemented in Rust and Mojo for faster training and inference
- **Inference-Ready**: Easy compilation and bundling of models with environments into portable binaries for inference
- **Modular Design**: Support for multiple algorithms (PPO, AlphaZero) with interchangeable training and inference
- **Language Interoperability**: Core in Rust with Python interface
- **Symmetry-Aware Training via Twists**: Environments can expose observation/action permutations (“twists”) so policies automatically exploit device or puzzle symmetries for faster learning.
- **Modular Design**: Support for multiple algorithms (PPO, AlphaZero, Evolution Strategies) with interchangeable training and inference
- **Language Interoperability**: Core in Rust/Mojo with Python interface

## ⚡ Performance (Mojo vs Rust)

Benchmark results on 8-puzzle environment (100K iterations):

| Operation | Mojo | Rust | Winner |
|-----------|------|------|--------|
| Reset | 0.06 μs | 0.13 μs | Mojo 2.2x faster |
| Combined RL step | 0.07 μs | 0.28 μs | Mojo 4x faster |
| Episode rollout | 0.09 ms/10K | 0.53 ms/1K | Mojo 5.8x faster |

Run benchmarks:
```shell
# Mojo
cd mojo && mojo run benchmark.mojo

# Rust
cd rust && cargo run --release --bin benchmark
```


## 🏗️ Current State (PoC)

### Rust Implementation
- Hybrid rust-python implementation:
- Data collection and inference in Rust
- Training in Python (PyTorch)
- Supported algorithms:
- PPO (Proximal Policy Optimization)
- AlphaZero
- Supported algorithms: PPO, AlphaZero
- Support for native Rust environments and Python environments through a wrapper

### Mojo Implementation
- Pure Mojo implementation with no external dependencies
- Training uses Evolution Strategies (derivative-free, no backprop needed)
- Optimized with `InlineArray` for stack allocation and `@always_inline`
- Model save/load support

### Common
- Focus on discrete observation and action spaces
- Support for native Rust environments and for Python environments through a wrapper


## 🚧 Roadmap
Upcoming Features (Alpha Version)

- Full training in Rust
- Full training in Rust (without PyTorch dependency)
- Mojo GPU acceleration with MAX
- Extended support for:
- Continuous observation spaces
- Continuous action spaces
Expand Down Expand Up @@ -173,7 +220,8 @@ Perfect for:
## 🔧 Current Limitations

- Limited to discrete observation and action spaces
- Python environments may create performance bottlenecks
- Python environments may create performance bottlenecks (Rust)
- Mojo version currently supports Evolution Strategies only (no PPO/AlphaZero yet)
- Documentation and testing coverage is currently minimal
- WebAssembly support is in development

Expand Down
179 changes: 179 additions & 0 deletions mojo/benchmark.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# -*- coding: utf-8 -*-
# (C) Copyright 2025 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0.

"""
Benchmark for TwisterL Mojo implementation.
Compares performance with the Rust implementation.

Optimized version using InlineArray for stack-allocated data.
"""

from time import perf_counter_ns
from collections import List

from twisterl.envs.puzzle import PuzzleEnv, NUM_ACTIONS, PUZZLE_SIZE


alias NUM_ITERATIONS = 100000 # Increased for more accurate timing
alias NUM_EPISODES = 10000


fn benchmark_env_operations():
print("\n--- Environment Operations Benchmark (Optimized) ---")

# Benchmark: Create environment
var start = perf_counter_ns()
for _ in range(NUM_ITERATIONS):
var env = PuzzleEnv(3, 3, 5, 2, 20)
_ = env # Prevent optimization
var elapsed = perf_counter_ns() - start
var elapsed_us = Float64(elapsed) / 1000.0
print("Create env (", NUM_ITERATIONS, " iterations):", elapsed_us / 1000.0, "ms (", elapsed_us / NUM_ITERATIONS, "us/iter)")

# Benchmark: Reset
var env = PuzzleEnv(3, 3, 5, 2, 20)
start = perf_counter_ns()
for _ in range(NUM_ITERATIONS):
env.reset()
elapsed = perf_counter_ns() - start
elapsed_us = Float64(elapsed) / 1000.0
print("Reset (", NUM_ITERATIONS, " iterations):", elapsed_us / 1000.0, "ms (", elapsed_us / NUM_ITERATIONS, "us/iter)")

# Benchmark: Step
env = PuzzleEnv(3, 3, 5, 2, 20)
env.reset()
start = perf_counter_ns()
for i in range(NUM_ITERATIONS):
env.step(i % 4)
elapsed = perf_counter_ns() - start
elapsed_us = Float64(elapsed) / 1000.0
print("Step (", NUM_ITERATIONS, " iterations):", elapsed_us / 1000.0, "ms (", elapsed_us / NUM_ITERATIONS, "us/iter)")

# Benchmark: Observe (optimized InlineArray version)
env = PuzzleEnv(3, 3, 5, 2, 20)
start = perf_counter_ns()
for _ in range(NUM_ITERATIONS):
var obs = env.observe()
_ = obs
elapsed = perf_counter_ns() - start
elapsed_us = Float64(elapsed) / 1000.0
print("Observe (", NUM_ITERATIONS, " iterations):", elapsed_us / 1000.0, "ms (", elapsed_us / NUM_ITERATIONS, "us/iter)")

# Benchmark: Masks (optimized InlineArray version)
env = PuzzleEnv(3, 3, 5, 2, 20)
start = perf_counter_ns()
for _ in range(NUM_ITERATIONS):
var masks = env.masks()
_ = masks
elapsed = perf_counter_ns() - start
elapsed_us = Float64(elapsed) / 1000.0
print("Masks (", NUM_ITERATIONS, " iterations):", elapsed_us / 1000.0, "ms (", elapsed_us / NUM_ITERATIONS, "us/iter)")

# Benchmark: Clone (optimized - direct InlineArray copy)
env = PuzzleEnv(3, 3, 5, 2, 20)
start = perf_counter_ns()
for _ in range(NUM_ITERATIONS):
var cloned = env.clone()
_ = cloned
elapsed = perf_counter_ns() - start
elapsed_us = Float64(elapsed) / 1000.0
print("Clone (", NUM_ITERATIONS, " iterations):", elapsed_us / 1000.0, "ms (", elapsed_us / NUM_ITERATIONS, "us/iter)")

# Benchmark: is_final
env = PuzzleEnv(3, 3, 5, 2, 20)
start = perf_counter_ns()
for _ in range(NUM_ITERATIONS):
var is_final = env.is_final()
_ = is_final
elapsed = perf_counter_ns() - start
elapsed_us = Float64(elapsed) / 1000.0
print("is_final (", NUM_ITERATIONS, " iterations):", elapsed_us / 1000.0, "ms (", elapsed_us / NUM_ITERATIONS, "us/iter)")

# Benchmark: reward
env = PuzzleEnv(3, 3, 5, 2, 20)
start = perf_counter_ns()
for _ in range(NUM_ITERATIONS):
var reward = env.reward()
_ = reward
elapsed = perf_counter_ns() - start
elapsed_us = Float64(elapsed) / 1000.0
print("Reward (", NUM_ITERATIONS, " iterations):", elapsed_us / 1000.0, "ms (", elapsed_us / NUM_ITERATIONS, "us/iter)")


fn benchmark_episode_rollout():
print("\n--- Episode Rollout Benchmark (Optimized) ---")

var start = perf_counter_ns()
for _ in range(NUM_EPISODES):
var env = PuzzleEnv(3, 3, 5, 2, 20)
env.reset()

var step_count = 0
while not env.is_final() and step_count < 100:
var masks = env.masks()
# Simple policy: pick first valid action
var action = 0
for i in range(NUM_ACTIONS):
if masks[i]:
action = i
break
env.step(action)
var obs = env.observe()
var reward = env.reward()
_ = obs
_ = reward
step_count += 1

var elapsed = perf_counter_ns() - start
var elapsed_ms = Float64(elapsed) / 1_000_000.0
print("Episode rollout (", NUM_EPISODES, " episodes):", elapsed_ms, "ms (", elapsed_ms / NUM_EPISODES, "ms/episode)")


fn benchmark_combined_operations():
print("\n--- Combined Operations Benchmark (Optimized) ---")

# Simulate what happens in a typical RL step
var iterations = NUM_ITERATIONS
var start = perf_counter_ns()

for _ in range(iterations):
var env = PuzzleEnv(3, 3, 5, 2, 20)
env.reset()
var obs = env.observe()
var masks = env.masks()
var action = 0
for i in range(NUM_ACTIONS):
if masks[i]:
action = i
break
env.step(action)
var reward = env.reward()
var is_final = env.is_final()
_ = obs
_ = reward
_ = is_final

var elapsed = perf_counter_ns() - start
var elapsed_us = Float64(elapsed) / 1000.0
print("Combined RL step (", iterations, " iterations):", elapsed_us / 1000.0, "ms (", elapsed_us / iterations, "us/iter)")


fn main():
print("============================================================")
print("TwisterL Mojo Benchmark (Optimized with InlineArray)")
print("============================================================")
print("Optimizations applied:")
print(" - InlineArray[Int, 9] for puzzle state (stack vs heap)")
print(" - InlineArray[Bool, 4] for action masks")
print(" - @always_inline for hot methods")
print(" - Pre-allocated capacity for List operations")

benchmark_env_operations()
benchmark_episode_rollout()
benchmark_combined_operations()

print("\n============================================================")
print("Benchmark complete!")
print("============================================================")
8 changes: 8 additions & 0 deletions mojo/mojoproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[project]
name = "twisterl"
version = "0.1.0"
description = "TwisterL - Reinforcement Learning Library in Mojo"
authors = ["IBM"]
license = "Apache-2.0"

[dependencies]
Loading
Loading