17

我最近在阅读的一些 Lua 源文件中经常看到这种类型的语法,这是什么意思,尤其是第二对括号一个示例, https://github.com/karpathy/char-rnn/blob中的第 8 行/master/model/LSTM.lua

local LSTM = {}
function LSTM.lstm(input_size, rnn_size, n, dropout)
  dropout = dropout or 0 

  -- there will be 2*n+1 inputs
  local inputs = {}
  table.insert(inputs, nn.Identity()())  -- line 8
  -- ...

https://github.com/torch/nn/blob/master/Identity.lua的源代码nn.Identity

********** 更新 **************

()() 模式在火炬库 'nn' 中被大量使用。第一对括号创建容器/节点的对象,第二对括号引用依赖节点。

例如,y = nn.Linear(2,4)(x) 表示 x 连接到 y,从 1*2 到 1*4 的变换是线性的。我只是了解用法,它的接线方式似乎可以通过以下答案之一来回答。

无论如何,接口的使用在下面有很好的记录。 https://github.com/torch/nngraph/blob/master/README.md

4

3 回答 3

14

不,()()在 Lua 中没有特殊含义,只是两个调用操作符()在一起。

操作数可能是一个返回函数的函数(或实现call元方法的表)。例如:

function foo()
  return function() print(42) end
end

foo()()   -- 42
于 2015-06-22T15:14:32.377 回答
13

作为对余浩的回答的补充,让我给出一些与 Torch 相关的精度:

  • nn.Identity()创建一个身份模块,
  • ()调用此模块触发器nn.Module __call__(感谢 Torch 类系统将其正确连接到元表中),
  • 默认情况下,此__call__方法执行向前/向后,
  • 但是这里使用了torch/nngraph并且nngraph 覆盖了此方法,如您在此处看到的。

因此,每个nn.Identity()()调用都在这里返回一个nngraph.Node({module=self})节点,其中 self 指的是当前nn.Identity()实例。

--

更新:可以在此处找到LSTM-s上下文中此语法的说明:

local i2h = nn.Linear(input_size, 4 * rnn_size)(input)  -- input to hidden

如果您不熟悉nngraph,我们正在构建一个模块并且已经使用图形节点再次调用它可能看起来很奇怪。实际发生的情况是第二个调用将nn.Moduleto转换为nngraph.gModule并且参数指定它是 graph 中的父级

于 2015-06-24T10:29:18.583 回答
2
  • 第一个()调用init函数,第二个()调用call函数
  • 如果该类不具有这些函数中的任何一个,则调用父函数。
  • 在 nn.Identity()() 的情况下,nn.Identity 既没有 init 函数也没有调用函数,因此 Identity 父级 nn.Module 的 init 和 call 函数称为。附上插图

    require 'torch'
    
    -- define some dummy A class
    local A = torch.class('A')
    function A:__init(stuff)
      self.stuff = stuff
      print('inside __init of A')
    end
    
    function A:__call__(arg1)
    print('inside __call__ of A')
    end
    
    -- define some dummy B class, inheriting from A
    local B,parent = torch.class('B', 'A')
    
    function B:__init(stuff)
      self.stuff = stuff
      print('inside __init of B')
    end
    
    function B:__call__(arg1)
    print('inside __call__ of B')
    end
    a=A()()
    b=B()()
    

    输出

    inside __init of A
    inside __call__ of A
    inside __init of B
    inside __call__ of B
    

另一个代码示例

    require 'torch'

    -- define some dummy A class
    local A = torch.class('A')
    function A:__init(stuff)
      self.stuff = stuff
      print('inside __init of A')
    end

    function A:__call__(arg1)
    print('inside __call__ of A')
    end

    -- define some dummy B class, inheriting from A
    local B,parent = torch.class('B', 'A')

    b=B()()

输出

    inside __init of A
    inside __call__ of A
于 2016-05-18T15:56:16.840 回答