因此,我尝试在 SageMaker 中训练 AWS DeepAR 算法,以便预测明天的最高值。问题是我不想使用 S3 存储桶来训练模型。我有一个完整的字典列表,其中包含target
和的 numpy 数组dynamic_feat
。也许是因为目标和动态壮举变量是在 numpy 数组中编码的,但我真的别无选择,只能这样做。如果我需要将它们转换为列表,请告诉我。所以,让我们开始吧。这是代码:
estimator = sagemaker.estimator.Estimator(
image_uri=image_name,
sagemaker_session=sagemaker_session,
role=role,
train_instance_count=1,
train_instance_type="ml.c4.xlarge",
base_job_name="deepar-stock",
output_path=s3_output_path,
)
hyperparameters = {
"time_freq": freq,
"epochs": "400",
"early_stopping_patience": "40",
"mini_batch_size": "64",
"learning_rate": "5E-4",
"context_length": str(context_length),
"prediction_length": str(prediction_length),
}
estimator.set_hyperparameters(**hyperparameters)
training_set = []
for u in ts:
training_set.append(str(ts[u]).encode('utf-8'))
estimator.fit(inputs={'train':training_set})
它不工作。这是错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<timed eval> in <module>
/opt/conda/lib/python3.7/site-packages/sagemaker/estimator.py in fit(self, inputs, wait, logs, job_name, experiment_config)
678 self._prepare_for_training(job_name=job_name)
679
--> 680 self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config)
681 self.jobs.append(self.latest_training_job)
682 if wait:
/opt/conda/lib/python3.7/site-packages/sagemaker/estimator.py in start_new(cls, estimator, inputs, experiment_config)
1449 all information about the started training job.
1450 """
-> 1451 train_args = cls._get_train_args(estimator, inputs, experiment_config)
1452 estimator.sagemaker_session.train(**train_args)
1453
/opt/conda/lib/python3.7/site-packages/sagemaker/estimator.py in _get_train_args(cls, estimator, inputs, experiment_config)
1481 )
1482
-> 1483 config = _Job._load_config(inputs, estimator)
1484
1485 current_hyperparameters = estimator.hyperparameters()
/opt/conda/lib/python3.7/site-packages/sagemaker/job.py in _load_config(inputs, estimator, expand_role, validate_uri)
65 def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
66 """Placeholder docstring"""
---> 67 input_config = _Job._format_inputs_to_input_config(inputs, validate_uri)
68 role = (
69 estimator.sagemaker_session.expand_role(estimator.role)
/opt/conda/lib/python3.7/site-packages/sagemaker/job.py in _format_inputs_to_input_config(inputs, validate_uri)
133 elif isinstance(inputs, dict):
134 for k, v in inputs.items():
--> 135 input_dict[k] = _Job._format_string_uri_input(v, validate_uri)
136 elif isinstance(inputs, list):
137 input_dict = _Job._format_record_set_list_input(inputs)
/opt/conda/lib/python3.7/site-packages/sagemaker/job.py in _format_string_uri_input(uri_input, validate_uri, content_type, input_mode, compression, target_attribute_name)
198 raise ValueError(
199 "Cannot format input {}. Expecting one of str, TrainingInput, file_input or "
--> 200 "FileSystemInput".format(uri_input)
201 )
202
ValueError: Cannot format input [b"{'start': '2015-05-07', 'cat': [0, 1069, 1082], 'target': array([210.13999939, 272.72000122, 303.79998779, ..., 4.76000023,\n 4.64699984, 4.63999987]), 'dynamic_feat': [array([180.6000061 , 210.27999878, 274.3999939 , ..., 4.30000019,\n 4.36000013, 4.55000019]), array([185.5 , 210.27999878, 282.94000244, ..., 4.36000013,\n 4.44000006, 4.57999992]), array([209.30000305, 269.77999878, 290.5 , ..., 4.44999981,\n 4.57999992, 4.57999992]), array([235893., 46807., 39136., ..., 809000., 152600., 52301.]), array([209.30000305, 269.77999878, 290.5 , ..., 4.44999981,\n 4.57999992, 4.57999992])]}", b"{'start': '2011-08-01', 'cat': [1, 1070, 1083], 'target': array([ 4.0999999 , 4.09000015, 3.99000001, ..., 13.5 ,\n 13.63000011, 13.69999981]), 'dynamic_feat': [array([ 3.95000005, 3.95000005, 3.94000006, ..., 13.25 ,\n 13.14000034, 13.35999966]), array([ 4.0999999 , 4. , 3.99000001, ..., 13.35999966,\n 13.14000034, 13.59000015]), array([ 4. , 3.97000003, 3.98000002, ..., 13.38000011,\n 13.47999954, 13.40999985]), array([ 16400., 9300., 4100., ..., 30100., 107900., 113438.]), array([ 3.25682783, 3.23240018, 3.24054289, ..., 13.38000011,\n 13.47999954, 13.40999985])]}", b"{'start': '2011-08-01', 'cat': [2, 1084, 1084], 'target': array([ 60.31999969, 59.20000076, 58.02999878, ..., 203.28999329,\n 203.28999329, 204.18739319]), 'dynamic_feat': [array([ 58.84999847, 57.79000092, 56.95000076, ..., 200.8999939 ,\n 202.16000366, 203.54499817]), array([ 60.29000092, 59.13999939, 57.5 , ..., 203.28999329,\n 203.02000427, 203.55000305]), array([ 59.20999908, 57.79000092, 57.97999954, ..., 202.32000732,\n 202.49000549, 203.61000061]), array([ 8400., 9700., 10800., ..., 23200., 36500., 5876.]), array([ 49.99819946, 48.79912567, 48.95956421, ..., 202.32000732,\n 202.49000549, 203.61000061])]}", b"{'start': '2017-07-13', 'cat': [3, 1084, 1084]