Skip to content

[Text Pipeline] Implement Embedding Connector#338

Open
syhuang22 wants to merge 2 commits intoAI-Hypercomputer:mainfrom
syhuang22:feat/ltx2-embedding-connector
Open

[Text Pipeline] Implement Embedding Connector#338
syhuang22 wants to merge 2 commits intoAI-Hypercomputer:mainfrom
syhuang22:feat/ltx2-embedding-connector

Conversation

@syhuang22
Copy link
Collaborator

This module acts as the crucial bridge in the LTX-2 text pipeline, responsible for processing and aligning the text embeddings (after feature extraction) before they are fed into the main diffusion model.

@syhuang22 syhuang22 requested a review from entrpn as a code owner February 26, 2026 00:29
@google-cla
Copy link

google-cla bot commented Feb 26, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@syhuang22 syhuang22 force-pushed the feat/ltx2-embedding-connector branch from e656273 to 099551b Compare February 26, 2026 00:48
Signed-off-by: James Huang <syhuang1201@gmail.com>
@syhuang22 syhuang22 force-pushed the feat/ltx2-embedding-connector branch from 099551b to 7747a33 Compare February 26, 2026 01:08
Signed-off-by: James Huang <syhuang1201@gmail.com>
rngs=rngs,
)
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim)
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
Copy link
Collaborator

@prishajain1 prishajain1 Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Diffusers uses elementwise_affine = False
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ltx2/connectors.py#L112
We should set use_scale = False

)
self.ff = NNXSimpleFeedForward(rngs=rngs, dim=dim, dim_out=dim)
self.norm1 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
self.norm2 = nnx.RMSNorm(dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_scale = False

)

self.final_norm = nnx.RMSNorm(
self.dim, epsilon=1e-6, dtype=jnp.float32, param_dtype=jnp.float32, use_scale=True, rngs=rngs
Copy link
Collaborator

@prishajain1 prishajain1 Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants