七、VGG16+BN(Batch Normalization)实现鸟类数据库分类
文章目录前文加利福尼亚理工学院鸟类数据库分类VGG16+BN版本数据生成器图像显示VGG16+BN模型构建VGG16+BN模型编译与拟合注意:GitHub下载地址:前文一、Windows系统下安装Tensorflow2.x(2.6)二、深度学习-读取数据三、Tensorflow图像处理预算四、线性回归模型的tensorflow实现五、深度学习-逻辑回归模型六、AlexNet实现中文字体识别——隶书
·
前文
- 一、Windows系统下安装Tensorflow2.x(2.6)
- 二、深度学习-读取数据
- 三、Tensorflow图像处理预算
- 四、线性回归模型的tensorflow实现
- 五、深度学习-逻辑回归模型
- 六、AlexNet实现中文字体识别——隶书和行楷
- 七、VGG16实现鸟类数据库分类
加利福尼亚理工学院鸟类数据库分类VGG16+BN版本
数据生成器
from keras.preprocessing.image import ImageDataGenerator
IMSIZE = 224
train_generator = ImageDataGenerator(rescale=1. / 255).flow_from_directory('../../data/data_vgg/train',
target_size=(IMSIZE, IMSIZE),
batch_size=20,
class_mode='categorical'
)
validation_generator = ImageDataGenerator(rescale=1. / 255).flow_from_directory('../../data/data_vgg/test',
target_size=(IMSIZE, IMSIZE),
batch_size=20,
class_mode='categorical'
)
)
图像显示
from matplotlib import pyplot as plt
plt.figure()
fig, ax = plt.subplots(2, 5)
fig.set_figheight(6)
fig.set_figwidth(15)
ax = ax.flatten()
X, Y = next(validation_generator)
for i in range(15): ax[i].imshow(X[i, :, :, ])
VGG16+BN模型构建
#VGG16+BN实现
#VGG16+BN模型构建
from keras.layers import Conv2D, BatchNormalization, MaxPooling2D
from keras.layers import Flatten, Dense, Input, Activation
from keras import Model
from keras.layers import GlobalAveragePooling2D
input_shape = (IMSIZE, IMSIZE, 3)
input_layer = Input(input_shape)
x = input_layer
x = BatchNormalization(axis=3)(x)
x = Conv2D(64, [3, 3], padding="same", activation='relu')(x)
x = BatchNormalization(axis=3)(x)
x = Conv2D(64, [3, 3], padding="same", activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = BatchNormalization(axis=3)(x)
x = Conv2D(128, [3, 3], padding="same", activation='relu')(x)
x = BatchNormalization(axis=3)(x)
x = Conv2D(128, [3, 3], padding="same", activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = BatchNormalization(axis=3)(x)
x = Conv2D(256, [3, 3], padding="same", activation='relu')(x)
x = BatchNormalization(axis=3)(x)
x = Conv2D(256, [3, 3], padding="same", activation='relu')(x)
x = BatchNormalization(axis=3)(x)
x = Conv2D(256, [3, 3], padding="same", activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = BatchNormalization(axis=3)(x)
x = Conv2D(512, [3, 3], padding="same", activation='relu')(x)
x = BatchNormalization(axis=3)(x)
x = Conv2D(512, [3, 3], padding="same", activation='relu')(x)
x = BatchNormalization(axis=3)(x)
x = Conv2D(512, [3, 3], padding="same", activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = BatchNormalization(axis=3)(x)
x = Conv2D(512, [3, 3], padding="same", activation='relu')(x)
x = BatchNormalization(axis=3)(x)
x = Conv2D(512, [3, 3], padding="same", activation='relu')(x)
x = BatchNormalization(axis=3)(x)
x = Conv2D(512, [3, 3], padding="same", activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = GlobalAveragePooling2D()(x)
x = Dense(315)(x)
x = Activation('softmax')(x)
output_layer = x
model_vgg16_b = Model(input_layer, output_layer)
model_vgg16_b.summary()
VGG16+BN模型编译与拟合
from keras.optimizers import Adam
model_vgg16.compile(loss='categorical_crossentropy',
optimizer=Adam(lr=0.001),
metrics=['accuracy'])
model_vgg16.fit_generator(train_generator,
epochs=20,
validation_data=validation_generator)
注意:
因为自己是使用tensorflow-GPU版本,自己电脑是1050Ti,4G显存。实际运行时候batch_size设置不到15大小,太大了就显存资源不足。
但是batch_size太小,总的数据集较大较多,所以最后消耗时间就较长。
所以为了效率和烧显卡,请酌情考虑
数据集来源:kaggle平台315种鸟类:315 Bird Species - Classification | Kaggle
GitHub下载地址:
更多推荐
已为社区贡献2条内容
所有评论(0)