当前位置: 首页 > news >正文

T10打卡—数据增强

​​​​​​​

  •   🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

1.导入及查看数据 

import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False
import os,PIL,pathlib
import warnings
warnings.filterwarnings('ignore')
data_dir="data/T8"
data_dir=pathlib.Path(data_dir)
image_count=len(list(data_dir.glob('*/*')))
print("图片总数:",image_count)

​​

2.加载数据

batch_size=64
img_hight=224
img_width=224import tensorflow as tftrain_ds=tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height,img_width),batch_size=batch_size
)val_ds=tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height,img_width),batch_size=batch_size
)

​​

3.创建验证集

val_batches=tf.data.experimental.cardinality(val_ds)
test_ds=val_ds.take(val_batches // 5)
val_ds=val_ds.skip(val_batches // 5)
print('Number of validation batches: %d' % tf.data.experimental.cardinality(val_ds))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_ds))
#查看分类
class_names=train_ds.class_names
print(class_names)

​​

4配置数据集

AUTOTUNE=tf.data.AUTOTUNE
def preprocess_image(image,label):return(image/255,label)
train_ds=train_ds.map(preprocess_image,num_parallel_calls=AUTOTUNE)
val_ds=val_ds.map(preprocess_image,num_parallel_calls=AUTOTUNE)
test_ds=test_ds.map(preprocess_image,num_parallel_calls=AUTOTUNE)train_ds=train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds=val_ds.cache().prefetch(buffer_size=AUTOTUNE)

5.数据可视化及数据增强

plt.figure(figsize=(15,10))
for images,labels in train_ds.take(1):for i in range(8):ax=plt.subplot(5,8,i+1)plt.imshow(images[i])plt.title(class_names[labels[i]])plt.axis("off")

data_arguementation=tf.keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])
image=tf.expand_dims(images[i],0)
plt.figure(figsize=(8, 8))
for i in range(9):augmented_image = data_augmentation(image)ax = plt.subplot(3, 3, i + 1)plt.imshow(augmented_image[0])plt.axis("off")

6.构建模型

model = tf.keras.Sequential([layers.Conv2D(16, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(32, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(64, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(len(class_names))
])

​​

7.编译及训练模型

model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])epochs=20
history=model.fit(train_ds,validation_data=val_ds,epochs=epochs
)

​​

8.查看准确率

loss,acc=model.evaluate(test_ds)
print("Accuracy:",acc)

​​

9.自定义增强函数

import random
def aug_img(image):seed=(random.randint(0,9),0)stateless_random_brightness = tf.image.stateless_random_contrast(image, lower=0.1, upper=1.0, seed=seed)return stateless_random_brightnessimage=tf.expand_dims(images[3]*255,0)
print("Min and max pixel values:", image.numpy().min(), image.numpy().max())plt.figure(figsize=(8, 8))
for i in range(9):augmented_image = aug_img(image)ax = plt.subplot(3, 3, i + 1)plt.imshow(augmented_image[0].numpy().astype("uint8"))plt.axis("off")

​​

总结:

1.数据增强

我们可以使用 tf.keras.layers.experimental.preprocessing.RandomFliptf.keras.layers.experimental.preprocessing.RandomRotation 进行数据增强

  • tf.keras.layers.experimental.preprocessing.RandomFlip:水平和垂直随机翻转每个图像。
  • tf.keras.layers.experimental.preprocessing.RandomRotation:随机旋转每个图像
data_augmentation = tf.keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])

第一个层表示进行随机的水平和垂直翻转,而第二个层表示按照 0.2 的弧度值进行随机旋转。

更多的数据增强方式可以参考:https://www.tensorflow.org/api_docs/python/tf/keras/layers/RandomRotation

2.增强方式

​​​​​​方式一:将其嵌入model中

model = tf.keras.Sequential([data_augmentation,layers.Conv2D(16, 3, padding='same', activation='relu'),layers.MaxPooling2D(),
])

这样做的好处是:

  • 数据增强这块的工作可以得到GPU的加速

注意:只有在模型训练时(Model.fit)才会进行增强,在模型评估(Model.evaluate)以及预测(Model.predict)时并不会进行增强操作。

方式二:在Dataset数据集中进行数据增强

batch_size = 32
AUTOTUNE = tf.data.AUTOTUNEdef prepare(ds):ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)return dstrain_ds = prepare(train_ds)


http://www.mrgr.cn/news/63336.html

相关文章:

  • AI的崛起:它将如何改变IT行业的职业景象?
  • 势均力敌(C++ 三级题--使用vector和push_back)
  • 苍穹外卖07——来单提醒和客户催单(涉及SpringTask、WebSocket协议、苍穹外卖跳过微信支付同时保证可以收到订单功能)
  • 微信小程序防止重复点击事件
  • Qt 智能指针
  • 华为数据治理方法论深入解读+全文阅读
  • 一文了解运维监控体系的方方面面
  • 低压电容补偿不用时会有电流损耗吗?
  • 力扣11.1
  • 创建线程池时为什么不建议使用Executors进行创建
  • VMware Workstation 17.0虚拟机安装Ubuntu Server 22.04.5 LTS并配置SSH与XFTP详细过程
  • 基于Matlab GUI的说话人识别测试平台
  • 基于SpringBoot的健身房系统的设计与实现(源码+定制+开发)
  • 国标GB28181摄像机接入EasyGBS国标GB28181软件与国标协议对接解决方案
  • 聚“芯”而行,华普微亮相第五届Silicon Labs Works With大会
  • HashSet 和 TreeSet 分别是如何实现去重的
  • Java 批量导出Word模板生成ZIP文件到浏览器默认下载位置
  • 【经验分享】从网页下载内嵌PDF的小妙招,亲测好用
  • OpenEuler 使用ffmpeg x11grab捕获屏幕流,rtsp推流,并用vlc播放
  • React04 State变量 组件渲染
  • fasdsdsadsa
  • 2024高性价比电容笔推荐!盘点实测西圣、绿联、酷盟电容笔!
  • qt QStackedWidget详解
  • Gemini API 和 Google AI Studio 升级,提升搜索准确性和响应能力
  • L 波段射频信号采集回放系统
  • window与Linux基础-1