Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Optional, Tuple
import jax
import jax.numpy as jnp
from flax import nnx
from maxdiffusion import common_types
from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention
from maxdiffusion.models.attention_flax import NNXSimpleFeedForward

Array = common_types.Array
DType = common_types.DType


class _BasicTransformerBlock1D(nnx.Module):

def __init__(
self,
dim: int,
heads: int,
dim_head: int,
rope_type: str = "interleaved",
attention_kernel: str = "flash",
mesh: jax.sharding.Mesh = None,
rngs: nnx.Rngs = None,
):
self.attn1 = LTX2Attention(
query_dim=dim,
heads=heads,
dim_head=dim_head,
rope_type=rope_type,
bias=True, # LTX-2 default
out_bias=True,
attention_kernel=attention_kernel,
mesh=mesh,
rngs=rngs,
)
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim)
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs)
self.norm2 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs)

def __call__(
self,
hidden_states: Array,
attention_mask: Optional[Array] = None,
rotary_emb: Optional[Tuple[Array, Array]] = None,
) -> Array:
# 1. Norm -> Attention
normed = self.norm1(hidden_states)
attn_output = self.attn1(normed, attention_mask=attention_mask, rotary_emb=rotary_emb)
hidden_states = hidden_states + attn_output

# 2. Norm -> FeedForward
normed = self.norm2(hidden_states)
ff_output = self.ff(normed)
hidden_states = hidden_states + ff_output

return hidden_states


class Embeddings1DConnector(nnx.Module):
"""
Applies 1D transformer processing with Thinking Tokens (Learnable Registers).
Uses nnx.scan for efficient JAX-idiomatic layer execution.
"""

def __init__(
self,
input_dim: int,
heads: int = 30,
head_dim: int = 128,
layers: int = 2,
theta: float = 10000.0,
num_learnable_registers: int = 128,
rope_type: str = "interleaved",
attention_kernel: str = "flash",
mesh: jax.sharding.Mesh = None,
rngs: nnx.Rngs = None,
):
self.dim = input_dim
self.theta = theta
self.num_learnable_registers = num_learnable_registers
self.num_layers = layers

# 1. Initialize Stacked Layers using vmap
# This creates a single module where parameters have an extra leading dimension [layers, ...]
# We need to ensure rngs are split for each layer
@nnx.split_rngs(splits=layers)
@nnx.vmap(in_axes=0, out_axes=0, axis_size=layers)
def create_block(rngs):
return _BasicTransformerBlock1D(
dim=input_dim,
heads=heads,
dim_head=head_dim,
rope_type=rope_type,
attention_kernel=attention_kernel,
mesh=mesh,
rngs=rngs,
)

# Call the vmapped constructor
self.stacked_blocks = create_block(rngs)

# 2. Thinking Tokens
if num_learnable_registers > 0:
key = rngs.params()
self.learnable_registers = nnx.Param(
jax.random.uniform(key, (num_learnable_registers, self.dim), dtype=jnp.bfloat16) * 2.0 - 1.0
)

self.final_norm = nnx.RMSNorm(
self.dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=False, rngs=rngs
)

def _replace_padded_with_learnable_registers(self, hidden_states: Array, attention_mask: Array) -> Tuple[Array, Array]:
b, t, d = hidden_states.shape
if t % self.num_learnable_registers != 0:
raise ValueError(f"Sequence length {t} must be divisible by {self.num_learnable_registers}")

num_duplications = t // self.num_learnable_registers
registers = jnp.tile(self.learnable_registers[...], (num_duplications, 1))
registers = jnp.expand_dims(registers, 0)

if attention_mask.ndim == 2:
mask = attention_mask[:, :, None]
else:
mask = attention_mask

output = jnp.where(mask > 0.5, hidden_states, registers)
new_mask = jnp.ones_like(attention_mask)
return output, new_mask

def _compute_1d_rope(self, seq_len: int, dtype: DType) -> Tuple[Array, Array]:
t = jnp.arange(seq_len, dtype=jnp.float32)
freqs = 1.0 / (self.theta ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
emb = jnp.outer(t, freqs)
cos = jnp.cos(emb)
sin = jnp.sin(emb)
cos = jnp.repeat(cos, 2, axis=-1)
sin = jnp.repeat(sin, 2, axis=-1)
return cos[None, ...], sin[None, ...]

def __call__(
self,
hidden_states: Array,
attention_mask: Optional[Array] = None,
) -> Array:
# 1. Thinking Tokens
if self.num_learnable_registers > 0 and attention_mask is not None:
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)

