Spaces:
Runtime error
Runtime error
Commit
·
c7a40ad
1
Parent(s):
6f88700
stylestudio
Browse files
ip_adapter/attention_processor.py
CHANGED
|
@@ -866,7 +866,8 @@ class AttnProcessor2_0_hijack(torch.nn.Module):
|
|
| 866 |
hidden_states = hidden_states / attn.rescale_output_factor
|
| 867 |
|
| 868 |
if self.denoise_step == self.num_inference_step:
|
| 869 |
-
self.denoise_step
|
|
|
|
| 870 |
return hidden_states
|
| 871 |
|
| 872 |
class IPAttnProcessor2_0_cross_modal(torch.nn.Module):
|
|
@@ -1031,6 +1032,6 @@ class IPAttnProcessor2_0_cross_modal(torch.nn.Module):
|
|
| 1031 |
hidden_states = hidden_states / attn.rescale_output_factor
|
| 1032 |
|
| 1033 |
if self.denoise_step == self.num_inference_step:
|
| 1034 |
-
self.denoise_step
|
| 1035 |
|
| 1036 |
return hidden_states
|
|
|
|
| 866 |
hidden_states = hidden_states / attn.rescale_output_factor
|
| 867 |
|
| 868 |
if self.denoise_step == self.num_inference_step:
|
| 869 |
+
self.denoise_step = 0
|
| 870 |
+
|
| 871 |
return hidden_states
|
| 872 |
|
| 873 |
class IPAttnProcessor2_0_cross_modal(torch.nn.Module):
|
|
|
|
| 1032 |
hidden_states = hidden_states / attn.rescale_output_factor
|
| 1033 |
|
| 1034 |
if self.denoise_step == self.num_inference_step:
|
| 1035 |
+
self.denoise_step = 0
|
| 1036 |
|
| 1037 |
return hidden_states
|
ip_adapter/ip_adapter.py
CHANGED
|
@@ -965,9 +965,7 @@ class StyleStudio_Adapter(CSGO):
|
|
| 965 |
if cross_attention_dim is None:
|
| 966 |
attn_procs[name] = AttnProcessor_hijack(
|
| 967 |
fuSAttn=self.fuSAttn,
|
| 968 |
-
|
| 969 |
-
end_fusion=self.end_fusion,
|
| 970 |
-
attn_name=name)
|
| 971 |
else:
|
| 972 |
# layername_id += 1
|
| 973 |
selected = False
|
|
@@ -984,9 +982,7 @@ class StyleStudio_Adapter(CSGO):
|
|
| 984 |
fuAttn=self.fuAttn,
|
| 985 |
fuIPAttn=self.fuIPAttn,
|
| 986 |
adainIP=self.adainIP,
|
| 987 |
-
fuScale=self.fuScale,
|
| 988 |
end_fusion=self.end_fusion,
|
| 989 |
-
attn_name=name,
|
| 990 |
)
|
| 991 |
if selected is False:
|
| 992 |
attn_procs[name] = IPAttnProcessor_cross_modal(
|
|
@@ -995,12 +991,9 @@ class StyleStudio_Adapter(CSGO):
|
|
| 995 |
num_style_tokens=self.num_style_tokens,
|
| 996 |
skip=True,
|
| 997 |
fuAttn=self.fuAttn,
|
| 998 |
-
|
| 999 |
fuIPAttn=self.fuIPAttn,
|
| 1000 |
adainIP=self.adainIP,
|
| 1001 |
-
fuScale=self.fuScale,
|
| 1002 |
end_fusion=self.end_fusion,
|
| 1003 |
-
attn_name=name,
|
| 1004 |
)
|
| 1005 |
|
| 1006 |
attn_procs[name].to(self.device, dtype=torch.float16)
|
|
|
|
| 965 |
if cross_attention_dim is None:
|
| 966 |
attn_procs[name] = AttnProcessor_hijack(
|
| 967 |
fuSAttn=self.fuSAttn,
|
| 968 |
+
end_fusion=self.end_fusion,)
|
|
|
|
|
|
|
| 969 |
else:
|
| 970 |
# layername_id += 1
|
| 971 |
selected = False
|
|
|
|
| 982 |
fuAttn=self.fuAttn,
|
| 983 |
fuIPAttn=self.fuIPAttn,
|
| 984 |
adainIP=self.adainIP,
|
|
|
|
| 985 |
end_fusion=self.end_fusion,
|
|
|
|
| 986 |
)
|
| 987 |
if selected is False:
|
| 988 |
attn_procs[name] = IPAttnProcessor_cross_modal(
|
|
|
|
| 991 |
num_style_tokens=self.num_style_tokens,
|
| 992 |
skip=True,
|
| 993 |
fuAttn=self.fuAttn,
|
|
|
|
| 994 |
fuIPAttn=self.fuIPAttn,
|
| 995 |
adainIP=self.adainIP,
|
|
|
|
| 996 |
end_fusion=self.end_fusion,
|
|
|
|
| 997 |
)
|
| 998 |
|
| 999 |
attn_procs[name].to(self.device, dtype=torch.float16)
|