前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Yolov8 源码解析(四十一)

Yolov8 源码解析(四十一)

作者头像
ApacheCN_飞龙
发布2024-09-13 17:20:16
1530
发布2024-09-13 17:20:16
举报
文章被收录于专栏:信数据得永生

.\yolov8\ultralytics\utils\callbacks\raytune.p

代码语言:javascript
复制
# Ultralytics YOLO 🚀, AGPL-3.0 license

# 从 ultralytics.utils 导入 SETTINGS 模块
from ultralytics.utils import SETTINGS

try:
    # 确保 SETTINGS 中的 "raytune" 键值为 True,验证集成已启用
    assert SETTINGS["raytune"] is True  # verify integration is enabled
    
    # 导入 ray 和相关的 tune、session 模块
    import ray
    from ray import tune
    from ray.tune import session as ray_session

except (ImportError, AssertionError):
    # 如果导入失败或者断言失败,将 tune 设置为 None
    tune = None


def on_fit_epoch_end(trainer):
    """Sends training metrics to Ray Tune at end of each epoch."""
    # 使用 ray.train._internal.session._get_session() 检查 Ray Tune 会话是否启用
    if ray.train._internal.session._get_session():  # replacement for deprecated ray.tune.is_session_enabled()
        metrics = trainer.metrics  # 获取训练指标
        metrics["epoch"] = trainer.epoch  # 将当前训练轮数添加到指标中
        ray_session.report(metrics)  # 将指标报告给 Ray Tune


callbacks = (
    {
        "on_fit_epoch_end": on_fit_epoch_end,  # 在每个 epoch 结束时调用 on_fit_epoch_end 回调函数
    }
    if tune  # 如果 tune 不为 None,表示 Ray Tune 已经成功导入
    else {}  # 如果 tune 为 None,回调函数为空字典
)

.\yolov8\ultralytics\utils\callbacks\tensorboard.py

代码语言:javascript
复制
# 导入上下文管理工具
import contextlib

# 从ultralytics.utils模块中导入必要的组件:LOGGER, SETTINGS, TESTS_RUNNING, colorstr
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr

try:
    # 尝试导入TensorBoard的SummaryWriter
    from torch.utils.tensorboard import SummaryWriter

    # 确保不处于测试运行中,避免记录pytest
    assert not TESTS_RUNNING
    # 确保SETTINGS中的tensorboard选项为True,验证集成已启用
    assert SETTINGS["tensorboard"] is True
    # 初始化TensorBoard的SummaryWriter实例为None
    WRITER = None
    # 定义输出前缀为TensorBoard:
    PREFIX = colorstr("TensorBoard: ")

    # 如果启用了TensorBoard,则需要以下导入
    import warnings
    from copy import deepcopy
    from ultralytics.utils.torch_utils import de_parallel, torch

except (ImportError, AssertionError, TypeError, AttributeError):
    # 处理导入错误、断言错误、类型错误和属性错误异常
    # TypeError用于处理Windows中的'Descriptors cannot not be created directly.' protobuf错误
    # AttributeError: 如果未安装'tensorflow',则模块'tensorflow'没有'io'属性
    SummaryWriter = None


def _log_scalars(scalars, step=0):
    """将标量值记录到TensorBoard中。"""
    if WRITER:
        for k, v in scalars.items():
            WRITER.add_scalar(k, v, step)


def _log_tensorboard_graph(trainer):
    """将模型图记录到TensorBoard中。"""

    # 输入图像尺寸
    imgsz = trainer.args.imgsz
    imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
    p = next(trainer.model.parameters())  # 获取模型的第一个参数以确定设备和类型
    im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype)  # 输入图像(必须是零,而不是空)

    with warnings.catch_warnings():
        # 忽略特定警告以减少干扰
        warnings.simplefilter("ignore", category=UserWarning)  # 抑制jit追踪警告
        warnings.simplefilter("ignore", category=torch.jit.TracerWarning)  # 抑制jit追踪警告

        # 首先尝试简单方法(例如YOLO)
        with contextlib.suppress(Exception):
            trainer.model.eval()  # 将模型置于评估模式,避免BatchNorm统计量的更改
            WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
            LOGGER.info(f"{PREFIX}模型图可视化已添加 ✅")
            return

        # 退回到TorchScript导出步骤(例如RTDETR)
        try:
            model = deepcopy(de_parallel(trainer.model))
            model.eval()
            model = model.fuse(verbose=False)
            for m in model.modules():
                if hasattr(m, "export"):  # 检测是否为RTDETRDecoder等,需使用Detect基类
                    m.export = True
                    m.format = "torchscript"
            model(im)  # 进行一次干跑
            WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
            LOGGER.info(f"{PREFIX}模型图可视化已添加 ✅")
        except Exception as e:
            LOGGER.warning(f"{PREFIX}警告 ⚠️ TensorBoard模型图可视化失败 {e}")


def on_pretrain_routine_start(trainer):
    """使用SummaryWriter初始化TensorBoard日志记录。"""
    # 检查是否存在 SummaryWriter 类
    if SummaryWriter:
        # 尝试创建全局变量 WRITER,并初始化 SummaryWriter 实例
        try:
            global WRITER
            WRITER = SummaryWriter(str(trainer.save_dir))
            # 记录日志,指示如何启动 TensorBoard 并查看日志
            LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
        # 捕获可能发生的异常
        except Exception as e:
            # 记录警告日志,指示 TensorBoard 初始化失败,当前运行未记录日志
            LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
# 在训练开始时调用的回调函数,用于记录 TensorBoard 图。
def on_train_start(trainer):
    # 如果存在 SummaryWriter 对象,则记录 TensorBoard 图
    if WRITER:
        _log_tensorboard_graph(trainer)


# 在每个训练周期结束时记录标量统计信息的回调函数。
def on_train_epoch_end(trainer):
    # 记录训练损失相关项的标量统计信息,使用指定的前缀 "train"
    _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
    # 记录学习率的标量统计信息
    _log_scalars(trainer.lr, trainer.epoch + 1)


# 在每个训练周期结束时记录周期度量指标的回调函数。
def on_fit_epoch_end(trainer):
    # 记录训练器的度量指标的标量统计信息
    _log_scalars(trainer.metrics, trainer.epoch + 1)


# 根据条件创建回调函数字典,可能包括各个训练阶段的回调函数。
callbacks = (
    {
        "on_pretrain_routine_start": on_pretrain_routine_start,  # 在预训练过程开始时调用的回调函数
        "on_train_start": on_train_start,  # 在训练开始时调用的回调函数
        "on_fit_epoch_end": on_fit_epoch_end,  # 在每个训练周期结束时调用的回调函数
        "on_train_epoch_end": on_train_epoch_end,  # 在每个训练周期结束时调用的回调函数
    }
    if SummaryWriter  # 如果存在 SummaryWriter 对象,则添加相应的回调函数到 callbacks 字典中
    else {}  # 如果不存在 SummaryWriter 对象,则 callbacks 字典为空
)

.\yolov8\ultralytics\utils\callbacks\wb.py

代码语言:javascript
复制
# Ultralytics YOLO 🚀, AGPL-3.0 license

# 导入必要的模块和变量
from ultralytics.utils import SETTINGS, TESTS_RUNNING  # 从ultralytics.utils中导入SETTINGS和TESTS_RUNNING变量
from ultralytics.utils.torch_utils import model_info_for_loggers  # 从ultralytics.utils.torch_utils中导入model_info_for_loggers函数

try:
    assert not TESTS_RUNNING  # 确保不是在运行测试时记录日志,断言不应该是pytest
    assert SETTINGS["wandb"] is True  # 验证W&B集成是否启用

    # 尝试导入并验证wandb模块
    import wandb as wb
    assert hasattr(wb, "__version__")  # 确保wandb模块已经正确导入,而不是一个目录
    _processed_plots = {}

except (ImportError, AssertionError):
    wb = None  # 如果导入失败或者断言失败,则将wb设为None


def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
    """
    Create and log a custom metric visualization to wandb.plot.pr_curve.

    This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall
    curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
    different classes.

    Args:
        x (List): Values for the x-axis; expected to have length N.
        y (List): Corresponding values for the y-axis; also expected to have length N.
        classes (List): Labels identifying the class of each point; length N.
        title (str, optional): Title for the plot; defaults to 'Precision Recall Curve'.
        x_title (str, optional): Label for the x-axis; defaults to 'Recall'.
        y_title (str, optional): Label for the y-axis; defaults to 'Precision'.

    Returns:
        (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
    """
    import pandas  # 用于更快的导入ultralytics的作用域

    # 创建一个包含x、y和classes的DataFrame对象,并保留小数点后三位
    df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
    fields = {"x": "x", "y": "y", "class": "class"}
    string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
    
    # 使用wandb.plot_table将数据表格化,并指定相关字段和字符串字段
    return wb.plot_table(
        "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
    )


