手撕OCC-NeRF:Occlusion-Free Scene Recovery via Neural Radiance Fields

链接:OCC-NeRF: Occlusion-Free Scene Recovery via Neural Radiance Fields

文件夹目录

occ-nerf/
├── .gitignore
├── LICENSE
├── README.md
├── environment.yml
├── local1.txt
├── dataloader/
│   ├── any_folder.py
│   ├── local_save.py
│   ├── with_colmap.py
│   ├── with_feature.py
│   ├── with_feature_colmap.py
│   └── with_mask.py
├── models/
│   ├── depth_decoder.py
│   ├── intrinsics.py
│   ├── layers.py
│   ├── nerf_feature.py
│   ├── nerf_mask.py
│   ├── nerf_models.py
│   └── poses.py
├── utils/
│   ├── align_traj.py
│   ├── comp_ate.py
│   ├── comp_ray_dir.py
│   ├── lie_group_helper.py
│   ├── pos_enc.py
│   ├── pose_utils.py
│   ├── split_dataset.py
│   ├── training_utils.py
│   ├── vgg.py
│   ├── vis_cam_traj.py
│   └── volume_op.py
├── tasks/
│   └── ...
└── third_party/
    ├── ATE/
    │   └── README.md
    └── pytorch_ssim/

DEBUG 代码

dataloader

├── dataloader/
│   ├── any_folder.py
│   ├── local_save.py
│   ├── with_colmap.py
│   ├── with_feature.py
│   ├── with_feature_colmap.py
│   └── with_mask.py

any_folder.py

import os                                       # 操作系统接口模块
import torch                                    # PyTorch 深度学习框架
import numpy as np                              # 科学计算库
from tqdm import tqdm                           # 进度条显示模块
import imageio                                  # 图像 IO 处理库
from dataloader.with_colmap import resize_imgs  # 自定义图像缩放函数

def load_imgs(image_dir, num_img_to_load, start, end, skip, load_sorted, load_img):
    img_names = np.array(sorted(os.listdir(image_dir)))  # 获取并排序目录下所有文件名
    
    if end == -1:                                 # 从 start 开始按间隔 skip 取图
        img_names = img_names[start::skip]
    else:                                         # 取 start 到 end 区间按间隔 skip 取图
        img_names = img_names[start:end:skip]
    
    if not load_sorted:                           # 是否打乱图像顺序
        np.random.shuffle(img_names)
    
    if num_img_to_load > len(img_names):          # 检查请求数量是否超出范围
        print(f'图像请求数{num_img_to_load}超过可用数{len(img_names)}')
        exit()
    elif num_img_to_load == -1:                   # 加载全部可用图像
        print(f'加载全部{len(img_names)}张图像')
    else:                                         # 截取指定数量的图像
        print(f'从{len(img_names)}张中加载{num_img_to_load}张')
        img_names = img_names[:num_img_to_load]
    
    img_paths = [os.path.join(image_dir, n) for n in img_names]  # 构建完整文件路径
    N_imgs = len(img_paths)                         # 计算实际加载数量
    
    img_list = []
    if load_img:                                     # 实际加载图像数据
        for p in tqdm(img_paths):
            img = imageio.imread(p)[:, :, :3]        # 读取 RGB 三通道图像
            img_list.append(img)
        img_list = np.stack(img_list)                # 堆叠为 4D 数组
        img_list = torch.from_numpy(img_list).float() / 255  # 转换为浮点张量并归一化
        H, W = img_list.shape[1], img_list.shape[2]  # 获取图像尺寸
    else:                                            # 仅获取图像尺寸
        tmp_img = imageio.imread(img_paths[0])
        H, W = tmp_img.shape[0], tmp_img.shape[1]
    
    return {                                         # 返回结构化数据
        'imgs': img_list,        # 图像张量 (N, H, W, 3)
        'img_names': img_names,  # 图像文件名数组
        'N_imgs': N_imgs,        # 总图像数
        'H': H,                  # 图像高度
        'W': W,                  # 图像宽度
    }

class DataLoaderAnyFolder:
    def __init__(self, base_dir, scene_name, res_ratio, num_img_to_load, 
                 start, end, skip, load_sorted, load_img=True):  # 初始化参数
        self.base_dir = base_dir                  # 数据根目录
        self.scene_name = scene_name              # 场景名称
        self.res_ratio = res_ratio                # 分辨率缩放比例
        self.num_img_to_load = num_img_to_load    # 最大加载数量
        self.start = start                        # 起始索引
        self.end = end                            # 结束索引
        self.skip = skip                          # 采样间隔
        self.load_sorted = load_sorted            # 是否保持顺序
        self.load_img = load_img                  # 是否实际加载图像
        
        self.imgs_dir = os.path.join(self.base_dir, self.scene_name)  # 构建图像目录路径
        
        image_data = load_imgs(self.imgs_dir, self.num_img_to_load,  # 加载图像数据
                              self.start, self.end, self.skip,
                              self.load_sorted, self.load_img)
        
        self.imgs = image_data['imgs']             # 图像张量
        self.img_names = image_data['img_names']   # 文件名列表
        self.N_imgs = image_data['N_imgs']         # 图像总数
        self.ori_H = image_data['H']               # 原始高度
        self.ori_W = image_data['W']               # 原始宽度
        
        self.near = 0.0                            # 近裁剪面(NDC 坐标系)
        self.far = 1.0                             # 远裁剪面(NDC 坐标系)
        
        if self.res_ratio > 1:                     # 计算实际使用分辨率
            self.H = self.ori_H // self.res_ratio
            self.W = self.ori_W // self.res_ratio
        else:
            self.H = self.ori_H
            self.W = self.ori_W
        
        if self.load_img:                          # 执行图像缩放
            self.imgs = resize_imgs(self.imgs, self.H, self.W)

if __name__ == '__main__':
    base_dir = '/your/data/path'                   # 数据根目录配置示例
    scene_name = 'LLFF/fern/images'                # 场景路径配置示例
    resize_ratio = 8                               # 缩放比例配置
    num_img_to_load = -1                           # 加载全部图像
    start, end, skip = 0, -1, 1                    # 采样参数初始化
    load_sorted, load_img = True, True             # 加载配置参数
    
    scene = DataLoaderAnyFolder(                   # 创建数据加载实例
        base_dir=base_dir,
        scene_name=scene_name,
        res_ratio=resize_ratio,
        num_img_to_load=num_img_to_load,
        start=start,
        end=end,
        skip=skip,
        load_sorted=load_sorted,
        load_img=load_img)

local_save.py

import os
import torch
import numpy as np
from tqdm import tqdm
import imageio

from dataloader.with_colmap import resize_imgs
from torchvision import models
from torch.nn import functional as F

@torch.no_grad()
class Vgg19(torch.nn.Module):
    def __init__(self, requires_grad=False):
        # 调用父类的构造函数
        super().__init__()
        # 加载预训练的 VGG19 模型的特征提取部分
        self.vgg_pretrained_features = models.vgg19(pretrained=True).features
        # 如果不需要计算梯度,则将模型参数的 requires_grad 属性设置为 False
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False
        # 初始化特征图的形状为 None
        self.feature_shape = None

    def forward(self, X, indices=None):
        # 记录输入特征图的形状
        self.feature_shape = X.shape
        # 如果没有指定索引,则默认使用 [7, 25]
        if indices is None:
            indices = [7,25]
        # 存储提取的特征图
        out = []
        # 遍历到最后一个索引位置
        for i in range(indices[-1]):
            # 通过 VGG19 的第 i 层进行特征提取
            X = self.vgg_pretrained_features[i](X)
            # 如果当前层的索引加 1 在指定的索引列表中
            if (i+1) in indices:
                if self.feature_shape is None:
                    # 如果特征图形状为 None,则记录当前特征图的形状
                    self.feature_shape = X.shape
                else:
                    # 对特征图进行双线性插值,使其尺寸与输入特征图的尺寸一致
                    X = F.interpolate(X,self.feature_shape[-2:],mode='bilinear',align_corners=True)
                # 将处理后的特征图添加到输出列表中
                out.append(X)
        # 将所有提取的特征图在通道维度上拼接起来
        return torch.cat(out,1)

def load_imgs(image_dir, num_img_to_load, start, end, skip, load_sorted, load_img):
    # 获取图像目录下所有图像的文件名,并按字母顺序排序
    img_names = np.array(sorted(os.listdir(image_dir))) # all image names

    # 时间域下采样:根据 start、end 和 skip 参数选择图像
    if end == -1:
        img_names = img_names[start::skip]
    else:
        img_names = img_names[start:end:skip]

    # 如果不按顺序加载图像,则对图像文件名进行随机打乱
    if not load_sorted:
        np.random.shuffle(img_names)

    # 检查要加载的图像数量是否超过可用图像数量
    if num_img_to_load > len(img_names):
        print('Asked for {0:6d} images but only {1:6d} available. Exit.'.format(num_img_to_load, len(img_names)))
        exit()
    elif num_img_to_load == -1:
        print('Loading all available {0:6d} images'.format(len(img_names)))
    else:
        print('Loading {0:6d} images out of {1:6d} images.'.format(num_img_to_load, len(img_names)))
        # 截取前 num_img_to_load 个图像文件名
        img_names = img_names[:num_img_to_load]

    # 构建图像文件的完整路径
    img_paths = [os.path.join(image_dir, n) for n in img_names]
    # 图像的数量
    N_imgs = len(img_paths)

    # 存储加载的图像
    img_list = []
    if load_img:
        # 使用 tqdm 显示加载进度
        for p in tqdm(img_paths):
            # 读取图像并截取前三个通道(RGB)
            img = imageio.imread(p)[:, :, :3] # (H, W, 3) np.uint8
            # 将图像添加到列表中
            img_list.append(img)
        # 将图像列表转换为 numpy 数组
        img_list = np.stack(img_list) # (N, H, W, 3)
        # 将 numpy 数组转换为 PyTorch 张量,并将像素值归一化到 [0, 1] 范围
        img_list = torch.from_numpy(img_list).float() / 255 # (N, H, W, 3) torch.float32
        # 获取图像的高度和宽度
        H, W = img_list.shape[1], img_list.shape[2]
    else:
        # 如果不加载图像,则读取第一张图像以获取图像的高度和宽度
        tmp_img = imageio.imread(img_paths[0]) # load one image to get H, W
        H, W = tmp_img.shape[0], tmp_img.shape[1]

    # 存储加载图像的相关信息
    results = {
        'imgs': img_list, # (N, H, W, 3) torch.float32
        'img_names': img_names, # (N, )
        'N_imgs': N_imgs,
        'H': H,
        'W': W,
    }

    return results


class DataLoaderAnyFolder:
    """
    Most useful fields:
    self.c2ws: (N_imgs, 4, 4) torch.float32
    self.imgs (N_imgs, H, W, 4) torch.float32
    self.ray_dir_cam (H, W, 3) torch.float32
    self.H scalar
    self.W scalar
    self.N_imgs scalar
    """
    def __init__(self, base_dir, scene_name, res_ratio, num_img_to_load, start, end, skip, load_sorted, load_img=True, device='cpu'):
        """
        :param base_dir: 数据的基础目录
        :param scene_name: 场景的名称
        :param res_ratio: 整数,如 [1, 2, 4] 等,用于将图像调整为较低的分辨率。
        :param start/end/skip: 用于在时间域上控制帧的加载。
        :param load_sorted: 布尔值,是否按顺序加载图像。
        :param load_img: 布尔值。如果设置为 False:仅统计图像数量、获取图像的高度和宽度,
                         但不加载图像。在可视化位姿或调试等情况下很有用。
        """
        self.base_dir = base_dir
        self.scene_name = scene_name
        self.res_ratio = res_ratio
        self.num_img_to_load = num_img_to_load
        self.start = start
        self.end = end
        self.skip = skip
        self.load_sorted = load_sorted
        self.load_img = load_img

        # 构建图像目录的完整路径
        self.imgs_dir = os.path.join(self.base_dir, self.scene_name)

        # 调用 load_imgs 函数加载图像
        image_data = load_imgs(self.imgs_dir, self.num_img_to_load, self.start, self.end, self.skip,
                               self.load_sorted, self.load_img)
        # 加载的图像张量
        self.imgs = image_data['imgs'] # (N, H, W, 3) torch.float32
        # 图像的文件名
        self.img_names = image_data['img_names'] # (N, )
        # 图像的数量
        self.N_imgs = image_data['N_imgs']
        # 原始图像的高度
        self.ori_H = image_data['H']
        # 原始图像的宽度
        self.ori_W = image_data['W']
        # 初始化 VGG19 编码器
        self.encoder = Vgg19()

        # 近裁剪平面距离
        self.near = 0.0
        # 远裁剪平面距离
        self.far = 1.0

        # 如果需要调整图像分辨率
        if self.res_ratio > 1:
            self.H = self.ori_H // self.res_ratio
            self.W = self.ori_W // self.res_ratio
        else:
            self.H = self.ori_H
            self.W = self.ori_W

        if self.load_img:
            # 调整图像的分辨率
            self.imgs = resize_imgs(self.imgs, self.H, self.W) # (N, H, W, 3) torch.float32
            # 存储图像的特征
            self.features = []
            # 使用 tqdm 显示处理进度
            for img in tqdm(self.imgs):
                # 对图像进行通道维度的调整,并通过编码器提取特征
                self.features.append(self.encoder(img.permute(2,0,1)[None,...]))
            # 将所有图像的特征在批次维度上拼接起来
            self.features = torch.cat(self.features,0)
            # 特征图的尺寸
            self.feature_size = (self.features.shape[-2],self.features.shape[-1]) # (H,W)
            # 打印特征图的形状
            print(self.features.shape)

