Update app.py
Browse files
app.py
CHANGED
|
@@ -20,10 +20,10 @@ import models_vit
|
|
| 20 |
|
| 21 |
def prepare_model(chkpt_dir, arch='vit_large_patch14'):
|
| 22 |
# build model
|
| 23 |
-
model = getattr(models_vit, arch)(global_pool=
|
| 24 |
# load model
|
| 25 |
checkpoint = torch.load(chkpt_dir, map_location='cpu')
|
| 26 |
-
msg = model.load_state_dict(checkpoint['model'], strict=
|
| 27 |
print(msg)
|
| 28 |
return model
|
| 29 |
|
|
|
|
| 20 |
|
| 21 |
def prepare_model(chkpt_dir, arch='vit_large_patch14'):
|
| 22 |
# build model
|
| 23 |
+
model = getattr(models_vit, arch)(global_pool=True)
|
| 24 |
# load model
|
| 25 |
checkpoint = torch.load(chkpt_dir, map_location='cpu')
|
| 26 |
+
msg = model.load_state_dict(checkpoint['model'], strict=True)
|
| 27 |
print(msg)
|
| 28 |
return model
|
| 29 |
|