我正在构建张量流概率模型的混合体。一个给定模型的联合分布是:
one_network_prior = tfd.JointDistributionNamed(
model=dict(
mu_g=tfb.Sigmoid(
low=-1.0,
high=1.0,
validate_args=True,
name="mu_g"
)(
tfd.Normal(
loc=tf.zeros((D,)),
scale=0.5,
validate_args=True
)
),
epsilon=tfd.Gamma(
concentration=0.4,
rate=1.0,
validate_args=True,
name="epsilon"
),
mu_s=lambda mu_g, epsilon: tfb.Sigmoid(
low=-1.0,
high=1.0,
validate_args=True,
name="mu_s"
)(
tfd.Normal(
loc=tf.stack(
[
mu_g
] * S
),
scale=epsilon,
validate_args=True
)
),
sigma=tfd.Gamma(
concentration=0.3,
rate=1.0,
validate_args=True,
name="sigma"
),
mu_s_t=lambda mu_s, sigma: tfb.Sigmoid(
low=-1.0,
high=1.0,
validate_args=True,
name="mu_s_t"
)(
tfd.Normal(
loc=tf.stack(
[
mu_s
] * T
),
scale=sigma,
validate_args=True
)
)
)
)
然后我需要“混合”模型,但这种混合是相当自定义的,我在自定义log_prob_fn
对数概率函数中手动完成:
def log_prob_fn(
mu_g,
epsilon,
mu_s,
sigma,
mu_s_t,
kappa,
spatial,
observed
):
log_probs_per_network = []
probs_per_network = []
for l in range(L):
log_probs_per_network.append(
tf.reduce_sum(
one_network_prior.log_prob(
{
"mu_g": mu_g[l],
"epsilon": epsilon[l],
"mu_s": mu_s[l],
"sigma": sigma[l],
"mu_s_t": mu_s_t[l]
}
)
)
)
dist = tfb.Sigmoid(
low=-1.0,
high=1.0,
validate_args=True
)(
tfd.Normal(
loc=tf.stack(
[
mu_s_t[l]
] * N
),
scale=kappa
)
)
probs_per_network.append(
tf.reduce_prod(
dist.prob(
observed
),
axis=-1
)
)
kappa_log_prob = kappa_prior.log_prob(
kappa
)
mixed_probs = (
spatial
*
tf.stack(
probs_per_network,
axis=-1
)
)
margin_prob = tf.reduce_sum(
mixed_probs,
axis=-1
)
mix_log_prob = tf.reduce_sum(
tf.math.log(
margin_prob
)
)
return (
tf.reduce_sum(
log_probs_per_network
)
+ kappa_log_prob
+ mix_log_prob
)
(我知道这个功能效率不高,但我不能从我以前的模型中轻松采样 - 批量形状 - 所以我现在不得不遍历模型)
请注意,分配dist
是为每个网络动态创建的。
然后目标是使用这个模型并将其拟合到数据中。我使用 生成了一个初始状态,one_network_prior
然后我手动混合了数据以获得(N, T, S, D)
观察到的数据,该数据将被馈送到 MCMC,如下所示:
hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=lambda *params: log_prob_fn(
*params,
observed=observed
),
step_size=0.065,
num_leapfrog_steps=5
)
unconstraining_bijectors = [
tfb.Sigmoid(
low=-1.0,
high=1.0
),
tfb.Softplus(),
tfb.Sigmoid(
low=-1.0,
high=1.0
),
tfb.Softplus(),
tfb.Sigmoid(
low=-1.0,
high=1.0
),
tfb.Softplus(),
tfb.SoftmaxCentered()
]
transformed_kernel = tfp.mcmc.TransformedTransitionKernel(
inner_kernel=hmc_kernel,
bijector=unconstraining_bijectors
)
adapted_kernel = tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel=transformed_kernel,
num_adaptation_steps=400,
target_accept_prob=0.65
)
@tf.function
def run_chain(initial_state, num_results=1000, num_burnin_steps=500):
return tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=num_burnin_steps,
current_state=initial_state,
kernel=adapted_kernel
)
samples, kernel_results = run_chain(
initial_state=init_state,
num_results=20000,
num_burnin_steps=5000
)
但是当我运行该run_chain
函数时,经过几次迭代后,我得到了一个错误:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-45-2ff348713067> in <module>
----> 1 samples, kernel_results = run_chain(
2 initial_state=init_state,
3 num_results=20000,
4 num_burnin_steps=5000
5 )
~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
778 else:
779 compiler = "nonXla"
--> 780 result = self._call(*args, **kwds)
781
782 new_tracing_count = self._get_tracing_count()
~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
844 *args, **kwds)
845 # If we did not create any variables the trace we have is good enough.
--> 846 return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds) # pylint: disable=protected-access
847
848 def fn_with_cond(*inner_args, **inner_kwds):
~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _filtered_call(self, args, kwargs, cancellation_manager)
1841 `args` and `kwargs`.
1842 """
-> 1843 return self._call_flat(
1844 [t for t in nest.flatten((args, kwargs), expand_composites=True)
1845 if isinstance(t, (ops.Tensor,
~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1921 and executing_eagerly):
1922 # No tape is watching; skip to running the function.
-> 1923 return self._build_call_outputs(self._inference_function.call(
1924 ctx, args, cancellation_manager=cancellation_manager))
1925 forward_backward = self._select_forward_and_backward_functions(
~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
543 with _InterpolateFunctionError(self):
544 if cancellation_manager is None:
--> 545 outputs = execute.execute(
546 str(self.signature.name),
547 num_outputs=self._num_outputs,
~/.pyenv/versions/3.8.0/envs/Kong2019-env/lib/python3.8/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
57 try:
58 ctx.ensure_initialized()
---> 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
InvalidArgumentError: assertion failed: [Argument `scale` must be positive.] [Condition x > 0 did not hold element-wise:] [x (mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/strided_slice_1:0) = ] [-nan]
[[{{node mcmc_sample_chain/trace_scan/while/body/_415/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/body/_2366/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/body/_3200/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/JointDistributionNamed/log_prob/Normal/assert_positive/assert_less/Assert/AssertGuard/else/_3580/mcmc_sample_chain/trace_scan/while/smart_for_loop/while/simple_step_size_adaptation___init__/_one_step/transformed_kernel_one_step/mh_one_step/hmc_kernel_one_step/leapfrog_integrate/while/leapfrog_integrate_one_step/maybe_call_fn_and_grads/value_and_gradients/JointDistributionNamed/log_prob/Normal/assert_positive/assert_less/Assert/AssertGuard/Assert}}]] [Op:__inference_run_chain_169987]
Function call stack:
run_chain
我的理解是负片kappa
被喂给了dist
,但是通过Softplus
双射器这不应该是可能的吗?当反转我所有的双射器时,该函数仍在运行,这很奇怪,因为尺寸应该被破坏SoftmaxCentered
。
所以我觉得我的双射器只是被忽略了。我错过了什么 ?
在此先感谢您的帮助 :)