动态

详情 返回 返回

3D Gaussian splatting 07: 代碼閲讀-訓練載入數據和保存結果 - 动态 详情

目錄

  • 3D Gaussian splatting 01: 環境搭建
  • 3D Gaussian splatting 02: 快速評估
  • 3D Gaussian splatting 03: 用户數據訓練和結果查看
  • 3D Gaussian splatting 04: 代碼閲讀-提取相機位姿和稀疏點雲
  • 3D Gaussian splatting 05: 代碼閲讀-訓練整體流程
  • 3D Gaussian splatting 06: 代碼閲讀-訓練參數
  • 3D Gaussian splatting 07: 代碼閲讀-訓練載入數據和保存結果
  • 3D Gaussian splatting 08: 自建模型展示網頁

訓練載入數據

在 train.py 中載入數據對應的方法調用棧如下, 因為convert.py預處理使用的是colmap, 讀取數據最終調用的是 readColmapSceneInfo 方法

Scene(dataset, gaussians)
└─sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.depths, args.eval, args.train_test_exp)
  └─readColmapSceneInfo(path, images, depths, eval, train_test_exp, llffhold=8)
    ├─readColmapCameras(cam_extrinsics, cam_intrinsics, depths_params, images_folder, depths_folder, test_cam_names_list)
    ├─read_points3D_binary(path_to_model_file)
    ├─storePly(path, xyz, rgb)
    └─fetchPly(path)

讀取流程是

  1. 從 cameras.bin, images.bin 讀取相機內參和圖像外參(位姿)
  2. 區分訓練集和測試集
  3. 從 points3D.bin 讀取3D點雲

在 readColmapSceneInfo() 方法中, 如果設置了--eval參數, 會將cam_names 排序後, 按序號與 llffhold 求餘是否為0分為訓練集和測試集. llffhold 值為8, 所以訓練集與測試集的比例為 7:1. 如果沒有指定, 則全部數據作為訓練集. 如果要手工指定測試集, 可以在 sparse/0 下創建一個 test.txt, 將參數 llffhold 的默認值改為0.

if eval:
    if "360" in path:
        llffhold = 8
    if llffhold:
        print("------------LLFF HOLD-------------")
        cam_names = [cam_extrinsics[cam_id].name for cam_id in cam_extrinsics]
        cam_names = sorted(cam_names)
        test_cam_names_list = [name for idx, name in enumerate(cam_names) if idx % llffhold == 0]
    else:
        with open(os.path.join(path, "sparse/0", "test.txt"), 'r') as file:
            test_cam_names_list = [line.strip() for line in file]

得到 test_cam_names_list, 會在 readColmapCameras 讀取 CameraInfo 時, 設置為鏡頭的 is_test 屬性, 在後續訓練,渲染和評估時, 用於區分是訓練集還是測試集.

cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, depth_params=depth_params,
                        image_path=image_path, image_name=image_name, depth_path=depth_path,
                        width=width, height=height, is_test=image_name in test_cam_names_list)

讀取鏡頭焦距時有一個轉換, 從焦距轉成視場角

focal_length_x = intr.params[0]
FovY = focal2fov(focal_length_x, height)

轉換方法的定義

def fov2focal(fov, pixels):
    return pixels / (2 * math.tan(fov / 2))

def focal2fov(focal, pixels):
    return 2*math.atan(pixels/(2*focal))

概念:

  • 視場角 FOV (Field of View): 相機視場的角度範圍, 通常按垂直/水平方向分別計算
  • 焦距 Focal Length: 相機從傳感器到視場面的距離, 與成像大小成正比

fov2focal(fov, pixels)

  • 將視場角(FOV)轉換為焦距(focal length)
  • 參數
    • fov: 視場角, 表示相機能觀察到的角度範圍
    • pixels: 圖像傳感器在某個維度(寬度或高度)的像素數量

基於針孔相機模型,焦距 \(f\) 的計算公式為

\[ f = \frac{\text{pixels}}{2 \cdot \tan(\text{FOV}/2)} \]

若圖像高度為 1080 像素, 垂直 FOV 為 60°, 對應弧度為math.radians(60)

