正如标题所示,我正在尝试让 AI 模型预测有关天气的各种特征。我的模型基于Google 的MetNet模型的openclimatefix实现。具体来说,我正在尝试使用MetNet2。然而,在实现模型时出现了一些问题(例如代码中的拼写错误以及需要降低某些参数以确保它可以在本地运行),但经过一些工作,我完成了。然而,在尝试可视化一些预测之后,我得到了非常奇怪的结果,如下所示。
顶部图像是基本事实,底部图像是我的模型的预测。我不知道如何解决这个问题。我已经尝试多次更改模型的各种参数大小和学习率,但似乎没有任何帮助。下面是运行它的代码的重要部分。我还有用于计算损失(使用 MS-SSIM)和加载数据的单独文件
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import xarray as xr
from numpy import float32
from torch.utils.data import DataLoader
from loss import MS_SSIMLoss
plt.rcParams["figure.figsize"] = (20, 12)
BATCH_SIZE = 1
EPOCHS = 8
device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
from metnet import MetNet2
model = MetNet2(
forecast_steps=24,
upsample_method = "interp",
input_channels=1,
sat_channels=1,
input_size=1024,
num_input_timesteps=12,
upsampler_channels=64,
lstm_channels=64,
encoder_channels=64,
output_channels=1,
center_crop_size=16
)
optimiser = optim.Adam(model.parameters(), lr=.01)
criterion = MS_SSIMLoss(channels=24) # produces less blurry images than nn.MSELoss()
losses = []
for epoch in range(EPOCHS):
print(f"Epoch {epoch + 1}")
running_loss = 0
i = 0
count = 0
for batch_coordinates, batch_features, batch_targets in ch_dataloader:
optimiser.zero_grad()
batch_predictions = model(batch_features.to(device).unsqueeze(dim=2))
batch_loss = criterion(batch_predictions.squeeze(dim=2), batch_targets.to(device))
batch_loss.backward()
optimiser.step()
running_loss += batch_loss.item() * batch_predictions.shape[0]
count += batch_predictions.shape[0]
i += 1
print(f"Completed batch {i} of epoch {epoch + 1} with loss {batch_loss.item()} -- processed {count} image sequences ({12 * count} images)")
losses.append(running_loss / count)
print(f"Loss for epoch {epoch + 1}/{EPOCHS}: {losses[-1]}")
for batch_coordinates, batch_features, batch_targets in ch_dataloader:
print(batch_features.shape)
p=model(batch_features.unsqueeze(dim=2)).squeeze(dim=2).detach().numpy()
fig, (ax1, ax2) = plt.subplots(1, 12, figsize=(20,8))
print(p.shape)
for i, img in enumerate(p[0][:12]):
ax2[i].imshow(img, cmap='viridis')
ax2[i].get_xaxis().set_visible(False)
ax2[i].get_yaxis().set_visible(False)
for i, img in enumerate(batch_targets[0][:12].numpy()):
ax1[i].imshow(img, cmap='viridis')
ax1[i].get_xaxis().set_visible(False)
ax1[i].get_yaxis().set_visible(False)
fig.tight_layout()
fig.subplots_adjust(wspace=0, hspace=0)
print(criterion(torch.from_numpy(p),batch_targets))
break
我应该如何继续?任何帮助,将不胜感激。谢谢!