Loading [MathJax]/jax/output/CommonHTML/jax.js

ABOUT ME

Deep Learning, Machine Learning, Math and Photo

Today
Yesterday
Total
  • [코드리뷰] NeRFusion Code Breakdown
    NeRF 2023. 7. 31. 14:41

    작성중

    * 해당 포스팅은 NeRFusion official code (Pytorch)를 기반으로 합니다.

     NeRFusion 논문 리뷰: 2023.07.27 - [Papers] - [논문리뷰] NeRFusion: Fusing Radiance Fields for Large-Scale Scene Reconstruction

     

    ▶ NeRF 코드 리뷰: 2023.07.03 - [NeRF] - [코드리뷰] NeRF Code Breakdown

     

    [코드리뷰] NeRF Code Breakdown

    작성중 Prepare Dataset if not os.path.exists('tiny_nerf_data.npz'): !wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.np data = np.load('tiny_nerf_data.npz') images = data['images'] poses = data['poses'] focal = data['f

    libby-yu.tistory.com


    본 포스팅에서는 논문에서 중요하게 다뤄진 부분을 중점적으로 리뷰하였습니다. Official 코드에서 중요한 부분(GRU, Local Voxel, Global Voxel)에 대한 구현 코드를 train시에 어떻게 사용하는지 공개하지 않았기 때문에, 아래 Figure 2.에 관한 작동 과정 코드는 없습니다.

    1. NeRFusion (Overall)
    2. Sparse neural volume 
      • 3D Feature Volume (Ui) from feature map (Fi)
      • Local Volume (Vt)
      • Global Volume (Vgt)
    3. Fusion Module - GRU (Gated Recurrent Unit)
      • 업데이트게이트 (zt)
      •  재설정게이트 (rt)
    4. Voxel Pruning
    5. Sampling
    6. Direct Inference 
    7. Training
      • Pre-training across scenes

    * 프로젝트에서 사용하는 라이브러리 중 vren 은 models/csrc 패키지를 라이브러리화 하여 배포한 것이기 때문에, 프로젝트 실행 전 따로 설치하여 사용해야 합니다.

    NeRFusion

    
      
    class NeRFusion2(nn.Module):
    def __init__(self, scale, grid_size=128, global_representation=None):
    super().__init__()
    # scene bounding box
    # TODO: this is a temp easy solution
    self.scale = scale
    self.register_buffer('center', torch.zeros(1, 3))
    self.register_buffer('xyz_min', -torch.ones(1, 3)*scale)
    self.register_buffer('xyz_max', torch.ones(1, 3)*scale)
    self.register_buffer('half_size', (self.xyz_max-self.xyz_min)/2)
    self.grid_size = grid_size
    self.cascades = 1
    self.register_buffer('density_bitfield',
    torch.ones(self.grid_size**3//8, dtype=torch.uint8)) # dummy
    self.register_buffer('density_grid',
    torch.zeros(self.cascades, self.grid_size**3))
    self.register_buffer('grid_coords',
    create_meshgrid3d(self.grid_size, self.grid_size, self.grid_size, False, dtype=torch.int32).reshape(-1, 3))
    self.global_representation = global_representation
    if global_representation is not None:
    self.initialize_global_volume(global_representation)
    self.xyz_encoder = \
    tcnn.Network(
    n_input_dims=16, n_output_dims=16,
    network_config={
    "otype": "FullyFusedMLP",
    "activation": "ReLU",
    "output_activation": "None",
    "n_neurons": 64,
    "n_hidden_layers": 1,
    }
    )
    else:
    self.xyz_encoder = \
    tcnn.NetworkWithInputEncoding(
    n_input_dims=3, n_output_dims=16,
    encoding_config={
    "otype": "Grid",
    "type": "Dense",
    "n_levels": 3,
    "n_feature_per_level": 2,
    "base_resolution": 128,
    "per_level_scale": 2.0,
    "interpolation": "Linear",
    },
    network_config={
    "otype": "FullyFusedMLP",
    "activation": "ReLU",
    "output_activation": "None",
    "n_neurons": 64,
    "n_hidden_layers": 1,
    }
    )
    self.dir_encoder = \
    tcnn.Encoding(
    n_input_dims=3,
    encoding_config={
    "otype": "SphericalHarmonics",
    "degree": 4,
    },
    )
    self.rgb_net = \
    tcnn.Network(
    n_input_dims=32, n_output_dims=3,
    network_config={
    "otype": "FullyFusedMLP",
    "activation": "ReLU",
    "output_activation": "Sigmoid",
    "n_neurons": 64,
    "n_hidden_layers": 2,
    }
    )
    def density(self, x, return_feat=False):
    """
    Inputs:
    x: (N, 3) xyz in [-scale, scale]
    return_feat: whether to return intermediate feature
    Outputs:
    sigmas: (N)
    """
    x = (x-self.xyz_min)/(self.xyz_max-self.xyz_min)
    h = self.xyz_encoder(x)
    sigmas = TruncExp.apply(h[:, 0])
    if return_feat: return sigmas, h
    return sigmas
    def forward(self, x, d, **kwargs):
    """
    Inputs:
    x: (N, 3) xyz in [-scale, scale]
    d: (N, 3) directions
    Outputs:
    sigmas: (N)
    rgbs: (N, 3)
    """
    if self.global_representation is not None:
    x = self.get_global_feature(x)
    sigmas, h = self.density(x, return_feat=True)
    d = d/torch.norm(d, dim=1, keepdim=True)
    d = self.dir_encoder((d+1)/2)
    rgbs = self.rgb_net(torch.cat([d, h], 1))
    return sigmas, rgbs
    @torch.no_grad()
    def get_all_cells(self):
    """
    Get all cells from the density grid.
    Outputs:
    cells: list (of length self.cascades) of indices and coords
    selected at each cascade
    """
    indices = vren.morton3D(self.grid_coords).long()
    cells = [(indices, self.grid_coords)] * self.cascades
    return cells
    @torch.no_grad()
    def sample_uniform_and_occupied_cells(self, M, density_threshold):
    """
    Sample both M uniform and occupied cells (per cascade)
    occupied cells are sample from cells with density > @density_threshold
    Outputs:
    cells: list (of length self.cascades) of indices and coords
    selected at each cascade
    """
    cells = []
    for c in range(self.cascades):
    # uniform cells
    coords1 = torch.randint(self.grid_size, (M, 3), dtype=torch.int32,
    device=self.density_grid.device)
    indices1 = vren.morton3D(coords1).long()
    # occupied cells
    indices2 = torch.nonzero(self.density_grid[c] > density_threshold)[:, 0]
    if len(indices2) > 0:
    rand_idx = torch.randint(len(indices2), (M,),
    device=self.density_grid.device)
    indices2 = indices2[rand_idx]
    coords2 = vren.morton3D_invert(indices2.int())
    # concatenate
    cells += [(torch.cat([indices1, indices2]), torch.cat([coords1, coords2]))]
    return cells
    @torch.no_grad()
    def prune_cells(self, K, poses, img_wh, chunk=64 ** 3):
    """
    mark the cells that aren't covered by the cameras with density -1
    only executed once before training starts
    Inputs:
    K: (3, 3) camera intrinsics
    poses: (N, 3, 4) camera to world poses
    img_wh: image width and height
    chunk: the chunk size to split the cells (to avoid OOM)
    """
    N_cams = poses.shape[0]
    self.count_grid = torch.zeros_like(self.density_grid)
    w2c_R = rearrange(poses[:, :3, :3], 'n a b -> n b a') # (N_cams, 3, 3)
    w2c_T = -w2c_R @ poses[:, :3, 3:] # (N_cams, 3, 1)
    cells = self.get_all_cells()
    for c in range(self.cascades):
    indices, coords = cells[c]
    for i in range(0, len(indices), chunk):
    xyzs = coords[i:i + chunk] / (self.grid_size - 1) * 2 - 1
    s = min(2 ** (c - 1), self.scale)
    half_grid_size = s / self.grid_size
    xyzs_w = (xyzs * (s - half_grid_size)).T # (3, chunk)
    xyzs_c = w2c_R @ xyzs_w + w2c_T # (N_cams, 3, chunk)
    uvd = K @ xyzs_c # (N_cams, 3, chunk)
    uv = uvd[:, :2] / uvd[:, 2:] # (N_cams, 2, chunk)
    in_image = (uvd[:, 2] >= 0) & \
    (uv[:, 0] >= 0) & (uv[:, 0] < img_wh[0]) & \
    (uv[:, 1] >= 0) & (uv[:, 1] < img_wh[1])
    covered_by_cam = (uvd[:, 2] >= NEAR_DISTANCE) & in_image # (N_cams, chunk)
    # if the cell is visible by at least one camera
    self.count_grid[c, indices[i:i + chunk]] = \
    count = covered_by_cam.sum(0) / N_cams
    too_near_to_cam = (uvd[:, 2] < NEAR_DISTANCE) & in_image # (N, chunk)
    # if the cell is too close (in front) to any camera
    too_near_to_any_cam = too_near_to_cam.any(0)
    # a valid cell should be visible by at least one camera and not too close to any camera
    valid_mask = (count > 0) & (~too_near_to_any_cam)
    self.density_grid[c, indices[i:i + chunk]] = \
    torch.where(valid_mask, 0., -1.)
    @torch.no_grad()
    def update_density_grid(self, density_threshold, warmup=False, decay=0.95, erode=False):
    density_grid_tmp = torch.zeros_like(self.density_grid)
    if warmup: # during the first steps
    cells = self.get_all_cells()
    else:
    cells = self.sample_uniform_and_occupied_cells(self.grid_size ** 3 // 4,
    density_threshold)
    # infer sigmas
    for c in range(self.cascades):
    indices, coords = cells[c]
    s = min(2 ** (c - 1), self.scale)
    half_grid_size = s / self.grid_size
    xyzs_w = (coords / (self.grid_size - 1) * 2 - 1) * (s - half_grid_size)
    # pick random position in the cell by adding noise in [-hgs, hgs]
    xyzs_w += (torch.rand_like(xyzs_w) * 2 - 1) * half_grid_size
    density_grid_tmp[c, indices] = self.density(xyzs_w)
    if erode:
    # My own logic. decay more the cells that are visible to few cameras
    decay = torch.clamp(decay ** (1 / self.count_grid), 0.1, 0.95)
    self.density_grid = \
    torch.where(self.density_grid < 0,
    self.density_grid,
    torch.maximum(self.density_grid * decay, density_grid_tmp))
    mean_density = self.density_grid[self.density_grid > 0].mean().item()
    vren.packbits(self.density_grid, min(mean_density, density_threshold),
    self.density_bitfield)

    Sparse Neural Volume

    Voxel 만드는 과정

    1. a
    
      
    class SPVCNN(nn.Module):
    def __init__(self, **kwargs):
    super().__init__()
    self.dropout = kwargs['dropout']
    cr = kwargs.get('cr', 1.0)
    cs = [32, 64, 128, 96, 96]
    cs = [int(cr * x) for x in cs]
    if 'pres' in kwargs and 'vres' in kwargs:
    self.pres = kwargs['pres']
    self.vres = kwargs['vres']
    self.stem = nn.Sequential(
    spnn.Conv3d(kwargs['in_channels'], cs[0], kernel_size=3, stride=1),
    spnn.BatchNorm(cs[0]), spnn.ReLU(True)
    )
    self.stage1 = nn.Sequential(
    BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1),
    ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1),
    ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1),
    )
    self.stage2 = nn.Sequential(
    BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1),
    ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1),
    ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1),
    )
    self.up1 = nn.ModuleList([
    BasicDeconvolutionBlock(cs[2], cs[3], ks=2, stride=2),
    nn.Sequential(
    ResidualBlock(cs[3] + cs[1], cs[3], ks=3, stride=1,
    dilation=1),
    ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1),
    )
    ])
    self.up2 = nn.ModuleList([
    BasicDeconvolutionBlock(cs[3], cs[4], ks=2, stride=2),
    nn.Sequential(
    ResidualBlock(cs[4] + cs[0], cs[4], ks=3, stride=1,
    dilation=1),
    ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1),
    )
    ])
    self.point_transforms = nn.ModuleList([
    nn.Sequential(
    nn.Linear(cs[0], cs[2]),
    nn.BatchNorm1d(cs[2]),
    nn.ReLU(True),
    ),
    nn.Sequential(
    nn.Linear(cs[2], cs[4]),
    nn.BatchNorm1d(cs[4]),
    nn.ReLU(True),
    )
    ])
    self.weight_initialization()
    if self.dropout:
    self.dropout = nn.Dropout(0.3, True)
    def weight_initialization(self):
    for m in self.modules():
    if isinstance(m, nn.BatchNorm1d):
    nn.init.constant_(m.weight, 1)
    nn.init.constant_(m.bias, 0)
    def forward(self, z):
    # x: SparseTensor z: PointTensor
    x0 = initial_voxelize(z, self.pres, self.vres)
    x0 = self.stem(x0)
    z0 = voxel_to_point(x0, z, nearest=False)
    z0.F = z0.F
    x1 = point_to_voxel(x0, z0)
    x1 = self.stage1(x1)
    x2 = self.stage2(x1)
    z1 = voxel_to_point(x2, z0)
    z1.F = z1.F + self.point_transforms[0](z0.F)
    y3 = point_to_voxel(x2, z1)
    if self.dropout:
    y3.F = self.dropout(y3.F)
    y3 = self.up1[0](y3)
    y3 = torchsparse.cat([y3, x1])
    y3 = self.up1[1](y3)
    y4 = self.up2[0](y3)
    y4 = torchsparse.cat([y4, x0])
    y4 = self.up2[1](y4)
    z3 = voxel_to_point(y4, z1)
    z3.F = z3.F + self.point_transforms[1](z1.F)
    return z3.F
    
      
    # z: PointTensor
    # return: SparseTensor
    def initial_voxelize(z, init_res, after_res):
    new_float_coord = torch.cat(
    [(z.C[:, :3] * init_res) / after_res, z.C[:, -1].view(-1, 1)], 1)
    pc_hash = F.sphash(torch.floor(new_float_coord).int())
    sparse_hash = torch.unique(pc_hash)
    idx_query = F.sphashquery(pc_hash, sparse_hash)
    counts = F.spcount(idx_query.int(), len(sparse_hash))
    inserted_coords = F.spvoxelize(torch.floor(new_float_coord), idx_query,
    counts)
    inserted_coords = torch.round(inserted_coords).int()
    inserted_feat = F.spvoxelize(z.F, idx_query, counts)
    new_tensor = SparseTensor(inserted_feat, inserted_coords, 1)
    new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords)
    z.additional_features['idx_query'][1] = idx_query
    z.additional_features['counts'][1] = counts
    z.C = new_float_coord
    return new_tensor
    
      
    # x: SparseTensor, z: PointTensor
    # return: SparseTensor
    def point_to_voxel(x, z):
    if z.additional_features is None or z.additional_features.get('idx_query') is None\
    or z.additional_features['idx_query'].get(x.s) is None:
    #pc_hash = hash_gpu(torch.floor(z.C).int())
    pc_hash = F.sphash(
    torch.cat([
    torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0],
    z.C[:, -1].int().view(-1, 1)
    ], 1))
    sparse_hash = F.sphash(x.C)
    idx_query = F.sphashquery(pc_hash, sparse_hash)
    counts = F.spcount(idx_query.int(), x.C.shape[0])
    z.additional_features['idx_query'][x.s] = idx_query
    z.additional_features['counts'][x.s] = counts
    else:
    idx_query = z.additional_features['idx_query'][x.s]
    counts = z.additional_features['counts'][x.s]
    inserted_feat = F.spvoxelize(z.F, idx_query, counts)
    new_tensor = SparseTensor(inserted_feat, x.C, x.s)
    new_tensor.cmaps = x.cmaps
    new_tensor.kmaps = x.kmaps
    return new_tensor
    
      
    # x: SparseTensor, z: PointTensor
    # return: PointTensor
    def voxel_to_point(x, z, nearest=False):
    if z.idx_query is None or z.weights is None or z.idx_query.get(
    x.s) is None or z.weights.get(x.s) is None:
    off = get_kernel_offsets(2, x.s, 1, device=z.F.device)
    #old_hash = kernel_hash_gpu(torch.floor(z.C).int(), off)
    old_hash = F.sphash(
    torch.cat([
    torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0],
    z.C[:, -1].int().view(-1, 1)
    ], 1), off)
    pc_hash = F.sphash(x.C.to(z.F.device))
    idx_query = F.sphashquery(old_hash, pc_hash)
    weights = F.calc_ti_weights(z.C, idx_query,
    scale=x.s[0]).transpose(0, 1).contiguous()
    idx_query = idx_query.transpose(0, 1).contiguous()
    if nearest:
    weights[:, 1:] = 0.
    idx_query[:, 1:] = -1
    new_feat = F.spdevoxelize(x.F, idx_query, weights)
    new_tensor = PointTensor(new_feat,
    z.C,
    idx_query=z.idx_query,
    weights=z.weights)
    new_tensor.additional_features = z.additional_features
    new_tensor.idx_query[x.s] = idx_query
    new_tensor.weights[x.s] = weights
    z.idx_query[x.s] = idx_query
    z.weights[x.s] = weights
    else:
    new_feat = F.spdevoxelize(x.F, z.idx_query.get(x.s),
    z.weights.get(x.s))
    new_tensor = PointTensor(new_feat,
    z.C,
    idx_query=z.idx_query,
    weights=z.weights)
    new_tensor.additional_features = z.additional_features
    return new_tensor
    
      
    class SparseVoxelGrid(nn.Module):
    def __init__(self, scale, resolution, feat_dim):
    """
    scale: range of xyz. 0.5 -> (-0.5, 0.5)
    resolution: #voxels within each dim. 128 -> 128x128x128
    """
    super().__init__()
    self.scale = scale
    self.resolution = resolution
    self.voxel_size = scale * 2 / resolution

     

    3D Feature Volume

     

    
      
    class VolumeRenderer(torch.autograd.Function):
    """
    Volume rendering with different number of samples per ray
    Used in training only
    Inputs:
    sigmas: (N)
    rgbs: (N, 3)
    deltas: (N)
    ts: (N)
    rays_a: (N_rays, 3) ray_idx, start_idx, N_samples
    meaning each entry corresponds to the @ray_idx th ray,
    whose samples are [start_idx:start_idx+N_samples]
    T_threshold: float, stop the ray if the transmittance is below it
    Outputs:
    total_samples: int, total effective samples
    opacity: (N_rays)
    depth: (N_rays)
    rgb: (N_rays, 3)
    ws: (N) sample point weights
    """
    @staticmethod
    @custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, sigmas, rgbs, deltas, ts, rays_a, T_threshold):
    total_samples, opacity, depth, rgb, ws = \
    vren.composite_train_fw(sigmas, rgbs, deltas, ts,
    rays_a, T_threshold)
    ctx.save_for_backward(sigmas, rgbs, deltas, ts, rays_a,
    opacity, depth, rgb, ws)
    ctx.T_threshold = T_threshold
    return total_samples.sum(), opacity, depth, rgb, ws
    @staticmethod
    @custom_bwd
    def backward(ctx, dL_dtotal_samples, dL_dopacity, dL_ddepth, dL_drgb, dL_dws):
    sigmas, rgbs, deltas, ts, rays_a, \
    opacity, depth, rgb, ws = ctx.saved_tensors
    dL_dsigmas, dL_drgbs = \
    vren.composite_train_bw(dL_dopacity, dL_ddepth, dL_drgb, dL_dws,
    sigmas, rgbs, ws, deltas, ts,
    rays_a,
    opacity, depth, rgb,
    ctx.T_threshold)
    return dL_dsigmas, dL_drgbs, None, None, None, None

     

    Fusion Module

    GRU Fusion

    ConvGRU를 사용하여 hidden state feature를 업데이트한다.
    Update hidden state features with ConvGRU.

    • convert2dense:
      1. sparse feature를 dense feature로 바꾼다.
      2. [updated_coords] 현재 feature의 좌표와 이전 좌표들을 FBV에서 합친다. 이는 global hidden state로부터 new feature coordinates로 만든다.
      3. ground truth tsdf를 합성한다.
    • update_map: Replace hidden state/tsdf in global hidden state/tsdf volume by direct substitute corresponding voxels
    • save_mesh: scene & scene_tsdf
    • forward (Incremental fusion)
      • If the fragment is from new scene, then reinitialize backend map
      • Each level has its corresponding voxel size.
      • Get the relative origin in global volume.
      • Convert to dense
        1. Convert sparse features to dense feature
        2. Combine current feature coordinates and previous feature coordinates within FBV from our backend map to get new feature coordinates(updated_coords)
      • Dense to sparse: get feature using new feature coordinates
      • Get fused ground truth
      • Feed back to global volume
     
    
      
    class GRUFusion(nn.Module):
    """
    Two functionalities of this class:
    1. GRU Fusion module as in the paper. Update hidden state features with ConvGRU.
    2. Substitute TSDF in the global volume when direct_substitute = True.
    """
    def __init__(self, cfg, ch_in=None, direct_substitute=False):
    super(GRUFusion, self).__init__()
    self.cfg = cfg
    # replace tsdf in global tsdf volume by direct substitute corresponding voxels
    self.direct_substitude = direct_substitute
    if direct_substitute:
    # tsdf
    self.ch_in = [1, 1, 1]
    self.feat_init = 1
    else:
    # features
    self.ch_in = ch_in
    self.feat_init = 0
    self.n_scales = len(cfg.THRESHOLDS) - 1
    self.scene_name = [None, None, None]
    self.global_origin = [None, None, None]
    self.global_volume = [None, None, None]
    self.target_tsdf_volume = [None, None, None]
    if direct_substitute:
    self.fusion_nets = None
    else:
    self.fusion_nets = nn.ModuleList()
    for i, ch in enumerate(ch_in):
    self.fusion_nets.append(ConvGRU(hidden_dim=ch,
    input_dim=ch,
    pres=1,
    vres=self.cfg.VOXEL_SIZE * 2 ** (self.n_scales - i)))
    def reset(self, i):
    self.global_volume[i] = PointTensor(torch.Tensor([]), torch.Tensor([]).view(0, 3).long()).cuda()
    self.target_tsdf_volume[i] = PointTensor(torch.Tensor([]), torch.Tensor([]).view(0, 3).long()).cuda()
    def convert2dense(self, current_coords, current_values, coords_target_global, tsdf_target, relative_origin,
    scale):
    '''
    1. convert sparse feature to dense feature;
    2. combine current feature coordinates and previous coordinates within FBV from global hidden state to get new feature coordinates (updated_coords);
    3. fuse ground truth tsdf.
    :param current_coords: (Tensor), current coordinates, (N, 3)
    :param current_values: (Tensor), current features/tsdf, (N, C)
    :param coords_target_global: (Tensor), ground truth coordinates, (N', 3)
    :param tsdf_target: (Tensor), tsdf ground truth, (N',)
    :param relative_origin: (Tensor), origin in global volume, (3,)
    :param scale:
    :return: updated_coords: (Tensor), coordinates after combination, (N', 3)
    :return: current_volume: (Tensor), current dense feature/tsdf volume, (DIM_X, DIM_Y, DIM_Z, C)
    :return: global_volume: (Tensor), global dense feature/tsdf volume, (DIM_X, DIM_Y, DIM_Z, C)
    :return: target_volume: (Tensor), dense target tsdf volume, (DIM_X, DIM_Y, DIM_Z, 1)
    :return: valid: mask: 1 represent in current FBV (N,)
    :return: valid_target: gt mask: 1 represent in current FBV (N,)
    '''
    # previous frame
    global_coords = self.global_volume[scale].C
    global_value = self.global_volume[scale].F
    global_tsdf_target = self.target_tsdf_volume[scale].F
    global_coords_target = self.target_tsdf_volume[scale].C
    dim = (torch.Tensor(self.cfg.N_VOX).cuda() // 2 ** (self.cfg.N_LAYER - scale - 1)).int()
    dim_list = dim.data.cpu().numpy().tolist()
    # mask voxels that are out of the FBV
    global_coords = global_coords - relative_origin
    valid = ((global_coords < dim) & (global_coords >= 0)).all(dim=-1)
    if self.cfg.FUSION.FULL is False:
    valid_volume = sparse_to_dense_torch(current_coords, 1, dim_list, 0, global_value.device)
    value = valid_volume[global_coords[valid][:, 0], global_coords[valid][:, 1], global_coords[valid][:, 2]]
    all_true = valid[valid]
    all_true[value == 0] = False
    valid[valid] = all_true
    # sparse to dense
    global_volume = sparse_to_dense_channel(global_coords[valid], global_value[valid], dim_list, self.ch_in[scale],
    self.feat_init, global_value.device)
    current_volume = sparse_to_dense_channel(current_coords, current_values, dim_list, self.ch_in[scale],
    self.feat_init, global_value.device)
    if self.cfg.FUSION.FULL is True:
    # change the structure of sparsity, combine current coordinates and previous coordinates from global volume
    if self.direct_substitude:
    updated_coords = torch.nonzero((global_volume.abs() < 1).any(-1) | (current_volume.abs() < 1).any(-1))
    else:
    updated_coords = torch.nonzero((global_volume != 0).any(-1) | (current_volume != 0).any(-1))
    else:
    updated_coords = current_coords
    # fuse ground truth
    if tsdf_target is not None:
    # mask voxels that are out of the FBV
    global_coords_target = global_coords_target - relative_origin
    valid_target = ((global_coords_target < dim) & (global_coords_target >= 0)).all(dim=-1)
    # combine current tsdf and global tsdf
    coords_target = torch.cat([global_coords_target[valid_target], coords_target_global])[:, :3]
    tsdf_target = torch.cat([global_tsdf_target[valid_target], tsdf_target.unsqueeze(-1)])
    # sparse to dense
    target_volume = sparse_to_dense_channel(coords_target, tsdf_target, dim_list, 1, 1,
    tsdf_target.device)
    else:
    target_volume = valid_target = None
    return updated_coords, current_volume, global_volume, target_volume, valid, valid_target
    def update_map(self, value, coords, target_volume, valid, valid_target,
    relative_origin, scale):
    '''
    Replace Hidden state/tsdf in global Hidden state/tsdf volume by direct substitute corresponding voxels
    :param value: (Tensor) fused feature (N, C)
    :param coords: (Tensor) updated coords (N, 3)
    :param target_volume: (Tensor) tsdf volume (DIM_X, DIM_Y, DIM_Z, 1)
    :param valid: (Tensor) mask: 1 represent in current FBV (N,)
    :param valid_target: (Tensor) gt mask: 1 represent in current FBV (N,)
    :param relative_origin: (Tensor), origin in global volume, (3,)
    :param scale:
    :return:
    '''
    # pred
    self.global_volume[scale].F = torch.cat(
    [self.global_volume[scale].F[valid == False], value])
    coords = coords + relative_origin
    self.global_volume[scale].C = torch.cat([self.global_volume[scale].C[valid == False], coords])
    # target
    if target_volume is not None:
    target_volume = target_volume.squeeze()
    self.target_tsdf_volume[scale].F = torch.cat(
    [self.target_tsdf_volume[scale].F[valid_target == False],
    target_volume[target_volume.abs() < 1].unsqueeze(-1)])
    target_coords = torch.nonzero(target_volume.abs() < 1) + relative_origin
    self.target_tsdf_volume[scale].C = torch.cat(
    [self.target_tsdf_volume[scale].C[valid_target == False], target_coords])
    def save_mesh(self, scale, outputs, scene):
    if outputs is None:
    outputs = dict()
    if "scene_name" not in outputs:
    outputs['origin'] = []
    outputs['scene_tsdf'] = []
    outputs['scene_name'] = []
    # only keep the newest result
    if scene in outputs['scene_name']:
    # delete old
    idx = outputs['scene_name'].index(scene)
    del outputs['origin'][idx]
    del outputs['scene_tsdf'][idx]
    del outputs['scene_name'][idx]
    # scene name
    outputs['scene_name'].append(scene)
    fuse_coords = self.global_volume[scale].C
    tsdf = self.global_volume[scale].F.squeeze(-1)
    max_c = torch.max(fuse_coords, dim=0)[0][:3]
    min_c = torch.min(fuse_coords, dim=0)[0][:3]
    outputs['origin'].append(min_c * self.cfg.VOXEL_SIZE * (2 ** (self.cfg.N_LAYER - scale - 1)))
    ind_coords = fuse_coords - min_c
    dim_list = (max_c - min_c + 1).int().data.cpu().numpy().tolist()
    tsdf_volume = sparse_to_dense_torch(ind_coords, tsdf, dim_list, 1, tsdf.device)
    outputs['scene_tsdf'].append(tsdf_volume)
    return outputs
    def forward(self, coords, values_in, inputs, scale=2, outputs=None, save_mesh=False):
    '''
    :param coords: (Tensor), coordinates of voxels, (N, 4) (4 : Batch ind, x, y, z)
    :param values_in: (Tensor), features/tsdf, (N, C)
    :param inputs: dict: meta data from dataloader
    :param scale:
    :param outputs:
    :param save_mesh: a bool to indicate whether or not to save the reconstructed mesh of current sample
    if direct_substitude:
    :return: outputs: dict: {
    'origin': (List), origin of the predicted partial volume,
    [3]
    'scene_tsdf': (List), predicted tsdf volume,
    [(nx, ny, nz)]
    'target': (List), ground truth tsdf volume,
    [(nx', ny', nz')]
    'scene_name': (List), name of each scene in 'scene_tsdf',
    [string]
    }
    else:
    :return: updated_coords_all: (Tensor), updated coordinates, (N', 4) (4 : Batch ind, x, y, z)
    :return: values_all: (Tensor), features after gru fusion, (N', C)
    :return: tsdf_target_all: (Tensor), tsdf ground truth, (N', 1)
    :return: occ_target_all: (Tensor), occupancy ground truth, (N', 1)
    '''
    if self.global_volume[scale] is not None:
    # delete computational graph to save memory
    self.global_volume[scale] = self.global_volume[scale].detach()
    batch_size = len(inputs['fragment'])
    interval = 2 ** (self.cfg.N_LAYER - scale - 1)
    tsdf_target_all = None
    occ_target_all = None
    values_all = None
    updated_coords_all = None
    # ---incremental fusion----
    for i in range(batch_size):
    scene = inputs['scene'][i] # scene name
    global_origin = inputs['vol_origin'][i] # origin of global volume
    origin = inputs['vol_origin_partial'][i] # origin of part volume
    if scene != self.scene_name[scale] and self.scene_name[scale] is not None and self.direct_substitude:
    outputs = self.save_mesh(scale, outputs, self.scene_name[scale])
    # if this fragment is from new scene, we reinitialize backend map
    if self.scene_name[scale] is None or scene != self.scene_name[scale]:
    self.scene_name[scale] = scene
    self.reset(scale)
    self.global_origin[scale] = global_origin
    # each level has its corresponding voxel size
    voxel_size = self.cfg.VOXEL_SIZE * interval
    # relative origin in global volume
    relative_origin = (origin - self.global_origin[scale]) / voxel_size
    relative_origin = relative_origin.cuda().long()
    batch_ind = torch.nonzero(coords[:, 0] == i).squeeze(1)
    if len(batch_ind) == 0:
    continue
    coords_b = coords[batch_ind, 1:].long() // interval
    values = values_in[batch_ind]
    if 'occ_list' in inputs.keys():
    # get partial gt
    occ_target = inputs['occ_list'][self.cfg.N_LAYER - scale - 1][i]
    tsdf_target = inputs['tsdf_list'][self.cfg.N_LAYER - scale - 1][i][occ_target]
    coords_target = torch.nonzero(occ_target)
    else:
    coords_target = tsdf_target = None
    # convert to dense: 1. convert sparse feature to dense feature; 2. combine current feature coordinates and
    # previous feature coordinates within FBV from our backend map to get new feature coordinates (updated_coords)
    updated_coords, current_volume, global_volume, target_volume, valid, valid_target = self.convert2dense(
    coords_b,
    values,
    coords_target,
    tsdf_target,
    relative_origin,
    scale)
    # dense to sparse: get features using new feature coordinates (updated_coords)
    values = current_volume[updated_coords[:, 0], updated_coords[:, 1], updated_coords[:, 2]]
    global_values = global_volume[updated_coords[:, 0], updated_coords[:, 1], updated_coords[:, 2]]
    # get fused gt
    if target_volume is not None:
    tsdf_target = target_volume[updated_coords[:, 0], updated_coords[:, 1], updated_coords[:, 2]]
    occ_target = tsdf_target.abs() < 1
    else:
    tsdf_target = occ_target = None
    if not self.direct_substitude:
    # convert to aligned camera coordinate
    r_coords = updated_coords.detach().clone().float()
    r_coords = r_coords.permute(1, 0).contiguous().float() * voxel_size + origin.unsqueeze(-1).float()
    r_coords = torch.cat((r_coords, torch.ones_like(r_coords[:1])), dim=0)
    r_coords = inputs['world_to_aligned_camera'][i, :3, :] @ r_coords
    r_coords = torch.cat([r_coords, torch.zeros(1, r_coords.shape[-1]).to(r_coords.device)])
    r_coords = r_coords.permute(1, 0).contiguous()
    h = PointTensor(global_values, r_coords)
    x = PointTensor(values, r_coords)
    values = self.fusion_nets[scale](h, x)
    # feed back to global volume (direct substitute)
    self.update_map(values, updated_coords, target_volume, valid, valid_target, relative_origin, scale)
    if updated_coords_all is None:
    updated_coords_all = torch.cat([torch.ones_like(updated_coords[:, :1]) * i, updated_coords * interval],
    dim=1)
    values_all = values
    tsdf_target_all = tsdf_target
    occ_target_all = occ_target
    else:
    updated_coords = torch.cat([torch.ones_like(updated_coords[:, :1]) * i, updated_coords * interval],
    dim=1)
    updated_coords_all = torch.cat([updated_coords_all, updated_coords])
    values_all = torch.cat([values_all, values])
    if tsdf_target_all is not None:
    tsdf_target_all = torch.cat([tsdf_target_all, tsdf_target])
    occ_target_all = torch.cat([occ_target_all, occ_target])
    if self.direct_substitude and save_mesh:
    outputs = self.save_mesh(scale, outputs, self.scene_name[scale])
    if self.direct_substitude:
    return outputs
    else:
    return updated_coords_all, values_all, tsdf_target_all, occ_target_all

    ConvGRU

    zt=Mz([Vgt1,Vt])rt=Mr([Vgt1,Vt])Vgt=Mt([rtVgt1,Vt])Vgt=(1zt)Vgt1+ztVgt

    • Mz,Mr는 (업데이트게이트(zt), 재설정게이트(rt)를 위한 희소 3D Conv Layer) sigmoid 활성화 함수 사용
    • Mt는 tanh 활성화 함수 사용
    
      
    class ConvGRU(nn.Module):
    def __init__(self, hidden_dim=128, input_dim=192 + 128, pres=1, vres=1):
    super(ConvGRU, self).__init__()
    self.convz = SConv3d(hidden_dim + input_dim, hidden_dim, pres, vres, 3)
    self.convr = SConv3d(hidden_dim + input_dim, hidden_dim, pres, vres, 3)
    self.convq = SConv3d(hidden_dim + input_dim, hidden_dim, pres, vres, 3)
    def forward(self, h, x):
    '''
    :param h: PintTensor
    :param x: PintTensor
    :return: h.F: Tensor (N, C)
    '''
    hx = PointTensor(torch.cat([h.F, x.F], dim=1), h.C)
    z = torch.sigmoid(self.convz(hx).F)
    r = torch.sigmoid(self.convr(hx).F)
    x.F = torch.cat([r * h.F, x.F], dim=1)
    q = torch.tanh(self.convq(x).F)
    h.F = (1 - z) * h.F + z * q
    return h.F

    Sparse 3D Conv Layer

    
      
    import torchsparse.nn as spnn
    class SConv3d(nn.Module):
    def __init__(self, inc, outc, pres, vres, ks=3, stride=1, dilation=1):
    super().__init__()
    self.net = spnn.Conv3d(inc,
    outc,
    kernel_size=ks,
    dilation=dilation,
    stride=stride)
    self.point_transforms = nn.Sequential(
    nn.Linear(inc, outc),
    )
    self.pres = pres
    self.vres = vres
    def forward(self, z):
    x = initial_voxelize(z, self.pres, self.vres)
    x = self.net(x)
    out = voxel_to_point(x, z, nearest=False)
    out.F = out.F + self.point_transforms(z.F)
    return out
    
      
    def sparse_to_dense_torch(locs, values, dim, default_val, device):
    dense = torch.full([dim[0], dim[1], dim[2]], float(default_val), device=device)
    if locs.shape[0] > 0:
    dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values
    return dense
    def sparse_to_dense_channel(locs, values, dim, c, default_val, device):
    dense = torch.full([dim[0], dim[1], dim[2], c], float(default_val), device=device)
    if locs.shape[0] > 0:
    dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values
    return dense

    Update Density Grid

    복셀을 계산한 결과에 따라 주변 복셀(density)을 업데이트 시켜준다. training_step에서 설정한 update_interval마다 실행된다.

    더보기
    
      
    def training_step(self, batch, batch_nb, *args):
    if self.global_step%self.update_interval == 0:
    self.model.update_density_grid(0.01*MAX_SAMPLES/3**0.5,
    warmup=self.global_step<self.warmup_steps,
    erode=self.hparams.dataset_name=='colmap')
    results = self(batch, split='train')
    loss_d = self.loss(results, batch)
    loss = sum(lo.mean() for lo in loss_d.values())
    with torch.no_grad():
    self.train_psnr(results['rgb'], batch['rgb'])
    self.log('lr', self.net_opt.param_groups[0]['lr'])
    self.log('train/loss', loss)
    # ray marching samples per ray (occupied space on the ray)
    self.log('train/rm_s', results['rm_samples']/len(batch['rgb']), True)
    # volume rendering samples per ray (stops marching when transmittance drops below 1e-4)
    self.log('train/vr_s', results['vr_samples']/len(batch['rgb']), True)
    self.log('train/psnr', self.train_psnr, True)
    return loss
    
      
    @torch.no_grad()
    def update_density_grid(self, density_threshold, warmup=False, decay=0.95, erode=False):
    density_grid_tmp = torch.zeros_like(self.density_grid)
    if warmup: # during the first steps
    cells = self.get_all_cells()
    else:
    cells = self.sample_uniform_and_occupied_cells(self.grid_size ** 3 // 4,
    density_threshold)
    # infer sigmas
    for c in range(self.cascades):
    indices, coords = cells[c]
    s = min(2 ** (c - 1), self.scale)
    half_grid_size = s / self.grid_size
    xyzs_w = (coords / (self.grid_size - 1) * 2 - 1) * (s - half_grid_size)
    # pick random position in the cell by adding noise in [-hgs, hgs]
    xyzs_w += (torch.rand_like(xyzs_w) * 2 - 1) * half_grid_size
    density_grid_tmp[c, indices] = self.density(xyzs_w)
    if erode:
    # My own logic. decay more the cells that are visible to few cameras
    decay = torch.clamp(decay ** (1 / self.count_grid), 0.1, 0.95)
    self.density_grid = \
    torch.where(self.density_grid < 0,
    self.density_grid,
    torch.maximum(self.density_grid * decay, density_grid_tmp))
    mean_density = self.density_grid[self.density_grid > 0].mean().item()
    vren.packbits(self.density_grid, min(mean_density, density_threshold),
    self.density_bitfield)

    Backbone (MNAS)

    
      
    class MnasMulti(nn.Module):
    def __init__(self, alpha=1.0):
    super(MnasMulti, self).__init__()
    depths = _get_depths(alpha)
    if alpha == 1.0:
    MNASNet = torchvision.models.mnasnet1_0(pretrained=True, progress=True)
    else:
    MNASNet = torchvision.models.MNASNet(alpha=alpha)
    self.conv0 = nn.Sequential(
    MNASNet.layers._modules['0'],
    MNASNet.layers._modules['1'],
    MNASNet.layers._modules['2'],
    MNASNet.layers._modules['3'],
    MNASNet.layers._modules['4'],
    MNASNet.layers._modules['5'],
    MNASNet.layers._modules['6'],
    MNASNet.layers._modules['7'],
    MNASNet.layers._modules['8'],
    )
    self.conv1 = MNASNet.layers._modules['9']
    self.conv2 = MNASNet.layers._modules['10']
    self.out1 = nn.Conv2d(depths[4], depths[4], 1, bias=False)
    self.out_channels = [depths[4]]
    final_chs = depths[4]
    self.inner1 = nn.Conv2d(depths[3], final_chs, 1, bias=True)
    self.inner2 = nn.Conv2d(depths[2], final_chs, 1, bias=True)
    self.out2 = nn.Conv2d(final_chs, depths[3], 3, padding=1, bias=False)
    self.out3 = nn.Conv2d(final_chs, depths[2], 3, padding=1, bias=False)
    self.out_channels.append(depths[3])
    self.out_channels.append(depths[2])
    def forward(self, x):
    conv0 = self.conv0(x)
    conv1 = self.conv1(conv0)
    conv2 = self.conv2(conv1)
    intra_feat = conv2
    outputs = []
    out = self.out1(intra_feat)
    outputs.append(out)
    intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="nearest") + self.inner1(conv1)
    out = self.out2(intra_feat)
    outputs.append(out)
    intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="nearest") + self.inner2(conv0)
    out = self.out3(intra_feat)
    outputs.append(out)
    return outputs[::-1]

     

     

    Voxel Pruning

    NeRFusion2의 method

    
      
    @torch.no_grad()
    def prune_cells(self, K, poses, img_wh, chunk=64 ** 3):
    """
    mark the cells that aren't covered by the cameras with density -1
    only executed once before training starts
    Inputs:
    K: (3, 3) camera intrinsics
    poses: (N, 3, 4) camera to world poses
    img_wh: image width and height
    chunk: the chunk size to split the cells (to avoid OOM)
    """
    N_cams = poses.shape[0]
    self.count_grid = torch.zeros_like(self.density_grid)
    w2c_R = rearrange(poses[:, :3, :3], 'n a b -> n b a') # (N_cams, 3, 3)
    w2c_T = -w2c_R @ poses[:, :3, 3:] # (N_cams, 3, 1)
    cells = self.get_all_cells()
    for c in range(self.cascades):
    indices, coords = cells[c]
    for i in range(0, len(indices), chunk):
    xyzs = coords[i:i + chunk] / (self.grid_size - 1) * 2 - 1
    s = min(2 ** (c - 1), self.scale)
    half_grid_size = s / self.grid_size
    xyzs_w = (xyzs * (s - half_grid_size)).T # (3, chunk)
    xyzs_c = w2c_R @ xyzs_w + w2c_T # (N_cams, 3, chunk)
    uvd = K @ xyzs_c # (N_cams, 3, chunk)
    uv = uvd[:, :2] / uvd[:, 2:] # (N_cams, 2, chunk)
    in_image = (uvd[:, 2] >= 0) & \
    (uv[:, 0] >= 0) & (uv[:, 0] < img_wh[0]) & \
    (uv[:, 1] >= 0) & (uv[:, 1] < img_wh[1])
    covered_by_cam = (uvd[:, 2] >= NEAR_DISTANCE) & in_image # (N_cams, chunk)
    # if the cell is visible by at least one camera
    self.count_grid[c, indices[i:i + chunk]] = \
    count = covered_by_cam.sum(0) / N_cams
    too_near_to_cam = (uvd[:, 2] < NEAR_DISTANCE) & in_image # (N, chunk)
    # if the cell is too close (in front) to any camera
    too_near_to_any_cam = too_near_to_cam.any(0)
    # a valid cell should be visible by at least one camera and not too close to any camera
    valid_mask = (count > 0) & (~too_near_to_any_cam)
    self.density_grid[c, indices[i:i + chunk]] = \
    torch.where(valid_mask, 0., -1.)

     

    Sampling

    
      
    # https://github.com/zju3dv/NeuralRecon/blob/master/ops/generate_grids.py
    def generate_grid(n_vox, interval):
    with torch.no_grad():
    # Create voxel grid
    grid_range = [torch.arange(0, n_vox[axis], interval) for axis in range(3)]
    grid = torch.stack(torch.meshgrid(grid_range[0], grid_range[1], grid_range[2])) # 3 dx dy dz
    grid = grid.unsqueeze(0).cuda().float() # 1 3 dx dy dz
    grid = grid.view(1, 3, -1)
    return grid

     

    Direct Inference

    Render (2)

    
      
    @torch.cuda.amp.autocast()
    def render(model, rays_o, rays_d, **kwargs):
    """
    Render rays by
    1. Compute the intersection of the rays with the scene bounding box
    2. Follow the process in @render_func (different for train/test)
    Inputs:
    model:
    rays_o: (N_rays, 3) ray origins
    rays_d: (N_rays, 3) ray directions
    Outputs:
    result: dictionary containing final rgb and depth
    """
    rays_o = rays_o.contiguous(); rays_d = rays_d.contiguous()
    hits_cnt, hits_t, hits_voxel_idx = \
    RayAABBIntersector.apply(rays_o, rays_d, model.center, model.half_size, 1)
    # 최소 거리 near로 고정
    hits_t[(hits_t[:, 0, 0]>=0)&(hits_t[:, 0, 0]<NEAR_DISTANCE), 0, 0] = NEAR_DISTANCE
    # hits_t == -1 if there's no hit
    # (N_rays, max_hits, 2)
    render_func = __render_rays_test if kwargs.get('test_time', False) else __render_rays_train
    results = render_func(model, rays_o, rays_d, hits_t, **kwargs)
    for k, v in results.items():
    if kwargs.get('to_cpu', False):
    v = v.cpu()
    if kwargs.get('to_numpy', False):
    v = v.numpy()
    results[k] = v
    return results

    Compute Intersection (2-1)

    Compute intersection of the rays with the scene bounding box (axis-aligned voxles).

    
      
    class RayAABBIntersector(torch.autograd.Function):
    """
    Computes the intersections of rays and axis-aligned voxels.
    Inputs:
    rays_o: (N_rays, 3) ray origins
    rays_d: (N_rays, 3) ray directions
    centers: (N_voxels, 3) voxel centers
    half_sizes: (N_voxels, 3) voxel half sizes
    max_hits: maximum number of intersected voxels to keep for one ray
    (for a cubic scene, this is at most 3*N_voxels^(1/3)-2)
    Outputs:
    hits_cnt: (N_rays) number of hits for each ray
    (followings are from near to far)
    hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit)
    hits_voxel_idx: (N_rays, max_hits) hit voxel indices (-1 if no hit)
    """
    @staticmethod
    @custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, rays_o, rays_d, center, half_size, max_hits):
    return vren.ray_aabb_intersect(rays_o, rays_d, center, half_size, max_hits)

    Render [Train] (2-2)

    1. [RayMarcher.apply()] 레이 마칭을 레이의 방향을 따라 진행한다. 밀도의 비트영역에 빈 공간을 뺀 영역을 쿼리해주고 유효한 샘플 포인트들을 얻는다.
      March the rays along their directions, querying density bitfield to skip empty space, and get the effective sample points (where there is object).
    2. [model(xyzs, dirs, **kwargs)] NN을 샘플링한 포인트와 뷰방향에 대해 추론해주고 properties(density, rgbs)를 추론한다.
      Infer the NN at these positions and view directions to get properties (currently sigmas and rgbs).
    3. [VolumeRenderer.apply()] 볼륨을 렌더링하여 결과를 합친다 (앞뒤로 합치고 투과율이 경계값 이하가 되면 멈춘다).
      Use volume rendering to combine the result (front to back compositing and early stop the ray if its transmittance is below a threshold).
    
      
    def __render_rays_train(model, rays_o, rays_d, hits_t, **kwargs):
    exp_step_factor = kwargs.get('exp_step_factor', 0.)
    results = {}
    (rays_a, xyzs, dirs,
    results['deltas'], results['ts'], results['rm_samples']) = \
    RayMarcher.apply(
    rays_o, rays_d, hits_t[:, 0], model.density_bitfield,
    model.cascades, model.scale,
    exp_step_factor, model.grid_size, MAX_SAMPLES)
    for k, v in kwargs.items(): # supply additional inputs, repeated per ray
    if isinstance(v, torch.Tensor):
    kwargs[k] = torch.repeat_interleave(v[rays_a[:, 0]], rays_a[:, 2], 0)
    # rays_a: (N_rays, 3) ray_idx, start_idx, N_samples
    sigmas, rgbs = model(xyzs, dirs, **kwargs)
    (results['vr_samples'], results['opacity'],
    results['depth'], results['rgb'], results['ws']) = \
    VolumeRenderer.apply(sigmas, rgbs, results['deltas'], results['ts'],
    rays_a, kwargs.get('T_threshold', 1e-4))
    results['rays_a'] = rays_a
    if exp_step_factor==0: # synthetic
    rgb_bg = torch.ones(3, device=rays_o.device)
    else: # real
    if kwargs.get('random_bg', False):
    rgb_bg = torch.rand(3, device=rays_o.device)
    else:
    rgb_bg = torch.zeros(3, device=rays_o.device)
    results['rgb'] = results['rgb'] + \
    rgb_bg*rearrange(1-results['opacity'], 'n -> n 1')
    return results

     

    Render [Test] (2-2)

    
      
    @torch.no_grad()
    def __render_rays_test(model, rays_o, rays_d, hits_t, **kwargs):
    """
    Render rays by
    while (a ray hasn't converged)
    1. Move each ray to its next occupied @N_samples (initially 1) samples
    and evaluate the properties (sigmas, rgbs) there
    2. Composite the result to output; if a ray has transmittance lower
    than a threshold, mark this ray as converged and stop marching it.
    When more rays are dead, we can increase the number of samples
    of each marching (the variable @N_samples)
    """
    exp_step_factor = kwargs.get('exp_step_factor', 0.)
    results = {}
    # output tensors to be filled in
    N_rays = len(rays_o)
    device = rays_o.device
    opacity = torch.zeros(N_rays, device=device)
    depth = torch.zeros(N_rays, device=device)
    rgb = torch.zeros(N_rays, 3, device=device)
    samples = total_samples = 0
    alive_indices = torch.arange(N_rays, device=device)
    # if it's synthetic data, bg is majority so min_samples=1 effectively covers the bg
    # otherwise, 4 is more efficient empirically
    min_samples = 1 if exp_step_factor==0 else 4
    while samples < kwargs.get('max_samples', MAX_SAMPLES):
    N_alive = len(alive_indices)
    if N_alive==0: break
    # the number of samples to add on each ray
    N_samples = max(min(N_rays//N_alive, 64), min_samples)
    samples += N_samples
    xyzs, dirs, deltas, ts, N_eff_samples = \
    vren.raymarching_test(rays_o, rays_d, hits_t[:, 0], alive_indices,
    model.density_bitfield, model.cascades,
    model.scale, exp_step_factor,
    model.grid_size, MAX_SAMPLES, N_samples)
    total_samples += N_eff_samples.sum()
    xyzs = rearrange(xyzs, 'n1 n2 c -> (n1 n2) c')
    dirs = rearrange(dirs, 'n1 n2 c -> (n1 n2) c')
    valid_mask = ~torch.all(dirs==0, dim=1)
    if valid_mask.sum()==0: break
    sigmas = torch.zeros(len(xyzs), device=device)
    rgbs = torch.zeros(len(xyzs), 3, device=device)
    sigmas[valid_mask], _rgbs = model(xyzs[valid_mask], dirs[valid_mask], **kwargs)
    rgbs[valid_mask] = _rgbs.float()
    sigmas = rearrange(sigmas, '(n1 n2) -> n1 n2', n2=N_samples)
    rgbs = rearrange(rgbs, '(n1 n2) c -> n1 n2 c', n2=N_samples)
    vren.composite_test_fw(
    sigmas, rgbs, deltas, ts,
    hits_t[:, 0], alive_indices, kwargs.get('T_threshold', 1e-4),
    N_eff_samples, opacity, depth, rgb)
    alive_indices = alive_indices[alive_indices>=0] # remove converged rays
    results['opacity'] = opacity
    results['depth'] = depth
    results['rgb'] = rgb
    results['total_samples'] = total_samples # total samples for all rays
    if exp_step_factor==0: # synthetic
    rgb_bg = torch.ones(3, device=device)
    else: # real
    rgb_bg = torch.zeros(3, device=device)
    results['rgb'] += rgb_bg*rearrange(1-opacity, 'n -> n 1')
    return results

     

    Ray Marcher (2-2-1)

    레이마칭(vren.raymarching_train)하여 샘플링된 지점의 위치와 방향 가져오기.

    Outputs

    • rays_a: (N_rays) - ray_idx, start_idx, N_samples
    • xyzs: (N, 3) position of samples
    • dirs: (N, 3) view direction of samples
    • deltas: (N) dt for integration (복셀 내부의 충돌지점의 거리 t~t+1 사이의 간격)
    • ts: (N) t of samples (복셀 내부의 충돌지점의 거리)
    
      
    class RayMarcher(torch.autograd.Function):
    """
    March the rays to get sample point positions and directions.
    Inputs:
    rays_o: (N_rays, 3) ray origins
    rays_d: (N_rays, 3) normalized ray directions
    hits_t: (N_rays, 2) near and far bounds from aabb intersection
    density_bitfield: (C*G**3//8)
    cascades: int
    scale: float
    exp_step_factor: the exponential factor to scale the steps
    grid_size: int
    max_samples: int
    Outputs:
    rays_a: (N_rays) ray_idx, start_idx, N_samples
    xyzs: (N, 3) sample positions # (t와 d를 통해 o + t*d 계산한 위치)
    dirs: (N, 3) sample view directions
    deltas: (N) dt for integration #(복셀 내부의 충돌 지점(t, t+1)의 간격)
    ts: (N) sample ts #(복샐 내부의 충돌지점==거리(t) 모임)
    """
    @staticmethod
    @custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, rays_o, rays_d, hits_t,
    density_bitfield, cascades, scale, exp_step_factor,
    grid_size, max_samples):
    # noise to perturb the first sample of each ray
    noise = torch.rand_like(rays_o[:, 0])
    rays_a, xyzs, dirs, deltas, ts, counter = \
    vren.raymarching_train(
    rays_o, rays_d, hits_t,
    density_bitfield, cascades, scale,
    exp_step_factor, noise, grid_size, max_samples)
    total_samples = counter[0] # total samples for all rays
    # remove redundant output
    xyzs = xyzs[:total_samples]
    dirs = dirs[:total_samples]
    deltas = deltas[:total_samples]
    ts = ts[:total_samples]
    ctx.save_for_backward(rays_a, ts)
    return rays_a, xyzs, dirs, deltas, ts, total_samples
    @staticmethod
    @custom_bwd
    def backward(ctx, dL_drays_a, dL_dxyzs, dL_ddirs,
    dL_ddeltas, dL_dts, dL_dtotal_samples):
    rays_a, ts = ctx.saved_tensors
    segments = torch.cat([rays_a[:, 1], rays_a[-1:, 1]+rays_a[-1:, 2]])
    dL_drays_o = segment_csr(dL_dxyzs, segments)
    dL_drays_d = \
    segment_csr(dL_dxyzs*rearrange(ts, 'n -> n 1')+dL_ddirs, segments)
    return dL_drays_o, dL_drays_d, None, None, None, None, None, None, None
    더보기

    Links for more information

    
      
    // ray marching utils
    // below code is based on https://github.com/ashawkey/torch-ngp/blob/main/raymarching/src/raymarching.cu
    __global__ void raymarching_train_kernel(
    const torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> rays_o,
    const torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> rays_d,
    const torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> hits_t,
    const uint8_t* __restrict__ density_bitfield,
    const int cascades,
    const int grid_size,
    const float scale,
    const float exp_step_factor,
    const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> noise,
    const int max_samples,
    int* __restrict__ counter,
    torch::PackedTensorAccessor64<int64_t, 2, torch::RestrictPtrTraits> rays_a,
    torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> xyzs,
    torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> dirs,
    torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> deltas,
    torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> ts
    ){
    const int r = blockIdx.x * blockDim.x + threadIdx.x;
    if (r >= rays_o.size(0)) return;
    const uint32_t grid_size3 = grid_size*grid_size*grid_size;
    const float grid_size_inv = 1.0f/grid_size;
    const float ox = rays_o[r][0], oy = rays_o[r][1], oz = rays_o[r][2];
    const float dx = rays_d[r][0], dy = rays_d[r][1], dz = rays_d[r][2];
    const float dx_inv = 1.0f/dx, dy_inv = 1.0f/dy, dz_inv = 1.0f/dz;
    float t1 = hits_t[r][0], t2 = hits_t[r][1];
    if (t1>=0) { // only perturb the starting t
    const float dt = calc_dt(t1, exp_step_factor, max_samples, grid_size, scale);
    t1 += dt*noise[r];
    }
    // first pass: compute the number of samples on the ray
    float t = t1; int N_samples = 0;
    // if t1 < 0 (no hit) this loop will be skipped (N_samples will be 0)
    while (0<=t && t<t2 && N_samples<max_samples){
    # caculating ray-box intersection points
    const float x = ox + (t * dx), y = oy + (t * dy), z = oz + (t * dz);
    const float dt = calc_dt(t, exp_step_factor, max_samples, grid_size, scale);
    const int mip = max(mip_from_pos(x, y, z, cascades),
    mip_from_dt(dt, grid_size, cascades));
    const float mip_bound = fminf(scalbnf(1.0f, mip-1), scale);
    const float mip_bound_inv = 1/mip_bound;
    // round down to nearest grid position
    const int nx = clamp(0.5f*(x*mip_bound_inv+1)*grid_size, 0.0f, grid_size-1.0f);
    const int ny = clamp(0.5f*(y*mip_bound_inv+1)*grid_size, 0.0f, grid_size-1.0f);
    const int nz = clamp(0.5f*(z*mip_bound_inv+1)*grid_size, 0.0f, grid_size-1.0f);
    const uint32_t idx = mip*grid_size3 + __morton3D(nx, ny, nz);
    const bool occ = density_bitfield[idx/8] & (1<<(idx%8));
    if (occ) {
    t += dt; N_samples++;
    } else { // skip until the next voxel
    const float tx = (((nx+0.5f+0.5f*signf(dx))*grid_size_inv*2-1)*mip_bound-x)*dx_inv;
    const float ty = (((ny+0.5f+0.5f*signf(dy))*grid_size_inv*2-1)*mip_bound-y)*dy_inv;
    const float tz = (((nz+0.5f+0.5f*signf(dz))*grid_size_inv*2-1)*mip_bound-z)*dz_inv;
    const float t_target = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
    do {
    t += calc_dt(t, exp_step_factor, max_samples, grid_size, scale);
    } while (t < t_target);
    }
    }
    // second pass: write to output
    const int start_idx = atomicAdd(counter, N_samples);
    const int ray_count = atomicAdd(counter+1, 1);
    rays_a[ray_count][0] = r;
    rays_a[ray_count][1] = start_idx; rays_a[ray_count][2] = N_samples;
    t = t1; int samples = 0;
    while (t<t2 && samples<N_samples){
    const float x = ox+t*dx, y = oy+t*dy, z = oz+t*dz;
    const float dt = calc_dt(t, exp_step_factor, max_samples, grid_size, scale);
    const int mip = max(mip_from_pos(x, y, z, cascades),
    mip_from_dt(dt, grid_size, cascades));
    const float mip_bound = fminf(scalbnf(1.0f, mip-1), scale);
    const float mip_bound_inv = 1/mip_bound;
    // round down to nearest grid position
    const int nx = clamp(0.5f*(x*mip_bound_inv+1)*grid_size, 0.0f, grid_size-1.0f);
    const int ny = clamp(0.5f*(y*mip_bound_inv+1)*grid_size, 0.0f, grid_size-1.0f);
    const int nz = clamp(0.5f*(z*mip_bound_inv+1)*grid_size, 0.0f, grid_size-1.0f);
    const uint32_t idx = mip*grid_size3 + __morton3D(nx, ny, nz);
    const bool occ = density_bitfield[idx/8] & (1<<(idx%8));
    if (occ) {
    const int s = start_idx + samples;
    xyzs[s][0] = x; xyzs[s][1] = y; xyzs[s][2] = z;
    dirs[s][0] = dx; dirs[s][1] = dy; dirs[s][2] = dz;
    ts[s] = t; deltas[s] = dt;
    t += dt; samples++;
    } else { // skip until the next voxel
    const float tx = (((nx+0.5f+0.5f*signf(dx))*grid_size_inv*2-1)*mip_bound-x)*dx_inv;
    const float ty = (((ny+0.5f+0.5f*signf(dy))*grid_size_inv*2-1)*mip_bound-y)*dy_inv;
    const float tz = (((nz+0.5f+0.5f*signf(dz))*grid_size_inv*2-1)*mip_bound-z)*dz_inv;
    const float t_target = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
    do {
    t += calc_dt(t, exp_step_factor, max_samples, grid_size, scale);
    } while (t < t_target);
    }
    }
    }

    Calculate dt: t * exp_step_factor의 범위를 [SQRT3/max_samples, SQRT3 * 2 * scale / grid_size]로 한정시키기

    
      
    inline __host__ __device__ float calc_dt(float t, float exp_step_factor, int max_samples, int grid_size, float scale){
    return clamp(t*exp_step_factor, SQRT3/max_samples, SQRT3*2*scale/grid_size);
    }

    mip from pos:

    • Example input range of |xyz| and return value of this function
      • [0, 0.5) -> 0
      • [0.5, 1) -> 1
      • [1, 2) -> 2
    
      
    inline __device__ int mip_from_pos(const float x, const float y, const float z, const int cascades) {
    const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z)));
    int exponent; frexpf(mx, &exponent);
    return min(cascades-1, max(0, exponent+1));
    }

    mip from dt:

    • Example input range of dt and return value of this function
      • [0, 1/grid_size) -> 0
      • [1/grid_size, 2/grid_size) -> 1
      • [2/grid_size, 4/grid_size) -> 2
    
      
    inline __device__ int mip_from_dt(float dt, int grid_size, int cascades) {
    int exponent; frexpf(dt*grid_size, &exponent);
    return min(cascades-1, max(0, exponent));
    }

    morton3D: to convert a certain set of integer coordinates to a Morton code, you have to convert the decimal values to binary and interleave the bits of each coordinate.

    • (x,y,z) = (5,9,1) = (0101,1001,0001)
    • Interleaving the bits results in: 010001000111 = 1095 th cell along the Z-curve.
    
      
    inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
    {
    uint32_t xx = __expand_bits(x);
    uint32_t yy = __expand_bits(y);
    uint32_t zz = __expand_bits(z);
    return xx | (yy << 1) | (zz << 2);
    }

    Volume Renderer (2-2-2)

    Ray마다 서로 다른 개수의 샘플들을 가지고 볼륨 렌더링을 진행한다. 이 과정은 학습시에만 사용한다.

    1. ray_a로부터 제공받은 ray_idx의 sample들을 합치기
      front to back compositing (using 'composite_train_fw')
    2. 투과율(T)이 경계값(T_threshold) 이하면 멈추기
    
      
    class VolumeRenderer(torch.autograd.Function):
    """
    Volume rendering with different number of samples per ray
    Used in training only
    Inputs:
    sigmas: (N)
    rgbs: (N, 3)
    deltas: (N)
    ts: (N)
    rays_a: (N_rays, 3) ray_idx, start_idx, N_samples
    meaning each entry corresponds to the @ray_idx th ray,
    whose samples are [start_idx:start_idx+N_samples]
    T_threshold: float, stop the ray if the transmittance is below it
    Outputs:
    total_samples: int, total effective samples
    opacity: (N_rays)
    depth: (N_rays)
    rgb: (N_rays, 3)
    ws: (N) sample point weights
    """
    @staticmethod
    @custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, sigmas, rgbs, deltas, ts, rays_a, T_threshold):
    # ray_idx의 samples를 합치기 (front to back compositing)
    total_samples, opacity, depth, rgb, ws = \
    vren.composite_train_fw(sigmas, rgbs, deltas, ts,
    rays_a, T_threshold)
    ctx.save_for_backward(sigmas, rgbs, deltas, ts, rays_a,
    opacity, depth, rgb, ws)
    ctx.T_threshold = T_threshold
    return total_samples.sum(), opacity, depth, rgb, ws
    @staticmethod
    @custom_bwd
    def backward(ctx, dL_dtotal_samples, dL_dopacity, dL_ddepth, dL_drgb, dL_dws):
    sigmas, rgbs, deltas, ts, rays_a, \
    opacity, depth, rgb, ws = ctx.saved_tensors
    dL_dsigmas, dL_drgbs = \
    vren.composite_train_bw(dL_dopacity, dL_ddepth, dL_drgb, dL_dws,
    sigmas, rgbs, ws, deltas, ts,
    rays_a,
    opacity, depth, rgb,
    ctx.T_threshold)
    return dL_dsigmas, dL_drgbs, None, None, None, None
    더보기

    composite train fw:

    Outputs

    • total_samples
    • opacity
    • depth
    • rgb
    • ws
    
      
    __global__ std::vector<torch::Tensor> composite_train_fw_kernel(
    const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> sigmas,
    const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> rgbs,
    const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> deltas,
    const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> ts,
    const torch::PackedTensorAccessor64<int64_t, 2, torch::RestrictPtrTraits> rays_a,
    const scalar_t T_threshold,
    torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> total_samples,
    torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> opacity,
    torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> depth,
    torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> rgb,
    torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> ws
    ){
    const int n = blockIdx.x * blockDim.x + threadIdx.x;
    if (n >= opacity.size(0)) return;
    const int ray_idx = rays_a[n][0], start_idx = rays_a[n][1], N_samples = rays_a[n][2];
    // front to back compositing
    int samples = 0; scalar_t T = 1.0f;
    while (samples < N_samples) {
    const int s = start_idx + samples;
    const scalar_t a = 1.0f - __expf(-sigmas[s]*deltas[s]);
    const scalar_t w = a * T; // weight of the sample point
    rgb[ray_idx][0] += w*rgbs[s][0];
    rgb[ray_idx][1] += w*rgbs[s][1];
    rgb[ray_idx][2] += w*rgbs[s][2];
    depth[ray_idx] += w*ts[s];
    opacity[ray_idx] += w;
    ws[s] = w;
    T *= 1.0f-a;
    if (T <= T_threshold) break; // ray has enough opacity
    samples++;
    }
    total_samples[ray_idx] = samples;
    return {total_samples, opacity, depth, rgb, ws};
    }

    Unproject

    
      
    def back_project(coords, origin, voxel_size, feats, KRcam):
    '''
    Unproject the image fetures to form a 3D (sparse) feature volume
    Inputs:
    coords: coordinates of voxels, # (num of voxels, 4) (4 : batch ind, x, y, z)
    origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0)) # (batch size, 3) (3: x, y, z)
    voxel_size: floats specifying the size of a voxel
    feats: image features # (num of views, batch size, C, H, W)
    KRcam: projection matrix # (num of views, batch size, 4, 4)
    Outputs:
    feature_volume_all: 3D feature volumes # (num of voxels, c + 1)
    count: number of times each voxel can be seen # (num of voxels,)
    '''
    n_views, bs, c, h, w = feats.shape
    feature_volume_all = torch.zeros(coords.shape[0], c + 1).cuda()
    count = torch.zeros(coords.shape[0]).cuda()
    for batch in range(bs):
    batch_ind = torch.nonzero(coords[:, 0] == batch).squeeze(1)
    coords_batch = coords[batch_ind][:, 1:]
    coords_batch = coords_batch.view(-1, 3)
    origin_batch = origin[batch].unsqueeze(0)
    feats_batch = feats[:, batch]
    proj_batch = KRcam[:, batch]
    grid_batch = coords_batch * voxel_size + origin_batch.float()
    rs_grid = grid_batch.unsqueeze(0).expand(n_views, -1, -1)
    rs_grid = rs_grid.permute(0, 2, 1).contiguous()
    nV = rs_grid.shape[-1]
    rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).cuda()], dim=1)
    # Project grid
    im_p = proj_batch @ rs_grid
    im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2]
    im_x = im_x / im_z
    im_y = im_y / im_z
    im_grid = torch.stack([2 * im_x / (w - 1) - 1, 2 * im_y / (h - 1) - 1], dim=-1)
    mask = im_grid.abs() <= 1
    mask = (mask.sum(dim=-1) == 2) & (im_z > 0)
    feats_batch = feats_batch.view(n_views, c, h, w)
    im_grid = im_grid.view(n_views, 1, -1, 2)
    features = grid_sample(feats_batch, im_grid, padding_mode='zeros', align_corners=True)
    features = features.view(n_views, c, -1)
    mask = mask.view(n_views, -1)
    im_z = im_z.view(n_views, -1)
    # remove nan
    features[mask.unsqueeze(1).expand(-1, c, -1) == False] = 0
    im_z[mask == False] = 0
    count[batch_ind] = mask.sum(dim=0).float()
    # aggregate multi view
    features = features.sum(dim=0)
    mask = mask.sum(dim=0)
    invalid_mask = mask == 0
    mask[invalid_mask] = 1
    in_scope_mask = mask.unsqueeze(0)
    features /= in_scope_mask
    features = features.permute(1, 0).contiguous()
    # concat normalized depth value
    im_z = im_z.sum(dim=0).unsqueeze(1) / in_scope_mask.permute(1, 0).contiguous()
    im_z_mean = im_z[im_z > 0].mean()
    im_z_std = torch.norm(im_z[im_z > 0] - im_z_mean) + 1e-5
    im_z_norm = (im_z - im_z_mean) / im_z_std
    im_z_norm[im_z <= 0] = 0
    features = torch.cat([features, im_z_norm], dim=1)
    feature_volume_all[batch_ind] = features
    return feature_volume_all, count

     

    Train (1)

    
      
    class NeRFSystem(LightningModule):
    def __init__(self, hparams):
    super().__init__()
    self.save_hyperparameters(hparams)
    self.warmup_steps = 256
    self.update_interval = 16
    self.loss = NeRFLoss(lambda_distortion=self.hparams.distortion_loss_w)
    self.train_psnr = PeakSignalNoiseRatio(data_range=1)
    self.val_psnr = PeakSignalNoiseRatio(data_range=1)
    self.val_ssim = StructuralSimilarityIndexMeasure(data_range=1)
    if self.hparams.eval_lpips:
    self.val_lpips = LearnedPerceptualImagePatchSimilarity('vgg')
    for p in self.val_lpips.net.parameters():
    p.requires_grad = False
    self.model = NeRFusion2(scale=self.hparams.scale)
    def forward(self, batch, split):
    if split=='train':
    poses = self.poses[batch['img_idxs']]
    directions = self.directions[batch['pix_idxs']]
    else:
    poses = batch['pose']
    directions = self.directions
    if self.hparams.optimize_ext:
    dR = axisangle_to_R(self.dR[batch['img_idxs']])
    poses[..., :3] = dR @ poses[..., :3]
    poses[..., 3] += self.dT[batch['img_idxs']]
    rays_o, rays_d = get_rays(directions, poses)
    kwargs = {'test_time': split!='train',
    'random_bg': self.hparams.random_bg}
    if self.hparams.scale > 0.5:
    kwargs['exp_step_factor'] = 1/256
    return render(self.model, rays_o, rays_d, **kwargs)
    def on_train_start(self):
    self.model.prune_cells(self.train_dataset.K.to(self.device),
    self.poses,
    self.train_dataset.img_wh)
    def training_step(self, batch, batch_nb, *args):
    if self.global_step%self.update_interval == 0:
    self.model.update_density_grid(0.01*MAX_SAMPLES/3**0.5,
    warmup=self.global_step<self.warmup_steps,
    erode=self.hparams.dataset_name=='colmap')
    results = self(batch, split='train')
    loss_d = self.loss(results, batch)
    loss = sum(lo.mean() for lo in loss_d.values())
    with torch.no_grad():
    self.train_psnr(results['rgb'], batch['rgb'])
    self.log('lr', self.net_opt.param_groups[0]['lr'])
    self.log('train/loss', loss)
    # ray marching samples per ray (occupied space on the ray)
    self.log('train/rm_s', results['rm_samples']/len(batch['rgb']), True)
    # volume rendering samples per ray (stops marching when transmittance drops below 1e-4)
    self.log('train/vr_s', results['vr_samples']/len(batch['rgb']), True)
    self.log('train/psnr', self.train_psnr, True)
    return loss
    더보기

    전체 Training 코드

    
      
    class NeRFSystem(LightningModule):
    def __init__(self, hparams):
    super().__init__()
    self.save_hyperparameters(hparams)
    self.warmup_steps = 256
    self.update_interval = 16
    self.loss = NeRFLoss(lambda_distortion=self.hparams.distortion_loss_w)
    self.train_psnr = PeakSignalNoiseRatio(data_range=1)
    self.val_psnr = PeakSignalNoiseRatio(data_range=1)
    self.val_ssim = StructuralSimilarityIndexMeasure(data_range=1)
    if self.hparams.eval_lpips:
    self.val_lpips = LearnedPerceptualImagePatchSimilarity('vgg')
    for p in self.val_lpips.net.parameters():
    p.requires_grad = False
    self.model = NeRFusion2(scale=self.hparams.scale) # scale=0.5
    def forward(self, batch, split):
    if split=='train':
    poses = self.poses[batch['img_idxs']]
    directions = self.directions[batch['pix_idxs']]
    else:
    poses = batch['pose']
    directions = self.directions
    if self.hparams.optimize_ext:
    dR = axisangle_to_R(self.dR[batch['img_idxs']])
    poses[..., :3] = dR @ poses[..., :3]
    poses[..., 3] += self.dT[batch['img_idxs']]
    rays_o, rays_d = get_rays(directions, poses)
    kwargs = {'test_time': split!='train',
    'random_bg': self.hparams.random_bg}
    if self.hparams.scale > 0.5:
    kwargs['exp_step_factor'] = 1/256
    return render(self.model, rays_o, rays_d, **kwargs)
    def setup(self, stage):
    dataset = dataset_dict[self.hparams.dataset_name]
    kwargs = {'root_dir': self.hparams.root_dir,
    'downsample': self.hparams.downsample}
    self.train_dataset = dataset(split=self.hparams.split, **kwargs)
    self.train_dataset.batch_size = self.hparams.batch_size
    self.train_dataset.ray_sampling_strategy = self.hparams.ray_sampling_strategy
    self.test_dataset = dataset(split='test', **kwargs)
    # define additional parameters
    self.register_buffer('directions', self.train_dataset.directions.to(self.device))
    self.register_buffer('poses', self.train_dataset.poses.to(self.device))
    if self.hparams.optimize_ext:
    N = len(self.train_dataset.poses)
    self.register_parameter('dR',
    nn.Parameter(torch.zeros(N, 3, device=self.device)))
    self.register_parameter('dT',
    nn.Parameter(torch.zeros(N, 3, device=self.device)))
    def configure_optimizers(self):
    load_ckpt(self.model, self.hparams.weight_path)
    net_params = []
    for n, p in self.named_parameters():
    if n not in ['dR', 'dT']: net_params += [p]
    opts = []
    self.net_opt = FusedAdam(net_params, self.hparams.lr, eps=1e-15)
    opts += [self.net_opt]
    if self.hparams.optimize_ext:
    opts += [FusedAdam([self.dR, self.dT], 1e-6)] # learning rate is hard-coded
    net_sch = CosineAnnealingLR(self.net_opt,
    self.hparams.num_epochs,
    self.hparams.lr/30)
    return opts, [net_sch]
    def train_dataloader(self):
    return DataLoader(self.train_dataset,
    num_workers=16,
    persistent_workers=True,
    batch_size=None,
    pin_memory=True)
    def val_dataloader(self):
    return DataLoader(self.test_dataset,
    num_workers=8,
    batch_size=None,
    pin_memory=True)
    def on_train_start(self):
    self.model.prune_cells(self.train_dataset.K.to(self.device),
    self.poses,
    self.train_dataset.img_wh)
    def training_step(self, batch, batch_nb, *args):
    if self.global_step%self.update_interval == 0:
    self.model.update_density_grid(0.01*MAX_SAMPLES/3**0.5,
    warmup=self.global_step<self.warmup_steps,
    erode=self.hparams.dataset_name=='colmap')
    results = self(batch, split='train')
    loss_d = self.loss(results, batch)
    loss = sum(lo.mean() for lo in loss_d.values())
    with torch.no_grad():
    self.train_psnr(results['rgb'], batch['rgb'])
    self.log('lr', self.net_opt.param_groups[0]['lr'])
    self.log('train/loss', loss)
    # ray marching samples per ray (occupied space on the ray)
    self.log('train/rm_s', results['rm_samples']/len(batch['rgb']), True)
    # volume rendering samples per ray (stops marching when transmittance drops below 1e-4)
    self.log('train/vr_s', results['vr_samples']/len(batch['rgb']), True)
    self.log('train/psnr', self.train_psnr, True)
    return loss
    def on_validation_start(self):
    torch.cuda.empty_cache()
    if not self.hparams.no_save_test:
    self.val_dir = f'results/{self.hparams.dataset_name}/{self.hparams.exp_name}'
    os.makedirs(self.val_dir, exist_ok=True)
    def validation_step(self, batch, batch_nb):
    rgb_gt = batch['rgb']
    results = self(batch, split='test')
    logs = {}
    # compute each metric per image
    self.val_psnr(results['rgb'], rgb_gt)
    logs['psnr'] = self.val_psnr.compute()
    self.val_psnr.reset()
    w, h = self.train_dataset.img_wh
    rgb_pred = rearrange(results['rgb'], '(h w) c -> 1 c h w', h=h)
    rgb_gt = rearrange(rgb_gt, '(h w) c -> 1 c h w', h=h)
    self.val_ssim(rgb_pred, rgb_gt)
    logs['ssim'] = self.val_ssim.compute()
    self.val_ssim.reset()
    if self.hparams.eval_lpips:
    self.val_lpips(torch.clip(rgb_pred*2-1, -1, 1),
    torch.clip(rgb_gt*2-1, -1, 1))
    logs['lpips'] = self.val_lpips.compute()
    self.val_lpips.reset()
    if not self.hparams.no_save_test: # save test image to disk
    idx = batch['img_idxs']
    rgb_pred = rearrange(results['rgb'].cpu().numpy(), '(h w) c -> h w c', h=h)
    rgb_pred = (rgb_pred*255).astype(np.uint8)
    imageio.imsave(os.path.join(self.val_dir, f'{idx:03d}.png'), rgb_pred)
    return logs
    def validation_epoch_end(self, outputs):
    psnrs = torch.stack([x['psnr'] for x in outputs])
    mean_psnr = all_gather_ddp_if_available(psnrs).mean()
    self.log('test/psnr', mean_psnr, True)
    ssims = torch.stack([x['ssim'] for x in outputs])
    mean_ssim = all_gather_ddp_if_available(ssims).mean()
    self.log('test/ssim', mean_ssim)
    if self.hparams.eval_lpips:
    lpipss = torch.stack([x['lpips'] for x in outputs])
    mean_lpips = all_gather_ddp_if_available(lpipss).mean()
    self.log('test/lpips_vgg', mean_lpips)
    def get_progress_bar_dict(self):
    # don't show the version number
    items = super().get_progress_bar_dict()
    items.pop("v_num", None)
    return items
Designed by Tistory.