Spaces:
Build error
Build error
| #!/usr/bin/env python | |
| # -*- coding:utf-8 -*- | |
| # Power by Zongsheng Yue 2022-07-13 16:59:27 | |
| import os | |
| import random | |
| import numpy as np | |
| from math import ceil | |
| from pathlib import Path | |
| from einops import rearrange | |
| from omegaconf import OmegaConf | |
| from skimage import img_as_ubyte | |
| from ResizeRight.resize_right import resize | |
| from utils import util_net | |
| from utils import util_image | |
| from utils import util_common | |
| import torch | |
| import torch.distributed as dist | |
| import torch.multiprocessing as mp | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from basicsr.utils import img2tensor | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from basicsr.utils.realesrgan_utils import RealESRGANer | |
| from facelib.utils.face_restoration_helper import FaceRestoreHelper | |
| class BaseSampler: | |
| def __init__(self, configs): | |
| ''' | |
| Input: | |
| configs: config, see the yaml file in folder ./configs/sample/ | |
| ''' | |
| self.configs = configs | |
| self.display = configs.display | |
| self.diffusion_cfg = configs.diffusion | |
| self.setup_dist() # setup distributed training: self.num_gpus, self.rank | |
| self.setup_seed() # setup seed | |
| self.build_model() | |
| def setup_seed(self, seed=None): | |
| seed = self.configs.seed if seed is None else seed | |
| seed += (self.rank+1) * 10000 | |
| if self.rank == 0 and self.display: | |
| print(f'Setting random seed {seed}') | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def setup_dist(self): | |
| if torch.cuda.is_available(): | |
| self.device = torch.device('cuda') | |
| print(f'Runing on GPU...') | |
| else: | |
| self.device = torch.device('cpu') | |
| print(f'Runing on CPU...') | |
| self.rank = 0 | |
| def build_model(self): | |
| obj = util_common.get_obj_from_str(self.configs.diffusion.target) | |
| self.diffusion = obj(**self.configs.diffusion.params) | |
| obj = util_common.get_obj_from_str(self.configs.model.target) | |
| model = obj(**self.configs.model.params).to(self.device) | |
| if not self.configs.model.ckpt_path is None: | |
| self.load_model(model, self.configs.model.ckpt_path) | |
| self.model = model | |
| self.model.eval() | |
| def load_model(self, model, ckpt_path=None): | |
| if not ckpt_path is None: | |
| if self.rank == 0 and self.display: | |
| print(f'Loading from {ckpt_path}...') | |
| ckpt = torch.load(ckpt_path, map_location=f"cuda:{self.rank}") | |
| util_net.reload_model(model, ckpt) | |
| if self.rank == 0 and self.display: | |
| print('Loaded Done') | |
| def reset_diffusion(self, diffusion_cfg): | |
| self.diffusion = create_gaussian_diffusion(**diffusion_cfg) | |
| class DifIRSampler(BaseSampler): | |
| def build_model(self): | |
| super().build_model() | |
| if not self.configs.model_ir is None: | |
| obj = util_common.get_obj_from_str(self.configs.model_ir.target) | |
| model_ir = obj(**self.configs.model_ir.params).cuda() | |
| if not self.configs.model_ir.ckpt_path is None: | |
| self.load_model(model_ir, self.configs.model_ir.ckpt_path) | |
| self.model_ir = model_ir | |
| self.model_ir.eval() | |
| if not self.configs.aligned: | |
| # face dection model | |
| self.face_helper = FaceRestoreHelper( | |
| self.configs.detection.upscale, | |
| face_size=self.configs.im_size, | |
| crop_ratio=(1, 1), | |
| det_model = self.configs.detection.det_model, | |
| save_ext='png', | |
| use_parse=True, | |
| device=self.device, | |
| ) | |
| # background super-resolution | |
| if self.configs.background_enhance or self.configs.face_upsample: | |
| bg_model = RRDBNet( | |
| num_in_ch=3, | |
| num_out_ch=3, | |
| num_feat=64, | |
| num_block=23, | |
| num_grow_ch=32, | |
| scale=2, | |
| ) | |
| self.bg_model = RealESRGANer( | |
| scale=2, | |
| model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', | |
| model=bg_model, | |
| tile=400, | |
| tile_pad=10, | |
| pre_pad=0, | |
| half=True, | |
| device=torch.device(f'cuda:{self.rank}'), | |
| ) # need to set False in CPU mode | |
| def sample_func_ir_aligned( | |
| self, | |
| y0, | |
| start_timesteps=None, | |
| post_fun=None, | |
| model_kwargs_ir=None, | |
| need_restoration=True, | |
| ): | |
| ''' | |
| Input: | |
| y0: n x c x h x w torch tensor, low-quality image, [0, 1], RGB | |
| or, h x w x c, numpy array, [0, 255], uint8, BGR | |
| start_timesteps: integer, range [0, num_timesteps-1], | |
| for accelerated sampling (e.g., 'ddim250'), range [0, 249] | |
| post_fun: post-processing for the enhanced image | |
| model_kwargs_ir: additional parameters for restoration model | |
| Output: | |
| sample: n x c x h x w, torch tensor, [0,1], RGB | |
| ''' | |
| if not isinstance(y0, torch.Tensor): | |
| y0 = img2tensor(y0, bgr2rgb=True, float32=True).unsqueeze(0) / 255. # 1 x c x h x w, [0,1] | |
| if start_timesteps is None: | |
| start_timesteps = self.diffusion.num_timesteps | |
| if post_fun is None: | |
| post_fun = lambda x: util_image.normalize_th( | |
| im=x, | |
| mean=0.5, | |
| std=0.5, | |
| reverse=False, | |
| ) | |
| # basical image restoration | |
| device = next(self.model.parameters()).device | |
| y0 = y0.to(device=device, dtype=torch.float32) | |
| if need_restoration: | |
| with torch.no_grad(): | |
| if model_kwargs_ir is None: | |
| im_hq = self.model_ir(y0) | |
| else: | |
| im_hq = self.model_ir(y0, **model_kwargs_ir) | |
| else: | |
| im_hq = y0 | |
| im_hq.clamp_(0.0, 1.0) | |
| h_old, w_old = im_hq.shape[2:4] | |
| if not (h_old == self.configs.im_size and w_old == self.configs.im_size): | |
| im_hq = resize(im_hq, out_shape=(self.configs.im_size,) * 2).to(torch.float32) | |
| # diffuse for im_hq | |
| yt = self.diffusion.q_sample( | |
| x_start=post_fun(im_hq), | |
| t=torch.tensor([start_timesteps,]*im_hq.shape[0], device=device), | |
| ) | |
| assert yt.shape[-1] == self.configs.im_size and yt.shape[-2] == self.configs.im_size | |
| if 'ddim' in self.configs.diffusion.params.timestep_respacing: | |
| sample = self.diffusion.ddim_sample_loop( | |
| self.model, | |
| shape=yt.shape, | |
| noise=yt, | |
| start_timesteps=start_timesteps, | |
| clip_denoised=True, | |
| denoised_fn=None, | |
| model_kwargs=None, | |
| device=None, | |
| progress=False, | |
| eta=0.0, | |
| ) | |
| else: | |
| sample = self.diffusion.p_sample_loop( | |
| self.model, | |
| shape=yt.shape, | |
| noise=yt, | |
| start_timesteps=start_timesteps, | |
| clip_denoised=True, | |
| denoised_fn=None, | |
| model_kwargs=None, | |
| device=None, | |
| progress=False, | |
| ) | |
| sample = util_image.normalize_th(sample, reverse=True).clamp(0.0, 1.0) | |
| if not (h_old == self.configs.im_size and w_old == self.configs.im_size): | |
| sample = resize(sample, out_shape=(h_old, w_old)).clamp(0.0, 1.0) | |
| return sample, im_hq | |
| def sample_func_bfr_unaligned( | |
| self, | |
| y0, | |
| bs=16, | |
| start_timesteps=None, | |
| post_fun=None, | |
| model_kwargs_ir=None, | |
| need_restoration=True, | |
| only_center_face=False, | |
| draw_box=False, | |
| ): | |
| ''' | |
| Input: | |
| y0: h x w x c numpy array, uint8, BGR | |
| bs: batch size for face restoration | |
| upscale: upsampling factor for the restorated image | |
| start_timesteps: integer, range [0, num_timesteps-1], | |
| for accelerated sampling (e.g., 'ddim250'), range [0, 249] | |
| post_fun: post-processing for the enhanced image | |
| model_kwargs_ir: additional parameters for restoration model | |
| only_center_face: | |
| draw_box: draw a box for each face | |
| Output: | |
| restored_img: h x w x c, numpy array, uint8, BGR | |
| restored_faces: list, h x w x c, numpy array, uint8, BGR | |
| cropped_faces: list, h x w x c, numpy array, uint8, BGR | |
| ''' | |
| def _process_batch(cropped_faces_list): | |
| length = len(cropped_faces_list) | |
| cropped_face_t = np.stack( | |
| img2tensor(cropped_faces_list, bgr2rgb=True, float32=True), | |
| axis=0) / 255. | |
| cropped_face_t = torch.from_numpy(cropped_face_t).to(torch.device(f"cuda:{self.rank}")) | |
| restored_faces = self.sample_func_ir_aligned( | |
| cropped_face_t, | |
| start_timesteps=start_timesteps, | |
| post_fun=post_fun, | |
| model_kwargs_ir=model_kwargs_ir, | |
| need_restoration=need_restoration, | |
| )[0] # [0, 1], b x c x h x w | |
| return restored_faces | |
| assert not self.configs.aligned | |
| self.face_helper.clean_all() | |
| self.face_helper.read_image(y0) | |
| num_det_faces = self.face_helper.get_face_landmarks_5( | |
| only_center_face=only_center_face, | |
| resize=640, | |
| eye_dist_threshold=5, | |
| ) | |
| # align and warp each face | |
| self.face_helper.align_warp_face() | |
| num_cropped_face = len(self.face_helper.cropped_faces) | |
| if num_cropped_face > bs: | |
| restored_faces = [] | |
| for idx_start in range(0, num_cropped_face, bs): | |
| idx_end = idx_start + bs if idx_start + bs < num_cropped_face else num_cropped_face | |
| current_cropped_faces = self.face_helper.cropped_faces[idx_start:idx_end] | |
| current_restored_faces = _process_batch(current_cropped_faces) | |
| current_restored_faces = util_image.tensor2img( | |
| list(current_restored_faces.split(1, dim=0)), | |
| rgb2bgr=True, | |
| min_max=(0, 1), | |
| out_type=np.uint8, | |
| ) | |
| restored_faces.extend(current_restored_faces) | |
| else: | |
| restored_faces = _process_batch(self.face_helper.cropped_faces) | |
| restored_faces = util_image.tensor2img( | |
| list(restored_faces.split(1, dim=0)), | |
| rgb2bgr=True, | |
| min_max=(0, 1), | |
| out_type=np.uint8, | |
| ) | |
| for xx in restored_faces: | |
| self.face_helper.add_restored_face(xx) | |
| # paste_back | |
| if self.configs.background_enhance: | |
| bg_img = self.bg_model.enhance(y0, outscale=self.configs.detection.upscale)[0] | |
| else: | |
| bg_img = None | |
| self.face_helper.get_inverse_affine(None) | |
| # paste each restored face to the input image | |
| if self.configs.face_upsample: | |
| restored_img = self.face_helper.paste_faces_to_input_image( | |
| upsample_img=bg_img, | |
| draw_box=draw_box, | |
| face_upsampler=self.bg_model, | |
| ) | |
| else: | |
| restored_img = self.face_helper.paste_faces_to_input_image( | |
| upsample_img=bg_img, | |
| draw_box=draw_box, | |
| ) | |
| cropped_faces = self.face_helper.cropped_faces | |
| return restored_img, restored_faces, cropped_faces | |
| if __name__ == '__main__': | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--save_dir", | |
| type=str, | |
| default="./save_dir", | |
| help="Folder to save the checkpoints and training log", | |
| ) | |
| parser.add_argument( | |
| "--gpu_id", | |
| type=str, | |
| default='', | |
| help="GPU Index, e.g., 025", | |
| ) | |
| parser.add_argument( | |
| "--cfg_path", | |
| type=str, | |
| default='./configs/sample/iddpm_ffhq256.yaml', | |
| help="Path of config files", | |
| ) | |
| parser.add_argument( | |
| "--bs", | |
| type=int, | |
| default=32, | |
| help="Batch size", | |
| ) | |
| parser.add_argument( | |
| "--num_images", | |
| type=int, | |
| default=3000, | |
| help="Number of sampled images", | |
| ) | |
| parser.add_argument( | |
| "--timestep_respacing", | |
| type=str, | |
| default='1000', | |
| help="Sampling steps for accelerate", | |
| ) | |
| args = parser.parse_args() | |
| configs = OmegaConf.load(args.cfg_path) | |
| configs.gpu_id = args.gpu_id | |
| configs.diffusion.params.timestep_respacing = args.timestep_respacing | |
| sampler_dist = DiffusionSampler(configs) | |
| sampler_dist.sample_func( | |
| bs=args.bs, | |
| num_images=args.num_images, | |
| save_dir=args.save_dir, | |
| ) | |