def _plot_curve(
    x,
    y,
    names=None,
    id="precision-recall",
    title="Precision Recall Curve",
    x_title="Recall",
    y_title="Precision",
    num_x=100,
    only_mean=False,
):
    """
    Log a metric curve visualization.

    This function generates a metric curve based on input data and logs the visualization to wandb.
    The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.
    """
    # 函数用于生成基于输入数据的度量曲线,并将其记录到wandb中
    pass  # 该函数当前没有实现任何功能,只是一个占位符
    Args:
        x (np.ndarray): Data points for the x-axis with length N.
        y (np.ndarray): Corresponding data points for the y-axis with shape CxN, where C is the number of classes.
        names (list, optional): Names of the classes corresponding to the y-axis data; length C. Defaults to [].
        id (str, optional): Unique identifier for the logged data in wandb. Defaults to 'precision-recall'.
        title (str, optional): Title for the visualization plot. Defaults to 'Precision Recall Curve'.
        x_title (str, optional): Label for the x-axis. Defaults to 'Recall'.
        y_title (str, optional): Label for the y-axis. Defaults to 'Precision'.
        num_x (int, optional): Number of interpolated data points for visualization. Defaults to 100.
        only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. Defaults to True.

    Note:
        The function leverages the '_custom_table' function to generate the actual visualization.
    """
    import numpy as np

    # Create new x
    if names is None:
        names = []
    # Generate a new array of x values by linearly interpolating between the first and last x values
    x_new = np.linspace(x[0], x[-1], num_x).round(5)

    # Create arrays for logging
    # Convert x_new to a list for logging purposes
    x_log = x_new.tolist()
    # Interpolate the mean values of y across the new x values and convert to list for logging
    y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()

    # Conditionally log either only the mean curve or all curves
    if only_mean:
        # Create a table with x and y data and log a line plot with WandB
        table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
        wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
    else:
        # Prepare to log multiple curves with individual class names
        classes = ["mean"] * len(x_log)
        for i, yi in enumerate(y):
            x_log.extend(x_new)  # Add new x values for the current class
            y_log.extend(np.interp(x_new, x, yi))  # Interpolate y values for the current class
            classes.extend([names[i]] * len(x_new))  # Append corresponding class names

        # Log a custom table visualization with WandB, without committing the log immediately
        wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)
# 定义一个函数,用于记录指定步骤中尚未记录的输入字典中的图表
def _log_plots(plots, step):
    # 使用浅拷贝以防止迭代过程中更改 plots 字典
    for name, params in plots.copy().items():
        # 获取图表的时间戳
        timestamp = params["timestamp"]
        # 如果未记录过这个图表(根据时间戳判断)
        if _processed_plots.get(name) != timestamp:
            # 记录图表到 wandb 的运行日志中,使用图表名称作为键,图像文件路径作为值
            wb.run.log({name.stem: wb.Image(str(name))}, step=step)
            # 更新已处理的图表记录
            _processed_plots[name] = timestamp


# 当训练前例程开始时执行的回调函数,根据模块的存在初始化并启动项目
def on_pretrain_routine_start(trainer):
    # 如果 wb.run 不存在,则初始化一个 wandb 运行时
    wb.run or wb.init(project=trainer.args.project or "YOLOv8", name=trainer.args.name, config=vars(trainer.args))


# 每个训练周期结束时记录训练指标和模型信息的回调函数
def on_fit_epoch_end(trainer):
    # 记录训练指标到 wandb 运行日志中,使用当前周期数作为步骤
    wb.run.log(trainer.metrics, step=trainer.epoch + 1)
    # 记录训练过程中的图表到 wandb 运行日志中
    _log_plots(trainer.plots, step=trainer.epoch + 1)
    # 记录验证集的图表到 wandb 运行日志中
    _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
    # 如果是第一个周期,记录模型信息到 wandb 运行日志中
    if trainer.epoch == 0:
        wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)


# 每个训练周期结束时记录指标和保存图像的回调函数
def on_train_epoch_end(trainer):
    # 记录训练损失项到 wandb 运行日志中,使用当前周期数作为步骤
    wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
    # 记录当前学习率到 wandb 运行日志中,使用当前周期数作为步骤
    wb.run.log(trainer.lr, step=trainer.epoch + 1)
    # 如果是第二个周期,记录训练过程中的图表到 wandb 运行日志中
    if trainer.epoch == 1:
        _log_plots(trainer.plots, step=trainer.epoch + 1)


# 训练结束时保存最佳模型作为 artifact 的回调函数
def on_train_end(trainer):
    # 记录验证集的图表到 wandb 运行日志中,使用当前周期数作为步骤
    _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
    # 记录训练过程中的图表到 wandb 运行日志中,使用当前周期数作为步骤
    _log_plots(trainer.plots, step=trainer.epoch + 1)
    # 创建一个类型为 "model" 的 artifact,用于保存最佳模型
    art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
    # 如果存在最佳模型文件,则将其添加到 artifact 中
    if trainer.best.exists():
        art.add_file(trainer.best)
        # 记录 artifact 到 wandb 运行日志中,并指定别名为 "best"
        wb.run.log_artifact(art, aliases=["best"])
    # 遍历验证集的指标曲线并绘制到 wandb 运行日志中
    for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
        x, y, x_title, y_title = curve_values
        _plot_curve(
            x,
            y,
            names=list(trainer.validator.metrics.names.values()),
            id=f"curves/{curve_name}",
            title=curve_name,
            x_title=x_title,
            y_title=y_title,
        )
    # 结束 wandb 运行日志,必须调用以完成运行
    wb.run.finish()  # required or run continues on dashboard


# 定义回调函数集合,根据 wandb 是否可用来决定包含哪些回调函数
callbacks = (
    {
        "on_pretrain_routine_start": on_pretrain_routine_start,
        "on_train_epoch_end": on_train_epoch_end,
        "on_fit_epoch_end": on_fit_epoch_end,
        "on_train_end": on_train_end,
    }
    if wb  # 如果 wb 可用,则包含上述四个回调函数
    else {}  # 否则为空字典
)

.\yolov8\ultralytics\utils\callbacks\__init__.py

代码语言:javascript
复制
# 导入必要的模块和函数
from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
# 指定要导出的模块、变量或函数,用于模块的公开接口
__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks"

.\yolov8\ultralytics\utils\checks.py

代码语言:javascript
复制
# 导入所需的标准库和第三方库
import contextlib  # 提供了对上下文管理器的支持
import glob  # 文件名匹配库
import inspect  # 检查对象,例如获取函数的源代码
import math  # 数学函数库
import os  # 提供了与操作系统交互的功能
import platform  # 提供了访问平台相关信息的函数
import re  # 正则表达式库
import shutil  # 文件操作工具
import subprocess  # 启动和管理子进程的库
import time  # 提供了各种时间相关的功能
from importlib import metadata  # 用于访问导入的模块元数据
from pathlib import Path  # 提供了处理文件路径的功能
from typing import Optional  # 提供类型提示支持

import cv2  # OpenCV库,用于计算机视觉
import numpy as np  # 数值计算库,支持多维数组和矩阵运算
import requests  # 发送HTTP请求的库
import torch  # PyTorch深度学习框架

from ultralytics.utils import (
    ASSETS,  # 从ultralytics.utils中导入ASSETS常量
    AUTOINSTALL,  # 从ultralytics.utils中导入AUTOINSTALL常量
    IS_COLAB,  # 从ultralytics.utils中导入IS_COLAB常量
    IS_JUPYTER,  # 从ultralytics.utils中导入IS_JUPYTER常量
    IS_KAGGLE,  # 从ultralytics.utils中导入IS_KAGGLE常量
    IS_PIP_PACKAGE,  # 从ultralytics.utils中导入IS_PIP_PACKAGE常量
    LINUX,  # 从ultralytics.utils中导入LINUX常量
    LOGGER,  # 从ultralytics.utils中导入LOGGER常量
    ONLINE,  # 从ultralytics.utils中导入ONLINE常量
    PYTHON_VERSION,  # 从ultralytics.utils中导入PYTHON_VERSION常量
    ROOT,  # 从ultralytics.utils中导入ROOT常量
    TORCHVISION_VERSION,  # 从ultralytics.utils中导入TORCHVISION_VERSION常量
    USER_CONFIG_DIR,  # 从ultralytics.utils中导入USER_CONFIG_DIR常量
    Retry,  # 从ultralytics.utils中导入Retry类
    SimpleNamespace,  # 从ultralytics.utils中导入SimpleNamespace类
    ThreadingLocked,  # 从ultralytics.utils中导入ThreadingLocked类
    TryExcept,  # 从ultralytics.utils中导入TryExcept类
    clean_url,  # 从ultralytics.utils中导入clean_url函数
    colorstr,  # 从ultralytics.utils中导入colorstr函数
    downloads,  # 从ultralytics.utils中导入downloads函数
    emojis,  # 从ultralytics.utils中导入emojis函数
    is_github_action_running,  # 从ultralytics.utils中导入is_github_action_running函数
    url2file,  # 从ultralytics.utils中导入url2file函数
)


def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
    """
    Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.

    Args:
        file_path (Path): Path to the requirements.txt file.
        package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'.

    Returns:
        (List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys.

    Example:
        ```python
        from ultralytics.utils.checks import parse_requirements

        parse_requirements(package='ultralytics')
        ```py
    """

    if package:
        # 使用元数据获取指定包的依赖信息,排除额外的条件依赖
        requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
    else:
        # 读取requirements.txt文件内容并按行分割成列表
        requires = Path(file_path).read_text().splitlines()

    requirements = []
    for line in requires:
        line = line.strip()  # 去除首尾空格
        if line and not line.startswith("#"):
            line = line.split("#")[0].strip()  # 忽略行内注释
            match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line)
            if match:
                # 将解析后的依赖信息作为SimpleNamespace对象存入requirements列表
                requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ""))

    return requirements


def parse_version(version="0.0.0") -> tuple:
    """
    Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
    function replaces deprecated 'pkg_resources.parse_version(v)'.

    Args:
        version (str): Version string, i.e. '2.0.1+cpu'

    Returns:
        (tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1)
    """
    try:
        # 使用正则表达式匹配并提取版本号中的数字部分,转换为整数元组
        return tuple(map(int, re.findall(r"\d+", version)[:3]))  # '2.0.1+cpu' -> (2, 0, 1)
    except Exception as e:
        # 如果出现异常,记录警告日志并返回(0, 0, 0)
        LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}")
        return 0, 0, 0


def is_ascii(s) -> bool:
    """
    Check if a string is composed of only ASCII characters.

    Args:
        s (str): String to be checked.

    Returns:
        (bool): True if the string is composed only of ASCII characters, False otherwise.
    """
    # 将变量 s 转换为字符串形式,无论其原始类型是列表、元组、None 等
    s = str(s)
    
    # 检查字符串 s 是否仅由 ASCII 字符组成
    # 使用 all() 函数和 ord() 函数来检查字符串中的每个字符的 ASCII 编码是否小于 128
    return all(ord(c) < 128 for c in s)
# 确认图像尺寸在每个维度上是否是给定步长的倍数。如果图像尺寸不是步长的倍数,则将其更新为大于或等于给定最小值的最近步长倍数。

def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
    """
    Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
    stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.

    Args:
        imgsz (int | cList[int]): Image size.
        stride (int): Stride value.
        min_dim (int): Minimum number of dimensions.
        max_dim (int): Maximum number of dimensions.
        floor (int): Minimum allowed value for image size.

    Returns:
        (List[int]): Updated image size.
    """

    # 如果步长是张量,则将其转换为整数
    stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)

    # 如果图像尺寸是整数,则将其转换为列表
    if isinstance(imgsz, int):
        imgsz = [imgsz]
    elif isinstance(imgsz, (list, tuple)):
        imgsz = list(imgsz)
    elif isinstance(imgsz, str):  # 例如 '640' 或 '[640,640]'
        imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz)
    else:
        raise TypeError(
            f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
            f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'"
        )

    # 应用最大维度限制
    if len(imgsz) > max_dim:
        msg = (
            "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list "
            "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
        )
        if max_dim != 1:
            raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
        LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}")
        imgsz = [max(imgsz)]

    # 将图像尺寸调整为步长的倍数
    sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]

    # 如果图像尺寸已更新,则打印警告信息
    if sz != imgsz:
        LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}")

    # 如果需要,添加缺失的维度
    sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz

    return sz


def check_version(
    current: str = "0.0.0",
    required: str = "0.0.0",
    name: str = "version",
    hard: bool = False,
    verbose: bool = False,
    msg: str = "",
) -> bool:
    """
    Check current version against the required version or range.

    Args:
        current (str): Current version or package name to get version from.
        required (str): Required version or range (in pip-style format).
        name (str, optional): Name to be used in warning message.
        hard (bool, optional): If True, raise an AssertionError if the requirement is not met.
        verbose (bool, optional): If True, print warning message if requirement is not met.
        msg (str, optional): Extra message to display if verbose.
    """
    # 在当前版本和所需版本或范围之间进行检查

    # (此函数中代码已省略,不在要求范围内)
    # 检查版本号是否符合要求的函数
    def check_version(current='', required=''):
        """
        Args:
            current (str): 当前版本号字符串,例如 '22.04'
            required (str): 要求的版本号约束,例如 '==22.04', '>=22.04', '>20.04,<22.04'
    
        Returns:
            (bool): 如果版本号符合要求则返回True,否则返回False.
    
        Example:
            ```python
            # 检查当前版本是否正好是 22.04
            check_version(current='22.04', required='==22.04')
    
            # 检查当前版本是否大于或等于 22.10(假设未指定不等式时,默认为 '>=')
            check_version(current='22.10', required='22.04')
    
            # 检查当前版本是否小于或等于 22.04
            check_version(current='22.04', required='<=22.04')
    
            # 检查当前版本是否在 20.04(包括)与 22.04(不包括)之间
            check_version(current='21.10', required='>20.04,<22.04')
            ```
        """
        if not current:  # 如果当前版本号为空或None
            LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.")
            return True
        elif not current[0].isdigit():  # 如果当前版本号开头不是数字(可能是包名而不是版本号字符串,例如 current='ultralytics')
            try:
                name = current  # 将包名赋值给 'name' 参数
                current = metadata.version(current)  # 从包名获取版本号字符串
            except metadata.PackageNotFoundError as e:
                if hard:
                    raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) from e
                else:
                    return False
    
        if not required:  # 如果要求的版本号约束为空或None,则视为版本号符合要求
            return True
    
        op = ""
        version = ""
        result = True
        c = parse_version(current)  # 将当前版本号字符串解析为版本号元组,例如 '1.2.3' -> (1, 2, 3)
        for r in required.strip(",").split(","):
            op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups()  # 解析版本号约束,例如 '>=22.04' -> ('>=', '22.04')
            v = parse_version(version)  # 将要求的版本号字符串解析为版本号元组,例如 '1.2.3' -> (1, 2, 3)
            if op == "==" and c != v:
                result = False
            elif op == "!=" and c == v:
                result = False
            elif op in {">=", ""} and not (c >= v):  # 如果未指定约束,则默认为 '>=required'
                result = False
            elif op == "<=" and not (c <= v):
                result = False
            elif op == ">" and not (c > v):
                result = False
            elif op == "<" and not (c < v):
                result = False
        if not result:
            warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}"
            if hard:
                raise ModuleNotFoundError(emojis(warning))  # 断言版本要求得到满足
            if verbose:
                LOGGER.warning(warning)
        return result
# 检查最新的 PyPI 包版本,不下载或安装包
def check_latest_pypi_version(package_name="ultralytics"):
    """
    Returns the latest version of a PyPI package without downloading or installing it.

    Parameters:
        package_name (str): The name of the package to find the latest version for.

    Returns:
        (str): The latest version of the package.
    """
    # 禁止 InsecureRequestWarning 警告
    with contextlib.suppress(Exception):
        requests.packages.urllib3.disable_warnings()  # Disable the InsecureRequestWarning
        # 获取包在 PyPI 上的 JSON 信息
        response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
        if response.status_code == 200:
            # 返回包的最新版本号
            return response.json()["info"]["version"]


# 检查 ultralytics 包是否有可用的更新版本
def check_pip_update_available():
    """
    Checks if a new version of the ultralytics package is available on PyPI.

    Returns:
        (bool): True if an update is available, False otherwise.
    """
    if ONLINE and IS_PIP_PACKAGE:
        with contextlib.suppress(Exception):
            from ultralytics import __version__

            # 获取最新的 PyPI 版本号
            latest = check_latest_pypi_version()
            # 检查当前版本是否小于最新版本
            if check_version(__version__, f"<{latest}"):  # check if current version is < latest version
                LOGGER.info(
                    f"New https://pypi.org/project/ultralytics/{latest} available 😃 "
                    f"Update with 'pip install -U ultralytics'"
                )
                return True
    return False


# 使用线程锁检查字体文件是否存在于用户配置目录,不存在则下载
@ThreadingLocked()
def check_font(font="Arial.ttf"):
    """
    Find font locally or download to user's configuration directory if it does not already exist.

    Args:
        font (str): Path or name of font.

    Returns:
        file (Path): Resolved font file path.
    """
    from matplotlib import font_manager

    # 检查用户配置目录是否存在字体文件
    name = Path(font).name
    file = USER_CONFIG_DIR / name
    if file.exists():
        return file

    # 检查系统中是否存在指定的字体
    matches = [s for s in font_manager.findSystemFonts() if font in s]
    if any(matches):
        return matches[0]

    # 如果缺失,则从 GitHub 下载到用户配置目录
    url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{name}"
    if downloads.is_url(url, check=True):
        downloads.safe_download(url=url, file=file)
        return file


# 检查当前 Python 版本是否满足指定的最小要求
def check_python(minimum: str = "3.8.0", hard: bool = True) -> bool:
    """
    Check current python version against the required minimum version.

    Args:
        minimum (str): Required minimum version of python.
        hard (bool, optional): If True, raise an AssertionError if the requirement is not met.

    Returns:
        (bool): Whether the installed Python version meets the minimum constraints.
    """
    return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard)


# 尝试检查安装的依赖项是否满足 YOLOv8 的要求,并尝试自动更新
@TryExcept()
def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
    """
    Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed.
    """
    Args:
        requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a
            string, or a list of package requirements as strings.
        exclude (Tuple[str]): Tuple of package names to exclude from checking.
        install (bool): If True, attempt to auto-update packages that don't meet requirements.
        cmds (str): Additional commands to pass to the pip install command when auto-updating.

    Example:
        ```py
        from ultralytics.utils.checks import check_requirements

        # Check a requirements.txt file
        check_requirements('path/to/requirements.txt')

        # Check a single package
        check_requirements('ultralytics>=8.0.0')

        # Check multiple packages
        check_requirements(['numpy', 'ultralytics>=8.0.0'])
        ```

    prefix = colorstr("red", "bold", "requirements:")  # 设置带有颜色的输出前缀

    check_python()  # 检查当前 Python 版本是否满足要求
    check_torchvision()  # 检查 torch 和 torchvision 的兼容性

    if isinstance(requirements, Path):  # 如果 requirements 是 Path 对象,代表是一个 requirements.txt 文件
        file = requirements.resolve()  # 获取文件的绝对路径
        assert file.exists(), f"{prefix} {file} not found, check failed."  # 断言文件存在,否则抛出异常
        requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude]  # 解析 requirements.txt 中的内容,并排除 exclude 中的包名
    elif isinstance(requirements, str):
        requirements = [requirements]  # 如果 requirements 是字符串,转为包含单个字符串的列表

    pkgs = []
    for r in requirements:
        r_stripped = r.split("/")[-1].replace(".git", "")  # 从 URL 形式的包名中提取出真实的包名
        match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped)  # 使用正则表达式匹配包名和版本要求
        name, required = match[1], match[2].strip() if match[2] else ""  # 获取包名和版本要求
        try:
            assert check_version(metadata.version(name), required)  # 检查当前安装的包版本是否符合要求,不符合则抛出异常
        except (AssertionError, metadata.PackageNotFoundError):
            pkgs.append(r)  # 将不符合要求的包加入列表中

    @Retry(times=2, delay=1)
    def attempt_install(packages, commands):
        """Attempt pip install command with retries on failure."""
        return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True).decode()
        # 使用带有重试机制的 subprocess 执行 pip install 命令并返回输出结果

    s = " ".join(f'"{x}"' for x in pkgs)  # 构建控制台输出字符串,列出需要更新的包名
    # 如果条件 s 不为空,则进入条件判断
    if s:
        # 如果 install 为真并且 AUTOINSTALL 环境变量为真,则继续执行
        if install and AUTOINSTALL:  # check environment variable
            # 计算需要更新的包的数量
            n = len(pkgs)  # number of packages updates
            # 记录日志信息,指示 Ultralytics 的要求未找到,并尝试自动更新
            LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
            try:
                t = time.time()  # 记录开始时间
                assert ONLINE, "AutoUpdate skipped (offline)"  # 检查是否在线,否则跳过自动更新
                # 执行自动安装操作,并记录日志返回信息
                LOGGER.info(attempt_install(s, cmds))
                dt = time.time() - t  # 计算自动更新所需时间
                # 记录自动更新成功的日志信息,显示安装的包的数量和名称
                LOGGER.info(
                    f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n"
                    f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
                )
            except Exception as e:
                # 记录警告日志,指示自动更新失败
                LOGGER.warning(f"{prefix} ❌ {e}")
                # 如果发生异常,返回 False
                return False
        else:
            # 如果不满足自动安装的条件,直接返回 False
            return False

    # 如果条件 s 为空或未满足自动安装条件,则返回 True
    return True
# 检查 PyTorch 和 Torchvision 的兼容性
def check_torchvision():
    """
    Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.

    This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
    to the provided compatibility table based on:
    https://github.com/pytorch/vision#installation.

    The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
    Torchvision versions.
    """

    # 兼容性表
    compatibility_table = {
        "2.3": ["0.18"],
        "2.2": ["0.17"],
        "2.1": ["0.16"],
        "2.0": ["0.15"],
        "1.13": ["0.14"],
        "1.12": ["0.13"],
    }

    # 提取主要和次要版本号
    v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
    # 如果当前 PyTorch 版本在兼容性表中
    if v_torch in compatibility_table:
        compatible_versions = compatibility_table[v_torch]
        # 提取当前 Torchvision 的主要和次要版本号
        v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2])
        # 如果当前 Torchvision 版本不在兼容的版本列表中
        if all(v_torchvision != v for v in compatible_versions):
            # 打印警告信息,说明 Torchvision 版本不兼容
            print(
                f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
                f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
                "'pip install -U torch torchvision' to update both.\n"
                "For a full compatibility table see https://github.com/pytorch/vision#installation"
            )


# 检查文件后缀是否符合要求
def check_suffix(file="yolov8n.pt", suffix=".pt", msg=""):
    """Check file(s) for acceptable suffix."""
    # 如果 file 和 suffix 都不为空
    if file and suffix:
        # 如果 suffix 是字符串,转换为元组
        if isinstance(suffix, str):
            suffix = (suffix,)
        # 对于 file 是列表或元组的情况,遍历每个文件名
        for f in file if isinstance(file, (list, tuple)) else [file]:
            # 获取文件的后缀名并转换为小写
            s = Path(f).suffix.lower().strip()  # file suffix
            # 如果后缀名长度大于0
            if len(s):
                # 断言文件后缀在给定的后缀列表中,否则触发 AssertionError
                assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}"


# 检查 YOLOv5u 文件名,并输出警告信息
def check_yolov5u_filename(file: str, verbose: bool = True):
    """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames."""
    # 检查文件名中是否包含'yolov3'或'yolov5'
    if "yolov3" in file or "yolov5" in file:
        # 如果文件名中包含'u.yaml',将其替换为'.yaml'
        if "u.yaml" in file:
            file = file.replace("u.yaml", ".yaml")  # 例如将'yolov5nu.yaml'替换为'yolov5n.yaml'
        # 如果文件名包含'.pt'且不包含'u'
        elif ".pt" in file and "u" not in file:
            # 保存原始文件名
            original_file = file
            # 使用正则表达式将文件名中的特定模式替换为带'u'后缀的新模式
            file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file)  # 例如将'yolov5n.pt'替换为'yolov5nu.pt'
            file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file)  # 例如将'yolov5n6.pt'替换为'yolov5n6u.pt'
            file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file)  # 例如将'yolov3-spp.pt'替换为'yolov3-sppu.pt'
            # 如果文件名已被修改且verbose为真,记录日志信息
            if file != original_file and verbose:
                LOGGER.info(
                    f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
                    f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
                    f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n"
                )
    # 返回处理后的文件名
    return file
# 检查模型文件名是否是有效的模型 stem,并返回一个完整的模型文件名
def check_model_file_from_stem(model="yolov8n"):
    if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS:
        # 如果模型名存在且没有后缀,并且模型 stem 在下载的 GitHub 资源中
        return Path(model).with_suffix(".pt")  # 添加后缀,例如 yolov8n -> yolov8n.pt
    else:
        return model  # 否则返回原始模型名


# 搜索/下载文件(如果需要),并返回文件路径
def check_file(file, suffix="", download=True, download_dir=".", hard=True):
    check_suffix(file, suffix)  # 可选步骤,检查文件后缀
    file = str(file).strip()  # 转换为字符串并去除空格
    file = check_yolov5u_filename(file)  # 将 yolov5n 转换为 yolov5nu
    if (
        not file
        or ("://" not in file and Path(file).exists())  # 在 Windows Python<3.10 中需要检查 '://' 的存在
        or file.lower().startswith("grpc://")
    ):  # 文件存在或者是 gRPC Triton 图像
        return file
    elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):  # 下载文件
        url = file  # 警告:Pathlib 会将 :// 转换为 :/
        file = Path(download_dir) / url2file(file)  # 将 URL 转换为本地文件路径,处理 %2F 和路径分隔符
        if file.exists():
            LOGGER.info(f"Found {clean_url(url)} locally at {file}")  # 文件已经存在
        else:
            downloads.safe_download(url=url, file=file, unzip=False)  # 安全下载文件
        return str(file)
    else:  # 搜索文件
        files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file))  # 查找文件
        if not files and hard:
            raise FileNotFoundError(f"'{file}' does not exist")
        elif len(files) > 1 and hard:
            raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
        return files[0] if len(files) else []  # 返回第一个匹配的文件,如果没有找到则返回空列表


# 搜索/下载 YAML 文件(如果需要),并返回文件路径,同时检查后缀
def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
    return check_file(file, suffix, hard=hard)


# 检查解析后的路径是否在预期目录下,防止路径遍历攻击
def check_is_path_safe(basedir, path):
    base_dir_resolved = Path(basedir).resolve()
    path_resolved = Path(path).resolve()

    return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts


# 检查环境是否支持显示图像
def check_imshow(warn=False):
    try:
        if LINUX:
            assert not IS_COLAB and not IS_KAGGLE
            assert "DISPLAY" in os.environ, "The DISPLAY environment variable isn't set."
        cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8))  # 显示一个小的 8x8 RGB 图像
        cv2.waitKey(1)
        cv2.destroyAllWindows()
        cv2.waitKey(1)
        return True  # 返回 True 表示显示正常
    # 捕获所有异常,并将异常信息保存在变量 e 中
    except Exception as e:
        # 如果 warn 参数为真,则记录警告消息,指示环境不支持 cv2.imshow() 或 PIL Image.show()
        LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}")
        # 返回 False 表示函数执行失败
        return False
def check_yolo(verbose=True, device=""):
    """Return a human-readable YOLO software and hardware summary."""
    # 导入 psutil 库,用于获取系统信息
    import psutil
    # 从 ultralytics.utils.torch_utils 中导入 select_device 函数
    from ultralytics.utils.torch_utils import select_device

    # 如果运行在 Jupyter 环境下
    if IS_JUPYTER:
        # 检查是否满足使用 wandb,如果不满足,不安装
        if check_requirements("wandb", install=False):
            os.system("pip uninstall -y wandb")  # 卸载 wandb:避免创建不必要的账户并导致无限挂起
        # 如果运行在 Colab 环境下,移除 /sample_data 目录
        if IS_COLAB:
            shutil.rmtree("sample_data", ignore_errors=True)  # 移除 Colab 的 /sample_data 目录

    # 如果 verbose 参数为 True
    if verbose:
        # 计算 GiB 换算的字节数
        gib = 1 << 30  # bytes per GiB
        # 获取系统的内存总量
        ram = psutil.virtual_memory().total
        # 获取根目录 "/" 的磁盘使用情况:总容量、已用容量、空闲容量
        total, used, free = shutil.disk_usage("/")
        # 构建系统信息字符串
        s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
        # 尝试清除 IPython 环境下的显示
        with contextlib.suppress(Exception):  # 如果安装了 ipython,则清除显示
            from IPython import display

            display.clear_output()
    else:
        s = ""

    # 调用 select_device 函数,设置设备
    select_device(device=device, newline=False)
    # 记录日志信息,表示设置完成
    LOGGER.info(f"Setup complete ✅ {s}")


def collect_system_info():
    """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA."""
    # 导入 psutil 库,用于获取系统信息
    import psutil
    # 从 ultralytics.utils 中导入相关变量和函数:ENVIRONMENT, IS_GIT_DIR
    from ultralytics.utils import ENVIRONMENT, IS_GIT_DIR
    # 从 ultralytics.utils.torch_utils 中导入 get_cpu_info 函数
    from ultralytics.utils.torch_utils import get_cpu_info

    # 计算 RAM 信息,将字节转换为 GB
    ram_info = psutil.virtual_memory().total / (1024**3)  # Convert bytes to GB
    # 调用 check_yolo 函数,执行 YOLO 系统信息的检查
    check_yolo()
    # 记录系统信息到日志中
    LOGGER.info(
        f"\n{'OS':<20}{platform.platform()}\n"
        f"{'Environment':<20}{ENVIRONMENT}\n"
        f"{'Python':<20}{PYTHON_VERSION}\n"
        f"{'Install':<20}{'git' if IS_GIT_DIR else 'pip' if IS_PIP_PACKAGE else 'other'}\n"
        f"{'RAM':<20}{ram_info:.2f} GB\n"
        f"{'CPU':<20}{get_cpu_info()}\n"
        f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n"
    )

    # 遍历解析 ultralytics 包的依赖要求
    for r in parse_requirements(package="ultralytics"):
        try:
            # 获取当前包的版本信息
            current = metadata.version(r.name)
            # 检查当前版本是否符合要求,返回对应的标志符号
            is_met = "✅ " if check_version(current, str(r.specifier), hard=True) else "❌ "
        except metadata.PackageNotFoundError:
            # 如果包未安装,标记为未安装
            current = "(not installed)"
            is_met = "❌ "
        # 记录依赖包的信息到日志中
        LOGGER.info(f"{r.name:<20}{is_met}{current}{r.specifier}")

    # 如果正在使用 GitHub Actions
    if is_github_action_running():
        LOGGER.info(
            f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n"
            f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n"
            f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n"
            f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n"
            f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n"
            f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n"
        )


def check_amp(model):
    """
    This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks
    fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will
    """
    # 这个函数检查 YOLOv8 模型的 PyTorch Automatic Mixed Precision (AMP) 功能
    pass
    def check_amp(model):
        """
        Check if Automatic Mixed Precision (AMP) works correctly with a YOLOv8 model.
    
        Args:
            model (nn.Module): A YOLOv8 model instance.
    
        Example:
            ```py
            from ultralytics import YOLO
            from ultralytics.utils.checks import check_amp
    
            model = YOLO('yolov8n.pt').model.cuda()
            check_amp(model)
            ```
    
        Returns:
            (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
        """
        from ultralytics.utils.torch_utils import autocast  # Import autocast function from torch_utils
    
        device = next(model.parameters()).device  # Get the device of the model
        if device.type in {"cpu", "mps"}:
            return False  # Return False if AMP is only supported on CUDA devices
    
        def amp_allclose(m, im):
            """All close FP32 vs AMP results."""
            a = m(im, device=device, verbose=False)[0].boxes.data  # Perform FP32 inference
            with autocast(enabled=True):
                b = m(im, device=device, verbose=False)[0].boxes.data  # Perform AMP inference
            del m  # Delete the model instance
            return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5)  # Check if results are close with 0.5 absolute tolerance
    
        im = ASSETS / "bus.jpg"  # Define the path to the image for checking
        prefix = colorstr("AMP: ")  # Add color formatting to log messages
        LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...")  # Log AMP check initialization
        warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."  # Warning message about AMP usage
        try:
            from ultralytics import YOLO  # Import YOLO class from ultralytics
    
            assert amp_allclose(YOLO("yolov8n.pt"), im)  # Assert if AMP results are close to FP32 results
            LOGGER.info(f"{prefix}checks passed ✅")  # Log that AMP checks passed
        except ConnectionError:
            LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}")  # Log warning if YOLOv8n download fails
        except (AttributeError, ModuleNotFoundError):
            LOGGER.warning(
                f"{prefix}checks skipped ⚠️. "
                f"Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}"
            )  # Log warning if YOLOv8n loading fails due to modifications
        except AssertionError:
            LOGGER.warning(
                f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to "
                f"NaN losses or zero-mAP results, so AMP will be disabled during training."
            )  # Log if AMP checks fail, indicating potential issues
            return False  # Return False if AMP checks fail
        return True  # Return True if AMP checks pass successfully
def git_describe(path=ROOT):  # path must be a directory
    """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
    # 尝试执行 git describe 命令获取当前目录下 Git 仓库的描述信息
    with contextlib.suppress(Exception):
        return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1]
    # 如果执行失败或出现异常,返回空字符串
    return ""


def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
    """Print function arguments (optional args dict)."""

    def strip_auth(v):
        """Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
        # 如果 URL 开头为 "http",长度超过 100,且为字符串类型,则清除可能的认证信息
        return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v

    x = inspect.currentframe().f_back  # previous frame
    file, _, func, _, _ = inspect.getframeinfo(x)
    if args is None:  # get args automatically
        # 如果未传入参数字典,则自动获取当前函数的参数和值
        args, _, _, frm = inspect.getargvalues(x)
        args = {k: v for k, v in frm.items() if k in args}
    try:
        # 尝试解析文件路径并相对于根目录确定文件路径或文件名(不带后缀)
        file = Path(file).resolve().relative_to(ROOT).with_suffix("")
    except ValueError:
        # 如果解析失败,直接取文件名(不带后缀)
        file = Path(file).stem
    # 构建输出字符串,包括文件名和函数名(根据传入的显示选项)
    s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
    # 使用 LOGGER 记录信息,输出每个参数的名称和经过 strip_auth 处理后的值
    LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items()))


def cuda_device_count() -> int:
    """
    Get the number of NVIDIA GPUs available in the environment.

    Returns:
        (int): The number of NVIDIA GPUs available.
    """
    try:
        # 运行 nvidia-smi 命令并捕获其输出
        output = subprocess.check_output(
            ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8"
        )

        # 取输出的第一行并去除首尾空白字符
        first_line = output.strip().split("\n")[0]

        # 将第一行的内容转换为整数并返回
        return int(first_line)
    except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
        # 如果命令执行失败,nvidia-smi 未找到,或输出无法转换为整数,则假定没有可用的 GPU
        return 0


def cuda_is_available() -> bool:
    """
    Check if CUDA is available in the environment.

    Returns:
        (bool): True if one or more NVIDIA GPUs are available, False otherwise.
    """
    # 检查是否有可用的 NVIDIA GPU,返回结果为布尔值
    return cuda_device_count() > 0


# Define constants
IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False)
IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12")

.\yolov8\ultralytics\utils\dist.py

代码语言:javascript
复制
# 导入必要的模块和函数
import os  # 系统操作模块
import shutil  # 文件操作模块
import socket  # 网络通信模块
import sys  # 系统模块
import tempfile  # 临时文件模块

from . import USER_CONFIG_DIR  # 导入当前目录下的 USER_CONFIG_DIR 变量
from .torch_utils import TORCH_1_9  # 导入 TORCH_1_9 变量和函数

# 查找本地空闲网络端口的函数
def find_free_network_port() -> int:
    """
    Finds a free port on localhost.

    It is useful in single-node training when we don't want to connect to a real main node but have to set the
    `MASTER_PORT` environment variable.
    """
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("127.0.0.1", 0))  # 绑定到本地地址,端口号自动分配
        return s.getsockname()[1]  # 返回分配的端口号


# 生成并返回一个 DDP 文件的函数
def generate_ddp_file(trainer):
    """Generates a DDP file and returns its file name."""
    module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)

    content = f"""
# Ultralytics Multi-GPU training temp file (should be automatically deleted after use)
overrides = {vars(trainer)}

if __name__ == "__main__":
    from {module} import {name}
    from ultralytics.utils import DEFAULT_CFG_DICT

    cfg = DEFAULT_CFG_DICT.copy()
    cfg.update(save_dir='')   # 处理额外的键 'save_dir'
    trainer = {name}(cfg=cfg, overrides=overrides)
    trainer.args.model = "{getattr(trainer.hub_session, 'model_url', trainer.args.model)}"
    results = trainer.train()
"""
    (USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)  # 创建存放 DDP 文件的目录
    with tempfile.NamedTemporaryFile(
        prefix="_temp_",
        suffix=f"{id(trainer)}.py",
        mode="w+",
        encoding="utf-8",
        dir=USER_CONFIG_DIR / "DDP",
        delete=False,
    ) as file:
        file.write(content)  # 写入临时文件内容
    return file.name  # 返回临时文件的文件名


# 生成并返回分布式训练命令的函数
def generate_ddp_command(world_size, trainer):
    """Generates and returns command for distributed training."""
    import __main__  # 本地导入,避免特定问题

    if not trainer.resume:
        shutil.rmtree(trainer.save_dir)  # 删除保存目录
    file = generate_ddp_file(trainer)  # 生成 DDP 文件
    dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
    port = find_free_network_port()  # 获取空闲端口号
    cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file]
    return cmd, file  # 返回生成的命令和文件名


# 清理函数,删除生成的临时文件
def ddp_cleanup(trainer, file):
    """Delete temp file if created."""
    if f"{id(trainer)}.py" in file:  # 如果文件名包含临时文件的标识
        os.remove(file)  # 删除临时文件

.\yolov8\ultralytics\utils\downloads.py

代码语言:javascript
复制
# Ultralytics YOLO 🚀, AGPL-3.0 license

# 导入必要的库
import contextlib  # 提供上下文管理工具的标准库
import re  # 提供正则表达式操作的模块
import shutil  # 提供高级文件操作的模块
import subprocess  # 提供运行外部命令的功能
from itertools import repeat  # 提供迭代工具函数
from multiprocessing.pool import ThreadPool  # 提供多线程池的功能
from pathlib import Path  # 提供处理文件路径的类和函数
from urllib import parse, request  # 提供处理 URL 相关的模块

import requests  # 提供进行 HTTP 请求的模块
import torch  # PyTorch 深度学习框架

# 从 Ultralytics 的 utils 模块中导入特定函数和类
from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file

# 定义 Ultralytics GitHub 上的资源仓库和文件名列表
GITHUB_ASSETS_REPO = "ultralytics/assets"
GITHUB_ASSETS_NAMES = (
    [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb", "-oiv7")]
    + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
    + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
    + [f"yolov8{k}-world.pt" for k in "smlx"]
    + [f"yolov8{k}-worldv2.pt" for k in "smlx"]
    + [f"yolov9{k}.pt" for k in "tsmce"]
    + [f"yolov10{k}.pt" for k in "nsmblx"]
    + [f"yolo_nas_{k}.pt" for k in "sml"]
    + [f"sam_{k}.pt" for k in "bl"]
    + [f"FastSAM-{k}.pt" for k in "sx"]
    + [f"rtdetr-{k}.pt" for k in "lx"]
    + ["mobile_sam.pt"]
    + ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"]
)
GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]


def is_url(url, check=False):
    """
    验证给定的字符串是否为 URL,并可选择检查该 URL 是否在线可用。

    Args:
        url (str): 要验证为 URL 的字符串。
        check (bool, optional): 如果为 True,则额外检查 URL 是否在线可用。默认为 True。

    Returns:
        bool: 如果是有效的 URL 返回 True。如果 'check' 为 True,则同时检查 URL 在线是否可用。否则返回 False。

    Example:
        ```py
        valid = is_url("https://www.example.com")
        ```
    """
    with contextlib.suppress(Exception):
        url = str(url)
        result = parse.urlparse(url)
        assert all([result.scheme, result.netloc])  # 检查是否为 URL
        if check:
            with request.urlopen(url) as response:
                return response.getcode() == 200  # 检查是否在线可用
        return True
    return False


def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
    """
    删除指定目录下的所有 ".DS_Store" 文件。

    Args:
        path (str, optional): 应删除 ".DS_Store" 文件的目录路径。
        files_to_delete (tuple): 要删除的文件列表。

    Example:
        ```py
        from ultralytics.utils.downloads import delete_dsstore

        delete_dsstore('path/to/dir')
        ```

    Note:
        ".DS_Store" 文件由苹果操作系统创建,包含关于文件和文件夹的元数据。它们是隐藏的系统文件,在不同操作系统间传输文件时可能会引起问题。
    """
    # 遍历需要删除的文件列表
    for file in files_to_delete:
        # 使用路径对象查找所有匹配指定文件名的文件
        matches = list(Path(path).rglob(file))
        # 记录日志信息,指示正在删除哪些文件
        LOGGER.info(f"Deleting {file} files: {matches}")
        # 遍历每一个找到的文件路径,并删除文件
        for f in matches:
            f.unlink()
# 解压缩一个 ZIP 文件到指定路径,排除在排除列表中的文件
def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True):
    """
    Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list.

    If the zipfile does not contain a single top-level directory, the function will create a new
    directory with the same name as the zipfile (without the extension) to extract its contents.
    If a path is not provided, the function will use the parent directory of the zipfile as the default path.

    Args:
        file (str): The path to the zipfile to be extracted.
        path (str, optional): The path to extract the zipfile to. Defaults to None.
        exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX').
        exist_ok (bool, optional): Whether to overwrite existing contents if they exist. Defaults to False.
        progress (bool, optional): Whether to display a progress bar. Defaults to True.

    Raises:
        BadZipFile: If the provided file does not exist or is not a valid zipfile.

    Returns:
        (Path): The path to the directory where the zipfile was extracted.

    Example:
        ```py
        from ultralytics.utils.downloads import unzip_file

        dir = unzip_file('path/to/file.zip')
        ```
    """
    from zipfile import ZipFile, BadZipFile
    from pathlib import Path

    # 删除目录中的 .DS_Store 文件
    delete_dsstore(directory)
    # 转换输入的路径为 Path 对象
    directory = Path(directory)
    # 如果目录不存在,则抛出 FileNotFoundError 异常
    if not directory.is_dir():
        raise FileNotFoundError(f"Directory '{directory}' does not exist.")

    # 查找目录下所有不在排除列表中的文件并压缩
    files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)]
    # 设定压缩后的文件名为目录名加 .zip 后缀
    zip_file = directory.with_suffix(".zip")
    # 设定压缩方式,根据 compress 参数选择 ZIP_DEFLATED 或 ZIP_STORED
    compression = ZIP_DEFLATED if compress else ZIP_STORED
    # 使用 ZipFile 对象打开 zip_file,以写入模式创建压缩文件
    with ZipFile(zip_file, "w", compression) as f:
        # 使用 TQDM 显示压缩进度条,遍历 files_to_zip 列表中的文件
        for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress):
            # 将文件写入压缩文件中,文件的相对路径以目录为基准
            f.write(file, file.relative_to(directory))

    # 返回压缩文件的路径
    return zip_file  # return path to zip file
    from zipfile import BadZipFile, ZipFile, is_zipfile

    # 检查文件是否存在且为有效的 ZIP 文件
    if not (Path(file).exists() and is_zipfile(file)):
        # 如果文件不存在或者不是有效的 ZIP 文件,则抛出异常
        raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.")
    
    if path is None:
        path = Path(file).parent  # 默认路径为文件所在目录

    # 解压缩文件内容
    with ZipFile(file) as zipObj:
        # 从所有文件中筛选出不包含指定排除项的文件
        files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
        
        # 获取顶层目录列表
        top_level_dirs = {Path(f).parts[0] for f in files}

        # 决定是直接解压缩还是解压缩到一个目录
        unzip_as_dir = len(top_level_dirs) == 1  # 判断是否只有一个顶层目录
        if unzip_as_dir:
            # 若 ZIP 文件只有一个顶层目录,则解压到指定的路径下
            extract_path = path
            path = Path(path) / list(top_level_dirs)[0]  # 将顶层目录添加到路径中
        else:
            # 若 ZIP 文件有多个文件在顶层,则解压缩到单独的子目录中
            path = extract_path = Path(path) / Path(file).stem  # 创建一个新的子目录

        # 检查目标目录是否已经存在且不为空,如果不允许覆盖,则直接返回目录路径
        if path.exists() and any(path.iterdir()) and not exist_ok:
            LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.")
            return path

        # 遍历文件列表,逐个解压文件
        for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress):
            # 确保文件路径在指定的解压路径内,避免路径遍历安全漏洞
            if ".." in Path(f).parts:
                LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.")
                continue
            zipObj.extract(f, extract_path)

    return path  # 返回解压后的目录路径
# 根据给定的 URL 获取文件的头部信息
try:
    r = requests.head(url)  # 发起 HEAD 请求获取文件信息
    assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}"  # 检查响应状态码
except Exception:
    return True  # 请求出现问题,默认返回 True

# 计算每个 GiB(2^30 字节)
gib = 1 << 30  # 每个 GiB 的字节数
# 计算要下载文件的大小(GB)
data = int(r.headers.get("Content-Length", 0)) / gib  # 文件大小(GB)

# 获取指定路径的磁盘使用情况
total, used, free = (x / gib for x in shutil.disk_usage(path))  # 总空间、已用空间、剩余空间(GB)

# 检查剩余空间是否足够
if data * sf < free:
    return True  # 空间足够

# 磁盘空间不足的情况
text = (
    f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, "
    f"Please free {data * sf - free:.1f} GB additional disk space and try again."
)
if hard:
    raise MemoryError(text)  # 抛出内存错误异常
LOGGER.warning(text)  # 记录警告日志
return False  # 返回空间不足
    # 使用 requests 库创建一个会话对象
    with requests.Session() as session:
        # 发送 GET 请求到指定的 Google Drive URL,并允许流式传输
        response = session.get(drive_url, stream=True)
        
        # 检查响应内容是否包含 "quota exceeded",如果是则抛出连接错误异常
        if "quota exceeded" in str(response.content.lower()):
            raise ConnectionError(
                emojis(
                    f"❌  Google Drive file download quota exceeded. "
                    f"Please try again later or download this file manually at {link}."
                )
            )
        
        # 遍历响应中的 cookies
        for k, v in response.cookies.items():
            # 如果 cookie 的键以 "download_warning" 开头,将 token 添加到 drive_url 中
            if k.startswith("download_warning"):
                drive_url += f"&confirm={v}"  # v 是 token
        
        # 获取响应头中的 content-disposition 属性
        cd = response.headers.get("content-disposition")
        
        # 如果 content-disposition 存在
        if cd:
            # 使用正则表达式解析出文件名
            filename = re.findall('filename="(.+)"', cd)[0]
    
    # 返回更新后的 drive_url 和解析出的文件名 filename
    return drive_url, filename
# 定义一个安全下载函数,从指定的 URL 下载文件,支持多种选项如重试、解压和删除已下载文件等

def safe_download(
    url,
    file=None,
    dir=None,
    unzip=True,
    delete=False,
    curl=False,
    retry=3,
    min_bytes=1e0,
    exist_ok=False,
    progress=True,
):
    """
    Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file.

    Args:
        url (str): The URL of the file to be downloaded.
        file (str, optional): The filename of the downloaded file.
            If not provided, the file will be saved with the same name as the URL.
        dir (str, optional): The directory to save the downloaded file.
            If not provided, the file will be saved in the current working directory.
        unzip (bool, optional): Whether to unzip the downloaded file. Default: True.
        delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False.
        curl (bool, optional): Whether to use curl command line tool for downloading. Default: False.
        retry (int, optional): The number of times to retry the download in case of failure. Default: 3.
        min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered
            a successful download. Default: 1E0.
        exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.
        progress (bool, optional): Whether to display a progress bar during the download. Default: True.

    Example:
        ```py
        from ultralytics.utils.downloads import safe_download

        link = "https://ultralytics.com/assets/bus.jpg"
        path = safe_download(link)
        ```
    """

    gdrive = url.startswith("https://drive.google.com/")  # 检查 URL 是否是谷歌驱动器的链接
    if gdrive:
        url, file = get_google_drive_file_info(url)  # 如果是谷歌驱动器链接,获取文件信息

    f = Path(dir or ".") / (file or url2file(url))  # 构造文件路径,默认在当前目录下生成或指定目录
    if "://" not in str(url) and Path(url).is_file():  # 检查 URL 是否存在(在 Windows Python<3.10 中需要检查 '://')
        f = Path(url)  # 如果 URL 是一个文件路径,则直接使用该路径作为文件名
    elif not f.is_file():  # 如果 URL 或文件不存在
        uri = (url if gdrive else clean_url(url)).replace(  # 清理和替换的 URL
            "https://github.com/ultralytics/assets/releases/download/v0.0.0/",
            "https://ultralytics.com/assets/",  # 替换为的 URL 别名
        )
        desc = f"Downloading {uri} to '{f}'"  # 下载描述信息
        LOGGER.info(f"{desc}...")  # 记录下载信息到日志
        f.parent.mkdir(parents=True, exist_ok=True)  # 创建目录(如果不存在)
        check_disk_space(url, path=f.parent)  # 检查磁盘空间是否足够
        for i in range(retry + 1):  # 重试下载的次数范围
            try:
                if curl or i > 0:  # 使用 curl 下载并支持重试
                    s = "sS" * (not progress)  # 是否静默下载
                    r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode  # 执行 curl 命令下载文件
                    assert r == 0, f"Curl return value {r}"  # 确保 curl 命令返回值为 0,表示下载成功
                else:  # 使用 urllib 下载
                    method = "torch"
                    if method == "torch":
                        torch.hub.download_url_to_file(url, f, progress=progress)  # 使用 torch 模块下载文件到指定路径
                    else:
                        with request.urlopen(url) as response, TQDM(  # 使用 urllib 打开 URL 并显示下载进度
                            total=int(response.getheader("Content-Length", 0)),
                            desc=desc,
                            disable=not progress,
                            unit="B",
                            unit_scale=True,
                            unit_divisor=1024,
                        ) as pbar:
                            with open(f, "wb") as f_opened:  # 打开文件并写入下载的数据
                                for data in response:
                                    f_opened.write(data)
                                    pbar.update(len(data))  # 更新下载进度条

                if f.exists():  # 如果文件存在
                    if f.stat().st_size > min_bytes:  # 如果文件大小大于指定的最小字节数
                        break  # 成功下载,退出循环
                    f.unlink()  # 删除部分下载的文件
            except Exception as e:
                if i == 0 and not is_online():  # 如果是第一次尝试且未联网
                    raise ConnectionError(emojis(f"❌  Download failure for {uri}. Environment is not online.")) from e  # 抛出连接错误异常
                elif i >= retry:  # 如果重试次数超过设定的值
                    raise ConnectionError(emojis(f"❌  Download failure for {uri}. Retry limit reached.")) from e  # 抛出连接错误异常
                LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {uri}...")  # 记录下载失败并重试的警告信息

    if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}:  # 如果需要解压且文件存在且文件后缀合法
        from zipfile import is_zipfile

        unzip_dir = (dir or f.parent).resolve()  # 如果提供了目录则解压到指定目录,否则解压到文件所在目录
        if is_zipfile(f):  # 如果是 ZIP 文件
            unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress)  # 解压 ZIP 文件
        elif f.suffix in {".tar", ".gz"}:  # 如果是 .tar 或 .gz 文件
            LOGGER.info(f"Unzipping {f} to {unzip_dir}...")  # 记录解压信息到日志
            subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)  # 使用 tar 命令解压文件
        if delete:
            f.unlink()  # 删除原始压缩文件
        return unzip_dir  # 返回解压后的目录路径
# 从 GitHub 仓库中获取指定版本的标签和资产列表。如果未指定版本,则获取最新发布的资产。
def get_github_assets(repo="ultralytics/assets", version="latest", retry=False):
    # 如果版本不是最新,将版本号格式化为 'tags/version',例如 'tags/v6.2'
    if version != "latest":
        version = f"tags/{version}"
    # 构建 GitHub API 请求的 URL
    url = f"https://api.github.com/repos/{repo}/releases/{version}"
    # 发送 GET 请求获取数据
    r = requests.get(url)  # github api
    # 如果请求失败且不是因为 403 状态码限制,并且设置了重试标志,则再次尝试请求
    if r.status_code != 200 and r.reason != "rate limit exceeded" and retry:
        r = requests.get(url)  # try again
    # 如果请求仍然失败,记录警告日志并返回空字符串和空列表
    if r.status_code != 200:
        LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}")
        return "", []
    # 解析 JSON 数据,返回标签名和资产名称列表
    data = r.json()
    return data["tag_name"], [x["name"] for x in data["assets"]]  # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...]


# 尝试从 GitHub 发布资产中下载文件,如果本地不存在。首先检查本地文件,然后尝试从指定的 GitHub 仓库版本下载。
def attempt_download_asset(file, repo="ultralytics/assets", release="v8.2.0", **kwargs):
    from ultralytics.utils import SETTINGS  # 用于解决循环导入问题的局部引入

    # 对文件名进行 YOLOv5u 文件名检查和更新
    file = str(file)
    file = checks.check_yolov5u_filename(file)
    file = Path(file.strip().replace("'", ""))
    # 如果文件存在于本地,直接返回文件路径
    if file.exists():
        return str(file)
    # 如果文件存在于设置中指定的权重目录中,直接返回文件路径
    elif (SETTINGS["weights_dir"] / file).exists():
        return str(SETTINGS["weights_dir"] / file)
    else:
        # 如果不是本地文件路径,则是URL
        name = Path(parse.unquote(str(file))).name  # 解码文件路径中的特殊字符,如 '%2F' 解码为 '/'
        download_url = f"https://github.com/{repo}/releases/download"
        
        if str(file).startswith(("http:/", "https:/")):  # 如果是以 http:/ 或 https:/ 开头的URL,则下载文件
            url = str(file).replace(":/", "://")  # 修正URL格式,Pathlib 会将 :// 转换为 :/
            file = url2file(name)  # 解析URL中的认证信息,例如 https://url.com/file.txt?auth...
            
            if Path(file).is_file():
                LOGGER.info(f"Found {clean_url(url)} locally at {file}")  # 文件已存在于本地
            else:
                safe_download(url=url, file=file, min_bytes=1e5, **kwargs)  # 安全下载文件

        elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:
            # 如果是 GitHub 的资源仓库且文件名在预定义的资源名称列表中,则安全下载
            safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs)

        else:
            # 否则,获取指定仓库和发布版本的 GitHub 资源标签和文件列表
            tag, assets = get_github_assets(repo, release)
            if not assets:
                tag, assets = get_github_assets(repo)  # 获取最新的发布版本
            if name in assets:
                # 如果文件名在资源列表中,则安全下载对应文件
                safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs)

        return str(file)  # 返回文件路径(本地文件或下载后的文件路径)
# 定义了一个下载函数,用于从指定的 URL 下载文件到指定目录。支持并发下载如果指定了多个线程。
def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False):
    """
    Downloads files from specified URLs to a given directory. Supports concurrent downloads if multiple threads are
    specified.

    Args:
        url (str | list): The URL or list of URLs of the files to be downloaded.
        dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory.
        unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True.
        delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False.
        curl (bool, optional): Flag to use curl for downloading. Defaults to False.
        threads (int, optional): Number of threads to use for concurrent downloads. Defaults to 1.
        retry (int, optional): Number of retries in case of download failure. Defaults to 3.
        exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False.

    Example:
        ```py
        download('https://ultralytics.com/assets/example.zip', dir='path/to/dir', unzip=True)
        ```
    """
    dir = Path(dir)  # 将目录参数转换为 Path 对象
    dir.mkdir(parents=True, exist_ok=True)  # 创建目录,如果目录不存在则递归创建

    if threads > 1:
        # 如果指定了多个线程,则使用线程池并发下载
        with ThreadPool(threads) as pool:
            pool.map(
                lambda x: safe_download(
                    url=x[0],  # 单个文件的下载 URL
                    dir=x[1],  # 下载文件保存的目录
                    unzip=unzip,  # 是否解压缩
                    delete=delete,  # 是否删除压缩文件
                    curl=curl,  # 是否使用 curl 下载
                    retry=retry,  # 下载失败时的重试次数
                    exist_ok=exist_ok,  # 是否覆盖已存在的文件
                    progress=threads <= 1,  # 是否显示下载进度
                ),
                zip(url, repeat(dir)),  # 将 URL 和目录参数进行组合
            )
            pool.close()  # 关闭线程池
            pool.join()  # 等待所有线程任务完成
    else:
        # 如果只有单个线程,顺序下载每个文件
        for u in [url] if isinstance(url, (str, Path)) else url:
            safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok)

.\yolov8\ultralytics\utils\errors.py

代码语言:javascript
复制
# 导入从ultralytics.utils包中导入emojis函数
from ultralytics.utils import emojis

# 定义一个自定义异常类HUBModelError,用于处理Ultralytics YOLO模型获取相关的错误
class HUBModelError(Exception):
    """
    Custom exception class for handling errors related to model fetching in Ultralytics YOLO.
    
    This exception is raised when a requested model is not found or cannot be retrieved.
    The message is also processed to include emojis for better user experience.
    
    Attributes:
        message (str): The error message displayed when the exception is raised.
    
    Note:
        The message is automatically processed through the 'emojis' function from the 'ultralytics.utils' package.
    """
    
    def __init__(self, message="Model not found. Please check model URL and try again."):
        """Create an exception for when a model is not found."""
        # 调用父类的初始化方法
        super().__init__(emojis(message))

.\yolov8\ultralytics\utils\files.py

代码语言:javascript
复制
# Ultralytics YOLO 🚀, AGPL-3.0 license

import contextlib                   # 导入上下文管理模块
import glob                         # 导入文件路径模块
import os                           # 导入操作系统接口模块
import shutil                       # 导入文件操作模块
import tempfile                     # 导入临时文件和目录模块
from contextlib import contextmanager  # 导入上下文管理器装饰器
from datetime import datetime       # 导入日期时间模块
from pathlib import Path            # 导入路径操作模块


class WorkingDirectory(contextlib.ContextDecorator):
    """Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager."""

    def __init__(self, new_dir):
        """Sets the working directory to 'new_dir' upon instantiation."""
        self.dir = new_dir          # 设置新的工作目录
        self.cwd = Path.cwd().resolve()  # 获取当前工作目录的绝对路径

    def __enter__(self):
        """Changes the current directory to the specified directory."""
        os.chdir(self.dir)          # 切换当前工作目录到指定目录

    def __exit__(self, exc_type, exc_val, exc_tb):  # noqa
        """Restore the current working directory on context exit."""
        os.chdir(self.cwd)          # 在上下文退出时恢复原始工作目录


@contextmanager
def spaces_in_path(path):
    """
    Context manager to handle paths with spaces in their names. If a path contains spaces, it replaces them with
    underscores, copies the file/directory to the new path, executes the context code block, then copies the
    file/directory back to its original location.

    Args:
        path (str | Path): The original path.

    Yields:
        (Path): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path.

    Example:
        ```py
        with ultralytics.utils.files import spaces_in_path

        with spaces_in_path('/path/with spaces') as new_path:
            # Your code here
        ```
    """

    # If path has spaces, replace them with underscores
    if " " in str(path):
        string = isinstance(path, str)  # 判断输入路径类型
        path = Path(path)

        # Create a temporary directory and construct the new path
        with tempfile.TemporaryDirectory() as tmp_dir:
            tmp_path = Path(tmp_dir) / path.name.replace(" ", "_")  # 构造替换空格后的临时路径

            # Copy file/directory
            if path.is_dir():
                # 如果是目录,则复制整个目录结构
                # tmp_path.mkdir(parents=True, exist_ok=True)
                shutil.copytree(path, tmp_path)
            elif path.is_file():
                # 如果是文件,则复制文件
                tmp_path.parent.mkdir(parents=True, exist_ok=True)
                shutil.copy2(path, tmp_path)

            try:
                # Yield the temporary path
                yield str(tmp_path) if string else tmp_path  # 生成临时路径并传递给上下文

            finally:
                # Copy file/directory back
                # 将文件/目录复制回原始位置
                if tmp_path.is_dir():
                    shutil.copytree(tmp_path, path, dirs_exist_ok=True)
                elif tmp_path.is_file():
                    shutil.copy2(tmp_path, path)  # 复制文件回原始位置

    else:
        # If there are no spaces, just yield the original path
        yield path  # 如果路径中没有空格,则直接传递原始路径


def increment_path(path, exist_ok=False, sep="", mkdir=False):
    """
    Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.

    """
    # 根据参数path创建一个Path对象,确保在不同操作系统上路径兼容性
    path = Path(path)  
    
    # 检查路径是否存在且exist_ok参数为False时,执行路径增量操作
    if path.exists() and not exist_ok:
        # 如果path是文件,则保留文件扩展名(suffix),否则suffix为空字符串
        path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
    
        # 方法1:从2开始尝试递增直到9999,形成新的路径p
        for n in range(2, 9999):
            p = f"{path}{sep}{n}{suffix}"  # 增加路径序号
            # 如果新路径p不存在,则中断循环
            if not os.path.exists(p):
                break
        # 更新path为新路径的Path对象
        path = Path(p)
    
    # 如果设置了mkdir为True,则创建路径作为目录(包括创建中间目录)
    if mkdir:
        path.mkdir(parents=True, exist_ok=True)  # 创建目录
    
    # 返回增加处理后的Path对象
    return path
def file_age(path=__file__):
    """Return days since last file update."""
    # 获取当前时间与文件最后修改时间的时间差
    dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)  # delta
    # 返回时间差的天数部分,表示文件自上次更新以来经过的天数
    return dt.days  # + dt.seconds / 86400  # fractional days


def file_date(path=__file__):
    """Return human-readable file modification date, i.e. '2021-3-26'."""
    # 获取文件最后修改时间
    t = datetime.fromtimestamp(Path(path).stat().st_mtime)
    # 返回文件最后修改时间的年、月、日组成的格式化字符串
    return f"{t.year}-{t.month}-{t.day}"


def file_size(path):
    """Return file/dir size (MB)."""
    if isinstance(path, (str, Path)):
        mb = 1 << 20  # bytes to MiB (1024 ** 2)
        path = Path(path)
        if path.is_file():
            # 如果路径是文件,则返回文件大小(MB)
            return path.stat().st_size / mb
        elif path.is_dir():
            # 如果路径是目录,则返回目录中所有文件大小的总和(MB)
            return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb
    # 默认情况下返回 0.0 表示大小为 0 MB
    return 0.0


def get_latest_run(search_dir="."):
    """Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
    # 在指定目录及其子目录中搜索所有符合条件的文件路径列表
    last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
    # 返回最新的文件路径,即创建时间最晚的文件路径,如果列表为空则返回空字符串
    return max(last_list, key=os.path.getctime) if last_list else ""


def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_names=False):
    """
    Updates and re-saves specified YOLO models in an 'updated_models' subdirectory.

    Args:
        model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt").
        source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory.
        update_names (bool, optional): Update model names from a data YAML.

    Example:
        ```py
        from ultralytics.utils.files import update_models

        model_names = (f"rtdetr-{size}.pt" for size in "lx")
        update_models(model_names)
        ```
    """
    from ultralytics import YOLO
    from ultralytics.nn.autobackend import default_class_names

    # 设置目标目录为当前目录下的 updated_models 子目录,如果不存在则创建
    target_dir = source_dir / "updated_models"
    target_dir.mkdir(parents=True, exist_ok=True)  # Ensure target directory exists

    for model_name in model_names:
        model_path = source_dir / model_name
        print(f"Loading model from {model_path}")

        # 加载模型
        model = YOLO(model_path)
        model.half()  # 使用半精度浮点数进行模型运算,加速模型计算速度

        if update_names:  # 根据数据 YAML 更新模型的类别名称
            model.model.names = default_class_names("coco8.yaml")

        # 定义新的保存路径
        save_path = target_dir / model_name

        # 使用 model.save() 方法重新保存模型
        print(f"Re-saving {model_name} model to {save_path}")
        model.save(save_path, use_dill=False)

