Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions dataloader/data_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def __init__(self, dali_iter, label_select=None):

def __next__(self):
data_dict = self.iter.__next__()[0]
tensor_data = data_dict["data"].cuda()
tensor_label = data_dict["label"].long().cuda()
tensor_data = data_dict["data"]
tensor_label = data_dict["label"].long()

if self.label_select is None:
return {"pixel_values": tensor_data, "labels": tensor_label}
Expand Down Expand Up @@ -142,7 +142,7 @@ def dali_dataloader(
num_shards=None,
shard_id=None,
):
local_rank = int(os.environ.get("LOCAL_RANK", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))

if num_shards is None:
num_shards = int(os.environ.get("WORLD_SIZE", "1"))
Expand All @@ -161,9 +161,9 @@ def dali_dataloader(

pipe = Pipeline(
batch_size=batch_size,
num_threads=2,
num_threads=max(1, workers),
device_id=local_rank % 8,
prefetch_queue_depth=1,
prefetch_queue_depth=3,
seed=seed,
)
device_memory_padding = 211025920
Expand Down
4 changes: 2 additions & 2 deletions dataloader/data_v2_multi_res.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def __init__(self, dali_iter, label_select=None):

def __next__(self):
data_dict = self.iter.__next__()[0]
tensor_data = data_dict["data"].cuda()
tensor_label = data_dict["label"].long().cuda()
tensor_data = data_dict["data"]
tensor_label = data_dict["label"].long()

if self.label_select is None:
return tensor_data, tensor_label
Expand Down
4 changes: 2 additions & 2 deletions dataloader/data_v2_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def __init__(self, dali_iter, label_select=None):

def __next__(self):
data_dict = self.iter.__next__()[0]
tensor_data = data_dict["data"].cuda()
tensor_label = data_dict["label"].long().cuda()
tensor_data = data_dict["data"]
tensor_label = data_dict["label"].long()

if self.label_select is None:
return {"pixel_values": tensor_data, "labels": tensor_label}
Expand Down
45 changes: 32 additions & 13 deletions onevision_encoder/modeling_onevision_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,18 @@ def __init__(self, config: OneVisionEncoderConfig):
1.0 / (base ** (torch.arange(self.w_size, dtype=torch.float32) / self.w_size)),
persistent=False,
)
self._freqs_cache = {}
self._max_cache_size = 8

def forward(self, t: int, h: int, w: int, device=None):
if device is None:
device = self.inv_freq_t.device

cache_key = (int(t), int(h), int(w), device.type, -1 if device.index is None else int(device.index))
cached = self._freqs_cache.get(cache_key)
if cached is not None:
return cached

inv_t = self.inv_freq_t.to(device=device)
inv_h = self.inv_freq_h.to(device=device)
inv_w = self.inv_freq_w.to(device=device)
Expand All @@ -158,6 +165,9 @@ def forward(self, t: int, h: int, w: int, device=None):
w_ids = torch.arange(w, device=device).repeat(h).repeat(t)

freqs = torch.cat([ft[t_ids], fh[h_ids], fw[w_ids]], dim=-1)
if len(self._freqs_cache) >= self._max_cache_size:
self._freqs_cache.pop(next(iter(self._freqs_cache)))
self._freqs_cache[cache_key] = freqs
return freqs

def forward_from_positions(self, patch_positions: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -611,23 +621,32 @@ def forward(
hidden_states = self.embeddings(pixel_values)
batch_size, total_patches, _ = hidden_states.shape

# 2. Visible Indices Handling
if visible_indices is None:
visible_indices = (
torch.arange(total_patches, device=pixel_values.device).unsqueeze(0).expand(batch_size, -1)
)

# 3. RoPE Construction
if patch_positions is not None:
freqs_visible = self.video_rope.forward_from_positions(patch_positions)
else:
# 2. Visible Indices / RoPE Construction
dense_full_path = visible_indices is None and patch_positions is None
if dense_full_path:
freqs_full = self.video_rope(
t=t_frames,
h=height // self.config.patch_size,
w=width // self.config.patch_size,
device=pixel_values.device,
)
freqs_visible = freqs_full[visible_indices]
freqs_visible = freqs_full.unsqueeze(0).expand(batch_size, -1, -1)
else:
if visible_indices is None:
visible_indices = (
torch.arange(total_patches, device=pixel_values.device).unsqueeze(0).expand(batch_size, -1)
)

if patch_positions is not None:
freqs_visible = self.video_rope.forward_from_positions(patch_positions)
else:
freqs_full = self.video_rope(
t=t_frames,
h=height // self.config.patch_size,
w=width // self.config.patch_size,
device=pixel_values.device,
)
freqs_visible = freqs_full[visible_indices]

# Concatenate D/2 + D/2 -> D for applying rope
freqs_visible = torch.cat([freqs_visible, freqs_visible], dim=-1)
Expand All @@ -636,8 +655,8 @@ def forward(
hidden_states = self.layernorm_pre(hidden_states)

# fix: gather hidden_states to match freqs_visible when using sparse visible_indices
num_visible = visible_indices.shape[1]
if num_visible != total_patches:
num_visible = total_patches if dense_full_path else visible_indices.shape[1]
if not dense_full_path and num_visible != total_patches:
# sparse mode: select only visible patches
hidden_states = hidden_states.gather(
1, visible_indices.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])
Expand Down
Loading