focal = 1080 / (2 * math.tan(math.radians(60)/2)) ≈ 935.3

focal2fov(focal, pixels):

  • 將焦距(focal length)轉換為視場角(FOV)
  • 參數
    • focal: 焦距(單位: 像素)
    • pixels: 圖像傳感器在某個維度的像素數量

視場角 \(\theta\) 的計算公式為

\[ \theta = 2 \cdot \arctan\left(\frac{\text{pixels}}{2 \cdot f}\right) \]

若焦距為 935.3 像素, 圖像高度為 1080 像素:

fov = 2 * math.atan(1080 / (2 * 935.3)) ≈ 1.047 弧度, 對應角度約 60°

再下面會判斷是否有 points3D.ply, 存在就讀取, 不存在就創建一個再讀取

    ply_path = os.path.join(path, "sparse/0/points3D.ply")
    bin_path = os.path.join(path, "sparse/0/points3D.bin")
    txt_path = os.path.join(path, "sparse/0/points3D.txt")
    if not os.path.exists(ply_path):
        print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
        try:
            xyz, rgb, _ = read_points3D_binary(bin_path)
        except:
            xyz, rgb, _ = read_points3D_text(txt_path)
        storePly(ply_path, xyz, rgb)
    try:
        pcd = fetchPly(ply_path)

從 points3D.bin 讀取3D點雲

def read_points3D_binary(path_to_model_file):
    """
    see: src/base/reconstruction.cc
        void Reconstruction::ReadPoints3DBinary(const std::string& path)
        void Reconstruction::WritePoints3DBinary(const std::string& path)
    """

    with open(path_to_model_file, "rb") as fid:
        num_points = read_next_bytes(fid, 8, "Q")[0]

        # 創建未初始化的 n * 3 數組, 隨機值
        xyzs = np.empty((num_points, 3))
        rgbs = np.empty((num_points, 3))
        errors = np.empty((num_points, 1))

        for p_id in range(num_points):
            binary_point_line_properties = read_next_bytes(
                fid, num_bytes=43, format_char_sequence="QdddBBBd")
            xyz = np.array(binary_point_line_properties[1:4])
            rgb = np.array(binary_point_line_properties[4:7])
            error = np.array(binary_point_line_properties[7])
            track_length = read_next_bytes(
                fid, num_bytes=8, format_char_sequence="Q")[0]
            track_elems = read_next_bytes(
                fid, num_bytes=8*track_length,
                format_char_sequence="ii"*track_length)
            xyzs[p_id] = xyz
            rgbs[p_id] = rgb
            errors[p_id] = error
    return xyzs, rgbs, errors

裏面用到的read_next_bytes方法, 讀取一段二進制字節, 使用struct.unpack按指定的格式, 轉為對應的變量

def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
    """Read and unpack the next bytes from a binary file.
    :param fid:
    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
    :param endian_character: Any of {@, =, <, >, !}
    :return: Tuple of read and unpacked values.
    """
    data = fid.read(num_bytes)
    return struct.unpack(endian_character + format_char_sequence, data)

BasicPointCloud

BasicPointCloud 用於表示三維點雲的基礎數據結構, 包含座標、顏色和法線信息

class BasicPointCloud(NamedTuple):
    points : np.array
    colors : np.array
    normals : np.array

def geom_transform_points(points, transf_matrix):
    # 將點轉換為齊次座標後應用變換矩陣, 返回經過投影變換後的三維座標, PyTorch實現的齊次座標變換,支持批量變換操作
    P, _ = points.shape
    ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
    points_hom = torch.cat([points, ones], dim=1)
    points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))

    denom = points_out[..., 3:] + 0.0000001
    return (points_out[..., :3] / denom).squeeze(dim=0)

def getWorld2View(R, t):
    # 創建世界座標系到相機座標系的4x4變換矩陣 R: 3x3旋轉矩陣,t: 3D平移向量
    Rt = np.zeros((4, 4))
    Rt[:3, :3] = R.transpose()
    Rt[:3, 3] = t
    Rt[3, 3] = 1.0
    return np.float32(Rt)

