我正在对非线性系统网络进行简单的模拟。特别是我有 N 个节点,每个节点由 m 个单元组成。每个单元的输出函数既取决于它的活动,也取决于同一节点中其他单元的活动。
我实现的模拟是在 scipy + jitcode 中。
我实现的第一个版本是根据 softmax 分布,因此我实现了这个简单的函数来计算每个单元的输出。
def soft_max(node_activities):
"""
This function computes the output of all the mini-columns
:param nodes_activities: Activities of the minicolumns grouped in nested lists
:return: One unique list with all the outputs
"""
G = 10
act = []
for node in nodes_activities:
sum_hc = 0
for unit in node:
sum_hc += symengine.exp(unit * G)
for unit in node:
act.append(symengine.exp(unit * G)/sum_hc)
return act
现在,我想用一个简单的函数替换上面的函数,对于每个节点,为活动度最高的单元输出 1,在其他单元中输出 0。长话短说,对于每个节点,只有一个单元输出 1。
我现在面临的主要问题是如何使用 symengine 执行此操作,以便 jitcode 可以使用它。我在下面实现的功能由于明显的原因不起作用。我猜 if 条件不是很有象征意义。
def soft_max(node_activities):
"""
This function computes the output of all the mini-columns
:param nodes_activities: Activities of the minicolumns grouped in nested lists
:return: One unique list with all the outputs
"""
G = 10
act = []
for node in nodes_activities:
max_act = symengine.Max(*node)
for unit in node:
if unit >= max_act:
act.append(1)
else:
act.append(0)
return act
我没有找到任何 symengine.argmax() 函数或任何智能替代解决方案。你有什么建议吗?
更新
def max_activation(activities):
act = []
for hc in activities:
sum_hc = 0
max_act = symengine.Max(*hc)
for mc in hc:
act.append(symengine.GreaterThan(mc, max_act))
print(act)
return act
测试这个功能:
max_activation([[y(1), y(2)], [y(3), y(4)]])
我得到以下有希望的输出。一旦我有一些测试,我会更新。
[max(y(2), y(1)) <= y(1), max(y(2), y(1)) <= y(2)]
[max(y(4), y(3)) <= y(3), max(y(4), y(3)) <= y(4)]