3

我正在对具有可变尺寸的图像堆栈训练深度学习模型。(Shape = [Batch, None, 256, 256, 1]), 其中 None 可以是可变的。

我使用 tf.RaggedTensor.merge_dimsions(0,1)将参差不齐的张量转换为 的形状[None, 256, 256, 1]以运行到预训练的 keras CNN 模型。

但是,使用 KerasLayer API 会导致以下错误:TypeError: the object of type 'RaggedTensor' has no len()

当我.merge_dimsions在 KerasLayer 之外应用并将张量传递给相同的预训练模型时,我没有收到此错误。

import tensorflow as tf

# Synthetic Data Pipeline
def synthetic_gen():
  varShape = tf.random.uniform((), minval=1, maxval=12, dtype=tf.int32)
  image = tf.random.normal((varShape, 256, 256, 1))
  image = tf.RaggedTensor.from_tensor(image, ragged_rank=1)
  yield image

ds = tf.data.Dataset.from_generator(synthetic_gen, output_signature=(tf.RaggedTensorSpec(shape=(None, 256, 256, 1), dtype=tf.float32, ragged_rank=1)))
ds = ds.repeat().batch(8)
print(next(iter(ds)).shape)

# Build Model
inputs = tf.keras.Input(
    type_spec=tf.RaggedTensorSpec(
        shape=(8, None, 256, 256, 1), 
        dtype=tf.float32, 
        ragged_rank=1))

ResNet50 = tf.keras.applications.ResNet50(
    include_top=True, 
    input_shape=(256, 256, 1),
    weights=None)

def merge(x):
  x = x.merge_dims(0, 1)
  return x
x = tf.keras.layers.Lambda(merge)(inputs)
merged_inputs = x
# x = ResNet50(x) # Uncommenting this will result in `model` producing an error when run for inference.

model = tf.keras.Model(inputs, x)

# Run inference
data = next(iter(ds))
model(data).shape # Will be an error if ResNet50 is used

这是一个演示该问题的 colab 笔记本。https://colab.research.google.com/drive/1kN78mf4_oNqxWOluV054NlqmakC5msli?usp=sharing

4

1 回答 1

2

不确定以下答案或解决方法对于复杂的网络设计是否稳定。但这里有一些指示。你得到的理由

Ragged Tensors have no len()

是因为ResNet模型,正如它所期望的那样,tensor而不是ragged_tensor. 但是,我不确定ResNet(weights=None)是否能够ragged_tensor直接使用。所以,如果我们能在ResNet被馈入之前转换参差不齐的数据,也许它不会抱怨。以下是根据此的完整工作代码。但请注意,可能有一些有效的方法是可能的。


数据

import tensorflow as tf

# Synthetic Data Pipeline
def synthetic_gen():
  varShape = tf.random.uniform((), minval=1, maxval=12, dtype=tf.int32)
  image = tf.random.normal((varShape, 256, 256, 1))
  image = tf.RaggedTensor.from_tensor(image, ragged_rank=1)
  yield image

ds = tf.data.Dataset.from_generator(synthetic_gen, 
                                    output_signature=(tf.RaggedTensorSpec(
                                        shape=(None, 256, 256, 1), 
                                        dtype=tf.float32, ragged_rank=1
                                        )
                                    )
                                )
ds = ds.repeat().batch(8)

基本型号

# Build Model
inputs = tf.keras.Input(
    type_spec=tf.RaggedTensorSpec(
        shape=(8, None, 256, 256, 1), 
        dtype=tf.float32, 
        ragged_rank=1))

ResNet50 = tf.keras.applications.ResNet50(
    include_top=True, 
    input_shape=(256, 256, 1),
    weights=None)

def merge(x):
  x = x.merge_dims(0, 1)
  return x

衣衫褴褛的模型

在这里,我们在将数据传递给ResNetragged_tensor之前转换为。tensor

class RagModel(tf.keras.Model):
    def __init__(self):
        super(RagModel, self).__init__()
        # base models 
        self.a = tf.keras.layers.Lambda(merge)
        # convert: tensor = ragged_tensor.to_tensor()
        self.b = tf.keras.layers.Lambda(lambda x: x.to_tensor())
        self.c = ResNet50
    
    def call(self, inputs, training=None, plot=False, **kwargs):
        x = self.a(inputs)
        x = self.b(x) if not plot else x
        x = self.c(x)
        return x
    
    # a helper function to plot 
    def build_graph(self):
        x = tf.keras.Input(type_spec=tf.RaggedTensorSpec(
            shape=(8, None, 256, 256, 1),
            dtype=tf.float32, ragged_rank=1)
        )
        return tf.keras.Model(inputs=[x],
                              outputs=self.call(x, plot=True))
   
x_model = RagModel()

data = next(iter(ds)); print(data.shape)
x_model(data).shape 
(8, None, 256, 256, 1)
TensorShape([39, 1000])

阴谋

tf.keras.utils.plot_model(x_model.build_graph(), 
              show_shapes=True, show_layer_names=True)

在此处输入图像描述

x_model.build_graph().summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         [(8, None, 256, 256, 1)]  0         
_________________________________________________________________
lambda_2 (Lambda)            (None, 256, 256, 1)       0         
_________________________________________________________________
resnet50 (Functional)        (None, 1000)              25630440  
=================================================================
Total params: 25,630,440
Trainable params: 25,577,320
Non-trainable params: 53,120
_________________________________________________________________
于 2021-10-18T21:01:23.490 回答