if __name__ == '__main__':
    # 数据的基础目录,需要替换为实际的路径
    base_dir = '/your/data/path'
    # 场景的名称
    scene_name = 'LLFF/fern/images'
    # 图像的缩放比例
    resize_ratio = 8
    # 要加载的图像数量,-1 表示加载所有图像
    num_img_to_load = -1
    # 开始加载图像的索引
    start = 0
    # 结束加载图像的索引,-1 表示加载到最后
    end = -1
    # 加载图像的间隔
    skip = 1
    # 是否按顺序加载图像
    load_sorted = True
    # 是否加载图像
    load_img = True

    # 初始化 DataLoaderAnyFolder 类
    scene = DataLoaderAnyFolder(base_dir=base_dir,
                                scene_name=scene_name,
                                res_ratio=resize_ratio,
                                num_img_to_load=num_img_to_load,
                                start=start,
                                end=end,
                                skip=skip,
                                load_sorted=load_sorted,
                                load_img=load_img)

with_colmap.py

import os
# os 模块提供了与操作系统进行交互的功能,
# 可以用来处理文件和目录路径、创建/删除目录、获取环境变量等,
# 在代码中主要用于构建文件和目录的路径。

import torch
# torch 是 PyTorch 的核心库,提供了张量(Tensor)数据结构,
# 支持自动求导机制,用于构建和训练深度学习模型,
# 可以在 CPU 或 GPU 上进行高效的数值计算。

import torch.nn.functional as F
# torch.nn.functional 提供了许多神经网络中常用的函数,
# 如激活函数、损失函数、卷积、池化等操作,
# 这些函数是无状态的,通常用于自定义神经网络层中的具体运算。

import numpy as np
# numpy 是 Python 中用于科学计算的基础库,
# 提供了高效的多维数组对象和各种数学函数,
# 可以进行数组操作、线性代数运算、随机数生成等,
# 在代码中主要用于处理图像数据和数组操作。

from tqdm import tqdm
# tqdm 是一个快速、可扩展的进度条工具,
# 可以在循环中显示进度条,方便用户了解代码执行的进度。

import imageio
# imageio 是一个用于读取和写入多种图像文件格式的库,
# 在代码中主要用于读取图像文件。

from utils.comp_ray_dir import comp_ray_dir_cam
# 从 utils 包中的 comp_ray_dir 模块导入 comp_ray_dir_cam 函数,
# 推测该函数用于计算相机坐标系下的光线方向。

from utils.pose_utils import center_poses
# 从 utils 包中的 pose_utils 模块导入 center_poses 函数,
# 推测该函数用于对相机位姿进行中心化处理。

from utils.lie_group_helper import convert3x4_4x4
# 从 utils 包中的 lie_group_helper 模块导入 convert3x4_4x4 函数,
# 推测该函数用于将 3x4 的相机位姿矩阵转换为 4x4 的齐次矩阵。


def resize_imgs(imgs, new_h, new_w):
    """
    :param imgs:    (N, H, W, 3)            torch.float32 RGB
    :param new_h:   int/torch int
    :param new_w:   int/torch int
    :return:        (N, new_H, new_W, 3)    torch.float32 RGB
    """
    # 将图像张量从 (N, H, W, 3) 转换为 (N, 3, H, W) 以适应 F.interpolate 函数的输入要求
    imgs = imgs.permute(0, 3, 1, 2)  # (N, 3, H, W)
    # 使用双线性插值方法将图像调整到指定的新高度和新宽度
    imgs = F.interpolate(imgs, size=(new_h, new_w), mode='bilinear')  # (N, 3, new_H, new_W)
    # 将图像张量从 (N, 3, new_H, new_W) 转换回 (N, new_H, new_W, 3)
    imgs = imgs.permute(0, 2, 3, 1)  # (N, new_H, new_W, 3)

    return imgs  # (N, new_H, new_W, 3) torch.float32 RGB


def load_imgs(image_dir, img_ids, new_h, new_w):
    # 获取图像目录下所有图像文件名,并按字母顺序排序
    img_names = np.array(sorted(os.listdir(image_dir)))  # all image names
    # 根据给定的图像索引筛选出需要的图像文件名
    img_names = img_names[img_ids]  # image name for this split

    # 构建每个图像的完整路径
    img_paths = [os.path.join(image_dir, n) for n in img_names]

    img_list = []
    # 使用 tqdm 显示加载图像的进度
    for p in tqdm(img_paths):
        # 读取图像并只保留前三个通道(RGB)
        img = imageio.imread(p)[:, :, :3]  # (H, W, 3) np.uint8
        img_list.append(img)
    # 将图像列表转换为 numpy 数组
    img_list = np.stack(img_list)  # (N, H, W, 3)
    # 将 numpy 数组转换为 PyTorch 张量,并将像素值归一化到 [0, 1] 范围
    img_list = torch.from_numpy(img_list).float() / 255  # (N, H, W, 3) torch.float32
    # 调用 resize_imgs 函数将图像调整到指定的新高度和新宽度
    img_list = resize_imgs(img_list, new_h, new_w)
    return img_list, img_names


def read_meta(in_dir, use_ndc):
    """
    Read the poses_bounds.npy file produced by LLFF imgs2poses.py.
    This function is modified from https://github.com/kwea123/nerf_pl.
    """
    # 加载 poses_bounds.npy 文件,该文件包含相机位姿和深度边界信息
    poses_bounds = np.load(os.path.join(in_dir, 'poses_bounds.npy'))  # (N_images, 17)

    # 提取相机位姿信息,将其重塑为 (N_images, 3, 5) 的形状
    c2ws = poses_bounds[:, :15].reshape(-1, 3, 5)  # (N_images, 3, 5)
    # 提取深度边界信息
    bounds = poses_bounds[:, -2:]  # (N_images, 2)
    # 提取图像高度、宽度和焦距信息
    H, W, focal = c2ws[0, :, -1]

    # 修正相机位姿的旋转部分,将旋转形式从 "down right back" 改为 "right up back"
    # 参考 https://github.com/bmild/nerf/issues/34
    c2ws = np.concatenate([c2ws[..., 1:2], -c2ws[..., :1], c2ws[..., 2:4]], -1)

    # 对相机位姿进行中心化处理,返回中心化后的相机位姿和平均位姿
    # pose_avg @ c2ws -> centred c2ws
    c2ws, pose_avg = center_poses(c2ws)  # (N_images, 3, 4), (4, 4)

    if use_ndc:
        # 获取最近深度值
        near_original = bounds.min()
        # 计算缩放因子,将最近深度调整到稍大于 1.0 的位置
        scale_factor = near_original * 0.75  # 0.75 is the default parameter
        # 对深度边界进行缩放
        bounds /= scale_factor
        # 对相机位姿的平移部分进行缩放
        c2ws[..., 3] /= scale_factor
    
    # 将 3x4 的相机位姿转换为 4x4 的齐次矩阵形式
    c2ws = convert3x4_4x4(c2ws)  # (N, 4, 4)

    results = {
        'c2ws': c2ws,       # (N, 4, 4) np
        'bounds': bounds,   # (N_images, 2) np
        'H': int(H),        # scalar
        'W': int(W),        # scalar
        'focal': focal,     # scalar
        'pose_avg': pose_avg,  # (4, 4) np
    }
    return results


class DataLoaderWithCOLMAP:
    """
    Most useful fields:
        self.c2ws:          (N_imgs, 4, 4)      torch.float32
        self.imgs           (N_imgs, H, W, 4)   torch.float32
        self.ray_dir_cam    (H, W, 3)           torch.float32
        self.H              scalar
        self.W              scalar
        self.N_imgs         scalar
    """
    def __init__(self, base_dir, scene_name, data_type, res_ratio, num_img_to_load, skip, use_ndc, load_img=True):
        """
        :param base_dir:
        :param scene_name:
        :param data_type:   'train' or 'val'.
        :param res_ratio:   int [1, 2, 4] etc to resize images to a lower resolution.
        :param num_img_to_load/skip: control frame loading in temporal domain.
        :param use_ndc      True/False, just centre the poses and scale them.
        :param load_img:    True/False. If set to false: only count number of images, get H and W,
                            but do not load imgs. Useful when vis poses or debug etc.
        """
        self.base_dir = base_dir
        self.scene_name = scene_name
        self.data_type = data_type
        self.res_ratio = res_ratio
        self.num_img_to_load = num_img_to_load
        self.skip = skip
        self.use_ndc = use_ndc
        self.load_img = load_img

        # 构建场景目录的完整路径
        self.scene_dir = os.path.join(self.base_dir, self.scene_name)
        # 构建图像目录的完整路径
        self.img_dir = os.path.join(self.scene_dir, 'images')

        # 读取所有的元信息,包括相机位姿、深度边界、图像尺寸和焦距等
        meta = read_meta(self.scene_dir, self.use_ndc)
        # 提取相机位姿信息
        self.c2ws = meta['c2ws']  # (N, 4, 4) all camera pose
        # 提取图像高度信息
        self.H = meta['H']
        # 提取图像宽度信息
        self.W = meta['W']
        # 提取焦距信息
        self.focal = float(meta['focal'])

        if self.res_ratio > 1:
            # 如果需要调整图像分辨率,对图像高度进行相应的缩放
            self.H = self.H // self.res_ratio
            # 如果需要调整图像分辨率,对图像宽度进行相应的缩放
            self.W = self.W // self.res_ratio
            # 如果需要调整图像分辨率,对焦距进行相应的缩放
            self.focal /= self.res_ratio

        # 近裁剪平面距离
        self.near = 0.0
        # 远裁剪平面距离
        self.far = 1.0
        # 加载图像并调整到指定的高度和宽度
        self.imgs, self.img_names = load_imgs(self.img_dir, np.arange(num_img_to_load), self.H, self.W)  # (N, H, W, 3) torch.float32
        # 截取前 num_img_to_load 个相机位姿
        self.c2ws = self.c2ws[:num_img_to_load]
        # 图像的数量
        self.N_imgs = self.c2ws.shape[0]

        # 生成相机坐标系下的光线方向
        self.ray_dir_cam = comp_ray_dir_cam(self.H, self.W, self.focal)  # (H, W, 3) torch.float32

        # 将相机位姿从 numpy 数组转换为 PyTorch 张量
        self.c2ws = torch.from_numpy(self.c2ws).float()  # (N, 4, 4) torch.float32
        # 将光线方向张量转换为 float32 类型
        self.ray_dir_cam = self.ray_dir_cam.float()  # (H, W, 3) torch.float32


if __name__ == '__main__':
    scene_name = 'LLFF/fern'
    use_ndc = True
    # 注意:需要将 /your/data/path 替换为实际的数据路径,
    # 这里创建了一个 DataLoaderWithCOLMAP 类的实例,用于加载指定场景的数据
    scene = DataLoaderWithCOLMAP(base_dir='/your/data/path',
                                 scene_name=scene_name,
                                 data_type='train',
                                 res_ratio=8,
                                 num_img_to_load=-1,
                                 skip=1,
                                 use_ndc=use_ndc)

with_feature_colmap.py

import os
# os 模块提供了与操作系统进行交互的功能,
# 可用于处理文件和目录路径、创建/删除目录、获取环境变量等,
# 在本代码里主要用于构建文件和目录的路径。

import torch
# torch 是 PyTorch 的核心库,提供了张量(Tensor)数据结构,
# 支持自动求导机制,用于构建和训练深度学习模型,
# 能够在 CPU 或 GPU 上高效地进行数值计算。

import numpy as np
# numpy 是 Python 中用于科学计算的基础库,
# 提供了高效的多维数组对象和各种数学函数,
# 可进行数组操作、线性代数运算、随机数生成等,
# 在代码中主要用于处理图像数据和数组操作。

from tqdm import tqdm
# tqdm 是一个快速、可扩展的进度条工具,
# 能在循环中显示进度条,方便用户了解代码执行的进度。

import imageio
# imageio 是一个用于读取和写入多种图像文件格式的库,
# 在代码中主要用于读取图像文件。

from dataloader.with_colmap import resize_imgs
# 从 dataloader.with_colmap 模块导入 resize_imgs 函数,
# 该函数用于调整图像的尺寸。

from torchvision import models
# torchvision 是 PyTorch 中用于计算机视觉任务的库,
# models 子模块提供了预训练的深度学习模型。

from torch.nn import functional as F
# torch.nn.functional 提供了许多神经网络中常用的函数,
# 例如激活函数、损失函数、卷积、池化等操作,
# 这些函数是无状态的,通常用于自定义神经网络层中的具体运算。

