diff --git a/Disco_Diffusion.ipynb b/Disco_Diffusion.ipynb index 5221aae6..c6809fa8 100644 --- a/Disco_Diffusion.ipynb +++ b/Disco_Diffusion.ipynb @@ -3,8 +3,8 @@ { "cell_type": "markdown", "metadata": { - "id": "view-in-github", - "colab_type": "text" + "colab_type": "text", + "id": "view-in-github" }, "source": [ "\"Open" @@ -377,15 +377,21 @@ " # !git clone https://github.com/facebookresearch/SLIP.git\n", " !git clone https://github.com/crowsonkb/guided-diffusion\n", " !git clone https://github.com/assafshocher/ResizeRight.git\n", - " !pip install -e ./CLIP\n", + " #!pip install -e ./CLIP\n", + " !pip install ./CLIP\n", " !pip install -e ./guided-diffusion\n", " !pip install lpips datetime timm\n", " !apt install imagemagick\n", " !git clone https://github.com/isl-org/MiDaS.git\n", - " !git clone https://github.com/alembics/disco-diffusion.git\n", + " #!git clone https://github.com/alembics/disco-diffusion.git\n", + "\n", + " !git clone --branch pytti-disco https://github.com/pytti-tools/disco-diffusion.git\n", + " !pip install ./disco-diffusion\n", + " \n", " # Rename a file to avoid a name conflict..\n", - " !mv MiDaS/utils.py MiDaS/midas_utils.py\n", - " !cp disco-diffusion/disco_xform_utils.py disco_xform_utils.py\n", + " !mv MiDaS/utils.py MiDaS/midas_utils.py # oof\n", + " !cp disco-diffusion/disco_xform_utils.py disco_xform_utils.py # uhh\n", + " !pip install loguru\n", "\n", "!mkdir model\n", "if not path_exists(f'{model_path}/dpt_large-midas-2f21e586.pt'):\n", @@ -393,6 +399,7 @@ "\n", "import sys\n", "import torch\n", + "from loguru import logger\n", "\n", "#Install pytorch3d\n", "if is_colab:\n", @@ -428,7 +435,8 @@ "import torchvision.transforms as T\n", "import torchvision.transforms.functional as TF\n", "from tqdm.notebook import tqdm\n", - "sys.path.append('./CLIP')\n", + "#sys.path.append('./CLIP')\n", + "#!pip install ./CLIP\n", "sys.path.append('./guided-diffusion')\n", "import clip\n", "from resize_right import resize\n", @@ -444,17 +452,21 @@ "#SuperRes\n", "if is_colab:\n", " !git clone https://github.com/CompVis/latent-diffusion.git\n", - " !git clone https://github.com/CompVis/taming-transformers\n", - " !pip install -e ./taming-transformers\n", + " #!git clone https://github.com/CompVis/taming-transformers\n", + " !git clone https://github.com/pytti-tools/taming-transformers\n", + " #!pip install -e ./taming-transformers\n", + " !pip install ./taming-transformers\n", " !pip install ipywidgets omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops wandb\n", "\n", "#SuperRes\n", "import ipywidgets as widgets\n", "import os\n", - "sys.path.append(\".\")\n", - "sys.path.append('./taming-transformers')\n", + "sys.path.append(\".\") \n", + "#sys.path.append('./taming-transformers')\n", "from taming.models import vqgan # checking correct import from taming\n", "from torchvision.datasets.utils import download_url\n", + "\n", + "# oy vey\n", "if is_colab:\n", " %cd '/content/latent-diffusion'\n", "else:\n", @@ -469,6 +481,8 @@ " from google.colab import files\n", "else:\n", " %cd $PROJECT_DIR\n", + "\n", + "\n", "from IPython.display import Image as ipyimg\n", "from numpy import asarray\n", "from einops import rearrange, repeat\n", @@ -481,31 +495,33 @@ "# AdaBins stuff\n", "if USE_ADABINS:\n", " if is_colab:\n", - " !git clone https://github.com/shariqfarooq123/AdaBins.git\n", - " if not path_exists(f'{model_path}/AdaBins_nyu.pt'):\n", - " !wget https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt -P {model_path}\n", - " !mkdir pretrained\n", - " !cp -P {model_path}/AdaBins_nyu.pt pretrained/AdaBins_nyu.pt\n", - " sys.path.append('./AdaBins')\n", - " from infer import InferenceHelper\n", + " #!git clone https://github.com/shariqfarooq123/AdaBins.git\n", + " !git clone https://github.com/pytti-tools/AdaBins\n", + " !pip install ./AdaBins\n", + " #if not path_exists(f'{model_path}/AdaBins_nyu.pt'):\n", + " # !wget https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt -P {model_path}\n", + " #!mkdir pretrained\n", + " #!cp -P {model_path}/AdaBins_nyu.pt pretrained/AdaBins_nyu.pt\n", + " #sys.path.append('./AdaBins')\n", + " #from infer import InferenceHelper\n", + " from adabins.infer import InferenceHelper\n", " MAX_ADABINS_AREA = 500000\n", "\n", "import torch\n", "DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", - "print('Using device:', DEVICE)\n", + "logger.debug('Using device:', DEVICE)\n", "device = DEVICE # At least one of the modules expects this name..\n", "\n", - "if torch.cuda.get_device_capability(DEVICE) == (8,0): ## A100 fix thanks to Emad\n", - " print('Disabling CUDNN for A100 gpu', file=sys.stderr)\n", - " torch.backends.cudnn.enabled = False" + "from disco.common import a100_cudnn_fix\n", + "a100_cudnn_fix(DEVICE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "id": "BLk3J0h3MtON", - "cellView": "form" + "cellView": "form", + "id": "BLk3J0h3MtON" }, "outputs": [], "source": [ @@ -624,291 +640,30 @@ "source": [ "#@title 1.5 Define necessary functions\n", "\n", - "# https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869\n", - "\n", - "import pytorch3d.transforms as p3dT\n", - "import disco_xform_utils as dxf\n", - "\n", - "def interp(t):\n", - " return 3 * t**2 - 2 * t ** 3\n", - "\n", - "def perlin(width, height, scale=10, device=None):\n", - " gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)\n", - " xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)\n", - " ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)\n", - " wx = 1 - interp(xs)\n", - " wy = 1 - interp(ys)\n", - " dots = 0\n", - " dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)\n", - " dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)\n", - " dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))\n", - " dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))\n", - " return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)\n", - "\n", - "def perlin_ms(octaves, width, height, grayscale, device=device):\n", - " out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]\n", - " # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]\n", - " for i in range(1 if grayscale else 3):\n", - " scale = 2 ** len(octaves)\n", - " oct_width = width\n", - " oct_height = height\n", - " for oct in octaves:\n", - " p = perlin(oct_width, oct_height, scale, device)\n", - " out_array[i] += p * oct\n", - " scale //= 2\n", - " oct_width *= 2\n", - " oct_height *= 2\n", - " return torch.cat(out_array)\n", - "\n", - "def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):\n", - " out = perlin_ms(octaves, width, height, grayscale)\n", - " if grayscale:\n", - " out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0))\n", - " out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')\n", - " else:\n", - " out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1])\n", - " out = TF.resize(size=(side_y, side_x), img=out)\n", - " out = TF.to_pil_image(out.clamp(0, 1).squeeze())\n", - "\n", - " out = ImageOps.autocontrast(out)\n", - " return out\n", - "\n", - "def regen_perlin():\n", - " if perlin_mode == 'color':\n", - " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n", - " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)\n", - " elif perlin_mode == 'gray':\n", - " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)\n", - " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n", - " else:\n", - " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n", - " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n", - "\n", - " init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)\n", - " del init2\n", - " return init.expand(batch_size, -1, -1, -1)\n", - "\n", - "def fetch(url_or_path):\n", - " if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):\n", - " r = requests.get(url_or_path)\n", - " r.raise_for_status()\n", - " fd = io.BytesIO()\n", - " fd.write(r.content)\n", - " fd.seek(0)\n", - " return fd\n", - " return open(url_or_path, 'rb')\n", - "\n", - "def read_image_workaround(path):\n", - " \"\"\"OpenCV reads images as BGR, Pillow saves them as RGB. Work around\n", - " this incompatibility to avoid colour inversions.\"\"\"\n", - " im_tmp = cv2.imread(path)\n", - " return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB)\n", - "\n", - "def parse_prompt(prompt):\n", - " if prompt.startswith('http://') or prompt.startswith('https://'):\n", - " vals = prompt.rsplit(':', 2)\n", - " vals = [vals[0] + ':' + vals[1], *vals[2:]]\n", - " else:\n", - " vals = prompt.rsplit(':', 1)\n", - " vals = vals + ['', '1'][len(vals):]\n", - " return vals[0], float(vals[1])\n", - "\n", - "def sinc(x):\n", - " return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))\n", - "\n", - "def lanczos(x, a):\n", - " cond = torch.logical_and(-a < x, x < a)\n", - " out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))\n", - " return out / out.sum()\n", - "\n", - "def ramp(ratio, width):\n", - " n = math.ceil(width / ratio + 1)\n", - " out = torch.empty([n])\n", - " cur = 0\n", - " for i in range(out.shape[0]):\n", - " out[i] = cur\n", - " cur += ratio\n", - " return torch.cat([-out[1:].flip([0]), out])[1:-1]\n", - "\n", - "def resample(input, size, align_corners=True):\n", - " n, c, h, w = input.shape\n", - " dh, dw = size\n", - "\n", - " input = input.reshape([n * c, 1, h, w])\n", - "\n", - " if dh < h:\n", - " kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)\n", - " pad_h = (kernel_h.shape[0] - 1) // 2\n", - " input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')\n", - " input = F.conv2d(input, kernel_h[None, None, :, None])\n", - "\n", - " if dw < w:\n", - " kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)\n", - " pad_w = (kernel_w.shape[0] - 1) // 2\n", - " input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')\n", - " input = F.conv2d(input, kernel_w[None, None, None, :])\n", - "\n", - " input = input.reshape([n, c, h, w])\n", - " return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)\n", - "\n", - "class MakeCutouts(nn.Module):\n", - " def __init__(self, cut_size, cutn, skip_augs=False):\n", - " super().__init__()\n", - " self.cut_size = cut_size\n", - " self.cutn = cutn\n", - " self.skip_augs = skip_augs\n", - " self.augs = T.Compose([\n", - " T.RandomHorizontalFlip(p=0.5),\n", - " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", - " T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n", - " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", - " T.RandomPerspective(distortion_scale=0.4, p=0.7),\n", - " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", - " T.RandomGrayscale(p=0.15),\n", - " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", - " # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n", - " ])\n", - "\n", - " def forward(self, input):\n", - " input = T.Pad(input.shape[2]//4, fill=0)(input)\n", - " sideY, sideX = input.shape[2:4]\n", - " max_size = min(sideX, sideY)\n", - "\n", - " cutouts = []\n", - " for ch in range(self.cutn):\n", - " if ch > self.cutn - self.cutn//4:\n", - " cutout = input.clone()\n", - " else:\n", - " size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))\n", - " offsetx = torch.randint(0, abs(sideX - size + 1), ())\n", - " offsety = torch.randint(0, abs(sideY - size + 1), ())\n", - " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n", - "\n", - " if not self.skip_augs:\n", - " cutout = self.augs(cutout)\n", - " cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))\n", - " del cutout\n", - "\n", - " cutouts = torch.cat(cutouts, dim=0)\n", - " return cutouts\n", + "from disco.common import (\n", + " interp,\n", + " perlin,\n", + " perlin_ms,\n", + " create_perlin_noise,\n", + " regen_perlin,\n", + " fetch,\n", + " read_image_workaround,\n", + " parse_prompt,\n", + " sinc,\n", + " lanczos,\n", + " ramp,\n", + " resample,\n", + " MakeCutouts,\n", + " MakeCutoutsDango,\n", + " spherical_dist_loss,\n", + " tv_loss,\n", + " range_loss,\n", + ")\n", + "\n", + "from disco.models.modules import alpha_sigma_to_t\n", "\n", "cutout_debug = False\n", "padargs = {}\n", - "\n", - "class MakeCutoutsDango(nn.Module):\n", - " def __init__(self, cut_size,\n", - " Overview=4, \n", - " InnerCrop = 0, IC_Size_Pow=0.5, IC_Grey_P = 0.2\n", - " ):\n", - " super().__init__()\n", - " self.cut_size = cut_size\n", - " self.Overview = Overview\n", - " self.InnerCrop = InnerCrop\n", - " self.IC_Size_Pow = IC_Size_Pow\n", - " self.IC_Grey_P = IC_Grey_P\n", - " if args.animation_mode == 'None':\n", - " self.augs = T.Compose([\n", - " T.RandomHorizontalFlip(p=0.5),\n", - " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", - " T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),\n", - " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", - " T.RandomGrayscale(p=0.1),\n", - " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", - " T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n", - " ])\n", - " elif args.animation_mode == 'Video Input':\n", - " self.augs = T.Compose([\n", - " T.RandomHorizontalFlip(p=0.5),\n", - " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", - " T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n", - " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", - " T.RandomPerspective(distortion_scale=0.4, p=0.7),\n", - " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", - " T.RandomGrayscale(p=0.15),\n", - " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", - " # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n", - " ])\n", - " elif args.animation_mode == '2D' or args.animation_mode == '3D':\n", - " self.augs = T.Compose([\n", - " T.RandomHorizontalFlip(p=0.4),\n", - " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", - " T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),\n", - " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", - " T.RandomGrayscale(p=0.1),\n", - " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n", - " T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.3),\n", - " ])\n", - " \n", - "\n", - " def forward(self, input):\n", - " cutouts = []\n", - " gray = T.Grayscale(3)\n", - " sideY, sideX = input.shape[2:4]\n", - " max_size = min(sideX, sideY)\n", - " min_size = min(sideX, sideY, self.cut_size)\n", - " l_size = max(sideX, sideY)\n", - " output_shape = [1,3,self.cut_size,self.cut_size] \n", - " output_shape_2 = [1,3,self.cut_size+2,self.cut_size+2]\n", - " pad_input = F.pad(input,((sideY-max_size)//2,(sideY-max_size)//2,(sideX-max_size)//2,(sideX-max_size)//2), **padargs)\n", - " cutout = resize(pad_input, out_shape=output_shape)\n", - "\n", - " if self.Overview>0:\n", - " if self.Overview<=4:\n", - " if self.Overview>=1:\n", - " cutouts.append(cutout)\n", - " if self.Overview>=2:\n", - " cutouts.append(gray(cutout))\n", - " if self.Overview>=3:\n", - " cutouts.append(TF.hflip(cutout))\n", - " if self.Overview==4:\n", - " cutouts.append(gray(TF.hflip(cutout)))\n", - " else:\n", - " cutout = resize(pad_input, out_shape=output_shape)\n", - " for _ in range(self.Overview):\n", - " cutouts.append(cutout)\n", - "\n", - " if cutout_debug:\n", - " if is_colab:\n", - " TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(\"/content/cutout_overview0.jpg\",quality=99)\n", - " else:\n", - " TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(\"cutout_overview0.jpg\",quality=99)\n", - "\n", - " \n", - " if self.InnerCrop >0:\n", - " for i in range(self.InnerCrop):\n", - " size = int(torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + min_size)\n", - " offsetx = torch.randint(0, sideX - size + 1, ())\n", - " offsety = torch.randint(0, sideY - size + 1, ())\n", - " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n", - " if i <= int(self.IC_Grey_P * self.InnerCrop):\n", - " cutout = gray(cutout)\n", - " cutout = resize(cutout, out_shape=output_shape)\n", - " cutouts.append(cutout)\n", - " if cutout_debug:\n", - " if is_colab:\n", - " TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(\"/content/cutout_InnerCrop.jpg\",quality=99)\n", - " else:\n", - " TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(\"cutout_InnerCrop.jpg\",quality=99)\n", - " cutouts = torch.cat(cutouts)\n", - " if skip_augs is not True: cutouts=self.augs(cutouts)\n", - " return cutouts\n", - "\n", - "def spherical_dist_loss(x, y):\n", - " x = F.normalize(x, dim=-1)\n", - " y = F.normalize(y, dim=-1)\n", - " return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) \n", - "\n", - "def tv_loss(input):\n", - " \"\"\"L2 total variation loss, as in Mahendran et al.\"\"\"\n", - " input = F.pad(input, (0, 1, 0, 1), 'replicate')\n", - " x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]\n", - " y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]\n", - " return (x_diff**2 + y_diff**2).mean([1, 2, 3])\n", - "\n", - "\n", - "def range_loss(input):\n", - " return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])\n", - "\n", "stop_on_next_loop = False # Make sure GPU memory doesn't get corrupted from cancelling the run mid-way through, allow a full frame to complete\n", "\n", "def do_run():\n", @@ -1120,6 +875,7 @@ " \n", " cur_t = None\n", " \n", + " # this prob doesn't need to be a closure...\n", " def cond_fn(x, t, y=None):\n", " with torch.enable_grad():\n", " x_is_NaN = False\n", @@ -1150,7 +906,14 @@ "\n", " cuts = MakeCutoutsDango(input_resolution,\n", " Overview= args.cut_overview[1000-t_int], \n", - " InnerCrop = args.cut_innercut[1000-t_int], IC_Size_Pow=args.cut_ic_pow, IC_Grey_P = args.cut_icgray_p[1000-t_int]\n", + " InnerCrop = args.cut_innercut[1000-t_int], \n", + " IC_Size_Pow=args.cut_ic_pow, \n", + " IC_Grey_P = args.cut_icgray_p[1000-t_int],\n", + " cutout_debug = cutout_debug,\n", + " padargs = padargs,\n", + " animation_mode = animation_mode, # args.animation_mode \n", + " debug_outpath = \"./\", # to do: this should be conditional on `is_colab``\n", + " skip_augs=skip_augs,\n", " )\n", " clip_in = normalize(cuts(x_in.add(1).div(2)))\n", " image_embeds = model_stat[\"clip_model\"].encode_image(clip_in).float()\n", @@ -1384,167 +1147,7 @@ "source": [ "#@title 1.6 Define the secondary diffusion model\n", "\n", - "def append_dims(x, n):\n", - " return x[(Ellipsis, *(None,) * (n - x.ndim))]\n", - "\n", - "\n", - "def expand_to_planes(x, shape):\n", - " return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]])\n", - "\n", - "\n", - "def alpha_sigma_to_t(alpha, sigma):\n", - " return torch.atan2(sigma, alpha) * 2 / math.pi\n", - "\n", - "\n", - "def t_to_alpha_sigma(t):\n", - " return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)\n", - "\n", - "\n", - "@dataclass\n", - "class DiffusionOutput:\n", - " v: torch.Tensor\n", - " pred: torch.Tensor\n", - " eps: torch.Tensor\n", - "\n", - "\n", - "class ConvBlock(nn.Sequential):\n", - " def __init__(self, c_in, c_out):\n", - " super().__init__(\n", - " nn.Conv2d(c_in, c_out, 3, padding=1),\n", - " nn.ReLU(inplace=True),\n", - " )\n", - "\n", - "\n", - "class SkipBlock(nn.Module):\n", - " def __init__(self, main, skip=None):\n", - " super().__init__()\n", - " self.main = nn.Sequential(*main)\n", - " self.skip = skip if skip else nn.Identity()\n", - "\n", - " def forward(self, input):\n", - " return torch.cat([self.main(input), self.skip(input)], dim=1)\n", - "\n", - "\n", - "class FourierFeatures(nn.Module):\n", - " def __init__(self, in_features, out_features, std=1.):\n", - " super().__init__()\n", - " assert out_features % 2 == 0\n", - " self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)\n", - "\n", - " def forward(self, input):\n", - " f = 2 * math.pi * input @ self.weight.T\n", - " return torch.cat([f.cos(), f.sin()], dim=-1)\n", - "\n", - "\n", - "class SecondaryDiffusionImageNet(nn.Module):\n", - " def __init__(self):\n", - " super().__init__()\n", - " c = 64 # The base channel count\n", - "\n", - " self.timestep_embed = FourierFeatures(1, 16)\n", - "\n", - " self.net = nn.Sequential(\n", - " ConvBlock(3 + 16, c),\n", - " ConvBlock(c, c),\n", - " SkipBlock([\n", - " nn.AvgPool2d(2),\n", - " ConvBlock(c, c * 2),\n", - " ConvBlock(c * 2, c * 2),\n", - " SkipBlock([\n", - " nn.AvgPool2d(2),\n", - " ConvBlock(c * 2, c * 4),\n", - " ConvBlock(c * 4, c * 4),\n", - " SkipBlock([\n", - " nn.AvgPool2d(2),\n", - " ConvBlock(c * 4, c * 8),\n", - " ConvBlock(c * 8, c * 4),\n", - " nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n", - " ]),\n", - " ConvBlock(c * 8, c * 4),\n", - " ConvBlock(c * 4, c * 2),\n", - " nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n", - " ]),\n", - " ConvBlock(c * 4, c * 2),\n", - " ConvBlock(c * 2, c),\n", - " nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n", - " ]),\n", - " ConvBlock(c * 2, c),\n", - " nn.Conv2d(c, 3, 3, padding=1),\n", - " )\n", - "\n", - " def forward(self, input, t):\n", - " timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)\n", - " v = self.net(torch.cat([input, timestep_embed], dim=1))\n", - " alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))\n", - " pred = input * alphas - v * sigmas\n", - " eps = input * sigmas + v * alphas\n", - " return DiffusionOutput(v, pred, eps)\n", - "\n", - "\n", - "class SecondaryDiffusionImageNet2(nn.Module):\n", - " def __init__(self):\n", - " super().__init__()\n", - " c = 64 # The base channel count\n", - " cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8]\n", - "\n", - " self.timestep_embed = FourierFeatures(1, 16)\n", - " self.down = nn.AvgPool2d(2)\n", - " self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)\n", - "\n", - " self.net = nn.Sequential(\n", - " ConvBlock(3 + 16, cs[0]),\n", - " ConvBlock(cs[0], cs[0]),\n", - " SkipBlock([\n", - " self.down,\n", - " ConvBlock(cs[0], cs[1]),\n", - " ConvBlock(cs[1], cs[1]),\n", - " SkipBlock([\n", - " self.down,\n", - " ConvBlock(cs[1], cs[2]),\n", - " ConvBlock(cs[2], cs[2]),\n", - " SkipBlock([\n", - " self.down,\n", - " ConvBlock(cs[2], cs[3]),\n", - " ConvBlock(cs[3], cs[3]),\n", - " SkipBlock([\n", - " self.down,\n", - " ConvBlock(cs[3], cs[4]),\n", - " ConvBlock(cs[4], cs[4]),\n", - " SkipBlock([\n", - " self.down,\n", - " ConvBlock(cs[4], cs[5]),\n", - " ConvBlock(cs[5], cs[5]),\n", - " ConvBlock(cs[5], cs[5]),\n", - " ConvBlock(cs[5], cs[4]),\n", - " self.up,\n", - " ]),\n", - " ConvBlock(cs[4] * 2, cs[4]),\n", - " ConvBlock(cs[4], cs[3]),\n", - " self.up,\n", - " ]),\n", - " ConvBlock(cs[3] * 2, cs[3]),\n", - " ConvBlock(cs[3], cs[2]),\n", - " self.up,\n", - " ]),\n", - " ConvBlock(cs[2] * 2, cs[2]),\n", - " ConvBlock(cs[2], cs[1]),\n", - " self.up,\n", - " ]),\n", - " ConvBlock(cs[1] * 2, cs[1]),\n", - " ConvBlock(cs[1], cs[0]),\n", - " self.up,\n", - " ]),\n", - " ConvBlock(cs[0] * 2, cs[0]),\n", - " nn.Conv2d(cs[0], 3, 3, padding=1),\n", - " )\n", - "\n", - " def forward(self, input, t):\n", - " timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)\n", - " v = self.net(torch.cat([input, timestep_embed], dim=1))\n", - " alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))\n", - " pred = input * alphas - v * sigmas\n", - " eps = input * sigmas + v * alphas\n", - " return DiffusionOutput(v, pred, eps)\n" + "from disco.models import SecondaryDiffusionImageNet, SecondaryDiffusionImageNet2" ] }, { @@ -1557,222 +1160,14 @@ "outputs": [], "source": [ "#@title 1.7 SuperRes Define\n", - "class DDIMSampler(object):\n", - " def __init__(self, model, schedule=\"linear\", **kwargs):\n", - " super().__init__()\n", - " self.model = model\n", - " self.ddpm_num_timesteps = model.num_timesteps\n", - " self.schedule = schedule\n", - "\n", - " def register_buffer(self, name, attr):\n", - " if type(attr) == torch.Tensor:\n", - " if attr.device != torch.device(\"cuda\"):\n", - " attr = attr.to(torch.device(\"cuda\"))\n", - " setattr(self, name, attr)\n", - "\n", - " def make_schedule(self, ddim_num_steps, ddim_discretize=\"uniform\", ddim_eta=0., verbose=True):\n", - " self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,\n", - " num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)\n", - " alphas_cumprod = self.model.alphas_cumprod\n", - " assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'\n", - " to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)\n", - "\n", - " self.register_buffer('betas', to_torch(self.model.betas))\n", - " self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n", - " self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))\n", - "\n", - " # calculations for diffusion q(x_t | x_{t-1}) and others\n", - " self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))\n", - " self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))\n", - " self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))\n", - " self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))\n", - " self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))\n", - "\n", - " # ddim sampling parameters\n", - " ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),\n", - " ddim_timesteps=self.ddim_timesteps,\n", - " eta=ddim_eta,verbose=verbose)\n", - " self.register_buffer('ddim_sigmas', ddim_sigmas)\n", - " self.register_buffer('ddim_alphas', ddim_alphas)\n", - " self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)\n", - " self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))\n", - " sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(\n", - " (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (\n", - " 1 - self.alphas_cumprod / self.alphas_cumprod_prev))\n", - " self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)\n", - "\n", - " @torch.no_grad()\n", - " def sample(self,\n", - " S,\n", - " batch_size,\n", - " shape,\n", - " conditioning=None,\n", - " callback=None,\n", - " normals_sequence=None,\n", - " img_callback=None,\n", - " quantize_x0=False,\n", - " eta=0.,\n", - " mask=None,\n", - " x0=None,\n", - " temperature=1.,\n", - " noise_dropout=0.,\n", - " score_corrector=None,\n", - " corrector_kwargs=None,\n", - " verbose=True,\n", - " x_T=None,\n", - " log_every_t=100,\n", - " **kwargs\n", - " ):\n", - " if conditioning is not None:\n", - " if isinstance(conditioning, dict):\n", - " cbs = conditioning[list(conditioning.keys())[0]].shape[0]\n", - " if cbs != batch_size:\n", - " print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n", - " else:\n", - " if conditioning.shape[0] != batch_size:\n", - " print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n", - "\n", - " self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n", - " # sampling\n", - " C, H, W = shape\n", - " size = (batch_size, C, H, W)\n", - " # print(f'Data shape for DDIM sampling is {size}, eta {eta}')\n", - "\n", - " samples, intermediates = self.ddim_sampling(conditioning, size,\n", - " callback=callback,\n", - " img_callback=img_callback,\n", - " quantize_denoised=quantize_x0,\n", - " mask=mask, x0=x0,\n", - " ddim_use_original_steps=False,\n", - " noise_dropout=noise_dropout,\n", - " temperature=temperature,\n", - " score_corrector=score_corrector,\n", - " corrector_kwargs=corrector_kwargs,\n", - " x_T=x_T,\n", - " log_every_t=log_every_t\n", - " )\n", - " return samples, intermediates\n", - "\n", - " @torch.no_grad()\n", - " def ddim_sampling(self, cond, shape,\n", - " x_T=None, ddim_use_original_steps=False,\n", - " callback=None, timesteps=None, quantize_denoised=False,\n", - " mask=None, x0=None, img_callback=None, log_every_t=100,\n", - " temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n", - " device = self.model.betas.device\n", - " b = shape[0]\n", - " if x_T is None:\n", - " img = torch.randn(shape, device=device)\n", - " else:\n", - " img = x_T\n", - "\n", - " if timesteps is None:\n", - " timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps\n", - " elif timesteps is not None and not ddim_use_original_steps:\n", - " subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1\n", - " timesteps = self.ddim_timesteps[:subset_end]\n", - "\n", - " intermediates = {'x_inter': [img], 'pred_x0': [img]}\n", - " time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)\n", - " total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]\n", - " print(f\"Running DDIM Sharpening with {total_steps} timesteps\")\n", - "\n", - " iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps)\n", - "\n", - " for i, step in enumerate(iterator):\n", - " index = total_steps - i - 1\n", - " ts = torch.full((b,), step, device=device, dtype=torch.long)\n", - "\n", - " if mask is not None:\n", - " assert x0 is not None\n", - " img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?\n", - " img = img_orig * mask + (1. - mask) * img\n", - "\n", - " outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,\n", - " quantize_denoised=quantize_denoised, temperature=temperature,\n", - " noise_dropout=noise_dropout, score_corrector=score_corrector,\n", - " corrector_kwargs=corrector_kwargs)\n", - " img, pred_x0 = outs\n", - " if callback: callback(i)\n", - " if img_callback: img_callback(pred_x0, i)\n", - "\n", - " if index % log_every_t == 0 or index == total_steps - 1:\n", - " intermediates['x_inter'].append(img)\n", - " intermediates['pred_x0'].append(pred_x0)\n", - "\n", - " return img, intermediates\n", - "\n", - " @torch.no_grad()\n", - " def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,\n", - " temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n", - " b, *_, device = *x.shape, x.device\n", - " e_t = self.model.apply_model(x, t, c)\n", - " if score_corrector is not None:\n", - " assert self.model.parameterization == \"eps\"\n", - " e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)\n", - "\n", - " alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas\n", - " alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev\n", - " sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas\n", - " sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas\n", - " # select parameters corresponding to the currently considered timestep\n", - " a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)\n", - " a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)\n", - " sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)\n", - " sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)\n", - "\n", - " # current prediction for x_0\n", - " pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\n", - " if quantize_denoised:\n", - " pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)\n", - " # direction pointing to x_t\n", - " dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t\n", - " noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature\n", - " if noise_dropout > 0.:\n", - " noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n", - " x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise\n", - " return x_prev, pred_x0\n", - "\n", - "\n", - "def download_models(mode):\n", - "\n", - " if mode == \"superresolution\":\n", - " # this is the small bsr light model\n", - " url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'\n", - " url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'\n", - "\n", - " path_conf = f'{model_path}/superres/project.yaml'\n", - " path_ckpt = f'{model_path}/superres/last.ckpt'\n", - "\n", - " download_url(url_conf, path_conf)\n", - " download_url(url_ckpt, path_ckpt)\n", - "\n", - " path_conf = path_conf + '/?dl=1' # fix it\n", - " path_ckpt = path_ckpt + '/?dl=1' # fix it\n", - " return path_conf, path_ckpt\n", - "\n", - " else:\n", - " raise NotImplementedError\n", - "\n", - "\n", - "def load_model_from_config(config, ckpt):\n", - " print(f\"Loading model from {ckpt}\")\n", - " pl_sd = torch.load(ckpt, map_location=\"cpu\")\n", - " global_step = pl_sd[\"global_step\"]\n", - " sd = pl_sd[\"state_dict\"]\n", - " model = instantiate_from_config(config.model)\n", - " m, u = model.load_state_dict(sd, strict=False)\n", - " model.cuda()\n", - " model.eval()\n", - " return {\"model\": model}, global_step\n", - "\n", - "\n", - "def get_model(mode):\n", - " path_conf, path_ckpt = download_models(mode)\n", - " config = OmegaConf.load(path_conf)\n", - " model, step = load_model_from_config(config, path_ckpt)\n", - " return model\n", "\n", + "from disco.models import DDIMSampler\n", + "from disco.models.ddimsampler import (\n", + " download_models,\n", + " load_model_from_config,\n", + " get_model,\n", + " \n", + ")\n", "\n", "def get_custom_cond(mode):\n", " dest = \"data/example_conditioning\"\n", @@ -1989,7 +1384,7 @@ " return log\n", "\n", "sr_diffMode = 'superresolution'\n", - "sr_model = get_model('superresolution')\n", + "sr_model = get_model('superresolution', model_path)\n", "\n", "\n", "\n", @@ -2419,238 +1814,46 @@ "#@markdown `frame_skip_steps` will blur the previous frame - higher values will flicker less but struggle to add enough new detail to zoom into.\n", "frames_skip_steps = '60%' #@param ['40%', '50%', '60%', '70%', '80%'] {type: 'string'}\n", "\n", - "\n", - "def parse_key_frames(string, prompt_parser=None):\n", - " \"\"\"Given a string representing frame numbers paired with parameter values at that frame,\n", - " return a dictionary with the frame numbers as keys and the parameter values as the values.\n", - "\n", - " Parameters\n", - " ----------\n", - " string: string\n", - " Frame numbers paired with parameter values at that frame number, in the format\n", - " 'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...'\n", - " prompt_parser: function or None, optional\n", - " If provided, prompt_parser will be applied to each string of parameter values.\n", - " \n", - " Returns\n", - " -------\n", - " dict\n", - " Frame numbers as keys, parameter values at that frame number as values\n", - "\n", - " Raises\n", - " ------\n", - " RuntimeError\n", - " If the input string does not match the expected format.\n", - " \n", - " Examples\n", - " --------\n", - " >>> parse_key_frames(\"10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)\")\n", - " {10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'}\n", - "\n", - " >>> parse_key_frames(\"10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)\", prompt_parser=lambda x: x.lower()))\n", - " {10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'}\n", - " \"\"\"\n", - " import re\n", - " pattern = r'((?P[0-9]+):[\\s]*[\\(](?P[\\S\\s]*?)[\\)])'\n", - " frames = dict()\n", - " for match_object in re.finditer(pattern, string):\n", - " frame = int(match_object.groupdict()['frame'])\n", - " param = match_object.groupdict()['param']\n", - " if prompt_parser:\n", - " frames[frame] = prompt_parser(param)\n", - " else:\n", - " frames[frame] = param\n", - "\n", - " if frames == {} and len(string) != 0:\n", - " raise RuntimeError('Key Frame string not correctly formatted')\n", - " return frames\n", - "\n", - "def get_inbetweens(key_frames, integer=False):\n", - " \"\"\"Given a dict with frame numbers as keys and a parameter value as values,\n", - " return a pandas Series containing the value of the parameter at every frame from 0 to max_frames.\n", - " Any values not provided in the input dict are calculated by linear interpolation between\n", - " the values of the previous and next provided frames. If there is no previous provided frame, then\n", - " the value is equal to the value of the next provided frame, or if there is no next provided frame,\n", - " then the value is equal to the value of the previous provided frame. If no frames are provided,\n", - " all frame values are NaN.\n", - "\n", - " Parameters\n", - " ----------\n", - " key_frames: dict\n", - " A dict with integer frame numbers as keys and numerical values of a particular parameter as values.\n", - " integer: Bool, optional\n", - " If True, the values of the output series are converted to integers.\n", - " Otherwise, the values are floats.\n", - " \n", - " Returns\n", - " -------\n", - " pd.Series\n", - " A Series with length max_frames representing the parameter values for each frame.\n", - " \n", - " Examples\n", - " --------\n", - " >>> max_frames = 5\n", - " >>> get_inbetweens({1: 5, 3: 6})\n", - " 0 5.0\n", - " 1 5.0\n", - " 2 5.5\n", - " 3 6.0\n", - " 4 6.0\n", - " dtype: float64\n", - "\n", - " >>> get_inbetweens({1: 5, 3: 6}, integer=True)\n", - " 0 5\n", - " 1 5\n", - " 2 5\n", - " 3 6\n", - " 4 6\n", - " dtype: int64\n", - " \"\"\"\n", - " key_frame_series = pd.Series([np.nan for a in range(max_frames)])\n", - "\n", - " for i, value in key_frames.items():\n", - " key_frame_series[i] = value\n", - " key_frame_series = key_frame_series.astype(float)\n", - " \n", - " interp_method = interp_spline\n", - "\n", - " if interp_method == 'Cubic' and len(key_frames.items()) <=3:\n", - " interp_method = 'Quadratic'\n", - " \n", - " if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:\n", - " interp_method = 'Linear'\n", - " \n", - " \n", - " key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]\n", - " key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]\n", - " # key_frame_series = key_frame_series.interpolate(method=intrp_method,order=1, limit_direction='both')\n", - " key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both')\n", - " if integer:\n", - " return key_frame_series.astype(int)\n", - " return key_frame_series\n", - "\n", - "def split_prompts(prompts):\n", - " prompt_series = pd.Series([np.nan for a in range(max_frames)])\n", - " for i, prompt in prompts.items():\n", - " prompt_series[i] = prompt\n", - " # prompt_series = prompt_series.astype(str)\n", - " prompt_series = prompt_series.ffill().bfill()\n", - " return prompt_series\n", + "from disco.animation import (\n", + " parse_key_frames,\n", + " get_inbetweens,\n", + " split_prompts,\n", + " process_keyframe_animation,\n", + ")\n", "\n", "if key_frames:\n", - " try:\n", - " angle_series = get_inbetweens(parse_key_frames(angle))\n", - " except RuntimeError as e:\n", - " print(\n", - " \"WARNING: You have selected to use key frames, but you have not \"\n", - " \"formatted `angle` correctly for key frames.\\n\"\n", - " \"Attempting to interpret `angle` as \"\n", - " f'\"0: ({angle})\"\\n'\n", - " \"Please read the instructions to find out how to use key frames \"\n", - " \"correctly.\\n\"\n", - " )\n", - " angle = f\"0: ({angle})\"\n", - " angle_series = get_inbetweens(parse_key_frames(angle))\n", - "\n", - " try:\n", - " zoom_series = get_inbetweens(parse_key_frames(zoom))\n", - " except RuntimeError as e:\n", - " print(\n", - " \"WARNING: You have selected to use key frames, but you have not \"\n", - " \"formatted `zoom` correctly for key frames.\\n\"\n", - " \"Attempting to interpret `zoom` as \"\n", - " f'\"0: ({zoom})\"\\n'\n", - " \"Please read the instructions to find out how to use key frames \"\n", - " \"correctly.\\n\"\n", - " )\n", - " zoom = f\"0: ({zoom})\"\n", - " zoom_series = get_inbetweens(parse_key_frames(zoom))\n", - "\n", - " try:\n", - " translation_x_series = get_inbetweens(parse_key_frames(translation_x))\n", - " except RuntimeError as e:\n", - " print(\n", - " \"WARNING: You have selected to use key frames, but you have not \"\n", - " \"formatted `translation_x` correctly for key frames.\\n\"\n", - " \"Attempting to interpret `translation_x` as \"\n", - " f'\"0: ({translation_x})\"\\n'\n", - " \"Please read the instructions to find out how to use key frames \"\n", - " \"correctly.\\n\"\n", - " )\n", - " translation_x = f\"0: ({translation_x})\"\n", - " translation_x_series = get_inbetweens(parse_key_frames(translation_x))\n", - "\n", - " try:\n", - " translation_y_series = get_inbetweens(parse_key_frames(translation_y))\n", - " except RuntimeError as e:\n", - " print(\n", - " \"WARNING: You have selected to use key frames, but you have not \"\n", - " \"formatted `translation_y` correctly for key frames.\\n\"\n", - " \"Attempting to interpret `translation_y` as \"\n", - " f'\"0: ({translation_y})\"\\n'\n", - " \"Please read the instructions to find out how to use key frames \"\n", - " \"correctly.\\n\"\n", - " )\n", - " translation_y = f\"0: ({translation_y})\"\n", - " translation_y_series = get_inbetweens(parse_key_frames(translation_y))\n", - "\n", - " try:\n", - " translation_z_series = get_inbetweens(parse_key_frames(translation_z))\n", - " except RuntimeError as e:\n", - " print(\n", - " \"WARNING: You have selected to use key frames, but you have not \"\n", - " \"formatted `translation_z` correctly for key frames.\\n\"\n", - " \"Attempting to interpret `translation_z` as \"\n", - " f'\"0: ({translation_z})\"\\n'\n", - " \"Please read the instructions to find out how to use key frames \"\n", - " \"correctly.\\n\"\n", - " )\n", - " translation_z = f\"0: ({translation_z})\"\n", - " translation_z_series = get_inbetweens(parse_key_frames(translation_z))\n", - "\n", - " try:\n", - " rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))\n", - " except RuntimeError as e:\n", - " print(\n", - " \"WARNING: You have selected to use key frames, but you have not \"\n", - " \"formatted `rotation_3d_x` correctly for key frames.\\n\"\n", - " \"Attempting to interpret `rotation_3d_x` as \"\n", - " f'\"0: ({rotation_3d_x})\"\\n'\n", - " \"Please read the instructions to find out how to use key frames \"\n", - " \"correctly.\\n\"\n", - " )\n", - " rotation_3d_x = f\"0: ({rotation_3d_x})\"\n", - " rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))\n", - "\n", - " try:\n", - " rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))\n", - " except RuntimeError as e:\n", - " print(\n", - " \"WARNING: You have selected to use key frames, but you have not \"\n", - " \"formatted `rotation_3d_y` correctly for key frames.\\n\"\n", - " \"Attempting to interpret `rotation_3d_y` as \"\n", - " f'\"0: ({rotation_3d_y})\"\\n'\n", - " \"Please read the instructions to find out how to use key frames \"\n", - " \"correctly.\\n\"\n", - " )\n", - " rotation_3d_y = f\"0: ({rotation_3d_y})\"\n", - " rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))\n", - "\n", - " try:\n", - " rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))\n", - " except RuntimeError as e:\n", - " print(\n", - " \"WARNING: You have selected to use key frames, but you have not \"\n", - " \"formatted `rotation_3d_z` correctly for key frames.\\n\"\n", - " \"Attempting to interpret `rotation_3d_z` as \"\n", - " f'\"0: ({rotation_3d_z})\"\\n'\n", - " \"Please read the instructions to find out how to use key frames \"\n", - " \"correctly.\\n\"\n", - " )\n", - " rotation_3d_z = f\"0: ({rotation_3d_z})\"\n", - " rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))\n", + " processed = process_keyframe_animation(\n", + " angle,\n", + " zoom,\n", + " translation_x,\n", + " translation_y,\n", + " translation_z,\n", + " rotation_3d_x,\n", + " rotation_3d_y,\n", + " rotation_3d_z,\n", + " key_frames,\n", + " interp_spline,\n", + " max_frames,\n", + " )\n", "\n", - "else:\n", + " (angle_series,\n", + " zoom_series,\n", + " translation_x_series,\n", + " translation_y_series,\n", + " translation_z_series,\n", + " rotation_3d_x_series,\n", + " rotation_3d_y_series,\n", + " rotation_3d_z_series) = (\n", + " processed['angle_series'],\n", + " processed['zoom_series'],\n", + " processed['translation_x_series'],\n", + " processed['translation_y_series'],\n", + " processed['translation_z_series'],\n", + " processed['rotation_3d_x_series'],\n", + " processed['rotation_3d_y_series'],\n", + " processed['rotation_3d_z_series'],\n", + " )\n", + "else: \n", " angle = float(angle)\n", " zoom = float(zoom)\n", " translation_x = float(translation_x)\n", @@ -2658,7 +1861,7 @@ " translation_z = float(translation_z)\n", " rotation_3d_x = float(rotation_3d_x)\n", " rotation_3d_y = float(rotation_3d_y)\n", - " rotation_3d_z = float(rotation_3d_z)" + " rotation_3d_z = float(rotation_3d_z)\n" ] }, { @@ -2865,8 +2068,8 @@ "\n", "args = {\n", " 'batchNum': batchNum,\n", - " 'prompts_series':split_prompts(text_prompts) if text_prompts else None,\n", - " 'image_prompts_series':split_prompts(image_prompts) if image_prompts else None,\n", + " 'prompts_series':split_prompts(text_prompts, max_frames) if text_prompts else None,\n", + " 'image_prompts_series':split_prompts(image_prompts, max_frames) if image_prompts else None,\n", " 'seed': seed,\n", " 'display_rate':display_rate,\n", " 'n_batches':n_batches if animation_mode == 'None' else 1,\n", @@ -3073,11 +2276,11 @@ "CnkTNXJAPzL2", "u1VHzHvNx5fd" ], + "include_colab_link": true, "machine_shape": "hm", "name": "Disco Diffusion v5 [w/ 3D animation]", "private_outputs": true, - "provenance": [], - "include_colab_link": true + "provenance": [] }, "kernelspec": { "display_name": "Python 3", @@ -3098,4 +2301,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..7fd26b97 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..13f9fc09 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,7 @@ +[metadata] +name = pyttitools-disco +version = 0.0.1 + +[options] +install_requires = + torch >= 1.10 diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..02d47a56 --- /dev/null +++ b/setup.py @@ -0,0 +1,16 @@ +from setuptools import setup, find_packages, find_namespace_packages +import logging + +p0 = find_packages(where="src") +p2 = find_namespace_packages( + where="src", + #include=["hydra_plugins.*"], +) + +setup( + packages=p0 + p2, + package_dir={ + "": "src", + }, + #install_requires=["pyttitools-adabins", "pyttitools-gma"], +) \ No newline at end of file diff --git a/src/disco/__init__.py b/src/disco/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/disco/animation.py b/src/disco/animation.py new file mode 100644 index 00000000..453be07f --- /dev/null +++ b/src/disco/animation.py @@ -0,0 +1,214 @@ +""" +Isolating animation stuff from disco, probably port it into pytti at some point +""" + +from loguru import logger +import numpy as np +import pandas as pd + +def parse_key_frames(string, prompt_parser=None): + """Given a string representing frame numbers paired with parameter values at that frame, + return a dictionary with the frame numbers as keys and the parameter values as the values. + + Parameters + ---------- + string: string + Frame numbers paired with parameter values at that frame number, in the format + 'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...' + prompt_parser: function or None, optional + If provided, prompt_parser will be applied to each string of parameter values. + + Returns + ------- + dict + Frame numbers as keys, parameter values at that frame number as values + + Raises + ------ + RuntimeError + If the input string does not match the expected format. + + Examples + -------- + >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)") + {10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'} + + >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)", prompt_parser=lambda x: x.lower())) + {10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'} + """ + import re + pattern = r'((?P[0-9]+):[\s]*[\(](?P[\S\s]*?)[\)])' + frames = dict() + for match_object in re.finditer(pattern, string): + frame = int(match_object.groupdict()['frame']) + param = match_object.groupdict()['param'] + if prompt_parser: + frames[frame] = prompt_parser(param) + else: + frames[frame] = param + + if frames == {} and len(string) != 0: + raise RuntimeError('Key Frame string not correctly formatted') + return frames + +def get_inbetweens( + key_frames, + integer=False, + # new args + interp_spline=None, + max_frames=None, + ): + """Given a dict with frame numbers as keys and a parameter value as values, + return a pandas Series containing the value of the parameter at every frame from 0 to max_frames. + Any values not provided in the input dict are calculated by linear interpolation between + the values of the previous and next provided frames. If there is no previous provided frame, then + the value is equal to the value of the next provided frame, or if there is no next provided frame, + then the value is equal to the value of the previous provided frame. If no frames are provided, + all frame values are NaN. + + Parameters + ---------- + key_frames: dict + A dict with integer frame numbers as keys and numerical values of a particular parameter as values. + integer: Bool, optional + If True, the values of the output series are converted to integers. + Otherwise, the values are floats. + + Returns + ------- + pd.Series + A Series with length max_frames representing the parameter values for each frame. + + Examples + -------- + >>> max_frames = 5 + >>> get_inbetweens({1: 5, 3: 6}) + 0 5.0 + 1 5.0 + 2 5.5 + 3 6.0 + 4 6.0 + dtype: float64 + + >>> get_inbetweens({1: 5, 3: 6}, integer=True) + 0 5 + 1 5 + 2 5 + 3 6 + 4 6 + dtype: int64 + """ + key_frame_series = pd.Series([np.nan for a in range(max_frames)]) + + for i, value in key_frames.items(): + key_frame_series[i] = value + key_frame_series = key_frame_series.astype(float) + + interp_method = interp_spline + + if interp_method == 'Cubic' and len(key_frames.items()) <=3: + interp_method = 'Quadratic' + + if interp_method == 'Quadratic' and len(key_frames.items()) <= 2: + interp_method = 'Linear' + + + key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()] + key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()] + # key_frame_series = key_frame_series.interpolate(method=intrp_method,order=1, limit_direction='both') + key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both') + if integer: + return key_frame_series.astype(int) + return key_frame_series + +def split_prompts( + prompts, + # new args + max_frames, + ): + """ + The function takes in a dictionary of prompt indices and their corresponding prompts. + It then creates a pandas series of nans of length max_frames. + It then iterates through the dictionary and fills in the series with the prompts corresponding to + the indices. + It then backfills and forwards fills to ensure that the series is complete. + + :param prompts: a dictionary of prompt_id -> prompt_text + :return: A pandas series with the prompts for each frame. + """ + prompt_series = pd.Series([np.nan for a in range(max_frames)]) + for i, prompt in prompts.items(): + prompt_series[i] = prompt + # prompt_series = prompt_series.astype(str) + prompt_series = prompt_series.ffill().bfill() + return prompt_series + +def process_keyframe_animation( + angle, + zoom, + translation_x, + translation_y, + translation_z, + rotation_3d_x, + rotation_3d_y, + rotation_3d_z, + key_frames=None, + # new args + interp_spline=None, + max_frames=None, +): + """ + Given a dictionary of keyframes, return a dictionary of interpolated values + + :param angle: The angle of the camera in degrees + :param zoom: The zoom level of the camera + :param translation_x: The x-axis translation of the camera + :param translation_y: "0: (0), 1: (0), 2: (0), 3: (0), 4: (0), 5: (0), 6: (0), 7: (0), 8: (0), 9: + (0), 10: (0), 11: (0 + :param translation_z: The distance between the camera and the object + :param rotation_3d_x: The rotation of the camera around the x axis + :param rotation_3d_y: The rotation of the camera around the y axis + :param rotation_3d_z: The rotation of the 3D plot around the z axis + :param key_frames: A list of key frames + :return: A dictionary of keyframes. + """ + assert key_frames is not None + outv = {} + for k,v in { + 'angle':angle, + 'zoom':zoom, + 'translation_x':translation_x, + 'translation_y':translation_y, + 'translation_z':translation_z, + 'rotation_3d_x':rotation_3d_x, + 'rotation_3d_y':rotation_3d_y, + 'rotation_3d_z':rotation_3d_z, + }.items(): + try: + outv[k+'_series'] = get_inbetweens( + key_frames=parse_key_frames(v), + integer=False, + # new args + interp_spline=interp_spline, + max_frames=max_frames, + ) + except RuntimeError: + logger.warning( + "WARNING: You have selected to use key frames, but you have not " + f"formatted `{k}` correctly for key frames.\n" + "Attempting to interpret `k` as " + f'"0: ({v})"\n' + "Please read the instructions to find out how to use key frames " + "correctly.\n" + ) + v = f"0: ({v})" + outv[k+'_series'] = get_inbetweens( + key_frames=parse_key_frames(v), + integer=False, + # new args + interp_spline=interp_spline, + max_frames=max_frames, + + ) + return outv + diff --git a/src/disco/common.py b/src/disco/common.py new file mode 100644 index 00000000..b8d91f0c --- /dev/null +++ b/src/disco/common.py @@ -0,0 +1,335 @@ +""" +isolating stuff that's not specific to disco. +at some point, this stuff should get imported from pytti-core +""" +import cv2 +import io +import math +import requests +import sys + +from PIL import ImageOps +from resize_right import resize +from torch import nn +from torch.nn import functional as F +import torchvision.transforms as T +import torchvision.transforms.functional as TF + +# https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869 +import pytorch3d.transforms as p3dT +import disco.disco_xform_utils as dxf + +def interp(t): + return 3 * t**2 - 2 * t ** 3 + +def perlin(width, height, scale=10, device=None): + gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device) + xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device) + ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device) + wx = 1 - interp(xs) + wy = 1 - interp(ys) + dots = 0 + dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys) + dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys) + dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys)) + dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys)) + return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale) + +def perlin_ms(octaves, width, height, grayscale, device): + out_array = [0.5] if grayscale else [0.5, 0.5, 0.5] + # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0] + for i in range(1 if grayscale else 3): + scale = 2 ** len(octaves) + oct_width = width + oct_height = height + for oct in octaves: + p = perlin(oct_width, oct_height, scale, device) + out_array[i] += p * oct + scale //= 2 + oct_width *= 2 + oct_height *= 2 + return torch.cat(out_array) + +def create_perlin_noise( + octaves=[1, 1, 1, 1], + width=2, + height=2, + grayscale=True, + # new args + side_y=None, + side_x=None, + ): + out = perlin_ms(octaves, width, height, grayscale) + if grayscale: + out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0)) + out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB') + else: + out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1]) + out = TF.resize(size=(side_y, side_x), img=out) + out = TF.to_pil_image(out.clamp(0, 1).squeeze()) + + out = ImageOps.autocontrast(out) + return out + +def regen_perlin( + # new args + perlin_mode, + batch_size, + device, +): + if perlin_mode == 'color': + init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False) + init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False) + elif perlin_mode == 'gray': + init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True) + init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True) + else: + init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False) + init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True) + + init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1) + del init2 + return init.expand(batch_size, -1, -1, -1) + +def fetch(url_or_path): + if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): + r = requests.get(url_or_path) + r.raise_for_status() + fd = io.BytesIO() + fd.write(r.content) + fd.seek(0) + return fd + return open(url_or_path, 'rb') + +def read_image_workaround(path): + """OpenCV reads images as BGR, Pillow saves them as RGB. Work around + this incompatibility to avoid colour inversions.""" + im_tmp = cv2.imread(path) + return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB) + +def parse_prompt(prompt): + if prompt.startswith('http://') or prompt.startswith('https://'): + vals = prompt.rsplit(':', 2) + vals = [vals[0] + ':' + vals[1], *vals[2:]] + else: + vals = prompt.rsplit(':', 1) + vals = vals + ['', '1'][len(vals):] + return vals[0], float(vals[1]) + +def sinc(x): + return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) + +def lanczos(x, a): + cond = torch.logical_and(-a < x, x < a) + out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([])) + return out / out.sum() + +def ramp(ratio, width): + n = math.ceil(width / ratio + 1) + out = torch.empty([n]) + cur = 0 + for i in range(out.shape[0]): + out[i] = cur + cur += ratio + return torch.cat([-out[1:].flip([0]), out])[1:-1] + +def resample(input, size, align_corners=True): + n, c, h, w = input.shape + dh, dw = size + + input = input.reshape([n * c, 1, h, w]) + + if dh < h: + kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) + pad_h = (kernel_h.shape[0] - 1) // 2 + input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect') + input = F.conv2d(input, kernel_h[None, None, :, None]) + + if dw < w: + kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) + pad_w = (kernel_w.shape[0] - 1) // 2 + input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect') + input = F.conv2d(input, kernel_w[None, None, None, :]) + + input = input.reshape([n, c, h, w]) + return F.interpolate(input, size, mode='bicubic', align_corners=align_corners) + +class MakeCutouts(nn.Module): + def __init__(self, cut_size, cutn, skip_augs=False): + super().__init__() + self.cut_size = cut_size + self.cutn = cutn + self.skip_augs = skip_augs + self.augs = T.Compose([ + T.RandomHorizontalFlip(p=0.5), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomAffine(degrees=15, translate=(0.1, 0.1)), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomPerspective(distortion_scale=0.4, p=0.7), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomGrayscale(p=0.15), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), + ]) + + def forward(self, input): + input = T.Pad(input.shape[2]//4, fill=0)(input) + sideY, sideX = input.shape[2:4] + max_size = min(sideX, sideY) + + cutouts = [] + for ch in range(self.cutn): + if ch > self.cutn - self.cutn//4: + cutout = input.clone() + else: + size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.)) + offsetx = torch.randint(0, abs(sideX - size + 1), ()) + offsety = torch.randint(0, abs(sideY - size + 1), ()) + cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] + + if not self.skip_augs: + cutout = self.augs(cutout) + cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) + del cutout + + cutouts = torch.cat(cutouts, dim=0) + return cutouts + +#cutout_debug = False +#padargs = {} + +class MakeCutoutsDango(nn.Module): + def __init__(self, cut_size, + Overview=4, + InnerCrop = 0, IC_Size_Pow=0.5, IC_Grey_P = 0.2, + # new args + cutout_debug = False, + padargs = None, + animation_mode = None, # args.animation_mode + debug_outpath = "./", + skip_augs=False, + ): + super().__init__() + self.cut_size = cut_size + self.Overview = Overview + self.InnerCrop = InnerCrop + self.IC_Size_Pow = IC_Size_Pow + self.IC_Grey_P = IC_Grey_P + + # Augs should be an argument that defaults to nn.Identity + # rather than requiring an "animation mode" which only makes sense in + # our particular use context + animation_mode = str(animation_mode) + #self.augs = nn.Identity() # don't reassign module + if (not skip_augs) and (animation_mode == 'None'): + self.augs = T.Compose([ + T.RandomHorizontalFlip(p=0.5), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomGrayscale(p=0.1), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), + ]) + elif (not skip_augs) and (animation_mode == 'Video Input'): + self.augs = T.Compose([ + T.RandomHorizontalFlip(p=0.5), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomAffine(degrees=15, translate=(0.1, 0.1)), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomPerspective(distortion_scale=0.4, p=0.7), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomGrayscale(p=0.15), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), + ]) + elif (not skip_augs) and (animation_mode in ('2D','3D')): + self.augs = T.Compose([ + T.RandomHorizontalFlip(p=0.4), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomGrayscale(p=0.1), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.3), + ]) + else: + self.augs = nn.Identity() + + self.cutout_debug = cutout_debug + self.debug_outpath = debug_outpath + if padargs is None: + padargs = {} + self.padargs = padargs + + def forward(self, input): + cutouts = [] + gray = T.Grayscale(3) + sideY, sideX = input.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + l_size = max(sideX, sideY) + output_shape = [1,3,self.cut_size,self.cut_size] + output_shape_2 = [1,3,self.cut_size+2,self.cut_size+2] + pad_input = F.pad(input,((sideY-max_size)//2,(sideY-max_size)//2,(sideX-max_size)//2,(sideX-max_size)//2), **self.padargs) + cutout = resize(pad_input, out_shape=output_shape) + + if self.Overview>0: + if self.Overview<=4: + if self.Overview>=1: + cutouts.append(cutout) + if self.Overview>=2: + cutouts.append(gray(cutout)) + if self.Overview>=3: + cutouts.append(TF.hflip(cutout)) + if self.Overview==4: + cutouts.append(gray(TF.hflip(cutout))) + else: + cutout = resize(pad_input, out_shape=output_shape) + for _ in range(self.Overview): + cutouts.append(cutout) + + if self.cutout_debug: + TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(self.debug_outpath + "cutout_overview0.jpg",quality=99) + + if self.InnerCrop >0: + for i in range(self.InnerCrop): + size = int(torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + min_size) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] + if i <= int(self.IC_Grey_P * self.InnerCrop): + cutout = gray(cutout) + cutout = resize(cutout, out_shape=output_shape) + cutouts.append(cutout) + if self.cutout_debug: + TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(self.debug_outpath + "cutout_InnerCrop.jpg",quality=99) + cutouts = torch.cat(cutouts) + return cutouts + +def spherical_dist_loss(x, y): + x = F.normalize(x, dim=-1) + y = F.normalize(y, dim=-1) + return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) + +def tv_loss(input): + """L2 total variation loss, as in Mahendran et al.""" + input = F.pad(input, (0, 1, 0, 1), 'replicate') + x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] + y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] + return (x_diff**2 + y_diff**2).mean([1, 2, 3]) + + +def range_loss(input): + return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3]) + +import torch +from loguru import logger + +def disable_cudnn(DEVICE): + if torch.cuda.get_device_capability(DEVICE) == (8,0): ## A100 fix thanks to Emad + logger.debug('Disabling CUDNN for A100 gpu', file=sys.stderr) + torch.backends.cudnn.enabled = False + +def a100_cudnn_fix(DEVICE): + disable_cudnn(DEVICE) \ No newline at end of file diff --git a/disco_xform_utils.py b/src/disco/disco_xform_utils.py similarity index 99% rename from disco_xform_utils.py rename to src/disco/disco_xform_utils.py index 9d4f1b84..4fc02725 100644 --- a/disco_xform_utils.py +++ b/src/disco/disco_xform_utils.py @@ -6,7 +6,7 @@ import sys, math try: - from infer import InferenceHelper + from adabins.infer import InferenceHelper except: print("disco_xform_utils.py failed to import InferenceHelper. Please ensure that AdaBins directory is in the path (i.e. via sys.path.append('./AdaBins') or other means).") sys.exit() diff --git a/src/disco/models/__init__.py b/src/disco/models/__init__.py new file mode 100644 index 00000000..882f8a52 --- /dev/null +++ b/src/disco/models/__init__.py @@ -0,0 +1,3 @@ +from .secondarydiffusionimagenet import SecondaryDiffusionImageNet +from .secondarydiffusionimagenet2 import SecondaryDiffusionImageNet2 +from .ddimsampler import DDIMSampler \ No newline at end of file diff --git a/src/disco/models/ddimsampler.py b/src/disco/models/ddimsampler.py new file mode 100644 index 00000000..f874c156 --- /dev/null +++ b/src/disco/models/ddimsampler.py @@ -0,0 +1,243 @@ +import torch +from loguru import logger +from omegaconf import OmegaConf + +from tqdm.notebook import tqdm # let's just assume notebook for now... +from torchvision.datasets.utils import download_url + +from ldm.util import instantiate_from_config +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +# from ldm.models.diffusion.ddim import DDIMSampler + +# need to add model_path +def download_models(mode, model_path): + """ + Downloads the models from the internet and saves them to the specified path + + :param mode: the model you want to download + :return: The paths to the config and checkpoint files. + """ + if mode == "superresolution": + # this is the small bsr light model + url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1' + url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1' + + path_conf = f'{model_path}/superres/project.yaml' + path_ckpt = f'{model_path}/superres/last.ckpt' + + download_url(url_conf, path_conf) + download_url(url_ckpt, path_ckpt) + + path_conf = path_conf + '/?dl=1' # fix it + path_ckpt = path_ckpt + '/?dl=1' # fix it + return path_conf, path_ckpt + + else: + raise NotImplementedError + + +def load_model_from_config(config, ckpt): + """ + Loads a model from a checkpoint + + :param config: The config object that contains the model's parameters + :param ckpt: the checkpoint file to load + :return: The model and the global step. + """ + logger.debug(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd["global_step"] + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + model.cuda() + model.eval() + return {"model": model}, global_step + + +def get_model(mode, model_path): + path_conf, path_ckpt = download_models(mode, model_path) + config = OmegaConf.load(path_conf) + model, step = load_model_from_config(config, path_ckpt) + return model + + +#@title 1.7 SuperRes Define +class DDIMSampler(object): + """ + Sampler for super resolution models + """ + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + logger.warning(f"Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + logger.warning(f"Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sharpening with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + e_t = self.model.apply_model(x, t, c) + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 \ No newline at end of file diff --git a/src/disco/models/modules.py b/src/disco/models/modules.py new file mode 100644 index 00000000..a2336718 --- /dev/null +++ b/src/disco/models/modules.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass +import math + +import torch +from torch import nn + +def append_dims(x, n): + """ + Append `n` `None` values to the end of the `x` array's dimensions + + :param x: The tensor to be reshaped + :param n: The number of dimensions to add + :return: a tensor with the same shape as x, but with additional dimensions of size n inserted at the + front. + """ + return x[(Ellipsis, *(None,) * (n - x.ndim))] + + +def expand_to_planes(x, shape): + """ + Given a tensor x, expand it to a tensor of shape (1, 1, ..., 1, *shape[2:]) + + :param x: The input tensor + :param shape: the shape of the tensor to be expanded + :return: The input tensor x is being expanded to the shape of the output tensor. + """ + return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]]) + + +def alpha_sigma_to_t(alpha, sigma): + """ + Given an alpha and sigma, return the corresponding t + + :param alpha: the rotation angle in radians + :param sigma: The standard deviation of the Gaussian kernel + :return: the angle in radians. + """ + return torch.atan2(sigma, alpha) * 2 / math.pi + + +def t_to_alpha_sigma(t): + """ + Given a tensor of angles, return a tuple of two tensors of the same shape, one containing the cosine + of the angles and the other containing the sine of the angles + + :param t: the time parameter + :return: the cosine and sine of the input t multiplied by pi/2. + """ + return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + + +@dataclass +class DiffusionOutput: + v: torch.Tensor + pred: torch.Tensor + eps: torch.Tensor + + +class ConvBlock(nn.Sequential): + """ + 3x3 conv + ReLU + """ + def __init__(self, c_in, c_out): + super().__init__( + nn.Conv2d(c_in, c_out, 3, padding=1), + nn.ReLU(inplace=True), + ) + + +class SkipBlock(nn.Module): + def __init__(self, main, skip=None): + super().__init__() + self.main = nn.Sequential(*main) + self.skip = skip if skip else nn.Identity() + + def forward(self, input): + return torch.cat([self.main(input), self.skip(input)], dim=1) + + +class FourierFeatures(nn.Module): + def __init__(self, in_features, out_features, std=1.): + super().__init__() + assert out_features % 2 == 0 + self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std) + + def forward(self, input): + f = 2 * math.pi * input @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) diff --git a/src/disco/models/secondarydiffusionimagenet.py b/src/disco/models/secondarydiffusionimagenet.py new file mode 100644 index 00000000..a7e986ac --- /dev/null +++ b/src/disco/models/secondarydiffusionimagenet.py @@ -0,0 +1,70 @@ +from functools import partial + +import torch +from torch import nn +from .modules import ( + FourierFeatures, + ConvBlock, + SkipBlock, + expand_to_planes, + append_dims, + t_to_alpha_sigma, + DiffusionOutput, +) + +class SecondaryDiffusionImageNet(nn.Module): + """ + Secondary diffusion model trained on Imagenet. + """ + def __init__(self): + super().__init__() + c = 64 # The base channel count + + self.timestep_embed = FourierFeatures(1, 16) + + self.net = nn.Sequential( + ConvBlock(3 + 16, c), + ConvBlock(c, c), + SkipBlock([ + nn.AvgPool2d(2), + ConvBlock(c, c * 2), + ConvBlock(c * 2, c * 2), + SkipBlock([ + nn.AvgPool2d(2), + ConvBlock(c * 2, c * 4), + ConvBlock(c * 4, c * 4), + SkipBlock([ + nn.AvgPool2d(2), + ConvBlock(c * 4, c * 8), + ConvBlock(c * 8, c * 4), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + ]), + ConvBlock(c * 8, c * 4), + ConvBlock(c * 4, c * 2), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + ]), + ConvBlock(c * 4, c * 2), + ConvBlock(c * 2, c), + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + ]), + ConvBlock(c * 2, c), + nn.Conv2d(c, 3, 3, padding=1), + ) + + def forward(self, input, t): + """ + Given an input, a time step, and a diffusion network, + compute the diffusion network's output, + the predicted value, and the diffusion noise + + :param input: the input to the diffusion network + :param t: The time step + :return: The diffusion output object contains the diffusion parameters, the predicted value, and + the residuals. + """ + timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape) + v = self.net(torch.cat([input, timestep_embed], dim=1)) + alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t)) + pred = input * alphas - v * sigmas + eps = input * sigmas + v * alphas + return DiffusionOutput(v, pred, eps) diff --git a/src/disco/models/secondarydiffusionimagenet2.py b/src/disco/models/secondarydiffusionimagenet2.py new file mode 100644 index 00000000..15aab22d --- /dev/null +++ b/src/disco/models/secondarydiffusionimagenet2.py @@ -0,0 +1,93 @@ +from functools import partial + +import torch +from torch import nn +from .modules import ( + FourierFeatures, + ConvBlock, + SkipBlock, + expand_to_planes, + append_dims, + t_to_alpha_sigma, + DiffusionOutput, +) + +# I think this is functionally identical to the other model, +# just with the down/up sampling functions aliased to a class attribute for legibility +class SecondaryDiffusionImageNet2(nn.Module): + """ + Secondary diffusion model trained on Imagenet. + """ + def __init__(self): + super().__init__() + c = 64 # The base channel count + cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8] + + self.timestep_embed = FourierFeatures(1, 16) + self.down = nn.AvgPool2d(2) + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) + + self.net = nn.Sequential( + ConvBlock(3 + 16, cs[0]), + ConvBlock(cs[0], cs[0]), + SkipBlock([ + self.down, + ConvBlock(cs[0], cs[1]), + ConvBlock(cs[1], cs[1]), + SkipBlock([ + self.down, + ConvBlock(cs[1], cs[2]), + ConvBlock(cs[2], cs[2]), + SkipBlock([ + self.down, + ConvBlock(cs[2], cs[3]), + ConvBlock(cs[3], cs[3]), + SkipBlock([ + self.down, + ConvBlock(cs[3], cs[4]), + ConvBlock(cs[4], cs[4]), + SkipBlock([ + self.down, + ConvBlock(cs[4], cs[5]), + ConvBlock(cs[5], cs[5]), + ConvBlock(cs[5], cs[5]), + ConvBlock(cs[5], cs[4]), + self.up, + ]), + ConvBlock(cs[4] * 2, cs[4]), + ConvBlock(cs[4], cs[3]), + self.up, + ]), + ConvBlock(cs[3] * 2, cs[3]), + ConvBlock(cs[3], cs[2]), + self.up, + ]), + ConvBlock(cs[2] * 2, cs[2]), + ConvBlock(cs[2], cs[1]), + self.up, + ]), + ConvBlock(cs[1] * 2, cs[1]), + ConvBlock(cs[1], cs[0]), + self.up, + ]), + ConvBlock(cs[0] * 2, cs[0]), + nn.Conv2d(cs[0], 3, 3, padding=1), + ) + + def forward(self, input, t): + """ + Given an input, a time step, and a diffusion network, + return a DiffusionOutput object containing the diffusion network's output, + prediction, and diffusion noise + + :param input: the input to the diffusion network + :param t: the time step of the diffusion process + :return: The diffusion output object contains the diffusion parameters, the predicted value, and + the residuals. + """ + timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape) + v = self.net(torch.cat([input, timestep_embed], dim=1)) + alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t)) + pred = input * alphas - v * sigmas + eps = input * sigmas + v * alphas + return DiffusionOutput(v, pred, eps) \ No newline at end of file