0

我想实现以下功能:

def bernoulli_tensor(shape, probability) -> tf.Tensor:
    ...
4

1 回答 1

0

这是一种可能的实现:

from typing import List

import tensorflow as tf
import tensorflow_probability as tfp



def bernoulli(shape: List[int], truth_probability: float = 0.5) -> tf.Tensor:
    distribution = tfp.distributions.Bernoulli(probs=truth_probability)
    return tf.cast(distribution.sample(shape), dtype=tf.dtypes.bool)
于 2020-02-24T19:35:46.827 回答