当您执行require "nn"
此加载时init.lua
,它会依次执行require('libnn')
. 这是torch/nn 的C 扩展。
如果你看init.c
你会发现:这是-edluaopen_libnn
时调用的初始化函数。libnn.so
require
该函数负责初始化 torch/nn 的所有部分,包括MSECriterion
viann_FloatMSECriterion_init(L)
和的原生部分nn_DoubleMSECriterion_init(L)
。
如果你看一下,generic/MSECriterion.c
你会发现通用(即为float
and扩展的宏double
)初始化函数:
static void nn_(MSECriterion_init)(lua_State *L)
{
luaT_pushmetatable(L, torch_Tensor);
luaT_registeratname(L, nn_(MSECriterion__), "nn");
lua_pop(L,1);
}
这个 init 函数修改了 any 的元表,torch.FloatTensor
因此它在keytorch.DoubleTensor
下填充了一堆函数(有关更多详细信息,请参阅Torch7 Lua C API)。这些函数是在之前定义的:nn
static const struct luaL_Reg nn_(MSECriterion__) [] = {
{"MSECriterion_updateOutput", nn_(MSECriterion_updateOutput)},
{"MSECriterion_updateGradInput", nn_(MSECriterion_updateGradInput)},
{NULL, NULL}
};
换句话说,任何张量都具有这些功能,这要归功于它的元表:
luajit -lnn
> print(torch.Tensor().nn.MSECriterion_updateOutput)
function: 0x40921df8
> print(torch.Tensor().nn.MSECriterion_updateGradInput)
function: 0x40921e20
注意:对于所有具有 C 本机实现对应的 torch/nn 模块,此机制都是相同的。
正如您在generic/MSECriterion.c上看到的那样input.nn.MSECriterion_updateOutput(self, input, target)
,调用效果也是如此。static int nn_(MSECriterion_updateOutput)(lua_State *L)
此函数计算输入张量之间的均方误差。