Spaces:
Sleeping
Sleeping
Commit
·
0c0d385
1
Parent(s):
02b0827
HarmonyView update
Browse files- .idea/workspace.xml +30 -17
- app.py +5 -4
- ldm/models/diffusion/sync_dreamer.py +45 -27
.idea/workspace.xml
CHANGED
|
@@ -4,15 +4,10 @@
|
|
| 4 |
<option name="autoReloadType" value="SELECTIVE" />
|
| 5 |
</component>
|
| 6 |
<component name="ChangeListManager">
|
| 7 |
-
<list default="true" id="a993d736-6297-4164-9c29-6b2ab1055a96" name="변경" comment="
|
| 8 |
-
<change afterPath="$PROJECT_DIR$/hf_demo/examples/cat.png" afterDir="false" />
|
| 9 |
-
<change afterPath="$PROJECT_DIR$/hf_demo/examples/crab.png" afterDir="false" />
|
| 10 |
-
<change afterPath="$PROJECT_DIR$/hf_demo/examples/elephant.png" afterDir="false" />
|
| 11 |
-
<change afterPath="$PROJECT_DIR$/hf_demo/examples/flower.png" afterDir="false" />
|
| 12 |
-
<change afterPath="$PROJECT_DIR$/hf_demo/examples/forest.png" afterDir="false" />
|
| 13 |
-
<change afterPath="$PROJECT_DIR$/hf_demo/examples/monkey.png" afterDir="false" />
|
| 14 |
-
<change afterPath="$PROJECT_DIR$/hf_demo/examples/teapot.png" afterDir="false" />
|
| 15 |
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
|
|
|
|
|
|
| 16 |
</list>
|
| 17 |
<option name="SHOW_DIALOG" value="false" />
|
| 18 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
|
@@ -34,14 +29,14 @@
|
|
| 34 |
<option name="hideEmptyMiddlePackages" value="true" />
|
| 35 |
<option name="showLibraryContents" value="true" />
|
| 36 |
</component>
|
| 37 |
-
<component name="PropertiesComponent"
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
}
|
| 44 |
-
}
|
| 45 |
<component name="RecentsManager">
|
| 46 |
<key name="CopyFile.RECENT_KEYS">
|
| 47 |
<recent name="$PROJECT_DIR$" />
|
|
@@ -104,7 +99,23 @@
|
|
| 104 |
<option name="project" value="LOCAL" />
|
| 105 |
<updated>1703061633630</updated>
|
| 106 |
</task>
|
| 107 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
<servers />
|
| 109 |
</component>
|
| 110 |
<component name="Vcs.Log.Tabs.Properties">
|
|
@@ -120,6 +131,8 @@
|
|
| 120 |
</component>
|
| 121 |
<component name="VcsManagerConfiguration">
|
| 122 |
<MESSAGE value="error resolve" />
|
| 123 |
-
<
|
|
|
|
|
|
|
| 124 |
</component>
|
| 125 |
</project>
|
|
|
|
| 4 |
<option name="autoReloadType" value="SELECTIVE" />
|
| 5 |
</component>
|
| 6 |
<component name="ChangeListManager">
|
| 7 |
+
<list default="true" id="a993d736-6297-4164-9c29-6b2ab1055a96" name="변경" comment="change title">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
| 9 |
+
<change beforePath="$PROJECT_DIR$/app.py" beforeDir="false" afterPath="$PROJECT_DIR$/app.py" afterDir="false" />
|
| 10 |
+
<change beforePath="$PROJECT_DIR$/ldm/models/diffusion/sync_dreamer.py" beforeDir="false" afterPath="$PROJECT_DIR$/ldm/models/diffusion/sync_dreamer.py" afterDir="false" />
|
| 11 |
</list>
|
| 12 |
<option name="SHOW_DIALOG" value="false" />
|
| 13 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
|
|
|
| 29 |
<option name="hideEmptyMiddlePackages" value="true" />
|
| 30 |
<option name="showLibraryContents" value="true" />
|
| 31 |
</component>
|
| 32 |
+
<component name="PropertiesComponent"><![CDATA[{
|
| 33 |
+
"keyToString": {
|
| 34 |
+
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
| 35 |
+
"RunOnceActivity.ShowReadmeOnStart": "true",
|
| 36 |
+
"git-widget-placeholder": "main",
|
| 37 |
+
"last_opened_file_path": "/home/byeongjun/PycharmProjects/cvpr2024"
|
| 38 |
}
|
| 39 |
+
}]]></component>
|
| 40 |
<component name="RecentsManager">
|
| 41 |
<key name="CopyFile.RECENT_KEYS">
|
| 42 |
<recent name="$PROJECT_DIR$" />
|
|
|
|
| 99 |
<option name="project" value="LOCAL" />
|
| 100 |
<updated>1703061633630</updated>
|
| 101 |
</task>
|
| 102 |
+
<task id="LOCAL-00006" summary="add example code">
|
| 103 |
+
<option name="closed" value="true" />
|
| 104 |
+
<created>1703069567948</created>
|
| 105 |
+
<option name="number" value="00006" />
|
| 106 |
+
<option name="presentableId" value="LOCAL-00006" />
|
| 107 |
+
<option name="project" value="LOCAL" />
|
| 108 |
+
<updated>1703069567948</updated>
|
| 109 |
+
</task>
|
| 110 |
+
<task id="LOCAL-00007" summary="change title">
|
| 111 |
+
<option name="closed" value="true" />
|
| 112 |
+
<created>1703070569206</created>
|
| 113 |
+
<option name="number" value="00007" />
|
| 114 |
+
<option name="presentableId" value="LOCAL-00007" />
|
| 115 |
+
<option name="project" value="LOCAL" />
|
| 116 |
+
<updated>1703070569206</updated>
|
| 117 |
+
</task>
|
| 118 |
+
<option name="localTasksCounter" value="8" />
|
| 119 |
<servers />
|
| 120 |
</component>
|
| 121 |
<component name="Vcs.Log.Tabs.Properties">
|
|
|
|
| 131 |
</component>
|
| 132 |
<component name="VcsManagerConfiguration">
|
| 133 |
<MESSAGE value="error resolve" />
|
| 134 |
+
<MESSAGE value="add example code" />
|
| 135 |
+
<MESSAGE value="change title" />
|
| 136 |
+
<option name="LAST_COMMIT_MESSAGE" value="change title" />
|
| 137 |
</component>
|
| 138 |
</project>
|
app.py
CHANGED
|
@@ -79,7 +79,7 @@ def resize_inputs(image_input, crop_size):
|
|
| 79 |
results = add_margin(ref_img_, size=256)
|
| 80 |
return results
|
| 81 |
|
| 82 |
-
def generate(model, sample_steps, batch_view_num, sample_num,
|
| 83 |
if deployed:
|
| 84 |
assert isinstance(model, SyncMultiviewDiffusion)
|
| 85 |
seed=int(seed)
|
|
@@ -104,7 +104,7 @@ def generate(model, sample_steps, batch_view_num, sample_num, cfg_scale, seed, i
|
|
| 104 |
|
| 105 |
if deployed:
|
| 106 |
sampler = SyncDDIMSampler(model, sample_steps)
|
| 107 |
-
x_sample = model.sample(sampler, data,
|
| 108 |
else:
|
| 109 |
x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
|
| 110 |
|
|
@@ -225,7 +225,8 @@ def run_demo():
|
|
| 225 |
input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
|
| 226 |
elevation.render()
|
| 227 |
with gr.Accordion('Advanced options', open=False):
|
| 228 |
-
|
|
|
|
| 229 |
sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=False, info='How many instance (16 images per instance)')
|
| 230 |
sample_steps = gr.Slider(10, 300, 50, step=10, label='Sample steps', interactive=False)
|
| 231 |
batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
|
|
@@ -252,7 +253,7 @@ def run_demo():
|
|
| 252 |
# crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
|
| 253 |
# .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
|
| 254 |
|
| 255 |
-
run_btn.click(partial(generate, model), inputs=[sample_steps, batch_view_num, sample_num,
|
| 256 |
.success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False)
|
| 257 |
|
| 258 |
demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
|
|
|
|
| 79 |
results = add_margin(ref_img_, size=256)
|
| 80 |
return results
|
| 81 |
|
| 82 |
+
def generate(model, sample_steps, batch_view_num, sample_num, cfg_scale_1, cfg_scale_2, seed, image_input, elevation_input):
|
| 83 |
if deployed:
|
| 84 |
assert isinstance(model, SyncMultiviewDiffusion)
|
| 85 |
seed=int(seed)
|
|
|
|
| 104 |
|
| 105 |
if deployed:
|
| 106 |
sampler = SyncDDIMSampler(model, sample_steps)
|
| 107 |
+
x_sample = model.sample(sampler, data, (cfg_scale_1, cfg_scale_2), batch_view_num)
|
| 108 |
else:
|
| 109 |
x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
|
| 110 |
|
|
|
|
| 225 |
input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
|
| 226 |
elevation.render()
|
| 227 |
with gr.Accordion('Advanced options', open=False):
|
| 228 |
+
cfg_scale_1 = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
|
| 229 |
+
cfg_scale_2 = gr.Slider(0.5, 1.5, 1.0, step=0.1, label='Classifier free guidance', interactive=True)
|
| 230 |
sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=False, info='How many instance (16 images per instance)')
|
| 231 |
sample_steps = gr.Slider(10, 300, 50, step=10, label='Sample steps', interactive=False)
|
| 232 |
batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
|
|
|
|
| 253 |
# crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
|
| 254 |
# .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
|
| 255 |
|
| 256 |
+
run_btn.click(partial(generate, model), inputs=[sample_steps, batch_view_num, sample_num, cfg_scale_1, cfg_scale_2, seed, input_block, elevation], outputs=[output_block], queue=True)\
|
| 257 |
.success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False)
|
| 258 |
|
| 259 |
demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD'])
|
ldm/models/diffusion/sync_dreamer.py
CHANGED
|
@@ -110,6 +110,7 @@ class UNetWrapper(nn.Module):
|
|
| 110 |
v_[k] = torch.cat([v, torch.zeros_like(v)], 0)
|
| 111 |
|
| 112 |
x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat)], 0)
|
|
|
|
| 113 |
if self.use_zero_123:
|
| 114 |
# zero123 does not multiply this when encoding, maybe a bug for zero123
|
| 115 |
first_stage_scale_factor = 0.18215
|
|
@@ -119,6 +120,24 @@ class UNetWrapper(nn.Module):
|
|
| 119 |
s = s_uc + unconditional_scale * (s - s_uc)
|
| 120 |
return s
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
class SpatialVolumeNet(nn.Module):
|
| 124 |
def __init__(self, time_dim, view_dim, view_num,
|
|
@@ -156,13 +175,12 @@ class SpatialVolumeNet(nn.Module):
|
|
| 156 |
device = x.device
|
| 157 |
|
| 158 |
spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device)
|
| 159 |
-
spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts), -1)
|
| 160 |
spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)]
|
| 161 |
spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1)
|
| 162 |
|
| 163 |
# encode source features
|
| 164 |
t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim)
|
| 165 |
-
# v_embed_ = v_embed.view(1, N, self.view_dim).repeat(B, 1, 1).view(B, N, self.view_dim)
|
| 166 |
v_embed_ = v_embed
|
| 167 |
target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1)
|
| 168 |
target_poses = target_poses.unsqueeze(0).repeat(B, 1, 1, 1)
|
|
@@ -227,7 +245,8 @@ class SyncMultiviewDiffusion(pl.LightningModule):
|
|
| 227 |
view_num=16, image_size=256,
|
| 228 |
cfg_scale=3.0, output_num=8, batch_view_num=4,
|
| 229 |
drop_conditions=False, drop_scheme='default',
|
| 230 |
-
clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt"
|
|
|
|
| 231 |
super().__init__()
|
| 232 |
|
| 233 |
self.finetune_unet = finetune_unet
|
|
@@ -255,7 +274,10 @@ class SyncMultiviewDiffusion(pl.LightningModule):
|
|
| 255 |
self.scheduler_config = scheduler_config
|
| 256 |
|
| 257 |
latent_size = image_size//8
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
def _init_clip_projection(self):
|
| 261 |
self.cc_projection = nn.Linear(772, 768)
|
|
@@ -468,9 +490,9 @@ class SyncMultiviewDiffusion(pl.LightningModule):
|
|
| 468 |
x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
|
| 469 |
return x_noisy, noise
|
| 470 |
|
| 471 |
-
def sample(self, sampler, batch, cfg_scale,
|
| 472 |
_, clip_embed, input_info = self.prepare(batch)
|
| 473 |
-
x_sample, inter = sampler.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval
|
| 474 |
|
| 475 |
N = x_sample.shape[1]
|
| 476 |
x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
|
|
@@ -509,7 +531,7 @@ class SyncMultiviewDiffusion(pl.LightningModule):
|
|
| 509 |
step = self.global_step
|
| 510 |
batch_ = {}
|
| 511 |
for k, v in batch.items(): batch_[k] = v[:self.output_num]
|
| 512 |
-
x_sample = self.sample(batch_, self.cfg_scale
|
| 513 |
output_dir = Path(self.image_dir) / 'images' / 'val'
|
| 514 |
output_dir.mkdir(exist_ok=True, parents=True)
|
| 515 |
self.log_image(x_sample, batch, step, output_dir=output_dir)
|
|
@@ -588,7 +610,7 @@ class SyncDDIMSampler:
|
|
| 588 |
return x_prev
|
| 589 |
|
| 590 |
@torch.no_grad()
|
| 591 |
-
def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale,
|
| 592 |
"""
|
| 593 |
@param x_target_noisy: B,N,4,H,W
|
| 594 |
@param input_info:
|
|
@@ -596,7 +618,6 @@ class SyncDDIMSampler:
|
|
| 596 |
@param time_steps: B,
|
| 597 |
@param index: int
|
| 598 |
@param unconditional_scale:
|
| 599 |
-
@param batch_view_num: int
|
| 600 |
@param is_step0: bool
|
| 601 |
@return:
|
| 602 |
"""
|
|
@@ -608,37 +629,34 @@ class SyncDDIMSampler:
|
|
| 608 |
t_embed = self.model.embed_time(time_steps) # B,t_dim
|
| 609 |
spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks)
|
| 610 |
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
target_indices_ = target_indices[ni:ni+batch_view_num].unsqueeze(0).repeat(B,1)
|
| 620 |
-
clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats(x_input, spatial_volume, clip_embed, t_embed, v_embed, target_indices_)
|
| 621 |
-
if unconditional_scale!=1.0:
|
| 622 |
noise = self.model.model.predict_with_unconditional_scale(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
|
| 623 |
else:
|
| 624 |
noise = self.model.model(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, is_train=False)
|
| 625 |
-
|
|
|
|
| 626 |
|
| 627 |
-
|
| 628 |
-
x_prev = self.denoise_apply_impl(x_target_noisy, index,
|
| 629 |
return x_prev
|
| 630 |
|
| 631 |
@torch.no_grad()
|
| 632 |
-
def sample(self, input_info, clip_embed, unconditional_scale
|
| 633 |
"""
|
| 634 |
@param input_info: x, elevation
|
| 635 |
@param clip_embed: B,M,768
|
| 636 |
@param unconditional_scale:
|
| 637 |
@param log_every_t:
|
| 638 |
-
@param batch_view_num:
|
| 639 |
@return:
|
| 640 |
"""
|
| 641 |
-
|
| 642 |
C, H, W = 4, self.latent_size, self.latent_size
|
| 643 |
B = clip_embed.shape[0]
|
| 644 |
N = self.model.view_num
|
|
@@ -654,7 +672,7 @@ class SyncDDIMSampler:
|
|
| 654 |
for i, step in enumerate(iterator):
|
| 655 |
index = total_steps - i - 1 # index in ddim state
|
| 656 |
time_steps = torch.full((B,), step, device=device, dtype=torch.long)
|
| 657 |
-
x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale,
|
| 658 |
if index % log_every_t == 0 or index == total_steps - 1:
|
| 659 |
intermediates['x_inter'].append(x_target_noisy)
|
| 660 |
|
|
|
|
| 110 |
v_[k] = torch.cat([v, torch.zeros_like(v)], 0)
|
| 111 |
|
| 112 |
x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat)], 0)
|
| 113 |
+
|
| 114 |
if self.use_zero_123:
|
| 115 |
# zero123 does not multiply this when encoding, maybe a bug for zero123
|
| 116 |
first_stage_scale_factor = 0.18215
|
|
|
|
| 120 |
s = s_uc + unconditional_scale * (s - s_uc)
|
| 121 |
return s
|
| 122 |
|
| 123 |
+
def predict_with_decomposed_unconditional_scales(self, x, t, clip_embed, volume_feats, x_concat, unconditional_scales):
|
| 124 |
+
x_ = torch.cat([x] * 3, 0)
|
| 125 |
+
t_ = torch.cat([t] * 3, 0)
|
| 126 |
+
clip_embed_ = torch.cat([clip_embed, torch.zeros_like(clip_embed), clip_embed], 0)
|
| 127 |
+
x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat), x_concat*4], 0)
|
| 128 |
+
|
| 129 |
+
v_ = {}
|
| 130 |
+
for k, v in volume_feats.items():
|
| 131 |
+
v_[k] = torch.cat([v, v, torch.zeros_like(v)], 0)
|
| 132 |
+
|
| 133 |
+
if self.use_zero_123:
|
| 134 |
+
# zero123 does not multiply this when encoding, maybe a bug for zero123
|
| 135 |
+
first_stage_scale_factor = 0.18215
|
| 136 |
+
x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor
|
| 137 |
+
x_ = torch.cat([x_, x_concat_], 1)
|
| 138 |
+
s, s_uc1, s_uc2 = self.diffusion_model(x_, t_, clip_embed_, source_dict=v_).chunk(3)
|
| 139 |
+
s = s + unconditional_scales[0] * (s - s_uc1) + unconditional_scales[1] * (s - s_uc2)
|
| 140 |
+
return s
|
| 141 |
|
| 142 |
class SpatialVolumeNet(nn.Module):
|
| 143 |
def __init__(self, time_dim, view_dim, view_num,
|
|
|
|
| 175 |
device = x.device
|
| 176 |
|
| 177 |
spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device)
|
| 178 |
+
spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts, indexing='ij'), -1)
|
| 179 |
spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)]
|
| 180 |
spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1)
|
| 181 |
|
| 182 |
# encode source features
|
| 183 |
t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim)
|
|
|
|
| 184 |
v_embed_ = v_embed
|
| 185 |
target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1)
|
| 186 |
target_poses = target_poses.unsqueeze(0).repeat(B, 1, 1, 1)
|
|
|
|
| 245 |
view_num=16, image_size=256,
|
| 246 |
cfg_scale=3.0, output_num=8, batch_view_num=4,
|
| 247 |
drop_conditions=False, drop_scheme='default',
|
| 248 |
+
clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt",
|
| 249 |
+
sample_type='ddim', sample_steps=200):
|
| 250 |
super().__init__()
|
| 251 |
|
| 252 |
self.finetune_unet = finetune_unet
|
|
|
|
| 274 |
self.scheduler_config = scheduler_config
|
| 275 |
|
| 276 |
latent_size = image_size//8
|
| 277 |
+
if sample_type=='ddim':
|
| 278 |
+
self.sampler = SyncDDIMSampler(self, sample_steps , "uniform", 1.0, latent_size=latent_size)
|
| 279 |
+
else:
|
| 280 |
+
raise NotImplementedError
|
| 281 |
|
| 282 |
def _init_clip_projection(self):
|
| 283 |
self.cc_projection = nn.Linear(772, 768)
|
|
|
|
| 490 |
x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
|
| 491 |
return x_noisy, noise
|
| 492 |
|
| 493 |
+
def sample(self, sampler, batch, cfg_scale, return_inter_results=False, inter_interval=50, inter_view_interval=2):
|
| 494 |
_, clip_embed, input_info = self.prepare(batch)
|
| 495 |
+
x_sample, inter = sampler.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval)
|
| 496 |
|
| 497 |
N = x_sample.shape[1]
|
| 498 |
x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
|
|
|
|
| 531 |
step = self.global_step
|
| 532 |
batch_ = {}
|
| 533 |
for k, v in batch.items(): batch_[k] = v[:self.output_num]
|
| 534 |
+
x_sample = self.sample(self.sampler, batch_, self.cfg_scale)
|
| 535 |
output_dir = Path(self.image_dir) / 'images' / 'val'
|
| 536 |
output_dir.mkdir(exist_ok=True, parents=True)
|
| 537 |
self.log_image(x_sample, batch, step, output_dir=output_dir)
|
|
|
|
| 610 |
return x_prev
|
| 611 |
|
| 612 |
@torch.no_grad()
|
| 613 |
+
def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, is_step0=False):
|
| 614 |
"""
|
| 615 |
@param x_target_noisy: B,N,4,H,W
|
| 616 |
@param input_info:
|
|
|
|
| 618 |
@param time_steps: B,
|
| 619 |
@param index: int
|
| 620 |
@param unconditional_scale:
|
|
|
|
| 621 |
@param is_step0: bool
|
| 622 |
@return:
|
| 623 |
"""
|
|
|
|
| 629 |
t_embed = self.model.embed_time(time_steps) # B,t_dim
|
| 630 |
spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks)
|
| 631 |
|
| 632 |
+
target_indices_ = torch.arange(N).unsqueeze(0).repeat(B, 1)
|
| 633 |
+
x_target_noisy_ = x_target_noisy.reshape(B*N,C,H,W)
|
| 634 |
+
|
| 635 |
+
time_steps_ = repeat_to_batch(time_steps, B, N)
|
| 636 |
+
clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats(x_input, spatial_volume, clip_embed, t_embed, v_embed, target_indices_)
|
| 637 |
+
|
| 638 |
+
if type(unconditional_scale) == float: ## CFG
|
| 639 |
+
if unconditional_scale != 1.0:
|
|
|
|
|
|
|
|
|
|
| 640 |
noise = self.model.model.predict_with_unconditional_scale(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
|
| 641 |
else:
|
| 642 |
noise = self.model.model(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, is_train=False)
|
| 643 |
+
else: ## DG
|
| 644 |
+
noise = self.model.model.predict_with_decomposed_unconditional_scales(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
|
| 645 |
|
| 646 |
+
noise = noise.reshape(B, N, 4, H, W)
|
| 647 |
+
x_prev = self.denoise_apply_impl(x_target_noisy, index, noise, is_step0)
|
| 648 |
return x_prev
|
| 649 |
|
| 650 |
@torch.no_grad()
|
| 651 |
+
def sample(self, input_info, clip_embed, unconditional_scale, log_every_t=50):
|
| 652 |
"""
|
| 653 |
@param input_info: x, elevation
|
| 654 |
@param clip_embed: B,M,768
|
| 655 |
@param unconditional_scale:
|
| 656 |
@param log_every_t:
|
|
|
|
| 657 |
@return:
|
| 658 |
"""
|
| 659 |
+
|
| 660 |
C, H, W = 4, self.latent_size, self.latent_size
|
| 661 |
B = clip_embed.shape[0]
|
| 662 |
N = self.model.view_num
|
|
|
|
| 672 |
for i, step in enumerate(iterator):
|
| 673 |
index = total_steps - i - 1 # index in ddim state
|
| 674 |
time_steps = torch.full((B,), step, device=device, dtype=torch.long)
|
| 675 |
+
x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, is_step0=index==0)
|
| 676 |
if index % log_every_t == 0 or index == total_steps - 1:
|
| 677 |
intermediates['x_inter'].append(x_target_noisy)
|
| 678 |
|