from utils.comp_ray_dir import comp_ray_dir_cam
# 从 utils 包中的 comp_ray_dir 模块导入 comp_ray_dir_cam 函数,
# 推测该函数用于计算相机坐标系下的光线方向。

from utils.pose_utils import center_poses
# 从 utils 包中的 pose_utils 模块导入 center_poses 函数,
# 推测该函数用于对相机位姿进行中心化处理。

from utils.lie_group_helper import convert3x4_4x4
# 从 utils 包中的 lie_group_helper 模块导入 convert3x4_4x4 函数,
# 推测该函数用于将 3x4 的相机位姿矩阵转换为 4x4 的齐次矩阵。

from utils.vgg import Vgg19
# 从 utils.vgg 模块导入 Vgg19 类,可能用于特征提取。


def load_imgs(image_dir, num_img_to_load, start, end, skip, load_sorted, load_img):
    # 获取图像目录下所有图像文件名,并按字母顺序排序
    img_names = np.array(sorted(os.listdir(image_dir)))  # all image names

    # 在时间域上对帧进行下采样
    if end == -1:
        img_names = img_names[start::skip]
    else:
        img_names = img_names[start:end:skip]

    # 如果不按顺序加载图像,则对图像文件名进行随机打乱
    if not load_sorted:
        np.random.shuffle(img_names)

    # 加载下采样后的图像
    if num_img_to_load > len(img_names):
        print('Asked for {0:6d} images but only {1:6d} available. Exit.'.format(num_img_to_load, len(img_names)))
        exit()
    elif num_img_to_load == -1:
        print('Loading all available {0:6d} images'.format(len(img_names)))
    else:
        print('Loading {0:6d} images out of {1:6d} images.'.format(num_img_to_load, len(img_names)))
        img_names = img_names[:num_img_to_load]

    # 构建每个图像的完整路径
    img_paths = [os.path.join(image_dir, n) for n in img_names]
    # 图像的数量
    N_imgs = len(img_paths)

    img_list = []
    if load_img:
        # 使用 tqdm 显示加载图像的进度
        for p in tqdm(img_paths):
            # 读取图像并只保留前三个通道(RGB)
            img = imageio.imread(p)[:, :, :3]  # (H, W, 3) np.uint8
            img_list.append(img)
        # 将图像列表转换为 numpy 数组
        img_list = np.stack(img_list)  # (N, H, W, 3)
        # 将 numpy 数组转换为 PyTorch 张量,并将像素值归一化到 [0, 1] 范围
        img_list = torch.from_numpy(img_list).float() / 255  # (N, H, W, 3) torch.float32
        # 获取图像的高度和宽度
        H, W = img_list.shape[1], img_list.shape[2]
    else:
        # 如果不加载图像,则读取第一张图像以获取图像的高度和宽度
        tmp_img = imageio.imread(img_paths[0])  # load one image to get H, W
        H, W = tmp_img.shape[0], tmp_img.shape[1]

    result = {
        'imgs': img_list,  # (N, H, W, 3) torch.float32
        'img_names': img_names,  # (N, )
        'N_imgs': N_imgs,
        'H': H,
        'W': W,
    }
    return result


def read_meta(in_dir, use_ndc):
    """
    Read the poses_bounds.npy file produced by LLFF imgs2poses.py.
    This function is modified from https://github.com/kwea123/nerf_pl.
    """
    # 加载 poses_bounds.npy 文件,该文件包含相机位姿和深度边界信息
    poses_bounds = np.load(os.path.join(in_dir, '../poses_bounds.npy'))  # (N_images, 17)

    # 提取相机位姿信息,将其重塑为 (N_images, 3, 5) 的形状
    c2ws = poses_bounds[:, :15].reshape(-1, 3, 5)  # (N_images, 3, 5)
    # 提取深度边界信息
    bounds = poses_bounds[:, -2:]  # (N_images, 2)
    # 提取图像高度、宽度和焦距信息
    H, W, focal = c2ws[0, :, -1]

    # 修正相机位姿的旋转部分,将旋转形式从 "down right back" 改为 "right up back"
    # 参考 https://github.com/bmild/nerf/issues/34
    c2ws = np.concatenate([c2ws[..., 1:2], -c2ws[..., :1], c2ws[..., 2:4]], -1)

    # 对相机位姿进行中心化处理,返回中心化后的相机位姿和平均位姿
    # pose_avg @ c2ws -> centred c2ws
    c2ws, pose_avg = center_poses(c2ws)  # (N_images, 3, 4), (4, 4)

    if use_ndc:
        # 修正尺度,使最近的深度略大于 1.0
        # 参考 https://github.com/bmild/nerf/issues/34
        near_original = bounds.min()
        # 0.75 是默认参数
        scale_factor = near_original * 0.75  
        # 最近的深度约为 1/0.75 = 1.33
        bounds /= scale_factor
        c2ws[..., 3] /= scale_factor
    
    # 将 3x4 的相机位姿矩阵转换为 4x4 的齐次矩阵形式
    c2ws = convert3x4_4x4(c2ws)  # (N, 4, 4)

    results = {
        'c2ws': c2ws,       # (N, 4, 4) np
        'bounds': bounds,   # (N_images, 2) np
        'H': int(H),        # scalar
        'W': int(W),        # scalar
        'focal': focal,     # scalar
        'pose_avg': pose_avg,  # (4, 4) np
    }
    return results


class Dataloader_feature_n_colmap:
    """
    Most useful fields:
        self.c2ws:          (N_imgs, 4, 4)      torch.float32
        self.imgs           (N_imgs, H, W, 4)   torch.float32
        self.ray_dir_cam    (H, W, 3)           torch.float32
        self.H              scalar
        self.W              scalar
        self.N_imgs         scalar
    """
    def __init__(self, base_dir, scene_name, res_ratio, num_img_to_load, start=0, end=-1, skip=1,
                 load_sorted=True, load_img=True, use_ndc=True, device='cpu'):
        """
        :param base_dir: 数据的基础目录
        :param scene_name: 场景的名称
        :param res_ratio: 整数,如 [1, 2, 4] 等,用于将图像调整为较低的分辨率。
        :param start/end/skip: 用于在时间域上控制帧的加载。
        :param load_sorted: 布尔值,是否按顺序加载图像。
        :param load_img: 布尔值。如果设置为 false:仅统计图像数量、获取图像的高度和宽度,
                         但不加载图像。在可视化位姿或调试等情况下很有用。
        """
        self.base_dir = base_dir
        self.scene_name = scene_name
        self.res_ratio = res_ratio
        self.num_img_to_load = num_img_to_load
        self.start = start
        self.end = end
        self.skip = skip
        self.use_ndc = use_ndc
        self.load_sorted = load_sorted
        self.load_img = load_img

        # 构建图像目录的完整路径
        self.imgs_dir = os.path.join(self.base_dir, self.scene_name)
        
        # 读取所有的元信息,包括相机位姿、深度边界、图像尺寸和焦距等
        meta = read_meta(self.imgs_dir, self.use_ndc)
        # 提取相机位姿信息并转换为 PyTorch 张量
        self.c2ws = torch.Tensor(meta['c2ws'])  # (N, 4, 4) all camera pose
        # 提取焦距信息
        self.focal = float(meta['focal'])
        # 根据 start、end 和 skip 参数对相机位姿进行筛选
        if self.end == -1:
            self.c2ws = self.c2ws[self.start::self.skip]
        else:
            self.c2ws = self.c2ws[self.start:self.end:self.skip]
        # 加载图像数据
        image_data = load_imgs(self.imgs_dir, self.num_img_to_load, self.start, self.end, self.skip,
                                self.load_sorted, self.load_img)
        # 提取加载的图像数据
        self.imgs = image_data['imgs']  # (N, H, W, 3) torch.float32
        # 提取图像文件名
        self.img_names = image_data['img_names']  # (N, )
        # 图像的数量
        self.N_imgs = image_data['N_imgs']
        # 原始图像的高度
        self.ori_H = image_data['H']
        # 原始图像的宽度
        self.ori_W = image_data['W']
        # 初始化 Vgg19 编码器并将其移动到指定设备
        self.encoder = Vgg19().to(device)

        # 始终使用归一化设备坐标(NDC)
        self.near = 0.0
        self.far = 1.0

        # 如果需要调整图像分辨率
        if self.res_ratio > 1:
            # 计算调整后的图像高度
            self.H = self.ori_H // self.res_ratio
            # 计算调整后的图像宽度
            self.W = self.ori_W // self.res_ratio
        else:
            self.H = self.ori_H
            self.W = self.ori_W
        # 调整焦距
        self.focal /= self.res_ratio

        if self.load_img:
            # 调整图像的分辨率并将其移动到指定设备
            self.imgs = resize_imgs(self.imgs, self.H, self.W).to(device)  # (N, H, W, 3) torch.float32
            self.features = []
            # 使用 tqdm 显示处理进度
            for img in tqdm(self.imgs):
                # 对图像进行通道维度的调整,并通过编码器提取特征
                self.features.append(self.encoder(img.permute(2, 0, 1)[None, ...]))
            # 这里注释掉了特征拼接的代码,可根据需要取消注释
            # self.features = torch.cat(self.features, 0)
            # print(self.features.shape)


if __name__ == '__main__':
    # 数据的基础目录,需要替换为实际的路径
    base_dir = '/your/data/path'
    # 场景的名称
    scene_name = 'LLFF/fern/images'
    # 图像的缩放比例
    resize_ratio = 8
    # 要加载的图像数量,-1 表示加载所有图像
    num_img_to_load = -1
    # 开始加载图像的索引
    start = 0
    # 结束加载图像的索引,-1 表示加载到最后
    end = -1
    # 加载图像的间隔
    skip = 1
    # 是否按顺序加载图像
    load_sorted = True
    # 是否加载图像
    load_img = True
    # 是否使用归一化设备坐标(NDC)
    use_ndc = True

    # 初始化 Dataloader_feature_n_colmap 类
    scene = Dataloader_feature_n_colmap(base_dir=base_dir,
                                scene_name=scene_name,
                                res_ratio=resize_ratio,
                                num_img_to_load=num_img_to_load,
                                start=start,
                                end=end,
                                skip=skip,
                                load_sorted=load_sorted,
                                load_img=load_img,
                                use_ndc=use_ndc)

with_feature.py

import os
# os 模块提供了与操作系统进行交互的功能,
# 可用于处理文件和目录路径、创建/删除目录、获取环境变量等,
# 在本代码里主要用于构建文件和目录的路径。

import torch
# torch 是 PyTorch 的核心库,提供了张量(Tensor)数据结构,
# 支持自动求导机制,用于构建和训练深度学习模型,
# 能够在 CPU 或 GPU 上高效地进行数值计算。

import torch.nn.functional as F
# torch.nn.functional 提供了许多神经网络中常用的函数,
# 例如激活函数、损失函数、卷积、池化等操作,
# 这些函数是无状态的,通常用于自定义神经网络层中的具体运算。

import numpy as np
# numpy 是 Python 中用于科学计算的基础库,
# 提供了高效的多维数组对象和各种数学函数,
# 可进行数组操作、线性代数运算、随机数生成等,
# 在代码中主要用于处理图像数据和数组操作。

from tqdm import tqdm
# tqdm 是一个快速、可扩展的进度条工具,
# 能在循环中显示进度条,方便用户了解代码执行的进度。

import imageio
# imageio 是一个用于读取和写入多种图像文件格式的库,
# 在代码中主要用于读取图像文件。

from utils.comp_ray_dir import comp_ray_dir_cam
# 从 utils 包中的 comp_ray_dir 模块导入 comp_ray_dir_cam 函数,
# 推测该函数用于计算相机坐标系下的光线方向。

from utils.pose_utils import center_poses
# 从 utils 包中的 pose_utils 模块导入 center_poses 函数,
# 推测该函数用于对相机位姿进行中心化处理。

from utils.lie_group_helper import convert3x4_4x4
# 从 utils 包中的 lie_group_helper 模块导入 convert3x4_4x4 函数,
# 推测该函数用于将 3x4 的相机位姿矩阵转换为 4x4 的齐次矩阵。


def resize_imgs(imgs, new_h, new_w):
    """
    :param imgs:    (N, H, W, 3)            torch.float32 格式的 RGB 图像
    :param new_h:   整数或 torch 整数类型,表示新的图像高度
    :param new_w:   整数或 torch 整数类型,表示新的图像宽度
    :return:        (N, new_H, new_W, 3)    torch.float32 格式的 RGB 图像
    """
    # 将图像张量的维度从 (N, H, W, 3) 调整为 (N, 3, H, W),以适配 F.interpolate 函数的输入要求
    imgs = imgs.permute(0, 3, 1, 2)  # 变为 (N, 3, H, W)
    # 使用双线性插值方法将图像调整到指定的新高度和新宽度
    imgs = F.interpolate(imgs, size=(new_h, new_w), mode='bilinear')  # 变为 (N, 3, new_H, new_W)
    # 将图像张量的维度从 (N, 3, new_H, new_W) 调整回 (N, new_H, new_W, 3)
    imgs = imgs.permute(0, 2, 3, 1)  # 变为 (N, new_H, new_W, 3)

    return imgs  # 返回 (N, new_H, new_W, 3) 格式的 torch.float32 类型 RGB 图像


def load_imgs(image_dir, img_ids, new_h, new_w):
    # 获取图像目录下所有图像文件名,并按字母顺序排序
    img_names = np.array(sorted(os.listdir(image_dir)))  # 得到所有图像文件名
    # 根据给定的图像索引筛选出本次需要的图像文件名
    img_names = img_names[img_ids]  # 得到本次分割所需的图像文件名

    # 构建每个图像的完整路径
    img_paths = [os.path.join(image_dir, n) for n in img_names]

    img_list = []
    # 使用 tqdm 显示加载图像的进度
    for p in tqdm(img_paths):
        # 读取图像并只保留前三个通道(RGB)
        img = imageio.imread(p)[:, :, :3]  # 得到 (H, W, 3) 格式的 np.uint8 类型图像
        img_list.append(img)
    # 将图像列表转换为 numpy 数组
    img_list = np.stack(img_list)  # 变为 (N, H, W, 3) 格式
    # 将 numpy 数组转换为 PyTorch 张量,并将像素值归一化到 [0, 1] 范围
    img_list = torch.from_numpy(img_list).float() / 255  # 变为 (N, H, W, 3) 格式的 torch.float32 类型
    # 调用 resize_imgs 函数将图像调整到指定的新高度和新宽度
    img_list = resize_imgs(img_list, new_h, new_w)
    return img_list, img_names


def read_meta(in_dir, use_ndc):
    """
    读取由 LLFF 的 imgs2poses.py 生成的 poses_bounds.npy 文件。
    此函数改编自 https://github.com/kwea123/nerf_pl。
    """
    # 加载 poses_bounds.npy 文件,该文件包含相机位姿和深度边界信息
    poses_bounds = np.load(os.path.join(in_dir, 'poses_bounds.npy'))  # 得到 (N_images, 17) 格式的数组

    # 提取相机位姿信息,将其重塑为 (N_images, 3, 5) 的形状
    c2ws = poses_bounds[:, :15].reshape(-1, 3, 5)  # 变为 (N_images, 3, 5) 格式
    # 提取深度边界信息
    bounds = poses_bounds[:, -2:]  # 变为 (N_images, 2) 格式
    # 提取图像高度、宽度和焦距信息
    H, W, focal = c2ws[0, :, -1]

    # 修正相机位姿的旋转部分,将旋转形式从 "下 右 后" 改为 "右 上 后"
    # 参考 https://github.com/bmild/nerf/issues/34
    c2ws = np.concatenate([c2ws[..., 1:2], -c2ws[..., :1], c2ws[..., 2:4]], -1)

    # 对相机位姿进行中心化处理,返回中心化后的相机位姿和平均位姿
    # pose_avg @ c2ws 得到中心化后的 c2ws
    c2ws, pose_avg = center_poses(c2ws)  # 分别得到 (N_images, 3, 4) 和 (4, 4) 格式的数组

    if use_ndc:
        # 获取最近深度值
        near_original = bounds.min()
        # 计算缩放因子,使最近深度调整到稍大于 1.0 的位置
        scale_factor = near_original * 0.75  # 0.75 是默认参数
        # 现在最近深度约为 1/0.75 = 1.33
        # 对深度边界进行缩放
        bounds /= scale_factor
        # 对相机位姿的平移部分进行缩放
        c2ws[..., 3] /= scale_factor
    
    # 将 3x4 的相机位姿矩阵转换为 4x4 的齐次矩阵形式
    c2ws = convert3x4_4x4(c2ws)  # 变为 (N, 4, 4) 格式

    results = {
        'c2ws': c2ws,       # (N, 4, 4) 格式的 numpy 数组
        'bounds': bounds,   # (N_images, 2) 格式的 numpy 数组
        'H': int(H),        # 标量,图像高度
        'W': int(W),        # 标量,图像宽度
        'focal': focal,     # 标量,焦距
        'pose_avg': pose_avg,  # (4, 4) 格式的 numpy 数组
    }
    return results


class DataLoaderWithCOLMAP:
    """
    最有用的字段:
        self.c2ws:          (N_imgs, 4, 4) 格式的 torch.float32 类型张量,表示相机位姿
        self.imgs           (N_imgs, H, W, 4) 格式的 torch.float32 类型张量,表示图像
        self.ray_dir_cam    (H, W, 3) 格式的 torch.float32 类型张量,表示相机坐标系下的光线方向
        self.H              标量,图像高度
        self.W              标量,图像宽度
        self.N_imgs         标量,图像数量
    """
    def __init__(self, base_dir, scene_name, data_type, res_ratio, num_img_to_load, skip, use_ndc, load_img=True):
        """
        :param base_dir: 数据的基础目录
        :param scene_name: 场景的名称
        :param data_type: 数据类型,'train' 或 'val'。
        :param res_ratio: 整数,如 [1, 2, 4] 等,用于将图像调整为较低的分辨率。
        :param num_img_to_load/skip: 用于在时间域上控制帧的加载。
        :param use_ndc: 布尔值,是否对相机位姿进行中心化和缩放。
        :param load_img: 布尔值。如果设置为 False:仅统计图像数量、获取图像的高度和宽度,
                         但不加载图像。在可视化位姿或调试等情况下很有用。
        """
        self.base_dir = base_dir
        self.scene_name = scene_name
        self.data_type = data_type
        self.res_ratio = res_ratio
        self.num_img_to_load = num_img_to_load
        self.skip = skip
        self.use_ndc = use_ndc
        self.load_img = load_img

        # 构建场景目录的完整路径
        self.scene_dir = os.path.join(self.base_dir, self.scene_name)
        # 构建图像目录的完整路径
        self.img_dir = os.path.join(self.scene_dir, 'images')

        # 读取所有的元信息,包括相机位姿、深度边界、图像尺寸和焦距等
        meta = read_meta(self.scene_dir, self.use_ndc)
        # 提取相机位姿信息
        self.c2ws = meta['c2ws']  # (N, 4, 4) 格式的 numpy 数组,表示所有相机位姿
        # 提取图像高度信息
        self.H = meta['H']
        # 提取图像宽度信息
        self.W = meta['W']
        # 提取焦距信息
        self.focal = float(meta['focal'])

        if self.res_ratio > 1:
            # 如果需要调整图像分辨率,对图像高度进行相应的缩放
            self.H = self.H // self.res_ratio
            # 如果需要调整图像分辨率,对图像宽度进行相应的缩放
            self.W = self.W // self.res_ratio
            # 如果需要调整图像分辨率,对焦距进行相应的缩放
            self.focal /= self.res_ratio

        # 近裁剪平面距离
        self.near = 0.0
        # 远裁剪平面距离
        self.far = 1.0
        # 加载图像并调整到指定的高度和宽度
        self.imgs, self.img_names = load_imgs(self.img_dir, np.arange(num_img_to_load), self.H, self.W)  # (N, H, W, 3) 格式的 torch.float32 类型张量
        # 截取前 num_img_to_load 个相机位姿
        self.c2ws = self.c2ws[:num_img_to_load]
        # 图像的数量
        self.N_imgs = self.c2ws.shape[0]

        # 生成相机坐标系下的光线方向
        self.ray_dir_cam = comp_ray_dir_cam(self.H, self.W, self.focal)  # (H, W, 3) 格式的 torch.float32 类型张量

        # 将相机位姿从 numpy 数组转换为 PyTorch 张量
        self.c2ws = torch.from_numpy(self.c2ws).float()  # (N, 4, 4) 格式的 torch.float32 类型张量
        # 将光线方向张量转换为 float32 类型
        self.ray_dir_cam = self.ray_dir_cam.float()  # (H, W, 3) 格式的 torch.float32 类型张量


if __name__ == '__main__':
    scene_name = 'LLFF/fern'
    use_ndc = True
    # 注意:需要将 /your/data/path 替换为实际的数据路径,
    # 这里创建了一个 DataLoaderWithCOLMAP 类的实例,用于加载指定场景的数据
    scene = DataLoaderWithCOLMAP(base_dir='/your/data/path',
                                 scene_name=scene_name,
                                 data_type='train',
                                 res_ratio=8,
                                 num_img_to_load=-1,
                                 skip=1,
                                 use_ndc=use_ndc)

with_mask.py

他们的mask其实是掩码文件,有没有可能只基于掩码文件去做呢?

import os
# os 模块提供了与操作系统进行交互的功能,
# 可用于处理文件和目录路径、创建/删除目录、获取环境变量等,
# 在本代码里主要用于构建文件和目录的路径。

import torch
# torch 是 PyTorch 的核心库,提供了张量(Tensor)数据结构,
# 支持自动求导机制,用于构建和训练深度学习模型,
# 能够在 CPU 或 GPU 上高效地进行数值计算。

import numpy as np
# numpy 是 Python 中用于科学计算的基础库,
# 提供了高效的多维数组对象和各种数学函数,
# 可进行数组操作、线性代数运算、随机数生成等,
# 在代码中主要用于处理图像数据和数组操作。

from tqdm import tqdm
# tqdm 是一个快速、可扩展的进度条工具,
# 能在循环中显示进度条,方便用户了解代码执行的进度。

import imageio
# imageio 是一个用于读取和写入多种图像文件格式的库,
# 在代码中主要用于读取图像文件。

from dataloader.with_colmap import resize_imgs
# 从 dataloader.with_colmap 模块导入 resize_imgs 函数,
# 该函数用于调整图像的尺寸。


def load_imgs(image_dir, mask_dir, num_img_to_load, start, end, skip, load_sorted, load_img):
    # 获取图像目录下所有图像文件名,并按字母顺序排序
    img_names = np.array(sorted(os.listdir(image_dir)))  # all image names

    # 在时间域上对帧进行下采样
    if end == -1 and len(os.listdir(mask_dir)) == len(img_names):
        # 若 end 为 -1 且掩码目录和图像目录文件数量相同,则按 skip 间隔选取
        img_names = img_names[start::skip]
    else:
        # 取 end 和掩码目录文件数量的最小值,避免越界
        end = min(end, len(os.listdir(mask_dir)))
        img_names = img_names[start:end:skip]

    # 如果不按顺序加载图像,则对图像文件名进行随机打乱
    if not load_sorted:
        np.random.shuffle(img_names)

    # 加载下采样后的图像
    if num_img_to_load > len(img_names):
        print('Asked for {0:6d} images but only {1:6d} available. Exit.'.format(num_img_to_load, len(img_names)))
        exit()
    elif num_img_to_load == -1:
        print('Loading all available {0:6d} images'.format(len(img_names)))
    else:
        print('Loading {0:6d} images out of {1:6d} images.'.format(num_img_to_load, len(img_names)))
        img_names = img_names[:num_img_to_load]

    # 构建每个图像的完整路径
    img_paths = [os.path.join(image_dir, n) for n in img_names]
    # 构建每个掩码图像的完整路径,假设掩码图像为 png 格式,且文件名和图像文件名对应
    mask_paths = [os.path.join(mask_dir, n[:-4]+'.png') for n in img_names]
    # 图像的数量
    N_imgs = len(img_paths)

    img_list, mask_list = [], []
    if load_img:
        # 使用 tqdm 显示加载图像的进度
        for i, p in tqdm(enumerate(img_paths)):
            # 读取图像并只保留前三个通道(RGB)
            img = imageio.imread(p)[:, :, :3]  # (H, W, 3) np.uint8
            img_list.append(img)
            # 读取对应的掩码图像,只取第一个通道
            img = imageio.imread(mask_paths[i])[:, :, [0]]  # (H, W, 1)
            mask_list.append(img)
        # 将图像列表转换为 numpy 数组
        img_list = np.stack(img_list)  # (N, H, W, 3)
        # 将掩码列表转换为 numpy 数组
        mask_list = np.stack(mask_list)
        # 将 numpy 数组转换为 PyTorch 张量,并将像素值归一化到 [0, 1] 范围
        img_list = torch.from_numpy(img_list).float() / 255  # (N, H, W, 3) torch.float32
        mask_list = torch.from_numpy(mask_list).float() / 255
        # 获取图像的高度和宽度
        H, W = img_list.shape[1], img_list.shape[2]
    else:
        # 如果不加载图像,则读取第一张图像以获取图像的高度和宽度
        tmp_img = imageio.imread(img_paths[0])  # load one image to get H, W
        H, W = tmp_img.shape[0], tmp_img.shape[1]

    results = {
        'imgs': img_list,  # (N, H, W, 3) torch.float32
        'img_names': img_names,  # (N, )
        'masks': mask_list,  # 掩码图像张量
        'N_imgs': N_imgs,
        'H': H,
        'W': W,
    }

    return results


class DataLoaderAnyFolder:
    """
    Most useful fields:
        self.c2ws:          (N_imgs, 4, 4)      torch.float32
        self.imgs           (N_imgs, H, W, 4)   torch.float32
        self.ray_dir_cam    (H, W, 3)           torch.float32
        self.H              scalar
        self.W              scalar
        self.N_imgs         scalar
    """
    def __init__(self, base_dir, scene_name, res_ratio, num_img_to_load, start, end, skip, load_sorted, load_img=True):
        """
        :param base_dir: 数据的基础目录
        :param scene_name: 场景的名称
        :param res_ratio: 整数,如 [1, 2, 4] 等,用于将图像调整为较低的分辨率。
        :param start/end/skip: 用于在时间域上控制帧的加载。
        :param load_sorted: 布尔值,是否按顺序加载图像。
        :param load_img: 布尔值。如果设置为 false:仅统计图像数量、获取图像的高度和宽度,
                         但不加载图像。在可视化位姿或调试等情况下很有用。
        """
        self.base_dir = base_dir
        self.scene_name = scene_name
        self.res_ratio = res_ratio
        self.num_img_to_load = num_img_to_load
        self.start = start
        self.end = end
        self.skip = skip
        self.load_sorted = load_sorted
        self.load_img = load_img

        # 构建图像目录的完整路径
        self.imgs_dir = os.path.join(self.base_dir, self.scene_name)
        # 构建掩码目录的完整路径,假设掩码目录在图像目录的上一级的 mask 文件夹下
        self.mask_dir = os.path.join(self.imgs_dir, '../mask/')

        # 调用 load_imgs 函数加载图像和掩码数据
        image_data = load_imgs(self.imgs_dir, self.mask_dir, self.num_img_to_load, self.start, self.end, self.skip,
                               self.load_sorted, self.load_img)

        # 提取加载的图像数据
        self.imgs = image_data['imgs']  # (N, H, W, 3) torch.float32
        # 提取图像文件名
        self.img_names = image_data['img_names']  # (N, )
        # 提取掩码数据
        self.masks = image_data['masks']
        # 图像的数量
        self.N_imgs = image_data['N_imgs']
        # 原始图像的高度
        self.ori_H = image_data['H']
        # 原始图像的宽度
        self.ori_W = image_data['W']

        # 始终使用归一化设备坐标(NDC),设置近裁剪平面距离
        self.near = 0.0
        # 始终使用归一化设备坐标(NDC),设置远裁剪平面距离
        self.far = 1.0

        # 如果需要调整图像分辨率
        if self.res_ratio > 1:
            # 计算调整后的图像高度
            self.H = self.ori_H // self.res_ratio
            # 计算调整后的图像宽度
            self.W = self.ori_W // self.res_ratio
        else:
            self.H = self.ori_H
            self.W = self.ori_W

        if self.load_img:
            # 调整图像的分辨率
            self.imgs = resize_imgs(self.imgs, self.H, self.W)  # (N, H, W, 3) torch.float32
            # 调整掩码图像的分辨率
            self.masks = resize_imgs(self.masks, self.H, self.W)


if __name__ == '__main__':
    # 数据的基础目录,需要替换为实际的路径
    base_dir = '/your/data/path'
    # 场景的名称
    scene_name = 'LLFF/fern/images'
    # 图像的缩放比例
    resize_ratio = 8
    # 要加载的图像数量,-1 表示加载所有图像
    num_img_to_load = -1
    # 开始加载图像的索引
    start = 0
    # 结束加载图像的索引,-1 表示加载到最后
    end = -1
    # 加载图像的间隔
    skip = 1
    # 是否按顺序加载图像
    load_sorted = True
    # 是否加载图像
    load_img = True

    # 初始化 DataLoaderAnyFolder 类
    scene = DataLoaderAnyFolder(base_dir=base_dir,
                                scene_name=scene_name,
                                res_ratio=resize_ratio,
                                num_img_to_load=num_img_to_load,
                                start=start,
                                end=end,
                                skip=skip,
                                load_sorted=load_sorted,
                                load_img=load_img)

models

├── models/  # 模型文件夹
│   ├── depth_decoder.py  # 深度解码器脚本文件
│   ├── intrinsics.py  # 内参相关脚本文件
│   ├── layers.py  # 层相关脚本文件
│   ├── nerf_feature.py  # NeRF特征相关脚本文件
│   ├── nerf_mask.py  # NeRF掩码相关脚本文件
│   ├── nerf_models.py  # NeRF模型相关脚本文件
│   └── poses.py  # 位姿相关脚本文件

depth_decoder.py

# 版权所有 Niantic 2019。专利申请中。保留所有权利。
#
# 本软件遵循 Monodepth2 许可证的条款,
# 该许可证仅允许非商业用途,完整条款可在 LICENSE 文件中获取。

from __future__ import absolute_import, division, print_function

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from models.layers import *

# 定义一个卷积层,输入通道数为 in_planes,输出通道数为 out_planes,卷积核大小为 kernel_size
# 如果 instancenorm 为 True,则使用实例归一化;否则使用批量归一化
def conv(in_planes, out_planes, kernel_size, instancenorm=False):
    if instancenorm:
        # 构建一个包含卷积层、实例归一化层和 LeakyReLU 激活函数的序列
        m = nn.Sequential(
            # 卷积层,使用指定的输入和输出通道数、卷积核大小,步长为 1,填充为 (kernel_size - 1) // 2,无偏置
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
                      stride=1, padding=(kernel_size - 1) // 2, bias=False),
            # 实例归一化层
            nn.InstanceNorm2d(out_planes),
            # LeakyReLU 激活函数,负斜率为 0.1,原地操作
            nn.LeakyReLU(0.1, inplace=True),
        )
    else:
        # 构建一个包含卷积层、批量归一化层和 LeakyReLU 激活函数的序列
        m = nn.Sequential(
            # 卷积层,使用指定的输入和输出通道数、卷积核大小,步长为 1,填充为 (kernel_size - 1) // 2,无偏置
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
                      stride=1, padding=(kernel_size - 1) // 2, bias=False),
            # 批量归一化层
            nn.BatchNorm2d(out_planes),
            # LeakyReLU 激活函数,负斜率为 0.1,原地操作
            nn.LeakyReLU(0.1, inplace=True)
        )
    return m

# 深度解码器类,继承自 nn.Module
class DepthDecoder(nn.Module):
    # 将元组转换为字符串,用于作为字典的键
    def tuple_to_str(self, key_tuple):
        key_str = '-'.join(str(key_tuple))
        return key_str

    def __init__(self, num_ch_enc, embedder, embedder_out_dim,
                 use_alpha=False, scales=range(4), num_output_channels=4,
                 use_skips=True, sigma_dropout_rate=0.0, **kwargs):
        # 调用父类的构造函数
        super(DepthDecoder, self).__init__()

        # 输出通道数
        self.num_output_channels = num_output_channels
        # 是否使用跳跃连接
        self.use_skips = use_skips
        # 上采样模式
        self.upsample_mode = 'nearest'
        # 要处理的尺度
        self.scales = scales
        # 是否使用 alpha
        self.use_alpha = use_alpha
        # sigma 的丢弃率
        self.sigma_dropout_rate = sigma_dropout_rate

        # 嵌入器
        self.embedder = embedder
        # 嵌入器的输出维度
        self.E = embedder_out_dim

        # 编码器最后一层的输出通道数
        final_enc_out_channels = num_ch_enc[-1]
        # 最大池化层,用于下采样
        self.downsample = nn.MaxPool2d(3, stride=2, padding=1)
        # 最近邻上采样层,用于上采样
        self.upsample = nn.UpsamplingNearest2d(scale_factor=2)
        # 第一个下采样卷积层
        self.conv_down1 = conv(final_enc_out_channels, 512, 1, False)
        # 第二个下采样卷积层
        self.conv_down2 = conv(512, 256, 3, False)
        # 第一个上采样卷积层
        self.conv_up1 = conv(256, 256, 3, False)
        # 第二个上采样卷积层
        self.conv_up2 = conv(256, final_enc_out_channels, 1, False)

        # 编码器各层的通道数
        self.num_ch_enc = num_ch_enc
        print("num_ch_enc=", num_ch_enc)
        # 将编码器各层的通道数加上嵌入器的输出维度
        self.num_ch_enc = [x + self.E for x in self.num_ch_enc]
        # 解码器各层的通道数
        self.num_ch_dec = np.array([16, 32, 64, 128, 256])
        # self.num_ch_enc = np.array([64, 64, 128, 256, 512])

        # 解码器的卷积层,使用 nn.ModuleDict 存储
        self.convs = nn.ModuleDict()
        # 从 4 到 0 遍历
        for i in range(4, -1, -1):
            # 上卷积层 0
            # 如果 i 为 4,则输入通道数为编码器最后一层的通道数;否则为解码器上一层的通道数
            num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
            # 输出通道数为解码器当前层的通道数
            num_ch_out = self.num_ch_dec[i]
            # 创建卷积块并添加到 convs 字典中
            self.convs[self.tuple_to_str(("upconv", i, 0))] = ConvBlock(num_ch_in, num_ch_out)
            print("upconv_{}_{}".format(i, 0), num_ch_in, num_ch_out)

            # 上卷积层 1
            # 输入通道数为解码器当前层的通道数
            num_ch_in = self.num_ch_dec[i]
            # 如果使用跳跃连接且 i 大于 0,则输入通道数加上编码器上一层的通道数
            if self.use_skips and i > 0:
                num_ch_in += self.num_ch_enc[i - 1]
            # 输出通道数为解码器当前层的通道数
            num_ch_out = self.num_ch_dec[i]
            # 创建卷积块并添加到 convs 字典中
            self.convs[self.tuple_to_str(("upconv", i, 1))] = ConvBlock(num_ch_in, num_ch_out)
            print("upconv_{}_{}".format(i, 1), num_ch_in, num_ch_out)

        # 遍历要处理的尺度
        for s in self.scales:
            # 创建一个 3x3 的卷积层并添加到 convs 字典中
            self.convs[self.tuple_to_str(("dispconv", s))] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)

        # Sigmoid 激活函数,用于将输出映射到 [0, 1] 范围
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_features, disparity):
        # 获取输入视差的批次大小和序列长度
        B, S = disparity.size()
        # 对输入视差进行嵌入操作,然后增加两个维度
        disparity = self.embedder(disparity.reshape(B * S, 1)).unsqueeze(2).unsqueeze(3)
---------------------------------------------------------------------------------------------------------------------
tensor_1d 是一个一维张量,直接使用 torch.tensor 创建,其形状为 (3,)。
tensor_2d 是通过对 tensor_1d 使用 unsqueeze(1) 增加一个维度得到的,形状为 (3, 1)import torch
import matplotlib.pyplot as plt

# 创建形状为 (3,) 的一维张量
tensor_1d = torch.tensor([1, 2, 3])

print("tensor_1d 是否定义:", 'tensor_1d' in locals())
print("tensor_1d 的值:", tensor_1d)

# 绘制一维张量
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(len(tensor_1d)), tensor_1d, marker='o')
plt.title('Shape (3,) Tensor')
plt.xlabel('Index')
plt.ylabel('Value')

# 创建形状为 (3, 1) 的二维张量
tensor_2d = tensor_1d.unsqueeze(1)

# 绘制二维张量
plt.subplot(1, 2, 2)
plt.bar(range(len(tensor_2d)), tensor_2d.flatten(), width=0.5)
plt.title('Shape (3, 1) Tensor')
plt.xlabel('Index')
plt.ylabel('Value')

