0

我使用来自“https://github.com/rwightman/pytorch-image-models”的预训练 EfficientNet-v2-b3 模型和 pytorch 框架来训练烟盒。火车如下:

  1. 有1100个等级,每个等级是一种卷烟规格。所有的图像都存储在一个名为 original_dataset_20210805 的目录中,每个子目录代表一类图像。

  2. 删除每个只有少于 50 个图像的类。剩下959个班。

  3. 对于每个类,随机选择 10 张图像进入名为“valData”的验证数据集,随机选择大约 1/10 幅图像进入名为“testData”的测试数据集,剩余的图像被选入名为“trainData”的训练数据集。

  4. 对于每张图像,将其大小调整为 w×h = 200×300。

  5. 为了增强数据,将每个图像旋转 90°,并将每个类别的所有旋转 90° 的图像选择为一个类别。例如,如果有一个香烟规格 A,则将 A 的所有图像旋转 90°,并将所有旋转的图像命名为新的类 A-rot1。然后旋转180°得到A-rot2,旋转270°得到A-rot3。对所有类进行轮换,然后我们有 959×4=3836 个类。

  6. “trainData”有 502172 张图像,“valData”有 38360 张图像,“testData”有 21463 张图像。

  7. 使用预训练模型开始训练。保存最佳模型如下:

if train_acc > last_train_acc and val_acc > last_val_acc:
  save_best_model()

出口列车if train_acc >= 0.99 and val_acc >= 0.99

  1. 在 Epoch 121,火车以 train_acc 0.9911 和 val_acc 0.9902 退出。

  2. 使用最佳模型推断testData,准确度为0.981。使用最好的模型来推断 trainData,我预计准确率应该超过 0.99,但实际上是 0.84。在 valData 上使用模型,实际准确率为 0.82。这很奇怪。然后我在另一个 original_dataset_20210709 上使用最好的模型,它与上面的 original_dataset_20210805 有些不同。并且 original_dataset_20210709 中的图像没有被调整为 w×h=200×300。准确度为 0.969。

  3. 推断代码如下:

def infer(cfg:Config):  
    transform_test = build_transforms(cfg.img_height, cfg.img_width, 'test')
    model = get_model(cfg, 'test')    
    model = model.to(cfg.get_device())
    model.eval() 

    records = []

    sub_classes = os.listdir(cfg.test_data_dirname)
    if sub_classes is None or len(sub_classes) < 1: 
        return 
    sub_classes= sorted(sub_classes)
    classid_dict = {}
    with open(cfg.classid_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()

        for line in lines:
            line = line.strip()
            tokens = line.split(',')
            classid_dict[int(tokens[0])] = tokens[1]
     

    records.append(cfg.test_data_dirname + ',' + str(len(sub_classes)) + ' classes\n')
    records.append('image, prediction result\n')
    start_time = datetime.now()
    elapsed = 0.0    
    count = 0

    with torch.no_grad():
        for sub_cls in sub_classes:
            print(' process sub-directory' + sub_cls)

            files = os.listdir(os.path.join(cfg.test_data_dirname, sub_cls))
            count += len(files)
            if files is None or len(files) < 1:
                print('The sub-directory ' + sub_cls + " has no files")
                continue

            for file in files:
                try:
                    img_path = os.path.join(cfg.test_data_dirname, sub_cls, file)
                    if os.path.isfile(img_path): 
                        img_test = Image.open(img_path)                      
                        img = img_test
                        img = transform_test(img).to(cfg.get_device()) 
                        img = torch.unsqueeze(img, 0) 
                        output = model(img)
                        _, preds = torch.max(output.data, 1)
                        id = preds[0].item() 

                        if classid_dict.get(id) is not None:
                            #print(img_path + ' is predicted as:' + classid_dict[id])
                            records.append(sub_cls + '/' + file + ',' + classid_dict[id] + '\n')
                            log_func(sub_cls + '/' + file + '  is predicted as:' + classid_dict[id]) 
                            pass
                        else: 
                            records.append(sub_cls + '/' + file + ', unknown class\n') 
                except Exception as e:
                    print(str(e))


    elapsed = (datetime.now() - start_time).total_seconds() 
    records.append('elapsed {:.4f} sec,average elapsed {:.4f} sec\n'.format(elapsed, elapsed/count)) 
    
    result_path = os.path.join(cfg.results_dir, 'infer_' + cfg.backbone + '_' + str(cfg.num_classes) + '_'  +  format_datetime(datetime.now()) + '.csv')
    with open(result_path, 'w', encoding='utf-8') as f:
        f.writelines(records)
  1. 我检查了 python 代码,发现可能的原因可能是在输入模型之前对图像的转换。变换代码如下:
    def build_transforms(img_height, img_width, run_mode="train", mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 
        if run_mode == 'train':
            transform = T.Compose([
                    # Use OpenCV to open the image
                    T.Lambda(lambda img: random_rotate_bound(img, 30)),
                    T.Lambda(lambda img: random_translate(img, 20)),
                    T.Lambda(lambda img: random_zoom(img)), 
    
                    T.Lambda(lambda img: sameScaleZoom(img, img_height, img_width)),
                    T.RandomChoice([T.Lambda(lambda img: random_AffineTransform(img)),
                                                T.Lambda(lambda img: random_warpPerspective(img))]),
                    T.RandomChoice([T.Lambda(lambda img: random_degarde_img(img)),
                                                T.Lambda(lambda img: random_mosaic(img)),
                                                T.Lambda(lambda img: random_motion_blur(img)),
                                                T.Lambda(lambda img: random_focus_blur(img))]),
                    # Convert the OpenCV-format image into PIL before continue
                    T.ToPILImage('RGB'),
                    T.RandomOrder([T.ColorJitter(brightness=0.5),
                                                T.ColorJitter(contrast=(0.2, 1.8)),
                                                T.ColorJitter(saturation=(0.2, 1.8)),
                                                T.ColorJitter(hue=0.08)]),
                    T.ToTensor(),
                    T.Normalize(mean, std)
                ])
        else:
            transform = T.Compose([
                #T.Lambda(lambda img: sameScaleZoom(img, img_height, img_width)),
                # On this case, use PIL rather than OpenCV to open the image 
                T.Resize(size=(img_height, img_width)),
                T.ToTensor(),            
                T.Normalize(mean, std)
            ])
    
        return transform

为了验证我的猜测,对于推断数据集“valData”(不使用“trainData”,因为它需要太多时间),我将转换从 更改transform_test = build_transforms(cfg.img_height, cfg.img_width, 'test')transform_test = build_transforms(cfg.img_height, cfg.img_width, 'train'). 预计准确度为 0.9918。

我的问题是:

  • 作为参考,经过训练的模型在 testData 上的准确度为 0.989,但在 trainData 上的准确度约为 0.84,在 valData 上的准确度约为 0.82。
  • 我在转换中做错了什么?
  • 还是有其他原因造成这种奇怪现象?</li>

感谢所有愿意回答问题的人。

附1:12)验证码如下:

def val(cfg:Config, model, criterion, transform=None):
    start_time = datetime.now()
    val_loss = 0
    total = 0
    val_correct = 0
    model.eval()
    if transform is None:
        transform = build_transforms(cfg.img_height, cfg.img_width)
    dset_loader, dset_size = load_data(cfg, transform, run_mode='val', shuffle=False)

    for data in dset_loader:
        inputs, labels = data            

        if cfg.is_use_cuda:
            #inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
            inputs = inputs.cuda()
            labels = torch.stack([anno.cuda() for anno in labels])
        else:
            #inputs, labels = Variable(inputs), Variable(labels)
            pass

        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs.data, 1)      
            val_loss += loss.data.item()*inputs.size(0)
            val_correct += torch.sum(preds == labels.data)
                
    val_loss /= dset_size
    val_acc = val_correct.item()*1.0/dset_size
    elapsed = (datetime.now() - start_time).total_seconds()
    log_func('exit val,{} samples,elapsed {:.4f} sec,average elapsed{:.4f} sec'.format(dset_size, elapsed, elapsed/dset_size))

    return val_loss, val_acc
  1. load_data 代码是:
