首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >nnU-Net 入门实操教程

nnU-Net 入门实操教程

原创
作者头像
buzzfrog
修改2025-10-16 11:31:55
修改2025-10-16 11:31:55
10530
代码可运行
举报
文章被收录于专栏:云上修行云上修行
运行总次数:0
代码可运行

目录

  1. nnU-Net 简介
  2. 环境安装
  3. 数据集准备
  4. 数据预处理
  5. 模型训练
  6. 模型推理
  7. 结果可视化
  8. 常见问题

1. nnU-Net 简介

1.1 什么是 nnU-Net?

nnU-Net (no-new-UNet) 是一个用于医学图像分割的自适应深度学习框架,由德国癌症研究中心(DKFZ)开发。它最大的特点是自动配置,能够根据不同的数据集自动确定最优的网络架构、训练策略和数据增强方案。它对应的论文为:2020年12月《nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation》和2024年7月《nnU-Net Revisited: A Call for Rigorous Validation in 3D Medical Image Segmentation》

架构及效果
架构及效果

1.2 核心优势

  • 全自动配置:无需手动调参,自动适应不同数据集
  • 性能优异:在多个医学图像分割挑战赛中获得第一名
  • 标准化流程:提供完整的数据处理、训练、推理pipeline
  • 开箱即用:遵循标准数据格式,即可快速上手

1.3 应用场景

  • 医学图像分割(CT、MRI、超声等)
  • 器官分割(肝脏、肺、脑组织等)
  • 病灶检测(肿瘤、病变区域等)
  • 3D 和 2D 图像分割任务

1.4 工作流程

代码语言:txt
复制
原始数据 → 数据格式转换 → 预处理 → 训练 → 推理 → 结果可视化

2. 环境安装

2.1 系统要求

  • 操作系统:Linux、macOS 或 Windows
  • Python:3.9 或更高版本
  • GPU:推荐使用 NVIDIA GPU(至少 8GB 显存,缺省参数实际会达到17GB)
  • 存储空间:根据数据集大小,建议至少 50GB 可用空间

2.2 安装步骤

步骤 1:创建虚拟环境(推荐)

代码语言:bash
复制
# 使用 conda
conda create -n nnunet python=3.10
conda activate nnunet

# 或使用 venv
python3 -m venv nnunet_env
source nnunet_env/bin/activate  # Linux/macOS
# nnunet_env\Scripts\activate  # Windows

步骤 2:安装 PyTorch

根据您的系统和 CUDA 版本安装 PyTorch:

代码语言:bash
复制
# CUDA 11.8 (推荐)
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# CUDA 12.1
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

# CPU 版本(仅用于测试,不推荐训练)
pip3 install torch torchvision torchaudio

# macOS (支持 MPS)
pip3 install torch torchvision torchaudio

步骤 3:安装 nnUNet

代码语言:bash
复制
pip3 install nnunetv2

步骤 4:安装可视化依赖

代码语言:bash
复制
pip3 install nibabel matplotlib

步骤 5:验证安装

代码语言:bash
复制
python3 -c "import nnunetv2; print('nnU-Net v2 安装成功!')"

2.3 环境变量设置

nnU-Net 需要三个环境变量来指定数据存储路径:

代码语言:bash
复制
# 创建项目目录
mkdir -p ~/nnUNet_workspace
cd ~/nnUNet_workspace

# 设置环境变量
export nnUNet_raw="$PWD/nnUNet_raw"
export nnUNet_preprocessed="$PWD/nnUNet_preprocessed"
export nnUNet_results="$PWD/nnUNet_results"

# 建议添加到 ~/.bashrc 或 ~/.zshrc,使其永久生效
echo 'export nnUNet_raw="$HOME/nnUNet_workspace/nnUNet_raw"' >> ~/.bashrc
echo 'export nnUNet_preprocessed="$HOME/nnUNet_workspace/nnUNet_preprocessed"' >> ~/.bashrc
echo 'export nnUNet_results="$HOME/nnUNet_workspace/nnUNet_results"' >> ~/.bashrc
source ~/.bashrc

目录说明:

  • nnUNet_raw:存放原始数据集
  • nnUNet_preprocessed:存放预处理后的数据
  • nnUNet_results:存放训练结果和模型权重

3. 数据集准备

3.1 示例数据集:Task04_Hippocampus

本教程使用 Hippocampus(海马体分割) 数据集作为示例。(为什么选它,因为这个数据集比较小,只有28.4MB)

数据集信息

  • 名称:Task04_Hippocampus
  • 来源:Medical Segmentation Decathlon
  • 提供方:Vanderbilt University Medical Center
  • 许可证:CC-BY-SA 4.0
  • 模态:MRI
  • 任务:海马体前部(Anterior)和后部(Posterior)分割
  • 训练样本:260 例
  • 测试样本:130 例

标签说明

  • 0:背景(background)
  • 1:海马体前部(Anterior)
  • 2:海马体后部(Posterior)

3.2 数据集下载

方法 1:官方下载地址

访问 Medical Segmentation Decathlon 官网:

代码语言:txt
复制
https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2
数据集下载页面
数据集下载页面

在 其中找到 Task04_Hippocampus.tar 文件并下载。

方法 2:命令行下载(需要 gdown)

代码语言:bash
复制
# 安装 gdown
pip3 install gdown

# 下载数据集(文件ID可能会变化,请从官网获取最新链接)
gdown https://drive.google.com/uc?id=1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C

方法 3:命令行下载(需要 wget)

代码语言:bash
复制
# 安装wget
sudo apt install wget

# 下载数据集(链接未来可能会失效)
wget https://msd-for-monai.s3-us-west-2.amazonaws.com/Task04_Hippocampus.tar -o Task04_Hippocampus.tar

3.3 数据集目录结构

下载后,将 Task04_Hippocampus.tar 解压到 nnUNet_raw 目录:

代码语言:bash
复制
cd $nnUNet_raw
tar -xvf Task04_Hippocampus.tar

解压后的目录结构:

