写点什么

tensorflow 实现 CNN 模型垃圾分类算法

用户头像
AI_robot
关注
发布于: 2021 年 03 月 31 日
tensorflow实现CNN模型垃圾分类算法

上海开始施行垃圾分类啦。那么我们能不能通过平常学习的机器学习和深度学习的算法来实现一个简单的垃圾分类的模型呢?下面主要用过 CNN 来实现垃圾的分类。在本数据集中,垃圾的种类有六种(和上海的标准不一样),分为玻璃、纸、硬纸板、塑料、金属、一般垃圾。该数据集包含了 2527 个生活垃圾图片。数据集的创建者将垃圾分为了 6 个类别,分别是:

玻璃(glass)共 501 个图片纸(paper)共 594 个图片硬纸板(cardboard)共 403 个图片塑料(plastic)共 482 个图片金属(metal)共 410 个图片一般垃圾(trash)共 137 个图片物品都是放在白板上在日光/室内光源下拍摄的,压缩后的尺寸为 512 * 384。

dataset from https://github.com/garythung/trashnet/tree/master/dataUnzip data/dataset-resized.zip

代码下载:

https://github.com/wennaz/Deep_Learning/

import numpy as npimport matplotlib.pyplot as pltfrom tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array, array_to_imgfrom tensorflow.keras.layers import Conv2D, Flatten, MaxPooling2D, Densefrom tensorflow.keras.models import Sequential
import glob, os, random
base_path ='/Users/zhangwenna/Desktop/dataset-resized'img_list = glob.glob(os.path.join(base_path, '*/*.jpg'))print(len(img_list))
# 我们总共有2527张图片。我们随机展示其中的6张图片for i, img_path in enumerate(random.sample(img_list, 6)): img = load_img(img_path) img = img_to_array(img, dtype=np.uint8) plt.subplot(2, 3, i+1) plt.imshow(img.squeeze())
复制代码

output1:

2527


#对数据进行分组#ImageDataGenerator()是keras.preprocessing.image模块中的图片生成器,可以每一次给模型“喂”一个batch_size大小的样本数据,#同时也可以在每一个批次中对这batch_size个样本数据进行增强,扩充数据集大小,增强模型的泛化能力。比如进行旋转,变形,归一化等等。
train_datagen = ImageDataGenerator( rescale=1./225, shear_range=0.1, zoom_range=0.1, width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True, vertical_flip=True, validation_split=0.1)#shear_range 剪切强度(逆时针方向的剪切变换角度)#validation_split: 保留用于验证的图像的比例(严格在0和1之间)test_datagen = ImageDataGenerator( rescale=1./255, validation_split=0.1) train_generator = train_datagen.flow_from_directory( base_path, target_size=(300, 300), batch_size=16, class_mode='categorical', subset='training', seed=0)
validation_generator = test_datagen.flow_from_directory( base_path, target_size=(300, 300), batch_size=16, class_mode='categorical', subset='validation', seed=0)
labels = (train_generator.class_indices)labels = dict((v,k) for k,v in labels.items())
print(labels)
复制代码

output2:

Found 2276 images belonging to 6 classes.

Found 251 images belonging to 6 classes.

{0: 'cardboard', 1: 'glass', 2: 'metal', 3: 'paper', 4: 'plastic', 5: 'trash'}

model = Sequential([    Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', input_shape=(300, 300, 3)),    MaxPooling2D(pool_size=2),
Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'), MaxPooling2D(pool_size=2), Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'), MaxPooling2D(pool_size=2), Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'), MaxPooling2D(pool_size=2),
Flatten(),
Dense(64, activation='relu'),
Dense(6, activation='softmax')])#(交叉熵损失函数) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])model.fit_generator(train_generator, epochs=100, steps_per_epoch=2276//32,validation_data=validation_generator, validation_steps=251//32)
#参数steps_per_epoch是通过把训练图像的数量除以批次大小得出的。例如,有100张图像且批次大小为50,则steps_per_epoch值为2
复制代码

output3:

Please use Model.fit, which supports generators.

