Skip to content

Commit 64b8435

Browse files
committed
In progress: Sokoban workshop code samples
1 parent c287f1e commit 64b8435

File tree

13 files changed

+711
-16
lines changed

13 files changed

+711
-16
lines changed

dev/fehu/docs/RL_Introduction-REINFORCE.md

Lines changed: 190 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22

33
Welcome to reinforcement learning! If you're familiar with supervised learning and neural network training, you're about to discover a fundamentally different approach to machine learning.
44

5-
{pause .block}
6-
This presentation is work-in-progress!
7-
85
## What is Reinforcement Learning? {#rl-definition}
96

107
{.definition title="Reinforcement Learning"}
@@ -34,6 +31,15 @@ Instead of learning from labeled examples, an **agent** learns by **acting** in
3431

3532
{pause}
3633

34+
### Workshop Setup: Your First Environment
35+
36+
Let's start by creating a simple grid world environment using Fehu:
37+
38+
{pause down="~duration:15"}
39+
{slip include src=../example/sokoban/workshop/slide1.ml}
40+
41+
{pause}
42+
3743
Think of it like learning to play a game:
3844
- You (the neural network) don't know the rules initially
3945
- You try actions and see what happens to the environment
@@ -106,6 +112,13 @@ Both environment and information states are **Markovian** - they capture all rel
106112
- **Action sampling**: Choose "down" with 60% probability
107113
- **Learned parameters**: θ represents all network weights and biases
108114

115+
{pause}
116+
117+
### Workshop Part 2: Create Your First Policy Network
118+
119+
{pause down="~duration:15"}
120+
{slip include src=../example/sokoban/workshop/slide2.ml}
121+
109122
***
110123