plt.tight_layout()
plt.show()
---------------------------------------------------------------------------------------------------------------------
        # 扩展编码器的输出以增加感受野
        # 获取编码器最后一层的输出
        encoder_out = input_features[-1]
        # 对编码器输出进行下采样,然后通过第一个下采样卷积层
        conv_down1 = self.conv_down1(self.downsample(encoder_out))
        # 对第一个下采样卷积层的输出进行下采样,然后通过第二个下采样卷积层
        conv_down2 = self.conv_down2(self.downsample(conv_down1))
        # 对第二个下采样卷积层的输出进行上采样,然后通过第一个上采样卷积层
        conv_up1 = self.conv_up1(self.upsample(conv_down2))
        # 对第一个上采样卷积层的输出进行上采样,然后通过第二个上采样卷积层
        conv_up2 = self.conv_up2(self.upsample(conv_up1))

        # 重复 / 重塑特征
        # 获取第二个上采样卷积层输出的通道数、高度和宽度
        _, C_feat, H_feat, W_feat = conv_up2.size()
        # 对第二个上采样卷积层的输出进行扩展和重塑
        feat_tmp = conv_up2.unsqueeze(1).expand(B, S, C_feat, H_feat, W_feat) \
            .contiguous().view(B * S, C_feat, H_feat, W_feat)
        # 对视差进行重复操作以匹配特征图的大小
        disparity_BsCHW = disparity.repeat(1, 1, H_feat, W_feat)
        # 将扩展后的特征和视差拼接在一起
        conv_up2 = torch.cat((feat_tmp, disparity_BsCHW), dim=1)

        # 重复 / 重塑输入特征
        for i, feat in enumerate(input_features):
            # 获取输入特征的通道数、高度和宽度
            _, C_feat, H_feat, W_feat = feat.size()
            # 对输入特征进行扩展和重塑
            feat_tmp = feat.unsqueeze(1).expand(B, S, C_feat, H_feat, W_feat) \
                .contiguous().view(B * S, C_feat, H_feat, W_feat)
            # 对视差进行重复操作以匹配特征图的大小
            disparity_BsCHW = disparity.repeat(1, 1, H_feat, W_feat)
            # 将扩展后的特征和视差拼接在一起
            input_features[i] = torch.cat((feat_tmp, disparity_BsCHW), dim=1)

        # 解码器部分
        # 存储输出结果的字典
        outputs = {}
        # 初始输入为扩展后的第二个上采样卷积层的输出
        x = conv_up2
        # 从 4 到 0 遍历
        for i in range(4, -1, -1):
            # 通过上卷积层 0
            x = self.convs[self.tuple_to_str(("upconv", i, 0))](x)
            # 进行上采样
            x = [upsample(x)]
            # 如果使用跳跃连接且 i 大于 0
            if self.use_skips and i > 0:
                # 将编码器上一层的特征添加到列表中
                x += [input_features[i - 1]]
            # 将列表中的特征拼接在一起
            x = torch.cat(x, 1)
            # 通过上卷积层 1
            x = self.convs[self.tuple_to_str(("upconv", i, 1))](x)
            # 如果当前尺度在要处理的尺度列表中
            if i in self.scales:
                # 通过视差卷积层得到输出
                output = self.convs[self.tuple_to_str(("dispconv", i))](x)
                # 获取输出的高度和宽度
                H_mpi, W_mpi = output.size(2), output.size(3)
                # 调整输出的维度
                mpi = output.view(B, S, 4, H_mpi, W_mpi)
                # 对 RGB 通道应用 Sigmoid 激活函数
                mpi_rgb = self.sigmoid(mpi[:, :, 0:3, :, :])
                # 如果不使用 alpha,则取绝对值并加上一个小的常数;否则应用 Sigmoid 激活函数
                mpi_sigma = torch.abs(mpi[:, :, 3:, :, :]) + 1e-4 \
                        if not self.use_alpha \
                        else self.sigmoid(mpi[:, :, 3:, :, :])

                # 如果 sigma 丢弃率大于 0 且处于训练模式
                if self.sigma_dropout_rate > 0.0 and self.training:
                    # 对 sigma 通道应用 2D Dropout
                    mpi_sigma = F.dropout2d(mpi_sigma, p=self.sigma_dropout_rate)

                # 将 RGB 和 sigma 通道拼接在一起,并存储到输出字典中
                outputs[("disp", i)] = torch.cat((mpi_rgb, mpi_sigma), dim=2)

        return outputs

intrinsics.py

import torch
import torch.nn as nn
import numpy as np

# 定义一个用于学习焦距的神经网络模块
class LearnFocal(nn.Module):
    def __init__(self, H, W, req_grad, fx_only, order=2, init_focal=None, learn_distortion=False):
        # 调用父类 nn.Module 的构造函数
        super(LearnFocal, self).__init__()
        # 图像的高度
        self.H = H
        # 图像的宽度
        self.W = W
        # 一个布尔值,如果为 True,只输出 [fx, fx];如果为 False,输出 [fx, fy]
        self.fx_only = fx_only  
        # 焦距初始化的阶数,检查我们的补充部分有相关说明
        self.order = order  

        # 畸变相关
        # 是否学习畸变参数
        self.learn_distortion = learn_distortion
        if learn_distortion:
            # 第一个畸变系数,可根据 req_grad 设置是否需要梯度
            self.k1 = nn.Parameter(torch.tensor(0.0, dtype=torch.float32), requires_grad=req_grad)
            # 第二个畸变系数,可根据 req_grad 设置是否需要梯度
            self.k2 = nn.Parameter(torch.tensor(0.0, dtype=torch.float32), requires_grad=req_grad)

        if self.fx_only:
            if init_focal is None:
                # 如果没有提供初始焦距,将 fx 初始化为 1.0,可根据 req_grad 设置是否需要梯度
                self.fx = nn.Parameter(torch.tensor(1.0, dtype=torch.float32), requires_grad=req_grad)  # (1, )
            else:
                if self.order == 2:
                    # 根据公式 a**2 * W = fx 计算系数 a,即 a**2 = fx / W
                    coe_x = torch.tensor(np.sqrt(init_focal / float(W)), requires_grad=False).float()
                elif self.order == 1:
                    # 根据公式 a * W = fx 计算系数 a,即 a = fx / W
                    coe_x = torch.tensor(init_focal / float(W), requires_grad=False).float()
                else:
                    print('焦距初始化阶数需要为 1 或 2。退出')
                    exit()
                # 将计算得到的系数作为 fx,可根据 req_grad 设置是否需要梯度
                self.fx = nn.Parameter(coe_x, requires_grad=req_grad)  # (1, )
        else:
            if init_focal is None:
                # 如果没有提供初始焦距,将 fx 初始化为 1.0,可根据 req_grad 设置是否需要梯度
                self.fx = nn.Parameter(torch.tensor(1.0, dtype=torch.float32), requires_grad=req_grad)  # (1, )
                # 如果没有提供初始焦距,将 fy 初始化为 1.0,可根据 req_grad 设置是否需要梯度
                self.fy = nn.Parameter(torch.tensor(1.0, dtype=torch.float32), requires_grad=req_grad)  # (1, )
            else:
                if self.order == 2:
                    # 根据公式 a**2 * W = fx 计算 x 方向的系数 a,即 a**2 = fx / W
                    coe_x = torch.tensor(np.sqrt(init_focal / float(W)), requires_grad=False).float()
                    # 根据公式 a**2 * H = fy 计算 y 方向的系数 a,即 a**2 = fy / H
                    coe_y = torch.tensor(np.sqrt(init_focal / float(H)), requires_grad=False).float()
                elif self.order == 1:
                    # 根据公式 a * W = fx 计算 x 方向的系数 a,即 a = fx / W
                    coe_x = torch.tensor(init_focal / float(W), requires_grad=False).float()
                    # 根据公式 a * H = fy 计算 y 方向的系数 a,即 a = fy / H
                    coe_y = torch.tensor(init_focal / float(H), requires_grad=False).float()
                else:
                    print('焦距初始化阶数需要为 1 或 2。退出')
                    exit()
                # 将计算得到的 x 方向系数作为 fx,可根据 req_grad 设置是否需要梯度
                self.fx = nn.Parameter(coe_x, requires_grad=req_grad)  # (1, )
                # 将计算得到的 y 方向系数作为 fy,可根据 req_grad 设置是否需要梯度
                self.fy = nn.Parameter(coe_y, requires_grad=req_grad)  # (1, )

    def forward(self, i=None):  # 参数 i=None 只是为了支持多 GPU 训练
        if self.fx_only:
            if self.order == 2:
                # 根据公式计算 fx 和 fy,因为 fx_only 为 True,所以 fy 等于 fx
                fxfy = torch.stack([self.fx ** 2 * self.W, self.fx ** 2 * self.W])
            else:
                # 根据公式计算 fx 和 fy,因为 fx_only 为 True,所以 fy 等于 fx
                fxfy = torch.stack([self.fx * self.W, self.fx * self.W])
        else:
            if self.order == 2:
                # 根据公式计算 fx 和 fy
                fxfy = torch.stack([self.fx**2 * self.W, self.fy**2 * self.H])
            else:
                # 根据公式计算 fx 和 fy
                fxfy = torch.stack([self.fx * self.W, self.fy * self.H])
        if self.learn_distortion:
            # 如果要学习畸变参数,返回焦距和畸变系数
            return fxfy, self.k1, self.k2
        else:
            # 否则只返回焦距
            return fxfy

layers.py

# 版权所有 Niantic 2019。专利申请中。保留所有权利。
#
# 本软件遵循 Monodepth2 许可证的条款,
# 该许可证仅允许非商业用途,完整条款可在 LICENSE 文件中获取。

from __future__ import absolute_import, division, print_function

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

# 将网络的 sigmoid 输出转换为深度预测
# 此转换公式在论文的“额外考虑”部分给出
def disp_to_depth(disp, min_depth, max_depth):
    """
    将网络的 sigmoid 输出转换为深度预测
    该转换公式在论文的“额外考虑”部分给出。
    """
    # 最小视差,为最大深度的倒数
    min_disp = 1 / max_depth
    # 最大视差,为最小深度的倒数
    max_disp = 1 / min_depth
    # 缩放后的视差
    scaled_disp = min_disp + (max_disp - min_disp) * disp
    # 深度值,为缩放后视差的倒数
    depth = 1 / scaled_disp
    return scaled_disp, depth

# 将网络输出的 (轴角, 平移) 转换为 4x4 矩阵
def transformation_from_parameters(axisangle, translation, invert=False):
    """将网络的 (轴角, 平移) 输出转换为 4x4 矩阵
    """
    # 从轴角表示转换为旋转矩阵
    R = rot_from_axisangle(axisangle)
    # 克隆平移向量
    t = translation.clone()

    if invert:
        # 如果需要反转,对旋转矩阵进行转置
        R = R.transpose(1, 2)
        # 平移向量取负
        t *= -1

    # 将平移向量转换为 4x4 变换矩阵
    T = get_translation_matrix(t)

    if invert:
        # 如果需要反转,先旋转再平移
        M = torch.matmul(R, T)
    else:
        # 正常情况下,先平移再旋转
        M = torch.matmul(T, R)

    return M

# 将平移向量转换为 4x4 变换矩阵
def get_translation_matrix(translation_vector):
    """将平移向量转换为 4x4 变换矩阵
    """
    # 初始化一个全零的 4x4 矩阵,形状为 (batch_size, 4, 4)
    T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)

    # 将平移向量调整为 (batch_size, 3, 1) 的形状
    t = translation_vector.contiguous().view(-1, 3, 1)

    # 设置矩阵的对角元素为 1
    T[:, 0, 0] = 1
    T[:, 1, 1] = 1
    T[:, 2, 2] = 1
    T[:, 3, 3] = 1
    # 将平移向量添加到矩阵的最后一列
    T[:, :3, 3, None] = t

    return T

# 将轴角旋转表示转换为 4x4 变换矩阵
# (改编自 https://github.com/Wallacoloo/printipi)
# 输入 'vec' 必须是 Bx1x3 的形状
def rot_from_axisangle(vec):
    """将轴角旋转表示转换为 4x4 变换矩阵
    (改编自 https://github.com/Wallacoloo/printipi)
    输入 'vec' 必须是 Bx1x3 的形状
    """
    # 计算轴角的模长
    angle = torch.norm(vec, 2, 2, True)
    # 计算单位轴向量
    axis = vec / (angle + 1e-7)

    # 计算角度的余弦值
    ca = torch.cos(angle)
    # 计算角度的正弦值
    sa = torch.sin(angle)
    # 计算 1 - cos(angle)
    C = 1 - ca

    # 提取轴向量的 x 分量,并增加一个维度
    x = axis[..., 0].unsqueeze(1)
    # 提取轴向量的 y 分量,并增加一个维度
    y = axis[..., 1].unsqueeze(1)
    # 提取轴向量的 z 分量,并增加一个维度
    z = axis[..., 2].unsqueeze(1)

    # 计算 x * sin(angle)
    xs = x * sa
    # 计算 y * sin(angle)
    ys = y * sa
    # 计算 z * sin(angle)
    zs = z * sa
    # 计算 x * (1 - cos(angle))
    xC = x * C
    # 计算 y * (1 - cos(angle))
    yC = y * C
    # 计算 z * (1 - cos(angle))
    zC = z * C
    # 计算 x * y * (1 - cos(angle))
    xyC = x * yC
    # 计算 y * z * (1 - cos(angle))
    yzC = y * zC
    # 计算 z * x * (1 - cos(angle))
    zxC = z * xC

    # 初始化一个全零的 4x4 旋转矩阵,形状为 (batch_size, 4, 4)
    rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)

    # 设置旋转矩阵的元素
    rot[:, 0, 0] = torch.squeeze(x * xC + ca)
    rot[:, 0, 1] = torch.squeeze(xyC - zs)
    rot[:, 0, 2] = torch.squeeze(zxC + ys)
    rot[:, 1, 0] = torch.squeeze(xyC + zs)
    rot[:, 1, 1] = torch.squeeze(y * yC + ca)
    rot[:, 1, 2] = torch.squeeze(yzC - xs)
    rot[:, 2, 0] = torch.squeeze(zxC - ys)
    rot[:, 2, 1] = torch.squeeze(yzC + xs)
    rot[:, 2, 2] = torch.squeeze(z * zC + ca)
    rot[:, 3, 3] = 1

    return rot

# 定义一个卷积块,包含卷积层、批量归一化层和 ELU 激活函数
class ConvBlock(nn.Module):
    """执行卷积后接 ELU 的层
    """
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()

        # 3x3 卷积层
        self.conv = Conv3x3(in_channels, out_channels)
        # ELU 激活函数,原地操作
        self.nonlin = nn.ELU(inplace=True)
        # 批量归一化层
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        # 通过卷积层
        out = self.conv(x)
        # 通过批量归一化层
        out = self.bn(out)
        # 通过 ELU 激活函数
        out = self.nonlin(out)
        return out

