diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..25c28e990731ce4fbdb1a5af212f924d8974e0e9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,153 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# .idea +.idea/ +/idea/ +*.ipr +*.iml +*.iws + +# system +.DS_Store + +# pytorch-lighting logs +lightning_logs/* + +# Edit settings +.editorconfig + +# local results +/workdir/ +.workdir/ + +# dataset +/dataset/ +!/dataset/placeholder.md \ No newline at end of file diff --git a/config/diffsketchedit.yaml b/config/diffsketchedit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89154d6e9aea022c32c34db0adb327a22be99d22 --- /dev/null +++ b/config/diffsketchedit.yaml @@ -0,0 +1,75 @@ +seed: 1 +image_size: 224 +mask_object: False # if the target image contains background, it's better to mask it out +fix_scale: False # if the target image is not squared, it is recommended to fix the scale + +# train +num_iter: 1000 +batch_size: 1 +num_stages: 1 # training stages, you can train x strokes, then freeze them and train another x strokes etc +lr_scheduler: False +lr_decay_rate: 0.1 +decay_steps: [ 1000, 1500 ] +lr: 1 +color_lr: 0.01 +pruning_freq: 50 +color_vars_threshold: 0.1 +width_lr: 0.1 +max_width: 50 # stroke width + +# stroke attrs +num_paths: 96 # number of strokes +width: 1.0 # stroke width +control_points_per_seg: 4 +num_segments: 1 +optim_opacity: True # if True, the stroke opacity is optimized +optim_width: False # if True, the stroke width is optimized +optim_rgba: False # if True, the stroke RGBA is optimized +opacity_delta: 0 # stroke pruning + +# init strokes +attention_init: True # if True, use the attention heads of Dino model to set the location of the initial strokes +xdog_intersec: True # initialize along the edge, mix XDoG and attn up +softmax_temp: 0.5 +cross_attn_res: 16 +self_attn_res: 32 +max_com: 20 # select the number of the self-attn maps +mean_comp: False # the average of the self-attn maps +comp_idx: 0 # if mean_comp==False, indicates the index of the self-attn map +attn_coeff: 1.0 # attn fusion, w * cross-attn + (1-w) * self-attn +log_cross_attn: False # True if cross attn every step +u2net_path: "./checkpoint/u2net/u2net.pth" + +# ldm +model_id: "sd14" +ldm_speed_up: False +enable_xformers: False +gradient_checkpoint: False +#token_ind: 1 # the index of CLIP prompt embedding, start from 1 +use_ddim: True +num_inference_steps: 50 +guidance_scale: 7.5 # sdxl default 5.0 +# ASDS loss +sds: + crop_size: 512 + augmentations: "affine" + guidance_scale: 100 + grad_scale: 1e-5 + t_range: [ 0.05, 0.95 ] + warmup: 0 + +clip: + model_name: "RN101" # RN101, ViT-L/14 + feats_loss_type: "l2" # clip visual loss type, conv layers + feats_loss_weights: [ 0,0,1.0,1.0,0 ] # RN based + # feats_loss_weights: [ 0,0,1.0,1.0,0,0,0,0,0,0,0,0 ] # ViT based + fc_loss_weight: 0.1 # clip visual loss, fc layer weight + augmentations: "affine" # augmentation before clip visual computation + num_aug: 4 # num of augmentation before clip visual computation + vis_loss: 1 # 1 or 0 for use or disable clip visual loss + text_visual_coeff: 0 # cosine similarity between text and img + +perceptual: + name: "lpips" # dists + lpips_net: 'vgg' + coeff: 0.2 \ No newline at end of file diff --git a/docs/figures/refine/ldm_generated_image0.png b/docs/figures/refine/ldm_generated_image0.png new file mode 100644 index 0000000000000000000000000000000000000000..0761be3703fd7c551b82c63995457110f7d0baca --- /dev/null +++ b/docs/figures/refine/ldm_generated_image0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:75a0f634d343e08b4c9fe1486b1ff8e2ff330322ff2c8aa6d67f37704425e844 +size 334894 diff --git a/docs/figures/refine/ldm_generated_image1.png b/docs/figures/refine/ldm_generated_image1.png new file mode 100644 index 0000000000000000000000000000000000000000..d8d5c7df12132d05c45659d961ba191db8d0ddb8 --- /dev/null +++ b/docs/figures/refine/ldm_generated_image1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:94d8ccb932a0c088e8014dc89652b65b5325ebfb10b217c6be7c305cd50527a3 +size 332345 diff --git a/docs/figures/refine/ldm_generated_image2.png b/docs/figures/refine/ldm_generated_image2.png new file mode 100644 index 0000000000000000000000000000000000000000..0205dbf0a136ef6179daee5225b1ca6994b0a33d --- /dev/null +++ b/docs/figures/refine/ldm_generated_image2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b7770774aea4ecd71b3daebd81d551ef5ae0bfc457e99f9ace53dad266115dc3 +size 334444 diff --git a/docs/figures/refine/visual_best-rendered0.png b/docs/figures/refine/visual_best-rendered0.png new file mode 100644 index 0000000000000000000000000000000000000000..33e59248c324f8b2fcb7a416397bfd4ae0566e0e --- /dev/null +++ b/docs/figures/refine/visual_best-rendered0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c61fe29bd5d1bb07b7f519587324f8dd50f83dbac2eb6256cd4716a6d36b63c +size 28248 diff --git a/docs/figures/refine/visual_best-rendered1.png b/docs/figures/refine/visual_best-rendered1.png new file mode 100644 index 0000000000000000000000000000000000000000..ed889ff7e1dbcbec18d5bd99e8ec967016616536 --- /dev/null +++ b/docs/figures/refine/visual_best-rendered1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5efbe511ae39e108ae1f1d98431d6eb9c5650ae17706d1270bb839797ac0fb9 +size 29152 diff --git a/docs/figures/refine/visual_best-rendered2.png b/docs/figures/refine/visual_best-rendered2.png new file mode 100644 index 0000000000000000000000000000000000000000..1c76820e59b362f6a11dfa18d483e8b70c94b07b --- /dev/null +++ b/docs/figures/refine/visual_best-rendered2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:80871b6669d46533a38c17164ef5a1992ce4cfa3fb1ddb6422b1a451c2b2e0d7 +size 29546 diff --git a/docs/figures/replace/ldm_generated_image0.png b/docs/figures/replace/ldm_generated_image0.png new file mode 100644 index 0000000000000000000000000000000000000000..b3e92620df050f79a0e641ddf6f75ddbf8c1fda7 --- /dev/null +++ b/docs/figures/replace/ldm_generated_image0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12cf20fbcb849f388e004a95d873a141f18cb7645e006438d8e789cd5dd83a4c +size 461954 diff --git a/docs/figures/replace/ldm_generated_image1.png b/docs/figures/replace/ldm_generated_image1.png new file mode 100644 index 0000000000000000000000000000000000000000..c2c0551849a9e4298bff66ea94d2e00de3480fdd --- /dev/null +++ b/docs/figures/replace/ldm_generated_image1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a35c9877e8790665b7ab05f49621241ebdf8b1f3ab0cb3079b5f932cc5337a0d +size 461043 diff --git a/docs/figures/replace/ldm_generated_image2.png b/docs/figures/replace/ldm_generated_image2.png new file mode 100644 index 0000000000000000000000000000000000000000..b6cd5f99c5c9c202449f05b9f2da55471cbdd079 --- /dev/null +++ b/docs/figures/replace/ldm_generated_image2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:464ba1a05a0835904bec99233100e88efea7e9cc751c95817e27cec7e9bd4991 +size 459910 diff --git a/docs/figures/replace/ldm_generated_image3.png b/docs/figures/replace/ldm_generated_image3.png new file mode 100644 index 0000000000000000000000000000000000000000..6d0d464e63b63cb640fd6f4b4138dc498a913a87 --- /dev/null +++ b/docs/figures/replace/ldm_generated_image3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f56a7cd6a71efe5231de48eb74699a91890778179111b0acff7cae0fda073f3 +size 484599 diff --git a/docs/figures/replace/visual_best-rendered0.png b/docs/figures/replace/visual_best-rendered0.png new file mode 100644 index 0000000000000000000000000000000000000000..141842f037886bb3aeb2624ff4e76649e7c6d0cc --- /dev/null +++ b/docs/figures/replace/visual_best-rendered0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bac660e47ecfc7f8b3d30181df9729afdaa55af72ae662cd67fab736c5869f09 +size 44046 diff --git a/docs/figures/replace/visual_best-rendered1.png b/docs/figures/replace/visual_best-rendered1.png new file mode 100644 index 0000000000000000000000000000000000000000..f817ab059e88431209b7ed5cd2526ec063b516f1 --- /dev/null +++ b/docs/figures/replace/visual_best-rendered1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c34d8ab9e3f3c9563ef019b3be48262940c940a75b138635b43f0c54e452de17 +size 50348 diff --git a/docs/figures/replace/visual_best-rendered2.png b/docs/figures/replace/visual_best-rendered2.png new file mode 100644 index 0000000000000000000000000000000000000000..1874dc319b9f32ec8477f707a1bcea00783c0c10 --- /dev/null +++ b/docs/figures/replace/visual_best-rendered2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61c57adbce024e3ae95c214e449aba74144bb8c374075c2ff4a37d0083c2f70e +size 52739 diff --git a/docs/figures/replace/visual_best-rendered3.png b/docs/figures/replace/visual_best-rendered3.png new file mode 100644 index 0000000000000000000000000000000000000000..70f7774c827bda779a0e5e0e95666e4e40d9f267 --- /dev/null +++ b/docs/figures/replace/visual_best-rendered3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b030cdd81791c3a3bf350d60523a2ee47af4ab1dc0028e22ddf4ec85d3df1a12 +size 59110 diff --git a/docs/figures/reweight/ldm_generated_image0.png b/docs/figures/reweight/ldm_generated_image0.png new file mode 100644 index 0000000000000000000000000000000000000000..79ea67f56d1c0465096eb0829e3e7dd3d5f8169b --- /dev/null +++ b/docs/figures/reweight/ldm_generated_image0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:af8c459a5f27d40ce4db6e3a4bcea442eac50d7d768010cee46db38e384e5af5 +size 466792 diff --git a/docs/figures/reweight/ldm_generated_image1.png b/docs/figures/reweight/ldm_generated_image1.png new file mode 100644 index 0000000000000000000000000000000000000000..3fa4c2c5fd6a46ecf0ce6219c3fe055954d26053 --- /dev/null +++ b/docs/figures/reweight/ldm_generated_image1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3bf5adb91c7881eb2ec46dc9d54f980e6d5b1a0465b84fe0715313fcb9cb2312 +size 499273 diff --git a/docs/figures/reweight/ldm_generated_image2.png b/docs/figures/reweight/ldm_generated_image2.png new file mode 100644 index 0000000000000000000000000000000000000000..f25570976787feb1e5136828ccb5453cd1b8764b --- /dev/null +++ b/docs/figures/reweight/ldm_generated_image2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3516da234863316239902f9058bef26c2c81bc48cbd49ac4060519c9d63797f1 +size 509096 diff --git a/docs/figures/reweight/visual_best-rendered0.png b/docs/figures/reweight/visual_best-rendered0.png new file mode 100644 index 0000000000000000000000000000000000000000..d7fef523c8077fe13cbbbf14d5bd619a1aef5315 --- /dev/null +++ b/docs/figures/reweight/visual_best-rendered0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47648174430fd29712507baaf9bfe9afa9c810258531c7d724c9bfbee82d642b +size 32130 diff --git a/docs/figures/reweight/visual_best-rendered1.png b/docs/figures/reweight/visual_best-rendered1.png new file mode 100644 index 0000000000000000000000000000000000000000..37affab1149575983201cafd1d0e7aff45a37a6c --- /dev/null +++ b/docs/figures/reweight/visual_best-rendered1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83ec8e4b3a5df5229a2503c3b728f8bf1cf7dc7aea09b77570f908ace62951cc +size 30663 diff --git a/docs/figures/reweight/visual_best-rendered2.png b/docs/figures/reweight/visual_best-rendered2.png new file mode 100644 index 0000000000000000000000000000000000000000..f97bb89d73b70532a1fd9935e22cda3b68f84f58 --- /dev/null +++ b/docs/figures/reweight/visual_best-rendered2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3acd23ddb8c6e77f6133cd91ebd0e624fa72101f0c964e11796041ebe5f31e38 +size 35395 diff --git a/libs/__init__.py b/libs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3906f788e4e3f3f66a0753c5bb851d98d515e1fb --- /dev/null +++ b/libs/__init__.py @@ -0,0 +1,9 @@ +from .utils import lazy + +__getattr__, __dir__, __all__ = lazy.attach( + __name__, + submodules={'engine', 'metric', 'modules', 'solver', 'utils'}, + submod_attrs={} +) + +__version__ = '0.0.1' diff --git a/libs/engine/__init__.py b/libs/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..115dddde393cc8d46c05aa1b18b5e7ae43b4a95c --- /dev/null +++ b/libs/engine/__init__.py @@ -0,0 +1,7 @@ +from .model_state import ModelState +from .config_processor import merge_and_update_config + +__all__ = [ + 'ModelState', + 'merge_and_update_config' +] diff --git a/libs/engine/config_processor.py b/libs/engine/config_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..5a73434f086e2ab23c9cbd592b11b7f1b5d8ebe3 --- /dev/null +++ b/libs/engine/config_processor.py @@ -0,0 +1,151 @@ +import os +from typing import Tuple +from functools import reduce + +from argparse import Namespace +from omegaconf import DictConfig, OmegaConf + + +################################################################################# +# merge yaml and argparse # +################################################################################# + +def register_resolver(): + OmegaConf.register_new_resolver( + "add", lambda *numbers: sum(numbers) + ) + OmegaConf.register_new_resolver( + "multiply", lambda *numbers: reduce(lambda x, y: x * y, numbers) + ) + OmegaConf.register_new_resolver( + "sub", lambda n1, n2: n1 - n2 + ) + + +def _merge_args_and_config( + cmd_args: Namespace, + yaml_config: DictConfig, + read_only: bool = False +) -> Tuple[DictConfig, DictConfig, DictConfig]: + # convert cmd line args to OmegaConf + cmd_args_dict = vars(cmd_args) + cmd_args_list = [] + for k, v in cmd_args_dict.items(): + cmd_args_list.append(f"{k}={v}") + cmd_args_conf = OmegaConf.from_cli(cmd_args_list) + + # The following overrides the previous configuration + # cmd_args_list > configs + args_ = OmegaConf.merge(yaml_config, cmd_args_conf) + + if read_only: + OmegaConf.set_readonly(args_, True) + + return args_, cmd_args_conf, yaml_config + + +def merge_configs(args, method_cfg_path): + """merge command line args (argparse) and config file (OmegaConf)""" + yaml_config_path = os.path.join("./", "config", method_cfg_path) + try: + yaml_config = OmegaConf.load(yaml_config_path) + except FileNotFoundError as e: + print(f"error: {e}") + print(f"input file path: `{method_cfg_path}`") + print(f"config path: `{yaml_config_path}` not found.") + raise FileNotFoundError(e) + return _merge_args_and_config(args, yaml_config, read_only=False) + + +def update_configs(source_args, update_nodes, strict=True, remove_update_nodes=True): + """update config file (OmegaConf) with dotlist""" + if update_nodes is None: + return source_args + + update_args_list = str(update_nodes).split() + if len(update_args_list) < 1: + return source_args + + # check update_args + for item in update_args_list: + item_key_ = str(item).split('=')[0] # get key + # item_val_ = str(item).split('=')[1] # get value + + if strict: + # Tests if a key is existing + # assert OmegaConf.select(source_args, item_key_) is not None, f"{item_key_} is not existing." + + # Tests if a value is missing + assert not OmegaConf.is_missing(source_args, item_key_), f"the value of {item_key_} is missing." + + # if keys is None, then add key and set the value + if OmegaConf.select(source_args, item_key_) is None: + source_args.item_key_ = item_key_ + + # update original yaml params + update_nodes = OmegaConf.from_dotlist(update_args_list) + merged_args = OmegaConf.merge(source_args, update_nodes) + + # remove update_args + if remove_update_nodes: + OmegaConf.update(merged_args, 'update', '') + return merged_args + + +def update_if_exist(source_args, update_nodes): + """update config file (OmegaConf) with dotlist""" + if update_nodes is None: + return source_args + + upd_args_list = str(update_nodes).split() + if len(upd_args_list) < 1: + return source_args + + update_args_list = [] + for item in upd_args_list: + item_key_ = str(item).split('=')[0] # get key + + # if a key is existing + # if OmegaConf.select(source_args, item_key_) is not None: + # update_args_list.append(item) + + update_args_list.append(item) + + # update source_args if key be selected + if len(update_args_list) < 1: + merged_args = source_args + else: + update_nodes = OmegaConf.from_dotlist(update_args_list) + merged_args = OmegaConf.merge(source_args, update_nodes) + + return merged_args + + +def merge_and_update_config(args): + register_resolver() + + # if yaml_config is existing, then merge command line args and yaml_config + # if os.path.isfile(args.config) and args.config is not None: + if args.config is not None and str(args.config).endswith('.yaml'): + merged_args, cmd_args, yaml_config = merge_configs(args, args.config) + else: + merged_args, cmd_args, yaml_config = args, args, None + + # update the yaml_config with the cmd '-update' flag + update_nodes = args.update + final_args = update_configs(merged_args, update_nodes) + + # to simplify log output, we empty this + yaml_config_update = update_if_exist(yaml_config, update_nodes) + cmd_args_update = update_if_exist(cmd_args, update_nodes) + cmd_args_update.update = "" # clear update params + + final_args.yaml_config = yaml_config_update + final_args.cmd_args = cmd_args_update + + # update seed + if final_args.seed < 0: + import random + final_args.seed = random.randint(0, 65535) + + return final_args diff --git a/libs/engine/model_state.py b/libs/engine/model_state.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc9ae05b3ead9f138dcbe376c68cc4bcf184ed7 --- /dev/null +++ b/libs/engine/model_state.py @@ -0,0 +1,335 @@ +import os +from functools import partial +from typing import Union, List +from pathlib import Path +from datetime import datetime, timedelta + +from omegaconf import DictConfig +from pprint import pprint +import torch +from accelerate.utils import LoggerType +from accelerate import ( + Accelerator, + GradScalerKwargs, + DistributedDataParallelKwargs, + InitProcessGroupKwargs +) + +from ..modules.ema import EMA +from ..utils.logging import get_logger + + +class ModelState: + """ + Handling logger and `hugging face` accelerate training + + features: + - Mixed Precision + - Gradient Scaler + - Gradient Accumulation + - Optimizer + - EMA + - Logger (default: python print) + - Monitor (default: wandb, tensorboard) + """ + + def __init__( + self, + args, + log_path_suffix: str = None, + ignore_log=False, # whether to create log file or not + ) -> None: + self.args: DictConfig = args + + """check valid""" + mixed_precision = self.args.get("mixed_precision") + # Bug: omegaconf convert 'no' to false + mixed_precision = "no" if type(mixed_precision) == bool else mixed_precision + split_batches = self.args.get("split_batches", False) + gradient_accumulate_step = self.args.get("gradient_accumulate_step", 1) + assert gradient_accumulate_step >= 1, f"except gradient_accumulate_step >= 1, get {gradient_accumulate_step}" + + """create working space""" + # rule: ['./config'. 'method_name', 'exp_name.yaml'] + # -> results_path: ./runs/{method_name}-{exp_name}, as a base folder + # config_prefix, config_name = str(self.args.get("config")).split('/') + # config_name_only = str(config_name).split(".")[0] + + config_name_only = str(self.args.get("config")).split(".")[0] + results_folder = self.args.get("results_path", None) + if results_folder is None: + # self.results_path = Path("./workdir") / f"{config_prefix}-{config_name_only}" + self.results_path = Path("./workdir") + else: + # self.results_path = Path(results_folder) / f"{config_prefix}-{config_name_only}" + self.results_path = Path(os.path.join(results_folder, self.args.get("edit_type"), )) + + # update results_path: ./runs/{method_name}-{exp_name}/{log_path_suffix} + # noting: can be understood as "results dir / methods / ablation study / your result" + if log_path_suffix is not None: + self.results_path = self.results_path / log_path_suffix + + kwargs_handlers = [] + """mixed precision training""" + if args.mixed_precision == "no": + scaler_handler = GradScalerKwargs( + init_scale=args.init_scale, + growth_factor=args.growth_factor, + backoff_factor=args.backoff_factor, + growth_interval=args.growth_interval, + enabled=True + ) + kwargs_handlers.append(scaler_handler) + + """distributed training""" + ddp_handler = DistributedDataParallelKwargs( + dim=0, + broadcast_buffers=True, + static_graph=False, + bucket_cap_mb=25, + find_unused_parameters=False, + check_reduction=False, + gradient_as_bucket_view=False + ) + kwargs_handlers.append(ddp_handler) + + init_handler = InitProcessGroupKwargs(timeout=timedelta(seconds=1200)) + kwargs_handlers.append(init_handler) + + """init visualized tracker""" + log_with = [] + self.args.visual = False + if args.use_wandb: + log_with.append(LoggerType.WANDB) + if args.tensorboard: + log_with.append(LoggerType.TENSORBOARD) + + """hugging face Accelerator""" + self.accelerator = Accelerator( + device_placement=True, + split_batches=split_batches, + mixed_precision=mixed_precision, + gradient_accumulation_steps=args.gradient_accumulate_step, + cpu=True if args.use_cpu else False, + log_with=None if len(log_with) == 0 else log_with, + project_dir=self.results_path / "vis", + kwargs_handlers=kwargs_handlers, + ) + + """logs""" + if self.accelerator.is_local_main_process: + # for logging results in a folder periodically + self.results_path.mkdir(parents=True, exist_ok=True) + if not ignore_log: + now_time = datetime.now().strftime('%Y-%m-%d-%H-%M') + # self.logger = get_logger( + # logs_dir=self.results_path.as_posix(), + # file_name=f"log.txt" + # ) + + print("==> command line args: ") + print(args.cmd_args) + print("==> yaml config args: ") + print(args.yaml_config) + + print("\n***** Model State *****") + if self.accelerator.distributed_type != "NO": + print(f"-> Distributed Type: {self.accelerator.distributed_type}") + print(f"-> Split Batch Size: {split_batches}, Total Batch Size: {self.actual_batch_size}") + print(f"-> Mixed Precision: {mixed_precision}, AMP: {self.accelerator.native_amp}," + f" Gradient Accumulate Step: {gradient_accumulate_step}") + print(f"-> Weight dtype: {self.weight_dtype}") + + if self.accelerator.scaler_handler is not None and self.accelerator.scaler_handler.enabled: + print(f"-> Enabled GradScaler: {self.accelerator.scaler_handler.to_kwargs()}") + + if args.use_wandb: + print(f"-> Init trackers: 'wandb' ") + self.args.visual = True + self.__init_tracker(project_name="my_project", tags=None, entity="") + + print(f"-> Working Space: '{self.results_path}'") + + """EMA""" + self.use_ema = args.get('ema', False) + self.ema_wrapper = self.__build_ema_wrapper() + + """glob step""" + self.step = 0 + + """log process""" + self.accelerator.wait_for_everyone() + print(f'Process {self.accelerator.process_index} using device: {self.accelerator.device}') + + self.print("-> state initialization complete \n") + + def __init_tracker(self, project_name, tags, entity): + self.accelerator.init_trackers( + project_name=project_name, + config=dict(self.args), + init_kwargs={ + "wandb": { + "notes": "accelerate trainer pipeline", + "tags": [ + f"total batch_size: {self.actual_batch_size}" + ], + "entity": entity, + }} + ) + + def __build_ema_wrapper(self): + if self.use_ema: + self.print(f"-> EMA: {self.use_ema}, decay: {self.args.ema_decay}, " + f"update_after_step: {self.args.ema_update_after_step}, " + f"update_every: {self.args.ema_update_every}") + ema_wrapper = partial( + EMA, beta=self.args.ema_decay, + update_after_step=self.args.ema_update_after_step, + update_every=self.args.ema_update_every + ) + else: + ema_wrapper = None + + return ema_wrapper + + @property + def device(self): + return self.accelerator.device + + @property + def weight_dtype(self): + weight_dtype = torch.float32 + if self.accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif self.accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + return weight_dtype + + @property + def actual_batch_size(self): + if self.accelerator.split_batches is False: + actual_batch_size = self.args.batch_size * self.accelerator.num_processes * self.accelerator.gradient_accumulation_steps + else: + assert self.actual_batch_size % self.accelerator.num_processes == 0 + actual_batch_size = self.args.batch_size + return actual_batch_size + + @property + def n_gpus(self): + return self.accelerator.num_processes + + @property + def no_decay_params_names(self): + no_decay = [ + "bn", "LayerNorm", "GroupNorm", + ] + return no_decay + + def no_decay_params(self, model, weight_decay): + """optimization tricks""" + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in model.named_parameters() + if not any(nd in n for nd in self.no_decay_params_names) + ], + "weight_decay": weight_decay, + }, + { + "params": [ + p for n, p in model.named_parameters() + if any(nd in n for nd in self.no_decay_params_names) + ], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + def optimized_params(self, model: torch.nn.Module, verbose=True) -> List: + """return parameters if `requires_grad` is True + + Args: + model: pytorch models + verbose: log optimized parameters + + Examples: + >>> self.params_optimized = self.optimized_params(uvit, verbose=True) + >>> optimizer = torch.optim.AdamW(self.params_optimized, lr=args.lr) + + Returns: + a list of parameters + """ + params_optimized = [] + for key, value in model.named_parameters(): + if value.requires_grad: + params_optimized.append(value) + if verbose: + self.print("\t {}, {}, {}".format(key, value.numel(), value.shape)) + return params_optimized + + def save_everything(self, fpath: str): + """Saving and loading the model, optimizer, RNG generators, and the GradScaler.""" + if not self.accelerator.is_main_process: + return + self.accelerator.save_state(fpath) + + def load_save_everything(self, fpath: str): + """Loading the model, optimizer, RNG generators, and the GradScaler.""" + self.accelerator.load_state(fpath) + + def save(self, milestone: Union[str, float, int], checkpoint: object) -> None: + if not self.accelerator.is_main_process: + return + + torch.save(checkpoint, self.results_path / f'model-{milestone}.pt') + + def save_in(self, root: Union[str, Path], checkpoint: object) -> None: + if not self.accelerator.is_main_process: + return + + torch.save(checkpoint, root) + + def load_ckpt_model_only(self, model: torch.nn.Module, path: Union[str, Path], rm_module_prefix: bool = False): + ckpt = torch.load(path, map_location=self.accelerator.device) + + unwrapped_model = self.accelerator.unwrap_model(model) + if rm_module_prefix: + unwrapped_model.load_state_dict({k.replace('module.', ''): v for k, v in ckpt.items()}) + else: + unwrapped_model.load_state_dict(ckpt) + return unwrapped_model + + def load_shared_weights(self, model: torch.nn.Module, path: Union[str, Path]): + ckpt = torch.load(path, map_location=self.accelerator.device) + self.print(f"pretrained_dict len: {len(ckpt)}") + unwrapped_model = self.accelerator.unwrap_model(model) + model_dict = unwrapped_model.state_dict() + pretrained_dict = {k: v for k, v in ckpt.items() if k in model_dict} + model_dict.update(pretrained_dict) + unwrapped_model.load_state_dict(model_dict, strict=False) + self.print(f"selected pretrained_dict: {len(model_dict)}") + return unwrapped_model + + def print(self, *args, **kwargs): + """Use in replacement of `print()` to only print once per server.""" + self.accelerator.print(*args, **kwargs) + + def pretty_print(self, msg): + if self.accelerator.is_local_main_process: + pprint(dict(msg)) + + def close_tracker(self): + self.accelerator.end_training() + + def free_memory(self): + self.accelerator.clear() + + def close(self, msg: str = "Training complete."): + """Use in end of training.""" + self.free_memory() + + if torch.cuda.is_available(): + self.print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB') + if self.args.visual: + self.close_tracker() + self.print(msg) diff --git a/libs/metric/__init__.py b/libs/metric/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/libs/metric/__init__.py @@ -0,0 +1 @@ + diff --git a/libs/metric/accuracy.py b/libs/metric/accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..d852de290a55c86afe22726b46c541890c2a9173 --- /dev/null +++ b/libs/metric/accuracy.py @@ -0,0 +1,25 @@ +def accuracy(output, target, topk=(1,)): + """ + Computes the accuracy over the k top predictions for the specified values of k. + + Args + output: logits or probs (num of batch, num of classes) + target: (num of batch, 1) or (num of batch, ) + topk: list of returned k + + refer: https://github.com/pytorch/examples/blob/master/imagenet/main.py + """ + maxK = max(topk) # get k in top-k + batch_size = target.size(0) + + _, pred = output.topk(k=maxK, dim=1, largest=True, sorted=True) # pred: [num of batch, k] + pred = pred.t() # pred: [k, num of batch] + + # [1, num of batch] -> [k, num_of_batch] : bool + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res # np.shape(res): [k, 1] diff --git a/libs/metric/clip_score/__init__.py b/libs/metric/clip_score/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38b0483c0c2c6702e88e5b5860c0306aaee7e369 --- /dev/null +++ b/libs/metric/clip_score/__init__.py @@ -0,0 +1,3 @@ +from .openaiCLIP_loss import CLIPScoreWrapper + +__all__ = ['CLIPScoreWrapper'] diff --git a/libs/metric/clip_score/openaiCLIP_loss.py b/libs/metric/clip_score/openaiCLIP_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..813b3c5ea09294f6913b50207c8db6f2de000a10 --- /dev/null +++ b/libs/metric/clip_score/openaiCLIP_loss.py @@ -0,0 +1,304 @@ +from typing import Union, List, Tuple +from collections import OrderedDict +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +import torchvision.transforms as transforms + + +class CLIPScoreWrapper(nn.Module): + + def __init__(self, + clip_model_name: str, + download_root: str = None, + device: torch.device = "cuda" if torch.cuda.is_available() else "cpu", + jit: bool = False, + # additional params + visual_score: bool = False, + feats_loss_type: str = None, + feats_loss_weights: List[float] = None, + fc_loss_weight: float = None, + context_length: int = 77): + super().__init__() + + import clip # local import + + # check model info + self.clip_model_name = clip_model_name + self.device = device + self.available_models = clip.available_models() + assert clip_model_name in self.available_models, f"A model backbone: {clip_model_name} that does not exist" + + # load CLIP + self.model, self.preprocess = clip.load(clip_model_name, device, jit=jit, download_root=download_root) + self.model.eval() + + # load tokenize + self.tokenize_fn = partial(clip.tokenize, context_length=context_length) + + # load CLIP visual + self.visual_encoder = VisualEncoderWrapper(self.model, clip_model_name).to(device) + self.visual_encoder.eval() + + # check loss weights + self.visual_score = visual_score + if visual_score: + assert feats_loss_type in ["l1", "l2", "cosine"], f"{feats_loss_type} is not exist." + if clip_model_name.startswith("ViT"): assert len(feats_loss_weights) == 12 + if clip_model_name.startswith("RN"): assert len(feats_loss_weights) == 5 + + # load visual loss wrapper + self.visual_loss_fn = CLIPVisualLossWrapper(self.visual_encoder, feats_loss_type, + feats_loss_weights, + fc_loss_weight) + + @property + def input_resolution(self): + return self.model.visual.input_resolution # default: 224 + + @property + def resize(self): # Resize only + return transforms.Compose([self.preprocess.transforms[0]]) + + @property + def normalize(self): + return transforms.Compose([ + self.preprocess.transforms[0], # Resize + self.preprocess.transforms[1], # CenterCrop + self.preprocess.transforms[-1], # Normalize + ]) + + @property + def norm_(self): # Normalize only + return transforms.Compose([self.preprocess.transforms[-1]]) + + def encode_image_layer_wise(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + semantic_vec, feature_maps = self.visual_encoder(x) + return semantic_vec, feature_maps + + def encode_text(self, text: Union[str, List[str]], norm: bool = True) -> torch.Tensor: + tokens = self.tokenize_fn(text).to(self.device) + text_features = self.model.encode_text(tokens) + if norm: + text_features = text_features.mean(axis=0, keepdim=True) + text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True) + return text_features_norm + return text_features + + def encode_image(self, image: torch.Tensor, norm: bool = True) -> torch.Tensor: + image_features = self.model.encode_image(image) + if norm: + image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True) + return image_features_norm + return image_features + + @torch.no_grad() + def predict(self, + image: torch.Tensor, + text: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]: + image_features = self.model.encode_image(image) + text_tokenize = self.tokenize_fn(text).to(self.device) + text_features = self.model.encode_text(text_tokenize) + logits_per_image, logits_per_text = self.model(image, text) + probs = logits_per_image.softmax(dim=-1).cpu().numpy() + return image_features, text_features, probs + + def compute_text_visual_distance( + self, image: torch.Tensor, text: Union[str, List[str]] + ) -> torch.Tensor: + image_features = self.model.encode_image(image) + text_tokenize = self.tokenize_fn(text).to(self.device) + with torch.no_grad(): + text_features = self.model.encode_text(text_tokenize) + + image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True) + text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True) + # loss = - (image_features_norm @ text_features_norm.T) + loss = 1 - torch.cosine_similarity(image_features_norm, text_features_norm, dim=1) + return loss.mean() + + def directional_text_visual_distance(self, src_text, src_img, tar_text, tar_img): + src_image_features = self.model.encode_image(src_img).detach() + tar_image_features = self.model.encode_image(tar_img) + src_text_tokenize = self.tokenize_fn(src_text).to(self.device) + tar_text_tokenize = self.tokenize_fn(tar_text).to(self.device) + with torch.no_grad(): + src_text_features = self.model.encode_text(src_text_tokenize) + tar_text_features = self.model.encode_text(tar_text_tokenize) + + delta_image_features = tar_image_features - src_image_features + delta_text_features = tar_text_features - src_text_features + + # # avold zero divisor + # delta_image_features_norm = delta_image_features / delta_image_features.norm(dim=-1, keepdim=True) + # delta_text_features_norm = delta_text_features / delta_text_features.norm(dim=-1, keepdim=True) + + loss = 1 - torch.cosine_similarity(delta_image_features, delta_text_features, dim=1, eps=1e-3) + return loss.mean() + + def compute_visual_distance( + self, x: torch.Tensor, y: torch.Tensor, clip_norm: bool = True, + ) -> Tuple[torch.Tensor, List]: + # return a fc loss and the list of feat loss + assert self.visual_score is True + assert x.shape == y.shape + assert x.shape[-1] == self.input_resolution and x.shape[-2] == self.input_resolution + assert y.shape[-1] == self.input_resolution and y.shape[-2] == self.input_resolution + + if clip_norm: + return self.visual_loss_fn(self.normalize(x), self.normalize(y)) + else: + return self.visual_loss_fn(x, y) + + +class VisualEncoderWrapper(nn.Module): + """ + semantic features and layer by layer feature maps are obtained from CLIP visual encoder. + """ + + def __init__(self, clip_model: nn.Module, clip_model_name: str): + super().__init__() + self.clip_model = clip_model + self.clip_model_name = clip_model_name + + if clip_model_name.startswith("ViT"): + self.feature_maps = OrderedDict() + for i in range(12): # 12 ResBlocks in ViT visual transformer + self.clip_model.visual.transformer.resblocks[i].register_forward_hook( + self.make_hook(i) + ) + + if clip_model_name.startswith("RN"): + layers = list(self.clip_model.visual.children()) + init_layers = torch.nn.Sequential(*layers)[:8] + self.layer1 = layers[8] + self.layer2 = layers[9] + self.layer3 = layers[10] + self.layer4 = layers[11] + self.att_pool2d = layers[12] + + def make_hook(self, name): + def hook(module, input, output): + if len(output.shape) == 3: + # LND -> NLD (B, 77, 768) + self.feature_maps[name] = output.permute(1, 0, 2) + else: + self.feature_maps[name] = output + + return hook + + def _forward_vit(self, x: torch.Tensor) -> Tuple[torch.Tensor, List]: + fc_feature = self.clip_model.encode_image(x).float() + feature_maps = [self.feature_maps[k] for k in range(12)] + + # fc_feature len: 1 ,feature_maps len: 12 + return fc_feature, feature_maps + + def _forward_resnet(self, x: torch.Tensor) -> Tuple[torch.Tensor, List]: + def stem(m, x): + for conv, bn, relu in [(m.conv1, m.bn1, m.relu1), (m.conv2, m.bn2, m.relu2), (m.conv3, m.bn3, m.relu3)]: + x = torch.relu(bn(conv(x))) + x = m.avgpool(x) + return x + + x = x.type(self.clip_model.visual.conv1.weight.dtype) + x = stem(self.clip_model.visual, x) + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + x4 = self.layer4(x3) + y = self.att_pool2d(x4) + + # fc_features len: 1 ,feature_maps len: 5 + return y, [x, x1, x2, x3, x4] + + def forward(self, x) -> Tuple[torch.Tensor, List[torch.Tensor]]: + if self.clip_model_name.startswith("ViT"): + fc_feat, visual_feat_maps = self._forward_vit(x) + if self.clip_model_name.startswith("RN"): + fc_feat, visual_feat_maps = self._forward_resnet(x) + + return fc_feat, visual_feat_maps + + +class CLIPVisualLossWrapper(nn.Module): + """ + Visual Feature Loss + FC loss + """ + + def __init__( + self, + visual_encoder: nn.Module, + feats_loss_type: str = None, + feats_loss_weights: List[float] = None, + fc_loss_weight: float = None, + ): + super().__init__() + self.visual_encoder = visual_encoder + self.feats_loss_weights = feats_loss_weights + self.fc_loss_weight = fc_loss_weight + + self.layer_criterion = layer_wise_distance(feats_loss_type) + + def forward(self, x: torch.Tensor, y: torch.Tensor): + x_fc_feature, x_feat_maps = self.visual_encoder(x) + y_fc_feature, y_feat_maps = self.visual_encoder(y) + + # visual feature loss + if sum(self.feats_loss_weights) == 0: + feats_loss_list = [torch.tensor(0, device=x.device)] + else: + feats_loss = self.layer_criterion(x_feat_maps, y_feat_maps, self.visual_encoder.clip_model_name) + feats_loss_list = [] + for layer, w in enumerate(self.feats_loss_weights): + if w: + feats_loss_list.append(feats_loss[layer] * w) + + # visual fc loss, default: cosine similarity + if self.fc_loss_weight == 0: + fc_loss = torch.tensor(0, device=x.device) + else: + fc_loss = (1 - torch.cosine_similarity(x_fc_feature, y_fc_feature, dim=1)).mean() + fc_loss = fc_loss * self.fc_loss_weight + + return fc_loss, feats_loss_list + + +################################################################################# +# layer wise metric # +################################################################################# + +def layer_wise_distance(metric_name: str): + return { + "l1": l1_layer_wise, + "l2": l2_layer_wise, + "cosine": cosine_layer_wise + }.get(metric_name.lower()) + + +def l2_layer_wise(x_features, y_features, clip_model_name): + return [ + torch.square(x_conv - y_conv).mean() + for x_conv, y_conv in zip(x_features, y_features) + ] + + +def l1_layer_wise(x_features, y_features, clip_model_name): + return [ + torch.abs(x_conv - y_conv).mean() + for x_conv, y_conv in zip(x_features, y_features) + ] + + +def cosine_layer_wise(x_features, y_features, clip_model_name): + if clip_model_name.startswith("RN"): + return [ + (1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean() + for x_conv, y_conv in zip(x_features, y_features) + ] + return [ + (1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean() + for x_conv, y_conv in zip(x_features, y_features) + ] diff --git a/libs/metric/lpips_origin/__init__.py b/libs/metric/lpips_origin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8cb4332c66e4725112fad88db377e3a50471f3ee --- /dev/null +++ b/libs/metric/lpips_origin/__init__.py @@ -0,0 +1,3 @@ +from .lpips import LPIPS + +__all__ = ['LPIPS'] diff --git a/libs/metric/lpips_origin/lpips.py b/libs/metric/lpips_origin/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..fa97aad10e2b18fa5b92b913e9def4f5135a83c0 --- /dev/null +++ b/libs/metric/lpips_origin/lpips.py @@ -0,0 +1,184 @@ +from __future__ import absolute_import + +import os + +import torch +import torch.nn as nn + +from . import pretrained_networks as pretrained_torch_models + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) + + +def upsample(x): + return nn.Upsample(size=x.shape[2:], mode='bilinear', align_corners=False)(x) + + +def normalize_tensor(in_feat, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) + return in_feat / (norm_factor + eps) + + +# Learned perceptual metric +class LPIPS(nn.Module): + + def __init__(self, + pretrained=True, + net='alex', + version='0.1', + lpips=True, + spatial=False, + pnet_rand=False, + pnet_tune=False, + use_dropout=True, + model_path=None, + eval_mode=True, + verbose=True): + """ Initializes a perceptual loss torch.nn.Module + + Parameters (default listed first) + --------------------------------- + lpips : bool + [True] use linear layers on top of base/trunk network + [False] means no linear layers; each layer is averaged together + pretrained : bool + This flag controls the linear layers, which are only in effect when lpips=True above + [True] means linear layers are calibrated with human perceptual judgments + [False] means linear layers are randomly initialized + pnet_rand : bool + [False] means trunk loaded with ImageNet classification weights + [True] means randomly initialized trunk + net : str + ['alex','vgg','squeeze'] are the base/trunk networks available + version : str + ['v0.1'] is the default and latest + ['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1) + model_path : 'str' + [None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1 + + The following parameters should only be changed if training the network: + + eval_mode : bool + [True] is for test mode (default) + [False] is for training mode + pnet_tune + [False] keep base/trunk frozen + [True] tune the base/trunk network + use_dropout : bool + [True] to use dropout when training linear layers + [False] for no dropout when training linear layers + """ + super(LPIPS, self).__init__() + if verbose: + print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' % + ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) + + self.pnet_type = net + self.pnet_tune = pnet_tune + self.pnet_rand = pnet_rand + self.spatial = spatial + self.lpips = lpips # false means baseline of just averaging all layers + self.version = version + self.scaling_layer = ScalingLayer() + + if self.pnet_type in ['vgg', 'vgg16']: + net_type = pretrained_torch_models.vgg16 + self.chns = [64, 128, 256, 512, 512] + elif self.pnet_type == 'alex': + net_type = pretrained_torch_models.alexnet + self.chns = [64, 192, 384, 256, 256] + elif self.pnet_type == 'squeeze': + net_type = pretrained_torch_models.squeezenet + self.chns = [64, 128, 256, 384, 384, 512, 512] + self.L = len(self.chns) + + self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) + + if lpips: + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + if self.pnet_type == 'squeeze': # 7 layers for squeezenet + self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) + self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) + self.lins += [self.lin5, self.lin6] + self.lins = nn.ModuleList(self.lins) + + if pretrained: + if model_path is None: + model_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + f"weights/v{version}/{net}.pth" + ) + if verbose: + print('Loading model from: %s' % model_path) + self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) + + if eval_mode: + self.eval() + + def forward(self, in0, in1, return_per_layer=False, normalize=False): + if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, 1] + in0 = 2 * in0 - 1 + in1 = 2 * in1 - 1 + + # Noting: v0.0 - original release had a bug, where input was not scaled + if self.version == '0.1': + in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) + else: + in0_input, in1_input = in0, in1 + + # model forward + outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) + + feats0, feats1, diffs = {}, {}, {} + for kk in range(self.L): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + if self.lpips: + if self.spatial: + res = [upsample(self.lins[kk](diffs[kk])) for kk in range(self.L)] + else: + res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] + else: + if self.spatial: + res = [upsample(diffs[kk].sum(dim=1, keepdim=True)) for kk in range(self.L)] + else: + res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] + + loss = sum(res) + + if return_per_layer: + return loss, res + else: + return loss + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """A single linear layer which does a 1x1 conv""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) diff --git a/libs/metric/lpips_origin/pretrained_networks.py b/libs/metric/lpips_origin/pretrained_networks.py new file mode 100644 index 0000000000000000000000000000000000000000..484b808da02eecb59c132e63a0fe4ae90b1e4d2e --- /dev/null +++ b/libs/metric/lpips_origin/pretrained_networks.py @@ -0,0 +1,196 @@ +from collections import namedtuple + +import torch +import torchvision.models as tv_models + + +class squeezenet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(squeezenet, self).__init__() + pretrained_features = tv_models.squeezenet1_1(weights=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + self.slice7 = torch.nn.Sequential() + self.N_slices = 7 + for x in range(2): + self.slice1.add_module(str(x), pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), pretrained_features[x]) + for x in range(10, 11): + self.slice5.add_module(str(x), pretrained_features[x]) + for x in range(11, 12): + self.slice6.add_module(str(x), pretrained_features[x]) + for x in range(12, 13): + self.slice7.add_module(str(x), pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + h = self.slice6(h) + h_relu6 = h + h = self.slice7(h) + h_relu7 = h + vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7']) + out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) + + return out + + +class alexnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(alexnet, self).__init__() + weights = tv_models.AlexNet_Weights.IMAGENET1K_V1 if pretrained else None + alexnet_pretrained_features = tv_models.alexnet(weights=weights).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(2): + self.slice1.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(10, 12): + self.slice5.add_module(str(x), alexnet_pretrained_features[x]) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) + out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + return out + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + weights = tv_models.VGG16_Weights.IMAGENET1K_V1 if pretrained else None + vgg_pretrained_features = tv_models.vgg16(weights=weights).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + return out + + +class resnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True, num=18): + super(resnet, self).__init__() + + if num == 18: + weights = tv_models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None + self.net = tv_models.resnet18(weights=weights) + elif num == 34: + weights = tv_models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None + self.net = tv_models.resnet34(weights=weights) + elif num == 50: + weights = tv_models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None + self.net = tv_models.resnet50(weights=weights) + elif num == 101: + weights = tv_models.ResNet101_Weights.IMAGENET1K_V2 if pretrained else None + self.net = tv_models.resnet101(weights=weights) + elif num == 152: + weights = tv_models.ResNet152_Weights.IMAGENET1K_V2 if pretrained else None + self.net = tv_models.resnet152(weights=weights) + self.N_slices = 5 + + if not requires_grad: + for param in self.net.parameters(): + param.requires_grad = False + + self.conv1 = self.net.conv1 + self.bn1 = self.net.bn1 + self.relu = self.net.relu + self.maxpool = self.net.maxpool + self.layer1 = self.net.layer1 + self.layer2 = self.net.layer2 + self.layer3 = self.net.layer3 + self.layer4 = self.net.layer4 + + def forward(self, X): + h = self.conv1(X) + h = self.bn1(h) + h = self.relu(h) + h_relu1 = h + h = self.maxpool(h) + h = self.layer1(h) + h_conv2 = h + h = self.layer2(h) + h_conv3 = h + h = self.layer3(h) + h_conv4 = h + h = self.layer4(h) + h_conv5 = h + + outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5']) + out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) + + return out diff --git a/libs/metric/lpips_origin/weights/v0.1/alex.pth b/libs/metric/lpips_origin/weights/v0.1/alex.pth new file mode 100644 index 0000000000000000000000000000000000000000..1df9dfe62abb1fc89cc7f82b4e5fe886c979708e Binary files /dev/null and b/libs/metric/lpips_origin/weights/v0.1/alex.pth differ diff --git a/libs/metric/lpips_origin/weights/v0.1/squeeze.pth b/libs/metric/lpips_origin/weights/v0.1/squeeze.pth new file mode 100644 index 0000000000000000000000000000000000000000..a3bd383bc4747bb587d3c3d47c80e43eda2ab536 Binary files /dev/null and b/libs/metric/lpips_origin/weights/v0.1/squeeze.pth differ diff --git a/libs/metric/lpips_origin/weights/v0.1/vgg.pth b/libs/metric/lpips_origin/weights/v0.1/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..47e943cfacabf7040b4af8cf4084ab91177f1b88 Binary files /dev/null and b/libs/metric/lpips_origin/weights/v0.1/vgg.pth differ diff --git a/libs/metric/piq/__init__.py b/libs/metric/piq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd3927e69b8b5a28c121b2e0808d5d4e6974ddf2 --- /dev/null +++ b/libs/metric/piq/__init__.py @@ -0,0 +1,2 @@ +# install: pip install piq +# repo: https://github.com/photosynthesis-team/piq diff --git a/libs/metric/piq/functional/__init__.py b/libs/metric/piq/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..131231bd3249723d3c21d98c45109ff99fb19612 --- /dev/null +++ b/libs/metric/piq/functional/__init__.py @@ -0,0 +1,15 @@ +from .base import ifftshift, get_meshgrid, similarity_map, gradient_map, pow_for_complex, crop_patches +from .colour_conversion import rgb2lmn, rgb2xyz, xyz2lab, rgb2lab, rgb2yiq, rgb2lhm +from .filters import haar_filter, hann_filter, scharr_filter, prewitt_filter, gaussian_filter +from .filters import binomial_filter1d, average_filter2d +from .layers import L2Pool2d +from .resize import imresize + +__all__ = [ + 'ifftshift', 'get_meshgrid', 'similarity_map', 'gradient_map', 'pow_for_complex', 'crop_patches', + 'rgb2lmn', 'rgb2xyz', 'xyz2lab', 'rgb2lab', 'rgb2yiq', 'rgb2lhm', + 'haar_filter', 'hann_filter', 'scharr_filter', 'prewitt_filter', 'gaussian_filter', + 'binomial_filter1d', 'average_filter2d', + 'L2Pool2d', + 'imresize', +] diff --git a/libs/metric/piq/functional/base.py b/libs/metric/piq/functional/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a8d34790ad89b225fd28ade34507c498a315fb6d --- /dev/null +++ b/libs/metric/piq/functional/base.py @@ -0,0 +1,111 @@ +r"""General purpose functions""" +from typing import Tuple, Union, Optional +import torch +from ..utils import _parse_version + + +def ifftshift(x: torch.Tensor) -> torch.Tensor: + r""" Similar to np.fft.ifftshift but applies to PyTorch Tensors""" + shift = [-(ax // 2) for ax in x.size()] + return torch.roll(x, shift, tuple(range(len(shift)))) + + +def get_meshgrid(size: Tuple[int, int], device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: + r"""Return coordinate grid matrices centered at zero point. + Args: + size: Shape of meshgrid to create + device: device to use for creation + dtype: dtype to use for creation + Returns: + Meshgrid of size on device with dtype values. + """ + if size[0] % 2: + # Odd + x = torch.arange(-(size[0] - 1) / 2, size[0] / 2, device=device, dtype=dtype) / (size[0] - 1) + else: + # Even + x = torch.arange(- size[0] / 2, size[0] / 2, device=device, dtype=dtype) / size[0] + + if size[1] % 2: + # Odd + y = torch.arange(-(size[1] - 1) / 2, size[1] / 2, device=device, dtype=dtype) / (size[1] - 1) + else: + # Even + y = torch.arange(- size[1] / 2, size[1] / 2, device=device, dtype=dtype) / size[1] + # Use indexing param depending on torch version + recommended_torch_version = _parse_version("1.10.0") + torch_version = _parse_version(torch.__version__) + if len(torch_version) > 0 and torch_version >= recommended_torch_version: + return torch.meshgrid(x, y, indexing='ij') + return torch.meshgrid(x, y) + + +def similarity_map(map_x: torch.Tensor, map_y: torch.Tensor, constant: float, alpha: float = 0.0) -> torch.Tensor: + r""" Compute similarity_map between two tensors using Dice-like equation. + + Args: + map_x: Tensor with map to be compared + map_y: Tensor with map to be compared + constant: Used for numerical stability + alpha: Masking coefficient. Subtracts - `alpha` * map_x * map_y from denominator and nominator + """ + return (2.0 * map_x * map_y - alpha * map_x * map_y + constant) / \ + (map_x ** 2 + map_y ** 2 - alpha * map_x * map_y + constant) + + +def gradient_map(x: torch.Tensor, kernels: torch.Tensor) -> torch.Tensor: + r""" Compute gradient map for a given tensor and stack of kernels. + + Args: + x: Tensor with shape (N, C, H, W). + kernels: Stack of tensors for gradient computation with shape (k_N, k_H, k_W) + Returns: + Gradients of x per-channel with shape (N, C, H, W) + """ + padding = kernels.size(-1) // 2 + grads = torch.nn.functional.conv2d(x, kernels, padding=padding) + + return torch.sqrt(torch.sum(grads ** 2, dim=-3, keepdim=True)) + + +def pow_for_complex(base: torch.Tensor, exp: Union[int, float]) -> torch.Tensor: + r""" Takes the power of each element in a 4D tensor with negative values or 5D tensor with complex values. + Complex numbers are represented by modulus and argument: r * \exp(i * \phi). + + It will likely to be redundant with introduction of torch.ComplexTensor. + + Args: + base: Tensor with shape (N, C, H, W) or (N, C, H, W, 2). + exp: Exponent + Returns: + Complex tensor with shape (N, C, H, W, 2). + """ + if base.dim() == 4: + x_complex_r = base.abs() + x_complex_phi = torch.atan2(torch.zeros_like(base), base) + elif base.dim() == 5 and base.size(-1) == 2: + x_complex_r = base.pow(2).sum(dim=-1).sqrt() + x_complex_phi = torch.atan2(base[..., 1], base[..., 0]) + else: + raise ValueError(f'Expected real or complex tensor, got {base.size()}') + + x_complex_pow_r = x_complex_r ** exp + x_complex_pow_phi = x_complex_phi * exp + x_real_pow = x_complex_pow_r * torch.cos(x_complex_pow_phi) + x_imag_pow = x_complex_pow_r * torch.sin(x_complex_pow_phi) + return torch.stack((x_real_pow, x_imag_pow), dim=-1) + + +def crop_patches(x: torch.Tensor, size=64, stride=32) -> torch.Tensor: + r"""Crop tensor with images into small patches + Args: + x: Tensor with shape (N, C, H, W), expected to be images-like entities + size: Size of a square patch + stride: Step between patches + """ + assert (x.shape[2] >= size) and (x.shape[3] >= size), \ + f"Images must be bigger than patch size. Got ({x.shape[2], x.shape[3]}) and ({size}, {size})" + channels = x.shape[1] + patches = x.unfold(1, channels, channels).unfold(2, size, stride).unfold(3, size, stride) + patches = patches.reshape(-1, channels, size, size) + return patches diff --git a/libs/metric/piq/functional/colour_conversion.py b/libs/metric/piq/functional/colour_conversion.py new file mode 100644 index 0000000000000000000000000000000000000000..9de6eb031a60aa765a326cb6ef8cf67c37177d97 --- /dev/null +++ b/libs/metric/piq/functional/colour_conversion.py @@ -0,0 +1,136 @@ +r"""Colour space conversion functions""" +from typing import Union, Dict +import torch + + +def rgb2lmn(x: torch.Tensor) -> torch.Tensor: + r"""Convert a batch of RGB images to a batch of LMN images + + Args: + x: Batch of images with shape (N, 3, H, W). RGB colour space. + + Returns: + Batch of images with shape (N, 3, H, W). LMN colour space. + """ + weights_rgb_to_lmn = torch.tensor([[0.06, 0.63, 0.27], + [0.30, 0.04, -0.35], + [0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t() + x_lmn = torch.matmul(x.permute(0, 2, 3, 1), weights_rgb_to_lmn).permute(0, 3, 1, 2) + return x_lmn + + +def rgb2xyz(x: torch.Tensor) -> torch.Tensor: + r"""Convert a batch of RGB images to a batch of XYZ images + + Args: + x: Batch of images with shape (N, 3, H, W). RGB colour space. + + Returns: + Batch of images with shape (N, 3, H, W). XYZ colour space. + """ + mask_below = (x <= 0.04045).type(x.dtype) + mask_above = (x > 0.04045).type(x.dtype) + + tmp = x / 12.92 * mask_below + torch.pow((x + 0.055) / 1.055, 2.4) * mask_above + + weights_rgb_to_xyz = torch.tensor([[0.4124564, 0.3575761, 0.1804375], + [0.2126729, 0.7151522, 0.0721750], + [0.0193339, 0.1191920, 0.9503041]], dtype=x.dtype, device=x.device) + + x_xyz = torch.matmul(tmp.permute(0, 2, 3, 1), weights_rgb_to_xyz.t()).permute(0, 3, 1, 2) + return x_xyz + + +def xyz2lab(x: torch.Tensor, illuminant: str = 'D50', observer: str = '2') -> torch.Tensor: + r"""Convert a batch of XYZ images to a batch of LAB images + + Args: + x: Batch of images with shape (N, 3, H, W). XYZ colour space. + illuminant: {“A”, “D50”, “D55”, “D65”, “D75”, “E”}, optional. The name of the illuminant. + observer: {“2”, “10”}, optional. The aperture angle of the observer. + + Returns: + Batch of images with shape (N, 3, H, W). LAB colour space. + """ + epsilon = 0.008856 + kappa = 903.3 + illuminants: Dict[str, Dict] = \ + {"A": {'2': (1.098466069456375, 1, 0.3558228003436005), + '10': (1.111420406956693, 1, 0.3519978321919493)}, + "D50": {'2': (0.9642119944211994, 1, 0.8251882845188288), + '10': (0.9672062750333777, 1, 0.8142801513128616)}, + "D55": {'2': (0.956797052643698, 1, 0.9214805860173273), + '10': (0.9579665682254781, 1, 0.9092525159847462)}, + "D65": {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white` + '10': (0.94809667673716, 1, 1.0730513595166162)}, + "D75": {'2': (0.9497220898840717, 1, 1.226393520724154), + '10': (0.9441713925645873, 1, 1.2064272211720228)}, + "E": {'2': (1.0, 1.0, 1.0), + '10': (1.0, 1.0, 1.0)}} + + illuminants_to_use = torch.tensor(illuminants[illuminant][observer], + dtype=x.dtype, device=x.device).view(1, 3, 1, 1) + + tmp = x / illuminants_to_use + + mask_below = (tmp <= epsilon).type(x.dtype) + mask_above = (tmp > epsilon).type(x.dtype) + tmp = torch.pow(tmp, 1. / 3.) * mask_above + (kappa * tmp + 16.) / 116. * mask_below + + weights_xyz_to_lab = torch.tensor([[0, 116., 0], + [500., -500., 0], + [0, 200., -200.]], dtype=x.dtype, device=x.device) + bias_xyz_to_lab = torch.tensor([-16., 0., 0.], dtype=x.dtype, device=x.device).view(1, 3, 1, 1) + + x_lab = torch.matmul(tmp.permute(0, 2, 3, 1), weights_xyz_to_lab.t()).permute(0, 3, 1, 2) + bias_xyz_to_lab + return x_lab + + +def rgb2lab(x: torch.Tensor, data_range: Union[int, float] = 255) -> torch.Tensor: + r"""Convert a batch of RGB images to a batch of LAB images + + Args: + x: Batch of images with shape (N, 3, H, W). RGB colour space. + data_range: dynamic range of the input image. + + Returns: + Batch of images with shape (N, 3, H, W). LAB colour space. + """ + return xyz2lab(rgb2xyz(x / float(data_range))) + + +def rgb2yiq(x: torch.Tensor) -> torch.Tensor: + r"""Convert a batch of RGB images to a batch of YIQ images + + Args: + x: Batch of images with shape (N, 3, H, W). RGB colour space. + + Returns: + Batch of images with shape (N, 3, H, W). YIQ colour space. + """ + yiq_weights = torch.tensor([ + [0.299, 0.587, 0.114], + [0.5959, -0.2746, -0.3213], + [0.2115, -0.5227, 0.3112]], dtype=x.dtype, device=x.device).t() + x_yiq = torch.matmul(x.permute(0, 2, 3, 1), yiq_weights).permute(0, 3, 1, 2) + return x_yiq + + +def rgb2lhm(x: torch.Tensor) -> torch.Tensor: + r"""Convert a batch of RGB images to a batch of LHM images + + Args: + x: Batch of images with shape (N, 3, H, W). RGB colour space. + + Returns: + Batch of images with shape (N, 3, H, W). LHM colour space. + + Reference: + https://arxiv.org/pdf/1608.07433.pdf + """ + lhm_weights = torch.tensor([ + [0.2989, 0.587, 0.114], + [0.3, 0.04, -0.35], + [0.34, -0.6, 0.17]], dtype=x.dtype, device=x.device).t() + x_lhm = torch.matmul(x.permute(0, 2, 3, 1), lhm_weights).permute(0, 3, 1, 2) + return x_lhm diff --git a/libs/metric/piq/functional/filters.py b/libs/metric/piq/functional/filters.py new file mode 100644 index 0000000000000000000000000000000000000000..ff5ef1ac5110fa57b75de7476567a409842c0dfc --- /dev/null +++ b/libs/metric/piq/functional/filters.py @@ -0,0 +1,111 @@ +r"""Filters for gradient computation, bluring, etc.""" +import torch +import numpy as np +from typing import Optional + + +def haar_filter(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: + r"""Creates Haar kernel + + Args: + kernel_size: size of the kernel + device: target device for kernel generation + dtype: target data type for kernel generation + Returns: + kernel: Tensor with shape (1, kernel_size, kernel_size) + """ + kernel = torch.ones((kernel_size, kernel_size), device=device, dtype=dtype) / kernel_size + kernel[kernel_size // 2:, :] = - kernel[kernel_size // 2:, :] + return kernel.unsqueeze(0) + + +def hann_filter(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: + r"""Creates Hann kernel + Args: + kernel_size: size of the kernel + device: target device for kernel generation + dtype: target data type for kernel generation + Returns: + kernel: Tensor with shape (1, kernel_size, kernel_size) + """ + # Take bigger window and drop borders + window = torch.hann_window(kernel_size + 2, periodic=False, device=device, dtype=dtype)[1:-1] + kernel = window[:, None] * window[None, :] + # Normalize and reshape kernel + return kernel.view(1, kernel_size, kernel_size) / kernel.sum() + + +def gaussian_filter(kernel_size: int, sigma: float, device: Optional[str] = None, + dtype: Optional[type] = None) -> torch.Tensor: + r"""Returns 2D Gaussian kernel N(0,`sigma`^2) + Args: + size: Size of the kernel + sigma: Std of the distribution + device: target device for kernel generation + dtype: target data type for kernel generation + Returns: + gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size) + """ + coords = torch.arange(kernel_size, dtype=dtype, device=device) + coords -= (kernel_size - 1) / 2. + + g = coords ** 2 + g = (- (g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma ** 2)).exp() + + g /= g.sum() + return g.unsqueeze(0) + + +# Gradient operator kernels +def scharr_filter(device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: + r"""Utility function that returns a normalized 3x3 Scharr kernel in X direction + + Args: + device: target device for kernel generation + dtype: target data type for kernel generation + Returns: + kernel: Tensor with shape (1, 3, 3) + """ + return torch.tensor([[[-3., 0., 3.], [-10., 0., 10.], [-3., 0., 3.]]], device=device, dtype=dtype) / 16 + + +def prewitt_filter(device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: + r"""Utility function that returns a normalized 3x3 Prewitt kernel in X direction + + Args: + device: target device for kernel generation + dtype: target data type for kernel generation + Returns: + kernel: Tensor with shape (1, 3, 3)""" + return torch.tensor([[[-1., 0., 1.], [-1., 0., 1.], [-1., 0., 1.]]], device=device, dtype=dtype) / 3 + + +def binomial_filter1d(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: + r"""Creates 1D normalized binomial filter + + Args: + kernel_size (int): kernel size + device: target device for kernel generation + dtype: target data type for kernel generation + + Returns: + Binomial kernel with shape (1, 1, kernel_size) + """ + kernel = np.poly1d([0.5, 0.5]) ** (kernel_size - 1) + return torch.tensor(kernel.c, dtype=dtype, device=device).view(1, 1, kernel_size) + + +def average_filter2d(kernel_size: int, device: Optional[str] = None, dtype: Optional[type] = None) -> torch.Tensor: + r"""Creates 2D normalized average filter + + Args: + kernel_size (int): kernel size + device: target device for kernel generation + dtype: target data type for kernel generation + + Returns: + kernel: Tensor with shape (1, kernel_size, kernel_size) + """ + window = torch.ones(kernel_size, dtype=dtype, device=device) / kernel_size + kernel = window[:, None] * window[None, :] + return kernel.unsqueeze(0) diff --git a/libs/metric/piq/functional/layers.py b/libs/metric/piq/functional/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3f0701dbc5acfc47e97aba32c0e04aa80c5bc8fc --- /dev/null +++ b/libs/metric/piq/functional/layers.py @@ -0,0 +1,33 @@ +r"""Custom layers used in metrics computations""" +import torch +from typing import Optional + +from .filters import hann_filter + + +class L2Pool2d(torch.nn.Module): + r"""Applies L2 pooling with Hann window of size 3x3 + Args: + x: Tensor with shape (N, C, H, W)""" + EPS = 1e-12 + + def __init__(self, kernel_size: int = 3, stride: int = 2, padding=1) -> None: + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + self.kernel: Optional[torch.Tensor] = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.kernel is None: + C = x.size(1) + self.kernel = hann_filter(self.kernel_size).repeat((C, 1, 1, 1)).to(x) + + out = torch.nn.functional.conv2d( + x ** 2, self.kernel, + stride=self.stride, + padding=self.padding, + groups=x.shape[1] + ) + return (out + self.EPS).sqrt() diff --git a/libs/metric/piq/functional/resize.py b/libs/metric/piq/functional/resize.py new file mode 100644 index 0000000000000000000000000000000000000000..b3a39b45a10a71cc38ac07a89929cf2fad033239 --- /dev/null +++ b/libs/metric/piq/functional/resize.py @@ -0,0 +1,426 @@ +""" +A standalone PyTorch implementation for fast and efficient bicubic resampling. +The resulting values are the same to MATLAB function imresize('bicubic'). +## Author: Sanghyun Son +## Email: sonsang35@gmail.com (primary), thstkdgus35@snu.ac.kr (secondary) +## Version: 1.2.0 +## Last update: July 9th, 2020 (KST) +Dependency: torch +Example:: +>>> import torch +>>> import core +>>> x = torch.arange(16).float().view(1, 1, 4, 4) +>>> y = core.imresize(x, sizes=(3, 3)) +>>> print(y) +tensor([[[[ 0.7506, 2.1004, 3.4503], + [ 6.1505, 7.5000, 8.8499], + [11.5497, 12.8996, 14.2494]]]]) +""" + +import math +import typing + +import torch +from torch.nn import functional as F + +__all__ = ['imresize'] + +_I = typing.Optional[int] +_D = typing.Optional[torch.dtype] + + +def nearest_contribution(x: torch.Tensor) -> torch.Tensor: + range_around_0 = torch.logical_and(x.gt(-0.5), x.le(0.5)) + cont = range_around_0.to(dtype=x.dtype) + return cont + + +def linear_contribution(x: torch.Tensor) -> torch.Tensor: + ax = x.abs() + range_01 = ax.le(1) + cont = (1 - ax) * range_01.to(dtype=x.dtype) + return cont + + +def cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor: + ax = x.abs() + ax2 = ax * ax + ax3 = ax * ax2 + + range_01 = ax.le(1) + range_12 = torch.logical_and(ax.gt(1), ax.le(2)) + + cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1 + cont_01 = cont_01 * range_01.to(dtype=x.dtype) + + cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a) + cont_12 = cont_12 * range_12.to(dtype=x.dtype) + + cont = cont_01 + cont_12 + return cont + + +def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor: + range_3sigma = (x.abs() <= 3 * sigma + 1) + # Normalization will be done after + cont = torch.exp(-x.pow(2) / (2 * sigma ** 2)) + cont = cont * range_3sigma.to(dtype=x.dtype) + return cont + + +def discrete_kernel( + kernel: str, scale: float, antialiasing: bool = True) -> torch.Tensor: + ''' + For downsampling with integer scale only. + ''' + downsampling_factor = int(1 / scale) + if kernel == 'cubic': + kernel_size_orig = 4 + else: + raise ValueError('Pass!') + + if antialiasing: + kernel_size = kernel_size_orig * downsampling_factor + else: + kernel_size = kernel_size_orig + + if downsampling_factor % 2 == 0: + a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size)) + else: + kernel_size -= 1 + a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1)) + + with torch.no_grad(): + r = torch.linspace(-a, a, steps=kernel_size) + k = cubic_contribution(r).view(-1, 1) + k = torch.matmul(k, k.t()) + k /= k.sum() + + return k + + +def reflect_padding( + x: torch.Tensor, + dim: int, + pad_pre: int, + pad_post: int) -> torch.Tensor: + ''' + Apply reflect padding to the given Tensor. + Note that it is slightly different from the PyTorch functional.pad, + where boundary elements are used only once. + Instead, we follow the MATLAB implementation + which uses boundary elements twice. + For example, + [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation, + while our implementation yields [a, a, b, c, d, d]. + ''' + b, c, h, w = x.size() + if dim == 2 or dim == -2: + padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w) + padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x) + for p in range(pad_pre): + padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :]) + for p in range(pad_post): + padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :]) + else: + padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post) + padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x) + for p in range(pad_pre): + padding_buffer[..., pad_pre - p - 1].copy_(x[..., p]) + for p in range(pad_post): + padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)]) + + return padding_buffer + + +def padding( + x: torch.Tensor, + dim: int, + pad_pre: int, + pad_post: int, + padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor: + if padding_type is None: + return x + elif padding_type == 'reflect': + x_pad = reflect_padding(x, dim, pad_pre, pad_post) + else: + raise ValueError('{} padding is not supported!'.format(padding_type)) + + return x_pad + + +def get_padding( + base: torch.Tensor, + kernel_size: int, + x_size: int) -> typing.Tuple[int, int, torch.Tensor]: + base = base.long() + r_min = base.min() + r_max = base.max() + kernel_size - 1 + + if r_min <= 0: + pad_pre = -r_min + pad_pre = pad_pre.item() + base += pad_pre + else: + pad_pre = 0 + + if r_max >= x_size: + pad_post = r_max - x_size + 1 + pad_post = pad_post.item() + else: + pad_post = 0 + + return pad_pre, pad_post, base + + +def get_weight( + dist: torch.Tensor, + kernel_size: int, + kernel: str = 'cubic', + sigma: float = 2.0, + antialiasing_factor: float = 1) -> torch.Tensor: + buffer_pos = dist.new_zeros(kernel_size, len(dist)) + for idx, buffer_sub in enumerate(buffer_pos): + buffer_sub.copy_(dist - idx) + + # Expand (downsampling) / Shrink (upsampling) the receptive field. + buffer_pos *= antialiasing_factor + if kernel == 'cubic': + weight = cubic_contribution(buffer_pos) + elif kernel == 'gaussian': + weight = gaussian_contribution(buffer_pos, sigma=sigma) + else: + raise ValueError('{} kernel is not supported!'.format(kernel)) + + weight /= weight.sum(dim=0, keepdim=True) + return weight + + +def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor: + # Resize height + if dim == 2 or dim == -2: + k = (kernel_size, 1) + h_out = x.size(-2) - kernel_size + 1 + w_out = x.size(-1) + # Resize width + else: + k = (1, kernel_size) + h_out = x.size(-2) + w_out = x.size(-1) - kernel_size + 1 + + unfold = F.unfold(x, k) + unfold = unfold.view(unfold.size(0), -1, h_out, w_out) + return unfold + + +def reshape_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, int, int]: + if x.dim() == 4: + b, c, h, w = x.size() + elif x.dim() == 3: + c, h, w = x.size() + b = None + elif x.dim() == 2: + h, w = x.size() + b = c = None + else: + raise ValueError('{}-dim Tensor is not supported!'.format(x.dim())) + + x = x.view(-1, 1, h, w) + return x, b, c, h, w + + +def reshape_output(x: torch.Tensor, b: _I, c: _I) -> torch.Tensor: + rh = x.size(-2) + rw = x.size(-1) + # Back to the original dimension + if b is not None: + x = x.view(b, c, rh, rw) # 4-dim + else: + if c is not None: + x = x.view(c, rh, rw) # 3-dim + else: + x = x.view(rh, rw) # 2-dim + + return x + + +def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]: + if x.dtype != torch.float32 or x.dtype != torch.float64: + dtype = x.dtype + x = x.float() + else: + dtype = None + + return x, dtype + + +def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor: + if dtype is not None: + if not dtype.is_floating_point: + x = x.round() + # To prevent over/underflow when converting types + if dtype is torch.uint8: + x = x.clamp(0, 255) + + x = x.to(dtype=dtype) + + return x + + +def resize_1d( + x: torch.Tensor, + dim: int, + size: int, + scale: float, + kernel: str = 'cubic', + sigma: float = 2.0, + padding_type: str = 'reflect', + antialiasing: bool = True) -> torch.Tensor: + ''' + Args: + x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W). + dim (int): + scale (float): + size (int): + Return: + ''' + # Identity case + if scale == 1: + return x + + # Default bicubic kernel with antialiasing (only when downsampling) + if kernel == 'cubic': + kernel_size = 4 + else: + kernel_size = math.floor(6 * sigma) + + if antialiasing and (scale < 1): + antialiasing_factor = scale + kernel_size = math.ceil(kernel_size / antialiasing_factor) + else: + antialiasing_factor = 1 + + # We allow margin to both sizes + kernel_size += 2 + + # Weights only depend on the shape of input and output, + # so we do not calculate gradients here. + with torch.no_grad(): + pos = torch.linspace( + 0, size - 1, steps=size, dtype=x.dtype, device=x.device, + ) + pos = (pos + 0.5) / scale - 0.5 + base = pos.floor() - (kernel_size // 2) + 1 + dist = pos - base + weight = get_weight( + dist, + kernel_size, + kernel=kernel, + sigma=sigma, + antialiasing_factor=antialiasing_factor, + ) + pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim)) + + # To backpropagate through x + x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type) + unfold = reshape_tensor(x_pad, dim, kernel_size) + # Subsampling first + if dim == 2 or dim == -2: + sample = unfold[..., base, :] + weight = weight.view(1, kernel_size, sample.size(2), 1) + else: + sample = unfold[..., base] + weight = weight.view(1, kernel_size, 1, sample.size(3)) + + # Apply the kernel + x = sample * weight + x = x.sum(dim=1, keepdim=True) + return x + + +def downsampling_2d( + x: torch.Tensor, + k: torch.Tensor, + scale: int, + padding_type: str = 'reflect') -> torch.Tensor: + c = x.size(1) + k_h = k.size(-2) + k_w = k.size(-1) + + k = k.to(dtype=x.dtype, device=x.device) + k = k.view(1, 1, k_h, k_w) + k = k.repeat(c, c, 1, 1) + e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False) + e = e.view(c, c, 1, 1) + k = k * e + + pad_h = (k_h - scale) // 2 + pad_w = (k_w - scale) // 2 + x = padding(x, -2, pad_h, pad_h, padding_type=padding_type) + x = padding(x, -1, pad_w, pad_w, padding_type=padding_type) + y = F.conv2d(x, k, padding=0, stride=scale) + return y + + +def imresize( + x: torch.Tensor, + scale: typing.Optional[float] = None, + sizes: typing.Optional[typing.Tuple[int, int]] = None, + kernel: typing.Union[str, torch.Tensor] = 'cubic', + sigma: float = 2, + rotation_degree: float = 0, + padding_type: str = 'reflect', + antialiasing: bool = True) -> torch.Tensor: + """ + Args: + x (torch.Tensor): + scale (float): + sizes (tuple(int, int)): + kernel (str, default='cubic'): + sigma (float, default=2): + rotation_degree (float, default=0): + padding_type (str, default='reflect'): + antialiasing (bool, default=True): + Return: + torch.Tensor: + """ + if scale is None and sizes is None: + raise ValueError('One of scale or sizes must be specified!') + if scale is not None and sizes is not None: + raise ValueError('Please specify scale or sizes to avoid conflict!') + + x, b, c, h, w = reshape_input(x) + + if sizes is None and scale is not None: + ''' + # Check if we can apply the convolution algorithm + scale_inv = 1 / scale + if isinstance(kernel, str) and scale_inv.is_integer(): + kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing) + elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer(): + raise ValueError( + 'An integer downsampling factor ' + 'should be used with a predefined kernel!' + ) + ''' + # Determine output size + sizes = (math.ceil(h * scale), math.ceil(w * scale)) + scales = (scale, scale) + + if scale is None and sizes is not None: + scales = (sizes[0] / h, sizes[1] / w) + + x, dtype = cast_input(x) + + if isinstance(kernel, str) and sizes is not None: + # Core resizing module + x = resize_1d(x, -2, size=sizes[0], scale=scales[0], kernel=kernel, sigma=sigma, padding_type=padding_type, + antialiasing=antialiasing) + x = resize_1d(x, -1, size=sizes[1], scale=scales[1], kernel=kernel, sigma=sigma, padding_type=padding_type, + antialiasing=antialiasing) + elif isinstance(kernel, torch.Tensor) and scale is not None: + x = downsampling_2d(x, kernel, scale=int(1 / scale)) + + x = reshape_output(x, b, c) + x = cast_output(x, dtype) + return x diff --git a/libs/metric/piq/perceptual.py b/libs/metric/piq/perceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..68a704d4f21fa569bcd6d1b4ce7862b780ba8d2a --- /dev/null +++ b/libs/metric/piq/perceptual.py @@ -0,0 +1,496 @@ +""" +Implementation of Content loss, Style loss, LPIPS and DISTS metrics +References: + .. [1] Gatys, Leon and Ecker, Alexander and Bethge, Matthias + (2016). A Neural Algorithm of Artistic Style} + Association for Research in Vision and Ophthalmology (ARVO) + https://arxiv.org/abs/1508.06576 + .. [2] Zhang, Richard and Isola, Phillip and Efros, et al. + (2018) The Unreasonable Effectiveness of Deep Features as a Perceptual Metric + 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition + https://arxiv.org/abs/1801.03924 +""" +from typing import List, Union, Collection + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss +from torchvision.models import vgg16, vgg19, VGG16_Weights, VGG19_Weights + +from .utils import _validate_input, _reduce +from .functional import similarity_map, L2Pool2d + +# Map VGG names to corresponding number in torchvision layer +VGG16_LAYERS = { + "conv1_1": '0', "relu1_1": '1', + "conv1_2": '2', "relu1_2": '3', + "pool1": '4', + "conv2_1": '5', "relu2_1": '6', + "conv2_2": '7', "relu2_2": '8', + "pool2": '9', + "conv3_1": '10', "relu3_1": '11', + "conv3_2": '12', "relu3_2": '13', + "conv3_3": '14', "relu3_3": '15', + "pool3": '16', + "conv4_1": '17', "relu4_1": '18', + "conv4_2": '19', "relu4_2": '20', + "conv4_3": '21', "relu4_3": '22', + "pool4": '23', + "conv5_1": '24', "relu5_1": '25', + "conv5_2": '26', "relu5_2": '27', + "conv5_3": '28', "relu5_3": '29', + "pool5": '30', +} + +VGG19_LAYERS = { + "conv1_1": '0', "relu1_1": '1', + "conv1_2": '2', "relu1_2": '3', + "pool1": '4', + "conv2_1": '5', "relu2_1": '6', + "conv2_2": '7', "relu2_2": '8', + "pool2": '9', + "conv3_1": '10', "relu3_1": '11', + "conv3_2": '12', "relu3_2": '13', + "conv3_3": '14', "relu3_3": '15', + "conv3_4": '16', "relu3_4": '17', + "pool3": '18', + "conv4_1": '19', "relu4_1": '20', + "conv4_2": '21', "relu4_2": '22', + "conv4_3": '23', "relu4_3": '24', + "conv4_4": '25', "relu4_4": '26', + "pool4": '27', + "conv5_1": '28', "relu5_1": '29', + "conv5_2": '30', "relu5_2": '31', + "conv5_3": '32', "relu5_3": '33', + "conv5_4": '34', "relu5_4": '35', + "pool5": '36', +} + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + +# Constant used in feature normalization to avoid zero division +EPS = 1e-10 + + +class ContentLoss(_Loss): + r"""Creates Content loss that can be used for image style transfer or as a measure for image to image tasks. + Uses pretrained VGG models from torchvision. + Expects input to be in range [0, 1] or normalized with ImageNet statistics into range [-1, 1] + + Args: + feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``. + layers: List of strings with layer names. Default: ``'relu3_3'`` + weights: List of float weight to balance different layers + replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details. + distance: Method to compute distance between features: ``'mse'`` | ``'mae'``. + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` + mean: List of float values used for data standardization. Default: ImageNet mean. + If there is no need to normalize data, use [0., 0., 0.]. + std: List of float values used for data standardization. Default: ImageNet std. + If there is no need to normalize data, use [1., 1., 1.]. + normalize_features: If true, unit-normalize each feature in channel dimension before scaling + and computing distance. See references for details. + + Examples: + >>> loss = ContentLoss() + >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) + >>> y = torch.rand(3, 3, 256, 256) + >>> output = loss(x, y) + >>> output.backward() + + References: + Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016). + A Neural Algorithm of Artistic Style + Association for Research in Vision and Ophthalmology (ARVO) + https://arxiv.org/abs/1508.06576 + + Zhang, Richard and Isola, Phillip and Efros, et al. (2018) + The Unreasonable Effectiveness of Deep Features as a Perceptual Metric + IEEE/CVF Conference on Computer Vision and Pattern Recognition + https://arxiv.org/abs/1801.03924 + """ + + def __init__(self, feature_extractor: Union[str, torch.nn.Module] = "vgg16", layers: Collection[str] = ("relu3_3",), + weights: List[Union[float, torch.Tensor]] = [1.], replace_pooling: bool = False, + distance: str = "mse", reduction: str = "mean", mean: List[float] = IMAGENET_MEAN, + std: List[float] = IMAGENET_STD, normalize_features: bool = False, + allow_layers_weights_mismatch: bool = False) -> None: + + assert allow_layers_weights_mismatch or len(layers) == len(weights), \ + f'Lengths of provided layers and weighs mismatch ({len(weights)} weights and {len(layers)} layers), ' \ + f'which will cause incorrect results. Please provide weight for each layer.' + + super().__init__() + + if callable(feature_extractor): + self.model = feature_extractor + self.layers = layers + else: + if feature_extractor == "vgg16": + # self.model = vgg16(pretrained=True, progress=False).features + self.model = vgg16(weights=VGG16_Weights.DEFAULT, progress=False).features + self.layers = [VGG16_LAYERS[l] for l in layers] + elif feature_extractor == "vgg19": + # self.model = vgg19(pretrained=True, progress=False).features + self.model = vgg19(weights=VGG19_Weights.DEFAULT, progress=False).features + self.layers = [VGG19_LAYERS[l] for l in layers] + else: + raise ValueError("Unknown feature extractor") + + if replace_pooling: + self.model = self.replace_pooling(self.model) + + # Disable gradients + for param in self.model.parameters(): + param.requires_grad_(False) + + self.distance = { + "mse": nn.MSELoss, + "mae": nn.L1Loss, + }[distance](reduction='none') + + self.weights = [torch.tensor(w) if not isinstance(w, torch.Tensor) else w for w in weights] + + mean = torch.tensor(mean) + std = torch.tensor(std) + self.mean = mean.view(1, -1, 1, 1) + self.std = std.view(1, -1, 1, 1) + + self.normalize_features = normalize_features + self.reduction = reduction + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + r"""Computation of Content loss between feature representations of prediction :math:`x` and + target :math:`y` tensors. + + Args: + x: An input tensor. Shape :math:`(N, C, H, W)`. + y: A target tensor. Shape :math:`(N, C, H, W)`. + + Returns: + Content loss between feature representations + """ + _validate_input([x, y], dim_range=(4, 4), data_range=(0, -1)) + + self.model.to(x) + x_features = self.get_features(x) + y_features = self.get_features(y) + + distances = self.compute_distance(x_features, y_features) + + # Scale distances, then average in spatial dimensions, then stack and sum in channels dimension + loss = torch.cat([(d * w.to(d)).mean(dim=[2, 3]) for d, w in zip(distances, self.weights)], dim=1).sum(dim=1) + + return _reduce(loss, self.reduction) + + def compute_distance(self, x_features: List[torch.Tensor], y_features: List[torch.Tensor]) -> List[torch.Tensor]: + r"""Take L2 or L1 distance between feature maps depending on ``distance``. + + Args: + x_features: Features of the input tensor. + y_features: Features of the target tensor. + + Returns: + Distance between feature maps + """ + return [self.distance(x, y) for x, y in zip(x_features, y_features)] + + def get_features(self, x: torch.Tensor) -> List[torch.Tensor]: + r""" + Args: + x: Tensor. Shape :math:`(N, C, H, W)`. + + Returns: + List of features extracted from intermediate layers + """ + # Normalize input + x = (x - self.mean.to(x)) / self.std.to(x) + + features = [] + for name, module in self.model._modules.items(): + x = module(x) + if name in self.layers: + features.append(self.normalize(x) if self.normalize_features else x) + + return features + + @staticmethod + def normalize(x: torch.Tensor) -> torch.Tensor: + r"""Normalize feature maps in channel direction to unit length. + + Args: + x: Tensor. Shape :math:`(N, C, H, W)`. + + Returns: + Normalized input + """ + norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) + return x / (norm_factor + EPS) + + def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module: + r"""Turn All MaxPool layers into AveragePool + + Args: + module: Module to change MaxPool int AveragePool + + Returns: + Module with AveragePool instead MaxPool + + """ + module_output = module + if isinstance(module, torch.nn.MaxPool2d): + module_output = torch.nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + + for name, child in module.named_children(): + module_output.add_module(name, self.replace_pooling(child)) + return module_output + + +class StyleLoss(ContentLoss): + r"""Creates Style loss that can be used for image style transfer or as a measure in + image to image tasks. Computes distance between Gram matrices of feature maps. + Uses pretrained VGG models from torchvision. + + By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1]. + If no normalisation is required, change `mean` and `std` values accordingly. + + Args: + feature_extractor: Model to extract features or model name: ``'vgg16'`` | ``'vgg19'``. + layers: List of strings with layer names. Default: ``'relu3_3'`` + weights: List of float weight to balance different layers + replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details. + distance: Method to compute distance between features: ``'mse'`` | ``'mae'``. + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` + mean: List of float values used for data standardization. Default: ImageNet mean. + If there is no need to normalize data, use [0., 0., 0.]. + std: List of float values used for data standardization. Default: ImageNet std. + If there is no need to normalize data, use [1., 1., 1.]. + normalize_features: If true, unit-normalize each feature in channel dimension before scaling + and computing distance. See references for details. + + Examples: + >>> loss = StyleLoss() + >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) + >>> y = torch.rand(3, 3, 256, 256) + >>> output = loss(x, y) + >>> output.backward() + + References: + Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016). + A Neural Algorithm of Artistic Style + Association for Research in Vision and Ophthalmology (ARVO) + https://arxiv.org/abs/1508.06576 + + Zhang, Richard and Isola, Phillip and Efros, et al. (2018) + The Unreasonable Effectiveness of Deep Features as a Perceptual Metric + IEEE/CVF Conference on Computer Vision and Pattern Recognition + https://arxiv.org/abs/1801.03924 + """ + + def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor): + r"""Take L2 or L1 distance between Gram matrices of feature maps depending on ``distance``. + + Args: + x_features: Features of the input tensor. + y_features: Features of the target tensor. + + Returns: + Distance between Gram matrices + """ + x_gram = [self.gram_matrix(x) for x in x_features] + y_gram = [self.gram_matrix(x) for x in y_features] + return [self.distance(x, y) for x, y in zip(x_gram, y_gram)] + + @staticmethod + def gram_matrix(x: torch.Tensor) -> torch.Tensor: + r"""Compute Gram matrix for batch of features. + + Args: + x: Tensor. Shape :math:`(N, C, H, W)`. + + Returns: + Gram matrix for given input + """ + B, C, H, W = x.size() + gram = [] + for i in range(B): + features = x[i].view(C, H * W) + + # Add fake channel dimension + gram.append(torch.mm(features, features.t()).unsqueeze(0)) + + return torch.stack(gram) + + +class LPIPS(ContentLoss): + r"""Learned Perceptual Image Patch Similarity metric. Only VGG16 learned weights are supported. + + By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1]. + If no normalisation is required, change `mean` and `std` values accordingly. + + Args: + replace_pooling: Flag to replace MaxPooling layer with AveragePooling. See references for details. + distance: Method to compute distance between features: ``'mse'`` | ``'mae'``. + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` + mean: List of float values used for data standardization. Default: ImageNet mean. + If there is no need to normalize data, use [0., 0., 0.]. + std: List of float values used for data standardization. Default: ImageNet std. + If there is no need to normalize data, use [1., 1., 1.]. + + Examples: + >>> loss = LPIPS() + >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) + >>> y = torch.rand(3, 3, 256, 256) + >>> output = loss(x, y) + >>> output.backward() + + References: + Gatys, Leon and Ecker, Alexander and Bethge, Matthias (2016). + A Neural Algorithm of Artistic Style + Association for Research in Vision and Ophthalmology (ARVO) + https://arxiv.org/abs/1508.06576 + + Zhang, Richard and Isola, Phillip and Efros, et al. (2018) + The Unreasonable Effectiveness of Deep Features as a Perceptual Metric + IEEE/CVF Conference on Computer Vision and Pattern Recognition + https://arxiv.org/abs/1801.03924 + https://github.com/richzhang/PerceptualSimilarity + """ + _weights_url = "https://github.com/photosynthesis-team/" + \ + "photosynthesis.metrics/releases/download/v0.4.0/lpips_weights.pt" + + def __init__(self, replace_pooling: bool = False, distance: str = "mse", reduction: str = "mean", + mean: List[float] = IMAGENET_MEAN, std: List[float] = IMAGENET_STD, ) -> None: + lpips_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'] + lpips_weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False) + super().__init__("vgg16", layers=lpips_layers, weights=lpips_weights, + replace_pooling=replace_pooling, distance=distance, + reduction=reduction, mean=mean, std=std, + normalize_features=True) + + +class DISTS(ContentLoss): + r"""Deep Image Structure and Texture Similarity metric. + + By default expects input to be in range [0, 1], which is then normalized by ImageNet statistics into range [-1, 1]. + If no normalisation is required, change `mean` and `std` values accordingly. + + Args: + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` + mean: List of float values used for data standardization. Default: ImageNet mean. + If there is no need to normalize data, use [0., 0., 0.]. + std: List of float values used for data standardization. Default: ImageNet std. + If there is no need to normalize data, use [1., 1., 1.]. + + Examples: + >>> loss = DISTS() + >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) + >>> y = torch.rand(3, 3, 256, 256) + >>> output = loss(x, y) + >>> output.backward() + + References: + Keyan Ding, Kede Ma, Shiqi Wang, Eero P. Simoncelli (2020). + Image Quality Assessment: Unifying Structure and Texture Similarity. + https://arxiv.org/abs/2004.07728 + https://github.com/dingkeyan93/DISTS + """ + _weights_url = "https://github.com/photosynthesis-team/piq/releases/download/v0.4.1/dists_weights.pt" + + def __init__(self, reduction: str = "mean", mean: List[float] = IMAGENET_MEAN, + std: List[float] = IMAGENET_STD) -> None: + dists_layers = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'] + channels = [3, 64, 128, 256, 512, 512] + + weights = torch.hub.load_state_dict_from_url(self._weights_url, progress=False) + dists_weights = list(torch.split(weights['alpha'], channels, dim=1)) + dists_weights.extend(torch.split(weights['beta'], channels, dim=1)) + + super().__init__("vgg16", layers=dists_layers, weights=dists_weights, + replace_pooling=True, reduction=reduction, mean=mean, std=std, + normalize_features=False, allow_layers_weights_mismatch=True) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + r""" + + Args: + x: An input tensor. Shape :math:`(N, C, H, W)`. + y: A target tensor. Shape :math:`(N, C, H, W)`. + + Returns: + Deep Image Structure and Texture Similarity loss, i.e. ``1-DISTS`` in range [0, 1]. + """ + _, _, H, W = x.shape + + if min(H, W) > 256: + x = torch.nn.functional.interpolate( + x, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear') + y = torch.nn.functional.interpolate( + y, scale_factor=256 / min(H, W), recompute_scale_factor=False, mode='bilinear') + + loss = super().forward(x, y) + return 1 - loss + + def compute_distance(self, x_features: torch.Tensor, y_features: torch.Tensor) -> List[torch.Tensor]: + r"""Compute structure similarity between feature maps + + Args: + x_features: Features of the input tensor. + y_features: Features of the target tensor. + + Returns: + Structural similarity distance between feature maps + """ + structure_distance, texture_distance = [], [] + # Small constant for numerical stability + EPS = 1e-6 + + for x, y in zip(x_features, y_features): + x_mean = x.mean([2, 3], keepdim=True) + y_mean = y.mean([2, 3], keepdim=True) + structure_distance.append(similarity_map(x_mean, y_mean, constant=EPS)) + + x_var = ((x - x_mean) ** 2).mean([2, 3], keepdim=True) + y_var = ((y - y_mean) ** 2).mean([2, 3], keepdim=True) + xy_cov = (x * y).mean([2, 3], keepdim=True) - x_mean * y_mean + texture_distance.append((2 * xy_cov + EPS) / (x_var + y_var + EPS)) + + return structure_distance + texture_distance + + def get_features(self, x: torch.Tensor) -> List[torch.Tensor]: + r""" + + Args: + x: Input tensor + + Returns: + List of features extracted from input tensor + """ + features = super().get_features(x) + + # Add input tensor as an additional feature + features.insert(0, x) + return features + + def replace_pooling(self, module: torch.nn.Module) -> torch.nn.Module: + r"""Turn All MaxPool layers into L2Pool + + Args: + module: Module to change MaxPool into L2Pool + + Returns: + Module with L2Pool instead of MaxPool + """ + module_output = module + if isinstance(module, torch.nn.MaxPool2d): + module_output = L2Pool2d(kernel_size=3, stride=2, padding=1) + + for name, child in module.named_children(): + module_output.add_module(name, self.replace_pooling(child)) + + return module_output diff --git a/libs/metric/piq/utils/__init__.py b/libs/metric/piq/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ab6241444c024e5daa7b90190a45f481a66b69b --- /dev/null +++ b/libs/metric/piq/utils/__init__.py @@ -0,0 +1,7 @@ +from .common import _validate_input, _reduce, _parse_version + +__all__ = [ + "_validate_input", + "_reduce", + '_parse_version' +] diff --git a/libs/metric/piq/utils/common.py b/libs/metric/piq/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..1ceb336a52669616ae5609941d90c916997a53eb --- /dev/null +++ b/libs/metric/piq/utils/common.py @@ -0,0 +1,158 @@ +import torch +import re +import warnings + +from typing import Tuple, List, Optional, Union, Dict, Any + +SEMVER_VERSION_PATTERN = re.compile( + r""" + ^ + (?P0|[1-9]\d*) + \. + (?P0|[1-9]\d*) + \. + (?P0|[1-9]\d*) + (?:-(?P + (?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*) + (?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))* + ))? + (?:\+(?P + [0-9a-zA-Z-]+ + (?:\.[0-9a-zA-Z-]+)* + ))? + $ + """, + re.VERBOSE, +) + + +PEP_440_VERSION_PATTERN = r""" + v? + (?: + (?:(?P[0-9]+)!)? # epoch + (?P[0-9]+(?:\.[0-9]+)*) # release segment + (?P
                                          # pre-release
+            [-_\.]?
+            (?P(a|b|c|rc|alpha|beta|pre|preview))
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+        (?P                                         # post release
+            (?:-(?P[0-9]+))
+            |
+            (?:
+                [-_\.]?
+                (?Ppost|rev|r)
+                [-_\.]?
+                (?P[0-9]+)?
+            )
+        )?
+        (?P                                          # dev release
+            [-_\.]?
+            (?Pdev)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+    )
+    (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
+"""
+
+
+def _validate_input(
+        tensors: List[torch.Tensor],
+        dim_range: Tuple[int, int] = (0, -1),
+        data_range: Tuple[float, float] = (0., -1.),
+        # size_dim_range: Tuple[float, float] = (0., -1.),
+        size_range: Optional[Tuple[int, int]] = None,
+) -> None:
+    r"""Check that input(-s)  satisfies the requirements
+    Args:
+        tensors: Tensors to check
+        dim_range: Allowed number of dimensions. (min, max)
+        data_range: Allowed range of values in tensors. (min, max)
+        size_range: Dimensions to include in size comparison. (start_dim, end_dim + 1)
+    """
+
+    if not __debug__:
+        return
+
+    x = tensors[0]
+
+    for t in tensors:
+        assert torch.is_tensor(t), f'Expected torch.Tensor, got {type(t)}'
+        assert t.device == x.device, f'Expected tensors to be on {x.device}, got {t.device}'
+
+        if size_range is None:
+            assert t.size() == x.size(), f'Expected tensors with same size, got {t.size()} and {x.size()}'
+        else:
+            assert t.size()[size_range[0]: size_range[1]] == x.size()[size_range[0]: size_range[1]], \
+                f'Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}'
+
+        if dim_range[0] == dim_range[1]:
+            assert t.dim() == dim_range[0], f'Expected number of dimensions to be {dim_range[0]}, got {t.dim()}'
+        elif dim_range[0] < dim_range[1]:
+            assert dim_range[0] <= t.dim() <= dim_range[1], \
+                f'Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}'
+
+        if data_range[0] < data_range[1]:
+            assert data_range[0] <= t.min(), \
+                f'Expected values to be greater or equal to {data_range[0]}, got {t.min()}'
+            assert t.max() <= data_range[1], \
+                f'Expected values to be lower or equal to {data_range[1]}, got {t.max()}'
+
+
+def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
+    r"""Reduce input in batch dimension if needed.
+
+    Args:
+        x: Tensor with shape (N, *).
+        reduction: Specifies the reduction type:
+            ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
+    """
+    if reduction == 'none':
+        return x
+    elif reduction == 'mean':
+        return x.mean(dim=0)
+    elif reduction == 'sum':
+        return x.sum(dim=0)
+    else:
+        raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
+
+
+def _parse_version(version: Union[str, bytes]) -> Tuple[int, ...]:
+    """ Parses valid Python versions according to Semver and PEP 440 specifications.
+    For more on Semver check: https://semver.org/
+    For more on PEP 440 check: https://www.python.org/dev/peps/pep-0440/.
+
+    Implementation is inspired by:
+    - https://github.com/python-semver
+    - https://github.com/pypa/packaging
+
+    Args:
+        version: unparsed information about the library of interest.
+
+    Returns:
+        parsed information about the library of interest.
+    """
+    if isinstance(version, bytes):
+        version = version.decode("UTF-8")
+    elif not isinstance(version, str) and not isinstance(version, bytes):
+        raise TypeError(f"not expecting type {type(version)}")
+
+    # Semver processing
+    match = SEMVER_VERSION_PATTERN.match(version)
+    if match:
+        matched_version_parts: Dict[str, Any] = match.groupdict()
+        release = tuple([int(matched_version_parts[k]) for k in ['major', 'minor', 'patch']])
+        return release
+
+    # PEP 440 processing
+    regex = re.compile(r"^\s*" + PEP_440_VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
+    match = regex.search(version)
+
+    if match is None:
+        warnings.warn(f"{version} is not a valid SemVer or PEP 440 string")
+        return tuple()
+
+    release = tuple(int(i) for i in match.group("release").split("."))
+    return release
diff --git a/libs/metric/pytorch_fid/__init__.py b/libs/metric/pytorch_fid/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..782e20db5f2769a78f8783031cbcd327437ece9a
--- /dev/null
+++ b/libs/metric/pytorch_fid/__init__.py
@@ -0,0 +1,54 @@
+__version__ = '0.3.0'
+
+import torch
+from einops import rearrange, repeat
+
+from .inception import InceptionV3
+from .fid_score import calculate_frechet_distance
+
+
+class PytorchFIDFactory(torch.nn.Module):
+    """
+
+   Args:
+       channels:
+       inception_block_idx:
+
+    Examples:
+    >>> fid_factory =  PytorchFIDFactory()
+    >>> fid_score = fid_factory.score(real_samples=data, fake_samples=all_images)
+    >>> print(fid_score)
+   """
+
+    def __init__(self, channels: int = 3, inception_block_idx: int = 2048):
+        super().__init__()
+        self.channels = channels
+
+        # load models
+        assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM
+        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx]
+        self.inception_v3 = InceptionV3([block_idx])
+
+    @torch.no_grad()
+    def calculate_activation_statistics(self, samples):
+        features = self.inception_v3(samples)[0]
+        features = rearrange(features, '... 1 1 -> ...')
+
+        mu = torch.mean(features, dim=0).cpu()
+        sigma = torch.cov(features).cpu()
+        return mu, sigma
+
+    def score(self, real_samples, fake_samples):
+        if self.channels == 1:
+            real_samples, fake_samples = map(
+                lambda t: repeat(t, 'b 1 ... -> b c ...', c=3), (real_samples, fake_samples)
+            )
+
+        min_batch = min(real_samples.shape[0], fake_samples.shape[0])
+        real_samples, fake_samples = map(lambda t: t[:min_batch], (real_samples, fake_samples))
+
+        m1, s1 = self.calculate_activation_statistics(real_samples)
+        m2, s2 = self.calculate_activation_statistics(fake_samples)
+
+        fid_value = calculate_frechet_distance(m1, s1, m2, s2)
+        return fid_value
diff --git a/libs/metric/pytorch_fid/fid_score.py b/libs/metric/pytorch_fid/fid_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..117e0c77d25afded5e63429bb0a27a71967530f5
--- /dev/null
+++ b/libs/metric/pytorch_fid/fid_score.py
@@ -0,0 +1,322 @@
+"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
+
+The FID metric calculates the distance between two distributions of images.
+Typically, we have summary statistics (mean & covariance matrix) of one
+of these distributions, while the 2nd distribution is given by a GAN.
+
+When run as a stand-alone program, it compares the distribution of
+images that are stored as PNG/JPEG at a specified location with a
+distribution given by summary statistics (in pickle format).
+
+The FID is calculated by assuming that X_1 and X_2 are the activations of
+the pool_3 layer of the inception net for generated samples and real world
+samples respectively.
+
+See --help to see further details.
+
+Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
+of Tensorflow
+
+Copyright 2018 Institute of Bioinformatics, JKU Linz
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import os
+import pathlib
+from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
+
+import numpy as np
+import torch
+import torchvision.transforms as TF
+from PIL import Image
+from scipy import linalg
+from torch.nn.functional import adaptive_avg_pool2d
+
+try:
+    from tqdm import tqdm
+except ImportError:
+    # If tqdm is not available, provide a mock version of it
+    def tqdm(x):
+        return x
+
+from .inception import InceptionV3
+
+parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+parser.add_argument('--batch-size', type=int, default=50,
+                    help='Batch size to use')
+parser.add_argument('--num-workers', type=int,
+                    help=('Number of processes to use for data loading. '
+                          'Defaults to `min(8, num_cpus)`'))
+parser.add_argument('--device', type=str, default=None,
+                    help='Device to use. Like cuda, cuda:0 or cpu')
+parser.add_argument('--dims', type=int, default=2048,
+                    choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
+                    help=('Dimensionality of Inception features to use. '
+                          'By default, uses pool3 features'))
+parser.add_argument('--save-stats', action='store_true',
+                    help=('Generate an npz archive from a directory of samples. '
+                          'The first path is used as input and the second as output.'))
+parser.add_argument('path', type=str, nargs=2,
+                    help=('Paths to the generated images or '
+                          'to .npz statistic files'))
+
+IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
+                    'tif', 'tiff', 'webp'}
+
+
+class ImagePathDataset(torch.utils.data.Dataset):
+    def __init__(self, files, transforms=None):
+        self.files = files
+        self.transforms = transforms
+
+    def __len__(self):
+        return len(self.files)
+
+    def __getitem__(self, i):
+        path = self.files[i]
+        img = Image.open(path).convert('RGB')
+        if self.transforms is not None:
+            img = self.transforms(img)
+        return img
+
+
+def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
+                    num_workers=1):
+    """Calculates the activations of the pool_3 layer for all images.
+
+    Params:
+    -- files       : List of image files paths
+    -- model       : Instance of inception model
+    -- batch_size  : Batch size of images for the model to process at once.
+                     Make sure that the number of samples is a multiple of
+                     the batch size, otherwise some samples are ignored. This
+                     behavior is retained to match the original FID score
+                     implementation.
+    -- dims        : Dimensionality of features returned by Inception
+    -- device      : Device to run calculations
+    -- num_workers : Number of parallel dataloader workers
+
+    Returns:
+    -- A numpy array of dimension (num images, dims) that contains the
+       activations of the given tensor when feeding inception with the
+       query tensor.
+    """
+    model.eval()
+
+    if batch_size > len(files):
+        print(('Warning: batch size is bigger than the data size. '
+               'Setting batch size to data size'))
+        batch_size = len(files)
+
+    dataset = ImagePathDataset(files, transforms=TF.ToTensor())
+    dataloader = torch.utils.data.DataLoader(dataset,
+                                             batch_size=batch_size,
+                                             shuffle=False,
+                                             drop_last=False,
+                                             num_workers=num_workers)
+
+    pred_arr = np.empty((len(files), dims))
+
+    start_idx = 0
+
+    for batch in tqdm(dataloader):
+        batch = batch.to(device)
+
+        with torch.no_grad():
+            pred = model(batch)[0]
+
+        # If model output is not scalar, apply global spatial average pooling.
+        # This happens if you choose a dimensionality not equal 2048.
+        if pred.size(2) != 1 or pred.size(3) != 1:
+            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
+
+        pred = pred.squeeze(3).squeeze(2).cpu().numpy()
+
+        pred_arr[start_idx:start_idx + pred.shape[0]] = pred
+
+        start_idx = start_idx + pred.shape[0]
+
+    return pred_arr
+
+
+def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
+    """Numpy implementation of the Frechet Distance.
+    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
+    and X_2 ~ N(mu_2, C_2) is
+            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
+
+    Stable version by Dougal J. Sutherland.
+
+    Params:
+    -- mu1   : Numpy array containing the activations of a layer of the
+               inception net (like returned by the function 'get_predictions')
+               for generated samples.
+    -- mu2   : The sample mean over activations, precalculated on an
+               representative data set.
+    -- sigma1: The covariance matrix over activations for generated samples.
+    -- sigma2: The covariance matrix over activations, precalculated on an
+               representative data set.
+
+    Returns:
+    --   : The Frechet Distance.
+    """
+
+    mu1 = np.atleast_1d(mu1)
+    mu2 = np.atleast_1d(mu2)
+
+    sigma1 = np.atleast_2d(sigma1)
+    sigma2 = np.atleast_2d(sigma2)
+
+    assert mu1.shape == mu2.shape, \
+        'Training and test mean vectors have different lengths'
+    assert sigma1.shape == sigma2.shape, \
+        'Training and test covariances have different dimensions'
+
+    diff = mu1 - mu2
+
+    # Product might be almost singular
+    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+    if not np.isfinite(covmean).all():
+        msg = ('fid calculation produces singular product; '
+               'adding %s to diagonal of cov estimates') % eps
+        print(msg)
+        offset = np.eye(sigma1.shape[0]) * eps
+        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+    # Numerical error might give slight imaginary component
+    if np.iscomplexobj(covmean):
+        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+            m = np.max(np.abs(covmean.imag))
+            raise ValueError('Imaginary component {}'.format(m))
+        covmean = covmean.real
+
+    tr_covmean = np.trace(covmean)
+
+    return (diff.dot(diff) + np.trace(sigma1)
+            + np.trace(sigma2) - 2 * tr_covmean)
+
+
+def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
+                                    device='cpu', num_workers=1):
+    """Calculation of the statistics used by the FID.
+    Params:
+    -- files       : List of image files paths
+    -- model       : Instance of inception model
+    -- batch_size  : The images numpy array is split into batches with
+                     batch size batch_size. A reasonable batch size
+                     depends on the hardware.
+    -- dims        : Dimensionality of features returned by Inception
+    -- device      : Device to run calculations
+    -- num_workers : Number of parallel dataloader workers
+
+    Returns:
+    -- mu    : The mean over samples of the activations of the pool_3 layer of
+               the inception model.
+    -- sigma : The covariance matrix of the activations of the pool_3 layer of
+               the inception model.
+    """
+    act = get_activations(files, model, batch_size, dims, device, num_workers)
+    mu = np.mean(act, axis=0)
+    sigma = np.cov(act, rowvar=False)
+    return mu, sigma
+
+
+def compute_statistics_of_path(path, model, batch_size, dims, device,
+                               num_workers=1):
+    if path.endswith('.npz'):
+        with np.load(path) as f:
+            m, s = f['mu'][:], f['sigma'][:]
+    else:
+        path = pathlib.Path(path)
+        files = sorted([file for ext in IMAGE_EXTENSIONS
+                        for file in path.glob('*.{}'.format(ext))])
+        m, s = calculate_activation_statistics(files, model, batch_size,
+                                               dims, device, num_workers)
+
+    return m, s
+
+
+def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1):
+    """Calculates the FID of two paths"""
+    for p in paths:
+        if not os.path.exists(p):
+            raise RuntimeError('Invalid path: %s' % p)
+
+    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
+
+    model = InceptionV3([block_idx]).to(device)
+
+    m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
+                                        dims, device, num_workers)
+    m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
+                                        dims, device, num_workers)
+    fid_value = calculate_frechet_distance(m1, s1, m2, s2)
+
+    return fid_value
+
+
+def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
+    """Calculates the FID of two paths"""
+    if not os.path.exists(paths[0]):
+        raise RuntimeError('Invalid path: %s' % paths[0])
+
+    if os.path.exists(paths[1]):
+        raise RuntimeError('Existing output file: %s' % paths[1])
+
+    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
+
+    model = InceptionV3([block_idx]).to(device)
+
+    print(f"Saving statistics for {paths[0]}")
+
+    m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
+                                        dims, device, num_workers)
+
+    np.savez_compressed(paths[1], mu=m1, sigma=s1)
+
+
+def main():
+    args = parser.parse_args()
+
+    if args.device is None:
+        device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
+    else:
+        device = torch.device(args.device)
+
+    if args.num_workers is None:
+        try:
+            num_cpus = len(os.sched_getaffinity(0))
+        except AttributeError:
+            # os.sched_getaffinity is not available under Windows, use
+            # os.cpu_count instead (which may not return the *available* number
+            # of CPUs).
+            num_cpus = os.cpu_count()
+
+        num_workers = min(num_cpus, 8) if num_cpus is not None else 0
+    else:
+        num_workers = args.num_workers
+
+    if args.save_stats:
+        save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers)
+        return
+
+    fid_value = calculate_fid_given_paths(args.path,
+                                          args.batch_size,
+                                          device,
+                                          args.dims,
+                                          num_workers)
+    print('FID: ', fid_value)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/libs/metric/pytorch_fid/inception.py b/libs/metric/pytorch_fid/inception.py
new file mode 100644
index 0000000000000000000000000000000000000000..8898a20c0609f5bb31df3641e783ea95db45b95f
--- /dev/null
+++ b/libs/metric/pytorch_fid/inception.py
@@ -0,0 +1,341 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+
+try:
+    from torchvision.models.utils import load_state_dict_from_url
+except ImportError:
+    from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+# Inception weights ported to Pytorch from
+# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'  # noqa: E501
+
+
+class InceptionV3(nn.Module):
+    """Pretrained InceptionV3 network returning feature maps"""
+
+    # Index of default block of inception to return,
+    # corresponds to output of final average pooling
+    DEFAULT_BLOCK_INDEX = 3
+
+    # Maps feature dimensionality to their output blocks indices
+    BLOCK_INDEX_BY_DIM = {
+        64: 0,   # First max pooling features
+        192: 1,  # Second max pooling featurs
+        768: 2,  # Pre-aux classifier features
+        2048: 3  # Final average pooling features
+    }
+
+    def __init__(self,
+                 output_blocks=(DEFAULT_BLOCK_INDEX,),
+                 resize_input=True,
+                 normalize_input=True,
+                 requires_grad=False,
+                 use_fid_inception=True):
+        """Build pretrained InceptionV3
+
+        Parameters
+        ----------
+        output_blocks : list of int
+            Indices of blocks to return features of. Possible values are:
+                - 0: corresponds to output of first max pooling
+                - 1: corresponds to output of second max pooling
+                - 2: corresponds to output which is fed to aux classifier
+                - 3: corresponds to output of final average pooling
+        resize_input : bool
+            If true, bilinearly resizes input to width and height 299 before
+            feeding input to model. As the network without fully connected
+            layers is fully convolutional, it should be able to handle inputs
+            of arbitrary size, so resizing might not be strictly needed
+        normalize_input : bool
+            If true, scales the input from range (0, 1) to the range the
+            pretrained Inception network expects, namely (-1, 1)
+        requires_grad : bool
+            If true, parameters of the model require gradients. Possibly useful
+            for finetuning the network
+        use_fid_inception : bool
+            If true, uses the pretrained Inception model used in Tensorflow's
+            FID implementation. If false, uses the pretrained Inception model
+            available in torchvision. The FID Inception model has different
+            weights and a slightly different structure from torchvision's
+            Inception model. If you want to compute FID scores, you are
+            strongly advised to set this parameter to true to get comparable
+            results.
+        """
+        super(InceptionV3, self).__init__()
+
+        self.resize_input = resize_input
+        self.normalize_input = normalize_input
+        self.output_blocks = sorted(output_blocks)
+        self.last_needed_block = max(output_blocks)
+
+        assert self.last_needed_block <= 3, \
+            'Last possible output block index is 3'
+
+        self.blocks = nn.ModuleList()
+
+        if use_fid_inception:
+            inception = fid_inception_v3()
+        else:
+            inception = _inception_v3(weights='DEFAULT')
+
+        # Block 0: input to maxpool1
+        block0 = [
+            inception.Conv2d_1a_3x3,
+            inception.Conv2d_2a_3x3,
+            inception.Conv2d_2b_3x3,
+            nn.MaxPool2d(kernel_size=3, stride=2)
+        ]
+        self.blocks.append(nn.Sequential(*block0))
+
+        # Block 1: maxpool1 to maxpool2
+        if self.last_needed_block >= 1:
+            block1 = [
+                inception.Conv2d_3b_1x1,
+                inception.Conv2d_4a_3x3,
+                nn.MaxPool2d(kernel_size=3, stride=2)
+            ]
+            self.blocks.append(nn.Sequential(*block1))
+
+        # Block 2: maxpool2 to aux classifier
+        if self.last_needed_block >= 2:
+            block2 = [
+                inception.Mixed_5b,
+                inception.Mixed_5c,
+                inception.Mixed_5d,
+                inception.Mixed_6a,
+                inception.Mixed_6b,
+                inception.Mixed_6c,
+                inception.Mixed_6d,
+                inception.Mixed_6e,
+            ]
+            self.blocks.append(nn.Sequential(*block2))
+
+        # Block 3: aux classifier to final avgpool
+        if self.last_needed_block >= 3:
+            block3 = [
+                inception.Mixed_7a,
+                inception.Mixed_7b,
+                inception.Mixed_7c,
+                nn.AdaptiveAvgPool2d(output_size=(1, 1))
+            ]
+            self.blocks.append(nn.Sequential(*block3))
+
+        for param in self.parameters():
+            param.requires_grad = requires_grad
+
+    def forward(self, inp):
+        """Get Inception feature maps
+
+        Parameters
+        ----------
+        inp : torch.autograd.Variable
+            Input tensor of shape Bx3xHxW. Values are expected to be in
+            range (0, 1)
+
+        Returns
+        -------
+        List of torch.autograd.Variable, corresponding to the selected output
+        block, sorted ascending by index
+        """
+        outp = []
+        x = inp
+
+        if self.resize_input:
+            x = F.interpolate(x,
+                              size=(299, 299),
+                              mode='bilinear',
+                              align_corners=False)
+
+        if self.normalize_input:
+            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)
+
+        for idx, block in enumerate(self.blocks):
+            x = block(x)
+            if idx in self.output_blocks:
+                outp.append(x)
+
+            if idx == self.last_needed_block:
+                break
+
+        return outp
+
+
+def _inception_v3(*args, **kwargs):
+    """Wraps `torchvision.models.inception_v3`"""
+    try:
+        version = tuple(map(int, torchvision.__version__.split('.')[:2]))
+    except ValueError:
+        # Just a caution against weird version strings
+        version = (0,)
+
+    # Skips default weight inititialization if supported by torchvision
+    # version. See https://github.com/mseitzer/pytorch-fid/issues/28.
+    if version >= (0, 6):
+        kwargs['init_weights'] = False
+
+    # Backwards compatibility: `weights` argument was handled by `pretrained`
+    # argument prior to version 0.13.
+    if version < (0, 13) and 'weights' in kwargs:
+        if kwargs['weights'] == 'DEFAULT':
+            kwargs['pretrained'] = True
+        elif kwargs['weights'] is None:
+            kwargs['pretrained'] = False
+        else:
+            raise ValueError(
+                'weights=={} not supported in torchvision {}'.format(
+                    kwargs['weights'], torchvision.__version__
+                )
+            )
+        del kwargs['weights']
+
+    return torchvision.models.inception_v3(*args, **kwargs)
+
+
+def fid_inception_v3():
+    """Build pretrained Inception model for FID computation
+
+    The Inception model for FID computation uses a different set of weights
+    and has a slightly different structure than torchvision's Inception.
+
+    This method first constructs torchvision's Inception and then patches the
+    necessary parts that are different in the FID Inception model.
+    """
+    inception = _inception_v3(num_classes=1008,
+                              aux_logits=False,
+                              weights=None)
+    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
+    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
+    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
+    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
+    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
+    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
+    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
+    inception.Mixed_7b = FIDInceptionE_1(1280)
+    inception.Mixed_7c = FIDInceptionE_2(2048)
+
+    state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
+    inception.load_state_dict(state_dict)
+    return inception
+
+
+class FIDInceptionA(torchvision.models.inception.InceptionA):
+    """InceptionA block patched for FID computation"""
+    def __init__(self, in_channels, pool_features):
+        super(FIDInceptionA, self).__init__(in_channels, pool_features)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch5x5 = self.branch5x5_1(x)
+        branch5x5 = self.branch5x5_2(branch5x5)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+        # Patch: Tensorflow's average pool does not use the padded zero's in
+        # its average calculation
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
+                                   count_include_pad=False)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+        return torch.cat(outputs, 1)
+
+
+class FIDInceptionC(torchvision.models.inception.InceptionC):
+    """InceptionC block patched for FID computation"""
+    def __init__(self, in_channels, channels_7x7):
+        super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch7x7 = self.branch7x7_1(x)
+        branch7x7 = self.branch7x7_2(branch7x7)
+        branch7x7 = self.branch7x7_3(branch7x7)
+
+        branch7x7dbl = self.branch7x7dbl_1(x)
+        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+        # Patch: Tensorflow's average pool does not use the padded zero's in
+        # its average calculation
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
+                                   count_include_pad=False)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+        return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_1(torchvision.models.inception.InceptionE):
+    """First InceptionE block patched for FID computation"""
+    def __init__(self, in_channels):
+        super(FIDInceptionE_1, self).__init__(in_channels)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = [
+            self.branch3x3_2a(branch3x3),
+            self.branch3x3_2b(branch3x3),
+        ]
+        branch3x3 = torch.cat(branch3x3, 1)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = [
+            self.branch3x3dbl_3a(branch3x3dbl),
+            self.branch3x3dbl_3b(branch3x3dbl),
+        ]
+        branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+        # Patch: Tensorflow's average pool does not use the padded zero's in
+        # its average calculation
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
+                                   count_include_pad=False)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+        return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_2(torchvision.models.inception.InceptionE):
+    """Second InceptionE block patched for FID computation"""
+    def __init__(self, in_channels):
+        super(FIDInceptionE_2, self).__init__(in_channels)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = [
+            self.branch3x3_2a(branch3x3),
+            self.branch3x3_2b(branch3x3),
+        ]
+        branch3x3 = torch.cat(branch3x3, 1)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = [
+            self.branch3x3dbl_3a(branch3x3dbl),
+            self.branch3x3dbl_3b(branch3x3dbl),
+        ]
+        branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+        # Patch: The FID Inception model uses max pooling instead of average
+        # pooling. This is likely an error in this specific Inception
+        # implementation, as other Inception models use average pooling here
+        # (which matches the description in the paper).
+        branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+        return torch.cat(outputs, 1)
diff --git a/libs/modules/__init__.py b/libs/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/libs/modules/__init__.py
@@ -0,0 +1 @@
+
diff --git a/libs/modules/edge_map/DoG/XDoG.py b/libs/modules/edge_map/DoG/XDoG.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e577bfa632390a54c149f20729e122eed8fe5ca
--- /dev/null
+++ b/libs/modules/edge_map/DoG/XDoG.py
@@ -0,0 +1,73 @@
+import numpy as np
+import cv2
+from scipy import ndimage as ndi
+from skimage import filters
+
+
+class XDoG:
+
+    def __init__(self,
+                 gamma=0.98,
+                 phi=200,
+                 eps=-0.1,
+                 sigma=0.8,
+                 k=10,
+                 binarize: bool = True):
+        """
+        XDoG algorithm.
+
+        Args:
+            gamma: Control the size of the Gaussian filter
+            phi: Control changes in edge strength
+            eps: Threshold for controlling edge strength
+            sigma: The standard deviation of the Gaussian filter controls the degree of smoothness
+            k: Control the size ratio of Gaussian filter, (k=10 or k=1.6)
+            binarize(bool): Whether to binarize the output
+        """
+
+        super(XDoG, self).__init__()
+
+        self.gamma = gamma
+        assert 0 <= self.gamma <= 1
+
+        self.phi = phi
+        assert 0 <= self.phi <= 1500
+
+        self.eps = eps
+        assert -1 <= self.eps <= 1
+
+        self.sigma = sigma
+        assert 0.1 <= self.sigma <= 10
+
+        self.k = k
+        assert 1 <= self.k <= 100
+
+        self.binarize = binarize
+
+    def __call__(self, img):
+        # to gray if image is not already grayscale
+        if len(img.shape) == 3 and img.shape[2] == 3:
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+        elif len(img.shape) == 3 and img.shape[2] == 4:
+            img = cv2.cvtColor(img, cv2.COLOR_BGRA2GRAY)
+
+        if np.isnan(img).any():
+            img[np.isnan(img)] = np.mean(img[~np.isnan(img)])
+
+        # gaussian filter
+        imf1 = ndi.gaussian_filter(img, self.sigma)
+        imf2 = ndi.gaussian_filter(img, self.sigma * self.k)
+        imdiff = imf1 - self.gamma * imf2
+
+        # XDoG
+        imdiff = (imdiff < self.eps) * 1.0 + (imdiff >= self.eps) * (1.0 + np.tanh(self.phi * imdiff))
+
+        # normalize
+        imdiff -= imdiff.min()
+        imdiff /= imdiff.max()
+
+        if self.binarize:
+            th = filters.threshold_otsu(imdiff)
+            imdiff = (imdiff >= th).astype('float32')
+
+        return imdiff
diff --git a/libs/modules/edge_map/DoG/__init__.py b/libs/modules/edge_map/DoG/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e878f55fff143f505d2fc04eabac259ef75dc4a
--- /dev/null
+++ b/libs/modules/edge_map/DoG/__init__.py
@@ -0,0 +1,3 @@
+from .XDoG import XDoG
+
+__all__ = ['XDoG']
diff --git a/libs/modules/edge_map/__init__.py b/libs/modules/edge_map/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/libs/modules/edge_map/__init__.py
@@ -0,0 +1 @@
+
diff --git a/libs/modules/edge_map/canny/__init__.py b/libs/modules/edge_map/canny/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..15dbe167ee42376c8082babaa22f908cdf62eaaa
--- /dev/null
+++ b/libs/modules/edge_map/canny/__init__.py
@@ -0,0 +1,10 @@
+import cv2
+
+
+class CannyDetector:
+
+    def __call__(self, img, low_threshold, high_threshold, L2gradient=False):
+        return cv2.Canny(img, low_threshold, high_threshold, L2gradient)
+
+
+__all__ = ['CannyDetector']
diff --git a/libs/modules/edge_map/image_grads/__init__.py b/libs/modules/edge_map/image_grads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e22e312a3eb4ae30cce5c312dcbddb15b8d1a2d
--- /dev/null
+++ b/libs/modules/edge_map/image_grads/__init__.py
@@ -0,0 +1,3 @@
+from .laplacian import LaplacianDetector
+
+__all__ = ['LaplacianDetector']
diff --git a/libs/modules/edge_map/image_grads/laplacian.py b/libs/modules/edge_map/image_grads/laplacian.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f9c8465344039aa688d01f49a088a7018cf7c6c
--- /dev/null
+++ b/libs/modules/edge_map/image_grads/laplacian.py
@@ -0,0 +1,7 @@
+import cv2
+
+
+class LaplacianDetector:
+
+    def __call__(self, img):
+        return cv2.Laplacian(img, cv2.CV_64F)
diff --git a/libs/modules/ema.py b/libs/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae0aa5b11ef33fe5beba415f75294f7addcb3874
--- /dev/null
+++ b/libs/modules/ema.py
@@ -0,0 +1,194 @@
+import copy
+
+import torch
+import torch.nn as nn
+
+__all__ = ['EMA']
+
+
+class EMA(nn.Module):
+    """
+    Implements exponential moving average shadowing for your model.
+    Utilizes an inverse decay schedule to manage longer term training runs.
+    By adjusting the power, you can control how fast EMA will ramp up to your specified beta.
+    @crowsonkb's notes on EMA Warmup:
+    If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are
+    good values for models you plan to train for a million or more steps (reaches decay
+    factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models
+    you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
+    215.4k steps).
+    Args:
+        inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
+        power (float): Exponential factor of EMA warmup. Default: 1.
+        min_value (float): The minimum EMA decay rate. Default: 0.
+    """
+
+    def __init__(
+            self,
+            model,
+            # if your model has lazylinears or other types of non-deepcopyable modules,
+            # you can pass in your own ema model
+            ema_model=None,
+            beta=0.9999,
+            update_after_step=100,
+            update_every=10,
+            inv_gamma=1.0,
+            power=2 / 3,
+            min_value=0.0,
+            param_or_buffer_names_no_ema=set(),
+            ignore_names=set(),
+            ignore_startswith_names=set(),
+            # set this to False if you do not wish for the online model to be
+            # saved along with the ema model (managed externally)
+            include_online_model=True
+    ):
+        super().__init__()
+        self.beta = beta
+
+        # whether to include the online model within the module tree, so that state_dict also saves it
+        self.include_online_model = include_online_model
+
+        if include_online_model:
+            self.online_model = model
+        else:
+            self.online_model = [model]  # hack
+
+        # ema model
+        self.ema_model = ema_model
+
+        if not exists(self.ema_model):
+            try:
+                self.ema_model = copy.deepcopy(model)
+            except:
+                print('Your model was not copyable. Please make sure you are not using any LazyLinear')
+                exit()
+
+        self.ema_model.requires_grad_(False)
+
+        self.parameter_names = {name for name, param in self.ema_model.named_parameters() if param.dtype == torch.float}
+        self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if buffer.dtype == torch.float}
+
+        self.update_every = update_every
+        self.update_after_step = update_after_step
+
+        self.inv_gamma = inv_gamma
+        self.power = power
+        self.min_value = min_value
+
+        assert isinstance(param_or_buffer_names_no_ema, (set, list))
+        self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema  # parameter or buffer
+
+        self.ignore_names = ignore_names
+        self.ignore_startswith_names = ignore_startswith_names
+
+        self.register_buffer('initted', torch.Tensor([False]))
+        self.register_buffer('step', torch.tensor([0]))
+
+    @property
+    def model(self):
+        return self.online_model if self.include_online_model else self.online_model[0]
+
+    def restore_ema_model_device(self):
+        device = self.initted.device
+        self.ema_model.to(device)
+
+    def get_params_iter(self, model):
+        for name, param in model.named_parameters():
+            if name not in self.parameter_names:
+                continue
+            yield name, param
+
+    def get_buffers_iter(self, model):
+        for name, buffer in model.named_buffers():
+            if name not in self.buffer_names:
+                continue
+            yield name, buffer
+
+    def copy_params_from_model_to_ema(self):
+        for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model),
+                                                       self.get_params_iter(self.model)):
+            ma_params.data.copy_(current_params.data)
+
+        for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model),
+                                                         self.get_buffers_iter(self.model)):
+            ma_buffers.data.copy_(current_buffers.data)
+
+    def get_current_decay(self):
+        epoch = clamp(self.step.item() - self.update_after_step - 1, min_value=0.)
+        value = 1 - (1 + epoch / self.inv_gamma) ** - self.power
+
+        if epoch <= 0:
+            return 0.
+
+        return clamp(value, min_value=self.min_value, max_value=self.beta)
+
+    def update(self):
+        step = self.step.item()
+        self.step += 1
+
+        if (step % self.update_every) != 0:
+            return
+
+        if step <= self.update_after_step:
+            self.copy_params_from_model_to_ema()
+            return
+
+        if not self.initted.item():
+            self.copy_params_from_model_to_ema()
+            self.initted.data.copy_(torch.Tensor([True]))
+
+        self.update_moving_average(self.ema_model, self.model)
+
+    @torch.no_grad()
+    def update_moving_average(self, ma_model, current_model):
+        current_decay = self.get_current_decay()
+
+        for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model),
+                                                          self.get_params_iter(ma_model)):
+            if name in self.ignore_names:
+                continue
+
+            if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
+                continue
+
+            if name in self.param_or_buffer_names_no_ema:
+                ma_params.data.copy_(current_params.data)
+                continue
+
+            ma_params.data.lerp_(current_params.data, 1. - current_decay)
+
+        for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model),
+                                                          self.get_buffers_iter(ma_model)):
+            if name in self.ignore_names:
+                continue
+
+            if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]):
+                continue
+
+            if name in self.param_or_buffer_names_no_ema:
+                ma_buffer.data.copy_(current_buffer.data)
+                continue
+
+            ma_buffer.data.lerp_(current_buffer.data, 1. - current_decay)
+
+    def __call__(self, *args, **kwargs):
+        return self.ema_model(*args, **kwargs)
+
+
+def exists(val):
+    return val is not None
+
+
+def is_float_dtype(dtype):
+    return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
+
+
+def clamp(value, min_value=None, max_value=None):
+    assert exists(min_value) or exists(max_value)
+    if exists(min_value):
+        value = max(value, min_value)
+
+    if exists(max_value):
+        value = min(value, max_value)
+
+    return value
diff --git a/libs/modules/vision/__init__.py b/libs/modules/vision/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a789bf479514d831b590f37e85c2af44a2605017
--- /dev/null
+++ b/libs/modules/vision/__init__.py
@@ -0,0 +1,7 @@
+from .inception import inception_v3
+from .vgg import VGG
+
+__all__ = [
+    'inception_v3',
+    'VGG'
+]
diff --git a/libs/modules/vision/inception.py b/libs/modules/vision/inception.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ce032a3985b06b4854523665e9e48099b7a87b2
--- /dev/null
+++ b/libs/modules/vision/inception.py
@@ -0,0 +1,477 @@
+#from collections import namedtuple
+import warnings
+from typing import Callable, Any, Optional, Tuple, List
+
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+__all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs']
+
+model_urls = {
+    # Inception v3 ported from TensorFlow
+    'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth',
+}
+
+InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
+InceptionOutputs.__annotations__ = {'logits': Tensor, 'aux_logits': Optional[Tensor]}
+
+# Script annotations failed with _GoogleNetOutputs = namedtuple ...
+# _InceptionOutputs set here for backwards compat
+_InceptionOutputs = InceptionOutputs
+
+
+def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":
+    r"""Inception v3 model architecture from
+    `"Rethinking the Inception Architecture for Computer Vision" `_.
+
+    .. note::
+        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
+        N x 3 x 299 x 299, so ensure your images are sized accordingly.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+        aux_logits (bool): If True, add an auxiliary branch that can improve training.
+            Default: *True*
+        transform_input (bool): If True, preprocesses the input according to the method with which it
+            was trained on ImageNet. Default: *False*
+    """
+    if pretrained:
+        if 'transform_input' not in kwargs:
+            kwargs['transform_input'] = True
+        if 'aux_logits' in kwargs:
+            original_aux_logits = kwargs['aux_logits']
+            kwargs['aux_logits'] = True
+        else:
+            original_aux_logits = True
+        kwargs['init_weights'] = False  # we are loading weights from a pretrained model
+        model = Inception3(**kwargs)
+        state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
+                                              progress=progress)
+        model.load_state_dict(state_dict)
+        if not original_aux_logits:
+            model.aux_logits = False
+            model.AuxLogits = None
+        return model
+
+    return Inception3(**kwargs)
+
+
+class Inception3(nn.Module):
+
+    def __init__(
+            self,
+            num_classes: int = 1000,
+            aux_logits: bool = True,
+            transform_input: bool = False,
+            inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
+            init_weights: Optional[bool] = None
+    ) -> None:
+        super(Inception3, self).__init__()
+        if inception_blocks is None:
+            inception_blocks = [
+                BasicConv2d, InceptionA, InceptionB, InceptionC,
+                InceptionD, InceptionE, InceptionAux
+            ]
+        if init_weights is None:
+            warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of '
+                          'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
+                          ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
+            init_weights = True
+        assert len(inception_blocks) == 7
+        conv_block = inception_blocks[0]
+        inception_a = inception_blocks[1]
+        inception_b = inception_blocks[2]
+        inception_c = inception_blocks[3]
+        inception_d = inception_blocks[4]
+        inception_e = inception_blocks[5]
+        inception_aux = inception_blocks[6]
+
+        self.aux_logits = aux_logits
+        self.transform_input = transform_input
+        self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
+        self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
+        self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
+        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
+        self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
+        self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
+        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
+        self.Mixed_5b = inception_a(192, pool_features=32)
+        self.Mixed_5c = inception_a(256, pool_features=64)
+        self.Mixed_5d = inception_a(288, pool_features=64)
+        self.Mixed_6a = inception_b(288)
+        self.Mixed_6b = inception_c(768, channels_7x7=128)
+        self.Mixed_6c = inception_c(768, channels_7x7=160)
+        self.Mixed_6d = inception_c(768, channels_7x7=160)
+        self.Mixed_6e = inception_c(768, channels_7x7=192)
+        self.AuxLogits: Optional[nn.Module] = None
+        if aux_logits:
+            self.AuxLogits = inception_aux(768, num_classes)
+        self.Mixed_7a = inception_d(768)
+        self.Mixed_7b = inception_e(1280)
+        self.Mixed_7c = inception_e(2048)
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.dropout = nn.Dropout()
+        self.fc = nn.Linear(2048, num_classes)
+        if init_weights:
+            for m in self.modules():
+                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
+                    import scipy.stats as stats
+                    stddev = m.stddev if hasattr(m, 'stddev') else 0.1
+                    X = stats.truncnorm(-2, 2, scale=stddev)
+                    values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
+                    values = values.view(m.weight.size())
+                    with torch.no_grad():
+                        m.weight.copy_(values)
+                elif isinstance(m, nn.BatchNorm2d):
+                    nn.init.constant_(m.weight, 1)
+                    nn.init.constant_(m.bias, 0)
+
+    def _transform_input(self, x: Tensor) -> Tensor:
+        if self.transform_input:
+            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
+            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
+            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
+            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
+        return x
+
+    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
+        # N x 3 x 299 x 299
+        x = self.Conv2d_1a_3x3(x)
+        # N x 32 x 149 x 149
+        x = self.Conv2d_2a_3x3(x)
+        # N x 32 x 147 x 147
+        x = self.Conv2d_2b_3x3(x)
+        # N x 64 x 147 x 147
+        feat = self.maxpool1(x)
+        # N x 64 x 73 x 73
+        x = self.Conv2d_3b_1x1(feat)
+        # N x 80 x 73 x 73
+        x = self.Conv2d_4a_3x3(x)
+        # N x 192 x 71 x 71
+        x = self.maxpool2(x)
+        # N x 192 x 35 x 35
+        x = self.Mixed_5b(x)
+        # N x 256 x 35 x 35
+        x = self.Mixed_5c(x)
+        # N x 288 x 35 x 35
+        x = self.Mixed_5d(x)
+        # N x 288 x 35 x 35
+        x = self.Mixed_6a(x)
+        # N x 768 x 17 x 17
+        x = self.Mixed_6b(x)
+        # N x 768 x 17 x 17
+        x = self.Mixed_6c(x)
+        # N x 768 x 17 x 17
+        x = self.Mixed_6d(x)
+        # N x 768 x 17 x 17
+        x = self.Mixed_6e(x)
+        # N x 768 x 17 x 17
+        aux: Optional[Tensor] = None
+        if self.AuxLogits is not None:
+            if self.training:
+                aux = self.AuxLogits(x)
+        # N x 768 x 17 x 17
+        x = self.Mixed_7a(x)
+        # N x 1280 x 8 x 8
+        x = self.Mixed_7b(x)
+        # N x 2048 x 8 x 8
+        x = self.Mixed_7c(x)
+        # N x 2048 x 8 x 8
+        # Adaptive average pooling
+        x = self.avgpool(x)
+        # N x 2048 x 1 x 1
+        x = self.dropout(x)
+        # N x 2048 x 1 x 1
+        x = torch.flatten(x, 1)
+        # N x 2048
+        x = self.fc(x)
+        # N x 1000 (num_classes)
+        return feat, x, aux
+
+    @torch.jit.unused
+    def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
+        if self.training and self.aux_logits:
+            return InceptionOutputs(x, aux)
+        else:
+            return x  # type: ignore[return-value]
+
+    def forward(self, x: Tensor) -> InceptionOutputs:
+        x = self._transform_input(x)
+        feat, x, aux = self._forward(x)
+        aux_defined = self.training and self.aux_logits
+        if torch.jit.is_scripting():
+            if not aux_defined:
+                warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
+            return feat, InceptionOutputs(x, aux)
+        else:
+            return feat, self.eager_outputs(x, aux)
+
+
+class InceptionA(nn.Module):
+
+    def __init__(
+            self,
+            in_channels: int,
+            pool_features: int,
+            conv_block: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super(InceptionA, self).__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
+
+        self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
+        self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
+
+        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
+        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
+        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
+
+        self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch1x1 = self.branch1x1(x)
+
+        branch5x5 = self.branch5x5_1(x)
+        branch5x5 = self.branch5x5_2(branch5x5)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionB(nn.Module):
+
+    def __init__(
+            self,
+            in_channels: int,
+            conv_block: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super(InceptionB, self).__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
+
+        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
+        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
+        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch3x3 = self.branch3x3(x)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
+
+        outputs = [branch3x3, branch3x3dbl, branch_pool]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionC(nn.Module):
+
+    def __init__(
+            self,
+            in_channels: int,
+            channels_7x7: int,
+            conv_block: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super(InceptionC, self).__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
+
+        c7 = channels_7x7
+        self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
+        self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
+        self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
+
+        self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
+        self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
+        self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
+        self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
+        self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
+
+        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch1x1 = self.branch1x1(x)
+
+        branch7x7 = self.branch7x7_1(x)
+        branch7x7 = self.branch7x7_2(branch7x7)
+        branch7x7 = self.branch7x7_3(branch7x7)
+
+        branch7x7dbl = self.branch7x7dbl_1(x)
+        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionD(nn.Module):
+
+    def __init__(
+            self,
+            in_channels: int,
+            conv_block: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super(InceptionD, self).__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
+        self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
+
+        self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
+        self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
+        self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
+        self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = self.branch3x3_2(branch3x3)
+
+        branch7x7x3 = self.branch7x7x3_1(x)
+        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
+        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
+        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
+
+        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
+        outputs = [branch3x3, branch7x7x3, branch_pool]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionE(nn.Module):
+
+    def __init__(
+            self,
+            in_channels: int,
+            conv_block: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super(InceptionE, self).__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
+
+        self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
+        self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
+        self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
+
+        self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
+        self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
+        self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
+        self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
+
+        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
+
+    def _forward(self, x: Tensor) -> List[Tensor]:
+        branch1x1 = self.branch1x1(x)
+
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = [
+            self.branch3x3_2a(branch3x3),
+            self.branch3x3_2b(branch3x3),
+        ]
+        branch3x3 = torch.cat(branch3x3, 1)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = [
+            self.branch3x3dbl_3a(branch3x3dbl),
+            self.branch3x3dbl_3b(branch3x3dbl),
+        ]
+        branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+        return outputs
+
+    def forward(self, x: Tensor) -> Tensor:
+        outputs = self._forward(x)
+        return torch.cat(outputs, 1)
+
+
+class InceptionAux(nn.Module):
+
+    def __init__(
+            self,
+            in_channels: int,
+            num_classes: int,
+            conv_block: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super(InceptionAux, self).__init__()
+        if conv_block is None:
+            conv_block = BasicConv2d
+        self.conv0 = conv_block(in_channels, 128, kernel_size=1)
+        self.conv1 = conv_block(128, 768, kernel_size=5)
+        self.conv1.stddev = 0.01  # type: ignore[assignment]
+        self.fc = nn.Linear(768, num_classes)
+        self.fc.stddev = 0.001  # type: ignore[assignment]
+
+    def forward(self, x: Tensor) -> Tensor:
+        # N x 768 x 17 x 17
+        x = F.avg_pool2d(x, kernel_size=5, stride=3)
+        # N x 768 x 5 x 5
+        x = self.conv0(x)
+        # N x 128 x 5 x 5
+        x = self.conv1(x)
+        # N x 768 x 1 x 1
+        # Adaptive average pooling
+        x = F.adaptive_avg_pool2d(x, (1, 1))
+        # N x 768 x 1 x 1
+        x = torch.flatten(x, 1)
+        # N x 768
+        x = self.fc(x)
+        # N x 1000
+        return x
+
+
+class BasicConv2d(nn.Module):
+
+    def __init__(
+            self,
+            in_channels: int,
+            out_channels: int,
+            **kwargs: Any
+    ) -> None:
+        super(BasicConv2d, self).__init__()
+        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
+        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = self.conv(x)
+        x = self.bn(x)
+        return F.relu(x, inplace=True)
diff --git a/libs/modules/vision/vgg.py b/libs/modules/vision/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..17a8a6e01da00c614185106cedb2ef85c21150b9
--- /dev/null
+++ b/libs/modules/vision/vgg.py
@@ -0,0 +1,189 @@
+from typing import Union, List, Dict, Any, cast
+
+import torch
+import torch.nn as nn
+from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+__all__ = [
+    'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
+    'vgg19_bn', 'vgg19',
+]
+
+model_urls = {
+    'vgg11': 'https://download.pytorch.org/models/vgg11-8a719046.pth',
+    'vgg13': 'https://download.pytorch.org/models/vgg13-19584684.pth',
+    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
+    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
+    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
+    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
+    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
+    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
+}
+
+
+class VGG(nn.Module):
+
+    def __init__(
+            self,
+            features: nn.Module,
+            num_classes: int = 1000,
+            init_weights: bool = True
+    ) -> None:
+        super(VGG, self).__init__()
+        self.features = features
+        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
+        self.classifier = nn.Sequential(
+            nn.Linear(512 * 7 * 7, 4096),
+            nn.ReLU(True),
+            nn.Dropout(),
+            nn.Linear(4096, 4096),
+            nn.ReLU(True),
+            nn.Dropout(),
+            nn.Linear(4096, num_classes),
+        )
+        if init_weights:
+            self._initialize_weights()
+
+    def forward(self, x: torch.Tensor):
+        feat = self.features(x)
+        x = self.avgpool(feat)
+        x = torch.flatten(x, 1)
+        x = self.classifier(x)
+        return feat, x
+
+    def _initialize_weights(self) -> None:
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                nn.init.constant_(m.bias, 0)
+
+
+def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
+    layers: List[nn.Module] = []
+    in_channels = 3
+    for v in cfg:
+        if v == 'M':
+            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
+        else:
+            v = cast(int, v)
+            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
+            if batch_norm:
+                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
+            else:
+                layers += [conv2d, nn.ReLU(inplace=True)]
+            in_channels = v
+    return nn.Sequential(*layers)
+
+
+cfgs: Dict[str, List[Union[str, int]]] = {
+    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
+    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
+}
+
+
+def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
+    if pretrained:
+        kwargs['init_weights'] = False
+    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
+    if pretrained:
+        state_dict = load_state_dict_from_url(model_urls[arch],
+                                              progress=progress)
+        model.load_state_dict(state_dict)
+    return model
+
+
+def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 11-layer model (configuration "A") from
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
+
+
+def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 11-layer model (configuration "A") with batch normalization
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
+
+
+def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 13-layer model (configuration "B")
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
+
+
+def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 13-layer model (configuration "B") with batch normalization
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
+
+
+def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 16-layer model (configuration "D")
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
+
+
+def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 16-layer model (configuration "D") with batch normalization
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
+
+
+def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 19-layer model (configuration "E")
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
+
+
+def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
+    r"""VGG 19-layer model (configuration 'E') with batch normalization
+    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
diff --git a/libs/modules/visual/__init__.py b/libs/modules/visual/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/libs/modules/visual/__init__.py
@@ -0,0 +1 @@
+
diff --git a/libs/modules/visual/imshow.py b/libs/modules/visual/imshow.py
new file mode 100644
index 0000000000000000000000000000000000000000..4901356778c547bfbea1543fa1109efd284f9c94
--- /dev/null
+++ b/libs/modules/visual/imshow.py
@@ -0,0 +1,172 @@
+import pathlib
+from typing import Union, List, Text, BinaryIO, AnyStr
+
+import matplotlib.pyplot as plt
+import torch
+import torchvision.transforms as transforms
+from torchvision.utils import make_grid
+
+__all__ = [
+    'sample2pil_transforms',
+    'pt2numpy_transforms',
+    'plt_pt_img',
+    'save_grid_images_and_labels',
+    'save_grid_images_and_captions',
+]
+
+# generate sample to PIL images
+sample2pil_transforms = transforms.Compose([
+    # unnormalizing to [0,1]
+    transforms.Lambda(lambda t: torch.clamp((t + 1) / 2, min=0.0, max=1.0)),
+    # Add 0.5 after unnormalizing to [0, 255]
+    transforms.Lambda(lambda t: torch.clamp(t * 255. + 0.5, min=0, max=255)),
+    # CHW to HWC
+    transforms.Lambda(lambda t: t.permute(1, 2, 0)),
+    # to numpy ndarray, dtype int8
+    transforms.Lambda(lambda t: t.to('cpu', torch.uint8).numpy()),
+    # Converts a numpy ndarray of shape H x W x C to a PIL Image
+    transforms.ToPILImage(),
+])
+
+# generate sample to PIL images
+pt2numpy_transforms = transforms.Compose([
+    # Add 0.5 after unnormalizing to [0, 255]
+    transforms.Lambda(lambda t: torch.clamp(t * 255. + 0.5, min=0, max=255)),
+    # CHW to HWC
+    transforms.Lambda(lambda t: t.permute(1, 2, 0)),
+    # to numpy ndarray, dtype int8
+    transforms.Lambda(lambda t: t.to('cpu', torch.uint8).numpy()),
+])
+
+
+def plt_pt_img(
+        pt_img: torch.Tensor,
+        save_path: AnyStr = None,
+        title: AnyStr = None,
+        dpi: int = 300
+):
+    grid = make_grid(pt_img, normalize=True, pad_value=2)
+    ndarr = pt2numpy_transforms(grid)
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.tight_layout()
+    if title is not None:
+        plt.title(f"{title}")
+
+    plt.show()
+    if save_path is not None:
+        plt.savefig(save_path, dpi=dpi)
+
+    plt.close()
+
+
+@torch.no_grad()
+def save_grid_images_and_labels(
+        images: Union[torch.Tensor, List[torch.Tensor]],
+        probs: Union[torch.Tensor, List[torch.Tensor]],
+        labels: Union[torch.Tensor, List[torch.Tensor]],
+        classes: Union[torch.Tensor, List[torch.Tensor]],
+        fp: Union[Text, pathlib.Path, BinaryIO],
+        nrow: int = 4,
+        normalize: bool = True
+) -> None:
+    """Save a given Tensor into an image file.
+    """
+    num_images = len(images)
+    num_rows, num_cols = _get_subplot_shape(num_images, nrow)
+
+    fig = plt.figure(figsize=(25, 20))
+
+    for i in range(num_images):
+        ax = fig.add_subplot(num_rows, num_cols, i + 1)
+
+        image, true_label, prob = images[i], labels[i], probs[i]
+
+        true_prob = prob[true_label]
+        incorrect_prob, incorrect_label = torch.max(prob, dim=0)
+        true_class = classes[true_label]
+
+        incorrect_class = classes[incorrect_label]
+
+        if normalize:
+            image = sample2pil_transforms(image)
+
+        ax.imshow(image)
+        title = f'true label: {true_class} ({true_prob:.3f})\n ' \
+                f'pred label: {incorrect_class} ({incorrect_prob:.3f})'
+        ax.set_title(title, fontsize=20)
+        ax.axis('off')
+
+    fig.subplots_adjust(hspace=0.3)
+
+    plt.savefig(fp)
+    plt.close()
+
+
+@torch.no_grad()
+def save_grid_images_and_captions(
+        images: Union[torch.Tensor, List[torch.Tensor]],
+        captions: List,
+        fp: Union[Text, pathlib.Path, BinaryIO],
+        nrow: int = 4,
+        normalize: bool = True
+) -> None:
+    """
+    Save a grid of images and their captions into an image file.
+
+    Args:
+        images (Union[torch.Tensor, List[torch.Tensor]]): A list of images to display.
+        captions (List): A list of captions for each image.
+        fp (Union[Text, pathlib.Path, BinaryIO]): The file path to save the image to.
+        nrow (int, optional): The number of images to display in each row. Defaults to 4.
+        normalize (bool, optional): Whether to normalize the image or not. Defaults to False.
+    """
+    num_images = len(images)
+    num_rows, num_cols = _get_subplot_shape(num_images, nrow)
+
+    fig = plt.figure(figsize=(25, 20))
+
+    for i in range(num_images):
+        ax = fig.add_subplot(num_rows, num_cols, i + 1)
+        image, caption = images[i], captions[i]
+
+        if normalize:
+            image = sample2pil_transforms(image)
+
+        ax.imshow(image)
+        title = f'"{caption}"' if num_images > 1 else f'"{captions}"'
+        title = _insert_newline(title)
+        ax.set_title(title, fontsize=20)
+        ax.axis('off')
+
+    fig.subplots_adjust(hspace=0.3)
+
+    plt.savefig(fp)
+    plt.close()
+
+
+def _get_subplot_shape(num_images, nrow):
+    """
+    Calculate the number of rows and columns required to display images in a grid.
+
+    Args:
+        num_images (int): The total number of images to display.
+        nrow (int): The maximum number of images to display in each row.
+
+    Returns:
+        Tuple[int, int]: The number of rows and columns required to display images in a grid.
+    """
+    num_cols = min(num_images, nrow)
+    num_rows = (num_images + num_cols - 1) // num_cols
+    return num_rows, num_cols
+
+
+def _insert_newline(string, point=9):
+    # split by blank
+    words = string.split()
+    if len(words) <= point:
+        return string
+
+    word_chunks = [words[i:i + point] for i in range(0, len(words), point)]
+    new_string = "\n".join(" ".join(chunk) for chunk in word_chunks)
+    return new_string
diff --git a/libs/modules/visual/video.py b/libs/modules/visual/video.py
new file mode 100644
index 0000000000000000000000000000000000000000..e30dc31ec6fe1e10d22523901198c6ecabc755ec
--- /dev/null
+++ b/libs/modules/visual/video.py
@@ -0,0 +1,34 @@
+from typing import Any, Union
+import pathlib
+
+import cv2
+
+
+def create_video(num_iter: int,
+                 save_dir: Union[Any, pathlib.Path],
+                 video_frame_freq: int = 1,
+                 fname: str = "rendering_process",
+                 verbose: bool = True):
+    if not isinstance(save_dir, pathlib.Path):
+        save_dir = pathlib.Path(save_dir)
+
+    img_array = []
+    for i in range(0, num_iter):
+        if i % video_frame_freq == 0 or i == num_iter - 1:
+            filename = save_dir / f"iter{i}.png"
+            img = cv2.imread(filename.as_posix())
+            img_array.append(img)
+
+    video_name = save_dir / f"{fname}.mp4"
+    out = cv2.VideoWriter(
+        video_name.as_posix(),
+        cv2.VideoWriter_fourcc(*'mp4v'),
+        30.0,  # fps
+        (600, 600)  # video size
+    )
+    for iii in range(len(img_array)):
+        out.write(img_array[iii])
+    out.release()
+
+    if verbose:
+        print(f"video saved in '{video_name}'.")
diff --git a/libs/solver/__init__.py b/libs/solver/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/libs/solver/__init__.py
@@ -0,0 +1 @@
+
diff --git a/libs/solver/lr_scheduler.py b/libs/solver/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9bb82e96c5cb05d1ce296fb859709ac44c3eaed
--- /dev/null
+++ b/libs/solver/lr_scheduler.py
@@ -0,0 +1,350 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch optimization for diffusion models."""
+
+import math
+from enum import Enum
+from typing import Optional, Union
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LambdaLR
+
+
+class SchedulerType(Enum):
+    LINEAR = "linear"
+    COSINE = "cosine"
+    COSINE_WITH_RESTARTS = "cosine_with_restarts"
+    POLYNOMIAL = "polynomial"
+    CONSTANT = "constant"
+    CONSTANT_WITH_WARMUP = "constant_with_warmup"
+    PIECEWISE_CONSTANT = "piecewise_constant"
+
+
+def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
+    """
+    Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+    return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
+
+
+def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
+    """
+    Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
+    increases linearly between 0 and the initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    def lr_lambda(current_step: int):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1.0, num_warmup_steps))
+        return 1.0
+
+    return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
+
+
+def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
+    """
+    Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        step_rules (`string`):
+            The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
+            if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
+            steps and multiple 0.005 for the other steps.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    rules_dict = {}
+    rule_list = step_rules.split(",")
+    for rule_str in rule_list[:-1]:
+        value_str, steps_str = rule_str.split(":")
+        steps = int(steps_str)
+        value = float(value_str)
+        rules_dict[steps] = value
+    last_lr_multiple = float(rule_list[-1])
+
+    def create_rules_function(rules_dict, last_lr_multiple):
+        def rule_func(steps: int) -> float:
+            sorted_steps = sorted(rules_dict.keys())
+            for i, sorted_step in enumerate(sorted_steps):
+                if steps < sorted_step:
+                    return rules_dict[sorted_steps[i]]
+            return last_lr_multiple
+
+        return rule_func
+
+    rules_func = create_rules_function(rules_dict, last_lr_multiple)
+
+    return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
+
+
+def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
+    """
+    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
+    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    def lr_lambda(current_step: int):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1, num_warmup_steps))
+        return max(
+            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
+        )
+
+    return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_cosine_schedule_with_warmup(
+        optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5,
+        last_epoch: int = -1
+):
+    """
+    Create a schedule with a learning rate that decreases following the values of the cosine function between the
+    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+    initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        num_periods (`float`, *optional*, defaults to 0.5):
+            The number of periods of the cosine function in a schedule (the default is to just decrease from the max
+            value to 0 following a half-cosine).
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    def lr_lambda(current_step):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1, num_warmup_steps))
+        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+    return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_cosine_with_hard_restarts_schedule_with_warmup(
+        optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
+):
+    """
+    Create a schedule with a learning rate that decreases following the values of the cosine function between the
+    initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
+    linearly between 0 and the initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        num_cycles (`int`, *optional*, defaults to 1):
+            The number of hard restarts to use.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    def lr_lambda(current_step):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1, num_warmup_steps))
+        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+        if progress >= 1.0:
+            return 0.0
+        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
+
+    return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_polynomial_decay_schedule_with_warmup(
+        optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
+):
+    """
+    Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
+    optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
+    initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        lr_end (`float`, *optional*, defaults to 1e-7):
+            The end LR.
+        power (`float`, *optional*, defaults to 1.0):
+            Power factor.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
+    implementation at
+    https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+
+    """
+
+    lr_init = optimizer.defaults["lr"]
+    if not (lr_init > lr_end):
+        raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
+
+    def lr_lambda(current_step: int):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1, num_warmup_steps))
+        elif current_step > num_training_steps:
+            return lr_end / lr_init  # as LambdaLR multiplies by lr_init
+        else:
+            lr_range = lr_init - lr_end
+            decay_steps = num_training_steps - num_warmup_steps
+            pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
+            decay = lr_range * pct_remaining ** power + lr_end
+            return decay / lr_init  # as LambdaLR multiplies by lr_init
+
+    return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+TYPE_TO_SCHEDULER_FUNCTION = {
+    SchedulerType.LINEAR: get_linear_schedule_with_warmup,
+    SchedulerType.COSINE: get_cosine_schedule_with_warmup,
+    SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
+    SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
+    SchedulerType.CONSTANT: get_constant_schedule,
+    SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
+    SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
+}
+
+
+def get_scheduler(
+        name: Union[str, SchedulerType],
+        optimizer: Optimizer,
+        step_rules: Optional[str] = None,
+        num_warmup_steps: Optional[int] = None,
+        num_training_steps: Optional[int] = None,
+        num_cycles: int = 1,
+        power: float = 1.0,
+        last_epoch: int = -1,
+):
+    """
+    Unified API to get any scheduler from its name.
+
+    Args:
+        name (`str` or `SchedulerType`):
+            The name of the scheduler to use.
+        optimizer (`torch.optim.Optimizer`):
+            The optimizer that will be used during training.
+        step_rules (`str`, *optional*):
+            A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
+        num_warmup_steps (`int`, *optional*):
+            The number of warmup steps to do. This is not required by all schedulers (hence the argument being
+            optional), the function will raise an error if it's unset and the scheduler type requires it.
+        num_training_steps (`int``, *optional*):
+            The number of training steps to do. This is not required by all schedulers (hence the argument being
+            optional), the function will raise an error if it's unset and the scheduler type requires it.
+        num_cycles (`int`, *optional*):
+            The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
+        power (`float`, *optional*, defaults to 1.0):
+            Power factor. See `POLYNOMIAL` scheduler
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+    """
+    name = SchedulerType(name)
+    schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+    if name == SchedulerType.CONSTANT:
+        return schedule_func(optimizer, last_epoch=last_epoch)
+
+    if name == SchedulerType.PIECEWISE_CONSTANT:
+        return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch)
+
+    # All other schedulers require `num_warmup_steps`
+    if num_warmup_steps is None:
+        raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+    if name == SchedulerType.CONSTANT_WITH_WARMUP:
+        return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)
+
+    # All other schedulers require `num_training_steps`
+    if num_training_steps is None:
+        raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+    if name == SchedulerType.COSINE_WITH_RESTARTS:
+        return schedule_func(
+            optimizer,
+            num_warmup_steps=num_warmup_steps,
+            num_training_steps=num_training_steps,
+            num_cycles=num_cycles,
+            last_epoch=last_epoch,
+        )
+
+    if name == SchedulerType.POLYNOMIAL:
+        return schedule_func(
+            optimizer,
+            num_warmup_steps=num_warmup_steps,
+            num_training_steps=num_training_steps,
+            power=power,
+            last_epoch=last_epoch,
+        )
+
+    return schedule_func(
+        optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch
+    )
diff --git a/libs/solver/optim.py b/libs/solver/optim.py
new file mode 100644
index 0000000000000000000000000000000000000000..65027e0381465cfb5f01dfbc35f1b132078cdf10
--- /dev/null
+++ b/libs/solver/optim.py
@@ -0,0 +1,53 @@
+from functools import partial
+
+import torch
+from omegaconf import DictConfig
+
+
+def get_optimizer(optimizer_name, parameters, lr=None, config: DictConfig = None):
+    param_dict = {}
+    if optimizer_name == "adam":
+        optimizer = partial(torch.optim.Adam, params=parameters)
+        if lr is not None:
+            optimizer = partial(torch.optim.Adam, params=parameters, lr=lr)
+        if config.get('betas'):
+            param_dict['betas'] = config.betas
+        if config.get('weight_decay'):
+            param_dict['weight_decay'] = config.weight_decay
+        if config.get('eps'):
+            param_dict['eps'] = config.eps
+    elif optimizer_name == "adamw":
+        optimizer = partial(torch.optim.AdamW, params=parameters)
+        if lr is not None:
+            optimizer = partial(torch.optim.AdamW, params=parameters, lr=lr)
+        if config.get('betas'):
+            param_dict['betas'] = config.betas
+        if config.get('weight_decay'):
+            param_dict['weight_decay'] = config.weight_decay
+        if config.get('eps'):
+            param_dict['eps'] = config.eps
+    elif optimizer_name == "radam":
+        optimizer = partial(torch.optim.RAdam, params=parameters)
+        if lr is not None:
+            optimizer = partial(torch.optim.RAdam, params=parameters, lr=lr)
+        if config.get('betas'):
+            param_dict['betas'] = config.betas
+        if config.get('weight_decay'):
+            param_dict['weight_decay'] = config.weight_decay
+    elif optimizer_name == "sgd":
+        optimizer = partial(torch.optim.SGD, params=parameters)
+        if lr is not None:
+            optimizer = partial(torch.optim.SGD, params=parameters, lr=lr)
+        if config.get('momentum'):
+            param_dict['momentum'] = config.momentum
+        if config.get('weight_decay'):
+            param_dict['weight_decay'] = config.weight_decay
+        if config.get('nesterov'):
+            param_dict['nesterov'] = config.nesterov
+    else:
+        raise NotImplementedError(f"Optimizer {optimizer_name} not implemented.")
+
+    if len(param_dict.keys()) > 0:
+        return optimizer(**param_dict)
+    else:
+        return optimizer()
diff --git a/libs/utils/__init__.py b/libs/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b01a9924d29974e9d1a967a835fe08fdff1f9d0
--- /dev/null
+++ b/libs/utils/__init__.py
@@ -0,0 +1,26 @@
+from . import lazy
+
+# __getattr__, __dir__, __all__ = lazy.attach(
+#     __name__,
+#     submodules={},
+#     submod_attrs={
+#         'misc': ['identity', 'exists', 'default', 'has_int_squareroot', 'sum_params', 'cycle', 'num_to_groups',
+#                  'extract', 'normalize', 'unnormalize'],
+#         'tqdm': ['tqdm_decorator'],
+#         'lazy': ['load']
+#     }
+# )
+
+from .misc import (
+    identity,
+    exists,
+    default,
+    has_int_squareroot,
+    sum_params,
+    cycle,
+    num_to_groups,
+    extract,
+    normalize,
+    unnormalize
+)
+from .tqdm import tqdm_decorator
diff --git a/libs/utils/argparse.py b/libs/utils/argparse.py
new file mode 100644
index 0000000000000000000000000000000000000000..395559bc624c5fff929aadf2c5a673665b79b0b4
--- /dev/null
+++ b/libs/utils/argparse.py
@@ -0,0 +1,112 @@
+import argparse
+
+
+#################################################################################
+#                            practical argparse utils                           #
+#################################################################################
+
+def accelerate_parser():
+    parser = argparse.ArgumentParser(add_help=False)
+
+    # Device
+    parser.add_argument("-cpu", "--use_cpu", action="store_true",
+                        help="Whether or not disable cuda")
+
+    # Gradient Accumulation
+    parser.add_argument("-cumgard", "--gradient-accumulate-step",
+                        type=int, default=1)
+    parser.add_argument("--split-batches", action="store_true",
+                        help="Whether or not the accelerator should split the batches "
+                             "yielded by the dataloaders across the devices.")
+
+    # Nvidia-Apex and GradScaler
+    parser.add_argument("-mprec", "--mixed-precision",
+                        type=str, default='no', choices=['no', 'fp16', 'bf16'],
+                        help="Whether to use mixed precision. Choose"
+                             "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+                             "and an Nvidia Ampere GPU.")
+    parser.add_argument("--init-scale",
+                        type=float, default=65536.0,
+                        help="Default value: `2.**16 = 65536.0` ,"
+                             "For ImageNet experiments, '2.**20 = 1048576.0' was a good default value."
+                             "the others: `2.**17 = 131072.0` ")
+    parser.add_argument("--growth-factor", type=float, default=2.0)
+    parser.add_argument("--backoff-factor", type=float, default=0.5)
+    parser.add_argument("--growth-interval", type=int, default=2000)
+
+    # Gradient Normalization
+    parser.add_argument("-gard_norm", "--max_grad_norm", type=float, default=-1)
+
+    # Trackers
+    parser.add_argument("--use-wandb", action="store_true")
+    parser.add_argument("--project-name", type=str, default="SketchGeneration")
+    parser.add_argument("--entity", type=str, default="ximinng")
+    parser.add_argument("--tensorboard", action="store_true")
+
+    # timing
+    parser.add_argument("-log_step", "--log_step", default=1000, type=int,
+                        help="can be use to control log.")
+    parser.add_argument("-eval_step", "--eval_step", default=10, type=int,
+                        help="can be use to calculate some metrics.")
+    parser.add_argument("-save_step", "--save_step", default=10, type=int,
+                        help="can be use to control saving checkpoint.")
+
+    # update configuration interface
+    # example: python main.py -c main.yaml -update "nnet.depth=16 batch_size=16"
+    parser.add_argument("-update",
+                        type=str, default="sds.warmup=1000",
+                        help="modified hyper-parameters of config file. ")
+    return parser
+
+
+def ema_parser():
+    parser = argparse.ArgumentParser(add_help=False)
+    parser.add_argument('--ema', action='store_true', help='enable EMA model')
+    parser.add_argument("--ema_decay", type=float, default=0.9999)
+    parser.add_argument("--ema_update_after_step", type=int, default=100)
+    parser.add_argument("--ema_update_every", type=int, default=10)
+    return parser
+
+
+def base_data_parser():
+    parser = argparse.ArgumentParser(add_help=False)
+    parser.add_argument("-spl", "--split",
+                        default='test', type=str,
+                        choices=['train', 'val', 'test', 'all'],
+                        help="which part of the data set, 'all' means combine training and test sets.")
+    parser.add_argument("-j", "--num_workers",
+                        default=6, type=int,
+                        help="how many subprocesses to use for data loading.")
+    parser.add_argument("--shuffle",
+                        action='store_true',
+                        help="how many subprocesses to use for data loading.")
+    parser.add_argument("--drop_last",
+                        action='store_true',
+                        help="how many subprocesses to use for data loading.")
+    return parser
+
+
+def base_training_parser():
+    parser = argparse.ArgumentParser(add_help=False)
+    parser.add_argument("-tbz", "--train_batch_size",
+                        default=32, type=int,
+                        help="how many images to sample during training.")
+    parser.add_argument("-wd", "--weight_decay", default=0, type=float)
+    return parser
+
+
+def base_sampling_parser():
+    parser = argparse.ArgumentParser(add_help=False)
+    parser.add_argument("-vbz", "--valid_batch_size",
+                        default=1, type=int,
+                        help="how many images to sample during evaluation")
+    parser.add_argument("-ts", "--total_samples",
+                        default=2000, type=int,
+                        help="the total number of samples, can be used to calculate FID.")
+    parser.add_argument("-ns", "--num_samples",
+                        default=4, type=int,
+                        help="number of samples taken at a time, "
+                             "can be used to repeatedly induce samples from a generation model "
+                             "from a fixed guided information, "
+                             "eg: `one latent to ns samples` (1 latent to 5 photo generation) ")
+    return parser
diff --git a/libs/utils/lazy.py b/libs/utils/lazy.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b7270c663853c3ebdaee90bab55de50d1a9742c
--- /dev/null
+++ b/libs/utils/lazy.py
@@ -0,0 +1,139 @@
+import importlib
+import importlib.util
+import os
+import sys
+
+
+def attach(package_name, submodules=None, submod_attrs=None):
+    """Attach lazily loaded submodules, functions, or other attributes.
+
+    Typically, modules import submodules and attributes as follows::
+
+      import mysubmodule
+      import anothersubmodule
+
+      from .foo import someattr
+
+    The idea is to replace a package's `__getattr__`, `__dir__`, and
+    `__all__`, such that all imports work exactly the way they did
+    before, except that they are only imported when used.
+
+    The typical way to call this function, replacing the above imports, is::
+
+      __getattr__, __lazy_dir__, __all__ = lazy.attach(
+        __name__,
+        ['mysubmodule', 'anothersubmodule'],
+        {'foo': 'someattr'}
+      )
+
+    This functionality requires Python 3.7 or higher.
+
+    Parameters
+    ----------
+    package_name : str
+        Typically use ``__name__``.
+    submodules : set
+        List of submodules to attach.
+    submod_attrs : dict
+        Dictionary of submodule -> list of attributes / functions.
+        These attributes are imported as they are used.
+
+    Returns
+    -------
+    __getattr__, __dir__, __all__
+
+    """
+    if submod_attrs is None:
+        submod_attrs = {}
+
+    if submodules is None:
+        submodules = set()
+    else:
+        submodules = set(submodules)
+
+    attr_to_modules = {
+        attr: mod for mod, attrs in submod_attrs.items() for attr in attrs
+    }
+
+    __all__ = list(submodules | attr_to_modules.keys())
+
+    def __getattr__(name):
+        if name in submodules:
+            return importlib.import_module(f'{package_name}.{name}')
+        elif name in attr_to_modules:
+            submod = importlib.import_module(
+                f'{package_name}.{attr_to_modules[name]}'
+            )
+            return getattr(submod, name)
+        else:
+            raise AttributeError(f'No {package_name} attribute {name}')
+
+    def __dir__():
+        return __all__
+
+    eager_import = os.environ.get('EAGER_IMPORT', '')
+    if eager_import not in ['', '0', 'false']:
+        for attr in set(attr_to_modules.keys()) | submodules:
+            __getattr__(attr)
+
+    return __getattr__, __dir__, list(__all__)
+
+
+def load(fullname):
+    """Return a lazily imported proxy for a module.
+
+    We often see the following pattern::
+
+      def myfunc():
+          import scipy as sp
+          sp.argmin(...)
+          ....
+
+    This is to prevent a module, in this case `scipy`, from being
+    imported at function definition time, since that can be slow.
+
+    This function provides a proxy module that, upon access, imports
+    the actual module.  So the idiom equivalent to the above example is::
+
+      sp = lazy.load("scipy")
+
+      def myfunc():
+          sp.argmin(...)
+          ....
+
+    The initial import time is fast because the actual import is delayed
+    until the first attribute is requested. The overall import time may
+    decrease as well for users that don't make use of large portions
+    of the library.
+
+    Parameters
+    ----------
+    fullname : str
+        The full name of the module or submodule to import.  For example::
+
+          sp = lazy.load('scipy')  # import scipy as sp
+          spla = lazy.load('scipy.linalg')  # import scipy.linalg as spla
+
+    Returns
+    -------
+    pm : importlib.util._LazyModule
+        Proxy module.  Can be used like any regularly imported module.
+        Actual loading of the module occurs upon first attribute request.
+
+    """
+    try:
+        return sys.modules[fullname]
+    except KeyError:
+        pass
+
+    spec = importlib.util.find_spec(fullname)
+    if spec is None:
+        raise ModuleNotFoundError(f"No module name '{fullname}'")
+
+    module = importlib.util.module_from_spec(spec)
+    sys.modules[fullname] = module
+
+    loader = importlib.util.LazyLoader(spec.loader)
+    loader.exec_module(module)
+
+    return module
diff --git a/libs/utils/logging.py b/libs/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaa14053111f5db9a1822f47d3d12d21e6be60f1
--- /dev/null
+++ b/libs/utils/logging.py
@@ -0,0 +1,60 @@
+import os
+import sys
+import errno
+
+
+def get_logger(logs_dir: str, file_name: str = "log.txt"):
+    logger = PrintLogger(os.path.join(logs_dir, file_name))
+    sys.stdout = logger  # record all python print
+    return logger
+
+
+class PrintLogger(object):
+
+    def __init__(self, fpath=None):
+        """
+        python standard input/output records
+        """
+        self.console = sys.stdout
+        self.file = None
+        if fpath is not None:
+            mkdir_if_missing(os.path.dirname(fpath))
+            self.file = open(fpath, 'w')
+
+    def __del__(self):
+        self.close()
+
+    def __enter__(self):
+        pass
+
+    def __exit__(self, *args):
+        self.close()
+
+    def write(self, msg):
+        self.console.write(msg)
+        if self.file is not None:
+            self.file.write(msg)
+
+    def write_in(self, msg):
+        """write in log only, not console"""
+        if self.file is not None:
+            self.file.write(msg)
+
+    def flush(self):
+        self.console.flush()
+        if self.file is not None:
+            self.file.flush()
+            os.fsync(self.file.fileno())
+
+    def close(self):
+        self.console.close()
+        if self.file is not None:
+            self.file.close()
+
+
+def mkdir_if_missing(dir_path):
+    try:
+        os.makedirs(dir_path)
+    except OSError as e:
+        if e.errno != errno.EEXIST:
+            raise
diff --git a/libs/utils/meter.py b/libs/utils/meter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f52ad30ff9c3273eb1aab59d7b7feaf90b798fac
--- /dev/null
+++ b/libs/utils/meter.py
@@ -0,0 +1,65 @@
+from enum import Enum
+
+import torch
+import torch.distributed as dist
+
+
+class Summary(Enum):
+    NONE = 0
+    AVERAGE = 1
+    SUM = 2
+    COUNT = 3
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+
+    def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
+        self.name = name
+        self.fmt = fmt
+        self.summary_type = summary_type
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+    def all_reduce(self):
+        if torch.cuda.is_available():
+            device = torch.device("cuda")
+        elif torch.backends.mps.is_available():
+            device = torch.device("mps")
+        else:
+            device = torch.device("cpu")
+
+        total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
+        dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
+        self.sum, self.count = total.tolist()
+        self.avg = self.sum / self.count
+
+    def __str__(self):
+        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+        return fmtstr.format(**self.__dict__)
+
+    def summary(self):
+        fmtstr = ''
+        if self.summary_type is Summary.NONE:
+            fmtstr = ''
+        elif self.summary_type is Summary.AVERAGE:
+            fmtstr = '{name} {avg:.3f}'
+        elif self.summary_type is Summary.SUM:
+            fmtstr = '{name} {sum:.3f}'
+        elif self.summary_type is Summary.COUNT:
+            fmtstr = '{name} {count:.3f}'
+        else:
+            raise ValueError('invalid summary type %r' % self.summary_type)
+
+        return fmtstr.format(**self.__dict__)
diff --git a/libs/utils/misc.py b/libs/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..57e68f50cff30c24c9e143b1717c0f915f6cca2b
--- /dev/null
+++ b/libs/utils/misc.py
@@ -0,0 +1,74 @@
+import math
+
+import torch
+
+
+def identity(t, *args, **kwargs):
+    """return t"""
+    return t
+
+
+def exists(x):
+    """whether x is None or not"""
+    return x is not None
+
+
+def default(val, d):
+    """ternary judgment: val != None ? val : d"""
+    if exists(val):
+        return val
+    return d() if callable(d) else d
+
+
+def has_int_squareroot(num):
+    return (math.sqrt(num) ** 2) == num
+
+
+def num_to_groups(num, divisor):
+    groups = num // divisor
+    remainder = num % divisor
+    arr = [divisor] * groups
+    if remainder > 0:
+        arr.append(remainder)
+    return arr
+
+
+#################################################################################
+#                             Model Utils                                       #
+#################################################################################
+
+def sum_params(model: torch.nn.Module, eps: float = 1e6):
+    return sum(p.numel() for p in model.parameters()) / eps
+
+
+#################################################################################
+#                            DataLoader Utils                                   #
+#################################################################################
+
+def cycle(dl):
+    while True:
+        for data in dl:
+            yield data
+
+
+#################################################################################
+#                            Diffusion Model Utils                              #
+#################################################################################
+
+def extract(a, t, x_shape):
+    b, *_ = t.shape
+    assert x_shape[0] == b
+    out = a.gather(-1, t)  # 1-D tensor, shape: (b,)
+    return out.reshape(b, *((1,) * (len(x_shape) - 1)))  # shape: [b, 1, 1, 1]
+
+
+def unnormalize(x):
+    """unnormalize_to_zero_to_one"""
+    x = (x + 1) * 0.5  # Map the data interval to [0, 1]
+    return torch.clamp(x, 0.0, 1.0)
+
+
+def normalize(x):
+    """normalize_to_neg_one_to_one"""
+    x = x * 2 - 1  # Map the data interval to [-1, 1]
+    return torch.clamp(x, -1.0, 1.0)
diff --git a/libs/utils/model_summary.py b/libs/utils/model_summary.py
new file mode 100644
index 0000000000000000000000000000000000000000..07fa1ce1cbb972f1ee77923d900bccf6cc546f45
--- /dev/null
+++ b/libs/utils/model_summary.py
@@ -0,0 +1,123 @@
+import sys
+from collections import OrderedDict
+
+import numpy as np
+import torch
+
+layer_modules = (torch.nn.MultiheadAttention,)
+
+
+def summary(model, input_data=None, input_data_args=None, input_shape=None, input_dtype=torch.FloatTensor,
+            batch_size=-1,
+            *args, **kwargs):
+    """
+    give example input data as least one way like below:
+    ① input_data ---> model.forward(input_data)
+    ② input_data_args ---> model.forward(*input_data_args)
+    ③ input_shape & input_dtype ---> model.forward(*[torch.rand(2, *size).type(input_dtype) for size in input_shape])
+    """
+
+    hooks = []
+    summary = OrderedDict()
+
+    def register_hook(module):
+        def hook(module, inputs, outputs):
+
+            class_name = str(module.__class__).split(".")[-1].split("'")[0]
+            module_idx = len(summary)
+
+            key = "%s-%i" % (class_name, module_idx + 1)
+
+            info = OrderedDict()
+            info["id"] = id(module)
+            if isinstance(outputs, (list, tuple)):
+                try:
+                    info["out"] = [batch_size] + list(outputs[0].size())[1:]
+                except AttributeError:
+                    # pack_padded_seq and pad_packed_seq store feature into data attribute
+                    info["out"] = [batch_size] + list(outputs[0].data.size())[1:]
+            else:
+                info["out"] = [batch_size] + list(outputs.size())[1:]
+
+            info["params_nt"], info["params"] = 0, 0
+            for name, param in module.named_parameters():
+                info["params"] += param.nelement() * param.requires_grad
+                info["params_nt"] += param.nelement() * (not param.requires_grad)
+
+            summary[key] = info
+
+        # ignore Sequential and ModuleList and other containers
+        if isinstance(module, layer_modules) or not module._modules:
+            hooks.append(module.register_forward_hook(hook))
+
+    model.apply(register_hook)
+
+    # multiple inputs to the network
+    if isinstance(input_shape, tuple):
+        input_shape = [input_shape]
+
+    if input_data is not None:
+        x = [input_data]
+    elif input_shape is not None:
+        # batch_size of 2 for batchnorm
+        x = [torch.rand(2, *size).type(input_dtype) for size in input_shape]
+    elif input_data_args is not None:
+        x = input_data_args
+    else:
+        x = []
+    try:
+        with torch.no_grad():
+            model(*x) if not (kwargs or args) else model(*x, *args, **kwargs)
+    except Exception:
+        # This can be usefull for debugging
+        print("Failed to run summary...")
+        raise
+    finally:
+        for hook in hooks:
+            hook.remove()
+    summary_logs = []
+    summary_logs.append("--------------------------------------------------------------------------")
+    line_new = "{:<30}  {:>20} {:>20}".format("Layer (type)", "Output Shape", "Param #")
+    summary_logs.append(line_new)
+    summary_logs.append("==========================================================================")
+    total_params = 0
+    total_output = 0
+    trainable_params = 0
+    for layer in summary:
+        # layer, output_shape, params
+        line_new = "{:<30}  {:>20} {:>20}".format(
+            layer,
+            str(summary[layer]["out"]),
+            "{0:,}".format(summary[layer]["params"] + summary[layer]["params_nt"])
+        )
+        total_params += (summary[layer]["params"] + summary[layer]["params_nt"])
+        total_output += np.prod(summary[layer]["out"])
+        trainable_params += summary[layer]["params"]
+        summary_logs.append(line_new)
+
+    # assume 4 bytes/number
+    if input_data is not None:
+        total_input_size = abs(sys.getsizeof(input_data) / (1024 ** 2.))
+    elif input_shape is not None:
+        total_input_size = abs(np.prod(input_shape) * batch_size * 4. / (1024 ** 2.))
+    else:
+        total_input_size = 0.0
+    total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
+    total_params_size = abs(total_params * 4. / (1024 ** 2.))
+    total_size = total_params_size + total_output_size + total_input_size
+
+    summary_logs.append("==========================================================================")
+    summary_logs.append("Total params: {0:,}".format(total_params))
+    summary_logs.append("Trainable params: {0:,}".format(trainable_params))
+    summary_logs.append("Non-trainable params: {0:,}".format(total_params - trainable_params))
+    summary_logs.append("--------------------------------------------------------------------------")
+    summary_logs.append("Input size (MB): %0.6f" % total_input_size)
+    summary_logs.append("Forward/backward pass size (MB): %0.6f" % total_output_size)
+    summary_logs.append("Params size (MB): %0.6f" % total_params_size)
+    summary_logs.append("Estimated Total Size (MB): %0.6f" % total_size)
+    summary_logs.append("--------------------------------------------------------------------------")
+
+    summary_info = "\n".join(summary_logs)
+
+    print(summary_info)
+    return summary_info
diff --git a/libs/utils/tqdm.py b/libs/utils/tqdm.py
new file mode 100644
index 0000000000000000000000000000000000000000..2db458f20263a6b617a845ec8645540133c4a29e
--- /dev/null
+++ b/libs/utils/tqdm.py
@@ -0,0 +1,25 @@
+from typing import Callable
+from tqdm.auto import tqdm
+
+
+def tqdm_decorator(func: Callable):
+    """A decorator function called tqdm_decorator that takes a function as an argument and
+    returns a new function that wraps the input function with a tqdm progress bar.
+
+    Noting: **The input function is assumed to have an object self as its first argument**, which contains a step attribute,
+    an args attribute with a train_num_steps attribute, and an accelerator attribute with an is_main_process attribute.
+
+    Args:
+        func: tqdm_decorator
+
+    Returns:
+            a new function that wraps the input function with a tqdm progress bar.
+    """
+
+    def wrapper(*args, **kwargs):
+        with tqdm(initial=args[0].step,
+                  total=args[0].args.train_num_steps,
+                  disable=not args[0].accelerator.is_main_process) as pbar:
+            func(*args, **kwargs, pbar=pbar)
+
+    return wrapper
diff --git a/methods/__init__.py b/methods/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/methods/__init__.py
@@ -0,0 +1 @@
+
diff --git a/methods/diffusers_warp/__init__.py b/methods/diffusers_warp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b87df6ff1569f3829cfd66febc30913ef6bdfc8
--- /dev/null
+++ b/methods/diffusers_warp/__init__.py
@@ -0,0 +1,134 @@
+from typing import AnyStr
+import pathlib
+from collections import OrderedDict
+from packaging import version
+
+import torch
+from diffusers import StableDiffusionPipeline, SchedulerMixin, DiffusionPipeline
+from diffusers.utils import is_torch_version, is_xformers_available
+
+huggingface_model_dict = OrderedDict({
+    "sd14": "/nfs/StableDiffusionModels/CompVis/stable-diffusion-v1-4",  # resolution: 512
+    "sd15": "/nfs/StableDiffusionModels/runwayml/stable-diffusion-v1-5",  # resolution: 512
+    "sd21b": "stabilityai/stable-diffusion-2-1-base",  # resolution: 512
+    "sd21": "stabilityai/stable-diffusion-2-1",  # resolution: 768
+    "sdxl": "stabilityai/stable-diffusion-xl-base-1.0",  # resolution: 1024
+})
+
+_model2resolution = {
+    "sd14": 512,
+    "sd15": 512,
+    "sd21b": 512,
+    "sd21": 768,
+    "sdxl": 1024,
+}
+
+
+def model2res(model_id: str):
+    return _model2resolution.get(model_id, 512)
+
+
+def init_diffusion_pipeline(model_id: AnyStr,
+                            custom_pipeline: StableDiffusionPipeline,
+                            custom_scheduler: SchedulerMixin = None,
+                            device: torch.device = "cuda",
+                            torch_dtype: torch.dtype = torch.float32,
+                            local_files_only: bool = True,
+                            force_download: bool = False,
+                            resume_download: bool = False,
+                            ldm_speed_up: bool = False,
+                            enable_xformers: bool = True,
+                            gradient_checkpoint: bool = False,
+                            lora_path: AnyStr = None,
+                            unet_path: AnyStr = None) -> StableDiffusionPipeline:
+    """
+    A tool for initial diffusers model.
+
+    Args:
+        model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path
+        custom_pipeline: any StableDiffusionPipeline pipeline
+        custom_scheduler: any scheduler
+        device: set device
+        local_files_only: prohibited download model
+        force_download: forced download model
+        resume_download: re-download model
+        ldm_speed_up: use the `torch.compile` api to speed up unet
+        enable_xformers: enable memory efficient attention from [xFormers]
+        gradient_checkpoint: activates gradient checkpointing for the current model
+        lora_path: load LoRA checkpoint
+        unet_path: load unet checkpoint
+
+    Returns:
+            diffusers.StableDiffusionPipeline
+    """
+
+    # get model id
+    model_id = huggingface_model_dict.get(model_id, model_id)
+
+    # process diffusion model
+    if custom_scheduler is not None:
+        pipeline = custom_pipeline.from_pretrained(
+            model_id,
+            torch_dtype=torch_dtype,
+            local_files_only=local_files_only,
+            force_download=force_download,
+            resume_download=resume_download,
+            scheduler=custom_scheduler.from_pretrained(model_id,
+                                                       subfolder="scheduler",
+                                                       local_files_only=local_files_only)
+        ).to(device)
+    else:
+        pipeline = custom_pipeline.from_pretrained(
+            model_id,
+            torch_dtype=torch_dtype,
+            local_files_only=local_files_only,
+            force_download=force_download,
+            resume_download=resume_download,
+        ).to(device)
+
+    # process unet model if exist
+    if unet_path is not None and pathlib.Path(unet_path).exists():
+        print(f"=> load u-net from {unet_path}")
+        pipeline.unet.from_pretrained(model_id, subfolder="unet")
+
+    # process lora layers if exist
+    if lora_path is not None and pathlib.Path(lora_path).exists():
+        pipeline.unet.load_attn_procs(lora_path)
+        print(f"=> load lora layers into U-Net from {lora_path} ...")
+
+    # torch.compile
+    if ldm_speed_up:
+        if is_torch_version(">=", "2.0.0"):
+            pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
+            print(f"=> enable torch.compile on U-Net")
+        else:
+            print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0")
+
+    # Meta xformers
+    if enable_xformers:
+        if is_xformers_available():
+            import xformers
+
+            xformers_version = version.parse(xformers.__version__)
+            if xformers_version == version.parse("0.0.16"):
+                print(
+                    "xFormers 0.0.16 cannot be used for training in some GPUs. "
+                    "If you observe problems during training, please update xFormers to at least 0.0.17. "
+                    "See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+                )
+            print(f"=> enable xformers")
+            pipeline.unet.enable_xformers_memory_efficient_attention()
+        else:
+            print(f"=> warning: calling xformers failed")
+
+    # gradient checkpointing
+    if gradient_checkpoint:
+        if pipeline.unet.is_gradient_checkpointing:
+            print(f"=> enable gradient checkpointing")
+            pipeline.unet.enable_gradient_checkpointing()
+        else:
+            print("=> waring: gradient checkpointing is not activated for this model.")
+
+    print(f"Diffusion Model: {model_id}")
+    print(pipeline.scheduler)
+    return pipeline
diff --git a/methods/diffvg_warp/__init__.py b/methods/diffvg_warp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..05a8d578f1bece1bcc8a2e7380b102c6e09d690a
--- /dev/null
+++ b/methods/diffvg_warp/__init__.py
@@ -0,0 +1,6 @@
+from .diffvg_state import DiffVGState, init_diffvg
+
+__all__ = [
+    'DiffVGState',
+    'init_diffvg'
+]
diff --git a/methods/diffvg_warp/diffvg_state.py b/methods/diffvg_warp/diffvg_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..a53c83b5723e27efe9e2d5ae67d4eeda64b8efdf
--- /dev/null
+++ b/methods/diffvg_warp/diffvg_state.py
@@ -0,0 +1,247 @@
+import pathlib
+from typing import AnyStr, List, Union
+import xml.etree.ElementTree as etree
+
+import torch
+import pydiffvg
+
+
+def init_diffvg(device: torch.device,
+                use_gpu: bool = torch.cuda.is_available(),
+                print_timing: bool = False):
+    pydiffvg.set_device(device)
+    pydiffvg.set_use_gpu(use_gpu)
+    pydiffvg.set_print_timing(print_timing)
+
+
+class DiffVGState(torch.nn.Module):
+
+    def __init__(self,
+                 device: torch.device,
+                 use_gpu: bool = torch.cuda.is_available(),
+                 print_timing: bool = False,
+                 canvas_width: int = True,
+                 canvas_height: int = True):
+        super(DiffVGState, self).__init__()
+        # pydiffvg device setting
+        self.device = device
+        init_diffvg(device, use_gpu, print_timing)
+
+        self.canvas_width = canvas_width
+        self.canvas_height = canvas_height
+
+        # record all paths
+        self.shapes = []
+        self.shape_groups = []
+        # record the current optimized path
+        self.cur_shapes = []
+        self.cur_shape_groups = []
+
+        self.point_vars = []
+        self.color_vars = []
+
+        self.strokes_counter = 0  # counts the number of calls to "get_path"
+
+    def load_svg(self, path_svg):
+        canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene(path_svg)
+        return canvas_width, canvas_height, shapes, shape_groups
+
+    def _save_svg(self,
+                  filename: Union[AnyStr, pathlib.Path],
+                  width: int = None,
+                  height: int = None,
+                  shapes: List = None,
+                  shape_groups: List = None,
+                  use_gamma: bool = False,
+                  background: str = None):
+        """
+        Save an SVG file with specified parameters and shapes.
+        Noting: New version of SVG saving function that is an adaptation of pydiffvg.save_svg.
+        The original version saved words resulting in incomplete glyphs.
+
+        Args:
+            filename (str): The path to save the SVG file.
+            width (int): The width of the SVG canvas.
+            height (int): The height of the SVG canvas.
+            shapes (list): A list of shapes to be included in the SVG.
+            shape_groups (list): A list of shape groups.
+            use_gamma (bool): Flag indicating whether to apply gamma correction.
+            background (str, optional): The background color of the SVG.
+
+        Returns:
+            None
+        """
+        root = etree.Element('svg')
+        root.set('version', '1.1')
+        root.set('xmlns', 'http://www.w3.org/2000/svg')
+        root.set('width', str(width))
+        root.set('height', str(height))
+
+        if background is not None:
+            print(f"setting background to {background}")
+            root.set('style', str(background))
+
+        defs = etree.SubElement(root, 'defs')
+        g = etree.SubElement(root, 'g')
+
+        if use_gamma:
+            f = etree.SubElement(defs, 'filter')
+            f.set('id', 'gamma')
+            f.set('x', '0')
+            f.set('y', '0')
+            f.set('width', '100%')
+            f.set('height', '100%')
+            gamma = etree.SubElement(f, 'feComponentTransfer')
+            gamma.set('color-interpolation-filters', 'sRGB')
+            feFuncR = etree.SubElement(gamma, 'feFuncR')
+            feFuncR.set('type', 'gamma')
+            feFuncR.set('amplitude', str(1))
+            feFuncR.set('exponent', str(1 / 2.2))
+            feFuncG = etree.SubElement(gamma, 'feFuncG')
+            feFuncG.set('type', 'gamma')
+            feFuncG.set('amplitude', str(1))
+            feFuncG.set('exponent', str(1 / 2.2))
+            feFuncB = etree.SubElement(gamma, 'feFuncB')
+            feFuncB.set('type', 'gamma')
+            feFuncB.set('amplitude', str(1))
+            feFuncB.set('exponent', str(1 / 2.2))
+            feFuncA = etree.SubElement(gamma, 'feFuncA')
+            feFuncA.set('type', 'gamma')
+            feFuncA.set('amplitude', str(1))
+            feFuncA.set('exponent', str(1 / 2.2))
+            g.set('style', 'filter:url(#gamma)')
+
+        # Store color
+        for i, shape_group in enumerate(shape_groups):
+            def add_color(shape_color, name):
+                if isinstance(shape_color, pydiffvg.LinearGradient):
+                    lg = shape_color
+                    color = etree.SubElement(defs, 'linearGradient')
+                    color.set('id', name)
+                    color.set('x1', str(lg.begin[0].item()))
+                    color.set('y1', str(lg.begin[1].item()))
+                    color.set('x2', str(lg.end[0].item()))
+                    color.set('y2', str(lg.end[1].item()))
+                    offsets = lg.offsets.data.cpu().numpy()
+                    stop_colors = lg.stop_colors.data.cpu().numpy()
+                    for j in range(offsets.shape[0]):
+                        stop = etree.SubElement(color, 'stop')
+                        stop.set('offset', str(offsets[j]))
+                        c = lg.stop_colors[j, :]
+                        stop.set('stop-color', 'rgb({}, {}, {})'.format(
+                            int(255 * c[0]), int(255 * c[1]), int(255 * c[2])
+                        ))
+                        stop.set('stop-opacity', '{}'.format(c[3]))
+                if isinstance(shape_color, pydiffvg.RadialGradient):
+                    lg = shape_color
+                    color = etree.SubElement(defs, 'radialGradient')
+                    color.set('id', name)
+                    color.set('cx', str(lg.center[0].item() / width))
+                    color.set('cy', str(lg.center[1].item() / height))
+                    # this only support width=height
+                    color.set('r', str(lg.radius[0].item() / width))
+                    offsets = lg.offsets.data.cpu().numpy()
+                    stop_colors = lg.stop_colors.data.cpu().numpy()
+                    for j in range(offsets.shape[0]):
+                        stop = etree.SubElement(color, 'stop')
+                        stop.set('offset', str(offsets[j]))
+                        c = lg.stop_colors[j, :]
+                        stop.set('stop-color', 'rgb({}, {}, {})'.format(
+                            int(255 * c[0]), int(255 * c[1]), int(255 * c[2])
+                        ))
+                        stop.set('stop-opacity', '{}'.format(c[3]))
+
+            if shape_group.fill_color is not None:
+                add_color(shape_group.fill_color, 'shape_{}_fill'.format(i))
+            if shape_group.stroke_color is not None:
+                add_color(shape_group.stroke_color, 'shape_{}_stroke'.format(i))
+
+        for i, shape_group in enumerate(shape_groups):
+            shape = shapes[shape_group.shape_ids[0]]
+            if isinstance(shape, pydiffvg.Circle):
+                shape_node = etree.SubElement(g, 'circle')
+                shape_node.set('r', str(shape.radius.item()))
+                shape_node.set('cx', str(shape.center[0].item()))
+                shape_node.set('cy', str(shape.center[1].item()))
+            elif isinstance(shape, pydiffvg.Polygon):
+                shape_node = etree.SubElement(g, 'polygon')
+                points = shape.points.data.cpu().numpy()
+                path_str = ''
+                for j in range(0, shape.points.shape[0]):
+                    path_str += '{} {}'.format(points[j, 0], points[j, 1])
+                    if j != shape.points.shape[0] - 1:
+                        path_str += ' '
+                shape_node.set('points', path_str)
+            elif isinstance(shape, pydiffvg.Path):
+                for j, id in enumerate(shape_group.shape_ids):
+                    shape = shapes[id]
+                    if isinstance(shape, pydiffvg.Path):
+                        if j == 0:
+                            shape_node = etree.SubElement(g, 'path')
+                            path_str = ''
+
+                        num_segments = shape.num_control_points.shape[0]
+                        num_control_points = shape.num_control_points.data.cpu().numpy()
+                        points = shape.points.data.cpu().numpy()
+                        num_points = shape.points.shape[0]
+                        path_str += 'M {} {}'.format(points[0, 0], points[0, 1])
+                        point_id = 1
+                        for j in range(0, num_segments):
+                            if num_control_points[j] == 0:
+                                p = point_id % num_points
+                                path_str += ' L {} {}'.format(
+                                    points[p, 0], points[p, 1])
+                                point_id += 1
+                            elif num_control_points[j] == 1:
+                                p1 = (point_id + 1) % num_points
+                                path_str += ' Q {} {} {} {}'.format(
+                                    points[point_id, 0], points[point_id, 1],
+                                    points[p1, 0], points[p1, 1])
+                                point_id += 2
+                            elif num_control_points[j] == 2:
+                                p2 = (point_id + 2) % num_points
+                                path_str += ' C {} {} {} {} {} {}'.format(
+                                    points[point_id, 0], points[point_id, 1],
+                                    points[point_id + 1, 0], points[point_id + 1, 1],
+                                    points[p2, 0], points[p2, 1])
+                                point_id += 3
+                shape_node.set('d', path_str)
+            elif isinstance(shape, pydiffvg.Rect):
+                shape_node = etree.SubElement(g, 'rect')
+                shape_node.set('x', str(shape.p_min[0].item()))
+                shape_node.set('y', str(shape.p_min[1].item()))
+                shape_node.set('width', str(shape.p_max[0].item() - shape.p_min[0].item()))
+                shape_node.set('height', str(shape.p_max[1].item() - shape.p_min[1].item()))
+            elif isinstance(shape, pydiffvg.Ellipse):
+                shape_node = etree.SubElement(g, 'ellipse')
+                shape_node.set('cx', str(shape.center[0].item()))
+                shape_node.set('cy', str(shape.center[1].item()))
+                shape_node.set('rx', str(shape.radius[0].item()))
+                shape_node.set('ry', str(shape.radius[1].item()))
+            else:
+                raise NotImplementedError(f'shape type: {type(shape)} is not involved in pydiffvg.')
+
+            shape_node.set('stroke-width', str(2 * shape.stroke_width.data.cpu().item()))
+            if shape_group.fill_color is not None:
+                if isinstance(shape_group.fill_color, pydiffvg.LinearGradient):
+                    shape_node.set('fill', 'url(#shape_{}_fill)'.format(i))
+                else:
+                    c = shape_group.fill_color.data.cpu().numpy()
+                    shape_node.set('fill', 'rgb({}, {}, {})'.format(
+                        int(255 * c[0]), int(255 * c[1]), int(255 * c[2])))
+                    shape_node.set('opacity', str(c[3]))
+            else:
+                shape_node.set('fill', 'none')
+            if shape_group.stroke_color is not None:
+                if isinstance(shape_group.stroke_color, pydiffvg.LinearGradient):
+                    shape_node.set('stroke', 'url(#shape_{}_stroke)'.format(i))
+                else:
+                    c = shape_group.stroke_color.data.cpu().numpy()
+                    shape_node.set('stroke', 'rgb({}, {}, {})'.format(
+                        int(255 * c[0]), int(255 * c[1]), int(255 * c[2])))
+                    shape_node.set('stroke-opacity', str(c[3]))
+                shape_node.set('stroke-linecap', 'round')
+                shape_node.set('stroke-linejoin', 'round')
+
+        with open(filename, "w") as f:
+            f.write(pydiffvg.prettify(root))
diff --git a/methods/diffvg_warp/parse_svg.py b/methods/diffvg_warp/parse_svg.py
new file mode 100644
index 0000000000000000000000000000000000000000..de5e372ed09644734a929912a49baf41df250878
--- /dev/null
+++ b/methods/diffvg_warp/parse_svg.py
@@ -0,0 +1,585 @@
+import torch
+import xml.etree.ElementTree as etree
+import numpy as np
+import diffvg
+import os
+import pydiffvg
+import svgpathtools
+import svgpathtools.parser
+import re
+import warnings
+import cssutils
+import logging
+import matplotlib.colors
+cssutils.log.setLevel(logging.ERROR)
+
+def remove_namespaces(s):
+    """
+        {...} ... -> ...
+    """
+    return re.sub('{.*}', '', s)
+
+def parse_style(s, defs):
+    style_dict = {}
+    for e in s.split(';'):
+        key_value = e.split(':')
+        if len(key_value) == 2:
+            key = key_value[0].strip()
+            value = key_value[1].strip()
+            if key == 'fill' or key == 'stroke':
+                # Special case: convert colors into tensor in definitions so
+                # that different shapes can share the same color
+                value = parse_color(value, defs)
+            style_dict[key] = value
+    return style_dict
+
+def parse_hex(s):
+    """
+        Hex to tuple
+    """
+    s = s.lstrip('#')
+    if len(s) == 3:
+        s = s[0] + s[0] + s[1] + s[1] + s[2] + s[2]
+    rgb = tuple(int(s[i:i+2], 16) for i in (0, 2, 4))
+    # sRGB to RGB
+    # return torch.pow(torch.tensor([rgb[0] / 255.0, rgb[1] / 255.0, rgb[2] / 255.0]), 2.2)
+    return torch.pow(torch.tensor([rgb[0] / 255.0, rgb[1] / 255.0, rgb[2] / 255.0]), 1.0)
+
+def parse_int(s):
+    """
+        trim alphabets
+    """
+    return int(float(''.join(i for i in s if (not i.isalpha()))))
+
+def parse_color(s, defs):
+    if s is None:
+        return None
+    if isinstance(s, torch.Tensor):
+        return s
+    s = s.lstrip(' ')
+    color = torch.tensor([0.0, 0.0, 0.0, 1.0])
+    if s[0] == '#':
+        color[:3] = parse_hex(s)
+    elif s[:3] == 'url':
+        # url(#id)
+        color = defs[s[4:-1].lstrip('#')]
+    elif s == 'none':
+        color = None
+    elif s[:4] == 'rgb(':
+        rgb = s[4:-1].split(',')
+        color = torch.tensor([int(rgb[0]) / 255.0, int(rgb[1]) / 255.0, int(rgb[2]) / 255.0, 1.0])
+    elif s == 'none':
+        return None
+    else:
+        try :
+            rgba = matplotlib.colors.to_rgba(s)
+            color = torch.tensor(rgba)
+        except ValueError :
+            warnings.warn('Unknown color command ' + s)
+    return color
+
+# https://github.com/mathandy/svgpathtools/blob/7ebc56a831357379ff22216bec07e2c12e8c5bc6/svgpathtools/parser.py
+def _parse_transform_substr(transform_substr):
+    type_str, value_str = transform_substr.split('(')
+    value_str = value_str.replace(',', ' ')
+    values = list(map(float, filter(None, value_str.split(' '))))
+
+    transform = np.identity(3)
+    if 'matrix' in type_str:
+        transform[0:2, 0:3] = np.array([values[0:6:2], values[1:6:2]])
+    elif 'translate' in transform_substr:
+        transform[0, 2] = values[0]
+        if len(values) > 1:
+            transform[1, 2] = values[1]
+    elif 'scale' in transform_substr:
+        x_scale = values[0]
+        y_scale = values[1] if (len(values) > 1) else x_scale
+        transform[0, 0] = x_scale
+        transform[1, 1] = y_scale
+    elif 'rotate' in transform_substr:
+        angle = values[0] * np.pi / 180.0
+        if len(values) == 3:
+            offset = values[1:3]
+        else:
+            offset = (0, 0)
+        tf_offset = np.identity(3)
+        tf_offset[0:2, 2:3] = np.array([[offset[0]], [offset[1]]])
+        tf_rotate = np.identity(3)
+        tf_rotate[0:2, 0:2] = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])
+        tf_offset_neg = np.identity(3)
+        tf_offset_neg[0:2, 2:3] = np.array([[-offset[0]], [-offset[1]]])
+
+        transform = tf_offset.dot(tf_rotate).dot(tf_offset_neg)
+    elif 'skewX' in transform_substr:
+        transform[0, 1] = np.tan(values[0] * np.pi / 180.0)
+    elif 'skewY' in transform_substr:
+        transform[1, 0] = np.tan(values[0] * np.pi / 180.0)
+    else:
+        # Return an identity matrix if the type of transform is unknown, and warn the user
+        warnings.warn('Unknown SVG transform type: {0}'.format(type_str))
+    return transform
+
+def parse_transform(transform_str):
+    """
+        Converts a valid SVG transformation string into a 3x3 matrix.
+        If the string is empty or null, this returns a 3x3 identity matrix
+    """
+    if not transform_str:
+        return np.identity(3)
+    elif not isinstance(transform_str, str):
+        raise TypeError('Must provide a string to parse')
+
+    total_transform = np.identity(3)
+    transform_substrs = transform_str.split(')')[:-1]  # Skip the last element, because it should be empty
+    for substr in transform_substrs:
+        total_transform = total_transform.dot(_parse_transform_substr(substr))
+
+    return torch.from_numpy(total_transform).type(torch.float32)
+
+def parse_linear_gradient(node, transform, defs):
+    begin = torch.tensor([0.0, 0.0])
+    end = torch.tensor([0.0, 0.0])
+    offsets = []
+    stop_colors = []
+    # Inherit from parent
+    for key in node.attrib:
+        if remove_namespaces(key) == 'href':
+            value = node.attrib[key]
+            parent = defs[value.lstrip('#')]
+            begin = parent.begin
+            end = parent.end
+            offsets = parent.offsets
+            stop_colors = parent.stop_colors
+
+    for attrib in node.attrib:
+        attrib = remove_namespaces(attrib)
+        if attrib == 'x1':
+            begin[0] = float(node.attrib['x1'])
+        elif attrib == 'y1':
+            begin[1] = float(node.attrib['y1'])
+        elif attrib == 'x2':
+            end[0] = float(node.attrib['x2'])
+        elif attrib == 'y2':
+            end[1] = float(node.attrib['y2'])
+        elif attrib == 'gradientTransform':
+            transform = transform @ parse_transform(node.attrib['gradientTransform'])
+
+    begin = transform @ torch.cat((begin, torch.ones([1])))
+    begin = begin / begin[2]
+    begin = begin[:2]
+    end = transform @ torch.cat((end, torch.ones([1])))
+    end = end / end[2]
+    end = end[:2]
+
+    for child in node:
+        tag = remove_namespaces(child.tag)
+        if tag == 'stop':
+            offset = float(child.attrib['offset'])
+            color = [0.0, 0.0, 0.0, 1.0]
+            if 'stop-color' in child.attrib:
+                c = parse_color(child.attrib['stop-color'], defs)
+                color[:3] = [c[0], c[1], c[2]]
+            if 'stop-opacity' in child.attrib:
+                color[3] = float(child.attrib['stop-opacity'])
+            if 'style' in child.attrib:
+                style = parse_style(child.attrib['style'], defs)
+                if 'stop-color' in style:
+                    c = parse_color(style['stop-color'], defs)
+                    color[:3] = [c[0], c[1], c[2]]
+                if 'stop-opacity' in style:
+                    color[3] = float(style['stop-opacity'])
+            offsets.append(offset)
+            stop_colors.append(color)
+    if isinstance(offsets, list):
+        offsets = torch.tensor(offsets)
+    if isinstance(stop_colors, list):
+        stop_colors = torch.tensor(stop_colors)
+
+    return pydiffvg.LinearGradient(begin, end, offsets, stop_colors)
+
+
+def parse_radial_gradient(node, transform, defs):
+    begin = torch.tensor([0.0, 0.0])
+    end = torch.tensor([0.0, 0.0])
+    center = torch.tensor([0.0, 0.0])
+    radius = torch.tensor([0.0, 0.0])
+    offsets = []
+    stop_colors = []
+    # Inherit from parent
+    for key in node.attrib:
+        if remove_namespaces(key) == 'href':
+            value = node.attrib[key]
+            parent = defs[value.lstrip('#')]
+            begin = parent.begin
+            end = parent.end
+            offsets = parent.offsets
+            stop_colors = parent.stop_colors
+
+    for attrib in node.attrib:
+        attrib = remove_namespaces(attrib)
+        if attrib == 'cx':
+            center[0] = float(node.attrib['cx'])
+        elif attrib == 'cy':
+            center[1] = float(node.attrib['cy'])
+        elif attrib == 'fx':
+            radius[0] = float(node.attrib['fx'])
+        elif attrib == 'fy':
+            radius[1] = float(node.attrib['fy'])
+        elif attrib == 'fr':
+            radius[0] = float(node.attrib['fr'])
+            radius[1] = float(node.attrib['fr'])
+        elif attrib == 'gradientTransform':
+            transform = transform @ parse_transform(node.attrib['gradientTransform'])
+
+    # TODO: this is incorrect
+    center = transform @ torch.cat((center, torch.ones([1])))
+    center = center / center[2]
+    center = center[:2]
+
+    for child in node:
+        tag = remove_namespaces(child.tag)
+        if tag == 'stop':
+            offset = float(child.attrib['offset'])
+            color = [0.0, 0.0, 0.0, 1.0]
+            if 'stop-color' in child.attrib:
+                c = parse_color(child.attrib['stop-color'], defs)
+                color[:3] = [c[0], c[1], c[2]]
+            if 'stop-opacity' in child.attrib:
+                color[3] = float(child.attrib['stop-opacity'])
+            if 'style' in child.attrib:
+                style = parse_style(child.attrib['style'], defs)
+                if 'stop-color' in style:
+                    c = parse_color(style['stop-color'], defs)
+                    color[:3] = [c[0], c[1], c[2]]
+                if 'stop-opacity' in style:
+                    color[3] = float(style['stop-opacity'])
+            offsets.append(offset)
+            stop_colors.append(color)
+    if isinstance(offsets, list):
+        offsets = torch.tensor(offsets)
+    if isinstance(stop_colors, list):
+        stop_colors = torch.tensor(stop_colors)
+
+    return pydiffvg.RadialGradient(begin, end, offsets, stop_colors)
+
+def parse_stylesheet(node, transform, defs):
+    # collect CSS classes
+    sheet = cssutils.parseString(node.text)
+    for rule in sheet:
+        if hasattr(rule, 'selectorText') and hasattr(rule, 'style'):
+            name = rule.selectorText
+            if len(name) >= 2 and name[0] == '.':
+                defs[name[1:]] = parse_style(rule.style.getCssText(), defs)
+    return defs
+
+def parse_defs(node, transform, defs):
+    for child in node:
+        tag = remove_namespaces(child.tag)
+        if tag == 'linearGradient':
+            if 'id' in child.attrib:
+                defs[child.attrib['id']] = parse_linear_gradient(child, transform, defs)
+        elif tag == 'radialGradient':
+            if 'id' in child.attrib:
+                defs[child.attrib['id']] = parse_radial_gradient(child, transform, defs)
+        elif tag == 'style':
+            defs = parse_stylesheet(child, transform, defs)
+    return defs
+
+def parse_common_attrib(node, transform, fill_color, defs):
+    attribs = {}
+    if 'class' in node.attrib:
+        attribs.update(defs[node.attrib['class']])
+    attribs.update(node.attrib)
+
+    name = ''
+    if 'id' in node.attrib:
+        name = node.attrib['id']
+
+    stroke_color = None
+    stroke_width = torch.tensor(0.5)
+    use_even_odd_rule = False
+
+    new_transform = transform
+    if 'transform' in attribs:
+        new_transform = transform @ parse_transform(attribs['transform'])
+    if 'fill' in attribs:
+        fill_color = parse_color(attribs['fill'], defs)
+    fill_opacity = 1.0
+    if 'fill-opacity' in attribs:
+        fill_opacity *= float(attribs['fill-opacity'])
+    if 'opacity' in attribs:
+        fill_opacity *= float(attribs['opacity'])
+    # Ignore opacity if the color is a gradient
+    if isinstance(fill_color, torch.Tensor):
+        fill_color[3] = fill_opacity
+
+    if 'fill-rule' in attribs:
+        if attribs['fill-rule'] == "evenodd":
+            use_even_odd_rule = True
+        elif attribs['fill-rule'] == "nonzero":
+            use_even_odd_rule = False
+        else:
+            warnings.warn('Unknown fill-rule: {}'.format(attribs['fill-rule']))
+
+    if 'stroke' in attribs:
+        stroke_color = parse_color(attribs['stroke'], defs)
+        if 'stroke-opacity' in attribs:
+            stroke_color[3] = float(attribs['stroke-opacity'])
+
+    if 'stroke-width' in attribs:
+        stroke_width = attribs['stroke-width']
+        if stroke_width[-2:] == 'px':
+            stroke_width = stroke_width[:-2]
+        stroke_width = torch.tensor(float(stroke_width) / 2.0)
+
+    if 'style' in attribs:
+        style = parse_style(attribs['style'], defs)
+        if 'fill' in style:
+            fill_color = parse_color(style['fill'], defs)
+        fill_opacity = 1.0
+        if 'fill-opacity' in style:
+            fill_opacity *= float(style['fill-opacity'])
+        if 'opacity' in style:
+            fill_opacity *= float(style['opacity'])
+        if 'fill-rule' in style:
+            if style['fill-rule'] == "evenodd":
+                use_even_odd_rule = True
+            elif style['fill-rule'] == "nonzero":
+                use_even_odd_rule = False
+            else:
+                warnings.warn('Unknown fill-rule: {}'.format(style['fill-rule']))
+        # Ignore opacity if the color is a gradient
+        if isinstance(fill_color, torch.Tensor):
+            fill_color[3] = fill_opacity
+        if 'stroke' in style:
+            if style['stroke'] != 'none':
+                stroke_color = parse_color(style['stroke'], defs)
+                # Ignore opacity if the color is a gradient
+                if isinstance(stroke_color, torch.Tensor):
+                    if 'stroke-opacity' in style:
+                        stroke_color[3] = float(style['stroke-opacity'])
+                    if 'opacity' in style:
+                        stroke_color[3] *= float(style['opacity'])
+                if 'stroke-width' in style:
+                    stroke_width = style['stroke-width']
+                    if stroke_width[-2:] == 'px':
+                        stroke_width = stroke_width[:-2]
+                    stroke_width = torch.tensor(float(stroke_width) / 2.0)
+
+        if isinstance(fill_color, pydiffvg.LinearGradient):
+            fill_color.begin = new_transform @ torch.cat((fill_color.begin, torch.ones([1])))
+            fill_color.begin = fill_color.begin / fill_color.begin[2]
+            fill_color.begin = fill_color.begin[:2]
+            fill_color.end = new_transform @ torch.cat((fill_color.end, torch.ones([1])))
+            fill_color.end = fill_color.end / fill_color.end[2]
+            fill_color.end = fill_color.end[:2]
+        if isinstance(stroke_color, pydiffvg.LinearGradient):
+            stroke_color.begin = new_transform @ torch.cat((stroke_color.begin, torch.ones([1])))
+            stroke_color.begin = stroke_color.begin / stroke_color.begin[2]
+            stroke_color.begin = stroke_color.begin[:2]
+            stroke_color.end = new_transform @ torch.cat((stroke_color.end, torch.ones([1])))
+            stroke_color.end = stroke_color.end / stroke_color.end[2]
+            stroke_color.end = stroke_color.end[:2]
+        if 'filter' in style:
+            print('*** WARNING ***: Ignoring filter for path with id "{}"'.format(name))
+
+    return new_transform, fill_color, stroke_color, stroke_width, use_even_odd_rule
+
+def is_shape(tag):
+    return tag == 'path' or tag == 'polygon' or tag == 'line' or tag == 'circle' or tag == 'rect'
+
+def parse_shape(node, transform, fill_color, shapes, shape_groups, defs):
+    tag = remove_namespaces(node.tag)
+    new_transform, new_fill_color, stroke_color, stroke_width, use_even_odd_rule = \
+        parse_common_attrib(node, transform, fill_color, defs)
+    if tag == 'path':
+        d = node.attrib['d']
+        name = ''
+        if 'id' in node.attrib:
+            name = node.attrib['id']
+        force_closing = new_fill_color is not None
+        paths = pydiffvg.from_svg_path(d, new_transform, force_closing)
+        for idx, path in enumerate(paths):
+            assert(path.points.shape[1] == 2)
+            path.stroke_width = stroke_width
+            path.source_id = name
+            path.id = "{}-{}".format(name,idx) if len(paths)>1 else name
+        prev_shapes_size = len(shapes)
+        shapes = shapes + paths
+        shape_ids = torch.tensor(list(range(prev_shapes_size, len(shapes))))
+        shape_groups.append(pydiffvg.ShapeGroup(\
+            shape_ids = shape_ids,
+            fill_color = new_fill_color,
+            stroke_color = stroke_color,
+            use_even_odd_rule = use_even_odd_rule,
+            id = name))
+    elif tag == 'polygon':
+        name = ''
+        if 'id' in node.attrib:
+            name = node.attrib['id']
+        force_closing = new_fill_color is not None
+        pts = node.attrib['points'].strip()
+        pts = pts.split(' ')
+        # import ipdb; ipdb.set_trace()
+        pts = [[float(y) for y in re.split(',| ', x)] for x in pts if x]
+        pts = torch.tensor(pts, dtype=torch.float32).view(-1, 2)
+        polygon = pydiffvg.Polygon(pts, force_closing)
+        polygon.stroke_width = stroke_width
+        shape_ids = torch.tensor([len(shapes)])
+        shapes.append(polygon)
+        shape_groups.append(pydiffvg.ShapeGroup(\
+            shape_ids = shape_ids,
+            fill_color = new_fill_color,
+            stroke_color = stroke_color,
+            use_even_odd_rule = use_even_odd_rule,
+            shape_to_canvas = new_transform,
+            id = name))
+    elif tag == 'line':
+        x1 = float(node.attrib['x1'])
+        y1 = float(node.attrib['y1'])
+        x2 = float(node.attrib['x2'])
+        y2 = float(node.attrib['y2'])
+        p1 = torch.tensor([x1, y1])
+        p2 = torch.tensor([x2, y2])
+        points = torch.stack((p1, p2))
+        line = pydiffvg.Polygon(points, False)
+        line.stroke_width = stroke_width
+        shape_ids = torch.tensor([len(shapes)])
+        shapes.append(line)
+        shape_groups.append(pydiffvg.ShapeGroup(\
+            shape_ids = shape_ids,
+            fill_color = new_fill_color,
+            stroke_color = stroke_color,
+            use_even_odd_rule = use_even_odd_rule,
+            shape_to_canvas = new_transform))
+    elif tag == 'circle':
+        radius = float(node.attrib['r'])
+        cx = float(node.attrib['cx'])
+        cy = float(node.attrib['cy'])
+        name = ''
+        if 'id' in node.attrib:
+            name = node.attrib['id']
+        center = torch.tensor([cx, cy])
+        circle = pydiffvg.Circle(radius = torch.tensor(radius),
+                                 center = center)
+        circle.stroke_width = stroke_width
+        shape_ids = torch.tensor([len(shapes)])
+        shapes.append(circle)
+        shape_groups.append(pydiffvg.ShapeGroup(\
+            shape_ids = shape_ids,
+            fill_color = new_fill_color,
+            stroke_color = stroke_color,
+            use_even_odd_rule = use_even_odd_rule,
+            shape_to_canvas = new_transform))
+    elif tag == 'ellipse':
+        rx = float(node.attrib['rx'])
+        ry = float(node.attrib['ry'])
+        cx = float(node.attrib['cx'])
+        cy = float(node.attrib['cy'])
+        name = ''
+        if 'id' in node.attrib:
+            name = node.attrib['id']
+        center = torch.tensor([cx, cy])
+        circle = pydiffvg.Circle(radius = torch.tensor(radius),
+                                 center = center)
+        circle.stroke_width = stroke_width
+        shape_ids = torch.tensor([len(shapes)])
+        shapes.append(circle)
+        shape_groups.append(pydiffvg.ShapeGroup(\
+            shape_ids = shape_ids,
+            fill_color = new_fill_color,
+            stroke_color = stroke_color,
+            use_even_odd_rule = use_even_odd_rule,
+            shape_to_canvas = new_transform))
+    elif tag == 'rect':
+        x = 0.0
+        y = 0.0
+        if x in node.attrib:
+            x = float(node.attrib['x'])
+        if y in node.attrib:
+            y = float(node.attrib['y'])
+        w = float(node.attrib['width'])
+        h = float(node.attrib['height'])
+        p_min = torch.tensor([x, y])
+        p_max = torch.tensor([x + w, x + h])
+        rect = pydiffvg.Rect(p_min = p_min, p_max = p_max)
+        rect.stroke_width = stroke_width
+        shape_ids = torch.tensor([len(shapes)])
+        shapes.append(rect)
+        shape_groups.append(pydiffvg.ShapeGroup(\
+            shape_ids = shape_ids,
+            fill_color = new_fill_color,
+            stroke_color = stroke_color,
+            use_even_odd_rule = use_even_odd_rule,
+            shape_to_canvas = new_transform))
+    return shapes, shape_groups
+
+def parse_group(node, transform, fill_color, shapes, shape_groups, defs):
+    if 'transform' in node.attrib:
+        transform = transform @ parse_transform(node.attrib['transform'])
+    if 'fill' in node.attrib:
+        fill_color = parse_color(node.attrib['fill'], defs)
+    for child in node:
+        tag = remove_namespaces(child.tag)
+        if is_shape(tag):
+            shapes, shape_groups = parse_shape(\
+                child, transform, fill_color, shapes, shape_groups, defs)
+        elif tag == 'g':
+            shapes, shape_groups = parse_group(\
+                child, transform, fill_color, shapes, shape_groups, defs)
+    return shapes, shape_groups
+
+def parse_scene(node):
+    canvas_width = -1
+    canvas_height = -1
+    defs = {}
+    shapes = []
+    shape_groups = []
+    fill_color = torch.tensor([0.0, 0.0, 0.0, 1.0])
+    transform = torch.eye(3)
+    if 'viewBox' in node.attrib:
+        view_box_array = node.attrib['viewBox'].split()
+        canvas_width = parse_int(view_box_array[2])
+        canvas_height = parse_int(view_box_array[3])
+    else:
+        if 'width' in node.attrib:
+            canvas_width = parse_int(node.attrib['width'])
+        else:
+            print('Warning: Can\'t find canvas width.')
+        if 'height' in node.attrib:
+            canvas_height = parse_int(node.attrib['height'])
+        else:
+            print('Warning: Can\'t find canvas height.')
+    for child in node:
+        tag = remove_namespaces(child.tag)
+        if tag == 'defs':
+            defs = parse_defs(child, transform, defs)
+        elif tag == 'style':
+            defs = parse_stylesheet(child, transform, defs)
+        elif tag == 'linearGradient':
+            if 'id' in child.attrib:
+                defs[child.attrib['id']] = parse_linear_gradient(child, transform, defs)
+        elif tag == 'radialGradient':
+            if 'id' in child.attrib:
+                defs[child.attrib['id']] = parse_radial_gradient(child, transform, defs)
+        elif is_shape(tag):
+            shapes, shape_groups = parse_shape(\
+                child, transform, fill_color, shapes, shape_groups, defs)
+        elif tag == 'g':
+            shapes, shape_groups = parse_group(\
+                child, transform, fill_color, shapes, shape_groups, defs)
+    return canvas_width, canvas_height, shapes, shape_groups
+
+def svg_to_scene(filename):
+    """
+        Load from a SVG file and convert to PyTorch tensors.
+    """
+
+    tree = etree.parse(filename)
+    root = tree.getroot()
+    cwd = os.getcwd()
+    if (os.path.dirname(filename) != ''):
+        os.chdir(os.path.dirname(filename))
+    ret = parse_scene(root)
+    os.chdir(cwd)
+    return ret
diff --git a/methods/painter/__init__.py b/methods/painter/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/methods/painter/__init__.py
@@ -0,0 +1 @@
+
diff --git a/methods/painter/diffsketchedit/ASDS_SDXL_pipeline.py b/methods/painter/diffsketchedit/ASDS_SDXL_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..56cbfe321c4171d3b5589ab9c0a2998b8cdb9df4
--- /dev/null
+++ b/methods/painter/diffsketchedit/ASDS_SDXL_pipeline.py
@@ -0,0 +1,668 @@
+import PIL
+from PIL import Image
+from typing import Callable, List, Optional, Union, Tuple, AnyStr
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.cuda.amp import custom_bwd, custom_fwd
+from torchvision import transforms
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
+
+from methods.token2attn.attn_control import AttentionStore
+from methods.token2attn.ptp_utils import text_under_image, view_images
+
+
+class Token2AttnMixinASDSSDXLPipeline(StableDiffusionXLPipeline):
+    r"""
+    Pipeline for text-to-image generation using Stable Diffusion XL.
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+        text_encoder ([`CLIPTextModel`]):
+            Frozen text-encoder. Stable Diffusion uses the text portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+        tokenizer (`CLIPTokenizer`):
+            Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+        safety_checker ([`StableDiffusionSafetyChecker`]):
+            Classification module that estimates whether generated images could be considered offensive or harmful.
+            Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+        feature_extractor ([`CLIPFeatureExtractor`]):
+            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+    """
+    _optional_components = ["safety_checker", "feature_extractor"]
+
+    @torch.no_grad()
+    def __call__(
+            self,
+            prompt: Union[str, List[str]],
+            prompt_2: Optional[Union[str, List[str]]] = None,
+            height: Optional[int] = None,
+            width: Optional[int] = None,
+            controller: AttentionStore = None,  # feed attention_store as control of ptp
+            num_inference_steps: int = 50,
+            denoising_end: Optional[float] = None,
+            guidance_scale: float = 5.0,
+            negative_prompt: Optional[Union[str, List[str]]] = None,
+            negative_prompt_2: Optional[Union[str, List[str]]] = None,
+            num_images_per_prompt: Optional[int] = 1,
+            eta: float = 0.0,
+            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+            latents: Optional[torch.FloatTensor] = None,
+            output_type: Optional[str] = "pil",
+            return_dict: bool = True,
+            callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+            callback_steps: Optional[int] = 1,
+            original_size: Optional[Tuple[int, int]] = None,
+            crops_coords_top_left: Tuple[int, int] = (0, 0),
+            target_size: Optional[Tuple[int, int]] = None,
+    ):
+        r"""
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+                instead.
+            prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+                used in both text-encoders
+            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The height in pixels of the generated image.
+            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The width in pixels of the generated image.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            denoising_end (`float`, *optional*):
+                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+                completed before it is intentionally prematurely terminated. As a result, the returned sample will
+                still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+                scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+                "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+            guidance_scale (`float`, *optional*, defaults to 5.0):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            negative_prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+                of a plain tuple.
+            callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. The function will be
+                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+            callback_steps (`int`, *optional*, defaults to 1):
+                The frequency at which the `callback` function will be called. If not specified, the callback will be
+                called at every step.
+            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+                `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
+                explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+                `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                For most cases, `target_size` should be set to the desired height and width of the generated image. If
+                not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
+                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+
+        Examples:
+
+        Returns:
+            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+            `tuple`. When returning a tuple, the first element is a list with the generated images.
+        """
+
+        self.register_attention_control(controller)  # add attention controller
+
+        # 0. Default height and width to unet
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        original_size = original_size or (height, width)
+        target_size = target_size or (height, width)
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(prompt, prompt_2, height, width, callback_steps)
+
+        # 2. Define call parameters
+        batch_size = 1 if isinstance(prompt, str) else len(prompt)
+        device = self._execution_device
+        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+        # corresponds to doing no classifier free guidance.
+        do_classifier_free_guidance = guidance_scale > 1.0
+
+        # 3. Encode input prompt
+        (
+            text_embeddings,
+            negative_text_embeddings,
+            pooled_text_embeddings,
+            negative_pooled_text_embeddings,
+        ) = self.encode_prompt(
+            prompt=prompt,
+            prompt_2=prompt_2,
+            device=device,
+            num_images_per_prompt=num_images_per_prompt,
+            do_classifier_free_guidance=do_classifier_free_guidance,
+            negative_prompt=negative_prompt,
+            negative_prompt_2=negative_prompt_2,
+        )
+
+        # 4. Prepare timesteps
+        self.scheduler.set_timesteps(num_inference_steps, device=device)
+        timesteps = self.scheduler.timesteps
+
+        # 5. Prepare latent variables
+        try:
+            num_channels_latents = self.unet.config.in_channels
+        except Exception or Warning:
+            num_channels_latents = self.unet.in_channels
+
+        latents = self.prepare_latents(
+            batch_size * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            text_embeddings.dtype,
+            device,
+            generator,
+            latents,
+        )
+
+        # 6. Prepare extra step kwargs. inherit TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7. Prepare added time ids & embeddings
+        add_text_embeddings = pooled_text_embeddings
+        add_time_ids = self._get_add_time_ids(
+            original_size, crops_coords_top_left, target_size, dtype=text_embeddings.dtype
+        )
+
+        if do_classifier_free_guidance:
+            text_embeddings = torch.cat([negative_text_embeddings, text_embeddings], dim=0)
+            add_text_embeddings = torch.cat([negative_pooled_text_embeddings, add_text_embeddings], dim=0)
+            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+        text_embeddings = text_embeddings.to(device)
+        add_text_embeddings = add_text_embeddings.to(device)
+        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+        # 8. Denoising loop
+
+        # 8.1 Apply denoising_end
+        if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
+            discrete_timestep_cutoff = int(
+                round(
+                    self.scheduler.config.num_train_timesteps
+                    - (denoising_end * self.scheduler.config.num_train_timesteps)
+                )
+            )
+            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+            timesteps = timesteps[:num_inference_steps]
+
+        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                # predict the noise residual
+                added_cond_kwargs = {"text_embeds": add_text_embeddings, "time_ids": add_time_ids}
+                noise_pred = self.unet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=text_embeddings,
+                    added_cond_kwargs=added_cond_kwargs
+                ).sample
+
+                # perform guidance
+                if do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+                # step callback
+                latents = controller.step_callback(latents)
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        callback(i, t, latents)
+
+        # 9. Post-processing
+
+        # The decode_latents method is deprecated and has been removed in sdxl
+        # image = self.decode_latents(latents)
+
+        # make sure the VAE is in float32 mode, as it overflows in float16
+        if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+            self.upcast_vae()
+            latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+        if not output_type == "latent":
+            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+        else:
+            image = latents
+            return StableDiffusionXLPipelineOutput(images=image)
+
+        # apply watermark if available
+        if self.watermark is not None:
+            image = self.watermark.apply_watermark(image)
+
+        image = self.image_processor.postprocess(image, output_type=output_type)
+
+        if not return_dict:
+            return (image,)
+
+        return StableDiffusionXLPipelineOutput(images=image)
+
+    def encode2latents(self,
+                       image,
+                       batch_size,
+                       num_images_per_prompt,
+                       dtype,
+                       device,
+                       generator=None):
+        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+            raise ValueError(
+                f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
+            )
+
+        # Offload text encoder if `enable_model_cpu_offload` was enabled
+        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+            self.text_encoder_2.to("cpu")
+            torch.cuda.empty_cache()
+
+        image = image.to(device=device, dtype=dtype)
+
+        batch_size = batch_size * num_images_per_prompt
+
+        if image.shape[1] == 4:
+            init_latents = image
+        else:
+            # make sure the VAE is in float32 mode, as it overflows in float16
+            if self.vae.config.force_upcast:
+                image = image.float()
+                self.vae.to(dtype=torch.float32)
+
+            if isinstance(generator, list) and len(generator) != batch_size:
+                raise ValueError(
+                    f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                    f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+                )
+
+            elif isinstance(generator, list):
+                init_latents = [
+                    self.vae.encode(image[i: i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
+                ]
+                init_latents = torch.cat(init_latents, dim=0)
+            else:
+                init_latents = self.vae.encode(image).latent_dist.sample(generator)
+
+            if self.vae.config.force_upcast:
+                self.vae.to(dtype)
+
+            init_latents = init_latents.to(dtype)
+            init_latents = self.vae.config.scaling_factor * init_latents
+
+        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+            # expand init_latents for batch_size
+            additional_image_per_prompt = batch_size // init_latents.shape[0]
+            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+            raise ValueError(
+                f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+            )
+        else:
+            init_latents = torch.cat([init_latents], dim=0)
+
+        latents = init_latents
+
+        return latents
+
+    @staticmethod
+    def S_aug(sketch: torch.Tensor,
+              im_res: int = 1024,
+              augments: str = "affine_contrast"):
+        # init augmentations
+        augment_list = []
+        if "affine" in augments:
+            augment_list.append(
+                transforms.RandomPerspective(fill=0, p=1.0, distortion_scale=0.5)
+            )
+            augment_list.append(
+                transforms.RandomResizedCrop(im_res, scale=(0.8, 0.8), ratio=(1.0, 1.0))
+            )
+        if "contrast" in augments:
+            # 2: increases the sharpness by a factor of 2.
+            augment_list.append(
+                transforms.RandomAdjustSharpness(sharpness_factor=2)
+            )
+        augment_compose = transforms.Compose(augment_list)
+
+        return augment_compose(sketch)
+
+    def score_distillation_sampling(self,
+                                    pred_rgb: torch.Tensor,
+                                    crop_size: int,
+                                    augments: str,
+                                    prompt: Union[List, str],
+                                    prompt_2: Optional[Union[List, str]] = None,
+                                    height: Optional[int] = None,
+                                    width: Optional[int] = None,
+                                    negative_prompt: Union[List, str] = None,
+                                    negative_prompt_2: Optional[Union[List, str]] = None,
+                                    guidance_scale: float = 100,
+                                    as_latent: bool = False,
+                                    grad_scale: float = 1,
+                                    t_range: Union[List[float], Tuple[float]] = (0.05, 0.95),
+                                    original_size: Optional[Tuple[int, int]] = None,
+                                    crops_coords_top_left: Tuple[int, int] = (0, 0),
+                                    target_size: Optional[Tuple[int, int]] = None):
+
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        original_size = original_size or (height, width)
+        target_size = target_size or (height, width)
+
+        batch_size = 1 if isinstance(prompt, str) else len(prompt)
+
+        num_train_timesteps = self.scheduler.config.num_train_timesteps
+        min_step = int(num_train_timesteps * t_range[0])
+        max_step = int(num_train_timesteps * t_range[1])
+        alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience
+
+        num_images_per_prompt = 1  # the number of images to generate per prompt
+
+        #  Encode input prompt
+        do_classifier_free_guidance = guidance_scale > 1.0
+        (
+            text_embeddings,
+            negative_text_embeddings,
+            pooled_text_embeddings,
+            negative_pooled_text_embeddings,
+        ) = self.encode_prompt(
+            prompt=prompt,
+            prompt_2=prompt_2,
+            device=self.device,
+            num_images_per_prompt=num_images_per_prompt,
+            do_classifier_free_guidance=do_classifier_free_guidance,
+            negative_prompt=negative_prompt,
+            negative_prompt_2=negative_prompt_2,
+        )
+
+        # sketch augmentation
+        pred_rgb_a = self.S_aug(pred_rgb, crop_size, augments)
+
+        # interp to 512x512 to be fed into vae.
+        if as_latent:
+            latents = F.interpolate(pred_rgb_a, (128, 128), mode='bilinear', align_corners=False) * 2 - 1
+        else:
+            # encode image into latents via vae, requires grad!
+            latents = self.encode2latents(
+                pred_rgb_a,
+                batch_size,
+                num_images_per_prompt,
+                text_embeddings.dtype,
+                self.device
+            )
+
+        # timestep ~ U(0.05, 0.95) to avoid very high/low noise level
+        t = torch.randint(min_step, max_step + 1, [1], dtype=torch.long, device=self.device)
+
+        # 7. Prepare added time ids & embeddings
+        add_text_embeddings = pooled_text_embeddings
+        add_time_ids = self._get_add_time_ids(
+            original_size, crops_coords_top_left, target_size, dtype=text_embeddings.dtype
+        )
+
+        if do_classifier_free_guidance:
+            text_embeddings = torch.cat([negative_text_embeddings, text_embeddings], dim=0)
+            add_text_embeddings = torch.cat([negative_pooled_text_embeddings, add_text_embeddings], dim=0)
+            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+
+        text_embeddings = text_embeddings.to(self.device)
+        add_text_embeddings = add_text_embeddings.to(self.device)
+        add_time_ids = add_time_ids.to(self.device).repeat(batch_size * num_images_per_prompt, 1)
+
+        # predict the noise residual with unet, stop gradient
+        with torch.no_grad():
+            # add noise
+            noise = torch.randn_like(latents)
+            latents_noisy = self.scheduler.add_noise(latents, noise, t)
+            # pred noise
+            latent_model_input = torch.cat([latents_noisy] * 2) if do_classifier_free_guidance else latents_noisy
+            # predict the noise residual
+            added_cond_kwargs = {"text_embeds": add_text_embeddings, "time_ids": add_time_ids}
+            noise_pred = self.unet(
+                latent_model_input,
+                t,
+                encoder_hidden_states=text_embeddings,
+                added_cond_kwargs=added_cond_kwargs
+            ).sample
+
+        # perform guidance (high scale from paper!)
+        if do_classifier_free_guidance:
+            noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
+            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)
+
+        # w(t), sigma_t^2
+        w = (1 - alphas[t])
+        grad = grad_scale * w * (noise_pred - noise)
+        grad = torch.nan_to_num(grad)
+
+        # since we omitted an item in grad, we need to use the custom function to specify the gradient
+        loss = SpecifyGradient.apply(latents, grad)
+
+        return loss, grad.mean()
+
+    def register_attention_control(self, controller):
+        attn_procs = {}
+        cross_att_count = 0
+        for name in self.unet.attn_processors.keys():
+            cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
+            if name.startswith("mid_block"):
+                hidden_size = self.unet.config.block_out_channels[-1]
+                place_in_unet = "mid"
+            elif name.startswith("up_blocks"):
+                block_id = int(name[len("up_blocks.")])
+                hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
+                place_in_unet = "up"
+            elif name.startswith("down_blocks"):
+                block_id = int(name[len("down_blocks.")])
+                hidden_size = self.unet.config.block_out_channels[block_id]
+                place_in_unet = "down"
+            else:
+                continue
+            cross_att_count += 1
+            attn_procs[name] = P2PCrossAttnProcessor(
+                controller=controller, place_in_unet=place_in_unet
+            )
+
+        self.unet.set_attn_processor(attn_procs)
+        controller.num_att_layers = cross_att_count
+
+    @staticmethod
+    def aggregate_attention(prompts,
+                            attention_store: AttentionStore,
+                            res: int,
+                            from_where: List[str],
+                            is_cross: bool,
+                            select: int):
+        if isinstance(prompts, str):
+            prompts = [prompts]
+        assert isinstance(prompts, list)
+
+        out = []
+        attention_maps = attention_store.get_average_attention()
+        num_pixels = res ** 2
+        for location in from_where:
+            for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
+                if item.shape[1] == num_pixels:
+                    cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
+                    out.append(cross_maps)
+        out = torch.cat(out, dim=0)
+        out = out.sum(0) / out.shape[0]
+        return out.cpu()
+
+    def get_cross_attention(self,
+                            prompts,
+                            attention_store: AttentionStore,
+                            res: int,
+                            from_where: List[str],
+                            select: int = 0,
+                            save_path=None):
+        tokens = self.tokenizer.encode(prompts[select])
+        decoder = self.tokenizer.decode
+        # shape: [res ** 2, res ** 2, seq_len]
+        attention_maps = self.aggregate_attention(prompts, attention_store, res, from_where, True, select)
+
+        images = []
+        for i in range(len(tokens)):
+            image = attention_maps[:, :, i]
+            image = 255 * image / image.max()
+            image = image.unsqueeze(-1).expand(*image.shape, 3)
+            image = image.numpy().astype(np.uint8)
+            image = np.array(Image.fromarray(image).resize((256, 256)))
+            image = text_under_image(image, decoder(int(tokens[i])))
+            images.append(image)
+        image_array = np.stack(images, axis=0)
+        view_images(image_array, save_image=True, fp=save_path)
+
+        return attention_maps, tokens
+
+    def get_self_attention_comp(self,
+                                prompts,
+                                attention_store: AttentionStore,
+                                res: int,
+                                from_where: List[str],
+                                img_size: int = 224,
+                                max_com=10,
+                                select: int = 0,
+                                save_path: AnyStr = None):
+        attention_maps = self.aggregate_attention(prompts, attention_store, res, from_where, False, select)
+        attention_maps = attention_maps.numpy().reshape((res ** 2, res ** 2))
+        # shape: [res ** 2, res ** 2]
+        u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
+        print(f"self-attention_maps: {attention_maps.shape}, "
+              f"u: {u.shape}, "
+              f"s: {s.shape}, "
+              f"vh: {vh.shape}")
+
+        images = []
+        vh_returns = []
+        for i in range(max_com):
+            image = vh[i].reshape(res, res)
+            image = (image - image.min()) / (image.max() - image.min())
+            image = 255 * image
+
+            ret_ = Image.fromarray(image).resize((img_size, img_size), resample=PIL.Image.Resampling.BILINEAR)
+            vh_returns.append(np.array(ret_))
+
+            image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
+            image = Image.fromarray(image).resize((256, 256))
+            image = np.array(image)
+            images.append(image)
+        image_array = np.stack(images, axis=0)
+        view_images(image_array, num_rows=max_com // 10, offset_ratio=0,
+                    save_image=True, fp=save_path / "self-attn-vh.png")
+
+        return attention_maps, (u, s, vh), np.stack(vh_returns, axis=0)
+
+
+class P2PCrossAttnProcessor:
+
+    def __init__(self, controller, place_in_unet):
+        super().__init__()
+        self.controller = controller
+        self.place_in_unet = place_in_unet
+
+    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
+        batch_size, sequence_length, _ = hidden_states.shape
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=batch_size)
+
+        query = attn.to_q(hidden_states)
+
+        is_cross = encoder_hidden_states is not None
+        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+
+        # one line change
+        self.controller(attention_probs, is_cross, self.place_in_unet)
+
+        hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        return hidden_states
+
+
+class SpecifyGradient(torch.autograd.Function):
+
+    @staticmethod
+    @custom_fwd
+    def forward(ctx, input_tensor, gt_grad):
+        ctx.save_for_backward(gt_grad)
+        # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
+        return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
+
+    @staticmethod
+    @custom_bwd
+    def backward(ctx, grad_scale):
+        gt_grad, = ctx.saved_tensors
+        gt_grad = gt_grad * grad_scale
+        return gt_grad, None
diff --git a/methods/painter/diffsketchedit/ASDS_pipeline.py b/methods/painter/diffsketchedit/ASDS_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..c905078917dc2c52f425406dd1f0ebea61739c2f
--- /dev/null
+++ b/methods/painter/diffsketchedit/ASDS_pipeline.py
@@ -0,0 +1,507 @@
+import PIL
+from PIL import Image
+from typing import Callable, List, Optional, Union, Tuple, AnyStr
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.cuda.amp import custom_bwd, custom_fwd
+from torchvision import transforms
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
+
+from methods.token2attn.attn_control import AttentionStore
+from methods.token2attn.ptp_utils import text_under_image, view_images
+
+
+class Token2AttnMixinASDSPipeline(StableDiffusionPipeline):
+    r"""
+    Pipeline for text-to-image generation using Stable Diffusion.
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+        text_encoder ([`CLIPTextModel`]):
+            Frozen text-encoder. Stable Diffusion uses the text portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+        tokenizer (`CLIPTokenizer`):
+            Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+        safety_checker ([`StableDiffusionSafetyChecker`]):
+            Classification module that estimates whether generated images could be considered offensive or harmful.
+            Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+        feature_extractor ([`CLIPFeatureExtractor`]):
+            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+    """
+    _optional_components = ["safety_checker", "feature_extractor"]
+
+    @torch.no_grad()
+    def __call__(
+            self,
+            prompt: Union[str, List[str]],
+            height: Optional[int] = None,
+            width: Optional[int] = None,
+            controller: AttentionStore = None,  # feed attention_store as control of ptp
+            num_inference_steps: int = 50,
+            guidance_scale: float = 7.5,
+            negative_prompt: Optional[Union[str, List[str]]] = None,
+            num_images_per_prompt: Optional[int] = 1,
+            eta: float = 0.0,
+            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+            latents: Optional[torch.FloatTensor] = None,
+            output_type: Optional[str] = "pil",
+            return_dict: bool = True,
+            callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+            callback_steps: Optional[int] = 1,
+    ):
+        r"""
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`):
+                The prompt or prompts to guide the image generation.
+            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The height in pixels of the generated image.
+            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The width in pixels of the generated image.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+                if `guidance_scale` is less than `1`).
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            callback (`Callable`, *optional*):
+                A function that will be called every `callback_steps` steps during inference. The function will be
+                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+            callback_steps (`int`, *optional*, defaults to 1):
+                The frequency at which the `callback` function will be called. If not specified, the callback will be
+                called at every step.
+
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+            When returning a tuple, the first element is a list with the generated images, and the second element is a
+            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+            (nsfw) content, according to the `safety_checker`.
+        """
+
+        self.register_attention_control(controller)  # add attention controller
+
+        # 0. Default height and width to unet
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(prompt, height, width, callback_steps)
+
+        # 2. Define call parameters
+        batch_size = 1 if isinstance(prompt, str) else len(prompt)
+        device = self._execution_device
+        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+        # corresponds to doing no classifier free guidance.
+        do_classifier_free_guidance = guidance_scale > 1.0
+
+        # 3. Encode input prompt
+        text_embeddings = self._encode_prompt(
+            prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+        )
+
+        # 4. Prepare timesteps
+        self.scheduler.set_timesteps(num_inference_steps, device=device)
+        timesteps = self.scheduler.timesteps
+
+        # 5. Prepare latent variables
+        try:
+            num_channels_latents = self.unet.config.in_channels
+        except Exception or Warning:
+            num_channels_latents = self.unet.in_channels
+
+        latents = self.prepare_latents2(
+            batch_size * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            text_embeddings.dtype,
+            torch.device("cpu"),
+            generator,
+            latents,
+        )
+        latents = latents.to(device)
+
+        # 6. Prepare extra step kwargs. inherit TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7. Denoising loop
+        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                # predict the noise residual
+                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+                # perform guidance
+                if do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+                # step callback
+                latents = controller.step_callback(latents)
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        callback(i, t, latents)
+
+        # image = self.decode_latents(latents)
+
+        # 8. Post-processing
+        # 9. Run safety checker
+        if not output_type == "latent":
+            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+            # image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+            has_nsfw_concept = None
+        else:
+            image = latents
+            has_nsfw_concept = None
+
+        if has_nsfw_concept is None:
+            do_denormalize = [True] * image.shape[0]
+        else:
+            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+        # 10. Convert to output_type
+        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+        if not return_dict:
+            return (image, has_nsfw_concept)
+
+        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+    def prepare_latents2(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+        shape = (1, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+
+        if latents is None:
+            latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+        else:
+            latents = latents.to(device)
+        latents = latents.repeat(batch_size, 1, 1, 1)
+
+        # scale the initial noise by the standard deviation required by the scheduler
+        latents = latents * self.scheduler.init_noise_sigma
+        return latents
+
+    def encode_(self, images):
+        images = (2 * images - 1).clamp(-1.0, 1.0)  # images: [B, 3, H, W]
+
+        # encode images
+        latents = self.vae.encode(images).latent_dist.sample()
+        latents = self.vae.config.scaling_factor * latents
+
+        # scale the initial noise by the standard deviation required by the scheduler
+        latents = latents * self.scheduler.init_noise_sigma
+
+        return latents
+
+    @staticmethod
+    def S_aug(sketch: torch.Tensor,
+              crop_size: int = 512,
+              augments: str = "affine_contrast"):
+        # init augmentations
+        augment_list = []
+        if "affine" in augments:
+            augment_list.append(
+                transforms.RandomPerspective(fill=0, p=1.0, distortion_scale=0.5)
+            )
+            augment_list.append(
+                transforms.RandomResizedCrop(crop_size, scale=(0.8, 0.8), ratio=(1.0, 1.0))
+            )
+        if "contrast" in augments:
+            # 2: increases the sharpness by a factor of 2.
+            augment_list.append(
+                transforms.RandomAdjustSharpness(sharpness_factor=2)
+            )
+        augment_compose = transforms.Compose(augment_list)
+
+        return augment_compose(sketch)
+
+    def score_distillation_sampling(self,
+                                    pred_rgb: torch.Tensor,
+                                    crop_size: int,
+                                    augments: str,
+                                    prompt: Union[List, str],
+                                    negative_prompt: Union[List, str] = None,
+                                    guidance_scale: float = 100,
+                                    as_latent: bool = False,
+                                    grad_scale: float = 1,
+                                    t_range: Union[List[float], Tuple[float]] = (0.02, 0.98)):
+        num_train_timesteps = self.scheduler.config.num_train_timesteps
+        min_step = int(num_train_timesteps * t_range[0])
+        max_step = int(num_train_timesteps * t_range[1])
+        alphas = self.scheduler.alphas_cumprod.to(self.device)  # for convenience
+
+        # sketch augmentation
+        pred_rgb_a = self.S_aug(pred_rgb, crop_size, augments)
+
+        # interp to crop_size x crop_size to be fed into vae.
+        if as_latent:
+            latents = F.interpolate(pred_rgb_a, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
+        else:
+            # encode image into latents with vae, requires grad!
+            latents = self.encode_(pred_rgb_a)
+
+        #  Encode input prompt
+        num_images_per_prompt = 1  # the number of images to generate per prompt
+        do_classifier_free_guidance = guidance_scale > 1.0
+        text_embeddings = self._encode_prompt(
+            prompt, self.device, num_images_per_prompt, do_classifier_free_guidance,
+            negative_prompt=negative_prompt,
+        )
+
+        # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
+        t = torch.randint(min_step, max_step + 1, [1], dtype=torch.long, device=self.device)
+
+        # predict the noise residual with unet, stop gradient
+        with torch.no_grad():
+            # add noise
+            noise = torch.randn_like(latents)
+            latents_noisy = self.scheduler.add_noise(latents, noise, t)
+            # pred noise
+            latent_model_input = torch.cat([latents_noisy] * 2) if do_classifier_free_guidance else latents_noisy
+            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+        # perform guidance (high scale from paper!)
+        if do_classifier_free_guidance:
+            noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
+            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)
+
+        # w(t), sigma_t^2
+        w = (1 - alphas[t])
+        grad = grad_scale * w * (noise_pred - noise)
+        grad = torch.nan_to_num(grad)
+
+        # since we omitted an item in grad, we need to use the custom function to specify the gradient
+        loss = SpecifyGradient.apply(latents, grad)
+
+        return loss, grad.mean()
+
+    def register_attention_control(self, controller):
+        attn_procs = {}
+        cross_att_count = 0
+        for name in self.unet.attn_processors.keys():
+            cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
+            if name.startswith("mid_block"):
+                hidden_size = self.unet.config.block_out_channels[-1]
+                place_in_unet = "mid"
+            elif name.startswith("up_blocks"):
+                block_id = int(name[len("up_blocks.")])
+                hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
+                place_in_unet = "up"
+            elif name.startswith("down_blocks"):
+                block_id = int(name[len("down_blocks.")])
+                hidden_size = self.unet.config.block_out_channels[block_id]
+                place_in_unet = "down"
+            else:
+                continue
+            cross_att_count += 1
+            attn_procs[name] = P2PCrossAttnProcessor(
+                controller=controller, place_in_unet=place_in_unet
+            )
+
+        self.unet.set_attn_processor(attn_procs)
+        controller.num_att_layers = cross_att_count
+
+    @staticmethod
+    def aggregate_attention(prompts,
+                            attention_store: AttentionStore,
+                            res: int,
+                            from_where: List[str],
+                            is_cross: bool,
+                            select: int):
+        if isinstance(prompts, str):
+            prompts = [prompts]
+        assert isinstance(prompts, list)
+
+        out = []
+        attention_maps = attention_store.get_average_attention()
+        num_pixels = res ** 2
+        for location in from_where:
+            for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
+                if item.shape[1] == num_pixels:
+                    cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
+                    out.append(cross_maps)
+        out = torch.cat(out, dim=0)
+        out = out.sum(0) / out.shape[0]
+        return out.cpu()
+
+    def get_cross_attention(self,
+                            prompts,
+                            attention_store: AttentionStore,
+                            res: int,
+                            from_where: List[str],
+                            select: int = 0,
+                            save_path=None):
+        tokens = self.tokenizer.encode(prompts[select])
+        decoder = self.tokenizer.decode
+        # shape: [res ** 2, res ** 2, seq_len]
+        attention_maps = self.aggregate_attention(prompts, attention_store, res, from_where, True, select)
+
+        images = []
+        for i in range(len(tokens)):
+            image = attention_maps[:, :, i]
+            image = 255 * image / image.max()
+            image = image.unsqueeze(-1).expand(*image.shape, 3)
+            image = image.numpy().astype(np.uint8)
+            image = np.array(Image.fromarray(image).resize((256, 256)))
+            image = text_under_image(image, decoder(int(tokens[i])))
+            images.append(image)
+        image_array = np.stack(images, axis=0)
+        view_images(image_array, save_image=True, fp=save_path)
+
+        return attention_maps, tokens
+
+    def get_cross_attention2(self,
+                             prompts,
+                             attention_store: AttentionStore,
+                             res: int,
+                             from_where: List[str],
+                             select: int = 0):
+        # shape: [res ** 2, res ** 2, seq_len]
+        attention_maps = self.aggregate_attention(prompts, attention_store, res, from_where, True, select)
+
+        return attention_maps
+
+    def get_self_attention_comp(self,
+                                prompts,
+                                attention_store: AttentionStore,
+                                res: int,
+                                from_where: List[str],
+                                img_size: int = 224,
+                                max_com=10,
+                                select: int = 0,
+                                save_path: AnyStr = None):
+        attention_maps = self.aggregate_attention(prompts, attention_store, res, from_where, False, select)
+        attention_maps = attention_maps.numpy().reshape((res ** 2, res ** 2))
+        # shape: [res ** 2, res ** 2]
+        u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
+        print(f"self-attention_maps: {attention_maps.shape}, "
+              f"u: {u.shape}, "
+              f"s: {s.shape}, "
+              f"vh: {vh.shape}")
+
+        images = []
+        vh_returns = []
+        for i in range(max_com):
+            image = vh[i].reshape(res, res)
+            image = (image - image.min()) / (image.max() - image.min())
+            image = 255 * image
+
+            ret_ = Image.fromarray(image).resize((img_size, img_size), resample=PIL.Image.Resampling.BILINEAR)
+            vh_returns.append(np.array(ret_))
+
+            image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
+            image = Image.fromarray(image).resize((256, 256))
+            image = np.array(image)
+            images.append(image)
+        image_array = np.stack(images, axis=0)
+        view_images(image_array, num_rows=max_com // 10, offset_ratio=0,
+                    save_image=True, fp=save_path / "self-attn-vh.png")
+
+        return attention_maps, (u, s, vh), np.stack(vh_returns, axis=0)
+
+
+class P2PCrossAttnProcessor:
+
+    def __init__(self, controller, place_in_unet):
+        super().__init__()
+        self.controller = controller
+        self.place_in_unet = place_in_unet
+
+    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
+        batch_size, sequence_length, _ = hidden_states.shape
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size=batch_size)
+
+        query = attn.to_q(hidden_states)
+
+        is_cross = encoder_hidden_states is not None
+        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+
+        # one line change
+        self.controller(attention_probs, is_cross, self.place_in_unet)
+
+        hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        return hidden_states
+
+
+class SpecifyGradient(torch.autograd.Function):
+
+    @staticmethod
+    @custom_fwd
+    def forward(ctx, input_tensor, gt_grad):
+        ctx.save_for_backward(gt_grad)
+        # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
+        return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
+
+    @staticmethod
+    @custom_bwd
+    def backward(ctx, grad_scale):
+        gt_grad, = ctx.saved_tensors
+        gt_grad = gt_grad * grad_scale
+        return gt_grad, None
diff --git a/methods/painter/diffsketchedit/__init__.py b/methods/painter/diffsketchedit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..56db90401c3e0a544d8bc3cd21858a61f1dc47df
--- /dev/null
+++ b/methods/painter/diffsketchedit/__init__.py
@@ -0,0 +1,9 @@
+from .painter_params import Painter, SketchPainterOptimizer
+from .ASDS_pipeline import Token2AttnMixinASDSPipeline
+from .ASDS_SDXL_pipeline import Token2AttnMixinASDSSDXLPipeline
+
+__all__ = [
+    'Painter', 'SketchPainterOptimizer',
+    'Token2AttnMixinASDSPipeline',
+    'Token2AttnMixinASDSSDXLPipeline'
+]
diff --git a/methods/painter/diffsketchedit/mask_utils.py b/methods/painter/diffsketchedit/mask_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..427af90bd079df83fadc67f6d7c4d73f66f8f175
--- /dev/null
+++ b/methods/painter/diffsketchedit/mask_utils.py
@@ -0,0 +1,59 @@
+from PIL import Image
+
+import numpy as np
+import torch
+from torchvision import transforms
+from skimage.transform import resize
+
+from .u2net import U2NET
+
+
+def get_mask_u2net(pil_im, output_dir, u2net_path, device="cpu"):
+    # input preprocess
+    w, h = pil_im.size[0], pil_im.size[1]
+    im_size = min(w, h)
+    data_transforms = transforms.Compose([
+        transforms.Resize(min(320, im_size), interpolation=transforms.InterpolationMode.BICUBIC),
+        transforms.ToTensor(),
+        transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
+                             std=(0.26862954, 0.26130258, 0.27577711)),
+    ])
+    input_im_trans = data_transforms(pil_im).unsqueeze(0).to(device)
+
+    # load U^2 Net model
+    net = U2NET(in_ch=3, out_ch=1)
+    net.load_state_dict(torch.load(u2net_path))
+    net.to(device)
+    net.eval()
+
+    # get mask
+    with torch.no_grad():
+        d1, d2, d3, d4, d5, d6, d7 = net(input_im_trans.detach())
+    pred = d1[:, 0, :, :]
+    pred = (pred - pred.min()) / (pred.max() - pred.min())
+    predict = pred
+    predict[predict < 0.5] = 0
+    predict[predict >= 0.5] = 1
+    mask = torch.cat([predict, predict, predict], dim=0).permute(1, 2, 0)
+    mask = mask.cpu().numpy()
+    mask = resize(mask, (h, w), anti_aliasing=False)
+    mask[mask < 0.5] = 0
+    mask[mask >= 0.5] = 1
+
+    # predict_np = predict.clone().cpu().data.numpy()
+    im = Image.fromarray((mask[:, :, 0] * 255).astype(np.uint8)).convert('RGB')
+    save_path_ = output_dir / "mask.png"
+    im.save(save_path_)
+
+    im_np = np.array(pil_im)
+    im_np = im_np / im_np.max()
+    im_np = mask * im_np
+    im_np[mask == 0] = 1
+    im_final = (im_np / im_np.max() * 255).astype(np.uint8)
+    im_final = Image.fromarray(im_final)
+
+    # free u2net
+    del net
+    torch.cuda.empty_cache()
+
+    return im_final, predict
diff --git a/methods/painter/diffsketchedit/painter_params.py b/methods/painter/diffsketchedit/painter_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..125e085a1df989692b83dd3b0cd830754bb964d0
--- /dev/null
+++ b/methods/painter/diffsketchedit/painter_params.py
@@ -0,0 +1,408 @@
+import os
+import random
+import pathlib
+
+import numpy as np
+from PIL import Image
+import pydiffvg
+import torch
+import torch.nn as nn
+
+from libs.modules.edge_map.DoG import XDoG
+from methods.diffvg_warp.parse_svg import svg_to_scene
+
+
+class Painter(nn.Module):
+
+    def __init__(
+            self,
+            args,
+            num_strokes=4,
+            num_segments=4,
+            imsize=224,
+            device=None,
+            target_im=None,
+            attention_map=None,
+            mask=None,
+            results_base=None,
+    ):
+        super(Painter, self).__init__()
+
+        self.args = args
+        self.device = device
+
+        self.num_paths = num_strokes
+        self.num_segments = num_segments
+        self.width = args.width
+        self.max_width = args.max_width
+        self.optim_width = args.optim_width
+        self.control_points_per_seg = args.control_points_per_seg
+        self.optim_rgba = args.optim_rgba
+        self.optim_alpha = args.optim_opacity
+        self.num_stages = args.num_stages
+        self.softmax_temp = args.softmax_temp
+
+        self.shapes = []
+        self.shape_groups = []
+        self.num_control_points = 0
+        self.canvas_width, self.canvas_height = imsize, imsize
+        self.points_vars = []
+        self.points_vars_gt = []
+        self.stroke_width_vars = []
+        self.color_vars = []
+        self.color_vars_threshold = args.color_vars_threshold
+
+        self.results_base = results_base[:results_base.find('stage=' + str(self.args.run_stage))] + 'stage=0'
+        self.path_svg = args.path_svg
+        self.strokes_per_stage = self.num_paths
+        self.optimize_flag = []
+
+        # attention related for strokes initialisation
+        self.attention_init = args.attention_init
+        self.xdog_intersec = args.xdog_intersec
+
+        self.image2clip_input = target_im
+        self.mask = mask
+        self.attention_map = attention_map if self.attention_init else None
+
+        self.thresh = self.set_attention_threshold_map() if self.attention_init else None
+        self.strokes_counter = 0  # counts the number of calls to "get_path"
+
+    def init_image(self, stage=0):
+        if stage > 0:
+            # Noting: if multi stages training than add new strokes on existing ones
+            # don't optimize on previous strokes
+            self.optimize_flag = [False for i in range(len(self.shapes))]
+            for i in range(self.strokes_per_stage):
+                stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0])
+                path = self.get_path()
+                self.shapes.append(path)
+                path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(self.shapes) - 1]),
+                                                 fill_color=None,
+                                                 stroke_color=stroke_color)
+                self.shape_groups.append(path_group)
+                self.optimize_flag.append(True)
+        else:
+            num_paths_exists = 0
+            if self.args.run_stage > 0:
+                assert self.path_svg != "" and self.path_svg is not None and pathlib.Path(self.path_svg).exists(), self.path_svg
+                print(f"-> init svg from `{self.path_svg}` ...")
+
+                self.canvas_width, self.canvas_height, self.shapes, self.shape_groups = self.load_svg(self.path_svg)
+                # if you want to add more strokes to existing ones and optimize on all of them
+                num_paths_exists = len(self.shapes)
+            else:
+                assert self.path_svg == "" or self.path_svg is None
+
+            for i in range(num_paths_exists, self.num_paths):
+                stroke_color = torch.tensor([0.0, 0.0, 0.0, 1.0])
+                path = self.get_path()
+                self.shapes.append(path)
+                path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([len(self.shapes) - 1]),
+                                                 fill_color=None,
+                                                 stroke_color=stroke_color)
+                self.shape_groups.append(path_group)
+
+            if self.args.run_stage > 0 and self.args.vector_local_edit:
+                self.optimize_flag = self.set_local_edit_strokes()
+            else:
+                self.optimize_flag = [True for i in range(len(self.shapes))]
+
+        img = self.render_warp()
+        img = img[:, :, 3:4] * img[:, :, :3] + \
+              torch.ones(img.shape[0], img.shape[1], 3, device=self.device) * (1 - img[:, :, 3:4])
+        img = img[:, :, :3]
+        img = img.unsqueeze(0)  # convert img from HWC to NCHW
+        img = img.permute(0, 3, 1, 2).to(self.device)  # NHWC -> NCHW
+
+        return img
+
+    def set_local_edit_strokes(self):
+        local_edit_mask_img_path = os.path.join(self.results_base,
+                                                'cross_attn_local_edit_' + str(self.args.vector_local_edit_attn_res) + "-" + str(self.args.run_stage) + '.png')
+        local_edit_mask = Image.open(local_edit_mask_img_path).convert('RGB')
+        local_edit_mask = np.array(local_edit_mask, dtype=np.float32)[:, :, 0]
+        local_edit_mask /= 255.0  # (224, 224), [0-BG, 1-FG]
+
+        optimize_flag = [False for _ in range(len(self.shapes))]
+
+        stroke_imgs = self.render_warp2()
+        stroke_imgs = torch.stack(stroke_imgs, dim=0)  # (N_strokes, H, W, 4)
+
+        opacity = stroke_imgs[:, :, :, 3:4]  # (N_strokes, H, W, 1)
+        stroke_imgs = opacity * stroke_imgs[:, :, :, :3] + \
+              (1 - opacity) * torch.ones(stroke_imgs.shape[0], stroke_imgs.shape[1], stroke_imgs.shape[2], 3, device=self.device)
+        stroke_imgs = stroke_imgs.cpu().data.numpy()[:, :, :, 0]  # (N_strokes, H, W), [0.0, 1.0]
+
+        for si in range(len(stroke_imgs)):
+            stroke_img = 1. - stroke_imgs[si]  # (H, W), [0.0-BG, 1.0-stroke]
+            stroke_mask = stroke_img > 0  # (H, W), [0.0-BG, 1.0-stroke]
+            union = stroke_mask * local_edit_mask
+
+            ## version-1
+            # valid = np.sum(union) > 0
+
+            ## version-2
+            valid = False
+            if np.sum(stroke_mask) > 0 and (np.sum(union) / np.sum(stroke_mask)) >= 0.5:
+                valid = True
+
+            if valid:
+                optimize_flag[si] = True
+
+        return optimize_flag
+
+    def get_image(self):
+        img = self.render_warp()
+
+        opacity = img[:, :, 3:4]
+        img = opacity * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device=self.device) * (1 - opacity)
+        img = img[:, :, :3]
+        img = img.unsqueeze(0)  # convert img from HWC to NCHW
+        img = img.permute(0, 3, 1, 2).to(self.device)  # NHWC -> NCHW
+        return img
+
+    def get_path(self):
+        self.num_control_points = torch.zeros(self.num_segments, dtype=torch.int32) + (self.control_points_per_seg - 2)
+        points = []
+        p0 = self.inds_normalised[self.strokes_counter] if self.attention_init else (random.random(), random.random())
+        points.append(p0)
+
+        for j in range(self.num_segments):
+            radius = 0.05
+            for k in range(self.control_points_per_seg - 1):
+                p1 = (p0[0] + radius * (random.random() - 0.5), p0[1] + radius * (random.random() - 0.5))
+                points.append(p1)
+                p0 = p1
+        points = torch.tensor(points).to(self.device)
+        points[:, 0] *= self.canvas_width
+        points[:, 1] *= self.canvas_height
+
+        path = pydiffvg.Path(num_control_points=self.num_control_points,
+                             points=points,
+                             stroke_width=torch.tensor(self.width),
+                             is_closed=False)
+        self.strokes_counter += 1
+        return path
+
+    def clip_curve_shape(self):
+        if self.optim_width:
+            for path in self.shapes:
+                path.stroke_width.data.clamp_(1.0, self.max_width)
+        if self.optim_rgba:
+            for group in self.shape_groups:
+                group.stroke_color.data.clamp_(0.0, 1.0)
+        else:
+            if self.optim_alpha:
+                for group in self.shape_groups:
+                    # group.stroke_color.data: RGBA
+                    group.stroke_color.data[:3].clamp_(0., 0.)  # to force black stroke
+                    group.stroke_color.data[-1].clamp_(0., 1.)  # opacity
+
+    def path_pruning(self):
+        # stroke pruning
+        for group in self.shape_groups:
+            group.stroke_color.data[-1] = (group.stroke_color.data[-1] >= self.color_vars_threshold).float()
+
+    def render_warp(self):
+        self.clip_curve_shape()
+
+        scene_args = pydiffvg.RenderFunction.serialize_scene(
+            self.canvas_width, self.canvas_height, self.shapes, self.shape_groups
+        )
+        _render = pydiffvg.RenderFunction.apply
+        img = _render(self.canvas_width,  # width
+                      self.canvas_height,  # height
+                      2,  # num_samples_x
+                      2,  # num_samples_y
+                      0,  # seed
+                      None,
+                      *scene_args)
+        return img
+
+    def render_warp2(self):
+        self.clip_curve_shape()
+
+        stroke_imgs = []
+        for si, shape_stroke in enumerate(self.shapes):
+            shapes = [shape_stroke]
+
+            path_group = pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]),
+                                             fill_color=None,
+                                             stroke_color=self.shape_groups[si].stroke_color)
+            shape_groups = [path_group]
+
+            scene_args = pydiffvg.RenderFunction.serialize_scene(
+                self.canvas_width, self.canvas_height, shapes, shape_groups
+            )
+            _render = pydiffvg.RenderFunction.apply
+            img = _render(self.canvas_width,  # width
+                          self.canvas_height,  # height
+                          2,  # num_samples_x
+                          2,  # num_samples_y
+                          0,  # seed
+                          None,
+                          *scene_args)
+            stroke_imgs.append(img)
+        return stroke_imgs
+
+    def set_points_parameters(self):
+        # stoke`s location optimization
+        self.points_vars = []
+        self.points_vars_gt = []
+        for i, path in enumerate(self.shapes):
+            path_points_gt = torch.clone(path.points)
+            path_points_gt.requires_grad = False
+            self.points_vars_gt.append(path_points_gt)
+            if self.optimize_flag[i]:
+                path.points.requires_grad = True
+                self.points_vars.append(path.points)
+            else:
+                path.points.requires_grad = False
+
+    def get_points_params(self):
+        return self.points_vars
+
+    def get_points_params_gt(self):
+        return self.points_vars_gt
+
+    def set_width_parameters(self):
+        # stroke`s  width optimization
+        self.stroke_width_vars = []
+        for i, path in enumerate(self.shapes):
+            if self.optimize_flag[i]:
+                path.stroke_width.requires_grad = True
+                self.stroke_width_vars.append(path.stroke_width)
+
+    def get_width_parameters(self):
+        return self.stroke_width_vars
+
+    def set_color_parameters(self):
+        # for storkes' color optimization (opacity)
+        self.color_vars = []
+        for i, group in enumerate(self.shape_groups):
+            if self.optimize_flag[i]:
+                group.stroke_color.requires_grad = True
+                self.color_vars.append(group.stroke_color)
+            else:
+                group.stroke_color.requires_grad = False
+
+    def get_color_parameters(self):
+        return self.color_vars
+
+    def save_svg(self, output_dir, fname):
+        pydiffvg.save_svg(f'{output_dir}/{fname}.svg',
+                          self.canvas_width,
+                          self.canvas_height,
+                          self.shapes,
+                          self.shape_groups)
+
+    def load_svg(self, path_svg):
+        canvas_width, canvas_height, shapes, shape_groups = svg_to_scene(path_svg)
+        return canvas_width, canvas_height, shapes, shape_groups
+
+    @staticmethod
+    def softmax(x, tau=0.2):
+        e_x = np.exp(x / tau)
+        return e_x / e_x.sum()
+
+    def set_inds_ldm(self):
+        attn_map = (self.attention_map - self.attention_map.min()) / \
+                   (self.attention_map.max() - self.attention_map.min())
+
+        if self.xdog_intersec:
+            xdog = XDoG(k=10)
+            im_xdog = xdog(self.image2clip_input[0].permute(1, 2, 0).cpu().numpy())
+            print(f"use XDoG, shape: {im_xdog.shape}")
+            intersec_map = (1 - im_xdog) * attn_map
+            attn_map = intersec_map
+
+        attn_map_soft = np.copy(attn_map)
+        attn_map_soft[attn_map > 0] = self.softmax(attn_map[attn_map > 0], tau=self.softmax_temp)
+
+        # select points
+        k = self.num_stages * self.num_paths
+        self.inds = np.random.choice(range(attn_map.flatten().shape[0]),
+                                     size=k,
+                                     replace=False,
+                                     p=attn_map_soft.flatten())
+        self.inds = np.array(np.unravel_index(self.inds, attn_map.shape)).T
+
+        self.inds_normalised = np.zeros(self.inds.shape)
+        self.inds_normalised[:, 0] = self.inds[:, 1] / self.canvas_width
+        self.inds_normalised[:, 1] = self.inds[:, 0] / self.canvas_height
+        self.inds_normalised = self.inds_normalised.tolist()
+        return attn_map_soft
+
+    def set_attention_threshold_map(self):
+        return self.set_inds_ldm()
+
+    def get_attn(self):
+        return self.attention_map
+
+    def get_thresh(self):
+        return self.thresh
+
+    def get_inds(self):
+        return self.inds
+
+    def get_mask(self):
+        return self.mask
+
+
+class SketchPainterOptimizer:
+
+    def __init__(
+            self,
+            renderer: nn.Module,
+            points_lr: float,
+            optim_alpha: bool,
+            optim_rgba: bool,
+            color_lr: float,
+            optim_width: bool,
+            width_lr: float
+    ):
+        self.renderer = renderer
+
+        self.points_lr = points_lr
+        self.optim_color = optim_alpha or optim_rgba
+        self.color_lr = color_lr
+        self.optim_width = optim_width
+        self.width_lr = width_lr
+
+        self.points_optimizer, self.width_optimizer, self.color_optimizer = None, None, None
+
+    def init_optimizers(self):
+        self.renderer.set_points_parameters()
+        self.points_optimizer = torch.optim.Adam(self.renderer.get_points_params(), lr=self.points_lr)
+        if self.optim_color:
+            self.renderer.set_color_parameters()
+            self.color_optimizer = torch.optim.Adam(self.renderer.get_color_parameters(), lr=self.color_lr)
+        if self.optim_width:
+            self.renderer.set_width_parameters()
+            self.width_optimizer = torch.optim.Adam(self.renderer.get_width_parameters(), lr=self.width_lr)
+
+    def update_lr(self, step, base_lr, decay_steps=(500, 750)):
+        if step % decay_steps[0] == 0 and step > 0:
+            for param_group in self.points_optimizer.param_groups:
+                param_group['lr'] = base_lr * 0.4
+        if step % decay_steps[1] == 0 and step > 0:
+            for param_group in self.points_optimizer.param_groups:
+                param_group['lr'] = base_lr * 0.1
+
+    def zero_grad_(self):
+        self.points_optimizer.zero_grad()
+        if self.optim_color:
+            self.color_optimizer.zero_grad()
+        if self.optim_width:
+            self.width_optimizer.zero_grad()
+
+    def step_(self):
+        self.points_optimizer.step()
+        if self.optim_color:
+            self.color_optimizer.step()
+        if self.optim_width:
+            self.width_optimizer.step()
+
+    def get_lr(self):
+        return self.points_optimizer.param_groups[0]['lr']
diff --git a/methods/painter/diffsketchedit/process_svg.py b/methods/painter/diffsketchedit/process_svg.py
new file mode 100644
index 0000000000000000000000000000000000000000..8652c478636976e7c8887032fe4daba27ef10dba
--- /dev/null
+++ b/methods/painter/diffsketchedit/process_svg.py
@@ -0,0 +1,66 @@
+import xml.etree.ElementTree as ET
+import statistics
+
+import argparse
+
+
+def remove_low_opacity_paths(svg_file_path, output_file_path, opacity_delta=0.2):
+    try:
+        # Parse the SVG file
+        tree = ET.parse(svg_file_path)
+        namespace = "http://www.w3.org/2000/svg"
+        ET.register_namespace("", namespace)
+
+        root = tree.getroot()
+        root.set('version', '1.1')
+
+        paths = root.findall('.//{http://www.w3.org/2000/svg}path')
+        # Collect stroke-opacity attribute values
+        opacity_values = []
+        for path in paths:
+            opacity = path.get("stroke-opacity")
+            if opacity is not None:
+                opacity_values.append(float(opacity))
+
+        # Calculate median opacity
+        median_opacity = statistics.median(opacity_values) + opacity_delta
+
+        # Create a temporary list to store paths to be removed
+        paths_to_remove = []
+        for path in paths:
+            opacity = path.get('stroke-opacity')
+            if opacity is not None and float(opacity) < median_opacity:
+                paths_to_remove.append(path)
+
+        # Remove paths from the root element
+        for path in paths_to_remove:
+            path.set('stroke-opacity', '0')
+
+        print(f"n_path: {len(paths)}, "
+              f"opacity_thresh: {median_opacity}, "
+              f"n_path_to_remove: {len(set(paths_to_remove))}.")
+
+        # Save the modified SVG to the specified path
+        tree.write(output_file_path, encoding='utf-8', xml_declaration=True, default_namespace="")
+        # print("SVG file saved successfully.")
+        # print(f"file has been saved in: {output_file_path}")
+    except Exception as e:
+        print(f"An error occurred: {str(e)}")
+
+
+if __name__ == '__main__':
+    """
+    python process_svg.py -save ./workdir/xx.svg -tar ./workdir/xx.svg
+    """
+    parser = argparse.ArgumentParser(description="vary style painterly rendering")
+    parser.add_argument("-tar", "--target_file",
+                        default="", type=str,
+                        help="the path of SVG file place.")
+    parser.add_argument("-save", "--save_path",
+                        default="", type=str,
+                        help="the path of processed SVG file place.")
+    parser.add_argument("-od", "--opacity_delta",
+                        default=0.1, type=float)
+    args = parser.parse_args()
+
+    remove_low_opacity_paths(args.target_file, args.save_path, float(args.opacity_delta))
diff --git a/methods/painter/diffsketchedit/sketch_utils.py b/methods/painter/diffsketchedit/sketch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9eb99f38a0b1c764c6e2bc563f8cebe62f3f3714
--- /dev/null
+++ b/methods/painter/diffsketchedit/sketch_utils.py
@@ -0,0 +1,168 @@
+import matplotlib.pyplot as plt
+import numpy as np
+from PIL import Image
+
+import torch
+from torchvision.utils import make_grid
+
+
+def plt_batch(
+        photos: torch.Tensor,
+        sketch: torch.Tensor,
+        step: int,
+        prompt: str,
+        save_path: str,
+        name: str,
+        dpi: int = 300
+):
+    if photos.shape != sketch.shape:
+        raise ValueError("photos and sketch must have the same dimensions")
+
+    plt.figure()
+    plt.subplot(1, 2, 1)  # nrows=1, ncols=2, index=1
+    grid = make_grid(photos, normalize=True, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.title("Generated sample")
+
+    plt.subplot(1, 2, 2)  # nrows=1, ncols=2, index=2
+    grid = make_grid(sketch, normalize=False, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.title(f"Rendering result - {step} steps")
+
+    plt.suptitle(insert_newline(prompt), fontsize=10)
+
+    plt.tight_layout()
+    plt.savefig(f"{save_path}/{name}.png", dpi=dpi)
+    plt.close()
+
+
+def plt_triplet(
+        photos: torch.Tensor,
+        sketch: torch.Tensor,
+        style: torch.Tensor,
+        step: int,
+        prompt: str,
+        save_path: str,
+        name: str,
+        dpi: int = 300
+):
+    if photos.shape != sketch.shape:
+        raise ValueError("photos and sketch must have the same dimensions")
+
+    plt.figure()
+    plt.subplot(1, 3, 1)  # nrows=1, ncols=3, index=1
+    grid = make_grid(photos, normalize=True, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.title("Generated sample")
+
+    plt.subplot(1, 3, 2)  # nrows=1, ncols=3, index=2
+    # style = (style + 1) / 2
+    grid = make_grid(style, normalize=False, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.title(f"Style")
+
+    plt.subplot(1, 3, 3)  # nrows=1, ncols=3, index=2
+    # sketch = (sketch + 1) / 2
+    grid = make_grid(sketch, normalize=False, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.title(f"Rendering result - {step} steps")
+
+    plt.suptitle(insert_newline(prompt), fontsize=10)
+
+    plt.tight_layout()
+    plt.savefig(f"{save_path}/{name}.png", dpi=dpi)
+    plt.close()
+
+
+def insert_newline(string, point=9):
+    # split by blank
+    words = string.split()
+    if len(words) <= point:
+        return string
+
+    word_chunks = [words[i:i + point] for i in range(0, len(words), point)]
+    new_string = "\n".join(" ".join(chunk) for chunk in word_chunks)
+    return new_string
+
+
+def log_tensor_img(inputs, output_dir, output_prefix="input", norm=False, dpi=300):
+    grid = make_grid(inputs, normalize=norm, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.tight_layout()
+    plt.savefig(f"{output_dir}/{output_prefix}.png", dpi=dpi)
+    plt.close()
+
+
+def plt_tensor_img(tensor, title, save_path, name, dpi=500):
+    grid = make_grid(tensor, normalize=True, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.title(f"{title}")
+    plt.savefig(f"{save_path}/{name}.png", dpi=dpi)
+    plt.close()
+
+
+def save_tensor_img(tensor, save_path, name, dpi=500):
+    grid = make_grid(tensor, normalize=True, pad_value=2)
+    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+    plt.imshow(ndarr)
+    plt.axis("off")
+    plt.tight_layout()
+    plt.savefig(f"{save_path}/{name}.png", dpi=dpi)
+    plt.close()
+
+
+def plt_attn(attn, threshold_map, inputs, inds, output_path):
+    # currently supports one image (and not a batch)
+    plt.figure(figsize=(10, 5))
+
+    plt.subplot(1, 3, 1)
+    main_im = make_grid(inputs, normalize=True, pad_value=2)
+    main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0))
+    plt.imshow(main_im, interpolation='nearest')
+    plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
+    plt.title("input img")
+    plt.axis("off")
+
+    plt.subplot(1, 3, 2)
+    plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1)
+    plt.title("attn map")
+    plt.axis("off")
+
+    plt.subplot(1, 3, 3)
+    threshold_map_ = (threshold_map - threshold_map.min()) / \
+                     (threshold_map.max() - threshold_map.min())
+    plt.imshow(np.nan_to_num(threshold_map_), interpolation='nearest', vmin=0, vmax=1)
+    plt.title("prob softmax")
+    plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
+    plt.axis("off")
+
+    plt.tight_layout()
+    plt.savefig(output_path)
+    plt.close()
+
+
+def fix_image_scale(im):
+    im_np = np.array(im) / 255
+    height, width = im_np.shape[0], im_np.shape[1]
+    max_len = max(height, width) + 20
+    new_background = np.ones((max_len, max_len, 3))
+    y, x = max_len // 2 - height // 2, max_len // 2 - width // 2
+    new_background[y: y + height, x: x + width] = im_np
+    new_background = (new_background / new_background.max()
+                      * 255).astype(np.uint8)
+    new_im = Image.fromarray(new_background)
+    return new_im
diff --git a/methods/painter/diffsketchedit/u2net.py b/methods/painter/diffsketchedit/u2net.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcedd43ece5537921eb68a4715a076f7d4d0f7cd
--- /dev/null
+++ b/methods/painter/diffsketchedit/u2net.py
@@ -0,0 +1,524 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class REBNCONV(nn.Module):
+    def __init__(self, in_ch=3, out_ch=3, dirate=1):
+        super(REBNCONV, self).__init__()
+
+        self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
+        self.bn_s1 = nn.BatchNorm2d(out_ch)
+        self.relu_s1 = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        hx = x
+        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
+
+        return xout
+
+
+## upsample tensor 'src' to have the same spatial size with tensor 'tar'
+def _upsample_like(src, tar):
+    src = F.interpolate(src, size=tar.shape[2:], mode='bilinear')
+
+    return src
+
+
+### RSU-7 ###
+class RSU7(nn.Module):  # UNet07DRES(nn.Module):
+
+    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+        super(RSU7, self).__init__()
+
+        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+        self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+        self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+    def forward(self, x):
+        hx = x
+        hxin = self.rebnconvin(hx)
+
+        hx1 = self.rebnconv1(hxin)
+        hx = self.pool1(hx1)
+
+        hx2 = self.rebnconv2(hx)
+        hx = self.pool2(hx2)
+
+        hx3 = self.rebnconv3(hx)
+        hx = self.pool3(hx3)
+
+        hx4 = self.rebnconv4(hx)
+        hx = self.pool4(hx4)
+
+        hx5 = self.rebnconv5(hx)
+        hx = self.pool5(hx5)
+
+        hx6 = self.rebnconv6(hx)
+
+        hx7 = self.rebnconv7(hx6)
+
+        hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
+        hx6dup = _upsample_like(hx6d, hx5)
+
+        hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
+        hx5dup = _upsample_like(hx5d, hx4)
+
+        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
+        hx4dup = _upsample_like(hx4d, hx3)
+
+        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+        return hx1d + hxin
+
+
+### RSU-6 ###
+class RSU6(nn.Module):  # UNet06DRES(nn.Module):
+
+    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+        super(RSU6, self).__init__()
+
+        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+        self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+        self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+    def forward(self, x):
+        hx = x
+
+        hxin = self.rebnconvin(hx)
+
+        hx1 = self.rebnconv1(hxin)
+        hx = self.pool1(hx1)
+
+        hx2 = self.rebnconv2(hx)
+        hx = self.pool2(hx2)
+
+        hx3 = self.rebnconv3(hx)
+        hx = self.pool3(hx3)
+
+        hx4 = self.rebnconv4(hx)
+        hx = self.pool4(hx4)
+
+        hx5 = self.rebnconv5(hx)
+
+        hx6 = self.rebnconv6(hx5)
+
+        hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
+        hx5dup = _upsample_like(hx5d, hx4)
+
+        hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
+        hx4dup = _upsample_like(hx4d, hx3)
+
+        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+        return hx1d + hxin
+
+
+### RSU-5 ###
+class RSU5(nn.Module):  # UNet05DRES(nn.Module):
+
+    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+        super(RSU5, self).__init__()
+
+        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+        self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+        self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+    def forward(self, x):
+        hx = x
+
+        hxin = self.rebnconvin(hx)
+
+        hx1 = self.rebnconv1(hxin)
+        hx = self.pool1(hx1)
+
+        hx2 = self.rebnconv2(hx)
+        hx = self.pool2(hx2)
+
+        hx3 = self.rebnconv3(hx)
+        hx = self.pool3(hx3)
+
+        hx4 = self.rebnconv4(hx)
+
+        hx5 = self.rebnconv5(hx4)
+
+        hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
+        hx4dup = _upsample_like(hx4d, hx3)
+
+        hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+        return hx1d + hxin
+
+
+### RSU-4 ###
+class RSU4(nn.Module):  # UNet04DRES(nn.Module):
+
+    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+        super(RSU4, self).__init__()
+
+        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
+        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
+
+        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
+
+        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
+        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+    def forward(self, x):
+        hx = x
+
+        hxin = self.rebnconvin(hx)
+
+        hx1 = self.rebnconv1(hxin)
+        hx = self.pool1(hx1)
+
+        hx2 = self.rebnconv2(hx)
+        hx = self.pool2(hx2)
+
+        hx3 = self.rebnconv3(hx)
+
+        hx4 = self.rebnconv4(hx3)
+
+        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
+
+        return hx1d + hxin
+
+
+### RSU-4F ###
+class RSU4F(nn.Module):  # UNet04FRES(nn.Module):
+
+    def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
+        super(RSU4F, self).__init__()
+
+        self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
+
+        self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
+        self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
+        self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
+
+        self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
+
+        self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
+        self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
+        self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
+
+    def forward(self, x):
+        hx = x
+
+        hxin = self.rebnconvin(hx)
+
+        hx1 = self.rebnconv1(hxin)
+        hx2 = self.rebnconv2(hx1)
+        hx3 = self.rebnconv3(hx2)
+
+        hx4 = self.rebnconv4(hx3)
+
+        hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
+        hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
+        hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
+
+        return hx1d + hxin
+
+
+##### U^2-Net ####
+class U2NET(nn.Module):
+
+    def __init__(self, in_ch=3, out_ch=1):
+        super(U2NET, self).__init__()
+
+        self.stage1 = RSU7(in_ch, 32, 64)
+        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage2 = RSU6(64, 32, 128)
+        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage3 = RSU5(128, 64, 256)
+        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage4 = RSU4(256, 128, 512)
+        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage5 = RSU4F(512, 256, 512)
+        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage6 = RSU4F(512, 256, 512)
+
+        # decoder
+        self.stage5d = RSU4F(1024, 256, 512)
+        self.stage4d = RSU4(1024, 128, 256)
+        self.stage3d = RSU5(512, 64, 128)
+        self.stage2d = RSU6(256, 32, 64)
+        self.stage1d = RSU7(128, 16, 64)
+
+        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
+        self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
+        self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
+        self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
+
+        self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
+
+    def forward(self, x):
+        hx = x
+
+        # stage 1
+        hx1 = self.stage1(hx)
+        hx = self.pool12(hx1)
+
+        # stage 2
+        hx2 = self.stage2(hx)
+        hx = self.pool23(hx2)
+
+        # stage 3
+        hx3 = self.stage3(hx)
+        hx = self.pool34(hx3)
+
+        # stage 4
+        hx4 = self.stage4(hx)
+        hx = self.pool45(hx4)
+
+        # stage 5
+        hx5 = self.stage5(hx)
+        hx = self.pool56(hx5)
+
+        # stage 6
+        hx6 = self.stage6(hx)
+        hx6up = _upsample_like(hx6, hx5)
+
+        # -------------------- decoder --------------------
+        hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
+        hx5dup = _upsample_like(hx5d, hx4)
+
+        hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
+        hx4dup = _upsample_like(hx4d, hx3)
+
+        hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
+
+        # side output
+        d1 = self.side1(hx1d)
+
+        d2 = self.side2(hx2d)
+        d2 = _upsample_like(d2, d1)
+
+        d3 = self.side3(hx3d)
+        d3 = _upsample_like(d3, d1)
+
+        d4 = self.side4(hx4d)
+        d4 = _upsample_like(d4, d1)
+
+        d5 = self.side5(hx5d)
+        d5 = _upsample_like(d5, d1)
+
+        d6 = self.side6(hx6)
+        d6 = _upsample_like(d6, d1)
+
+        d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
+
+        return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), \
+               torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), \
+               torch.sigmoid(d6)
+
+
+### U^2-Net small ###
+class U2NETP(nn.Module):
+
+    def __init__(self, in_ch=3, out_ch=1):
+        super(U2NETP, self).__init__()
+
+        self.stage1 = RSU7(in_ch, 16, 64)
+        self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage2 = RSU6(64, 16, 64)
+        self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage3 = RSU5(64, 16, 64)
+        self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage4 = RSU4(64, 16, 64)
+        self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage5 = RSU4F(64, 16, 64)
+        self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
+
+        self.stage6 = RSU4F(64, 16, 64)
+
+        # decoder
+        self.stage5d = RSU4F(128, 16, 64)
+        self.stage4d = RSU4(128, 16, 64)
+        self.stage3d = RSU5(128, 16, 64)
+        self.stage2d = RSU6(128, 16, 64)
+        self.stage1d = RSU7(128, 16, 64)
+
+        self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
+        self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
+
+        self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
+
+    def forward(self, x):
+        hx = x
+
+        # stage 1
+        hx1 = self.stage1(hx)
+        hx = self.pool12(hx1)
+
+        # stage 2
+        hx2 = self.stage2(hx)
+        hx = self.pool23(hx2)
+
+        # stage 3
+        hx3 = self.stage3(hx)
+        hx = self.pool34(hx3)
+
+        # stage 4
+        hx4 = self.stage4(hx)
+        hx = self.pool45(hx4)
+
+        # stage 5
+        hx5 = self.stage5(hx)
+        hx = self.pool56(hx5)
+
+        # stage 6
+        hx6 = self.stage6(hx)
+        hx6up = _upsample_like(hx6, hx5)
+
+        # decoder
+        hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
+        hx5dup = _upsample_like(hx5d, hx4)
+
+        hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
+        hx4dup = _upsample_like(hx4d, hx3)
+
+        hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
+        hx3dup = _upsample_like(hx3d, hx2)
+
+        hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
+        hx2dup = _upsample_like(hx2d, hx1)
+
+        hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
+
+        # side output
+        d1 = self.side1(hx1d)
+
+        d2 = self.side2(hx2d)
+        d2 = _upsample_like(d2, d1)
+
+        d3 = self.side3(hx3d)
+        d3 = _upsample_like(d3, d1)
+
+        d4 = self.side4(hx4d)
+        d4 = _upsample_like(d4, d1)
+
+        d5 = self.side5(hx5d)
+        d5 = _upsample_like(d5, d1)
+
+        d6 = self.side6(hx6)
+        d6 = _upsample_like(d6, d1)
+
+        d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
+
+        return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), \
+               torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), \
+               torch.sigmoid(d6)
diff --git a/methods/token2attn/__init__.py b/methods/token2attn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/methods/token2attn/__init__.py
@@ -0,0 +1 @@
+
diff --git a/methods/token2attn/attn_control.py b/methods/token2attn/attn_control.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b687b558552a52b73cb120824047b63342f9228
--- /dev/null
+++ b/methods/token2attn/attn_control.py
@@ -0,0 +1,323 @@
+from abc import ABC, abstractmethod
+from typing import Optional, Union, Tuple, List, Dict
+
+import torch
+import torch.nn.functional as F
+
+from .ptp_utils import (get_word_inds, get_time_words_attention_alpha)
+from .seq_aligner import (get_replacement_mapper, get_refinement_mapper)
+
+
+class AttentionControl(ABC):
+
+    def __init__(self):
+        self.cur_step = 0
+        self.num_att_layers = -1
+        self.cur_att_layer = 0
+
+    def step_callback(self, x_t):
+        return x_t
+
+    def between_steps(self):
+        return
+
+    @property
+    def num_uncond_att_layers(self):
+        return 0
+
+    @abstractmethod
+    def forward(self, attn, is_cross: bool, place_in_unet: str):
+        raise NotImplementedError
+
+    def __call__(self, attn, is_cross: bool, place_in_unet: str):
+        if self.cur_att_layer >= self.num_uncond_att_layers:
+            h = attn.shape[0]
+            attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
+        self.cur_att_layer += 1
+        if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
+            self.cur_att_layer = 0
+            self.cur_step += 1
+            self.between_steps()
+        return attn
+
+    def reset(self):
+        self.cur_step = 0
+        self.cur_att_layer = 0
+
+
+class EmptyControl(AttentionControl):
+
+    def forward(self, attn, is_cross: bool, place_in_unet: str):
+        return attn
+
+
+class AttentionStore(AttentionControl):
+
+    def __init__(self):
+        super(AttentionStore, self).__init__()
+        self.step_store = self.get_empty_store()
+        self.attention_store = {}
+
+    @staticmethod
+    def get_empty_store():
+        return {"down_cross": [], "mid_cross": [], "up_cross": [],
+                "down_self": [], "mid_self": [], "up_self": []}
+
+    def forward(self, attn, is_cross: bool, place_in_unet: str):
+        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
+        if attn.shape[1] <= 32 ** 2:  # avoid memory overhead
+            self.step_store[key].append(attn)
+        return attn
+
+    def between_steps(self):
+        if len(self.attention_store) == 0:
+            self.attention_store = self.step_store
+        else:
+            for key in self.attention_store:
+                for i in range(len(self.attention_store[key])):
+                    self.attention_store[key][i] += self.step_store[key][i]
+        self.step_store = self.get_empty_store()
+
+    def get_average_attention(self):
+        average_attention = {
+            key: [item / self.cur_step for item in self.attention_store[key]]
+            for key in self.attention_store
+        }
+        return average_attention
+
+    def reset(self):
+        super(AttentionStore, self).reset()
+        self.step_store = self.get_empty_store()
+        self.attention_store = {}
+
+
+class LocalBlend:
+
+    def __init__(self,
+                 prompts: List[str],
+                 words: [List[List[str]]],
+                 tokenizer,
+                 device,
+                 threshold=.3,
+                 max_num_words=77):
+        self.max_num_words = max_num_words
+
+        alpha_layers1 = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words)
+        alpha_layers2 = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words)
+        for i, (prompt, words_) in enumerate(zip(prompts, words)):
+            assert type(words_) is list and len(words_) == 2
+            ind1 = get_word_inds(prompt, words_[0], tokenizer)
+            alpha_layers1[i, :, :, :, :, ind1] = 1
+            ind2 = get_word_inds(prompt, words_[1], tokenizer)
+            alpha_layers2[i, :, :, :, :, ind2] = 1
+
+        self.alpha_layers1 = alpha_layers1.to(device)
+        self.alpha_layers2 = alpha_layers2.to(device)
+        self.threshold = threshold
+
+    def __call__(self, x_t, attention_store):
+        k = 1
+        maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
+        maps = [item.reshape(self.alpha_layers1.shape[0], -1, 1, 16, 16, self.max_num_words) for item in maps]
+        maps = torch.cat(maps, dim=1)
+
+        maps1 = (maps * self.alpha_layers1).sum(-1).mean(1)
+        mask1 = F.max_pool2d(maps1, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
+        mask1 = F.interpolate(mask1, size=(x_t.shape[2:]))
+        mask1 = mask1 / mask1.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
+        mask1 = mask1.gt(self.threshold)
+
+        maps2 = (maps * self.alpha_layers2).sum(-1).mean(1)
+        mask2 = F.max_pool2d(maps2, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
+        mask2 = F.interpolate(mask2, size=(x_t.shape[2:]))
+        mask2 = mask2 / mask2.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
+        mask2 = mask2.gt(self.threshold)
+
+        mask = (mask1 + mask2).float()
+
+        prev_x_t = torch.cat([x_t[:1], x_t[:-1]], dim=0)
+        x_t = (1 - mask) * prev_x_t + mask * x_t
+        return x_t
+
+
+class AttentionControlEdit(AttentionStore, ABC):
+
+    def __init__(self,
+                 prompts,
+                 num_steps: int,
+                 cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
+                 self_replace_steps: Union[float, Tuple[float, float]],
+                 local_blend: Optional[LocalBlend],
+                 tokenizer,
+                 device):
+        super(AttentionControlEdit, self).__init__()
+        self.tokenizer = tokenizer
+        self.device = device
+
+        self.batch_size = len(prompts)
+        self.cross_replace_alpha = get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps,
+                                                                  self.tokenizer).to(self.device)
+        if type(self_replace_steps) is float:
+            self_replace_steps = 0, self_replace_steps
+        self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
+        self.local_blend = local_blend  # define outside
+
+    def step_callback(self, x_t):
+        if self.local_blend is not None:
+            x_t = self.local_blend(x_t, self.attention_store)
+        return x_t
+
+    def replace_self_attention(self, attn_base, att_replace):
+        if att_replace.shape[2] <= 16 ** 2:
+            return attn_base.clone()
+        else:
+            return att_replace
+
+    @abstractmethod
+    def replace_cross_attention(self, attn_base, att_replace):
+        raise NotImplementedError
+
+    def forward(self, attn, is_cross: bool, place_in_unet: str):
+        super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
+        # FIXME not replace correctly
+        if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
+            h = attn.shape[0] // (self.batch_size)
+            attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
+            attn_base, attn_repalce = attn[:-1], attn[1:]
+            if is_cross:
+                alpha_words = self.cross_replace_alpha[self.cur_step]
+                attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (
+                        1 - alpha_words) * attn_repalce
+                attn[1:] = attn_repalce_new
+            else:
+                attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
+            attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
+        return attn
+
+
+class AttentionReplace(AttentionControlEdit):
+
+    def __init__(self,
+                 prompts,
+                 num_steps: int,
+                 cross_replace_steps: float,
+                 self_replace_steps: float,
+                 local_blend: Optional[LocalBlend] = None,
+                 tokenizer=None,
+                 device=None):
+        super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps,
+                                               local_blend, tokenizer, device)
+        self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device)
+
+    def replace_cross_attention(self, attn_base, att_replace):
+        # attn_base/att_replace: (len(prompts)-1, 8, 4096, 77)
+        # self.mapper: (len(prompts)-1, 77, 77)
+        version = 'v2'
+
+        if version == 'v1':
+            return torch.einsum('bhpw,bwn->bhpn', attn_base, self.mapper)
+        else:
+            bsz = attn_base.size()[0]
+            attn_base_replace = []
+            for batch_i in range(bsz):
+                if batch_i == 0:
+                    attn_base_i = attn_base[batch_i]  # (8, 4096, 77)
+                else:
+                    attn_base_i = attn_base_replace[-1]
+                mapper_i = self.mapper[batch_i:batch_i + 1, :, :]  # (1, 77, 77)
+                attn_base_replace_i = torch.einsum('hpw,bwn->bhpn', attn_base_i, mapper_i)  # (1, 8, 4096, 77)
+                attn_base_replace.append(attn_base_replace_i[0])
+            attn_base_replace = torch.stack(attn_base_replace, dim=0)  # (len(prompts)-1, 8, 4096, 77)
+            return attn_base_replace
+
+
+class AttentionRefine(AttentionControlEdit):
+
+    def __init__(self,
+                 prompts,
+                 num_steps: int,
+                 cross_replace_steps: float,
+                 self_replace_steps: float,
+                 local_blend: Optional[LocalBlend] = None,
+                 tokenizer=None,
+                 device=None):
+        super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps,
+                                              local_blend, tokenizer, device)
+        self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer)
+        self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device)
+        self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
+
+    def replace_cross_attention(self, attn_base, att_replace):
+        # attn_base/att_replace: (len(prompts)-1, 8, 4096, 77)
+        version = 'v2'
+
+        bsz = attn_base.size()[0]
+        attn_base_replace = []
+        for batch_i in range(bsz):
+            if version == 'v1':
+                attn_base_i = attn_base[batch_i]  # (8, 4096, 77)
+            else:
+                if batch_i == 0:
+                    attn_base_i = attn_base[batch_i]
+                else:
+                    attn_base_i = attn_base_replace[-1]
+            mapper_i = self.mapper[batch_i:batch_i + 1, :]  # (1, 77)
+            attn_base_replace_i = attn_base_i[:, :, mapper_i].permute(2, 0, 1, 3)  # (1, 8, 4096, 77)
+            attn_base_replace.append(attn_base_replace_i[0])
+        attn_base_replace = torch.stack(attn_base_replace, dim=0)  # (len(prompts)-1, 8, 4096, 77)
+        attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
+        return attn_replace
+
+
+class AttentionReweight(AttentionControlEdit):
+
+    def __init__(self,
+                 prompts,
+                 num_steps: int,
+                 cross_replace_steps: float,
+                 self_replace_steps: float,
+                 equalizer,
+                 local_blend: Optional[LocalBlend] = None,
+                 controller: Optional[AttentionControlEdit] = None,
+                 tokenizer=None,
+                 device=None):
+        super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps,
+                                                local_blend, tokenizer, device)
+        self.equalizer = equalizer.to(self.device)
+        self.prev_controller = controller
+
+    def replace_cross_attention(self, attn_base, att_replace):
+        # attn_base/att_replace: (len(prompts)-1, 8, 4096, 77)
+        # self.equalizer: (len(prompts)-1, 77)
+        if self.prev_controller is not None:
+            attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
+
+        version = 'v2'
+
+        if version == 'v1':
+            attn_replace = attn_base[:, :, :, :] * self.equalizer[:, None, None, :]
+            return attn_replace
+        else:
+            bsz = attn_base.size()[0]
+            attn_replace_rst_all = []
+            for bi in range(bsz):
+                if bi == 0:
+                    attn_replace_rst = attn_base[bi, :, :, :] * self.equalizer[bi, None, None, :]
+                else:
+                    attn_replace_rst = attn_replace_rst_all[-1] * self.equalizer[bi, None, None, :]
+                attn_replace_rst_all.append(attn_replace_rst)
+            attn_replace_rst_all = torch.stack(attn_replace_rst_all, dim=0)
+            return attn_replace_rst_all
+
+
+def get_equalizer(tokenizer, texts: List[str],
+                  word_select: List[str],
+                  values: List[float]):
+    equalizer = torch.ones(len(values), 77)
+    values = torch.tensor(values, dtype=torch.float32)
+    for wi, word in enumerate(word_select):
+        text = texts[wi]
+        value = values[wi]
+        inds = get_word_inds(text, word, tokenizer)
+        equalizer[wi, inds] = value
+    return equalizer
diff --git a/methods/token2attn/ptp_utils.py b/methods/token2attn/ptp_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4557c4cd929222f7580e444f392c0260b0fbe1dd
--- /dev/null
+++ b/methods/token2attn/ptp_utils.py
@@ -0,0 +1,110 @@
+# -*- coding: utf-8 -*-
+import pathlib
+from typing import Union, Optional, List, Tuple, Dict, Text, BinaryIO
+from PIL import Image
+
+import torch
+import cv2
+import numpy as np
+
+from .seq_aligner import get_word_inds
+
+
+def text_under_image(image: np.ndarray,
+                     text: str,
+                     text_color: Tuple[int, int, int] = (0, 0, 0)) -> np.ndarray:
+    h, w, c = image.shape
+    offset = int(h * .2)
+    img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
+    font = cv2.FONT_HERSHEY_SIMPLEX
+    img[:h] = image
+    textsize = cv2.getTextSize(text, font, 1, 2)[0]
+    text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
+    cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2)
+    return img
+
+
+def view_images(images: Union[np.ndarray, List],
+                num_rows: int = 1,
+                offset_ratio: float = 0.02,
+                save_image: bool = False,
+                fp: Union[Text, pathlib.Path, BinaryIO] = None) -> np.ndarray:
+    if save_image:
+        assert fp is not None
+
+    if isinstance(images, np.ndarray) and images.ndim == 4:
+        num_empty = images.shape[0] % num_rows
+    else:
+        images = [images] if not isinstance(images, list) else images
+        num_empty = len(images) % num_rows
+
+    empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
+    images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
+    num_items = len(images)
+
+    # Calculate the composite image
+    h, w, c = images[0].shape
+    offset = int(h * offset_ratio)
+    num_cols = int(np.ceil(num_items / num_rows))  # count the number of columns
+    image_h = h * num_rows + offset * (num_rows - 1)
+    image_w = w * num_cols + offset * (num_cols - 1)
+    assert image_h > 0, "Invalid image height: {} (num_rows={}, offset_ratio={}, num_items={})".format(
+        image_h, num_rows, offset_ratio, num_items)
+    assert image_w > 0, "Invalid image width: {} (num_cols={}, offset_ratio={}, num_items={})".format(
+        image_w, num_cols, offset_ratio, num_items)
+    image_ = np.ones((image_h, image_w, 3), dtype=np.uint8) * 255
+
+    # Ensure that the last row is filled with empty images if necessary
+    if len(images) % num_cols > 0:
+        empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
+        num_empty = num_cols - len(images) % num_cols
+        images += [empty_images] * num_empty
+
+    for i in range(num_rows):
+        for j in range(num_cols):
+            k = i * num_cols + j
+            if k >= num_items:
+                break
+            image_[i * (h + offset): i * (h + offset) + h, j * (w + offset): j * (w + offset) + w] = images[k]
+
+    pil_img = Image.fromarray(image_)
+    if save_image:
+        pil_img.save(fp)
+    return pil_img
+
+
+def update_alpha_time_word(alpha,
+                           bounds: Union[float, Tuple[float, float]],
+                           prompt_ind: int,
+                           word_inds: Optional[torch.Tensor] = None):
+    if isinstance(bounds, float):
+        bounds = 0, bounds
+    start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
+    if word_inds is None:
+        word_inds = torch.arange(alpha.shape[2])
+    alpha[: start, prompt_ind, word_inds] = 0
+    alpha[start: end, prompt_ind, word_inds] = 1
+    alpha[end:, prompt_ind, word_inds] = 0
+    return alpha
+
+
+def get_time_words_attention_alpha(prompts, num_steps,
+                                   cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
+                                   tokenizer,
+                                   max_num_words=77):
+    if type(cross_replace_steps) is not dict:
+        cross_replace_steps = {"default_": cross_replace_steps}
+    if "default_" not in cross_replace_steps:
+        cross_replace_steps["default_"] = (0., 1.)
+    alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
+    for i in range(len(prompts) - 1):
+        alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
+                                                  i)
+    for key, item in cross_replace_steps.items():
+        if key != "default_":
+            inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
+            for i, ind in enumerate(inds):
+                if len(ind) > 0:
+                    alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
+    alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
+    return alpha_time_words
diff --git a/methods/token2attn/seq_aligner.py b/methods/token2attn/seq_aligner.py
new file mode 100644
index 0000000000000000000000000000000000000000..b077acde60cde06197d9f6369c2d13566ccddeae
--- /dev/null
+++ b/methods/token2attn/seq_aligner.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+import torch
+import numpy as np
+
+
+class ScoreParams:
+
+    def __init__(self, gap, match, mismatch):
+        self.gap = gap
+        self.match = match
+        self.mismatch = mismatch
+
+    def mis_match_char(self, x, y):
+        if x != y:
+            return self.mismatch
+        else:
+            return self.match
+
+
+def get_matrix(size_x, size_y, gap):
+    matrix = []
+    for i in range(len(size_x) + 1):
+        sub_matrix = []
+        for j in range(len(size_y) + 1):
+            sub_matrix.append(0)
+        matrix.append(sub_matrix)
+    for j in range(1, len(size_y) + 1):
+        matrix[0][j] = j * gap
+    for i in range(1, len(size_x) + 1):
+        matrix[i][0] = i * gap
+    return matrix
+
+
+def get_matrix(size_x, size_y, gap):
+    matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
+    matrix[0, 1:] = (np.arange(size_y) + 1) * gap
+    matrix[1:, 0] = (np.arange(size_x) + 1) * gap
+    return matrix
+
+
+def get_traceback_matrix(size_x, size_y):
+    matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
+    matrix[0, 1:] = 1
+    matrix[1:, 0] = 2
+    matrix[0, 0] = 4
+    return matrix
+
+
+def global_align(x, y, score):
+    matrix = get_matrix(len(x), len(y), score.gap)
+    trace_back = get_traceback_matrix(len(x), len(y))
+    for i in range(1, len(x) + 1):
+        for j in range(1, len(y) + 1):
+            left = matrix[i, j - 1] + score.gap
+            up = matrix[i - 1, j] + score.gap
+            diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
+            matrix[i, j] = max(left, up, diag)
+            if matrix[i, j] == left:
+                trace_back[i, j] = 1
+            elif matrix[i, j] == up:
+                trace_back[i, j] = 2
+            else:
+                trace_back[i, j] = 3
+    return matrix, trace_back
+
+
+def get_aligned_sequences(x, y, trace_back):
+    x_seq = []
+    y_seq = []
+    i = len(x)
+    j = len(y)
+    mapper_y_to_x = []
+    while i > 0 or j > 0:
+        if trace_back[i, j] == 3:
+            x_seq.append(x[i - 1])
+            y_seq.append(y[j - 1])
+            i = i - 1
+            j = j - 1
+            mapper_y_to_x.append((j, i))
+        elif trace_back[i][j] == 1:
+            x_seq.append('-')
+            y_seq.append(y[j - 1])
+            j = j - 1
+            mapper_y_to_x.append((j, -1))
+        elif trace_back[i][j] == 2:
+            x_seq.append(x[i - 1])
+            y_seq.append('-')
+            i = i - 1
+        elif trace_back[i][j] == 4:
+            break
+    mapper_y_to_x.reverse()
+    return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
+
+
+def get_mapper(x: str, y: str, tokenizer, max_len=77):
+    x_seq = tokenizer.encode(x)
+    y_seq = tokenizer.encode(y)
+    score = ScoreParams(0, 1, -1)
+    matrix, trace_back = global_align(x_seq, y_seq, score)
+    mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
+    alphas = torch.ones(max_len)
+    alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
+    mapper = torch.zeros(max_len, dtype=torch.int64)
+    mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
+    mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
+    return mapper, alphas
+
+
+def get_refinement_mapper(prompts, tokenizer, max_len=77):
+    mappers, alphas = [], []
+    for i in range(1, len(prompts)):
+        mapper, alpha = get_mapper(prompts[i-1], prompts[i], tokenizer, max_len)
+        mappers.append(mapper)
+        alphas.append(alpha)
+    return torch.stack(mappers), torch.stack(alphas)
+
+
+def get_word_inds(text: str, word_place: int, tokenizer):
+    split_text = text.split(" ")
+    if type(word_place) is str:
+        word_place = [i for i, word in enumerate(split_text) if word_place == word]
+    elif type(word_place) is int:
+        word_place = [word_place]
+    out = []
+    if len(word_place) > 0:
+        words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
+        cur_len, ptr = 0, 0
+
+        for i in range(len(words_encode)):
+            cur_len += len(words_encode[i])
+            if ptr in word_place:
+                out.append(i + 1)
+            if cur_len >= len(split_text[ptr]):
+                ptr += 1
+                cur_len = 0
+    return np.array(out)
+
+
+def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
+    words_x = x.split(' ')
+    words_y = y.split(' ')
+    if len(words_x) != len(words_y):
+        raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
+                         f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
+    inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
+    inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
+    inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
+    mapper = np.zeros((max_len, max_len))
+    i = j = 0
+    cur_inds = 0
+    while i < max_len and j < max_len:
+        if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
+            inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
+            if len(inds_source_) == len(inds_target_):
+                mapper[inds_source_, inds_target_] = 1
+            else:
+                ratio = 1 / len(inds_target_)
+                for i_t in inds_target_:
+                    mapper[inds_source_, i_t] = ratio
+            cur_inds += 1
+            i += len(inds_source_)
+            j += len(inds_target_)
+        elif cur_inds < len(inds_source):
+            mapper[i, j] = 1
+            i += 1
+            j += 1
+        else:
+            mapper[j, j] = 1
+            i += 1
+            j += 1
+
+    return torch.from_numpy(mapper).float()
+
+
+def get_replacement_mapper(prompts, tokenizer, max_len=77):
+    # x_seq = prompts[0]
+    mappers = []
+    for i in range(1, len(prompts)):
+        mapper = get_replacement_mapper_(prompts[i-1], prompts[i], tokenizer, max_len)
+        mappers.append(mapper)
+    return torch.stack(mappers)
diff --git a/pipelines/__init__.py b/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/pipelines/__init__.py
@@ -0,0 +1 @@
+
diff --git a/pipelines/painter/__init__.py b/pipelines/painter/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/pipelines/painter/diffsketchedit_pipeline.py b/pipelines/painter/diffsketchedit_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac77096fa4e1d31bb0fc33afac0c3168708595e1
--- /dev/null
+++ b/pipelines/painter/diffsketchedit_pipeline.py
@@ -0,0 +1,673 @@
+import os
+import pathlib
+from PIL import Image
+from functools import partial
+
+import torch
+import torch.nn.functional as F
+from torchvision import transforms
+from torchvision.datasets.folder import is_image_file
+from tqdm.auto import tqdm
+import numpy as np
+from skimage.color import rgb2gray
+import diffusers
+
+from libs.engine import ModelState
+from libs.metric.lpips_origin import LPIPS
+from libs.metric.piq.perceptual import DISTS as DISTS_PIQ
+from libs.metric.clip_score import CLIPScoreWrapper
+from methods.painter.diffsketchedit import (
+    Painter, SketchPainterOptimizer, Token2AttnMixinASDSPipeline, Token2AttnMixinASDSSDXLPipeline)
+from methods.painter.diffsketchedit.sketch_utils import (
+    log_tensor_img, plt_batch, plt_attn, save_tensor_img, fix_image_scale)
+from methods.painter.diffsketchedit.mask_utils import get_mask_u2net
+from methods.token2attn.attn_control import AttentionStore, EmptyControl, \
+    LocalBlend, AttentionReplace, AttentionRefine, AttentionReweight, get_equalizer
+from methods.token2attn.ptp_utils import view_images, get_word_inds
+from methods.diffusers_warp import init_diffusion_pipeline, model2res
+from methods.diffvg_warp import init_diffvg
+from methods.painter.diffsketchedit.process_svg import remove_low_opacity_paths
+
+
+class DiffSketchEditPipeline(ModelState):
+    def __init__(self, args):
+        super().__init__(args, ignore_log=True)
+
+        init_diffvg(self.device, True, args.print_timing)
+
+        if args.model_id == "sdxl":
+            # default LSDSSDXLPipeline scheduler is EulerDiscreteScheduler
+            # when LSDSSDXLPipeline calls, scheduler.timesteps will change in step 4
+            # which causes problem in sds add_noise() function
+            # because the random t may not in scheduler.timesteps
+            custom_pipeline = Token2AttnMixinASDSSDXLPipeline
+            custom_scheduler = diffusers.DPMSolverMultistepScheduler
+            self.args.cross_attn_res = self.args.cross_attn_res * 2
+        elif args.model_id == 'sd21':
+            custom_pipeline = Token2AttnMixinASDSPipeline
+            custom_scheduler = diffusers.DDIMScheduler
+        elif args.model_id == 'sd15':
+            custom_pipeline = Token2AttnMixinASDSPipeline
+            custom_scheduler = diffusers.DDIMScheduler
+        else:  # sd14
+            custom_pipeline = Token2AttnMixinASDSPipeline
+            custom_scheduler = None
+
+        self.diffusion = init_diffusion_pipeline(
+            self.args.model_id,
+            custom_pipeline=custom_pipeline,
+            custom_scheduler=custom_scheduler,
+            device=self.device,
+            local_files_only=not args.download,
+            force_download=args.force_download,
+            resume_download=args.resume_download,
+            ldm_speed_up=args.ldm_speed_up,
+            enable_xformers=args.enable_xformers,
+            gradient_checkpoint=args.gradient_checkpoint,
+        )
+
+        # init clip model and clip score wrapper
+        self.cargs = self.args.clip
+        self.clip_score_fn = CLIPScoreWrapper(self.cargs.model_name,
+                                              device=self.device,
+                                              visual_score=True,
+                                              feats_loss_type=self.cargs.feats_loss_type,
+                                              feats_loss_weights=self.cargs.feats_loss_weights,
+                                              fc_loss_weight=self.cargs.fc_loss_weight)
+
+    def update_info(self, seed, token_ind, prompt_input):
+        prompt_dir_name = prompt_input.split(' ')
+        prompt_dir_name = '_'.join(prompt_dir_name)
+
+        attn_log_ = f"-tk{token_ind}"
+        logdir_ = f"seed{seed}" \
+                  f"{attn_log_}" \
+                  f"-stage={self.args.run_stage}"
+        logdir_sec_ = f""
+        self.args.path_svg = ""
+
+        if self.args.run_stage > 0:
+            logdir_sec_ = f"{logdir_sec_}-local={self.args.vector_local_edit}"
+            last_svg_base = os.path.join(self.args.results_path, self.args.edit_type, prompt_dir_name, logdir_[:-1] + str(self.args.run_stage - 1))
+            if self.args.run_stage != 1:
+                last_svg_base += logdir_sec_
+            self.args.path_svg = os.path.join(last_svg_base, "visual_best.svg")
+            self.args.attention_init = False
+
+        logdir_ = f"{prompt_dir_name}" + f"/" + logdir_ + logdir_sec_
+        super().__init__(self.args, log_path_suffix=logdir_)
+
+        # create log dir
+        self.png_logs_dir = self.results_path / "png_logs"
+        self.svg_logs_dir = self.results_path / "svg_logs"
+        self.attn_logs_dir = self.results_path / "attn_logs"
+        if self.accelerator.is_main_process:
+            self.png_logs_dir.mkdir(parents=True, exist_ok=True)
+            self.svg_logs_dir.mkdir(parents=True, exist_ok=True)
+            self.attn_logs_dir.mkdir(parents=True, exist_ok=True)
+
+        self.g_device = torch.Generator().manual_seed(seed)
+
+    def load_render(self, target_img, attention_map, mask=None):
+        renderer = Painter(self.args,
+                           num_strokes=self.args.num_paths,
+                           num_segments=self.args.num_segments,
+                           imsize=self.args.image_size,
+                           device=self.device,
+                           target_im=target_img,
+                           attention_map=attention_map,
+                           mask=mask)
+        return renderer
+
+    def attn_map_normalizing(self, cross_attn_map):
+        cross_attn_map = 255 * cross_attn_map / cross_attn_map.max()
+        # [res, res, 3]
+        cross_attn_map = cross_attn_map.unsqueeze(-1).expand(*cross_attn_map.shape, 3)
+        # [3, res, res]
+        cross_attn_map = cross_attn_map.permute(2, 0, 1).unsqueeze(0)
+        # [3, clip_size, clip_size]
+        cross_attn_map = F.interpolate(cross_attn_map, size=self.args.image_size, mode='bicubic')
+        cross_attn_map = torch.clamp(cross_attn_map, min=0, max=255)
+        # rgb to gray
+        cross_attn_map = rgb2gray(cross_attn_map.squeeze(0).permute(1, 2, 0)).astype(np.float32)
+        # torch to numpy
+        if cross_attn_map.shape[-1] != self.args.image_size and cross_attn_map.shape[-2] != self.args.image_size:
+            cross_attn_map = cross_attn_map.reshape(self.args.image_size, self.args.image_size)
+        # to [0, 1]
+        cross_attn_map = (cross_attn_map - cross_attn_map.min()) / (cross_attn_map.max() - cross_attn_map.min())
+        return cross_attn_map
+
+    def compute_local_edit_maps(self, cross_attn_maps_src_tar, prompts, words, save_path, threshold=0.3):
+        """
+        cross_attn_maps_src_tar: [(res, res, 77), (res, res, 77)]
+        """
+        local_edit_region = np.zeros(shape=(self.args.image_size, self.args.image_size), dtype=np.float32)
+        for i, (prompt, word) in enumerate(zip(prompts, words)):
+            ind = get_word_inds(prompt, word, self.diffusion.tokenizer)  # list
+            assert len(ind) == 1
+            ind = ind[0]
+
+            cross_attn_map = cross_attn_maps_src_tar[i][:, :, ind]  # (res, res)
+            cross_attn_map = self.attn_map_normalizing(cross_attn_map)  # (image_size, image_size), [0.0, 1.0]
+            cross_attn_map_bin = cross_attn_map >= threshold
+            local_edit_region += cross_attn_map_bin
+        local_edit_region = (np.clip(local_edit_region, 0, 1) * 255).astype(np.uint8)
+        local_edit_region = Image.fromarray(local_edit_region, 'L')
+        local_edit_region.save(save_path, 'PNG')
+
+    def extract_ldm_attn(self, prompts, token_ind, changing_region_words, reweight_word, reweight_weight):
+        ######################### Change here for editing methods #########################
+        ## init controller
+        if not self.args.attention_init:
+            controller = EmptyControl()
+        else:
+            lb = LocalBlend(prompts=prompts,
+                            words=changing_region_words, tokenizer=self.diffusion.tokenizer,
+                            device=self.device)  # changing region
+            # if self.args.edit_type == "none":
+            #     controller = AttentionStore()
+            if self.args.edit_type == "replace":
+                controller = AttentionReplace(prompts=prompts,
+                                              num_steps=self.args.num_inference_steps,
+                                              cross_replace_steps=0.4,  # larger is more similar shape
+                                              self_replace_steps=0.4,
+                                              local_blend=lb,
+                                              tokenizer=self.diffusion.tokenizer,
+                                              device=self.device)
+            elif self.args.edit_type == "refine":
+                controller = AttentionRefine(prompts=prompts,
+                                             num_steps=self.args.num_inference_steps,
+                                             cross_replace_steps=0.8,  # larger is more similar shape
+                                             self_replace_steps=0.4,
+                                             local_blend=lb,
+                                             tokenizer=self.diffusion.tokenizer,
+                                             device=self.device)
+            elif self.args.edit_type == "reweight":
+                equalizer = get_equalizer(self.diffusion.tokenizer, prompts[1:],
+                                          reweight_word, reweight_weight)
+                controller = AttentionReweight(prompts=prompts,
+                                               num_steps=self.args.num_inference_steps,
+                                               cross_replace_steps=0.8,  # larger is more similar shape
+                                               self_replace_steps=0.4,
+                                               local_blend=lb,
+                                               equalizer=equalizer,
+                                               # controller=controller_a,
+                                               tokenizer=self.diffusion.tokenizer,
+                                               device=self.device)
+            else:
+                raise Exception('Unknown edit_type:', self.args.edit_type)
+
+        ######################### Change here for editing methods (end) #########################
+
+        height = width = model2res(self.args.model_id)
+        outputs = self.diffusion(prompt=prompts,
+                                 negative_prompt=[self.args.negative_prompt] * len(prompts),
+                                 height=height,
+                                 width=width,
+                                 controller=controller,
+                                 num_inference_steps=self.args.num_inference_steps,
+                                 guidance_scale=self.args.guidance_scale,
+                                 generator=self.g_device)
+
+        print('outputs.images', len(outputs.images))
+        for ii, img in enumerate(outputs.images):
+            if ii == 0:
+                filename = "ldm_generated_image.png"
+                target_file = self.results_path / filename
+            else:
+                filename = "ldm_generated_image" + str(ii) + ".png"
+
+            target_file_tmp = self.results_path / filename
+            view_images([np.array(img)], save_image=True, fp=target_file_tmp)
+
+        if self.args.attention_init:
+            """ldm cross-attention map"""
+            cross_attention_maps, tokens = \
+                self.diffusion.get_cross_attention(prompts,
+                                                   controller,
+                                                   res=self.args.cross_attn_res,
+                                                   from_where=("up", "down"),
+                                                   save_path=self.results_path / "cross_attn.png",
+                                                   select=0)
+            for ii in range(1, len(outputs.images)):
+                cross_attn_png_name = "cross_attn" + str(ii) + ".png"
+                cross_attention_maps_i, tokens_i = \
+                    self.diffusion.get_cross_attention(prompts,
+                                                       controller,
+                                                       res=self.args.cross_attn_res,
+                                                       from_where=("up", "down"),
+                                                       save_path=self.results_path / cross_attn_png_name,
+                                                       select=ii)
+
+            self.print(f"the length of tokens is {len(tokens)}, select {token_ind}-th token")
+            # [res, res, seq_len]
+            self.print(f"origin cross_attn_map shape: {cross_attention_maps.shape}")
+            # [res, res]
+            cross_attn_map = cross_attention_maps[:, :, token_ind]
+            self.print(f"select cross_attn_map shape: {cross_attn_map.shape}\n")
+            cross_attn_map = self.attn_map_normalizing(cross_attn_map)
+
+            ######################### ldm cross-attention map (for vector local editing) #########################
+            cross_attention_maps_local_list = []
+            for ii in range(len(outputs.images)):
+                cross_attention_maps_local = \
+                    self.diffusion.get_cross_attention2(prompts,
+                                                        controller,
+                                                        res=self.args.vector_local_edit_attn_res,
+                                                        from_where=("up", "down"),
+                                                        select=ii)  # (res, res, 77)
+                cross_attention_maps_local_list.append(cross_attention_maps_local)
+
+                if ii == 0:
+                    continue
+
+                save_name = "cross_attn_local_edit_" + str(self.args.vector_local_edit_attn_res) + "-" + str(ii) + ".png"
+
+                if self.args.edit_type == "replace":
+                    self.compute_local_edit_maps([cross_attention_maps_local_list[ii-1]], [prompts[ii-1]], [changing_region_words[ii][0]],
+                                                 save_path=self.results_path / save_name,
+                                                 threshold=self.args.vector_local_edit_bin_threshold_replace)
+                elif self.args.edit_type == "refine":
+                    self.compute_local_edit_maps([cross_attention_maps_local_list[ii]], [prompts[ii]], [changing_region_words[ii][1]],
+                                                 save_path=self.results_path / save_name,
+                                                 threshold=self.args.vector_local_edit_bin_threshold_refine)
+                elif self.args.edit_type == "reweight":
+                    self.compute_local_edit_maps([cross_attention_maps_local_list[ii-1]], [prompts[ii-1]], [changing_region_words[ii][0]],
+                                                 save_path=self.results_path / save_name,
+                                                 threshold=self.args.vector_local_edit_bin_threshold_reweight)
+
+            if self.args.sd_image_only:
+                return target_file.as_posix(), None
+
+            #########################  #########################
+
+            """ldm self-attention map"""
+            self_attention_maps, svd, vh_ = \
+                self.diffusion.get_self_attention_comp(prompts,
+                                                       controller,
+                                                       res=self.args.self_attn_res,
+                                                       from_where=("up", "down"),
+                                                       img_size=self.args.image_size,
+                                                       max_com=self.args.max_com,
+                                                       save_path=self.results_path)
+
+            # comp self-attention map
+            if self.args.mean_comp:
+                self_attn = np.mean(vh_, axis=0)
+                self.print(f"use the mean of {self.args.max_com} comps.")
+            else:
+                self_attn = vh_[self.args.comp_idx]
+                self.print(f"select {self.args.comp_idx}-th comp.")
+            # to [0, 1]
+            self_attn = (self_attn - self_attn.min()) / (self_attn.max() - self_attn.min())
+            # visual final self-attention
+            self_attn_vis = np.copy(self_attn)
+            self_attn_vis = self_attn_vis * 255
+            self_attn_vis = np.repeat(np.expand_dims(self_attn_vis, axis=2), 3, axis=2).astype(np.uint8)
+            view_images(self_attn_vis, save_image=True, fp=self.results_path / "self-attn-final.png")
+
+            """attention map fusion"""
+            attn_map = self.args.attn_coeff * cross_attn_map + (1 - self.args.attn_coeff) * self_attn
+            # to [0, 1]
+            attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
+
+            self.print(f"-> fusion attn_map: {attn_map.shape}")
+        else:
+            attn_map = None
+
+        return target_file.as_posix(), attn_map
+
+    @property
+    def clip_norm_(self):
+        return transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
+
+    def clip_pair_augment(self,
+                          x: torch.Tensor,
+                          y: torch.Tensor,
+                          im_res: int,
+                          augments: str = "affine_norm",
+                          num_aug: int = 4):
+        # init augmentations
+        augment_list = []
+        if "affine" in augments:
+            augment_list.append(
+                transforms.RandomPerspective(fill=0, p=1.0, distortion_scale=0.5)
+            )
+            augment_list.append(
+                transforms.RandomResizedCrop(im_res, scale=(0.8, 0.8), ratio=(1.0, 1.0))
+            )
+        augment_list.append(self.clip_norm_)  # CLIP Normalize
+
+        # compose augmentations
+        augment_compose = transforms.Compose(augment_list)
+        # make augmentation pairs
+        x_augs, y_augs = [self.clip_score_fn.normalize(x)], [self.clip_score_fn.normalize(y)]
+        # repeat N times
+        for n in range(num_aug):
+            augmented_pair = augment_compose(torch.cat([x, y]))
+            x_augs.append(augmented_pair[0].unsqueeze(0))
+            y_augs.append(augmented_pair[1].unsqueeze(0))
+        xs = torch.cat(x_augs, dim=0)
+        ys = torch.cat(y_augs, dim=0)
+        return xs, ys
+
+    def painterly_rendering(self, prompts, token_ind, changing_region_words, reweight_word, reweight_weight):
+        # log prompts
+        self.print(f"prompts: {prompts}")
+        self.print(f"negative_prompt: {self.args.negative_prompt}")
+        self.print(f"token_ind: {token_ind}")
+        self.print(f"changing_region_words: {changing_region_words}")
+        self.print(f"reweight_word: {reweight_word}")
+        self.print(f"reweight_weight: {reweight_weight}\n")
+        if self.args.negative_prompt is None:
+            self.args.negative_prompt = ""
+
+        log_path = os.path.join(self.results_path.as_posix(), 'log.txt')
+        with open(log_path, "w") as f:
+            f.write("prompts: " + str(prompts) + "\n")
+            f.write("negative_prompt: " + self.args.negative_prompt + "\n")
+            f.write("token_ind: " + str(token_ind) + "\n")
+            f.write("changing_region_words: " + str(changing_region_words) + "\n")
+            f.write("reweight_word: " + str(reweight_word) + "\n")
+            f.write("reweight_weight: " + str(reweight_weight) + "\n")
+            f.close()
+
+        # init attention
+        if self.args.run_stage == 0:
+            target_file, attention_map = self.extract_ldm_attn(prompts, token_ind, changing_region_words,
+                                                               reweight_word, reweight_weight)
+        else:
+            results_base = self.results_path.as_posix()
+            target_file = os.path.join(results_base[:results_base.find('stage=' + str(self.args.run_stage))] + 'stage=0', "ldm_generated_image" + str(self.args.run_stage) + ".png")
+            attention_map = None
+
+        if not self.args.sd_image_only:
+            # timesteps_ = self.diffusion.scheduler.timesteps.cpu().numpy().tolist()
+            # self.print(f"{len(timesteps_)} denoising steps, {timesteps_}")
+
+            perceptual_loss_fn = None
+            if self.args.perceptual.coeff > 0:
+                if self.args.perceptual.name == "lpips":
+                    lpips_loss_fn = LPIPS(net=self.args.perceptual.lpips_net).to(self.device)
+                    perceptual_loss_fn = partial(lpips_loss_fn.forward, return_per_layer=False, normalize=False)
+                elif self.args.perceptual.name == "dists":
+                    perceptual_loss_fn = DISTS_PIQ()
+
+            inputs, mask = self.get_target(target_file,
+                                           self.args.image_size,
+                                           self.results_path,
+                                           self.args.u2net_path,
+                                           self.args.mask_object,
+                                           self.args.fix_scale,
+                                           self.device)
+            inputs = inputs.detach()  # inputs as GT
+            self.print("inputs shape: ", inputs.shape)
+
+            # load renderer
+            renderer = Painter(self.args,
+                               num_strokes=self.args.num_paths,
+                               num_segments=self.args.num_segments,
+                               imsize=self.args.image_size,
+                               device=self.device,
+                               target_im=inputs,
+                               attention_map=attention_map,
+                               mask=mask,
+                               results_base=self.results_path.as_posix())
+
+            # init img
+            img = renderer.init_image(stage=0)
+            self.print("init_image shape: ", img.shape)
+            log_tensor_img(img, self.results_path, output_prefix="init_sketch")
+            # load optimizer
+            optimizer = SketchPainterOptimizer(renderer,
+                                               self.args.lr,
+                                               self.args.optim_opacity,
+                                               self.args.optim_rgba,
+                                               self.args.color_lr,
+                                               self.args.optim_width,
+                                               self.args.width_lr)
+            optimizer.init_optimizers()
+
+            # log params
+            self.print(f"-> Painter points Params: {len(renderer.get_points_params())}")
+            self.print(f"-> Painter width Params: {len(renderer.get_width_parameters())}")
+            self.print(f"-> Painter opacity Params: {len(renderer.get_color_parameters())}")
+
+            best_visual_loss, best_semantic_loss = 100, 100
+            best_iter_v, best_iter_s = 0, 0
+            min_delta = 1e-6
+            vid_idx = 1
+
+            self.print(f"\ntotal optimization steps: {self.args.num_iter}")
+            with tqdm(initial=self.step, total=self.args.num_iter, disable=not self.accelerator.is_main_process) as pbar:
+                while self.step < self.args.num_iter:
+                    raster_sketch = renderer.get_image().to(self.device)
+
+                    target_prompt = prompts[self.args.run_stage]
+
+                    # ASDS loss
+                    sds_loss, grad = torch.tensor(0), torch.tensor(0)
+                    if self.step >= self.args.sds.warmup:
+                        grad_scale = self.args.sds.grad_scale if self.step > self.args.sds.warmup else 0
+                        sds_loss, grad = self.diffusion.score_distillation_sampling(
+                            raster_sketch,
+                            crop_size=self.args.sds.crop_size,
+                            augments=self.args.sds.augmentations,
+                            prompt=[target_prompt],
+                            negative_prompt=[self.args.negative_prompt],
+                            guidance_scale=self.args.sds.guidance_scale,
+                            grad_scale=grad_scale,
+                            t_range=list(self.args.sds.t_range),
+                        )
+
+                    # CLIP data augmentation
+                    raster_sketch_aug, inputs_aug = self.clip_pair_augment(
+                        raster_sketch, inputs,
+                        im_res=224,
+                        augments=self.cargs.augmentations,
+                        num_aug=self.cargs.num_aug
+                    )
+                    # raster_sketch: (1, 3, 224, 224), [0, 1]
+                    # inputs: (1, 3, 224, 224), [0, 1]
+                    # raster_sketch_aug: (5, 3, 224, 224), [2+, -1.7]
+                    # inputs_aug: (5, 3, 224, 224), [2+, -1.7]
+
+                    # clip visual loss
+                    total_visual_loss = torch.tensor(0)
+                    l_clip_fc, l_clip_conv, clip_conv_loss_sum = torch.tensor(0), [], torch.tensor(0)
+                    if self.args.clip.vis_loss > 0:
+                        l_clip_fc, l_clip_conv = self.clip_score_fn.compute_visual_distance(
+                            raster_sketch_aug, inputs_aug, clip_norm=False
+                        )
+                        clip_conv_loss_sum = sum(l_clip_conv)
+                        total_visual_loss = self.args.clip.vis_loss * (clip_conv_loss_sum + l_clip_fc)
+
+                    # perceptual loss
+                    l_percep = torch.tensor(0.)
+                    if perceptual_loss_fn is not None:
+                        l_perceptual = perceptual_loss_fn(raster_sketch, inputs).mean()
+                        l_percep = l_perceptual * self.args.perceptual.coeff
+
+                    # text-visual loss
+                    l_tvd = torch.tensor(0.)
+                    if self.cargs.text_visual_coeff > 0:
+                        l_tvd = self.clip_score_fn.compute_text_visual_distance(
+                            raster_sketch_aug, target_prompt
+                        ) * self.cargs.text_visual_coeff
+
+                    # total loss
+                    loss = sds_loss + total_visual_loss + l_percep + l_tvd
+
+                    # optimization
+                    optimizer.zero_grad_()
+                    loss.backward()
+                    optimizer.step_()
+
+                    # if self.step % self.args.pruning_freq == 0:
+                    #     renderer.path_pruning()
+
+                    # update lr
+                    if self.args.lr_scheduler:
+                        optimizer.update_lr(self.step, self.args.lr, self.args.decay_steps)
+
+                    # records
+                    pbar.set_description(
+                        f"lr: {optimizer.get_lr():.2f}, "
+                        f"l_total: {loss.item():.4f}, "
+                        f"l_clip_fc: {l_clip_fc.item():.4f}, "
+                        f"l_clip_conv({len(l_clip_conv)}): {clip_conv_loss_sum.item():.4f}, "
+                        f"l_tvd: {l_tvd.item():.4f}, "
+                        f"l_percep: {l_percep.item():.4f}, "
+                        f"sds: {grad.item():.4e}"
+                    )
+
+                    # log video
+                    if self.args.make_video and (self.step % self.args.video_frame_freq == 0) \
+                            and self.accelerator.is_main_process:
+                        log_tensor_img(raster_sketch, output_dir=self.png_logs_dir,
+                                       output_prefix=f'frame{vid_idx}', dpi=100)
+                        vid_idx += 1
+
+                    # log raster and svg
+                    if self.step % self.args.save_step == 0 and self.accelerator.is_main_process:
+                        # log png
+                        plt_batch(inputs,
+                                  raster_sketch,
+                                  self.step,
+                                  target_prompt,
+                                  save_path=self.png_logs_dir.as_posix(),
+                                  name=f"iter{self.step}")
+                        # log svg
+                        renderer.save_svg(self.svg_logs_dir.as_posix(), f"svg_iter{self.step}")
+
+                        # log cross attn
+                        if self.args.log_cross_attn:
+                            controller = AttentionStore()
+                            _, _ = self.diffusion.get_cross_attention([target_prompt],
+                                                                      controller,
+                                                                      res=self.args.cross_attn_res,
+                                                                      from_where=("up", "down"),
+                                                                      save_path=self.attn_logs_dir / f"iter{self.step}.png")
+
+                    # logging the best raster images and SVG
+                    if self.step % self.args.eval_step == 0 and self.accelerator.is_main_process:
+                        with torch.no_grad():
+                            # visual metric
+                            l_clip_fc, l_clip_conv = self.clip_score_fn.compute_visual_distance(
+                                raster_sketch_aug, inputs_aug, clip_norm=False
+                            )
+                            loss_eval = sum(l_clip_conv) + l_clip_fc
+
+                            cur_delta = loss_eval.item() - best_visual_loss
+                            if abs(cur_delta) > min_delta and cur_delta < 0:
+                                best_visual_loss = loss_eval.item()
+                                best_iter_v = self.step
+                                plt_batch(inputs,
+                                          raster_sketch,
+                                          best_iter_v,
+                                          target_prompt,
+                                          save_path=self.results_path.as_posix(),
+                                          name="visual_best")
+                                renderer.save_svg(self.results_path.as_posix(), "visual_best")
+
+                            # semantic metric
+                            loss_eval = self.clip_score_fn.compute_text_visual_distance(
+                                raster_sketch_aug, target_prompt
+                            )
+                            cur_delta = loss_eval.item() - best_semantic_loss
+                            if abs(cur_delta) > min_delta and cur_delta < 0:
+                                best_semantic_loss = loss_eval.item()
+                                best_iter_s = self.step
+                                plt_batch(inputs,
+                                          raster_sketch,
+                                          best_iter_s,
+                                          target_prompt,
+                                          save_path=self.results_path.as_posix(),
+                                          name="semantic_best")
+                                renderer.save_svg(self.results_path.as_posix(), "semantic_best")
+
+                    # log attention
+                    if self.step == 0 and self.args.attention_init and self.accelerator.is_main_process:
+                        plt_attn(renderer.get_attn(),
+                                 renderer.get_thresh(),
+                                 inputs,
+                                 renderer.get_inds(),
+                                 (self.results_path / "attention_map.jpg").as_posix())
+
+                    self.step += 1
+                    pbar.update(1)
+
+            # saving final svg
+            renderer.save_svg(self.svg_logs_dir.as_posix(), "final_svg_tmp")
+            # stroke pruning
+            if self.args.opacity_delta != 0:
+                remove_low_opacity_paths(self.svg_logs_dir / "final_svg_tmp.svg",
+                                         self.results_path / "final_svg.svg",
+                                         self.args.opacity_delta)
+
+            # save raster img
+            final_raster_sketch = renderer.get_image().to(self.device)
+            save_tensor_img(final_raster_sketch,
+                            save_path=self.results_path,
+                            name='final_render')
+
+            # convert the intermediate renderings to a video
+            if self.args.make_video:
+                from subprocess import call
+                call([
+                    "ffmpeg",
+                    "-framerate", 24,
+                    "-i", (self.png_logs_dir / "frame%d.png").as_posix(),
+                    "-vb", "20M",
+                    (self.results_path / "out.mp4").as_posix()
+                ])
+
+        # self.close(msg="painterly rendering complete.")
+
+    def get_target(self,
+                   target_file,
+                   image_size,
+                   output_dir,
+                   u2net_path,
+                   mask_object,
+                   fix_scale,
+                   device):
+        if not is_image_file(target_file):
+            raise TypeError(f"{target_file} is not image file.")
+
+        target = Image.open(target_file)
+
+        if target.mode == "RGBA":
+            # Create a white rgba background
+            new_image = Image.new("RGBA", target.size, "WHITE")
+            # Paste the image on the background.
+            new_image.paste(target, (0, 0), target)
+            target = new_image
+        target = target.convert("RGB")
+
+        # U2Net mask
+        mask = target
+        if mask_object:
+            if pathlib.Path(u2net_path).exists():
+                masked_im, mask = get_mask_u2net(target, output_dir, u2net_path, device)
+                target = masked_im
+            else:
+                self.print(f"'{u2net_path}' is not exist, disable mask target")
+
+        if fix_scale:
+            target = fix_image_scale(target)
+
+        # define image transforms
+        transforms_ = []
+        if target.size[0] != target.size[1]:
+            transforms_.append(transforms.Resize((image_size, image_size)))
+        else:
+            transforms_.append(transforms.Resize(image_size))
+            transforms_.append(transforms.CenterCrop(image_size))
+        transforms_.append(transforms.ToTensor())
+
+        # preprocess
+        data_transforms = transforms.Compose(transforms_)
+        target_ = data_transforms(target).unsqueeze(0).to(self.device)
+
+        return target_, mask
diff --git a/run_painterly_render.py b/run_painterly_render.py
new file mode 100644
index 0000000000000000000000000000000000000000..a441c3581af8fbb8dbcd7f01ade1af1428224f9a
--- /dev/null
+++ b/run_painterly_render.py
@@ -0,0 +1,141 @@
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+import sys
+import argparse
+
+from accelerate.utils import set_seed
+
+sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])
+
+from libs.engine import merge_and_update_config
+from libs.utils.argparse import accelerate_parser, base_data_parser
+from pipelines.painter.diffsketchedit_pipeline import DiffSketchEditPipeline
+
+
+class PromptInfo:
+    def __init__(self, prompts, token_ind, changing_region_words, reweight_word=None, reweight_weight=None):
+        self.prompts = prompts
+        self.token_ind = token_ind
+        self.changing_region_words = changing_region_words
+        self.reweight_word = reweight_word
+        self.reweight_weight = reweight_weight
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(
+        description="vary style and content painterly rendering",
+        parents=[accelerate_parser(), base_data_parser()]
+    )
+    # config
+    parser.add_argument("-c", "--config",
+                        type=str,
+                        default="diffsketchedit.yaml",
+                        help="YAML/YML file for configuration.")
+
+    parser.add_argument("-style", "--style_file",
+                        default="", type=str,
+                        help="the path of style img place.")
+
+    # result path
+    parser.add_argument("-respath", "--results_path",
+                        type=str, default="./workdir",
+                        help="If it is None, it is automatically generated.")
+    parser.add_argument("-npt", "--negative_prompt", default="", type=str)
+
+    parser.add_argument("--sd_image_only", default=0, type=int,
+                        help="1 for generating the SD images only; 0 for generating the subsequent vector sketches.")
+
+    parser.add_argument("--vector_local_edit", default=1, type=int)
+    parser.add_argument("--vector_local_edit_bin_threshold_replace", default=0.3, type=float)
+    parser.add_argument("--vector_local_edit_bin_threshold_refine", default=0.3, type=float)
+    parser.add_argument("--vector_local_edit_bin_threshold_reweight", default=0.3, type=float)
+    parser.add_argument("--vector_local_edit_attn_res", default=16, choices=[16, 32, 64], type=int)
+
+    # DiffSVG
+    parser.add_argument("--print_timing", "-timing", action="store_true",
+                        help="set print svg rendering timing.")
+    # diffuser
+    parser.add_argument("--download", default=0, type=int,
+                        help="download models from huggingface automatically.")
+    parser.add_argument("--force_download", "-download", action="store_true",
+                        help="force the models to be downloaded from huggingface.")
+    parser.add_argument("--resume_download", "-dpm_resume", action="store_true",
+                        help="download the models again from the breakpoint.")
+    # rendering quantity
+    # like: python main.py -rdbz -srange 100 200
+    parser.add_argument("--render_batch", "-rdbz", action="store_true")
+    parser.add_argument("-srange", "--seed_range",
+                        required=False, nargs='+',
+                        help="Sampling quantity.")
+    # visual rendering process
+    parser.add_argument("-mv", "--make_video", action="store_true",
+                        help="make a video of the rendering process.")
+    parser.add_argument("-frame_freq", "--video_frame_freq",
+                        default=1, type=int,
+                        help="video frame control.")
+    args = parser.parse_args()
+
+    args = merge_and_update_config(args)
+
+    ############################### main parameters ###############################
+
+    seeds_list = [25760]
+    # seeds_list = [random.randint(1, 65536) for _ in range(100)]
+    args.edit_type = "replace"  # ["replace", "refine", "reweight"]
+    prompt_infos = [
+        ## "replace" examples
+        PromptInfo(prompts=["A painting of a squirrel eating a burger",
+                            "A painting of a rabbit eating a burger",
+                            "A painting of a rabbit eating a pumpkin",
+                            "A painting of a owl eating a pumpkin"],
+                   token_ind=5,
+                   changing_region_words=[["", ""], ["squirrel", "rabbit"], ["burger", "pumpkin"], ["rabbit", "owl"]]),
+
+        # PromptInfo(prompts=["A boy wearing a cap",
+        #                     "A boy wearing a beanie"],
+        #            token_ind=2,
+        #            changing_region_words=[["", ""], ["cap", "beanie"]]),
+
+        # PromptInfo(prompts=["A desk near the bookshelf",
+        #                     "A chair near the bookshelf"],
+        #            token_ind=2,
+        #            changing_region_words=[["", ""], ["desk", "chair"]]),
+
+        ## "refine" examples
+        # PromptInfo(prompts=["An evening dress",
+        #                     "An evening dress with sleeves",
+        #                     "An evening dress with sleeves and a belt"],
+        #            token_ind=3,
+        #            changing_region_words=[["", ""], ["", "sleeves"], ["", "belt"]]),
+
+        ## "reweight" examples
+        # PromptInfo(prompts=["A face with moustache and smile"] * 3,
+        #            token_ind=2,
+        #            changing_region_words=[["", ""], ["moustache", "moustache"], ["smile", "smile"]],
+        #            reweight_word=["moustache", "smile"],
+        #            reweight_weight=[-1.0, 3.0]),
+
+        # PromptInfo(prompts=["A photo of a birthday cake with candles"] * 2,
+        #            token_ind=6,
+        #            changing_region_words=[["", ""], ["candles", "candles"]],
+        #            reweight_word=["candles"],
+        #            reweight_weight=[-5.0])
+    ]
+
+    ############################### main parameters (end) ###############################
+
+    args.batch_size = 1  # rendering one SVG at a time
+    pipe = DiffSketchEditPipeline(args)
+
+    for seed in seeds_list:
+        for prompt_info in prompt_infos:
+            run_stages = len(prompt_info.prompts)
+            for run_stage in range(run_stages):
+                args.run_stage = run_stage
+                set_seed(seed)
+                pipe.update_info(seed, prompt_info.token_ind, prompt_info.prompts[0])
+                pipe.painterly_rendering(prompt_info.prompts,
+                                         prompt_info.token_ind, prompt_info.changing_region_words,
+                                         reweight_word=prompt_info.reweight_word, reweight_weight=prompt_info.reweight_weight)
+                pipe.close(msg="painterly rendering complete.")