.\yolov8\ultralytics\utils\instance.py

代码语言:javascript
复制
# Ultralytics YOLO 🚀, AGPL-3.0 license

# 导入必要的模块和库
from collections import abc
from itertools import repeat
from numbers import Number
from typing import List

import numpy as np

# 从本地导入自定义的操作函数
from .ops import ltwh2xywh, ltwh2xyxy, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh

# 定义一个辅助函数_ntuple,用于解析参数为可迭代对象或重复值
def _ntuple(n):
    """From PyTorch internals."""
    
    def parse(x):
        """Parse bounding boxes format between XYWH and LTWH."""
        return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))
    
    return parse

# 定义两个辅助函数,分别生成2元组和4元组
to_2tuple = _ntuple(2)
to_4tuple = _ntuple(4)

# 定义支持的边界框格式列表
# `xyxy` 表示左上角和右下角坐标
# `xywh` 表示中心点坐标和宽度、高度(YOLO格式)
# `ltwh` 表示左上角坐标和宽度、高度(COCO格式)
_formats = ["xyxy", "xywh", "ltwh"]

# 导出的类名列表
__all__ = ("Bboxes",)  # tuple or list

# 定义边界框类 Bboxes
class Bboxes:
    """
    A class for handling bounding boxes.

    The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh'.
    Bounding box data should be provided in numpy arrays.

    Attributes:
        bboxes (numpy.ndarray): The bounding boxes stored in a 2D numpy array.
        format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh').

    Note:
        This class does not handle normalization or denormalization of bounding boxes.
    """

    def __init__(self, bboxes, format="xyxy") -> None:
        """Initializes the Bboxes class with bounding box data in a specified format."""
        # 检查边界框格式是否有效
        assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
        # 如果边界框是1维数组,则转换成2维数组
        bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
        # 检查边界框数组维度为2
        assert bboxes.ndim == 2
        # 检查每个边界框的数组形状为(4,)
        assert bboxes.shape[1] == 4
        self.bboxes = bboxes
        self.format = format
        # self.normalized = normalized

    def convert(self, format):
        """Converts bounding box format from one type to another."""
        # 检查目标格式是否有效
        assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
        # 如果当前格式与目标格式相同,则无需转换
        if self.format == format:
            return
        # 根据当前格式和目标格式选择相应的转换函数
        elif self.format == "xyxy":
            func = xyxy2xywh if format == "xywh" else xyxy2ltwh
        elif self.format == "xywh":
            func = xywh2xyxy if format == "xyxy" else xywh2ltwh
        else:
            func = ltwh2xyxy if format == "xyxy" else ltwh2xywh
        # 执行转换,并更新边界框数组和格式
        self.bboxes = func(self.bboxes)
        self.format = format

    def areas(self):
        """Return box areas."""
        # 计算每个边界框的面积
        return (
            (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])  # format xyxy
            if self.format == "xyxy"
            else self.bboxes[:, 3] * self.bboxes[:, 2]  # format xywh or ltwh
        )

    # def denormalize(self, w, h):
    #    if not self.normalized:
    #         return
    #     assert (self.bboxes <= 1.0).all()
    #     self.bboxes[:, 0::2] *= w
    #     self.bboxes[:, 1::2] *= h
    #     self.normalized = False
    #
    # def normalize(self, w, h):
    #     if self.normalized:
    #         return
    # 检查是否有任何边界框的值大于1.0
    assert (self.bboxes > 1.0).any()
    # 将所有边界框的 x 坐标和宽度进行归一化处理
    self.bboxes[:, 0::2] /= w
    # 将所有边界框的 y 坐标和高度进行归一化处理
    self.bboxes[:, 1::2] /= h
    # 设置标志,表示边界框已被归一化处理
    self.normalized = True

