Skip to content

Commit 89c7d75

Browse files
Upgrade to Gymnasium rather than Gym (#1381)
1 parent 32fdb49 commit 89c7d75

File tree

7 files changed

+27
-33
lines changed

7 files changed

+27
-33
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ https://pytorch.org/examples/
2121
- [Variational Auto-Encoders](./vae/README.md)
2222
- [Superresolution using an efficient sub-pixel convolutional neural network](./super_resolution/README.md)
2323
- [Hogwild training of shared ConvNets across multiple processes on MNIST](mnist_hogwild)
24-
- [Training a CartPole to balance in OpenAI Gym with actor-critic](./reinforcement_learning/README.md)
24+
- [Training a CartPole to balance with actor-critic](./reinforcement_learning/README.md)
2525
- [Natural Language Inference (SNLI) with GloVe vectors, LSTMs, and torchtext](snli)
2626
- [Time sequence prediction - use an LSTM to learn Sine waves](./time_sequence_prediction/README.md)
2727
- [Implement the Neural Style Transfer algorithm on images](./fast_neural_style/README.md)

distributed/rpc/batch/reinforce.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import argparse
2-
import gym
2+
import gymnasium as gym
33
import os
44
import threading
55
import time
@@ -68,7 +68,7 @@ class Observer:
6868
def __init__(self, batch=True):
6969
self.id = rpc.get_worker_info().id - 1
7070
self.env = gym.make('CartPole-v1')
71-
self.env.seed(args.seed)
71+
self.env.reset(seed=args.seed)
7272
self.select_action = Agent.select_action_batch if batch else Agent.select_action
7373

7474
def run_episode(self, agent_rref, n_steps):
@@ -92,10 +92,10 @@ def run_episode(self, agent_rref, n_steps):
9292
)
9393

9494
# apply the action to the environment, and get the reward
95-
state, reward, done, _ = self.env.step(action)
95+
state, reward, terminated, truncated, _ = self.env.step(action)
9696
rewards[step] = reward
9797

98-
if done or step + 1 >= n_steps:
98+
if terminated or truncated or step + 1 >= n_steps:
9999
curr_rewards = rewards[start_step:(step + 1)]
100100
R = 0
101101
for i in range(curr_rewards.numel() -1, -1, -1):
@@ -226,8 +226,7 @@ def run_worker(rank, world_size, n_episode, batch, print_log=True):
226226
last_reward, running_reward = agent.run_episode(n_steps=NUM_STEPS)
227227

228228
if print_log:
229-
print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
230-
i_episode, last_reward, running_reward))
229+
print(f'Episode {i_episode}\tLast reward: {last_reward:.2f}\tAverage reward: {running_reward:.2f}')
231230
else:
232231
# other ranks are the observer
233232
rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
torch==2.2.0
22
torchvision==0.7.0
33
numpy
4-
gym
4+
gymnasium

docs/source/index.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,12 @@ experiment with PyTorch.
8888
`GO TO EXAMPLE <https://github.com/pytorch/examples/blob/main/mnist_hogwild>`__ :opticon:`link-external`
8989

9090
---
91-
Training a CartPole to balance in OpenAI Gym with actor-critic
92-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
91+
Training a CartPole to balance with actor-critic
92+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
9393

9494
This reinforcement learning tutorial demonstrates how to train a
9595
CartPole to balance
96-
in the `OpenAI Gym <https://gym.openai.com/>`__ toolkit by using the
96+
in the `Gymnasium <https://gymnasium.farama.org/>`__ toolkit by using the
9797
`Actor-Critic <https://proceedings.neurips.cc/paper/1999/file/6449f44a102fde848669bdd9eb6b76fa-Paper.pdf>`__ method.
9898

9999
`GO TO EXAMPLE <https://github.com/pytorch/examples/blob/main/reinforcement_learning>`__ :opticon:`link-external`

reinforcement_learning/actor_critic.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import argparse
2-
import gym
2+
import gymnasium as gym
33
import numpy as np
44
from itertools import count
55
from collections import namedtuple
@@ -24,7 +24,8 @@
2424
args = parser.parse_args()
2525

2626

27-
env = gym.make('CartPole-v1')
27+
render_mode = "human" if args.render else None
28+
env = gym.make('CartPole-v1', render_mode=render_mode)
2829
env.reset(seed=args.seed)
2930
torch.manual_seed(args.seed)
3031

@@ -152,14 +153,11 @@ def main():
152153
action = select_action(state)
153154

154155
# take the action
155-
state, reward, done, _, _ = env.step(action)
156-
157-
if args.render:
158-
env.render()
156+
state, reward, terminated, truncated, _ = env.step(action)
159157

160158
model.rewards.append(reward)
161159
ep_reward += reward
162-
if done:
160+
if terminated or truncated:
163161
break
164162

165163
# update cumulative reward
@@ -170,13 +168,12 @@ def main():
170168

171169
# log results
172170
if i_episode % args.log_interval == 0:
173-
print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
174-
i_episode, ep_reward, running_reward))
171+
print(f'Episode {i_episode}\tLast reward: {ep_reward:.2f}\tAverage reward: {running_reward:.2f}')
175172

176173
# check if we have "solved" the cart pole problem
177174
if running_reward > env.spec.reward_threshold:
178-
print("Solved! Running reward is now {} and "
179-
"the last episode runs to {} time steps!".format(running_reward, t))
175+
print(f"Solved! Running reward is now {running_reward} and "
176+
f"the last episode runs to {t} time steps!")
180177
break
181178

182179

reinforcement_learning/reinforce.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import argparse
2-
import gym
2+
import gymnasium as gym
33
import numpy as np
44
from itertools import count
55
from collections import deque
@@ -22,7 +22,8 @@
2222
args = parser.parse_args()
2323

2424

25-
env = gym.make('CartPole-v1')
25+
render_mode = "human" if args.render else None
26+
env = gym.make('CartPole-v1', render_mode=render_mode)
2627
env.reset(seed=args.seed)
2728
torch.manual_seed(args.seed)
2829

@@ -85,22 +86,20 @@ def main():
8586
ep_reward = 0
8687
for t in range(1, 10000): # Don't infinite loop while learning
8788
action = select_action(state)
88-
state, reward, done, _, _ = env.step(action)
89+
state, reward, terminated, truncated, _ = env.step(action)
8990
if args.render:
9091
env.render()
9192
policy.rewards.append(reward)
9293
ep_reward += reward
93-
if done:
94+
if terminated or truncated:
9495
break
9596

9697
running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
9798
finish_episode()
9899
if i_episode % args.log_interval == 0:
99-
print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
100-
i_episode, ep_reward, running_reward))
100+
print(f'Episode {i_episode}\tLast reward: {ep_reward:.2f}\tAverage reward: {running_reward:.2f}')
101101
if running_reward > env.spec.reward_threshold:
102-
print("Solved! Running reward is now {} and "
103-
"the last episode runs to {} time steps!".format(running_reward, t))
102+
print(f"Solved! Running reward is now {running_reward} and the last episode runs to {t} time steps!")
104103
break
105104

106105

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
torch
2-
numpy<2
3-
gym
4-
pygame
2+
numpy
3+
gymnasium[classic-control]

0 commit comments

Comments
 (0)