def load_data(cfg:Config, transform, run_mode='train', shuffle=True): 

    if run_mode == 'train':
        dataset = TheDataset(cfg, transform, run_mode)
        data_loader = DataLoader(dataset, batch_size=cfg.train_batch_size, shuffle=shuffle, num_workers=cfg.num_workers)
        return data_loader, len(dataset)

    else:
        dataset = TheDataset(cfg, transform, run_mode)
        data_loader = DataLoader(dataset, batch_size=cfg.val_batch_size, shuffle=shuffle, num_workers=cfg.num_workers)
        return data_loader, len(dataset)
  1. “TheDataset”类定义如下:
class TheDataset(Dataset): 

    def __init__(self, cfg:Config, transforms, run_mode='train') -> None: 
        super().__init__()
        self.img_mode = cfg.img_mode 
        self.transforms = transforms
        self.config = cfg
        self.run_mode = run_mode 
        assert cfg is not None, "The config object cannot be none"
        assert cfg.train_data_dirname is not None, "The train data cannot be none"
        assert transforms is not None, 'The transforms cannot be none'
        
        self.label_list = list()
        self.path_list = list() 
        self.label_2_path_index_list = {}  # Key:the label,value:a list each element of which is the index of the image file path related to the key in path_list

        if run_mode == 'train':
            self.dirname = cfg.train_data_dirname
            self.file_path = cfg.train_data_file_list
        elif run_mode == 'val':
            self.dirname = cfg.val_data_dirname
            self.file_path = cfg.val_data_file_list
        elif run_mode == 'test':
            self.dirname = cfg.test_data_dirname
            self.file_path = cfg.test_data_file_list
        else:
            self.dirname = cfg.train_data_dirname
            self.file_path = cfg.train_data_file_list

        index = 0
        with open(self.file_path, 'r') as f:
            for line in f:
                if line is not None and len(line) > 5:
                    a_path, a_label = line.strip().split(',')
                    if a_path is not None and a_label is not None:
                        a_label = int(a_label)
                        self.path_list.append(os.path.join(self.dirname, a_path.strip()))
                        self.label_list.append(a_label)
                        if self.label_2_path_index_list.get(a_label) is None:
                            self.label_2_path_index_list[a_label] = []
                        self.label_2_path_index_list[a_label].append(index)
                        index += 1
        
    def __getitem__(self, index):
        img_path = self.path_list[index]
        img_label = self.label_list[index]

        img = cv2.imread(img_path)
        if self.img_mode == 'RGB':
            try:
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            except:
                msg = 'cannot convert to RGB:' + img_path
                log_func(msg)
        
        img = self.transforms(img)

        return img, img_label


    def __len__(self):
        return len(self.label_list)

    def __repr__(self):
        return self.__str__()

    def __str__(self):
        return "TheDataset info: datasize={}, num_labels={}".format(len(self.path_list), len(self.label_2_path_index_list))

附加2:15)整个train.py是:

from pathlib import WindowsPath
import sys
import json
import os
import cv2
import torch 
import torch.nn as nn
from PIL import Image
import torch.optim as optim
from torch.autograd import Variable
from datetime import datetime
import pandas as pd  
from torch.cuda.amp.grad_scaler import GradScaler
from torch.cuda.amp.autocast_mode import autocast
from torchvision import transforms, datasets
from efficientnet_pytorch import EfficientNet
import torch.nn.functional as F

from part01_data import load_data
from part03_transform import build_transforms 
from part02_model import get_model, exp_lr_scheduler
from utils import print, set_logpath, format_datetime, write_one_log_record
from config import Config, ConfigEncoder

log_path = ''


def val(cfg:Config, model, criterion, transform=None):
    start_time = datetime.now()
    val_loss = 0
    total = 0
    val_correct = 0
    model.eval()
    if transform is None:
        transform = build_transforms(cfg.img_height, cfg.img_width)
    dset_loader, dset_size = load_data(cfg, transform, run_mode='val', shuffle=False)

    for data in dset_loader:
        inputs, labels = data            

        if cfg.is_use_cuda:
            #inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
            inputs = inputs.cuda()
            labels = torch.stack([anno.cuda() for anno in labels])
        else:
            #inputs, labels = Variable(inputs), Variable(labels)
            pass

        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs.data, 1)      
            val_loss += loss.data.item()*inputs.size(0)
            val_correct += torch.sum(preds == labels.data)
                
    val_loss /= dset_size
    val_acc = val_correct.item()*1.0/dset_size
    elapsed = (datetime.now() - start_time).total_seconds()
    print('val exit,{} samples,elapsed {:.4f} sec,average elapsed {:.4f} sec'.format(dset_size, elapsed, elapsed/dset_size))

    return val_loss, val_acc


