Page 257 - Understanding Machine Learning
P. 257

20.6 SGD and Backpropagation  239


              obtain that the loss, as a function of W t−1 , can be written as
                             g t (W t−1 ) =   t (o t ) =   t (σ(a t )) =   t (σ(W t−1 o t−1 )).
              It would be convenient to rewrite this as follows. Let w t−1 ∈ R k t−1 k t  be the column
              vector obtained by concatenating the rows of W t−1 and then taking the transpose of
              the resulting long vector. Define by O t−1 the k t × (k t−1 k t )matrix
                                                             
                                           o 
    0    ···  0
                                             t−1
                                                              
                                           0    o 
   ···  0 
                                          
                                                  t−1
                                                             
                                  O t−1 =   .     .   .    .  .               (20.2)
                                           .      .    .
                                           .      .    .   . 
                                                            . 
                                             0    0    ···  o
                                                            t−1
              Then, W t−1 o t−1 = O t−1 w t−1 , so we can also write
                                     g t (w t−1 ) =   t (σ(O t−1 w t−1 )).
              Therefore, applying the chain rule, we obtain that
                               (g t ) = J σ(O t−1 w t−1 ) (  t )diag(σ (O t−1 w t−1 )) O t−1 .

                           J w t−1
              Using our notation we have o t = σ(O t−1 w t−1 )and a t = O t−1 w t−1 , which yields
                                              (  t )diag(σ (a t )) O t−1 .

                                  J w t−1  (g t ) = J o t
                                    (  t ). Then, we can further rewrite the preceding as
              Let us also denote δ t = J o t


                               (g t ) = δ t,1 σ (a t,1 )o 
       )o 
  .       (20.3)
                           J w t−1              t−1  , ... , δ t,k t  σ (a t,k t  t−1
                 It is left to calculate the vector δ t = J o t  (  t ) for every t. This is the gradient of   t
              at o t . We calculate this in a recursive manner. First observe that for the last layer
              we have that   T (u) =  (u,y), where   is the loss function. Since we assume that
                       1      2
                       2
               (u,y) =  u − y  we obtain that J u (  T ) = (u − y). In particular, δ T = J o T  (  T ) =
              (o T − y). Next, note that
                                          t (u) =   t+1 (σ(W t u)).
              Therefore, by the chain rule,

                                 J u (  t ) = J σ(W t u) (  t+1 )diag(σ (W t u))W t .

              In particular,

                              δ t = J o t  (  t ) = J σ(W t o t ) (  t+1 )diag(σ (W t o t ))W t


                                      (  t+1 )diag(σ (a t+1 ))W t
                                = J o t+1
                                = δ t+1 diag(σ (a t+1 ))W t .

                 In summary, we can first calculate the vectors {a t ,o t } from the bottom of the
              network to its top. Then, we calculate the vectors {δ t } from the top of the network
              back to its bottom. Once we have all of these vectors, the partial derivatives are
              easily obtained using Equation (20.3). We have thus shown that the pseudocode of
              backpropagation indeed calculates the gradient.
   252   253   254   255   256   257   258   259   260   261   262