# 2. RoPE
seq_len = hidden_states.shape[1]
rotary_emb = self._compute_1d_rope(seq_len, hidden_states.dtype)

# 3. Transformer Blocks (Scan)

# Scan function signature: (carry, x) -> (carry, y)
def block_scan_fn(carry, block_module):
hidden_states = carry
# block_module is a sliced view of the vmapped module
hidden_states = block_module(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
return hidden_states, None

# Execute scan
hidden_states, _ = nnx.scan(
block_scan_fn,
length=self.num_layers,
in_axes=(nnx.Carry, 0), # Scan over the layers dimension (0) of block_module
out_axes=(nnx.Carry, 0),
)(hidden_states, self.stacked_blocks)

# 4. Final Norm
hidden_states = self.final_norm(hidden_states)

return hidden_states
114 changes: 114 additions & 0 deletions src/maxdiffusion/tests/test_embeddings_connector_ltx2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest
import jax.numpy as jnp
import numpy as np
from flax import nnx
from ..models.ltx2.text_encoders.embeddings_connector_ltx2 import Embeddings1DConnector


class Embeddings1DConnectorTest(unittest.TestCase):

def setUp(self):
self.rng = nnx.Rngs(0)
self.B = 2
self.T = 16 # Must be divisible by num_learnable_registers if we want tiling to work simply
self.D = 64 # inner_dim

# Test config
self.num_learnable_registers = 8
self.heads = 4
self.head_dim = 16

# input dim = heads * head_dim = 64

def test_thinking_tokens_replacement(self):
connector = Embeddings1DConnector(
input_dim=self.D,
heads=self.heads,
head_dim=self.head_dim,
layers=1,
num_learnable_registers=self.num_learnable_registers,
mesh=None,
rngs=self.rng,
)

# Create input [B, T, D]
hidden_states = jnp.zeros((self.B, self.T, self.D))

# Create mask [B, T]
# Batch 0: First 4 valid, rest padding
# Batch 1: First 8 valid, rest padding
mask = np.zeros((self.B, self.T), dtype=np.int32)
mask[0, :4] = 1
mask[1, :8] = 1

# Explicitly run replacement method
output, new_mask = connector._replace_padded_with_learnable_registers(hidden_states, jnp.array(mask))

# 1. Check Mask Reset
self.assertTrue(jnp.all(new_mask == 1.0), "New mask should be all 1s")

# 2. Check Valid Tokens (should be 0 as input was 0)
# Batch 0, 0-3
valid_b0 = output[0, :4, :]
self.assertTrue(jnp.all(valid_b0 == 0.0), "Valid tokens should remain unchanged")

# 3. Check Thinking Tokens (Padding area)
# Batch 0, 4-15
thinking_b0 = output[0, 4:, :]

# The learnable registers should be tiled.
# Registers shape: [8, 64]
# T=16, so it's tiled 2 times -> [16, 64]
# We need to verify that padding positions contain values from registers

# Get expected registers values
registers_val = connector.learnable_registers[...] # [8, 64]
tiled_regs = jnp.tile(registers_val, (2, 1)) # [16, 64]

expected_padding = tiled_regs[4:, :] # corresponding slice

np.testing.assert_allclose(
thinking_b0, expected_padding, err_msg="Padding should be replaced by corresponding register values"
)
print("\n[PASS] Thinking Tokens Replacement Logic Verified.")

def test_forward_shape_and_run(self):
connector = Embeddings1DConnector(
input_dim=self.D,
heads=self.heads,
head_dim=self.head_dim,
layers=2,
num_learnable_registers=self.num_learnable_registers,
attention_kernel="dot_product", # Use dot_product for testing on CPU
mesh=None,
rngs=self.rng,
)

hidden_states = jnp.array(np.random.randn(self.B, self.T, self.D))
mask = jnp.ones((self.B, self.T)) # All valid

output = connector(hidden_states, mask)

self.assertEqual(output.shape, (self.B, self.T, self.D))
self.assertFalse(jnp.isnan(output).any(), "Output should not contain NaNs")
print("\n[PASS] Embeddings1DConnector Forward Pass Verified.")


if __name__ == "__main__":
unittest.main()
Loading