From 75434baf145f0950c61e98dde0edc2c8c4f85ce7 Mon Sep 17 00:00:00 2001 From: Raphael Tang Date: Fri, 7 Jul 2023 03:51:07 +0000 Subject: [PATCH 01/12] Add visuolinguistic analysis notebooks --- notebooks/0-setup.ipynb | 220 + notebooks/1-visuosyntactic-analyses.ipynb | 12799 ++++++++++++++++++++ notebooks/2-visuosemantic-analyses.ipynb | 333 + 3 files changed, 13352 insertions(+) create mode 100644 notebooks/0-setup.ipynb create mode 100644 notebooks/1-visuosyntactic-analyses.ipynb create mode 100644 notebooks/2-visuosemantic-analyses.ipynb diff --git a/notebooks/0-setup.ipynb b/notebooks/0-setup.ipynb new file mode 100644 index 0000000..f8f05b7 --- /dev/null +++ b/notebooks/0-setup.ipynb @@ -0,0 +1,220 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "65bf851a", + "metadata": {}, + "source": [ + "# Installation of Prerequisites" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "4616b907", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2023-07-07 00:30:51-- https://nlp.stanford.edu/software/stanford-corenlp-4.5.4.zip\n", + "Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140\n", + "Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.\n", + "HTTP request sent, awaiting response... 302 FOUND\n", + "Location: https://downloads.cs.stanford.edu/nlp/software/stanford-corenlp-4.5.4.zip [following]\n", + "--2023-07-07 00:30:52-- https://downloads.cs.stanford.edu/nlp/software/stanford-corenlp-4.5.4.zip\n", + "Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22\n", + "Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 506470124 (483M) [application/zip]\n", + "Saving to: ‘stanford-corenlp-4.5.4.zip’\n", + "\n", + "orenlp-4.5.4.zip 18%[==> ] 88.40M 5.08MB/s eta 60s ^C\n", + "Archive: stanford-corenlp-4.5.4.zip\n", + " End-of-central-directory signature not found. Either this file is not\n", + " a zipfile, or it constitutes one disk of a multi-part archive. In the\n", + " latter case the central directory and zipfile comment will be found on\n", + " the last disk(s) of this archive.\n", + "unzip: cannot find zipfile directory in one of stanford-corenlp-4.5.4.zip or\n", + " stanford-corenlp-4.5.4.zip.zip, and cannot find stanford-corenlp-4.5.4.zip.ZIP, period.\n" + ] + } + ], + "source": [ + "!wget https://nlp.stanford.edu/software/stanford-corenlp-4.5.4.zip\n", + "!unzip stanford-corenlp-4.5.4.zip" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "956dcdf1", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting stanza\n", + " Downloading stanza-1.5.0-py3-none-any.whl (802 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m802.5/802.5 kB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hCollecting emoji (from stanza)\n", + " Downloading emoji-2.6.0.tar.gz (356 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m356.6/356.6 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25hRequirement already satisfied: numpy in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from stanza) (1.24.3)\n", + "Collecting protobuf (from stanza)\n", + " Downloading protobuf-4.23.4-cp37-abi3-manylinux2014_x86_64.whl (304 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m304.5/304.5 kB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: requests in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from stanza) (2.29.0)\n", + "Requirement already satisfied: six in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from stanza) (1.16.0)\n", + "Requirement already satisfied: torch>=1.3.0 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from stanza) (2.0.1)\n", + "Requirement already satisfied: tqdm in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from stanza) (4.65.0)\n", + "Requirement already satisfied: filelock in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from torch>=1.3.0->stanza) (3.9.0)\n", + "Requirement already satisfied: typing-extensions in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from torch>=1.3.0->stanza) (4.6.3)\n", + "Requirement already satisfied: sympy in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from torch>=1.3.0->stanza) (1.11.1)\n", + "Requirement already satisfied: networkx in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from torch>=1.3.0->stanza) (2.8.4)\n", + "Requirement already satisfied: jinja2 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from torch>=1.3.0->stanza) (3.1.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from requests->stanza) (2.0.4)\n", + "Requirement already satisfied: idna<4,>=2.5 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from requests->stanza) (3.4)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from requests->stanza) (1.26.16)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from requests->stanza) (2023.5.7)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from jinja2->torch>=1.3.0->stanza) (2.1.1)\n", + "Requirement already satisfied: mpmath>=0.19 in /home/ralph/miniconda3/envs/daam/lib/python3.8/site-packages (from sympy->torch>=1.3.0->stanza) (1.2.1)\n", + "Building wheels for collected packages: emoji\n", + " Building wheel for emoji (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25h Created wheel for emoji: filename=emoji-2.6.0-py2.py3-none-any.whl size=351312 sha256=438edf73fbaa4e062879aa454062e0f1172c51d84c7df39f189be5af4a37dd92\n", + " Stored in directory: /home/ralph/.cache/pip/wheels/65/d8/90/e78a11fccc67c1983e5496ee1b6c831bce3185ed9dec4cd2c2\n", + "Successfully built emoji\n", + "Installing collected packages: protobuf, emoji, stanza\n", + "Successfully installed emoji-2.6.0 protobuf-4.23.4 stanza-1.5.0\n" + ] + } + ], + "source": [ + "!pip install stanza" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8bf1406f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-07-07 00:32:22 WARNING: Directory stanford-corenlp-4.5.4 already exists. Please install CoreNLP to a new directory.\n" + ] + } + ], + "source": [ + "import stanza\n", + "stanza.install_corenlp(dir='stanford-corenlp-4.5.4')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "321801f0", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install daam==0.1.0" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ec55a0b3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2023-07-07 00:33:02-- http://images.cocodataset.org/annotations/annotations_trainval2014.zip\n", + "Resolving images.cocodataset.org (images.cocodataset.org)... 16.182.64.121, 52.216.220.249, 54.231.170.209, ...\n", + "Connecting to images.cocodataset.org (images.cocodataset.org)|16.182.64.121|:80... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 252872794 (241M) [application/zip]\n", + "Saving to: ‘annotations_trainval2014.zip’\n", + "\n", + "annotations_trainva 100%[===================>] 241.16M 34.5MB/s in 9.6s \n", + "\n", + "2023-07-07 00:33:12 (25.1 MB/s) - ‘annotations_trainval2014.zip’ saved [252872794/252872794]\n", + "\n" + ] + } + ], + "source": [ + "!wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1ed5edd7", + "metadata": {}, + "outputs": [], + "source": [ + "!mkdir -p coco\n", + "!mv annotations_* coco" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2f6ed111", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Errno 2] No such file or directory: 'coco'\n", + "/home/ralph/programming/daam/notebooks/coco\n", + "Archive: annotations_trainval2014.zip\n", + " inflating: annotations/instances_train2014.json \n", + " inflating: annotations/instances_val2014.json \n", + " inflating: annotations/person_keypoints_train2014.json \n", + " inflating: annotations/person_keypoints_val2014.json \n", + " inflating: annotations/captions_train2014.json \n", + " inflating: annotations/captions_val2014.json \n" + ] + } + ], + "source": [ + "%cd coco\n", + "!unzip annotations_*" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/1-visuosyntactic-analyses.ipynb b/notebooks/1-visuosyntactic-analyses.ipynb new file mode 100644 index 0000000..e0b1414 --- /dev/null +++ b/notebooks/1-visuosyntactic-analyses.ipynb @@ -0,0 +1,12799 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": true + }, + "source": [ + "# Visuosyntactic Analyses" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: CORENLP_HOME=stanford-corenlp-4.5.4\n" + ] + } + ], + "source": [ + "%env CORENLP_HOME=stanford-corenlp-4.5.4" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-07-07 00:32:43 WARNING: Directory stanford-corenlp-4.5.4 already exists. Please install CoreNLP to a new directory.\n", + "2023-07-07 00:32:43 INFO: Writing properties to tmp file: corenlp_server-8e8c4b3e34ad4e6a.props\n" + ] + } + ], + "source": [ + "from stanza.server import CoreNLPClient\n", + "import stanza\n", + "\n", + "stanza.install_corenlp(dir='stanford-corenlp-4.5.4')\n", + "client = CoreNLPClient(annotators=['tokenize', 'ssplit', 'pos', 'lemma', 'ner', 'parse', 'depparse','coref'], timeout=30000, memory='6G')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Generate DAAM Maps" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "import json\n", + "\n", + "annotations = json.load(Path('coco/annotations/captions_val2014.json').open())" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['info', 'images', 'licenses', 'annotations'])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "annotations.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.DataFrame(annotations['annotations'])" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "!mkdir -p experiments/visuosyntax" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "df = df.sample(1500, replace=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "torch.cuda.amp.autocast().__enter__()\n", + "torch.set_grad_enabled(False);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from diffusers import StableDiffusionPipeline\n", + "from daam import set_seed, trace\n", + "\n", + "pipe = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2-1-base')" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "pipe.to('cuda:0');" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "\n", + "for _, row in tqdm(df.iterrows(), total=len(df)):\n", + " image_id, caption = row.image_id, row.caption\n", + " gen = set_seed(image_id)\n", + " output_folder = Path('experiments/visuosyntax')\n", + " \n", + " with trace(pipe) as tc:\n", + " out = pipe(caption, num_inference_steps=30, generator=gen)\n", + " exp = tc.to_experiment(output_folder, id=str(image_id), seed=image_id)\n", + " exp.save(output_folder, heat_maps=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Parse and Analyze" + ] + }, + { + "cell_type": "code", + "execution_count": 146, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/1488 [00:00 float:\n", + " i = ((a > t) & (b > t)).float().sum()\n", + " u = ((a > t) | (b > t)).float().sum()\n", + " \n", + " if u < 1e-6:\n", + " return 0.0\n", + " else:\n", + " return (i / u).item()\n", + "\n", + "def ioa(a, b, t: float = 0.15) -> float:\n", + " i = ((a > t) & (b > t)).float().sum()\n", + " a = (a > t).float().sum()\n", + " \n", + " if a < 1e-6:\n", + " return 0.0\n", + " else:\n", + " return (i / a).item()\n", + "\n", + "stats = []\n", + "\n", + "for path in tqdm(list(Path('experiments/visuosyntax').iterdir())):\n", + " exp = GenerationExperiment.load(path)\n", + " sent = client.annotate(exp.prompt).sentence[0]\n", + " heat_map = exp.heat_map() \n", + " word_maps = dict()\n", + " \n", + " for tok in sent.token:\n", + " try:\n", + " word_maps[tok.word] = heat_map.compute_word_heat_map(tok.word).value.cuda()\n", + " except ValueError:\n", + " pass \n", + " \n", + " for edge in sent.enhancedDependencies.edge:\n", + " head = sent.token[edge.source - 1].word\n", + " rel = edge.dep\n", + " dep = sent.token[edge.target - 1].word\n", + " \n", + " try:\n", + " head_heat_map = word_maps[head]\n", + " dep_heat_map = word_maps[dep]\n", + " except KeyError:\n", + " continue\n", + " \n", + " stats.append(dict(\n", + " rel=rel,\n", + " iou=iou(head_heat_map, dep_heat_map),\n", + " iod=ioa(dep_heat_map, head_heat_map),\n", + " ioh=ioa(head_heat_map, dep_heat_map)\n", + " ))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Results" + ] + }, + { + "cell_type": "code", + "execution_count": 149, + "metadata": {}, + "outputs": [], + "source": [ + "stats_df = pd.DataFrame(stats)\n", + "res_df = stats_df.groupby('rel').agg(count=('rel', len), mIoU=('iou', 'mean'), mIoD=('iod', 'mean'), mIoH=('ioh', 'mean'))\n", + "res_df = res_df.sort_values('count', ascending=False).iloc[:10]\n", + "res_df['delta'] = (res_df['mIoH'] - res_df['mIoD']).abs()" + ] + }, + { + "cell_type": "code", + "execution_count": 150, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
mIoUmIoDmIoHdelta
rel
punct0.0998572.4484100.1032952.345114
nmod:of8.65707412.85535821.9878569.132498
compound33.43411359.13079549.9851709.145626
nsubj5.02722710.69213322.71029312.018160
case3.83195218.0880065.89582912.192177
det0.44781113.0129750.65780812.355168
conj:and28.43592855.50186739.64988315.851984
acl6.45200928.69241511.10118417.591231
obj6.64195210.56667336.44249625.875823
amod14.69087845.06272019.05172026.011000
\n", + "
" + ], + "text/plain": [ + " mIoU mIoD mIoH delta\n", + "rel \n", + "punct 0.099857 2.448410 0.103295 2.345114\n", + "nmod:of 8.657074 12.855358 21.987856 9.132498\n", + "compound 33.434113 59.130795 49.985170 9.145626\n", + "nsubj 5.027227 10.692133 22.710293 12.018160\n", + "case 3.831952 18.088006 5.895829 12.192177\n", + "det 0.447811 13.012975 0.657808 12.355168\n", + "conj:and 28.435928 55.501867 39.649883 15.851984\n", + "acl 6.452009 28.692415 11.101184 17.591231\n", + "obj 6.641952 10.566673 36.442496 25.875823\n", + "amod 14.690878 45.062720 19.051720 26.011000" + ] + }, + "execution_count": 150, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res_df.drop(columns=['count'], inplace=True)\n", + "res_df = res_df.transform(lambda x: x * 100)\n", + "res_df.sort_values('delta')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/notebooks/2-visuosemantic-analyses.ipynb b/notebooks/2-visuosemantic-analyses.ipynb new file mode 100644 index 0000000..ecad308 --- /dev/null +++ b/notebooks/2-visuosemantic-analyses.ipynb @@ -0,0 +1,333 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0cbd9358", + "metadata": {}, + "source": [ + "# Visuosemantic Analyses" + ] + }, + { + "cell_type": "markdown", + "id": "10418469", + "metadata": {}, + "source": [ + "## Adjectival Entanglement" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df263c0e", + "metadata": {}, + "outputs": [], + "source": [ + "from diffusers import StableDiffusionPipeline\n", + "from matplotlib import pyplot as plt\n", + "import numpy as np\n", + "import time\n", + "import torch\n", + "import random\n", + "import daam\n", + "\n", + "def set_seed(seed):\n", + " gen = torch.Generator(device='cuda')\n", + " random.seed(seed)\n", + " np.random.seed(seed)\n", + " torch.cuda.manual_seed(seed)\n", + " torch.manual_seed(seed)\n", + "\n", + " return gen.manual_seed(s)\n", + "\n", + "\n", + "model = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2-base')\n", + "model = model.to('cuda')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5f24052d", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0bec2b27ce4f4ea19b7f768f42510a55", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/20 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.rcParams['figure.figsize'] = (8, 8)\n", + "fig, ax = make_im_subplots(2, 2)\n", + "\n", + "# Original images\n", + "ax[0, 1].imshow(blue_image)\n", + "ax[1, 0].imshow(green_image)\n", + "ax[1, 1].imshow(red_image)\n", + "\n", + "# Heat map\n", + "blue_map.plot_overlay(blue_image, ax=ax[0, 0])\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "cd2d3c03", + "metadata": {}, + "source": [ + "## Cohyponym Entanglement" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "09eee34f", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ce88426a80e2426ba6a2b4fe10dc04bd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/20 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.rcParams['figure.figsize'] = (8, 8)\n", + "fig, ax = make_im_subplots(2, 3)\n", + "\n", + "with daam.trace(model, save_heads=True) as trc:\n", + " im1 = model('a zebra and a giraffe', num_inference_steps=20).images[0]\n", + " heat_map = trc.compute_global_heat_map()\n", + " zebra_map = heat_map.compute_word_heat_map('zebra')\n", + " giraffe_map = heat_map.compute_word_heat_map('giraffe')\n", + "\n", + "with daam.trace(model, save_heads=True) as trc:\n", + " im2 = model('a crab and a lobster', num_inference_steps=20).images[0]\n", + " heat_map = trc.compute_global_heat_map()\n", + " crab_map = heat_map.compute_word_heat_map('crab')\n", + " lobster_map = heat_map.compute_word_heat_map('lobster')\n", + "\n", + "ax[0, 0].imshow(im1)\n", + "ax[1, 0].imshow(im2)\n", + "zebra_map.plot_overlay(im1, ax=ax[0, 1])\n", + "giraffe_map.plot_overlay(im1, ax=ax[0, 2])\n", + "\n", + "crab_map.plot_overlay(im2, ax=ax[1, 1])\n", + "lobster_map.plot_overlay(im2, ax=ax[1, 2])\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c3637649", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e8076b19247e4ac6bc54dca3516f27cb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/20 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.rcParams['figure.figsize'] = (8, 8)\n", + "fig, ax = make_im_subplots(2, 3)\n", + "\n", + "with daam.trace(model, save_heads=True) as trc:\n", + " im1 = model('a zebra and a fridge', num_inference_steps=20).images[0]\n", + " heat_map = trc.compute_global_heat_map()\n", + " zebra_map = heat_map.compute_word_heat_map('zebra')\n", + " giraffe_map = heat_map.compute_word_heat_map('fridge')\n", + "\n", + "with daam.trace(model, save_heads=True) as trc:\n", + " im2 = model('a crab and a beach ball', num_inference_steps=20).images[0]\n", + " heat_map = trc.compute_global_heat_map()\n", + " crab_map = heat_map.compute_word_heat_map('crab')\n", + " lobster_map = heat_map.compute_word_heat_map('ball')\n", + "\n", + "ax[0, 0].imshow(im1)\n", + "ax[1, 0].imshow(im2)\n", + "zebra_map.plot_overlay(im1, ax=ax[0, 1])\n", + "giraffe_map.plot_overlay(im1, ax=ax[0, 2])\n", + "\n", + "crab_map.plot_overlay(im2, ax=ax[1, 1])\n", + "lobster_map.plot_overlay(im2, ax=ax[1, 2])\n", + "\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From f2290b739643ffe5f47913ebbab4d3d5c6ac270d Mon Sep 17 00:00:00 2001 From: Raphael Tang Date: Fri, 7 Jul 2023 03:52:22 +0000 Subject: [PATCH 02/12] Remove drivel --- notebooks/1-visuosyntactic-analyses.ipynb | 12382 +------------------- 1 file changed, 2 insertions(+), 12380 deletions(-) diff --git a/notebooks/1-visuosyntactic-analyses.ipynb b/notebooks/1-visuosyntactic-analyses.ipynb index e0b1414..9549497 100644 --- a/notebooks/1-visuosyntactic-analyses.ipynb +++ b/notebooks/1-visuosyntactic-analyses.ipynb @@ -179,12387 +179,9 @@ }, { "cell_type": "code", - "execution_count": 146, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 0%| | 0/1488 [00:00 Date: Sat, 8 Jul 2023 01:56:35 -0400 Subject: [PATCH 03/12] Update README.md - Add dataset link. --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 0a86f04..e6827e8 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,8 @@ exp = GenerationExperiment.load('experiment-dir') # load the experiment We'll continue adding docs. In the meantime, check out the `GenerationExperiment`, `GlobalHeatMap`, and `DiffusionHeatMapHooker` classes, as well as the `daam/run/*.py` example scripts. -Our datasets are here: https://git.uwaterloo.ca/r33tang/daam-data +Our datasets are here: https://git.uwaterloo.ca/r33tang/daam-data. +You can also download the COCO-Gen dataset from the paper [here](http://ralphtang.com/coco-gen.tar.gz). ## See Also - [DAAM-i2i](https://github.com/RishiDarkDevil/daam-i2i), an extension of DAAM to image-to-image attribution. From c99e5b1cae28e3b99103d045e5d0abf59589b300 Mon Sep 17 00:00:00 2001 From: Raphael Tang Date: Sat, 8 Jul 2023 01:57:40 -0400 Subject: [PATCH 04/12] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e6827e8..c9e7a73 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ exp = GenerationExperiment.load('experiment-dir') # load the experiment We'll continue adding docs. In the meantime, check out the `GenerationExperiment`, `GlobalHeatMap`, and `DiffusionHeatMapHooker` classes, as well as the `daam/run/*.py` example scripts. Our datasets are here: https://git.uwaterloo.ca/r33tang/daam-data. -You can also download the COCO-Gen dataset from the paper [here](http://ralphtang.com/coco-gen.tar.gz). +You can also download the COCO-Gen dataset from the paper at http://ralphtang.com/coco-gen.tar.gz. ## See Also - [DAAM-i2i](https://github.com/RishiDarkDevil/daam-i2i), an extension of DAAM to image-to-image attribution. From 5dabd0f87c6ee7f4d08324027d9fa8d58fcdd8df Mon Sep 17 00:00:00 2001 From: Raphael Tang Date: Sat, 8 Jul 2023 01:59:41 -0400 Subject: [PATCH 05/12] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c9e7a73..476373d 100644 --- a/README.md +++ b/README.md @@ -78,8 +78,8 @@ exp = GenerationExperiment.load('experiment-dir') # load the experiment We'll continue adding docs. In the meantime, check out the `GenerationExperiment`, `GlobalHeatMap`, and `DiffusionHeatMapHooker` classes, as well as the `daam/run/*.py` example scripts. -Our datasets are here: https://git.uwaterloo.ca/r33tang/daam-data. -You can also download the COCO-Gen dataset from the paper at http://ralphtang.com/coco-gen.tar.gz. +You download the COCO-Gen dataset from the paper at http://ralphtang.com/coco-gen.tar.gz. +If click the link doesn't work on your browser, copy and paste the link in a new tab, or use a CLI utility such as `wget`. ## See Also - [DAAM-i2i](https://github.com/RishiDarkDevil/daam-i2i), an extension of DAAM to image-to-image attribution. From 9928126340d9bc91d115975425f2c7828ec1841a Mon Sep 17 00:00:00 2001 From: Raphael Tang Date: Sat, 8 Jul 2023 02:00:14 -0400 Subject: [PATCH 06/12] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 476373d..d61176a 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ exp = GenerationExperiment.load('experiment-dir') # load the experiment We'll continue adding docs. In the meantime, check out the `GenerationExperiment`, `GlobalHeatMap`, and `DiffusionHeatMapHooker` classes, as well as the `daam/run/*.py` example scripts. You download the COCO-Gen dataset from the paper at http://ralphtang.com/coco-gen.tar.gz. -If click the link doesn't work on your browser, copy and paste the link in a new tab, or use a CLI utility such as `wget`. +If clicking the link doesn't work on your browser, copy and paste it in a new tab, or use a CLI utility such as `wget`. ## See Also - [DAAM-i2i](https://github.com/RishiDarkDevil/daam-i2i), an extension of DAAM to image-to-image attribution. From 95599b9f137e6017efea3d8796e153eb8ba6bd02 Mon Sep 17 00:00:00 2001 From: Raphael Tang Date: Tue, 11 Jul 2023 01:45:29 -0400 Subject: [PATCH 07/12] Add ACL citation to readme --- README.md | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index d61176a..aaa8d84 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # What the DAAM: Interpreting Stable Diffusion Using Cross Attention -[![HF Spaces](https://img.shields.io/badge/HuggingFace%20Space-online-green.svg)](https://huggingface.co/spaces/tetrisd/Diffusion-Attentive-Attribution-Maps) [![Citation](https://img.shields.io/badge/Citation-arXiv-orange.svg)](https://gist.githubusercontent.com/daemon/c526f4f9ab2d5e946e6bae90a9a02571/raw/02dcc6cb09a39559b39449a7d27d3b950bec39bd/daam-citation.bib) [![PyPi version](https://badgen.net/pypi/v/daam?color=blue)](https://pypi.org/project/daam) [![Downloads](https://static.pepy.tech/badge/daam)](https://pepy.tech/project/daam) +[![HF Spaces](https://img.shields.io/badge/HuggingFace%20Space-online-green.svg)](https://huggingface.co/spaces/tetrisd/Diffusion-Attentive-Attribution-Maps) [![Citation](https://img.shields.io/badge/Citation-ACL-orange.svg)](https://gist.github.com/daemon/639de6fea584d7df1a62f04a2ea0cdad) [![PyPi version](https://badgen.net/pypi/v/daam?color=blue)](https://pypi.org/project/daam) [![Downloads](https://static.pepy.tech/badge/daam)](https://pepy.tech/project/daam) ![example image](example.jpg) @@ -8,7 +8,7 @@ I regularly update this codebase. Please submit an issue if you have any questions. -In [our paper](https://arxiv.org/abs/2210.04885), we propose diffusion attentive attribution maps (DAAM), a cross attention-based approach for interpreting Stable Diffusion. +In [our paper](https://aclanthology.org/2023.acl-long.310), we propose diffusion attentive attribution maps (DAAM), a cross attention-based approach for interpreting Stable Diffusion. Check out our demo: https://huggingface.co/spaces/tetrisd/Diffusion-Attentive-Attribution-Maps. See our [documentation](https://castorini.github.io/daam/), hosted by GitHub pages, and [our Colab notebook](https://colab.research.google.com/drive/1miGauqa07uHnDoe81NmbmtTtnupmlipv?usp=sharing), updated for v0.0.11. @@ -90,10 +90,19 @@ If clicking the link doesn't work on your browser, copy and paste it in a new ta ## Citation ``` -@article{tang2022daam, - title={What the {DAAM}: Interpreting Stable Diffusion Using Cross Attention}, - author={Tang, Raphael and Liu, Linqing and Pandey, Akshat and Jiang, Zhiying and Yang, Gefei and Kumar, Karun and Stenetorp, Pontus and Lin, Jimmy and Ture, Ferhan}, - journal={arXiv:2210.04885}, - year={2022} +@inproceedings{tang2023daam, + title = "What the {DAAM}: Interpreting Stable Diffusion Using Cross Attention", + author = "Tang, Raphael and + Liu, Linqing and + Pandey, Akshat and + Jiang, Zhiying and + Yang, Gefei and + Kumar, Karun and + Stenetorp, Pontus and + Lin, Jimmy and + Ture, Ferhan", + booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", + year = "2023", + url = "https://aclanthology.org/2023.acl-long.310", } ``` From 2d16393ce73ed5df6f9607f9ce7291bb9585d016 Mon Sep 17 00:00:00 2001 From: Raphael Tang Date: Wed, 12 Jul 2023 10:45:36 -0400 Subject: [PATCH 08/12] Fix typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index aaa8d84..b2e1523 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ exp = GenerationExperiment.load('experiment-dir') # load the experiment We'll continue adding docs. In the meantime, check out the `GenerationExperiment`, `GlobalHeatMap`, and `DiffusionHeatMapHooker` classes, as well as the `daam/run/*.py` example scripts. -You download the COCO-Gen dataset from the paper at http://ralphtang.com/coco-gen.tar.gz. +You can download the COCO-Gen dataset from the paper at http://ralphtang.com/coco-gen.tar.gz. If clicking the link doesn't work on your browser, copy and paste it in a new tab, or use a CLI utility such as `wget`. ## See Also From d4311db7faeade5a838f5ae9d4c6a1bcef5b32f4 Mon Sep 17 00:00:00 2001 From: Raphael Tang Date: Tue, 18 Jul 2023 12:16:05 -0400 Subject: [PATCH 09/12] Fix downblocks list length (#45) --- daam/hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daam/hook.py b/daam/hook.py index ffea6bc..320ba6f 100644 --- a/daam/hook.py +++ b/daam/hook.py @@ -101,7 +101,7 @@ def locate(self, model: UNet2DConditionModel) -> List[Attention]: self.layer_names.clear() blocks_list = [] up_names = ['up'] * len(model.up_blocks) - down_names = ['down'] * len(model.up_blocks) + down_names = ['down'] * len(model.down_blocks) for unet_block, name in itertools.chain( zip(model.up_blocks, up_names), From a3231297bc3a6780a5ac7e56e3c5080badde4003 Mon Sep 17 00:00:00 2001 From: Raphael Tang Date: Sun, 23 Jul 2023 13:03:16 -0400 Subject: [PATCH 10/12] Fix Colab notebook Closes #46 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b2e1523..8eb1bba 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ I regularly update this codebase. Please submit an issue if you have any questio In [our paper](https://aclanthology.org/2023.acl-long.310), we propose diffusion attentive attribution maps (DAAM), a cross attention-based approach for interpreting Stable Diffusion. Check out our demo: https://huggingface.co/spaces/tetrisd/Diffusion-Attentive-Attribution-Maps. -See our [documentation](https://castorini.github.io/daam/), hosted by GitHub pages, and [our Colab notebook](https://colab.research.google.com/drive/1miGauqa07uHnDoe81NmbmtTtnupmlipv?usp=sharing), updated for v0.0.11. +See our [documentation](https://castorini.github.io/daam/), hosted by GitHub pages, and [our Colab notebook](https://colab.research.google.com/drive/1miGauqa07uHnDoe81NmbmtTtnupmlipv?usp=sharing), updated for v0.1.0. ## Getting Started First, install [PyTorch](https://pytorch.org) for your platform. From b32e13310721a860bb21e74d409a210d4ded9112 Mon Sep 17 00:00:00 2001 From: Dave Lage Date: Wed, 29 Nov 2023 19:08:57 -0500 Subject: [PATCH 11/12] Unhook cross attention processor (#53) --- daam/trace.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/daam/trace.py b/daam/trace.py index 6932a4a..54b9f71 100644 --- a/daam/trace.py +++ b/daam/trace.py @@ -279,8 +279,12 @@ def __call__( return hidden_states def _hook_impl(self): + self.original_processor = self.module.processor self.module.set_processor(self) + def _unhook_impl(self): + self.module.set_processor(self.original_processor) + @property def num_heat_maps(self): return len(next(iter(self.heat_maps.values()))) From 09564e048cf5ffc1523aa0605736e7f178ca505a Mon Sep 17 00:00:00 2001 From: Raphael Tang Date: Sun, 7 Jan 2024 15:46:19 -0500 Subject: [PATCH 12/12] Add Stable Diffusion XL support (#56) --- README.md | 13 ++++++------ daam/_version.py | 2 +- daam/hook.py | 10 +++++++--- daam/run/generate.py | 19 ++++++++++++++---- daam/trace.py | 47 +++++++++++++++++++++++++++++++++----------- daam/utils.py | 6 +++++- requirements.txt | 6 +++--- 7 files changed, 74 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 8eb1bba..73de368 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ ![example image](example.jpg) -### Updated to support Diffusers 0.16.1! +### Updated to support Stable Diffusion XL (SDXL) and Diffusers 0.21.1! I regularly update this codebase. Please submit an issue if you have any questions. @@ -33,6 +33,7 @@ dog.heat_map.png running.heat_map.png prompt.txt ``` Your current working directory will now contain the generated image as `output.png` and a DAAM map for every word, as well as some auxiliary data. You can see more options for `daam` by running `daam -h`. +To use Stable Diffusion XL as the backend, run `daam --model xl-base-1.0 "Dog jumping"`. ### Using DAAM as a Library @@ -40,23 +41,23 @@ Import and use DAAM as follows: ```python from daam import trace, set_seed -from diffusers import StableDiffusionPipeline +from diffusers import DiffusionPipeline from matplotlib import pyplot as plt import torch -model_id = 'stabilityai/stable-diffusion-2-base' +model_id = 'stabilityai/stable-diffusion-xl-base-1.0' device = 'cuda' -pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True) +pipe = DiffusionPipeline.from_pretrained(model_id, use_auth_token=True, torch_dtype=torch.float16, use_safetensors=True, variant='fp16') pipe = pipe.to(device) prompt = 'A dog runs across the field' gen = set_seed(0) # for reproducibility -with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad(): +with torch.no_grad(): with trace(pipe) as tc: - out = pipe(prompt, num_inference_steps=30, generator=gen) + out = pipe(prompt, num_inference_steps=50, generator=gen) heat_map = tc.compute_global_heat_map() heat_map = heat_map.compute_word_heat_map('dog') heat_map.plot_overlay(out.images[0]) diff --git a/daam/_version.py b/daam/_version.py index b794fd4..7fd229a 100644 --- a/daam/_version.py +++ b/daam/_version.py @@ -1 +1 @@ -__version__ = '0.1.0' +__version__ = '0.2.0' diff --git a/daam/hook.py b/daam/hook.py index 320ba6f..f82762c 100644 --- a/daam/hook.py +++ b/daam/hook.py @@ -55,9 +55,13 @@ def unhook(self): return self - def monkey_patch(self, fn_name, fn): - self.old_state[f'old_fn_{fn_name}'] = getattr(self.module, fn_name) - setattr(self.module, fn_name, functools.partial(fn, self.module)) + def monkey_patch(self, fn_name, fn, strict: bool = True): + try: + self.old_state[f'old_fn_{fn_name}'] = getattr(self.module, fn_name) + setattr(self.module, fn_name, functools.partial(fn, self.module)) + except AttributeError: + if strict: + raise def monkey_super(self, fn_name, *args, **kwargs): return self.old_state[f'old_fn_{fn_name}'](*args, **kwargs) diff --git a/daam/run/generate.py b/daam/run/generate.py index 0ad063d..c191b51 100644 --- a/daam/run/generate.py +++ b/daam/run/generate.py @@ -7,7 +7,7 @@ import time import pandas as pd -from diffusers import StableDiffusionPipeline +from diffusers import StableDiffusionPipeline, DiffusionPipeline from tqdm import tqdm import inflect import numpy as np @@ -25,7 +25,8 @@ def main(): 'v2-base': 'stabilityai/stable-diffusion-2-base', 'v2-large': 'stabilityai/stable-diffusion-2', 'v2-1-base': 'stabilityai/stable-diffusion-2-1-base', - 'v2-1-large': 'stabilityai/stable-diffusion-2-1' + 'v2-1-large': 'stabilityai/stable-diffusion-2-1', + 'xl-base-1.0': 'stabilityai/stable-diffusion-xl-base-1.0', } parser = argparse.ArgumentParser() @@ -192,10 +193,20 @@ def main(): prompts = new_prompts prompts = prompts[:args.gen_limit] - pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True) + + if 'xl' in model_id: + pipe = DiffusionPipeline.from_pretrained( + model_id, + use_auth_token=True, + torch_dtype=torch.float16, + use_safetensors=True, variant='fp16' + ) + else: + pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True) + pipe = auto_device(pipe) - with auto_autocast(dtype=torch.float16), torch.no_grad(): + with torch.no_grad(): for gen_idx, (prompt_id, prompt) in enumerate(tqdm(prompts)): seed = int(time.time()) if args.random_seed else args.seed prompt = prompt.replace(',', ' ,').replace('.', ' .').strip() diff --git a/daam/trace.py b/daam/trace.py index 54b9f71..bffdbbb 100644 --- a/daam/trace.py +++ b/daam/trace.py @@ -2,7 +2,8 @@ from typing import List, Type, Any, Dict, Tuple, Union import math -from diffusers import StableDiffusionPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline +from diffusers.image_processor import VaeImageProcessor from diffusers.models.attention_processor import Attention import numpy as np import PIL.Image as Image @@ -21,8 +22,7 @@ class DiffusionHeatMapHooker(AggregateHooker): def __init__( self, - pipeline: - StableDiffusionPipeline, + pipeline: Union[StableDiffusionPipeline, StableDiffusionXLPipeline], low_memory: bool = False, load_heads: bool = False, save_heads: bool = False, @@ -30,7 +30,7 @@ def __init__( ): self.all_heat_maps = RawHeatMapCollection() h = (pipeline.unet.config.sample_size * pipeline.vae_scale_factor) - self.latent_hw = 4096 if h == 512 else 9216 # 64x64 or 96x96 depending on if it's 2.0-v or 2.0 + self.latent_hw = 4096 if h == 512 or h == 1024 else 9216 # 64x64 or 96x96 depending on if it's 2.0-v or 2.0 locate_middle = load_heads or save_heads self.locator = UNetCrossAttentionLocator(restrict={0} if low_memory else None, locate_middle_block=locate_middle) self.last_prompt: str = '' @@ -52,6 +52,9 @@ def __init__( modules.append(PipelineHooker(pipeline, self)) + if type(pipeline) == StableDiffusionXLPipeline: + modules.append(ImageProcessorHooker(pipeline.image_processor, self)) + super().__init__(modules) self.pipe = pipeline @@ -129,6 +132,21 @@ def compute_global_heat_map(self, prompt=None, factors=None, head_idx=None, laye return GlobalHeatMap(self.pipe.tokenizer, prompt, maps) +class ImageProcessorHooker(ObjectHooker[VaeImageProcessor]): + def __init__(self, processor: VaeImageProcessor, parent_trace: 'trace'): + super().__init__(processor) + self.parent_trace = parent_trace + + def _hooked_postprocess(hk_self, _: VaeImageProcessor, *args, **kwargs): + images = hk_self.monkey_super('postprocess', *args, **kwargs) + hk_self.parent_trace.last_image = images[0] + + return images + + def _hook_impl(self): + self.monkey_patch('postprocess', self._hooked_postprocess) + + class PipelineHooker(ObjectHooker[StableDiffusionPipeline]): def __init__(self, pipeline: StableDiffusionPipeline, parent_trace: 'trace'): super().__init__(pipeline) @@ -137,12 +155,20 @@ def __init__(self, pipeline: StableDiffusionPipeline, parent_trace: 'trace'): def _hooked_run_safety_checker(hk_self, self: StableDiffusionPipeline, image, *args, **kwargs): image, has_nsfw = hk_self.monkey_super('run_safety_checker', image, *args, **kwargs) - pil_image = self.numpy_to_pil(image) - hk_self.parent_trace.last_image = pil_image[0] + + if self.image_processor: + if torch.is_tensor(image): + images = self.image_processor.postprocess(image, output_type='pil') + else: + images = self.image_processor.numpy_to_pil(image) + else: + images = self.numpy_to_pil(image) + + hk_self.parent_trace.last_image = images[len(images)-1] return image, has_nsfw - def _hooked_encode_prompt(hk_self, _: StableDiffusionPipeline, prompt: Union[str, List[str]], *args, **kwargs): + def _hooked_check_inputs(hk_self, _: StableDiffusionPipeline, prompt: Union[str, List[str]], *args, **kwargs): if not isinstance(prompt, str) and len(prompt) > 1: raise ValueError('Only single prompt generation is supported for heat map computation.') elif not isinstance(prompt, str): @@ -152,13 +178,12 @@ def _hooked_encode_prompt(hk_self, _: StableDiffusionPipeline, prompt: Union[str hk_self.heat_maps.clear() hk_self.parent_trace.last_prompt = last_prompt - ret = hk_self.monkey_super('_encode_prompt', prompt, *args, **kwargs) - return ret + return hk_self.monkey_super('check_inputs', prompt, *args, **kwargs) def _hook_impl(self): - self.monkey_patch('run_safety_checker', self._hooked_run_safety_checker) - self.monkey_patch('_encode_prompt', self._hooked_encode_prompt) + self.monkey_patch('run_safety_checker', self._hooked_run_safety_checker, strict=False) # not present in SDXL + self.monkey_patch('check_inputs', self._hooked_check_inputs) class UNetCrossAttentionHooker(ObjectHooker[Attention]): diff --git a/daam/utils.py b/daam/utils.py index 6b26761..8cfde13 100644 --- a/daam/utils.py +++ b/daam/utils.py @@ -73,12 +73,16 @@ def cache_dir() -> Path: def compute_token_merge_indices(tokenizer, prompt: str, word: str, word_idx: int = None, offset_idx: int = 0): merge_idxs = [] tokens = tokenizer.tokenize(prompt.lower()) + tokens = [x.replace('', '') for x in tokens] # New tokenizer uses wordpiece markers. + if word_idx is None: word = word.lower() - search_tokens = tokenizer.tokenize(word) + search_tokens = [x.replace('', '') for x in tokenizer.tokenize(word)] # New tokenizer uses wordpiece markers. start_indices = [x + offset_idx for x in range(len(tokens)) if tokens[x:x + len(search_tokens)] == search_tokens] + for indice in start_indices: merge_idxs += [i + indice for i in range(0, len(search_tokens))] + if not merge_idxs: raise ValueError(f'Search word {word} not found in prompt!') else: diff --git a/requirements.txt b/requirements.txt index 8448528..aa15461 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ scikit-image -diffusers==0.16.1 +diffusers==0.21.2 spacy gradio ftfy -transformers==4.27.4 +transformers==4.30.2 pandas numba nltk inflect joblib -accelerate==0.18.0 +accelerate==0.23.0