我正在尝试实现一个自定义 tfd.Distribution ,它代表一个复杂的状态转换模型(STM)。我需要从抽象方法 tfd.Distribution._sample_n 的实现中返回具有不同维度的 2 个数组的元组。但是,当包装器方法 (tfd.Distribution.sample) 尝试打包这些数组时,我遇到了麻烦。
STM 表征存在于多个互斥状态的群体。随着时间的推移,人口中的个体根据随机过程在状态之间转换。为了表示 STM(即样本)的实现,您最终会得到一个长度为 T 的向量,其中包含转换发生的时间,以及一个形状为 [T, M, N] 的多维数组,其中 T 是时间步数, M 是州的数量,N 是人口中的个体数量。
到目前为止,我有:
class Foo(tfd.Distribution):
def __init__(self):
super().__init__(dtype=tf.float32,
#...other config here
)
def _sample_n(self, n, seed=None):
# Sampling algorithm here
# t.shape = [T]
# y.shape = [T, M, N]
return t, y
foo = Foo()
foo.sample()
期望的结果:调用foo.sample()
应该返回一个 (tf.tensor, tf.tensor) 的元组,其形状分别为 [T] 和 [T, M, N]。
实际的:
ValueError: Shapes must be equal rank, but are 1 and 3
From merging shape 0 with other shapes. for 'MyEpidemic/sample/Shape/packed' (op: 'Pack') with input shapes: [11], [11,3,1000].