def mul(self, scale):
    """
    Args:
        scale (tuple | list | int): 四个坐标的缩放比例。
    """
    # 如果 scale 是一个单独的数值,则转换为包含四个相同值的元组
    if isinstance(scale, Number):
        scale = to_4tuple(scale)
    # 断言 scale 是元组或列表类型
    assert isinstance(scale, (tuple, list))
    # 断言 scale 的长度为四,即包含四个缩放比例
    assert len(scale) == 4
    # 将所有边界框的四个坐标分别乘以对应的缩放比例
    self.bboxes[:, 0] *= scale[0]
    self.bboxes[:, 1] *= scale[1]
    self.bboxes[:, 2] *= scale[2]
    self.bboxes[:, 3] *= scale[3]

def add(self, offset):
    """
    Args:
        offset (tuple | list | int): 四个坐标的偏移量。
    """
    # 如果 offset 是一个单独的数值,则转换为包含四个相同值的元组
    if isinstance(offset, Number):
        offset = to_4tuple(offset)
    # 断言 offset 是元组或列表类型
    assert isinstance(offset, (tuple, list))
    # 断言 offset 的长度为四,即包含四个偏移量
    assert len(offset) == 4
    # 将所有边界框的四个坐标分别加上对应的偏移量
    self.bboxes[:, 0] += offset[0]
    self.bboxes[:, 1] += offset[1]
    self.bboxes[:, 2] += offset[2]
    self.bboxes[:, 3] += offset[3]

