您可以使用tf.split
KerasLambda
层从 Tensorflow 使用
使用 Lambda 将形状张量拆分为(64,16,16)
所需的(64,1,1,256)
任何索引,然后对其进行子集化。
import numpy as np
import tensorflow as tf
import keras.backend as K
from keras.models import Model
from keras.layers import Input, Lambda
# input data
data = np.ones((3,64,16,16))
# define lambda function to split
def lambda_fun(x) :
x = K.expand_dims(x, 4)
split1 = tf.split(x, 16, 2)
x = K.concatenate(split1, 4)
split2 = tf.split(x, 16, 3)
x = K.concatenate(split2, 4)
return x
## check thet splitting works fine
input = Input(shape= (64,16,16))
ll = Lambda(lambda_fun)(input)
model = Model(inputs=input, outputs=ll)
res = model.predict(data)
print(np.shape(res)) #(3, 64, 1, 1, 256)