From 2c1bfc03a29baad5e0cf138a27308e789d28e6be Mon Sep 17 00:00:00 2001 From: James Huang Date: Thu, 26 Feb 2026 00:00:56 +0000 Subject: [PATCH 1/2] [Text Pipeline] Implement Text Encoders Wrappers with mesh support --- .../ltx2/text_encoders/text_encoders_ltx2.py | 164 ++++++++++++++++++ .../tests/test_text_encoders_ltx2.py | 90 ++++++++++ 2 files changed, 254 insertions(+) create mode 100644 src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py create mode 100644 src/maxdiffusion/tests/test_text_encoders_ltx2.py diff --git a/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py b/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py new file mode 100644 index 00000000..f043ff4e --- /dev/null +++ b/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py @@ -0,0 +1,164 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Optional, Tuple, Union, List +import jax +import jax.numpy as jnp +from flax import nnx +from maxdiffusion import common_types + +from .feature_extractor_ltx2 import LTX2GemmaFeatureExtractor +from .embeddings_connector_ltx2 import Embeddings1DConnector + +Array = common_types.Array +DType = common_types.DType + + +class LTX2VideoGemmaTextEncoder(nnx.Module): + """ + Encoder for Video-only tasks. + Pipeline: Gemma Hidden States -> Feature Extractor -> Video Connector -> Output + """ + + def __init__( + self, + # Feature Extractor Config + gemma_dim: int = 3840, # Gemma-3-12b + gemma_layers: int = 49, # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states + projection_dim: int = 4096, # LTX-2 conditioning dim + # Connector Config + connector_heads: int = 32, + connector_head_dim: int = 128, + connector_layers: int = 2, + num_thinking_tokens: int = 128, + dtype: DType = jnp.float32, + attention_kernel: str = "flash", + mesh: jax.sharding.Mesh = None, + rngs: nnx.Rngs = None, + ): + input_dim = gemma_dim * gemma_layers + + self.feature_extractor = LTX2GemmaFeatureExtractor( + input_dim=input_dim, + output_dim=projection_dim, + dtype=dtype, + rngs=rngs, + ) + + self.embeddings_connector = Embeddings1DConnector( + input_dim=projection_dim, + heads=connector_heads, + head_dim=connector_head_dim, + layers=connector_layers, + num_learnable_registers=num_thinking_tokens, + rope_type="interleaved", + attention_kernel=attention_kernel, + mesh=mesh, + rngs=rngs, + ) + + def __call__( + self, + hidden_states: Union[Tuple[Array, ...], List[Array]], + attention_mask: Array, + ) -> Array: + """ + Args: + hidden_states: From Gemma output.hidden_states (Tuple of [B, T, D]) + attention_mask: [B, T] + """ + # 1. Feature Extraction (Stack -> Norm -> Project) + features = self.feature_extractor(hidden_states, attention_mask) + + # 2. Connection (Refine + Thinking Tokens) + video_embeds = self.embeddings_connector(features, attention_mask) + + return video_embeds + + +class LTX2AudioVideoGemmaTextEncoder(nnx.Module): + """ + Encoder for Audio-Video tasks. + Pipeline: Gemma Hidden States -> Feature Extractor -> [Video Connector, Audio Connector] + """ + + def __init__( + self, + # Feature Extractor Config (Shared) + gemma_dim: int = 3840, # Gemma-3-12b + gemma_layers: int = 49, # Gemma-3 has 48 layers + 1 embedding layer output = 49 hidden states + projection_dim: int = 4096, + # Connector Config + connector_heads: int = 32, + connector_head_dim: int = 128, + connector_layers: int = 2, + num_thinking_tokens: int = 128, + dtype: DType = jnp.float32, + attention_kernel: str = "flash", + mesh: jax.sharding.Mesh = None, + rngs: nnx.Rngs = None, + ): + input_dim = gemma_dim * gemma_layers + + self.feature_extractor = LTX2GemmaFeatureExtractor( + input_dim=input_dim, + output_dim=projection_dim, + dtype=dtype, + rngs=rngs, + ) + + # Two independent connectors + self.video_embeddings_connector = Embeddings1DConnector( + input_dim=projection_dim, + heads=connector_heads, + head_dim=connector_head_dim, + layers=connector_layers, + num_learnable_registers=num_thinking_tokens, + rope_type="interleaved", + attention_kernel=attention_kernel, + mesh=mesh, + rngs=rngs, + ) + + self.audio_embeddings_connector = Embeddings1DConnector( + input_dim=projection_dim, + heads=connector_heads, + head_dim=connector_head_dim, + layers=connector_layers, + num_learnable_registers=num_thinking_tokens, + rope_type="interleaved", + attention_kernel=attention_kernel, + mesh=mesh, + rngs=rngs, + ) + + def __call__( + self, + hidden_states: Union[Tuple[Array, ...], List[Array]], + attention_mask: Array, + ) -> Tuple[Array, Array]: + """ + Returns: + (video_embeds, audio_embeds) + """ + # 1. Shared Feature Extraction + features = self.feature_extractor(hidden_states, attention_mask) + + # 2. Parallel Connection + video_embeds = self.video_embeddings_connector(features, attention_mask) + audio_embeds = self.audio_embeddings_connector(features, attention_mask) + + return video_embeds, audio_embeds diff --git a/src/maxdiffusion/tests/test_text_encoders_ltx2.py b/src/maxdiffusion/tests/test_text_encoders_ltx2.py new file mode 100644 index 00000000..1332365d --- /dev/null +++ b/src/maxdiffusion/tests/test_text_encoders_ltx2.py @@ -0,0 +1,90 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx +from ..models.ltx2.text_encoders.text_encoders_ltx2 import LTX2VideoGemmaTextEncoder, LTX2AudioVideoGemmaTextEncoder + + +class LTX2TextEncodersTest(unittest.TestCase): + + def setUp(self): + self.rng = nnx.Rngs(0) + self.B = 2 + self.T = 16 + self.gemma_dim = 32 + self.gemma_layers = 3 + self.proj_dim = 64 + + # Mock Gemma hidden states + self.hidden_states = [jnp.array(np.random.randn(self.B, self.T, self.gemma_dim)) for _ in range(self.gemma_layers)] + + self.attention_mask = jnp.ones((self.B, self.T)) + + def test_video_encoder_forward(self): + encoder = LTX2VideoGemmaTextEncoder( + gemma_dim=self.gemma_dim, + gemma_layers=self.gemma_layers, + projection_dim=self.proj_dim, + connector_heads=4, + connector_head_dim=16, + connector_layers=1, + num_thinking_tokens=8, + attention_kernel="dot_product", + mesh=None, + rngs=self.rng, + ) + + output = encoder(tuple(self.hidden_states), self.attention_mask) + + # Expected shape: [B, T, proj_dim] + self.assertEqual(output.shape, (self.B, self.T, self.proj_dim)) + print("\n[PASS] Video Encoder Forward Pass Verified.") + + def test_av_encoder_forward(self): + encoder = LTX2AudioVideoGemmaTextEncoder( + gemma_dim=self.gemma_dim, + gemma_layers=self.gemma_layers, + projection_dim=self.proj_dim, + connector_heads=4, + connector_head_dim=16, + connector_layers=1, + num_thinking_tokens=8, + attention_kernel="dot_product", + mesh=None, + rngs=self.rng, + ) + + video_out, audio_out = encoder(tuple(self.hidden_states), self.attention_mask) + + # Expected shapes: Both [B, T, proj_dim] + self.assertEqual(video_out.shape, (self.B, self.T, self.proj_dim)) + self.assertEqual(audio_out.shape, (self.B, self.T, self.proj_dim)) + + # Ensure they are different (different random init for connectors) + # Note: In reality they are initialized differently, so outputs should differ + self.assertFalse( + jnp.allclose(video_out, audio_out), "Video and Audio outputs should differ due to different connector weights" + ) + + print("\n[PASS] Audio-Video Encoder Forward Pass Verified.") + + +if __name__ == "__main__": + unittest.main() From 8ffec22aed515e67be3c2452b7169c7bc325eea0 Mon Sep 17 00:00:00 2001 From: James Huang Date: Thu, 26 Feb 2026 01:29:41 +0000 Subject: [PATCH 2/2] ci fix Signed-off-by: James Huang --- .../models/ltx2/text_encoders/text_encoders_ltx2.py | 2 +- src/maxdiffusion/tests/test_text_encoders_ltx2.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py b/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py index f043ff4e..0728139a 100644 --- a/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py +++ b/src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py @@ -14,7 +14,7 @@ limitations under the License. """ -from typing import Optional, Tuple, Union, List +from typing import Tuple, Union, List import jax import jax.numpy as jnp from flax import nnx diff --git a/src/maxdiffusion/tests/test_text_encoders_ltx2.py b/src/maxdiffusion/tests/test_text_encoders_ltx2.py index 1332365d..e7e22500 100644 --- a/src/maxdiffusion/tests/test_text_encoders_ltx2.py +++ b/src/maxdiffusion/tests/test_text_encoders_ltx2.py @@ -15,7 +15,6 @@ """ import unittest -import jax import jax.numpy as jnp import numpy as np from flax import nnx