def __len__(self):
    """返回边界框的数量。"""
    return len(self.bboxes)

@classmethod
def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
    """
    将一个 Bboxes 对象的列表或元组连接成一个单一的 Bboxes 对象。

    Args:
        boxes_list (List[Bboxes]): 要连接的 Bboxes 对象的列表。
        axis (int, optional): 沿着哪个轴连接边界框。默认为 0。

    Returns:
        Bboxes: 包含连接后的边界框的新 Bboxes 对象。

    Note:
        输入应为 Bboxes 对象的列表或元组。
    """
    # 断言 boxes_list 是列表或元组类型
    assert isinstance(boxes_list, (list, tuple))
    # 如果 boxes_list 为空,则返回一个空的 Bboxes 对象
    if not boxes_list:
        return cls(np.empty(0))
    # 断言 boxes_list 中的所有元素都是 Bboxes 对象
    assert all(isinstance(box, Bboxes) for box in boxes_list)

    # 如果 boxes_list 只包含一个元素,则直接返回这个元素
    if len(boxes_list) == 1:
        return boxes_list[0]
    # 使用 np.concatenate 将所有 Bboxes 对象中的边界框数组连接起来
    return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
    # 定义一个特殊方法 __getitem__,用于通过索引获取特定的边界框或一组边界框。

    def __getitem__(self, index) -> "Bboxes":
        """
        Retrieve a specific bounding box or a set of bounding boxes using indexing.

        Args:
            index (int, slice, or np.ndarray): The index, slice, or boolean array to select
                                               the desired bounding boxes.

        Returns:
            Bboxes: A new Bboxes object containing the selected bounding boxes.

        Raises:
            AssertionError: If the indexed bounding boxes do not form a 2-dimensional matrix.

        Note:
            When using boolean indexing, make sure to provide a boolean array with the same
            length as the number of bounding boxes.
        """

        # 如果索引是整数,返回一个包含单个边界框的新 Bboxes 对象
        if isinstance(index, int):
            return Bboxes(self.bboxes[index].view(1, -1))

        # 对于其他类型的索引,直接获取对应的边界框数组
        b = self.bboxes[index]

        # 断言所得到的边界框数组是二维矩阵,否则抛出异常
        assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"

        # 返回一个新的 Bboxes 对象,其中包含选定的边界框数组
        return Bboxes(b)
