diff --git a/Disco_Diffusion.ipynb b/Disco_Diffusion.ipynb
index 5221aae6..c6809fa8 100644
--- a/Disco_Diffusion.ipynb
+++ b/Disco_Diffusion.ipynb
@@ -3,8 +3,8 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "view-in-github",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "view-in-github"
},
"source": [
"
"
@@ -377,15 +377,21 @@
" # !git clone https://github.com/facebookresearch/SLIP.git\n",
" !git clone https://github.com/crowsonkb/guided-diffusion\n",
" !git clone https://github.com/assafshocher/ResizeRight.git\n",
- " !pip install -e ./CLIP\n",
+ " #!pip install -e ./CLIP\n",
+ " !pip install ./CLIP\n",
" !pip install -e ./guided-diffusion\n",
" !pip install lpips datetime timm\n",
" !apt install imagemagick\n",
" !git clone https://github.com/isl-org/MiDaS.git\n",
- " !git clone https://github.com/alembics/disco-diffusion.git\n",
+ " #!git clone https://github.com/alembics/disco-diffusion.git\n",
+ "\n",
+ " !git clone --branch pytti-disco https://github.com/pytti-tools/disco-diffusion.git\n",
+ " !pip install ./disco-diffusion\n",
+ " \n",
" # Rename a file to avoid a name conflict..\n",
- " !mv MiDaS/utils.py MiDaS/midas_utils.py\n",
- " !cp disco-diffusion/disco_xform_utils.py disco_xform_utils.py\n",
+ " !mv MiDaS/utils.py MiDaS/midas_utils.py # oof\n",
+ " !cp disco-diffusion/disco_xform_utils.py disco_xform_utils.py # uhh\n",
+ " !pip install loguru\n",
"\n",
"!mkdir model\n",
"if not path_exists(f'{model_path}/dpt_large-midas-2f21e586.pt'):\n",
@@ -393,6 +399,7 @@
"\n",
"import sys\n",
"import torch\n",
+ "from loguru import logger\n",
"\n",
"#Install pytorch3d\n",
"if is_colab:\n",
@@ -428,7 +435,8 @@
"import torchvision.transforms as T\n",
"import torchvision.transforms.functional as TF\n",
"from tqdm.notebook import tqdm\n",
- "sys.path.append('./CLIP')\n",
+ "#sys.path.append('./CLIP')\n",
+ "#!pip install ./CLIP\n",
"sys.path.append('./guided-diffusion')\n",
"import clip\n",
"from resize_right import resize\n",
@@ -444,17 +452,21 @@
"#SuperRes\n",
"if is_colab:\n",
" !git clone https://github.com/CompVis/latent-diffusion.git\n",
- " !git clone https://github.com/CompVis/taming-transformers\n",
- " !pip install -e ./taming-transformers\n",
+ " #!git clone https://github.com/CompVis/taming-transformers\n",
+ " !git clone https://github.com/pytti-tools/taming-transformers\n",
+ " #!pip install -e ./taming-transformers\n",
+ " !pip install ./taming-transformers\n",
" !pip install ipywidgets omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops wandb\n",
"\n",
"#SuperRes\n",
"import ipywidgets as widgets\n",
"import os\n",
- "sys.path.append(\".\")\n",
- "sys.path.append('./taming-transformers')\n",
+ "sys.path.append(\".\") \n",
+ "#sys.path.append('./taming-transformers')\n",
"from taming.models import vqgan # checking correct import from taming\n",
"from torchvision.datasets.utils import download_url\n",
+ "\n",
+ "# oy vey\n",
"if is_colab:\n",
" %cd '/content/latent-diffusion'\n",
"else:\n",
@@ -469,6 +481,8 @@
" from google.colab import files\n",
"else:\n",
" %cd $PROJECT_DIR\n",
+ "\n",
+ "\n",
"from IPython.display import Image as ipyimg\n",
"from numpy import asarray\n",
"from einops import rearrange, repeat\n",
@@ -481,31 +495,33 @@
"# AdaBins stuff\n",
"if USE_ADABINS:\n",
" if is_colab:\n",
- " !git clone https://github.com/shariqfarooq123/AdaBins.git\n",
- " if not path_exists(f'{model_path}/AdaBins_nyu.pt'):\n",
- " !wget https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt -P {model_path}\n",
- " !mkdir pretrained\n",
- " !cp -P {model_path}/AdaBins_nyu.pt pretrained/AdaBins_nyu.pt\n",
- " sys.path.append('./AdaBins')\n",
- " from infer import InferenceHelper\n",
+ " #!git clone https://github.com/shariqfarooq123/AdaBins.git\n",
+ " !git clone https://github.com/pytti-tools/AdaBins\n",
+ " !pip install ./AdaBins\n",
+ " #if not path_exists(f'{model_path}/AdaBins_nyu.pt'):\n",
+ " # !wget https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt -P {model_path}\n",
+ " #!mkdir pretrained\n",
+ " #!cp -P {model_path}/AdaBins_nyu.pt pretrained/AdaBins_nyu.pt\n",
+ " #sys.path.append('./AdaBins')\n",
+ " #from infer import InferenceHelper\n",
+ " from adabins.infer import InferenceHelper\n",
" MAX_ADABINS_AREA = 500000\n",
"\n",
"import torch\n",
"DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
- "print('Using device:', DEVICE)\n",
+ "logger.debug('Using device:', DEVICE)\n",
"device = DEVICE # At least one of the modules expects this name..\n",
"\n",
- "if torch.cuda.get_device_capability(DEVICE) == (8,0): ## A100 fix thanks to Emad\n",
- " print('Disabling CUDNN for A100 gpu', file=sys.stderr)\n",
- " torch.backends.cudnn.enabled = False"
+ "from disco.common import a100_cudnn_fix\n",
+ "a100_cudnn_fix(DEVICE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
- "id": "BLk3J0h3MtON",
- "cellView": "form"
+ "cellView": "form",
+ "id": "BLk3J0h3MtON"
},
"outputs": [],
"source": [
@@ -624,291 +640,30 @@
"source": [
"#@title 1.5 Define necessary functions\n",
"\n",
- "# https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869\n",
- "\n",
- "import pytorch3d.transforms as p3dT\n",
- "import disco_xform_utils as dxf\n",
- "\n",
- "def interp(t):\n",
- " return 3 * t**2 - 2 * t ** 3\n",
- "\n",
- "def perlin(width, height, scale=10, device=None):\n",
- " gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)\n",
- " xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)\n",
- " ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)\n",
- " wx = 1 - interp(xs)\n",
- " wy = 1 - interp(ys)\n",
- " dots = 0\n",
- " dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)\n",
- " dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)\n",
- " dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))\n",
- " dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))\n",
- " return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)\n",
- "\n",
- "def perlin_ms(octaves, width, height, grayscale, device=device):\n",
- " out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]\n",
- " # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]\n",
- " for i in range(1 if grayscale else 3):\n",
- " scale = 2 ** len(octaves)\n",
- " oct_width = width\n",
- " oct_height = height\n",
- " for oct in octaves:\n",
- " p = perlin(oct_width, oct_height, scale, device)\n",
- " out_array[i] += p * oct\n",
- " scale //= 2\n",
- " oct_width *= 2\n",
- " oct_height *= 2\n",
- " return torch.cat(out_array)\n",
- "\n",
- "def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):\n",
- " out = perlin_ms(octaves, width, height, grayscale)\n",
- " if grayscale:\n",
- " out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0))\n",
- " out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')\n",
- " else:\n",
- " out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1])\n",
- " out = TF.resize(size=(side_y, side_x), img=out)\n",
- " out = TF.to_pil_image(out.clamp(0, 1).squeeze())\n",
- "\n",
- " out = ImageOps.autocontrast(out)\n",
- " return out\n",
- "\n",
- "def regen_perlin():\n",
- " if perlin_mode == 'color':\n",
- " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
- " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)\n",
- " elif perlin_mode == 'gray':\n",
- " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)\n",
- " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
- " else:\n",
- " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
- " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
- "\n",
- " init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)\n",
- " del init2\n",
- " return init.expand(batch_size, -1, -1, -1)\n",
- "\n",
- "def fetch(url_or_path):\n",
- " if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):\n",
- " r = requests.get(url_or_path)\n",
- " r.raise_for_status()\n",
- " fd = io.BytesIO()\n",
- " fd.write(r.content)\n",
- " fd.seek(0)\n",
- " return fd\n",
- " return open(url_or_path, 'rb')\n",
- "\n",
- "def read_image_workaround(path):\n",
- " \"\"\"OpenCV reads images as BGR, Pillow saves them as RGB. Work around\n",
- " this incompatibility to avoid colour inversions.\"\"\"\n",
- " im_tmp = cv2.imread(path)\n",
- " return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB)\n",
- "\n",
- "def parse_prompt(prompt):\n",
- " if prompt.startswith('http://') or prompt.startswith('https://'):\n",
- " vals = prompt.rsplit(':', 2)\n",
- " vals = [vals[0] + ':' + vals[1], *vals[2:]]\n",
- " else:\n",
- " vals = prompt.rsplit(':', 1)\n",
- " vals = vals + ['', '1'][len(vals):]\n",
- " return vals[0], float(vals[1])\n",
- "\n",
- "def sinc(x):\n",
- " return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))\n",
- "\n",
- "def lanczos(x, a):\n",
- " cond = torch.logical_and(-a < x, x < a)\n",
- " out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))\n",
- " return out / out.sum()\n",
- "\n",
- "def ramp(ratio, width):\n",
- " n = math.ceil(width / ratio + 1)\n",
- " out = torch.empty([n])\n",
- " cur = 0\n",
- " for i in range(out.shape[0]):\n",
- " out[i] = cur\n",
- " cur += ratio\n",
- " return torch.cat([-out[1:].flip([0]), out])[1:-1]\n",
- "\n",
- "def resample(input, size, align_corners=True):\n",
- " n, c, h, w = input.shape\n",
- " dh, dw = size\n",
- "\n",
- " input = input.reshape([n * c, 1, h, w])\n",
- "\n",
- " if dh < h:\n",
- " kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)\n",
- " pad_h = (kernel_h.shape[0] - 1) // 2\n",
- " input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')\n",
- " input = F.conv2d(input, kernel_h[None, None, :, None])\n",
- "\n",
- " if dw < w:\n",
- " kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)\n",
- " pad_w = (kernel_w.shape[0] - 1) // 2\n",
- " input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')\n",
- " input = F.conv2d(input, kernel_w[None, None, None, :])\n",
- "\n",
- " input = input.reshape([n, c, h, w])\n",
- " return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)\n",
- "\n",
- "class MakeCutouts(nn.Module):\n",
- " def __init__(self, cut_size, cutn, skip_augs=False):\n",
- " super().__init__()\n",
- " self.cut_size = cut_size\n",
- " self.cutn = cutn\n",
- " self.skip_augs = skip_augs\n",
- " self.augs = T.Compose([\n",
- " T.RandomHorizontalFlip(p=0.5),\n",
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
- " T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n",
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
- " T.RandomPerspective(distortion_scale=0.4, p=0.7),\n",
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
- " T.RandomGrayscale(p=0.15),\n",
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
- " # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
- " ])\n",
- "\n",
- " def forward(self, input):\n",
- " input = T.Pad(input.shape[2]//4, fill=0)(input)\n",
- " sideY, sideX = input.shape[2:4]\n",
- " max_size = min(sideX, sideY)\n",
- "\n",
- " cutouts = []\n",
- " for ch in range(self.cutn):\n",
- " if ch > self.cutn - self.cutn//4:\n",
- " cutout = input.clone()\n",
- " else:\n",
- " size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))\n",
- " offsetx = torch.randint(0, abs(sideX - size + 1), ())\n",
- " offsety = torch.randint(0, abs(sideY - size + 1), ())\n",
- " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n",
- "\n",
- " if not self.skip_augs:\n",
- " cutout = self.augs(cutout)\n",
- " cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))\n",
- " del cutout\n",
- "\n",
- " cutouts = torch.cat(cutouts, dim=0)\n",
- " return cutouts\n",
+ "from disco.common import (\n",
+ " interp,\n",
+ " perlin,\n",
+ " perlin_ms,\n",
+ " create_perlin_noise,\n",
+ " regen_perlin,\n",
+ " fetch,\n",
+ " read_image_workaround,\n",
+ " parse_prompt,\n",
+ " sinc,\n",
+ " lanczos,\n",
+ " ramp,\n",
+ " resample,\n",
+ " MakeCutouts,\n",
+ " MakeCutoutsDango,\n",
+ " spherical_dist_loss,\n",
+ " tv_loss,\n",
+ " range_loss,\n",
+ ")\n",
+ "\n",
+ "from disco.models.modules import alpha_sigma_to_t\n",
"\n",
"cutout_debug = False\n",
"padargs = {}\n",
- "\n",
- "class MakeCutoutsDango(nn.Module):\n",
- " def __init__(self, cut_size,\n",
- " Overview=4, \n",
- " InnerCrop = 0, IC_Size_Pow=0.5, IC_Grey_P = 0.2\n",
- " ):\n",
- " super().__init__()\n",
- " self.cut_size = cut_size\n",
- " self.Overview = Overview\n",
- " self.InnerCrop = InnerCrop\n",
- " self.IC_Size_Pow = IC_Size_Pow\n",
- " self.IC_Grey_P = IC_Grey_P\n",
- " if args.animation_mode == 'None':\n",
- " self.augs = T.Compose([\n",
- " T.RandomHorizontalFlip(p=0.5),\n",
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
- " T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),\n",
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
- " T.RandomGrayscale(p=0.1),\n",
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
- " T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
- " ])\n",
- " elif args.animation_mode == 'Video Input':\n",
- " self.augs = T.Compose([\n",
- " T.RandomHorizontalFlip(p=0.5),\n",
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
- " T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n",
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
- " T.RandomPerspective(distortion_scale=0.4, p=0.7),\n",
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
- " T.RandomGrayscale(p=0.15),\n",
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
- " # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
- " ])\n",
- " elif args.animation_mode == '2D' or args.animation_mode == '3D':\n",
- " self.augs = T.Compose([\n",
- " T.RandomHorizontalFlip(p=0.4),\n",
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
- " T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),\n",
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
- " T.RandomGrayscale(p=0.1),\n",
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
- " T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.3),\n",
- " ])\n",
- " \n",
- "\n",
- " def forward(self, input):\n",
- " cutouts = []\n",
- " gray = T.Grayscale(3)\n",
- " sideY, sideX = input.shape[2:4]\n",
- " max_size = min(sideX, sideY)\n",
- " min_size = min(sideX, sideY, self.cut_size)\n",
- " l_size = max(sideX, sideY)\n",
- " output_shape = [1,3,self.cut_size,self.cut_size] \n",
- " output_shape_2 = [1,3,self.cut_size+2,self.cut_size+2]\n",
- " pad_input = F.pad(input,((sideY-max_size)//2,(sideY-max_size)//2,(sideX-max_size)//2,(sideX-max_size)//2), **padargs)\n",
- " cutout = resize(pad_input, out_shape=output_shape)\n",
- "\n",
- " if self.Overview>0:\n",
- " if self.Overview<=4:\n",
- " if self.Overview>=1:\n",
- " cutouts.append(cutout)\n",
- " if self.Overview>=2:\n",
- " cutouts.append(gray(cutout))\n",
- " if self.Overview>=3:\n",
- " cutouts.append(TF.hflip(cutout))\n",
- " if self.Overview==4:\n",
- " cutouts.append(gray(TF.hflip(cutout)))\n",
- " else:\n",
- " cutout = resize(pad_input, out_shape=output_shape)\n",
- " for _ in range(self.Overview):\n",
- " cutouts.append(cutout)\n",
- "\n",
- " if cutout_debug:\n",
- " if is_colab:\n",
- " TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(\"/content/cutout_overview0.jpg\",quality=99)\n",
- " else:\n",
- " TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(\"cutout_overview0.jpg\",quality=99)\n",
- "\n",
- " \n",
- " if self.InnerCrop >0:\n",
- " for i in range(self.InnerCrop):\n",
- " size = int(torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + min_size)\n",
- " offsetx = torch.randint(0, sideX - size + 1, ())\n",
- " offsety = torch.randint(0, sideY - size + 1, ())\n",
- " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n",
- " if i <= int(self.IC_Grey_P * self.InnerCrop):\n",
- " cutout = gray(cutout)\n",
- " cutout = resize(cutout, out_shape=output_shape)\n",
- " cutouts.append(cutout)\n",
- " if cutout_debug:\n",
- " if is_colab:\n",
- " TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(\"/content/cutout_InnerCrop.jpg\",quality=99)\n",
- " else:\n",
- " TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(\"cutout_InnerCrop.jpg\",quality=99)\n",
- " cutouts = torch.cat(cutouts)\n",
- " if skip_augs is not True: cutouts=self.augs(cutouts)\n",
- " return cutouts\n",
- "\n",
- "def spherical_dist_loss(x, y):\n",
- " x = F.normalize(x, dim=-1)\n",
- " y = F.normalize(y, dim=-1)\n",
- " return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) \n",
- "\n",
- "def tv_loss(input):\n",
- " \"\"\"L2 total variation loss, as in Mahendran et al.\"\"\"\n",
- " input = F.pad(input, (0, 1, 0, 1), 'replicate')\n",
- " x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]\n",
- " y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]\n",
- " return (x_diff**2 + y_diff**2).mean([1, 2, 3])\n",
- "\n",
- "\n",
- "def range_loss(input):\n",
- " return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])\n",
- "\n",
"stop_on_next_loop = False # Make sure GPU memory doesn't get corrupted from cancelling the run mid-way through, allow a full frame to complete\n",
"\n",
"def do_run():\n",
@@ -1120,6 +875,7 @@
" \n",
" cur_t = None\n",
" \n",
+ " # this prob doesn't need to be a closure...\n",
" def cond_fn(x, t, y=None):\n",
" with torch.enable_grad():\n",
" x_is_NaN = False\n",
@@ -1150,7 +906,14 @@
"\n",
" cuts = MakeCutoutsDango(input_resolution,\n",
" Overview= args.cut_overview[1000-t_int], \n",
- " InnerCrop = args.cut_innercut[1000-t_int], IC_Size_Pow=args.cut_ic_pow, IC_Grey_P = args.cut_icgray_p[1000-t_int]\n",
+ " InnerCrop = args.cut_innercut[1000-t_int], \n",
+ " IC_Size_Pow=args.cut_ic_pow, \n",
+ " IC_Grey_P = args.cut_icgray_p[1000-t_int],\n",
+ " cutout_debug = cutout_debug,\n",
+ " padargs = padargs,\n",
+ " animation_mode = animation_mode, # args.animation_mode \n",
+ " debug_outpath = \"./\", # to do: this should be conditional on `is_colab``\n",
+ " skip_augs=skip_augs,\n",
" )\n",
" clip_in = normalize(cuts(x_in.add(1).div(2)))\n",
" image_embeds = model_stat[\"clip_model\"].encode_image(clip_in).float()\n",
@@ -1384,167 +1147,7 @@
"source": [
"#@title 1.6 Define the secondary diffusion model\n",
"\n",
- "def append_dims(x, n):\n",
- " return x[(Ellipsis, *(None,) * (n - x.ndim))]\n",
- "\n",
- "\n",
- "def expand_to_planes(x, shape):\n",
- " return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]])\n",
- "\n",
- "\n",
- "def alpha_sigma_to_t(alpha, sigma):\n",
- " return torch.atan2(sigma, alpha) * 2 / math.pi\n",
- "\n",
- "\n",
- "def t_to_alpha_sigma(t):\n",
- " return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)\n",
- "\n",
- "\n",
- "@dataclass\n",
- "class DiffusionOutput:\n",
- " v: torch.Tensor\n",
- " pred: torch.Tensor\n",
- " eps: torch.Tensor\n",
- "\n",
- "\n",
- "class ConvBlock(nn.Sequential):\n",
- " def __init__(self, c_in, c_out):\n",
- " super().__init__(\n",
- " nn.Conv2d(c_in, c_out, 3, padding=1),\n",
- " nn.ReLU(inplace=True),\n",
- " )\n",
- "\n",
- "\n",
- "class SkipBlock(nn.Module):\n",
- " def __init__(self, main, skip=None):\n",
- " super().__init__()\n",
- " self.main = nn.Sequential(*main)\n",
- " self.skip = skip if skip else nn.Identity()\n",
- "\n",
- " def forward(self, input):\n",
- " return torch.cat([self.main(input), self.skip(input)], dim=1)\n",
- "\n",
- "\n",
- "class FourierFeatures(nn.Module):\n",
- " def __init__(self, in_features, out_features, std=1.):\n",
- " super().__init__()\n",
- " assert out_features % 2 == 0\n",
- " self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)\n",
- "\n",
- " def forward(self, input):\n",
- " f = 2 * math.pi * input @ self.weight.T\n",
- " return torch.cat([f.cos(), f.sin()], dim=-1)\n",
- "\n",
- "\n",
- "class SecondaryDiffusionImageNet(nn.Module):\n",
- " def __init__(self):\n",
- " super().__init__()\n",
- " c = 64 # The base channel count\n",
- "\n",
- " self.timestep_embed = FourierFeatures(1, 16)\n",
- "\n",
- " self.net = nn.Sequential(\n",
- " ConvBlock(3 + 16, c),\n",
- " ConvBlock(c, c),\n",
- " SkipBlock([\n",
- " nn.AvgPool2d(2),\n",
- " ConvBlock(c, c * 2),\n",
- " ConvBlock(c * 2, c * 2),\n",
- " SkipBlock([\n",
- " nn.AvgPool2d(2),\n",
- " ConvBlock(c * 2, c * 4),\n",
- " ConvBlock(c * 4, c * 4),\n",
- " SkipBlock([\n",
- " nn.AvgPool2d(2),\n",
- " ConvBlock(c * 4, c * 8),\n",
- " ConvBlock(c * 8, c * 4),\n",
- " nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
- " ]),\n",
- " ConvBlock(c * 8, c * 4),\n",
- " ConvBlock(c * 4, c * 2),\n",
- " nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
- " ]),\n",
- " ConvBlock(c * 4, c * 2),\n",
- " ConvBlock(c * 2, c),\n",
- " nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
- " ]),\n",
- " ConvBlock(c * 2, c),\n",
- " nn.Conv2d(c, 3, 3, padding=1),\n",
- " )\n",
- "\n",
- " def forward(self, input, t):\n",
- " timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)\n",
- " v = self.net(torch.cat([input, timestep_embed], dim=1))\n",
- " alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))\n",
- " pred = input * alphas - v * sigmas\n",
- " eps = input * sigmas + v * alphas\n",
- " return DiffusionOutput(v, pred, eps)\n",
- "\n",
- "\n",
- "class SecondaryDiffusionImageNet2(nn.Module):\n",
- " def __init__(self):\n",
- " super().__init__()\n",
- " c = 64 # The base channel count\n",
- " cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8]\n",
- "\n",
- " self.timestep_embed = FourierFeatures(1, 16)\n",
- " self.down = nn.AvgPool2d(2)\n",
- " self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)\n",
- "\n",
- " self.net = nn.Sequential(\n",
- " ConvBlock(3 + 16, cs[0]),\n",
- " ConvBlock(cs[0], cs[0]),\n",
- " SkipBlock([\n",
- " self.down,\n",
- " ConvBlock(cs[0], cs[1]),\n",
- " ConvBlock(cs[1], cs[1]),\n",
- " SkipBlock([\n",
- " self.down,\n",
- " ConvBlock(cs[1], cs[2]),\n",
- " ConvBlock(cs[2], cs[2]),\n",
- " SkipBlock([\n",
- " self.down,\n",
- " ConvBlock(cs[2], cs[3]),\n",
- " ConvBlock(cs[3], cs[3]),\n",
- " SkipBlock([\n",
- " self.down,\n",
- " ConvBlock(cs[3], cs[4]),\n",
- " ConvBlock(cs[4], cs[4]),\n",
- " SkipBlock([\n",
- " self.down,\n",
- " ConvBlock(cs[4], cs[5]),\n",
- " ConvBlock(cs[5], cs[5]),\n",
- " ConvBlock(cs[5], cs[5]),\n",
- " ConvBlock(cs[5], cs[4]),\n",
- " self.up,\n",
- " ]),\n",
- " ConvBlock(cs[4] * 2, cs[4]),\n",
- " ConvBlock(cs[4], cs[3]),\n",
- " self.up,\n",
- " ]),\n",
- " ConvBlock(cs[3] * 2, cs[3]),\n",
- " ConvBlock(cs[3], cs[2]),\n",
- " self.up,\n",
- " ]),\n",
- " ConvBlock(cs[2] * 2, cs[2]),\n",
- " ConvBlock(cs[2], cs[1]),\n",
- " self.up,\n",
- " ]),\n",
- " ConvBlock(cs[1] * 2, cs[1]),\n",
- " ConvBlock(cs[1], cs[0]),\n",
- " self.up,\n",
- " ]),\n",
- " ConvBlock(cs[0] * 2, cs[0]),\n",
- " nn.Conv2d(cs[0], 3, 3, padding=1),\n",
- " )\n",
- "\n",
- " def forward(self, input, t):\n",
- " timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)\n",
- " v = self.net(torch.cat([input, timestep_embed], dim=1))\n",
- " alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))\n",
- " pred = input * alphas - v * sigmas\n",
- " eps = input * sigmas + v * alphas\n",
- " return DiffusionOutput(v, pred, eps)\n"
+ "from disco.models import SecondaryDiffusionImageNet, SecondaryDiffusionImageNet2"
]
},
{
@@ -1557,222 +1160,14 @@
"outputs": [],
"source": [
"#@title 1.7 SuperRes Define\n",
- "class DDIMSampler(object):\n",
- " def __init__(self, model, schedule=\"linear\", **kwargs):\n",
- " super().__init__()\n",
- " self.model = model\n",
- " self.ddpm_num_timesteps = model.num_timesteps\n",
- " self.schedule = schedule\n",
- "\n",
- " def register_buffer(self, name, attr):\n",
- " if type(attr) == torch.Tensor:\n",
- " if attr.device != torch.device(\"cuda\"):\n",
- " attr = attr.to(torch.device(\"cuda\"))\n",
- " setattr(self, name, attr)\n",
- "\n",
- " def make_schedule(self, ddim_num_steps, ddim_discretize=\"uniform\", ddim_eta=0., verbose=True):\n",
- " self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,\n",
- " num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)\n",
- " alphas_cumprod = self.model.alphas_cumprod\n",
- " assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'\n",
- " to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)\n",
- "\n",
- " self.register_buffer('betas', to_torch(self.model.betas))\n",
- " self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n",
- " self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))\n",
- "\n",
- " # calculations for diffusion q(x_t | x_{t-1}) and others\n",
- " self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))\n",
- " self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))\n",
- " self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))\n",
- " self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))\n",
- " self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))\n",
- "\n",
- " # ddim sampling parameters\n",
- " ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),\n",
- " ddim_timesteps=self.ddim_timesteps,\n",
- " eta=ddim_eta,verbose=verbose)\n",
- " self.register_buffer('ddim_sigmas', ddim_sigmas)\n",
- " self.register_buffer('ddim_alphas', ddim_alphas)\n",
- " self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)\n",
- " self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))\n",
- " sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(\n",
- " (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (\n",
- " 1 - self.alphas_cumprod / self.alphas_cumprod_prev))\n",
- " self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)\n",
- "\n",
- " @torch.no_grad()\n",
- " def sample(self,\n",
- " S,\n",
- " batch_size,\n",
- " shape,\n",
- " conditioning=None,\n",
- " callback=None,\n",
- " normals_sequence=None,\n",
- " img_callback=None,\n",
- " quantize_x0=False,\n",
- " eta=0.,\n",
- " mask=None,\n",
- " x0=None,\n",
- " temperature=1.,\n",
- " noise_dropout=0.,\n",
- " score_corrector=None,\n",
- " corrector_kwargs=None,\n",
- " verbose=True,\n",
- " x_T=None,\n",
- " log_every_t=100,\n",
- " **kwargs\n",
- " ):\n",
- " if conditioning is not None:\n",
- " if isinstance(conditioning, dict):\n",
- " cbs = conditioning[list(conditioning.keys())[0]].shape[0]\n",
- " if cbs != batch_size:\n",
- " print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n",
- " else:\n",
- " if conditioning.shape[0] != batch_size:\n",
- " print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n",
- "\n",
- " self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n",
- " # sampling\n",
- " C, H, W = shape\n",
- " size = (batch_size, C, H, W)\n",
- " # print(f'Data shape for DDIM sampling is {size}, eta {eta}')\n",
- "\n",
- " samples, intermediates = self.ddim_sampling(conditioning, size,\n",
- " callback=callback,\n",
- " img_callback=img_callback,\n",
- " quantize_denoised=quantize_x0,\n",
- " mask=mask, x0=x0,\n",
- " ddim_use_original_steps=False,\n",
- " noise_dropout=noise_dropout,\n",
- " temperature=temperature,\n",
- " score_corrector=score_corrector,\n",
- " corrector_kwargs=corrector_kwargs,\n",
- " x_T=x_T,\n",
- " log_every_t=log_every_t\n",
- " )\n",
- " return samples, intermediates\n",
- "\n",
- " @torch.no_grad()\n",
- " def ddim_sampling(self, cond, shape,\n",
- " x_T=None, ddim_use_original_steps=False,\n",
- " callback=None, timesteps=None, quantize_denoised=False,\n",
- " mask=None, x0=None, img_callback=None, log_every_t=100,\n",
- " temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n",
- " device = self.model.betas.device\n",
- " b = shape[0]\n",
- " if x_T is None:\n",
- " img = torch.randn(shape, device=device)\n",
- " else:\n",
- " img = x_T\n",
- "\n",
- " if timesteps is None:\n",
- " timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps\n",
- " elif timesteps is not None and not ddim_use_original_steps:\n",
- " subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1\n",
- " timesteps = self.ddim_timesteps[:subset_end]\n",
- "\n",
- " intermediates = {'x_inter': [img], 'pred_x0': [img]}\n",
- " time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)\n",
- " total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]\n",
- " print(f\"Running DDIM Sharpening with {total_steps} timesteps\")\n",
- "\n",
- " iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps)\n",
- "\n",
- " for i, step in enumerate(iterator):\n",
- " index = total_steps - i - 1\n",
- " ts = torch.full((b,), step, device=device, dtype=torch.long)\n",
- "\n",
- " if mask is not None:\n",
- " assert x0 is not None\n",
- " img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?\n",
- " img = img_orig * mask + (1. - mask) * img\n",
- "\n",
- " outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,\n",
- " quantize_denoised=quantize_denoised, temperature=temperature,\n",
- " noise_dropout=noise_dropout, score_corrector=score_corrector,\n",
- " corrector_kwargs=corrector_kwargs)\n",
- " img, pred_x0 = outs\n",
- " if callback: callback(i)\n",
- " if img_callback: img_callback(pred_x0, i)\n",
- "\n",
- " if index % log_every_t == 0 or index == total_steps - 1:\n",
- " intermediates['x_inter'].append(img)\n",
- " intermediates['pred_x0'].append(pred_x0)\n",
- "\n",
- " return img, intermediates\n",
- "\n",
- " @torch.no_grad()\n",
- " def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,\n",
- " temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n",
- " b, *_, device = *x.shape, x.device\n",
- " e_t = self.model.apply_model(x, t, c)\n",
- " if score_corrector is not None:\n",
- " assert self.model.parameterization == \"eps\"\n",
- " e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)\n",
- "\n",
- " alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas\n",
- " alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev\n",
- " sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas\n",
- " sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas\n",
- " # select parameters corresponding to the currently considered timestep\n",
- " a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)\n",
- " a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)\n",
- " sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)\n",
- " sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)\n",
- "\n",
- " # current prediction for x_0\n",
- " pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\n",
- " if quantize_denoised:\n",
- " pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)\n",
- " # direction pointing to x_t\n",
- " dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t\n",
- " noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature\n",
- " if noise_dropout > 0.:\n",
- " noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n",
- " x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise\n",
- " return x_prev, pred_x0\n",
- "\n",
- "\n",
- "def download_models(mode):\n",
- "\n",
- " if mode == \"superresolution\":\n",
- " # this is the small bsr light model\n",
- " url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'\n",
- " url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'\n",
- "\n",
- " path_conf = f'{model_path}/superres/project.yaml'\n",
- " path_ckpt = f'{model_path}/superres/last.ckpt'\n",
- "\n",
- " download_url(url_conf, path_conf)\n",
- " download_url(url_ckpt, path_ckpt)\n",
- "\n",
- " path_conf = path_conf + '/?dl=1' # fix it\n",
- " path_ckpt = path_ckpt + '/?dl=1' # fix it\n",
- " return path_conf, path_ckpt\n",
- "\n",
- " else:\n",
- " raise NotImplementedError\n",
- "\n",
- "\n",
- "def load_model_from_config(config, ckpt):\n",
- " print(f\"Loading model from {ckpt}\")\n",
- " pl_sd = torch.load(ckpt, map_location=\"cpu\")\n",
- " global_step = pl_sd[\"global_step\"]\n",
- " sd = pl_sd[\"state_dict\"]\n",
- " model = instantiate_from_config(config.model)\n",
- " m, u = model.load_state_dict(sd, strict=False)\n",
- " model.cuda()\n",
- " model.eval()\n",
- " return {\"model\": model}, global_step\n",
- "\n",
- "\n",
- "def get_model(mode):\n",
- " path_conf, path_ckpt = download_models(mode)\n",
- " config = OmegaConf.load(path_conf)\n",
- " model, step = load_model_from_config(config, path_ckpt)\n",
- " return model\n",
"\n",
+ "from disco.models import DDIMSampler\n",
+ "from disco.models.ddimsampler import (\n",
+ " download_models,\n",
+ " load_model_from_config,\n",
+ " get_model,\n",
+ " \n",
+ ")\n",
"\n",
"def get_custom_cond(mode):\n",
" dest = \"data/example_conditioning\"\n",
@@ -1989,7 +1384,7 @@
" return log\n",
"\n",
"sr_diffMode = 'superresolution'\n",
- "sr_model = get_model('superresolution')\n",
+ "sr_model = get_model('superresolution', model_path)\n",
"\n",
"\n",
"\n",
@@ -2419,238 +1814,46 @@
"#@markdown `frame_skip_steps` will blur the previous frame - higher values will flicker less but struggle to add enough new detail to zoom into.\n",
"frames_skip_steps = '60%' #@param ['40%', '50%', '60%', '70%', '80%'] {type: 'string'}\n",
"\n",
- "\n",
- "def parse_key_frames(string, prompt_parser=None):\n",
- " \"\"\"Given a string representing frame numbers paired with parameter values at that frame,\n",
- " return a dictionary with the frame numbers as keys and the parameter values as the values.\n",
- "\n",
- " Parameters\n",
- " ----------\n",
- " string: string\n",
- " Frame numbers paired with parameter values at that frame number, in the format\n",
- " 'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...'\n",
- " prompt_parser: function or None, optional\n",
- " If provided, prompt_parser will be applied to each string of parameter values.\n",
- " \n",
- " Returns\n",
- " -------\n",
- " dict\n",
- " Frame numbers as keys, parameter values at that frame number as values\n",
- "\n",
- " Raises\n",
- " ------\n",
- " RuntimeError\n",
- " If the input string does not match the expected format.\n",
- " \n",
- " Examples\n",
- " --------\n",
- " >>> parse_key_frames(\"10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)\")\n",
- " {10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'}\n",
- "\n",
- " >>> parse_key_frames(\"10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)\", prompt_parser=lambda x: x.lower()))\n",
- " {10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'}\n",
- " \"\"\"\n",
- " import re\n",
- " pattern = r'((?P[0-9]+):[\\s]*[\\(](?P[\\S\\s]*?)[\\)])'\n",
- " frames = dict()\n",
- " for match_object in re.finditer(pattern, string):\n",
- " frame = int(match_object.groupdict()['frame'])\n",
- " param = match_object.groupdict()['param']\n",
- " if prompt_parser:\n",
- " frames[frame] = prompt_parser(param)\n",
- " else:\n",
- " frames[frame] = param\n",
- "\n",
- " if frames == {} and len(string) != 0:\n",
- " raise RuntimeError('Key Frame string not correctly formatted')\n",
- " return frames\n",
- "\n",
- "def get_inbetweens(key_frames, integer=False):\n",
- " \"\"\"Given a dict with frame numbers as keys and a parameter value as values,\n",
- " return a pandas Series containing the value of the parameter at every frame from 0 to max_frames.\n",
- " Any values not provided in the input dict are calculated by linear interpolation between\n",
- " the values of the previous and next provided frames. If there is no previous provided frame, then\n",
- " the value is equal to the value of the next provided frame, or if there is no next provided frame,\n",
- " then the value is equal to the value of the previous provided frame. If no frames are provided,\n",
- " all frame values are NaN.\n",
- "\n",
- " Parameters\n",
- " ----------\n",
- " key_frames: dict\n",
- " A dict with integer frame numbers as keys and numerical values of a particular parameter as values.\n",
- " integer: Bool, optional\n",
- " If True, the values of the output series are converted to integers.\n",
- " Otherwise, the values are floats.\n",
- " \n",
- " Returns\n",
- " -------\n",
- " pd.Series\n",
- " A Series with length max_frames representing the parameter values for each frame.\n",
- " \n",
- " Examples\n",
- " --------\n",
- " >>> max_frames = 5\n",
- " >>> get_inbetweens({1: 5, 3: 6})\n",
- " 0 5.0\n",
- " 1 5.0\n",
- " 2 5.5\n",
- " 3 6.0\n",
- " 4 6.0\n",
- " dtype: float64\n",
- "\n",
- " >>> get_inbetweens({1: 5, 3: 6}, integer=True)\n",
- " 0 5\n",
- " 1 5\n",
- " 2 5\n",
- " 3 6\n",
- " 4 6\n",
- " dtype: int64\n",
- " \"\"\"\n",
- " key_frame_series = pd.Series([np.nan for a in range(max_frames)])\n",
- "\n",
- " for i, value in key_frames.items():\n",
- " key_frame_series[i] = value\n",
- " key_frame_series = key_frame_series.astype(float)\n",
- " \n",
- " interp_method = interp_spline\n",
- "\n",
- " if interp_method == 'Cubic' and len(key_frames.items()) <=3:\n",
- " interp_method = 'Quadratic'\n",
- " \n",
- " if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:\n",
- " interp_method = 'Linear'\n",
- " \n",
- " \n",
- " key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]\n",
- " key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]\n",
- " # key_frame_series = key_frame_series.interpolate(method=intrp_method,order=1, limit_direction='both')\n",
- " key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both')\n",
- " if integer:\n",
- " return key_frame_series.astype(int)\n",
- " return key_frame_series\n",
- "\n",
- "def split_prompts(prompts):\n",
- " prompt_series = pd.Series([np.nan for a in range(max_frames)])\n",
- " for i, prompt in prompts.items():\n",
- " prompt_series[i] = prompt\n",
- " # prompt_series = prompt_series.astype(str)\n",
- " prompt_series = prompt_series.ffill().bfill()\n",
- " return prompt_series\n",
+ "from disco.animation import (\n",
+ " parse_key_frames,\n",
+ " get_inbetweens,\n",
+ " split_prompts,\n",
+ " process_keyframe_animation,\n",
+ ")\n",
"\n",
"if key_frames:\n",
- " try:\n",
- " angle_series = get_inbetweens(parse_key_frames(angle))\n",
- " except RuntimeError as e:\n",
- " print(\n",
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
- " \"formatted `angle` correctly for key frames.\\n\"\n",
- " \"Attempting to interpret `angle` as \"\n",
- " f'\"0: ({angle})\"\\n'\n",
- " \"Please read the instructions to find out how to use key frames \"\n",
- " \"correctly.\\n\"\n",
- " )\n",
- " angle = f\"0: ({angle})\"\n",
- " angle_series = get_inbetweens(parse_key_frames(angle))\n",
- "\n",
- " try:\n",
- " zoom_series = get_inbetweens(parse_key_frames(zoom))\n",
- " except RuntimeError as e:\n",
- " print(\n",
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
- " \"formatted `zoom` correctly for key frames.\\n\"\n",
- " \"Attempting to interpret `zoom` as \"\n",
- " f'\"0: ({zoom})\"\\n'\n",
- " \"Please read the instructions to find out how to use key frames \"\n",
- " \"correctly.\\n\"\n",
- " )\n",
- " zoom = f\"0: ({zoom})\"\n",
- " zoom_series = get_inbetweens(parse_key_frames(zoom))\n",
- "\n",
- " try:\n",
- " translation_x_series = get_inbetweens(parse_key_frames(translation_x))\n",
- " except RuntimeError as e:\n",
- " print(\n",
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
- " \"formatted `translation_x` correctly for key frames.\\n\"\n",
- " \"Attempting to interpret `translation_x` as \"\n",
- " f'\"0: ({translation_x})\"\\n'\n",
- " \"Please read the instructions to find out how to use key frames \"\n",
- " \"correctly.\\n\"\n",
- " )\n",
- " translation_x = f\"0: ({translation_x})\"\n",
- " translation_x_series = get_inbetweens(parse_key_frames(translation_x))\n",
- "\n",
- " try:\n",
- " translation_y_series = get_inbetweens(parse_key_frames(translation_y))\n",
- " except RuntimeError as e:\n",
- " print(\n",
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
- " \"formatted `translation_y` correctly for key frames.\\n\"\n",
- " \"Attempting to interpret `translation_y` as \"\n",
- " f'\"0: ({translation_y})\"\\n'\n",
- " \"Please read the instructions to find out how to use key frames \"\n",
- " \"correctly.\\n\"\n",
- " )\n",
- " translation_y = f\"0: ({translation_y})\"\n",
- " translation_y_series = get_inbetweens(parse_key_frames(translation_y))\n",
- "\n",
- " try:\n",
- " translation_z_series = get_inbetweens(parse_key_frames(translation_z))\n",
- " except RuntimeError as e:\n",
- " print(\n",
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
- " \"formatted `translation_z` correctly for key frames.\\n\"\n",
- " \"Attempting to interpret `translation_z` as \"\n",
- " f'\"0: ({translation_z})\"\\n'\n",
- " \"Please read the instructions to find out how to use key frames \"\n",
- " \"correctly.\\n\"\n",
- " )\n",
- " translation_z = f\"0: ({translation_z})\"\n",
- " translation_z_series = get_inbetweens(parse_key_frames(translation_z))\n",
- "\n",
- " try:\n",
- " rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))\n",
- " except RuntimeError as e:\n",
- " print(\n",
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
- " \"formatted `rotation_3d_x` correctly for key frames.\\n\"\n",
- " \"Attempting to interpret `rotation_3d_x` as \"\n",
- " f'\"0: ({rotation_3d_x})\"\\n'\n",
- " \"Please read the instructions to find out how to use key frames \"\n",
- " \"correctly.\\n\"\n",
- " )\n",
- " rotation_3d_x = f\"0: ({rotation_3d_x})\"\n",
- " rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))\n",
- "\n",
- " try:\n",
- " rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))\n",
- " except RuntimeError as e:\n",
- " print(\n",
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
- " \"formatted `rotation_3d_y` correctly for key frames.\\n\"\n",
- " \"Attempting to interpret `rotation_3d_y` as \"\n",
- " f'\"0: ({rotation_3d_y})\"\\n'\n",
- " \"Please read the instructions to find out how to use key frames \"\n",
- " \"correctly.\\n\"\n",
- " )\n",
- " rotation_3d_y = f\"0: ({rotation_3d_y})\"\n",
- " rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))\n",
- "\n",
- " try:\n",
- " rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))\n",
- " except RuntimeError as e:\n",
- " print(\n",
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
- " \"formatted `rotation_3d_z` correctly for key frames.\\n\"\n",
- " \"Attempting to interpret `rotation_3d_z` as \"\n",
- " f'\"0: ({rotation_3d_z})\"\\n'\n",
- " \"Please read the instructions to find out how to use key frames \"\n",
- " \"correctly.\\n\"\n",
- " )\n",
- " rotation_3d_z = f\"0: ({rotation_3d_z})\"\n",
- " rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))\n",
+ " processed = process_keyframe_animation(\n",
+ " angle,\n",
+ " zoom,\n",
+ " translation_x,\n",
+ " translation_y,\n",
+ " translation_z,\n",
+ " rotation_3d_x,\n",
+ " rotation_3d_y,\n",
+ " rotation_3d_z,\n",
+ " key_frames,\n",
+ " interp_spline,\n",
+ " max_frames,\n",
+ " )\n",
"\n",
- "else:\n",
+ " (angle_series,\n",
+ " zoom_series,\n",
+ " translation_x_series,\n",
+ " translation_y_series,\n",
+ " translation_z_series,\n",
+ " rotation_3d_x_series,\n",
+ " rotation_3d_y_series,\n",
+ " rotation_3d_z_series) = (\n",
+ " processed['angle_series'],\n",
+ " processed['zoom_series'],\n",
+ " processed['translation_x_series'],\n",
+ " processed['translation_y_series'],\n",
+ " processed['translation_z_series'],\n",
+ " processed['rotation_3d_x_series'],\n",
+ " processed['rotation_3d_y_series'],\n",
+ " processed['rotation_3d_z_series'],\n",
+ " )\n",
+ "else: \n",
" angle = float(angle)\n",
" zoom = float(zoom)\n",
" translation_x = float(translation_x)\n",
@@ -2658,7 +1861,7 @@
" translation_z = float(translation_z)\n",
" rotation_3d_x = float(rotation_3d_x)\n",
" rotation_3d_y = float(rotation_3d_y)\n",
- " rotation_3d_z = float(rotation_3d_z)"
+ " rotation_3d_z = float(rotation_3d_z)\n"
]
},
{
@@ -2865,8 +2068,8 @@
"\n",
"args = {\n",
" 'batchNum': batchNum,\n",
- " 'prompts_series':split_prompts(text_prompts) if text_prompts else None,\n",
- " 'image_prompts_series':split_prompts(image_prompts) if image_prompts else None,\n",
+ " 'prompts_series':split_prompts(text_prompts, max_frames) if text_prompts else None,\n",
+ " 'image_prompts_series':split_prompts(image_prompts, max_frames) if image_prompts else None,\n",
" 'seed': seed,\n",
" 'display_rate':display_rate,\n",
" 'n_batches':n_batches if animation_mode == 'None' else 1,\n",
@@ -3073,11 +2276,11 @@
"CnkTNXJAPzL2",
"u1VHzHvNx5fd"
],
+ "include_colab_link": true,
"machine_shape": "hm",
"name": "Disco Diffusion v5 [w/ 3D animation]",
"private_outputs": true,
- "provenance": [],
- "include_colab_link": true
+ "provenance": []
},
"kernelspec": {
"display_name": "Python 3",
@@ -3098,4 +2301,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
-}
\ No newline at end of file
+}
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..7fd26b97
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools"]
+build-backend = "setuptools.build_meta"
\ No newline at end of file
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 00000000..13f9fc09
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,7 @@
+[metadata]
+name = pyttitools-disco
+version = 0.0.1
+
+[options]
+install_requires =
+ torch >= 1.10
diff --git a/setup.py b/setup.py
new file mode 100644
index 00000000..02d47a56
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,16 @@
+from setuptools import setup, find_packages, find_namespace_packages
+import logging
+
+p0 = find_packages(where="src")
+p2 = find_namespace_packages(
+ where="src",
+ #include=["hydra_plugins.*"],
+)
+
+setup(
+ packages=p0 + p2,
+ package_dir={
+ "": "src",
+ },
+ #install_requires=["pyttitools-adabins", "pyttitools-gma"],
+)
\ No newline at end of file
diff --git a/src/disco/__init__.py b/src/disco/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/disco/animation.py b/src/disco/animation.py
new file mode 100644
index 00000000..453be07f
--- /dev/null
+++ b/src/disco/animation.py
@@ -0,0 +1,214 @@
+"""
+Isolating animation stuff from disco, probably port it into pytti at some point
+"""
+
+from loguru import logger
+import numpy as np
+import pandas as pd
+
+def parse_key_frames(string, prompt_parser=None):
+ """Given a string representing frame numbers paired with parameter values at that frame,
+ return a dictionary with the frame numbers as keys and the parameter values as the values.
+
+ Parameters
+ ----------
+ string: string
+ Frame numbers paired with parameter values at that frame number, in the format
+ 'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...'
+ prompt_parser: function or None, optional
+ If provided, prompt_parser will be applied to each string of parameter values.
+
+ Returns
+ -------
+ dict
+ Frame numbers as keys, parameter values at that frame number as values
+
+ Raises
+ ------
+ RuntimeError
+ If the input string does not match the expected format.
+
+ Examples
+ --------
+ >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)")
+ {10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'}
+
+ >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)", prompt_parser=lambda x: x.lower()))
+ {10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'}
+ """
+ import re
+ pattern = r'((?P[0-9]+):[\s]*[\(](?P[\S\s]*?)[\)])'
+ frames = dict()
+ for match_object in re.finditer(pattern, string):
+ frame = int(match_object.groupdict()['frame'])
+ param = match_object.groupdict()['param']
+ if prompt_parser:
+ frames[frame] = prompt_parser(param)
+ else:
+ frames[frame] = param
+
+ if frames == {} and len(string) != 0:
+ raise RuntimeError('Key Frame string not correctly formatted')
+ return frames
+
+def get_inbetweens(
+ key_frames,
+ integer=False,
+ # new args
+ interp_spline=None,
+ max_frames=None,
+ ):
+ """Given a dict with frame numbers as keys and a parameter value as values,
+ return a pandas Series containing the value of the parameter at every frame from 0 to max_frames.
+ Any values not provided in the input dict are calculated by linear interpolation between
+ the values of the previous and next provided frames. If there is no previous provided frame, then
+ the value is equal to the value of the next provided frame, or if there is no next provided frame,
+ then the value is equal to the value of the previous provided frame. If no frames are provided,
+ all frame values are NaN.
+
+ Parameters
+ ----------
+ key_frames: dict
+ A dict with integer frame numbers as keys and numerical values of a particular parameter as values.
+ integer: Bool, optional
+ If True, the values of the output series are converted to integers.
+ Otherwise, the values are floats.
+
+ Returns
+ -------
+ pd.Series
+ A Series with length max_frames representing the parameter values for each frame.
+
+ Examples
+ --------
+ >>> max_frames = 5
+ >>> get_inbetweens({1: 5, 3: 6})
+ 0 5.0
+ 1 5.0
+ 2 5.5
+ 3 6.0
+ 4 6.0
+ dtype: float64
+
+ >>> get_inbetweens({1: 5, 3: 6}, integer=True)
+ 0 5
+ 1 5
+ 2 5
+ 3 6
+ 4 6
+ dtype: int64
+ """
+ key_frame_series = pd.Series([np.nan for a in range(max_frames)])
+
+ for i, value in key_frames.items():
+ key_frame_series[i] = value
+ key_frame_series = key_frame_series.astype(float)
+
+ interp_method = interp_spline
+
+ if interp_method == 'Cubic' and len(key_frames.items()) <=3:
+ interp_method = 'Quadratic'
+
+ if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:
+ interp_method = 'Linear'
+
+
+ key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]
+ key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]
+ # key_frame_series = key_frame_series.interpolate(method=intrp_method,order=1, limit_direction='both')
+ key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both')
+ if integer:
+ return key_frame_series.astype(int)
+ return key_frame_series
+
+def split_prompts(
+ prompts,
+ # new args
+ max_frames,
+ ):
+ """
+ The function takes in a dictionary of prompt indices and their corresponding prompts.
+ It then creates a pandas series of nans of length max_frames.
+ It then iterates through the dictionary and fills in the series with the prompts corresponding to
+ the indices.
+ It then backfills and forwards fills to ensure that the series is complete.
+
+ :param prompts: a dictionary of prompt_id -> prompt_text
+ :return: A pandas series with the prompts for each frame.
+ """
+ prompt_series = pd.Series([np.nan for a in range(max_frames)])
+ for i, prompt in prompts.items():
+ prompt_series[i] = prompt
+ # prompt_series = prompt_series.astype(str)
+ prompt_series = prompt_series.ffill().bfill()
+ return prompt_series
+
+def process_keyframe_animation(
+ angle,
+ zoom,
+ translation_x,
+ translation_y,
+ translation_z,
+ rotation_3d_x,
+ rotation_3d_y,
+ rotation_3d_z,
+ key_frames=None,
+ # new args
+ interp_spline=None,
+ max_frames=None,
+):
+ """
+ Given a dictionary of keyframes, return a dictionary of interpolated values
+
+ :param angle: The angle of the camera in degrees
+ :param zoom: The zoom level of the camera
+ :param translation_x: The x-axis translation of the camera
+ :param translation_y: "0: (0), 1: (0), 2: (0), 3: (0), 4: (0), 5: (0), 6: (0), 7: (0), 8: (0), 9:
+ (0), 10: (0), 11: (0
+ :param translation_z: The distance between the camera and the object
+ :param rotation_3d_x: The rotation of the camera around the x axis
+ :param rotation_3d_y: The rotation of the camera around the y axis
+ :param rotation_3d_z: The rotation of the 3D plot around the z axis
+ :param key_frames: A list of key frames
+ :return: A dictionary of keyframes.
+ """
+ assert key_frames is not None
+ outv = {}
+ for k,v in {
+ 'angle':angle,
+ 'zoom':zoom,
+ 'translation_x':translation_x,
+ 'translation_y':translation_y,
+ 'translation_z':translation_z,
+ 'rotation_3d_x':rotation_3d_x,
+ 'rotation_3d_y':rotation_3d_y,
+ 'rotation_3d_z':rotation_3d_z,
+ }.items():
+ try:
+ outv[k+'_series'] = get_inbetweens(
+ key_frames=parse_key_frames(v),
+ integer=False,
+ # new args
+ interp_spline=interp_spline,
+ max_frames=max_frames,
+ )
+ except RuntimeError:
+ logger.warning(
+ "WARNING: You have selected to use key frames, but you have not "
+ f"formatted `{k}` correctly for key frames.\n"
+ "Attempting to interpret `k` as "
+ f'"0: ({v})"\n'
+ "Please read the instructions to find out how to use key frames "
+ "correctly.\n"
+ )
+ v = f"0: ({v})"
+ outv[k+'_series'] = get_inbetweens(
+ key_frames=parse_key_frames(v),
+ integer=False,
+ # new args
+ interp_spline=interp_spline,
+ max_frames=max_frames,
+
+ )
+ return outv
+
diff --git a/src/disco/common.py b/src/disco/common.py
new file mode 100644
index 00000000..b8d91f0c
--- /dev/null
+++ b/src/disco/common.py
@@ -0,0 +1,335 @@
+"""
+isolating stuff that's not specific to disco.
+at some point, this stuff should get imported from pytti-core
+"""
+import cv2
+import io
+import math
+import requests
+import sys
+
+from PIL import ImageOps
+from resize_right import resize
+from torch import nn
+from torch.nn import functional as F
+import torchvision.transforms as T
+import torchvision.transforms.functional as TF
+
+# https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869
+import pytorch3d.transforms as p3dT
+import disco.disco_xform_utils as dxf
+
+def interp(t):
+ return 3 * t**2 - 2 * t ** 3
+
+def perlin(width, height, scale=10, device=None):
+ gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)
+ xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)
+ ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)
+ wx = 1 - interp(xs)
+ wy = 1 - interp(ys)
+ dots = 0
+ dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)
+ dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)
+ dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))
+ dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))
+ return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)
+
+def perlin_ms(octaves, width, height, grayscale, device):
+ out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]
+ # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]
+ for i in range(1 if grayscale else 3):
+ scale = 2 ** len(octaves)
+ oct_width = width
+ oct_height = height
+ for oct in octaves:
+ p = perlin(oct_width, oct_height, scale, device)
+ out_array[i] += p * oct
+ scale //= 2
+ oct_width *= 2
+ oct_height *= 2
+ return torch.cat(out_array)
+
+def create_perlin_noise(
+ octaves=[1, 1, 1, 1],
+ width=2,
+ height=2,
+ grayscale=True,
+ # new args
+ side_y=None,
+ side_x=None,
+ ):
+ out = perlin_ms(octaves, width, height, grayscale)
+ if grayscale:
+ out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0))
+ out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')
+ else:
+ out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1])
+ out = TF.resize(size=(side_y, side_x), img=out)
+ out = TF.to_pil_image(out.clamp(0, 1).squeeze())
+
+ out = ImageOps.autocontrast(out)
+ return out
+
+def regen_perlin(
+ # new args
+ perlin_mode,
+ batch_size,
+ device,
+):
+ if perlin_mode == 'color':
+ init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
+ init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)
+ elif perlin_mode == 'gray':
+ init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)
+ init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
+ else:
+ init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
+ init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
+
+ init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)
+ del init2
+ return init.expand(batch_size, -1, -1, -1)
+
+def fetch(url_or_path):
+ if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
+ r = requests.get(url_or_path)
+ r.raise_for_status()
+ fd = io.BytesIO()
+ fd.write(r.content)
+ fd.seek(0)
+ return fd
+ return open(url_or_path, 'rb')
+
+def read_image_workaround(path):
+ """OpenCV reads images as BGR, Pillow saves them as RGB. Work around
+ this incompatibility to avoid colour inversions."""
+ im_tmp = cv2.imread(path)
+ return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB)
+
+def parse_prompt(prompt):
+ if prompt.startswith('http://') or prompt.startswith('https://'):
+ vals = prompt.rsplit(':', 2)
+ vals = [vals[0] + ':' + vals[1], *vals[2:]]
+ else:
+ vals = prompt.rsplit(':', 1)
+ vals = vals + ['', '1'][len(vals):]
+ return vals[0], float(vals[1])
+
+def sinc(x):
+ return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
+
+def lanczos(x, a):
+ cond = torch.logical_and(-a < x, x < a)
+ out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
+ return out / out.sum()
+
+def ramp(ratio, width):
+ n = math.ceil(width / ratio + 1)
+ out = torch.empty([n])
+ cur = 0
+ for i in range(out.shape[0]):
+ out[i] = cur
+ cur += ratio
+ return torch.cat([-out[1:].flip([0]), out])[1:-1]
+
+def resample(input, size, align_corners=True):
+ n, c, h, w = input.shape
+ dh, dw = size
+
+ input = input.reshape([n * c, 1, h, w])
+
+ if dh < h:
+ kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
+ pad_h = (kernel_h.shape[0] - 1) // 2
+ input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
+ input = F.conv2d(input, kernel_h[None, None, :, None])
+
+ if dw < w:
+ kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
+ pad_w = (kernel_w.shape[0] - 1) // 2
+ input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
+ input = F.conv2d(input, kernel_w[None, None, None, :])
+
+ input = input.reshape([n, c, h, w])
+ return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
+
+class MakeCutouts(nn.Module):
+ def __init__(self, cut_size, cutn, skip_augs=False):
+ super().__init__()
+ self.cut_size = cut_size
+ self.cutn = cutn
+ self.skip_augs = skip_augs
+ self.augs = T.Compose([
+ T.RandomHorizontalFlip(p=0.5),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomPerspective(distortion_scale=0.4, p=0.7),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomGrayscale(p=0.15),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
+ ])
+
+ def forward(self, input):
+ input = T.Pad(input.shape[2]//4, fill=0)(input)
+ sideY, sideX = input.shape[2:4]
+ max_size = min(sideX, sideY)
+
+ cutouts = []
+ for ch in range(self.cutn):
+ if ch > self.cutn - self.cutn//4:
+ cutout = input.clone()
+ else:
+ size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))
+ offsetx = torch.randint(0, abs(sideX - size + 1), ())
+ offsety = torch.randint(0, abs(sideY - size + 1), ())
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
+
+ if not self.skip_augs:
+ cutout = self.augs(cutout)
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
+ del cutout
+
+ cutouts = torch.cat(cutouts, dim=0)
+ return cutouts
+
+#cutout_debug = False
+#padargs = {}
+
+class MakeCutoutsDango(nn.Module):
+ def __init__(self, cut_size,
+ Overview=4,
+ InnerCrop = 0, IC_Size_Pow=0.5, IC_Grey_P = 0.2,
+ # new args
+ cutout_debug = False,
+ padargs = None,
+ animation_mode = None, # args.animation_mode
+ debug_outpath = "./",
+ skip_augs=False,
+ ):
+ super().__init__()
+ self.cut_size = cut_size
+ self.Overview = Overview
+ self.InnerCrop = InnerCrop
+ self.IC_Size_Pow = IC_Size_Pow
+ self.IC_Grey_P = IC_Grey_P
+
+ # Augs should be an argument that defaults to nn.Identity
+ # rather than requiring an "animation mode" which only makes sense in
+ # our particular use context
+ animation_mode = str(animation_mode)
+ #self.augs = nn.Identity() # don't reassign module
+ if (not skip_augs) and (animation_mode == 'None'):
+ self.augs = T.Compose([
+ T.RandomHorizontalFlip(p=0.5),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomGrayscale(p=0.1),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
+ ])
+ elif (not skip_augs) and (animation_mode == 'Video Input'):
+ self.augs = T.Compose([
+ T.RandomHorizontalFlip(p=0.5),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomPerspective(distortion_scale=0.4, p=0.7),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomGrayscale(p=0.15),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
+ ])
+ elif (not skip_augs) and (animation_mode in ('2D','3D')):
+ self.augs = T.Compose([
+ T.RandomHorizontalFlip(p=0.4),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomGrayscale(p=0.1),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.3),
+ ])
+ else:
+ self.augs = nn.Identity()
+
+ self.cutout_debug = cutout_debug
+ self.debug_outpath = debug_outpath
+ if padargs is None:
+ padargs = {}
+ self.padargs = padargs
+
+ def forward(self, input):
+ cutouts = []
+ gray = T.Grayscale(3)
+ sideY, sideX = input.shape[2:4]
+ max_size = min(sideX, sideY)
+ min_size = min(sideX, sideY, self.cut_size)
+ l_size = max(sideX, sideY)
+ output_shape = [1,3,self.cut_size,self.cut_size]
+ output_shape_2 = [1,3,self.cut_size+2,self.cut_size+2]
+ pad_input = F.pad(input,((sideY-max_size)//2,(sideY-max_size)//2,(sideX-max_size)//2,(sideX-max_size)//2), **self.padargs)
+ cutout = resize(pad_input, out_shape=output_shape)
+
+ if self.Overview>0:
+ if self.Overview<=4:
+ if self.Overview>=1:
+ cutouts.append(cutout)
+ if self.Overview>=2:
+ cutouts.append(gray(cutout))
+ if self.Overview>=3:
+ cutouts.append(TF.hflip(cutout))
+ if self.Overview==4:
+ cutouts.append(gray(TF.hflip(cutout)))
+ else:
+ cutout = resize(pad_input, out_shape=output_shape)
+ for _ in range(self.Overview):
+ cutouts.append(cutout)
+
+ if self.cutout_debug:
+ TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(self.debug_outpath + "cutout_overview0.jpg",quality=99)
+
+ if self.InnerCrop >0:
+ for i in range(self.InnerCrop):
+ size = int(torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + min_size)
+ offsetx = torch.randint(0, sideX - size + 1, ())
+ offsety = torch.randint(0, sideY - size + 1, ())
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
+ if i <= int(self.IC_Grey_P * self.InnerCrop):
+ cutout = gray(cutout)
+ cutout = resize(cutout, out_shape=output_shape)
+ cutouts.append(cutout)
+ if self.cutout_debug:
+ TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(self.debug_outpath + "cutout_InnerCrop.jpg",quality=99)
+ cutouts = torch.cat(cutouts)
+ return cutouts
+
+def spherical_dist_loss(x, y):
+ x = F.normalize(x, dim=-1)
+ y = F.normalize(y, dim=-1)
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
+
+def tv_loss(input):
+ """L2 total variation loss, as in Mahendran et al."""
+ input = F.pad(input, (0, 1, 0, 1), 'replicate')
+ x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
+ y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
+ return (x_diff**2 + y_diff**2).mean([1, 2, 3])
+
+
+def range_loss(input):
+ return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
+
+import torch
+from loguru import logger
+
+def disable_cudnn(DEVICE):
+ if torch.cuda.get_device_capability(DEVICE) == (8,0): ## A100 fix thanks to Emad
+ logger.debug('Disabling CUDNN for A100 gpu', file=sys.stderr)
+ torch.backends.cudnn.enabled = False
+
+def a100_cudnn_fix(DEVICE):
+ disable_cudnn(DEVICE)
\ No newline at end of file
diff --git a/disco_xform_utils.py b/src/disco/disco_xform_utils.py
similarity index 99%
rename from disco_xform_utils.py
rename to src/disco/disco_xform_utils.py
index 9d4f1b84..4fc02725 100644
--- a/disco_xform_utils.py
+++ b/src/disco/disco_xform_utils.py
@@ -6,7 +6,7 @@
import sys, math
try:
- from infer import InferenceHelper
+ from adabins.infer import InferenceHelper
except:
print("disco_xform_utils.py failed to import InferenceHelper. Please ensure that AdaBins directory is in the path (i.e. via sys.path.append('./AdaBins') or other means).")
sys.exit()
diff --git a/src/disco/models/__init__.py b/src/disco/models/__init__.py
new file mode 100644
index 00000000..882f8a52
--- /dev/null
+++ b/src/disco/models/__init__.py
@@ -0,0 +1,3 @@
+from .secondarydiffusionimagenet import SecondaryDiffusionImageNet
+from .secondarydiffusionimagenet2 import SecondaryDiffusionImageNet2
+from .ddimsampler import DDIMSampler
\ No newline at end of file
diff --git a/src/disco/models/ddimsampler.py b/src/disco/models/ddimsampler.py
new file mode 100644
index 00000000..f874c156
--- /dev/null
+++ b/src/disco/models/ddimsampler.py
@@ -0,0 +1,243 @@
+import torch
+from loguru import logger
+from omegaconf import OmegaConf
+
+from tqdm.notebook import tqdm # let's just assume notebook for now...
+from torchvision.datasets.utils import download_url
+
+from ldm.util import instantiate_from_config
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+# from ldm.models.diffusion.ddim import DDIMSampler
+
+# need to add model_path
+def download_models(mode, model_path):
+ """
+ Downloads the models from the internet and saves them to the specified path
+
+ :param mode: the model you want to download
+ :return: The paths to the config and checkpoint files.
+ """
+ if mode == "superresolution":
+ # this is the small bsr light model
+ url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'
+ url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'
+
+ path_conf = f'{model_path}/superres/project.yaml'
+ path_ckpt = f'{model_path}/superres/last.ckpt'
+
+ download_url(url_conf, path_conf)
+ download_url(url_ckpt, path_ckpt)
+
+ path_conf = path_conf + '/?dl=1' # fix it
+ path_ckpt = path_ckpt + '/?dl=1' # fix it
+ return path_conf, path_ckpt
+
+ else:
+ raise NotImplementedError
+
+
+def load_model_from_config(config, ckpt):
+ """
+ Loads a model from a checkpoint
+
+ :param config: The config object that contains the model's parameters
+ :param ckpt: the checkpoint file to load
+ :return: The model and the global step.
+ """
+ logger.debug(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ global_step = pl_sd["global_step"]
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ model.cuda()
+ model.eval()
+ return {"model": model}, global_step
+
+
+def get_model(mode, model_path):
+ path_conf, path_ckpt = download_models(mode, model_path)
+ config = OmegaConf.load(path_conf)
+ model, step = load_model_from_config(config, path_ckpt)
+ return model
+
+
+#@title 1.7 SuperRes Define
+class DDIMSampler(object):
+ """
+ Sampler for super resolution models
+ """
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ logger.warning(f"Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ logger.warning(f"Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sharpening with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs)
+ img, pred_x0 = outs
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+ b, *_, device = *x.shape, x.device
+ e_t = self.model.apply_model(x, t, c)
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
\ No newline at end of file
diff --git a/src/disco/models/modules.py b/src/disco/models/modules.py
new file mode 100644
index 00000000..a2336718
--- /dev/null
+++ b/src/disco/models/modules.py
@@ -0,0 +1,88 @@
+from dataclasses import dataclass
+import math
+
+import torch
+from torch import nn
+
+def append_dims(x, n):
+ """
+ Append `n` `None` values to the end of the `x` array's dimensions
+
+ :param x: The tensor to be reshaped
+ :param n: The number of dimensions to add
+ :return: a tensor with the same shape as x, but with additional dimensions of size n inserted at the
+ front.
+ """
+ return x[(Ellipsis, *(None,) * (n - x.ndim))]
+
+
+def expand_to_planes(x, shape):
+ """
+ Given a tensor x, expand it to a tensor of shape (1, 1, ..., 1, *shape[2:])
+
+ :param x: The input tensor
+ :param shape: the shape of the tensor to be expanded
+ :return: The input tensor x is being expanded to the shape of the output tensor.
+ """
+ return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]])
+
+
+def alpha_sigma_to_t(alpha, sigma):
+ """
+ Given an alpha and sigma, return the corresponding t
+
+ :param alpha: the rotation angle in radians
+ :param sigma: The standard deviation of the Gaussian kernel
+ :return: the angle in radians.
+ """
+ return torch.atan2(sigma, alpha) * 2 / math.pi
+
+
+def t_to_alpha_sigma(t):
+ """
+ Given a tensor of angles, return a tuple of two tensors of the same shape, one containing the cosine
+ of the angles and the other containing the sine of the angles
+
+ :param t: the time parameter
+ :return: the cosine and sine of the input t multiplied by pi/2.
+ """
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
+
+
+@dataclass
+class DiffusionOutput:
+ v: torch.Tensor
+ pred: torch.Tensor
+ eps: torch.Tensor
+
+
+class ConvBlock(nn.Sequential):
+ """
+ 3x3 conv + ReLU
+ """
+ def __init__(self, c_in, c_out):
+ super().__init__(
+ nn.Conv2d(c_in, c_out, 3, padding=1),
+ nn.ReLU(inplace=True),
+ )
+
+
+class SkipBlock(nn.Module):
+ def __init__(self, main, skip=None):
+ super().__init__()
+ self.main = nn.Sequential(*main)
+ self.skip = skip if skip else nn.Identity()
+
+ def forward(self, input):
+ return torch.cat([self.main(input), self.skip(input)], dim=1)
+
+
+class FourierFeatures(nn.Module):
+ def __init__(self, in_features, out_features, std=1.):
+ super().__init__()
+ assert out_features % 2 == 0
+ self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)
+
+ def forward(self, input):
+ f = 2 * math.pi * input @ self.weight.T
+ return torch.cat([f.cos(), f.sin()], dim=-1)
diff --git a/src/disco/models/secondarydiffusionimagenet.py b/src/disco/models/secondarydiffusionimagenet.py
new file mode 100644
index 00000000..a7e986ac
--- /dev/null
+++ b/src/disco/models/secondarydiffusionimagenet.py
@@ -0,0 +1,70 @@
+from functools import partial
+
+import torch
+from torch import nn
+from .modules import (
+ FourierFeatures,
+ ConvBlock,
+ SkipBlock,
+ expand_to_planes,
+ append_dims,
+ t_to_alpha_sigma,
+ DiffusionOutput,
+)
+
+class SecondaryDiffusionImageNet(nn.Module):
+ """
+ Secondary diffusion model trained on Imagenet.
+ """
+ def __init__(self):
+ super().__init__()
+ c = 64 # The base channel count
+
+ self.timestep_embed = FourierFeatures(1, 16)
+
+ self.net = nn.Sequential(
+ ConvBlock(3 + 16, c),
+ ConvBlock(c, c),
+ SkipBlock([
+ nn.AvgPool2d(2),
+ ConvBlock(c, c * 2),
+ ConvBlock(c * 2, c * 2),
+ SkipBlock([
+ nn.AvgPool2d(2),
+ ConvBlock(c * 2, c * 4),
+ ConvBlock(c * 4, c * 4),
+ SkipBlock([
+ nn.AvgPool2d(2),
+ ConvBlock(c * 4, c * 8),
+ ConvBlock(c * 8, c * 4),
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
+ ]),
+ ConvBlock(c * 8, c * 4),
+ ConvBlock(c * 4, c * 2),
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
+ ]),
+ ConvBlock(c * 4, c * 2),
+ ConvBlock(c * 2, c),
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
+ ]),
+ ConvBlock(c * 2, c),
+ nn.Conv2d(c, 3, 3, padding=1),
+ )
+
+ def forward(self, input, t):
+ """
+ Given an input, a time step, and a diffusion network,
+ compute the diffusion network's output,
+ the predicted value, and the diffusion noise
+
+ :param input: the input to the diffusion network
+ :param t: The time step
+ :return: The diffusion output object contains the diffusion parameters, the predicted value, and
+ the residuals.
+ """
+ timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)
+ v = self.net(torch.cat([input, timestep_embed], dim=1))
+ alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))
+ pred = input * alphas - v * sigmas
+ eps = input * sigmas + v * alphas
+ return DiffusionOutput(v, pred, eps)
diff --git a/src/disco/models/secondarydiffusionimagenet2.py b/src/disco/models/secondarydiffusionimagenet2.py
new file mode 100644
index 00000000..15aab22d
--- /dev/null
+++ b/src/disco/models/secondarydiffusionimagenet2.py
@@ -0,0 +1,93 @@
+from functools import partial
+
+import torch
+from torch import nn
+from .modules import (
+ FourierFeatures,
+ ConvBlock,
+ SkipBlock,
+ expand_to_planes,
+ append_dims,
+ t_to_alpha_sigma,
+ DiffusionOutput,
+)
+
+# I think this is functionally identical to the other model,
+# just with the down/up sampling functions aliased to a class attribute for legibility
+class SecondaryDiffusionImageNet2(nn.Module):
+ """
+ Secondary diffusion model trained on Imagenet.
+ """
+ def __init__(self):
+ super().__init__()
+ c = 64 # The base channel count
+ cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8]
+
+ self.timestep_embed = FourierFeatures(1, 16)
+ self.down = nn.AvgPool2d(2)
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
+
+ self.net = nn.Sequential(
+ ConvBlock(3 + 16, cs[0]),
+ ConvBlock(cs[0], cs[0]),
+ SkipBlock([
+ self.down,
+ ConvBlock(cs[0], cs[1]),
+ ConvBlock(cs[1], cs[1]),
+ SkipBlock([
+ self.down,
+ ConvBlock(cs[1], cs[2]),
+ ConvBlock(cs[2], cs[2]),
+ SkipBlock([
+ self.down,
+ ConvBlock(cs[2], cs[3]),
+ ConvBlock(cs[3], cs[3]),
+ SkipBlock([
+ self.down,
+ ConvBlock(cs[3], cs[4]),
+ ConvBlock(cs[4], cs[4]),
+ SkipBlock([
+ self.down,
+ ConvBlock(cs[4], cs[5]),
+ ConvBlock(cs[5], cs[5]),
+ ConvBlock(cs[5], cs[5]),
+ ConvBlock(cs[5], cs[4]),
+ self.up,
+ ]),
+ ConvBlock(cs[4] * 2, cs[4]),
+ ConvBlock(cs[4], cs[3]),
+ self.up,
+ ]),
+ ConvBlock(cs[3] * 2, cs[3]),
+ ConvBlock(cs[3], cs[2]),
+ self.up,
+ ]),
+ ConvBlock(cs[2] * 2, cs[2]),
+ ConvBlock(cs[2], cs[1]),
+ self.up,
+ ]),
+ ConvBlock(cs[1] * 2, cs[1]),
+ ConvBlock(cs[1], cs[0]),
+ self.up,
+ ]),
+ ConvBlock(cs[0] * 2, cs[0]),
+ nn.Conv2d(cs[0], 3, 3, padding=1),
+ )
+
+ def forward(self, input, t):
+ """
+ Given an input, a time step, and a diffusion network,
+ return a DiffusionOutput object containing the diffusion network's output,
+ prediction, and diffusion noise
+
+ :param input: the input to the diffusion network
+ :param t: the time step of the diffusion process
+ :return: The diffusion output object contains the diffusion parameters, the predicted value, and
+ the residuals.
+ """
+ timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)
+ v = self.net(torch.cat([input, timestep_embed], dim=1))
+ alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))
+ pred = input * alphas - v * sigmas
+ eps = input * sigmas + v * alphas
+ return DiffusionOutput(v, pred, eps)
\ No newline at end of file