Overview

When policy evaluation is stopped after just one update of each state, the algorithm is called value iteration. It can be written as a particularly simple update operation that combines the policy improvement and truncated policy evaluation steps [1]. This post looks at the value iteration algorithm.

Value iteration

One drawback of policy iteration is that each of its iterations involves a policy evaluation. This, however, may itself be an iterative computation; therefore requiring multiple sweeps through the state set [1]. Furthermore, if the evaluation is done iteratively, then convergence to $V_{\pi}$ occurs only in the limit [1].

Given the above limitations of policy iterations, the question posed is whether we could we stop earlier? [1]. Luckily, the policy evaluation step of policy iteration can truncated without loosing the convergence gurantees of the method. Moreover, this can be done in several ways [1].

In particular, when policy evaluation is stopped after just one update of each state, the algorithm is called value iteration. It can be written as a particularly simple update operation that combines the policy improvement and truncated policy evaluation steps [1]

$$V_{k+1}(s) = max_{\alpha}\sum_{s^*, r}p(s^*, r | s, \alpha)\left[r + \gamma V_{k}(s^*)\right], ~~ \forall s \in \mathbb{S}$$

Figure 1: Value iteration algorithm. Image from [1].

The value iteration is obtained simply by turning the Bellman optimality equation into an update rule [1]. It requires the maximum to be taken over all the actions. Furthermore, the algorithm terminates by checking the amount of change of the value function.

Code

import scala.collection.mutable.ArrayBuffer
import scala.util.control.Breaks._
import scala.math.max
import scala.collection.mutable.ArrayBuffer

import scala.util.control.Breaks._

import scala.math.max
object Grid{
    
    class State(val idx: Int){
    
        val neigbors = new ArrayBuffer[Int]()
        for(i <- 0 until 4){
            neigbors += -1
        }
        
        def addNeighbors(neighbors: Array[Int]): Unit = {
            require(neighbors.length == 4)
            
            for(n <- 0 until neighbors.length){
                addNeighbor(n, neighbors(n))
            }
        }
        
        def addNeighbor(idx: Int, nIdx: Int): Unit = {
            require(idx < 4)
            neigbors(idx) = nIdx
        }
        
        def getNeighbor(idx: Int): Int = {
            require(idx < 4)
            return neigbors(idx)
        }
    }
}
defined object Grid
class Grid{
    
    val states = new ArrayBuffer[Grid.State]()
    
    
    def nStates : Int = states.length
    def nActions: Int = 4
    def envDynamics(state: Grid.State, action: Int): (Double, Int, Double, Boolean) = {
        (0.25, states(state.idx).getNeighbor(action), 1.0, false)
    }
    
    def getState(idx: Int): Grid.State = {
    
        require(idx < nStates)
        states(idx)  
    }
    
    def create(): Unit = {
        
        // add a new state
        for(s <- 0 until 9){
            
            states += new Grid.State(s)
        
            if(s == 0){
                states(s).addNeighbors(Array(0, 1, 3, 0))
            }
            else if(s == 1){
                states(s).addNeighbors(Array(1, 2, 4, 0))   
            }
            else if(s == 2){
                states(s).addNeighbors(Array(2, 2, 5, 1)) 
            }
            else if(s == 3){
                states(s).addNeighbors(Array(0, 4, 6, 3)) 
            }
            else if(s == 4){
                states(s).addNeighbors(Array(1, 5, 7, 3)) 
            }
            else if(s == 5){
                states(s).addNeighbors(Array(2, 5, 8, 4)) 
            }
            else if(s == 6){
                states(s).addNeighbors(Array(3, 7, 6, 6)) 
            }
            else if(s == 7){
                states(s).addNeighbors(Array(4, 8, 7, 6)) 
            }
            else if(s == 8){
                states(s).addNeighbors(Array(5, 8, 8, 7)) 
            }
        }
    }
    
}
defined class Grid
class ValueIteration(val numIterations: Int, val tolerance: Double,
                    val gamma: Double){
    
    val valueF = new ArrayBuffer[Double]()
    var residual = 1.0
    
    def train(grid: Grid): Unit = {
        
        valueF.clear()
        
        for(i <- 0 until grid.nStates){
            valueF += 0.0
        }
        
        breakable {
            
            for(itr <- Range(0, numIterations)){

                println("> Learning iteration " + itr)
                println("> Learning residual " + residual)

                step(grid)

                if(residual < tolerance) break;
            }
        }   
    }
    
    
    
    def step(grid: Grid): Unit = {
        
        var delta: Double = 0.0
        
        for(sIdx <- 0 until grid.nStates){
            
            // Do a one-step lookahead to find the best action
            val lookAheadVals = this.one_step_lookahead(grid, grid.getState(sIdx))
            val maxActionValue = lookAheadVals.max
            delta = max(delta, (maxActionValue - valueF(sIdx).abs))
            
            // # Update the value function. Ref: Sutton book eq. 4.10.
            valueF(sIdx) = maxActionValue
        }
                        
        this.residual = delta
        
    }
    
    // Helper function to calculate the value for 
    // all action in a given state.
    // Returns a vector of length grid.nActions containing 
    // the expected value of each action.
    def one_step_lookahead(grid: Grid, state: Grid.State): ArrayBuffer[Double] = {
        
         val values = new ArrayBuffer[Double](grid.nActions)
        
         for(i <- 0 until grid.nActions){
                values += 0.0
         }
        
         for(i <- 0 until grid.nActions){
                val (prob, next_state, reward, done) = grid.envDynamics(state, i)
                val oldVal = values(i)
                values(i) = oldVal + prob * (reward + this.gamma * valueF(next_state))
                
         }
                                     
         values  
    }
    
       
}
defined class ValueIteration
val grid = new Grid
grid.create()
grid: Grid = ammonite.$sess.cmd2$Helper$Grid@794b701a
val valueFunction = new ValueIteration(100, 1.0e-4, 1.0)
valueFunction: ValueIteration = ammonite.$sess.cmd29$Helper$ValueIteration@44c65bd3
valueFunction.train(grid)
> Learning iteration 0
> Learning residual 1.0
> Learning iteration 1
> Learning residual 0.3330078125
> Learning iteration 2
> Learning residual 0.078125
> Learning iteration 3
> Learning residual 0.0048828125
> Learning iteration 4
> Learning residual 3.0517578125E-4

References

  1. Richard S. Sutton and Andrew G. Barto, Reinforcement Learning: An Introduction.