1

我是 Pytorch Lightning 的新手。我使用 PL 制作了一个非常简单的模型。我在训练前后检查了模型的权重,但知道训练期间损失减少,它们完全相同。

def main(args, df_train, df_dev, df_test) :
    """ main function"""
    # Wandb connect
    wandb_connect()
    wandb_logger = WandbLogger(project="project name", name="Run name")

    # Tokenization
    [df_train, df_dev, df_test], params, tokenizer_qid, tokenizer_uid, tokenizer_qu_id, tokenizer_rank = apply_tokenization([df_train, df_dev, df_test])

    # Dataloadeers
    [train_loader, dev_loader, test_loader] =  list(map(lambda x : Dataset_SM(x).get_dataloader(args.batch_size), [df_train, df_dev, df_test]))
    
   # Model definition
    model = NCM(**params).to(device)
   # Weight before training
    WW = model.emb_qid.weight
    print(torch.mean(model.emb_qid.weight))
    # Train & Eval
    es = EarlyStopping(monitor='dev_loss', patience=4)
    checkpoint_callback = ModelCheckpoint(dirpath=args.result_path)
    trainer = pl.Trainer(max_epochs=args.n_epochs, callbacks=[es, checkpoint_callback], val_check_interval=args.val_check_interval,
                         logger=wandb_logger, gpus=1)
    trainer.fit(model, train_loader, dev_loader)
    trainer.save_checkpoint(args.result_path + "example.ckpt")
    loaded_model = NCM.load_from_checkpoint(checkpoint_path=args.result_path + "example.ckpt", **params)
    print(loaded_model.emb_qid.weight == WW)

如果我错过了什么,有人可以告诉我吗?

4

1 回答 1

0

仅分配.weight给新变量不会复制数据,而只会传递对数据的引用。

这是一个简短的例子:

import copy
import torch

def is_same(a, b):
    return ((a - b).float().abs().sum() < 1e-6).item()

a = torch.nn.Conv2d(5, 5, 3)
b = a.weight
c = copy.deepcopy(a.weight)

torch.nn.init.xavier_uniform_(a.weight.data)

print("a == b: " + str(is_same(a.weight, b)))
print("a == c: " + str(is_same(a.weight, c)))
# output
a == b: True
a == c: False
于 2022-01-25T10:08:54.530 回答