我想实现以下功能:
def bernoulli_tensor(shape, probability) -> tf.Tensor:
...
我想实现以下功能:
def bernoulli_tensor(shape, probability) -> tf.Tensor:
...
这是一种可能的实现:
from typing import List
import tensorflow as tf
import tensorflow_probability as tfp
def bernoulli(shape: List[int], truth_probability: float = 0.5) -> tf.Tensor:
distribution = tfp.distributions.Bernoulli(probs=truth_probability)
return tf.cast(distribution.sample(shape), dtype=tf.dtypes.bool)