Author: Apoorv Nandan
Date created: 2020/05/13
Last modified: 2024/02/22
Description: Implement Actor Critic Method in CartPole environment.
This script shows an implementation of Actor Critic method on CartPole-V0 environment.
As an agent takes actions and moves through an environment, it learns to map the observed state of the environment to two possible outputs:
Agent and Critic learn to perform their tasks, such that the recommended actions from the actor maximize the rewards.
A pole is attached to a cart placed on a frictionless track. The agent has to apply force to move the cart. It is rewarded for every time step the pole remains upright. The agent, therefore, must learn to keep the pole from falling over.
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import gym
import numpy as np
import keras
from keras import ops
from keras import layers
import tensorflow as tf
# Configuration parameters for the whole setup
seed = 42
gamma = 0.99 # Discount factor for past rewards
max_steps_per_episode = 10000
env = gym.make("CartPole-v0") # Create the environment
env.seed(seed)
eps = np.finfo(np.float32).eps.item() # Smallest number such that 1.0 + eps != 1.0
This network learns two functions:
In our implementation, they share the initial layer.
num_inputs = 4
num_actions = 2
num_hidden = 128
inputs = layers.Input(shape=(num_inputs,))
common = layers.Dense(num_hidden, activation="relu")(inputs)
action = layers.Dense(num_actions, activation="softmax")(common)
critic = layers.Dense(1)(common)
model = keras.Model(inputs=inputs, outputs=[action, critic])
optimizer = keras.optimizers.Adam(learning_rate=0.01)
huber_loss = keras.losses.Huber()
action_probs_history = []
critic_value_history = []
rewards_history = []
running_reward = 0
episode_count = 0
while True: # Run until solved
state = env.reset()
episode_reward = 0
with tf.GradientTape() as tape:
for timestep in range(1, max_steps_per_episode):
# env.render(); Adding this line would show the attempts
# of the agent in a pop up window.
state = ops.convert_to_tensor(state)
state = ops.expand_dims(state, 0)
# Predict action probabilities and estimated future rewards
# from environment state
action_probs, critic_value = model(state)
critic_value_history.append(critic_value[0, 0])
# Sample action from action probability distribution
action = np.random.choice(num_actions, p=np.squeeze(action_probs))
action_probs_history.append(ops.log(action_probs[0, action]))
# Apply the sampled action in our environment
state, reward, done, _ = env.step(action)
rewards_history.append(reward)
episode_reward += reward
if done:
break
# Update running reward to check condition for solving
running_reward = 0.05 * episode_reward + (1 - 0.05) * running_reward
# Calculate expected value from rewards
# - At each timestep what was the total reward received after that timestep
# - Rewards in the past are discounted by multiplying them with gamma
# - These are the labels for our critic
returns = []
discounted_sum = 0
for r in rewards_history[::-1]:
discounted_sum = r + gamma * discounted_sum
returns.insert(0, discounted_sum)
# Normalize
returns = np.array(returns)
returns = (returns - np.mean(returns)) / (np.std(returns) + eps)
returns = returns.tolist()
# Calculating loss values to update our network
history = zip(action_probs_history, critic_value_history, returns)
actor_losses = []
critic_losses = []
for log_prob, value, ret in history:
# At this point in history, the critic estimated that we would get a
# total reward = `value` in the future. We took an action with log probability
# of `log_prob` and ended up receiving a total reward = `ret`.
# The actor must be updated so that it predicts an action that leads to
# high rewards (compared to critic's estimate) with high probability.
diff = ret - value
actor_losses.append(-log_prob * diff) # actor loss
# The critic must be updated so that it predicts a better estimate of
# the future rewards.
critic_losses.append(
huber_loss(ops.expand_dims(value, 0), ops.expand_dims(ret, 0))
)
# Backpropagation
loss_value = sum(actor_losses) + sum(critic_losses)
grads = tape.gradient(loss_value, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
# Clear the loss and reward history
action_probs_history.clear()
critic_value_history.clear()
rewards_history.clear()
# Log details
episode_count += 1
if episode_count % 10 == 0:
template = "running reward: {:.2f} at episode {}"
print(template.format(running_reward, episode_count))
if running_reward > 195: # Condition to consider the task solved
print("Solved at episode {}!".format(episode_count))
break
running reward: 8.82 at episode 10
running reward: 23.04 at episode 20
running reward: 28.41 at episode 30
running reward: 53.59 at episode 40
running reward: 53.71 at episode 50
running reward: 77.35 at episode 60
running reward: 74.76 at episode 70
running reward: 57.89 at episode 80
running reward: 46.59 at episode 90
running reward: 43.48 at episode 100
running reward: 63.77 at episode 110
running reward: 111.13 at episode 120
running reward: 142.77 at episode 130
running reward: 127.96 at episode 140
running reward: 113.92 at episode 150
running reward: 128.57 at episode 160
running reward: 139.95 at episode 170
running reward: 154.95 at episode 180
running reward: 171.45 at episode 190
running reward: 171.33 at episode 200
running reward: 177.74 at episode 210
running reward: 184.76 at episode 220
running reward: 190.88 at episode 230
running reward: 154.78 at episode 240
running reward: 114.38 at episode 250
running reward: 107.51 at episode 260
running reward: 128.99 at episode 270
running reward: 157.48 at episode 280
running reward: 174.54 at episode 290
running reward: 184.76 at episode 300
running reward: 190.87 at episode 310
running reward: 194.54 at episode 320
Solved at episode 322!
In early stages of training:
In later stages of training: