nnU-Net (no-new-UNet) 是一个用于医学图像分割的自适应深度学习框架,由德国癌症研究中心(DKFZ)开发。它最大的特点是自动配置,能够根据不同的数据集自动确定最优的网络架构、训练策略和数据增强方案。它对应的论文为:2020年12月《nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation》和2024年7月《nnU-Net Revisited: A Call for Rigorous Validation in 3D Medical Image Segmentation》
原始数据 → 数据格式转换 → 预处理 → 训练 → 推理 → 结果可视化
# 使用 conda
conda create -n nnunet python=3.10
conda activate nnunet
# 或使用 venv
python3 -m venv nnunet_env
source nnunet_env/bin/activate # Linux/macOS
# nnunet_env\Scripts\activate # Windows
根据您的系统和 CUDA 版本安装 PyTorch:
# CUDA 11.8 (推荐)
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# CUDA 12.1
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# CPU 版本(仅用于测试,不推荐训练)
pip3 install torch torchvision torchaudio
# macOS (支持 MPS)
pip3 install torch torchvision torchaudio
pip3 install nnunetv2
pip3 install nibabel matplotlib
python3 -c "import nnunetv2; print('nnU-Net v2 安装成功!')"
nnU-Net 需要三个环境变量来指定数据存储路径:
# 创建项目目录
mkdir -p ~/nnUNet_workspace
cd ~/nnUNet_workspace
# 设置环境变量
export nnUNet_raw="$PWD/nnUNet_raw"
export nnUNet_preprocessed="$PWD/nnUNet_preprocessed"
export nnUNet_results="$PWD/nnUNet_results"
# 建议添加到 ~/.bashrc 或 ~/.zshrc,使其永久生效
echo 'export nnUNet_raw="$HOME/nnUNet_workspace/nnUNet_raw"' >> ~/.bashrc
echo 'export nnUNet_preprocessed="$HOME/nnUNet_workspace/nnUNet_preprocessed"' >> ~/.bashrc
echo 'export nnUNet_results="$HOME/nnUNet_workspace/nnUNet_results"' >> ~/.bashrc
source ~/.bashrc
目录说明:
nnUNet_raw
:存放原始数据集nnUNet_preprocessed
:存放预处理后的数据nnUNet_results
:存放训练结果和模型权重本教程使用 Hippocampus(海马体分割) 数据集作为示例。(为什么选它,因为这个数据集比较小,只有28.4MB)
0
:背景(background)1
:海马体前部(Anterior)2
:海马体后部(Posterior)访问 Medical Segmentation Decathlon 官网:
https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2
在 其中找到 Task04_Hippocampus.tar
文件并下载。
# 安装 gdown
pip3 install gdown
# 下载数据集(文件ID可能会变化,请从官网获取最新链接)
gdown https://drive.google.com/uc?id=1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C
# 安装wget
sudo apt install wget
# 下载数据集(链接未来可能会失效)
wget https://msd-for-monai.s3-us-west-2.amazonaws.com/Task04_Hippocampus.tar -o Task04_Hippocampus.tar
下载后,将 Task04_Hippocampus.tar
解压到 nnUNet_raw
目录:
cd $nnUNet_raw
tar -xvf Task04_Hippocampus.tar
解压后的目录结构:
$nnUNet_raw/Task04_Hippocampus/
├── dataset.json # 数据集元数据
├── imagesTr/ # 训练图像 (260个 .nii.gz 文件)
│ ├── hippocampus_001.nii.gz
│ ├── hippocampus_002.nii.gz
│ └── ...
├── labelsTr/ # 训练标签 (260个 .nii.gz 文件)
│ ├── hippocampus_001.nii.gz
│ ├── hippocampus_002.nii.gz
│ └── ...
└── imagesTs/ # 测试图像 (130个 .nii.gz 文件)
├── hippocampus_267.nii.gz
└── ...
如果数据集是旧版 nnUNet 格式(如 Task04_Hippocampus),需要转换为 v2 格式:
# 重命名为标准格式(Task04 → Task004)
mv $nnUNet_raw/Task04_Hippocampus $nnUNet_raw/Task004_Hippocampus
# 使用转换工具
nnUNetv2_convert_old_nnUNet_dataset \
$nnUNet_raw/Task004_Hippocampus \
Dataset004_Hippocampus
转换后会生成新的数据集目录:
$nnUNet_raw/Dataset004_Hippocampus/
├── dataset.json
├── imagesTr/
│ ├── hippocampus_001_0000.nii.gz # 注意:文件名添加了 _0000 后缀
│ └── ...
├── labelsTr/
│ ├── hippocampus_001.nii.gz
│ └── ...
└── imagesTs/
├── hippocampus_267_0000.nii.gz
└── ...
说明:
_0000
后缀表示第 0 个通道(模态),多模态数据会有 _0001
, _0002
等dataset.json
包含数据集的元信息(标签、模态、文件列表等)nnU-Net 会自动执行以下预处理:
nnUNetv2_plan_and_preprocess -d 4 --verify_dataset_integrity
参数说明:
-d 4
:数据集 ID(对应 Dataset004_Hippocampus)--verify_dataset_integrity
:验证数据集完整性(检查文件是否缺失、标签是否正确等)预处理完成后,会在 $nnUNet_preprocessed/Dataset004_Hippocampus
生成以下文件:
$nnUNet_preprocessed/Dataset004_Hippocampus/
├── dataset_fingerprint.json # 数据集指纹(统计信息)
├── dataset.json # 数据集元数据
├── nnUNetPlans.json # 训练计划(网络配置)
├── splits_final.json # 交叉验证划分
├── gt_segmentations/ # Ground truth 分割(用于验证)
├── nnUNetPlans_2d/ # 2D 预处理数据
└── nnUNetPlans_3d_fullres/ # 3D 全分辨率预处理数据
cat $nnUNet_preprocessed/Dataset004_Hippocampus/nnUNetPlans.json
该文件包含:
nnUNet 支持多种训练配置:
nnUNetv2_train 4 2d 0
参数说明:
4
:数据集 ID2d
:网络配置(逐切片训练)0
:交叉验证 fold(0-4,共5折)nnUNetv2_train 4 3d_fullres 0
nnUNetv2_train 4 3d_lowres 0
# 第一阶段:低分辨率
nnUNetv2_train 4 3d_lowres 0
# 第二阶段:级联
nnUNetv2_train 4 3d_cascade_fullres 0
CUDA_VISIBLE_DEVICES=0 nnUNetv2_train 4 2d 0
CUDA_VISIBLE_DEVICES=0,1 nnUNetv2_train 4 2d 0
nnUNetv2_train -device mps 4 2d 0
nnUNetv2_train -device cpu 4 2d 0
训练日志会保存在:
$nnUNet_results/Dataset004_Hippocampus/nnUNetTrainer__nnUNetPlans__2d/fold_0/
├── training_log.txt # 训练日志
├── checkpoint_best.pth # 最佳模型权重
├── checkpoint_final.pth # 最终模型权重
├── progress.png # 训练曲线图
└── validation_raw/ # 验证集预测结果
实时查看训练日志:
tail -f $nnUNet_results/Dataset004_Hippocampus/nnUNetTrainer__nnUNetPlans__2d/fold_0/training_log.txt
为了获得最佳性能,建议训练所有 5 折:
# 自动训练所有折
nnUNetv2_train 4 2d all
# 或手动训练每一折
for fold in 0 1 2 3 4; do
nnUNetv2_train 4 2d $fold
done
使用训练好的模型对新数据进行预测:
nnUNetv2_predict \
-i $nnUNet_raw/Dataset004_Hippocampus/imagesTs \
-o $nnUNet_results/Dataset004_Hippocampus/predictions \
-d 4 \
-c 2d \
-f 0
参数说明:
-i
:输入图像目录-o
:输出分割结果目录-d
:数据集 ID-c
:配置(2d、3d_fullres等)-f
:使用哪个 fold 的模型(可指定多个:-f 0 1 2 3 4
)nnUNetv2_predict \
-i $nnUNet_raw/Dataset004_Hippocampus/imagesTs \
-o $nnUNet_results/Dataset004_Hippocampus/predictions_ensemble \
-d 4 \
-c 2d \
-f 0 1 2 3 4 # 使用所有5折进行集成
集成预测通常比单折预测效果更好。
nnUNetv2_predict -i INPUT -o OUTPUT -d 4 -c 2d -f 0 --disable_tta
nnUNetv2_predict -i INPUT -o OUTPUT -d 4 -c 2d -f 0 --save_probabilities
项目中提供了四个可视化脚本:
python3 show_nii.py $nnUNet_raw/Dataset004_Hippocampus/imagesTr/hippocampus_001_0000.nii.gz
显示轴向、矢状、冠状三个切面的中间切片。
python3 show_nii.py $nnUNet_raw/Dataset004_Hippocampus/imagesTr/hippocampus_001_0000.nii.gz -i --view axial
使用滑块逐层浏览切片。
视图选项:
axial
:轴向切片(水平切面)sagittal
:矢状切片(侧面切面)coronal
:冠状切片(正面切面)python3 show_nii_3d.py nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz --mode surface
python3 show_nii_3d.py nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz --mode scatter
python3 show_nii_3d_interactive.py nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz
可以使用鼠标旋转
python3 show_nii_overlay_3d.py \
nnUNet_raw/Dataset004_Hippocampus/imagesTr/hippocampus_001_0000.nii.gz \
nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz
在3D互动图中,你会看到:
#!/usr/bin/env python3
"""
NIfTI 文件可视化脚本
用法: python3 show_nii.py <nii_file_path>
# 三个切片视图解释
- 1. 轴向切片(Axial)
也叫横断面或水平切面
就像把人体水平切开,从头顶往下看
类似 CT/MRI 扫描时,躺在床上被一层层扫描的那个方向
可以看到左右对称的结构
- 2. 矢状切片(Sagittal)
也叫侧面切面
就像把人体从中间垂直切开,从左侧或右侧看
可以看到前后方向的结构(比如鼻子、脑、脊柱的前后关系)
- 3. 冠状切片(Coronal)
也叫额状切面或正面切面
就像把人体从前往后垂直切开,从正面或背面看
可以看到左右和上下的结构
# 为什么需要三个视图?
因为医学图像数据是 3D 的(有宽度、高度、深度),但屏幕是 2D 的。通过这三个互相垂直的切面,医生可以从不同角度全面观察器官的形态、病变的位置和大小。
在你的脚本中:
- 非交互模式:同时显示三个视图的中间切片,让你快速了解整个 3D 数据
- 交互模式(-i 参数):选择一个视图,然后用滑块逐层浏览所有切片
例子:
python3 show_nii.py nnUNet_raw/Dataset004_Hippocampus/imagesTr/hippocampus_001_0000.nii.gz
交互模式:
python3 show_nii.py nnUNet_raw/Dataset004_Hippocampus/imagesTr/hippocampus_001_0000.nii.gz -i
"""
import argparse
import sys
import os
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
# 配置 matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'STHeiti', 'Microsoft YaHei',
'PingFang HK', 'Heiti TC', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
class NiftiViewer:
def __init__(self, nii_path):
"""初始化 NIfTI 查看器"""
if not os.path.exists(nii_path):
raise FileNotFoundError(f"文件不存在: {nii_path}")
print(f"正在加载文件: {nii_path}")
self.nii_img = nib.load(nii_path)
self.data = self.nii_img.get_fdata()
self.nii_path = nii_path
print(f"数据形状: {self.data.shape}")
print(f"数据类型: {self.data.dtype}")
print(f"数值范围: [{self.data.min():.2f}, {self.data.max():.2f}]")
# 归一化数据以便更好地显示
self.data_norm = self._normalize_data(self.data)
def _normalize_data(self, data):
"""归一化数据到 [0, 1] 范围"""
data_min = data.min()
data_max = data.max()
if data_max - data_min > 0:
return (data - data_min) / (data_max - data_min)
else:
return data
def show_slices(self):
"""显示三个正交切面的中间切片"""
# 获取中间切片索引
mid_axial = self.data.shape[2] // 2
mid_sagittal = self.data.shape[0] // 2
mid_coronal = self.data.shape[1] // 2
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# 轴向切片 (Axial)
axes[0].imshow(self.data_norm[:, :, mid_axial].T, cmap='gray', origin='lower')
axes[0].set_title(f'轴向切片 (Axial)\n切片: {mid_axial}/{self.data.shape[2]}')
axes[0].axis('off')
# 矢状切片 (Sagittal)
axes[1].imshow(self.data_norm[mid_sagittal, :, :].T, cmap='gray', origin='lower')
axes[1].set_title(f'矢状切片 (Sagittal)\n切片: {mid_sagittal}/{self.data.shape[0]}')
axes[1].axis('off')
# 冠状切片 (Coronal)
axes[2].imshow(self.data_norm[:, mid_coronal, :].T, cmap='gray', origin='lower')
axes[2].set_title(f'冠状切片 (Coronal)\n切片: {mid_coronal}/{self.data.shape[1]}')
axes[2].axis('off')
plt.suptitle(f'NIfTI 查看器: {os.path.basename(self.nii_path)}', fontsize=14)
plt.tight_layout()
plt.show()
def show_interactive(self, view='axial'):
"""显示可交互的切片浏览器"""
fig, ax = plt.subplots(figsize=(10, 8))
plt.subplots_adjust(bottom=0.15)
if view == 'axial':
max_slice = self.data.shape[2] - 1
init_slice = max_slice // 2
img_data = self.data_norm[:, :, init_slice].T
view_name = '轴向 (Axial)'
elif view == 'sagittal':
max_slice = self.data.shape[0] - 1
init_slice = max_slice // 2
img_data = self.data_norm[init_slice, :, :].T
view_name = '矢状 (Sagittal)'
elif view == 'coronal':
max_slice = self.data.shape[1] - 1
init_slice = max_slice // 2
img_data = self.data_norm[:, init_slice, :].T
view_name = '冠状 (Coronal)'
else:
raise ValueError(f"不支持的视图: {view}")
im = ax.imshow(img_data, cmap='gray', origin='lower')
ax.set_title(f'{view_name} - 切片: {init_slice}/{max_slice}')
ax.axis('off')
# 添加滑块
ax_slider = plt.axes([0.2, 0.05, 0.6, 0.03])
slider = Slider(ax_slider, '切片', 0, max_slice, valinit=init_slice, valstep=1)
def update(val):
slice_idx = int(slider.val)
if view == 'axial':
img_data = self.data_norm[:, :, slice_idx].T
elif view == 'sagittal':
img_data = self.data_norm[slice_idx, :, :].T
elif view == 'coronal':
img_data = self.data_norm[:, slice_idx, :].T
im.set_data(img_data)
ax.set_title(f'{view_name} - 切片: {slice_idx}/{max_slice}')
fig.canvas.draw_idle()
slider.on_changed(update)
plt.suptitle(f'文件: {os.path.basename(self.nii_path)}', fontsize=12)
plt.show()
def main():
parser = argparse.ArgumentParser(
description='NIfTI (.nii / .nii.gz) 文件可视化工具',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
python show_nii.py image.nii.gz
python show_nii.py image.nii.gz --view axial
python show_nii.py image.nii.gz --interactive --view sagittal
""")
parser.add_argument('nii_file', help='NIfTI 文件路径 (.nii 或 .nii.gz)')
parser.add_argument('--interactive', '-i', action='store_true',
help='交互模式,使用滑块浏览切片')
parser.add_argument('--view', '-v', choices=['axial', 'sagittal', 'coronal'],
default='axial', help='选择视图方向 (默认: axial)')
args = parser.parse_args()
try:
viewer = NiftiViewer(args.nii_file)
if args.interactive:
viewer.show_interactive(view=args.view)
else:
viewer.show_slices()
except Exception as e:
print(f"错误: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == '__main__':
main()
#!/usr/bin/env python3
"""
NIfTI 文件 3D 可视化脚本
用法: python3 show_nii_3d.py <nii_file_path>
例子:
# 方案一:matplotlib 3D 表面渲染
python3 show_nii_3d.py nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz --mode surface
# 方案一:散点图模式(更快)
python3 show_nii_3d.py nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz --mode scatter
"""
import argparse
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from skimage import measure
# 配置中文支持
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'STHeiti',
'Microsoft YaHei', 'PingFang HK', 'Heiti TC']
plt.rcParams['axes.unicode_minus'] = False
def show_3d_surface(nii_path, threshold=None, downsample=2):
"""
使用等值面(isosurface)显示 3D 结构
Args:
nii_path: NIfTI 文件路径
threshold: 等值面阈值(None 则自动选择)
downsample: 降采样因子(减少数据量以提高速度)
"""
print(f"加载文件: {nii_path}")
img = nib.load(nii_path)
data = img.get_fdata()
print(f"原始数据形状: {data.shape}")
print(f"数值范围: [{data.min():.2f}, {data.max():.2f}]")
# 降采样以提高渲染速度
if downsample > 1:
data = data[::downsample, ::downsample, ::downsample]
print(f"降采样后形状: {data.shape}")
# 自动选择阈值
if threshold is None:
# 对于分割标签,使用 0.5
if np.unique(data).size < 10: # 可能是分割 mask
threshold = 0.5
else:
# 对于灰度图像,使用均值
threshold = np.mean(data) + 0.5 * np.std(data)
print(f"使用阈值: {threshold}")
# 生成等值面(marching cubes 算法)
print("生成 3D 网格(这可能需要一些时间)...")
try:
verts, faces, normals, values = measure.marching_cubes(
data, level=threshold, spacing=(1.0, 1.0, 1.0)
)
except Exception as e:
print(f"错误: {e}")
print("尝试调整阈值或检查数据...")
return
print(f"生成了 {len(verts)} 个顶点和 {len(faces)} 个面")
# 创建 3D 图形
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')
# 绘制 3D 表面
mesh = ax.plot_trisurf(
verts[:, 0], verts[:, 1], faces, verts[:, 2],
cmap='viridis', alpha=0.8, linewidth=0, antialiased=True
)
# 设置标签和标题
ax.set_xlabel('X 轴', fontsize=12)
ax.set_ylabel('Y 轴', fontsize=12)
ax.set_zlabel('Z 轴', fontsize=12)
ax.set_title(f'3D 表面渲染\n阈值: {threshold:.2f}', fontsize=14, pad=20)
# 添加颜色条
fig.colorbar(mesh, ax=ax, shrink=0.5, aspect=5)
plt.tight_layout()
plt.show()
def show_3d_scatter(nii_path, threshold=None, max_points=50000):
"""
使用散点图显示 3D 体素
Args:
nii_path: NIfTI 文件路径
threshold: 显示阈值
max_points: 最大显示点数
"""
print(f"加载文件: {nii_path}")
img = nib.load(nii_path)
data = img.get_fdata()
print(f"数据形状: {data.shape}")
print(f"数值范围: [{data.min():.2f}, {data.max():.2f}]")
# 自动选择阈值
if threshold is None:
if np.unique(data).size < 10:
threshold = 0.5
else:
threshold = np.percentile(data, 70) # 只显示前 30% 的高值
print(f"使用阈值: {threshold}")
# 找到所有超过阈值的体素
mask = data > threshold
coords = np.argwhere(mask)
values = data[mask]
print(f"找到 {len(coords)} 个体素")
# 如果点太多,随机采样
if len(coords) > max_points:
print(f"随机采样到 {max_points} 个点")
indices = np.random.choice(len(coords), max_points, replace=False)
coords = coords[indices]
values = values[indices]
# 创建 3D 散点图
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')
scatter = ax.scatter(
coords[:, 0], coords[:, 1], coords[:, 2],
c=values, cmap='hot', alpha=0.6, s=1
)
ax.set_xlabel('X 轴', fontsize=12)
ax.set_ylabel('Y 轴', fontsize=12)
ax.set_zlabel('Z 轴', fontsize=12)
ax.set_title(f'3D 体素分布\n阈值: {threshold:.2f}', fontsize=14, pad=20)
fig.colorbar(scatter, ax=ax, shrink=0.5, aspect=5)
plt.tight_layout()
plt.show()
def main():
parser = argparse.ArgumentParser(
description='NIfTI 文件 3D 可视化工具',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 表面渲染(适合分割标签)
python show_nii_3d.py image.nii.gz --mode surface
# 散点图(快速预览)
python show_nii_3d.py image.nii.gz --mode scatter
# 自定义阈值
python show_nii_3d.py image.nii.gz --mode surface --threshold 0.5
""")
parser.add_argument('nii_file', help='NIfTI 文件路径')
parser.add_argument('--mode', '-m', choices=['surface', 'scatter'],
default='surface', help='显示模式(默认: surface)')
parser.add_argument('--threshold', '-t', type=float, default=None,
help='阈值(默认自动选择)')
parser.add_argument('--downsample', '-d', type=int, default=2,
help='降采样因子(仅 surface 模式,默认: 2)')
parser.add_argument('--max-points', '-p', type=int, default=50000,
help='最大显示点数(仅 scatter 模式,默认: 50000)')
args = parser.parse_args()
try:
if args.mode == 'surface':
show_3d_surface(args.nii_file, args.threshold, args.downsample)
else:
show_3d_scatter(args.nii_file, args.threshold, args.max_points)
except Exception as e:
print(f"错误: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
main()
#!/usr/bin/env python3
"""
NIfTI 文件交互式 3D 可视化(使用 Plotly)
用法: python3 show_nii_3d_interactive.py <nii_file_path>
例子:
python3 show_nii_3d_interactive.py nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz
"""
import argparse
import numpy as np
import nibabel as nib
import plotly.graph_objects as go
from skimage import measure
def show_interactive_3d(nii_path, threshold=None, downsample=2):
"""使用 Plotly 创建可交互的 3D 可视化"""
print(f"加载文件: {nii_path}")
img = nib.load(nii_path)
data = img.get_fdata()
print(f"原始数据形状: {data.shape}")
print(f"数值范围: [{data.min():.2f}, {data.max():.2f}]")
# 降采样
if downsample > 1:
data = data[::downsample, ::downsample, ::downsample]
print(f"降采样后形状: {data.shape}")
# 自动选择阈值
if threshold is None:
if np.unique(data).size < 10:
threshold = 0.5
else:
threshold = np.mean(data) + 0.5 * np.std(data)
print(f"使用阈值: {threshold}")
print("生成 3D 网格...")
# 生成等值面
verts, faces, normals, values = measure.marching_cubes(
data, level=threshold, spacing=(1.0, 1.0, 1.0)
)
print(f"生成了 {len(verts)} 个顶点和 {len(faces)} 个面")
# 创建 3D mesh
x, y, z = verts.T
i, j, k = faces.T
fig = go.Figure(data=[
go.Mesh3d(
x=x, y=y, z=z,
i=i, j=j, k=k,
opacity=0.8,
colorscale='Viridis',
intensity=z, # 根据 z 坐标着色
name='',
showscale=True,
hoverinfo='text',
text=f'3D 结构<br>阈值: {threshold:.2f}'
)
])
# 设置布局
fig.update_layout(
title=f'交互式 3D 可视化<br><sub>阈值: {threshold:.2f}</sub>',
scene=dict(
xaxis_title='X 轴',
yaxis_title='Y 轴',
zaxis_title='Z 轴',
aspectmode='data'
),
width=1000,
height=800,
)
print("打开交互式窗口...")
print("提示: 可以用鼠标旋转、缩放、平移")
fig.show()
def main():
parser = argparse.ArgumentParser(description='NIfTI 文件交互式 3D 可视化')
parser.add_argument('nii_file', help='NIfTI 文件路径')
parser.add_argument('--threshold', '-t', type=float, default=None,
help='阈值(默认自动选择)')
parser.add_argument('--downsample', '-d', type=int, default=2,
help='降采样因子(默认: 2)')
args = parser.parse_args()
try:
show_interactive_3d(args.nii_file, args.threshold, args.downsample)
except Exception as e:
print(f"错误: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
main()
#!/usr/bin/env python3
"""
NIfTI 文件原始图像和标签叠加的 3D 交互式可视化
用法: python3 show_nii_overlay_3d.py <image_file> <label_file>
例子:
python3 show_nii_overlay_3d.py \
nnUNet_raw/Dataset004_Hippocampus/imagesTr/hippocampus_001_0000.nii.gz \
nnUNet_raw/Dataset004_Hippocampus/labelsTr/hippocampus_001.nii.gz
"""
import argparse
import numpy as np
import nibabel as nib
import plotly.graph_objects as go
from skimage import measure
def create_mesh_data(data, threshold, downsample, color, name, opacity=0.8):
"""
从3D数据创建mesh数据
Args:
data: 3D numpy数组
threshold: 等值面阈值
downsample: 降采样因子
color: 网格颜色
name: 网格名称
opacity: 透明度
Returns:
plotly Mesh3d对象,如果失败返回None
"""
# 降采样
if downsample > 1:
data = data[::downsample, ::downsample, ::downsample]
print(f" 降采样后形状: {data.shape}")
print(f" 数值范围: [{data.min():.2f}, {data.max():.2f}]")
print(f" 使用阈值: {threshold}")
# 检查数据是否有效
if data.max() <= threshold:
print(f" 警告: 最大值 {data.max():.2f} 小于等于阈值 {threshold:.2f},跳过此数据")
return None
print(f" 生成 3D 网格...")
try:
# 生成等值面
verts, faces, normals, values = measure.marching_cubes(
data, level=threshold, spacing=(1.0, 1.0, 1.0)
)
print(f" 生成了 {len(verts)} 个顶点和 {len(faces)} 个面")
# 创建 3D mesh
x, y, z = verts.T
i, j, k = faces.T
mesh = go.Mesh3d(
x=x, y=y, z=z,
i=i, j=j, k=k,
opacity=opacity,
color=color,
name=name,
hoverinfo='text',
text=f'{name}<br>顶点数: {len(verts)}'
)
return mesh
except Exception as e:
print(f" 错误: {e}")
return None
def show_overlay_3d(image_path, label_path,
image_threshold=None, label_threshold=None,
downsample=2,
image_opacity=0.3, label_opacity=0.8,
image_color='lightblue', label_color='red'):
"""
在同一个交互式3D图中显示原始图像和标签
Args:
image_path: 原始图像文件路径
label_path: 标签文件路径
image_threshold: 原始图像阈值
label_threshold: 标签阈值
downsample: 降采样因子
image_opacity: 原始图像透明度
label_opacity: 标签透明度
image_color: 原始图像颜色
label_color: 标签颜色
"""
meshes = []
# 加载并处理原始图像
print(f"\n加载原始图像: {image_path}")
img = nib.load(image_path)
image_data = img.get_fdata()
print(f"原始图像形状: {image_data.shape}")
# 自动选择原始图像阈值
if image_threshold is None:
image_threshold = np.percentile(image_data[image_data > 0], 75)
print(f"自动选择原始图像阈值: {image_threshold:.2f}")
image_mesh = create_mesh_data(
image_data,
image_threshold,
downsample,
image_color,
'原始图像',
image_opacity
)
if image_mesh:
meshes.append(image_mesh)
# 加载并处理标签
print(f"\n加载标签: {label_path}")
label_img = nib.load(label_path)
label_data = label_img.get_fdata()
print(f"标签形状: {label_data.shape}")
# 检查形状是否匹配
if image_data.shape != label_data.shape:
print(f"警告: 图像形状 {image_data.shape} 与标签形状 {label_data.shape} 不匹配")
# 处理多标签的情况
unique_labels = np.unique(label_data)
unique_labels = unique_labels[unique_labels > 0] # 排除背景
print(f"发现 {len(unique_labels)} 个标签: {unique_labels}")
# 为每个标签创建mesh
colors = ['red', 'green', 'yellow', 'cyan', 'magenta', 'orange', 'purple', 'pink']
for idx, label_val in enumerate(unique_labels):
label_mask = (label_data == label_val).astype(float)
if label_threshold is None:
current_threshold = 0.5
else:
current_threshold = label_threshold
current_color = colors[idx % len(colors)] if idx > 0 else label_color
print(f"\n处理标签 {label_val}:")
label_mesh = create_mesh_data(
label_mask,
current_threshold,
downsample,
current_color,
f'标签 {int(label_val)}',
label_opacity
)
if label_mesh:
meshes.append(label_mesh)
# 创建图形
if not meshes:
print("\n错误: 没有生成任何网格,请检查阈值设置")
return
print(f"\n创建交互式3D可视化(共 {len(meshes)} 个网格)...")
fig = go.Figure(data=meshes)
# 设置布局
fig.update_layout(
title='原始图像与标签的 3D 叠加可视化<br><sub>提示: 鼠标可旋转、缩放、平移</sub>',
scene=dict(
xaxis_title='X 轴',
yaxis_title='Y 轴',
zaxis_title='Z 轴',
aspectmode='data',
camera=dict(
eye=dict(x=1.5, y=1.5, z=1.5)
)
),
width=1200,
height=900,
showlegend=True,
legend=dict(
x=0.02,
y=0.98,
bgcolor='rgba(255, 255, 255, 0.8)',
bordercolor='black',
borderwidth=1
)
)
print("打开交互式窗口...")
print("提示:")
print(" - 鼠标左键拖动: 旋转")
print(" - 鼠标滚轮: 缩放")
print(" - 鼠标右键拖动: 平移")
print(" - 点击图例可以隐藏/显示对应的网格")
fig.show()
def main():
parser = argparse.ArgumentParser(
description='NIfTI 文件原始图像和标签的 3D 交互式叠加可视化',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 基本用法
python3 show_nii_overlay_3d.py image.nii.gz label.nii.gz
# 自定义阈值
python3 show_nii_overlay_3d.py image.nii.gz label.nii.gz --image-threshold 100 --label-threshold 0.5
# 调整透明度
python3 show_nii_overlay_3d.py image.nii.gz label.nii.gz --image-opacity 0.5 --label-opacity 0.9
# 更改颜色
python3 show_nii_overlay_3d.py image.nii.gz label.nii.gz --image-color blue --label-color yellow
# 减少降采样以提高质量(更慢)
python3 show_nii_overlay_3d.py image.nii.gz label.nii.gz --downsample 1
""")
parser.add_argument('image_file', help='原始图像 NIfTI 文件路径')
parser.add_argument('label_file', help='标签 NIfTI 文件路径')
parser.add_argument('--image-threshold', '-it', type=float, default=None,
help='原始图像阈值(默认自动选择第75百分位)')
parser.add_argument('--label-threshold', '-lt', type=float, default=None,
help='标签阈值(默认: 0.5)')
parser.add_argument('--downsample', '-d', type=int, default=2,
help='降采样因子,越大速度越快但质量越低(默认: 2)')
parser.add_argument('--image-opacity', '-io', type=float, default=0.3,
help='原始图像透明度 (0-1)(默认: 0.3)')
parser.add_argument('--label-opacity', '-lo', type=float, default=0.8,
help='标签透明度 (0-1)(默认: 0.8)')
parser.add_argument('--image-color', '-ic', type=str, default='lightblue',
help='原始图像颜色(默认: lightblue)')
parser.add_argument('--label-color', '-lc', type=str, default='red',
help='标签颜色(默认: red)')
args = parser.parse_args()
try:
show_overlay_3d(
args.image_file,
args.label_file,
args.image_threshold,
args.label_threshold,
args.downsample,
args.image_opacity,
args.label_opacity,
args.image_color,
args.label_color
)
except Exception as e:
print(f"\n错误: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
main()
A: 确保先安装 PyTorch,再安装 nnUNetv2:
pip3 install torch torchvision torchaudio
pip3 install nnunetv2
A: 每次使用前确保设置了三个环境变量:
export nnUNet_raw="/path/to/nnUNet_raw"
export nnUNet_preprocessed="/path/to/nnUNet_preprocessed"
export nnUNet_results="/path/to/nnUNet_results"
或在脚本中指定:
#!/bin/bash
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
export nnUNet_raw="${SCRIPT_DIR}/nnUNet_raw"
export nnUNet_preprocessed="${SCRIPT_DIR}/nnUNet_preprocessed"
export nnUNet_results="${SCRIPT_DIR}/nnUNet_results"
# 后续命令...
A: nnUNet v2 的 dataset.json 格式示例:
{
"name": "Hippocampus",
"description": "Left and right hippocampus segmentation",
"reference": "Vanderbilt University Medical Center",
"licence": "CC-BY-SA 4.0",
"release": "1.0 04/05/2018",
"channel_names": {
"0": "MRI"
},
"labels": {
"background": 0,
"Anterior": 1,
"Posterior": 2
},
"numTraining": 260,
"file_ending": ".nii.gz"
}
A: 检查文件命名:
case_001_0000.nii.gz
(必须有 _0000
后缀)case_001.nii.gz
(无后缀)A: 尝试以下方法:
A:
A: nnUNet 会自动保存 checkpoint,直接重新运行相同的命令即可恢复:
nnUNetv2_train 4 2d 0 # 会自动从 checkpoint 恢复
A: 检查:
_0000.nii.gz
后缀)A: 使用 nnUNet 内置评估工具:
nnUNetv2_evaluate_folder \
$nnUNet_results/Dataset004_Hippocampus/predictions \
$nnUNet_raw/Dataset004_Hippocampus/labelsTs \
-l 1 2 # 评估标签 1 和 2
会输出 Dice、IoU、95% Hausdorff 距离等指标。
A: 按照以下步骤:
mkdir -p $nnUNet_raw/Dataset999_MyDataset/{imagesTr,labelsTr,imagesTs}
case_001_0000.nii.gz
case_001.nii.gz
{
"name": "MyDataset",
"description": "My custom dataset",
"channel_names": {
"0": "CT"
},
"labels": {
"background": 0,
"target": 1
},
"numTraining": 100,
"file_ending": ".nii.gz"
}
nnUNetv2_plan_and_preprocess -d 999 --verify_dataset_integrity
nnUNetv2_train 999 2d 0
本教程涵盖了 nnUNet 的基础使用流程:
将以下内容保存为 train.sh
:
#!/bin/bash
# 获取脚本所在目录
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
# 设置环境变量
export nnUNet_raw="${SCRIPT_DIR}/nnUNet_raw"
export nnUNet_preprocessed="${SCRIPT_DIR}/nnUNet_preprocessed"
export nnUNet_results="${SCRIPT_DIR}/nnUNet_results"
# 1. 解压数据集
echo "==> 解压数据集..."
tar -xvf ./nnUNet_raw/Task04_Hippocampus.tar -C ./nnUNet_raw
# 2. 重命名数据集
echo "==> 重命名数据集..."
mv ./nnUNet_raw/Task04_Hippocampus ./nnUNet_raw/Task004_Hippocampus
# 3. 删除旧的 Dataset004(如果存在)
rm -rf ${nnUNet_raw}/Dataset004_Hippocampus
# 4. 转换为 nnUNet v2 格式
echo "==> 转换数据集格式..."
nnUNetv2_convert_old_nnUNet_dataset \
./nnUNet_raw/Task004_Hippocampus \
Dataset004_Hippocampus
# 5. 数据预处理
echo "==> 数据预处理..."
nnUNetv2_plan_and_preprocess -d 4 --verify_dataset_integrity
# 6. 训练 2D 模型
echo "==> 开始训练 2D 模型..."
nnUNetv2_train -device mps 4 2d 0
# 7. (可选)训练 3D 模型
# echo "==> 开始训练 3D 模型..."
# nnUNetv2_train -device mps 4 3d_fullres 0
echo "==> 训练完成!"
运行脚本:
chmod +x train.sh
./train.sh
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。