def train(cfg:Config, shuffle=True):
    train_log_path = os.path.join(cfg.results_dir, cfg.backbone + '_' + str(cfg.num_classes) + 'classes_' + format_datetime(datetime.now()) + '.csv')
    print('Begin to train,the data directory:' + cfg.train_data_dirname)
    if cfg.is_use_apex:
        scaler = GradScaler()
    
    # step 1:Preparation
    best_acc = 0.0
    best_val_acc = 0.0
    start_epoch = -1

    criterion = nn.CrossEntropyLoss()    
    model_ft, optimizer_args, start_epoch, best_acc, best_val_acc = get_model(cfg, 'train')       
    if cfg.is_use_cuda:
        model_ft = model_ft.cuda()
        criterion = criterion.cuda()                        
 
    optimizer = optim.SGD(model_ft.parameters(), lr=1e-2, momentum=0.9, weight_decay=0.0004) 
    if optimizer_args is not None:
        optimizer.load_state_dict(optimizer_args)

    since = datetime.now()
    best_model_wts = model_ft.state_dict()
    
    transform = build_transforms(cfg.img_height, cfg.img_width)
    print('the transforms are as follows:')
    print(str(transform))
    print('preparation is finished')
    write_one_log_record('epoch, train loss, train accuracy, validation loss, validation accuracy, elapsed/minute\n', train_log_path, 'w')

    start_epoch_dt = datetime.now()    

    for epoch in range(start_epoch+1,cfg.num_epochs):
        # step 2:load data and adjust optimizer
        model_ft.train(True)
        dset_loader, dset_size = load_data(cfg, transform, run_mode='train', shuffle=shuffle)
        
        print('Epoch: {}/{},totally {} images'.format(epoch+1, cfg.num_epochs, dset_size))
        
        optimizer = exp_lr_scheduler(optimizer, epoch)

        running_loss = 0.0
        running_corrects = 0
        count = 0
        batch_count = len(dset_loader)        
        start_batches_dt = datetime.now()

        # step 3:begin batch train
        for data in dset_loader:      
            # step 3.1:detach sample and label and move them to the device 
            inputs, labels = data            

            if cfg.is_use_cuda:
                #inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
                inputs = inputs.cuda()
                labels = torch.stack([anno.cuda() for anno in labels])
            else:
                #inputs, labels = Variable(inputs), Variable(labels)
                pass
            
            # step 3.2:compute and forward
            optimizer.zero_grad()
            if cfg.is_use_apex:
                with autocast():
                    outputs = model_ft(inputs)
                    loss = criterion(outputs, labels)

                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model_ft(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()    

            # step 3.3:detach label and compute loss and correct count
            _, preds = torch.max(outputs.data, 1)    
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)        

            # step 3.4:print batch info
            count += 1
            start_batches_dt = output_batch_info(cfg, epoch, count, batch_count, loss.item(), outputs.size()[0], start_batches_dt)            

        # step 4:exit this epoch and compute the loss
        train_loss = running_loss / dset_size
        train_acc = running_corrects.double() / dset_size
        val_loss, val_acc = val(cfg, model_ft, criterion, transform)
        
        # step 5:judge the best model and save it
        best_model_wts, best_acc, best_val_acc = save_best_model(cfg, model_ft, best_model_wts, train_acc, best_acc, val_acc, best_val_acc)
            
        # step 6:save the last checkpoint
        save_newest_checkpoint(cfg, model_ft, optimizer, epoch, best_acc, best_val_acc)

        # step 7:save the middle checkpoint
        save_checkpoint_per_epochs(cfg, model_ft, optimizer, epoch, best_acc, best_val_acc)
        
        # step 8:compute the loss, accuracy and elapsed time in this epoch
        start_epoch_dt = summarize_epoch_info(start_epoch_dt, epoch, train_loss, train_acc, val_loss, val_acc, train_log_path)

        # step 9:judge it is proper to exit the train process
        if have_meet_acc_requirement_or_not(cfg, epoch, train_loss, train_acc, val_loss, val_acc):
            break         

    time_elapsed = (datetime.now() - since).total_seconds()
    print('train complete,elapsed {}hours {:.4f} minutes'.format(time_elapsed//3600, (time_elapsed - (time_elapsed//3600)*3600)/60))

    return best_model_wts

def output_batch_info(cfg:Config, epoch, count, batch_count, loss_per_sample, size_of_this_batch, start_batches_dt): 
    flag = ''
    elapsed = (datetime.now() - start_batches_dt).total_seconds() 
    if count % cfg.print_per_batch == 0:   
        flag = str(cfg.print_per_batch)
        more_time = (batch_count - count) * elapsed/cfg.print_per_batch
    if size_of_this_batch < cfg.train_batch_size: # the last batch
        flag = '本'
        more_time = (batch_count - count) * elapsed
    if len(flag) > 0:                               
        print(' Epoch: {}, batch: {}/{}, average train loss of each sample: {:.4f}, batch {} elapsed: {:.4f} sec,this epoch needs more {:.4f} sec'.format(epoch+1, count, batch_count, loss_per_sample, flag, elapsed, more_time))            
        return datetime.now()
    return start_batches_dt

def have_meet_acc_requirement_or_not(cfg: Config, epoch, train_loss, train_acc, val_loss, val_acc):  
    if train_acc < cfg.acc_valve or (cfg.is_check_best_with_val_loss and val_acc < cfg.acc_valve):
        return False 
    return True

def summarize_epoch_info(start_epoch_dt, epoch, train_loss, train_acc, val_loss, val_acc, output_path): 
    elapsed = (datetime.now() - start_epoch_dt).total_seconds()/60    
    remained_minutes = (cfg.num_epochs - epoch - 1)*elapsed
    remained_hours = remained_minutes//60
    remained_minutes = remained_minutes - remained_hours*60 
    record = '{},{:.4f},{:.4f},{:.4f},{:.4f},{:.4f}\n'.format(epoch+1, train_loss, train_acc, val_loss, val_acc, elapsed)    
    write_one_log_record(record, output_path, 'a')

    return datetime.now()

def save_one_checkpoint(model, optimizer, epoch, best_acc, best_val_acc, output_path): 
    checkpoint = {
        'net': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'best_acc': best_acc,
        'best_val_acc': best_val_acc
    }
    torch.save(checkpoint, output_path)

def save_checkpoint_per_epochs(cfg:Config, model, optimizer, epoch, best_acc, best_val_acc):     
    if cfg.save_per_epoch > 0 and (epoch+1)%cfg.save_per_epoch == 0: 
        checkpoint_path = cfg.resume_ckpt_dir + "/" + cfg.backbone + f'_checkpoint_{epoch+1}_' + str(cfg.num_classes) + 'classes.pth'
        save_one_checkpoint(model, optimizer, epoch, best_acc, best_val_acc, checkpoint_path)

def save_newest_checkpoint(cfg:Config, model, optimizer, epoch, best_acc, best_val_acc):  
    checkpoint_path = cfg.resume_ckpt_dir + "/" + cfg.backbone + '_checkpoint_last_' + str(cfg.num_classes) + 'classes.pth'
    save_one_checkpoint(model, optimizer, epoch, best_acc, best_val_acc, checkpoint_path)

def save_best_model(cfg:Config, model, best_model_weights, train_acc, best_acc, val_acc, best_val_acc): 
    if train_acc <= best_acc or (cfg.is_check_best_with_val_loss and val_acc <= best_val_acc):
        return best_model_weights, best_acc, best_val_acc
 
    
    best_model_weights = model.state_dict()      
    model_out_path = cfg.models_dir + "/" + cfg.backbone  + '_best_' + str(cfg.num_classes) + 'classes.pth'
    torch.save(best_model_weights, model_out_path)    
    
    best_acc = train_acc
    best_val_acc = val_acc if val_acc > best_val_acc else best_val_acc

    return best_model_weights, train_acc, best_val_acc

def infer(cfg:Config):  
    transform_test = build_transforms(cfg.img_height, cfg.img_width, 'test')
    #transform_test = build_transforms(cfg.img_height, cfg.img_width, 'train')
    model = get_model(cfg, 'test')
    model = model.to(cfg.get_device())
    model.eval() 

    records = []

    sub_classes = os.listdir(cfg.test_data_dirname)
    if sub_classes is None or len(sub_classes) < 1: 
        return 
    sub_classes= sorted(sub_classes)
    classid_dict = {}
    with open(cfg.classid_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()

        for line in lines:
            line = line.strip()
            tokens = line.split(',')
            classid_dict[int(tokens[0])] = tokens[1]
     

    records.append(cfg.test_data_dirname + ',' + str(len(sub_classes)) + ' classes\n')
    records.append('image, predict \n')
    start_time = datetime.now()
    elapsed = 0.0    
    count = 0
    with torch.no_grad():
        for sub_cls in sub_classes: 
            files = os.listdir(os.path.join(cfg.test_data_dirname, sub_cls))
            count += len(files)
            if files is None or len(files) < 1: 
                continue
            for file in files:
                try:
                    img_path = os.path.join(cfg.test_data_dirname, sub_cls, file)
                    if os.path.isfile(img_path):
                        # 当使用test模式transform = build_transforms(cfg.img_height, cfg.img_width, 'test')生成变换时,
                        # 使用img = Image.open(img_path)
                        img_test = Image.open(img_path)  

                        img = img_test
                        img = transform_test(img).to(cfg.get_device()) 
                        img = torch.unsqueeze(img, 0) 
                        output = model(img)
                        _, preds = torch.max(output.data, 1)
                        id = preds[0].item() 

                        if classid_dict.get(id) is not None: 
                            records.append(sub_cls + '/' + file + ',' + classid_dict[id] + '\n')
                            print(sub_cls + '/' + file + '  is predicted as:' + classid_dict[id]) 
                            pass
                        else: 
                            records.append(sub_cls + '/' + file + ', unknown\n') 
                except Exception as e:
                    print(str(e))


    elapsed = (datetime.now() - start_time).total_seconds() 

    records.append('elapsed {:.4f} sec ,average elapsed {:.4f} sec\n'.format(elapsed, elapsed/count)) 
    result_path = os.path.join(cfg.results_dir, 'infer_' + cfg.backbone + '_' + str(cfg.num_classes) + '_'  +  format_datetime(datetime.now()) + '.csv')
    with open(result_path, 'w', encoding='utf-8') as f:
        f.writelines(records)

def use_one_model(cfg:Config, model_name):  
    cfg.backbone = model_name
    log_path = os.path.join(cfg.log_dir, cfg.backbone + '_' + str(cfg.num_classes) +  'classes_' +  format_datetime(datetime.now()) + '.log')
    set_logpath(log_path)

    start_time = datetime.now()
    torch.cuda.empty_cache()    
    
    print('start, the args are:=====') 
    args = json.dumps(cfg, ensure_ascii=False, cls=ConfigEncoder, indent=2)
    print(args)
    try:
        #train(cfg)
        infer(cfg)
    except Exception as e:
        print(str(e))
    elapsed = (datetime.now() - start_time).total_seconds()
    hours = elapsed//3600
    minutes = (elapsed - hours*3600)/60 

def use_many_models(cfg:Config): 
    #backbones = ['efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'adv-efficientnet-b0', 'adv-efficientnet-b1', 'adv-efficientnet-b2', 'tf_efficientnet_b0_ns', 'tf_efficientnet_b1_ns','tf_efficientnet_b2_ns', 'efficientnet-b3', 'adv-efficientnet-b3', 'tf_efficientnet_b3_ns']
    backbones = ['tf_efficientnetv2_b0', 'tf_efficientnetv2_b1', 'tf_efficientnetv2_b2', 'tf_efficientnetv2_b3', 'tf_efficientnetv2_s']

    for backbone in backbones:
        use_one_model(cfg, backbone)
if __name__ == '__main__':
    cfg = Config()       
    use_one_model(cfg, cfg.backbone) 
    

  
4

0 回答 0