0

我有一种情况,其中一个组件的梯度必须在另一个组件中计算。我试图做的只是让梯度成为第一个组件的输出和第二个组件的输入。我已将其设置为 pass_by_obj,以免影响其他计算。任何关于这是否是最好的方法的建议将不胜感激。不过,使用 check_partial_derivatives() 时出现错误。对于指定为 pass_by_obj 的任何输出,这似乎都是一个错误。这是一个简单的案例:

import numpy as np
from openmdao.api import Group, Problem, Component, ScipyGMRES, ExecComp, IndepVarComp

class Comp1(Component):
    def __init__(self):
        super(Comp1, self).__init__()
        self.add_param('x', shape=1)

        self.add_output('y', shape=1)
        self.add_output('dz_dy', shape=1, pass_by_obj=True)

    def solve_nonlinear(self, params, unknowns, resids):

        x = params['x']

        unknowns['y'] = 4.0*x + 1.0
        unknowns['dz_dy'] = 2.0*x

    def linearize(self, params, unknowns, resids):

        J = {}
        J['y', 'x'] = 4.0
        return J

class Comp2(Component):
    def __init__(self):
        super(Comp2, self).__init__()
        self.add_param('y', shape=1)
        self.add_param('dz_dy', shape=1, pass_by_obj=True)

        self.add_output('z', shape=1)

    def solve_nonlinear(self, params, unknowns, resids):
        y = params['y']
        unknowns['z'] = y*2.0

    def linearize(self, params, unknowns, resids):
        J = {}
        J['z', 'y'] = params['dz_dy']
        return J

class TestGroup(Group):
    def __init__(self):
        super(TestGroup, self).__init__()
        self.add('x', IndepVarComp('x', 0.0), promotes=['*'])
        self.add('c1', Comp1(), promotes=['*'])
        self.add('c2', Comp2(), promotes=['*'])

p = Problem()
p.root = TestGroup()
p.setup(check=False)

p['x'] = 2.0

p.run()

print p['z']
print 'gradients'
test_grad = open('partial_gradients_test.txt', 'w')
partial = p.check_partial_derivatives(out_stream=test_grad)

我收到以下错误消息:

partial = p.check_partial_derivatives(out_stream=test_grad)
  File "/usr/local/lib/python2.7/site-packages/openmdao/core/problem.py", line 1699, in check_partial_derivatives
    dresids._dat[u_name].val[idx] = 1.0
TypeError: '_ByObjWrapper' object does not support item assignment

我之前询问过在 check_partial_derivatives() 中检查 pass_by_obj 的参数,这可能只是检查 pass_by_obj 的未知数的问题。

4

1 回答 1

0

您得到的错误是与 check_partial_derivatives 函数相关的另一个错误。它应该很容易修复,但同时您可以删除 pass_by_obj 设置。由于您正在计算一个组件中的值并将其传递给另一个组件,因此根本不需要 pass_by_obj (如果不这样做会更有效)。

你说你这样做是为了“不影响其他计算”,但我不太明白你的意思。除非您在 solve_nonlinear 方法中使用它,否则它不会影响任何事情。

于 2015-12-29T22:08:37.567 回答