极验九宫格模型训练
项目地址:https://github.com/ultralytics/ultralyticsYOLOv8 是 Ultralytics 公司于 2023 年推出的最新一代 YOLO(You Only Look Once)实时目标检测模型,继承并优化了 YOLO 系列“单阶段、端到端、高精度、高速度”的核心理念。它在架构设计上融合了先进的骨干网络(如 C2f 模块)、更高效的特征金字塔(PAN-FP
参考文章:
https://www.52pojie.cn/forum.php?mod=viewthread&tid=2032868&extra=page%3D2&page=1
1.模型介绍和流程介绍
YOLOv8模型:
项目地址:https://github.com/ultralytics/ultralytics
YOLOv8 是 Ultralytics 公司于 2023 年推出的最新一代 YOLO(You Only Look Once)实时目标检测模型,继承并优化了 YOLO 系列“单阶段、端到端、高精度、高速度”的核心理念。它在架构设计上融合了先进的骨干网络(如 C2f 模块)、更高效的特征金字塔(PAN-FPN)以及改进的损失函数和数据增强策略,在保持推理速度优势的同时显著提升了检测精度。YOLOv8 支持目标检测、实例分割、姿态估计和图像分类等多种任务,并提供从 nano 到 extra-large 的多个模型尺寸,适用于从边缘设备到云端服务器的广泛场景。其开源、易用、训练部署一体化的特点,使其成为工业界和学术界广泛采用的目标检测框架之一。
clip模型:
项目地址:https://github.com/OpenAI/CLIP?tab=readme-ov-file
https://github.com/OFA-Sys/Chinese-CLIP
CLIP(Contrastive Language–Image Pretraining)是由 OpenAI 于 2021 年提出的一种多模态预训练模型,其核心思想是通过对比学习(Contrastive Learning)将图像和文本映射到同一个语义向量空间中,从而实现强大的零样本迁移能力(zero-shot transfer)。上面第一个链接是国外原版地址,下面是对中文进行适配的汉化版,专为 中文图文理解与匹配任务 设计。
这里提供两种可行方案,一种是使用yolov分类模型,一种是使用clip模型,也可以两种混合使用(使用yolov分类识别点击目标图,在使用clip识别九宫格)
2.yolov8训练
1.数据集准备
极验的九宫格图片分为上下两个部分,一个是目标点击图,一个是九宫格区域,将图片进行裁剪为小图,注意最好区分开目标图和九宫格图的名称方便后续处理。


将所有裁剪的小图放到同一个文件,然后使用视觉模型识别,这里我使用的ai平台(ai),将所有图片进行识别分类保存,大致保存格式如图所示。

2.模型训练
将数据集划分,比例可自行设定,我设定的是0.7,0.2,0.1,并创建data.yaml。然后就可以了,训练配置如下所示
from ultralytics import YOLO
def main():
model = YOLO('yolov8m-cls.pt')
# 分类训练配置
model.train(
data=r'E:\yolov8\ultralytics\nine_data\data.yaml',
epochs=30,
batch=32,
imgsz=96,
device='0', # 使用GPU 0
)
if __name__ == '__main__':
main() # 确保训练代码在 __main__ 中执行
yaml文件:
# YOLOv8 分类数据集配置文件
path: E:\yolov8\ultralytics\nine_data # 数据集根目录
train: train # 训练集相对路径
val: val # 验证集相对路径
test: test # 测试集相对路径
# 类别数量
nc: 90
# 类别名称
names: ['乌龟', '书', '井盖', '交通信号灯', '企鹅', '伞', '兔子', '公交车', '公鸡', '冰箱', '剪刀', '叉子', '口红', '台灯', '台球', '听诊器', '喷泉', '地球仪', '头盔', '帽子', '手套', '手电筒', '手表', '打印机', '拉链', '捕蚊拍', '插座', '摩托车', '救护车', '斧头', '方向盘', '望远镜', '桌子', '桥', '梳子', '椅子', '气球', '水壶', '注射器', '火箭', '烟斗', '熊猫', '牙刷', '牛', '狗', '狮子', '猪', '猫', '猴子', '瓢虫', '电钻', '相机', '眼镜', '碗', '秋千', '积木', '笔', '纽扣', '羊', '羽毛球', '老虎', '船', '蝴蝶', '螺丝刀', '行李箱', '袋鼠', '袜子', '计算器', '订书机', '贝壳', '足球', '轮椅', '轮胎', '过山车', '钥匙', '钱包', '铁轨', '铲子', '锅', '键盘', '长颈鹿', '音响', '领带', '骆驼', '鱼', '鳄鱼', '鸟', '鹿', '鼠标', '齿轮']
训练完成后,会在runs目录下生成一个classfy文件夹,里面就是分类模型训练结果。

