-
Notifications
You must be signed in to change notification settings - Fork 0
/
random_agent.py
54 lines (48 loc) · 1.74 KB
/
random_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gym
from gym.wrappers import RecordVideo
from itertools import count
from src.my_utils import create_plot, plot_durations
ENV_NAME = 'CartPole-v1'
# Random search strategy implementation
def random_search(env, episodes):
rewards = []
episode_durations = []
for i in range(episodes):
state = env.reset()
done = False
total = 0
while not done:
for t in count():
# Sample a random action from action space
action = env.action_space.sample()
# Apply action on environment
next_state, reward, done, _, _ = env.step(action)
# Update total reward
total += reward
# render environment (needed to record episodes)
# env.render()
if done:
episode_durations.append(t + 1)
plot_durations(episode_durations)
break
# Add total to the list of rewards
rewards.append(total)
return rewards
def main():
# Random agent
n_episodes = 150
path_video = './episodes/random_agent/'
plot_title = 'Random Strategy'
# visualize pygame window (UI) change render_mode to "human"
env = gym.make(ENV_NAME, render_mode="rgb_array")
env.action_space.seed(42)
# Record one episode every 25 episodes
env = RecordVideo(env, path_video, episode_trigger=lambda x: x % 25 == 0, name_prefix=format(ENV_NAME))
print("Starting with", n_episodes, "episodes ...")
total = random_search(env, n_episodes)
create_plot(plot_title, total, n_episodes)
average_rewards = sum(total) / len(total)
print(average_rewards)
print(max(total))
if __name__ == "__main__":
main()