1

我训练了一个具有 1000 次迭代的网络,并且希望在不从头开始的情况下继续这种训练直到 2000 次迭代。我阅读了针对这个问题的不同方法并编写了下面的代码,所以最后我的参数位于“saved_pa​​rams”中。但是从现在开始,我不明白我必须用这些参数做什么。

有人可以解释我该怎么做吗?如何将这些参数用于我的训练过程?

from __future__ import print_function
import numpy as np
import theano
import lasagne
import pickle


input_var=None
ini = lasagne.init.HeUniform()

l_in = lasagne.layers.InputLayer(shape=(None, 1, 120, 120), input_var=input_var)
b= np.zeros((1, 4), dtype=theano.config.floatX)
b = b.flatten()

loc_l1 = lasagne.layers.MaxPool2DLayer(l_in, pool_size=(2, 2))
loc_l2 = lasagne.layers.Conv2DLayer(loc_l1, num_filters=20, filter_size=(5, 5), W=ini)
loc_l3 = lasagne.layers.MaxPool2DLayer(loc_l2, pool_size=(2, 2))
loc_l4 = lasagne.layers.Conv2DLayer(loc_l3, num_filters=20, filter_size=(5, 5), W=ini)
loc_l5 = lasagne.layers.DenseLayer(loc_l4, num_units=50, W=lasagne.init.HeUniform('relu'))
network = lasagne.layers.DenseLayer(loc_l5, num_units=4, b=b, W=lasagne.init.Constant(0.0), nonlinearity=lasagne.nonlinearities.identity)


def save_network(filename,param_values):
    f = open(filename, 'wb')
    pickle.dump(param_values,f,protocol=-1)
    f.close()

def load_network(filename):
    f = open(filename, 'rb')
    param_values = pickle.load(f)
    f.close()
    return param_values


save_network("model.npz",lasagne.layers.get_all_param_values(network))

saved_params = load_network("model.npz")
lasagne.layers.set_all_param_values(network, saved_params)
4

3 回答 3

0

此代码只是一个示例。它正在做以下事情: 1. 从上次训练中加载训练的权重 2. 使用相同的测试训练数据(否则你在测试数据上训练) 3. 启动网络的拟合方法(net_loaded.fit(parameters))正在使用模型的加载权重

要从这个级联中得到一个图表,你必须保存你的准确性值在 epoch 图表上或者你用来可视化组合结果的任何东西。

于 2017-11-03T06:00:53.767 回答
0

您可以只使用 load 然后调用 fit 方法还是更改了参数?如果你想要一个图表,那么将你的错误保存为 1000 个时期

于 2017-10-25T16:21:23.837 回答
0
if(load):
        net1 = Lenet(classes, num_epochs)
        net1.load_weights_from('Lenet.npz')
        network = net1
        train_X = np.float32(train_X)
        print("train_x",train_X)
        print("train_y",train_Y)
        train_Y = np.int16(train_Y)
        network = net1.fit(train_X, train_Y, num_epochs)
        print ("Loading weights successfully done.")
于 2017-10-25T16:24:41.527 回答