I'm learning LSTM networks and decided to try synthetic test. I want LSTM network fed by some points (x,y) to distinguish between three basic functions:
- line: y = k*x + b
- parabola: y = k*x^2 + b
- sqrt: y = k*sqrt(x) + b
I'm using lua + torch.
Dataset is totally virtual - it is created on-the-fly at the 'dataset' object. When training cycle asks for another minibatch of samples, function mt.__index returns sample, created dynamically. It randomly selects on of the three described functions and picks some random points for them.
Idea is that LSTM network would learn some features to recognize what kind of a function do last points belong to.
Full yet simple source script included:
require "torch"
require "nn"
require "rnn"
-- hyper-parameters
batchSize = 8
rho = 5 -- sequence length
hiddenSize = 100
outputSize = 3
lr = 0.001
-- Initialize synthetic dataset
-- dataset[index] returns table of the form: {inputs, targets}
-- where inputs is a set of points (x,y) of a randomly selected function: line, parabola, sqrt
-- and targets is a set of corresponding class of a function (1=line, 2=parabola, 3=sqrt)
local dataset = {}
dataset.size = function (self)
return 1000
end
local mt = {}
mt.__index = function (self, i)
local class = math.random(3)
local t = torch.Tensor(3):zero()
t[class] = 1
local targets = {}
for i = 1,batchSize do table.insert(targets, class) end
local inputs = {}
local k = math.random()
local b = math.random()*5
-- Line
if class == 1 then
for i = 1,batchSize do
local x = math.random()*10 + 5
local y = k*x + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
-- Parabola
elseif class == 2 then
for i = 1,batchSize do
local x = math.random()*10 + 5
local y = k*x*x + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
-- Sqrt
else
for i = 1,batchSize do
local x = math.random()*5 + 5
local y = k*math.sqrt(x) + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
end
return { inputs, targets }
end -- dataset.__index meta function
setmetatable(dataset, mt)
-- Initialize random number generator
math.randomseed( os.time() )
-- build simple recurrent neural network
local model = nn.Sequencer(
nn.Sequential()
:add( nn.LSTM(2, hiddenSize, rho) )
:add( nn.Linear(hiddenSize, outputSize) )
:add( nn.LogSoftMax() )
)
print(model)
-- build criterion
local criterion = nn.SequencerCriterion( nn.ClassNLLCriterion() )
-- training
model:training()
local epoch = 1
while true do
print ("Epoch "..tostring(epoch).." started")
for iteration = 1, dataset:size() do
-- 1. Load minibatch of samples
local sample = dataset[iteration] -- pick random sample (dataset always returns random set)
local inputs = sample[1]
local targets = sample[2]
-- 2. Perform forward run and calculate error
local outputs = model:forward(inputs)
local err = criterion:forward(outputs, targets)
print(string.format("Epoch %d Iteration %d Error = %f", epoch, iteration, err))
-- 3. Backward sequence through model(i.e. backprop through time)
local gradOutputs = criterion:backward(outputs, targets)
-- Sequencer handles the backwardThroughTime internally
model:backward(inputs, gradOutputs)
model:updateParameters(lr)
model:zeroGradParameters()
end -- for dataset
epoch = epoch + 1
end -- while epoch
The problem is: network does not converge. Could you share any ideas what I'm doing wrong?