动态

详情 返回 返回

從零實現3D Gaussian Splatting:完整渲染流程的PyTorch代碼詳解 - 动态 详情

3D Gaussian Splatting(3DGS)現在幾乎成了3D視覺領域的標配技術。NVIDIA把它整合進COSMOS,Meta的新款AR眼鏡可以直接在設備端跑3DGS做實時環境捕獲和渲染。這技術已經不只是停留在論文階段了,產品落地速度是相當快的。

所以這篇文章我們用PyTorch從頭實現最初那篇3DGS論文,代碼量控制在幾百行以內。雖然實現很簡潔但效果能達到SOTA水平。

需要説明的是,這裏主要講實現細節不會設計每個公式背後的數學推導。

場景如何表示

3DGS把場景表示成一堆各向異性的3D高斯分佈。這跟NeRF那種把場景隱式編碼在神經網絡裏的做法不太一樣,3DGS的表示是顯式的、可微的。

我們這次直接加載已經訓練好的場景,只跑forward pass,不涉及訓練和反向傳播。

pos = torch.load('trained_gaussians/kitchen/pos_7000.pt').cuda()
opacity_raw = torch.load('trained_gaussians/kitchen/opacity_raw_7000.pt').cuda()
f_dc = torch.load('trained_gaussians/kitchen/f_dc_7000.pt').cuda()
f_rest = torch.load('trained_gaussians/kitchen/f_rest_7000.pt').cuda()
scale_raw = torch.load('trained_gaussians/kitchen/scale_raw_7000.pt').cuda()
q_raw = torch.load('trained_gaussians/kitchen/q_rot_7000.pt').cuda()

每個3D高斯分佈用均值(pos)和協方差矩陣來描述,這倆參數定義了高斯分佈在3D空間裏的位置和形狀——是被拉長的還是接近球形。協方差矩陣用尺度(scale_raw)和旋轉(q_raw)來參數化,對應論文裏的公式(6)。

另外每個高斯還存了不透明度(opacity_raw)和顏色信息。顏色是view-dependent的,跟NeRF類似,所以不能存成固定的RGB值。這裏用球諧函數(Spherical Harmonics)來表示顏色隨觀察方向的變化,係數就是f_dcf_rest

渲染pipeline

3DGS的渲染可以分成兩個主要階段:

第一階段是預處理。把3D高斯投影到圖像平面並按深度排序然後組織成tile。這樣做是為了並行處理更高效,避免重複計算。

第二階段是針對每個tile做體積渲染。遍歷跟tile重疊的那些高斯,用體積渲染方程累加它們的貢獻。最後通過alpha compositing算出像素顏色,原理上跟NeRF的渲染類似,但這裏是在屏幕空間顯式實現的。

值得注意的是,3DGS在像素(tile)上並行化渲染,其實也可以在高斯上並行化。像素並行化的好處是相鄰像素能共享空間信息,可以提前剔除不相關的高斯,渲染速度基本不受高斯數量影響。這兩種策略篇論文都有討論。

代碼結構基本就是這兩個階段:

@torch.no_grad()
def render(pos, color, opacity_raw, sigma, c2w, H, W, fx, fy, cx, cy,
           near=2e-3, far=100, pix_guard=64, T=16, min_conis=1e-6,
           chi_square_clip=9.21, alpha_max=0.99, alpha_cutoff=1/255.):

    # Block 1: Project the Gaussians onto the image plane

    # Block2: Global depth sorting

    # Block3: Tiling

    # Block 4: Compute the inverse covariance matrix

    final_image = torch.zeros((H * W, 3), device=pos.device, dtype=pos.dtype)

    # Iterate over tiles
    for tile_id, s0, s1 in zip(unique_tile_ids.tolist(), start.tolist(), end.tolist()):

        current_gaussian_ids = gaussian_ids[s0:s1]

        # Block 5: Compute pixel coordinates for this tile

        # Block 6: Apply the volumetric rendering equation

    return final_image.reshape((H, W, 3)).clamp(0, 1)

1、投影高斯分佈

用相機內參(fx, fy, cx, cy)和相機到世界的變換矩陣(c2w)把3D的均值和協方差投影出來。論文裏公式5定義了協方差矩陣Σ怎麼從3D世界空間轉到2D相機空間。

uv, x, y, z = project_points(pos, c2w, fx, fy, cx, cy)
in_guard = (uv[:, 0] > -pix_guard) & (uv[:, 0] < W + pix_guard) & (
    uv[:, 1] > -pix_guard) & (uv[:, 1] < H + pix_guard) & (z > near) & (z < far)

