diff --git a/dataloader/data_v2.py b/dataloader/data_v2.py index 47b14d9f..279c51f6 100644 --- a/dataloader/data_v2.py +++ b/dataloader/data_v2.py @@ -106,8 +106,8 @@ def __init__(self, dali_iter, label_select=None): def __next__(self): data_dict = self.iter.__next__()[0] - tensor_data = data_dict["data"].cuda() - tensor_label = data_dict["label"].long().cuda() + tensor_data = data_dict["data"] + tensor_label = data_dict["label"].long() if self.label_select is None: return {"pixel_values": tensor_data, "labels": tensor_label} @@ -142,7 +142,7 @@ def dali_dataloader( num_shards=None, shard_id=None, ): - local_rank = int(os.environ.get("LOCAL_RANK", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) if num_shards is None: num_shards = int(os.environ.get("WORLD_SIZE", "1")) @@ -161,9 +161,9 @@ def dali_dataloader( pipe = Pipeline( batch_size=batch_size, - num_threads=2, + num_threads=max(1, workers), device_id=local_rank % 8, - prefetch_queue_depth=1, + prefetch_queue_depth=3, seed=seed, ) device_memory_padding = 211025920 diff --git a/dataloader/data_v2_multi_res.py b/dataloader/data_v2_multi_res.py index 49484ea3..2a3df7b5 100644 --- a/dataloader/data_v2_multi_res.py +++ b/dataloader/data_v2_multi_res.py @@ -100,8 +100,8 @@ def __init__(self, dali_iter, label_select=None): def __next__(self): data_dict = self.iter.__next__()[0] - tensor_data = data_dict["data"].cuda() - tensor_label = data_dict["label"].long().cuda() + tensor_data = data_dict["data"] + tensor_label = data_dict["label"].long() if self.label_select is None: return tensor_data, tensor_label diff --git a/dataloader/data_v2_ocr.py b/dataloader/data_v2_ocr.py index db04d3e1..da404590 100644 --- a/dataloader/data_v2_ocr.py +++ b/dataloader/data_v2_ocr.py @@ -100,8 +100,8 @@ def __init__(self, dali_iter, label_select=None): def __next__(self): data_dict = self.iter.__next__()[0] - tensor_data = data_dict["data"].cuda() - tensor_label = data_dict["label"].long().cuda() + tensor_data = data_dict["data"] + tensor_label = data_dict["label"].long() if self.label_select is None: return {"pixel_values": tensor_data, "labels": tensor_label} diff --git a/onevision_encoder/modeling_onevision_encoder.py b/onevision_encoder/modeling_onevision_encoder.py index 13561134..365b8735 100644 --- a/onevision_encoder/modeling_onevision_encoder.py +++ b/onevision_encoder/modeling_onevision_encoder.py @@ -140,11 +140,18 @@ def __init__(self, config: OneVisionEncoderConfig): 1.0 / (base ** (torch.arange(self.w_size, dtype=torch.float32) / self.w_size)), persistent=False, ) + self._freqs_cache = {} + self._max_cache_size = 8 def forward(self, t: int, h: int, w: int, device=None): if device is None: device = self.inv_freq_t.device + cache_key = (int(t), int(h), int(w), device.type, -1 if device.index is None else int(device.index)) + cached = self._freqs_cache.get(cache_key) + if cached is not None: + return cached + inv_t = self.inv_freq_t.to(device=device) inv_h = self.inv_freq_h.to(device=device) inv_w = self.inv_freq_w.to(device=device) @@ -158,6 +165,9 @@ def forward(self, t: int, h: int, w: int, device=None): w_ids = torch.arange(w, device=device).repeat(h).repeat(t) freqs = torch.cat([ft[t_ids], fh[h_ids], fw[w_ids]], dim=-1) + if len(self._freqs_cache) >= self._max_cache_size: + self._freqs_cache.pop(next(iter(self._freqs_cache))) + self._freqs_cache[cache_key] = freqs return freqs def forward_from_positions(self, patch_positions: torch.Tensor) -> torch.Tensor: @@ -611,23 +621,32 @@ def forward( hidden_states = self.embeddings(pixel_values) batch_size, total_patches, _ = hidden_states.shape - # 2. Visible Indices Handling - if visible_indices is None: - visible_indices = ( - torch.arange(total_patches, device=pixel_values.device).unsqueeze(0).expand(batch_size, -1) - ) - - # 3. RoPE Construction - if patch_positions is not None: - freqs_visible = self.video_rope.forward_from_positions(patch_positions) - else: + # 2. Visible Indices / RoPE Construction + dense_full_path = visible_indices is None and patch_positions is None + if dense_full_path: freqs_full = self.video_rope( t=t_frames, h=height // self.config.patch_size, w=width // self.config.patch_size, device=pixel_values.device, ) - freqs_visible = freqs_full[visible_indices] + freqs_visible = freqs_full.unsqueeze(0).expand(batch_size, -1, -1) + else: + if visible_indices is None: + visible_indices = ( + torch.arange(total_patches, device=pixel_values.device).unsqueeze(0).expand(batch_size, -1) + ) + + if patch_positions is not None: + freqs_visible = self.video_rope.forward_from_positions(patch_positions) + else: + freqs_full = self.video_rope( + t=t_frames, + h=height // self.config.patch_size, + w=width // self.config.patch_size, + device=pixel_values.device, + ) + freqs_visible = freqs_full[visible_indices] # Concatenate D/2 + D/2 -> D for applying rope freqs_visible = torch.cat([freqs_visible, freqs_visible], dim=-1) @@ -636,8 +655,8 @@ def forward( hidden_states = self.layernorm_pre(hidden_states) # fix: gather hidden_states to match freqs_visible when using sparse visible_indices - num_visible = visible_indices.shape[1] - if num_visible != total_patches: + num_visible = total_patches if dense_full_path else visible_indices.shape[1] + if not dense_full_path and num_visible != total_patches: # sparse mode: select only visible patches hidden_states = hidden_states.gather( 1, visible_indices.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) diff --git a/training/train.py b/training/train.py index 6cb82b2a..d14e90e2 100644 --- a/training/train.py +++ b/training/train.py @@ -84,6 +84,13 @@ parser.add_argument("--backward_passes_per_step", type=int, default=1, help="Gradient accumulation steps") parser.add_argument("--repeat_pfc", type=int, default=0, help="Repeat factor for PFC ops or rebuild cycles") parser.add_argument("--save_pfc", type=int, default=1, help="Save PFC weights in checkpoints (0/1)") +parser.add_argument( + "--compile_backend", + type=str, + default="auto", + choices=["auto", "none", "inductor", "aot_eager", "eager"], + help="torch.compile backend for backbone DDP module", +) # Initialization / Resume parser.add_argument("--init_backbone", default="NULL", help="Backbone init path or 'NULL'") @@ -381,8 +388,22 @@ def wrap_ddp(model): ) backbone_ddp = wrap_ddp(backbone) - # backbone_ddp_compiled = backbone_ddp - backbone_ddp_compiled = torch.compile(backbone_ddp) + dali_types = {dataset_config.dali_type for dataset_config in args.list_datasets} + compile_backend = args.compile_backend + selected_backend = compile_backend + if compile_backend == "auto": + selected_backend = "inductor" + if len(dali_types) > 1: + # Mixed input signatures tend to retrace heavily under torch.compile. + logger.info(f"Mixed dali_type inputs detected {sorted(dali_types)}; disable torch.compile to avoid retracing.") + selected_backend = "none" + + if selected_backend != "none": + backbone_ddp_compiled = torch.compile(backbone_ddp, backend=selected_backend) + logger.info(f"Backbone DDP torch.compile enabled. backend={selected_backend}") + else: + backbone_ddp_compiled = backbone_ddp + logger.info("Backbone DDP torch.compile disabled.") # Get patch_size from backbone config (outside of training loop for efficiency) backbone_module = unwrap_module(backbone) @@ -485,10 +506,19 @@ def wrap_ddp(model): list_iter.append(iter(list_dali_dataloader[i])) list_next_data_batch.append(next(list_iter[i])) - if global_step > args.total_steps: - logger.info("global_step > total_steps") + if global_step >= args.total_steps: + logger.info("global_step >= total_steps") exit() + sampling_bin = max(1, 64 // args.num_frames) + frame_bin_starts = torch.arange(args.num_frames, device=f"cuda:{local_rank}") * sampling_bin + token_offsets = torch.arange(args.num_tokens_per_frame, device=f"cuda:{local_rank}") + frame_token_offsets = frame_bin_starts.view(-1, 1) * args.num_tokens_per_frame + token_offsets.view(1, -1) + idx_range_by_head = [ + torch.arange(args.list_batch_sizes_adjusted[head_id], device=f"cuda:{local_rank}") + for head_id in range(args.num_heads) + ] + num_samples = 0 end_of_batch = False while not end_of_batch: @@ -512,29 +542,32 @@ def wrap_ddp(model): n1 = int(bs * args.residual_ratio) # n1 controls residual samples n2 = n1 + int(bs * args.frame_sampling_ratio) # n2 controls frame_sampling samples # n3 (collage) is implicit: bs - n2 + has_residual = n1 > 0 + has_frame_sampling = n2 > n1 + has_collage = n2 < bs + has_combined = n2 > 0 - idx_range = torch.arange(bs).cuda() # [8] + idx_range = idx_range_by_head[head_id][:bs] # [8] mask_residual = idx_range < n1 # first n1 samples use residual strategy mask_frame_sampling = (idx_range >= n1) & (idx_range < n2) # samples [n1, n2) use frame sampling mask_collage = idx_range >= n2 # samples [n2, bs) use collage strategy # mask_residual: select first args.target_num patches - if mask_residual.any(): + if has_residual: out[mask_residual] = visible_indices[mask_residual, :] # [4, 2048] # mask_frame_sampling: sample 8 frames from 64, get all patches per frame FRAMES = 64 - if mask_frame_sampling.any(): - nB = mask_frame_sampling.sum().item() # 3 - # frames: sample 1 frame from each of 8 bins (each bin has 8 frames) - frames = ( - torch.arange(args.num_frames).cuda() * (FRAMES // args.num_frames) + torch.randint(FRAMES // args.num_frames, (nB, args.num_frames)).cuda() - ) # [3, 8], values in [0,7], [8,15], .. ., [56,63] - # sel_b: for each frame, get all 256 patches - out[mask_frame_sampling] = (frames.unsqueeze(-1) * args.num_tokens_per_frame + torch.arange(args.num_tokens_per_frame).cuda()).reshape(nB, -1) # [3, 8*256] = [3, 2048] + if has_frame_sampling: + frame_sampling_idx = torch.nonzero(mask_frame_sampling, as_tuple=False).squeeze(1) # [nB] + nB = frame_sampling_idx.numel() + frame_offsets = torch.randint(sampling_bin, (nB, args.num_frames), device=head_input.device) + frames = (frame_bin_starts.unsqueeze(0) + frame_offsets).clamp(max=FRAMES - 1) # [nB, 8], values in [0, 63] + patch_indices = frame_token_offsets[frames].reshape(nB, -1) # [nB, 2048] + out[frame_sampling_idx] = patch_indices combined_mask = mask_residual | mask_frame_sampling # [8], first 7 samples are True - if combined_mask.any(): + if has_combined: combined_idx = combined_mask.nonzero(as_tuple=False).squeeze(1) # [7] video = head_input[combined_idx] # [7, 3, 64, 224, 224] vis_idx = out[combined_idx] # [7, 2048] @@ -582,7 +615,7 @@ def wrap_ddp(model): combined_head_output = backbone_ddp_compiled(combined_head_input, visible_indices=vis_idx) # input: [7, 3, 8, 224, 224], vis_idx: [7, 2048] combined_head_output = (combined_head_output.pooler_output if hasattr(combined_head_output, "pooler_output") else combined_head_output["head_output"]).float() # [7, D] - if mask_collage.any(): + if has_collage: coll_idx = torch.nonzero(mask_collage, as_tuple=False).squeeze(1) # [1] nC = coll_idx.numel() # 1 FRAMES = 64 @@ -596,10 +629,8 @@ def wrap_ddp(model): Cf = head_subset.size(1) # 3 Hf = head_subset.size(3) # 224 Wf = head_subset.size(4) # 224 - avg = FRAMES // args.num_frames # 64 // 8 = 8 - base = torch.arange(args.num_frames).cuda() * avg # [0, 8, 16, 24, 32, 40, 48, 56] - offs = torch.randint(avg, (nC, args.num_frames)).cuda() # [1, 8], values in [0, 7] - frames_idx = (base.unsqueeze(0) + offs).long().clamp(max=FRAMES - 1) # [1, 8], values in [0, 63] + frame_offsets = torch.randint(sampling_bin, (nC, args.num_frames), device=head_subset.device) + frames_idx = (frame_bin_starts.unsqueeze(0) + frame_offsets).long().clamp(max=FRAMES - 1) # [1, 8], values in [0, 63] idx_expand = frames_idx.view(nC, 1, args.num_frames, 1, 1).expand(-1, Cf, -1, Hf, Wf) # [1, 3, 8, 224, 224] sel_frames = torch.gather(head_subset, 2, idx_expand) # [1, 3, 8, 224, 224] sel_frames = sel_frames.permute(0, 2, 1, 3, 4) # [1, 8, 3, 224, 224] @@ -615,10 +646,10 @@ def wrap_ddp(model): D = combined_head_output.size(1) # embedding dimension - head_embedding_full = torch.zeros(bs, D, dtype=torch.float32).cuda() # [8, D] - if combined_mask.any(): + head_embedding_full = torch.zeros(bs, D, dtype=torch.float32, device=head_input.device) # [8, D] + if has_combined: head_embedding_full[combined_idx] = combined_head_output # head_embedding_full[0:7] = [7, D] - if mask_collage.any(): + if has_collage: head_embedding_full[coll_idx] = collage_head_output # head_embedding_full[7] = [1, D] list_embedding.append(head_embedding_full) # [8, D] @@ -639,19 +670,19 @@ def wrap_ddp(model): raise ValueError(f"Unsupported DALI type: {dataset_config.dali_type}") list_loss = [] - list_loss_float = [] + list_loss_detached = [] for head_id, pfc in enumerate(list_module_pfc): dataset_config = args.list_datasets[head_id] head_embedding = list_embedding[head_id] - head_label = list_data_batch[head_id]["labels"].long().cuda() + head_label = list_data_batch[head_id]["labels"].long() label_select = dataset_config.label_select random_diff = dataset_config.random_diff loss_weight = args.list_loss_weights[head_id] head_label = head_label[:, label_select : label_select + random_diff] head_loss = pfc(head_embedding, head_label, random_diff) * loss_weight list_loss.append(head_loss) - list_loss_float.append(head_loss.item()) + list_loss_detached.append(head_loss.detach()) is_accumulation_step = (global_step + 1) % args.backward_passes_per_step != 0 scaled_loss = sum(list_loss) / args.backward_passes_per_step @@ -671,7 +702,7 @@ def wrap_ddp(model): batch_end_callback( global_step=global_step, lr_scheduler=lr_scheduler, - list_loss_float=list_loss_float, + list_loss=list_loss_detached, batch_size=args.batch_size, num_samples=num_samples, ) @@ -693,7 +724,7 @@ def wrap_ddp(model): keep_num=20, ) - if global_step > args.total_steps: + if global_step >= args.total_steps: save_checkpoint( args.output, backbone, @@ -738,7 +769,8 @@ def __init__( self.num_head = len(self.list_head_names) self.time_start = time.time() - self.list_loss_metric = [ScalaMetric() for _ in self.list_head_names] + self.loss_sum = torch.zeros(self.num_head, dtype=torch.float32, device=f"cuda:{local_rank}") + self.loss_count = 0 self.init = False self.tic = 0 @@ -757,12 +789,13 @@ def __call__( self, global_step: int, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - list_loss_float: list[float], + list_loss: list[torch.Tensor], batch_size: int, num_samples=None, ): - for i in range(self.num_head): - self.list_loss_metric[i].update(list_loss_float[i]) + for head_id, head_loss in enumerate(list_loss): + self.loss_sum[head_id] += head_loss.float() + self.loss_count += 1 if global_step > 0 and global_step % self.frequent == 0: if self.init: @@ -791,21 +824,22 @@ def __call__( progress = f"step: {global_step}/{self.total_steps} ({global_step / self.total_steps * 100:.2f}%) " time_info = f"remain: {remaining_time_hours:.2f} hours" + loss_avg_list = (self.loss_sum / max(1, self.loss_count)).detach().cpu().tolist() loss_str_format = "" for head_id, name in enumerate(self.list_head_names): + head_loss_avg = float(loss_avg_list[head_id]) if rank == 0 and self.tb_writer: - self.tb_writer.add_scalar(f"loss/{name}", self.list_loss_metric[head_id].avg, global_step) + self.tb_writer.add_scalar(f"loss/{name}", head_loss_avg, global_step) self.tb_writer.add_scalar(f"lr/{name}", lr_scheduler.get_last_lr()[head_id + 1], global_step) self.tb_writer.add_scalar( f"samples vs. loss/{name}", - self.list_loss_metric[head_id].avg, + head_loss_avg, num_samples, ) loss_str_format += f"\n{f'name: {name}':<50}{f'lr: {lr_scheduler.get_last_lr()[head_id + 1]:.8f}':<20}" - loss_str_format += f"{f'loss: {self.list_loss_metric[head_id].avg:.4f}':<20}" - self.list_loss_metric[head_id].reset() + loss_str_format += f"{f'loss: {head_loss_avg:.4f}':<20}" examples_info = f"samples: {num_samples}" msg = f"{header}{progress}{time_info} {examples_info}{loss_str_format}" @@ -815,32 +849,13 @@ def __call__( # Flush TensorBoard writer if self.tb_writer: self.tb_writer.flush() + self.loss_sum.zero_() + self.loss_count = 0 else: self.init = True self.tic = time.time() -class ScalaMetric(object): - def __init__(self): - self.val = None - self.avg = None - self.sum = None - self.count = None - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - def log_args(args, logger, writer: SummaryWriter = None, save_dir: str = None, rank: int = 0): if rank != 0: return