我在正确加载 .csv 文件以用作非常简单的密集 NN 模型的输入时遇到问题。csv 文件包含所有输入特征和一个“目标”列,用作回归的输出。
到目前为止,这就是我正在做的事情:
def main():
batch_size = 500
## load input file
df_data = pd.read_csv('some_file.csv', index_col=0)
## random train/test split
df_train = df_data.sample(frac=0.8,random_state=200)
df_test = df_data.drop(df_train.index)
## data pre-processing
df_train.reset_index(drop=True, inplace=True)
df_test.reset_index(drop=True, inplace=True)
y_train = df_train['target'].to_numpy(dtype=np.float64)
y_test = df_test['target'].to_numpy(dtype=np.float64)
X_train = df_train.drop(['target'], axis=1).to_numpy(dtype=np.float64)
X_test = df_test.drop(['target'], axis=1).to_numpy(dtype=np.float64)
dataset = mx.gluon.data.dataset.ArrayDataset(X_train, y_train)
data_loader = mx.gluon.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
## building model
model = nn.Sequential()
model.add(nn.Dense(150))
model.add(nn.Dense(1))
model.initialize(init.Normal(sigma=0.01))
## loss function (squared loss)
loss = gloss.L2Loss()
## optimization algorithm, specify:
trainer = gluon.Trainer(model.collect_params(), 'sgd', {'learning_rate': 0.03})
## training #
num_epochs = 10
for epoch in range(1, num_epochs + 1):
for X_batch, Y_batch in data_loader:
with autograd.record():
l = loss(model(X_batch), Y_batch)
l.backward()
trainer.step(batch_size)
# overall (entire dataset) loss after epoch
l = loss(model(X_train), y_train)
print(f'\nEpoch {epoch}, loss: {l.mean().asnumpy()}')
我收到了错误:
mxnet.base.MXNetError: [16:09:03] src/operator/numpy/linalg/./../../tensor/../elemwise_op_common.h:135: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node at 1-th input: expected float64, got float32
因此,我尝试通过将 np.float64 切换为 np.float32 来转换数据,但我得到:
File "/home/lews/anaconda3/envs/gluon/lib/python3.7/site-packages/mxnet/gluon/block.py", line 1136, in forward
raise ValueError('In HybridBlock, there must be one NDArray or one Symbol in the input.'
ValueError: In HybridBlock, there must be one NDArray or one Symbol in the input. Please check the type of the args.
加载此数据的正确方法是什么?