YOLOv8结合SAHI推理图像和视频

文章目录

  • 前言
  • 视频效果
  • 必要环境
  • 一、完整代码
  • 二、运行方法
    • 1、 推理图像
    • 2、 推理视频
  • 总结


前言

在上一篇文章中,我们深入探讨了如何通过结合YOLOv8和SAHI来增强小目标检测效果
,并计算了相关评估指标,虽然我们也展示了可视化功能,但是这些功能往往需要结合实际的ground truth(GT)数据进行对比,这在实际操作中可能会稍显不便。

为了进一步简化操作,这篇文章将直接分享可以用来推理图像和视频的代码,通过这段代码,我们能够更加方便地使用SAHI进行小目标检测,而不需要反复处理和对比GT数据。

不多说啦,以下是完整的代码示例,供大家参考使用
在这里插入图片描述


视频效果

b站链接:使用SAHI增强YOLOv8推理,提升小目标检测效果(附教程)


必要环境

  1. 参考上期博客
    地址:使用YOLOv8+SAHI增强小目标检测效果并计算评估指标

一、完整代码

import os
import cv2
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import argparse
from tqdm import tqdm
import time

parser = argparse.ArgumentParser(description="Object Detection Evaluation Script")
parser.add_argument('--filepath', type=str, default='test/images', help='Path to the images folder or video file')
parser.add_argument('--output_dir', type=str, default='output', help='Directory to save the output images or video')

parser.add_argument('--model_type', type=str, default='yolov8', help='Type of the detection model')
parser.add_argument('--model_path', type=str, default='yolov8n.pt', help='Path to the model weights')
parser.add_argument('--confidence_threshold', type=float, default=0.5, help='Confidence threshold for the model')
parser.add_argument('--device', type=str, default="cuda:0", help='Device to run the model on')
parser.add_argument('--slice_height', type=int, default=256, help='Height of the image slices')
parser.add_argument('--slice_width', type=int, default=256, help='Width of the image slices')
parser.add_argument('--overlap_height_ratio', type=float, default=0.2, help='Overlap height ratio for slicing')
parser.add_argument('--overlap_width_ratio', type=float, default=0.2, help='Overlap width ratio for slicing')
parser.add_argument('--visualize_predictions', action='store_true', default=False, help='Visualize prediction results')
parser.add_argument('--images_format', type=str, nargs='+', default=['.png', '.jpg', '.jpeg'],
                    help='List of acceptable image formats')
parser.add_argument('--videos_format', type=str, nargs='+', default=['.mp4', '.avi'],
                    help='List of acceptable video formats')
args = parser.parse_args()


def get_color(idx):
    idx = int(idx) + 5
    return ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)


def load_detection_model():
    return AutoDetectionModel.from_pretrained(
        model_type=args.model_type,
        model_path=args.model_path,
        confidence_threshold=args.confidence_threshold,
        device=args.device
    )


def process_image(image_name, model):
    img_path = os.path.join(args.filepath, image_name)
    img_vis = cv2.imread(img_path)

    result = get_sliced_prediction(
        img_path,
        model,
        slice_height=args.slice_height,
        slice_width=args.slice_width,
        overlap_height_ratio=args.overlap_height_ratio,
        overlap_width_ratio=args.overlap_width_ratio,
    )

    for pred in result.object_prediction_list:
        bbox = pred.bbox
        cls = pred.category.id
        score = pred.score.value
        name = pred.category.name
        xmin_pd, ymin_pd, xmax_pd, ymax_pd = bbox.minx, bbox.miny, bbox.maxx, bbox.maxy

        cv2.rectangle(img_vis, (int(xmin_pd), int(ymin_pd)), (int(xmax_pd), int(ymax_pd)),
                      get_color(cls), 2, cv2.LINE_AA)
        cv2.putText(img_vis, f"{name} {score:.2f}", (int(xmin_pd), int(ymin_pd - 5)),
                    cv2.FONT_HERSHEY_COMPLEX, 0.5, get_color(cls), thickness=2)

    if args.visualize_predictions:
        cv2.imshow(image_name, img_vis)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

    # 保存结果图像
    output_path = os.path.join(args.output_dir, image_name)
    os.makedirs(args.output_dir, exist_ok=True)
    cv2.imwrite(output_path, img_vis)


