diff --git a/train-cifar.lua b/train-cifar.lua index fe4a889..3deb4a3 100644 --- a/train-cifar.lua +++ b/train-cifar.lua @@ -184,7 +184,10 @@ function forwardBackwardBatch(batch) --]] -- From https://github.com/bgshih/cifar.torch/blob/master/train.lua#L119-L128 - if sgdState.epochCounter < 80 then + if sgdState.nEvalCounter < 400 then + sgdState.learningRate = 0.01 + print "Warmup" + elseif sgdState.epochCounter < 80 then sgdState.learningRate = 0.1 elseif sgdState.epochCounter < 120 then sgdState.learningRate = 0.01