# 定义一个 3x3 卷积层,包含填充操作
class Conv3x3(nn.Module):
    """对输入进行填充和卷积的层
    """
    def __init__(self, in_channels, out_channels, use_refl=True):
        super(Conv3x3, self).__init__()

        if use_refl:
            # 使用反射填充
            self.pad = nn.ReflectionPad2d(1)
        else:
            # 使用零填充
            self.pad = nn.ZeroPad2d(1)
        # 3x3 卷积层
        self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)

    def forward(self, x):
        # 进行填充操作
        out = self.pad(x)
        # 进行卷积操作
        out = self.conv(out)
        return out

# 将深度图像转换为点云的层
class BackprojectDepth(nn.Module):
    """将深度图像转换为点云的层
    """
    def __init__(self, batch_size, height, width):
        super(BackprojectDepth, self).__init__()

        # 批量大小
        self.batch_size = batch_size
        # 图像高度
        self.height = height
        # 图像宽度
        self.width = width

        # 生成二维网格坐标
        meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
        # 将网格坐标堆叠在一起,并转换为 float32 类型
        self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
        # 将网格坐标转换为 PyTorch 张量,并设置为不需要梯度
        self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
                                      requires_grad=False)

        # 初始化一个全为 1 的张量,形状为 (batch_size, 1, height * width),并设置为不需要梯度
        self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
                                 requires_grad=False)

        # 调整网格坐标的形状,并重复 batch_size 次
        self.pix_coords = torch.unsqueeze(torch.stack(
            [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
        self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
        # 将网格坐标和全 1 张量拼接在一起,并设置为不需要梯度
        self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
                                       requires_grad=False)

    def forward(self, depth, inv_K):
        # 将逆相机内参矩阵与像素坐标相乘
        cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
        # 将深度值与相机坐标相乘
        cam_points = depth.view(self.batch_size, 1, -1) * cam_points
        # 将相机坐标和全 1 张量拼接在一起
        cam_points = torch.cat([cam_points, self.ones], 1)

        return cam_points

# 将 3D 点投影到具有内参 K 和位置 T 的相机中的层
class Project3D(nn.Module):
    """将 3D 点投影到具有内参 K 和位置 T 的相机中的层
    """
    def __init__(self, batch_size, height, width, eps=1e-7):
        super(Project3D, self).__init__()

        # 批量大小
        self.batch_size = batch_size
        # 图像高度
        self.height = height
        # 图像宽度
        self.width = width
        # 防止除零的小常数
        self.eps = eps

    def forward(self, points, K, T):
        # 计算投影矩阵 P
        P = torch.matmul(K, T)[:, :3, :]

        # 将投影矩阵 P 与 3D 点相乘
        cam_points = torch.matmul(P, points)

        # 计算像素坐标
        pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
        # 调整像素坐标的形状
        pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
        # 交换维度
        pix_coords = pix_coords.permute(0, 2, 3, 1)
        # 归一化像素坐标
        pix_coords[..., 0] /= self.width - 1
        pix_coords[..., 1] /= self.height - 1
        # 将像素坐标映射到 [-1, 1] 范围
        pix_coords = (pix_coords - 0.5) * 2
        return pix_coords

# 将输入张量上采样 2 倍
def upsample(x):
    """将输入张量上采样 2 倍
    """
    return F.interpolate(x, scale_factor=2, mode="nearest")

# 计算视差图像的平滑损失
# 彩色图像用于边缘感知平滑
def get_smooth_loss(disp, img):
    """计算视差图像的平滑损失
    彩色图像用于边缘感知平滑
    """
    # 计算视差在 x 方向的梯度
    grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
    # 计算视差在 y 方向的梯度
    grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])

    # 计算图像在 x 方向的平均梯度
    grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
    # 计算图像在 y 方向的平均梯度
    grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)

    # 根据图像梯度对视差梯度进行加权
    grad_disp_x *= torch.exp(-grad_img_x)
    grad_disp_y *= torch.exp(-grad_img_y)

    # 返回视差梯度的平均值
    return grad_disp_x.mean() + grad_disp_y.mean()

# 计算一对图像之间 SSIM 损失的层
class SSIM(nn.Module):
    """计算一对图像之间 SSIM 损失的层
    """
    def __init__(self):
        super(SSIM, self).__init__()
        # 3x3 平均池化层,用于计算均值
        self.mu_x_pool   = nn.AvgPool2d(3, 1)
        self.mu_y_pool   = nn.AvgPool2d(3, 1)
        # 3x3 平均池化层,用于计算方差
        self.sig_x_pool  = nn.AvgPool2d(3, 1)
        self.sig_y_pool  = nn.AvgPool2d(3, 1)
        # 3x3 平均池化层,用于计算协方差
        self.sig_xy_pool = nn.AvgPool2d(3, 1)

        # 反射填充层
        self.refl = nn.ReflectionPad2d(1)

        # 常数 C1
        self.C1 = 0.01 ** 2
        # 常数 C2
        self.C2 = 0.03 ** 2

    def forward(self, x, y):
        # 对输入图像进行反射填充
        x = self.refl(x)
        y = self.refl(y)

        # 计算图像 x 的均值
        mu_x = self.mu_x_pool(x)
        # 计算图像 y 的均值
        mu_y = self.mu_y_pool(y)

        # 计算图像 x 的方差
        sigma_x  = self.sig_x_pool(x ** 2) - mu_x ** 2
        # 计算图像 y 的方差
        sigma_y  = self.sig_y_pool(y ** 2) - mu_y ** 2
        # 计算图像 x 和 y 的协方差
        sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y

        # 计算 SSIM 分子
        SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
        # 计算 SSIM 分母
        SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)

        # 计算 SSIM 损失,并将结果限制在 [0, 1] 范围内
        return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)

# 计算预测深度和真实深度之间的误差指标
def compute_depth_errors(gt, pred):
    """计算预测深度和真实深度之间的误差指标
    """
    # 计算预测深度和真实深度的比值的最大值
    thresh = torch.max((gt / pred), (pred / gt))
    # 计算阈值小于 1.25 的比例
    a1 = (thresh < 1.25     ).float().mean()
    # 计算阈值小于 1.25^2 的比例
    a2 = (thresh < 1.25 ** 2).float().mean()
    # 计算阈值小于 1.25^3 的比例
    a3 = (thresh < 1.25 ** 3).float().mean()

    # 计算均方误差
    rmse = (gt - pred) ** 2
    rmse = torch.sqrt(rmse.mean())

    # 计算对数均方误差
    rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
    rmse_log = torch.sqrt(rmse_log.mean())

    # 计算绝对相对误差
    abs_rel = torch.mean(torch.abs(gt - pred) / gt)

    # 计算平方相对误差
    sq_rel = torch.mean((gt - pred) ** 2 / gt)

    return abs_rel, sq_rel, rmse, rmse_log, a

nerf_feature.py

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data

