0

我正在尝试将一个组件的输出的最后一个元素连接到另一个组件的输入。一个例子如下所示:

import numpy as np
from openmdao.api import Component, Problem, Group

class C1(Component):
    def __init__(self):
        super(C1, self).__init__()
        self.add_param('fin', val=1.0)
        self.add_output('arr', val=np.zeros(5))

    def solve_nonlinear(self, params, unknowns, resids):
        fin = params['fin']
        unknowns['arr'] = np.array([2*fin])


class C2(Component):
    def __init__(self):
        super(C2, self).__init__()
        self.add_param('flt', val=0.0)
        self.add_output('fout', val=0.0)

    def solve_nonlinear(self, params, unknowns, resids):
        flt = params['flt']
        unknowns['fout'] = 2*flt

class A(Group):
    def __init__(self):
        super(A, self).__init__()

        self.add('c1', C1())
        self.add('c2', C2())

        self.connect('c1.arr[-1]', 'c2.flt')

if __name__ == '__main__':

    a = Problem()
    a.root = A()
    a.setup()

    a.run()
    print a.root.c2.unknowns['fout']

我收到错误:

openmdao.core.checks.ConnectError: Source 'c1.arr[-1]' cannot be connected to target 'c2.flt': 'c1.arr[-1]' does not exist.

有没有办法做到这一点?我知道它适用于旧版本的 OpenMDAO。

4

2 回答 2

3

OpenMDAO 支持通过使用 'src_indices' 参数连接到源的特定索引。例如:

self.connect('c1.arr', 'c2.flt', src_indices=[4])

当前不支持负索引。

于 2015-11-30T20:29:00.627 回答
1

这里有很多小问题。首先,C1 的 solve_nonlinear 方法的数组计算大小错误。它最终解决了,但你真的应该将数组设置为正确的大小(长度 5)。

对于数组的一部分(请参阅文档更高级的文档),您指定 src_indices 参数来连接。

import numpy as np
from openmdao.api import Component, Problem, Group

class C1(Component):
    def __init__(self):
        super(C1, self).__init__()
        self.add_param('fin', val=1.0)
        self.add_output('arr', val=np.zeros(5))

    def solve_nonlinear(self, params, unknowns, resids):
        fin = params['fin']
        unknowns['arr'] = fin*np.arange(5)

class C2(Component):
    def __init__(self):
        super(C2, self).__init__()
        self.add_param('flt', val=0.0)
        self.add_output('fout', val=0.0)

    def solve_nonlinear(self, params, unknowns, resids):
        flt = params['flt']
        unknowns['fout'] = 2*flt

class A(Group):
    def __init__(self):
        super(A, self).__init__()

        self.add('c1', C1())
        self.add('c2', C2())

        self.connect('c1.arr', 'c2.flt', src_indices=[4,])

if __name__ == '__main__':

    a = Problem()
    a.root = A()
    a.setup()

    a.run()
    print a.root.c2.unknowns['fout']
于 2015-11-30T20:36:44.870 回答