这是我自己用matlab语言实现的梯度下降算法
m = height(data_training); % number of samples
cols = {'x1', 'x2', 'x3', 'x4', 'x5', 'x6',...
'x7', 'x8','x9', 'x10', 'x11', 'x12', 'x13', 'x14', 'x15'};
y = data_training{:, {'y'}}';
X = [ones(m,1) data_training{:,cols}]';
theta = zeros(1,width(data_training));
alpha = 1e-2; % learning rate
iter = 400;
dJ = zeros(1,width(data_training));
J_seq = zeros(1, iter);
for n = 1:iter
err = (theta*X - y);
for j = 1:width(data_training)
dJ(j) = 1/m*sum(err*X(j,:)');
end
J = 1/(2*m)*sum((theta*X-y).^2);
theta = theta - alpha.*dJ;
J_seq(n) = J;
if mod(n,100) == 0
plot(1:iter, J_seq);
end
end
编辑 工作算法
我已将此算法应用于以下训练数据集。最后一列是输出变量。在这里,我们有 15 个不同的功能。
由于我不知道的原因,当我在 50 次迭代后绘制成本函数 J 以检查它是否趋于收敛时,我发现它没有收敛。你能帮我理解吗?是实施错误还是我应该做点什么?
36 27 71 8.1 3.34 11.4 81.5 3243 8.8 42.6 11.7 21 15 59 59 921.87
35 23 72 11.1 3.14 11 78.8 4281 3.6 50.7 14.4 8 10 39 57 997.88
44 29 74 10.4 3.21 9.8 81.6 4260 0.8 39.4 12.4 6 6 33 54 962.35
47 45 79 6.5 3.41 11.1 77.5 3125 27.1 50.2 20.6 18 8 24 56 982.29
43 35 77 7.6 3.44 9.6 84.6 6441 24.4 43.7 14.3 43 38 206 55 1071.3
53 45 80 7.7 3.45 10.2 66.8 3325 38.5 43.1 25.5 30 32 72 54 1030.4
43 30 74 10.9 3.23 12.1 83.9 4679 3.5 49.2 11.3 21 32 62 56 934.7
45 30 73 9.3 3.29 10.6 86 2140 5.3 40.4 10.5 6 4 4 56 899.53
36 24 70 9 3.31 10.5 83.2 6582 8.1 42.5 12.6 18 12 37 61 1001.9
36 27 72 9.5 3.36 10.7 79.3 4213 6.7 41 13.2 12 7 20 59 912.35
52 42 79 7.7 3.39 9.6 69.2 2302 22.2 41.3 24.2 18 8 27 56 1017.6
33 26 76 8.6 3.2 10.9 83.4 6122 16.3 44.9 10.7 88 63 278 58 1024.9
40 34 77 9.2 3.21 10.2 77 4101 13 45.7 15.1 26 26 146 57 970.47
35 28 71 8.8 3.29 11.1 86.3 3042 14.7 44.6 11.4 31 21 64 60 985.95
37 31 75 8 3.26 11.9 78.4 4259 13.1 49.6 13.9 23 9 15 58 958.84
35 46 85 7.1 3.22 11.8 79.9 1441 14.8 51.2 16.1 1 1 1 54 860.1
36 30 75 7.5 3.35 11.4 81.9 4029 12.4 44 12 6 4 16 58 936.23
15 30 73 8.2 3.15 12.2 84.2 4824 4.7 53.1 12.7 17 8 28 38 871.77
31 27 74 7.2 3.44 10.8 87 4834 15.8 43.5 13.6 52 35 124 59 959.22
30 24 72 6.5 3.53 10.8 79.5 3694 13.1 33.8 12.4 11 4 11 61 941.18
31 45 85 7.3 3.22 11.4 80.7 1844 11.5 48.1 18.5 1 1 1 53 891.71
31 24 72 9 3.37 10.9 82.8 3226 5.1 45.2 12.3 5 3 10 61 871.34
42 40 77 6.1 3.45 10.4 71.8 2269 22.7 41.4 19.5 8 3 5 53 971.12
43 27 72 9 3.25 11.5 87.1 2909 7.2 51.6 9.5 7 3 10 56 887.47
46 55 84 5.6 3.35 11.4 79.7 2647 21 46.9 17.9 6 5 1 59 952.53
39 29 76 8.7 3.23 11.4 78.6 4412 15.6 46.6 13.2 13 7 33 60 968.66
35 31 81 9.2 3.1 12 78.3 3262 12.6 48.6 13.9 7 4 4 55 919.73
43 32 74 10.1 3.38 9.5 79.2 3214 2.9 43.7 12 11 7 32 54 844.05
11 53 68 9.2 2.99 12.1 90.6 4700 7.8 48.9 12.3 648 319 130 47 861.83
30 35 71 8.3 3.37 9.9 77.4 4474 13.1 42.6 17.7 38 37 193 57 989.26
50 42 82 7.3 3.49 10.4 72.5 3497 36.7 43.3 26.4 15 10 34 59 1006.5
60 67 82 10 2.98 11.5 88.6 4657 13.6 47.3 22.4 3 1 1 60 861.44
30 20 69 8.8 3.26 11.1 85.4 2934 5.8 44 9.4 33 23 125 64 929.15
25 12 73 9.2 3.28 12.1 83.1 2095 2 51.9 9.8 20 11 26 50 857.62
45 40 80 8.3 3.32 10.1 70.3 2682 21 46.1 24.1 17 14 78 56 961.01
46 30 72 10.2 3.16 11.3 83.2 3327 8.8 45.3 12.2 4 3 8 58 923.23