def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
    # 增強版視圖變換,支持場景平移和縮放, 通過相機到世界座標系的逆變換實現
    Rt = np.zeros((4, 4))
    Rt[:3, :3] = R.transpose()
    Rt[:3, 3] = t
    Rt[3, 3] = 1.0

    C2W = np.linalg.inv(Rt)
    cam_center = C2W[:3, 3]
    cam_center = (cam_center + translate) * scale
    C2W[:3, 3] = cam_center
    Rt = np.linalg.inv(C2W)
    return np.float32(Rt)

def getProjectionMatrix(znear, zfar, fovX, fovY):
    # 生成透視投影矩陣 參數包含近/遠裁剪面,水平和垂直視場角 返回4x4投影矩陣
    tanHalfFovY = math.tan((fovY / 2))
    tanHalfFovX = math.tan((fovX / 2))

    top = tanHalfFovY * znear
    bottom = -top
    right = tanHalfFovX * znear
    left = -right

    P = torch.zeros(4, 4)

    z_sign = 1.0

    P[0, 0] = 2.0 * znear / (right - left)
    P[1, 1] = 2.0 * znear / (top - bottom)
    P[0, 2] = (right + left) / (right - left)
    P[1, 2] = (top + bottom) / (top - bottom)
    P[3, 2] = z_sign
    P[2, 2] = z_sign * zfar / (zfar - znear)
    P[2, 3] = -(zfar * znear) / (zfar - znear)
    return P

def fov2focal(fov, pixels):
    # 視場角轉焦距(單位:像素)
    return pixels / (2 * math.tan(fov / 2))

def focal2fov(focal, pixels):
    # 焦距轉視場角
    return 2*math.atan(pixels/(2*focal))

訓練結果數據結構

安裝 pyntcloud

pip install pyntcloud

查看 ply 文件

>>> from pyntcloud import PyntCloud
>>> cloud = PyntCloud.from_file("output/1ed8e6a1-9/point_cloud/iteration_7000/point_cloud.ply")
>>> print(cloud)
PyntCloud
743269 points with 59 scalar fields
0 faces in mesh
0 kdtrees
0 voxelgrids
Centroid: 1.6537141799926758, -2.9306182861328125, -4.471662521362305

>>> type(cloud.points)
<class 'pandas.core.frame.DataFrame'>

點的數據類型是 DataFrame, 查看第一個點的屬性列, 每一項都是float32/4個字節, 但是屬性太多被省略了

>>> print(cloud.points.loc[0])
x          1.947371
y         -0.500535
z          1.388533
nx         0.000000
ny         0.000000
             ...   
scale_2   -4.380099
rot_0      0.840099
rot_1     -0.143527
rot_2      0.065419
rot_3      0.179504
Name: 0, Length: 62, dtype: float32

此去掉rows限制, 就可以打印全貌了

>>> pd.set_option('display.max_rows', None)
>>> print(cloud.points.loc[0])
x            1.947371
y           -0.500535
z            1.388533
nx           0.000000
ny           0.000000
nz           0.000000
f_dc_0      -0.264158
f_dc_1       0.352959
f_dc_2       0.361867
f_rest_0     0.012889
f_rest_1    -0.001385
f_rest_2     0.044487
f_rest_3     0.013909
# 省略 f_rest_ 開頭的字段
f_rest_41   -0.038870
f_rest_42   -0.015730
f_rest_43    0.042109
f_rest_44    0.021378
opacity     -1.817663
scale_0     -5.108221
scale_1     -4.811676
scale_2     -4.380099
rot_0        0.840099
rot_1       -0.143527
rot_2        0.065419
rot_3        0.179504
Name: 0, dtype: float32

結果數據輸出的時候是通過拼接參數產生的

attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)

對應的屬性含義

  • x, y, z: 3D點雲位置座標
  • nx, ny, nz: 未使用
  • f_dc_0 - f_dc_2, f_rest_0 - f_rest_44: 顏色特徵的DC分量和剩餘分量, 3階一共16個RGB球諧係數
  • opacity: 不透明度參數
  • scale_0 - scale_2: 縮放參數
  • rot_0 - rot_3: 旋轉參數
user avatar testing- 头像
点赞 1 用户, 点赞了这篇动态!
点赞

Add a new 评论

Some HTML is okay.