Page 254 - Understanding Machine Learning
P. 254

Neural Networks
           236

                 same cryptographic assumption, any hypothesis class which contains intersections
                 of halfspaces cannot be learned efficiently.
                    A widely used heuristic for training neural networks relies on the SGD frame-
                 work we studied in Chapter 14. There, we have shown that SGD is a successful
                 learner if the loss function is convex. In neural networks, the loss function is highly
                 nonconvex. Nevertheless, we can still implement the SGD algorithm and hope
                 it will find a reasonable solution (as happens to be the case in several practical
                 tasks).




                 20.6 SGD AND BACKPROPAGATION
                 The problem of finding a hypothesis in H V ,E,σ with a low risk amounts to the prob-
                 lem of tuning the weights over the edges. In this section we show how to apply a
                 heuristic search for good weights using the SGD algorithm. Throughout this section
                 we assume that σ is the sigmoid function, σ(a) = 1/(1 + e −a ), but the derivation
                 holds for any differentiable scalar function.
                    Since E is a finite set, we can think of the weight function as a vector w ∈ R |E| .
                 Suppose the network has n input neurons and k output neurons, and denote by
                            k
                      n
                 h w : R → R the function calculated by the network if the weight function is defined
                 by w. Let us denote by  (h w (x),y) the loss of predicting h w (x) when the target
                 is y ∈ Y. For concreteness, we will take   to be the squared loss,  (h w (x), y) =
                  1  h w (x) − y  ; however, similar derivation can be obtained for every differentiable
                             2
                  2
                                                                                n
                                                                                     k
                 function. Finally, given a distribution D over the examples domain, R × R ,let
                  L D (w) be the risk of the network, namely,
                                         L D (w) =  E  [ (h w (x),y)].
                                                 (x,y)∼D

                    Recall the SGD algorithm for minimizing the risk function L D (w). We repeat
                 the pseudocode from Chapter 14 with a few modifications, which are relevant to the
                 neural network application because of the nonconvexity of the objective function.
                 First, while in Chapter 14 we initialized w to be the zero vector, here we initialize w
                 to be a randomly chosen vector with values close to zero. This is because an initial-
                 ization with the zero vector will lead all hidden neurons to have the same weights
                 (if the network is a full layered network). In addition, the hope is that if we repeat
                 the SGD procedure several times, where each time we initialize the process with
                 a new random vector, one of the runs will lead to a good local minimum. Second,
                 while a fixed step size, η, is guaranteed to be good enough for convex problems,
                 here we utilize a variable step size, η t , as defined in Section 14.4.2. Because of the
                 nonconvexity of the loss function, the choice of the sequence η t is more significant,
                 and it is tuned in practice by a trial and error manner. Third, we output the best
                 performing vector on a validation set. In addition, it is sometimes helpful to add reg-
                 ularization on the weights, with parameter λ. That is, we try to minimize L D (w) +
                  λ  w  . Finally, the gradient does not have a closed form solution. Instead, it is
                      2
                  2
                 implemented using the backpropagation algorithm, which will be described in the
                 sequel.
   249   250   251   252   253   254   255   256   257   258   259