一般来说,实现这一目标的唯一方法是使用昂贵的拒绝采样器。然后你就没有易处理的密度。一般来说,TFP 要求我们所有的分布都具有易于处理的密度(即dist.prob(x)
)。我们确实有一个 autodiff 友好的TruncatedNormal
,或者你注意到HalfNormal
的。
如果你想实现其他东西,它可能很简单:
class Rejection(tfd.Distribution):
def __init__(self, underlying, condition, name=None):
self._u = underlying
self._c = condition
super().__init__(dtype=underlying.dtype,
name=name or f'rejection_{underlying}',
reparameterization_type=tfd.NOT_REPARAMETERIZED,
validate_args=underlying.validate_args,
allow_nan_stats=underlying.allow_nan_stats)
def _batch_shape(self):
return self._u.batch_shape
def _batch_shape_tensor(self):
return self._u.batch_shape_tensor()
def _event_shape(self):
return self._u.event_shape
def _event_shape_tensor(self):
return self._u.event_shape_tensor()
def _sample_n(self, n, seed=None):
return tf.while_loop(
lambda samples: not tf.reduce_all(self._c(samples)),
lambda samples: (tf.where(self._c(samples), samples, self._u.sample(n, seed=seed)),),
(self._u.sample(n, seed=seed),))[0]
d = Rejection(tfd.Normal(0,1), lambda x: x > -.3)
s = d.sample(100).numpy()
print(s.min())