当前位置: 首页 > news >正文

T4—猴痘识别

  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

1.导入数据

#设置gpu
from tensorflow import keras
from tensorflow.keras import layers,models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow  as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:gpu0 = gpus[0]                                       tf.config.experimental.set_memory_growth(gpu0, True)  tf.config.set_visible_devices([gpu0],"GPU")
gpusdata_dir="data/45-data/"
data_dir=pathlib.Path(data_dir)image_count=len(list(data_dir.glob('*/*.jpg')))
print("图片的总数为:",image_count)Monkeypox=list(data_dir.glob('Monkeypox/*.jpg'))
PIL.Image.open(str(Monkeypox[0]))

2.加载数据

batch_size=32
img_height=224
img_width=224train_ds=tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height,img_width),batch_size=batch_size)val_ds=tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height,img_width),batch_size=batch_size)

3.数据可视化

class_names=train_ds.class_names
print(class_names)plt.figure(figsize=(20,10))for images,labels in train_ds.take(1):for i in range(20):ax=plt.subplot(5,10,i+1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")

4.检查数据

for image_batch,labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

5.配置数据与构建模型

AUTOTUNE=tf.data.AUTOTUNE
train_ds=train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds=val_ds.cache().prefetch(buffer_size=AUTOTUNE)num_classes=2
model = models.Sequential([layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), layers.AveragePooling2D((2, 2)),               layers.Conv2D(32, (3, 3), activation='relu'),  layers.AveragePooling2D((2, 2)),               layers.Dropout(0.3),  layers.Conv2D(64, (3, 3), activation='relu'),  layers.Dropout(0.3),  layers.Flatten(),                       layers.Dense(128, activation='relu'),  layers.Dense(num_classes)               
])model.summary()  

6.编译并训练模型

opt=tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(optimizer=opt,loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])from tensorflow.keras.callbacks import ModelCheckpoint
epochs=50
checkpointer=ModelCheckpoint('best_model.h5',monitor='val_accuracy',verbose=1,save_weights_only=True)
history=model.fit(train_ds,validation_data=val_ds,epochs=epochs,callbacks=[checkpointer])

7.结果可视化

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

8.预测

model.load_weights('best_model.h5')
from PIL import Image
import numpy as npimg=Image.open("data/45-data/Others/NM15_02_11.jpg")
image=tf.image.resize(img,[img_height,img_width])
img_array=tf.expand_dims(image,0)predictions=model.predict(img_array)
print("预测结果为:",class_names[np.argmax(predictions)])

总结:

1.shuffle()函数:

首先,Dataset会取所有数据的前buffer_size数据项,填充 buffer,如下图

然后,从buffer中随机选择一条数据输出,比如这里随机选中了item 7,那么bufferitem 7对应的位置就空出来了

然后,从Dataset中顺序选择最新的一条数据填充到buffer中,这里是item 10

然后在从Buffer中随机选择下一条数据输出。

需要说明的是,这里的数据项item,并不只是单单一条真实数据,如果有batch size,则一条数据项item包含了batch size条真实数据。

shuffle是防止数据过拟合的重要手段,然而不当的buffer size,会导致shuffle无意义

2.prefetch() :预取数据,加速运行

    CPU 正在准备数据时,加速器处于空闲状态。相反,当加速器正在训练模型时,CPU 处于空闲状态。因此,训练所用的时间是 CPU 预处理时间和加速器训练时间的总和。prefetch()将训练步骤的预处理和模型执行过程重叠到一起。当加速器正在执行第 N 个训练步时,CPU 正在准备第 N+1 步的数据。这样做不仅可以最大限度地缩短训练的单步用时(而不是总用时),而且可以缩短提取和转换数据所需的时间。如果不使用prefetch(),CPU 和 GPU/TPU 在大部分时间都处于空闲状态。


http://www.mrgr.cn/news/32987.html

相关文章:

  • Redis数据结构之哈希表
  • 【HTTP】请求“报头”,Referer 和 Cookie
  • 盘点3款.NetCore(C#)开源免费商城系统
  • C++(2)进阶语法
  • 十四、运算放大电路
  • 初中数学证明集锦之三角形内角和
  • 【小沐学GIS】blender导入OpenStreetMap城市建筑(blender-osm、blosm)
  • 结构体对齐、函数传参、库移植
  • Spring:统一结果私有属性造成的前端无法访问异常报错问题
  • 博客管理系统可行性分析报告
  • Elionix 电子束曝光系统
  • 分析redis实现分布式锁的思路
  • 【亿美软通-注册/登录安全分析报告】
  • 掌握 JavaScript 中的函数表达式
  • 安装黑群晖系统,并使用NAS公网助手访问教程(好文)
  • Android通知服务及相关概念
  • Flutter 获取手机连接的Wifi信息
  • Ribbon布局和尺寸调整
  • 详解lsof
  • NXP官方或正点原子mfgtool下载系统报错initialize the library falied error code:29