0

因此,我尝试在 SageMaker 中训练 AWS DeepAR 算法,以便预测明天的最高值。问题是我不想使用 S3 存储桶来训练模型。我有一个完整的字典列表,其中包含target和的 numpy 数组dynamic_feat。也许是因为目标和动态壮举变量是在 numpy 数组中编码的,但我真的别无选择,只能这样做。如果我需要将它们转换为列表,请告诉我。所以,让我们开始吧。这是代码:

estimator = sagemaker.estimator.Estimator(
    image_uri=image_name,
    sagemaker_session=sagemaker_session,
    role=role,
    train_instance_count=1,
    train_instance_type="ml.c4.xlarge",
    base_job_name="deepar-stock",
    output_path=s3_output_path,
)

hyperparameters = {
    "time_freq": freq,
    "epochs": "400",
    "early_stopping_patience": "40",
    "mini_batch_size": "64",
    "learning_rate": "5E-4",
    "context_length": str(context_length),
    "prediction_length": str(prediction_length),
}

estimator.set_hyperparameters(**hyperparameters)

training_set = []
for u in ts:
    training_set.append(str(ts[u]).encode('utf-8'))

estimator.fit(inputs={'train':training_set})

它不工作。这是错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<timed eval> in <module>

/opt/conda/lib/python3.7/site-packages/sagemaker/estimator.py in fit(self, inputs, wait, logs, job_name, experiment_config)
    678         self._prepare_for_training(job_name=job_name)
    679 
--> 680         self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config)
    681         self.jobs.append(self.latest_training_job)
    682         if wait:

/opt/conda/lib/python3.7/site-packages/sagemaker/estimator.py in start_new(cls, estimator, inputs, experiment_config)
   1449             all information about the started training job.
   1450         """
-> 1451         train_args = cls._get_train_args(estimator, inputs, experiment_config)
   1452         estimator.sagemaker_session.train(**train_args)
   1453 

/opt/conda/lib/python3.7/site-packages/sagemaker/estimator.py in _get_train_args(cls, estimator, inputs, experiment_config)
   1481                 )
   1482 
-> 1483         config = _Job._load_config(inputs, estimator)
   1484 
   1485         current_hyperparameters = estimator.hyperparameters()

/opt/conda/lib/python3.7/site-packages/sagemaker/job.py in _load_config(inputs, estimator, expand_role, validate_uri)
     65     def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
     66         """Placeholder docstring"""
---> 67         input_config = _Job._format_inputs_to_input_config(inputs, validate_uri)
     68         role = (
     69             estimator.sagemaker_session.expand_role(estimator.role)

/opt/conda/lib/python3.7/site-packages/sagemaker/job.py in _format_inputs_to_input_config(inputs, validate_uri)
    133         elif isinstance(inputs, dict):
    134             for k, v in inputs.items():
--> 135                 input_dict[k] = _Job._format_string_uri_input(v, validate_uri)
    136         elif isinstance(inputs, list):
    137             input_dict = _Job._format_record_set_list_input(inputs)

/opt/conda/lib/python3.7/site-packages/sagemaker/job.py in _format_string_uri_input(uri_input, validate_uri, content_type, input_mode, compression, target_attribute_name)
    198         raise ValueError(
    199             "Cannot format input {}. Expecting one of str, TrainingInput, file_input or "
--> 200             "FileSystemInput".format(uri_input)
    201         )
    202 

ValueError: Cannot format input [b"{'start': '2015-05-07', 'cat': [0, 1069, 1082], 'target': array([210.13999939, 272.72000122, 303.79998779, ...,   4.76000023,\n         4.64699984,   4.63999987]), 'dynamic_feat': [array([180.6000061 , 210.27999878, 274.3999939 , ...,   4.30000019,\n         4.36000013,   4.55000019]), array([185.5       , 210.27999878, 282.94000244, ...,   4.36000013,\n         4.44000006,   4.57999992]), array([209.30000305, 269.77999878, 290.5       , ...,   4.44999981,\n         4.57999992,   4.57999992]), array([235893.,  46807.,  39136., ..., 809000., 152600.,  52301.]), array([209.30000305, 269.77999878, 290.5       , ...,   4.44999981,\n         4.57999992,   4.57999992])]}", b"{'start': '2011-08-01', 'cat': [1, 1070, 1083], 'target': array([ 4.0999999 ,  4.09000015,  3.99000001, ..., 13.5       ,\n       13.63000011, 13.69999981]), 'dynamic_feat': [array([ 3.95000005,  3.95000005,  3.94000006, ..., 13.25      ,\n       13.14000034, 13.35999966]), array([ 4.0999999 ,  4.        ,  3.99000001, ..., 13.35999966,\n       13.14000034, 13.59000015]), array([ 4.        ,  3.97000003,  3.98000002, ..., 13.38000011,\n       13.47999954, 13.40999985]), array([ 16400.,   9300.,   4100., ...,  30100., 107900., 113438.]), array([ 3.25682783,  3.23240018,  3.24054289, ..., 13.38000011,\n       13.47999954, 13.40999985])]}", b"{'start': '2011-08-01', 'cat': [2, 1084, 1084], 'target': array([ 60.31999969,  59.20000076,  58.02999878, ..., 203.28999329,\n       203.28999329, 204.18739319]), 'dynamic_feat': [array([ 58.84999847,  57.79000092,  56.95000076, ..., 200.8999939 ,\n       202.16000366, 203.54499817]), array([ 60.29000092,  59.13999939,  57.5       , ..., 203.28999329,\n       203.02000427, 203.55000305]), array([ 59.20999908,  57.79000092,  57.97999954, ..., 202.32000732,\n       202.49000549, 203.61000061]), array([ 8400.,  9700., 10800., ..., 23200., 36500.,  5876.]), array([ 49.99819946,  48.79912567,  48.95956421, ..., 202.32000732,\n       202.49000549, 203.61000061])]}", b"{'start': '2017-07-13', 'cat': [3, 1084, 1084]
4

0 回答 0