uv = uv[in_guard]
pos = pos[in_guard]
color = color[in_guard]
opacity = torch.sigmoid(opacity_raw[in_guard]).clamp(0, 0.999)
z = z[in_guard]
x = x[in_guard]
y = y[in_guard]
sigma = sigma[in_guard]
idx = torch.nonzero(in_guard, as_tuple=False).squeeze(1)

# Project the covariance
Rcw = c2w[:3, :3]
Rwc = Rcw.t()
invz = 1 / z.clamp_min(1e-6)
invz2 = invz * invz
J = torch.zeros((pos.shape[0], 2, 3), device=pos.device, dtype=pos.dtype)
J[:, 0, 0] = fx * invz
J[:, 1, 1] = fy * invz
J[:, 0, 2] = -fx * x * invz2
J[:, 1, 2] = -fy * y * invz2
tmp = Rwc.unsqueeze(0) @ sigma @ Rwc.t().unsqueeze(0)  # Eq. 5
sigma_camera = J @ tmp @ J.transpose(1, 2)
# Enforce symmetry
sigma_camera = 0.5 * (sigma_camera + sigma_camera.transpose(1, 2))
# Ensure positive definiteness
evals, evecs = torch.linalg.eigh(sigma_camera)
evals = torch.clamp(evals, min=1e-6, max=1e4)
sigma_camera = evecs @ torch.diag_embed(evals) @ evecs.transpose(1, 2)

2、全局深度排序

為了保證透明度混合的正確性,要把高斯按深度從遠到近排序,這個思路跟NeRF的front-to-back累積是一樣的。

# Global depth sorting
order = torch.argsort(z, descending=False)
uv = uv[order]
u = uv[:, 0]
v = uv[:, 1]
color = color[order]
opacity = opacity[order]
sigma_camera = sigma_camera[order]
evals = evals[order]
idx = idx[order]

3、Tiling劃分

屏幕會被切成一個個tile(比如16×16像素),然後在tile內部的像素上並行渲染。這個設計跟Pulsar論文裏提到的基於tile、像素並行的策略一致。

# Tiling
major_variance = evals[:, 1].clamp_min(1e-12).clamp_max(1e4)  # [N]
radius = torch.ceil(3.0 * torch.sqrt(major_variance)).to(torch.int64)
umin = torch.floor(u - radius).to(torch.int64)
umax = torch.floor(u + radius).to(torch.int64)
vmin = torch.floor(v - radius).to(torch.int64)
vmax = torch.floor(v + radius).to(torch.int64)

on_screen = (umax >= 0) & (umin < W) & (vmax >= 0) & (vmin < H)
if not on_screen.any():
    raise Exception("All projected points are off-screen")
u, v = u[on_screen], v[on_screen]
color = color[on_screen]
opacity = opacity[on_screen]
sigma_camera = sigma_camera[on_screen]
umin, umax = umin[on_screen], umax[on_screen]
vmin, vmax = vmin[on_screen], vmax[on_screen]
idx = idx[on_screen]
umin = umin.clamp(0, W - 1)
umax = umax.clamp(0, W - 1)
vmin = vmin.clamp(0, H - 1)
vmax = vmax.clamp(0, H - 1)

# Tile index for each AABB
umin_tile = (umin // T).to(torch.int64)  # [N]
umax_tile = (umax // T).to(torch.int64)  # [N]
vmin_tile = (vmin // T).to(torch.int64)  # [N]
vmax_tile = (vmax // T).to(torch.int64)  # [N]

# Number of tiles each gaussian intersects
n_u = umax_tile - umin_tile + 1  # [N]
n_v = vmax_tile - vmin_tile + 1  # [N]

# Max number of tiles
max_u = int(n_u.max().item())
max_v = int(n_v.max().item())

nb_gaussians = umin_tile.shape[0]
span_indices_u = torch.arange(max_u, device=pos.device, dtype=torch.int64)  # [max_u]
span_indices_v = torch.arange(max_v, device=pos.device, dtype=torch.int64)  # [max_v]
tile_u = (umin_tile[:, None, None] + span_indices_u[None, :, None]
          ).expand(nb_gaussians, max_u, max_v)  # [N, max_u, max_v]
tile_v = (vmin_tile[:, None, None] + span_indices_v[None, None, :]
          ).expand(nb_gaussians, max_u, max_v)  # [N, max_u, max_v]
mask = (span_indices_u[None, :, None] < n_u[:, None, None]
        ) & (span_indices_v[None, None, :] < n_v[:, None, None])  # [N, max_u, max_v]
flat_tile_u = tile_u[mask]  # [0, 0, 1, 1, 2, ...]
flat_tile_v = tile_v[mask]  # [0, 1, 0, 1, 2]

nb_tiles_per_gaussian = n_u * n_v  # [N]
gaussian_ids = torch.repeat_interleave(
    torch.arange(nb_gaussians, device=pos.device, dtype=torch.int64),
    nb_tiles_per_gaussian)  # [0, 0, 0, 0, 1 ...]
nb_tiles_u = (W + T - 1) // T
flat_tile_id = flat_tile_v * nb_tiles_u + flat_tile_u  # [0, 0, 0, 0, 1 ...]

idx_z_order = torch.arange(nb_gaussians, device=pos.device, dtype=torch.int64)
M = nb_gaussians + 1
comp = flat_tile_id * M + idx_z_order[gaussian_ids]
comp_sorted, perm = torch.sort(comp)
gaussian_ids = gaussian_ids[perm]
tile_ids_1d = torch.div(comp_sorted, M, rounding_mode='floor')

# tile_ids_1d [0, 0, 0, 1, 1, 2, 2, 2, 2]
# nb_gaussian_per_tile [3, 2, 4]
# start [0, 3, 5]
# end [3, 5, 9]
unique_tile_ids, nb_gaussian_per_tile = torch.unique_consecutive(tile_ids_1d, return_counts=True)
start = torch.zeros_like(unique_tile_ids)
start[1:] = torch.cumsum(nb_gaussian_per_tile[:-1], dim=0)
end = start + nb_gaussian_per_tile

4、逆協方差矩陣

要算每個高斯的不透明度貢獻,先要計算出它的高斯概率密度函數(PDF),也就是論文裏的公式(4)。所以就需要逆協方差矩陣,這是直接從相機座標系下的協方差矩陣算出來。

inverse_covariance = inv2x2(sigma_camera)
inverse_covariance[:, 0, 0] = torch.clamp(
        inverse_covariance[:, 0, 0], min=min_conis)
inverse_covariance[:, 1, 1] = torch.clamp(
        inverse_covariance[:, 1, 1], min=min_conis)

5、像素座標計算

公式(4)還依賴每個像素到高斯中心的距離。在每個tile內部先算出屏幕空間的像素座標,下一步會用這些座標高效地計算所有高斯的公式(4)。

txi = tile_id % nb_tiles_u
tyi = tile_id // nb_tiles_u
x0, y0 = txi * T, tyi * T
x1, y1 = min((txi + 1) * T, W), min((tyi + 1) * T, H)
if x0 >= x1 or y0 >= y1:
    continue

xs = torch.arange(x0, x1, device=pos.device, dtype=pos.dtype)
ys = torch.arange(y0, y1, device=pos.device, dtype=pos.dtype)
pu, pv = torch.meshgrid(xs, ys, indexing='xy')
px_u = pu.reshape(-1)  # [T * T]
px_v = pv.reshape(-1)
pixel_idx_1d = (px_v * W + px_u).to(torch.int64)

6、體積渲染方程

最後用標準的alpha compositing累加顏色。每個高斯根據自己的不透明度和沿視線方向累積的透射率來貢獻顏色。這步對應NeRF裏的核心體積渲染原理。

gaussian_i_u = u[current_gaussian_ids]  # [N]
gaussian_i_v = v[current_gaussian_ids]  # [N]
gaussian_i_color = color[current_gaussian_ids]  # [N, 3]
gaussian_i_opacity = opacity[current_gaussian_ids]  # [N]
gaussian_i_inverse_covariance = inverse_covariance[current_gaussian_ids]  # [N, 2, 2]

du = px_u.unsqueeze(0) - gaussian_i_u.unsqueeze(-1)  # [N, T * T]
dv = px_v.unsqueeze(0) - gaussian_i_v.unsqueeze(-1)  # [N, T * T]
A11 = gaussian_i_inverse_covariance[:, 0, 0].unsqueeze(-1)  # [N, 1]
A12 = gaussian_i_inverse_covariance[:, 0, 1].unsqueeze(-1)
A22 = gaussian_i_inverse_covariance[:, 1, 1].unsqueeze(-1)
q = A11 * du * du + 2 * A12 * du * dv + A22 * dv * dv   # [N, T * T]

inside = q <= chi_square_clip
g = torch.exp(-0.5 * torch.clamp(q, max=chi_square_clip))  # [N, T * T]
g = torch.where(inside, g, torch.zeros_like(g))
alpha_i = (gaussian_i_opacity.unsqueeze(-1) * g).clamp_max(alpha_max)  # [N, T * T]
alpha_i = torch.where(alpha_i >= alpha_cutoff, alpha_i, torch.zeros_like(alpha_i))
one_minus_alpha_i = 1 - alpha_i  # [N, T * T]

T_i = torch.cumprod(one_minus_alpha_i, dim=0)
T_i = torch.concatenate([
    torch.ones((1, alpha_i.shape[-1]), device=pos.device, dtype=pos.dtype),
    T_i[:-1]], dim=0)
alive = (T_i > 1e-4).float()
w = alpha_i * T_i * alive  # [N, T * T]

final_image[pixel_idx_1d] = (w.unsqueeze(-1) * gaussian_i_color.unsqueeze(1)).sum(dim=0)

球諧函數表示

沒有給每個高斯分配單一的RGB值,而是用球諧函數(SH)把顏色表示成觀察方向的平滑函數。

簡單説,球諧函數就是球面上的傅里葉變換。可以把定義在所有方向上的函數(比如依賴視角的顏色)分解成一系列基函數的加權和。每個高斯學習一組SH係數,編碼了顏色隨觀察方向的變化規律,能捕捉鏡面高光或者表面相關的光照變化這類效果。

下面代碼計算每個高斯的SH顏色:

SH_C0 = 0.28209479177387814
SH_C1_x = 0.4886025119029199
SH_C1_y = 0.4886025119029199
SH_C1_z = 0.4886025119029199
SH_C2_xy = 1.0925484305920792
SH_C2_xz = 1.0925484305920792
SH_C2_yz = 1.0925484305920792
SH_C2_zz = 0.31539156525252005
SH_C2_xx_yy = 0.5462742152960396
SH_C3_yxx_yyy = 0.5900435899266435
SH_C3_xyz = 2.890611442640554
SH_C3_yzz_yxx_yyy = 0.4570457994644658
SH_C3_zzz_zxx_zyy = 0.3731763325901154
SH_C3_xzz_xxx_xyy = 0.4570457994644658
SH_C3_zxx_zyy = 1.445305721320277
SH_C3_xxx_xyy = 0.5900435899266435

def evaluate_sh(f_dc, f_rest, points, c2w):

    sh = torch.empty((points.shape[0], 16, 3),
                     device=points.device, dtype=points.dtype)
    sh[:, 0] = f_dc
    sh[:, 1:, 0] = f_rest[:, :15]  # R
    sh[:, 1:, 1] = f_rest[:, 15:30]  # G
    sh[:, 1:, 2] = f_rest[:, 30:45]  # B

    view_dir = points - c2w[:3, 3].unsqueeze(0)  # [N, 3]
    view_dir = view_dir / (view_dir.norm(dim=-1, keepdim=True) + 1e-8)
    x, y, z = view_dir[:, 0], view_dir[:, 1], view_dir[:, 2]

    xx, yy, zz = x * x, y * y, z * z
    xy, xz, yz = x * y, x * z, y * z

    Y0 = torch.full_like(x, SH_C0)  # [N]
    Y1 = - SH_C1_y * y
    Y2 = SH_C1_z * z
    Y3 = - SH_C1_x * x
    Y4 = SH_C2_xy * xy
    Y5 = SH_C2_yz * yz
    Y6 = SH_C2_zz * (3 * zz - 1)
    Y7 = SH_C2_xz * xz
    Y8 = SH_C2_xx_yy * (xx - yy)
    Y9 = SH_C3_yxx_yyy * y * (3 * xx - yy)
    Y10 = SH_C3_xyz * x * y * z
    Y11 = SH_C3_yzz_yxx_yyy * y * (4 * zz - xx - yy)
    Y12 = SH_C3_zzz_zxx_zyy * z * (2 * zz - 3 * xx - 3 * yy)
    Y13 = SH_C3_xzz_xxx_xyy * x * (4 * zz - xx - yy)
    Y14 = SH_C3_zxx_zyy * z * (xx - yy)
    Y15 = SH_C3_xxx_xyy * x * (xx - 3 * yy)
    Y = torch.stack([Y0, Y1, Y2, Y3, Y4, Y5, Y6, Y7, Y8, Y9, Y10, Y11, Y12, Y13, Y14, Y15],
                    dim=1)  # [N, 16]
    return torch.sigmoid((sh * Y.unsqueeze(2)).sum(dim=1))

這種表示方式很緊湊,表達能力又強能建模複雜的view-dependent效果,還不用在渲染時跑神經網絡。實際應用中,3DGS一般用3階球諧函數,對應每個顏色通道16個係數。這個配置在視覺真實感和內存效率之間平衡得不錯。

有個常問的面試題:100萬個高斯的3DGS表示需要多少存儲空間?

每個高斯存儲的內容包括:球諧函數48個係數(16×3),位置3個float,不透明度1個float,尺度3個float,旋轉四元數4個float。總共59個float,按每個float 4字節算,大概236字節,100萬個高斯就是225 MB左右。

另外一個進階問題:如果只用2階球諧函數呢?

那就是(9×3)+3=30個係數而不是48個,總數從59降到44個float,內存能省25%左右。

輔助函數實現

前面用到的幾個輔助函數現在實現一下。首先是計算2×2矩陣的逆,這個比較簡單:

def inv2x2(M, eps=1e-12):
    a = M[:, 0, 0]
    b = M[:, 0, 1]
    c = M[:, 1, 0]
    d = M[:, 1, 1]
    det = a * d - b * c
    safe_det = torch.clamp(det, min=eps)
    inv = torch.empty_like(M)
    inv[:, 0, 0] = d / safe_det
    inv[:, 0, 1] = -b / safe_det
    inv[:, 1, 0] = -c / safe_det
    inv[:, 1, 1] = a / safe_det
    return inv

然後是透視相機的光柵化,把3D點從世界空間投影到2D圖像平面。這步用相機內參和外參確定每個高斯在屏幕空間的位置。

從學習到的參數構建協方差矩陣。每個高斯由尺度和旋轉四元數定義,它們決定了3D空間裏的形狀和朝向。先把四元數轉成旋轉矩陣,再跟對角尺度矩陣組合成完整的3D協方差,渲染時這個協方差會投影到相機空間。

def project_points(pc, c2w, fx, fy, cx, cy):
    w2c = torch.eye(4, device=pc.device)
    R = c2w[:3, :3]
    t = c2w[:3, 3]
    w2c[:3, :3] = R.t()
    w2c[:3, 3] = -R.t() @ t

    PC = ((w2c @ torch.concatenate(
        [pc, torch.ones_like(pc[:, :1])], dim=1).t()).t())[:, :3]
    x, y, z = PC[:, 0], PC[:, 1], PC[:, 2]  # Camera space

    uv = torch.stack([fx * x / z + cx, fy * y / z + cy], dim=-1)
    return uv, x, y, z

完整流程整合

最後把所有模塊組合起來

if __name__ == "__main__":

    pos = torch.load('trained_gaussians/kitchen/pos_7000.pt').cuda()
    opacity_raw = torch.load('trained_gaussians/kitchen/opacity_raw_7000.pt').cuda()
    f_dc = torch.load('trained_gaussians/kitchen/f_dc_7000.pt').cuda()
    f_rest = torch.load('trained_gaussians/kitchen/f_rest_7000.pt').cuda()
    scale_raw = torch.load('trained_gaussians/kitchen/scale_raw_7000.pt').cuda()
    q_raw = torch.load('trained_gaussians/kitchen/q_rot_7000.pt').cuda()

    cam_parameters = np.load('out_colmap/kitchen/cam_meta.npy', 
                             allow_pickle=True).item()
    orbit_c2ws = torch.load('camera_trajectories/kitchen_orbit.pt').cuda()

    sigma = build_sigma_from_params(scale_raw, q_raw)

    with torch.no_grad():
        for i, c2w_i in tqdm(enumerate(orbit_c2ws)):

            c2w = c2w_i
            H = cam_parameters['height'] // 2
            W = cam_parameters['width'] // 2
            H_src = cam_parameters['height']
            W_src = cam_parameters['width']
            fx, fy = cam_parameters['fx'], cam_parameters['fy']
            cx, cy = W_src / 2, H_src / 2
            fx, fy, cx, cy = scale_intrinsics(H, W, H_src, W_src, fx, fy, cx, cy)

            color = evaluate_sh(f_dc, f_rest, pos, c2w)
            img = render(pos, color, opacity_raw, sigma, c2w, H, W, fx, fy, cx, cy)

            Image.fromarray((img.cpu().detach().numpy() * 255).astype(np.uint8)
                           ).save(f'novel_views/frame_{i:04d}.png')

總結

這篇文章我們用純PyTorch實現了3D Gaussian Splatting的完整渲染pipeline,代碼量控制在幾百行以內。整個實現圍繞兩個核心階段展開:預處理階段完成3D高斯到2D圖像平面的投影、深度排序和tile劃分;渲染階段則通過體積渲染方程完成alpha compositing。

完整代碼:

https://avoid.overfit.cn/post/0da3f0cac46049a7a3fdb063fb50f757

Add a new 评论

Some HTML is okay.