0

我在 R 中使用 fit_generator 时出错...这是我的代码..`

model <- keras_model_sequential()

model %>%
  layer_conv_2d(32, c(3,3), input_shape = c(64, 64, 3)) %>%
  layer_activation("relu") %>%
  layer_max_pooling_2d(pool_size = c(2,2)) %>%
  layer_conv_2d(32, c(3, 3)) %>%
  layer_activation("relu") %>%
  layer_max_pooling_2d(pool_size = c(2, 2)) %>%
  layer_flatten() %>%
  layer_dense(128) %>%
  layer_activation("relu") %>%
  layer_dense(128) %>%
  layer_activation("relu") %>%
  layer_dense(2) %>%
  layer_activation("softmax")

opt <- optimizer_adam(lr = 0.001, decay = 1e-6)

model %>%
  compile(loss = "categorical_crossentropy", optimizer = opt, metrics = "accuracy")

train_gen <- image_data_generator(rescale = 1./255,
                                  shear_range = 0.2,
                                  zoom_range = 0.2,
                                  horizontal_flip = T)

test_gen <- image_data_generator(rescale = 1./255)

train_set = train_gen$flow_from_directory('dataset/training_set',
                                          target_size = c(64, 64),
                                          class_mode = "categorical")

test_set = test_gen$flow_from_directory('dataset/test_set',
                                        target_size = c(64, 64),
                                        batch_size = 32,
                                        class_mode = 'categorical')

model$fit_generator(train_set,
                    steps_per_epoch = 50,
                    epochs = 10)

错误:py_call_impl 中的错误(可调用,dots$args,dots$keywords):StopIteration:'float' 对象不能解释为整数

如果我放置验证集,它也会有另一个错误 bool(validation_data)。浮动错误..

4

1 回答 1

1

如果没有最小的可重复示例,很难帮助您。

我猜您在尝试运行时会收到此错误

train_set = train_gen$flow_from_directory('dataset/training_set',
                                          target_size = c(64, 64),
                                          class_mode = "categorical")

在这里,您自己使用reticulate而不是keras(R 包)包装器调用 python 函数。这可能有效,但您必须更明确地了解 type 和 use target_size = as.integer(c(64, 64)),因为 python 需要一个整数。

或者,我建议查看包flow_images_from_directory()中包含的功能keras


这同样适用于

model$fit_generator(train_set,
                    steps_per_epoch = 50,
                    epochs = 10)

我建议调查

model %>% 
  fit_generator()

相反,它是keras包的一部分。

于 2018-05-09T22:44:10.507 回答