-
[코드리뷰] NeRF Code BreakdownNeRF 2023. 7. 3. 15:53
작성중
Prepare Dataset
if not os.path.exists('tiny_nerf_data.npz'): !wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.np data = np.load('tiny_nerf_data.npz') images = data['images'] poses = data['poses'] focal = data['focal'] print(f'Images shape: {images.shape}') print(f'Poses shape: {poses.shape}') print(f'Focal length: {focal}') height, width = images.shape[1:3] near, far = 2., 6. n_training = 100 testimg_idx = 101 testimg, testpose = images[testimg_idx], poses[testimg_idx] plt.imshow(testimg) print('Pose') print(testpose)
더보기Images shape: (106, 100, 100, 3)
Poses shape: (106, 4, 4)
Focal length: 138.88887889922103
Pose at loc '0' [
[ 6.8935126e-01 5.3373039e-01 -4.8982298e-01 -1.9745398e+00]
[-7.2442728e-01 5.0788772e-01 -4.6610624e-01 -1.8789345e+00]
[ 1.4901163e-08 6.7615211e-01 7.3676193e-01 2.9699826e+00]
[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 1.0000000e+00]]Origins and Directions
(a) 106개 중 하나의 pose and (b) pose to dir dirs = np.stack([np.sum([0, 0, -1] * pose[:3, :3], axis=-1) for pose in poses]) # (106, 3) origins = poses[:, :3, -1] # (106, 3) ax = plt.figure(figsize=(12, 8)).add_subplot(projection='3d') _ = ax.quiver( origins[..., 0].flatten(), # X (pose) origins[..., 1].flatten(), # Y (pose) origins[..., 2].flatten(), # Z (pose) dirs[..., 0].flatten(), # U (dir) dirs[..., 1].flatten(), # V (dir) dirs[..., 2].flatten(), # W (dir) length=0.5, normalize=True) ax.set_xlabel('X') ax.set_ylabel('Y') ax.set_zlabel('z') plt.show()
더보기dir
print(np.stack([0, 0, -1] * pose[:3, :3] for pose in poses)[0]) # [[-0. 0. 0.01334572] [-0. -0. -0.95394367] [-0. 0. -0.29968831]]
dirs = np.stack([np.sum([0, 0, -1] * pose[:3, :3], axis=-1) for pose in poses]) # dirs = np.stack([[0, 0, -1] @ pose[:3, :3].T for pose in poses]) 와 동일한 코드 print(dirs[0]) # [ 0.01334572 -0.95394367 -0.29968831]
origin
print(poses[:, :3, -1][0]) # [-0.05379832 3.8454704 1.2080823 ]
하나의 Object에 대해 106개의 pose 가 존재 def get_rays( height: int, width: int, focal_length: float, c2w: torch.Tensor # 내가 사영시키고자 하는 ray direction ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Find origin and direction of rays through every pixel and camera origin. """ # Apply pinhole camera model to gather directions at each pixel i, j = torch.meshgrid( torch.arange(width, dtype=torch.float32).to(c2w), torch.arange(height, dtype=torch.float32).to(c2w), indexing='ij') i, j = i.transpose(-1, -2), j.transpose(-1, -2) directions = torch.stack([(i - width * .5) / focal_length, -(j - height * .5) / focal_length, -torch.ones_like(i) ], dim=-1) # Apply camera pose to directions rays_d = torch.sum(directions[..., None, :] * c2w[:3, :3], dim=-1) # Origin is same for all directions (the optical center) rays_o = c2w[:3, -1].expand(rays_d.shape) return rays_o, rays_d
# Gather as torch tensors images = torch.from_numpy(data['images'][:n_training]).to(device) poses = torch.from_numpy(data['poses']).to(device) focal = torch.from_numpy(data['focal']).to(device) testimg = torch.from_numpy(data['images'][testimg_idx]).to(device) testpose = torch.from_numpy(data['poses'][testimg_idx]).to(device) # Grab rays from sample image height, width = images.shape[1:3] with torch.no_grad(): ray_origin, ray_direction = get_rays(height, width, focal, testpose)
더보기def sampling_color(dat, step=1): if not isinstance(dat, np.ndarray): dat = dat.cpu() if dat.shape[-1] != 1: dat = dat[..., 0] if step > 1: dat = dat.flatten()[::step] size = len(dat) sampled = np.linspace(0, 2*np.pi, size) u = np.cos(sampled) v = np.sin(sampled) colors = np.arctan2(u, v) norm = Normalize() norm.autoscale(colors) colormap = cm.Paired return colormap(norm(colors))
step = 1 ax = plt.figure(figsize=(12, 8)).add_subplot(projection='3d') size = len(rays_o[..., 0].flatten()[::step]) sampled = rays_o[..., 0].flatten()[::step] / (2 * np.pi) sampled = np.linspace(0, 2*np.pi, size) u = np.cos(sampled) v = np.sin(sampled) colors = np.arctan2(u, v) norm = Normalize() norm.autoscale(colors) # we need to normalize our colors array to match it colormap domain # which is [0, 1] colormap = cm.Paired _ = ax.quiver( rays_o[..., 0].flatten()[::step], rays_o[..., 1].flatten()[::step], rays_o[..., 2].flatten()[::step], rays_d[..., 0].flatten()[::step], rays_d[..., 1].flatten()[::step], rays_d[..., 2].flatten()[::step], color=colormap(norm(colors)), length=0.1, normalize=False, arrow_length_ratio=.05) ax.set_xlabel('X') ax.set_ylabel('Y') ax.set_zlabel('z') plt.show()
(a) test 대상 이미지 위치 (b) test 이미지 pose로 행렬 연산 후, 100*100 이미지 pixel에서 ray plot 한 모습 Stratified Sampling
Now that we have these lines, defined as origin and direction vectors, we can begin the process of sampling them. Recall that NeRF takes a coarse-to-fine sampling strategy, starting with the stratified sampling approach.
The stratified sampling approach splits the ray into evenly-spaced bins and randomly samples within each bin.
Stratified sampling은 Ray를 동일간격으로 slicing하여 각 slice 내에서 랜덤 샘플링을 통해 샘플을 추출한다.
The perturb setting determines whether to sample points uniformly from each bin or to simply use the bin center as the point.
Perturb 설정은 샘플링 지점을 균일한 간격으로 할지, 단순히 bin의 중심점을 샘플링 point로 사용할지 정한다.
In most cases, we want to keep perturb = True as it will encourage the network to learn over a continuously sampled space. It may be useful to disable for debugging.
대부분의 경우 perturb=True로 설정한다. 이는 네트워크가 연속적인 샘플링 공간에서 더 잘 학습할 수 있게 돕고, debugging을 조금 더 어렵게 한다.def sample_stratified( rays_o: torch.Tensor, # 100x100x3 rays_d: torch.Tensor, # 100x100x3 near: float, far: float, n_samples: int, perturb: Optional[bool] = True, # regularly-spaced or not inverse_depth: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Sample along ray from regularly-spaced bins. """ # Grab samples for space integration along ray t_vals = torch.linspace(0., 1., n_samples, device=rays_o.device) if not inverse_depth: # Sample linearly between `near` and `far` (== parsing [0, 1] into [near, far]) z_vals = near * (1.-t_vals) + far * (t_vals) # [Out] tensor([2.0000, 2.5714, 3.1429, 3.7143, 4.2857, 4.8571, 5.4286, 6.0000], device='cuda:0') else: # Sample linearly in inverse depth (disparity) z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals)) # Draw uniform samples from bins along ray if perturb: mids = .5 * (z_vals[1:] + z_vals[:-1]) # tensor([2.2857, 2.8571, 3.4286, 4.0000, 4.5714, 5.1429, 5.7143], device='cuda:0') upper = torch.concat([mids, z_vals[-1:]], dim=-1) # (mids, z_vals[-1]) # tensor([2.2857, 2.8571, 3.4286, 4.0000, 4.5714, 5.1429, 5.7143, 6.0000], device='cuda:0') lower = torch.concat([z_vals[:1], mids], dim=-1) # (z_vals[0], mids) # tensor([2.0000, 2.2857, 2.8571, 3.4286, 4.0000, 4.5714, 5.1429, 5.7143], device='cuda:0') t_rand = torch.rand([n_samples], device=z_vals.device) z_vals = lower + (upper - lower) * t_rand # [Out] tensor([2.2012, 2.6339, 3.1944, 3.9895, 4.4883, 4.9825, 5.3318, 5.8038], device='cuda:0') # ((1 - t_rand)*z_vals[0] + t_rand*mids, (1 - t_rand)*mids + z_vals[-1]*t_rand) # == 18line 의 ((1 - t_rand) * near + t_rand * far, (1 - t_rand) * near + t_rand * far) 랑 똑같은 형식임 # == Sample linearly between (z_vals[0], mids) / (mids, z_vals[-1]) z_vals = z_vals.expand(list(rays_o.shape[:-1]) + [n_samples]) # (, 8) -> (1e4, 8) size의 z_vals 생성 ''' [Out] tensor([[2.2012, 2.6339, 3.1944, ..., 4.9825, 5.3318, 5.8038], [2.2012, 2.6339, 3.1944, ..., 4.9825, 5.3318, 5.8038], [2.2012, 2.6339, 3.1944, ..., 4.9825, 5.3318, 5.8038], ..., [2.2012, 2.6339, 3.1944, ..., 4.9825, 5.3318, 5.8038], [2.2012, 2.6339, 3.1944, ..., 4.9825, 5.3318, 5.8038], [2.2012, 2.6339, 3.1944, ..., 4.9825, 5.3318, 5.8038]], device='cuda:0') ''' # Apply scale from `rays_d` and offset from `rays_o` to samples # pts: (width, height, n_samples, 3) pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] # rays_o/rays_d : (1e4, 3) -> (1e4, 1, 3) # z_vals : (1e4, 8) -> (1e4, 8, 1) # pts: (1e4, 8, 3) return pts, z_vals
# Draw stratified samples from example rays_o = ray_origin.view([-1, 3]) rays_d = ray_direction.view([-1, 3]) n_samples = 8 perturb = True inverse_depth = False with torch.no_grad(): pts, z_vals = sample_stratified(rays_o, rays_d, near, far, n_samples, perturb=perturb, inverse_depth=inverse_depth) print('Input Points') print(pts.shape) # (10000, n_samples, 3) print('') print('Distances Along Ray') print(z_vals.shape) # (10000, n_samples, 1)
더보기def plot_rays_3d(ax, ray_origin_cpu, ray_direction_cpu, color='black', length=0.1, arrow_ratio=.05, step=9): ray_origin_cpu = ray_origin_cpu.cpu() ray_direction_cpu = ray_direction_cpu.cpu() _ = ax.quiver( ray_origin_cpu[..., 0].flatten()[::step], ray_origin_cpu[..., 1].flatten()[::step], ray_origin_cpu[..., 2].flatten()[::step], ray_direction_cpu[..., 0].flatten()[::step], ray_direction_cpu[..., 1].flatten()[::step], ray_direction_cpu[..., 2].flatten()[::step], length=length, color=color, normalize=False, arrow_length_ratio=arrow_ratio ) ax.set_xlabel('X') ax.set_ylabel('Y') ax.set_zlabel('z'
n = 8 colors = np.linspace(0, 1, n) colors = cm.Paired(colors) ax = plt.figure(figsize=(16, 12)).add_subplot(projection='3d') margin = 0.5 ax.margins(margin, margin, margin) arrow_ratio=.1 step = 10000 for i in range(n): plot_rays_3d(ax, torch.zeros_like(rays_o.reshape(100, 100, 3)), rays_o.reshape(100, 100, 3), color='black', length=1, arrow_ratio=arrow_ratio, step=step) plot_rays_3d(ax, torch.zeros_like(rays_o.reshape(100, 100, 3)), rays_d.reshape(100, 100, 3), color='gray', length=1, arrow_ratio=arrow_ratio, step=step) plot_rays_3d(ax, rays_o.reshape(100, 100, 3), (pts[:, i, :]).reshape(100, 100, 3) - rays_o.reshape(100, 100, 3), color=colors[i], length=1, arrow_ratio=arrow_ratio, step=step) plot_rays_3d(ax, rays_o.reshape(100, 100, 3), rays_d.reshape(100, 100, 3), color='gray', length=1, arrow_ratio=arrow_ratio, step=step) plt.show()
(1) black: rays_o (2) gray: rays_d & rays_o - rays_d (3) colors: pts Perturbed vs. Unperturbed
Now we visualize these sampled points.
이제 sampled points를 시각화 해보자.The unperturbed blue points are the bin "centers." The red points are a sampling of perturbed points.
파란색 점들은 Unperturbed Points로, 각 bin의 중심값을 샘플링한 점들입니다. 빨간색 점들은 Perturbed Points를 의미한다.Notice how the red points are slightly offset from the blue points above them, but all are constrained between near and far.
빨간 점들이 파란 점들에 비해 아래로 조금 떨어져 있다 (잘 보이게 하기 위해). 모든 점들은 near (2.) 와 far (6.) 사이에 위치한다.y_vals = torch.zeros_like(z_vals) _, z_vals_unperturbed = sample_stratified(rays_o, rays_d, near, far, n_samples, perturb=False, inverse_depth=inverse_depth) plt.plot(z_vals_unperturbed[0].cpu().numpy(), 1 + y_vals[0].cpu().numpy(), 'b-o') plt.plot(z_vals[0].cpu().numpy(), y_vals[0].cpu().numpy(), 'r-o') plt.ylim([-1, 2]) plt.title('Stratified Sampling (blue) with Perturbation (red)') ax = plt.gca() ax.axes.yaxis.set_visible(False) plt.grid(True)
Positional Encoder
Much like Transformers, NeRFs make use of positional encoders. In this case, it's to map the inputs to a higher frequency space to compensate for the bias that neural networks have for learning lower-frequency functions.
Transformer 처럼, NeRF도 Positional Encoder (PE) 를 사용합니다. 하지만, NeRF는 네트워크가 낮은 frequency 함수를 학습할 때 발생하는 bias를 상쇄하기 위해 input을 더 높은 frequency 공간으로 맵핑하기 위해 PE를 사용한다.Here we build a simple torch.nn.Module of our positional encoder.
우리는 간단한 torch.nn.Module을 만들어 PE를 테스트 해 볼 것이다.The same encoder implementation can be applied to both input samples and view directions.
동일하게 구현된 Encoder를 사용하여 샘플링된 input 과 view direction에 적용하였다.However, we choose different parameters for these inputs.
하지만, 우리는 샘플링된 input과 view direction에 서로 다른 파라미터를 선택하였다.We use the default settings from the original.
우리는 기본 설정을 사용하여 구현하였다.#The Positional Encoding 를 어떻게 하는 것인가? #Transformer_모델
Self-attention이 있는 Transformer의 후속 모델들은 positional encoding도 transformer의 방식을 따른다. 본 포스팅은 Positional Encoding부분을 자세하게 다루고자 한다. "Attention is all you need"라는 논문에서 cos, sin
pongdangstory.tistory.com
class PositionalEncoder(nn.Module): r""" Sine-cosine positional encoder for input points. """ def __init__( self, d_input: int, n_freqs: int, log_space: bool = False ): super().__init__() self.d_input = d_input self.n_freqs = n_freqs self.log_space = log_space self.d_output = d_input * (1 + 2 * self.n_freqs) self.embed_fns = [lambda x: x] # Define frequencies in either linear or log scale if self.log_space: freq_bands = 2.**torch.linspace(0., self.n_freqs - 1, self.n_freqs) else: freq_bands = torch.linspace(2.**0., 2.**(self.n_freqs - 1), self.n_freqs) # Alternate sin and cos for freq in freq_bands: self.embed_fns.append(lambda x, freq=freq: torch.sin(x * freq)) self.embed_fns.append(lambda x, freq=freq: torch.cos(x * freq)) def forward( self, x ) -> torch.Tensor: r""" Apply positional encoding to input. """ return torch.concat([fn(x) for fn in self.embed_fns], dim=-1)
# Create encoders for points and view directions encoder = PositionalEncoder(3, 10) viewdirs_encoder = PositionalEncoder(3, 4) # Grab flattened points and view directions pts_flattened = pts.reshape(-1, 3) # (1e4, 8, 3) -> (8e4, 3) viewdirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) # (1e4, 3) # torch.norm 은 벡터의 길이를 측정하는 함수이다. # viewdirs는 rays_d 의 unit vector 이다. flattened_viewdirs = viewdirs[:, None, ...].expand(pts.shape).reshape((-1, 3)) # pts 와 형태 맞춰주기 # Encode inputs encoded_points = encoder(pts_flattened) # (8e4, 63) : n_freq(10) * 2(cos, sin) + x = 21 ... (,3) 형태로 x가 들어오기 때문에 -> 21 * 3 = 63 encoded_viewdirs = viewdirs_encoder(flattened_viewdirs) # (8e4, 27) # n_freq(4) * 2 + x = 9 ... (,3) -> 27 print('Encoded Points') print(encoded_points.shape) print(torch.min(encoded_points), torch.max(encoded_points), torch.mean(encoded_points)) print('') print(encoded_viewdirs.shape) print('Encoded Viewdirs') print(torch.min(encoded_viewdirs), torch.max(encoded_viewdirs), torch.mean(encoded_viewdirs)) print('')
NeRF Model
Here we define the NeRF model, which consists primarily of a ModuleList of Linear layers, separated by non-linear activation functions and the occasional residual connection.
Linear 레이어와 Non-linear activation 함수의 ModuleList 로 이루어진 NeRF 모델을 정의하고 residual 연결을 일부 사용한다.This model features an optional input for view directions, which will alter the model architecture if provided at instantiation.
이 모델의 feature들은 view direction의 옵션이다. 모델을 인스턴스화 시킬 때 설계를 바꿀 수 있도록 포함되어있다.This implementation is based on Section 3 of the original "NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis" paper and uses the same defaults.
아래의 구현은 NeRF 논문의 Section 3를 바탕으로 하였다.1. Basic NeRF Model
class NeRF(nn.Module): r""" Neural radiance fields module. """ def __init__( self, d_input: int = 3, n_layers: int = 8, d_filter: int = 256, skip: Tuple[int] = (4,), d_viewdirs: Optional[int] = None # (8e4, 27) ): super().__init__() self.d_input = d_input self.skip = skip self.act = nn.functional.relu self.d_viewdirs = d_viewdirs # Create model layers self.layers = nn.ModuleList( [nn.Linear(self.d_input, d_filter)] + [nn.Linear(d_filter + self.d_input, d_filter) if i + 1 in skip \ else nn.Linear(d_filter, d_filter) for i in range(n_layers - 1)] ) # Bottleneck layers if self.d_viewdirs is not None: # If using viewdirs, split alpha and RGB self.alpha_out = nn.Linear(d_filter, 1) self.rgb_filters = nn.Linear(d_filter, d_filter) self.branch = nn.Linear(d_filter + self.d_viewdirs, d_filter // 2) # (256 + 27, 256 // 2) self.output = nn.Linear(d_filter // 2, 3) # alpha_out 채널과 합쳐져서 depth=4 가 될 예정 else: # If no viewdirs, use simpler output self.output = nn.Linear(d_filter, 4) def forward( self, x: torch.Tensor, viewdirs: Optional[torch.Tensor] = None ) -> torch.Tensor: r""" Forward pass with optional view direction. """ # Cannot use viewdirs if instantiated with d_viewdirs = None if self.d_viewdirs is None and viewdirs is not None: raise ValueError('Cannot input x_direction if d_viewdirs was not given.') # Apply forward pass up to bottleneck x_input = x for i, layer in enumerate(self.layers): x = torch.cat([x, x_input], dim=-1) if i in self.skip else x x = self.act(layer(x)) # Apply bottleneck if self.d_viewdirs is not None: # Split alpha from network output alpha = self.alpha_out(x) # Pass through bottleneck to get RGB x = self.rgb_filters(x) x = torch.concat([x, viewdirs], dim=-1) x = self.act(self.branch(x)) x = self.output(x) # Concatenate alphas to output x = torch.concat([x, alpha], dim=-1) else: # Simple output x = self.output(x) return x
더보기2. MipNeRF Model
class MipNerfModel(nn.Module): """Nerf NN Model with both coarse and fine MLPs.""" num_samples: int = 128 # The number of samples per level. num_levels: int = 2 # The number of sampling levels. resample_padding: float = 0.01 # Dirichlet/alpha "padding" on the histogram. stop_level_grad: bool = True # If True, don't backprop across levels') use_viewdirs: bool = True # If True, use view directions as a condition. lindisp: bool = False # If True, sample linearly in disparity, not in depth. ray_shape: str = 'cone' # The shape of cast rays ('cone' or 'cylinder'). min_deg_point: int = 0 # Min degree of positional encoding for 3D points. max_deg_point: int = 16 # Max degree of positional encoding for 3D points. deg_view: int = 4 # Degree of positional encoding for viewdirs. density_activation: Callable[..., Any] = nn.softplus # Density activation. density_noise: float = 0. # Standard deviation of noise added to raw density. density_bias: float = -1. # The shift added to raw densities pre-activation. rgb_activation: Callable[..., Any] = nn.sigmoid # The RGB activation. rgb_padding: float = 0.001 # Padding added to the RGB outputs. disable_integration: bool = False # If True, use PE instead of IPE. @nn.compact def __call__(self, rng, rays, randomized, white_bkgd): """The mip-NeRF Model. Args: rng: jnp.ndarray, random number generator. rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs. randomized: bool, use randomized stratified sampling. white_bkgd: bool, if True, use white as the background (black o.w.). Returns: ret: list, [*(rgb, distance, acc)] """ # Construct the MLP. mlp = MLP() ret = [] for i_level in range(self.num_levels): key, rng = random.split(rng) if i_level == 0: # Stratified sampling along rays t_vals, samples = mip.sample_along_rays( key, rays.origins, rays.directions, rays.radii, self.num_samples, rays.near, rays.far, randomized, self.lindisp, self.ray_shape, ) else: t_vals, samples = mip.resample_along_rays( key, rays.origins, rays.directions, rays.radii, t_vals, weights, randomized, self.ray_shape, self.stop_level_grad, resample_padding=self.resample_padding, ) if self.disable_integration: samples = (samples[0], jnp.zeros_like(samples[1])) samples_enc = mip.integrated_pos_enc( samples, self.min_deg_point, self.max_deg_point, ) # Point attribute predictions if self.use_viewdirs: viewdirs_enc = mip.pos_enc( rays.viewdirs, min_deg=0, max_deg=self.deg_view, append_identity=True, ) raw_rgb, raw_density = mlp(samples_enc, viewdirs_enc) else: raw_rgb, raw_density = mlp(samples_enc) # Add noise to regularize the density predictions if needed. if randomized and (self.density_noise > 0): key, rng = random.split(rng) raw_density += self.density_noise * random.normal( key, raw_density.shape, dtype=raw_density.dtype) # Volumetric rendering. rgb = self.rgb_activation(raw_rgb) rgb = rgb * (1 + 2 * self.rgb_padding) - self.rgb_padding density = self.density_activation(raw_density + self.density_bias) comp_rgb, distance, acc, weights = mip.volumetric_rendering( rgb, density, t_vals, rays.directions, white_bkgd=white_bkgd, ) ret.append((comp_rgb, distance, acc)) return ret
Volume Rendering
From the raw NeRF outputs, we still need to convert these into an image.
아직 NeRF 결과물은 이미지로 변환하는 과정이 필요하다.This is where we apply the volume integration described in Equations 1-3 in Section 4 of the paper.
이 과정에서 volume 을 합치는 과정이 필요하다. 이것은 Section 4에 나온 1-3 방정식에 설명되어 있다.Essentially, we take the weighted sum of all samples along the ray of each pixel to get the estimated color value at that pixel.
필수적으로, 우리는 모든 샘플들의 각 pixel에서 ray를 따라 가중치합을 해야한다. 이를 통해 우리는 해당 픽셀의 색을 추정할 수 있다.Each RGB sample is weighted by its alpha value.
각 RGB 샘플은 alpha 값을 가중치로 사용하여 연산한다.Higher alpha values indicate higher likelihood that the sampled area is opaque, therefore points further along the ray are likelier to be occluded.
높은 알파값은 샘플링된 부분이 불투명할 가능성이 높다는 것을 의미한다. 따라서, 해당 ray 뒤에 위치한 ray 위의 포인트들은 가려질 가능성이 있다. (= 알파값을 사용하여 불투명한 지점을 찾아내고)The cumulative product ensures that those further points are dampened.
누적곱을 사용하여 그 이후 지점들의 영향을 줄여버린다.def cumprod_exclusive( tensor: torch.Tensor ) -> torch.Tensor: r""" (Courtesy of https://github.com/krrish94/nerf-pytorch) Mimick functionality of tf.math.cumprod(..., exclusive=True), as it isn't available in PyTorch. Args: tensor (torch.Tensor): Tensor whose cumprod (cumulative product, see `torch.cumprod`) along dim=-1 is to be computed. Returns: cumprod (torch.Tensor): cumprod of Tensor along dim=-1, mimiciking the functionality of tf.math.cumprod(..., exclusive=True) (see `tf.math.cumprod` for details). """ # Compute regular cumprod first (this is equivalent to `tf.math.cumprod(..., exclusive=False)`). cumprod = torch.cumprod(tensor, -1) # "Roll" the elements along dimension 'dim' by 1 element. cumprod = torch.roll(cumprod, 1, -1) # Replace the first element by "1" as this is what tf.cumprod(..., exclusive=True) does. cumprod[..., 0] = 1. return cumprod def raw2outputs( raw: torch.Tensor, # (n_rays, n_samples, 4) z_vals: torch.Tensor, # (n_rays, n_samples) rays_d: torch.Tensor, # (n_rays, 3) raw_noise_std: float = 0.0, white_bkgd: bool = False ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: r""" Convert the raw NeRF output into RGB and other maps. """ # Difference between consecutive elements of `z_vals`. dists = z_vals[..., 1:] - z_vals[..., :-1] # (n_rays, n_samples - 1) dists = torch.cat([dists, 1e10 * torch.ones_like(dists[..., :1])], dim=-1) # (n_rays, n_samples) # Multiply each distance by the norm of its corresponding direction ray # to convert to real world distance (accounts for non-unit directions). dists = dists * torch.norm(rays_d[..., None, :], dim=-1) # (n_rays, n_samples) # Add noise to model's predictions for density. Can be used to # regularize network during training (prevents floater artifacts). noise = 0. if raw_noise_std > 0.: noise = torch.randn(raw[..., 3].shape) * raw_noise_std # Predict density of each sample along each ray. Higher values imply # higher likelihood of being absorbed at this point. [n_rays, n_samples] alpha = 1.0 - torch.exp(-nn.functional.relu(raw[..., 3] + noise) * dists) # (n_rays, n_samples) # Compute weight for RGB of each sample along each ray. [n_rays, n_samples] # The higher the alpha, the lower subsequent weights are driven. weights = alpha * cumprod_exclusive(1. - alpha + 1e-10) # (n_rays, n_samples) # Compute weighted RGB map. rgb = torch.sigmoid(raw[..., :3]) # (n_rays, n_samples, 3) rgb_map = torch.sum(weights[..., None] * rgb, dim=-2) # [n_rays, 1, 3] # Estimated depth map is predicted distance. depth_map = torch.sum(weights * z_vals, dim=-1) # Disparity map is inverse depth. disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1)) # Sum of weights along each ray. In [0, 1] up to numerical error. acc_map = torch.sum(weights, dim=-1) # weight가 alpha값을 의미하니까, [0, 1] 범위를 벗어나는 경우 accuracy X # To composite onto a white background, use the accumulated alpha map. if white_bkgd: rgb_map = rgb_map + (1. - acc_map[..., None]) return rgb_map, depth_map, acc_map, weights
Hierarchical Volume Sampling
The 3D space is in fact very sparse with occlusions and so most points don't contribute much to the rendered image.
실제로 3D 공간에서 서로 가려진 경우가 많기 때문에 이미지는 소수의 points로만 렌더링된다.It is therefore more beneficial to oversample regions with a high likelihood of contributing to the integral.
그래서 렌더링에 참여하는 부분을 집중적으로 sampling하면 성능이 개선될 수 있다.Here we apply learned, normalized weights to the first set of samples to create a PDF across the ray, then apply inverse transform sampling to this PDF to gather a second set of samples.
우리는 학습기반의 정규화된 가중치 (weights) 를 첫번째 샘플링 세트 (z_vals) 에 적용하여 ray를 따라 pdf를 구성한다. 그 다음, pdf를 활용하여 역변환 샘플링을 통해 (렌더링에 관여하는 point들을 재추정하고) 두번째 샘플링 세트 (렌더링 대상 area를 집중적으로 샘플링한 포인트들) 를 구성한다.
(weights를 정규화된 가중치로 볼 수 있는 이유: [0, 1]의 alpha 값이기 때문에)def sample_pdf( bins: torch.Tensor, # (n_rays, n_samples - 1) # z_vals_mid 를 input으로 넣어주고 있다. weights: torch.Tensor, # (n_rays, n_samples - 2) # weights[:, 1:-1] # raw2outputs 함수를 통해 재구성된 alpha값 n_samples: int, # 64 perturb: bool = False ) -> torch.Tensor: r""" Apply inverse transform sampling to a weighted set of points. """ # Normalize weights to get PDF. pdf = (weights + 1e-5) / torch.sum(weights + 1e-5, -1, keepdims=True) # [n_rays, weights.shape[-1]] # (1e4, 62) # Convert PDF to CDF. cdf = torch.cumsum(pdf, dim=-1) # [n_rays, weights.shape[-1]] # (1e4, 62) cdf = torch.concat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1) # (n_rays, weights.shape[-1] + 1) == (n_rays, 62 + 1) # Take sample positions to grab from CDF. Linear when perturb == 0. if not perturb: u = torch.linspace(0., 1., n_samples, device=cdf.device) # [1, n_samples] tar_shape = list(cdf.shape[:-1]) + [n_samples] # = [n_rays, n_samples] u = u.expand(tar_shape) # [n_rays, n_samples] else: u = torch.rand(list(cdf.shape[:-1]) + [n_samples], device=cdf.device) # (n_rays, n_samples) # (1e4, 64) # Find indices along CDF where values in u would be placed. # CDF 상에서 u에 해당 값들이 위치하는 index 찾기. u = u.contiguous() # Returns contiguous tensor with same values. inds = torch.searchsorted(cdf, u, right=True) # (n_rays, n_samples) # cdf 63 개 값 중, u 64 개 위치의 index 반환하기 # => 하나의 bin에 2개의 sampled point가 존재하는 구간이 반드시 하나 이상 존재함. # Clamp indices that are out of bounds. below = torch.clamp(inds - 1, min=0) # (n_rays, n_samples) above = torch.clamp(inds, max=cdf.shape[-1] - 1) # (n_rays, n_samples) inds_g = torch.stack([below, above], dim=-1) # (n_rays, n_samples, 2) # Sample from cdf and the corresponding bin centers. matched_shape = list(inds_g.shape[:-1]) + [cdf.shape[-1]] # (3,) cdf_g = torch.gather(cdf.unsqueeze(-2).expand(matched_shape), dim=-1, index=inds_g) # (n_rays, n_samples, 2) # cdf.shape = (n_rays, n_samples - 1) # -> cdf.unsqueeze(-2).shape = (n_rays, 1, n_samples - 1) # -> ~.expand(matched_shape) = (n_rays, 1, n_samples) # -> torch.gather(~, dim=-1) = (n_rays, n_samples, 2) ~ shape of inds_g bins_g = torch.gather(bins.unsqueeze(-2).expand(matched_shape), dim=-1, index=inds_g) # (n_rays, n_samples, 2) ~ shape of inds_g # Convert samples to ray length. denom = (cdf_g[..., 1] - cdf_g[..., 0]) # (n_rays, n_samples) denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) # 1e-5보다 작은 값인 경우 1로 값 고정 # (n_rays, n_samples) t = (u - cdf_g[..., 0]) / denom # (n_rays, n_samples) samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) # (n_rays, n_samples) ''' t: below~above 사이에서 u의 위치를 비율로 나타낸 것 (비율 1은 above-below 거리를 의미한다.) below + t * denom 은 실제 u의 위치를 의미한다. bins_below + t * (bins_above - bins_below) 의 의미는 실제 ray 상에서 cdf 위치 (u'...약간 approx.된) 를 찾기 위한 식이다. 위는 sampling 할 위치를 ray 에서 찾는 식이다. ''' return samples # (n_rays, n_samples)
def sample_hierarchical( rays_o: torch.Tensor, # (100, 100, 3) rays_d: torch.Tensor, # (100, 100, 3) z_vals: torch.Tensor, # (n_rays, n_samples) weights: torch.Tensor, # (n_rays, n_samples) n_samples: int, # 64 perturb: bool = False ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r""" Apply hierarchical sampling to the rays. """ # Draw samples from PDF using z_vals as bins and weights as probabilities. z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) ''' z_vals[..., 1:], z_vals[..., :-1] # (n_rays, n_samples - 1) z_vals_mid # (n_rays, n_samples - 1) ''' new_z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], n_samples, perturb=perturb) # (n_rays, n_samples) ''' weights[0] 을 버리는 이유: sample_pdf 연산 중 noise 연산 시 1 값이 주는 영향력을 제거하기 위해서 -> sampled_pdf 함수 내부에서 weights[0] = 1 요소를 추가해주어 (n_rays, n_samples - 1) 로 만들어줌 weights[-1] 을 버리는 이유: z_vals_mid 값을 사용 + weights 값은 해당 포인트 이전까지의 영향력을 나타내기 때문에 -> z_vals_mid[-1] 은 weights[-1] 의 영역을 일부 포함하지 않기 때문에 weights의 right 값이 아닌 left 값을 선택 ''' new_z_samples = new_z_samples.detach() # Resample points from ray based on PDF. z_vals_combined, _ = torch.sort(torch.cat([z_vals, new_z_samples], dim=-1), dim=-1) pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals_combined[..., :, None] # [n_rays, n_samples + n_samples, 3] return pts, # (n_rays, n_samples * 2, 3) z_vals_combined, # (n_rays, n_samples * 2) new_z_samples # (n_rays, n_samples)
Full Forward Pass
Here is where we put everything together to compute a single forward pass through our model.
Due to potential memory issues, the forward pass is computed in "chunks," which are then aggregated across a single batch.
The gradient propagation is done after the whole batch is processed, hence the distinction between "chunks" and "batches."
Chunking is especially important for the Google Colab environment, which provides more modest resources than those cited in the original paper.
더보기def get_chunks( inputs: torch.Tensor, chunksize: int = 2**15 ) -> List[torch.Tensor]: r""" Divide an input into chunks. """ return [inputs[i:i + chunksize] for i in range(0, inputs.shape[0], chunksize)] def prepare_chunks( points: torch.Tensor, encoding_function: Callable[[torch.Tensor], torch.Tensor], chunksize: int = 2**15 ) -> List[torch.Tensor]: r""" Encode and chunkify points to prepare for NeRF model. """ points = points.reshape((-1, 3)) points = encoding_function(points) points = get_chunks(points, chunksize=chunksize) return points def prepare_viewdirs_chunks( points: torch.Tensor, rays_d: torch.Tensor, encoding_function: Callable[[torch.Tensor], torch.Tensor], chunksize: int = 2**15 ) -> List[torch.Tensor]: r""" Encode and chunkify viewdirs to prepare for NeRF model. """ # Prepare the viewdirs viewdirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) viewdirs = viewdirs[:, None, ...].expand(points.shape).reshape((-1, 3)) viewdirs = encoding_function(viewdirs) viewdirs = get_chunks(viewdirs, chunksize=chunksize) return viewdirs
def nerf_forward( rays_o: torch.Tensor, # (100, 100, 3) rays_d: torch.Tensor, # (100, 100, 3) near: float, far: float, encoding_fn: Callable[[torch.Tensor], torch.Tensor], coarse_model: nn.Module, # class NeRF kwargs_sample_stratified: dict = None, n_samples_hierarchical: int = 0, kwargs_sample_hierarchical: dict = None, fine_model = None, viewdirs_encoding_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, chunksize: int = 2**15 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]: r""" Compute forward pass through model(s). """ # Set no kwargs if none are given. if kwargs_sample_stratified is None: kwargs_sample_stratified = {} if kwargs_sample_hierarchical is None: kwargs_sample_hierarchical = {} # Sample query points along each ray. query_points, z_vals = sample_stratified( rays_o, rays_d, near, far, **kwargs_sample_stratified) # Prepare batches. batches = prepare_chunks(query_points, encoding_fn, chunksize=chunksize) if viewdirs_encoding_fn is not None: batches_viewdirs = prepare_viewdirs_chunks(query_points, rays_d, viewdirs_encoding_fn, chunksize=chunksize) else: batches_viewdirs = [None] * len(batches) # Coarse model pass. # Split the encoded points into "chunks", run the model on all chunks, and # concatenate the results (to avoid out-of-memory issues). predictions = [] for batch, batch_viewdirs in zip(batches, batches_viewdirs): predictions.append(coarse_model(batch, viewdirs=batch_viewdirs)) raw = torch.cat(predictions, dim=0) raw = raw.reshape(list(query_points.shape[:2]) + [raw.shape[-1]]) # Perform differentiable volume rendering to re-synthesize the RGB image. rgb_map, depth_map, acc_map, weights = raw2outputs(raw, z_vals, rays_d) # rgb_map, depth_map, acc_map, weights = render_volume_density(raw, rays_o, z_vals) outputs = { 'z_vals_stratified': z_vals } # Fine model pass. if n_samples_hierarchical > 0: # Save previous outputs to return. rgb_map_0, depth_map_0, acc_map_0 = rgb_map, depth_map, acc_map # Apply hierarchical sampling for fine query points. query_points, z_vals_combined, z_hierarch = sample_hierarchical( rays_o, rays_d, z_vals, weights, n_samples_hierarchical, **kwargs_sample_hierarchical) # Prepare inputs as before. batches = prepare_chunks(query_points, encoding_fn, chunksize=chunksize) if viewdirs_encoding_fn is not None: batches_viewdirs = prepare_viewdirs_chunks(query_points, rays_d, viewdirs_encoding_fn, chunksize=chunksize) else: batches_viewdirs = [None] * len(batches) # Forward pass new samples through fine model. fine_model = fine_model if fine_model is not None else coarse_model predictions = [] for batch, batch_viewdirs in zip(batches, batches_viewdirs): predictions.append(fine_model(batch, viewdirs=batch_viewdirs)) raw = torch.cat(predictions, dim=0) raw = raw.reshape(list(query_points.shape[:2]) + [raw.shape[-1]]) # Perform differentiable volume rendering to re-synthesize the RGB image. rgb_map, depth_map, acc_map, weights = raw2outputs(raw, z_vals_combined, rays_d) # Store outputs. outputs['z_vals_hierarchical'] = z_hierarch outputs['rgb_map_0'] = rgb_map_0 outputs['depth_map_0'] = depth_map_0 outputs['acc_map_0'] = acc_map_0 # Store outputs. outputs['rgb_map'] = rgb_map outputs['depth_map'] = depth_map outputs['acc_map'] = acc_map outputs['weights'] = weights return outputs
Train
At long last, we have (almost) everything we need to train the model.
Now we will do some setup for a simple training procedure, creating hyperparameters and helper functions, then train our model.
Hyperparameters
All hyperparameters for training are set here.
Defaults were taken from the original, unless computational constraints prohibit them.
In this case, we apply sensible defaults that are well within the resources provided by Google Colab.
더보기# Encoders d_input = 3 # Number of input dimensions n_freqs = 10 # Number of encoding functions for samples log_space = True # If set, frequencies scale in log space use_viewdirs = True # If set, use view direction as input n_freqs_views = 4 # Number of encoding functions for views # Stratified sampling n_samples = 64 # Number of spatial samples per ray perturb = True # If set, applies noise to sample positions inverse_depth = False # If set, samples points linearly in inverse depth # Model d_filter = 128 # Dimensions of linear layer filters n_layers = 2 # Number of layers in network bottleneck skip = [] # Layers at which to apply input residual use_fine_model = True # If set, creates a fine model d_filter_fine = 128 # Dimensions of linear layer filters of fine network n_layers_fine = 6 # Number of layers in fine network bottleneck # Hierarchical sampling n_samples_hierarchical = 64 # Number of samples per ray perturb_hierarchical = False # If set, applies noise to sample positions # Optimizer lr = 5e-4 # Learning rate # Training n_iters = 10000 batch_size = 2**14 # Number of rays per gradient step (power of 2) one_image_per_step = True # One image per gradient step (disables batching) chunksize = 2**14 # Modify as needed to fit in GPU memory center_crop = True # Crop the center of image (one_image_per_) center_crop_iters = 50 # Stop cropping center after this many epochs display_rate = 25 # Display test output every X epochs # Early Stopping warmup_iters = 100 # Number of iterations during warmup phase warmup_min_fitness = 10.0 # Min val PSNR to continue training at warmup_iters n_restarts = 10 # Number of times to restart if training stalls # We bundle the kwargs for various functions to pass all at once. kwargs_sample_stratified = { 'n_samples': n_samples, 'perturb': perturb, 'inverse_depth': inverse_depth } kwargs_sample_hierarchical = { 'perturb': perturb }
Training Classes and Functions
Here we create some helper functions for training.
NeRF can be prone to local minima, in which training will quickly stall and produce blank outputs.
EarlyStopping is used to restart the training when learning stalls, if necessary.
def plot_samples( z_vals: torch.Tensor, z_hierarch: Optional[torch.Tensor] = None, ax: Optional[np.ndarray] = None): r""" Plot stratified and (optional) hierarchical samples. """ y_vals = 1 + np.zeros_like(z_vals) if ax is None: ax = plt.subplot() ax.plot(z_vals, y_vals, 'b-o') if z_hierarch is not None: y_hierarch = np.zeros_like(z_hierarch) ax.plot(z_hierarch, y_hierarch, 'r-o') ax.set_ylim([-1, 2]) ax.set_title('Stratified Samples (blue) and Hierarchical Samples (red)') ax.axes.yaxis.set_visible(False) ax.grid(True) return ax def crop_center( img: torch.Tensor, frac: float = 0.5 ) -> torch.Tensor: r""" Crop center square from image. """ h_offset = round(img.shape[0] * (frac / 2)) w_offset = round(img.shape[1] * (frac / 2)) return img[h_offset:-h_offset, w_offset:-w_offset] class EarlyStopping: r""" Early stopping helper based on fitness criterion. """ def __init__( self, patience: int = 30, margin: float = 1e-4 ): self.best_fitness = 0.0 # In our case PSNR self.best_iter = 0 self.margin = margin self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop def __call__( self, iter: int, fitness: float ): r""" Check if criterion for stopping is met. """ if (fitness - self.best_fitness) > self.margin: self.best_iter = iter self.best_fitness = fitness delta = iter - self.best_iter stop = delta >= self.patience # stop training if patience exceeded return stop
def init_models(): r""" Initialize models, encoders, and optimizer for NeRF training. """ # Encoders encoder = PositionalEncoder(d_input, n_freqs, log_space=log_space) encode = lambda x: encoder(x) # View direction encoders if use_viewdirs: encoder_viewdirs = PositionalEncoder(d_input, n_freqs_views, log_space=log_space) encode_viewdirs = lambda x: encoder_viewdirs(x) d_viewdirs = encoder_viewdirs.d_output else: encode_viewdirs = None d_viewdirs = None # Models model = NeRF(encoder.d_output, n_layers=n_layers, d_filter=d_filter, skip=skip, d_viewdirs=d_viewdirs) model.to(device) model_params = list(model.parameters()) if use_fine_model: fine_model = NeRF(encoder.d_output, n_layers=n_layers, d_filter=d_filter, skip=skip, d_viewdirs=d_viewdirs) fine_model.to(device) model_params = model_params + list(fine_model.parameters()) else: fine_model = None # Optimizer optimizer = torch.optim.Adam(model_params, lr=lr) # Early Stopping warmup_stopper = EarlyStopping(patience=50) return model, fine_model, encode, encode_viewdirs, optimizer, warmup_stopper
Training Loop
def train(): r""" Launch training session for NeRF. """ # Shuffle rays across all images. if not one_image_per_step: height, width = images.shape[1:3] all_rays = torch.stack([torch.stack(get_rays(height, width, focal, p), 0) for p in poses[:n_training]], 0) rays_rgb = torch.cat([all_rays, images[:, None]], 1) rays_rgb = torch.permute(rays_rgb, [0, 2, 3, 1, 4]) rays_rgb = rays_rgb.reshape([-1, 3, 3]) rays_rgb = rays_rgb.type(torch.float32) rays_rgb = rays_rgb[torch.randperm(rays_rgb.shape[0])] i_batch = 0 train_psnrs = [] val_psnrs = [] iternums = [] for i in trange(n_iters): model.train() if one_image_per_step: # Randomly pick an image as the target. target_img_idx = np.random.randint(images.shape[0]) target_img = images[target_img_idx].to(device) if center_crop and i < center_crop_iters: target_img = crop_center(target_img) height, width = target_img.shape[:2] target_pose = poses[target_img_idx].to(device) rays_o, rays_d = get_rays(height, width, focal, target_pose) rays_o = rays_o.reshape([-1, 3]) rays_d = rays_d.reshape([-1, 3]) else: # Random over all images. batch = rays_rgb[i_batch:i_batch + batch_size] batch = torch.transpose(batch, 0, 1) rays_o, rays_d, target_img = batch height, width = target_img.shape[:2] i_batch += batch_size # Shuffle after one epoch if i_batch >= rays_rgb.shape[0]: rays_rgb = rays_rgb[torch.randperm(rays_rgb.shape[0])] i_batch = 0 target_img = target_img.reshape([-1, 3]) # Run one iteration of TinyNeRF and get the rendered RGB image. outputs = nerf_forward(rays_o, rays_d, near, far, encode, model, kwargs_sample_stratified=kwargs_sample_stratified, n_samples_hierarchical=n_samples_hierarchical, kwargs_sample_hierarchical=kwargs_sample_hierarchical, fine_model=fine_model, viewdirs_encoding_fn=encode_viewdirs, chunksize=chunksize) # Check for any numerical issues. for k, v in outputs.items(): if torch.isnan(v).any(): print(f"! [Numerical Alert] {k} contains NaN.") if torch.isinf(v).any(): print(f"! [Numerical Alert] {k} contains Inf.") # Backprop! rgb_predicted = outputs['rgb_map'] loss = torch.nn.functional.mse_loss(rgb_predicted, target_img) loss.backward() optimizer.step() optimizer.zero_grad() psnr = -10. * torch.log10(loss) train_psnrs.append(psnr.item()) # Evaluate testimg at given display rate. if i % display_rate == 0: model.eval() height, width = testimg.shape[:2] rays_o, rays_d = get_rays(height, width, focal, testpose) rays_o = rays_o.reshape([-1, 3]) rays_d = rays_d.reshape([-1, 3]) outputs = nerf_forward(rays_o, rays_d, near, far, encode, model, kwargs_sample_stratified=kwargs_sample_stratified, n_samples_hierarchical=n_samples_hierarchical, kwargs_sample_hierarchical=kwargs_sample_hierarchical, fine_model=fine_model, viewdirs_encoding_fn=encode_viewdirs, chunksize=chunksize) rgb_predicted = outputs['rgb_map'] loss = torch.nn.functional.mse_loss(rgb_predicted, testimg.reshape(-1, 3)) print("Loss:", loss.item()) val_psnr = -10. * torch.log10(loss) val_psnrs.append(val_psnr.item()) iternums.append(i) # Plot example outputs fig, ax = plt.subplots(1, 4, figsize=(24,4), gridspec_kw={'width_ratios': [1, 1, 1, 3]}) ax[0].imshow(rgb_predicted.reshape([height, width, 3]).detach().cpu().numpy()) ax[0].set_title(f'Iteration: {i}') ax[1].imshow(testimg.detach().cpu().numpy()) ax[1].set_title(f'Target') ax[2].plot(range(0, i + 1), train_psnrs, 'r') ax[2].plot(iternums, val_psnrs, 'b') ax[2].set_title('PSNR (train=red, val=blue') z_vals_strat = outputs['z_vals_stratified'].view((-1, n_samples)) z_sample_strat = z_vals_strat[z_vals_strat.shape[0] // 2].detach().cpu().numpy() if 'z_vals_hierarchical' in outputs: z_vals_hierarch = outputs['z_vals_hierarchical'].view((-1, n_samples_hierarchical)) z_sample_hierarch = z_vals_hierarch[z_vals_hierarch.shape[0] // 2].detach().cpu().numpy() else: z_sample_hierarch = None _ = plot_samples(z_sample_strat, z_sample_hierarch, ax=ax[3]) ax[3].margins(0) plt.show() # Check PSNR for issues and stop if any are found. if i == warmup_iters - 1: if val_psnr < warmup_min_fitness: print(f'Val PSNR {val_psnr} below warmup_min_fitness {warmup_min_fitness}. Stopping...') return False, train_psnrs, val_psnrs elif i < warmup_iters: if warmup_stopper is not None and warmup_stopper(i, psnr): print(f'Train PSNR flatlined at {psnr} for {warmup_stopper.patience} iters. Stopping...') return False, train_psnrs, val_psnrs return True, train_psnrs, val_psnrs
# Run training session(s) for _ in range(n_restarts): model, fine_model, encode, encode_viewdirs, optimizer, warmup_stopper = init_models() success, train_psnrs, val_psnrs = train() if success and val_psnrs[-1] >= warmup_min_fitness: print('Training successful!') break print('') print(f'Done!') torch.save(model.state_dict(), 'nerf.pth') torch.save(fine_model.state_dict(), 'nerf-fine.pth')