Epoch 1/10071/71 [==============================] - 221s 3s/step - loss: 1.7205 - acc: 0.2553 - val_loss: 1.6083 - val_acc: 0.3750Epoch 2/10071/71 [==============================] - 121s 2s/step - loss: 1.5227 - acc: 0.3803 - val_loss: 1.4301 - val_acc: 0.4464Epoch 3/10071/71 [==============================] - 123s 2s/step - loss: 1.4023 - acc: 0.4428 - val_loss: 1.5464 - val_acc: 0.4286Epoch 4/10071/71 [==============================] - 118s 2s/step - loss: 1.3962 - acc: 0.4489 - val_loss: 1.4306 - val_acc: 0.4286Epoch 5/10071/71 [==============================] - 116s 2s/step - loss: 1.3639 - acc: 0.4384 - val_loss: 1.3748 - val_acc: 0.4196Epoch 6/10071/71 [==============================] - 118s 2s/step - loss: 1.2870 - acc: 0.4850 - val_loss: 1.2453 - val_acc: 0.5536Epoch 7/10071/71 [==============================] - 117s 2s/step - loss: 1.2601 - acc: 0.5133 - val_loss: 1.4683 - val_acc: 0.4464Epoch 8/10071/71 [==============================] - 117s 2s/step - loss: 1.2186 - acc: 0.5088 - val_loss: 1.2113 - val_acc: 0.5000Epoch 9/10071/71 [==============================] - 115s 2s/step - loss: 1.1850 - acc: 0.5214 - val_loss: 1.3347 - val_acc: 0.4464Epoch 10/10071/71 [==============================] - 122s 2s/step - loss: 1.1420 - acc: 0.5423 - val_loss: 1.2093 - val_acc: 0.5536Epoch 11/10071/71 [==============================] - 112s 2s/step - loss: 1.0990 - acc: 0.5678 - val_loss: 1.1321 - val_acc: 0.5089Epoch 12/10071/71 [==============================] - 112s 2s/step - loss: 1.0840 - acc: 0.5721 - val_loss: 1.1863 - val_acc: 0.5357Epoch 13/10071/71 [==============================] - 155s 2s/step - loss: 1.0766 - acc: 0.5979 - val_loss: 1.4430 - val_acc: 0.4554Epoch 14/10071/71 [==============================] - 169s 2s/step - loss: 0.9695 - acc: 0.6338 - val_loss: 0.9983 - val_acc: 0.6518Epoch 15/10071/71 [==============================] - 119s 2s/step - loss: 0.9901 - acc: 0.6294 - val_loss: 1.1473 - val_acc: 0.5625Epoch 16/10071/71 [==============================] - 115s 2s/step - loss: 1.0062 - acc: 0.6406 - val_loss: 1.0303 - val_acc: 0.6250Epoch 17/10071/71 [==============================] - 114s 2s/step - loss: 0.9439 - acc: 0.6397 - val_loss: 1.0116 - val_acc: 0.5714Epoch 18/10071/71 [==============================] - 116s 2s/step - loss: 0.9797 - acc: 0.6391 - val_loss: 1.1799 - val_acc: 0.5268Epoch 19/10071/71 [==============================] - 115s 2s/step - loss: 0.9340 - acc: 0.6459 - val_loss: 1.0967 - val_acc: 0.5804Epoch 20/10071/71 [==============================] - 114s 2s/step - loss: 0.8780 - acc: 0.6708 - val_loss: 1.0752 - val_acc: 0.5804Epoch 21/10071/71 [==============================] - 114s 2s/step - loss: 0.8546 - acc: 0.6824 - val_loss: 1.1991 - val_acc: 0.5714Epoch 22/10071/71 [==============================] - 115s 2s/step - loss: 0.8694 - acc: 0.6628 - val_loss: 1.2398 - val_acc: 0.5357Epoch 23/10071/71 [==============================] - 117s 2s/step - loss: 0.8411 - acc: 0.6901 - val_loss: 1.1025 - val_acc: 0.6786Epoch 24/10071/71 [==============================] - 117s 2s/step - loss: 0.8107 - acc: 0.7130 - val_loss: 1.1774 - val_acc: 0.5536Epoch 25/10071/71 [==============================] - 116s 2s/step - loss: 0.8752 - acc: 0.6673 - val_loss: 0.8081 - val_acc: 0.6696Epoch 26/10071/71 [==============================] - 111s 2s/step - loss: 0.8150 - acc: 0.7020 - val_loss: 0.9926 - val_acc: 0.6518Epoch 27/10071/71 [==============================] - 113s 2s/step - loss: 0.7882 - acc: 0.7104 - val_loss: 0.9890 - val_acc: 0.6339Epoch 28/10071/71 [==============================] - 113s 2s/step - loss: 0.7705 - acc: 0.7201 - val_loss: 1.0953 - val_acc: 0.6071Epoch 29/10071/71 [==============================] - 112s 2s/step - loss: 0.7430 - acc: 0.7377 - val_loss: 0.8792 - val_acc: 0.6429Epoch 30/10071/71 [==============================] - 112s 2s/step - loss: 0.7626 - acc: 0.7210 - val_loss: 0.8883 - val_acc: 0.6518Epoch 31/10071/71 [==============================] - 113s 2s/step - loss: 0.8552 - acc: 0.6815 - val_loss: 1.3025 - val_acc: 0.4732Epoch 32/10071/71 [==============================] - 112s 2s/step - loss: 0.7941 - acc: 0.7069 - val_loss: 0.8129 - val_acc: 0.7054Epoch 33/10071/71 [==============================] - 109s 2s/step - loss: 0.7429 - acc: 0.7313 - val_loss: 0.8716 - val_acc: 0.6696Epoch 34/10071/71 [==============================] - 112s 2s/step - loss: 0.6959 - acc: 0.7536 - val_loss: 1.0984 - val_acc: 0.6250Epoch 35/10071/71 [==============================] - 111s 2s/step - loss: 0.7375 - acc: 0.7394 - val_loss: 0.8002 - val_acc: 0.6786Epoch 36/10071/71 [==============================] - 111s 2s/step - loss: 0.7072 - acc: 0.7333 - val_loss: 0.7551 - val_acc: 0.7321Epoch 37/10071/71 [==============================] - 113s 2s/step - loss: 0.7440 - acc: 0.7403 - val_loss: 1.1043 - val_acc: 0.5982Epoch 38/10071/71 [==============================] - 111s 2s/step - loss: 0.7527 - acc: 0.7174 - val_loss: 0.8664 - val_acc: 0.6964Epoch 39/10071/71 [==============================] - 111s 2s/step - loss: 0.6643 - acc: 0.7473 - val_loss: 0.8213 - val_acc: 0.7232Epoch 40/10071/71 [==============================] - 112s 2s/step - loss: 0.7021 - acc: 0.7456 - val_loss: 0.8613 - val_acc: 0.7143Epoch 41/10071/71 [==============================] - 113s 2s/step - loss: 0.6386 - acc: 0.7720 - val_loss: 1.0223 - val_acc: 0.6250Epoch 42/10071/71 [==============================] - 111s 2s/step - loss: 0.6568 - acc: 0.7447 - val_loss: 0.9515 - val_acc: 0.6786Epoch 43/10071/71 [==============================] - 112s 2s/step - loss: 0.6394 - acc: 0.7852 - val_loss: 0.7870 - val_acc: 0.6786Epoch 44/10071/71 [==============================] - 113s 2s/step - loss: 0.6437 - acc: 0.7676 - val_loss: 0.8471 - val_acc: 0.6696Epoch 45/10071/71 [==============================] - 111s 2s/step - loss: 0.6337 - acc: 0.7623 - val_loss: 0.8950 - val_acc: 0.6964Epoch 46/10071/71 [==============================] - 111s 2s/step - loss: 0.5993 - acc: 0.7820 - val_loss: 1.0203 - val_acc: 0.5625Epoch 47/10071/71 [==============================] - 117s 2s/step - loss: 0.6087 - acc: 0.7799 - val_loss: 1.0065 - val_acc: 0.5982Epoch 48/10071/71 [==============================] - 138s 2s/step - loss: 0.6285 - acc: 0.7722 - val_loss: 0.8781 - val_acc: 0.6875Epoch 49/10071/71 [==============================] - 121s 2s/step - loss: 0.5612 - acc: 0.7984 - val_loss: 1.2335 - val_acc: 0.6429Epoch 50/10071/71 [==============================] - 116s 2s/step - loss: 0.5699 - acc: 0.7984 - val_loss: 0.9054 - val_acc: 0.6607Epoch 51/10071/71 [==============================] - 110s 2s/step - loss: 0.6030 - acc: 0.7909 - val_loss: 0.9821 - val_acc: 0.6429Epoch 52/10071/71 [==============================] - 111s 2s/step - loss: 0.6945 - acc: 0.7412 - val_loss: 0.8685 - val_acc: 0.6786Epoch 53/10071/71 [==============================] - 113s 2s/step - loss: 0.5679 - acc: 0.7905 - val_loss: 0.8510 - val_acc: 0.6875Epoch 54/10071/71 [==============================] - 113s 2s/step - loss: 0.6023 - acc: 0.7835 - val_loss: 0.8247 - val_acc: 0.6518Epoch 55/10071/71 [==============================] - 112s 2s/step - loss: 0.5590 - acc: 0.7896 - val_loss: 0.7802 - val_acc: 0.7321Epoch 56/10071/71 [==============================] - 114s 2s/step - loss: 0.5679 - acc: 0.8028 - val_loss: 0.7660 - val_acc: 0.6964Epoch 57/10071/71 [==============================] - 110s 2s/step - loss: 0.5839 - acc: 0.8028 - val_loss: 0.7611 - val_acc: 0.7321Epoch 58/10071/71 [==============================] - 111s 2s/step - loss: 0.5590 - acc: 0.7967 - val_loss: 1.0786 - val_acc: 0.6071Epoch 59/10071/71 [==============================] - 116s 2s/step - loss: 0.5194 - acc: 0.8275 - val_loss: 0.7342 - val_acc: 0.7321Epoch 60/10071/71 [==============================] - 110s 2s/step - loss: 0.4677 - acc: 0.8185 - val_loss: 0.9167 - val_acc: 0.6786Epoch 61/10071/71 [==============================] - 109s 2s/step - loss: 0.4906 - acc: 0.8052 - val_loss: 0.7638 - val_acc: 0.7321Epoch 62/10071/71 [==============================] - 112s 2s/step - loss: 0.5267 - acc: 0.8081 - val_loss: 1.2296 - val_acc: 0.5982Epoch 63/10071/71 [==============================] - 111s 2s/step - loss: 0.5880 - acc: 0.7909 - val_loss: 0.8299 - val_acc: 0.6964Epoch 64/10071/71 [==============================] - 110s 2s/step - loss: 0.5203 - acc: 0.8203 - val_loss: 0.6984 - val_acc: 0.7589Epoch 65/10071/71 [==============================] - 110s 2s/step - loss: 0.5617 - acc: 0.8007 - val_loss: 0.8506 - val_acc: 0.6964Epoch 66/10071/71 [==============================] - 112s 2s/step - loss: 0.4157 - acc: 0.8530 - val_loss: 0.9649 - val_acc: 0.6786Epoch 67/10071/71 [==============================] - 110s 2s/step - loss: 0.4726 - acc: 0.8363 - val_loss: 0.7467 - val_acc: 0.7411Epoch 68/10071/71 [==============================] - 115s 2s/step - loss: 0.4825 - acc: 0.8247 - val_loss: 0.9306 - val_acc: 0.6964Epoch 69/10071/71 [==============================] - 110s 2s/step - loss: 0.4757 - acc: 0.8363 - val_loss: 1.1517 - val_acc: 0.6161Epoch 70/10071/71 [==============================] - 111s 2s/step - loss: 0.4948 - acc: 0.8185 - val_loss: 0.8304 - val_acc: 0.7054Epoch 71/10071/71 [==============================] - 110s 2s/step - loss: 0.5174 - acc: 0.8096 - val_loss: 0.8679 - val_acc: 0.6518Epoch 72/10071/71 [==============================] - 111s 2s/step - loss: 0.4799 - acc: 0.8274 - val_loss: 0.8524 - val_acc: 0.7321Epoch 73/10071/71 [==============================] - 112s 2s/step - loss: 0.4212 - acc: 0.8495 - val_loss: 1.0715 - val_acc: 0.6875Epoch 74/10071/71 [==============================] - 111s 2s/step - loss: 0.5003 - acc: 0.8078 - val_loss: 0.8279 - val_acc: 0.7143Epoch 75/10071/71 [==============================] - 111s 2s/step - loss: 0.4267 - acc: 0.8425 - val_loss: 0.7447 - val_acc: 0.7500Epoch 76/10071/71 [==============================] - 111s 2s/step - loss: 0.4268 - acc: 0.8371 - val_loss: 0.8244 - val_acc: 0.7500Epoch 77/10071/71 [==============================] - 111s 2s/step - loss: 0.4720 - acc: 0.8247 - val_loss: 0.8961 - val_acc: 0.6786Epoch 78/10071/71 [==============================] - 112s 2s/step - loss: 0.4979 - acc: 0.8204 - val_loss: 0.8691 - val_acc: 0.6429Epoch 79/10071/71 [==============================] - 111s 2s/step - loss: 0.4445 - acc: 0.8461 - val_loss: 1.0964 - val_acc: 0.5982Epoch 80/10071/71 [==============================] - 112s 2s/step - loss: 0.4660 - acc: 0.8283 - val_loss: 0.9248 - val_acc: 0.6607Epoch 81/10071/71 [==============================] - 111s 2s/step - loss: 0.4824 - acc: 0.8222 - val_loss: 1.2059 - val_acc: 0.6339Epoch 82/10071/71 [==============================] - 110s 2s/step - loss: 0.4382 - acc: 0.8354 - val_loss: 0.8243 - val_acc: 0.6875Epoch 83/10071/71 [==============================] - 111s 2s/step - loss: 0.3791 - acc: 0.8603 - val_loss: 1.3547 - val_acc: 0.5804Epoch 84/10071/71 [==============================] - 112s 2s/step - loss: 0.4175 - acc: 0.8468 - val_loss: 1.1149 - val_acc: 0.7321Epoch 85/10071/71 [==============================] - 113s 2s/step - loss: 0.6471 - acc: 0.7740 - val_loss: 1.0958 - val_acc: 0.6250Epoch 86/10071/71 [==============================] - 113s 2s/step - loss: 0.4434 - acc: 0.8504 - val_loss: 0.8250 - val_acc: 0.6696Epoch 87/10071/71 [==============================] - 112s 2s/step - loss: 0.3719 - acc: 0.8559 - val_loss: 0.8524 - val_acc: 0.7589Epoch 88/10071/71 [==============================] - 109s 2s/step - loss: 0.3978 - acc: 0.8532 - val_loss: 0.8410 - val_acc: 0.7321Epoch 89/10071/71 [==============================] - 111s 2s/step - loss: 0.4387 - acc: 0.8398 - val_loss: 0.8426 - val_acc: 0.7232Epoch 90/10071/71 [==============================] - 110s 2s/step - loss: 0.4056 - acc: 0.8594 - val_loss: 0.8563 - val_acc: 0.7232Epoch 91/10071/71 [==============================] - 111s 2s/step - loss: 0.3897 - acc: 0.8592 - val_loss: 0.7448 - val_acc: 0.7321Epoch 92/10071/71 [==============================] - 110s 2s/step - loss: 0.3947 - acc: 0.8541 - val_loss: 0.7799 - val_acc: 0.7321Epoch 93/10071/71 [==============================] - 109s 2s/step - loss: 0.4416 - acc: 0.8488 - val_loss: 0.9649 - val_acc: 0.6518Epoch 94/10071/71 [==============================] - 116s 2s/step - loss: 0.3962 - acc: 0.8550 - val_loss: 1.2210 - val_acc: 0.6607Epoch 95/10071/71 [==============================] - 124s 2s/step - loss: 0.4087 - acc: 0.8577 - val_loss: 1.0710 - val_acc: 0.6607Epoch 96/10071/71 [==============================] - 117s 2s/step - loss: 0.3748 - acc: 0.8671 - val_loss: 0.8149 - val_acc: 0.7589Epoch 97/10071/71 [==============================] - 114s 2s/step - loss: 0.3882 - acc: 0.8550 - val_loss: 1.1649 - val_acc: 0.6875Epoch 98/10071/71 [==============================] - 115s 2s/step - loss: 0.3485 - acc: 0.8719 - val_loss: 0.9793 - val_acc: 0.6786Epoch 99/10071/71 [==============================] - 118s 2s/step - loss: 0.4128 - acc: 0.8477 - val_loss: 1.0489 - val_acc: 0.6964Epoch 100/10071/71 [==============================] - 115s 2s/step - loss: 0.3668 - acc: 0.8644 - val_loss: 1.1848 - val_acc: 0.6250
复制代码

结果展示 下面我们随机抽取 validation 中的 16 张图片,展示图片以及其标签,并且给予我们的预测。 我们发现预测的准确度还是蛮高的,对于大部分图片,都能识别出其类别。

test_x, test_y = validation_generator.__getitem__(1)
preds = model.predict(test_x)
plt.figure(figsize=(16, 16))for i in range(16): plt.subplot(4, 4, i+1) plt.title('pred:%s / truth:%s' % (labels[np.argmax(preds[i])], labels[np.argmax(test_y[i])])) plt.imshow(test_x[i])
复制代码

output4:


用户头像

AI_robot

关注

还未添加个人签名 2021.03.31 加入

Deep Learning从业者

评论

发布
暂无评论
tensorflow实现CNN模型垃圾分类算法