1

我试图了解 CTC 实现在 TensorFlow 中的工作原理。我写了一个简单的例子来测试 CTC 功能,但出于某种原因,我inf对一些目标/输入值感到很困惑,我确定为什么会这样!?

代码:

import tensorflow as tf
import numpy as np

# https://github.com/philipperemy/tensorflow-ctc-speech-recognition/blob/master/utils.py
def sparse_tuple_from(sequences, dtype=np.int32):
    """Create a sparse representention of x.
    Args:
        sequences: a list of lists of type dtype where each element is a sequence
    Returns:
        A tuple with (indices, values, shape)
    """
    indices = []
    values = []

    for n, seq in enumerate(sequences):
        indices.extend(zip([n] * len(seq), range(len(seq))))
        values.extend(seq)

    indices = np.asarray(indices, dtype=np.int64)
    values = np.asarray(values, dtype=dtype)
    shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)

    return indices, values, shape

batch_size = 1
seq_length = 2
n_labels = 2

seq_len = tf.placeholder(tf.int32, [None])
targets = tf.sparse_placeholder(tf.int32)
logits = tf.constant(np.random.random((batch_size, seq_length, n_labels+1)),dtype=tf.float32) # +1 for the blank label
loss = tf.reduce_mean(tf.nn.ctc_loss(targets, logits, seq_len, time_major = False))


with tf.Session() as sess:
    for it in range(10):
        rand_target = np.random.randint(n_labels, size=(seq_length))
        sample_target = sparse_tuple_from([rand_target])

        logitsval = sess.run(logits)
        lossval = sess.run(loss, feed_dict={seq_len: [seq_length], targets: sample_target})
        print('******* Iter: %d *******'%it)
        print('logits:', logitsval)
        print('rand_target:', rand_target)
        print('rand_sparse_target:', sample_target)
        print('loss:', lossval)
        print()

样本输出:

******* Iter: 0 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521

******* Iter: 1 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [1 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([1, 1], dtype=int32), array([1, 2]))
loss: inf

******* Iter: 2 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521

******* Iter: 3 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [1 0]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([1, 0], dtype=int32), array([1, 2]))
loss: 1.59766

******* Iter: 4 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 0]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 0], dtype=int32), array([1, 2]))
loss: inf

******* Iter: 5 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521

******* Iter: 6 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [1 0]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([1, 0], dtype=int32), array([1, 2]))
loss: 1.59766

******* Iter: 7 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [1 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([1, 1], dtype=int32), array([1, 2]))
loss: inf

******* Iter: 8 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 1]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 1], dtype=int32), array([1, 2]))
loss: 2.61521

******* Iter: 9 *******
logits: [[[ 0.10151503  0.88581538  0.56466645]
  [ 0.76043415  0.52718711  0.01166286]]]
rand_target: [0 0]
rand_sparse_target: (array([[0, 0],
       [0, 1]]), array([0, 0], dtype=int32), array([1, 2]))
loss: inf

知道我在那里想念什么!?

4

1 回答 1

2

仔细查看您的输入文本(rand_target),我相信您会看到一些与 inf 损失值相关的简单模式;-)

对正在发生的事情的简短解释:CTC 通过允许重复每个字符来对文本进行编码,并且它还允许在字符之间插入非字符标记(称为“CTC 空白标签”)。撤消此编码(或解码)仅意味着丢弃重复的字符,然后丢弃所有空白。举一些例子(“...”对应于文本,“...”对应于编码,“-”对应于空白标签):

  • "to" -> 'tttooo', or 'to' or 't-oo', or 'to', 等等...
  • "too" -> 'to-o',或 'tttoo---oo',或 '---too--',但不是 'too'(想想解码后的 'too' 的样子)

现在我们已经足够了解为什么您的某些样本会失败:

  • 输入文本的长度为 2
  • 编码的长度为 2
  • 如果输入字符重复(例如'11',或作为python列表:[1, 1]),那么编码它的唯一方法是在两者之间放置一个空格(想想大量解码'11'和'1 -1')。但是编码的长度为 3。
  • 因此,无法将长度为 2 且带有重复字符的文本编码为长度为 2 的编码,因此 TF 损失实现返回 inf

您还可以将编码想象为状态机 - 参见下图。文本“11”可以由所有可能的路径来表示,这些路径从一个开始状态(两个最左边的状态)开始,到一个最终状态(两个最右边的状态)结束。如您所见,最短的路径是“1-1”。

在此处输入图像描述

总而言之,您必须为输入文本中的每个重复字符至少插入一个额外的空白。也许这篇文章有助于理解 CTC:https ://towardsdatascience.com/3797e43a86c

于 2018-09-30T16:01:13.923 回答