代码语言:bash
复制
$nnUNet_raw/Task04_Hippocampus/
├── dataset.json          # 数据集元数据
├── imagesTr/             # 训练图像 (260个 .nii.gz 文件)
│   ├── hippocampus_001.nii.gz
│   ├── hippocampus_002.nii.gz
│   └── ...
├── labelsTr/             # 训练标签 (260个 .nii.gz 文件)
│   ├── hippocampus_001.nii.gz
│   ├── hippocampus_002.nii.gz
│   └── ...
└── imagesTs/             # 测试图像 (130个 .nii.gz 文件)
    ├── hippocampus_267.nii.gz
    └── ...

3.4 数据格式转换(旧版本 → nnUNet v2)

如果数据集是旧版 nnUNet 格式(如 Task04_Hippocampus),需要转换为 v2 格式:

代码语言:bash
复制
# 重命名为标准格式(Task04 → Task004)
mv $nnUNet_raw/Task04_Hippocampus $nnUNet_raw/Task004_Hippocampus

# 使用转换工具
nnUNetv2_convert_old_nnUNet_dataset \
    $nnUNet_raw/Task004_Hippocampus \
    Dataset004_Hippocampus

转换后会生成新的数据集目录:

代码语言:bash
复制
$nnUNet_raw/Dataset004_Hippocampus/
├── dataset.json
├── imagesTr/
│   ├── hippocampus_001_0000.nii.gz  # 注意:文件名添加了 _0000 后缀
│   └── ...
├── labelsTr/
│   ├── hippocampus_001.nii.gz
│   └── ...
└── imagesTs/
    ├── hippocampus_267_0000.nii.gz
    └── ...

说明:

  • _0000 后缀表示第 0 个通道(模态),多模态数据会有 _0001, _0002
  • dataset.json 包含数据集的元信息(标签、模态、文件列表等)

4. 数据预处理

4.1 预处理步骤

nnU-Net 会自动执行以下预处理:

  1. 数据分析:分析图像尺寸、间距、强度分布
  2. 重采样:统一图像间距(voxel spacing)
  3. 归一化:标准化图像强度值
  4. 裁剪:去除背景区域
  5. 生成训练计划:确定网络架构、batch size、patch size 等

4.2 执行预处理

代码语言:bash
复制
nnUNetv2_plan_and_preprocess -d 4 --verify_dataset_integrity

参数说明:

  • -d 4:数据集 ID(对应 Dataset004_Hippocampus)
  • --verify_dataset_integrity:验证数据集完整性(检查文件是否缺失、标签是否正确等)

4.3 预处理输出

预处理完成后,会在 $nnUNet_preprocessed/Dataset004_Hippocampus 生成以下文件:

代码语言:bash
复制
$nnUNet_preprocessed/Dataset004_Hippocampus/
├── dataset_fingerprint.json      # 数据集指纹(统计信息)
├── dataset.json                  # 数据集元数据
├── nnUNetPlans.json              # 训练计划(网络配置)
├── splits_final.json             # 交叉验证划分
├── gt_segmentations/             # Ground truth 分割(用于验证)
├── nnUNetPlans_2d/               # 2D 预处理数据
└── nnUNetPlans_3d_fullres/       # 3D 全分辨率预处理数据

4.4 查看训练计划

代码语言:bash
复制
cat $nnUNet_preprocessed/Dataset004_Hippocampus/nnUNetPlans.json

该文件包含:

  • 网络配置:UNet 编码器/解码器层数、卷积核大小
  • 数据配置:patch size、batch size、数据增强策略
  • 训练配置:学习率、优化器、训练轮数

5. 模型训练

5.1 训练命令

nnUNet 支持多种训练配置:

5.1.1 2D 网络训练

代码语言:bash
复制
nnUNetv2_train 4 2d 0

参数说明:

  • 4:数据集 ID
  • 2d:网络配置(逐切片训练)
  • 0:交叉验证 fold(0-4,共5折)

5.1.2 3D 全分辨率网络训练

代码语言:bash
复制
nnUNetv2_train 4 3d_fullres 0

5.1.3 3D 低分辨率网络训练

代码语言:bash
复制
nnUNetv2_train 4 3d_lowres 0

5.1.4 3D 级联网络训练

代码语言:bash
复制
# 第一阶段:低分辨率
nnUNetv2_train 4 3d_lowres 0

# 第二阶段:级联
nnUNetv2_train 4 3d_cascade_fullres 0

5.2 GPU/CPU 设备选择

使用特定 GPU

代码语言:bash
复制
CUDA_VISIBLE_DEVICES=0 nnUNetv2_train 4 2d 0

使用多 GPU

代码语言:bash
复制
CUDA_VISIBLE_DEVICES=0,1 nnUNetv2_train 4 2d 0

使用 macOS MPS(Apple Silicon)

代码语言:bash
复制
nnUNetv2_train -device mps 4 2d 0

使用 CPU(不推荐,速度慢)

代码语言:bash
复制
nnUNetv2_train -device cpu 4 2d 0

5.3 训练过程监控

训练日志会保存在:

代码语言:bash
复制
$nnUNet_results/Dataset004_Hippocampus/nnUNetTrainer__nnUNetPlans__2d/fold_0/
├── training_log.txt         # 训练日志
├── checkpoint_best.pth      # 最佳模型权重
├── checkpoint_final.pth     # 最终模型权重
├── progress.png             # 训练曲线图
└── validation_raw/          # 验证集预测结果
progress.png训练曲线图
progress.png训练曲线图

实时查看训练日志:

代码语言:bash
复制
tail -f $nnUNet_results/Dataset004_Hippocampus/nnUNetTrainer__nnUNetPlans__2d/fold_0/training_log.txt

5.4 训练时间估计

  • 2D 网络:GTX 1080 Ti 约 4-6 小时
  • 3D 全分辨率:GTX 1080 Ti 约 12-24 小时
  • Apple M1/M2 (MPS):约为 GPU 训练时间的 2-3 倍

5.5 训练所有交叉验证折

为了获得最佳性能,建议训练所有 5 折:

代码语言:bash
复制
# 自动训练所有折
nnUNetv2_train 4 2d all

# 或手动训练每一折
for fold in 0 1 2 3 4; do
    nnUNetv2_train 4 2d $fold
done

6. 模型推理

6.1 推理命令

使用训练好的模型对新数据进行预测:

代码语言:bash
复制
nnUNetv2_predict \
    -i $nnUNet_raw/Dataset004_Hippocampus/imagesTs \
    -o $nnUNet_results/Dataset004_Hippocampus/predictions \
    -d 4 \
    -c 2d \
    -f 0

参数说明:

  • -i:输入图像目录
  • -o:输出分割结果目录
  • -d:数据集 ID
  • -c:配置(2d、3d_fullres等)
  • -f:使用哪个 fold 的模型(可指定多个:-f 0 1 2 3 4

6.2 使用所有折的集成预测

代码语言:bash
复制
nnUNetv2_predict \
    -i $nnUNet_raw/Dataset004_Hippocampus/imagesTs \
    -o $nnUNet_results/Dataset004_Hippocampus/predictions_ensemble \
    -d 4 \
    -c 2d \
    -f 0 1 2 3 4  # 使用所有5折进行集成

集成预测通常比单折预测效果更好。

6.3 调整推理设置

禁用测试时增强(TTA)

代码语言:bash
复制
nnUNetv2_predict -i INPUT -o OUTPUT -d 4 -c 2d -f 0 --disable_tta

保存概率图

代码语言:bash
复制
nnUNetv2_predict -i INPUT -o OUTPUT -d 4 -c 2d -f 0 --save_probabilities

7. 结果可视化

7.1 使用提供的可视化脚本

项目中提供了四个可视化脚本:

7.1.1 快速查看(三视图)

代码语言:bash
复制
python3 show_nii.py $nnUNet_raw/Dataset004_Hippocampus/imagesTr/hippocampus_001_0000.nii.gz

显示轴向、矢状、冠状三个切面的中间切片。

7.1.2 交互式浏览

代码语言:bash
复制
python3 show_nii.py $nnUNet_raw/Dataset004_Hippocampus/imagesTr/hippocampus_001_0000.nii.gz -i --view axial

使用滑块逐层浏览切片。

视图选项:

  • axial:轴向切片(水平切面)
  • sagittal:矢状切片(侧面切面)
  • coronal:冠状切片(正面切面)

7.1.3 3D 表面渲染

代码语言:bash
复制
python3 show_nii_3d.py nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz --mode surface

7.1.4 散点图模式

代码语言:bash
复制
python3 show_nii_3d.py nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz --mode scatter

7.1.5 交互式 3D

代码语言:bash
复制
python3 show_nii_3d_interactive.py nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz

可以使用鼠标旋转

7.1.6 3D整体效果

代码语言:bash
复制
python3 show_nii_overlay_3d.py \
    nnUNet_raw/Dataset004_Hippocampus/imagesTr/hippocampus_001_0000.nii.gz \
    nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz

在3D互动图中,你会看到:

  • 透明的浅蓝色"云":包裹在外围的原始图像轮廓
  • 红色实体(Anterior海马体的前部区域):位于其中的主要分割目标
  • 绿色实体(Posterior海马体的后部区域) 你可以通过点击图例来隐藏/显示任意层,方便单独查看某个区域。

7.2 可视化代码

show_nii.py

代码语言:python
代码运行次数:0
运行
复制
#!/usr/bin/env python3
"""
NIfTI 文件可视化脚本
用法: python3 show_nii.py <nii_file_path>

# 三个切片视图解释
 - 1. 轴向切片(Axial)
也叫横断面或水平切面
就像把人体水平切开,从头顶往下看
类似 CT/MRI 扫描时,躺在床上被一层层扫描的那个方向
可以看到左右对称的结构

 - 2. 矢状切片(Sagittal)
也叫侧面切面
就像把人体从中间垂直切开,从左侧或右侧看
可以看到前后方向的结构(比如鼻子、脑、脊柱的前后关系)

 - 3. 冠状切片(Coronal)
也叫额状切面或正面切面
就像把人体从前往后垂直切开,从正面或背面看
可以看到左右和上下的结构

# 为什么需要三个视图?
因为医学图像数据是 3D 的(有宽度、高度、深度),但屏幕是 2D 的。通过这三个互相垂直的切面,医生可以从不同角度全面观察器官的形态、病变的位置和大小。
在你的脚本中:
 - 非交互模式:同时显示三个视图的中间切片,让你快速了解整个 3D 数据
 - 交互模式(-i 参数):选择一个视图,然后用滑块逐层浏览所有切片

例子:
python3 show_nii.py nnUNet_raw/Dataset004_Hippocampus/imagesTr/hippocampus_001_0000.nii.gz
交互模式:
python3 show_nii.py nnUNet_raw/Dataset004_Hippocampus/imagesTr/hippocampus_001_0000.nii.gz -i
"""

import argparse
import sys
import os
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider

# 配置 matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'STHeiti', 'Microsoft YaHei', 
                                     'PingFang HK', 'Heiti TC', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题


class NiftiViewer:
    def __init__(self, nii_path):
        """初始化 NIfTI 查看器"""
        if not os.path.exists(nii_path):
            raise FileNotFoundError(f"文件不存在: {nii_path}")
        
        print(f"正在加载文件: {nii_path}")
        self.nii_img = nib.load(nii_path)
        self.data = self.nii_img.get_fdata()
        self.nii_path = nii_path
        
        print(f"数据形状: {self.data.shape}")
        print(f"数据类型: {self.data.dtype}")
        print(f"数值范围: [{self.data.min():.2f}, {self.data.max():.2f}]")
        
        # 归一化数据以便更好地显示
        self.data_norm = self._normalize_data(self.data)
        
    def _normalize_data(self, data):
        """归一化数据到 [0, 1] 范围"""
        data_min = data.min()
        data_max = data.max()
        if data_max - data_min > 0:
            return (data - data_min) / (data_max - data_min)
        else:
            return data
    
    def show_slices(self):
        """显示三个正交切面的中间切片"""
        # 获取中间切片索引
        mid_axial = self.data.shape[2] // 2
        mid_sagittal = self.data.shape[0] // 2
        mid_coronal = self.data.shape[1] // 2
        
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # 轴向切片 (Axial)
        axes[0].imshow(self.data_norm[:, :, mid_axial].T, cmap='gray', origin='lower')
        axes[0].set_title(f'轴向切片 (Axial)\n切片: {mid_axial}/{self.data.shape[2]}')
        axes[0].axis('off')
        
        # 矢状切片 (Sagittal)
        axes[1].imshow(self.data_norm[mid_sagittal, :, :].T, cmap='gray', origin='lower')
        axes[1].set_title(f'矢状切片 (Sagittal)\n切片: {mid_sagittal}/{self.data.shape[0]}')
        axes[1].axis('off')
        
        # 冠状切片 (Coronal)
        axes[2].imshow(self.data_norm[:, mid_coronal, :].T, cmap='gray', origin='lower')
        axes[2].set_title(f'冠状切片 (Coronal)\n切片: {mid_coronal}/{self.data.shape[1]}')
        axes[2].axis('off')
        
        plt.suptitle(f'NIfTI 查看器: {os.path.basename(self.nii_path)}', fontsize=14)
        plt.tight_layout()
        plt.show()
    
    def show_interactive(self, view='axial'):
        """显示可交互的切片浏览器"""
        fig, ax = plt.subplots(figsize=(10, 8))
        plt.subplots_adjust(bottom=0.15)
        
        if view == 'axial':
            max_slice = self.data.shape[2] - 1
            init_slice = max_slice // 2
            img_data = self.data_norm[:, :, init_slice].T
            view_name = '轴向 (Axial)'
        elif view == 'sagittal':
            max_slice = self.data.shape[0] - 1
            init_slice = max_slice // 2
            img_data = self.data_norm[init_slice, :, :].T
            view_name = '矢状 (Sagittal)'
        elif view == 'coronal':
            max_slice = self.data.shape[1] - 1
            init_slice = max_slice // 2
            img_data = self.data_norm[:, init_slice, :].T
            view_name = '冠状 (Coronal)'
        else:
            raise ValueError(f"不支持的视图: {view}")
        
        im = ax.imshow(img_data, cmap='gray', origin='lower')
        ax.set_title(f'{view_name} - 切片: {init_slice}/{max_slice}')
        ax.axis('off')
        
        # 添加滑块
        ax_slider = plt.axes([0.2, 0.05, 0.6, 0.03])
        slider = Slider(ax_slider, '切片', 0, max_slice, valinit=init_slice, valstep=1)
        
        def update(val):
            slice_idx = int(slider.val)
            if view == 'axial':
                img_data = self.data_norm[:, :, slice_idx].T
            elif view == 'sagittal':
                img_data = self.data_norm[slice_idx, :, :].T
            elif view == 'coronal':
                img_data = self.data_norm[:, slice_idx, :].T
            
            im.set_data(img_data)
            ax.set_title(f'{view_name} - 切片: {slice_idx}/{max_slice}')
            fig.canvas.draw_idle()
        
        slider.on_changed(update)
        
        plt.suptitle(f'文件: {os.path.basename(self.nii_path)}', fontsize=12)
        plt.show()


def main():
    parser = argparse.ArgumentParser(
        description='NIfTI (.nii / .nii.gz) 文件可视化工具',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
示例:
  python show_nii.py image.nii.gz
  python show_nii.py image.nii.gz --view axial
  python show_nii.py image.nii.gz --interactive --view sagittal
        """)
    
    parser.add_argument('nii_file', help='NIfTI 文件路径 (.nii 或 .nii.gz)')
    parser.add_argument('--interactive', '-i', action='store_true',
                        help='交互模式,使用滑块浏览切片')
    parser.add_argument('--view', '-v', choices=['axial', 'sagittal', 'coronal'],
                        default='axial', help='选择视图方向 (默认: axial)')
    
    args = parser.parse_args()
    
    try:
        viewer = NiftiViewer(args.nii_file)
        
        if args.interactive:
            viewer.show_interactive(view=args.view)
        else:
            viewer.show_slices()
            
    except Exception as e:
        print(f"错误: {e}", file=sys.stderr)
        sys.exit(1)


if __name__ == '__main__':
    main()

show_nii_3d.py

代码语言:python
代码运行次数:0
运行
复制
#!/usr/bin/env python3
"""
NIfTI 文件 3D 可视化脚本
用法: python3 show_nii_3d.py <nii_file_path>
例子:
# 方案一:matplotlib 3D 表面渲染
python3 show_nii_3d.py nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz --mode surface

# 方案一:散点图模式(更快)
python3 show_nii_3d.py nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz --mode scatter

"""

import argparse
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from skimage import measure

# 配置中文支持
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'STHeiti', 
                                     'Microsoft YaHei', 'PingFang HK', 'Heiti TC']
plt.rcParams['axes.unicode_minus'] = False


def show_3d_surface(nii_path, threshold=None, downsample=2):
    """
    使用等值面(isosurface)显示 3D 结构
    
    Args:
        nii_path: NIfTI 文件路径
        threshold: 等值面阈值(None 则自动选择)
        downsample: 降采样因子(减少数据量以提高速度)
    """
    print(f"加载文件: {nii_path}")
    img = nib.load(nii_path)
    data = img.get_fdata()
    
    print(f"原始数据形状: {data.shape}")
    print(f"数值范围: [{data.min():.2f}, {data.max():.2f}]")
    
    # 降采样以提高渲染速度
    if downsample > 1:
        data = data[::downsample, ::downsample, ::downsample]
        print(f"降采样后形状: {data.shape}")
    
    # 自动选择阈值
    if threshold is None:
        # 对于分割标签,使用 0.5
        if np.unique(data).size < 10:  # 可能是分割 mask
            threshold = 0.5
        else:
            # 对于灰度图像,使用均值
            threshold = np.mean(data) + 0.5 * np.std(data)
    
    print(f"使用阈值: {threshold}")
    
    # 生成等值面(marching cubes 算法)
    print("生成 3D 网格(这可能需要一些时间)...")
    try:
        verts, faces, normals, values = measure.marching_cubes(
            data, level=threshold, spacing=(1.0, 1.0, 1.0)
        )
    except Exception as e:
        print(f"错误: {e}")
        print("尝试调整阈值或检查数据...")
        return
    
    print(f"生成了 {len(verts)} 个顶点和 {len(faces)} 个面")
    
    # 创建 3D 图形
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')
    
    # 绘制 3D 表面
    mesh = ax.plot_trisurf(
        verts[:, 0], verts[:, 1], faces, verts[:, 2],
        cmap='viridis', alpha=0.8, linewidth=0, antialiased=True
    )
    
    # 设置标签和标题
    ax.set_xlabel('X 轴', fontsize=12)
    ax.set_ylabel('Y 轴', fontsize=12)
    ax.set_zlabel('Z 轴', fontsize=12)
    ax.set_title(f'3D 表面渲染\n阈值: {threshold:.2f}', fontsize=14, pad=20)
    
    # 添加颜色条
    fig.colorbar(mesh, ax=ax, shrink=0.5, aspect=5)
    
    plt.tight_layout()
    plt.show()


def show_3d_scatter(nii_path, threshold=None, max_points=50000):
    """
    使用散点图显示 3D 体素
    
    Args:
        nii_path: NIfTI 文件路径
        threshold: 显示阈值
        max_points: 最大显示点数
    """
    print(f"加载文件: {nii_path}")
    img = nib.load(nii_path)
    data = img.get_fdata()
    
    print(f"数据形状: {data.shape}")
    print(f"数值范围: [{data.min():.2f}, {data.max():.2f}]")
    
    # 自动选择阈值
    if threshold is None:
        if np.unique(data).size < 10:
            threshold = 0.5
        else:
            threshold = np.percentile(data, 70)  # 只显示前 30% 的高值
    
    print(f"使用阈值: {threshold}")
    
    # 找到所有超过阈值的体素
    mask = data > threshold
    coords = np.argwhere(mask)
    values = data[mask]
    
    print(f"找到 {len(coords)} 个体素")
    
    # 如果点太多,随机采样
    if len(coords) > max_points:
        print(f"随机采样到 {max_points} 个点")
        indices = np.random.choice(len(coords), max_points, replace=False)
        coords = coords[indices]
        values = values[indices]
    
    # 创建 3D 散点图
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')
    
    scatter = ax.scatter(
        coords[:, 0], coords[:, 1], coords[:, 2],
        c=values, cmap='hot', alpha=0.6, s=1
    )
    
    ax.set_xlabel('X 轴', fontsize=12)
    ax.set_ylabel('Y 轴', fontsize=12)
    ax.set_zlabel('Z 轴', fontsize=12)
    ax.set_title(f'3D 体素分布\n阈值: {threshold:.2f}', fontsize=14, pad=20)
    
    fig.colorbar(scatter, ax=ax, shrink=0.5, aspect=5)
    
    plt.tight_layout()
    plt.show()


def main():
    parser = argparse.ArgumentParser(
        description='NIfTI 文件 3D 可视化工具',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
示例:
  # 表面渲染(适合分割标签)
  python show_nii_3d.py image.nii.gz --mode surface
  
  # 散点图(快速预览)
  python show_nii_3d.py image.nii.gz --mode scatter
  
  # 自定义阈值
  python show_nii_3d.py image.nii.gz --mode surface --threshold 0.5
        """)
    
    parser.add_argument('nii_file', help='NIfTI 文件路径')
    parser.add_argument('--mode', '-m', choices=['surface', 'scatter'],
                        default='surface', help='显示模式(默认: surface)')
    parser.add_argument('--threshold', '-t', type=float, default=None,
                        help='阈值(默认自动选择)')
    parser.add_argument('--downsample', '-d', type=int, default=2,
                        help='降采样因子(仅 surface 模式,默认: 2)')
    parser.add_argument('--max-points', '-p', type=int, default=50000,
                        help='最大显示点数(仅 scatter 模式,默认: 50000)')
    
    args = parser.parse_args()
    
    try:
        if args.mode == 'surface':
            show_3d_surface(args.nii_file, args.threshold, args.downsample)
        else:
            show_3d_scatter(args.nii_file, args.threshold, args.max_points)
    except Exception as e:
        print(f"错误: {e}")
        import traceback
        traceback.print_exc()


if __name__ == '__main__':
    main()

show_nii_3d_interactive.py

代码语言:python
代码运行次数:0
运行
复制
#!/usr/bin/env python3
"""
NIfTI 文件交互式 3D 可视化(使用 Plotly)
用法: python3 show_nii_3d_interactive.py <nii_file_path>
例子:
python3 show_nii_3d_interactive.py nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz
"""

import argparse
import numpy as np
import nibabel as nib
import plotly.graph_objects as go
from skimage import measure


def show_interactive_3d(nii_path, threshold=None, downsample=2):
    """使用 Plotly 创建可交互的 3D 可视化"""
    print(f"加载文件: {nii_path}")
    img = nib.load(nii_path)
    data = img.get_fdata()
    
    print(f"原始数据形状: {data.shape}")
    print(f"数值范围: [{data.min():.2f}, {data.max():.2f}]")
    
    # 降采样
    if downsample > 1:
        data = data[::downsample, ::downsample, ::downsample]
        print(f"降采样后形状: {data.shape}")
    
    # 自动选择阈值
    if threshold is None:
        if np.unique(data).size < 10:
            threshold = 0.5
        else:
            threshold = np.mean(data) + 0.5 * np.std(data)
    
    print(f"使用阈值: {threshold}")
    print("生成 3D 网格...")
    
    # 生成等值面
    verts, faces, normals, values = measure.marching_cubes(
        data, level=threshold, spacing=(1.0, 1.0, 1.0)
    )
    
    print(f"生成了 {len(verts)} 个顶点和 {len(faces)} 个面")
    
    # 创建 3D mesh
    x, y, z = verts.T
    i, j, k = faces.T
    
    fig = go.Figure(data=[
        go.Mesh3d(
            x=x, y=y, z=z,
            i=i, j=j, k=k,
            opacity=0.8,
            colorscale='Viridis',
            intensity=z,  # 根据 z 坐标着色
            name='',
            showscale=True,
            hoverinfo='text',
            text=f'3D 结构<br>阈值: {threshold:.2f}'
        )
    ])
    
    # 设置布局
    fig.update_layout(
        title=f'交互式 3D 可视化<br><sub>阈值: {threshold:.2f}</sub>',
        scene=dict(
            xaxis_title='X 轴',
            yaxis_title='Y 轴',
            zaxis_title='Z 轴',
            aspectmode='data'
        ),
        width=1000,
        height=800,
    )
    
    print("打开交互式窗口...")
    print("提示: 可以用鼠标旋转、缩放、平移")
    fig.show()


def main():
    parser = argparse.ArgumentParser(description='NIfTI 文件交互式 3D 可视化')
    parser.add_argument('nii_file', help='NIfTI 文件路径')
    parser.add_argument('--threshold', '-t', type=float, default=None,
                        help='阈值(默认自动选择)')
    parser.add_argument('--downsample', '-d', type=int, default=2,
                        help='降采样因子(默认: 2)')
    
    args = parser.parse_args()
    
    try:
        show_interactive_3d(args.nii_file, args.threshold, args.downsample)
    except Exception as e:
        print(f"错误: {e}")
        import traceback
        traceback.print_exc()


if __name__ == '__main__':
    main()

show_nii_overlay_3d.py

代码语言:python
代码运行次数:0
运行
复制
#!/usr/bin/env python3
"""
NIfTI 文件原始图像和标签叠加的 3D 交互式可视化
用法: python3 show_nii_overlay_3d.py <image_file> <label_file>
例子:
python3 show_nii_overlay_3d.py \
    nnUNet_raw/Dataset004_Hippocampus/imagesTr/hippocampus_001_0000.nii.gz \
    nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz
"""

import argparse
import numpy as np
import nibabel as nib
import plotly.graph_objects as go
from skimage import measure


def create_mesh_data(data, threshold, downsample, color, name, opacity=0.8):
    """
    从3D数据创建mesh数据
    
    Args:
        data: 3D numpy数组
        threshold: 等值面阈值
        downsample: 降采样因子
        color: 网格颜色
        name: 网格名称
        opacity: 透明度
    
    Returns:
        plotly Mesh3d对象,如果失败返回None
    """
    # 降采样
    if downsample > 1:
        data = data[::downsample, ::downsample, ::downsample]
    
    print(f"  降采样后形状: {data.shape}")
    print(f"  数值范围: [{data.min():.2f}, {data.max():.2f}]")
    print(f"  使用阈值: {threshold}")
    
    # 检查数据是否有效
    if data.max() <= threshold:
        print(f"  警告: 最大值 {data.max():.2f} 小于等于阈值 {threshold:.2f},跳过此数据")
        return None
    
    print(f"  生成 3D 网格...")
    
    try:
        # 生成等值面
        verts, faces, normals, values = measure.marching_cubes(
            data, level=threshold, spacing=(1.0, 1.0, 1.0)
        )
        
        print(f"  生成了 {len(verts)} 个顶点和 {len(faces)} 个面")
        
        # 创建 3D mesh
        x, y, z = verts.T
        i, j, k = faces.T
        
        mesh = go.Mesh3d(
            x=x, y=y, z=z,
            i=i, j=j, k=k,
            opacity=opacity,
            color=color,
            name=name,
            hoverinfo='text',
            text=f'{name}<br>顶点数: {len(verts)}'
        )
        
        return mesh
        
    except Exception as e:
        print(f"  错误: {e}")
        return None


def show_overlay_3d(image_path, label_path, 
                    image_threshold=None, label_threshold=None,
                    downsample=2,
                    image_opacity=0.3, label_opacity=0.8,
                    image_color='lightblue', label_color='red'):
    """
    在同一个交互式3D图中显示原始图像和标签
    
    Args:
        image_path: 原始图像文件路径
        label_path: 标签文件路径
        image_threshold: 原始图像阈值
        label_threshold: 标签阈值
        downsample: 降采样因子
        image_opacity: 原始图像透明度
        label_opacity: 标签透明度
        image_color: 原始图像颜色
        label_color: 标签颜色
    """
    meshes = []
    
    # 加载并处理原始图像
    print(f"\n加载原始图像: {image_path}")
    img = nib.load(image_path)
    image_data = img.get_fdata()
    print(f"原始图像形状: {image_data.shape}")
    
    # 自动选择原始图像阈值
    if image_threshold is None:
        image_threshold = np.percentile(image_data[image_data > 0], 75)
        print(f"自动选择原始图像阈值: {image_threshold:.2f}")
    
    image_mesh = create_mesh_data(
        image_data, 
        image_threshold, 
        downsample, 
        image_color, 
        '原始图像',
        image_opacity
    )
    
    if image_mesh:
        meshes.append(image_mesh)
    
    # 加载并处理标签
    print(f"\n加载标签: {label_path}")
    label_img = nib.load(label_path)
    label_data = label_img.get_fdata()
    print(f"标签形状: {label_data.shape}")
    
    # 检查形状是否匹配
    if image_data.shape != label_data.shape:
        print(f"警告: 图像形状 {image_data.shape} 与标签形状 {label_data.shape} 不匹配")
    
    # 处理多标签的情况
    unique_labels = np.unique(label_data)
    unique_labels = unique_labels[unique_labels > 0]  # 排除背景
    print(f"发现 {len(unique_labels)} 个标签: {unique_labels}")
    
    # 为每个标签创建mesh
    colors = ['red', 'green', 'yellow', 'cyan', 'magenta', 'orange', 'purple', 'pink']
    
    for idx, label_val in enumerate(unique_labels):
        label_mask = (label_data == label_val).astype(float)
        
        if label_threshold is None:
            current_threshold = 0.5
        else:
            current_threshold = label_threshold
        
        current_color = colors[idx % len(colors)] if idx > 0 else label_color
        
        print(f"\n处理标签 {label_val}:")
        label_mesh = create_mesh_data(
            label_mask,
            current_threshold,
            downsample,
            current_color,
            f'标签 {int(label_val)}',
            label_opacity
        )
        
        if label_mesh:
            meshes.append(label_mesh)
    
    # 创建图形
    if not meshes:
        print("\n错误: 没有生成任何网格,请检查阈值设置")
        return
    
    print(f"\n创建交互式3D可视化(共 {len(meshes)} 个网格)...")
    fig = go.Figure(data=meshes)
    
    # 设置布局
    fig.update_layout(
        title='原始图像与标签的 3D 叠加可视化<br><sub>提示: 鼠标可旋转、缩放、平移</sub>',
        scene=dict(
            xaxis_title='X 轴',
            yaxis_title='Y 轴',
            zaxis_title='Z 轴',
            aspectmode='data',
            camera=dict(
                eye=dict(x=1.5, y=1.5, z=1.5)
            )
        ),
        width=1200,
        height=900,
        showlegend=True,
        legend=dict(
            x=0.02,
            y=0.98,
            bgcolor='rgba(255, 255, 255, 0.8)',
            bordercolor='black',
            borderwidth=1
        )
    )
    
    print("打开交互式窗口...")
    print("提示:")
    print("  - 鼠标左键拖动: 旋转")
    print("  - 鼠标滚轮: 缩放")
    print("  - 鼠标右键拖动: 平移")
    print("  - 点击图例可以隐藏/显示对应的网格")
    
    fig.show()


def main():
    parser = argparse.ArgumentParser(
        description='NIfTI 文件原始图像和标签的 3D 交互式叠加可视化',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
示例:
  # 基本用法
  python3 show_nii_overlay_3d.py image.nii.gz label.nii.gz
  
  # 自定义阈值
  python3 show_nii_overlay_3d.py image.nii.gz label.nii.gz --image-threshold 100 --label-threshold 0.5
  
  # 调整透明度
  python3 show_nii_overlay_3d.py image.nii.gz label.nii.gz --image-opacity 0.5 --label-opacity 0.9
  
  # 更改颜色
  python3 show_nii_overlay_3d.py image.nii.gz label.nii.gz --image-color blue --label-color yellow
  
  # 减少降采样以提高质量(更慢)
  python3 show_nii_overlay_3d.py image.nii.gz label.nii.gz --downsample 1
        """)
    
    parser.add_argument('image_file', help='原始图像 NIfTI 文件路径')
    parser.add_argument('label_file', help='标签 NIfTI 文件路径')
    
    parser.add_argument('--image-threshold', '-it', type=float, default=None,
                        help='原始图像阈值(默认自动选择第75百分位)')
    parser.add_argument('--label-threshold', '-lt', type=float, default=None,
                        help='标签阈值(默认: 0.5)')
    
    parser.add_argument('--downsample', '-d', type=int, default=2,
                        help='降采样因子,越大速度越快但质量越低(默认: 2)')
    
    parser.add_argument('--image-opacity', '-io', type=float, default=0.3,
                        help='原始图像透明度 (0-1)(默认: 0.3)')
    parser.add_argument('--label-opacity', '-lo', type=float, default=0.8,
                        help='标签透明度 (0-1)(默认: 0.8)')
    
    parser.add_argument('--image-color', '-ic', type=str, default='lightblue',
                        help='原始图像颜色(默认: lightblue)')
    parser.add_argument('--label-color', '-lc', type=str, default='red',
                        help='标签颜色(默认: red)')
    
    args = parser.parse_args()
    
    try:
        show_overlay_3d(
            args.image_file, 
            args.label_file,
            args.image_threshold,
            args.label_threshold,
            args.downsample,
            args.image_opacity,
            args.label_opacity,
            args.image_color,
            args.label_color
        )
    except Exception as e:
        print(f"\n错误: {e}")
        import traceback
        traceback.print_exc()


if __name__ == '__main__':
    main()

8. 常见问题

8.1 环境和安装问题

Q1: 安装 nnUNetv2 时报错?

A: 确保先安装 PyTorch,再安装 nnUNetv2:

代码语言:bash
复制
pip3 install torch torchvision torchaudio
pip3 install nnunetv2

Q2: 找不到环境变量?

A: 每次使用前确保设置了三个环境变量:

代码语言:bash
复制
export nnUNet_raw="/path/to/nnUNet_raw"
export nnUNet_preprocessed="/path/to/nnUNet_preprocessed"
export nnUNet_results="/path/to/nnUNet_results"

或在脚本中指定:

代码语言:bash
复制
#!/bin/bash
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
export nnUNet_raw="${SCRIPT_DIR}/nnUNet_raw"
export nnUNet_preprocessed="${SCRIPT_DIR}/nnUNet_preprocessed"
export nnUNet_results="${SCRIPT_DIR}/nnUNet_results"

# 后续命令...

8.2 数据集问题

Q3: dataset.json 格式不正确?

A: nnUNet v2 的 dataset.json 格式示例:

代码语言:json
复制
{
    "name": "Hippocampus",
    "description": "Left and right hippocampus segmentation",
    "reference": "Vanderbilt University Medical Center",
    "licence": "CC-BY-SA 4.0",
    "release": "1.0 04/05/2018",
    "channel_names": {
        "0": "MRI"
    },
    "labels": {
        "background": 0,
        "Anterior": 1,
        "Posterior": 2
    },
    "numTraining": 260,
    "file_ending": ".nii.gz"
}

Q4: 图像和标签不匹配?

A: 检查文件命名:

  • 图像:case_001_0000.nii.gz(必须有 _0000 后缀)
  • 标签:case_001.nii.gz(无后缀)

8.3 训练问题

Q5: 显存不足(CUDA out of memory)?

A: 尝试以下方法:

  1. 减小 batch size(修改 nnUNetPlans.json)
  2. 使用 2D 网络而非 3D
  3. 使用更小的 GPU(nnUNet 会自动调整)
  4. 启用混合精度训练(默认启用)

Q6: 训练速度很慢?

A:

  • 确保使用 GPU 而非 CPU
  • 检查数据是否在 SSD 上(而非机械硬盘)
  • 关闭不必要的后台程序

Q7: 如何恢复中断的训练?

A: nnUNet 会自动保存 checkpoint,直接重新运行相同的命令即可恢复:

代码语言:bash
复制
nnUNetv2_train 4 2d 0  # 会自动从 checkpoint 恢复

8.4 推理和结果问题

Q8: 推理结果全是背景?

A: 检查:

  1. 输入图像格式是否正确(需要 _0000.nii.gz 后缀)
  2. 是否使用了正确的模型和配置
  3. 图像强度范围是否异常

Q9: 如何评估模型性能?

A: 使用 nnUNet 内置评估工具:

代码语言:bash
复制
nnUNetv2_evaluate_folder \
    $nnUNet_results/Dataset004_Hippocampus/predictions \
    $nnUNet_raw/Dataset004_Hippocampus/labelsTs \
    -l 1 2  # 评估标签 1 和 2

会输出 Dice、IoU、95% Hausdorff 距离等指标。

8.5 自定义数据集

Q10: 如何准备自己的数据集?

A: 按照以下步骤:

  1. 创建目录结构
代码语言:bash
复制
mkdir -p $nnUNet_raw/Dataset999_MyDataset/{imagesTr,labelsTr,imagesTs}
  1. 准备图像和标签
    • 图像命名:case_001_0000.nii.gz
    • 标签命名:case_001.nii.gz
    • 确保图像和标签尺寸一致
  2. 创建 dataset.json
代码语言:json
复制
{
    "name": "MyDataset",
    "description": "My custom dataset",
    "channel_names": {
        "0": "CT"
    },
    "labels": {
        "background": 0,
        "target": 1
    },
    "numTraining": 100,
    "file_ending": ".nii.gz"
}
  1. 运行预处理和训练
代码语言:bash
复制
nnUNetv2_plan_and_preprocess -d 999 --verify_dataset_integrity
nnUNetv2_train 999 2d 0

总结

本教程涵盖了 nnUNet 的基础使用流程:

  1. ✅ 环境安装和配置
  2. ✅ 数据集下载和准备
  3. ✅ 数据预处理
  4. ✅ 模型训练(2D/3D)
  5. ✅ 模型推理
  6. ✅ 结果可视化

进阶学习资源

  • 官方文档https://github.com/MIC-DKFZ/nnUNet
  • 论文原文:Isensee et al., Nature Methods 2021
  • 视频教程:YouTube 搜索 "nnUNet tutorial"
  • 社区讨论:GitHub Issues

完整训练脚本示例

将以下内容保存为 train.sh

代码语言:bash
复制
#!/bin/bash

# 获取脚本所在目录
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"

# 设置环境变量
export nnUNet_raw="${SCRIPT_DIR}/nnUNet_raw"
export nnUNet_preprocessed="${SCRIPT_DIR}/nnUNet_preprocessed"
export nnUNet_results="${SCRIPT_DIR}/nnUNet_results"

# 1. 解压数据集
echo "==> 解压数据集..."
tar -xvf ./nnUNet_raw/Task04_Hippocampus.tar -C ./nnUNet_raw

# 2. 重命名数据集
echo "==> 重命名数据集..."
mv ./nnUNet_raw/Task04_Hippocampus ./nnUNet_raw/Task004_Hippocampus

# 3. 删除旧的 Dataset004(如果存在)
rm -rf ${nnUNet_raw}/Dataset004_Hippocampus

# 4. 转换为 nnUNet v2 格式
echo "==> 转换数据集格式..."
nnUNetv2_convert_old_nnUNet_dataset \
    ./nnUNet_raw/Task004_Hippocampus \
    Dataset004_Hippocampus

# 5. 数据预处理
echo "==> 数据预处理..."
nnUNetv2_plan_and_preprocess -d 4 --verify_dataset_integrity

# 6. 训练 2D 模型
echo "==> 开始训练 2D 模型..."
nnUNetv2_train -device mps 4 2d 0

# 7. (可选)训练 3D 模型
# echo "==> 开始训练 3D 模型..."
# nnUNetv2_train -device mps 4 3d_fullres 0

echo "==> 训练完成!"

运行脚本:

代码语言:bash
复制
chmod +x train.sh
./train.sh

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 目录
  • 1. nnU-Net 简介
    • 1.1 什么是 nnU-Net?
    • 1.2 核心优势
    • 1.3 应用场景
    • 1.4 工作流程
  • 2. 环境安装
    • 2.1 系统要求
    • 2.2 安装步骤
      • 步骤 1:创建虚拟环境(推荐)
      • 步骤 2:安装 PyTorch
      • 步骤 3:安装 nnUNet
      • 步骤 4:安装可视化依赖
      • 步骤 5:验证安装
    • 2.3 环境变量设置
  • 3. 数据集准备
    • 3.1 示例数据集:Task04_Hippocampus
      • 数据集信息
      • 标签说明
    • 3.2 数据集下载
      • 方法 1:官方下载地址
      • 方法 2:命令行下载(需要 gdown)
      • 方法 3:命令行下载(需要 wget)
    • 3.3 数据集目录结构
    • 3.4 数据格式转换(旧版本 → nnUNet v2)
  • 4. 数据预处理
    • 4.1 预处理步骤
    • 4.2 执行预处理
    • 4.3 预处理输出
    • 4.4 查看训练计划
  • 5. 模型训练
    • 5.1 训练命令
      • 5.1.1 2D 网络训练
      • 5.1.2 3D 全分辨率网络训练
      • 5.1.3 3D 低分辨率网络训练
      • 5.1.4 3D 级联网络训练
    • 5.2 GPU/CPU 设备选择
      • 使用特定 GPU
      • 使用多 GPU
      • 使用 macOS MPS(Apple Silicon)
      • 使用 CPU(不推荐,速度慢)
    • 5.3 训练过程监控
    • 5.4 训练时间估计
    • 5.5 训练所有交叉验证折
  • 6. 模型推理
    • 6.1 推理命令
    • 6.2 使用所有折的集成预测
    • 6.3 调整推理设置
      • 禁用测试时增强(TTA)
      • 保存概率图
  • 7. 结果可视化
    • 7.1 使用提供的可视化脚本
      • 7.1.1 快速查看(三视图)
      • 7.1.2 交互式浏览
      • 7.1.3 3D 表面渲染
      • 7.1.4 散点图模式
      • 7.1.5 交互式 3D
      • 7.1.6 3D整体效果
    • 7.2 可视化代码
      • show_nii.py
      • show_nii_3d.py
      • show_nii_3d_interactive.py
      • show_nii_overlay_3d.py
  • 8. 常见问题
    • 8.1 环境和安装问题
      • Q1: 安装 nnUNetv2 时报错?
      • Q2: 找不到环境变量?
    • 8.2 数据集问题
      • Q3: dataset.json 格式不正确?
      • Q4: 图像和标签不匹配?
    • 8.3 训练问题
      • Q5: 显存不足(CUDA out of memory)?
      • Q6: 训练速度很慢?
      • Q7: 如何恢复中断的训练?
    • 8.4 推理和结果问题
      • Q8: 推理结果全是背景?
      • Q9: 如何评估模型性能?
    • 8.5 自定义数据集
      • Q10: 如何准备自己的数据集?
  • 总结
    • 进阶学习资源
    • 完整训练脚本示例
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档