0

我正在使用 gpt2、python 3.9 和 tensorflow 2.5,当连接到烧瓶(终端中运行烧瓶)时,我收到以下消息:

TypeError:无法根据规则“安全”将数组数据从 dtype('O') 转换为 dtype('int64')

这是 generator.py 中的代码

#!/usr/bin/env python3

import fire
import json
import os
import numpy as np
import tensorflow.compat.v1 as tf

# import model, sample, encoder
  from text_generator import model
  from text_generator import sample
  from text_generator import encoder


class AI:
 def generate_text(self, input_text):
    model_name = '117M_Trained'
    seed = None,
    nsamples = 1
    batch_size = 1
    length = 150
    temperature = 1
    top_k = 40
    top_p = 1
    models_dir = 'models'
    self.response = ''

    models_dir = os.path.expanduser(os.path.expandvars(models_dir))
    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    enc = encoder.get_encoder(model_name, models_dir)
    hparams = model.default_hparams()
    cur_path = os.path.dirname(__file__) + '/models' + '/' + model_name
    with open(cur_path + '/hparams.json') as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx // 2
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

    with tf.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(
            hparams=hparams, length=length,
            context=context,
            batch_size=batch_size,
            temperature=temperature, top_k=top_k, top_p=top_p
        )

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(cur_path)
        saver.restore(sess, ckpt)

        context_tokens = enc.encode(input_text)
        generated = 0
        for _ in range(nsamples // batch_size):
            out = sess.run(output, feed_dict={
                context: [context_tokens for _ in range(batch_size)]
            })[:, len(context_tokens):]
            for i in range(batch_size):
                generated += 1
                text = enc.decode(out[i])
                self.response = text

    return self.response


  ai = AI()
  text = ai.generate_text('How are you?')
  print(text)

感谢任何帮助 ps 我还在整个回溯下方添加了

     * Serving Flask app 'text_generator' (lazy loading)
 * Environment: development
 * Debug mode: on
2021-09-14 19:58:08.687907: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Traceback (most recent call last):
  File "_mt19937.pyx", line 178, in numpy.random._mt19937.MT19937._legacy_seeding
TypeError: 'tuple' object cannot be interpreted as an integer

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/dusandev/miniconda3/bin/flask", line 8, in <module>
    sys.exit(main())
  File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/flask/cli.py", line 990, in main
    cli.main(args=sys.argv[1:])
  File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/flask/cli.py", line 596, in main
    return super().main(*args, **kwargs)
  File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/click/core.py", line 1062, in main
    rv = self.invoke(ctx)
  File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/click/core.py", line 1668, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/click/core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/click/core.py", line 763, in invoke
    return __callback(*args, **kwargs)
  File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/click/decorators.py", line 84, in new_func
    return ctx.invoke(f, obj, *args, **kwargs)
  File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/click/core.py", line 763, in invoke
    return __callback(*args, **kwargs)
  File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/flask/cli.py", line 845, in run_command
    app = DispatchingApp(info.load_app, use_eager_loading=eager_loading)
  File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/flask/cli.py", line 321, in __init__
    self._load_unlocked()
  File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/flask/cli.py", line 346, in _load_unlocked
    self._app = rv = self.loader()
  File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/flask/cli.py", line 402, in load_app
    app = locate_app(self, import_name, name)
  File "/Users/dusandev/miniconda3/lib/python3.9/site-packages/flask/cli.py", line 256, in locate_app
    __import__(module_name)
  File "/Users/dusandev/Desktop/AI/text_generator/__init__.py", line 2, in <module>
    from .routes import generator
  File "/Users/dusandev/Desktop/AI/text_generator/routes.py", line 2, in <module>
    from .generator import ai
  File "/Users/dusandev/Desktop/AI/text_generator/generator.py", line 74, in <module>
    text = ai.generate_text('How are you?')
  File "/Users/dusandev/Desktop/AI/text_generator/generator.py", line 46, in generate_text
    np.random.seed(seed)
  File "mtrand.pyx", line 244, in numpy.random.mtrand.RandomState.seed
  File "_mt19937.pyx", line 166, in numpy.random._mt19937.MT19937._legacy_seeding
  File "_mt19937.pyx", line 186, in numpy.random._mt19937.MT19937._legacy_seeding
TypeError: Cannot cast array data from dtype('O') to dtype('int64') according to the rule 'safe'
4

1 回答 1

0

问题是None,代码中的行。这导致元组(None,)作为np.random.seed(seed). 它接受整数,但您正在发送元组。

于 2021-09-14T18:44:59.333 回答