# 定义一个名为 NerfWFeatures 的神经网络模块
class NerfWFeatures(nn.Module):
    def __init__(self, pos_in_dims, dir_in_dims, D):
        """
        :param pos_in_dims: 标量,编码后位置的通道数
        :param dir_in_dims: 标量,编码后方向的通道数
        :param D:           标量,隐藏层的维度数
        """
        # 调用父类 nn.Module 的构造函数
        super().__init__()

        # 存储编码后位置的通道数
        self.pos_in_dims = pos_in_dims
        # 存储编码后方向的通道数
        self.dir_in_dims = dir_in_dims

        # 定义第一层神经网络块,包含四个线性层和 ReLU 激活函数
        self.layers0 = nn.Sequential(
            nn.Linear(pos_in_dims, D), nn.ReLU(),
            nn.Linear(D, D), nn.ReLU(),
            nn.Linear(D, D), nn.ReLU(),
            nn.Linear(D, D), nn.ReLU(),
        )

        # 定义第二层神经网络块,包含四个线性层和 ReLU 激活函数,有一个跳跃连接
        self.layers1 = nn.Sequential(
            nn.Linear(D + pos_in_dims + 32, D), nn.ReLU(),  # 跳跃连接
            nn.Linear(D, D), nn.ReLU(),
            nn.Linear(D, D), nn.ReLU(),
            nn.Linear(D, D), nn.ReLU(),
        )

        # 定义用于计算密度的全连接层,最后使用 Softplus 激活函数
        self.fc_density = nn.Sequential(
            nn.Linear(D, 1), nn.Softplus()
        )
        # 定义用于提取特征的全连接层
        self.fc_feature = nn.Sequential(
            nn.Linear(D, D)
        )
        # 定义用于处理特征和方向信息以生成中间特征的层
        self.rgb_layers = nn.Sequential(nn.Linear(D + dir_in_dims, D // 2), nn.ReLU())
        # 定义用于从中间特征生成 RGB 颜色的全连接层
        self.fc_rgb = nn.Sequential(nn.Linear(D // 2, 3))

        # 以下代码被注释掉,原本用于初始化偏置
        # self.fc_density[0].bias.data = torch.tensor([0.1]).float()
        # self.fc_rgb[0].bias.data = torch.tensor([0.02, 0.02, 0.02]).float()

    def forward(self, pos_enc, dir_enc, cost_volume):
        """
        :param pos_enc: (H, W, N_sample, pos_in_dims) 编码后的位置
        :param dir_enc: (H, W, N_sample, dir_in_dims) 编码后的方向
        :return: rgb_density (H, W, N_sample, 4)
        """
        # 通过第一层神经网络块处理编码后的位置
        x = self.layers0(pos_enc)  # (H, W, N_sample, D)
        # 将处理后的结果、原始编码位置和代价体进行拼接
        x = torch.cat([x, pos_enc, cost_volume], dim=3)  # (H, W, N_sample, D+pos_in_dims)
        # 通过第二层神经网络块处理拼接后的结果
        x = self.layers1(x)  # (H, W, N_sample, D)

        # 计算密度
        density = self.fc_density(x)  # (H, W, N_sample, 1)

        # 提取特征
        feat = self.fc_feature(x)  # (H, W, N_sample, D)
        # 将提取的特征和编码后的方向进行拼接
        x = torch.cat([feat, dir_enc], dim=3)  # (H, W, N_sample, D+dir_in_dims)
        # 通过 rgb_layers 层处理拼接后的结果
        x = self.rgb_layers(x)  # (H, W, N_sample, D/2)
        # 生成 RGB 颜色
        rgb = self.fc_rgb(x)  # (H, W, N_sample, 3)

        # 将 RGB 颜色和密度进行拼接
        rgb_den = torch.cat([rgb, density], dim=3)  # (H, W, N_sample, 4)
        return rgb_den

nerf_mask.py

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data

# 定义一个名为 OfficialNerf 的神经网络模块,继承自 nn.Module
class OfficialNerf(nn.Module):
    def __init__(self, pos_in_dims, dir_in_dims, D):
        """
        :param pos_in_dims: 标量,编码后位置的通道数
        :param dir_in_dims: 标量,编码后方向的通道数
        :param D: 标量,隐藏层的维度数
        """
        # 调用父类的构造函数
        super(OfficialNerf, self).__init__()

        # 存储编码后位置的通道数
        self.pos_in_dims = pos_in_dims
        # 存储编码后方向的通道数
        self.dir_in_dims = dir_in_dims

        # 定义第一层神经网络序列,包含四个线性层和 ReLU 激活函数
        self.layers0 = nn.Sequential(
            nn.Linear(pos_in_dims, D), nn.ReLU(),
            nn.Linear(D, D), nn.ReLU(),
            nn.Linear(D, D), nn.ReLU(),
            nn.Linear(D, D), nn.ReLU(),
        )

        # 定义第二层神经网络序列,包含四个线性层和 ReLU 激活函数,有一个跳跃连接
        self.layers1 = nn.Sequential(
            nn.Linear(D + pos_in_dims, D), nn.ReLU(),  # 跳跃连接
            nn.Linear(D, D), nn.ReLU(),
            nn.Linear(D, D), nn.ReLU(),
            nn.Linear(D, D), nn.ReLU(),
        )
        # 定义掩码网络,包含两个线性层,中间有 ReLU 激活函数,最后有 Sigmoid 激活函数
        self.mask_net = nn.Sequential(nn.Linear(D, D), nn.ReLU(), nn.Linear(D, 1), nn.Sigmoid())
        # 定义密度预测网络,包含一个线性层和 Softplus 激活函数
        self.fc_density = nn.Sequential(nn.Linear(D, 1), nn.Softplus())
        # 定义特征提取的线性层
        self.fc_feature = nn.Linear(D, D)
        # 定义 RGB 处理的神经网络序列,包含一个线性层和 ReLU 激活函数
        self.rgb_layers = nn.Sequential(nn.Linear(D + dir_in_dims, D // 2), nn.ReLU())
        # 定义 RGB 预测的神经网络序列,包含一个线性层
        self.fc_rgb = nn.Sequential(nn.Linear(D // 2, 3))

        # 以下代码被注释掉,原本用于初始化偏置
        # self.fc_density[0].bias.data = torch.tensor([0.1]).float()
        # self.fc_rgb[0].bias.data = torch.tensor([0.02, 0.02, 0.02]).float()

    def forward(self, pos_enc, dir_enc):
        """
        :param pos_enc: (H, W, N_sample, pos_in_dims) 编码后的位置
        :param dir_enc: (H, W, N_sample, dir_in_dims) 编码后的方向
        :return: rgb_density (H, W, N_sample, 4)
        """
        # 通过第一层神经网络序列处理编码后的位置
        x = self.layers0(pos_enc)  # 输出形状为 (H, W, N_sample, D)
        # 将处理后的结果和原始编码位置在第 3 维拼接
        x = torch.cat([x, pos_enc], dim=3)  # 输出形状为 (H, W, N_sample, D + pos_in_dims)
        # 通过第二层神经网络序列处理拼接后的结果
        x = self.layers1(x)  # 输出形状为 (H, W, N_sample, D)

        # 通过掩码网络得到掩码概率
        mask_prob = self.mask_net(x)
        # 通过密度预测网络得到密度
        density = self.fc_density(x)  # 输出形状为 (H, W, N_sample, 1)

        # 通过特征提取线性层得到特征
        feat = self.fc_feature(x)  # 输出形状为 (H, W, N_sample, D)
        # 将特征和编码后的方向在第 3 维拼接
        x = torch.cat([feat, dir_enc], dim=3)  # 输出形状为 (H, W, N_sample, D + dir_in_dims)
        # 通过 RGB 处理神经网络序列处理拼接后的结果
        x = self.rgb_layers(x)  # 输出形状为 (H, W, N_sample, D / 2)
        # 通过 RGB 预测神经网络序列得到 RGB 值
        rgb = self.fc_rgb(x)  # 输出形状为 (H, W, N_sample, 3)

        # 将 RGB 值、密度和掩码概率在第3维拼接
        rgb_den = torch.cat([rgb, density, mask_prob], dim=3)  # 输出形状为 (H, W, N_sample, 4)
        return rgb_den

nerf_models.py

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data

# 定义 OfficialNerf 类,继承自 nn.Module,用于实现官方的 NeRF 模型
class OfficialNerf(nn.Module):
    def __init__(self, pos_in_dims, dir_in_dims, D):
        """
        :param pos_in_dims: 标量,编码后位置的通道数
        :param dir_in_dims: 标量,编码后方向的通道数
        :param D: 标量,隐藏层的维度数
        """
        # 调用父类的构造函数
        super(OfficialNerf, self).__init__()

        # 存储编码后位置的通道数
        self.pos_in_dims = pos_in_dims
        # 存储编码后方向的通道数
        self.dir_in_dims = dir_in_dims

        # 定义第一层神经网络序列,包含四个线性层和 ReLU 激活函数
        self.layers0 = nn.Sequential(
            nn.Linear(pos_in_dims, D), nn.ReLU(),
            nn.Linear(D, D), nn.ReLU(),
            nn.Linear(D, D), nn.ReLU(),
            nn.Linear(D, D), nn.ReLU(),
        )

        # 定义第二层神经网络序列,包含四个线性层和 ReLU 激活函数,有一个跳跃连接
        self.layers1 = nn.Sequential(
            nn.Linear(D + pos_in_dims, D), nn.ReLU(),  # 跳跃连接
            nn.Linear(D, D), nn.ReLU(),
            nn.Linear(D, D), nn.ReLU(),
            nn.Linear(D, D), nn.ReLU(),
        )

        # 定义密度预测网络,包含一个线性层和 Softplus 激活函数
        self.fc_density = nn.Sequential(nn.Linear(D, 1), nn.Softplus())
        # 定义特征提取的线性层
        self.fc_feature = nn.Linear(D, D)
        # 定义 RGB 处理的神经网络序列,包含一个线性层和 ReLU 激活函数
        self.rgb_layers = nn.Sequential(nn.Linear(D + dir_in_dims, D // 2), nn.ReLU())
        # 定义 RGB 预测的神经网络序列,包含一个线性层
        self.fc_rgb = nn.Sequential(nn.Linear(D // 2, 3))

        # 以下代码被注释掉,原本用于初始化偏置
        # self.fc_density[0].bias.data = torch.tensor([0.1]).float()
        # self.fc_rgb[0].bias.data = torch.tensor([0.02, 0.02, 0.02]).float()

    def forward(self, pos_enc, dir_enc):
        """
        :param pos_enc: (H, W, N_sample, pos_in_dims) 编码后的位置
        :param dir_enc: (H, W, N_sample, dir_in_dims) 编码后的方向
        :return: rgb_density (H, W, N_sample, 4)
        """
        # 通过第一层神经网络序列处理编码后的位置
        x = self.layers0(pos_enc)  # 输出形状为 (H, W, N_sample, D)
        # 将处理后的结果和原始编码位置在第 3 维拼接
        x = torch.cat([x, pos_enc], dim=3)  # 输出形状为 (H, W, N_sample, D + pos_in_dims)
        # 通过第二层神经网络序列处理拼接后的结果
        x = self.layers1(x)  # 输出形状为 (H, W, N_sample, D)

        # 通过密度预测网络得到密度
        density = self.fc_density(x)  # 输出形状为 (H, W, N_sample, 1)

        # 通过特征提取线性层得到特征
        feat = self.fc_feature(x)  # 输出形状为 (H, W, N_sample, D)
        # 将特征和编码后的方向在第 3 维拼接
        x = torch.cat([feat, dir_enc], dim=3)  # 输出形状为 (H, W, N_sample, D + dir_in_dims)
        # 通过 RGB 处理神经网络序列处理拼接后的结果
        x = self.rgb_layers(x)  # 输出形状为 (H, W, N_sample, D / 2)
        # 通过 RGB 预测神经网络序列得到 RGB 值
        rgb = self.fc_rgb(x)  # 输出形状为 (H, W, N_sample, 3)

        # 将 RGB 值和密度在第 3 维拼接
        rgb_den = torch.cat([rgb, density], dim=3)  # 输出形状为 (H, W, N_sample, 4)
        return rgb_den

# 定义 fullNeRF 类,继承自 nn.Module,用于实现完整的 NeRF 模型
class fullNeRF(nn.Module):
    def __init__(self, in_channels_xyz, in_channels_dir, W, D=8, skips=[4]):
        # 调用父类的构造函数
        super().__init__()
        # 存储网络的深度
        self.D = D
        # 存储隐藏层的宽度
        self.W = W
        # 存储跳跃连接的位置
        self.skips = skips
        # 存储输入位置编码的通道数
        self.in_channels_xyz = in_channels_xyz
        # 存储输入方向编码的通道数
        self.in_channels_dir = in_channels_dir

        # 定义位置编码层
        for i in range(D):
            if i == 0:
                # 第一层,输入维度为输入位置编码的通道数,输出维度为隐藏层宽度
                layer = nn.Linear(in_channels_xyz, W)
            elif i in skips:
                # 跳跃连接层,输入维度为隐藏层宽度加上输入位置编码的通道数,输出维度为隐藏层宽度
                layer = nn.Linear(W + in_channels_xyz, W)
            else:
                # 普通层,输入和输出维度均为隐藏层宽度
                layer = nn.Linear(W, W)
            # 为每个层添加 ReLU 激活函数
            layer = nn.Sequential(layer, nn.ReLU(True))
            # 将层添加到模型中
            setattr(self, f"xyz_encoding_{i + 1}", layer)
        # 定义位置编码的最终线性层
        self.xyz_encoding_final = nn.Linear(W, W)
        # 定义方向编码层
        self.dir_encoding = nn.Sequential(
            nn.Linear(W + in_channels_dir, W), nn.ReLU(True),
            nn.Linear(W, W // 2), nn.ReLU(True)
        )

        # 定义静态输出层
        # 静态密度预测层,包含一个线性层和 Softplus 激活函数
        self.static_sigma = nn.Sequential(nn.Linear(W, 1), nn.Softplus())
        # 静态 RGB 预测层,包含两个线性层,中间有 ReLU 激活函数
        self.static_rgb = nn.Sequential(nn.Linear(W // 2, W // 2), nn.ReLU(inplace=True),
                                        nn.Linear(W // 2, 3))

    def forward(self, input_xyz, input_dir_a):
        # 存储输入的位置编码
        xyz_ = input_xyz
        for i in range(self.D):
            if i in self.skips:
                # 如果是跳跃连接位置,将输入的位置编码和当前处理结果拼接
                xyz_ = torch.cat([input_xyz, xyz_], -1)
            # 通过相应的位置编码层处理
            xyz_ = getattr(self, f"xyz_encoding_{i + 1}")(xyz_)

        # 通过静态密度预测层得到静态密度
        static_sigma = self.static_sigma(xyz_)  # 输出形状为 (B, 1)

        # 通过位置编码的最终线性层得到最终位置编码
        xyz_encoding_final = self.xyz_encoding_final(xyz_)
        # 将最终位置编码和输入的方向编码拼接
        dir_encoding_input = torch.cat([xyz_encoding_final, input_dir_a], -1)
        # 通过方向编码层处理拼接后的结果
        dir_encoding = self.dir_encoding(dir_encoding_input)

        # 通过静态 RGB 预测层得到静态 RGB 值
        static_rgb = self.static_rgb(dir_encoding)
        # 将静态 RGB 值和静态密度拼接
        static = torch.cat([static_rgb, static_sigma], -1)  # 输出形状为 (B, 4)

        return static

poses.py

import torch
import torch.nn as nn
from utils.lie_group_helper import make_c2w

# 定义 LearnPose 类,继承自 nn.Module,用于学习相机位姿
class LearnPose(nn.Module):
    def __init__(self, num_cams, learn_R, learn_t, init_c2w=None):
        """
        :param num_cams: 相机的数量
        :param learn_R: 是否学习旋转部分,布尔值
        :param learn_t: 是否学习平移部分,布尔值
        :param init_c2w: (N, 4, 4) 的 torch 张量,表示初始的相机到世界的变换矩阵
        """
        # 调用父类的构造函数
        super(LearnPose, self).__init__()
        # 存储相机的数量
        self.num_cams = num_cams
        # 初始化初始相机到世界的变换矩阵为 None
        self.init_c2w = None
        if init_c2w is not None:
            # 如果提供了初始变换矩阵,将其作为不可训练的参数存储
            self.init_c2w = nn.Parameter(init_c2w, requires_grad=False)

        # 定义旋转参数,初始化为全零,是否可训练由 learn_R 决定
        self.r = nn.Parameter(torch.zeros(size=(num_cams, 3), dtype=torch.float32), requires_grad=learn_R)  # (N, 3)
        # 定义平移参数,初始化为全零,是否可训练由 learn_t 决定
        self.t = nn.Parameter(torch.zeros(size=(num_cams, 3), dtype=torch.float32), requires_grad=learn_t)  # (N, 3)

    def forward(self, cam_id):
        # 根据相机 ID 提取对应的旋转参数,形状为 (3, ),表示轴角
        r = self.r[cam_id]  # (3, ) 轴角
        # 根据相机 ID 提取对应的平移参数,形状为 (3, )
        t = self.t[cam_id]  # (3, )
        # 使用 make_c2w 函数将轴角和平移参数转换为相机到世界的变换矩阵,形状为 (4, 4)
        c2w = make_c2w(r, t)  # (4, 4)

        # 如果提供了初始变换矩阵,学习初始位姿和目标位姿之间的增量位姿
        if self.init_c2w is not None:
            # 将当前计算得到的变换矩阵与初始变换矩阵相乘
            c2w = c2w @ self.init_c2w[cam_id]

        return c2w

utils

├── utils/  # 工具文件夹
│   ├── align_traj.py  # 轨迹对齐脚本文件,用于对不同轨迹数据进行对齐操作
│   ├── comp_ate.py  # 计算绝对轨迹误差(Absolute Trajectory Error, ATE)的脚本文件
│   ├── comp_ray_dir.py  # 计算光线方向的脚本文件,常用于计算机视觉和三维重建中的光线追踪等场景
│   ├── lie_group_helper.py  # 李群相关辅助函数的脚本文件,李群在机器人运动学、计算机视觉中的位姿表示等方面有应用
│   ├── pos_enc.py  # 位置编码脚本文件,在深度学习模型(如NeRF)中用于对位置信息进行编码
│   ├── pose_utils.py  # 位姿处理工具脚本文件,包含处理相机位姿或物体位姿的相关函数
│   ├── split_dataset.py  # 数据集划分脚本文件,用于将数据集划分为训练集、验证集和测试集等
│   ├── training_utils.py  # 训练辅助工具脚本文件,包含训练模型时常用的工具函数,如学习率调整、损失函数计算等
│   ├── vgg.py  # VGG网络相关脚本文件,可能包含VGG模型的定义、加载预训练权重等操作
│   ├── vis_cam_traj.py  # 可视化相机轨迹的脚本文件,用于将相机在三维空间中的运动轨迹进行可视化展示
│   └── volume_op.py  # 体操作脚本文件,在三维重建、体渲染等场景中对三维体数据进行操作