diff --git a/train-cifar.lua b/train-cifar.lua index e114829..d3b026e 100644 --- a/train-cifar.lua +++ b/train-cifar.lua @@ -178,7 +178,9 @@ function forwardBackwardBatch(batch) --]] -- From https://github.com/bgshih/cifar.torch/blob/master/train.lua#L119-L128 - if sgdState.epochCounter < 80 then + if sgdState.nEvalCounter < 400 then + sgdState.learningRate = 0.01 + elseif sgdState.epochCounter < 80 then sgdState.learningRate = 0.1 elseif sgdState.epochCounter < 120 then sgdState.learningRate = 0.01