0

我正在尝试为 R ( github ) 中的 keras 包实现自定义层。

我正在实现的层基于此处可用的 AttentionWithContext 层:gist

这是我的代码:

AttentionWithContext <- R6::R6Class("AttentionWithContext",

                     inherit = KerasLayer,

                     public = list(

                       W_regularizer = NULL, 
                       b_regularizer = NULL, 
                       u_regularizer = NULL,
                       W_constraint=NULL, 
                       b_constraint=NULL, 
                       u_constraint=NULL,
                       bias=NULL,
                       b=NULL,
                       W=NULL,
                       u=NULL,
                       supports_masking=NULL,
                       init=NULL,
                       name = NULL,

                      initialize = function(name = 'attention',
                                             W_regularizer = NULL, 
                                             b_regularizer = NULL, 
                                             u_regularizer = NULL,
                                             W_constraint=NULL, 
                                             b_constraint=NULL, 
                                             u_constraint=NULL,
                                             bias=TRUE ) {

                         self$supports_masking = TRUE
                         self$init = keras::initializer_glorot_uniform()
                         self$W_regularizer = W_regularizer
                         self$b_regularizer = b_regularizer
                         self$u_regularizer = u_regularizer
                         self$W_constraint = W_constraint
                         self$b_constraint = b_constraint
                         self$u_constraint = u_constraint
                         self$bias = bias
                         self$name = name
                       },

                       build = function(input_shape) {
                         assertthat::assert_that(length(input_shape) == 3)

                         self$W = self$add_weight(shape = reticulate::tuple(input_shape[[3]],input_shape[[3]], NULL), 
                                                  initializer = self$init,
                                                  name=stringr::str_interp('${self$name}_W'),
                                                  regularizer = self$W_regularizer,
                                                  constraint = self$W_constraint)

                         if (self$bias) {

                           self$b = self$add_weight(shape = reticulate::tuple(input_shape[[3]]), 
                                                    initializer='zero',
                                                    name = stringr::str_interp('${self$name}_b'),
                                                    regularizer = self$b_regularizer,
                                                    constraint = self$b_constraint)
                         }

                         self$u = self$add_weight(shape = reticulate::tuple(input_shape[[3]]), 
                                                  initializer=self$init,
                                                  name = stringr::str_interp('${self$name}_u'),
                                                  regularizer = self$u_regularizer,
                                                  constraint = self$u_constraint)

                       },

                       compute_mask = function(input, input_mask=NULL) {
                         return(NULL)
                       },

                       call = function(x, mask = NULL) {
                         uit = keras::k_squeeze(keras::k_dot(x, keras::k_expand_dims(self$W)), axis=-1)

                         if (self$bias) {
                           uit = uit + self$b
                         }

                         uit = keras::k_tanh(uit)
                         ait = keras::k_dot(uit, self$u)
                         a = keras::k_exp(ait)

                         if (!is.null(mask)) {
                           a = a * keras::k_cast(mask, keras::k_floatx())
                         }

                         a = a/keras::k_cast(keras::k_sum(a, axis = 1, keepdims = TRUE) + keras::k_epsilon(), keras::k_floatx()) 
                         weighted_input = x * keras::k_expand_dims(a)

                         keras::k_sum(weighted_input, axis=1)
                     },

                       compute_output_shape = function(input_shape) {
                          list(input_shape[[1]], input_shape[[3]])
                       }
                     )
)

# define layer wrapper function
layer_attention_with_context <- function(object, W_regularizer = NULL, 
                        b_regularizer = NULL, 
                        u_regularizer = NULL,
                        W_constraint=NULL, 
                        b_constraint=NULL, 
                        u_constraint=NULL, 
                        bias=TRUE,
                        name = 'attention_with_context') {
 create_layer(AttentionWithContext, object, list(W_regularizer =  W_regularizer, 
                                   b_regularizer = b_regularizer, 
                                   u_regularizer = u_regularizer, 
                                   W_constraint= W_constraint, 
                                   b_constraint=b_constraint, 
                                   u_constraint=u_constraint, 
                                   bias=bias,
                                   name = name
))
}

# Example 
model <- keras_model_sequential()
model %>%
layer_embedding(input_dim = 20000,
              output_dim = 128,
              input_length = 30) %>%
layer_lstm(64, return_sequences = TRUE) %>%
layer_attention_with_context() %>%
time_distributed(layer_dense(units=10))

当我运行它时,我收到一条神秘的错误消息:

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  RuntimeError: Evaluation error: TypeError: unsupported operand type(s) for *: 'NoneType' and 'int'. 

我试图探索这个错误,我认为它可能来自这一行: reticulate::tuple(input_shape[[3]],input_shape[[3]], NULL)

在原始代码中,在 python 中,我们可以看到: (input_shape[-1], input_shape[-1],)

我找不到在 R 中创建这种结构的方法。

有任何想法吗 ?

4

0 回答 0