Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Tuple, Union, List
import jax
import jax.numpy as jnp
from flax import nnx
from maxdiffusion import common_types

from .feature_extractor_ltx2 import LTX2GemmaFeatureExtractor
from .embeddings_connector_ltx2 import Embeddings1DConnector

Array = common_types.Array
DType = common_types.DType


class LTX2VideoGemmaTextEncoder(nnx.Module):
"""
Encoder for Video-only tasks.
Pipeline: Gemma Hidden States -> Feature Extractor -> Video Connector -> Output
"""

def __init__(
self,
# Feature Extractor Config
gemma_dim: int = 3840, # Gemma-3-12b
gemma_layers: int = 49, # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states
projection_dim: int = 4096, # LTX-2 conditioning dim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change needed 4096 -> 3840

# Connector Config
connector_heads: int = 32,
connector_head_dim: int = 128,
connector_layers: int = 2,
num_thinking_tokens: int = 128,
dtype: DType = jnp.float32,
attention_kernel: str = "flash",
mesh: jax.sharding.Mesh = None,
rngs: nnx.Rngs = None,
):
input_dim = gemma_dim * gemma_layers

self.feature_extractor = LTX2GemmaFeatureExtractor(
input_dim=input_dim,
output_dim=projection_dim,
dtype=dtype,
rngs=rngs,
)

self.embeddings_connector = Embeddings1DConnector(
input_dim=projection_dim,
heads=connector_heads,
head_dim=connector_head_dim,
layers=connector_layers,
num_learnable_registers=num_thinking_tokens,
rope_type="interleaved",
attention_kernel=attention_kernel,
mesh=mesh,
rngs=rngs,
)

def __call__(
self,
hidden_states: Union[Tuple[Array, ...], List[Array]],
attention_mask: Array,
) -> Array:
"""
Args:
hidden_states: From Gemma output.hidden_states (Tuple of [B, T, D])
attention_mask: [B, T]
"""
# 1. Feature Extraction (Stack -> Norm -> Project)
features = self.feature_extractor(hidden_states, attention_mask)

# 2. Connection (Refine + Thinking Tokens)
video_embeds = self.embeddings_connector(features, attention_mask)

return video_embeds


class LTX2AudioVideoGemmaTextEncoder(nnx.Module):
"""
Encoder for Audio-Video tasks.
Pipeline: Gemma Hidden States -> Feature Extractor -> [Video Connector, Audio Connector]
"""

def __init__(
self,
# Feature Extractor Config (Shared)
gemma_dim: int = 3840, # Gemma-3-12b
gemma_layers: int = 49, # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states
projection_dim: int = 4096,
Copy link
Collaborator

@prishajain1 prishajain1 Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4096 -> 3840

# Connector Config
connector_heads: int = 32,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connector_head_dim: int = 128,
connector_layers: int = 2,
num_thinking_tokens: int = 128,
dtype: DType = jnp.float32,
attention_kernel: str = "flash",
mesh: jax.sharding.Mesh = None,
rngs: nnx.Rngs = None,
):
input_dim = gemma_dim * gemma_layers

self.feature_extractor = LTX2GemmaFeatureExtractor(
input_dim=input_dim,
output_dim=projection_dim,
dtype=dtype,
rngs=rngs,
)

# Two independent connectors
self.video_embeddings_connector = Embeddings1DConnector(
input_dim=projection_dim,
heads=connector_heads,
head_dim=connector_head_dim,
layers=connector_layers,
num_learnable_registers=num_thinking_tokens,
rope_type="interleaved",
attention_kernel=attention_kernel,
mesh=mesh,
rngs=rngs,
)

self.audio_embeddings_connector = Embeddings1DConnector(
input_dim=projection_dim,
heads=connector_heads,
head_dim=connector_head_dim,
layers=connector_layers,
num_learnable_registers=num_thinking_tokens,
rope_type="interleaved",
attention_kernel=attention_kernel,
mesh=mesh,
rngs=rngs,
)

def __call__(
self,
hidden_states: Union[Tuple[Array, ...], List[Array]],
attention_mask: Array,
) -> Tuple[Array, Array]:
"""
Returns:
(video_embeds, audio_embeds)
"""
# 1. Shared Feature Extraction
features = self.feature_extractor(hidden_states, attention_mask)

# 2. Parallel Connection
video_embeds = self.video_embeddings_connector(features, attention_mask)
audio_embeds = self.audio_embeddings_connector(features, attention_mask)

return video_embeds, audio_embeds
89 changes: 89 additions & 0 deletions src/maxdiffusion/tests/test_text_encoders_ltx2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest
import jax.numpy as jnp
import numpy as np
from flax import nnx
from ..models.ltx2.text_encoders.text_encoders_ltx2 import LTX2VideoGemmaTextEncoder, LTX2AudioVideoGemmaTextEncoder


class LTX2TextEncodersTest(unittest.TestCase):

def setUp(self):
self.rng = nnx.Rngs(0)
self.B = 2
self.T = 16
self.gemma_dim = 32
self.gemma_layers = 3
self.proj_dim = 64

# Mock Gemma hidden states
self.hidden_states = [jnp.array(np.random.randn(self.B, self.T, self.gemma_dim)) for _ in range(self.gemma_layers)]

self.attention_mask = jnp.ones((self.B, self.T))

def test_video_encoder_forward(self):
encoder = LTX2VideoGemmaTextEncoder(
gemma_dim=self.gemma_dim,
gemma_layers=self.gemma_layers,
projection_dim=self.proj_dim,
connector_heads=4,
connector_head_dim=16,
connector_layers=1,
num_thinking_tokens=8,
attention_kernel="dot_product",
mesh=None,
rngs=self.rng,
)

output = encoder(tuple(self.hidden_states), self.attention_mask)

# Expected shape: [B, T, proj_dim]
self.assertEqual(output.shape, (self.B, self.T, self.proj_dim))
print("\n[PASS] Video Encoder Forward Pass Verified.")

def test_av_encoder_forward(self):
encoder = LTX2AudioVideoGemmaTextEncoder(
gemma_dim=self.gemma_dim,
gemma_layers=self.gemma_layers,
projection_dim=self.proj_dim,
connector_heads=4,
connector_head_dim=16,
connector_layers=1,
num_thinking_tokens=8,
attention_kernel="dot_product",
mesh=None,
rngs=self.rng,
)

video_out, audio_out = encoder(tuple(self.hidden_states), self.attention_mask)

# Expected shapes: Both [B, T, proj_dim]
self.assertEqual(video_out.shape, (self.B, self.T, self.proj_dim))
self.assertEqual(audio_out.shape, (self.B, self.T, self.proj_dim))

# Ensure they are different (different random init for connectors)
# Note: In reality they are initialized differently, so outputs should differ
self.assertFalse(
jnp.allclose(video_out, audio_out), "Video and Audio outputs should differ due to different connector weights"
)

print("\n[PASS] Audio-Video Encoder Forward Pass Verified.")


if __name__ == "__main__":
unittest.main()
Loading