Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions dependencies/requirements/base_requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
--extra-index-url https://download.pytorch.org/whl/cpu
absl-py
aqtp
datasets
einops
flax
ftfy
google-cloud-storage
grain
hf_transfer
huggingface_hub
imageio-ffmpeg
imageio
jax
jaxlib
Jinja2
opencv-python-headless
optax
orbax-checkpoint
parameterized
Pillow
pyink
pylint
pytest
ruff
scikit-image
sentencepiece
tensorboard-plugin-profile
tensorboard
tensorboardx
tensorflow-datasets
tensorflow
tokamax
tokenizers
transformers

# pinning torch and torchvision to specific versions to avoid
# installing GPU versions from PyPI when running seed-env
torch @ https://download.pytorch.org/whl/cpu/torch-2.10.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.25.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
qwix @ https://github.com/google/qwix/archive/408a0f48f988b6c5b180e07f0cb1d05997bf0dcc.zip

191 changes: 191 additions & 0 deletions dependencies/requirements/generated_requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Generated by seed-env. Do not edit manually.
# If you need to modify dependencies, please do so in the host requirements file and run seed-env again.

absl-py>=2.3.1
aiofiles>=25.1.0
aiohappyeyeballs>=2.6.1
aiohttp>=3.13.3
aiosignal>=1.4.0
aqtp>=0.9.0
array-record>=0.8.3 ; sys_platform != 'win32'
astroid>=4.0.4
astunparse>=1.6.3
attrs>=25.4.0
auditwheel>=6.6.0
black>=25.12.0
build>=1.4.0
certifi>=2026.1.4
cffi>=2.0.0 ; platform_python_implementation != 'PyPy'
charset-normalizer>=3.4.4
cheroot>=11.1.2
chex>=0.1.91
click>=8.3.1
cloudpickle>=3.1.2
colorama>=0.4.6
contourpy>=1.3.3
cryptography>=46.0.5
cycler>=0.12.1
dataclasses-json>=0.6.7
datasets>=2.14.4
decorator>=5.2.1
dill>=0.3.7
dm-tree>=0.1.9
docstring-parser>=0.17.0
einops>=0.8.2
etils>=1.13.0
execnet>=2.1.2
filelock>=3.20.3
flatbuffers>=25.12.19
flax>=0.12.4
fonttools>=4.61.1
frozenlist>=1.8.0
fsspec>=2026.1.0
ftfy>=6.3.1
gast>=0.7.0
gcsfs>=2026.1.0
google-api-core>=2.29.0
google-auth-oauthlib>=1.2.4
google-auth>=2.48.0
google-cloud-core>=2.5.0
google-cloud-storage-control>=1.10.0
google-cloud-storage>=3.9.0
google-crc32c>=1.8.0
google-pasta>=0.2.0
google-resumable-media>=2.8.0
googleapis-common-protos>=1.72.0
grain>=0.2.15
grpc-google-iam-v1>=0.14.3
grpcio-status>=1.76.0
grpcio>=1.76.0
gviz-api>=1.10.0
h5py>=3.15.1
hf-transfer>=0.1.9
hf-xet>=1.2.1 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
huggingface-hub>=0.36.2
humanize>=4.15.0
hypothesis>=6.142.1
idna>=3.11
imageio-ffmpeg>=0.6.0
imageio>=2.37.2
immutabledict>=4.3.0
importlib-resources>=6.5.2
iniconfig>=2.3.0
isort>=7.0.0
jaraco-functools>=4.4.0
jax>=0.9.0
jaxlib>=0.9.0
jaxtyping>=0.3.7
jinja2>=3.1.6
keras>=3.13.1
kiwisolver>=1.4.9
lazy-loader>=0.4
libclang>=18.1.1
libtpu>=0.0.34 ; platform_machine == 'x86_64' and sys_platform == 'linux'
markdown-it-py>=4.0.0
markdown>=3.10.1
markupsafe>=3.0.3
marshmallow>=3.26.2
matplotlib>=3.10.8
mccabe>=0.7.0
mdurl>=0.1.2
ml-dtypes>=0.5.4
more-itertools>=10.8.0
mpmath>=1.3.0
msgpack>=1.1.2
multidict>=6.7.1
multiprocess>=0.70.15
mypy-extensions>=1.1.0
namex>=0.1.0
nest-asyncio>=1.6.0
networkx>=3.6.1
numpy-typing-compat>=20251206.2.0
numpy>=2.0.2
nvidia-cuda-cccl>=13.1.115
oauthlib>=3.3.1
opencv-python-headless>=4.13.0.92
opt-einsum>=3.4.0
optax>=0.2.6
optree>=0.18.0
optype>=0.15.0
orbax-checkpoint>=0.11.32
orbax-export>=0.0.8
packaging>=26.0
pandas>=3.0.0
parameterized>=0.9.0
pathspec>=1.0.4
pillow>=12.1.0
platformdirs>=4.7.1
pluggy>=1.6.0
portpicker>=1.6.0
promise>=2.3
propcache>=0.4.1
proto-plus>=1.27.1
protobuf>=6.33.5
psutil>=7.2.1
pyarrow>=23.0.0
pyasn1-modules>=0.4.2
pyasn1>=0.6.2
pycparser>=3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'
pyelftools>=0.32
pygments>=2.19.2
pyink>=25.12.0
pylint>=4.0.4
pyparsing>=3.3.2
pyproject-hooks>=1.2.0
pytest-xdist>=3.8.0
pytest>=8.4.2
python-dateutil>=2.9.0.post0
pytokens>=0.4.1
pyyaml>=6.0.3
qwix @ https://github.com/google/qwix/archive/408a0f48f988b6c5b180e07f0cb1d05997bf0dcc.zip
regex>=2026.1.15
requests-oauthlib>=2.0.0
requests>=2.32.5
rich>=14.2.0
rsa>=4.9.1
ruff>=0.15.1
safetensors>=0.7.0
scikit-image>=0.26.0
scipy-stubs>=1.17.0.1
scipy>=1.17.0
sentencepiece>=0.2.1
setuptools>=80.10.1
simple-parsing>=0.1.8
simplejson>=3.20.2
six>=1.17.0
sortedcontainers>=2.4.0
sympy>=1.14.0
tensorboard-data-server>=0.7.2
tensorboard-plugin-profile>=2.21.6
tensorboard>=2.20.0
tensorboardx>=2.6.4
tensorflow-datasets>=4.9.9
tensorflow-metadata>=1.17.3
tensorflow>=2.20.0
tensorstore>=0.1.80
termcolor>=3.3.0
tifffile>=2026.1.28
tokamax>=0.1.0
tokenizers>=0.22.2
toml>=0.10.2
tomlkit>=0.14.0
toolz>=1.1.0
torch @ https://download.pytorch.org/whl/cpu/torch-2.10.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.25.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
tqdm>=4.67.3
transformers>=4.57.6
treescope>=0.1.10
typing-extensions>=4.15.0
typing-inspect>=0.9.0
tzdata>=2025.3 ; sys_platform == 'emscripten' or sys_platform == 'win32'
urllib3>=2.6.3
wadler-lindig>=0.1.7
wcwidth>=0.6.0
werkzeug>=3.1.5
wheel>=0.46.2
wrapt>=2.1.1
xprof>=2.21.6
xxhash>=3.6.0
yarl>=1.22.0
zipp>=3.23.0
zstandard>=0.25.0
52 changes: 52 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,55 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

