Spaces:
Runtime error
Runtime error
Duplicate from cooelf/Multimodal-CoT
Browse filesCo-authored-by: Zhuosheng Zhang <[email protected]>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +36 -0
- README.md +13 -0
- __pycache__/model.cpython-37.pyc +0 -0
- __pycache__/model.cpython-38.pyc +0 -0
- api/61.png +0 -0
- app.py +150 -0
- model.py +515 -0
- requirements.txt +6 -0
- timm/__init__.py +4 -0
- timm/__pycache__/__init__.cpython-37.pyc +0 -0
- timm/__pycache__/__init__.cpython-38.pyc +0 -0
- timm/__pycache__/version.cpython-37.pyc +0 -0
- timm/__pycache__/version.cpython-38.pyc +0 -0
- timm/data/__init__.py +12 -0
- timm/data/__pycache__/__init__.cpython-37.pyc +0 -0
- timm/data/__pycache__/__init__.cpython-38.pyc +0 -0
- timm/data/__pycache__/auto_augment.cpython-37.pyc +0 -0
- timm/data/__pycache__/auto_augment.cpython-38.pyc +0 -0
- timm/data/__pycache__/config.cpython-37.pyc +0 -0
- timm/data/__pycache__/config.cpython-38.pyc +0 -0
- timm/data/__pycache__/constants.cpython-37.pyc +0 -0
- timm/data/__pycache__/constants.cpython-38.pyc +0 -0
- timm/data/__pycache__/dataset.cpython-37.pyc +0 -0
- timm/data/__pycache__/dataset.cpython-38.pyc +0 -0
- timm/data/__pycache__/dataset_factory.cpython-37.pyc +0 -0
- timm/data/__pycache__/dataset_factory.cpython-38.pyc +0 -0
- timm/data/__pycache__/distributed_sampler.cpython-37.pyc +0 -0
- timm/data/__pycache__/distributed_sampler.cpython-38.pyc +0 -0
- timm/data/__pycache__/loader.cpython-37.pyc +0 -0
- timm/data/__pycache__/loader.cpython-38.pyc +0 -0
- timm/data/__pycache__/mixup.cpython-37.pyc +0 -0
- timm/data/__pycache__/mixup.cpython-38.pyc +0 -0
- timm/data/__pycache__/random_erasing.cpython-37.pyc +0 -0
- timm/data/__pycache__/random_erasing.cpython-38.pyc +0 -0
- timm/data/__pycache__/real_labels.cpython-37.pyc +0 -0
- timm/data/__pycache__/real_labels.cpython-38.pyc +0 -0
- timm/data/__pycache__/transforms.cpython-37.pyc +0 -0
- timm/data/__pycache__/transforms.cpython-38.pyc +0 -0
- timm/data/__pycache__/transforms_factory.cpython-37.pyc +0 -0
- timm/data/__pycache__/transforms_factory.cpython-38.pyc +0 -0
- timm/data/auto_augment.py +822 -0
- timm/data/config.py +78 -0
- timm/data/constants.py +7 -0
- timm/data/dataset.py +146 -0
- timm/data/dataset_factory.py +30 -0
- timm/data/distributed_sampler.py +51 -0
- timm/data/loader.py +262 -0
- timm/data/mixup.py +316 -0
- timm/data/parsers/__init__.py +1 -0
- timm/data/parsers/__pycache__/__init__.cpython-37.pyc +0 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
qa9.jpg filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
upload4.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Multimodal-CoT
|
| 3 |
+
emoji: 🏖️
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
app_file: app.py
|
| 8 |
+
pinned: false
|
| 9 |
+
license: openrail
|
| 10 |
+
duplicated_from: cooelf/Multimodal-CoT
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
__pycache__/model.cpython-37.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
__pycache__/model.cpython-38.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
api/61.png
ADDED
|
app.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import string
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import requests
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import T5Tokenizer
|
| 6 |
+
from model import T5ForMultimodalGeneration
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import timm
|
| 9 |
+
from timm.data import resolve_data_config
|
| 10 |
+
from timm.data.transforms_factory import create_transform
|
| 11 |
+
|
| 12 |
+
rationale_model_dir = "cooelf/MM-CoT-UnifiedQA-Base-Rationale-Joint"
|
| 13 |
+
answer_model_dir = "cooelf/MM-CoT-UnifiedQA-Base-Answer-Joint"
|
| 14 |
+
|
| 15 |
+
vit_model = timm.create_model("vit_base_patch16_384", pretrained=True, num_classes=0)
|
| 16 |
+
vit_model.eval()
|
| 17 |
+
config = resolve_data_config({}, model=vit_model)
|
| 18 |
+
transform = create_transform(**config)
|
| 19 |
+
tokenizer = T5Tokenizer.from_pretrained(rationale_model_dir)
|
| 20 |
+
r_model = T5ForMultimodalGeneration.from_pretrained(rationale_model_dir, patch_size=(577, 768))
|
| 21 |
+
a_model = T5ForMultimodalGeneration.from_pretrained(answer_model_dir, patch_size=(577, 768))
|
| 22 |
+
|
| 23 |
+
def inference_chat(input_image,input_text):
|
| 24 |
+
with torch.no_grad():
|
| 25 |
+
# print(input_image)
|
| 26 |
+
# img = Image.open(input_image).convert("RGB")
|
| 27 |
+
input = transform(input_image).unsqueeze(0)
|
| 28 |
+
out = vit_model.forward_features(input)
|
| 29 |
+
image_features = out.detach()
|
| 30 |
+
|
| 31 |
+
source = tokenizer.batch_encode_plus(
|
| 32 |
+
[input_text],
|
| 33 |
+
max_length=512,
|
| 34 |
+
pad_to_max_length=True,
|
| 35 |
+
truncation=True,
|
| 36 |
+
padding="max_length",
|
| 37 |
+
return_tensors="pt",
|
| 38 |
+
)
|
| 39 |
+
source_ids = source["input_ids"]
|
| 40 |
+
source_mask = source["attention_mask"]
|
| 41 |
+
rationale = r_model.generate(
|
| 42 |
+
input_ids=source_ids,
|
| 43 |
+
attention_mask=source_mask,
|
| 44 |
+
image_ids=image_features,
|
| 45 |
+
max_length=512,
|
| 46 |
+
num_beams=1,
|
| 47 |
+
do_sample=False
|
| 48 |
+
)
|
| 49 |
+
rationale = tokenizer.batch_decode(rationale, skip_special_tokens=True)[0]
|
| 50 |
+
print(rationale)
|
| 51 |
+
|
| 52 |
+
input_text = input_text + "\n" + rationale +"\nAnswer:"
|
| 53 |
+
print(input_text)
|
| 54 |
+
|
| 55 |
+
source = tokenizer.batch_encode_plus(
|
| 56 |
+
[input_text],
|
| 57 |
+
max_length=512,
|
| 58 |
+
pad_to_max_length=True,
|
| 59 |
+
truncation=True,
|
| 60 |
+
padding="max_length",
|
| 61 |
+
return_tensors="pt",
|
| 62 |
+
)
|
| 63 |
+
source_ids = source["input_ids"]
|
| 64 |
+
source_mask = source["attention_mask"]
|
| 65 |
+
answer = a_model.generate(
|
| 66 |
+
input_ids=source_ids,
|
| 67 |
+
attention_mask=source_mask,
|
| 68 |
+
image_ids=image_features,
|
| 69 |
+
max_length=64,
|
| 70 |
+
num_beams=1,
|
| 71 |
+
do_sample=False
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
answer = tokenizer.batch_decode(answer, skip_special_tokens=True)[0]
|
| 75 |
+
return rationale, answer
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
title = """# Multimodal-CoT"""
|
| 79 |
+
# description = """**VLE** (Visual-Language Encoder) is an image-text multimodal understanding model built on the pre-trained text and image encoders. See https://github.com/iflytek/VLE for more details.
|
| 80 |
+
# We demonstrate visual question answering systems built with VLE and LLM."""
|
| 81 |
+
# description1 = """**VQA**: The image and the question are fed to a VQA model (VLEForVQA) and the model predicts the answer.
|
| 82 |
+
|
| 83 |
+
# **VQA+LLM**: We feed the caption, question, and answers predicted by the VQA model to the LLM and ask the LLM to generate the final answer. The outptus from VQA+LLM may vary due to the decoding strategy of the LLM."""
|
| 84 |
+
|
| 85 |
+
with gr.Blocks(
|
| 86 |
+
css="""
|
| 87 |
+
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
|
| 88 |
+
#component-21 > div.wrap.svelte-w6rprc {height: 600px;}
|
| 89 |
+
"""
|
| 90 |
+
) as iface:
|
| 91 |
+
state = gr.State([])
|
| 92 |
+
#caption_output = None
|
| 93 |
+
gr.Markdown(title)
|
| 94 |
+
# gr.Markdown(description)
|
| 95 |
+
#gr.Markdown(article)
|
| 96 |
+
|
| 97 |
+
with gr.Row():
|
| 98 |
+
with gr.Column(scale=1):
|
| 99 |
+
image_input = gr.Image(type="pil",label="Image")
|
| 100 |
+
with gr.Row():
|
| 101 |
+
with gr.Column(scale=1):
|
| 102 |
+
chat_input = gr.Textbox(lines=1, label="Question")
|
| 103 |
+
with gr.Row():
|
| 104 |
+
clear_button = gr.Button(value="Clear", interactive=True,width=30)
|
| 105 |
+
submit_button = gr.Button(
|
| 106 |
+
value="Submit", interactive=True, variant="primary"
|
| 107 |
+
)
|
| 108 |
+
'''
|
| 109 |
+
cap_submit_button = gr.Button(
|
| 110 |
+
value="Submit_CAP", interactive=True, variant="primary"
|
| 111 |
+
)
|
| 112 |
+
gpt3_submit_button = gr.Button(
|
| 113 |
+
value="Submit_GPT3", interactive=True, variant="primary"
|
| 114 |
+
)
|
| 115 |
+
'''
|
| 116 |
+
with gr.Column():
|
| 117 |
+
# gr.Markdown(description1)
|
| 118 |
+
rationale = gr.Textbox(lines=0, label="Rationale")
|
| 119 |
+
answer = gr.Textbox(lines=0, label="Answer")
|
| 120 |
+
|
| 121 |
+
chat_input.submit(
|
| 122 |
+
inference_chat,
|
| 123 |
+
[
|
| 124 |
+
image_input,
|
| 125 |
+
chat_input,
|
| 126 |
+
],
|
| 127 |
+
[rationale, answer],
|
| 128 |
+
)
|
| 129 |
+
clear_button.click(
|
| 130 |
+
lambda: ("", [],"",""),
|
| 131 |
+
[],
|
| 132 |
+
[chat_input, state, rationale, answer],
|
| 133 |
+
queue=False,
|
| 134 |
+
)
|
| 135 |
+
submit_button.click(
|
| 136 |
+
inference_chat,
|
| 137 |
+
[
|
| 138 |
+
image_input,
|
| 139 |
+
chat_input,
|
| 140 |
+
],
|
| 141 |
+
[rationale, answer],
|
| 142 |
+
)
|
| 143 |
+
examples=[['api/61.png',"Question: Think about the magnetic force between the magnets in each pair. Which of the following statements is true?\nContext: The images below show two pairs of magnets. The magnets in different pairs do not affect each other. All the magnets shown are made of the same material, but some of them are different sizes and shapes.\nOptions: (A) The magnitude of the magnetic force is the same in both pairs. (B) The magnitude of the magnetic force is smaller in Pair 1. (C) The magnitude of the magnetic force is smaller in Pair 2.\nSolution:","Magnet sizes affect the magnitude of the magnetic force. Imagine magnets that are the same shape and made of the same material. The smaller the magnets, the smaller the magnitude of the magnetic force between them.nMagnet A is the same size in both pairs. But Magnet B is smaller in Pair 2 than in Pair 1. So, the magnitude of the magnetic force is smaller in Pair 2 than in Pair 1.","The answer is (C)."],
|
| 144 |
+
]
|
| 145 |
+
examples = gr.Examples(
|
| 146 |
+
examples=examples,inputs=[image_input, chat_input, rationale, answer],
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
iface.queue(concurrency_count=1, api_open=False, max_size=10)
|
| 150 |
+
iface.launch(enable_queue=True)
|
model.py
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Adapted from https://github.com/huggingface/transformers
|
| 3 |
+
'''
|
| 4 |
+
|
| 5 |
+
from transformers import T5Config, T5ForConditionalGeneration
|
| 6 |
+
from transformers.models.t5.modeling_t5 import T5Stack, __HEAD_MASK_WARNING_MSG, T5Block, T5LayerNorm
|
| 7 |
+
import copy
|
| 8 |
+
from transformers.modeling_outputs import ModelOutput, BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput
|
| 9 |
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
| 10 |
+
import math
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
from typing import Optional, Tuple, Union
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn
|
| 16 |
+
from torch.nn import CrossEntropyLoss
|
| 17 |
+
from transformers.modeling_outputs import (
|
| 18 |
+
BaseModelOutput,
|
| 19 |
+
Seq2SeqLMOutput,
|
| 20 |
+
)
|
| 21 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
| 22 |
+
from torch.utils.checkpoint import checkpoint
|
| 23 |
+
|
| 24 |
+
class JointEncoder(T5Stack):
|
| 25 |
+
def __init__(self, config, embed_tokens=None, patch_size=None):
|
| 26 |
+
super().__init__(config)
|
| 27 |
+
|
| 28 |
+
self.embed_tokens = embed_tokens
|
| 29 |
+
self.is_decoder = config.is_decoder
|
| 30 |
+
|
| 31 |
+
self.patch_num, self.patch_dim = patch_size
|
| 32 |
+
self.image_dense = nn.Linear(self.patch_dim, config.d_model)
|
| 33 |
+
self.mha_layer = torch.nn.MultiheadAttention(embed_dim=config.hidden_size, kdim=config.hidden_size, vdim=config.hidden_size, num_heads=1, batch_first=True)
|
| 34 |
+
self.gate_dense = nn.Linear(2*config.hidden_size, config.hidden_size)
|
| 35 |
+
self.sigmoid = nn.Sigmoid()
|
| 36 |
+
|
| 37 |
+
self.block = nn.ModuleList(
|
| 38 |
+
[T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
|
| 39 |
+
)
|
| 40 |
+
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
| 41 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
| 42 |
+
|
| 43 |
+
# Initialize weights and apply final processing
|
| 44 |
+
self.post_init()
|
| 45 |
+
# Model parallel
|
| 46 |
+
self.model_parallel = False
|
| 47 |
+
self.device_map = None
|
| 48 |
+
self.gradient_checkpointing = False
|
| 49 |
+
|
| 50 |
+
def parallelize(self, device_map=None):
|
| 51 |
+
warnings.warn(
|
| 52 |
+
"`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
|
| 53 |
+
" with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
|
| 54 |
+
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
|
| 55 |
+
" 'block.1': 1, ...}",
|
| 56 |
+
FutureWarning,
|
| 57 |
+
)
|
| 58 |
+
# Check validity of device_map
|
| 59 |
+
self.device_map = (
|
| 60 |
+
get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
|
| 61 |
+
)
|
| 62 |
+
assert_device_map(self.device_map, len(self.block))
|
| 63 |
+
self.model_parallel = True
|
| 64 |
+
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
|
| 65 |
+
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
| 66 |
+
# Load onto devices
|
| 67 |
+
for k, v in self.device_map.items():
|
| 68 |
+
for layer in v:
|
| 69 |
+
cuda_device = "cuda:" + str(k)
|
| 70 |
+
self.block[layer] = self.block[layer].to(cuda_device)
|
| 71 |
+
|
| 72 |
+
# Set embed_tokens to first layer
|
| 73 |
+
self.embed_tokens = self.embed_tokens.to(self.first_device)
|
| 74 |
+
# Set final layer norm to last device
|
| 75 |
+
self.final_layer_norm = self.final_layer_norm.to(self.last_device)
|
| 76 |
+
|
| 77 |
+
def deparallelize(self):
|
| 78 |
+
warnings.warn(
|
| 79 |
+
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
|
| 80 |
+
FutureWarning,
|
| 81 |
+
)
|
| 82 |
+
self.model_parallel = False
|
| 83 |
+
self.device_map = None
|
| 84 |
+
self.first_device = "cpu"
|
| 85 |
+
self.last_device = "cpu"
|
| 86 |
+
for i in range(len(self.block)):
|
| 87 |
+
self.block[i] = self.block[i].to("cpu")
|
| 88 |
+
self.embed_tokens = self.embed_tokens.to("cpu")
|
| 89 |
+
self.final_layer_norm = self.final_layer_norm.to("cpu")
|
| 90 |
+
torch.cuda.empty_cache()
|
| 91 |
+
|
| 92 |
+
def get_input_embeddings(self):
|
| 93 |
+
return self.embed_tokens
|
| 94 |
+
|
| 95 |
+
def set_input_embeddings(self, new_embeddings):
|
| 96 |
+
self.embed_tokens = new_embeddings
|
| 97 |
+
|
| 98 |
+
def forward(
|
| 99 |
+
self,
|
| 100 |
+
input_ids=None,
|
| 101 |
+
attention_mask=None,
|
| 102 |
+
encoder_hidden_states=None,
|
| 103 |
+
encoder_attention_mask=None,
|
| 104 |
+
inputs_embeds=None,
|
| 105 |
+
image_ids=None,
|
| 106 |
+
head_mask=None,
|
| 107 |
+
cross_attn_head_mask=None,
|
| 108 |
+
past_key_values=None,
|
| 109 |
+
use_cache=None,
|
| 110 |
+
output_attentions=None,
|
| 111 |
+
output_hidden_states=None,
|
| 112 |
+
return_dict=None,
|
| 113 |
+
):
|
| 114 |
+
# Model parallel
|
| 115 |
+
if self.model_parallel:
|
| 116 |
+
torch.cuda.set_device(self.first_device)
|
| 117 |
+
self.embed_tokens = self.embed_tokens.to(self.first_device)
|
| 118 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 119 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 120 |
+
output_hidden_states = (
|
| 121 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 122 |
+
)
|
| 123 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 124 |
+
|
| 125 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 126 |
+
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
|
| 129 |
+
)
|
| 130 |
+
elif input_ids is not None:
|
| 131 |
+
input_shape = input_ids.size()
|
| 132 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 133 |
+
elif inputs_embeds is not None:
|
| 134 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 135 |
+
else:
|
| 136 |
+
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
| 137 |
+
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
|
| 138 |
+
|
| 139 |
+
if inputs_embeds is None:
|
| 140 |
+
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
|
| 141 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 142 |
+
|
| 143 |
+
batch_size, seq_length = input_shape
|
| 144 |
+
|
| 145 |
+
# required mask seq length can be calculated via length of past
|
| 146 |
+
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
|
| 147 |
+
|
| 148 |
+
if use_cache is True:
|
| 149 |
+
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
|
| 150 |
+
|
| 151 |
+
if attention_mask is None:
|
| 152 |
+
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
|
| 153 |
+
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
|
| 154 |
+
encoder_seq_length = encoder_hidden_states.shape[1]
|
| 155 |
+
encoder_attention_mask = torch.ones(
|
| 156 |
+
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# initialize past_key_values with `None` if past does not exist
|
| 160 |
+
if past_key_values is None:
|
| 161 |
+
past_key_values = [None] * len(self.block)
|
| 162 |
+
|
| 163 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 164 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 165 |
+
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
| 166 |
+
|
| 167 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 168 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 169 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
| 170 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 171 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 172 |
+
if encoder_attention_mask is None:
|
| 173 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
|
| 174 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 175 |
+
else:
|
| 176 |
+
encoder_extended_attention_mask = None
|
| 177 |
+
|
| 178 |
+
# Prepare head mask if needed
|
| 179 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
| 180 |
+
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
|
| 181 |
+
present_key_value_states = () if use_cache else None
|
| 182 |
+
all_hidden_states = () if output_hidden_states else None
|
| 183 |
+
all_attentions = () if output_attentions else None
|
| 184 |
+
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
|
| 185 |
+
position_bias = None
|
| 186 |
+
encoder_decoder_position_bias = None
|
| 187 |
+
|
| 188 |
+
hidden_states = self.dropout(inputs_embeds)
|
| 189 |
+
|
| 190 |
+
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
|
| 191 |
+
layer_head_mask = head_mask[i]
|
| 192 |
+
cross_attn_layer_head_mask = cross_attn_head_mask[i]
|
| 193 |
+
# Model parallel
|
| 194 |
+
if self.model_parallel:
|
| 195 |
+
torch.cuda.set_device(hidden_states.device)
|
| 196 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
| 197 |
+
if attention_mask is not None:
|
| 198 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
| 199 |
+
if position_bias is not None:
|
| 200 |
+
position_bias = position_bias.to(hidden_states.device)
|
| 201 |
+
if encoder_hidden_states is not None:
|
| 202 |
+
encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
|
| 203 |
+
if encoder_extended_attention_mask is not None:
|
| 204 |
+
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
|
| 205 |
+
if encoder_decoder_position_bias is not None:
|
| 206 |
+
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
|
| 207 |
+
if layer_head_mask is not None:
|
| 208 |
+
layer_head_mask = layer_head_mask.to(hidden_states.device)
|
| 209 |
+
if cross_attn_layer_head_mask is not None:
|
| 210 |
+
cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
|
| 211 |
+
if output_hidden_states:
|
| 212 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 213 |
+
|
| 214 |
+
if self.gradient_checkpointing and self.training:
|
| 215 |
+
if use_cache:
|
| 216 |
+
logger.warning_once(
|
| 217 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 218 |
+
)
|
| 219 |
+
use_cache = False
|
| 220 |
+
|
| 221 |
+
def create_custom_forward(module):
|
| 222 |
+
def custom_forward(*inputs):
|
| 223 |
+
return tuple(module(*inputs, use_cache, output_attentions))
|
| 224 |
+
|
| 225 |
+
return custom_forward
|
| 226 |
+
|
| 227 |
+
layer_outputs = checkpoint(
|
| 228 |
+
create_custom_forward(layer_module),
|
| 229 |
+
hidden_states,
|
| 230 |
+
extended_attention_mask,
|
| 231 |
+
position_bias,
|
| 232 |
+
encoder_hidden_states,
|
| 233 |
+
encoder_extended_attention_mask,
|
| 234 |
+
encoder_decoder_position_bias,
|
| 235 |
+
layer_head_mask,
|
| 236 |
+
cross_attn_layer_head_mask,
|
| 237 |
+
None, # past_key_value is always None with gradient checkpointing
|
| 238 |
+
)
|
| 239 |
+
else:
|
| 240 |
+
layer_outputs = layer_module(
|
| 241 |
+
hidden_states,
|
| 242 |
+
attention_mask=extended_attention_mask,
|
| 243 |
+
position_bias=position_bias,
|
| 244 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 245 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 246 |
+
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
| 247 |
+
layer_head_mask=layer_head_mask,
|
| 248 |
+
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
|
| 249 |
+
past_key_value=past_key_value,
|
| 250 |
+
use_cache=use_cache,
|
| 251 |
+
output_attentions=output_attentions,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# layer_outputs is a tuple with:
|
| 255 |
+
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
|
| 256 |
+
if use_cache is False:
|
| 257 |
+
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
|
| 258 |
+
|
| 259 |
+
hidden_states, present_key_value_state = layer_outputs[:2]
|
| 260 |
+
|
| 261 |
+
# We share the position biases between the layers - the first layer store them
|
| 262 |
+
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
|
| 263 |
+
# (cross-attention position bias), (cross-attention weights)
|
| 264 |
+
position_bias = layer_outputs[2]
|
| 265 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
| 266 |
+
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
|
| 267 |
+
# append next layer key value states
|
| 268 |
+
if use_cache:
|
| 269 |
+
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
| 270 |
+
|
| 271 |
+
if output_attentions:
|
| 272 |
+
all_attentions = all_attentions + (layer_outputs[3],)
|
| 273 |
+
if self.is_decoder:
|
| 274 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
|
| 275 |
+
|
| 276 |
+
# Model Parallel: If it's the last layer for that device, put things on the next device
|
| 277 |
+
if self.model_parallel:
|
| 278 |
+
for k, v in self.device_map.items():
|
| 279 |
+
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
| 280 |
+
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
| 281 |
+
|
| 282 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
| 283 |
+
hidden_states = self.dropout(hidden_states)
|
| 284 |
+
|
| 285 |
+
# Add last layer
|
| 286 |
+
if output_hidden_states:
|
| 287 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 288 |
+
|
| 289 |
+
image_embedding = self.image_dense(image_ids)
|
| 290 |
+
image_att, _ = self.mha_layer(hidden_states, image_embedding, image_embedding)
|
| 291 |
+
merge = torch.cat([hidden_states, image_att], dim=-1)
|
| 292 |
+
gate = self.sigmoid(self.gate_dense(merge))
|
| 293 |
+
hidden_states = (1 - gate) * hidden_states + gate * image_att
|
| 294 |
+
|
| 295 |
+
if not return_dict:
|
| 296 |
+
return tuple(
|
| 297 |
+
v
|
| 298 |
+
for v in [
|
| 299 |
+
hidden_states,
|
| 300 |
+
present_key_value_states,
|
| 301 |
+
all_hidden_states,
|
| 302 |
+
all_attentions,
|
| 303 |
+
all_cross_attentions,
|
| 304 |
+
]
|
| 305 |
+
if v is not None
|
| 306 |
+
)
|
| 307 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 308 |
+
last_hidden_state=hidden_states,
|
| 309 |
+
past_key_values=present_key_value_states,
|
| 310 |
+
hidden_states=all_hidden_states,
|
| 311 |
+
attentions=all_attentions,
|
| 312 |
+
cross_attentions=all_cross_attentions,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class T5ForMultimodalGeneration(T5ForConditionalGeneration):
|
| 317 |
+
_keys_to_ignore_on_load_missing = [
|
| 318 |
+
r"encoder.embed_tokens.weight",
|
| 319 |
+
r"decoder.embed_tokens.weight",
|
| 320 |
+
r"lm_head.weight",
|
| 321 |
+
]
|
| 322 |
+
_keys_to_ignore_on_load_unexpected = [
|
| 323 |
+
r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
def __init__(self, config: T5Config, patch_size):
|
| 327 |
+
super().__init__(config)
|
| 328 |
+
self.model_dim = config.d_model
|
| 329 |
+
|
| 330 |
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
| 331 |
+
|
| 332 |
+
encoder_config = copy.deepcopy(config)
|
| 333 |
+
encoder_config.is_decoder = False
|
| 334 |
+
encoder_config.use_cache = False
|
| 335 |
+
encoder_config.is_encoder_decoder = False
|
| 336 |
+
# self.encoder = T5Stack(encoder_config, self.shared)
|
| 337 |
+
self.encoder = JointEncoder(encoder_config, self.shared, patch_size)
|
| 338 |
+
decoder_config = copy.deepcopy(config)
|
| 339 |
+
decoder_config.is_decoder = True
|
| 340 |
+
decoder_config.is_encoder_decoder = False
|
| 341 |
+
decoder_config.num_layers = config.num_decoder_layers
|
| 342 |
+
self.decoder = T5Stack(decoder_config, self.shared)
|
| 343 |
+
|
| 344 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 345 |
+
|
| 346 |
+
# Initialize weights and apply final processing
|
| 347 |
+
self.post_init()
|
| 348 |
+
|
| 349 |
+
# Model parallel
|
| 350 |
+
self.model_parallel = False
|
| 351 |
+
self.device_map = None
|
| 352 |
+
|
| 353 |
+
def forward(
|
| 354 |
+
self,
|
| 355 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 356 |
+
image_ids=None,
|
| 357 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 358 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
| 359 |
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
| 360 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 361 |
+
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
| 362 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
| 363 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 364 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 365 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 366 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 367 |
+
labels: Optional[torch.LongTensor] = None,
|
| 368 |
+
use_cache: Optional[bool] = None,
|
| 369 |
+
output_attentions: Optional[bool] = None,
|
| 370 |
+
output_hidden_states: Optional[bool] = None,
|
| 371 |
+
return_dict: Optional[bool] = None,
|
| 372 |
+
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
| 373 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 374 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 375 |
+
|
| 376 |
+
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
| 377 |
+
if head_mask is not None and decoder_head_mask is None:
|
| 378 |
+
if self.config.num_layers == self.config.num_decoder_layers:
|
| 379 |
+
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
| 380 |
+
decoder_head_mask = head_mask
|
| 381 |
+
|
| 382 |
+
# Encode if needed (training, first prediction pass)
|
| 383 |
+
if encoder_outputs is None:
|
| 384 |
+
# Convert encoder inputs in embeddings if needed
|
| 385 |
+
encoder_outputs = self.encoder(
|
| 386 |
+
input_ids=input_ids,
|
| 387 |
+
attention_mask=attention_mask,
|
| 388 |
+
inputs_embeds=inputs_embeds,
|
| 389 |
+
image_ids=image_ids,
|
| 390 |
+
head_mask=head_mask,
|
| 391 |
+
output_attentions=output_attentions,
|
| 392 |
+
output_hidden_states=output_hidden_states,
|
| 393 |
+
return_dict=return_dict,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
| 397 |
+
encoder_outputs = BaseModelOutput(
|
| 398 |
+
last_hidden_state=encoder_outputs[0],
|
| 399 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
| 400 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
hidden_states = encoder_outputs[0]
|
| 404 |
+
|
| 405 |
+
if self.model_parallel:
|
| 406 |
+
torch.cuda.set_device(self.decoder.first_device)
|
| 407 |
+
|
| 408 |
+
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
| 409 |
+
# get decoder inputs from shifting lm labels to the right
|
| 410 |
+
decoder_input_ids = self._shift_right(labels)
|
| 411 |
+
|
| 412 |
+
# Set device for model parallelism
|
| 413 |
+
if self.model_parallel:
|
| 414 |
+
torch.cuda.set_device(self.decoder.first_device)
|
| 415 |
+
hidden_states = hidden_states.to(self.decoder.first_device)
|
| 416 |
+
if decoder_input_ids is not None:
|
| 417 |
+
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
|
| 418 |
+
if attention_mask is not None:
|
| 419 |
+
attention_mask = attention_mask.to(self.decoder.first_device)
|
| 420 |
+
if decoder_attention_mask is not None:
|
| 421 |
+
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
|
| 422 |
+
|
| 423 |
+
# Decode
|
| 424 |
+
decoder_outputs = self.decoder(
|
| 425 |
+
input_ids=decoder_input_ids,
|
| 426 |
+
attention_mask=decoder_attention_mask,
|
| 427 |
+
inputs_embeds=decoder_inputs_embeds,
|
| 428 |
+
past_key_values=past_key_values,
|
| 429 |
+
encoder_hidden_states=hidden_states,
|
| 430 |
+
encoder_attention_mask=attention_mask,
|
| 431 |
+
head_mask=decoder_head_mask,
|
| 432 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
| 433 |
+
use_cache=use_cache,
|
| 434 |
+
output_attentions=output_attentions,
|
| 435 |
+
output_hidden_states=output_hidden_states,
|
| 436 |
+
return_dict=return_dict,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
sequence_output = decoder_outputs[0]
|
| 440 |
+
|
| 441 |
+
# Set device for model parallelism
|
| 442 |
+
if self.model_parallel:
|
| 443 |
+
torch.cuda.set_device(self.encoder.first_device)
|
| 444 |
+
self.lm_head = self.lm_head.to(self.encoder.first_device)
|
| 445 |
+
sequence_output = sequence_output.to(self.lm_head.weight.device)
|
| 446 |
+
|
| 447 |
+
if self.config.tie_word_embeddings:
|
| 448 |
+
# Rescale output before projecting on vocab
|
| 449 |
+
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
| 450 |
+
sequence_output = sequence_output * (self.model_dim**-0.5)
|
| 451 |
+
|
| 452 |
+
lm_logits = self.lm_head(sequence_output)
|
| 453 |
+
|
| 454 |
+
loss = None
|
| 455 |
+
if labels is not None:
|
| 456 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
| 457 |
+
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
| 458 |
+
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
| 459 |
+
|
| 460 |
+
if not return_dict:
|
| 461 |
+
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
| 462 |
+
return ((loss,) + output) if loss is not None else output
|
| 463 |
+
|
| 464 |
+
return Seq2SeqLMOutput(
|
| 465 |
+
loss=loss,
|
| 466 |
+
logits=lm_logits,
|
| 467 |
+
past_key_values=decoder_outputs.past_key_values,
|
| 468 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 469 |
+
decoder_attentions=decoder_outputs.attentions,
|
| 470 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
| 471 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
| 472 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 473 |
+
encoder_attentions=encoder_outputs.attentions,
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
def prepare_inputs_for_generation(
|
| 477 |
+
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
| 478 |
+
):
|
| 479 |
+
# cut decoder_input_ids if past is used
|
| 480 |
+
if past is not None:
|
| 481 |
+
decoder_input_ids = decoder_input_ids[:, -1:]
|
| 482 |
+
|
| 483 |
+
output = {
|
| 484 |
+
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
| 485 |
+
"encoder_outputs": encoder_outputs,
|
| 486 |
+
"past_key_values": past,
|
| 487 |
+
"decoder_input_ids": decoder_input_ids,
|
| 488 |
+
"attention_mask": attention_mask,
|
| 489 |
+
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
if "image_ids" in kwargs:
|
| 493 |
+
output["image_ids"] = kwargs['image_ids']
|
| 494 |
+
|
| 495 |
+
return output
|
| 496 |
+
|
| 497 |
+
def test_step(self, tokenizer, batch, **kwargs):
|
| 498 |
+
device = next(self.parameters()).device
|
| 499 |
+
input_ids = batch['input_ids'].to(device)
|
| 500 |
+
image_ids = batch['image_ids'].to(device)
|
| 501 |
+
|
| 502 |
+
output = self.generate(
|
| 503 |
+
input_ids=input_ids,
|
| 504 |
+
image_ids=image_ids,
|
| 505 |
+
**kwargs
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
generated_sents = tokenizer.batch_decode(output, skip_special_tokens=True)
|
| 509 |
+
targets = tokenizer.batch_decode(batch['labels'], skip_special_tokens=True)
|
| 510 |
+
|
| 511 |
+
result = {}
|
| 512 |
+
result['preds'] = generated_sents
|
| 513 |
+
result['targets'] = targets
|
| 514 |
+
|
| 515 |
+
return result
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
huggingface-hub>=0.4.0
|
| 2 |
+
git+https://github.com/huggingface/transformers.git
|
| 3 |
+
torch
|
| 4 |
+
torchvision
|
| 5 |
+
sentencepiece
|
| 6 |
+
numpy
|
timm/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .version import __version__
|
| 2 |
+
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
|
| 3 |
+
is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
|
| 4 |
+
get_model_default_value, is_model_pretrained
|
timm/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (537 Bytes). View file
|
|
|
timm/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (541 Bytes). View file
|
|
|
timm/__pycache__/version.cpython-37.pyc
ADDED
|
Binary file (156 Bytes). View file
|
|
|
timm/__pycache__/version.cpython-38.pyc
ADDED
|
Binary file (160 Bytes). View file
|
|
|
timm/data/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
| 2 |
+
rand_augment_transform, auto_augment_transform
|
| 3 |
+
from .config import resolve_data_config
|
| 4 |
+
from .constants import *
|
| 5 |
+
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
| 6 |
+
from .dataset_factory import create_dataset
|
| 7 |
+
from .loader import create_loader
|
| 8 |
+
from .mixup import Mixup, FastCollateMixup
|
| 9 |
+
from .parsers import create_parser
|
| 10 |
+
from .real_labels import RealLabelsImagenet
|
| 11 |
+
from .transforms import *
|
| 12 |
+
from .transforms_factory import create_transform
|
timm/data/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (848 Bytes). View file
|
|
|
timm/data/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (852 Bytes). View file
|
|
|
timm/data/__pycache__/auto_augment.cpython-37.pyc
ADDED
|
Binary file (25.2 kB). View file
|
|
|
timm/data/__pycache__/auto_augment.cpython-38.pyc
ADDED
|
Binary file (23.3 kB). View file
|
|
|
timm/data/__pycache__/config.cpython-37.pyc
ADDED
|
Binary file (1.59 kB). View file
|
|
|
timm/data/__pycache__/config.cpython-38.pyc
ADDED
|
Binary file (1.6 kB). View file
|
|
|
timm/data/__pycache__/constants.cpython-37.pyc
ADDED
|
Binary file (483 Bytes). View file
|
|
|
timm/data/__pycache__/constants.cpython-38.pyc
ADDED
|
Binary file (479 Bytes). View file
|
|
|
timm/data/__pycache__/dataset.cpython-37.pyc
ADDED
|
Binary file (5 kB). View file
|
|
|
timm/data/__pycache__/dataset.cpython-38.pyc
ADDED
|
Binary file (5.06 kB). View file
|
|
|
timm/data/__pycache__/dataset_factory.cpython-37.pyc
ADDED
|
Binary file (938 Bytes). View file
|
|
|
timm/data/__pycache__/dataset_factory.cpython-38.pyc
ADDED
|
Binary file (966 Bytes). View file
|
|
|
timm/data/__pycache__/distributed_sampler.cpython-37.pyc
ADDED
|
Binary file (2.06 kB). View file
|
|
|
timm/data/__pycache__/distributed_sampler.cpython-38.pyc
ADDED
|
Binary file (2.08 kB). View file
|
|
|
timm/data/__pycache__/loader.cpython-37.pyc
ADDED
|
Binary file (7.09 kB). View file
|
|
|
timm/data/__pycache__/loader.cpython-38.pyc
ADDED
|
Binary file (7.13 kB). View file
|
|
|
timm/data/__pycache__/mixup.cpython-37.pyc
ADDED
|
Binary file (11.5 kB). View file
|
|
|
timm/data/__pycache__/mixup.cpython-38.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
timm/data/__pycache__/random_erasing.cpython-37.pyc
ADDED
|
Binary file (3.66 kB). View file
|
|
|
timm/data/__pycache__/random_erasing.cpython-38.pyc
ADDED
|
Binary file (3.69 kB). View file
|
|
|
timm/data/__pycache__/real_labels.cpython-37.pyc
ADDED
|
Binary file (2.37 kB). View file
|
|
|
timm/data/__pycache__/real_labels.cpython-38.pyc
ADDED
|
Binary file (2.4 kB). View file
|
|
|
timm/data/__pycache__/transforms.cpython-37.pyc
ADDED
|
Binary file (5.66 kB). View file
|
|
|
timm/data/__pycache__/transforms.cpython-38.pyc
ADDED
|
Binary file (5.7 kB). View file
|
|
|
timm/data/__pycache__/transforms_factory.cpython-37.pyc
ADDED
|
Binary file (5.01 kB). View file
|
|
|
timm/data/__pycache__/transforms_factory.cpython-38.pyc
ADDED
|
Binary file (5.06 kB). View file
|
|
|
timm/data/auto_augment.py
ADDED
|
@@ -0,0 +1,822 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" AutoAugment, RandAugment, and AugMix for PyTorch
|
| 2 |
+
|
| 3 |
+
This code implements the searched ImageNet policies with various tweaks and improvements and
|
| 4 |
+
does not include any of the search code.
|
| 5 |
+
|
| 6 |
+
AA and RA Implementation adapted from:
|
| 7 |
+
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
|
| 8 |
+
|
| 9 |
+
AugMix adapted from:
|
| 10 |
+
https://github.com/google-research/augmix
|
| 11 |
+
|
| 12 |
+
Papers:
|
| 13 |
+
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
|
| 14 |
+
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
|
| 15 |
+
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
|
| 16 |
+
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
|
| 17 |
+
|
| 18 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 19 |
+
"""
|
| 20 |
+
import random
|
| 21 |
+
import math
|
| 22 |
+
import re
|
| 23 |
+
from PIL import Image, ImageOps, ImageEnhance, ImageChops
|
| 24 |
+
import PIL
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
|
| 29 |
+
|
| 30 |
+
_FILL = (128, 128, 128)
|
| 31 |
+
|
| 32 |
+
# This signifies the max integer that the controller RNN could predict for the
|
| 33 |
+
# augmentation scheme.
|
| 34 |
+
_MAX_LEVEL = 10.
|
| 35 |
+
|
| 36 |
+
_HPARAMS_DEFAULT = dict(
|
| 37 |
+
translate_const=250,
|
| 38 |
+
img_mean=_FILL,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _interpolation(kwargs):
|
| 45 |
+
interpolation = kwargs.pop('resample', Image.BILINEAR)
|
| 46 |
+
if isinstance(interpolation, (list, tuple)):
|
| 47 |
+
return random.choice(interpolation)
|
| 48 |
+
else:
|
| 49 |
+
return interpolation
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _check_args_tf(kwargs):
|
| 53 |
+
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
|
| 54 |
+
kwargs.pop('fillcolor')
|
| 55 |
+
kwargs['resample'] = _interpolation(kwargs)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def shear_x(img, factor, **kwargs):
|
| 59 |
+
_check_args_tf(kwargs)
|
| 60 |
+
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def shear_y(img, factor, **kwargs):
|
| 64 |
+
_check_args_tf(kwargs)
|
| 65 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def translate_x_rel(img, pct, **kwargs):
|
| 69 |
+
pixels = pct * img.size[0]
|
| 70 |
+
_check_args_tf(kwargs)
|
| 71 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def translate_y_rel(img, pct, **kwargs):
|
| 75 |
+
pixels = pct * img.size[1]
|
| 76 |
+
_check_args_tf(kwargs)
|
| 77 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def translate_x_abs(img, pixels, **kwargs):
|
| 81 |
+
_check_args_tf(kwargs)
|
| 82 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def translate_y_abs(img, pixels, **kwargs):
|
| 86 |
+
_check_args_tf(kwargs)
|
| 87 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def rotate(img, degrees, **kwargs):
|
| 91 |
+
_check_args_tf(kwargs)
|
| 92 |
+
if _PIL_VER >= (5, 2):
|
| 93 |
+
return img.rotate(degrees, **kwargs)
|
| 94 |
+
elif _PIL_VER >= (5, 0):
|
| 95 |
+
w, h = img.size
|
| 96 |
+
post_trans = (0, 0)
|
| 97 |
+
rotn_center = (w / 2.0, h / 2.0)
|
| 98 |
+
angle = -math.radians(degrees)
|
| 99 |
+
matrix = [
|
| 100 |
+
round(math.cos(angle), 15),
|
| 101 |
+
round(math.sin(angle), 15),
|
| 102 |
+
0.0,
|
| 103 |
+
round(-math.sin(angle), 15),
|
| 104 |
+
round(math.cos(angle), 15),
|
| 105 |
+
0.0,
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
def transform(x, y, matrix):
|
| 109 |
+
(a, b, c, d, e, f) = matrix
|
| 110 |
+
return a * x + b * y + c, d * x + e * y + f
|
| 111 |
+
|
| 112 |
+
matrix[2], matrix[5] = transform(
|
| 113 |
+
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
|
| 114 |
+
)
|
| 115 |
+
matrix[2] += rotn_center[0]
|
| 116 |
+
matrix[5] += rotn_center[1]
|
| 117 |
+
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
|
| 118 |
+
else:
|
| 119 |
+
return img.rotate(degrees, resample=kwargs['resample'])
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def auto_contrast(img, **__):
|
| 123 |
+
return ImageOps.autocontrast(img)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def invert(img, **__):
|
| 127 |
+
return ImageOps.invert(img)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def equalize(img, **__):
|
| 131 |
+
return ImageOps.equalize(img)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def solarize(img, thresh, **__):
|
| 135 |
+
return ImageOps.solarize(img, thresh)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def solarize_add(img, add, thresh=128, **__):
|
| 139 |
+
lut = []
|
| 140 |
+
for i in range(256):
|
| 141 |
+
if i < thresh:
|
| 142 |
+
lut.append(min(255, i + add))
|
| 143 |
+
else:
|
| 144 |
+
lut.append(i)
|
| 145 |
+
if img.mode in ("L", "RGB"):
|
| 146 |
+
if img.mode == "RGB" and len(lut) == 256:
|
| 147 |
+
lut = lut + lut + lut
|
| 148 |
+
return img.point(lut)
|
| 149 |
+
else:
|
| 150 |
+
return img
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def posterize(img, bits_to_keep, **__):
|
| 154 |
+
if bits_to_keep >= 8:
|
| 155 |
+
return img
|
| 156 |
+
return ImageOps.posterize(img, bits_to_keep)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def contrast(img, factor, **__):
|
| 160 |
+
return ImageEnhance.Contrast(img).enhance(factor)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def color(img, factor, **__):
|
| 164 |
+
return ImageEnhance.Color(img).enhance(factor)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def brightness(img, factor, **__):
|
| 168 |
+
return ImageEnhance.Brightness(img).enhance(factor)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def sharpness(img, factor, **__):
|
| 172 |
+
return ImageEnhance.Sharpness(img).enhance(factor)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _randomly_negate(v):
|
| 176 |
+
"""With 50% prob, negate the value"""
|
| 177 |
+
return -v if random.random() > 0.5 else v
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _rotate_level_to_arg(level, _hparams):
|
| 181 |
+
# range [-30, 30]
|
| 182 |
+
level = (level / _MAX_LEVEL) * 30.
|
| 183 |
+
level = _randomly_negate(level)
|
| 184 |
+
return level,
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _enhance_level_to_arg(level, _hparams):
|
| 188 |
+
# range [0.1, 1.9]
|
| 189 |
+
return (level / _MAX_LEVEL) * 1.8 + 0.1,
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _enhance_increasing_level_to_arg(level, _hparams):
|
| 193 |
+
# the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
|
| 194 |
+
# range [0.1, 1.9]
|
| 195 |
+
level = (level / _MAX_LEVEL) * .9
|
| 196 |
+
level = 1.0 + _randomly_negate(level)
|
| 197 |
+
return level,
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _shear_level_to_arg(level, _hparams):
|
| 201 |
+
# range [-0.3, 0.3]
|
| 202 |
+
level = (level / _MAX_LEVEL) * 0.3
|
| 203 |
+
level = _randomly_negate(level)
|
| 204 |
+
return level,
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _translate_abs_level_to_arg(level, hparams):
|
| 208 |
+
translate_const = hparams['translate_const']
|
| 209 |
+
level = (level / _MAX_LEVEL) * float(translate_const)
|
| 210 |
+
level = _randomly_negate(level)
|
| 211 |
+
return level,
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def _translate_rel_level_to_arg(level, hparams):
|
| 215 |
+
# default range [-0.45, 0.45]
|
| 216 |
+
translate_pct = hparams.get('translate_pct', 0.45)
|
| 217 |
+
level = (level / _MAX_LEVEL) * translate_pct
|
| 218 |
+
level = _randomly_negate(level)
|
| 219 |
+
return level,
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def _posterize_level_to_arg(level, _hparams):
|
| 223 |
+
# As per Tensorflow TPU EfficientNet impl
|
| 224 |
+
# range [0, 4], 'keep 0 up to 4 MSB of original image'
|
| 225 |
+
# intensity/severity of augmentation decreases with level
|
| 226 |
+
return int((level / _MAX_LEVEL) * 4),
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _posterize_increasing_level_to_arg(level, hparams):
|
| 230 |
+
# As per Tensorflow models research and UDA impl
|
| 231 |
+
# range [4, 0], 'keep 4 down to 0 MSB of original image',
|
| 232 |
+
# intensity/severity of augmentation increases with level
|
| 233 |
+
return 4 - _posterize_level_to_arg(level, hparams)[0],
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _posterize_original_level_to_arg(level, _hparams):
|
| 237 |
+
# As per original AutoAugment paper description
|
| 238 |
+
# range [4, 8], 'keep 4 up to 8 MSB of image'
|
| 239 |
+
# intensity/severity of augmentation decreases with level
|
| 240 |
+
return int((level / _MAX_LEVEL) * 4) + 4,
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _solarize_level_to_arg(level, _hparams):
|
| 244 |
+
# range [0, 256]
|
| 245 |
+
# intensity/severity of augmentation decreases with level
|
| 246 |
+
return int((level / _MAX_LEVEL) * 256),
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _solarize_increasing_level_to_arg(level, _hparams):
|
| 250 |
+
# range [0, 256]
|
| 251 |
+
# intensity/severity of augmentation increases with level
|
| 252 |
+
return 256 - _solarize_level_to_arg(level, _hparams)[0],
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _solarize_add_level_to_arg(level, _hparams):
|
| 256 |
+
# range [0, 110]
|
| 257 |
+
return int((level / _MAX_LEVEL) * 110),
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
LEVEL_TO_ARG = {
|
| 261 |
+
'AutoContrast': None,
|
| 262 |
+
'Equalize': None,
|
| 263 |
+
'Invert': None,
|
| 264 |
+
'Rotate': _rotate_level_to_arg,
|
| 265 |
+
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
|
| 266 |
+
'Posterize': _posterize_level_to_arg,
|
| 267 |
+
'PosterizeIncreasing': _posterize_increasing_level_to_arg,
|
| 268 |
+
'PosterizeOriginal': _posterize_original_level_to_arg,
|
| 269 |
+
'Solarize': _solarize_level_to_arg,
|
| 270 |
+
'SolarizeIncreasing': _solarize_increasing_level_to_arg,
|
| 271 |
+
'SolarizeAdd': _solarize_add_level_to_arg,
|
| 272 |
+
'Color': _enhance_level_to_arg,
|
| 273 |
+
'ColorIncreasing': _enhance_increasing_level_to_arg,
|
| 274 |
+
'Contrast': _enhance_level_to_arg,
|
| 275 |
+
'ContrastIncreasing': _enhance_increasing_level_to_arg,
|
| 276 |
+
'Brightness': _enhance_level_to_arg,
|
| 277 |
+
'BrightnessIncreasing': _enhance_increasing_level_to_arg,
|
| 278 |
+
'Sharpness': _enhance_level_to_arg,
|
| 279 |
+
'SharpnessIncreasing': _enhance_increasing_level_to_arg,
|
| 280 |
+
'ShearX': _shear_level_to_arg,
|
| 281 |
+
'ShearY': _shear_level_to_arg,
|
| 282 |
+
'TranslateX': _translate_abs_level_to_arg,
|
| 283 |
+
'TranslateY': _translate_abs_level_to_arg,
|
| 284 |
+
'TranslateXRel': _translate_rel_level_to_arg,
|
| 285 |
+
'TranslateYRel': _translate_rel_level_to_arg,
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
NAME_TO_OP = {
|
| 290 |
+
'AutoContrast': auto_contrast,
|
| 291 |
+
'Equalize': equalize,
|
| 292 |
+
'Invert': invert,
|
| 293 |
+
'Rotate': rotate,
|
| 294 |
+
'Posterize': posterize,
|
| 295 |
+
'PosterizeIncreasing': posterize,
|
| 296 |
+
'PosterizeOriginal': posterize,
|
| 297 |
+
'Solarize': solarize,
|
| 298 |
+
'SolarizeIncreasing': solarize,
|
| 299 |
+
'SolarizeAdd': solarize_add,
|
| 300 |
+
'Color': color,
|
| 301 |
+
'ColorIncreasing': color,
|
| 302 |
+
'Contrast': contrast,
|
| 303 |
+
'ContrastIncreasing': contrast,
|
| 304 |
+
'Brightness': brightness,
|
| 305 |
+
'BrightnessIncreasing': brightness,
|
| 306 |
+
'Sharpness': sharpness,
|
| 307 |
+
'SharpnessIncreasing': sharpness,
|
| 308 |
+
'ShearX': shear_x,
|
| 309 |
+
'ShearY': shear_y,
|
| 310 |
+
'TranslateX': translate_x_abs,
|
| 311 |
+
'TranslateY': translate_y_abs,
|
| 312 |
+
'TranslateXRel': translate_x_rel,
|
| 313 |
+
'TranslateYRel': translate_y_rel,
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class AugmentOp:
|
| 318 |
+
|
| 319 |
+
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
| 320 |
+
hparams = hparams or _HPARAMS_DEFAULT
|
| 321 |
+
self.aug_fn = NAME_TO_OP[name]
|
| 322 |
+
self.level_fn = LEVEL_TO_ARG[name]
|
| 323 |
+
self.prob = prob
|
| 324 |
+
self.magnitude = magnitude
|
| 325 |
+
self.hparams = hparams.copy()
|
| 326 |
+
self.kwargs = dict(
|
| 327 |
+
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
|
| 328 |
+
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# If magnitude_std is > 0, we introduce some randomness
|
| 332 |
+
# in the usually fixed policy and sample magnitude from a normal distribution
|
| 333 |
+
# with mean `magnitude` and std-dev of `magnitude_std`.
|
| 334 |
+
# NOTE This is my own hack, being tested, not in papers or reference impls.
|
| 335 |
+
# If magnitude_std is inf, we sample magnitude from a uniform distribution
|
| 336 |
+
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
| 337 |
+
|
| 338 |
+
def __call__(self, img):
|
| 339 |
+
if self.prob < 1.0 and random.random() > self.prob:
|
| 340 |
+
return img
|
| 341 |
+
magnitude = self.magnitude
|
| 342 |
+
if self.magnitude_std:
|
| 343 |
+
if self.magnitude_std == float('inf'):
|
| 344 |
+
magnitude = random.uniform(0, magnitude)
|
| 345 |
+
elif self.magnitude_std > 0:
|
| 346 |
+
magnitude = random.gauss(magnitude, self.magnitude_std)
|
| 347 |
+
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
|
| 348 |
+
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
|
| 349 |
+
return self.aug_fn(img, *level_args, **self.kwargs)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def auto_augment_policy_v0(hparams):
|
| 353 |
+
# ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
|
| 354 |
+
policy = [
|
| 355 |
+
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
| 356 |
+
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
| 357 |
+
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
|
| 358 |
+
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
|
| 359 |
+
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
|
| 360 |
+
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
|
| 361 |
+
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
|
| 362 |
+
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
|
| 363 |
+
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
|
| 364 |
+
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
| 365 |
+
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
| 366 |
+
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
| 367 |
+
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
| 368 |
+
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
| 369 |
+
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
| 370 |
+
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
| 371 |
+
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
|
| 372 |
+
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
|
| 373 |
+
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
|
| 374 |
+
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
| 375 |
+
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
| 376 |
+
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
| 377 |
+
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
|
| 378 |
+
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
| 379 |
+
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
| 380 |
+
]
|
| 381 |
+
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
| 382 |
+
return pc
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def auto_augment_policy_v0r(hparams):
|
| 386 |
+
# ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
|
| 387 |
+
# in Google research implementation (number of bits discarded increases with magnitude)
|
| 388 |
+
policy = [
|
| 389 |
+
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
| 390 |
+
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
| 391 |
+
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
|
| 392 |
+
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
|
| 393 |
+
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
|
| 394 |
+
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
|
| 395 |
+
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
|
| 396 |
+
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
|
| 397 |
+
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
|
| 398 |
+
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
| 399 |
+
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
| 400 |
+
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
| 401 |
+
[('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
| 402 |
+
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
| 403 |
+
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
| 404 |
+
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
| 405 |
+
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
|
| 406 |
+
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
|
| 407 |
+
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
|
| 408 |
+
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
| 409 |
+
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
| 410 |
+
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
| 411 |
+
[('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
|
| 412 |
+
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
| 413 |
+
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
| 414 |
+
]
|
| 415 |
+
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
| 416 |
+
return pc
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def auto_augment_policy_original(hparams):
|
| 420 |
+
# ImageNet policy from https://arxiv.org/abs/1805.09501
|
| 421 |
+
policy = [
|
| 422 |
+
[('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
|
| 423 |
+
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
| 424 |
+
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
| 425 |
+
[('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
|
| 426 |
+
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
| 427 |
+
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
| 428 |
+
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
| 429 |
+
[('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
|
| 430 |
+
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
| 431 |
+
[('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
|
| 432 |
+
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
| 433 |
+
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
| 434 |
+
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
| 435 |
+
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
| 436 |
+
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
| 437 |
+
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
|
| 438 |
+
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
|
| 439 |
+
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
|
| 440 |
+
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
|
| 441 |
+
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
|
| 442 |
+
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
| 443 |
+
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
| 444 |
+
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
| 445 |
+
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
| 446 |
+
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
| 447 |
+
]
|
| 448 |
+
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
| 449 |
+
return pc
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def auto_augment_policy_originalr(hparams):
|
| 453 |
+
# ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
|
| 454 |
+
policy = [
|
| 455 |
+
[('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
|
| 456 |
+
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
| 457 |
+
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
| 458 |
+
[('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
|
| 459 |
+
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
| 460 |
+
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
| 461 |
+
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
| 462 |
+
[('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
|
| 463 |
+
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
| 464 |
+
[('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
|
| 465 |
+
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
| 466 |
+
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
| 467 |
+
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
| 468 |
+
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
| 469 |
+
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
| 470 |
+
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
|
| 471 |
+
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
|
| 472 |
+
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
|
| 473 |
+
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
|
| 474 |
+
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
|
| 475 |
+
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
| 476 |
+
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
| 477 |
+
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
| 478 |
+
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
| 479 |
+
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
| 480 |
+
]
|
| 481 |
+
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
| 482 |
+
return pc
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def auto_augment_policy(name='v0', hparams=None):
|
| 486 |
+
hparams = hparams or _HPARAMS_DEFAULT
|
| 487 |
+
if name == 'original':
|
| 488 |
+
return auto_augment_policy_original(hparams)
|
| 489 |
+
elif name == 'originalr':
|
| 490 |
+
return auto_augment_policy_originalr(hparams)
|
| 491 |
+
elif name == 'v0':
|
| 492 |
+
return auto_augment_policy_v0(hparams)
|
| 493 |
+
elif name == 'v0r':
|
| 494 |
+
return auto_augment_policy_v0r(hparams)
|
| 495 |
+
else:
|
| 496 |
+
assert False, 'Unknown AA policy (%s)' % name
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
class AutoAugment:
|
| 500 |
+
|
| 501 |
+
def __init__(self, policy):
|
| 502 |
+
self.policy = policy
|
| 503 |
+
|
| 504 |
+
def __call__(self, img):
|
| 505 |
+
sub_policy = random.choice(self.policy)
|
| 506 |
+
for op in sub_policy:
|
| 507 |
+
img = op(img)
|
| 508 |
+
return img
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def auto_augment_transform(config_str, hparams):
|
| 512 |
+
"""
|
| 513 |
+
Create a AutoAugment transform
|
| 514 |
+
|
| 515 |
+
:param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
|
| 516 |
+
dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
|
| 517 |
+
The remaining sections, not order sepecific determine
|
| 518 |
+
'mstd' - float std deviation of magnitude noise applied
|
| 519 |
+
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
|
| 520 |
+
|
| 521 |
+
:param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
|
| 522 |
+
|
| 523 |
+
:return: A PyTorch compatible Transform
|
| 524 |
+
"""
|
| 525 |
+
config = config_str.split('-')
|
| 526 |
+
policy_name = config[0]
|
| 527 |
+
config = config[1:]
|
| 528 |
+
for c in config:
|
| 529 |
+
cs = re.split(r'(\d.*)', c)
|
| 530 |
+
if len(cs) < 2:
|
| 531 |
+
continue
|
| 532 |
+
key, val = cs[:2]
|
| 533 |
+
if key == 'mstd':
|
| 534 |
+
# noise param injected via hparams for now
|
| 535 |
+
hparams.setdefault('magnitude_std', float(val))
|
| 536 |
+
else:
|
| 537 |
+
assert False, 'Unknown AutoAugment config section'
|
| 538 |
+
aa_policy = auto_augment_policy(policy_name, hparams=hparams)
|
| 539 |
+
return AutoAugment(aa_policy)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
_RAND_TRANSFORMS = [
|
| 543 |
+
'AutoContrast',
|
| 544 |
+
'Equalize',
|
| 545 |
+
'Invert',
|
| 546 |
+
'Rotate',
|
| 547 |
+
'Posterize',
|
| 548 |
+
'Solarize',
|
| 549 |
+
'SolarizeAdd',
|
| 550 |
+
'Color',
|
| 551 |
+
'Contrast',
|
| 552 |
+
'Brightness',
|
| 553 |
+
'Sharpness',
|
| 554 |
+
'ShearX',
|
| 555 |
+
'ShearY',
|
| 556 |
+
'TranslateXRel',
|
| 557 |
+
'TranslateYRel',
|
| 558 |
+
#'Cutout' # NOTE I've implement this as random erasing separately
|
| 559 |
+
]
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
_RAND_INCREASING_TRANSFORMS = [
|
| 563 |
+
'AutoContrast',
|
| 564 |
+
'Equalize',
|
| 565 |
+
'Invert',
|
| 566 |
+
'Rotate',
|
| 567 |
+
'PosterizeIncreasing',
|
| 568 |
+
'SolarizeIncreasing',
|
| 569 |
+
'SolarizeAdd',
|
| 570 |
+
'ColorIncreasing',
|
| 571 |
+
'ContrastIncreasing',
|
| 572 |
+
'BrightnessIncreasing',
|
| 573 |
+
'SharpnessIncreasing',
|
| 574 |
+
'ShearX',
|
| 575 |
+
'ShearY',
|
| 576 |
+
'TranslateXRel',
|
| 577 |
+
'TranslateYRel',
|
| 578 |
+
#'Cutout' # NOTE I've implement this as random erasing separately
|
| 579 |
+
]
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
# These experimental weights are based loosely on the relative improvements mentioned in paper.
|
| 584 |
+
# They may not result in increased performance, but could likely be tuned to so.
|
| 585 |
+
_RAND_CHOICE_WEIGHTS_0 = {
|
| 586 |
+
'Rotate': 0.3,
|
| 587 |
+
'ShearX': 0.2,
|
| 588 |
+
'ShearY': 0.2,
|
| 589 |
+
'TranslateXRel': 0.1,
|
| 590 |
+
'TranslateYRel': 0.1,
|
| 591 |
+
'Color': .025,
|
| 592 |
+
'Sharpness': 0.025,
|
| 593 |
+
'AutoContrast': 0.025,
|
| 594 |
+
'Solarize': .005,
|
| 595 |
+
'SolarizeAdd': .005,
|
| 596 |
+
'Contrast': .005,
|
| 597 |
+
'Brightness': .005,
|
| 598 |
+
'Equalize': .005,
|
| 599 |
+
'Posterize': 0,
|
| 600 |
+
'Invert': 0,
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def _select_rand_weights(weight_idx=0, transforms=None):
|
| 605 |
+
transforms = transforms or _RAND_TRANSFORMS
|
| 606 |
+
assert weight_idx == 0 # only one set of weights currently
|
| 607 |
+
rand_weights = _RAND_CHOICE_WEIGHTS_0
|
| 608 |
+
probs = [rand_weights[k] for k in transforms]
|
| 609 |
+
probs /= np.sum(probs)
|
| 610 |
+
return probs
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
|
| 614 |
+
hparams = hparams or _HPARAMS_DEFAULT
|
| 615 |
+
transforms = transforms or _RAND_TRANSFORMS
|
| 616 |
+
return [AugmentOp(
|
| 617 |
+
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
class RandAugment:
|
| 621 |
+
def __init__(self, ops, num_layers=2, choice_weights=None):
|
| 622 |
+
self.ops = ops
|
| 623 |
+
self.num_layers = num_layers
|
| 624 |
+
self.choice_weights = choice_weights
|
| 625 |
+
|
| 626 |
+
def __call__(self, img):
|
| 627 |
+
# no replacement when using weighted choice
|
| 628 |
+
ops = np.random.choice(
|
| 629 |
+
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
|
| 630 |
+
for op in ops:
|
| 631 |
+
img = op(img)
|
| 632 |
+
return img
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def rand_augment_transform(config_str, hparams):
|
| 636 |
+
"""
|
| 637 |
+
Create a RandAugment transform
|
| 638 |
+
|
| 639 |
+
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
| 640 |
+
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
| 641 |
+
sections, not order sepecific determine
|
| 642 |
+
'm' - integer magnitude of rand augment
|
| 643 |
+
'n' - integer num layers (number of transform ops selected per image)
|
| 644 |
+
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
|
| 645 |
+
'mstd' - float std deviation of magnitude noise applied
|
| 646 |
+
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
| 647 |
+
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
| 648 |
+
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
|
| 649 |
+
|
| 650 |
+
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme
|
| 651 |
+
|
| 652 |
+
:return: A PyTorch compatible Transform
|
| 653 |
+
"""
|
| 654 |
+
magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
|
| 655 |
+
num_layers = 2 # default to 2 ops per image
|
| 656 |
+
weight_idx = None # default to no probability weights for op choice
|
| 657 |
+
transforms = _RAND_TRANSFORMS
|
| 658 |
+
config = config_str.split('-')
|
| 659 |
+
assert config[0] == 'rand'
|
| 660 |
+
config = config[1:]
|
| 661 |
+
for c in config:
|
| 662 |
+
cs = re.split(r'(\d.*)', c)
|
| 663 |
+
if len(cs) < 2:
|
| 664 |
+
continue
|
| 665 |
+
key, val = cs[:2]
|
| 666 |
+
if key == 'mstd':
|
| 667 |
+
# noise param injected via hparams for now
|
| 668 |
+
hparams.setdefault('magnitude_std', float(val))
|
| 669 |
+
elif key == 'inc':
|
| 670 |
+
if bool(val):
|
| 671 |
+
transforms = _RAND_INCREASING_TRANSFORMS
|
| 672 |
+
elif key == 'm':
|
| 673 |
+
magnitude = int(val)
|
| 674 |
+
elif key == 'n':
|
| 675 |
+
num_layers = int(val)
|
| 676 |
+
elif key == 'w':
|
| 677 |
+
weight_idx = int(val)
|
| 678 |
+
else:
|
| 679 |
+
assert False, 'Unknown RandAugment config section'
|
| 680 |
+
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
|
| 681 |
+
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
|
| 682 |
+
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
_AUGMIX_TRANSFORMS = [
|
| 686 |
+
'AutoContrast',
|
| 687 |
+
'ColorIncreasing', # not in paper
|
| 688 |
+
'ContrastIncreasing', # not in paper
|
| 689 |
+
'BrightnessIncreasing', # not in paper
|
| 690 |
+
'SharpnessIncreasing', # not in paper
|
| 691 |
+
'Equalize',
|
| 692 |
+
'Rotate',
|
| 693 |
+
'PosterizeIncreasing',
|
| 694 |
+
'SolarizeIncreasing',
|
| 695 |
+
'ShearX',
|
| 696 |
+
'ShearY',
|
| 697 |
+
'TranslateXRel',
|
| 698 |
+
'TranslateYRel',
|
| 699 |
+
]
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
def augmix_ops(magnitude=10, hparams=None, transforms=None):
|
| 703 |
+
hparams = hparams or _HPARAMS_DEFAULT
|
| 704 |
+
transforms = transforms or _AUGMIX_TRANSFORMS
|
| 705 |
+
return [AugmentOp(
|
| 706 |
+
name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
class AugMixAugment:
|
| 710 |
+
""" AugMix Transform
|
| 711 |
+
Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
|
| 712 |
+
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
|
| 713 |
+
https://arxiv.org/abs/1912.02781
|
| 714 |
+
"""
|
| 715 |
+
def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
|
| 716 |
+
self.ops = ops
|
| 717 |
+
self.alpha = alpha
|
| 718 |
+
self.width = width
|
| 719 |
+
self.depth = depth
|
| 720 |
+
self.blended = blended # blended mode is faster but not well tested
|
| 721 |
+
|
| 722 |
+
def _calc_blended_weights(self, ws, m):
|
| 723 |
+
ws = ws * m
|
| 724 |
+
cump = 1.
|
| 725 |
+
rws = []
|
| 726 |
+
for w in ws[::-1]:
|
| 727 |
+
alpha = w / cump
|
| 728 |
+
cump *= (1 - alpha)
|
| 729 |
+
rws.append(alpha)
|
| 730 |
+
return np.array(rws[::-1], dtype=np.float32)
|
| 731 |
+
|
| 732 |
+
def _apply_blended(self, img, mixing_weights, m):
|
| 733 |
+
# This is my first crack and implementing a slightly faster mixed augmentation. Instead
|
| 734 |
+
# of accumulating the mix for each chain in a Numpy array and then blending with original,
|
| 735 |
+
# it recomputes the blending coefficients and applies one PIL image blend per chain.
|
| 736 |
+
# TODO the results appear in the right ballpark but they differ by more than rounding.
|
| 737 |
+
img_orig = img.copy()
|
| 738 |
+
ws = self._calc_blended_weights(mixing_weights, m)
|
| 739 |
+
for w in ws:
|
| 740 |
+
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
|
| 741 |
+
ops = np.random.choice(self.ops, depth, replace=True)
|
| 742 |
+
img_aug = img_orig # no ops are in-place, deep copy not necessary
|
| 743 |
+
for op in ops:
|
| 744 |
+
img_aug = op(img_aug)
|
| 745 |
+
img = Image.blend(img, img_aug, w)
|
| 746 |
+
return img
|
| 747 |
+
|
| 748 |
+
def _apply_basic(self, img, mixing_weights, m):
|
| 749 |
+
# This is a literal adaptation of the paper/official implementation without normalizations and
|
| 750 |
+
# PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
|
| 751 |
+
# typical augmentation transforms, could use a GPU / Kornia implementation.
|
| 752 |
+
img_shape = img.size[0], img.size[1], len(img.getbands())
|
| 753 |
+
mixed = np.zeros(img_shape, dtype=np.float32)
|
| 754 |
+
for mw in mixing_weights:
|
| 755 |
+
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
|
| 756 |
+
ops = np.random.choice(self.ops, depth, replace=True)
|
| 757 |
+
img_aug = img # no ops are in-place, deep copy not necessary
|
| 758 |
+
for op in ops:
|
| 759 |
+
img_aug = op(img_aug)
|
| 760 |
+
mixed += mw * np.asarray(img_aug, dtype=np.float32)
|
| 761 |
+
np.clip(mixed, 0, 255., out=mixed)
|
| 762 |
+
mixed = Image.fromarray(mixed.astype(np.uint8))
|
| 763 |
+
return Image.blend(img, mixed, m)
|
| 764 |
+
|
| 765 |
+
def __call__(self, img):
|
| 766 |
+
mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
|
| 767 |
+
m = np.float32(np.random.beta(self.alpha, self.alpha))
|
| 768 |
+
if self.blended:
|
| 769 |
+
mixed = self._apply_blended(img, mixing_weights, m)
|
| 770 |
+
else:
|
| 771 |
+
mixed = self._apply_basic(img, mixing_weights, m)
|
| 772 |
+
return mixed
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
def augment_and_mix_transform(config_str, hparams):
|
| 776 |
+
""" Create AugMix PyTorch transform
|
| 777 |
+
|
| 778 |
+
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
| 779 |
+
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
| 780 |
+
sections, not order sepecific determine
|
| 781 |
+
'm' - integer magnitude (severity) of augmentation mix (default: 3)
|
| 782 |
+
'w' - integer width of augmentation chain (default: 3)
|
| 783 |
+
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
|
| 784 |
+
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
|
| 785 |
+
'mstd' - float std deviation of magnitude noise applied (default: 0)
|
| 786 |
+
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
|
| 787 |
+
|
| 788 |
+
:param hparams: Other hparams (kwargs) for the Augmentation transforms
|
| 789 |
+
|
| 790 |
+
:return: A PyTorch compatible Transform
|
| 791 |
+
"""
|
| 792 |
+
magnitude = 3
|
| 793 |
+
width = 3
|
| 794 |
+
depth = -1
|
| 795 |
+
alpha = 1.
|
| 796 |
+
blended = False
|
| 797 |
+
hparams['magnitude_std'] = float('inf')
|
| 798 |
+
config = config_str.split('-')
|
| 799 |
+
assert config[0] == 'augmix'
|
| 800 |
+
config = config[1:]
|
| 801 |
+
for c in config:
|
| 802 |
+
cs = re.split(r'(\d.*)', c)
|
| 803 |
+
if len(cs) < 2:
|
| 804 |
+
continue
|
| 805 |
+
key, val = cs[:2]
|
| 806 |
+
if key == 'mstd':
|
| 807 |
+
# noise param injected via hparams for now
|
| 808 |
+
hparams.setdefault('magnitude_std', float(val))
|
| 809 |
+
elif key == 'm':
|
| 810 |
+
magnitude = int(val)
|
| 811 |
+
elif key == 'w':
|
| 812 |
+
width = int(val)
|
| 813 |
+
elif key == 'd':
|
| 814 |
+
depth = int(val)
|
| 815 |
+
elif key == 'a':
|
| 816 |
+
alpha = float(val)
|
| 817 |
+
elif key == 'b':
|
| 818 |
+
blended = bool(val)
|
| 819 |
+
else:
|
| 820 |
+
assert False, 'Unknown AugMix config section'
|
| 821 |
+
ops = augmix_ops(magnitude=magnitude, hparams=hparams)
|
| 822 |
+
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
|
timm/data/config.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from .constants import *
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
_logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
|
| 9 |
+
new_config = {}
|
| 10 |
+
default_cfg = default_cfg
|
| 11 |
+
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
|
| 12 |
+
default_cfg = model.default_cfg
|
| 13 |
+
|
| 14 |
+
# Resolve input/image size
|
| 15 |
+
in_chans = 3
|
| 16 |
+
if 'chans' in args and args['chans'] is not None:
|
| 17 |
+
in_chans = args['chans']
|
| 18 |
+
|
| 19 |
+
input_size = (in_chans, 224, 224)
|
| 20 |
+
if 'input_size' in args and args['input_size'] is not None:
|
| 21 |
+
assert isinstance(args['input_size'], (tuple, list))
|
| 22 |
+
assert len(args['input_size']) == 3
|
| 23 |
+
input_size = tuple(args['input_size'])
|
| 24 |
+
in_chans = input_size[0] # input_size overrides in_chans
|
| 25 |
+
elif 'img_size' in args and args['img_size'] is not None:
|
| 26 |
+
assert isinstance(args['img_size'], int)
|
| 27 |
+
input_size = (in_chans, args['img_size'], args['img_size'])
|
| 28 |
+
else:
|
| 29 |
+
if use_test_size and 'test_input_size' in default_cfg:
|
| 30 |
+
input_size = default_cfg['test_input_size']
|
| 31 |
+
elif 'input_size' in default_cfg:
|
| 32 |
+
input_size = default_cfg['input_size']
|
| 33 |
+
new_config['input_size'] = input_size
|
| 34 |
+
|
| 35 |
+
# resolve interpolation method
|
| 36 |
+
new_config['interpolation'] = 'bicubic'
|
| 37 |
+
if 'interpolation' in args and args['interpolation']:
|
| 38 |
+
new_config['interpolation'] = args['interpolation']
|
| 39 |
+
elif 'interpolation' in default_cfg:
|
| 40 |
+
new_config['interpolation'] = default_cfg['interpolation']
|
| 41 |
+
|
| 42 |
+
# resolve dataset + model mean for normalization
|
| 43 |
+
new_config['mean'] = IMAGENET_DEFAULT_MEAN
|
| 44 |
+
if 'mean' in args and args['mean'] is not None:
|
| 45 |
+
mean = tuple(args['mean'])
|
| 46 |
+
if len(mean) == 1:
|
| 47 |
+
mean = tuple(list(mean) * in_chans)
|
| 48 |
+
else:
|
| 49 |
+
assert len(mean) == in_chans
|
| 50 |
+
new_config['mean'] = mean
|
| 51 |
+
elif 'mean' in default_cfg:
|
| 52 |
+
new_config['mean'] = default_cfg['mean']
|
| 53 |
+
|
| 54 |
+
# resolve dataset + model std deviation for normalization
|
| 55 |
+
new_config['std'] = IMAGENET_DEFAULT_STD
|
| 56 |
+
if 'std' in args and args['std'] is not None:
|
| 57 |
+
std = tuple(args['std'])
|
| 58 |
+
if len(std) == 1:
|
| 59 |
+
std = tuple(list(std) * in_chans)
|
| 60 |
+
else:
|
| 61 |
+
assert len(std) == in_chans
|
| 62 |
+
new_config['std'] = std
|
| 63 |
+
elif 'std' in default_cfg:
|
| 64 |
+
new_config['std'] = default_cfg['std']
|
| 65 |
+
|
| 66 |
+
# resolve default crop percentage
|
| 67 |
+
new_config['crop_pct'] = DEFAULT_CROP_PCT
|
| 68 |
+
if 'crop_pct' in args and args['crop_pct'] is not None:
|
| 69 |
+
new_config['crop_pct'] = args['crop_pct']
|
| 70 |
+
elif 'crop_pct' in default_cfg:
|
| 71 |
+
new_config['crop_pct'] = default_cfg['crop_pct']
|
| 72 |
+
|
| 73 |
+
if verbose:
|
| 74 |
+
_logger.info('Data processing configuration for current model + dataset:')
|
| 75 |
+
for n, v in new_config.items():
|
| 76 |
+
_logger.info('\t%s: %s' % (n, str(v)))
|
| 77 |
+
|
| 78 |
+
return new_config
|
timm/data/constants.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DEFAULT_CROP_PCT = 0.875
|
| 2 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
| 3 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
| 4 |
+
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
| 5 |
+
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
| 6 |
+
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
|
| 7 |
+
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
|
timm/data/dataset.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Quick n Simple Image Folder, Tarfile based DataSet
|
| 2 |
+
|
| 3 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 4 |
+
"""
|
| 5 |
+
import torch.utils.data as data
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from .parsers import create_parser
|
| 13 |
+
|
| 14 |
+
_logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
_ERROR_RETRY = 50
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ImageDataset(data.Dataset):
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
root,
|
| 25 |
+
parser=None,
|
| 26 |
+
class_map='',
|
| 27 |
+
load_bytes=False,
|
| 28 |
+
transform=None,
|
| 29 |
+
):
|
| 30 |
+
if parser is None or isinstance(parser, str):
|
| 31 |
+
parser = create_parser(parser or '', root=root, class_map=class_map)
|
| 32 |
+
self.parser = parser
|
| 33 |
+
self.load_bytes = load_bytes
|
| 34 |
+
self.transform = transform
|
| 35 |
+
self._consecutive_errors = 0
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, index):
|
| 38 |
+
img, target = self.parser[index]
|
| 39 |
+
try:
|
| 40 |
+
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
|
| 41 |
+
except Exception as e:
|
| 42 |
+
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
|
| 43 |
+
self._consecutive_errors += 1
|
| 44 |
+
if self._consecutive_errors < _ERROR_RETRY:
|
| 45 |
+
return self.__getitem__((index + 1) % len(self.parser))
|
| 46 |
+
else:
|
| 47 |
+
raise e
|
| 48 |
+
self._consecutive_errors = 0
|
| 49 |
+
if self.transform is not None:
|
| 50 |
+
img = self.transform(img)
|
| 51 |
+
if target is None:
|
| 52 |
+
target = torch.tensor(-1, dtype=torch.long)
|
| 53 |
+
return img, target
|
| 54 |
+
|
| 55 |
+
def __len__(self):
|
| 56 |
+
return len(self.parser)
|
| 57 |
+
|
| 58 |
+
def filename(self, index, basename=False, absolute=False):
|
| 59 |
+
return self.parser.filename(index, basename, absolute)
|
| 60 |
+
|
| 61 |
+
def filenames(self, basename=False, absolute=False):
|
| 62 |
+
return self.parser.filenames(basename, absolute)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class IterableImageDataset(data.IterableDataset):
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
root,
|
| 70 |
+
parser=None,
|
| 71 |
+
split='train',
|
| 72 |
+
is_training=False,
|
| 73 |
+
batch_size=None,
|
| 74 |
+
class_map='',
|
| 75 |
+
load_bytes=False,
|
| 76 |
+
repeats=0,
|
| 77 |
+
transform=None,
|
| 78 |
+
):
|
| 79 |
+
assert parser is not None
|
| 80 |
+
if isinstance(parser, str):
|
| 81 |
+
self.parser = create_parser(
|
| 82 |
+
parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats)
|
| 83 |
+
else:
|
| 84 |
+
self.parser = parser
|
| 85 |
+
self.transform = transform
|
| 86 |
+
self._consecutive_errors = 0
|
| 87 |
+
|
| 88 |
+
def __iter__(self):
|
| 89 |
+
for img, target in self.parser:
|
| 90 |
+
if self.transform is not None:
|
| 91 |
+
img = self.transform(img)
|
| 92 |
+
if target is None:
|
| 93 |
+
target = torch.tensor(-1, dtype=torch.long)
|
| 94 |
+
yield img, target
|
| 95 |
+
|
| 96 |
+
def __len__(self):
|
| 97 |
+
if hasattr(self.parser, '__len__'):
|
| 98 |
+
return len(self.parser)
|
| 99 |
+
else:
|
| 100 |
+
return 0
|
| 101 |
+
|
| 102 |
+
def filename(self, index, basename=False, absolute=False):
|
| 103 |
+
assert False, 'Filename lookup by index not supported, use filenames().'
|
| 104 |
+
|
| 105 |
+
def filenames(self, basename=False, absolute=False):
|
| 106 |
+
return self.parser.filenames(basename, absolute)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class AugMixDataset(torch.utils.data.Dataset):
|
| 110 |
+
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
|
| 111 |
+
|
| 112 |
+
def __init__(self, dataset, num_splits=2):
|
| 113 |
+
self.augmentation = None
|
| 114 |
+
self.normalize = None
|
| 115 |
+
self.dataset = dataset
|
| 116 |
+
if self.dataset.transform is not None:
|
| 117 |
+
self._set_transforms(self.dataset.transform)
|
| 118 |
+
self.num_splits = num_splits
|
| 119 |
+
|
| 120 |
+
def _set_transforms(self, x):
|
| 121 |
+
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
|
| 122 |
+
self.dataset.transform = x[0]
|
| 123 |
+
self.augmentation = x[1]
|
| 124 |
+
self.normalize = x[2]
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def transform(self):
|
| 128 |
+
return self.dataset.transform
|
| 129 |
+
|
| 130 |
+
@transform.setter
|
| 131 |
+
def transform(self, x):
|
| 132 |
+
self._set_transforms(x)
|
| 133 |
+
|
| 134 |
+
def _normalize(self, x):
|
| 135 |
+
return x if self.normalize is None else self.normalize(x)
|
| 136 |
+
|
| 137 |
+
def __getitem__(self, i):
|
| 138 |
+
x, y = self.dataset[i] # all splits share the same dataset base transform
|
| 139 |
+
x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
|
| 140 |
+
# run the full augmentation on the remaining splits
|
| 141 |
+
for _ in range(self.num_splits - 1):
|
| 142 |
+
x_list.append(self._normalize(self.augmentation(x)))
|
| 143 |
+
return tuple(x_list), y
|
| 144 |
+
|
| 145 |
+
def __len__(self):
|
| 146 |
+
return len(self.dataset)
|
timm/data/dataset_factory.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from .dataset import IterableImageDataset, ImageDataset
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _search_split(root, split):
|
| 7 |
+
# look for sub-folder with name of split in root and use that if it exists
|
| 8 |
+
split_name = split.split('[')[0]
|
| 9 |
+
try_root = os.path.join(root, split_name)
|
| 10 |
+
if os.path.exists(try_root):
|
| 11 |
+
return try_root
|
| 12 |
+
if split_name == 'validation':
|
| 13 |
+
try_root = os.path.join(root, 'val')
|
| 14 |
+
if os.path.exists(try_root):
|
| 15 |
+
return try_root
|
| 16 |
+
return root
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs):
|
| 20 |
+
name = name.lower()
|
| 21 |
+
if name.startswith('tfds'):
|
| 22 |
+
ds = IterableImageDataset(
|
| 23 |
+
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
|
| 24 |
+
else:
|
| 25 |
+
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
| 26 |
+
kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
|
| 27 |
+
if search_split and os.path.isdir(root):
|
| 28 |
+
root = _search_split(root, split)
|
| 29 |
+
ds = ImageDataset(root, parser=name, **kwargs)
|
| 30 |
+
return ds
|
timm/data/distributed_sampler.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import Sampler
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class OrderedDistributedSampler(Sampler):
|
| 8 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
| 9 |
+
It is especially useful in conjunction with
|
| 10 |
+
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
|
| 11 |
+
process can pass a DistributedSampler instance as a DataLoader sampler,
|
| 12 |
+
and load a subset of the original dataset that is exclusive to it.
|
| 13 |
+
.. note::
|
| 14 |
+
Dataset is assumed to be of constant size.
|
| 15 |
+
Arguments:
|
| 16 |
+
dataset: Dataset used for sampling.
|
| 17 |
+
num_replicas (optional): Number of processes participating in
|
| 18 |
+
distributed training.
|
| 19 |
+
rank (optional): Rank of the current process within num_replicas.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, dataset, num_replicas=None, rank=None):
|
| 23 |
+
if num_replicas is None:
|
| 24 |
+
if not dist.is_available():
|
| 25 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 26 |
+
num_replicas = dist.get_world_size()
|
| 27 |
+
if rank is None:
|
| 28 |
+
if not dist.is_available():
|
| 29 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 30 |
+
rank = dist.get_rank()
|
| 31 |
+
self.dataset = dataset
|
| 32 |
+
self.num_replicas = num_replicas
|
| 33 |
+
self.rank = rank
|
| 34 |
+
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
| 35 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 36 |
+
|
| 37 |
+
def __iter__(self):
|
| 38 |
+
indices = list(range(len(self.dataset)))
|
| 39 |
+
|
| 40 |
+
# add extra samples to make it evenly divisible
|
| 41 |
+
indices += indices[:(self.total_size - len(indices))]
|
| 42 |
+
assert len(indices) == self.total_size
|
| 43 |
+
|
| 44 |
+
# subsample
|
| 45 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
| 46 |
+
assert len(indices) == self.num_samples
|
| 47 |
+
|
| 48 |
+
return iter(indices)
|
| 49 |
+
|
| 50 |
+
def __len__(self):
|
| 51 |
+
return self.num_samples
|
timm/data/loader.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Loader Factory, Fast Collate, CUDA Prefetcher
|
| 2 |
+
|
| 3 |
+
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
|
| 4 |
+
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
|
| 5 |
+
|
| 6 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from .transforms_factory import create_transform
|
| 13 |
+
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 14 |
+
from .distributed_sampler import OrderedDistributedSampler
|
| 15 |
+
from .random_erasing import RandomErasing
|
| 16 |
+
from .mixup import FastCollateMixup
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def fast_collate(batch):
|
| 20 |
+
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
|
| 21 |
+
assert isinstance(batch[0], tuple)
|
| 22 |
+
batch_size = len(batch)
|
| 23 |
+
if isinstance(batch[0][0], tuple):
|
| 24 |
+
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
|
| 25 |
+
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
|
| 26 |
+
inner_tuple_size = len(batch[0][0])
|
| 27 |
+
flattened_batch_size = batch_size * inner_tuple_size
|
| 28 |
+
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
|
| 29 |
+
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
|
| 30 |
+
for i in range(batch_size):
|
| 31 |
+
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
|
| 32 |
+
for j in range(inner_tuple_size):
|
| 33 |
+
targets[i + j * batch_size] = batch[i][1]
|
| 34 |
+
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
|
| 35 |
+
return tensor, targets
|
| 36 |
+
elif isinstance(batch[0][0], np.ndarray):
|
| 37 |
+
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
| 38 |
+
assert len(targets) == batch_size
|
| 39 |
+
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
| 40 |
+
for i in range(batch_size):
|
| 41 |
+
tensor[i] += torch.from_numpy(batch[i][0])
|
| 42 |
+
return tensor, targets
|
| 43 |
+
elif isinstance(batch[0][0], torch.Tensor):
|
| 44 |
+
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
| 45 |
+
assert len(targets) == batch_size
|
| 46 |
+
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
| 47 |
+
for i in range(batch_size):
|
| 48 |
+
tensor[i].copy_(batch[i][0])
|
| 49 |
+
return tensor, targets
|
| 50 |
+
else:
|
| 51 |
+
assert False
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class PrefetchLoader:
|
| 55 |
+
|
| 56 |
+
def __init__(self,
|
| 57 |
+
loader,
|
| 58 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
| 59 |
+
std=IMAGENET_DEFAULT_STD,
|
| 60 |
+
fp16=False,
|
| 61 |
+
re_prob=0.,
|
| 62 |
+
re_mode='const',
|
| 63 |
+
re_count=1,
|
| 64 |
+
re_num_splits=0):
|
| 65 |
+
self.loader = loader
|
| 66 |
+
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
| 67 |
+
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
| 68 |
+
self.fp16 = fp16
|
| 69 |
+
if fp16:
|
| 70 |
+
self.mean = self.mean.half()
|
| 71 |
+
self.std = self.std.half()
|
| 72 |
+
if re_prob > 0.:
|
| 73 |
+
self.random_erasing = RandomErasing(
|
| 74 |
+
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
|
| 75 |
+
else:
|
| 76 |
+
self.random_erasing = None
|
| 77 |
+
|
| 78 |
+
def __iter__(self):
|
| 79 |
+
stream = torch.cuda.Stream()
|
| 80 |
+
first = True
|
| 81 |
+
|
| 82 |
+
for next_input, next_target in self.loader:
|
| 83 |
+
with torch.cuda.stream(stream):
|
| 84 |
+
next_input = next_input.cuda(non_blocking=True)
|
| 85 |
+
next_target = next_target.cuda(non_blocking=True)
|
| 86 |
+
if self.fp16:
|
| 87 |
+
next_input = next_input.half().sub_(self.mean).div_(self.std)
|
| 88 |
+
else:
|
| 89 |
+
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
| 90 |
+
if self.random_erasing is not None:
|
| 91 |
+
next_input = self.random_erasing(next_input)
|
| 92 |
+
|
| 93 |
+
if not first:
|
| 94 |
+
yield input, target
|
| 95 |
+
else:
|
| 96 |
+
first = False
|
| 97 |
+
|
| 98 |
+
torch.cuda.current_stream().wait_stream(stream)
|
| 99 |
+
input = next_input
|
| 100 |
+
target = next_target
|
| 101 |
+
|
| 102 |
+
yield input, target
|
| 103 |
+
|
| 104 |
+
def __len__(self):
|
| 105 |
+
return len(self.loader)
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def sampler(self):
|
| 109 |
+
return self.loader.sampler
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def dataset(self):
|
| 113 |
+
return self.loader.dataset
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def mixup_enabled(self):
|
| 117 |
+
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
| 118 |
+
return self.loader.collate_fn.mixup_enabled
|
| 119 |
+
else:
|
| 120 |
+
return False
|
| 121 |
+
|
| 122 |
+
@mixup_enabled.setter
|
| 123 |
+
def mixup_enabled(self, x):
|
| 124 |
+
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
| 125 |
+
self.loader.collate_fn.mixup_enabled = x
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def create_loader(
|
| 129 |
+
dataset,
|
| 130 |
+
input_size,
|
| 131 |
+
batch_size,
|
| 132 |
+
is_training=False,
|
| 133 |
+
use_prefetcher=True,
|
| 134 |
+
no_aug=False,
|
| 135 |
+
re_prob=0.,
|
| 136 |
+
re_mode='const',
|
| 137 |
+
re_count=1,
|
| 138 |
+
re_split=False,
|
| 139 |
+
scale=None,
|
| 140 |
+
ratio=None,
|
| 141 |
+
hflip=0.5,
|
| 142 |
+
vflip=0.,
|
| 143 |
+
color_jitter=0.4,
|
| 144 |
+
auto_augment=None,
|
| 145 |
+
num_aug_splits=0,
|
| 146 |
+
interpolation='bilinear',
|
| 147 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
| 148 |
+
std=IMAGENET_DEFAULT_STD,
|
| 149 |
+
num_workers=1,
|
| 150 |
+
distributed=False,
|
| 151 |
+
crop_pct=None,
|
| 152 |
+
collate_fn=None,
|
| 153 |
+
pin_memory=False,
|
| 154 |
+
fp16=False,
|
| 155 |
+
tf_preprocessing=False,
|
| 156 |
+
use_multi_epochs_loader=False,
|
| 157 |
+
persistent_workers=True,
|
| 158 |
+
):
|
| 159 |
+
re_num_splits = 0
|
| 160 |
+
if re_split:
|
| 161 |
+
# apply RE to second half of batch if no aug split otherwise line up with aug split
|
| 162 |
+
re_num_splits = num_aug_splits or 2
|
| 163 |
+
dataset.transform = create_transform(
|
| 164 |
+
input_size,
|
| 165 |
+
is_training=is_training,
|
| 166 |
+
use_prefetcher=use_prefetcher,
|
| 167 |
+
no_aug=no_aug,
|
| 168 |
+
scale=scale,
|
| 169 |
+
ratio=ratio,
|
| 170 |
+
hflip=hflip,
|
| 171 |
+
vflip=vflip,
|
| 172 |
+
color_jitter=color_jitter,
|
| 173 |
+
auto_augment=auto_augment,
|
| 174 |
+
interpolation=interpolation,
|
| 175 |
+
mean=mean,
|
| 176 |
+
std=std,
|
| 177 |
+
crop_pct=crop_pct,
|
| 178 |
+
tf_preprocessing=tf_preprocessing,
|
| 179 |
+
re_prob=re_prob,
|
| 180 |
+
re_mode=re_mode,
|
| 181 |
+
re_count=re_count,
|
| 182 |
+
re_num_splits=re_num_splits,
|
| 183 |
+
separate=num_aug_splits > 0,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
sampler = None
|
| 187 |
+
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
|
| 188 |
+
if is_training:
|
| 189 |
+
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
| 190 |
+
else:
|
| 191 |
+
# This will add extra duplicate entries to result in equal num
|
| 192 |
+
# of samples per-process, will slightly alter validation results
|
| 193 |
+
sampler = OrderedDistributedSampler(dataset)
|
| 194 |
+
|
| 195 |
+
if collate_fn is None:
|
| 196 |
+
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
|
| 197 |
+
|
| 198 |
+
loader_class = torch.utils.data.DataLoader
|
| 199 |
+
|
| 200 |
+
if use_multi_epochs_loader:
|
| 201 |
+
loader_class = MultiEpochsDataLoader
|
| 202 |
+
|
| 203 |
+
loader_args = dict(
|
| 204 |
+
batch_size=batch_size,
|
| 205 |
+
shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
|
| 206 |
+
num_workers=num_workers,
|
| 207 |
+
sampler=sampler,
|
| 208 |
+
collate_fn=collate_fn,
|
| 209 |
+
pin_memory=pin_memory,
|
| 210 |
+
drop_last=is_training,
|
| 211 |
+
persistent_workers=persistent_workers)
|
| 212 |
+
try:
|
| 213 |
+
loader = loader_class(dataset, **loader_args)
|
| 214 |
+
except TypeError as e:
|
| 215 |
+
loader_args.pop('persistent_workers') # only in Pytorch 1.7+
|
| 216 |
+
loader = loader_class(dataset, **loader_args)
|
| 217 |
+
if use_prefetcher:
|
| 218 |
+
prefetch_re_prob = re_prob if is_training and not no_aug else 0.
|
| 219 |
+
loader = PrefetchLoader(
|
| 220 |
+
loader,
|
| 221 |
+
mean=mean,
|
| 222 |
+
std=std,
|
| 223 |
+
fp16=fp16,
|
| 224 |
+
re_prob=prefetch_re_prob,
|
| 225 |
+
re_mode=re_mode,
|
| 226 |
+
re_count=re_count,
|
| 227 |
+
re_num_splits=re_num_splits
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
return loader
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
| 234 |
+
|
| 235 |
+
def __init__(self, *args, **kwargs):
|
| 236 |
+
super().__init__(*args, **kwargs)
|
| 237 |
+
self._DataLoader__initialized = False
|
| 238 |
+
self.batch_sampler = _RepeatSampler(self.batch_sampler)
|
| 239 |
+
self._DataLoader__initialized = True
|
| 240 |
+
self.iterator = super().__iter__()
|
| 241 |
+
|
| 242 |
+
def __len__(self):
|
| 243 |
+
return len(self.batch_sampler.sampler)
|
| 244 |
+
|
| 245 |
+
def __iter__(self):
|
| 246 |
+
for i in range(len(self)):
|
| 247 |
+
yield next(self.iterator)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class _RepeatSampler(object):
|
| 251 |
+
""" Sampler that repeats forever.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
sampler (Sampler)
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
def __init__(self, sampler):
|
| 258 |
+
self.sampler = sampler
|
| 259 |
+
|
| 260 |
+
def __iter__(self):
|
| 261 |
+
while True:
|
| 262 |
+
yield from iter(self.sampler)
|
timm/data/mixup.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Mixup and Cutmix
|
| 2 |
+
|
| 3 |
+
Papers:
|
| 4 |
+
mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
|
| 5 |
+
|
| 6 |
+
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
|
| 7 |
+
|
| 8 |
+
Code Reference:
|
| 9 |
+
CutMix: https://github.com/clovaai/CutMix-PyTorch
|
| 10 |
+
|
| 11 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 12 |
+
"""
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
|
| 18 |
+
x = x.long().view(-1, 1)
|
| 19 |
+
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
|
| 23 |
+
off_value = smoothing / num_classes
|
| 24 |
+
on_value = 1. - smoothing + off_value
|
| 25 |
+
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
|
| 26 |
+
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
|
| 27 |
+
return y1 * lam + y2 * (1. - lam)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def rand_bbox(img_shape, lam, margin=0., count=None):
|
| 31 |
+
""" Standard CutMix bounding-box
|
| 32 |
+
Generates a random square bbox based on lambda value. This impl includes
|
| 33 |
+
support for enforcing a border margin as percent of bbox dimensions.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
img_shape (tuple): Image shape as tuple
|
| 37 |
+
lam (float): Cutmix lambda value
|
| 38 |
+
margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
|
| 39 |
+
count (int): Number of bbox to generate
|
| 40 |
+
"""
|
| 41 |
+
ratio = np.sqrt(1 - lam)
|
| 42 |
+
img_h, img_w = img_shape[-2:]
|
| 43 |
+
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
|
| 44 |
+
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
|
| 45 |
+
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
|
| 46 |
+
cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
|
| 47 |
+
yl = np.clip(cy - cut_h // 2, 0, img_h)
|
| 48 |
+
yh = np.clip(cy + cut_h // 2, 0, img_h)
|
| 49 |
+
xl = np.clip(cx - cut_w // 2, 0, img_w)
|
| 50 |
+
xh = np.clip(cx + cut_w // 2, 0, img_w)
|
| 51 |
+
return yl, yh, xl, xh
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def rand_bbox_minmax(img_shape, minmax, count=None):
|
| 55 |
+
""" Min-Max CutMix bounding-box
|
| 56 |
+
Inspired by Darknet cutmix impl, generates a random rectangular bbox
|
| 57 |
+
based on min/max percent values applied to each dimension of the input image.
|
| 58 |
+
|
| 59 |
+
Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
img_shape (tuple): Image shape as tuple
|
| 63 |
+
minmax (tuple or list): Min and max bbox ratios (as percent of image size)
|
| 64 |
+
count (int): Number of bbox to generate
|
| 65 |
+
"""
|
| 66 |
+
assert len(minmax) == 2
|
| 67 |
+
img_h, img_w = img_shape[-2:]
|
| 68 |
+
cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
|
| 69 |
+
cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
|
| 70 |
+
yl = np.random.randint(0, img_h - cut_h, size=count)
|
| 71 |
+
xl = np.random.randint(0, img_w - cut_w, size=count)
|
| 72 |
+
yu = yl + cut_h
|
| 73 |
+
xu = xl + cut_w
|
| 74 |
+
return yl, yu, xl, xu
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
|
| 78 |
+
""" Generate bbox and apply lambda correction.
|
| 79 |
+
"""
|
| 80 |
+
if ratio_minmax is not None:
|
| 81 |
+
yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
|
| 82 |
+
else:
|
| 83 |
+
yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
|
| 84 |
+
if correct_lam or ratio_minmax is not None:
|
| 85 |
+
bbox_area = (yu - yl) * (xu - xl)
|
| 86 |
+
lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
|
| 87 |
+
return (yl, yu, xl, xu), lam
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Mixup:
|
| 91 |
+
""" Mixup/Cutmix that applies different params to each element or whole batch
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
|
| 95 |
+
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
|
| 96 |
+
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
|
| 97 |
+
prob (float): probability of applying mixup or cutmix per batch or element
|
| 98 |
+
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
|
| 99 |
+
mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
|
| 100 |
+
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
|
| 101 |
+
label_smoothing (float): apply label smoothing to the mixed target tensor
|
| 102 |
+
num_classes (int): number of classes for target
|
| 103 |
+
"""
|
| 104 |
+
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
|
| 105 |
+
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
|
| 106 |
+
self.mixup_alpha = mixup_alpha
|
| 107 |
+
self.cutmix_alpha = cutmix_alpha
|
| 108 |
+
self.cutmix_minmax = cutmix_minmax
|
| 109 |
+
if self.cutmix_minmax is not None:
|
| 110 |
+
assert len(self.cutmix_minmax) == 2
|
| 111 |
+
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
|
| 112 |
+
self.cutmix_alpha = 1.0
|
| 113 |
+
self.mix_prob = prob
|
| 114 |
+
self.switch_prob = switch_prob
|
| 115 |
+
self.label_smoothing = label_smoothing
|
| 116 |
+
self.num_classes = num_classes
|
| 117 |
+
self.mode = mode
|
| 118 |
+
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
|
| 119 |
+
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
|
| 120 |
+
|
| 121 |
+
def _params_per_elem(self, batch_size):
|
| 122 |
+
lam = np.ones(batch_size, dtype=np.float32)
|
| 123 |
+
use_cutmix = np.zeros(batch_size, dtype=np.bool)
|
| 124 |
+
if self.mixup_enabled:
|
| 125 |
+
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
|
| 126 |
+
use_cutmix = np.random.rand(batch_size) < self.switch_prob
|
| 127 |
+
lam_mix = np.where(
|
| 128 |
+
use_cutmix,
|
| 129 |
+
np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
|
| 130 |
+
np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
|
| 131 |
+
elif self.mixup_alpha > 0.:
|
| 132 |
+
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
|
| 133 |
+
elif self.cutmix_alpha > 0.:
|
| 134 |
+
use_cutmix = np.ones(batch_size, dtype=np.bool)
|
| 135 |
+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
|
| 136 |
+
else:
|
| 137 |
+
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
|
| 138 |
+
lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
|
| 139 |
+
return lam, use_cutmix
|
| 140 |
+
|
| 141 |
+
def _params_per_batch(self):
|
| 142 |
+
lam = 1.
|
| 143 |
+
use_cutmix = False
|
| 144 |
+
if self.mixup_enabled and np.random.rand() < self.mix_prob:
|
| 145 |
+
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
|
| 146 |
+
use_cutmix = np.random.rand() < self.switch_prob
|
| 147 |
+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
|
| 148 |
+
np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
| 149 |
+
elif self.mixup_alpha > 0.:
|
| 150 |
+
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
| 151 |
+
elif self.cutmix_alpha > 0.:
|
| 152 |
+
use_cutmix = True
|
| 153 |
+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
|
| 154 |
+
else:
|
| 155 |
+
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
|
| 156 |
+
lam = float(lam_mix)
|
| 157 |
+
return lam, use_cutmix
|
| 158 |
+
|
| 159 |
+
def _mix_elem(self, x):
|
| 160 |
+
batch_size = len(x)
|
| 161 |
+
lam_batch, use_cutmix = self._params_per_elem(batch_size)
|
| 162 |
+
x_orig = x.clone() # need to keep an unmodified original for mixing source
|
| 163 |
+
for i in range(batch_size):
|
| 164 |
+
j = batch_size - i - 1
|
| 165 |
+
lam = lam_batch[i]
|
| 166 |
+
if lam != 1.:
|
| 167 |
+
if use_cutmix[i]:
|
| 168 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
| 169 |
+
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
| 170 |
+
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
|
| 171 |
+
lam_batch[i] = lam
|
| 172 |
+
else:
|
| 173 |
+
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
|
| 174 |
+
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
|
| 175 |
+
|
| 176 |
+
def _mix_pair(self, x):
|
| 177 |
+
batch_size = len(x)
|
| 178 |
+
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
|
| 179 |
+
x_orig = x.clone() # need to keep an unmodified original for mixing source
|
| 180 |
+
for i in range(batch_size // 2):
|
| 181 |
+
j = batch_size - i - 1
|
| 182 |
+
lam = lam_batch[i]
|
| 183 |
+
if lam != 1.:
|
| 184 |
+
if use_cutmix[i]:
|
| 185 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
| 186 |
+
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
| 187 |
+
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
|
| 188 |
+
x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
|
| 189 |
+
lam_batch[i] = lam
|
| 190 |
+
else:
|
| 191 |
+
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
|
| 192 |
+
x[j] = x[j] * lam + x_orig[i] * (1 - lam)
|
| 193 |
+
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
|
| 194 |
+
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
|
| 195 |
+
|
| 196 |
+
def _mix_batch(self, x):
|
| 197 |
+
lam, use_cutmix = self._params_per_batch()
|
| 198 |
+
if lam == 1.:
|
| 199 |
+
return 1.
|
| 200 |
+
if use_cutmix:
|
| 201 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
| 202 |
+
x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
| 203 |
+
x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
|
| 204 |
+
else:
|
| 205 |
+
x_flipped = x.flip(0).mul_(1. - lam)
|
| 206 |
+
x.mul_(lam).add_(x_flipped)
|
| 207 |
+
return lam
|
| 208 |
+
|
| 209 |
+
def __call__(self, x, target):
|
| 210 |
+
assert len(x) % 2 == 0, 'Batch size should be even when using this'
|
| 211 |
+
if self.mode == 'elem':
|
| 212 |
+
lam = self._mix_elem(x)
|
| 213 |
+
elif self.mode == 'pair':
|
| 214 |
+
lam = self._mix_pair(x)
|
| 215 |
+
else:
|
| 216 |
+
lam = self._mix_batch(x)
|
| 217 |
+
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
|
| 218 |
+
return x, target
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class FastCollateMixup(Mixup):
|
| 222 |
+
""" Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
|
| 223 |
+
|
| 224 |
+
A Mixup impl that's performed while collating the batches.
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
def _mix_elem_collate(self, output, batch, half=False):
|
| 228 |
+
batch_size = len(batch)
|
| 229 |
+
num_elem = batch_size // 2 if half else batch_size
|
| 230 |
+
assert len(output) == num_elem
|
| 231 |
+
lam_batch, use_cutmix = self._params_per_elem(num_elem)
|
| 232 |
+
for i in range(num_elem):
|
| 233 |
+
j = batch_size - i - 1
|
| 234 |
+
lam = lam_batch[i]
|
| 235 |
+
mixed = batch[i][0]
|
| 236 |
+
if lam != 1.:
|
| 237 |
+
if use_cutmix[i]:
|
| 238 |
+
if not half:
|
| 239 |
+
mixed = mixed.copy()
|
| 240 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
| 241 |
+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
| 242 |
+
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
|
| 243 |
+
lam_batch[i] = lam
|
| 244 |
+
else:
|
| 245 |
+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
|
| 246 |
+
np.rint(mixed, out=mixed)
|
| 247 |
+
output[i] += torch.from_numpy(mixed.astype(np.uint8))
|
| 248 |
+
if half:
|
| 249 |
+
lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
|
| 250 |
+
return torch.tensor(lam_batch).unsqueeze(1)
|
| 251 |
+
|
| 252 |
+
def _mix_pair_collate(self, output, batch):
|
| 253 |
+
batch_size = len(batch)
|
| 254 |
+
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
|
| 255 |
+
for i in range(batch_size // 2):
|
| 256 |
+
j = batch_size - i - 1
|
| 257 |
+
lam = lam_batch[i]
|
| 258 |
+
mixed_i = batch[i][0]
|
| 259 |
+
mixed_j = batch[j][0]
|
| 260 |
+
assert 0 <= lam <= 1.0
|
| 261 |
+
if lam < 1.:
|
| 262 |
+
if use_cutmix[i]:
|
| 263 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
| 264 |
+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
| 265 |
+
patch_i = mixed_i[:, yl:yh, xl:xh].copy()
|
| 266 |
+
mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
|
| 267 |
+
mixed_j[:, yl:yh, xl:xh] = patch_i
|
| 268 |
+
lam_batch[i] = lam
|
| 269 |
+
else:
|
| 270 |
+
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
|
| 271 |
+
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
|
| 272 |
+
mixed_i = mixed_temp
|
| 273 |
+
np.rint(mixed_j, out=mixed_j)
|
| 274 |
+
np.rint(mixed_i, out=mixed_i)
|
| 275 |
+
output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
|
| 276 |
+
output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
|
| 277 |
+
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
|
| 278 |
+
return torch.tensor(lam_batch).unsqueeze(1)
|
| 279 |
+
|
| 280 |
+
def _mix_batch_collate(self, output, batch):
|
| 281 |
+
batch_size = len(batch)
|
| 282 |
+
lam, use_cutmix = self._params_per_batch()
|
| 283 |
+
if use_cutmix:
|
| 284 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
| 285 |
+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
| 286 |
+
for i in range(batch_size):
|
| 287 |
+
j = batch_size - i - 1
|
| 288 |
+
mixed = batch[i][0]
|
| 289 |
+
if lam != 1.:
|
| 290 |
+
if use_cutmix:
|
| 291 |
+
mixed = mixed.copy() # don't want to modify the original while iterating
|
| 292 |
+
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
|
| 293 |
+
else:
|
| 294 |
+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
|
| 295 |
+
np.rint(mixed, out=mixed)
|
| 296 |
+
output[i] += torch.from_numpy(mixed.astype(np.uint8))
|
| 297 |
+
return lam
|
| 298 |
+
|
| 299 |
+
def __call__(self, batch, _=None):
|
| 300 |
+
batch_size = len(batch)
|
| 301 |
+
assert batch_size % 2 == 0, 'Batch size should be even when using this'
|
| 302 |
+
half = 'half' in self.mode
|
| 303 |
+
if half:
|
| 304 |
+
batch_size //= 2
|
| 305 |
+
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
| 306 |
+
if self.mode == 'elem' or self.mode == 'half':
|
| 307 |
+
lam = self._mix_elem_collate(output, batch, half=half)
|
| 308 |
+
elif self.mode == 'pair':
|
| 309 |
+
lam = self._mix_pair_collate(output, batch)
|
| 310 |
+
else:
|
| 311 |
+
lam = self._mix_batch_collate(output, batch)
|
| 312 |
+
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
| 313 |
+
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
|
| 314 |
+
target = target[:batch_size]
|
| 315 |
+
return output, target
|
| 316 |
+
|
timm/data/parsers/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .parser_factory import create_parser
|
timm/data/parsers/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (200 Bytes). View file
|
|
|