需要使用tf_state_ops.assign()
或tf.compat.v1.scatter_update()
实现这一功能。下面是一个使用tf_state_ops.assign()
.
import tensorflow as tf
import tensorflow.keras.layers as KL
import tensorflow_probability as tfp
from tensorflow.python.ops import state_ops as tf_state_ops
class CustomLayer(KL.Layer):
"""custom layer for storing moving average of nth percentile of some values"""
def __init__(
self,
percentile: float = 66.67,
name: str = "thresh",
alpha: float = 0.9,
moving_thresh_initializer: float = 0.0,
**kwargs
):
"""Layer initialization
Args:
percentile (float, optional): percentile for thresholding. Defaults to 66.67.
name (str, optional): name for the tensor. Defaults to "thresh".
alpha (float, optional): decay value for moving average. Defaults to 0.9.
moving_thresh_initializer (float, optional): Initial threshold. Defaults to 0.0
"""
super().__init__(trainable=False, name=name, **kwargs)
self.percentile = percentile
self.moving_thresh_initializer = tf.constant_initializer(
value=moving_thresh_initializer
)
self.alpha = alpha
def build(self, input_shape):
"""build the layer"""
shape = ()
self.moving_thresh = self.add_weight(
shape=shape,
name="moving_thresh",
initializer=self.moving_thresh_initializer,
trainable=False,
)
return super().build(input_shape)
def call(self, inputs: tf.Tensor) -> tf.Tensor:
"""call method on the layer
Args:
inputs (tf.Tensor): samplewise values for a given batch
Returns:
tf.Tensor (shape = ()): threshold value
"""
batch_thresh = tfp.stats.percentile(
inputs, q=self.percentile, axis=[0], interpolation="linear"
)
self.moving_thresh = tf_state_ops.assign(
self.moving_thresh,
self.alpha * self.moving_thresh + (1.0 - self.alpha) * batch_loss_thresh,
# use_locking=self._use_locking,
)
return self.moving_thresh
def get_config(self) -> dict:
"""Setting up the layer config
Returns:
dict: config key-value pairs
"""
base_config = super().get_config()
config = {
"alpha": self.alpha,
"moving_thresh_initializer": self.moving_thresh_initializer,
"percentile": self.percentile,
"threshhold": self.moving_thresh,
}
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape: tuple) -> tuple:
"""shape of the layer output"""
return ()
上述自定义层可以包含在工作流中,如下所示:
thresholding_layer = CustomLayer()
# Dummy input
x = np.zeros((batch_size, 1))
current_threshold = thresholding_layer(x)
有关使用上述自定义层的更多详细信息以及您的用法,tf.compat.v1.scatter_update()
请查看以下链接。
https://medium.com/dive-into-ml-ai/custom-layer-with-memory-in-keras-1d0c03e722e9