Fix model namings
Browse files
app.py
CHANGED
|
@@ -41,9 +41,9 @@ tfms = transforms.Compose([
|
|
| 41 |
transforms.ToTensor(),
|
| 42 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
| 43 |
])
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
|
| 48 |
|
| 49 |
depth_config_path = 'tddfa/configs/mb05_120x120.yml' # 'tddfa/configs/mb1_120x120.yml
|
|
@@ -51,12 +51,12 @@ cfg = yaml.load(open(depth_config_path), Loader=yaml.SafeLoader)
|
|
| 51 |
tddfa = TDDFA(gpu_mode=False, **cfg)
|
| 52 |
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
weights = torch.load('./DSDG/DUM/checkpoint/CDCN_U_P1_updated.pkl', map_location=device)
|
| 57 |
-
|
| 58 |
-
optimizer = optim.Adam(
|
| 59 |
-
|
| 60 |
|
| 61 |
|
| 62 |
class Normaliztion_valtest(object):
|
|
@@ -119,7 +119,7 @@ def inference(img, model_name):
|
|
| 119 |
faceRegion = faceRegion.unsqueeze(0)
|
| 120 |
|
| 121 |
if model_name == 'DeePixBiS':
|
| 122 |
-
mask, binary =
|
| 123 |
res = torch.mean(mask).item()
|
| 124 |
cls = 'Real' if res >= pix_threshhold else 'Spoof'
|
| 125 |
res = 1 - res
|
|
@@ -144,7 +144,7 @@ def inference(img, model_name):
|
|
| 144 |
|
| 145 |
map_score = 0.0
|
| 146 |
for frame_t in range(inputs.shape[1]):
|
| 147 |
-
mu, logvar, map_x, x_concat, x_Block1, x_Block2, x_Block3, x_input =
|
| 148 |
|
| 149 |
score_norm = torch.sum(mu) / torch.sum(test_maps[:, frame_t, :, :])
|
| 150 |
map_score += score_norm
|
|
|
|
| 41 |
transforms.ToTensor(),
|
| 42 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
| 43 |
])
|
| 44 |
+
deepix_model = DeePixBiS(pretrained=False)
|
| 45 |
+
deepix_model.load_state_dict(torch.load('./DeePixBiS/DeePixBiS.pth'))
|
| 46 |
+
deepix_model.eval()
|
| 47 |
|
| 48 |
|
| 49 |
depth_config_path = 'tddfa/configs/mb05_120x120.yml' # 'tddfa/configs/mb1_120x120.yml
|
|
|
|
| 51 |
tddfa = TDDFA(gpu_mode=False, **cfg)
|
| 52 |
|
| 53 |
|
| 54 |
+
cdcn_model = CDCN_u(basic_conv=Conv2d_cd, theta=0.7)
|
| 55 |
+
cdcn_model = cdcn_model.to(device)
|
| 56 |
weights = torch.load('./DSDG/DUM/checkpoint/CDCN_U_P1_updated.pkl', map_location=device)
|
| 57 |
+
cdcn_model.load_state_dict(weights)
|
| 58 |
+
optimizer = optim.Adam(cdcn_model.parameters(), lr=0.001, weight_decay=0.00005)
|
| 59 |
+
cdcn_model.eval()
|
| 60 |
|
| 61 |
|
| 62 |
class Normaliztion_valtest(object):
|
|
|
|
| 119 |
faceRegion = faceRegion.unsqueeze(0)
|
| 120 |
|
| 121 |
if model_name == 'DeePixBiS':
|
| 122 |
+
mask, binary = deepix_model.forward(faceRegion)
|
| 123 |
res = torch.mean(mask).item()
|
| 124 |
cls = 'Real' if res >= pix_threshhold else 'Spoof'
|
| 125 |
res = 1 - res
|
|
|
|
| 144 |
|
| 145 |
map_score = 0.0
|
| 146 |
for frame_t in range(inputs.shape[1]):
|
| 147 |
+
mu, logvar, map_x, x_concat, x_Block1, x_Block2, x_Block3, x_input = cdcn_model(inputs[:, frame_t, :, :, :])
|
| 148 |
|
| 149 |
score_norm = torch.sum(mu) / torch.sum(test_maps[:, frame_t, :, :])
|
| 150 |
map_score += score_norm
|