def process_video(video_path, model):
    cap = cv2.VideoCapture(video_path)
    output_path = os.path.join(args.output_dir, os.path.basename(video_path))
    os.makedirs(args.output_dir, exist_ok=True)

    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = None
    fps_list = []

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        start_time = time.time()

        result = get_sliced_prediction(
            frame,
            model,
            slice_height=args.slice_height,
            slice_width=args.slice_width,
            overlap_height_ratio=args.overlap_height_ratio,
            overlap_width_ratio=args.overlap_width_ratio,
        )

        for pred in result.object_prediction_list:
            bbox = pred.bbox
            cls = pred.category.id
            score = pred.score.value
            name = pred.category.name
            xmin_pd, ymin_pd, xmax_pd, ymax_pd = bbox.minx, bbox.miny, bbox.maxx, bbox.maxy

            cv2.rectangle(frame, (int(xmin_pd), int(ymin_pd)), (int(xmax_pd), int(ymax_pd)),
                          get_color(cls), 2, cv2.LINE_AA)
            cv2.putText(frame, f"{name} {score:.2f}", (int(xmin_pd), int(ymin_pd - 5)),
                        cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(cls), thickness=2)

        end_time = time.time()
        fps = 1 / (end_time - start_time)
        fps_list.append(fps)
        cv2.putText(frame, f"FPS: {fps:.2f}", (30, 60), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 0), 4)

        if out is None:
            frame_height, frame_width = frame.shape[:2]
            out = cv2.VideoWriter(output_path, fourcc, cap.get(cv2.CAP_PROP_FPS), (frame_width, frame_height))

        out.write(frame)

        if args.visualize_predictions:
            cv2.imshow('Video', frame)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    cap.release()
    out.release()
    cv2.destroyAllWindows()

    avg_fps = sum(fps_list) / len(fps_list)
    print(f"Average FPS: {avg_fps:.2f}")


def main():
    detection_model = load_detection_model()

    if os.path.isfile(args.filepath) and os.path.splitext(args.filepath)[1].lower() in args.videos_format:
        process_video(args.filepath, detection_model)
    else:
        image_names = [name for name in os.listdir(args.filepath) if
                       os.path.splitext(name)[1].lower() in args.images_format]

        for i, image_name in enumerate(tqdm(image_names, desc="Processing images")):
            process_image(image_name, detection_model)


if __name__ == "__main__":
    main()

二、运行方法

1、 推理图像

  • 将文件命名为 yolov8_sahi_inference.py
  • 运行如下命令,输出结果将保存到output文件夹下
    python yolov8_sahi_inference.py --filepath test/images  --output_dir output --model_type yolov8 --model_path yolov8n.pt
    

关键参数详解:

  • –filepath: 指定要推理的图像文件夹的路径
  • –output_dir: 指定保存推理结果的路径
  • –model_type: 指定检测模型的类型 (默认为yolov8)
  • –model_path: 指定模型权重文件的路径
  • –visualize_predictions: 如果设置True,将在运行代码过程中可视化推理结果

效果:
在这里插入图片描述

2、 推理视频

  • 将文件命名为 yolov8_sahi_inference.py
  • 运行如下命令,输出结果将保存到output文件夹下
    python yolov8_sahi_inference.py --filepath inputvideo.mp4  --output_dir output --model_type yolov8 --model_path yolov8n.pt --visualize_predictions
    

