ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [코드리뷰] NeRF Code Breakdown
    NeRF 2023. 7. 3. 15:53

    작성중

     

    Prepare Dataset

    if not os.path.exists('tiny_nerf_data.npz'):
      !wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.np
      
    data = np.load('tiny_nerf_data.npz')
    images = data['images']
    poses = data['poses']
    focal = data['focal']
    
    print(f'Images shape: {images.shape}')
    print(f'Poses shape: {poses.shape}')
    print(f'Focal length: {focal}')
    
    height, width = images.shape[1:3]
    near, far = 2., 6.
    
    n_training = 100
    testimg_idx = 101
    testimg, testpose = images[testimg_idx], poses[testimg_idx]
    
    plt.imshow(testimg)
    print('Pose')
    print(testpose)

     

    더보기

    Images shape: (106, 100, 100, 3)

    Poses shape: (106, 4, 4)
    Focal length: 138.88887889922103
    Pose at loc '0' [
    [ 6.8935126e-01 5.3373039e-01 -4.8982298e-01 -1.9745398e+00]
    [-7.2442728e-01 5.0788772e-01 -4.6610624e-01 -1.8789345e+00]
    [ 1.4901163e-08 6.7615211e-01 7.3676193e-01 2.9699826e+00]
    [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 1.0000000e+00]]

     

    Origins and Directions

    (a) 106개 중 하나의 pose and (b) pose to dir

     

    dirs = np.stack([np.sum([0, 0, -1] * pose[:3, :3], axis=-1) for pose in poses]) # (106, 3)
    origins = poses[:, :3, -1] # (106, 3)
    
    ax = plt.figure(figsize=(12, 8)).add_subplot(projection='3d')
    _ = ax.quiver(
      origins[..., 0].flatten(), # X (pose)
      origins[..., 1].flatten(), # Y (pose)
      origins[..., 2].flatten(), # Z (pose)
      dirs[..., 0].flatten(), # U (dir)
      dirs[..., 1].flatten(), # V (dir)
      dirs[..., 2].flatten(), # W (dir)
      length=0.5, normalize=True)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('z')
    plt.show()
    더보기
    dir
    print(np.stack([0, 0, -1] * pose[:3, :3] for pose in poses)[0])
    # [[-0. 0. 0.01334572] [-0. -0. -0.95394367] [-0. 0. -0.29968831]]
    dirs = np.stack([np.sum([0, 0, -1] * pose[:3, :3], axis=-1) for pose in poses])
    # dirs = np.stack([[0, 0, -1] @ pose[:3, :3].T for pose in poses]) 와 동일한 코드
    print(dirs[0])
    # [ 0.01334572 -0.95394367 -0.29968831]

    origin

    print(poses[:, :3, -1][0])
    # [-0.05379832  3.8454704   1.2080823 ]​

    하나의 Object에 대해 106개의 pose 가 존재

     

    def get_rays(
      height: int,
      width: int,
      focal_length: float,
      c2w: torch.Tensor # 내가 사영시키고자 하는 ray direction
    ) -> Tuple[torch.Tensor, torch.Tensor]:
      r"""
      Find origin and direction of rays through every pixel and camera origin.
      """
    
      # Apply pinhole camera model to gather directions at each pixel
      i, j = torch.meshgrid(
          torch.arange(width, dtype=torch.float32).to(c2w),
          torch.arange(height, dtype=torch.float32).to(c2w),
          indexing='ij')
      i, j = i.transpose(-1, -2), j.transpose(-1, -2)
      directions = torch.stack([(i - width * .5) / focal_length,
                                -(j - height * .5) / focal_length,
                                -torch.ones_like(i)
                               ], dim=-1)
    
      # Apply camera pose to directions
      rays_d = torch.sum(directions[..., None, :] * c2w[:3, :3], dim=-1)
    
      # Origin is same for all directions (the optical center)
      rays_o = c2w[:3, -1].expand(rays_d.shape)
      return rays_o, rays_d
    # Gather as torch tensors
    images = torch.from_numpy(data['images'][:n_training]).to(device)
    poses = torch.from_numpy(data['poses']).to(device)
    focal = torch.from_numpy(data['focal']).to(device)
    testimg = torch.from_numpy(data['images'][testimg_idx]).to(device)
    testpose = torch.from_numpy(data['poses'][testimg_idx]).to(device)
    
    # Grab rays from sample image
    height, width = images.shape[1:3]
    with torch.no_grad():
      ray_origin, ray_direction = get_rays(height, width, focal, testpose)
    더보기
    def sampling_color(dat, step=1):
        if not isinstance(dat, np.ndarray):
            dat = dat.cpu()
        if dat.shape[-1] != 1:
            dat = dat[..., 0]
        if step > 1:
            dat = dat.flatten()[::step]
        size = len(dat)
        sampled = np.linspace(0, 2*np.pi, size)
        u = np.cos(sampled)
        v = np.sin(sampled)
        colors = np.arctan2(u, v)
        norm = Normalize()
        norm.autoscale(colors)
        colormap = cm.Paired
        return colormap(norm(colors))
    step = 1
    ax = plt.figure(figsize=(12, 8)).add_subplot(projection='3d')
    size = len(rays_o[..., 0].flatten()[::step])
    sampled = rays_o[..., 0].flatten()[::step] / (2 * np.pi)
    sampled = np.linspace(0, 2*np.pi, size)
    
    u = np.cos(sampled)
    v = np.sin(sampled)
    colors = np.arctan2(u, v)
    
    norm = Normalize()
    norm.autoscale(colors)
    # we need to normalize our colors array to match it colormap domain
    # which is [0, 1]
    
    colormap = cm.Paired
    
    _ = ax.quiver(
      rays_o[..., 0].flatten()[::step],
      rays_o[..., 1].flatten()[::step],
      rays_o[..., 2].flatten()[::step],
      rays_d[..., 0].flatten()[::step],
      rays_d[..., 1].flatten()[::step],
      rays_d[..., 2].flatten()[::step], color=colormap(norm(colors)), length=0.1, normalize=False, arrow_length_ratio=.05)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('z')
    plt.show()

    (a) test 대상 이미지 위치 (b) test 이미지 pose로 행렬 연산 후, 100*100 이미지 pixel에서 ray plot 한 모습

    Stratified Sampling

    Now that we have these lines, defined as origin and direction vectors, we can begin the process of sampling them. Recall that NeRF takes a coarse-to-fine sampling strategy, starting with the stratified sampling approach.

    The stratified sampling approach splits the ray into evenly-spaced bins and randomly samples within each bin.
    Stratified sampling은 Ray를 동일간격으로 slicing하여 각 slice 내에서 랜덤 샘플링을 통해 샘플을 추출한다.
    The perturb setting determines whether to sample points uniformly from each bin or to simply use the bin center as the point.
    Perturb 설정은 샘플링 지점을 균일한 간격으로 할지, 단순히 bin의 중심점을 샘플링 point로 사용할지 정한다.
    In most cases, we want to keep perturb = True as it will encourage the network to learn over a continuously sampled space. It may be useful to disable for debugging.
    대부분의 경우 perturb=True로 설정한다. 이는 네트워크가 연속적인 샘플링 공간에서 더 잘 학습할 수 있게 돕고, debugging을 조금 더 어렵게 한다.

    def sample_stratified(
      rays_o: torch.Tensor, # 100x100x3
      rays_d: torch.Tensor, # 100x100x3
      near: float, 
      far: float,
      n_samples: int,
      perturb: Optional[bool] = True, # regularly-spaced or not
      inverse_depth: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
      r"""
      Sample along ray from regularly-spaced bins.
      """
    
      # Grab samples for space integration along ray
      t_vals = torch.linspace(0., 1., n_samples, device=rays_o.device)
      if not inverse_depth:
        # Sample linearly between `near` and `far` (== parsing [0, 1] into [near, far])
        z_vals = near * (1.-t_vals) + far * (t_vals)
        # [Out] tensor([2.0000, 2.5714, 3.1429, 3.7143, 4.2857, 4.8571, 5.4286, 6.0000],
           device='cuda:0')
      else:
        # Sample linearly in inverse depth (disparity)
        z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
    
      # Draw uniform samples from bins along ray
      if perturb:
        mids = .5 * (z_vals[1:] + z_vals[:-1])
        # tensor([2.2857, 2.8571, 3.4286, 4.0000, 4.5714, 5.1429, 5.7143],
           device='cuda:0')
        upper = torch.concat([mids, z_vals[-1:]], dim=-1) 	# (mids, z_vals[-1])
        # tensor([2.2857, 2.8571, 3.4286, 4.0000, 4.5714, 5.1429, 5.7143, 6.0000],
           device='cuda:0')
        lower = torch.concat([z_vals[:1], mids], dim=-1)	# (z_vals[0], mids)
        # tensor([2.0000, 2.2857, 2.8571, 3.4286, 4.0000, 4.5714, 5.1429, 5.7143],
           device='cuda:0')
        t_rand = torch.rand([n_samples], device=z_vals.device)
        z_vals = lower + (upper - lower) * t_rand 
        # [Out] tensor([2.2012, 2.6339, 3.1944, 3.9895, 4.4883, 4.9825, 5.3318, 5.8038],
           device='cuda:0')
        # ((1 - t_rand)*z_vals[0] + t_rand*mids, (1 - t_rand)*mids + z_vals[-1]*t_rand)
        # == 18line 의 ((1 - t_rand) * near + t_rand * far, (1 - t_rand) * near + t_rand * far) 랑 똑같은 형식임
        # == Sample linearly between (z_vals[0], mids) / (mids, z_vals[-1])
      z_vals = z_vals.expand(list(rays_o.shape[:-1]) + [n_samples]) # (, 8) -> (1e4, 8) size의 z_vals 생성
      '''
      [Out] tensor([[2.2012, 2.6339, 3.1944,  ..., 4.9825, 5.3318, 5.8038],
            [2.2012, 2.6339, 3.1944,  ..., 4.9825, 5.3318, 5.8038],
            [2.2012, 2.6339, 3.1944,  ..., 4.9825, 5.3318, 5.8038],
            ...,
            [2.2012, 2.6339, 3.1944,  ..., 4.9825, 5.3318, 5.8038],
            [2.2012, 2.6339, 3.1944,  ..., 4.9825, 5.3318, 5.8038],
            [2.2012, 2.6339, 3.1944,  ..., 4.9825, 5.3318, 5.8038]],
           device='cuda:0')
      '''
    
      # Apply scale from `rays_d` and offset from `rays_o` to samples
      # pts: (width, height, n_samples, 3)
      pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
      # rays_o/rays_d : (1e4, 3) -> (1e4, 1, 3)
      # z_vals : (1e4, 8) -> (1e4, 8, 1)
      # pts: (1e4, 8, 3)
      return pts, z_vals
    # Draw stratified samples from example
    rays_o = ray_origin.view([-1, 3])
    rays_d = ray_direction.view([-1, 3])
    n_samples = 8
    perturb = True
    inverse_depth = False
    with torch.no_grad():
      pts, z_vals = sample_stratified(rays_o, rays_d, near, far, n_samples,
                                      perturb=perturb, inverse_depth=inverse_depth)
    
    print('Input Points')
    print(pts.shape) # (10000, n_samples, 3)
    print('')
    print('Distances Along Ray')
    print(z_vals.shape) # (10000, n_samples, 1)
    더보기
    def plot_rays_3d(ax, ray_origin_cpu, ray_direction_cpu, color='black', length=0.1, arrow_ratio=.05, step=9):
        ray_origin_cpu = ray_origin_cpu.cpu()
        ray_direction_cpu = ray_direction_cpu.cpu()
        
        _ = ax.quiver(
            ray_origin_cpu[..., 0].flatten()[::step],
            ray_origin_cpu[..., 1].flatten()[::step],
            ray_origin_cpu[..., 2].flatten()[::step],
            ray_direction_cpu[..., 0].flatten()[::step],
            ray_direction_cpu[..., 1].flatten()[::step],
            ray_direction_cpu[..., 2].flatten()[::step], length=length, color=color, normalize=False, arrow_length_ratio=arrow_ratio
        )
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('z'
    n = 8
    colors = np.linspace(0, 1, n)
    colors = cm.Paired(colors)
    ax = plt.figure(figsize=(16, 12)).add_subplot(projection='3d')
    margin = 0.5
    ax.margins(margin, margin, margin)
    arrow_ratio=.1
    step = 10000
    
    for i in range(n):
        plot_rays_3d(ax, torch.zeros_like(rays_o.reshape(100, 100, 3)), rays_o.reshape(100, 100, 3), color='black', length=1, arrow_ratio=arrow_ratio, step=step)
        plot_rays_3d(ax, torch.zeros_like(rays_o.reshape(100, 100, 3)), rays_d.reshape(100, 100, 3), color='gray', length=1, arrow_ratio=arrow_ratio, step=step)
        plot_rays_3d(ax, rays_o.reshape(100, 100, 3), (pts[:, i, :]).reshape(100, 100, 3) - rays_o.reshape(100, 100, 3), color=colors[i], length=1, arrow_ratio=arrow_ratio, step=step)
        plot_rays_3d(ax, rays_o.reshape(100, 100, 3), rays_d.reshape(100, 100, 3), color='gray', length=1, arrow_ratio=arrow_ratio, step=step)
    
    plt.show()

    (1) black: rays_o (2) gray: rays_d & rays_o - rays_d (3) colors: pts

     

    Perturbed vs. Unperturbed

    Now we visualize these sampled points.
    이제 sampled points를 시각화 해보자.

    The unperturbed blue points are the bin "centers." The red points are a sampling of perturbed points.
    파란색 점들은 Unperturbed Points로, 각 bin의 중심값을 샘플링한 점들입니다. 빨간색 점들은 Perturbed Points를 의미한다.

    Notice how the red points are slightly offset from the blue points above them, but all are constrained between near and far.
    빨간 점들이 파란 점들에 비해 아래로 조금 떨어져 있다 (잘 보이게 하기 위해). 모든 점들은 near (2.) 와 far (6.) 사이에 위치한다. 

    y_vals = torch.zeros_like(z_vals)
    
    _, z_vals_unperturbed = sample_stratified(rays_o, rays_d, near, far, n_samples,
                                      perturb=False, inverse_depth=inverse_depth)
    plt.plot(z_vals_unperturbed[0].cpu().numpy(), 1 + y_vals[0].cpu().numpy(), 'b-o')
    plt.plot(z_vals[0].cpu().numpy(), y_vals[0].cpu().numpy(), 'r-o')
    plt.ylim([-1, 2])
    plt.title('Stratified Sampling (blue) with Perturbation (red)')
    ax = plt.gca()
    ax.axes.yaxis.set_visible(False)
    plt.grid(True)

     

    Positional Encoder

    Much like Transformers, NeRFs make use of positional encoders. In this case, it's to map the inputs to a higher frequency space to compensate for the bias that neural networks have for learning lower-frequency functions.
    Transformer 처럼, NeRF도 Positional Encoder (PE) 를 사용합니다. 하지만, NeRF는 네트워크가 낮은 frequency 함수를 학습할 때 발생하는 bias를 상쇄하기 위해 input을 더 높은 frequency 공간으로 맵핑하기 위해 PE를 사용한다.

    Here we build a simple torch.nn.Module of our positional encoder.
    우리는 간단한 torch.nn.Module을 만들어 PE를 테스트 해 볼 것이다. 

    The same encoder implementation can be applied to both input samples and view directions.
    동일하게 구현된 Encoder를 사용하여 샘플링된 input 과 view direction에 적용하였다.

    However, we choose different parameters for these inputs. 
    하지만, 우리는 샘플링된 input과 view direction에 서로 다른 파라미터를 선택하였다. 

     

    We use the default settings from the original.
    우리는 기본 설정을 사용하여 구현하였다.

    Good Posting about PE

     

    #The Positional Encoding 를 어떻게 하는 것인가? #Transformer_모델

    Self-attention이 있는 Transformer의 후속 모델들은 positional encoding도 transformer의 방식을 따른다. 본 포스팅은 Positional Encoding부분을 자세하게 다루고자 한다. "Attention is all you need"라는 논문에서 cos, sin

    pongdangstory.tistory.com

    class PositionalEncoder(nn.Module):
      r"""
      Sine-cosine positional encoder for input points.
      """
      def __init__(
        self,
        d_input: int,
        n_freqs: int,
        log_space: bool = False
      ):
        super().__init__()
        self.d_input = d_input
        self.n_freqs = n_freqs
        self.log_space = log_space
        self.d_output = d_input * (1 + 2 * self.n_freqs)
        self.embed_fns = [lambda x: x]
    
        # Define frequencies in either linear or log scale
        if self.log_space:
          freq_bands = 2.**torch.linspace(0., self.n_freqs - 1, self.n_freqs)
        else:
          freq_bands = torch.linspace(2.**0., 2.**(self.n_freqs - 1), self.n_freqs)
    
        # Alternate sin and cos
        for freq in freq_bands:
          self.embed_fns.append(lambda x, freq=freq: torch.sin(x * freq))
          self.embed_fns.append(lambda x, freq=freq: torch.cos(x * freq))
    
      def forward(
        self,
        x
      ) -> torch.Tensor:
        r"""
        Apply positional encoding to input.
        """
        return torch.concat([fn(x) for fn in self.embed_fns], dim=-1)
    # Create encoders for points and view directions
    encoder = PositionalEncoder(3, 10)
    viewdirs_encoder = PositionalEncoder(3, 4)
    
    # Grab flattened points and view directions
    pts_flattened = pts.reshape(-1, 3) # (1e4, 8, 3) -> (8e4, 3)
    viewdirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) # (1e4, 3)
    # torch.norm 은 벡터의 길이를 측정하는 함수이다.
    # viewdirs는 rays_d 의 unit vector 이다.
    flattened_viewdirs = viewdirs[:, None, ...].expand(pts.shape).reshape((-1, 3)) # pts 와 형태 맞춰주기
    
    # Encode inputs
    encoded_points = encoder(pts_flattened) # (8e4, 63) : n_freq(10) * 2(cos, sin) + x = 21 ... (,3) 형태로 x가 들어오기 때문에 -> 21 * 3 = 63
    encoded_viewdirs = viewdirs_encoder(flattened_viewdirs) # (8e4, 27) # n_freq(4) * 2 + x = 9 ... (,3) -> 27
    
    print('Encoded Points')
    print(encoded_points.shape)
    print(torch.min(encoded_points), torch.max(encoded_points), torch.mean(encoded_points))
    print('')
    
    print(encoded_viewdirs.shape)
    print('Encoded Viewdirs')
    print(torch.min(encoded_viewdirs), torch.max(encoded_viewdirs), torch.mean(encoded_viewdirs))
    print('')

     

    NeRF Model

    Here we define the NeRF model, which consists primarily of a ModuleList of Linear layers, separated by non-linear activation functions and the occasional residual connection.
    Linear 레이어와 Non-linear activation 함수의 ModuleList 로 이루어진 NeRF 모델을 정의하고 residual 연결을 일부 사용한다.

    This model features an optional input for view directions, which will alter the model architecture if provided at instantiation.
    이 모델의 feature들은 view direction의 옵션이다. 모델을 인스턴스화 시킬 때 설계를 바꿀 수 있도록 포함되어있다.

    This implementation is based on Section 3 of the original "NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis" paper and uses the same defaults.
    아래의 구현은 NeRF 논문의 Section 3를 바탕으로 하였다. 

    1. Basic NeRF Model

    class NeRF(nn.Module):
      r"""
      Neural radiance fields module.
      """
      def __init__(
        self,
        d_input: int = 3,
        n_layers: int = 8,
        d_filter: int = 256,
        skip: Tuple[int] = (4,),
        d_viewdirs: Optional[int] = None # (8e4, 27)
      ):
        super().__init__()
        self.d_input = d_input
        self.skip = skip
        self.act = nn.functional.relu
        self.d_viewdirs = d_viewdirs
    
        # Create model layers
        self.layers = nn.ModuleList(
          [nn.Linear(self.d_input, d_filter)] +
          [nn.Linear(d_filter + self.d_input, d_filter) if i + 1 in skip \
           else nn.Linear(d_filter, d_filter) for i in range(n_layers - 1)]
        )
    
        # Bottleneck layers
        if self.d_viewdirs is not None:
          # If using viewdirs, split alpha and RGB
          self.alpha_out = nn.Linear(d_filter, 1)
          self.rgb_filters = nn.Linear(d_filter, d_filter)
          self.branch = nn.Linear(d_filter + self.d_viewdirs, d_filter // 2) # (256 + 27, 256 // 2)
          self.output = nn.Linear(d_filter // 2, 3) # alpha_out 채널과 합쳐져서 depth=4 가 될 예정
        else:
          # If no viewdirs, use simpler output
          self.output = nn.Linear(d_filter, 4)
    
      def forward(
        self,
        x: torch.Tensor,
        viewdirs: Optional[torch.Tensor] = None
      ) -> torch.Tensor:
        r"""
        Forward pass with optional view direction.
        """
    
        # Cannot use viewdirs if instantiated with d_viewdirs = None
        if self.d_viewdirs is None and viewdirs is not None:
          raise ValueError('Cannot input x_direction if d_viewdirs was not given.')
    
        # Apply forward pass up to bottleneck
        x_input = x
        for i, layer in enumerate(self.layers):
          x = torch.cat([x, x_input], dim=-1) if i in self.skip else x
          x = self.act(layer(x))
    
        # Apply bottleneck
        if self.d_viewdirs is not None:
          # Split alpha from network output
          alpha = self.alpha_out(x)
    
          # Pass through bottleneck to get RGB
          x = self.rgb_filters(x)
          x = torch.concat([x, viewdirs], dim=-1)
          x = self.act(self.branch(x))
          x = self.output(x)
    
          # Concatenate alphas to output
          x = torch.concat([x, alpha], dim=-1)
        else:
          # Simple output
          x = self.output(x)
        return x

     

    더보기

    2. MipNeRF Model

    class MipNerfModel(nn.Module):
      """Nerf NN Model with both coarse and fine MLPs."""
      num_samples: int = 128  # The number of samples per level.
      num_levels: int = 2  # The number of sampling levels.
      resample_padding: float = 0.01  # Dirichlet/alpha "padding" on the histogram.
      stop_level_grad: bool = True  # If True, don't backprop across levels')
      use_viewdirs: bool = True  # If True, use view directions as a condition.
      lindisp: bool = False  # If True, sample linearly in disparity, not in depth.
      ray_shape: str = 'cone'  # The shape of cast rays ('cone' or 'cylinder').
      min_deg_point: int = 0  # Min degree of positional encoding for 3D points.
      max_deg_point: int = 16  # Max degree of positional encoding for 3D points.
      deg_view: int = 4  # Degree of positional encoding for viewdirs.
      density_activation: Callable[..., Any] = nn.softplus  # Density activation.
      density_noise: float = 0.  # Standard deviation of noise added to raw density.
      density_bias: float = -1.  # The shift added to raw densities pre-activation.
      rgb_activation: Callable[..., Any] = nn.sigmoid  # The RGB activation.
      rgb_padding: float = 0.001  # Padding added to the RGB outputs.
      disable_integration: bool = False  # If True, use PE instead of IPE.
    
      @nn.compact
      def __call__(self, rng, rays, randomized, white_bkgd):
        """The mip-NeRF Model.
    
        Args:
          rng: jnp.ndarray, random number generator.
          rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs.
          randomized: bool, use randomized stratified sampling.
          white_bkgd: bool, if True, use white as the background (black o.w.).
    
        Returns:
          ret: list, [*(rgb, distance, acc)]
        """
        # Construct the MLP.
        mlp = MLP()
    
        ret = []
        for i_level in range(self.num_levels):
          key, rng = random.split(rng)
          if i_level == 0:
            # Stratified sampling along rays
            t_vals, samples = mip.sample_along_rays(
                key,
                rays.origins,
                rays.directions,
                rays.radii,
                self.num_samples,
                rays.near,
                rays.far,
                randomized,
                self.lindisp,
                self.ray_shape,
            )
          else:
            t_vals, samples = mip.resample_along_rays(
                key,
                rays.origins,
                rays.directions,
                rays.radii,
                t_vals,
                weights,
                randomized,
                self.ray_shape,
                self.stop_level_grad,
                resample_padding=self.resample_padding,
            )
          if self.disable_integration:
            samples = (samples[0], jnp.zeros_like(samples[1]))
          samples_enc = mip.integrated_pos_enc(
              samples,
              self.min_deg_point,
              self.max_deg_point,
          )
    
          # Point attribute predictions
          if self.use_viewdirs:
            viewdirs_enc = mip.pos_enc(
                rays.viewdirs,
                min_deg=0,
                max_deg=self.deg_view,
                append_identity=True,
            )
            raw_rgb, raw_density = mlp(samples_enc, viewdirs_enc)
          else:
            raw_rgb, raw_density = mlp(samples_enc)
    
          # Add noise to regularize the density predictions if needed.
          if randomized and (self.density_noise > 0):
            key, rng = random.split(rng)
            raw_density += self.density_noise * random.normal(
                key, raw_density.shape, dtype=raw_density.dtype)
    
          # Volumetric rendering.
          rgb = self.rgb_activation(raw_rgb)
          rgb = rgb * (1 + 2 * self.rgb_padding) - self.rgb_padding
          density = self.density_activation(raw_density + self.density_bias)
          comp_rgb, distance, acc, weights = mip.volumetric_rendering(
              rgb,
              density,
              t_vals,
              rays.directions,
              white_bkgd=white_bkgd,
          )
          ret.append((comp_rgb, distance, acc))
    
        return ret

     

     

    Volume Rendering

    From the raw NeRF outputs, we still need to convert these into an image.
    아직 NeRF 결과물은 이미지로 변환하는 과정이 필요하다.

    This is where we apply the volume integration described in Equations 1-3 in Section 4 of the paper.
    이 과정에서 volume 을 합치는 과정이 필요하다. 이것은 Section 4에 나온 1-3 방정식에 설명되어 있다.

    Essentially, we take the weighted sum of all samples along the ray of each pixel to get the estimated color value at that pixel.
    필수적으로, 우리는 모든 샘플들의 각 pixel에서 ray를 따라 가중치합을 해야한다. 이를 통해 우리는 해당 픽셀의 색을 추정할 수 있다.

    Each RGB sample is weighted by its alpha value.
    각 RGB 샘플은 alpha 값을 가중치로 사용하여 연산한다.

    Higher alpha values indicate higher likelihood that the sampled area is opaque, therefore points further along the ray are likelier to be occluded.
    높은 알파값은 샘플링된 부분이 불투명할 가능성이 높다는 것을 의미한다. 따라서, 해당 ray 뒤에 위치한 ray 위의 포인트들은 가려질 가능성이 있다. (= 알파값을 사용하여 불투명한 지점을 찾아내고)

    The cumulative product ensures that those further points are dampened.
    누적곱을 사용하여 그 이후 지점들의 영향을 줄여버린다.

    def cumprod_exclusive(
      tensor: torch.Tensor
    ) -> torch.Tensor:
      r"""
      (Courtesy of https://github.com/krrish94/nerf-pytorch)
    
      Mimick functionality of tf.math.cumprod(..., exclusive=True), as it isn't available in PyTorch.
    
      Args:
      tensor (torch.Tensor): Tensor whose cumprod (cumulative product, see `torch.cumprod`) along dim=-1
        is to be computed.
      Returns:
      cumprod (torch.Tensor): cumprod of Tensor along dim=-1, mimiciking the functionality of
        tf.math.cumprod(..., exclusive=True) (see `tf.math.cumprod` for details).
      """
    
      # Compute regular cumprod first (this is equivalent to `tf.math.cumprod(..., exclusive=False)`).
      cumprod = torch.cumprod(tensor, -1)
      # "Roll" the elements along dimension 'dim' by 1 element.
      cumprod = torch.roll(cumprod, 1, -1)
      # Replace the first element by "1" as this is what tf.cumprod(..., exclusive=True) does.
      cumprod[..., 0] = 1.
    
      return cumprod
    
    def raw2outputs(
      raw: torch.Tensor, # (n_rays, n_samples, 4)
      z_vals: torch.Tensor, # (n_rays, n_samples)
      rays_d: torch.Tensor, # (n_rays, 3)
      raw_noise_std: float = 0.0,
      white_bkgd: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
      r"""
      Convert the raw NeRF output into RGB and other maps.
      """
    
      # Difference between consecutive elements of `z_vals`.
      dists = z_vals[..., 1:] - z_vals[..., :-1] # (n_rays, n_samples - 1)
      dists = torch.cat([dists, 1e10 * torch.ones_like(dists[..., :1])], dim=-1) # (n_rays, n_samples)
    
      # Multiply each distance by the norm of its corresponding direction ray
      # to convert to real world distance (accounts for non-unit directions).
      dists = dists * torch.norm(rays_d[..., None, :], dim=-1) # (n_rays, n_samples)
    
      # Add noise to model's predictions for density. Can be used to
      # regularize network during training (prevents floater artifacts).
      noise = 0.
      if raw_noise_std > 0.:
        noise = torch.randn(raw[..., 3].shape) * raw_noise_std
    
      # Predict density of each sample along each ray. Higher values imply
      # higher likelihood of being absorbed at this point. [n_rays, n_samples]
      alpha = 1.0 - torch.exp(-nn.functional.relu(raw[..., 3] + noise) * dists) # (n_rays, n_samples)
    
      # Compute weight for RGB of each sample along each ray. [n_rays, n_samples]
      # The higher the alpha, the lower subsequent weights are driven.
      weights = alpha * cumprod_exclusive(1. - alpha + 1e-10) # (n_rays, n_samples)
    
      # Compute weighted RGB map.
      rgb = torch.sigmoid(raw[..., :3])  # (n_rays, n_samples, 3)
      rgb_map = torch.sum(weights[..., None] * rgb, dim=-2)  # [n_rays, 1, 3]
    
      # Estimated depth map is predicted distance.
      depth_map = torch.sum(weights * z_vals, dim=-1)
    
      # Disparity map is inverse depth.
      disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map),
                                depth_map / torch.sum(weights, -1))
    
      # Sum of weights along each ray. In [0, 1] up to numerical error.
      acc_map = torch.sum(weights, dim=-1) 
      # weight가 alpha값을 의미하니까, [0, 1] 범위를 벗어나는 경우 accuracy X
    
      # To composite onto a white background, use the accumulated alpha map.
      if white_bkgd:
        rgb_map = rgb_map + (1. - acc_map[..., None])
    
      return rgb_map, depth_map, acc_map, weights

     

    Hierarchical Volume Sampling

    The 3D space is in fact very sparse with occlusions and so most points don't contribute much to the rendered image.
    실제로 3D 공간에서 서로 가려진 경우가 많기 때문에 이미지는 소수의 points로만 렌더링된다.

    It is therefore more beneficial to oversample regions with a high likelihood of contributing to the integral.
    그래서 렌더링에 참여하는 부분을 집중적으로 sampling하면 성능이 개선될 수 있다.

    Here we apply learned, normalized weights to the first set of samples to create a PDF across the ray, then apply inverse transform sampling to this PDF to gather a second set of samples.
    우리는 학습기반의 정규화된 가중치 (weights) 를 첫번째 샘플링 세트 (z_vals) 에 적용하여 ray를 따라 pdf를 구성한다. 그 다음, pdf를 활용하여 역변환 샘플링을 통해 (렌더링에 관여하는 point들을 재추정하고) 두번째 샘플링 세트 (렌더링 대상 area를 집중적으로 샘플링한 포인트들) 를 구성한다.
    (weights를 정규화된 가중치로 볼 수 있는 이유: [0, 1]의 alpha 값이기 때문에)

    def sample_pdf(
      bins: torch.Tensor, # (n_rays, n_samples - 1) # z_vals_mid 를 input으로 넣어주고 있다.
      weights: torch.Tensor, # (n_rays, n_samples - 2) # weights[:, 1:-1] # raw2outputs 함수를 통해 재구성된 alpha값
      n_samples: int, # 64
      perturb: bool = False
    ) -> torch.Tensor:
      r"""
      Apply inverse transform sampling to a weighted set of points.
      """
    
      # Normalize weights to get PDF.
      pdf = (weights + 1e-5) / torch.sum(weights + 1e-5, -1, keepdims=True) # [n_rays, weights.shape[-1]]  # (1e4, 62)
    
      # Convert PDF to CDF.
      cdf = torch.cumsum(pdf, dim=-1) # [n_rays, weights.shape[-1]] # (1e4, 62)
      cdf = torch.concat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1) 
      # (n_rays, weights.shape[-1] + 1) == (n_rays, 62 + 1)
    
      # Take sample positions to grab from CDF. Linear when perturb == 0.
      if not perturb:
        u = torch.linspace(0., 1., n_samples, device=cdf.device) # [1, n_samples]
        tar_shape = list(cdf.shape[:-1]) + [n_samples] # = [n_rays, n_samples]
        u = u.expand(tar_shape) # [n_rays, n_samples]
      else:
        u = torch.rand(list(cdf.shape[:-1]) + [n_samples], device=cdf.device) 
        # (n_rays, n_samples) # (1e4, 64)
    
      # Find indices along CDF where values in u would be placed.
      # CDF 상에서 u에 해당 값들이 위치하는 index 찾기.
      u = u.contiguous() # Returns contiguous tensor with same values.
      inds = torch.searchsorted(cdf, u, right=True) # (n_rays, n_samples)
      # cdf 63 개 값 중, u 64 개 위치의 index 반환하기 
      # => 하나의 bin에 2개의 sampled point가 존재하는 구간이 반드시 하나 이상 존재함.
    
      # Clamp indices that are out of bounds.
      below = torch.clamp(inds - 1, min=0) # (n_rays, n_samples)
      above = torch.clamp(inds, max=cdf.shape[-1] - 1) # (n_rays, n_samples)
      inds_g = torch.stack([below, above], dim=-1) # (n_rays, n_samples, 2)
    
      # Sample from cdf and the corresponding bin centers.
      matched_shape = list(inds_g.shape[:-1]) + [cdf.shape[-1]] # (3,)
      cdf_g = torch.gather(cdf.unsqueeze(-2).expand(matched_shape), dim=-1,
                           index=inds_g) # (n_rays, n_samples, 2)
      # cdf.shape = (n_rays, n_samples - 1) 
      # -> cdf.unsqueeze(-2).shape = (n_rays, 1, n_samples - 1)
      # -> ~.expand(matched_shape) = (n_rays, 1, n_samples)
      # -> torch.gather(~, dim=-1) = (n_rays, n_samples, 2) ~ shape of inds_g
      bins_g = torch.gather(bins.unsqueeze(-2).expand(matched_shape), dim=-1,
                            index=inds_g) # (n_rays, n_samples, 2) ~ shape of inds_g
    
      # Convert samples to ray length.
      denom = (cdf_g[..., 1] - cdf_g[..., 0]) # (n_rays, n_samples)
      denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) # 1e-5보다 작은 값인 경우 1로 값 고정
      # (n_rays, n_samples)
       
      t = (u - cdf_g[..., 0]) / denom # (n_rays, n_samples)
      samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) # (n_rays, n_samples)
      '''
      t: below~above 사이에서 u의 위치를 비율로 나타낸 것 (비율 1은 above-below 거리를 의미한다.)
      below + t * denom 은 실제 u의 위치를 의미한다.
      bins_below + t * (bins_above - bins_below) 의 의미는 실제 ray 상에서 cdf 위치 (u'...약간 approx.된) 를 찾기 위한 식이다.
      위는 sampling 할 위치를 ray 에서 찾는 식이다.
      '''
      
      return samples # (n_rays, n_samples)

     

    def sample_hierarchical(
      rays_o: torch.Tensor, # (100, 100, 3)
      rays_d: torch.Tensor, # (100, 100, 3)
      z_vals: torch.Tensor, # (n_rays, n_samples)
      weights: torch.Tensor, # (n_rays, n_samples)
      n_samples: int, # 64
      perturb: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
      r"""
      Apply hierarchical sampling to the rays.
      """
    
      # Draw samples from PDF using z_vals as bins and weights as probabilities.
      z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
      '''
      z_vals[..., 1:], z_vals[..., :-1] # (n_rays, n_samples - 1)
      z_vals_mid # (n_rays, n_samples - 1)
      '''
      new_z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], n_samples,
                              perturb=perturb) # (n_rays, n_samples)
      '''
      weights[0] 을 버리는 이유: sample_pdf 연산 중 noise 연산 시 1 값이 주는 영향력을 제거하기 위해서
      	-> sampled_pdf 함수 내부에서 weights[0] = 1 요소를 추가해주어 (n_rays, n_samples - 1) 로 만들어줌
      weights[-1] 을 버리는 이유: z_vals_mid 값을 사용 + weights 값은 해당 포인트 이전까지의 영향력을 나타내기 때문에
      	-> z_vals_mid[-1] 은 weights[-1] 의 영역을 일부 포함하지 않기 때문에 weights의 right 값이 아닌 left 값을 선택
      '''
      new_z_samples = new_z_samples.detach()
    
      # Resample points from ray based on PDF.
      z_vals_combined, _ = torch.sort(torch.cat([z_vals, new_z_samples], dim=-1), dim=-1)
      pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals_combined[..., :, None]  # [n_rays, n_samples + n_samples, 3]
      return pts, # (n_rays, n_samples * 2, 3)
      	z_vals_combined, # (n_rays, n_samples * 2)
    	new_z_samples # (n_rays, n_samples)

     

    Full Forward Pass

    Here is where we put everything together to compute a single forward pass through our model.

    Due to potential memory issues, the forward pass is computed in "chunks," which are then aggregated across a single batch.

    The gradient propagation is done after the whole batch is processed, hence the distinction between "chunks" and "batches."

    Chunking is especially important for the Google Colab environment, which provides more modest resources than those cited in the original paper.

    더보기
    def get_chunks(
      inputs: torch.Tensor,
      chunksize: int = 2**15
    ) -> List[torch.Tensor]:
      r"""
      Divide an input into chunks.
      """
      return [inputs[i:i + chunksize] for i in range(0, inputs.shape[0], chunksize)]
    
    def prepare_chunks(
      points: torch.Tensor,
      encoding_function: Callable[[torch.Tensor], torch.Tensor],
      chunksize: int = 2**15
    ) -> List[torch.Tensor]:
      r"""
      Encode and chunkify points to prepare for NeRF model.
      """
      points = points.reshape((-1, 3))
      points = encoding_function(points)
      points = get_chunks(points, chunksize=chunksize)
      return points
    
    def prepare_viewdirs_chunks(
      points: torch.Tensor,
      rays_d: torch.Tensor,
      encoding_function: Callable[[torch.Tensor], torch.Tensor],
      chunksize: int = 2**15
    ) -> List[torch.Tensor]:
      r"""
      Encode and chunkify viewdirs to prepare for NeRF model.
      """
      # Prepare the viewdirs
      viewdirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
      viewdirs = viewdirs[:, None, ...].expand(points.shape).reshape((-1, 3))
      viewdirs = encoding_function(viewdirs)
      viewdirs = get_chunks(viewdirs, chunksize=chunksize)
      return viewdirs

     

    def nerf_forward(
      rays_o: torch.Tensor, # (100, 100, 3)
      rays_d: torch.Tensor, # (100, 100, 3)
      near: float,
      far: float,
      encoding_fn: Callable[[torch.Tensor], torch.Tensor],
      coarse_model: nn.Module, # class NeRF 
      kwargs_sample_stratified: dict = None,
      n_samples_hierarchical: int = 0,
      kwargs_sample_hierarchical: dict = None,
      fine_model = None,
      viewdirs_encoding_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
      chunksize: int = 2**15
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
      r"""
      Compute forward pass through model(s).
      """
    
      # Set no kwargs if none are given.
      if kwargs_sample_stratified is None:
        kwargs_sample_stratified = {}
      if kwargs_sample_hierarchical is None:
        kwargs_sample_hierarchical = {}
    
      # Sample query points along each ray.
      query_points, z_vals = sample_stratified(
          rays_o, rays_d, near, far, **kwargs_sample_stratified)
    
      # Prepare batches.
      batches = prepare_chunks(query_points, encoding_fn, chunksize=chunksize)
      if viewdirs_encoding_fn is not None:
        batches_viewdirs = prepare_viewdirs_chunks(query_points, rays_d,
                                                   viewdirs_encoding_fn,
                                                   chunksize=chunksize)
      else:
        batches_viewdirs = [None] * len(batches)
    
      # Coarse model pass.
      # Split the encoded points into "chunks", run the model on all chunks, and
      # concatenate the results (to avoid out-of-memory issues).
      predictions = []
      for batch, batch_viewdirs in zip(batches, batches_viewdirs):
        predictions.append(coarse_model(batch, viewdirs=batch_viewdirs))
      raw = torch.cat(predictions, dim=0)
      raw = raw.reshape(list(query_points.shape[:2]) + [raw.shape[-1]])
    
      # Perform differentiable volume rendering to re-synthesize the RGB image.
      rgb_map, depth_map, acc_map, weights = raw2outputs(raw, z_vals, rays_d)
      # rgb_map, depth_map, acc_map, weights = render_volume_density(raw, rays_o, z_vals)
      outputs = {
          'z_vals_stratified': z_vals
      }
    
      # Fine model pass.
      if n_samples_hierarchical > 0:
        # Save previous outputs to return.
        rgb_map_0, depth_map_0, acc_map_0 = rgb_map, depth_map, acc_map
    
        # Apply hierarchical sampling for fine query points.
        query_points, z_vals_combined, z_hierarch = sample_hierarchical(
          rays_o, rays_d, z_vals, weights, n_samples_hierarchical,
          **kwargs_sample_hierarchical)
    
        # Prepare inputs as before.
        batches = prepare_chunks(query_points, encoding_fn, chunksize=chunksize)
        if viewdirs_encoding_fn is not None:
          batches_viewdirs = prepare_viewdirs_chunks(query_points, rays_d,
                                                     viewdirs_encoding_fn,
                                                     chunksize=chunksize)
        else:
          batches_viewdirs = [None] * len(batches)
    
        # Forward pass new samples through fine model.
        fine_model = fine_model if fine_model is not None else coarse_model
        predictions = []
        for batch, batch_viewdirs in zip(batches, batches_viewdirs):
          predictions.append(fine_model(batch, viewdirs=batch_viewdirs))
        raw = torch.cat(predictions, dim=0)
        raw = raw.reshape(list(query_points.shape[:2]) + [raw.shape[-1]])
    
        # Perform differentiable volume rendering to re-synthesize the RGB image.
        rgb_map, depth_map, acc_map, weights = raw2outputs(raw, z_vals_combined, rays_d)
    
        # Store outputs.
        outputs['z_vals_hierarchical'] = z_hierarch
        outputs['rgb_map_0'] = rgb_map_0
        outputs['depth_map_0'] = depth_map_0
        outputs['acc_map_0'] = acc_map_0
    
      # Store outputs.
      outputs['rgb_map'] = rgb_map
      outputs['depth_map'] = depth_map
      outputs['acc_map'] = acc_map
      outputs['weights'] = weights
      return outputs

     

    Train

    At long last, we have (almost) everything we need to train the model.

    Now we will do some setup for a simple training procedure, creating hyperparameters and helper functions, then train our model.

    Hyperparameters

    All hyperparameters for training are set here.

    Defaults were taken from the original, unless computational constraints prohibit them.

    In this case, we apply sensible defaults that are well within the resources provided by Google Colab.

    더보기
    # Encoders
    d_input = 3           # Number of input dimensions
    n_freqs = 10          # Number of encoding functions for samples
    log_space = True      # If set, frequencies scale in log space
    use_viewdirs = True   # If set, use view direction as input
    n_freqs_views = 4     # Number of encoding functions for views
    
    # Stratified sampling
    n_samples = 64         # Number of spatial samples per ray
    perturb = True         # If set, applies noise to sample positions
    inverse_depth = False  # If set, samples points linearly in inverse depth
    
    # Model
    d_filter = 128          # Dimensions of linear layer filters
    n_layers = 2            # Number of layers in network bottleneck
    skip = []               # Layers at which to apply input residual
    use_fine_model = True   # If set, creates a fine model
    d_filter_fine = 128     # Dimensions of linear layer filters of fine network
    n_layers_fine = 6       # Number of layers in fine network bottleneck
    
    # Hierarchical sampling
    n_samples_hierarchical = 64   # Number of samples per ray
    perturb_hierarchical = False  # If set, applies noise to sample positions
    
    # Optimizer
    lr = 5e-4  # Learning rate
    
    # Training
    n_iters = 10000
    batch_size = 2**14          # Number of rays per gradient step (power of 2)
    one_image_per_step = True   # One image per gradient step (disables batching)
    chunksize = 2**14           # Modify as needed to fit in GPU memory
    center_crop = True          # Crop the center of image (one_image_per_)
    center_crop_iters = 50      # Stop cropping center after this many epochs
    display_rate = 25          # Display test output every X epochs
    
    # Early Stopping
    warmup_iters = 100          # Number of iterations during warmup phase
    warmup_min_fitness = 10.0   # Min val PSNR to continue training at warmup_iters
    n_restarts = 10             # Number of times to restart if training stalls
    
    # We bundle the kwargs for various functions to pass all at once.
    kwargs_sample_stratified = {
        'n_samples': n_samples,
        'perturb': perturb,
        'inverse_depth': inverse_depth
    }
    kwargs_sample_hierarchical = {
        'perturb': perturb
    }

     

    Training Classes and Functions

    Here we create some helper functions for training.

    NeRF can be prone to local minima, in which training will quickly stall and produce blank outputs. 

    EarlyStopping is used to restart the training when learning stalls, if necessary.

    def plot_samples(
      z_vals: torch.Tensor,
      z_hierarch: Optional[torch.Tensor] = None,
      ax: Optional[np.ndarray] = None):
      r"""
      Plot stratified and (optional) hierarchical samples.
      """
      y_vals = 1 + np.zeros_like(z_vals)
    
      if ax is None:
        ax = plt.subplot()
      ax.plot(z_vals, y_vals, 'b-o')
      if z_hierarch is not None:
        y_hierarch = np.zeros_like(z_hierarch)
        ax.plot(z_hierarch, y_hierarch, 'r-o')
      ax.set_ylim([-1, 2])
      ax.set_title('Stratified  Samples (blue) and Hierarchical Samples (red)')
      ax.axes.yaxis.set_visible(False)
      ax.grid(True)
      return ax
    
    def crop_center(
      img: torch.Tensor,
      frac: float = 0.5
    ) -> torch.Tensor:
      r"""
      Crop center square from image.
      """
      h_offset = round(img.shape[0] * (frac / 2))
      w_offset = round(img.shape[1] * (frac / 2))
      return img[h_offset:-h_offset, w_offset:-w_offset]
    
    class EarlyStopping:
      r"""
      Early stopping helper based on fitness criterion.
      """
      def __init__(
        self,
        patience: int = 30,
        margin: float = 1e-4
      ):
        self.best_fitness = 0.0  # In our case PSNR
        self.best_iter = 0
        self.margin = margin
        self.patience = patience or float('inf')  # epochs to wait after fitness stops improving to stop
    
      def __call__(
        self,
        iter: int,
        fitness: float
      ):
        r"""
        Check if criterion for stopping is met.
        """
        if (fitness - self.best_fitness) > self.margin:
          self.best_iter = iter
          self.best_fitness = fitness
        delta = iter - self.best_iter
        stop = delta >= self.patience  # stop training if patience exceeded
        return stop
    def init_models():
      r"""
      Initialize models, encoders, and optimizer for NeRF training.
      """
      # Encoders
      encoder = PositionalEncoder(d_input, n_freqs, log_space=log_space)
      encode = lambda x: encoder(x)
    
      # View direction encoders
      if use_viewdirs:
        encoder_viewdirs = PositionalEncoder(d_input, n_freqs_views,
                                            log_space=log_space)
        encode_viewdirs = lambda x: encoder_viewdirs(x)
        d_viewdirs = encoder_viewdirs.d_output
      else:
        encode_viewdirs = None
        d_viewdirs = None
    
      # Models
      model = NeRF(encoder.d_output, n_layers=n_layers, d_filter=d_filter, skip=skip,
                  d_viewdirs=d_viewdirs)
      model.to(device)
      model_params = list(model.parameters())
      if use_fine_model:
        fine_model = NeRF(encoder.d_output, n_layers=n_layers, d_filter=d_filter, skip=skip,
                          d_viewdirs=d_viewdirs)
        fine_model.to(device)
        model_params = model_params + list(fine_model.parameters())
      else:
        fine_model = None
    
      # Optimizer
      optimizer = torch.optim.Adam(model_params, lr=lr)
    
      # Early Stopping
      warmup_stopper = EarlyStopping(patience=50)
    
      return model, fine_model, encode, encode_viewdirs, optimizer, warmup_stopper

     

    Training Loop

    def train():
      r"""
      Launch training session for NeRF.
      """
      # Shuffle rays across all images.
      if not one_image_per_step:
        height, width = images.shape[1:3]
        all_rays = torch.stack([torch.stack(get_rays(height, width, focal, p), 0)
                            for p in poses[:n_training]], 0)
        rays_rgb = torch.cat([all_rays, images[:, None]], 1)
        rays_rgb = torch.permute(rays_rgb, [0, 2, 3, 1, 4])
        rays_rgb = rays_rgb.reshape([-1, 3, 3])
        rays_rgb = rays_rgb.type(torch.float32)
        rays_rgb = rays_rgb[torch.randperm(rays_rgb.shape[0])]
        i_batch = 0
    
      train_psnrs = []
      val_psnrs = []
      iternums = []
      for i in trange(n_iters):
        model.train()
    
        if one_image_per_step:
          # Randomly pick an image as the target.
          target_img_idx = np.random.randint(images.shape[0])
          target_img = images[target_img_idx].to(device)
          if center_crop and i < center_crop_iters:
            target_img = crop_center(target_img)
          height, width = target_img.shape[:2]
          target_pose = poses[target_img_idx].to(device)
          rays_o, rays_d = get_rays(height, width, focal, target_pose)
          rays_o = rays_o.reshape([-1, 3])
          rays_d = rays_d.reshape([-1, 3])
        else:
          # Random over all images.
          batch = rays_rgb[i_batch:i_batch + batch_size]
          batch = torch.transpose(batch, 0, 1)
          rays_o, rays_d, target_img = batch
          height, width = target_img.shape[:2]
          i_batch += batch_size
          # Shuffle after one epoch
          if i_batch >= rays_rgb.shape[0]:
              rays_rgb = rays_rgb[torch.randperm(rays_rgb.shape[0])]
              i_batch = 0
        target_img = target_img.reshape([-1, 3])
    
        # Run one iteration of TinyNeRF and get the rendered RGB image.
        outputs = nerf_forward(rays_o, rays_d,
                               near, far, encode, model,
                               kwargs_sample_stratified=kwargs_sample_stratified,
                               n_samples_hierarchical=n_samples_hierarchical,
                               kwargs_sample_hierarchical=kwargs_sample_hierarchical,
                               fine_model=fine_model,
                               viewdirs_encoding_fn=encode_viewdirs,
                               chunksize=chunksize)
    
        # Check for any numerical issues.
        for k, v in outputs.items():
          if torch.isnan(v).any():
            print(f"! [Numerical Alert] {k} contains NaN.")
          if torch.isinf(v).any():
            print(f"! [Numerical Alert] {k} contains Inf.")
    
        # Backprop!
        rgb_predicted = outputs['rgb_map']
        loss = torch.nn.functional.mse_loss(rgb_predicted, target_img)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        psnr = -10. * torch.log10(loss)
        train_psnrs.append(psnr.item())
    
        # Evaluate testimg at given display rate.
        if i % display_rate == 0:
          model.eval()
          height, width = testimg.shape[:2]
          rays_o, rays_d = get_rays(height, width, focal, testpose)
          rays_o = rays_o.reshape([-1, 3])
          rays_d = rays_d.reshape([-1, 3])
          outputs = nerf_forward(rays_o, rays_d,
                                 near, far, encode, model,
                                 kwargs_sample_stratified=kwargs_sample_stratified,
                                 n_samples_hierarchical=n_samples_hierarchical,
                                 kwargs_sample_hierarchical=kwargs_sample_hierarchical,
                                 fine_model=fine_model,
                                 viewdirs_encoding_fn=encode_viewdirs,
                                 chunksize=chunksize)
    
          rgb_predicted = outputs['rgb_map']
          loss = torch.nn.functional.mse_loss(rgb_predicted, testimg.reshape(-1, 3))
          print("Loss:", loss.item())
          val_psnr = -10. * torch.log10(loss)
          val_psnrs.append(val_psnr.item())
          iternums.append(i)
    
          # Plot example outputs
          fig, ax = plt.subplots(1, 4, figsize=(24,4), gridspec_kw={'width_ratios': [1, 1, 1, 3]})
          ax[0].imshow(rgb_predicted.reshape([height, width, 3]).detach().cpu().numpy())
          ax[0].set_title(f'Iteration: {i}')
          ax[1].imshow(testimg.detach().cpu().numpy())
          ax[1].set_title(f'Target')
          ax[2].plot(range(0, i + 1), train_psnrs, 'r')
          ax[2].plot(iternums, val_psnrs, 'b')
          ax[2].set_title('PSNR (train=red, val=blue')
          z_vals_strat = outputs['z_vals_stratified'].view((-1, n_samples))
          z_sample_strat = z_vals_strat[z_vals_strat.shape[0] // 2].detach().cpu().numpy()
          if 'z_vals_hierarchical' in outputs:
            z_vals_hierarch = outputs['z_vals_hierarchical'].view((-1, n_samples_hierarchical))
            z_sample_hierarch = z_vals_hierarch[z_vals_hierarch.shape[0] // 2].detach().cpu().numpy()
          else:
            z_sample_hierarch = None
          _ = plot_samples(z_sample_strat, z_sample_hierarch, ax=ax[3])
          ax[3].margins(0)
          plt.show()
    
        # Check PSNR for issues and stop if any are found.
        if i == warmup_iters - 1:
          if val_psnr < warmup_min_fitness:
            print(f'Val PSNR {val_psnr} below warmup_min_fitness {warmup_min_fitness}. Stopping...')
            return False, train_psnrs, val_psnrs
        elif i < warmup_iters:
          if warmup_stopper is not None and warmup_stopper(i, psnr):
            print(f'Train PSNR flatlined at {psnr} for {warmup_stopper.patience} iters. Stopping...')
            return False, train_psnrs, val_psnrs
    
      return True, train_psnrs, val_psnrs
    # Run training session(s)
    for _ in range(n_restarts):
      model, fine_model, encode, encode_viewdirs, optimizer, warmup_stopper = init_models()
      success, train_psnrs, val_psnrs = train()
      if success and val_psnrs[-1] >= warmup_min_fitness:
        print('Training successful!')
        break
    
    print('')
    print(f'Done!')
    
    torch.save(model.state_dict(), 'nerf.pth')
    torch.save(fine_model.state_dict(), 'nerf-fine.pth')

     

Designed by Tistory.