Shortcuts

pl_bolts.models.rl.dqn_model module

Deep Q Network

class pl_bolts.models.rl.dqn_model.DQN(env, gpus=0, eps_start=1.0, eps_end=0.02, eps_last_frame=150000, sync_rate=1000, gamma=0.99, learning_rate=0.0001, batch_size=32, replay_size=100000, warm_start_size=10000, num_samples=500, **kwargs)[source]

Bases: pytorch_lightning.LightningModule

Basic DQN Model

PyTorch Lightning implementation of DQN

Paper authors: Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, Ioannis Antonoglou, Daan Wierstra, Martin Riedmiller.

Model implemented by:

  • Donal Byrne <https://github.com/djbyrne>

Example

>>> from pl_bolts.models.rl.dqn_model import DQN
...
>>> model = DQN("PongNoFrameskip-v4")

Train:

trainer = Trainer()
trainer.fit(model)
Parameters
  • env (str) – gym environment tag

  • gpus (int) – number of gpus being used

  • eps_start (float) – starting value of epsilon for the epsilon-greedy exploration

  • eps_end (float) – final value of epsilon for the epsilon-greedy exploration

  • eps_last_frame (int) – the final frame in for the decrease of epsilon. At this frame espilon = eps_end

  • sync_rate (int) – the number of iterations between syncing up the target network with the train network

  • gamma (float) – discount factor

  • learning_rate (float) – learning rate

  • batch_size (int) – size of minibatch pulled from the DataLoader

  • replay_size (int) – total capacity of the replay buffer

  • warm_start_size (int) – how many random steps through the environment to be carried out at the start of training to fill the buffer with a starting point

  • num_samples (int) – the number of samples to pull from the dataset iterator and feed to the DataLoader

Note

This example is based on:

https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition /blob/master/Chapter06/02_dqn_pong.py

Note

Currently only supports CPU and single GPU training with distributed_backend=dp

static add_model_specific_args(arg_parser)[source]

Adds arguments for DQN model

Note: these params are fine tuned for Pong env

Parameters

arg_parser (ArgumentParser) – parent parser

Return type

ArgumentParser

build_networks()[source]

Initializes the DQN train and target networks

Return type

None

configure_optimizers()[source]

Initialize Adam optimizer

Return type

List[Optimizer]

forward(x)[source]

Passes in a state x through the network and gets the q_values of each action as an output

Parameters

x (Tensor) – environment state

Return type

Tensor

Returns

q values

populate(warm_start)[source]

Populates the buffer with initial experience

Return type

None

prepare_data()[source]

Initialize the Replay Buffer dataset used for retrieving experiences

Return type

None

test_dataloader()[source]

Get test loader

Return type

DataLoader

test_epoch_end(outputs)[source]

Log the avg of the test results

Return type

Dict[str, Tensor]

test_step(*args, **kwargs)[source]

Evaluate the agent for 10 episodes

Return type

Dict[str, Tensor]

train_dataloader()[source]

Get train loader

Return type

DataLoader

training_step(batch, _)[source]

Carries out a single step through the environment to update the replay buffer. Then calculates loss based on the minibatch recieved

Parameters
Return type

OrderedDict

Returns

Training loss and log metrics

Read the Docs v: 0.1.0
Versions
latest
stable
0.1.1
0.1.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.