关键参数详解:

  • –filepath: 指定要推理的视频路径
  • –output_dir: 指定保存推理结果的路径
  • –model_type: 指定检测模型的类型 (默认为yolov8)
  • –model_path: 指定模型权重文件的路径
  • –visualize_predictions: 如果设置True,将在运行代码过程中可视化推理结果

效果:
在这里插入图片描述


总结

本期博客就到这里啦,喜欢的小伙伴们可以点点关注,感谢!

最近经常在b站上更新一些有关目标检测的视频,大家感兴趣可以来看看 https://b23.tv/1upjbcG

学习交流群:995760755

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/780919.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Open3D 点云的圆柱形邻域搜索

目录 一、概述 1.1原理 1.2应用 二、代码实现 2.1完整代码 2.2程序说明 三、实现效果 3.1原始点云 3.2搜索后点云 一、概述 1.1原理 圆柱邻域搜索的基本思想是确定点云中的哪些点位于给定圆柱的内部。一个圆柱可以由以下几个参数定义: 中心点:…

SpEL表达式相关知识点

SpEL表达式 知识点 Spel概述 Spring 表达式,即 Spring Expression Language,简称 SpEL。 那么是什么SpEL表达式呢? SpEL (Spring Expression Language) 是一种在Spring框架中用于处理表达式的语言。SpEL中的表达式可以支持调用bean的方法…

如何利用Github Action实现自动Merge PR

我是蚂蚁背大象(Apache EventMesh PMC&Committer),文章对你有帮助给项目rocketmq-rust star,关注我GitHub:mxsm,文章有不正确的地方请您斧正,创建ISSUE提交PR~谢谢! Emal:mxsmapache.com 1. 引言 GitHub Actions 是 GitHub 提供的一种强大而灵活的自…

苍穹外卖 ...待更新

苍穹外卖 1、 阿里云OSS2、菜品分类查询 1、 阿里云OSS 工具类 package com.sky.utils;import com.aliyun.oss.ClientException; import com.aliyun.oss.OSS; import com.aliyun.oss.OSSClientBuilder; import com.aliyun.oss.OSSException; import lombok.AllArgsConstructor…

【QT】显示类控件

显示类控件 显示类控件1. label - 标签2. LCD Number - 显示数字的控件3. ProgressBar - 进度条4. Calendar Widget - 日历5. Line Edit - 输入框6. Text Edit - 多行输入框7. Combo Box - 下拉框8. Spin Box - 微调框9. Date Edit & Time Edit - 日期微调框10. Dial - 旋钮…

自注意力机制和多头注意力机制区别

Ref:小白看得懂的 Transformer (图解) Ref:一文彻底搞懂 Transformer(图解手撕) 多头注意力机制(Multi-Head Attention)和自注意力机制(Self-Attention)是现代深度学习模型&#x…

浅尝Apache Mesos

文章目录 1. Mesos是什么2. 共享集群3. Apache Mesos3.1 Mesos主节点3.2 Mesos代理3.3 Mesos框架 4. 资源管理4.1 资源提供4.2 资源角色4.3 资源预留4.4 资源权重与配额 5. 实现框架5.1 框架主类5.3 实现执行器 6. 小结参考 1. Mesos是什么 Mesos是什么,Mesos是一个…

8、Redis 的线程模型、I/O 模型和多线程

Redis 的线程模型、I/O 模型和多线程 1. Redis 的线程模型 Redis 以其高效的单线程模型著称,从设计之初,Redis 就选择了单线程模式,这在很大程度上简化了其内部实现和维护。单线程模式避免了多线程编程中常见的竞争条件和锁机制问题&#x…

【WebRTC实现点对点视频通话】

介绍 WebRTC (Web Real-Time Communications) 是一个实时通讯技术,也是实时音视频技术的标准和框架。简单来说WebRTC是一个集大成的实时音视频技术集,包含了各种客户端api、音视频编/解码lib、流媒体传输协议、回声消除、安全传输等。对于开发者来说可以…