然后就可以使用代码识别测试效果了,测试代码如下所示。
import os
from ultralytics import YOLO
# 加载模型
onnx_model = YOLO(r"E:\yolov8\ultralytics\runs\classify\suduku4\weights\best.pt", task='classify')
dirpath = r'E:\yolov8\ultralytics\datasets\方向盘\0aa85105a3bc43b38375640d8f86d431_6.jpg'
results = onnx_model.predict(source=dirpath, imgsz=96)
# 处理结果
for result in results:
# 获取预测结果
top1_index = result.probs.top1 # 最高概率类别索引[5](@ref)
top1_confidence = result.probs.top1conf # 最高概率置信度
top5_indices = result.probs.top5 # 前5个最高概率类别索引
top5_confidences = result.probs.top5conf # 前5个最高概率置信度
# 获取类别名称映射字典[1,3,5](@ref)
class_names_dict = result.names # 或者使用 onnx_model.names
# 获取最高概率对应的类别名称[1,3](@ref)
top1_class_name = class_names_dict[top1_index]
# 获取前5个预测结果对应的类别名称
top5_class_names = [class_names_dict[idx] for idx in top5_indices]
# 打印结果
print(f"最高概率类别索引: {top1_index}")
print(f"最高概率类别名称: {top1_class_name}")
print(f"最高概率置信度: {top1_confidence:.4f}")
print("\n前5个预测结果:")
for i, (idx, conf, name) in enumerate(zip(top5_indices, top5_confidences, top5_class_names)):
print(f"{i + 1}. 类别索引: {idx}, 类别名称: {name}, 置信度: {conf:.4f}")

分类模型就训练完了,识别时就将裁剪的图片使用模型识别并获取坐标即可,不过想要准确率高就需要很多不同类型的九宫格图片。
这种方法训练出来效果还可以,多优化几次准确率基本到90%以上。

3.clip模型
这个模型相比较就比较方便了,可以选择自己训练,也可以选择直接使用他训练好的模型。
测试代码如下图所示,只需要修改类别和传入图片,不过注意他这里时返回的所有类别的相似度,需要自己处理一下取最高相似度和对应类别名称。
import torch
from PIL import Image
import cn_clip.clip as clip
from cn_clip.clip import load_from_name, available_models
print("Available models:", available_models())
# Available models: ['ViT-B-16', 'ViT-L-14', 'ViT-L-14-336', 'ViT-H-14', 'RN50']
device = "cuda" if torch.cuda.is_available() else "cpu"
# 如本地模型不存在,自动从ModelScope下载模型,需要提前安装`modelscope`包
model, preprocess = load_from_name("ViT-B-16", device=device, download_root='./', use_modelscope=True)
model.eval()
image = preprocess(Image.open("examples/pokemon.jpeg")).unsqueeze(0).to(device)
text = clip.tokenize(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
# 对特征进行归一化,请使用归一化后的图文特征用于下游任务
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
logits_per_image, logits_per_text = model.get_similarity(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("Label probs:", probs) # [[1.268734e-03 5.436878e-02 6.795761e-04 9.436829e-01]]
如果想要训练自己的数据集可以参考项目给出的文档,使用官方模型效果也还可以,准确率稳定在80以上。

更多推荐

所有评论(0)