[build-system]
requires = ["hatchling", "hatch-requirements-txt"]
build-backend = "hatchling.build"

[tool.hatch.version]
path = "src/maxdiffusion/__init__.py"

[project]
name = "maxdiffusion"
dynamic = ["version", "optional-dependencies"]
requires-python = ">=3.12"
readme = "README.md"
license = "Apache-2.0"
classifiers = [
"Programming Language :: Python",
]
dependencies = []

[tool.hatch.metadata.hooks.requirements_txt.optional-dependencies]
tpu = ["dependencies/requirements/generated_requirements/tpu-requirements.txt"]
cuda12 = ["dependencies/requirements/generated_requirements/cuda12-requirements.txt"]

[project.urls]
Repository = "https://github.com/AI-Hypercomputer/maxdiffusion.git"
"Bug Tracker" = "https://github.com/AI-Hypercomputer/maxdiffusion/issues"

[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.build.targets.wheel]
packages = ["src/maxdiffusion", "src/install_maxdiffusion_extra_deps"]

[tool.hatch.build.targets.wheel.hooks.custom]
path = "build_hooks.py"

[project.scripts]
install_maxdiffusion_github_deps = "install_maxdiffusion_extra_deps.install_github_deps:main"

[tool.ruff]
# Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741", "F402", "F823", "E402", "I001"]
Expand Down
41 changes: 22 additions & 19 deletions setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ if ! python3 -c 'import sys; assert sys.version_info >= (3, 12)' 2>/dev/null; th
if [[ $REPLY =~ ^[Yy]$ ]]; then
# Check if uv is installed first; if not, install uv
if ! command -v uv &> /dev/null; then
echo -e "\n'uv' command not found. Installing it now via the official installer..."
curl -LsSf https://astral.sh/uv/install.sh | sh
# echo -e "\n'uv' command not found. Installing it now via the official installer..."
# curl -LsSf https://astral.sh/uv/install.sh | sh

