2
2
3
3
Welcome to reinforcement learning! If you're familiar with supervised learning and neural network training, you're about to discover a fundamentally different approach to machine learning.
4
4
5
- {pause .block}
6
- This presentation is work-in-progress!
7
-
8
5
## What is Reinforcement Learning? {#rl-definition}
9
6
10
7
{.definition title="Reinforcement Learning"}
@@ -34,6 +31,15 @@ Instead of learning from labeled examples, an **agent** learns by **acting** in
34
31
35
32
{pause}
36
33
34
+ ### Workshop Setup: Your First Environment
35
+
36
+ Let's start by creating a simple grid world environment using Fehu:
37
+
38
+ {pause down="~ duration:15"}
39
+ {slip include src=../example/sokoban/workshop/slide1.ml}
40
+
41
+ {pause}
42
+
37
43
Think of it like learning to play a game:
38
44
- You (the neural network) don't know the rules initially
39
45
- You try actions and see what happens to the environment
@@ -106,6 +112,13 @@ Both environment and information states are **Markovian** - they capture all rel
106
112
- ** Action sampling** : Choose "down" with 60% probability
107
113
- ** Learned parameters** : θ represents all network weights and biases
108
114
115
+ {pause}
116
+
117
+ ### Workshop Part 2: Create Your First Policy Network
118
+
119
+ {pause down="~ duration:15"}
120
+ {slip include src=../example/sokoban/workshop/slide2.ml}
121
+
109
122
***
110
123
111
124
{pause up #episodes}
@@ -141,6 +154,13 @@ $$V^\pi(s) = \mathbb{E}_\pi[G_t | S_t = s]$$
141
154
142
155
But how do we compute gradients when the "target" (return) depends on our own actions?
143
156
157
+ {pause}
158
+
159
+ ### Workshop Part 3: Collect an Episode
160
+
161
+ {pause down="~ duration:15"}
162
+ {slip include src=../example/sokoban/workshop/slide3.ml}
163
+
144
164
***
145
165
146
166
{pause center #reinforce-intro}
@@ -206,6 +226,12 @@ From Sutton & Barto:
206
226
- Update: $\theta \leftarrow \theta + \alpha G_t \nabla_ \theta \ln \pi(A_t|S_t,\theta)$
207
227
208
228
{pause up=algorithm-reinforce}
229
+
230
+ ### Workshop Part 4: Implement Basic REINFORCE
231
+
232
+ {pause down="~ duration:15"}
233
+ {slip include src=../example/sokoban/workshop/slide4.ml}
234
+
209
235
### Key Properties: High Variance Problem
210
236
211
237
From Sutton & Barto:
@@ -298,6 +324,12 @@ From Sutton & Barto:
298
324
> - ** Learned** to predict V(s) using gradient descent
299
325
> - More complex but much more effective
300
326
327
+ {pause down}
328
+ ### Workshop Part 5: Add a Simple Baseline
329
+
330
+ {pause down="~ duration:15"}
331
+ {slip include src=../example/sokoban/workshop/slide5.ml}
332
+
301
333
***
302
334
303
335
{pause center #reinforce-baseline}
@@ -320,6 +352,13 @@ The baseline **neural network** is learned to predict expected returns, reducing
320
352
- ** Policy network** : θ parameters, outputs action probabilities
321
353
- ** Baseline network** : w parameters, outputs state value estimates
322
354
355
+ {pause}
356
+
357
+ ### Workshop Part 6: Learned Baseline (Actor-Critic)
358
+
359
+ {pause down="~ duration:15"}
360
+ {slip include src=../example/sokoban/workshop/slide6.ml}
361
+
323
362
***
324
363
325
364
{pause center}
@@ -358,6 +397,13 @@ REINFORCE with baseline is a simple actor-critic method:
358
397
- Implement REINFORCE
359
398
- Enhance with the constant baseline
360
399
400
+ {pause}
401
+
402
+ ### Workshop Summary: What We Built
403
+
404
+ {pause down="~ duration:15"}
405
+ {slip include src=../example/sokoban/workshop/slide_pip.ml}
406
+
361
407
***
362
408
363
409
{pause center #policy-ratios}
@@ -436,6 +482,13 @@ $$L^{CLIP}(\theta) = \min\left(\text{ratio}_t \cdot A_t, \; \text{clip}(\text{ra
436
482
- ` ratio = 0.01 ` → clipped to ` 0.8 ` → prevents tiny updates too
437
483
- ` ratio = 1.1 ` → no clipping needed, within [ 0.8, 1.2]
438
484
485
+ {pause}
486
+
487
+ ### Workshop Part 7: Add Clipping for Stability
488
+
489
+ {pause down="~ duration:15"}
490
+ {slip include src=../example/sokoban/workshop/slide7.ml}
491
+
439
492
***
440
493
441
494
{pause center #kl-penalty}
@@ -469,6 +522,13 @@ $$L_{total}(\theta) = L_{policy}(\theta) - \beta \cdot D_{KL}[\pi_{old} \| \pi_{
469
522
>
470
523
> The penalty acts like a "trust region" - we trust small changes more than large ones.
471
524
525
+ {pause}
526
+
527
+ ### Workshop Part 8: Add KL Penalty
528
+
529
+ {pause down="~ duration:15"}
530
+ {slip include src=../example/sokoban/workshop/slide8.ml}
531
+
472
532
{pause up=kl-objective}
473
533
> ### Why Both Clipping AND KL Penalty?
474
534
>
@@ -532,20 +592,134 @@ Now we can understand GRPO: **REINFORCE + GRPO Innovation + Clipping + KL Penalt
532
592
{pause down=grpo-for-llms}
533
593
534
594
{#grpo-implementation}
535
- ### GRPO Implementation Reality
536
-
537
- ``` python
538
- # For each query, generate G=4 responses
539
- responses = model.generate(query, num_return_sequences = 4 )
595
+ ### GRPO Implementation in Fehu
540
596
541
- # Compute group-relative advantages
542
- rewards = [reward_model(r) for r in responses]
543
- advantages = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-8 )
544
-
545
- # Clipped policy gradient update
546
- ratios = new_probs / old_probs
547
- clipped_loss = min (ratios * advantages,
548
- clip(ratios, 0.8 , 1.2 ) * advantages)
597
+ ``` ocaml
598
+ (* Workshop Part 9: GRPO Implementation *)
599
+ open Fehu
600
+
601
+ (* Generate multiple trajectories for the same initial state *)
602
+ let collect_group_trajectories env policy_net params init_state group_size =
603
+ let trajectories = Array.init group_size (fun _ ->
604
+ (* Each trajectory starts from same state but different random actions *)
605
+ collect_episode_from_state env policy_net params init_state 100
606
+ ) in
607
+ trajectories
608
+
609
+ (* Compute group-relative advantages *)
610
+ let compute_group_advantages rewards =
611
+ let mean = Array.fold_left (+.) 0. rewards /.
612
+ float_of_int (Array.length rewards) in
613
+ let variance = Array.fold_left (fun acc r ->
614
+ acc +. (r -. mean) ** 2.) 0. rewards /.
615
+ float_of_int (Array.length rewards) in
616
+ let std = sqrt variance in
617
+
618
+ (* Normalize advantages within group *)
619
+ Array.map (fun r -> (r -. mean) /. (std +. 1e-8)) rewards
620
+
621
+ (* GRPO training step *)
622
+ let train_grpo_step env policy_net params old_params group_size epsilon beta =
623
+ let device = Rune.c in
624
+
625
+ (* Get initial state *)
626
+ let init_obs, _ = env.reset () in
627
+
628
+ (* Collect group of trajectories from same starting point *)
629
+ let group = collect_group_trajectories env policy_net params init_obs group_size in
630
+
631
+ (* Extract returns for each trajectory *)
632
+ let group_returns = Array.map (fun traj ->
633
+ let returns = compute_returns traj.rewards 0.99 in
634
+ returns.(0) (* Total return *)
635
+ ) group in
636
+
637
+ (* Compute group-relative advantages *)
638
+ let group_advantages = compute_group_advantages group_returns in
639
+
640
+ (* Update policy using clipped objective with KL penalty *)
641
+ let loss, grads = Kaun.value_and_grad (fun p ->
642
+ let total_loss = ref (Rune.zeros device Rune.float32 [||]) in
643
+
644
+ Array.iteri (fun g_idx trajectory ->
645
+ let advantage = group_advantages.(g_idx) in
646
+
647
+ Array.iteri (fun t state ->
648
+ let action = trajectory.actions.(t) in
649
+
650
+ (* Compute new and old log probs *)
651
+ let new_logits = Kaun.apply policy_net p ~training:true state in
652
+ let new_log_probs = Rune.log_softmax ~axis:(-1) new_logits in
653
+ let new_action_log_prob = Rune.gather new_log_probs action in
654
+
655
+ let old_logits = Kaun.apply policy_net old_params ~training:false state in
656
+ let old_log_probs = Rune.log_softmax ~axis:(-1) old_logits in
657
+ let old_action_log_prob = Rune.gather old_log_probs action in
658
+
659
+ (* Compute ratio and clip *)
660
+ let log_ratio = Rune.sub new_action_log_prob old_action_log_prob in
661
+ let ratio = Rune.exp log_ratio in
662
+ let clipped_ratio = clip_ratio ratio epsilon in
663
+
664
+ (* Clipped objective *)
665
+ let adv_scalar = Rune.scalar device Rune.float32 advantage in
666
+ let obj1 = Rune.mul ratio adv_scalar in
667
+ let obj2 = Rune.mul clipped_ratio adv_scalar in
668
+ let clipped_obj = Rune.minimum obj1 obj2 in
669
+
670
+ (* Add KL penalty *)
671
+ let kl_penalty = Rune.mul
672
+ (Rune.scalar device Rune.float32 beta)
673
+ (Rune.sub old_action_log_prob new_action_log_prob) in
674
+
675
+ let step_loss = Rune.sub (Rune.neg clipped_obj) kl_penalty in
676
+ total_loss := Rune.add !total_loss step_loss
677
+ ) trajectory.states
678
+ ) group;
679
+
680
+ (* Average over all steps and trajectories *)
681
+ let total_steps = Array.fold_left (fun acc traj ->
682
+ acc + Array.length traj.states) 0 group in
683
+ Rune.div !total_loss (Rune.scalar device Rune.float32 (float_of_int total_steps))
684
+ ) params in
685
+
686
+ (loss, grads)
687
+
688
+ (* Complete GRPO training loop *)
689
+ let train_grpo env n_iterations group_size learning_rate epsilon beta =
690
+ let device = Rune.c in
691
+ let rng = Rune.Rng.key 42 in
692
+
693
+ (* Initialize policy *)
694
+ let policy_net = create_policy_network 5 4 in
695
+ let dummy_obs = Rune.zeros device Rune.float32 [|5; 5|] in
696
+ let params = Kaun.init policy_net ~rngs:rng dummy_obs in
697
+ let old_params = ref (Ptree.copy params) in (* Keep old params for ratios *)
698
+
699
+ (* Optimizer *)
700
+ let optimizer = Kaun.Optimizer.adam ~lr:learning_rate () in
701
+ let opt_state = ref (optimizer.init params) in
702
+
703
+ for iter = 1 to n_iterations do
704
+ (* GRPO update *)
705
+ let loss, grads = train_grpo_step env policy_net params !old_params
706
+ group_size epsilon beta in
707
+
708
+ (* Apply gradients *)
709
+ let updates, new_state = optimizer.update !opt_state params grads in
710
+ opt_state := new_state;
711
+ Kaun.Optimizer.apply_updates_inplace params updates;
712
+
713
+ (* Update old params periodically *)
714
+ if iter mod 5 = 0 then
715
+ old_params := Ptree.copy params;
716
+
717
+ if iter mod 10 = 0 then
718
+ Printf.printf "Iteration %d: Loss = %.4f\n"
719
+ iter (Rune.unsafe_get [] loss)
720
+ done;
721
+
722
+ (policy_net, params)
549
723
```
550
724
551
725
***
0 commit comments