Loading [MathJax]/jax/output/CommonHTML/jax.js

ABOUT ME

Deep Learning, Machine Learning, Math and Photo

Today
Yesterday
Total
  • [코드리뷰] Mip-NeRF Code Breakdown
    NeRF 2023. 7. 26. 17:58

    * 해당 포스팅은 카카오에서 pytorch 코드로 구현한 Mip-NeRF 코드를 기반으로 합니다.

    ▶ Mip-NeRF 논문 리뷰: 2023.07.23 - [Papers] - [논문리뷰] Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields

     

    [논문리뷰] Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields

    ICCV 2021 [Paper][Code(Jax)] Authors (Google, UC Berkeley) Jonathan T. Barron Ben Mildenhall Matthew Tancik Peter Hedman Ricardo Martin-Brualla Pratul P. Srinivasan 0. Abstract 기존 모델의 문제점 Ray를 사용한다. NeRF는 aliased, blurred 된

    libby-yu.tistory.com

    ▶ NeRF 코드 리뷰: 2023.07.03 - [NeRF] - [코드리뷰] NeRF Code Breakdown

     

    [코드리뷰] NeRF Code Breakdown

    작성중 Prepare Dataset if not os.path.exists('tiny_nerf_data.npz'): !wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.np data = np.load('tiny_nerf_data.npz') images = data['images'] poses = data['poses'] focal = data['f

    libby-yu.tistory.com


    본 포스팅에서는 NeRF 코드와 상이한 부분만 리뷰하겠습니다.

     

    1. Cone tracing - Sampling
    2. Integrated positional encoding (IPE)

     

    Cone Tracing

    기존 NeRF에서 ray를 따라 sampling 했던 것과는 달리, Mip-NeRF는 cone(원뿔)을 캐스팅한 후 이를 일정 간격의 볼륨 (conical frustrum) 으로 sampling (cone tracing) 하고 있다. 

    1. Casting Cone

    μt,σ2t,σ2r 값을 활용하여 Multivariate Gaussian을 구하는 과정이다. 여기서 Multivariate Gaussian은 conical frustrum 내부 볼륨의 특성을 의미한다.

    
      
    def cast_rays(t_vals, origins, directions, radii, ray_shape):
    t0 = t_vals[..., :-1]
    t1 = t_vals[..., 1:]
    if ray_shape == "cone":
    gaussian_fn = conical_frustum_to_gaussian
    elif ray_shape == "cylinder":
    gaussian_fn = cylinder_to_gaussian
    else:
    assert False
    means, covs = gaussian_fn(directions, t0, t1, radii)
    means = means + origins[..., None, :]
    return means, covs

    1-1) Conical Frustrum

    두개의 t 값 사이의 conical frustrum에 위치한 x의 집합은 아래와 같다. F(x,o,d,˙r,t0,t1)={(t0<dT(xo)d22<t1)(dT(xo)d2xo2>11+(˙r/d2)2)} 만약 x(x,o,d,˙r,t0,t1)로 정의된 conical frustrum 안에 위치해 있다면, F(x,)=1이다. Conical frustrum 내부의 모든 좌표에 대해 PE를 계산하면, γ(o,t,˙r,t0,t1)=γ(x)F(x,o,d,˙r,t0,t1)dxF(x,o,d,˙r,t0,t1)dx Gaussian을 구하기 위해 μt,σ2t,σ2r 값을 구한다.

    
      
    def conical_frustum_to_gaussian(d, t0, t1, radius):
    mu = (t0 + t1) / 2 # 중심점
    delta = (t1 - t0) / 2 # 너비의 절반
    t_mean = mu + (2 * mu * delta**2) / (3 * mu**2 + delta**2) # Ray의 평균 거리
    t_var = (delta**2) / 3 - (4 / 15) * (
    (delta**4 * (12 * mu**2 - delta**2)) / (3 * mu**2 + delta**2) ** 2
    ) # Ray의 분산
    r_var = radius**2 * (
    (mu**2) / 4
    + (5 / 12) * delta**2
    - 4 / 15 * (delta**4) / (3 * mu**2 + delta**2)
    ) # Ray의 수직분산
    return translate_gaussian(d, t_mean, t_var, r_var)
    • 중심점: tμ=(t0+t1)/2
    • 너비의 절반: tδ=(t1t0)/2 * Interval (Critical for numerical stability)
    • Ray의 평균 거리: μt=μ+2μt2δ3μ2+t2δ
    • Ray 분산: σ2t=t2δ3415t4δ(12μ2t2δ)(3μ2+t2δ)2
    • Ray의 수직 분산 (Variance perpendicular): σ2r=˙r2(μ24+512t2δ415t4δ3μ2+t2δ) 

    1-2) Cylinder

    
      
    def cylinder_to_gaussian(d, t0, t1, radius):
    t_mean = (t0 + t1) / 2
    r_var = radius**2 / 4
    t_var = (t1 - t0) ** 2 / 12
    return translate_gaussian(d, t_mean, t_var, r_var)

    2) Translate to Gaussian

    Conical frustrum 의 좌표계 상의 Gaussian 을 World 좌표계로 변환하는 과정 μ=o+μtd,=σ2t(ddT)+σ2r(IddTd22) 

    * 아래 코드는 카카오 구현에서 lift_gaussian(.) 함수로 정의되어 있지만, 논문의 gaussian lifting 의미와 상이하여 임의로 함수이름을 고쳤습니다.

    
      
    def translate_gaussian(d, t_mean, t_var, r_var):
    mean = d[..., None, :] * t_mean[..., None]
    d_mag_sq = torch.sum(d**2, dim=-1, keepdim=True)
    thresholds = torch.ones_like(d_mag_sq) * 1e-10
    d_mag_sq = torch.fmax(d_mag_sq, thresholds) # 나누어 주는 값이기 때문에 최솟값을 0이 아닌 1e-10으로 고정
    d_outer_diag = d**2
    null_outer_diag = 1 - d_outer_diag / d_mag_sq
    t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :]
    xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :]
    cov_diag = t_cov_diag + xy_cov_diag
    return mean, cov_diag
    • μ=o+μtd
    • d_mag_sq: d22
    • d_outer_diag: ddT
    • null_outer_diag: IddTd22
    • t_cov_diag: σ2t(ddT)
    • xy_cov_diag: σ2r(IddTd22)
    • cov_diag:

    2. Sampling

    
      
    def sample_along_rays(
    rays_o,
    rays_d,
    radii,
    num_samples,
    near,
    far,
    randomized,
    lindisp,
    ray_shape,
    ):
    bsz = rays_o.shape[0]
    t_vals = torch.linspace(0.0, 1.0, num_samples + 1, device=rays_o.device)
    if lindisp:
    t_vals = 1.0 / (1.0 / near * (1.0 - t_vals) + 1.0 / far * t_vals)
    else:
    t_vals = near * (1.0 - t_vals) + far * t_vals
    if randomized:
    mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1])
    upper = torch.cat([mids, t_vals[..., -1:]], -1)
    lower = torch.cat([t_vals[..., :1], mids], -1)
    t_rand = torch.rand((bsz, num_samples + 1), device=rays_o.device)
    t_vals = lower + (upper - lower) * t_rand
    else:
    t_vals = torch.broadcast_to(t_vals, (bsz, num_samples + 1))
    means, covs = cast_rays(t_vals, rays_o, rays_d, radii, ray_shape)
    return t_vals, (means, covs)

    Integrated Positional Encoding

    IPE 는 Cone tracing으로 샘플링된 Multivariate Gaussian을 통해 연산된다.

    
      
    t_vals, samples = sample_along_rays(
    rays_o,
    rays_d,
    radius,
    num_samples,
    near,
    far,
    randomized,
    lindisp,
    ray_shape,
    )
    samples_enc = integrated_pos_enc(samples, 2, 16)
    
      
    def integrated_pos_enc(samples, min_deg=0, max_deg=16):
    scales, shape = pe_fourier(samples, min_deg, max_deg)
    y, y_var = lift_gaussian(samples, scales, shape)
    samples_enc = expected_sin(
    torch.cat([y, y + 0.5 * np.pi], axis=-1), torch.cat([y_var] * 2, axis=-1))[0]
    return samples_enc

    1) Rewrite the PE as a Fourier feature

     

    P=[1002002L100010020...02L10001002002L1]T,γ(x)=[sin(Px)cos(Px)]

    
      
    def pe_fourier(samples, min_deg=0, max_deg=16):
    x, x_cov_diag = samples
    scales = torch.tensor([2**i for i in range(min_deg, max_deg)]).type_as(x)
    shape = list(x.shape[:-1]) + [-1]
    return scales, shape

    2) Lift the multivariate Gaussian

    μγ=Pμ,γ=PPT

    
      
    def lift_gaussian(samples, scales, shape):
    x, x_cov_diag = samples
    y = torch.reshape(x[..., None, :] * scales[:, None], shape)
    y_var = torch.reshape(x_cov_diag[..., None, :] * scales[:, None] ** 2, shape)
    return y, y_var

    3) Expectations over lifted multivariate Gaussian

    ExN(μ,σ2)[sin(x)]=sin(μ)exp((1/2)σ2) ExN(μ,σ2)[cos(x)]=cos(μ)exp((1/2)σ2)

    
      
    def expected_sin(x, x_var):
    y = torch.exp(-0.5 * x_var) * torch.sin(x)
    y_var = 0.5 * (1 - torch.exp(-2 * x_var) * torch.cos(2 * x)) - y**2
    y_var = torch.fmax(torch.zeros_like(y_var), y_var)
    return y, y_var
Designed by Tistory.