echo -e "\n\e[33m'uv' has been installed.\e[0m"
echo "The installer likely printed instructions to update your shell's PATH."
echo "Please open a NEW terminal session (or 'source ~/.bashrc') and re-run this script."
exit 1
# echo -e "\n\e[33m'uv' has been installed.\e[0m"
# echo "The installer likely printed instructions to update your shell's PATH."
# echo "Please open a NEW terminal session (or 'source ~/.bashrc') and re-run this script."
# exit 1
pip install uv
fi
maxdiffusion_dir=$(pwd)
cd
Expand All @@ -53,7 +54,7 @@ if ! python3 -c 'import sys; assert sys.version_info >= (3, 12)' 2>/dev/null; th
echo "No name provided. Using default name: '$venv_name'"
fi
echo "Creating virtual environment '$venv_name' with Python 3.12..."
uv venv --python 3.12 "$venv_name" --seed
python3 -m uv venv --python 3.12 "$venv_name" --seed
printf '%s\n' "$(realpath -- "$venv_name")" >> /tmp/venv_created
echo -e "\n\e[32mVirtual environment '$venv_name' created successfully!\e[0m"
echo "To activate it, run the following command:"
Expand Down Expand Up @@ -81,6 +82,8 @@ apt update -y && apt -y install gcsfuse
rm -rf /var/lib/apt/lists/*
EOF

python3 -m pip install -U setuptools wheel uv

# Set environment variables from command line arguments
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
Expand All @@ -104,7 +107,7 @@ if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then
fi

# Install dependencies from requirements.txt first
pip3 install -U -r requirements.txt || echo "Failed to install dependencies in the requirements" >&2
python3 -m uv pip install -U -r requirements.txt || echo "Failed to install dependencies in the requirements" >&2

# Install JAX and JAXlib based on the specified mode
if [[ "$MODE" == "stable" || ! -v MODE ]]; then
Expand All @@ -113,23 +116,23 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then
echo "Installing stable jax, jaxlib for tpu"
if [[ -n "$JAX_VERSION" ]]; then
echo "Installing stable jax, jaxlib, libtpu version ${JAX_VERSION}"
pip3 install "jax[tpu]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
python3 -m uv pip install "jax[tpu]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
else
echo "Installing stable jax, jaxlib, libtpu
for tpu"
pip3 install 'jax[tpu]>0.4' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
python3 -m uv pip install 'jax[tpu]>0.4' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
fi
elif [[ $DEVICE == "gpu" ]]; then
echo "Installing stable jax, jaxlib for NVIDIA gpu"
if [[ -n "$JAX_VERSION" ]]; then
echo "Installing stable jax, jaxlib ${JAX_VERSION}"
pip3 install -U "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python3 -m uv pip install -U "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
else
echo "Installing stable jax, jaxlib, libtpu for NVIDIA gpu"
pip3 install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python3 -m uv pip install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
fi
export NVTE_FRAMEWORK=jax
pip3 install transformer_engine[jax]==2.1.0
python3 -m uv pip install transformer_engine[jax]==2.1.0
fi

elif [[ $MODE == "nightly" ]]; then
Expand All @@ -140,22 +143,22 @@ elif [[ $MODE == "nightly" ]]; then
pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
# Install Transformer Engine
export NVTE_FRAMEWORK=jax
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
python3 -m uv pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
elif [[ $DEVICE == "tpu" ]]; then
echo "Installing jax-nightly,jaxlib-nightly"
# Install jax-nightly
pip3 install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
python3 -m uv pip install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
# Install jaxlib-nightly
pip3 install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
python3 -m uv pip install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Install libtpu-nightly
pip3 install --pre -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
python3 -m uv pip install --pre -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
fi
echo "Installing nightly tensorboard plugin profile"
pip3 install tbp-nightly --upgrade
python3 -m uv pip install tbp-nightly --upgrade
else
echo -e "\n\nError: You can only set MODE to [stable,nightly].\n\n"
exit 1
fi

# Install maxdiffusion
pip3 install -U . || echo "Failed to install maxdiffusion" >&2
python3 -m uv pip install -U . || echo "Failed to install maxdiffusion" >&2
Loading
Loading