我有一个数据集——X
由 15 个变量和 64 个观察值和一个列向量——Y
代表目标(标签)的 64 个值组成。我尝试Y
使用 PyTorch 将参数拟合到二次函数以返回观察值 ( ),但出现错误。我在文章末尾以 json 格式提供了数据集,以实现可重复性。
如果我有一个示例,我的代码可能是:
X = torch.from_numpy(X)
X.requires_grad = True
W = np.random.randn(15,15)
W = np.triu(W, k=0)
W = torch.from_numpy(W)
W.requires_grad = True
# define parameters for gradient descent
max_iter=100
lr_rate = 1e-3
# we will do gradient descent for max_iter iteration
for i in range(max_iter):
# compute the loss
loss = Y - (X@torch.transpose(X, 1,0) * W).sum()
# use torch.autograd.grad to compute the gradient
W = W - lr_rate*torch.autograd.grad(out, W)[0]
print(f"{i}: {out}")
您能否提供一个使用我在下面提供的数据以矢量化方式实现所述目标(将参数拟合到数据)的正确实现示例?
数据: X:
'{"embed_item_dim":{"0":-1.0,"1":1.0,"2":-1.0,"3":1.0,"4":-1.0,"5":1.0,"6":-1.0,"7":1.0,"8":-1.0,"9":1.0,"10":-1.0,"11":1.0,"12":-1.0,"13":1.0,"14":-1.0,"15":1.0,"16":-1.0,"17":1.0,"18":-1.0,"19":1.0,"20":-1.0,"21":1.0,"22":-1.0,"23":1.0,"24":-1.0,"25":1.0,"26":-1.0,"27":1.0,"28":-1.0,"29":1.0,"30":-1.0,"31":1.0,"32":-1.0,"33":1.0,"34":-1.0,"35":1.0,"36":-1.0,"37":1.0,"38":-1.0,"39":1.0,"40":-1.0,"41":1.0,"42":-1.0,"43":1.0,"44":-1.0,"45":1.0,"46":-1.0,"47":1.0,"48":-1.0,"49":1.0,"50":-1.0,"51":1.0,"52":-1.0,"53":1.0,"54":-1.0,"55":1.0,"56":-1.0,"57":1.0,"58":-1.0,"59":1.0,"60":-1.0,"61":1.0,"62":-1.0,"63":1.0},"embed_category_dim":{"0":-1.0,"1":-1.0,"2":1.0,"3":1.0,"4":-1.0,"5":-1.0,"6":1.0,"7":1.0,"8":-1.0,"9":-1.0,"10":1.0,"11":1.0,"12":-1.0,"13":-1.0,"14":1.0,"15":1.0,"16":-1.0,"17":-1.0,"18":1.0,"19":1.0,"20":-1.0,"21":-1.0,"22":1.0,"23":1.0,"24":-1.0,"25":-1.0,"26":1.0,"27":1.0,"28":-1.0,"29":-1.0,"30":1.0,"31":1.0,"32":-1.0,"33":-1.0,"34":1.0,"35":1.0,"36":-1.0,"37":-1.0,"38":1.0,"39":1.0,"40":-1.0,"41":-1.0,"42":1.0,"43":1.0,"44":-1.0,"45":-1.0,"46":1.0,"47":1.0,"48":-1.0,"49":-1.0,"50":1.0,"51":1.0,"52":-1.0,"53":-1.0,"54":1.0,"55":1.0,"56":-1.0,"57":-1.0,"58":1.0,"59":1.0,"60":-1.0,"61":-1.0,"62":1.0,"63":1.0},"embed_shop_dim":{"0":-1.0,"1":-1.0,"2":-1.0,"3":-1.0,"4":1.0,"5":1.0,"6":1.0,"7":1.0,"8":-1.0,"9":-1.0,"10":-1.0,"11":-1.0,"12":1.0,"13":1.0,"14":1.0,"15":1.0,"16":-1.0,"17":-1.0,"18":-1.0,"19":-1.0,"20":1.0,"21":1.0,"22":1.0,"23":1.0,"24":-1.0,"25":-1.0,"26":-1.0,"27":-1.0,"28":1.0,"29":1.0,"30":1.0,"31":1.0,"32":-1.0,"33":-1.0,"34":-1.0,"35":-1.0,"36":1.0,"37":1.0,"38":1.0,"39":1.0,"40":-1.0,"41":-1.0,"42":-1.0,"43":-1.0,"44":1.0,"45":1.0,"46":1.0,"47":1.0,"48":-1.0,"49":-1.0,"50":-1.0,"51":-1.0,"52":1.0,"53":1.0,"54":1.0,"55":1.0,"56":-1.0,"57":-1.0,"58":-1.0,"59":-1.0,"60":1.0,"61":1.0,"62":1.0,"63":1.0},"categorical_dim":{"0":-1.0,"1":-1.0,"2":-1.0,"3":-1.0,"4":-1.0,"5":-1.0,"6":-1.0,"7":-1.0,"8":1.0,"9":1.0,"10":1.0,"11":1.0,"12":1.0,"13":1.0,"14":1.0,"15":1.0,"16":-1.0,"17":-1.0,"18":-1.0,"19":-1.0,"20":-1.0,"21":-1.0,"22":-1.0,"23":-1.0,"24":1.0,"25":1.0,"26":1.0,"27":1.0,"28":1.0,"29":1.0,"30":1.0,"31":1.0,"32":-1.0,"33":-1.0,"34":-1.0,"35":-1.0,"36":-1.0,"37":-1.0,"38":-1.0,"39":-1.0,"40":1.0,"41":1.0,"42":1.0,"43":1.0,"44":1.0,"45":1.0,"46":1.0,"47":1.0,"48":-1.0,"49":-1.0,"50":-1.0,"51":-1.0,"52":-1.0,"53":-1.0,"54":-1.0,"55":-1.0,"56":1.0,"57":1.0,"58":1.0,"59":1.0,"60":1.0,"61":1.0,"62":1.0,"63":1.0},"categorical_dropout":{"0":-1.0,"1":-1.0,"2":-1.0,"3":-1.0,"4":-1.0,"5":-1.0,"6":-1.0,"7":-1.0,"8":-1.0,"9":-1.0,"10":-1.0,"11":-1.0,"12":-1.0,"13":-1.0,"14":-1.0,"15":-1.0,"16":1.0,"17":1.0,"18":1.0,"19":1.0,"20":1.0,"21":1.0,"22":1.0,"23":1.0,"24":1.0,"25":1.0,"26":1.0,"27":1.0,"28":1.0,"29":1.0,"30":1.0,"31":1.0,"32":-1.0,"33":-1.0,"34":-1.0,"35":-1.0,"36":-1.0,"37":-1.0,"38":-1.0,"39":-1.0,"40":-1.0,"41":-1.0,"42":-1.0,"43":-1.0,"44":-1.0,"45":-1.0,"46":-1.0,"47":-1.0,"48":1.0,"49":1.0,"50":1.0,"51":1.0,"52":1.0,"53":1.0,"54":1.0,"55":1.0,"56":1.0,"57":1.0,"58":1.0,"59":1.0,"60":1.0,"61":1.0,"62":1.0,"63":1.0},"numerical_dim":{"0":-1.0,"1":-1.0,"2":-1.0,"3":-1.0,"4":-1.0,"5":-1.0,"6":-1.0,"7":-1.0,"8":-1.0,"9":-1.0,"10":-1.0,"11":-1.0,"12":-1.0,"13":-1.0,"14":-1.0,"15":-1.0,"16":-1.0,"17":-1.0,"18":-1.0,"19":-1.0,"20":-1.0,"21":-1.0,"22":-1.0,"23":-1.0,"24":-1.0,"25":-1.0,"26":-1.0,"27":-1.0,"28":-1.0,"29":-1.0,"30":-1.0,"31":-1.0,"32":1.0,"33":1.0,"34":1.0,"35":1.0,"36":1.0,"37":1.0,"38":1.0,"39":1.0,"40":1.0,"41":1.0,"42":1.0,"43":1.0,"44":1.0,"45":1.0,"46":1.0,"47":1.0,"48":1.0,"49":1.0,"50":1.0,"51":1.0,"52":1.0,"53":1.0,"54":1.0,"55":1.0,"56":1.0,"57":1.0,"58":1.0,"59":1.0,"60":1.0,"61":1.0,"62":1.0,"63":1.0},"numerical_dropout":{"0":-1.0,"1":1.0,"2":1.0,"3":-1.0,"4":1.0,"5":-1.0,"6":-1.0,"7":1.0,"8":-1.0,"9":1.0,"10":1.0,"11":-1.0,"12":1.0,"13":-1.0,"14":-1.0,"15":1.0,"16":-1.0,"17":1.0,"18":1.0,"19":-1.0,"20":1.0,"21":-1.0,"22":-1.0,"23":1.0,"24":-1.0,"25":1.0,"26":1.0,"27":-1.0,"28":1.0,"29":-1.0,"30":-1.0,"31":1.0,"32":-1.0,"33":1.0,"34":1.0,"35":-1.0,"36":1.0,"37":-1.0,"38":-1.0,"39":1.0,"40":-1.0,"41":1.0,"42":1.0,"43":-1.0,"44":1.0,"45":-1.0,"46":-1.0,"47":1.0,"48":-1.0,"49":1.0,"50":1.0,"51":-1.0,"52":1.0,"53":-1.0,"54":-1.0,"55":1.0,"56":-1.0,"57":1.0,"58":1.0,"59":-1.0,"60":1.0,"61":-1.0,"62":-1.0,"63":1.0},"mixed_dim1":{"0":-1.0,"1":1.0,"2":1.0,"3":-1.0,"4":-1.0,"5":1.0,"6":1.0,"7":-1.0,"8":1.0,"9":-1.0,"10":-1.0,"11":1.0,"12":1.0,"13":-1.0,"14":-1.0,"15":1.0,"16":-1.0,"17":1.0,"18":1.0,"19":-1.0,"20":-1.0,"21":1.0,"22":1.0,"23":-1.0,"24":1.0,"25":-1.0,"26":-1.0,"27":1.0,"28":1.0,"29":-1.0,"30":-1.0,"31":1.0,"32":-1.0,"33":1.0,"34":1.0,"35":-1.0,"36":-1.0,"37":1.0,"38":1.0,"39":-1.0,"40":1.0,"41":-1.0,"42":-1.0,"43":1.0,"44":1.0,"45":-1.0,"46":-1.0,"47":1.0,"48":-1.0,"49":1.0,"50":1.0,"51":-1.0,"52":-1.0,"53":1.0,"54":1.0,"55":-1.0,"56":1.0,"57":-1.0,"58":-1.0,"59":1.0,"60":1.0,"61":-1.0,"62":-1.0,"63":1.0},"mixed_dropout1":{"0":-1.0,"1":1.0,"2":1.0,"3":-1.0,"4":-1.0,"5":1.0,"6":1.0,"7":-1.0,"8":-1.0,"9":1.0,"10":1.0,"11":-1.0,"12":-1.0,"13":1.0,"14":1.0,"15":-1.0,"16":1.0,"17":-1.0,"18":-1.0,"19":1.0,"20":1.0,"21":-1.0,"22":-1.0,"23":1.0,"24":1.0,"25":-1.0,"26":-1.0,"27":1.0,"28":1.0,"29":-1.0,"30":-1.0,"31":1.0,"32":-1.0,"33":1.0,"34":1.0,"35":-1.0,"36":-1.0,"37":1.0,"38":1.0,"39":-1.0,"40":-1.0,"41":1.0,"42":1.0,"43":-1.0,"44":-1.0,"45":1.0,"46":1.0,"47":-1.0,"48":1.0,"49":-1.0,"50":-1.0,"51":1.0,"52":1.0,"53":-1.0,"54":-1.0,"55":1.0,"56":1.0,"57":-1.0,"58":-1.0,"59":1.0,"60":1.0,"61":-1.0,"62":-1.0,"63":1.0},"mixed_dim2":{"0":-1.0,"1":1.0,"2":1.0,"3":-1.0,"4":-1.0,"5":1.0,"6":1.0,"7":-1.0,"8":-1.0,"9":1.0,"10":1.0,"11":-1.0,"12":-1.0,"13":1.0,"14":1.0,"15":-1.0,"16":-1.0,"17":1.0,"18":1.0,"19":-1.0,"20":-1.0,"21":1.0,"22":1.0,"23":-1.0,"24":-1.0,"25":1.0,"26":1.0,"27":-1.0,"28":-1.0,"29":1.0,"30":1.0,"31":-1.0,"32":1.0,"33":-1.0,"34":-1.0,"35":1.0,"36":1.0,"37":-1.0,"38":-1.0,"39":1.0,"40":1.0,"41":-1.0,"42":-1.0,"43":1.0,"44":1.0,"45":-1.0,"46":-1.0,"47":1.0,"48":1.0,"49":-1.0,"50":-1.0,"51":1.0,"52":1.0,"53":-1.0,"54":-1.0,"55":1.0,"56":1.0,"57":-1.0,"58":-1.0,"59":1.0,"60":1.0,"61":-1.0,"62":-1.0,"63":1.0},"mixed_dropout2":{"0":-1.0,"1":1.0,"2":-1.0,"3":1.0,"4":1.0,"5":-1.0,"6":1.0,"7":-1.0,"8":1.0,"9":-1.0,"10":1.0,"11":-1.0,"12":-1.0,"13":1.0,"14":-1.0,"15":1.0,"16":-1.0,"17":1.0,"18":-1.0,"19":1.0,"20":1.0,"21":-1.0,"22":1.0,"23":-1.0,"24":1.0,"25":-1.0,"26":1.0,"27":-1.0,"28":-1.0,"29":1.0,"30":-1.0,"31":1.0,"32":-1.0,"33":1.0,"34":-1.0,"35":1.0,"36":1.0,"37":-1.0,"38":1.0,"39":-1.0,"40":1.0,"41":-1.0,"42":1.0,"43":-1.0,"44":-1.0,"45":1.0,"46":-1.0,"47":1.0,"48":-1.0,"49":1.0,"50":-1.0,"51":1.0,"52":1.0,"53":-1.0,"54":1.0,"55":-1.0,"56":1.0,"57":-1.0,"58":1.0,"59":-1.0,"60":-1.0,"61":1.0,"62":-1.0,"63":1.0},"mixed_dim3":{"0":-1.0,"1":1.0,"2":-1.0,"3":1.0,"4":1.0,"5":-1.0,"6":1.0,"7":-1.0,"8":-1.0,"9":1.0,"10":-1.0,"11":1.0,"12":1.0,"13":-1.0,"14":1.0,"15":-1.0,"16":1.0,"17":-1.0,"18":1.0,"19":-1.0,"20":-1.0,"21":1.0,"22":-1.0,"23":1.0,"24":1.0,"25":-1.0,"26":1.0,"27":-1.0,"28":-1.0,"29":1.0,"30":-1.0,"31":1.0,"32":-1.0,"33":1.0,"34":-1.0,"35":1.0,"36":1.0,"37":-1.0,"38":1.0,"39":-1.0,"40":-1.0,"41":1.0,"42":-1.0,"43":1.0,"44":1.0,"45":-1.0,"46":1.0,"47":-1.0,"48":1.0,"49":-1.0,"50":1.0,"51":-1.0,"52":-1.0,"53":1.0,"54":-1.0,"55":1.0,"56":1.0,"57":-1.0,"58":1.0,"59":-1.0,"60":-1.0,"61":1.0,"62":-1.0,"63":1.0},"mixed_dropout3":{"0":-1.0,"1":1.0,"2":-1.0,"3":1.0,"4":1.0,"5":-1.0,"6":1.0,"7":-1.0,"8":-1.0,"9":1.0,"10":-1.0,"11":1.0,"12":1.0,"13":-1.0,"14":1.0,"15":-1.0,"16":-1.0,"17":1.0,"18":-1.0,"19":1.0,"20":1.0,"21":-1.0,"22":1.0,"23":-1.0,"24":-1.0,"25":1.0,"26":-1.0,"27":1.0,"28":1.0,"29":-1.0,"30":1.0,"31":-1.0,"32":1.0,"33":-1.0,"34":1.0,"35":-1.0,"36":-1.0,"37":1.0,"38":-1.0,"39":1.0,"40":1.0,"41":-1.0,"42":1.0,"43":-1.0,"44":-1.0,"45":1.0,"46":-1.0,"47":1.0,"48":1.0,"49":-1.0,"50":1.0,"51":-1.0,"52":-1.0,"53":1.0,"54":-1.0,"55":1.0,"56":1.0,"57":-1.0,"58":1.0,"59":-1.0,"60":-1.0,"61":1.0,"62":-1.0,"63":1.0},"last_layer_dim":{"0":-1.0,"1":1.0,"2":-1.0,"3":1.0,"4":-1.0,"5":1.0,"6":-1.0,"7":1.0,"8":1.0,"9":-1.0,"10":1.0,"11":-1.0,"12":1.0,"13":-1.0,"14":1.0,"15":-1.0,"16":1.0,"17":-1.0,"18":1.0,"19":-1.0,"20":1.0,"21":-1.0,"22":1.0,"23":-1.0,"24":-1.0,"25":1.0,"26":-1.0,"27":1.0,"28":-1.0,"29":1.0,"30":-1.0,"31":1.0,"32":-1.0,"33":1.0,"34":-1.0,"35":1.0,"36":-1.0,"37":1.0,"38":-1.0,"39":1.0,"40":1.0,"41":-1.0,"42":1.0,"43":-1.0,"44":1.0,"45":-1.0,"46":1.0,"47":-1.0,"48":1.0,"49":-1.0,"50":1.0,"51":-1.0,"52":1.0,"53":-1.0,"54":1.0,"55":-1.0,"56":-1.0,"57":1.0,"58":-1.0,"59":1.0,"60":-1.0,"61":1.0,"62":-1.0,"63":1.0},"last_layer_dropout":{"0":-1.0,"1":1.0,"2":-1.0,"3":1.0,"4":-1.0,"5":1.0,"6":-1.0,"7":1.0,"8":1.0,"9":-1.0,"10":1.0,"11":-1.0,"12":1.0,"13":-1.0,"14":1.0,"15":-1.0,"16":-1.0,"17":1.0,"18":-1.0,"19":1.0,"20":-1.0,"21":1.0,"22":-1.0,"23":1.0,"24":1.0,"25":-1.0,"26":1.0,"27":-1.0,"28":1.0,"29":-1.0,"30":1.0,"31":-1.0,"32":1.0,"33":-1.0,"34":1.0,"35":-1.0,"36":1.0,"37":-1.0,"38":1.0,"39":-1.0,"40":-1.0,"41":1.0,"42":-1.0,"43":1.0,"44":-1.0,"45":1.0,"46":-1.0,"47":1.0,"48":1.0,"49":-1.0,"50":1.0,"51":-1.0,"52":1.0,"53":-1.0,"54":1.0,"55":-1.0,"56":-1.0,"57":1.0,"58":-1.0,"59":1.0,"60":-1.0,"61":1.0,"62":-1.0,"63":1.0}}'
是
'{"0":2.0561221309,"1":2.0649733606,"2":2.0733728925,"3":2.0594125771,"4":2.0949032045,"5":2.0294939058,"6":2.0436441327,"7":2.1209041954,"8":2.0496001055,"9":2.148755921,"10":2.0937250525,"11":2.0629058135,"12":2.0641746866,"13":2.0592979107,"14":2.1166172412,"15":2.1125198086,"16":2.0525522671,"17":2.0687485594,"18":2.0649582587,"19":2.0818384718,"20":2.0839422046,"21":2.043783441,"22":2.05290516,"23":2.0565277924,"24":2.0550897444,"25":2.0663609971,"26":2.0895415003,"27":2.0706054531,"28":2.0639581304,"29":2.0889003421,"30":2.0436977626,"31":2.1350170653,"32":2.0395688425,"33":2.079368626,"34":2.0439947954,"35":2.072433023,"36":2.050665861,"37":2.037977855,"38":2.0527567514,"39":2.050903715,"40":2.0381965719,"41":2.0673631206,"42":2.085004701,"43":2.0458497661,"44":2.0540644062,"45":2.050330556,"46":2.0859451303,"47":2.0323004844,"48":2.05113558,"49":2.046360857,"50":2.0572361143,"51":2.0659940765,"52":2.0583657215,"53":2.0520969623,"54":2.0683284923,"55":2.0491708591,"56":2.0932832342,"57":2.0416396082,"58":2.0703974941,"59":2.0464359665,"60":2.0591405783,"61":2.0527808995,"62":2.0670555565,"63":2.0898413706}'