111124
{pause up #episodes}
@@ -141,6 +154,13 @@ $$V^\pi(s) = \mathbb{E}_\pi[G_t | S_t = s]$$
141154

142155
But how do we compute gradients when the "target" (return) depends on our own actions?
143156

157+
{pause}
158+
159+
### Workshop Part 3: Collect an Episode
160+
161+
{pause down="~duration:15"}
162+
{slip include src=../example/sokoban/workshop/slide3.ml}
163+
144164
***
145165

146166
{pause center #reinforce-intro}
@@ -206,6 +226,12 @@ From Sutton & Barto:
206226
- Update: $\theta \leftarrow \theta + \alpha G_t \nabla_\theta \ln \pi(A_t|S_t,\theta)$
207227

208228
{pause up=algorithm-reinforce}
229+
230+
### Workshop Part 4: Implement Basic REINFORCE
231+
232+
{pause down="~duration:15"}
233+
{slip include src=../example/sokoban/workshop/slide4.ml}
234+
209235
### Key Properties: High Variance Problem
210236

211237
From Sutton & Barto:
@@ -298,6 +324,12 @@ From Sutton & Barto:
298324
> - **Learned** to predict V(s) using gradient descent
299325
> - More complex but much more effective
300326
327+
{pause down}
328+
### Workshop Part 5: Add a Simple Baseline
329+
330+
{pause down="~duration:15"}
331+
{slip include src=../example/sokoban/workshop/slide5.ml}
332+
301333
***
302334

303335
{pause center #reinforce-baseline}
@@ -320,6 +352,13 @@ The baseline **neural network** is learned to predict expected returns, reducing
320352
- **Policy network**: θ parameters, outputs action probabilities
321353
- **Baseline network**: w parameters, outputs state value estimates
322354

355+
{pause}
356+
357+
### Workshop Part 6: Learned Baseline (Actor-Critic)
358+
359+
{pause down="~duration:15"}
360+
{slip include src=../example/sokoban/workshop/slide6.ml}
361+
323362
***
324363

325364
{pause center}
@@ -358,6 +397,13 @@ REINFORCE with baseline is a simple actor-critic method:
358397
- Implement REINFORCE
359398
- Enhance with the constant baseline
360399

400+
{pause}
401+
402+
### Workshop Summary: What We Built
403+
404+
{pause down="~duration:15"}
405+
{slip include src=../example/sokoban/workshop/slide_pip.ml}
406+
361407
***
362408

363409
{pause center #policy-ratios}
@@ -436,6 +482,13 @@ $$L^{CLIP}(\theta) = \min\left(\text{ratio}_t \cdot A_t, \; \text{clip}(\text{ra
436482
- `ratio = 0.01` → clipped to `0.8` → prevents tiny updates too
437483
- `ratio = 1.1` → no clipping needed, within [0.8, 1.2]
438484

485+
{pause}
486+
487+
### Workshop Part 7: Add Clipping for Stability
488+
489+
{pause down="~duration:15"}
490+
{slip include src=../example/sokoban/workshop/slide7.ml}
491+
439492
***
440493

441494
{pause center #kl-penalty}
@@ -469,6 +522,13 @@ $$L_{total}(\theta) = L_{policy}(\theta) - \beta \cdot D_{KL}[\pi_{old} \| \pi_{
469522
>
470523
> The penalty acts like a "trust region" - we trust small changes more than large ones.
471524
525+
{pause}
526+
527+
### Workshop Part 8: Add KL Penalty
528+
529+
{pause down="~duration:15"}
530+
{slip include src=../example/sokoban/workshop/slide8.ml}
531+
472532
{pause up=kl-objective}
473533
> ### Why Both Clipping AND KL Penalty?
474534
>
@@ -532,20 +592,134 @@ Now we can understand GRPO: **REINFORCE + GRPO Innovation + Clipping + KL Penalt
532592
{pause down=grpo-for-llms}
533593

534594
{#grpo-implementation}
535-
### GRPO Implementation Reality
536-
537-
```python
538-
# For each query, generate G=4 responses
539-
responses = model.generate(query, num_return_sequences=4)
595+
### GRPO Implementation in Fehu
540596

541-
# Compute group-relative advantages
542-
rewards = [reward_model(r) for r in responses]
543-
advantages = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-8)
544-
545-
# Clipped policy gradient update
546-
ratios = new_probs / old_probs
547-
clipped_loss = min(ratios * advantages,
548-
clip(ratios, 0.8, 1.2) * advantages)
597+
```ocaml
598+
(* Workshop Part 9: GRPO Implementation *)
599+
open Fehu
600+
601+
(* Generate multiple trajectories for the same initial state *)
602+
let collect_group_trajectories env policy_net params init_state group_size =
603+
let trajectories = Array.init group_size (fun _ ->
604+
(* Each trajectory starts from same state but different random actions *)
605+
collect_episode_from_state env policy_net params init_state 100
606+
) in
607+
trajectories
608+
609+
(* Compute group-relative advantages *)
610+
let compute_group_advantages rewards =
611+
let mean = Array.fold_left (+.) 0. rewards /.
612+
float_of_int (Array.length rewards) in
613+
let variance = Array.fold_left (fun acc r ->
614+
acc +. (r -. mean) ** 2.) 0. rewards /.
615+
float_of_int (Array.length rewards) in
616+
let std = sqrt variance in
617+
618+
(* Normalize advantages within group *)
619+
Array.map (fun r -> (r -. mean) /. (std +. 1e-8)) rewards
620+
621+
(* GRPO training step *)
622+
let train_grpo_step env policy_net params old_params group_size epsilon beta =
623+
let device = Rune.c in
624+
625+
(* Get initial state *)
626+
let init_obs, _ = env.reset () in
627+
628+
(* Collect group of trajectories from same starting point *)
629+
let group = collect_group_trajectories env policy_net params init_obs group_size in
630+
631+
(* Extract returns for each trajectory *)
632+
let group_returns = Array.map (fun traj ->
633+
let returns = compute_returns traj.rewards 0.99 in
634+
returns.(0) (* Total return *)
635+
) group in
636+
637+
(* Compute group-relative advantages *)
638+
let group_advantages = compute_group_advantages group_returns in
639+
640+
(* Update policy using clipped objective with KL penalty *)
641+
let loss, grads = Kaun.value_and_grad (fun p ->
642+
let total_loss = ref (Rune.zeros device Rune.float32 [||]) in
643+
644+
Array.iteri (fun g_idx trajectory ->
645+
let advantage = group_advantages.(g_idx) in
646+
647+
Array.iteri (fun t state ->
648+
let action = trajectory.actions.(t) in
649+
650+
(* Compute new and old log probs *)
651+
let new_logits = Kaun.apply policy_net p ~training:true state in
652+
let new_log_probs = Rune.log_softmax ~axis:(-1) new_logits in
653+
let new_action_log_prob = Rune.gather new_log_probs action in
654+
655+
let old_logits = Kaun.apply policy_net old_params ~training:false state in
656+
let old_log_probs = Rune.log_softmax ~axis:(-1) old_logits in
657+
let old_action_log_prob = Rune.gather old_log_probs action in
658+
659+
(* Compute ratio and clip *)
660+
let log_ratio = Rune.sub new_action_log_prob old_action_log_prob in
661+
let ratio = Rune.exp log_ratio in
662+
let clipped_ratio = clip_ratio ratio epsilon in
663+
664+
(* Clipped objective *)
665+
let adv_scalar = Rune.scalar device Rune.float32 advantage in
666+
let obj1 = Rune.mul ratio adv_scalar in
667+
let obj2 = Rune.mul clipped_ratio adv_scalar in
668+
let clipped_obj = Rune.minimum obj1 obj2 in
669+
670+
(* Add KL penalty *)
671+
let kl_penalty = Rune.mul
672+
(Rune.scalar device Rune.float32 beta)
673+
(Rune.sub old_action_log_prob new_action_log_prob) in
674+
675+
let step_loss = Rune.sub (Rune.neg clipped_obj) kl_penalty in
676+
total_loss := Rune.add !total_loss step_loss
677+
) trajectory.states
678+
) group;
679+
680+
(* Average over all steps and trajectories *)
681+
let total_steps = Array.fold_left (fun acc traj ->
682+
acc + Array.length traj.states) 0 group in
683+
Rune.div !total_loss (Rune.scalar device Rune.float32 (float_of_int total_steps))
684+
) params in
685+
686+
(loss, grads)
687+
688+
(* Complete GRPO training loop *)
689+
let train_grpo env n_iterations group_size learning_rate epsilon beta =
690+
let device = Rune.c in
691+
let rng = Rune.Rng.key 42 in
692+
693+
(* Initialize policy *)
694+
let policy_net = create_policy_network 5 4 in
695+
let dummy_obs = Rune.zeros device Rune.float32 [|5; 5|] in
696+
let params = Kaun.init policy_net ~rngs:rng dummy_obs in
697+
let old_params = ref (Ptree.copy params) in (* Keep old params for ratios *)
698+
699+
(* Optimizer *)
700+
let optimizer = Kaun.Optimizer.adam ~lr:learning_rate () in
701+
let opt_state = ref (optimizer.init params) in
702+
703+
for iter = 1 to n_iterations do
704+
(* GRPO update *)
705+
let loss, grads = train_grpo_step env policy_net params !old_params
706+
group_size epsilon beta in
707+
708+
(* Apply gradients *)
709+
let updates, new_state = optimizer.update !opt_state params grads in
710+
opt_state := new_state;
711+
Kaun.Optimizer.apply_updates_inplace params updates;
712+
713+
(* Update old params periodically *)
714+
if iter mod 5 = 0 then
715+
old_params := Ptree.copy params;
716+
717+
if iter mod 10 = 0 then
718+
Printf.printf "Iteration %d: Loss = %.4f\n"
719+
iter (Rune.unsafe_get [] loss)
720+
done;
721+
722+
(policy_net, params)
549723
```
550724

551725
***

dev/fehu/example/dune

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
(executables
22
(names cartpole_random dqn_cartpole)
3+
(modules cartpole_random dqn_cartpole)
34
(libraries fehu kaun classic unix))

dev/fehu/example/sokoban/dune

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
(executables
22
(names sokoban_reinforce sokoban_grpo sokoban_dqn backoff_tabular visualize_logs test_room_gen test_visualization test_solvability test_deadlock test_wall_segments test_dqn_curriculum)
3+
(modules sokoban_reinforce sokoban_grpo sokoban_dqn backoff_tabular visualize_logs test_room_gen test_visualization test_solvability test_deadlock test_wall_segments test_dqn_curriculum)
34
(libraries fehu kaun sokoban unix))
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
(executables
2+
(names slide1 slide2 slide3 slide4 slide5 slide6 slide7 slide8 slide_pip)
3+
(libraries fehu rune kaun sokoban unix)
4+
(modules slide1 slide2 slide3 slide4 slide5 slide6 slide7 slide8 slide_pip)
5+
(modes exe))
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
(*
2+
```ocaml
3+
*)
4+
open Fehu
5+
let dev = Rune.metal ()
6+
(* Workshop Part 1: Define a simple grid world *)
7+
let create_simple_gridworld size =
8+
(* Mutable state for agent position *)
9+
let agent_pos = ref (0, 0) in
10+
let goal_pos = (size - 1, size - 1) in
11+
(* Define observation and action spaces *)
12+
let observation_space =
13+
Space.Box {
14+
low = Rune.zeros dev Rune.float32 [|size; size|];
15+
high = Rune.ones dev Rune.float32 [|size; size|];
16+
shape = [|size; size|];
17+
}
18+
in
19+
(* Up, Down, Left, Right *)
20+
let action_space = Space.Discrete 4 in
21+
(* Reset function - initialize episode *)
22+
let reset ?seed () =
23+
let _ = Option.map Random.init seed in
24+
agent_pos := (0, 0);
25+
let obs = Rune.zeros dev Rune.float32 [|size; size|] in
26+
(* Mark agent position *)
27+
Rune.unsafe_set [0; 0] 1.0 obs;
28+
(obs, [])
29+
in
30+
(* Step function - take action and return new state *)
31+
let step action =
32+
let action_idx =
33+
Rune.unsafe_get [] action |> int_of_float in
34+
let x, y = !agent_pos in
35+
(* Move based on action *)
36+
let new_pos = match action_idx with
37+
| 0 -> (x, max 0 (y - 1)) (* Up *)
38+
| 1 -> (x, min (size-1) (y + 1)) (* Down *)
39+
| 2 -> (max 0 (x - 1), y) (* Left *)
40+
| 3 -> (min (size-1) (x + 1), y) (* Right *)
41+
| _ -> (x, y)
42+
in
43+
agent_pos := new_pos;
44+
(* Create observation *)
45+
let obs = Rune.zeros dev Rune.float32 [|size; size|] in
46+
let x, y = !agent_pos in
47+
Rune.unsafe_set [x; y] 1.0 obs;
48+
(* Compute reward *)
49+
let reward = if new_pos = goal_pos then 10.0 else -0.1 in
50+
let terminated = new_pos = goal_pos in
51+
(obs, reward, terminated, false, [])
52+
in
53+
Env.make ~observation_space ~action_space ~reset ~step ()
54+
(* Test the environment *)
55+
let () =
56+
let env = create_simple_gridworld 5 in
57+
let obs, _ = env.reset () in
58+
print_endline "Initial state:";
59+
Rune.print obs
60+
(*
61+
```
62+
*)

0 commit comments

Comments
 (0)