class Instances:
    """
    Container for bounding boxes, segments, and keypoints of detected objects in an image.

    Attributes:
        _bboxes (Bboxes): Internal object for handling bounding box operations.
        keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3]. Default is None.
        normalized (bool): Flag indicating whether the bounding box coordinates are normalized.
        segments (ndarray): Segments array with shape [N, 1000, 2] after resampling.

    Args:
        bboxes (ndarray): An array of bounding boxes with shape [N, 4].
        segments (list | ndarray, optional): A list or array of object segments. Default is None.
        keypoints (ndarray, optional): An array of keypoints with shape [N, 17, 3]. Default is None.
        bbox_format (str, optional): The format of bounding boxes ('xywh' or 'xyxy'). Default is 'xywh'.
        normalized (bool, optional): Whether the bounding box coordinates are normalized. Default is True.

    Examples:
        ```py
        # Create an Instances object
        instances = Instances(
            bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),
            segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],
            keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]])
        )
        ```

    Note:
        The bounding box format is either 'xywh' or 'xyxy', and is determined by the `bbox_format` argument.
        This class does not perform input validation, and it assumes the inputs are well-formed.
    """

    def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
        """
        Args:
            bboxes (ndarray): bboxes with shape [N, 4].
            segments (list | ndarray): segments.
            keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3].
        """
        # Initialize internal bounding box handler with given format
        self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
        # Set keypoints attribute
        self.keypoints = keypoints
        # Set normalized flag
        self.normalized = normalized
        # Set segments attribute
        self.segments = segments

    def convert_bbox(self, format):
        """Convert bounding box format."""
        # Delegate conversion to internal bounding box handler
        self._bboxes.convert(format=format)

    @property
    def bbox_areas(self):
        """Calculate the area of bounding boxes."""
        # Retrieve areas of bounding boxes using internal handler
        return self._bboxes.areas()

    def scale(self, scale_w, scale_h, bbox_only=False):
        """This might be similar with denormalize func but without normalized sign."""
        # Scale bounding boxes
        self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
        # If only bbox scaling is requested, return early
        if bbox_only:
            return
        # Scale segments coordinates
        self.segments[..., 0] *= scale_w
        self.segments[..., 1] *= scale_h
        # If keypoints exist, scale their coordinates as well
        if self.keypoints is not None:
            self.keypoints[..., 0] *= scale_w
            self.keypoints[..., 1] *= scale_h
    def denormalize(self, w, h):
        """Denormalizes boxes, segments, and keypoints from normalized coordinates."""
        # 如果未进行归一化,则直接返回
        if not self.normalized:
            return
        # 缩放边界框(bounding boxes),分割(segments)和关键点(keypoints)到原始坐标
        self._bboxes.mul(scale=(w, h, w, h))
        # 对分割的 x 和 y 坐标进行反归一化
        self.segments[..., 0] *= w
        self.segments[..., 1] *= h
        # 如果存在关键点数据,则对其 x 和 y 坐标进行反归一化
        if self.keypoints is not None:
            self.keypoints[..., 0] *= w
            self.keypoints[..., 1] *= h
        # 标记对象已经不再是归一化状态
        self.normalized = False

    def normalize(self, w, h):
        """Normalize bounding boxes, segments, and keypoints to image dimensions."""
        # 如果已经进行了归一化,则直接返回
        if self.normalized:
            return
        # 将边界框(bounding boxes),分割(segments)和关键点(keypoints)归一化到图像尺寸
        self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
        self.segments[..., 0] /= w
        self.segments[..., 1] /= h
        if self.keypoints is not None:
            self.keypoints[..., 0] /= w
            self.keypoints[..., 1] /= h
        # 标记对象已经处于归一化状态
        self.normalized = True

    def add_padding(self, padw, padh):
        """Handle rect and mosaic situation."""
        # 断言对象未处于归一化状态,即只能使用绝对坐标添加填充
        assert not self.normalized, "you should add padding with absolute coordinates."
        # 添加填充到边界框(bounding boxes),分割(segments)和关键点(keypoints)
        self._bboxes.add(offset=(padw, padh, padw, padh))
        self.segments[..., 0] += padw
        self.segments[..., 1] += padh
        if self.keypoints is not None:
            self.keypoints[..., 0] += padw
            self.keypoints[..., 1] += padh

    def __getitem__(self, index) -> "Instances":
        """
        Retrieve a specific instance or a set of instances using indexing.

        Args:
            index (int, slice, or np.ndarray): The index, slice, or boolean array to select
                                               the desired instances.

        Returns:
            Instances: A new Instances object containing the selected bounding boxes,
                       segments, and keypoints if present.

        Note:
            When using boolean indexing, make sure to provide a boolean array with the same
            length as the number of instances.
        """
        # 根据索引获取特定的实例或一组实例
        segments = self.segments[index] if len(self.segments) else self.segments
        keypoints = self.keypoints[index] if self.keypoints is not None else None
        bboxes = self.bboxes[index]
        bbox_format = self._bboxes.format
        # 返回一个新的 Instances 对象,包含所选的边界框(bounding boxes),分割(segments)和关键点(keypoints)
        return Instances(
            bboxes=bboxes,
            segments=segments,
            keypoints=keypoints,
            bbox_format=bbox_format,
            normalized=self.normalized,
        )

    def flipud(self, h):
        """Flips the coordinates of bounding boxes, segments, and keypoints vertically."""
        # 如果边界框的格式是 "xyxy",则对应的顶部和底部坐标进行垂直翻转
        if self._bboxes.format == "xyxy":
            y1 = self.bboxes[:, 1].copy()
            y2 = self.bboxes[:, 3].copy()
            self.bboxes[:, 1] = h - y2
            self.bboxes[:, 3] = h - y1
        else:
            # 否则直接对 y 坐标进行垂直翻转
            self.bboxes[:, 1] = h - self.bboxes[:, 1]
        # 对分割的 y 坐标进行垂直翻转
        self.segments[..., 1] = h - self.segments[..., 1]
        if self.keypoints is not None:
            # 如果存在关键点数据,则对其 y 坐标进行垂直翻转
            self.keypoints[..., 1] = h - self.keypoints[..., 1]
    def fliplr(self, w):
        """Reverses the order of the bounding boxes and segments horizontally."""
        # 检查边界框格式是否为 "xyxy"
        if self._bboxes.format == "xyxy":
            # 复制边界框的 x1 和 x2 坐标
            x1 = self.bboxes[:, 0].copy()
            x2 = self.bboxes[:, 2].copy()
            # 更新边界框的 x1 和 x2 坐标以反转水平方向
            self.bboxes[:, 0] = w - x2
            self.bboxes[:, 2] = w - x1
        else:
            # 更新边界框的 x 坐标以反转水平方向
            self.bboxes[:, 0] = w - self.bboxes[:, 0]
        # 更新段的 x 坐标以反转水平方向
        self.segments[..., 0] = w - self.segments[..., 0]
        # 如果关键点不为 None,则更新关键点的 x 坐标以反转水平方向
        if self.keypoints is not None:
            self.keypoints[..., 0] = w - self.keypoints[..., 0]

    def clip(self, w, h):
        """Clips bounding boxes, segments, and keypoints values to stay within image boundaries."""
        # 保存原始的边界框格式
        ori_format = self._bboxes.format
        # 转换边界框格式为 "xyxy"
        self.convert_bbox(format="xyxy")
        # 将边界框的 x 和 y 坐标限制在图像边界内
        self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
        self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
        # 如果原始边界框格式不是 "xyxy",则转换回原始格式
        if ori_format != "xyxy":
            self.convert_bbox(format=ori_format)
        # 将段的 x 和 y 坐标限制在图像边界内
        self.segments[..., 0] = self.segments[..., 0].clip(0, w)
        self.segments[..., 1] = self.segments[..., 1].clip(0, h)
        # 如果关键点不为 None,则将关键点的 x 和 y 坐标限制在图像边界内
        if self.keypoints is not None:
            self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
            self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)

    def remove_zero_area_boxes(self):
        """Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height."""
        # 判断哪些边界框面积大于 0
        good = self.bbox_areas > 0
        # 如果存在面积为 0 的边界框,则移除它们
        if not all(good):
            self._bboxes = self._bboxes[good]
            # 如果段的长度不为 0,则移除与边界框对应的段
            if len(self.segments):
                self.segments = self.segments[good]
            # 如果关键点不为 None,则移除与边界框对应的关键点
            if self.keypoints is not None:
                self.keypoints = self.keypoints[good]
        # 返回保留的边界框索引列表
        return good

    def update(self, bboxes, segments=None, keypoints=None):
        """Updates instance variables."""
        # 更新边界框变量
        self._bboxes = Bboxes(bboxes, format=self._bboxes.format)
        # 如果提供了段变量,则更新段变量
        if segments is not None:
            self.segments = segments
        # 如果提供了关键点变量,则更新关键点变量
        if keypoints is not None:
            self.keypoints = keypoints

    def __len__(self):
        """Return the length of the instance list."""
        # 返回边界框列表的长度
        return len(self.bboxes)
    # 定义一个类方法,用于将多个 Instances 对象连接成一个单一的 Instances 对象
    def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
        """
        Concatenates a list of Instances objects into a single Instances object.

        Args:
            instances_list (List[Instances]): A list of Instances objects to concatenate.
            axis (int, optional): The axis along which the arrays will be concatenated. Defaults to 0.

        Returns:
            Instances: A new Instances object containing the concatenated bounding boxes,
                       segments, and keypoints if present.

        Note:
            The `Instances` objects in the list should have the same properties, such as
            the format of the bounding boxes, whether keypoints are present, and if the
            coordinates are normalized.
        """
        # 断言 instances_list 是一个列表或元组
        assert isinstance(instances_list, (list, tuple))
        # 如果 instances_list 为空列表,则返回一个空的 Instances 对象
        if not instances_list:
            return cls(np.empty(0))
        # 断言 instances_list 中的所有元素都是 Instances 对象
        assert all(isinstance(instance, Instances) for instance in instances_list)

        # 如果 instances_list 中只有一个元素,则直接返回该元素
        if len(instances_list) == 1:
            return instances_list[0]

        # 确定是否使用关键点
        use_keypoint = instances_list[0].keypoints is not None
        # 获取边界框格式
        bbox_format = instances_list[0]._bboxes.format
        # 获取是否使用了规范化的标志
        normalized = instances_list[0].normalized

        # 按指定轴连接边界框数组
        cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
        # 按指定轴连接分割数组
        cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
        # 如果使用关键点,则按指定轴连接关键点数组;否则设置为 None
        cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
        # 返回一个新的 Instances 对象,包含连接后的边界框、分割和关键点(如果有)、边界框格式和规范化信息
        return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)

    @property
    def bboxes(self):
        """Return bounding boxes."""
        # 返回私有成员变量 _bboxes 的 bboxes 属性
        return self._bboxes.bboxes
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2024-09-06,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • .\yolov8\ultralytics\utils\callbacks\raytune.p
  • .\yolov8\ultralytics\utils\callbacks\tensorboard.py
  • .\yolov8\ultralytics\utils\callbacks\wb.py
  • .\yolov8\ultralytics\utils\callbacks\__init__.py
  • .\yolov8\ultralytics\utils\checks.py
  • .\yolov8\ultralytics\utils\dist.py
  • .\yolov8\ultralytics\utils\downloads.py
  • .\yolov8\ultralytics\utils\errors.py
  • .\yolov8\ultralytics\utils\files.py
  • .\yolov8\ultralytics\utils\instance.py
相关产品与服务
灰盒安全测试
腾讯知识图谱(Tencent Knowledge Graph,TKG)是一个集成图数据库、图计算引擎和图可视化分析的一站式平台。支持抽取和融合异构数据,支持千亿级节点关系的存储和计算,支持规则匹配、机器学习、图嵌入等图数据挖掘算法,拥有丰富的图数据渲染和展现的可视化方案。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档