Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,17 @@ After installation completes, run the training script.
- For use on GPU it is recommended to enable the cudnn_te_flash attention kernel for optimal performance.
- Best performance is achieved with the use of batch parallelism, which can be enabled by using the ici_fsdp_batch_parallelism axis. Note that this parallelism strategy does not support fractional batch sizes.
- ici_fsdp_batch_parallelism and ici_fsdp_parallelism can be combined to allow for fractional batch sizes. However, padding is not currently supported for the cudnn_te_flash attention kernel and it is therefore required that the sequence length is divisible by the number of devices in the ici_fsdp_parallelism axis.
- For benchmarking training performance on multiple data dimension input without downloading/re-processing the dataset, the synthetic data iterator is supported.
- Set dataset_type='synthetic' and synthetic_num_samples=null to enable the synthetic data iterator.
- The following overrides on data dimensions are supported:
- synthetic_override_height: 720
- synthetic_override_width: 1280
- synthetic_override_num_frames: 85
- synthetic_override_max_sequence_length: 512
- synthetic_override_text_embed_dim: 4096
- synthetic_override_num_channels_latents: 16
- synthetic_override_vae_scale_factor_spatial: 8
- synthetic_override_vae_scale_factor_temporal: 4

You should eventually see a training run as:

Expand Down
15 changes: 14 additions & 1 deletion src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,20 @@ allow_split_physical_axes: False
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tf'
dataset_type: 'tfrecord' # Options: 'tfrecord', 'hf', 'tf', 'grain', 'synthetic'
# ==============================================================================
# Synthetic Data Configuration (only used when dataset_type='synthetic')
# ==============================================================================
# To use synthetic data for testing/debugging without real datasets:
# 1. Set dataset_type: 'synthetic' above
# 2. Optionally set synthetic_num_samples (null=infinite, or a number like 10000)
# 3. Optionally override dimensions
#
# synthetic_num_samples: null # null for infinite, or set a number
#
# Optional dimension overrides:
# resolution: 512
# ==============================================================================
cache_latents_text_encoder_outputs: True
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
# only apply to small dataset that fits in memory
Expand Down
23 changes: 22 additions & 1 deletion src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,28 @@ allow_split_physical_axes: False
# Replace with dataset path or train_data_dir. One has to be set.
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tfrecord'
dataset_type: 'tfrecord' # Options: 'tfrecord', 'hf', 'tf', 'grain', 'synthetic'
# ==============================================================================
# Synthetic Data Configuration (only used when dataset_type='synthetic')
# ==============================================================================
# To use synthetic data for testing/debugging without real datasets:
# 1. Set dataset_type: 'synthetic' above
# 2. Optionally set synthetic_num_samples (null=infinite, or a number like 10000)
# 3. Optionally override dimensions with synthetic_override_* flags below
#
# synthetic_num_samples: null # null for infinite, or set a number
#
# Optional dimension overrides (comment out to use pipeline/config values):
# synthetic_override_height: 720
# synthetic_override_width: 1280
# synthetic_override_num_frames: 121
# synthetic_override_max_sequence_length: 512
# synthetic_override_text_embed_dim: 4096
# synthetic_override_num_channels_latents: 16
# synthetic_override_vae_scale_factor_spatial: 8
# synthetic_override_vae_scale_factor_temporal: 4
# ==============================================================================

cache_latents_text_encoder_outputs: True
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
# only apply to small dataset that fits in memory
Expand Down
14 changes: 12 additions & 2 deletions src/maxdiffusion/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from maxdiffusion.input_pipeline import _hf_data_processing
from maxdiffusion.input_pipeline import _grain_data_processing
from maxdiffusion.input_pipeline import _tfds_data_processing
from maxdiffusion.input_pipeline import synthetic_data_iterator
from maxdiffusion import multihost_dataloading
from maxdiffusion.maxdiffusion_utils import tokenize_captions, transform_images, vae_apply
from maxdiffusion.dreambooth.dreambooth_constants import (
Expand Down Expand Up @@ -54,8 +55,9 @@ def make_data_iterator(
feature_description=None,
prepare_sample_fn=None,
is_training=True,
pipeline=None,
):
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)"""
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord, grain, synthetic)"""

if config.dataset_type == "hf" or config.dataset_type == "tf":
if tokenize_fn is None or image_transforms_fn is None:
Expand Down Expand Up @@ -110,8 +112,16 @@ def make_data_iterator(
prepare_sample_fn,
is_training,
)
elif config.dataset_type == "synthetic":
return synthetic_data_iterator.make_synthetic_iterator(
config=config,
mesh=mesh,
global_batch_size=global_batch_size,
pipeline=pipeline,
is_training=is_training,
)
else:
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain, synthetic)"


def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, vae, vae_params):
Expand Down
Loading