1

我正在尝试训练不同的 EfficientNet 模型来对音乐流派进行分类。首先,我将我拥有的音频转换为 Mel-Spectrograms 并用它们训练一个模型。这是一个多标签分类,我使用 sigmoid 层作为最后的预测层,使用二元交叉熵作为损失函数。我使用了不同的指标,例如二进制准确性和召回率,但主要是二进制准确性。我使用迁移学习首先将现有的 EfficientNetB1(或 EfficientNetB3)模型迁移到我的新数据集上,该数据集由大约 67k 图片(频谱图)和 16 个流派/类组成,标签由 1 到 4 个类(one-hot -编码)。之后模型被保存并再次加载以进行微调,大多数层都未冻结(除了 BatchNormalization 层)。迁移模型的学习率为 0.01,微调为 0.001。

因此,在转移模型后,在进行微调时,训练最终(虽然不是在所有运行中)由于以下异常而停止(我还包括整个训练日志):

Python Tensorflow Version (nightly version) - 2.4.0-dev20200704
CLASSES: {32, 1, 66, 41, 42, 10, 12, 76, 107, 17, 18, 1235, 21, 25, 250, 27}
TRAINING IMAGES: 59142
VALIDATION IMAGES: 7805
Total number of Training samples: 
TRAIN: 59136
VAL: 7800
Steps per epoch: 7392
Validation Steps: 975
2021-02-23 14:15:07.741865: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
2021-02-23 14:15:07.771234: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:86:00.0 name: Tesla K80 computeCapability: 3.7
coreClock: 0.8235GHz coreCount: 13 deviceMemorySize: 11.17GiB deviceMemoryBandwidth: 223.96GiB/s
2021-02-23 14:15:07.771295: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2021-02-23 14:15:07.806741: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2021-02-23 14:15:07.808235: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10
2021-02-23 14:15:07.808550: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
2021-02-23 14:15:07.810184: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10
2021-02-23 14:15:07.810985: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.10
2021-02-23 14:15:07.829478: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
2021-02-23 14:15:07.832361: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu devices: 0
2021-02-23 14:15:07.832759: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-02-23 14:15:07.866705: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] CPU Frequency: 3092840000 Hz
2021-02-23 14:15:07.867161: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x60a5b70 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2021-02-23 14:15:07.867204: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2021-02-23 14:15:07.953336: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x609d8b0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2021-02-23 14:15:07.953391: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Tesla K80, Compute Capability 3.7
2021-02-23 14:15:07.955334: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:86:00.0 name: Tesla K80 computeCapability: 3.7
coreClock: 0.8235GHz coreCount: 13 deviceMemorySize: 11.17GiB deviceMemoryBandwidth: 223.96GiB/s
2021-02-23 14:15:07.955389: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2021-02-23 14:15:07.955456: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2021-02-23 14:15:07.955499: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10
2021-02-23 14:15:07.955542: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
2021-02-23 14:15:07.955584: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10
2021-02-23 14:15:07.955625: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.10
2021-02-23 14:15:07.955667: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
2021-02-23 14:15:07.959120: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu devices: 0
2021-02-23 14:15:07.959174: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2021-02-23 14:15:08.467663: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1257] Device interconnect StreamExecutor with strength 1 edge matrix:
2021-02-23 14:15:08.467725: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1263]      0 
2021-02-23 14:15:08.467737: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1276] 0:   N 
2021-02-23 14:15:08.470827: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1402] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10617 MB memory) -> physical GPU (device: 0, name: Tesla K80, pci bus id: 0000:86:00.0, compute capability: 3.7)
2021-02-23 14:15:14.300401: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session started.
2021-02-23 14:15:14.300477: I tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1391] Profiler found 1 GPUs
2021-02-23 14:15:14.300872: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcupti.so.10.1'; dlerror: libcupti.so.10.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: ~/.local/lib:/usr/local/cuda-10.1/lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
2021-02-23 14:15:14.300989: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcupti.so'; dlerror: libcupti.so: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: ~/.local/lib:/usr/local/cuda-10.1/lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
2021-02-23 14:15:14.301009: E tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1441] function cupti_interface_->Subscribe( &subscriber_, (CUpti_CallbackFunc)ApiCallback, this)failed with error CUPTI could not be loaded or symbol could not be found.
START OF TRAINING NOW!
Epoch 1/10
WARNING:tensorflow:From /home/user06/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
2021-02-23 14:15:24.859460: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2021-02-23 14:15:25.675793: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
7392/7392 [==============================] - ETA: 0s - loss: 0.3681 - binary_accuracy: 0.8819 - recall: 0.0837   
Epoch 00001: val_loss improved from inf to 0.32435, saving model to /home/user06/GenreClassification/GenreClassification/data/checkpoints/EfficientNetB1_transfer_01_old-test4_01_0.32.hdf5
7392/7392 [==============================] - 2204s 298ms/step - loss: 0.3681 - binary_accuracy: 0.8819 - recall: 0.0837 - val_loss: 0.3244 - val_binary_accuracy: 0.8756 - val_recall: 0.1175
Epoch 2/10
7392/7392 [==============================] - ETA: 0s - loss: 0.3574 - binary_accuracy: 0.8835 - recall: 0.0803   
Epoch 00002: val_loss did not improve from 0.32435
7392/7392 [==============================] - 1776s 240ms/step - loss: 0.3574 - binary_accuracy: 0.8835 - recall: 0.0803 - val_loss: 0.3474 - val_binary_accuracy: 0.8754 - val_recall: 0.0971
Epoch 3/10
7392/7392 [==============================] - ETA: 0s - loss: 0.3565 - binary_accuracy: 0.8840 - recall: 0.0788   
Epoch 00003: val_loss did not improve from 0.32435
7392/7392 [==============================] - 1582s 214ms/step - loss: 0.3565 - binary_accuracy: 0.8840 - recall: 0.0788 - val_loss: 0.3469 - val_binary_accuracy: 0.8683 - val_recall: 0.1250
Epoch 4/10
7392/7392 [==============================] - ETA: 0s - loss: 0.3564 - binary_accuracy: 0.8840 - recall: 0.0790     
Epoch 00004: val_loss did not improve from 0.32435
7392/7392 [==============================] - 1564s 212ms/step - loss: 0.3564 - binary_accuracy: 0.8840 - recall: 0.0790 - val_loss: 0.3337 - val_binary_accuracy: 0.8765 - val_recall: 0.1108
Epoch 5/10
7392/7392 [==============================] - ETA: 0s - loss: 0.3569 - binary_accuracy: 0.8836 - recall: 0.0781   
Epoch 00005: val_loss did not improve from 0.32435
7392/7392 [==============================] - 1564s 212ms/step - loss: 0.3569 - binary_accuracy: 0.8836 - recall: 0.0781 - val_loss: 0.3472 - val_binary_accuracy: 0.8735 - val_recall: 0.0962
Epoch 6/10
7392/7392 [==============================] - ETA: 0s - loss: 0.3567 - binary_accuracy: 0.8837 - recall: 0.0801     
Epoch 00006: val_loss did not improve from 0.32435
7392/7392 [==============================] - 1568s 212ms/step - loss: 0.3567 - binary_accuracy: 0.8837 - recall: 0.0801 - val_loss: 0.3266 - val_binary_accuracy: 0.8783 - val_recall: 0.1191
Epoch 7/10
7392/7392 [==============================] - ETA: 0s - loss: 0.3565 - binary_accuracy: 0.8842 - recall: 0.0788   
Epoch 00007: val_loss did not improve from 0.32435
7392/7392 [==============================] - 1568s 212ms/step - loss: 0.3565 - binary_accuracy: 0.8842 - recall: 0.0788 - val_loss: 0.3410 - val_binary_accuracy: 0.8756 - val_recall: 0.1026
Epoch 8/10
7392/7392 [==============================] - ETA: 0s - loss: 0.3567 - binary_accuracy: 0.8841 - recall: 0.0788    
Epoch 00008: val_loss did not improve from 0.32435
7392/7392 [==============================] - 1683s 228ms/step - loss: 0.3567 - binary_accuracy: 0.8841 - recall: 0.0788 - val_loss: 0.3749 - val_binary_accuracy: 0.8727 - val_recall: 0.0643
Epoch 9/10
7392/7392 [==============================] - ETA: 0s - loss: 0.3575 - binary_accuracy: 0.8839 - recall: 0.0780   
Epoch 00009: val_loss did not improve from 0.32435
7392/7392 [==============================] - 1634s 221ms/step - loss: 0.3575 - binary_accuracy: 0.8839 - recall: 0.0780 - val_loss: 0.3842 - val_binary_accuracy: 0.8561 - val_recall: 0.1199
Epoch 10/10
7392/7392 [==============================] - ETA: 0s - loss: 0.3569 - binary_accuracy: 0.8840 - recall: 0.0776   
Epoch 00010: val_loss did not improve from 0.32435
7392/7392 [==============================] - 1585s 214ms/step - loss: 0.3569 - binary_accuracy: 0.8840 - recall: 0.0776 - val_loss: 0.3378 - val_binary_accuracy: 0.8808 - val_recall: 0.1001
2021-02-23 18:54:20.781432: I tensorflow/core/profiler/lib/profiler_session.cc:164] Profiler session started.
2021-02-23 18:54:20.781518: E tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1441] function cupti_interface_->Subscribe( &subscriber_, (CUpti_CallbackFunc)ApiCallback, this)failed with error CUPTI could not be loaded or symbol could not be found.
START OF TRAINING NOW!
Epoch 1/10
7392/7392 [==============================] - ETA: 0s - loss: 0.3620 - binary_accuracy: 0.8917 - recall_1: 9.7708e-05   
Epoch 00001: val_loss improved from inf to 0.36043, saving model to /home/user06/GenreClassification/GenreClassification/data/checkpoints/EfficientNetB1_transfer_01_old-test4_tuned_01_0.36.hdf5
7392/7392 [==============================] - 3149s 426ms/step - loss: 0.3620 - binary_accuracy: 0.8917 - recall_1: 9.7708e-05 - val_loss: 0.3604 - val_binary_accuracy: 0.8823 - val_recall_1: 0.0000e+00
Epoch 2/10
7392/7392 [==============================] - ETA: 0s - loss: 0.3578 - binary_accuracy: 0.8918 - recall_1: 0.0000e+00Traceback (most recent call last):
  File "evaluate.py", line 87, in <module>
    r = evaluate(sys.argv[1])
