World Models

Can agents learn inside their own dreams? World models proposes a three part biologically inspired cognitive system. A visual sensory component compresses high-dimensional observations into a low-dimensional latent vector. A memory component makes predictions about future codes based on historical information. A controller component makes decisions about what actions to take based only on the representations created by the vision and memory components.

The vision component is modeled with a variational autoencoder that compresses each frame it receives at time step $t$ into a low dimensional latent vector $z_t$​​. This compressed representation can be used to reconstruct the original image.

The memory component is modeled with a mixture density recurrent neural network which is a predictive model of the future $z$ vectors that the variational autoencoder is expected to produce. This model is trained to output a probability density function $p(z)$ instead of a deterministic prediction of $z$, where the probability density function $p(z)$ is approximated as a mixture of Gaussians. The model objective is to output the probability distribution of the next latent vector $z_{t+1}$​​ given the current and past information made available to it.

$$ P(z_{t+1}∣a_t,z_t,h_t) $$

where $a_t$​​ is the action taken at time $t$ and $h_t$​​​ is the hidden state of the RNN at time $t$. During sampling, a temperature parameter $\tau$ can be adjusted to control model uncertainty.

The controller component is a linear mapping that uses representations from the visual and memory components to select good actions. This model is deliberately made as simple and small as possible, and trained separately from the vision and memory networks, so that most of the agent’s complexity resides in the world model.

$$ a_t=W_c[z_t h_t] + b_c $$

where $W_c$​​ and $b_c$​​ are the weight matrix and bias vector that maps the concatenated input vector $[z_t h_t]$ to the output action vector $a_t$​​​.

Since the world model is trained to predict future states, it can be used to generate synthetic sequences of observations on its own. The controller can be trained entirely on these synthetic sequences of latent states which is a process akin to dreaming.