报错问题 

Failed to import pydot. You must `pip install pydot` and install graphviz


我们在调用keras里面的高级API——plot_model(),去画神经网络的结构图的时候可能会遇到两个报错问题。

第一个是说keras.utils里面不存在plot_model()这个用法。

cannot import name 'plot_model' from 'keras.utils'

这个问题好解决,因为keras里面确实没有plot_model()用法,但是他的好兄弟——TensorFlow里面有.....

直接这样导入:

from tensorflow.keras.utils import plot_model

就可以了。

第二个报错问题是:

Failed to import pydot. You must `pip install pydot` and install graphviz

意思是缺失两个包,一个pydot,一个graphviz。

我查了很多文章,很多方法比较麻烦,都需要手动下载,手动配置环境变量,后来看到一个很简单的方法,并且也测试有效。

直接在anaconda prompt里面:

conda install graphviz
conda install pydotplus

就可以了,不过安装过程好像会给你装上很多额外的包.....不过不影响环境,神经网络还是一样能跑。


画图测试

导入包,构建一个网络。这个网络是Model类,采用函数API实现,稍微复杂点,可以看图就会清楚他的结构。

导入包

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import keras
from keras.preprocessing import sequence
from keras.models import Sequential,Model
from keras.layers import Dense,Input, Dropout, Embedding, Flatten,MaxPooling1D,Conv1D,SimpleRNN,LSTM,GRU,Multiply
from keras.layers import Bidirectional,Activation,BatchNormalization
from keras.layers.merge import concatenate

from keras.callbacks import EarlyStopping
from tensorflow.keras import regularizers
from keras.utils.np_utils import to_categorical
from tensorflow.keras  import optimizers
from tensorflow.keras.utils import plot_model

定义模型:

inputs = Input(name='inputs',shape=[64,100], dtype='float64')
gru=Bidirectional(GRU(32,return_sequences=True,))(inputs)
mlp = Dense(64,activation='relu')(gru)
attention_probs = Dense(64, activation='softmax', name='attention_vec')(mlp)
attention_mul =  Multiply()([mlp, attention_probs])
mlp = Dense(64)(attention_mul) #原始的全连接
fla=Flatten()(mlp)
output = Dense(2, activation='softmax')(fla)
model = Model(inputs=[inputs], outputs=output)
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

画出图形:

plot_model(model,'new_model.png',show_shapes=True)

第一参数是神经网络模型,第二个参数是储存的图片名称,第三个是在图片上打印出每层的数据形状。

 用这种图就能很方便的展示组建的模型的架构,多输入多输出都行。

show_shapes=True参数改为False,就可以简化展示图片,不打印形状。我这里换了一种结构的网络。

plot_model(model,'model2.png',show_shapes=False)

 

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