0

我正在尝试使用该tf.keras.utils.model_to_dot()函数绘制我的模型,但不断收到以下错误:

TypeError: object of type 'Cluster' has no len()

这是我使用的代码:

import tensorflow
import pydot

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128,activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')
])

model_graph = tf.keras.utils.model_to_dot(model, expand_nested=True, subgraph=True)
graph = pydot.graph_from_dot_data(model_graph)
graph.write_png('model.png')

我在这里做错了什么?

4

1 回答 1

0

按照文档,这里是https://pydotplus.readthedocs.io/reference.html

pydotplus.graphviz.graph_from_dot_data(data)[source]
Load graph as defined by data in DOT format.

The data is assumed to be in DOT format. It will be parsed and a Dot class will be returned, representing the graph.

你的代码应该是这样的,

import pydot

(graph,) = pydot.graph_from_dot_file('somefile.dot')
graph.write_png('somefile.png')

感谢@Judge Maygarden,在 python 中将点转换为 png

基本上,它需要一个 .dot 文件,而不是模型本身。

也许你正在寻找tf.keras.utils.plot_model?例如,来自https://www.tensorflow.org/api_docs/python/tf/keras/utils/plot_model

tf.keras.utils.plot_model(
    model, to_file='model.png', show_shapes=False, show_layer_names=True,
    rankdir='TB', expand_nested=False, dpi=96
)
于 2020-08-26T06:50:08.873 回答