IndexError: list index out of range

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "evaluate.py", line 89, in <module>
    r = evaluate()
  File "evaluate.py", line 73, in evaluate
    model_name, model = train.train_model(model, TRAIN_GEN, VAL_GEN, VAL_STEPS, strategy, CLASS_WEIGHTS, t_step=2)
  File "/home/user06/GenreClassification/GenreClassification/src/process/train.py", line 205, in train_model
    class_weight=class_weights)
  File "/home/user06/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/user06/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1137, in fit
    return_dict=True)
  File "/home/user06/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/user06/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1383, in evaluate
    tmp_logs = test_function(iterator)
  File "/home/user06/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/home/user06/.local/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 814, in _call
    results = self._stateful_fn(*args, **kwds)
  File "/home/user06/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2844, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
  File "/home/user06/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1847, in _filtered_call
    cancellation_manager=cancellation_manager)
  File "/home/user06/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1923, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/home/user06/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 550, in call
    ctx=ctx)
  File "/home/user06/.local/lib/python3.6/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  assertion failed: [predictions must be >= 0] [Condition x >= y did not hold element-wise:] [x (EffNetB1/top_predictions_custom/Sigmoid:0) = ] [[nan nan nan...]...] [y (Cast_8/x:0) = ] [0]
     [[{{node assert_greater_equal/Assert/AssertGuard/else/_21/assert_greater_equal/Assert/AssertGuard/Assert}}]]
     [[div_no_nan_2/ReadVariableOp/_56]]
  (1) Invalid argument:  assertion failed: [predictions must be >= 0] [Condition x >= y did not hold element-wise:] [x (EffNetB1/top_predictions_custom/Sigmoid:0) = ] [[nan nan nan...]...] [y (Cast_8/x:0) = ] [0]
     [[{{node assert_greater_equal/Assert/AssertGuard/else/_21/assert_greater_equal/Assert/AssertGuard/Assert}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_test_function_327987]

Function call stack:
test_function -> test_function

terminate called without an active exception
Aborted (core dumped)

这个异常似乎是随机抛出的,因为有时训练效果很好,有时却不行,这就是为什么这对我来说是一个大问题。我现在处理这个问题已经超过 1 周了,似乎找不到解决方案。我尝试更改批量大小,我通常使用 8,因为在我使用的大多数 GPU 上,内存不足以容纳一批中的更多图片。我试图不使用类 wights,但模型有时仍然会失败。我试图省略某些指标,只使用准确性,但训练在某些时候仍然失败。

非常感谢任何帮助或指导。有关更多信息,请告诉我,这是我关于堆栈溢出的第二个问题,我对此很陌生。

我注意到,我还可以包含我的模型创建和训练的代码,所以首先是模型创建(build_efficientNetB3_model):

from tensorflow.keras.applications.efficientnet import EfficientNetB3, EfficientNetB1, EfficientNetB0
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input, concatenate, Conv2D, BatchNormalization, \
    Activation, MaxPool2D, Flatten, Dropout
from tensorflow.keras.optimizers import SGD, RMSprop, Adam
from tensorflow.keras.metrics import BinaryAccuracy, FalseNegatives, FalsePositives, Recall, Precision
import config as conf

def create_eff(TRANSFER: conf.Transfer):
    input_tensor = Input(shape=(conf.IM_SIZE[0], conf.IM_SIZE[1], conf.IM_DIM), name='input_specs')

    if conf.EFFICIENT_VERSION is conf.EfficientNet.B0:
        # rescaling of image size from 300x300 to 224x224 is done in one of the first layers of B0
        base_model = EfficientNetB0(include_top=False, input_tensor=input_tensor, weights='imagenet')
    elif conf.EFFICIENT_VERSION is conf.EfficientNet.B1:
        # rescaling of image size from 300x300 to 240x240 is done in one of the first layers of B1
        base_model = EfficientNetB1(include_top=False, input_tensor=input_tensor, weights='imagenet')
    elif conf.EFFICIENT_VERSION is conf.EfficientNet.B3:
        base_model = EfficientNetB3(include_top=False, input_tensor=input_tensor, weights='imagenet')

    # Freeze the pretrained weights here for Transfer Learning, all layers besides the newly added ones.
    if TRANSFER.value:
        base_model.trainable = False

    # Rebuild top
    x = GlobalAveragePooling2D(name="top_avg_pool_custom")(base_model.output)
    x = BatchNormalization(name="top_batch_normalization_custom")(x)
    x = Dropout(0.2, name="top_dropout_custom")(x)

    model = Model(base_model.input, x)

    print(f'Is the model trainable? Default is trainable, not trainable in case all layers are set so not trainable. '
          f'[{model.trainable}]')

    return model


def build_efficientNetB3_model(classes: set, TRANSFER: conf.Transfer):
    conf.CLASSES = classes

    eff_model = create_eff(TRANSFER)

    # A logistic layer -- with x classes
    predictions = Dense(len(conf.CLASSES), activation='sigmoid', name='top_predictions_custom')(eff_model.output)
    model = Model(inputs=eff_model.input, outputs=predictions, name=f"EffNet{conf.EFFICIENT_VERSION.value}")

    for i, layer in enumerate(model.layers):
        print(i, layer.name, layer.trainable)

    lr = 1e-3
    if TRANSFER.value:
        lr = 1e-2
    optimizer = Adam(learning_rate=lr)

    model.compile(
        optimizer=optimizer, loss='binary_crossentropy', metrics=[BinaryAccuracy(), Recall()]  # Specificity(), Precision()
    ) 

    return model


def unfreeze_model(model):
    # We unfreeze the top layers while leaving BatchNorm layers frozen
    for layer in model.layers:
        if not isinstance(layer, BatchNormalization):
            layer.trainable = True

    # Once the model has converged to new data, the learning rate for re-training on unfrozen model should be low
    optimizer = Adam(learning_rate=1e-3)
    model.compile(
        optimizer=optimizer, loss='binary_crossentropy', metrics=[BinaryAccuracy(), Recall()]  # Specificity(), Precision()
    )

现在训练(文件是评估.py和train.py):

评估.py

def evaluate(argv=None):
    # Evaluate the command line argument, should for now be YES or NO to activate Transfer Learning
    try:
        args = getopt.getopt(argv, "")
        print(args)
    except getopt.GetoptError:
        print('args exception')

    # default for Transfer Learning is YES, argument needs to be set to not use Transfer Learning
    TRANSFER = Transfer['YES']
    if args[1] is not None:
        TRANSFER = Transfer[args[1]]

    subtrack_dict = uio.load_pickle(conf.SUB_PARSED_TRACKS_V3)
    TRACKS = subtrack_dict['parsed_tracks'] 
    conf.CURRENT_SET = conf.SET.SUB_LARGE_V3

    # Parse Dataset
    conf.CLASSES, TRAIN, VAL = train.parseTrainSet(tracks=TRACKS)
    # Prepare generators for training
    TRAIN_GEN, VAL_GEN, VAL_STEPS = train.prepare_generators(TRAIN, VAL)

    CLASS_WEIGHTS = cws.calculate_class_weights(tracks=TRACKS)

    strategy = tf.distribute.MirroredStrategy()

    # Get the EfficientNetB3 model
    with strategy.scope():
        MODEL = k_net.build_efficientNetB3_model(conf.CLASSES, TRANSFER)
    if not TRANSFER.value:  # Use no transfer learning
        # Train and return trained net
        model_name, model = train.train_model(MODEL, TRAIN_GEN, VAL_GEN, VAL_STEPS, strategy, CLASS_WEIGHTS)
        del model
    else:
        #  1. Training mittels Transfer Lerning, nur die top Schichten mit Anzahl EPOCHS_TRANSFER an epochen. (t_step=1)
        #  2. Anpassen des Modells mittels einer Funktion wie "unfreeze()", da sollte auch die Anpassung der
        #  Lernrate enthalten sein.
        #  3. Fine-Tuning des angepassten Modells mit Anzahl EPOCHS an Epochen und neuen checkpoints. (t_step=2)
        model_name, model = train.train_model(MODEL, TRAIN_GEN, VAL_GEN, VAL_STEPS, strategy, CLASS_WEIGHTS, t_step=1)
        with strategy.scope():
            k_net.unfreeze_model(model)
        model_name, model = train.train_model(model, TRAIN_GEN, VAL_GEN, VAL_STEPS, strategy, CLASS_WEIGHTS, t_step=2)

火车.py

def train_model(MODEL, training_generator, validation_generator, validation_steps, strategy, class_weights, t_step=-1):
    logdir = os.path.join(conf.TENSORBOARD_LOGS_PATH,
                          conf.RUN_NAME + '_' + datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))

    with strategy.scope():
        if t_step is 2:
            checkdir = os.path.join(conf.CHECKPOINTS_PATH, conf.RUN_NAME + "_tuned"
                                    + "_{epoch:02d}_{val_loss:.2f}.hdf5")
        else:
            checkdir = os.path.join(conf.CHECKPOINTS_PATH, conf.RUN_NAME + "_{epoch:02d}_{val_loss:.2f}.hdf5")
        checkpoint = callbacks.ModelCheckpoint(checkdir, monitor='val_loss', verbose=1, save_best_only=True,
                                               save_weights_only=True, mode='auto', save_freq='epoch')

    if t_step is 1:
        epochs = conf.EPOCHS_TRANSFER
    else:
        epochs = conf.EPOCHS

    class_weights = class_weights if conf.USE_CLASS_WEIGHTS else None

    print('START OF TRAINING NOW!')
    hist = MODEL.fit(x=training_generator, epochs=epochs, verbose=1, callbacks=[checkpoint],
                     validation_data=validation_generator, validation_steps=validation_steps,
                     class_weight=class_weights)

    kio.saveModel(MODEL, hist, conf.CLASSES)  # save on existing model (correct behaviour)

    return conf.RUN_NAME, MODEL
4

0 回答 0