【云原生】Prometheus监控Docker指标并接入Grafana

目录 一、前言 二、docker监控概述 2.1 docker常用监控指标 2.2 docker常用监控工具 三、CAdvisor概述 3.1 CAdvisor是什么 3.2 CAdvisor功能特点 3.3 CAdvisor使用场景 四、CAdvisor对接Prometheus与Grafana 4.1 环境准备 4.2 docker部署CAdvisor 4.2.2 docker部署…

汉诺塔与青蛙跳台阶

1.汉诺塔 根据汉诺塔 - 维基百科 介绍 1.1 背景 最早发明这个问题的人是法国数学家爱德华卢卡斯。 传说越南河内某间寺院有三根银棒,上串 64 个金盘。寺院里的僧侣依照一个古老的预言,以上述规则移动这些盘子;预言说当这些盘子移动完毕&am…

Java项目:基于SSM框架实现的共享客栈管理系统分前后台【ssm+B/S架构+源码+数据库+毕业论文】

一、项目简介 本项目是一套基于SSM框架实现的共享客栈管理系统 包含:项目源码、数据库脚本等,该项目附带全部源码可作为毕设使用。 项目都经过严格调试,eclipse或者idea 确保可以运行! 该系统功能完善、界面美观、操作简单、功能…

网页生成二维码、在线演示

https://andi.cn/page/621504.html

go语言day11 错误 defer(),panic(),recover()

错误: 创建错误 1)fmt包下提供的方法 fmt.Errorf(" 格式化字符串信息 " , 空接口类型对象 ) 2)errors包下提供的方法 errors.New(" 字符串信息 ") 创建自定义错误 需要实现error接口,而error接口…

【JAVA多线程】线程池概论

目录 1.概述 2.ThreadPoolExector 2.1.参数 2.2.新任务提交流程 2.3.拒绝策略 2.4.代码示例 1.概述 线程池的核心: 线程池的实现原理是个标准的生产消费者模型,调用方不停向线程池中写数据,线程池中的线程组不停从队列中取任务。 实现…

“未来已来·智能共融”高峰论坛在京成功举办

在人工智能技术的澎湃浪潮中,其与传统产业的深度融合正逐步成为驱动区域经济增长的新引擎。2024年7月4号,一场以“未来已来智能共融——探索人类智能与人工智能共生共进的新路径”为主题的高峰论坛在北京电子科技职业学院图书馆圆满落幕,为北京经济技术开发区(简称“北京经开区…

时间处理的未来:Java 8全新日期与时间API完全解析

文章目录 一、改进背景二、本地日期时间三、时区日期时间四、格式化 一、改进背景 Java 8针对时间处理进行了全面的改进,重新设计了所有日期时间、日历及时区相关的 API。并把它们都统一放置在 java.time 包和子包下。 Java5的不足之处: 非线程安全&…

JAVA 课设 满汉楼餐厅点餐系统

一、代码详解 1.总体结构展示 2.总体代码 2.1 libs文件 链接:https://pan.baidu.com/s/1nH-I7gIlsqyMpXDDCFRuOA 提取码:3404 2.2 配置的德鲁连接池 #keyvalue driverClassNamecom.mysql.cj.jdbc.Driver urljdbc:mysql://localhost:3306/mhl?rewriteBa…

SAP_MM模块-特殊业务场景下的系统实现方案

一、业务背景 目前公司有一种电商业务,卖的是备品配件,是公司先跟供应商采购,然后再销售给客户,系统账就是按照正常业务来流转,公司进行采购订单入库,然后销售订单出库。 不过这种备品配件,实…

Android使用http加载自建服务器静态网页

最终效果如下图,成功加载了电脑端的静态网页内容,这是一个xml文件。 电脑端搭建http服务器 使用“Apache Http Server”,下载地址是:https://httpd.apache.org/download.cgi。具体操作步骤,参考:Apache …