victor010 cooelf commited on
Commit
160d3b1
·
0 Parent(s):

Duplicate from cooelf/Multimodal-CoT

Browse files

Co-authored-by: Zhuosheng Zhang <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +36 -0
  2. README.md +13 -0
  3. __pycache__/model.cpython-37.pyc +0 -0
  4. __pycache__/model.cpython-38.pyc +0 -0
  5. api/61.png +0 -0
  6. app.py +150 -0
  7. model.py +515 -0
  8. requirements.txt +6 -0
  9. timm/__init__.py +4 -0
  10. timm/__pycache__/__init__.cpython-37.pyc +0 -0
  11. timm/__pycache__/__init__.cpython-38.pyc +0 -0
  12. timm/__pycache__/version.cpython-37.pyc +0 -0
  13. timm/__pycache__/version.cpython-38.pyc +0 -0
  14. timm/data/__init__.py +12 -0
  15. timm/data/__pycache__/__init__.cpython-37.pyc +0 -0
  16. timm/data/__pycache__/__init__.cpython-38.pyc +0 -0
  17. timm/data/__pycache__/auto_augment.cpython-37.pyc +0 -0
  18. timm/data/__pycache__/auto_augment.cpython-38.pyc +0 -0
  19. timm/data/__pycache__/config.cpython-37.pyc +0 -0
  20. timm/data/__pycache__/config.cpython-38.pyc +0 -0
  21. timm/data/__pycache__/constants.cpython-37.pyc +0 -0
  22. timm/data/__pycache__/constants.cpython-38.pyc +0 -0
  23. timm/data/__pycache__/dataset.cpython-37.pyc +0 -0
  24. timm/data/__pycache__/dataset.cpython-38.pyc +0 -0
  25. timm/data/__pycache__/dataset_factory.cpython-37.pyc +0 -0
  26. timm/data/__pycache__/dataset_factory.cpython-38.pyc +0 -0
  27. timm/data/__pycache__/distributed_sampler.cpython-37.pyc +0 -0
  28. timm/data/__pycache__/distributed_sampler.cpython-38.pyc +0 -0
  29. timm/data/__pycache__/loader.cpython-37.pyc +0 -0
  30. timm/data/__pycache__/loader.cpython-38.pyc +0 -0
  31. timm/data/__pycache__/mixup.cpython-37.pyc +0 -0
  32. timm/data/__pycache__/mixup.cpython-38.pyc +0 -0
  33. timm/data/__pycache__/random_erasing.cpython-37.pyc +0 -0
  34. timm/data/__pycache__/random_erasing.cpython-38.pyc +0 -0
  35. timm/data/__pycache__/real_labels.cpython-37.pyc +0 -0
  36. timm/data/__pycache__/real_labels.cpython-38.pyc +0 -0
  37. timm/data/__pycache__/transforms.cpython-37.pyc +0 -0
  38. timm/data/__pycache__/transforms.cpython-38.pyc +0 -0
  39. timm/data/__pycache__/transforms_factory.cpython-37.pyc +0 -0
  40. timm/data/__pycache__/transforms_factory.cpython-38.pyc +0 -0
  41. timm/data/auto_augment.py +822 -0
  42. timm/data/config.py +78 -0
  43. timm/data/constants.py +7 -0
  44. timm/data/dataset.py +146 -0
  45. timm/data/dataset_factory.py +30 -0
  46. timm/data/distributed_sampler.py +51 -0
  47. timm/data/loader.py +262 -0
  48. timm/data/mixup.py +316 -0
  49. timm/data/parsers/__init__.py +1 -0
  50. timm/data/parsers/__pycache__/__init__.cpython-37.pyc +0 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ qa9.jpg filter=lfs diff=lfs merge=lfs -text
36
+ upload4.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Multimodal-CoT
3
+ emoji: 🏖️
4
+ colorFrom: red
5
+ colorTo: blue
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ license: openrail
10
+ duplicated_from: cooelf/Multimodal-CoT
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__pycache__/model.cpython-37.pyc ADDED
Binary file (12.1 kB). View file
 
__pycache__/model.cpython-38.pyc ADDED
Binary file (12.2 kB). View file
 
api/61.png ADDED
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+ import gradio as gr
3
+ import requests
4
+ import torch
5
+ from transformers import T5Tokenizer
6
+ from model import T5ForMultimodalGeneration
7
+ from PIL import Image
8
+ import timm
9
+ from timm.data import resolve_data_config
10
+ from timm.data.transforms_factory import create_transform
11
+
12
+ rationale_model_dir = "cooelf/MM-CoT-UnifiedQA-Base-Rationale-Joint"
13
+ answer_model_dir = "cooelf/MM-CoT-UnifiedQA-Base-Answer-Joint"
14
+
15
+ vit_model = timm.create_model("vit_base_patch16_384", pretrained=True, num_classes=0)
16
+ vit_model.eval()
17
+ config = resolve_data_config({}, model=vit_model)
18
+ transform = create_transform(**config)
19
+ tokenizer = T5Tokenizer.from_pretrained(rationale_model_dir)
20
+ r_model = T5ForMultimodalGeneration.from_pretrained(rationale_model_dir, patch_size=(577, 768))
21
+ a_model = T5ForMultimodalGeneration.from_pretrained(answer_model_dir, patch_size=(577, 768))
22
+
23
+ def inference_chat(input_image,input_text):
24
+ with torch.no_grad():
25
+ # print(input_image)
26
+ # img = Image.open(input_image).convert("RGB")
27
+ input = transform(input_image).unsqueeze(0)
28
+ out = vit_model.forward_features(input)
29
+ image_features = out.detach()
30
+
31
+ source = tokenizer.batch_encode_plus(
32
+ [input_text],
33
+ max_length=512,
34
+ pad_to_max_length=True,
35
+ truncation=True,
36
+ padding="max_length",
37
+ return_tensors="pt",
38
+ )
39
+ source_ids = source["input_ids"]
40
+ source_mask = source["attention_mask"]
41
+ rationale = r_model.generate(
42
+ input_ids=source_ids,
43
+ attention_mask=source_mask,
44
+ image_ids=image_features,
45
+ max_length=512,
46
+ num_beams=1,
47
+ do_sample=False
48
+ )
49
+ rationale = tokenizer.batch_decode(rationale, skip_special_tokens=True)[0]
50
+ print(rationale)
51
+
52
+ input_text = input_text + "\n" + rationale +"\nAnswer:"
53
+ print(input_text)
54
+
55
+ source = tokenizer.batch_encode_plus(
56
+ [input_text],
57
+ max_length=512,
58
+ pad_to_max_length=True,
59
+ truncation=True,
60
+ padding="max_length",
61
+ return_tensors="pt",
62
+ )
63
+ source_ids = source["input_ids"]
64
+ source_mask = source["attention_mask"]
65
+ answer = a_model.generate(
66
+ input_ids=source_ids,
67
+ attention_mask=source_mask,
68
+ image_ids=image_features,
69
+ max_length=64,
70
+ num_beams=1,
71
+ do_sample=False
72
+ )
73
+
74
+ answer = tokenizer.batch_decode(answer, skip_special_tokens=True)[0]
75
+ return rationale, answer
76
+
77
+
78
+ title = """# Multimodal-CoT"""
79
+ # description = """**VLE** (Visual-Language Encoder) is an image-text multimodal understanding model built on the pre-trained text and image encoders. See https://github.com/iflytek/VLE for more details.
80
+ # We demonstrate visual question answering systems built with VLE and LLM."""
81
+ # description1 = """**VQA**: The image and the question are fed to a VQA model (VLEForVQA) and the model predicts the answer.
82
+
83
+ # **VQA+LLM**: We feed the caption, question, and answers predicted by the VQA model to the LLM and ask the LLM to generate the final answer. The outptus from VQA+LLM may vary due to the decoding strategy of the LLM."""
84
+
85
+ with gr.Blocks(
86
+ css="""
87
+ .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
88
+ #component-21 > div.wrap.svelte-w6rprc {height: 600px;}
89
+ """
90
+ ) as iface:
91
+ state = gr.State([])
92
+ #caption_output = None
93
+ gr.Markdown(title)
94
+ # gr.Markdown(description)
95
+ #gr.Markdown(article)
96
+
97
+ with gr.Row():
98
+ with gr.Column(scale=1):
99
+ image_input = gr.Image(type="pil",label="Image")
100
+ with gr.Row():
101
+ with gr.Column(scale=1):
102
+ chat_input = gr.Textbox(lines=1, label="Question")
103
+ with gr.Row():
104
+ clear_button = gr.Button(value="Clear", interactive=True,width=30)
105
+ submit_button = gr.Button(
106
+ value="Submit", interactive=True, variant="primary"
107
+ )
108
+ '''
109
+ cap_submit_button = gr.Button(
110
+ value="Submit_CAP", interactive=True, variant="primary"
111
+ )
112
+ gpt3_submit_button = gr.Button(
113
+ value="Submit_GPT3", interactive=True, variant="primary"
114
+ )
115
+ '''
116
+ with gr.Column():
117
+ # gr.Markdown(description1)
118
+ rationale = gr.Textbox(lines=0, label="Rationale")
119
+ answer = gr.Textbox(lines=0, label="Answer")
120
+
121
+ chat_input.submit(
122
+ inference_chat,
123
+ [
124
+ image_input,
125
+ chat_input,
126
+ ],
127
+ [rationale, answer],
128
+ )
129
+ clear_button.click(
130
+ lambda: ("", [],"",""),
131
+ [],
132
+ [chat_input, state, rationale, answer],
133
+ queue=False,
134
+ )
135
+ submit_button.click(
136
+ inference_chat,
137
+ [
138
+ image_input,
139
+ chat_input,
140
+ ],
141
+ [rationale, answer],
142
+ )
143
+ examples=[['api/61.png',"Question: Think about the magnetic force between the magnets in each pair. Which of the following statements is true?\nContext: The images below show two pairs of magnets. The magnets in different pairs do not affect each other. All the magnets shown are made of the same material, but some of them are different sizes and shapes.\nOptions: (A) The magnitude of the magnetic force is the same in both pairs. (B) The magnitude of the magnetic force is smaller in Pair 1. (C) The magnitude of the magnetic force is smaller in Pair 2.\nSolution:","Magnet sizes affect the magnitude of the magnetic force. Imagine magnets that are the same shape and made of the same material. The smaller the magnets, the smaller the magnitude of the magnetic force between them.nMagnet A is the same size in both pairs. But Magnet B is smaller in Pair 2 than in Pair 1. So, the magnitude of the magnetic force is smaller in Pair 2 than in Pair 1.","The answer is (C)."],
144
+ ]
145
+ examples = gr.Examples(
146
+ examples=examples,inputs=[image_input, chat_input, rationale, answer],
147
+ )
148
+
149
+ iface.queue(concurrency_count=1, api_open=False, max_size=10)
150
+ iface.launch(enable_queue=True)
model.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from https://github.com/huggingface/transformers
3
+ '''
4
+
5
+ from transformers import T5Config, T5ForConditionalGeneration
6
+ from transformers.models.t5.modeling_t5 import T5Stack, __HEAD_MASK_WARNING_MSG, T5Block, T5LayerNorm
7
+ import copy
8
+ from transformers.modeling_outputs import ModelOutput, BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput
9
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
10
+ import math
11
+ import os
12
+ import warnings
13
+ from typing import Optional, Tuple, Union
14
+ import torch
15
+ from torch import nn
16
+ from torch.nn import CrossEntropyLoss
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutput,
19
+ Seq2SeqLMOutput,
20
+ )
21
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
22
+ from torch.utils.checkpoint import checkpoint
23
+
24
+ class JointEncoder(T5Stack):
25
+ def __init__(self, config, embed_tokens=None, patch_size=None):
26
+ super().__init__(config)
27
+
28
+ self.embed_tokens = embed_tokens
29
+ self.is_decoder = config.is_decoder
30
+
31
+ self.patch_num, self.patch_dim = patch_size
32
+ self.image_dense = nn.Linear(self.patch_dim, config.d_model)
33
+ self.mha_layer = torch.nn.MultiheadAttention(embed_dim=config.hidden_size, kdim=config.hidden_size, vdim=config.hidden_size, num_heads=1, batch_first=True)
34
+ self.gate_dense = nn.Linear(2*config.hidden_size, config.hidden_size)
35
+ self.sigmoid = nn.Sigmoid()
36
+
37
+ self.block = nn.ModuleList(
38
+ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
39
+ )
40
+ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
41
+ self.dropout = nn.Dropout(config.dropout_rate)
42
+
43
+ # Initialize weights and apply final processing
44
+ self.post_init()
45
+ # Model parallel
46
+ self.model_parallel = False
47
+ self.device_map = None
48
+ self.gradient_checkpointing = False
49
+
50
+ def parallelize(self, device_map=None):
51
+ warnings.warn(
52
+ "`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
53
+ " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
54
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
55
+ " 'block.1': 1, ...}",
56
+ FutureWarning,
57
+ )
58
+ # Check validity of device_map
59
+ self.device_map = (
60
+ get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
61
+ )
62
+ assert_device_map(self.device_map, len(self.block))
63
+ self.model_parallel = True
64
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
65
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
66
+ # Load onto devices
67
+ for k, v in self.device_map.items():
68
+ for layer in v:
69
+ cuda_device = "cuda:" + str(k)
70
+ self.block[layer] = self.block[layer].to(cuda_device)
71
+
72
+ # Set embed_tokens to first layer
73
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
74
+ # Set final layer norm to last device
75
+ self.final_layer_norm = self.final_layer_norm.to(self.last_device)
76
+
77
+ def deparallelize(self):
78
+ warnings.warn(
79
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
80
+ FutureWarning,
81
+ )
82
+ self.model_parallel = False
83
+ self.device_map = None
84
+ self.first_device = "cpu"
85
+ self.last_device = "cpu"
86
+ for i in range(len(self.block)):
87
+ self.block[i] = self.block[i].to("cpu")
88
+ self.embed_tokens = self.embed_tokens.to("cpu")
89
+ self.final_layer_norm = self.final_layer_norm.to("cpu")
90
+ torch.cuda.empty_cache()
91
+
92
+ def get_input_embeddings(self):
93
+ return self.embed_tokens
94
+
95
+ def set_input_embeddings(self, new_embeddings):
96
+ self.embed_tokens = new_embeddings
97
+
98
+ def forward(
99
+ self,
100
+ input_ids=None,
101
+ attention_mask=None,
102
+ encoder_hidden_states=None,
103
+ encoder_attention_mask=None,
104
+ inputs_embeds=None,
105
+ image_ids=None,
106
+ head_mask=None,
107
+ cross_attn_head_mask=None,
108
+ past_key_values=None,
109
+ use_cache=None,
110
+ output_attentions=None,
111
+ output_hidden_states=None,
112
+ return_dict=None,
113
+ ):
114
+ # Model parallel
115
+ if self.model_parallel:
116
+ torch.cuda.set_device(self.first_device)
117
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
118
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
119
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
120
+ output_hidden_states = (
121
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
122
+ )
123
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
124
+
125
+ if input_ids is not None and inputs_embeds is not None:
126
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
127
+ raise ValueError(
128
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
129
+ )
130
+ elif input_ids is not None:
131
+ input_shape = input_ids.size()
132
+ input_ids = input_ids.view(-1, input_shape[-1])
133
+ elif inputs_embeds is not None:
134
+ input_shape = inputs_embeds.size()[:-1]
135
+ else:
136
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
137
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
138
+
139
+ if inputs_embeds is None:
140
+ assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
141
+ inputs_embeds = self.embed_tokens(input_ids)
142
+
143
+ batch_size, seq_length = input_shape
144
+
145
+ # required mask seq length can be calculated via length of past
146
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
147
+
148
+ if use_cache is True:
149
+ assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
150
+
151
+ if attention_mask is None:
152
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
153
+ if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
154
+ encoder_seq_length = encoder_hidden_states.shape[1]
155
+ encoder_attention_mask = torch.ones(
156
+ batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
157
+ )
158
+
159
+ # initialize past_key_values with `None` if past does not exist
160
+ if past_key_values is None:
161
+ past_key_values = [None] * len(self.block)
162
+
163
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
164
+ # ourselves in which case we just need to make it broadcastable to all heads.
165
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
166
+
167
+ # If a 2D or 3D attention mask is provided for the cross-attention
168
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
169
+ if self.is_decoder and encoder_hidden_states is not None:
170
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
171
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
172
+ if encoder_attention_mask is None:
173
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
174
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
175
+ else:
176
+ encoder_extended_attention_mask = None
177
+
178
+ # Prepare head mask if needed
179
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
180
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
181
+ present_key_value_states = () if use_cache else None
182
+ all_hidden_states = () if output_hidden_states else None
183
+ all_attentions = () if output_attentions else None
184
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
185
+ position_bias = None
186
+ encoder_decoder_position_bias = None
187
+
188
+ hidden_states = self.dropout(inputs_embeds)
189
+
190
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
191
+ layer_head_mask = head_mask[i]
192
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
193
+ # Model parallel
194
+ if self.model_parallel:
195
+ torch.cuda.set_device(hidden_states.device)
196
+ # Ensure that attention_mask is always on the same device as hidden_states
197
+ if attention_mask is not None:
198
+ attention_mask = attention_mask.to(hidden_states.device)
199
+ if position_bias is not None:
200
+ position_bias = position_bias.to(hidden_states.device)
201
+ if encoder_hidden_states is not None:
202
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
203
+ if encoder_extended_attention_mask is not None:
204
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
205
+ if encoder_decoder_position_bias is not None:
206
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
207
+ if layer_head_mask is not None:
208
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
209
+ if cross_attn_layer_head_mask is not None:
210
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
211
+ if output_hidden_states:
212
+ all_hidden_states = all_hidden_states + (hidden_states,)
213
+
214
+ if self.gradient_checkpointing and self.training:
215
+ if use_cache:
216
+ logger.warning_once(
217
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
218
+ )
219
+ use_cache = False
220
+
221
+ def create_custom_forward(module):
222
+ def custom_forward(*inputs):
223
+ return tuple(module(*inputs, use_cache, output_attentions))
224
+
225
+ return custom_forward
226
+
227
+ layer_outputs = checkpoint(
228
+ create_custom_forward(layer_module),
229
+ hidden_states,
230
+ extended_attention_mask,
231
+ position_bias,
232
+ encoder_hidden_states,
233
+ encoder_extended_attention_mask,
234
+ encoder_decoder_position_bias,
235
+ layer_head_mask,
236
+ cross_attn_layer_head_mask,
237
+ None, # past_key_value is always None with gradient checkpointing
238
+ )
239
+ else:
240
+ layer_outputs = layer_module(
241
+ hidden_states,
242
+ attention_mask=extended_attention_mask,
243
+ position_bias=position_bias,
244
+ encoder_hidden_states=encoder_hidden_states,
245
+ encoder_attention_mask=encoder_extended_attention_mask,
246
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
247
+ layer_head_mask=layer_head_mask,
248
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
249
+ past_key_value=past_key_value,
250
+ use_cache=use_cache,
251
+ output_attentions=output_attentions,
252
+ )
253
+
254
+ # layer_outputs is a tuple with:
255
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
256
+ if use_cache is False:
257
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
258
+
259
+ hidden_states, present_key_value_state = layer_outputs[:2]
260
+
261
+ # We share the position biases between the layers - the first layer store them
262
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
263
+ # (cross-attention position bias), (cross-attention weights)
264
+ position_bias = layer_outputs[2]
265
+ if self.is_decoder and encoder_hidden_states is not None:
266
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
267
+ # append next layer key value states
268
+ if use_cache:
269
+ present_key_value_states = present_key_value_states + (present_key_value_state,)
270
+
271
+ if output_attentions:
272
+ all_attentions = all_attentions + (layer_outputs[3],)
273
+ if self.is_decoder:
274
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
275
+
276
+ # Model Parallel: If it's the last layer for that device, put things on the next device
277
+ if self.model_parallel:
278
+ for k, v in self.device_map.items():
279
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
280
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
281
+
282
+ hidden_states = self.final_layer_norm(hidden_states)
283
+ hidden_states = self.dropout(hidden_states)
284
+
285
+ # Add last layer
286
+ if output_hidden_states:
287
+ all_hidden_states = all_hidden_states + (hidden_states,)
288
+
289
+ image_embedding = self.image_dense(image_ids)
290
+ image_att, _ = self.mha_layer(hidden_states, image_embedding, image_embedding)
291
+ merge = torch.cat([hidden_states, image_att], dim=-1)
292
+ gate = self.sigmoid(self.gate_dense(merge))
293
+ hidden_states = (1 - gate) * hidden_states + gate * image_att
294
+
295
+ if not return_dict:
296
+ return tuple(
297
+ v
298
+ for v in [
299
+ hidden_states,
300
+ present_key_value_states,
301
+ all_hidden_states,
302
+ all_attentions,
303
+ all_cross_attentions,
304
+ ]
305
+ if v is not None
306
+ )
307
+ return BaseModelOutputWithPastAndCrossAttentions(
308
+ last_hidden_state=hidden_states,
309
+ past_key_values=present_key_value_states,
310
+ hidden_states=all_hidden_states,
311
+ attentions=all_attentions,
312
+ cross_attentions=all_cross_attentions,
313
+ )
314
+
315
+
316
+ class T5ForMultimodalGeneration(T5ForConditionalGeneration):
317
+ _keys_to_ignore_on_load_missing = [
318
+ r"encoder.embed_tokens.weight",
319
+ r"decoder.embed_tokens.weight",
320
+ r"lm_head.weight",
321
+ ]
322
+ _keys_to_ignore_on_load_unexpected = [
323
+ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
324
+ ]
325
+
326
+ def __init__(self, config: T5Config, patch_size):
327
+ super().__init__(config)
328
+ self.model_dim = config.d_model
329
+
330
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
331
+
332
+ encoder_config = copy.deepcopy(config)
333
+ encoder_config.is_decoder = False
334
+ encoder_config.use_cache = False
335
+ encoder_config.is_encoder_decoder = False
336
+ # self.encoder = T5Stack(encoder_config, self.shared)
337
+ self.encoder = JointEncoder(encoder_config, self.shared, patch_size)
338
+ decoder_config = copy.deepcopy(config)
339
+ decoder_config.is_decoder = True
340
+ decoder_config.is_encoder_decoder = False
341
+ decoder_config.num_layers = config.num_decoder_layers
342
+ self.decoder = T5Stack(decoder_config, self.shared)
343
+
344
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
345
+
346
+ # Initialize weights and apply final processing
347
+ self.post_init()
348
+
349
+ # Model parallel
350
+ self.model_parallel = False
351
+ self.device_map = None
352
+
353
+ def forward(
354
+ self,
355
+ input_ids: Optional[torch.LongTensor] = None,
356
+ image_ids=None,
357
+ attention_mask: Optional[torch.FloatTensor] = None,
358
+ decoder_input_ids: Optional[torch.LongTensor] = None,
359
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
360
+ head_mask: Optional[torch.FloatTensor] = None,
361
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
362
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
363
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
364
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
365
+ inputs_embeds: Optional[torch.FloatTensor] = None,
366
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
367
+ labels: Optional[torch.LongTensor] = None,
368
+ use_cache: Optional[bool] = None,
369
+ output_attentions: Optional[bool] = None,
370
+ output_hidden_states: Optional[bool] = None,
371
+ return_dict: Optional[bool] = None,
372
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
373
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
374
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
375
+
376
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
377
+ if head_mask is not None and decoder_head_mask is None:
378
+ if self.config.num_layers == self.config.num_decoder_layers:
379
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
380
+ decoder_head_mask = head_mask
381
+
382
+ # Encode if needed (training, first prediction pass)
383
+ if encoder_outputs is None:
384
+ # Convert encoder inputs in embeddings if needed
385
+ encoder_outputs = self.encoder(
386
+ input_ids=input_ids,
387
+ attention_mask=attention_mask,
388
+ inputs_embeds=inputs_embeds,
389
+ image_ids=image_ids,
390
+ head_mask=head_mask,
391
+ output_attentions=output_attentions,
392
+ output_hidden_states=output_hidden_states,
393
+ return_dict=return_dict,
394
+ )
395
+
396
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
397
+ encoder_outputs = BaseModelOutput(
398
+ last_hidden_state=encoder_outputs[0],
399
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
400
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
401
+ )
402
+
403
+ hidden_states = encoder_outputs[0]
404
+
405
+ if self.model_parallel:
406
+ torch.cuda.set_device(self.decoder.first_device)
407
+
408
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
409
+ # get decoder inputs from shifting lm labels to the right
410
+ decoder_input_ids = self._shift_right(labels)
411
+
412
+ # Set device for model parallelism
413
+ if self.model_parallel:
414
+ torch.cuda.set_device(self.decoder.first_device)
415
+ hidden_states = hidden_states.to(self.decoder.first_device)
416
+ if decoder_input_ids is not None:
417
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
418
+ if attention_mask is not None:
419
+ attention_mask = attention_mask.to(self.decoder.first_device)
420
+ if decoder_attention_mask is not None:
421
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
422
+
423
+ # Decode
424
+ decoder_outputs = self.decoder(
425
+ input_ids=decoder_input_ids,
426
+ attention_mask=decoder_attention_mask,
427
+ inputs_embeds=decoder_inputs_embeds,
428
+ past_key_values=past_key_values,
429
+ encoder_hidden_states=hidden_states,
430
+ encoder_attention_mask=attention_mask,
431
+ head_mask=decoder_head_mask,
432
+ cross_attn_head_mask=cross_attn_head_mask,
433
+ use_cache=use_cache,
434
+ output_attentions=output_attentions,
435
+ output_hidden_states=output_hidden_states,
436
+ return_dict=return_dict,
437
+ )
438
+
439
+ sequence_output = decoder_outputs[0]
440
+
441
+ # Set device for model parallelism
442
+ if self.model_parallel:
443
+ torch.cuda.set_device(self.encoder.first_device)
444
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
445
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
446
+
447
+ if self.config.tie_word_embeddings:
448
+ # Rescale output before projecting on vocab
449
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
450
+ sequence_output = sequence_output * (self.model_dim**-0.5)
451
+
452
+ lm_logits = self.lm_head(sequence_output)
453
+
454
+ loss = None
455
+ if labels is not None:
456
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
457
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
458
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
459
+
460
+ if not return_dict:
461
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
462
+ return ((loss,) + output) if loss is not None else output
463
+
464
+ return Seq2SeqLMOutput(
465
+ loss=loss,
466
+ logits=lm_logits,
467
+ past_key_values=decoder_outputs.past_key_values,
468
+ decoder_hidden_states=decoder_outputs.hidden_states,
469
+ decoder_attentions=decoder_outputs.attentions,
470
+ cross_attentions=decoder_outputs.cross_attentions,
471
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
472
+ encoder_hidden_states=encoder_outputs.hidden_states,
473
+ encoder_attentions=encoder_outputs.attentions,
474
+ )
475
+
476
+ def prepare_inputs_for_generation(
477
+ self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
478
+ ):
479
+ # cut decoder_input_ids if past is used
480
+ if past is not None:
481
+ decoder_input_ids = decoder_input_ids[:, -1:]
482
+
483
+ output = {
484
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
485
+ "encoder_outputs": encoder_outputs,
486
+ "past_key_values": past,
487
+ "decoder_input_ids": decoder_input_ids,
488
+ "attention_mask": attention_mask,
489
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
490
+ }
491
+
492
+ if "image_ids" in kwargs:
493
+ output["image_ids"] = kwargs['image_ids']
494
+
495
+ return output
496
+
497
+ def test_step(self, tokenizer, batch, **kwargs):
498
+ device = next(self.parameters()).device
499
+ input_ids = batch['input_ids'].to(device)
500
+ image_ids = batch['image_ids'].to(device)
501
+
502
+ output = self.generate(
503
+ input_ids=input_ids,
504
+ image_ids=image_ids,
505
+ **kwargs
506
+ )
507
+
508
+ generated_sents = tokenizer.batch_decode(output, skip_special_tokens=True)
509
+ targets = tokenizer.batch_decode(batch['labels'], skip_special_tokens=True)
510
+
511
+ result = {}
512
+ result['preds'] = generated_sents
513
+ result['targets'] = targets
514
+
515
+ return result
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ huggingface-hub>=0.4.0
2
+ git+https://github.com/huggingface/transformers.git
3
+ torch
4
+ torchvision
5
+ sentencepiece
6
+ numpy
timm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .version import __version__
2
+ from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
3
+ is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
4
+ get_model_default_value, is_model_pretrained
timm/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (537 Bytes). View file
 
timm/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (541 Bytes). View file
 
timm/__pycache__/version.cpython-37.pyc ADDED
Binary file (156 Bytes). View file
 
timm/__pycache__/version.cpython-38.pyc ADDED
Binary file (160 Bytes). View file
 
timm/data/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
2
+ rand_augment_transform, auto_augment_transform
3
+ from .config import resolve_data_config
4
+ from .constants import *
5
+ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
6
+ from .dataset_factory import create_dataset
7
+ from .loader import create_loader
8
+ from .mixup import Mixup, FastCollateMixup
9
+ from .parsers import create_parser
10
+ from .real_labels import RealLabelsImagenet
11
+ from .transforms import *
12
+ from .transforms_factory import create_transform
timm/data/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (848 Bytes). View file
 
timm/data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (852 Bytes). View file
 
timm/data/__pycache__/auto_augment.cpython-37.pyc ADDED
Binary file (25.2 kB). View file
 
timm/data/__pycache__/auto_augment.cpython-38.pyc ADDED
Binary file (23.3 kB). View file
 
timm/data/__pycache__/config.cpython-37.pyc ADDED
Binary file (1.59 kB). View file
 
timm/data/__pycache__/config.cpython-38.pyc ADDED
Binary file (1.6 kB). View file
 
timm/data/__pycache__/constants.cpython-37.pyc ADDED
Binary file (483 Bytes). View file
 
timm/data/__pycache__/constants.cpython-38.pyc ADDED
Binary file (479 Bytes). View file
 
timm/data/__pycache__/dataset.cpython-37.pyc ADDED
Binary file (5 kB). View file
 
timm/data/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (5.06 kB). View file
 
timm/data/__pycache__/dataset_factory.cpython-37.pyc ADDED
Binary file (938 Bytes). View file
 
timm/data/__pycache__/dataset_factory.cpython-38.pyc ADDED
Binary file (966 Bytes). View file
 
timm/data/__pycache__/distributed_sampler.cpython-37.pyc ADDED
Binary file (2.06 kB). View file
 
timm/data/__pycache__/distributed_sampler.cpython-38.pyc ADDED
Binary file (2.08 kB). View file
 
timm/data/__pycache__/loader.cpython-37.pyc ADDED
Binary file (7.09 kB). View file
 
timm/data/__pycache__/loader.cpython-38.pyc ADDED
Binary file (7.13 kB). View file
 
timm/data/__pycache__/mixup.cpython-37.pyc ADDED
Binary file (11.5 kB). View file
 
timm/data/__pycache__/mixup.cpython-38.pyc ADDED
Binary file (11.4 kB). View file
 
timm/data/__pycache__/random_erasing.cpython-37.pyc ADDED
Binary file (3.66 kB). View file
 
timm/data/__pycache__/random_erasing.cpython-38.pyc ADDED
Binary file (3.69 kB). View file
 
timm/data/__pycache__/real_labels.cpython-37.pyc ADDED
Binary file (2.37 kB). View file
 
timm/data/__pycache__/real_labels.cpython-38.pyc ADDED
Binary file (2.4 kB). View file
 
timm/data/__pycache__/transforms.cpython-37.pyc ADDED
Binary file (5.66 kB). View file
 
timm/data/__pycache__/transforms.cpython-38.pyc ADDED
Binary file (5.7 kB). View file
 
timm/data/__pycache__/transforms_factory.cpython-37.pyc ADDED
Binary file (5.01 kB). View file
 
timm/data/__pycache__/transforms_factory.cpython-38.pyc ADDED
Binary file (5.06 kB). View file
 
timm/data/auto_augment.py ADDED
@@ -0,0 +1,822 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ AutoAugment, RandAugment, and AugMix for PyTorch
2
+
3
+ This code implements the searched ImageNet policies with various tweaks and improvements and
4
+ does not include any of the search code.
5
+
6
+ AA and RA Implementation adapted from:
7
+ https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
8
+
9
+ AugMix adapted from:
10
+ https://github.com/google-research/augmix
11
+
12
+ Papers:
13
+ AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
14
+ Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
15
+ RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
16
+ AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
17
+
18
+ Hacked together by / Copyright 2020 Ross Wightman
19
+ """
20
+ import random
21
+ import math
22
+ import re
23
+ from PIL import Image, ImageOps, ImageEnhance, ImageChops
24
+ import PIL
25
+ import numpy as np
26
+
27
+
28
+ _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
29
+
30
+ _FILL = (128, 128, 128)
31
+
32
+ # This signifies the max integer that the controller RNN could predict for the
33
+ # augmentation scheme.
34
+ _MAX_LEVEL = 10.
35
+
36
+ _HPARAMS_DEFAULT = dict(
37
+ translate_const=250,
38
+ img_mean=_FILL,
39
+ )
40
+
41
+ _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
42
+
43
+
44
+ def _interpolation(kwargs):
45
+ interpolation = kwargs.pop('resample', Image.BILINEAR)
46
+ if isinstance(interpolation, (list, tuple)):
47
+ return random.choice(interpolation)
48
+ else:
49
+ return interpolation
50
+
51
+
52
+ def _check_args_tf(kwargs):
53
+ if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
54
+ kwargs.pop('fillcolor')
55
+ kwargs['resample'] = _interpolation(kwargs)
56
+
57
+
58
+ def shear_x(img, factor, **kwargs):
59
+ _check_args_tf(kwargs)
60
+ return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
61
+
62
+
63
+ def shear_y(img, factor, **kwargs):
64
+ _check_args_tf(kwargs)
65
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
66
+
67
+
68
+ def translate_x_rel(img, pct, **kwargs):
69
+ pixels = pct * img.size[0]
70
+ _check_args_tf(kwargs)
71
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
72
+
73
+
74
+ def translate_y_rel(img, pct, **kwargs):
75
+ pixels = pct * img.size[1]
76
+ _check_args_tf(kwargs)
77
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
78
+
79
+
80
+ def translate_x_abs(img, pixels, **kwargs):
81
+ _check_args_tf(kwargs)
82
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
83
+
84
+
85
+ def translate_y_abs(img, pixels, **kwargs):
86
+ _check_args_tf(kwargs)
87
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
88
+
89
+
90
+ def rotate(img, degrees, **kwargs):
91
+ _check_args_tf(kwargs)
92
+ if _PIL_VER >= (5, 2):
93
+ return img.rotate(degrees, **kwargs)
94
+ elif _PIL_VER >= (5, 0):
95
+ w, h = img.size
96
+ post_trans = (0, 0)
97
+ rotn_center = (w / 2.0, h / 2.0)
98
+ angle = -math.radians(degrees)
99
+ matrix = [
100
+ round(math.cos(angle), 15),
101
+ round(math.sin(angle), 15),
102
+ 0.0,
103
+ round(-math.sin(angle), 15),
104
+ round(math.cos(angle), 15),
105
+ 0.0,
106
+ ]
107
+
108
+ def transform(x, y, matrix):
109
+ (a, b, c, d, e, f) = matrix
110
+ return a * x + b * y + c, d * x + e * y + f
111
+
112
+ matrix[2], matrix[5] = transform(
113
+ -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
114
+ )
115
+ matrix[2] += rotn_center[0]
116
+ matrix[5] += rotn_center[1]
117
+ return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
118
+ else:
119
+ return img.rotate(degrees, resample=kwargs['resample'])
120
+
121
+
122
+ def auto_contrast(img, **__):
123
+ return ImageOps.autocontrast(img)
124
+
125
+
126
+ def invert(img, **__):
127
+ return ImageOps.invert(img)
128
+
129
+
130
+ def equalize(img, **__):
131
+ return ImageOps.equalize(img)
132
+
133
+
134
+ def solarize(img, thresh, **__):
135
+ return ImageOps.solarize(img, thresh)
136
+
137
+
138
+ def solarize_add(img, add, thresh=128, **__):
139
+ lut = []
140
+ for i in range(256):
141
+ if i < thresh:
142
+ lut.append(min(255, i + add))
143
+ else:
144
+ lut.append(i)
145
+ if img.mode in ("L", "RGB"):
146
+ if img.mode == "RGB" and len(lut) == 256:
147
+ lut = lut + lut + lut
148
+ return img.point(lut)
149
+ else:
150
+ return img
151
+
152
+
153
+ def posterize(img, bits_to_keep, **__):
154
+ if bits_to_keep >= 8:
155
+ return img
156
+ return ImageOps.posterize(img, bits_to_keep)
157
+
158
+
159
+ def contrast(img, factor, **__):
160
+ return ImageEnhance.Contrast(img).enhance(factor)
161
+
162
+
163
+ def color(img, factor, **__):
164
+ return ImageEnhance.Color(img).enhance(factor)
165
+
166
+
167
+ def brightness(img, factor, **__):
168
+ return ImageEnhance.Brightness(img).enhance(factor)
169
+
170
+
171
+ def sharpness(img, factor, **__):
172
+ return ImageEnhance.Sharpness(img).enhance(factor)
173
+
174
+
175
+ def _randomly_negate(v):
176
+ """With 50% prob, negate the value"""
177
+ return -v if random.random() > 0.5 else v
178
+
179
+
180
+ def _rotate_level_to_arg(level, _hparams):
181
+ # range [-30, 30]
182
+ level = (level / _MAX_LEVEL) * 30.
183
+ level = _randomly_negate(level)
184
+ return level,
185
+
186
+
187
+ def _enhance_level_to_arg(level, _hparams):
188
+ # range [0.1, 1.9]
189
+ return (level / _MAX_LEVEL) * 1.8 + 0.1,
190
+
191
+
192
+ def _enhance_increasing_level_to_arg(level, _hparams):
193
+ # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
194
+ # range [0.1, 1.9]
195
+ level = (level / _MAX_LEVEL) * .9
196
+ level = 1.0 + _randomly_negate(level)
197
+ return level,
198
+
199
+
200
+ def _shear_level_to_arg(level, _hparams):
201
+ # range [-0.3, 0.3]
202
+ level = (level / _MAX_LEVEL) * 0.3
203
+ level = _randomly_negate(level)
204
+ return level,
205
+
206
+
207
+ def _translate_abs_level_to_arg(level, hparams):
208
+ translate_const = hparams['translate_const']
209
+ level = (level / _MAX_LEVEL) * float(translate_const)
210
+ level = _randomly_negate(level)
211
+ return level,
212
+
213
+
214
+ def _translate_rel_level_to_arg(level, hparams):
215
+ # default range [-0.45, 0.45]
216
+ translate_pct = hparams.get('translate_pct', 0.45)
217
+ level = (level / _MAX_LEVEL) * translate_pct
218
+ level = _randomly_negate(level)
219
+ return level,
220
+
221
+
222
+ def _posterize_level_to_arg(level, _hparams):
223
+ # As per Tensorflow TPU EfficientNet impl
224
+ # range [0, 4], 'keep 0 up to 4 MSB of original image'
225
+ # intensity/severity of augmentation decreases with level
226
+ return int((level / _MAX_LEVEL) * 4),
227
+
228
+
229
+ def _posterize_increasing_level_to_arg(level, hparams):
230
+ # As per Tensorflow models research and UDA impl
231
+ # range [4, 0], 'keep 4 down to 0 MSB of original image',
232
+ # intensity/severity of augmentation increases with level
233
+ return 4 - _posterize_level_to_arg(level, hparams)[0],
234
+
235
+
236
+ def _posterize_original_level_to_arg(level, _hparams):
237
+ # As per original AutoAugment paper description
238
+ # range [4, 8], 'keep 4 up to 8 MSB of image'
239
+ # intensity/severity of augmentation decreases with level
240
+ return int((level / _MAX_LEVEL) * 4) + 4,
241
+
242
+
243
+ def _solarize_level_to_arg(level, _hparams):
244
+ # range [0, 256]
245
+ # intensity/severity of augmentation decreases with level
246
+ return int((level / _MAX_LEVEL) * 256),
247
+
248
+
249
+ def _solarize_increasing_level_to_arg(level, _hparams):
250
+ # range [0, 256]
251
+ # intensity/severity of augmentation increases with level
252
+ return 256 - _solarize_level_to_arg(level, _hparams)[0],
253
+
254
+
255
+ def _solarize_add_level_to_arg(level, _hparams):
256
+ # range [0, 110]
257
+ return int((level / _MAX_LEVEL) * 110),
258
+
259
+
260
+ LEVEL_TO_ARG = {
261
+ 'AutoContrast': None,
262
+ 'Equalize': None,
263
+ 'Invert': None,
264
+ 'Rotate': _rotate_level_to_arg,
265
+ # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
266
+ 'Posterize': _posterize_level_to_arg,
267
+ 'PosterizeIncreasing': _posterize_increasing_level_to_arg,
268
+ 'PosterizeOriginal': _posterize_original_level_to_arg,
269
+ 'Solarize': _solarize_level_to_arg,
270
+ 'SolarizeIncreasing': _solarize_increasing_level_to_arg,
271
+ 'SolarizeAdd': _solarize_add_level_to_arg,
272
+ 'Color': _enhance_level_to_arg,
273
+ 'ColorIncreasing': _enhance_increasing_level_to_arg,
274
+ 'Contrast': _enhance_level_to_arg,
275
+ 'ContrastIncreasing': _enhance_increasing_level_to_arg,
276
+ 'Brightness': _enhance_level_to_arg,
277
+ 'BrightnessIncreasing': _enhance_increasing_level_to_arg,
278
+ 'Sharpness': _enhance_level_to_arg,
279
+ 'SharpnessIncreasing': _enhance_increasing_level_to_arg,
280
+ 'ShearX': _shear_level_to_arg,
281
+ 'ShearY': _shear_level_to_arg,
282
+ 'TranslateX': _translate_abs_level_to_arg,
283
+ 'TranslateY': _translate_abs_level_to_arg,
284
+ 'TranslateXRel': _translate_rel_level_to_arg,
285
+ 'TranslateYRel': _translate_rel_level_to_arg,
286
+ }
287
+
288
+
289
+ NAME_TO_OP = {
290
+ 'AutoContrast': auto_contrast,
291
+ 'Equalize': equalize,
292
+ 'Invert': invert,
293
+ 'Rotate': rotate,
294
+ 'Posterize': posterize,
295
+ 'PosterizeIncreasing': posterize,
296
+ 'PosterizeOriginal': posterize,
297
+ 'Solarize': solarize,
298
+ 'SolarizeIncreasing': solarize,
299
+ 'SolarizeAdd': solarize_add,
300
+ 'Color': color,
301
+ 'ColorIncreasing': color,
302
+ 'Contrast': contrast,
303
+ 'ContrastIncreasing': contrast,
304
+ 'Brightness': brightness,
305
+ 'BrightnessIncreasing': brightness,
306
+ 'Sharpness': sharpness,
307
+ 'SharpnessIncreasing': sharpness,
308
+ 'ShearX': shear_x,
309
+ 'ShearY': shear_y,
310
+ 'TranslateX': translate_x_abs,
311
+ 'TranslateY': translate_y_abs,
312
+ 'TranslateXRel': translate_x_rel,
313
+ 'TranslateYRel': translate_y_rel,
314
+ }
315
+
316
+
317
+ class AugmentOp:
318
+
319
+ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
320
+ hparams = hparams or _HPARAMS_DEFAULT
321
+ self.aug_fn = NAME_TO_OP[name]
322
+ self.level_fn = LEVEL_TO_ARG[name]
323
+ self.prob = prob
324
+ self.magnitude = magnitude
325
+ self.hparams = hparams.copy()
326
+ self.kwargs = dict(
327
+ fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
328
+ resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
329
+ )
330
+
331
+ # If magnitude_std is > 0, we introduce some randomness
332
+ # in the usually fixed policy and sample magnitude from a normal distribution
333
+ # with mean `magnitude` and std-dev of `magnitude_std`.
334
+ # NOTE This is my own hack, being tested, not in papers or reference impls.
335
+ # If magnitude_std is inf, we sample magnitude from a uniform distribution
336
+ self.magnitude_std = self.hparams.get('magnitude_std', 0)
337
+
338
+ def __call__(self, img):
339
+ if self.prob < 1.0 and random.random() > self.prob:
340
+ return img
341
+ magnitude = self.magnitude
342
+ if self.magnitude_std:
343
+ if self.magnitude_std == float('inf'):
344
+ magnitude = random.uniform(0, magnitude)
345
+ elif self.magnitude_std > 0:
346
+ magnitude = random.gauss(magnitude, self.magnitude_std)
347
+ magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
348
+ level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
349
+ return self.aug_fn(img, *level_args, **self.kwargs)
350
+
351
+
352
+ def auto_augment_policy_v0(hparams):
353
+ # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
354
+ policy = [
355
+ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
356
+ [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
357
+ [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
358
+ [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
359
+ [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
360
+ [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
361
+ [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
362
+ [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
363
+ [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
364
+ [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
365
+ [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
366
+ [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
367
+ [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
368
+ [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
369
+ [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
370
+ [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
371
+ [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
372
+ [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
373
+ [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
374
+ [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
375
+ [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
376
+ [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
377
+ [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
378
+ [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
379
+ [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
380
+ ]
381
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
382
+ return pc
383
+
384
+
385
+ def auto_augment_policy_v0r(hparams):
386
+ # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
387
+ # in Google research implementation (number of bits discarded increases with magnitude)
388
+ policy = [
389
+ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
390
+ [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
391
+ [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
392
+ [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
393
+ [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
394
+ [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
395
+ [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
396
+ [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
397
+ [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
398
+ [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
399
+ [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
400
+ [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
401
+ [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
402
+ [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
403
+ [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
404
+ [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
405
+ [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
406
+ [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
407
+ [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
408
+ [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
409
+ [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
410
+ [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
411
+ [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
412
+ [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
413
+ [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
414
+ ]
415
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
416
+ return pc
417
+
418
+
419
+ def auto_augment_policy_original(hparams):
420
+ # ImageNet policy from https://arxiv.org/abs/1805.09501
421
+ policy = [
422
+ [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
423
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
424
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
425
+ [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
426
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
427
+ [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
428
+ [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
429
+ [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
430
+ [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
431
+ [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
432
+ [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
433
+ [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
434
+ [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
435
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
436
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
437
+ [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
438
+ [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
439
+ [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
440
+ [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
441
+ [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
442
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
443
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
444
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
445
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
446
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
447
+ ]
448
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
449
+ return pc
450
+
451
+
452
+ def auto_augment_policy_originalr(hparams):
453
+ # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
454
+ policy = [
455
+ [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
456
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
457
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
458
+ [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
459
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
460
+ [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
461
+ [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
462
+ [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
463
+ [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
464
+ [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
465
+ [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
466
+ [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
467
+ [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
468
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
469
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
470
+ [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
471
+ [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
472
+ [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
473
+ [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
474
+ [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
475
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
476
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
477
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
478
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
479
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
480
+ ]
481
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
482
+ return pc
483
+
484
+
485
+ def auto_augment_policy(name='v0', hparams=None):
486
+ hparams = hparams or _HPARAMS_DEFAULT
487
+ if name == 'original':
488
+ return auto_augment_policy_original(hparams)
489
+ elif name == 'originalr':
490
+ return auto_augment_policy_originalr(hparams)
491
+ elif name == 'v0':
492
+ return auto_augment_policy_v0(hparams)
493
+ elif name == 'v0r':
494
+ return auto_augment_policy_v0r(hparams)
495
+ else:
496
+ assert False, 'Unknown AA policy (%s)' % name
497
+
498
+
499
+ class AutoAugment:
500
+
501
+ def __init__(self, policy):
502
+ self.policy = policy
503
+
504
+ def __call__(self, img):
505
+ sub_policy = random.choice(self.policy)
506
+ for op in sub_policy:
507
+ img = op(img)
508
+ return img
509
+
510
+
511
+ def auto_augment_transform(config_str, hparams):
512
+ """
513
+ Create a AutoAugment transform
514
+
515
+ :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
516
+ dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
517
+ The remaining sections, not order sepecific determine
518
+ 'mstd' - float std deviation of magnitude noise applied
519
+ Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
520
+
521
+ :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
522
+
523
+ :return: A PyTorch compatible Transform
524
+ """
525
+ config = config_str.split('-')
526
+ policy_name = config[0]
527
+ config = config[1:]
528
+ for c in config:
529
+ cs = re.split(r'(\d.*)', c)
530
+ if len(cs) < 2:
531
+ continue
532
+ key, val = cs[:2]
533
+ if key == 'mstd':
534
+ # noise param injected via hparams for now
535
+ hparams.setdefault('magnitude_std', float(val))
536
+ else:
537
+ assert False, 'Unknown AutoAugment config section'
538
+ aa_policy = auto_augment_policy(policy_name, hparams=hparams)
539
+ return AutoAugment(aa_policy)
540
+
541
+
542
+ _RAND_TRANSFORMS = [
543
+ 'AutoContrast',
544
+ 'Equalize',
545
+ 'Invert',
546
+ 'Rotate',
547
+ 'Posterize',
548
+ 'Solarize',
549
+ 'SolarizeAdd',
550
+ 'Color',
551
+ 'Contrast',
552
+ 'Brightness',
553
+ 'Sharpness',
554
+ 'ShearX',
555
+ 'ShearY',
556
+ 'TranslateXRel',
557
+ 'TranslateYRel',
558
+ #'Cutout' # NOTE I've implement this as random erasing separately
559
+ ]
560
+
561
+
562
+ _RAND_INCREASING_TRANSFORMS = [
563
+ 'AutoContrast',
564
+ 'Equalize',
565
+ 'Invert',
566
+ 'Rotate',
567
+ 'PosterizeIncreasing',
568
+ 'SolarizeIncreasing',
569
+ 'SolarizeAdd',
570
+ 'ColorIncreasing',
571
+ 'ContrastIncreasing',
572
+ 'BrightnessIncreasing',
573
+ 'SharpnessIncreasing',
574
+ 'ShearX',
575
+ 'ShearY',
576
+ 'TranslateXRel',
577
+ 'TranslateYRel',
578
+ #'Cutout' # NOTE I've implement this as random erasing separately
579
+ ]
580
+
581
+
582
+
583
+ # These experimental weights are based loosely on the relative improvements mentioned in paper.
584
+ # They may not result in increased performance, but could likely be tuned to so.
585
+ _RAND_CHOICE_WEIGHTS_0 = {
586
+ 'Rotate': 0.3,
587
+ 'ShearX': 0.2,
588
+ 'ShearY': 0.2,
589
+ 'TranslateXRel': 0.1,
590
+ 'TranslateYRel': 0.1,
591
+ 'Color': .025,
592
+ 'Sharpness': 0.025,
593
+ 'AutoContrast': 0.025,
594
+ 'Solarize': .005,
595
+ 'SolarizeAdd': .005,
596
+ 'Contrast': .005,
597
+ 'Brightness': .005,
598
+ 'Equalize': .005,
599
+ 'Posterize': 0,
600
+ 'Invert': 0,
601
+ }
602
+
603
+
604
+ def _select_rand_weights(weight_idx=0, transforms=None):
605
+ transforms = transforms or _RAND_TRANSFORMS
606
+ assert weight_idx == 0 # only one set of weights currently
607
+ rand_weights = _RAND_CHOICE_WEIGHTS_0
608
+ probs = [rand_weights[k] for k in transforms]
609
+ probs /= np.sum(probs)
610
+ return probs
611
+
612
+
613
+ def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
614
+ hparams = hparams or _HPARAMS_DEFAULT
615
+ transforms = transforms or _RAND_TRANSFORMS
616
+ return [AugmentOp(
617
+ name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
618
+
619
+
620
+ class RandAugment:
621
+ def __init__(self, ops, num_layers=2, choice_weights=None):
622
+ self.ops = ops
623
+ self.num_layers = num_layers
624
+ self.choice_weights = choice_weights
625
+
626
+ def __call__(self, img):
627
+ # no replacement when using weighted choice
628
+ ops = np.random.choice(
629
+ self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
630
+ for op in ops:
631
+ img = op(img)
632
+ return img
633
+
634
+
635
+ def rand_augment_transform(config_str, hparams):
636
+ """
637
+ Create a RandAugment transform
638
+
639
+ :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
640
+ dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
641
+ sections, not order sepecific determine
642
+ 'm' - integer magnitude of rand augment
643
+ 'n' - integer num layers (number of transform ops selected per image)
644
+ 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
645
+ 'mstd' - float std deviation of magnitude noise applied
646
+ 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
647
+ Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
648
+ 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
649
+
650
+ :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
651
+
652
+ :return: A PyTorch compatible Transform
653
+ """
654
+ magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
655
+ num_layers = 2 # default to 2 ops per image
656
+ weight_idx = None # default to no probability weights for op choice
657
+ transforms = _RAND_TRANSFORMS
658
+ config = config_str.split('-')
659
+ assert config[0] == 'rand'
660
+ config = config[1:]
661
+ for c in config:
662
+ cs = re.split(r'(\d.*)', c)
663
+ if len(cs) < 2:
664
+ continue
665
+ key, val = cs[:2]
666
+ if key == 'mstd':
667
+ # noise param injected via hparams for now
668
+ hparams.setdefault('magnitude_std', float(val))
669
+ elif key == 'inc':
670
+ if bool(val):
671
+ transforms = _RAND_INCREASING_TRANSFORMS
672
+ elif key == 'm':
673
+ magnitude = int(val)
674
+ elif key == 'n':
675
+ num_layers = int(val)
676
+ elif key == 'w':
677
+ weight_idx = int(val)
678
+ else:
679
+ assert False, 'Unknown RandAugment config section'
680
+ ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
681
+ choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
682
+ return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
683
+
684
+
685
+ _AUGMIX_TRANSFORMS = [
686
+ 'AutoContrast',
687
+ 'ColorIncreasing', # not in paper
688
+ 'ContrastIncreasing', # not in paper
689
+ 'BrightnessIncreasing', # not in paper
690
+ 'SharpnessIncreasing', # not in paper
691
+ 'Equalize',
692
+ 'Rotate',
693
+ 'PosterizeIncreasing',
694
+ 'SolarizeIncreasing',
695
+ 'ShearX',
696
+ 'ShearY',
697
+ 'TranslateXRel',
698
+ 'TranslateYRel',
699
+ ]
700
+
701
+
702
+ def augmix_ops(magnitude=10, hparams=None, transforms=None):
703
+ hparams = hparams or _HPARAMS_DEFAULT
704
+ transforms = transforms or _AUGMIX_TRANSFORMS
705
+ return [AugmentOp(
706
+ name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
707
+
708
+
709
+ class AugMixAugment:
710
+ """ AugMix Transform
711
+ Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
712
+ From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
713
+ https://arxiv.org/abs/1912.02781
714
+ """
715
+ def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
716
+ self.ops = ops
717
+ self.alpha = alpha
718
+ self.width = width
719
+ self.depth = depth
720
+ self.blended = blended # blended mode is faster but not well tested
721
+
722
+ def _calc_blended_weights(self, ws, m):
723
+ ws = ws * m
724
+ cump = 1.
725
+ rws = []
726
+ for w in ws[::-1]:
727
+ alpha = w / cump
728
+ cump *= (1 - alpha)
729
+ rws.append(alpha)
730
+ return np.array(rws[::-1], dtype=np.float32)
731
+
732
+ def _apply_blended(self, img, mixing_weights, m):
733
+ # This is my first crack and implementing a slightly faster mixed augmentation. Instead
734
+ # of accumulating the mix for each chain in a Numpy array and then blending with original,
735
+ # it recomputes the blending coefficients and applies one PIL image blend per chain.
736
+ # TODO the results appear in the right ballpark but they differ by more than rounding.
737
+ img_orig = img.copy()
738
+ ws = self._calc_blended_weights(mixing_weights, m)
739
+ for w in ws:
740
+ depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
741
+ ops = np.random.choice(self.ops, depth, replace=True)
742
+ img_aug = img_orig # no ops are in-place, deep copy not necessary
743
+ for op in ops:
744
+ img_aug = op(img_aug)
745
+ img = Image.blend(img, img_aug, w)
746
+ return img
747
+
748
+ def _apply_basic(self, img, mixing_weights, m):
749
+ # This is a literal adaptation of the paper/official implementation without normalizations and
750
+ # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
751
+ # typical augmentation transforms, could use a GPU / Kornia implementation.
752
+ img_shape = img.size[0], img.size[1], len(img.getbands())
753
+ mixed = np.zeros(img_shape, dtype=np.float32)
754
+ for mw in mixing_weights:
755
+ depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
756
+ ops = np.random.choice(self.ops, depth, replace=True)
757
+ img_aug = img # no ops are in-place, deep copy not necessary
758
+ for op in ops:
759
+ img_aug = op(img_aug)
760
+ mixed += mw * np.asarray(img_aug, dtype=np.float32)
761
+ np.clip(mixed, 0, 255., out=mixed)
762
+ mixed = Image.fromarray(mixed.astype(np.uint8))
763
+ return Image.blend(img, mixed, m)
764
+
765
+ def __call__(self, img):
766
+ mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
767
+ m = np.float32(np.random.beta(self.alpha, self.alpha))
768
+ if self.blended:
769
+ mixed = self._apply_blended(img, mixing_weights, m)
770
+ else:
771
+ mixed = self._apply_basic(img, mixing_weights, m)
772
+ return mixed
773
+
774
+
775
+ def augment_and_mix_transform(config_str, hparams):
776
+ """ Create AugMix PyTorch transform
777
+
778
+ :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
779
+ dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
780
+ sections, not order sepecific determine
781
+ 'm' - integer magnitude (severity) of augmentation mix (default: 3)
782
+ 'w' - integer width of augmentation chain (default: 3)
783
+ 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
784
+ 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
785
+ 'mstd' - float std deviation of magnitude noise applied (default: 0)
786
+ Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
787
+
788
+ :param hparams: Other hparams (kwargs) for the Augmentation transforms
789
+
790
+ :return: A PyTorch compatible Transform
791
+ """
792
+ magnitude = 3
793
+ width = 3
794
+ depth = -1
795
+ alpha = 1.
796
+ blended = False
797
+ hparams['magnitude_std'] = float('inf')
798
+ config = config_str.split('-')
799
+ assert config[0] == 'augmix'
800
+ config = config[1:]
801
+ for c in config:
802
+ cs = re.split(r'(\d.*)', c)
803
+ if len(cs) < 2:
804
+ continue
805
+ key, val = cs[:2]
806
+ if key == 'mstd':
807
+ # noise param injected via hparams for now
808
+ hparams.setdefault('magnitude_std', float(val))
809
+ elif key == 'm':
810
+ magnitude = int(val)
811
+ elif key == 'w':
812
+ width = int(val)
813
+ elif key == 'd':
814
+ depth = int(val)
815
+ elif key == 'a':
816
+ alpha = float(val)
817
+ elif key == 'b':
818
+ blended = bool(val)
819
+ else:
820
+ assert False, 'Unknown AugMix config section'
821
+ ops = augmix_ops(magnitude=magnitude, hparams=hparams)
822
+ return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
timm/data/config.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from .constants import *
3
+
4
+
5
+ _logger = logging.getLogger(__name__)
6
+
7
+
8
+ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
9
+ new_config = {}
10
+ default_cfg = default_cfg
11
+ if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
12
+ default_cfg = model.default_cfg
13
+
14
+ # Resolve input/image size
15
+ in_chans = 3
16
+ if 'chans' in args and args['chans'] is not None:
17
+ in_chans = args['chans']
18
+
19
+ input_size = (in_chans, 224, 224)
20
+ if 'input_size' in args and args['input_size'] is not None:
21
+ assert isinstance(args['input_size'], (tuple, list))
22
+ assert len(args['input_size']) == 3
23
+ input_size = tuple(args['input_size'])
24
+ in_chans = input_size[0] # input_size overrides in_chans
25
+ elif 'img_size' in args and args['img_size'] is not None:
26
+ assert isinstance(args['img_size'], int)
27
+ input_size = (in_chans, args['img_size'], args['img_size'])
28
+ else:
29
+ if use_test_size and 'test_input_size' in default_cfg:
30
+ input_size = default_cfg['test_input_size']
31
+ elif 'input_size' in default_cfg:
32
+ input_size = default_cfg['input_size']
33
+ new_config['input_size'] = input_size
34
+
35
+ # resolve interpolation method
36
+ new_config['interpolation'] = 'bicubic'
37
+ if 'interpolation' in args and args['interpolation']:
38
+ new_config['interpolation'] = args['interpolation']
39
+ elif 'interpolation' in default_cfg:
40
+ new_config['interpolation'] = default_cfg['interpolation']
41
+
42
+ # resolve dataset + model mean for normalization
43
+ new_config['mean'] = IMAGENET_DEFAULT_MEAN
44
+ if 'mean' in args and args['mean'] is not None:
45
+ mean = tuple(args['mean'])
46
+ if len(mean) == 1:
47
+ mean = tuple(list(mean) * in_chans)
48
+ else:
49
+ assert len(mean) == in_chans
50
+ new_config['mean'] = mean
51
+ elif 'mean' in default_cfg:
52
+ new_config['mean'] = default_cfg['mean']
53
+
54
+ # resolve dataset + model std deviation for normalization
55
+ new_config['std'] = IMAGENET_DEFAULT_STD
56
+ if 'std' in args and args['std'] is not None:
57
+ std = tuple(args['std'])
58
+ if len(std) == 1:
59
+ std = tuple(list(std) * in_chans)
60
+ else:
61
+ assert len(std) == in_chans
62
+ new_config['std'] = std
63
+ elif 'std' in default_cfg:
64
+ new_config['std'] = default_cfg['std']
65
+
66
+ # resolve default crop percentage
67
+ new_config['crop_pct'] = DEFAULT_CROP_PCT
68
+ if 'crop_pct' in args and args['crop_pct'] is not None:
69
+ new_config['crop_pct'] = args['crop_pct']
70
+ elif 'crop_pct' in default_cfg:
71
+ new_config['crop_pct'] = default_cfg['crop_pct']
72
+
73
+ if verbose:
74
+ _logger.info('Data processing configuration for current model + dataset:')
75
+ for n, v in new_config.items():
76
+ _logger.info('\t%s: %s' % (n, str(v)))
77
+
78
+ return new_config
timm/data/constants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ DEFAULT_CROP_PCT = 0.875
2
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
3
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
4
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
5
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
6
+ IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
7
+ IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
timm/data/dataset.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Quick n Simple Image Folder, Tarfile based DataSet
2
+
3
+ Hacked together by / Copyright 2020 Ross Wightman
4
+ """
5
+ import torch.utils.data as data
6
+ import os
7
+ import torch
8
+ import logging
9
+
10
+ from PIL import Image
11
+
12
+ from .parsers import create_parser
13
+
14
+ _logger = logging.getLogger(__name__)
15
+
16
+
17
+ _ERROR_RETRY = 50
18
+
19
+
20
+ class ImageDataset(data.Dataset):
21
+
22
+ def __init__(
23
+ self,
24
+ root,
25
+ parser=None,
26
+ class_map='',
27
+ load_bytes=False,
28
+ transform=None,
29
+ ):
30
+ if parser is None or isinstance(parser, str):
31
+ parser = create_parser(parser or '', root=root, class_map=class_map)
32
+ self.parser = parser
33
+ self.load_bytes = load_bytes
34
+ self.transform = transform
35
+ self._consecutive_errors = 0
36
+
37
+ def __getitem__(self, index):
38
+ img, target = self.parser[index]
39
+ try:
40
+ img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
41
+ except Exception as e:
42
+ _logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
43
+ self._consecutive_errors += 1
44
+ if self._consecutive_errors < _ERROR_RETRY:
45
+ return self.__getitem__((index + 1) % len(self.parser))
46
+ else:
47
+ raise e
48
+ self._consecutive_errors = 0
49
+ if self.transform is not None:
50
+ img = self.transform(img)
51
+ if target is None:
52
+ target = torch.tensor(-1, dtype=torch.long)
53
+ return img, target
54
+
55
+ def __len__(self):
56
+ return len(self.parser)
57
+
58
+ def filename(self, index, basename=False, absolute=False):
59
+ return self.parser.filename(index, basename, absolute)
60
+
61
+ def filenames(self, basename=False, absolute=False):
62
+ return self.parser.filenames(basename, absolute)
63
+
64
+
65
+ class IterableImageDataset(data.IterableDataset):
66
+
67
+ def __init__(
68
+ self,
69
+ root,
70
+ parser=None,
71
+ split='train',
72
+ is_training=False,
73
+ batch_size=None,
74
+ class_map='',
75
+ load_bytes=False,
76
+ repeats=0,
77
+ transform=None,
78
+ ):
79
+ assert parser is not None
80
+ if isinstance(parser, str):
81
+ self.parser = create_parser(
82
+ parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats)
83
+ else:
84
+ self.parser = parser
85
+ self.transform = transform
86
+ self._consecutive_errors = 0
87
+
88
+ def __iter__(self):
89
+ for img, target in self.parser:
90
+ if self.transform is not None:
91
+ img = self.transform(img)
92
+ if target is None:
93
+ target = torch.tensor(-1, dtype=torch.long)
94
+ yield img, target
95
+
96
+ def __len__(self):
97
+ if hasattr(self.parser, '__len__'):
98
+ return len(self.parser)
99
+ else:
100
+ return 0
101
+
102
+ def filename(self, index, basename=False, absolute=False):
103
+ assert False, 'Filename lookup by index not supported, use filenames().'
104
+
105
+ def filenames(self, basename=False, absolute=False):
106
+ return self.parser.filenames(basename, absolute)
107
+
108
+
109
+ class AugMixDataset(torch.utils.data.Dataset):
110
+ """Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
111
+
112
+ def __init__(self, dataset, num_splits=2):
113
+ self.augmentation = None
114
+ self.normalize = None
115
+ self.dataset = dataset
116
+ if self.dataset.transform is not None:
117
+ self._set_transforms(self.dataset.transform)
118
+ self.num_splits = num_splits
119
+
120
+ def _set_transforms(self, x):
121
+ assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
122
+ self.dataset.transform = x[0]
123
+ self.augmentation = x[1]
124
+ self.normalize = x[2]
125
+
126
+ @property
127
+ def transform(self):
128
+ return self.dataset.transform
129
+
130
+ @transform.setter
131
+ def transform(self, x):
132
+ self._set_transforms(x)
133
+
134
+ def _normalize(self, x):
135
+ return x if self.normalize is None else self.normalize(x)
136
+
137
+ def __getitem__(self, i):
138
+ x, y = self.dataset[i] # all splits share the same dataset base transform
139
+ x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
140
+ # run the full augmentation on the remaining splits
141
+ for _ in range(self.num_splits - 1):
142
+ x_list.append(self._normalize(self.augmentation(x)))
143
+ return tuple(x_list), y
144
+
145
+ def __len__(self):
146
+ return len(self.dataset)
timm/data/dataset_factory.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from .dataset import IterableImageDataset, ImageDataset
4
+
5
+
6
+ def _search_split(root, split):
7
+ # look for sub-folder with name of split in root and use that if it exists
8
+ split_name = split.split('[')[0]
9
+ try_root = os.path.join(root, split_name)
10
+ if os.path.exists(try_root):
11
+ return try_root
12
+ if split_name == 'validation':
13
+ try_root = os.path.join(root, 'val')
14
+ if os.path.exists(try_root):
15
+ return try_root
16
+ return root
17
+
18
+
19
+ def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs):
20
+ name = name.lower()
21
+ if name.startswith('tfds'):
22
+ ds = IterableImageDataset(
23
+ root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
24
+ else:
25
+ # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
26
+ kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
27
+ if search_split and os.path.isdir(root):
28
+ root = _search_split(root, split)
29
+ ds = ImageDataset(root, parser=name, **kwargs)
30
+ return ds
timm/data/distributed_sampler.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data import Sampler
4
+ import torch.distributed as dist
5
+
6
+
7
+ class OrderedDistributedSampler(Sampler):
8
+ """Sampler that restricts data loading to a subset of the dataset.
9
+ It is especially useful in conjunction with
10
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
11
+ process can pass a DistributedSampler instance as a DataLoader sampler,
12
+ and load a subset of the original dataset that is exclusive to it.
13
+ .. note::
14
+ Dataset is assumed to be of constant size.
15
+ Arguments:
16
+ dataset: Dataset used for sampling.
17
+ num_replicas (optional): Number of processes participating in
18
+ distributed training.
19
+ rank (optional): Rank of the current process within num_replicas.
20
+ """
21
+
22
+ def __init__(self, dataset, num_replicas=None, rank=None):
23
+ if num_replicas is None:
24
+ if not dist.is_available():
25
+ raise RuntimeError("Requires distributed package to be available")
26
+ num_replicas = dist.get_world_size()
27
+ if rank is None:
28
+ if not dist.is_available():
29
+ raise RuntimeError("Requires distributed package to be available")
30
+ rank = dist.get_rank()
31
+ self.dataset = dataset
32
+ self.num_replicas = num_replicas
33
+ self.rank = rank
34
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
35
+ self.total_size = self.num_samples * self.num_replicas
36
+
37
+ def __iter__(self):
38
+ indices = list(range(len(self.dataset)))
39
+
40
+ # add extra samples to make it evenly divisible
41
+ indices += indices[:(self.total_size - len(indices))]
42
+ assert len(indices) == self.total_size
43
+
44
+ # subsample
45
+ indices = indices[self.rank:self.total_size:self.num_replicas]
46
+ assert len(indices) == self.num_samples
47
+
48
+ return iter(indices)
49
+
50
+ def __len__(self):
51
+ return self.num_samples
timm/data/loader.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Loader Factory, Fast Collate, CUDA Prefetcher
2
+
3
+ Prefetcher and Fast Collate inspired by NVIDIA APEX example at
4
+ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
5
+
6
+ Hacked together by / Copyright 2020 Ross Wightman
7
+ """
8
+
9
+ import torch.utils.data
10
+ import numpy as np
11
+
12
+ from .transforms_factory import create_transform
13
+ from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
14
+ from .distributed_sampler import OrderedDistributedSampler
15
+ from .random_erasing import RandomErasing
16
+ from .mixup import FastCollateMixup
17
+
18
+
19
+ def fast_collate(batch):
20
+ """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
21
+ assert isinstance(batch[0], tuple)
22
+ batch_size = len(batch)
23
+ if isinstance(batch[0][0], tuple):
24
+ # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
25
+ # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
26
+ inner_tuple_size = len(batch[0][0])
27
+ flattened_batch_size = batch_size * inner_tuple_size
28
+ targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
29
+ tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
30
+ for i in range(batch_size):
31
+ assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
32
+ for j in range(inner_tuple_size):
33
+ targets[i + j * batch_size] = batch[i][1]
34
+ tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
35
+ return tensor, targets
36
+ elif isinstance(batch[0][0], np.ndarray):
37
+ targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
38
+ assert len(targets) == batch_size
39
+ tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
40
+ for i in range(batch_size):
41
+ tensor[i] += torch.from_numpy(batch[i][0])
42
+ return tensor, targets
43
+ elif isinstance(batch[0][0], torch.Tensor):
44
+ targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
45
+ assert len(targets) == batch_size
46
+ tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
47
+ for i in range(batch_size):
48
+ tensor[i].copy_(batch[i][0])
49
+ return tensor, targets
50
+ else:
51
+ assert False
52
+
53
+
54
+ class PrefetchLoader:
55
+
56
+ def __init__(self,
57
+ loader,
58
+ mean=IMAGENET_DEFAULT_MEAN,
59
+ std=IMAGENET_DEFAULT_STD,
60
+ fp16=False,
61
+ re_prob=0.,
62
+ re_mode='const',
63
+ re_count=1,
64
+ re_num_splits=0):
65
+ self.loader = loader
66
+ self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
67
+ self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
68
+ self.fp16 = fp16
69
+ if fp16:
70
+ self.mean = self.mean.half()
71
+ self.std = self.std.half()
72
+ if re_prob > 0.:
73
+ self.random_erasing = RandomErasing(
74
+ probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
75
+ else:
76
+ self.random_erasing = None
77
+
78
+ def __iter__(self):
79
+ stream = torch.cuda.Stream()
80
+ first = True
81
+
82
+ for next_input, next_target in self.loader:
83
+ with torch.cuda.stream(stream):
84
+ next_input = next_input.cuda(non_blocking=True)
85
+ next_target = next_target.cuda(non_blocking=True)
86
+ if self.fp16:
87
+ next_input = next_input.half().sub_(self.mean).div_(self.std)
88
+ else:
89
+ next_input = next_input.float().sub_(self.mean).div_(self.std)
90
+ if self.random_erasing is not None:
91
+ next_input = self.random_erasing(next_input)
92
+
93
+ if not first:
94
+ yield input, target
95
+ else:
96
+ first = False
97
+
98
+ torch.cuda.current_stream().wait_stream(stream)
99
+ input = next_input
100
+ target = next_target
101
+
102
+ yield input, target
103
+
104
+ def __len__(self):
105
+ return len(self.loader)
106
+
107
+ @property
108
+ def sampler(self):
109
+ return self.loader.sampler
110
+
111
+ @property
112
+ def dataset(self):
113
+ return self.loader.dataset
114
+
115
+ @property
116
+ def mixup_enabled(self):
117
+ if isinstance(self.loader.collate_fn, FastCollateMixup):
118
+ return self.loader.collate_fn.mixup_enabled
119
+ else:
120
+ return False
121
+
122
+ @mixup_enabled.setter
123
+ def mixup_enabled(self, x):
124
+ if isinstance(self.loader.collate_fn, FastCollateMixup):
125
+ self.loader.collate_fn.mixup_enabled = x
126
+
127
+
128
+ def create_loader(
129
+ dataset,
130
+ input_size,
131
+ batch_size,
132
+ is_training=False,
133
+ use_prefetcher=True,
134
+ no_aug=False,
135
+ re_prob=0.,
136
+ re_mode='const',
137
+ re_count=1,
138
+ re_split=False,
139
+ scale=None,
140
+ ratio=None,
141
+ hflip=0.5,
142
+ vflip=0.,
143
+ color_jitter=0.4,
144
+ auto_augment=None,
145
+ num_aug_splits=0,
146
+ interpolation='bilinear',
147
+ mean=IMAGENET_DEFAULT_MEAN,
148
+ std=IMAGENET_DEFAULT_STD,
149
+ num_workers=1,
150
+ distributed=False,
151
+ crop_pct=None,
152
+ collate_fn=None,
153
+ pin_memory=False,
154
+ fp16=False,
155
+ tf_preprocessing=False,
156
+ use_multi_epochs_loader=False,
157
+ persistent_workers=True,
158
+ ):
159
+ re_num_splits = 0
160
+ if re_split:
161
+ # apply RE to second half of batch if no aug split otherwise line up with aug split
162
+ re_num_splits = num_aug_splits or 2
163
+ dataset.transform = create_transform(
164
+ input_size,
165
+ is_training=is_training,
166
+ use_prefetcher=use_prefetcher,
167
+ no_aug=no_aug,
168
+ scale=scale,
169
+ ratio=ratio,
170
+ hflip=hflip,
171
+ vflip=vflip,
172
+ color_jitter=color_jitter,
173
+ auto_augment=auto_augment,
174
+ interpolation=interpolation,
175
+ mean=mean,
176
+ std=std,
177
+ crop_pct=crop_pct,
178
+ tf_preprocessing=tf_preprocessing,
179
+ re_prob=re_prob,
180
+ re_mode=re_mode,
181
+ re_count=re_count,
182
+ re_num_splits=re_num_splits,
183
+ separate=num_aug_splits > 0,
184
+ )
185
+
186
+ sampler = None
187
+ if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
188
+ if is_training:
189
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset)
190
+ else:
191
+ # This will add extra duplicate entries to result in equal num
192
+ # of samples per-process, will slightly alter validation results
193
+ sampler = OrderedDistributedSampler(dataset)
194
+
195
+ if collate_fn is None:
196
+ collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
197
+
198
+ loader_class = torch.utils.data.DataLoader
199
+
200
+ if use_multi_epochs_loader:
201
+ loader_class = MultiEpochsDataLoader
202
+
203
+ loader_args = dict(
204
+ batch_size=batch_size,
205
+ shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
206
+ num_workers=num_workers,
207
+ sampler=sampler,
208
+ collate_fn=collate_fn,
209
+ pin_memory=pin_memory,
210
+ drop_last=is_training,
211
+ persistent_workers=persistent_workers)
212
+ try:
213
+ loader = loader_class(dataset, **loader_args)
214
+ except TypeError as e:
215
+ loader_args.pop('persistent_workers') # only in Pytorch 1.7+
216
+ loader = loader_class(dataset, **loader_args)
217
+ if use_prefetcher:
218
+ prefetch_re_prob = re_prob if is_training and not no_aug else 0.
219
+ loader = PrefetchLoader(
220
+ loader,
221
+ mean=mean,
222
+ std=std,
223
+ fp16=fp16,
224
+ re_prob=prefetch_re_prob,
225
+ re_mode=re_mode,
226
+ re_count=re_count,
227
+ re_num_splits=re_num_splits
228
+ )
229
+
230
+ return loader
231
+
232
+
233
+ class MultiEpochsDataLoader(torch.utils.data.DataLoader):
234
+
235
+ def __init__(self, *args, **kwargs):
236
+ super().__init__(*args, **kwargs)
237
+ self._DataLoader__initialized = False
238
+ self.batch_sampler = _RepeatSampler(self.batch_sampler)
239
+ self._DataLoader__initialized = True
240
+ self.iterator = super().__iter__()
241
+
242
+ def __len__(self):
243
+ return len(self.batch_sampler.sampler)
244
+
245
+ def __iter__(self):
246
+ for i in range(len(self)):
247
+ yield next(self.iterator)
248
+
249
+
250
+ class _RepeatSampler(object):
251
+ """ Sampler that repeats forever.
252
+
253
+ Args:
254
+ sampler (Sampler)
255
+ """
256
+
257
+ def __init__(self, sampler):
258
+ self.sampler = sampler
259
+
260
+ def __iter__(self):
261
+ while True:
262
+ yield from iter(self.sampler)
timm/data/mixup.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Mixup and Cutmix
2
+
3
+ Papers:
4
+ mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5
+
6
+ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7
+
8
+ Code Reference:
9
+ CutMix: https://github.com/clovaai/CutMix-PyTorch
10
+
11
+ Hacked together by / Copyright 2020 Ross Wightman
12
+ """
13
+ import numpy as np
14
+ import torch
15
+
16
+
17
+ def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
18
+ x = x.long().view(-1, 1)
19
+ return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
20
+
21
+
22
+ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
23
+ off_value = smoothing / num_classes
24
+ on_value = 1. - smoothing + off_value
25
+ y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
26
+ y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
27
+ return y1 * lam + y2 * (1. - lam)
28
+
29
+
30
+ def rand_bbox(img_shape, lam, margin=0., count=None):
31
+ """ Standard CutMix bounding-box
32
+ Generates a random square bbox based on lambda value. This impl includes
33
+ support for enforcing a border margin as percent of bbox dimensions.
34
+
35
+ Args:
36
+ img_shape (tuple): Image shape as tuple
37
+ lam (float): Cutmix lambda value
38
+ margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
39
+ count (int): Number of bbox to generate
40
+ """
41
+ ratio = np.sqrt(1 - lam)
42
+ img_h, img_w = img_shape[-2:]
43
+ cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
44
+ margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
45
+ cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
46
+ cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
47
+ yl = np.clip(cy - cut_h // 2, 0, img_h)
48
+ yh = np.clip(cy + cut_h // 2, 0, img_h)
49
+ xl = np.clip(cx - cut_w // 2, 0, img_w)
50
+ xh = np.clip(cx + cut_w // 2, 0, img_w)
51
+ return yl, yh, xl, xh
52
+
53
+
54
+ def rand_bbox_minmax(img_shape, minmax, count=None):
55
+ """ Min-Max CutMix bounding-box
56
+ Inspired by Darknet cutmix impl, generates a random rectangular bbox
57
+ based on min/max percent values applied to each dimension of the input image.
58
+
59
+ Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
60
+
61
+ Args:
62
+ img_shape (tuple): Image shape as tuple
63
+ minmax (tuple or list): Min and max bbox ratios (as percent of image size)
64
+ count (int): Number of bbox to generate
65
+ """
66
+ assert len(minmax) == 2
67
+ img_h, img_w = img_shape[-2:]
68
+ cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
69
+ cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
70
+ yl = np.random.randint(0, img_h - cut_h, size=count)
71
+ xl = np.random.randint(0, img_w - cut_w, size=count)
72
+ yu = yl + cut_h
73
+ xu = xl + cut_w
74
+ return yl, yu, xl, xu
75
+
76
+
77
+ def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
78
+ """ Generate bbox and apply lambda correction.
79
+ """
80
+ if ratio_minmax is not None:
81
+ yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
82
+ else:
83
+ yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
84
+ if correct_lam or ratio_minmax is not None:
85
+ bbox_area = (yu - yl) * (xu - xl)
86
+ lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
87
+ return (yl, yu, xl, xu), lam
88
+
89
+
90
+ class Mixup:
91
+ """ Mixup/Cutmix that applies different params to each element or whole batch
92
+
93
+ Args:
94
+ mixup_alpha (float): mixup alpha value, mixup is active if > 0.
95
+ cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
96
+ cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97
+ prob (float): probability of applying mixup or cutmix per batch or element
98
+ switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99
+ mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
100
+ correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101
+ label_smoothing (float): apply label smoothing to the mixed target tensor
102
+ num_classes (int): number of classes for target
103
+ """
104
+ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
105
+ mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
106
+ self.mixup_alpha = mixup_alpha
107
+ self.cutmix_alpha = cutmix_alpha
108
+ self.cutmix_minmax = cutmix_minmax
109
+ if self.cutmix_minmax is not None:
110
+ assert len(self.cutmix_minmax) == 2
111
+ # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
112
+ self.cutmix_alpha = 1.0
113
+ self.mix_prob = prob
114
+ self.switch_prob = switch_prob
115
+ self.label_smoothing = label_smoothing
116
+ self.num_classes = num_classes
117
+ self.mode = mode
118
+ self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
119
+ self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
120
+
121
+ def _params_per_elem(self, batch_size):
122
+ lam = np.ones(batch_size, dtype=np.float32)
123
+ use_cutmix = np.zeros(batch_size, dtype=np.bool)
124
+ if self.mixup_enabled:
125
+ if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
126
+ use_cutmix = np.random.rand(batch_size) < self.switch_prob
127
+ lam_mix = np.where(
128
+ use_cutmix,
129
+ np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
130
+ np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
131
+ elif self.mixup_alpha > 0.:
132
+ lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
133
+ elif self.cutmix_alpha > 0.:
134
+ use_cutmix = np.ones(batch_size, dtype=np.bool)
135
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
136
+ else:
137
+ assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
138
+ lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
139
+ return lam, use_cutmix
140
+
141
+ def _params_per_batch(self):
142
+ lam = 1.
143
+ use_cutmix = False
144
+ if self.mixup_enabled and np.random.rand() < self.mix_prob:
145
+ if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
146
+ use_cutmix = np.random.rand() < self.switch_prob
147
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
148
+ np.random.beta(self.mixup_alpha, self.mixup_alpha)
149
+ elif self.mixup_alpha > 0.:
150
+ lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
151
+ elif self.cutmix_alpha > 0.:
152
+ use_cutmix = True
153
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
154
+ else:
155
+ assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
156
+ lam = float(lam_mix)
157
+ return lam, use_cutmix
158
+
159
+ def _mix_elem(self, x):
160
+ batch_size = len(x)
161
+ lam_batch, use_cutmix = self._params_per_elem(batch_size)
162
+ x_orig = x.clone() # need to keep an unmodified original for mixing source
163
+ for i in range(batch_size):
164
+ j = batch_size - i - 1
165
+ lam = lam_batch[i]
166
+ if lam != 1.:
167
+ if use_cutmix[i]:
168
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
169
+ x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
170
+ x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
171
+ lam_batch[i] = lam
172
+ else:
173
+ x[i] = x[i] * lam + x_orig[j] * (1 - lam)
174
+ return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
175
+
176
+ def _mix_pair(self, x):
177
+ batch_size = len(x)
178
+ lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
179
+ x_orig = x.clone() # need to keep an unmodified original for mixing source
180
+ for i in range(batch_size // 2):
181
+ j = batch_size - i - 1
182
+ lam = lam_batch[i]
183
+ if lam != 1.:
184
+ if use_cutmix[i]:
185
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
186
+ x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
187
+ x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
188
+ x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
189
+ lam_batch[i] = lam
190
+ else:
191
+ x[i] = x[i] * lam + x_orig[j] * (1 - lam)
192
+ x[j] = x[j] * lam + x_orig[i] * (1 - lam)
193
+ lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
194
+ return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
195
+
196
+ def _mix_batch(self, x):
197
+ lam, use_cutmix = self._params_per_batch()
198
+ if lam == 1.:
199
+ return 1.
200
+ if use_cutmix:
201
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
202
+ x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
203
+ x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
204
+ else:
205
+ x_flipped = x.flip(0).mul_(1. - lam)
206
+ x.mul_(lam).add_(x_flipped)
207
+ return lam
208
+
209
+ def __call__(self, x, target):
210
+ assert len(x) % 2 == 0, 'Batch size should be even when using this'
211
+ if self.mode == 'elem':
212
+ lam = self._mix_elem(x)
213
+ elif self.mode == 'pair':
214
+ lam = self._mix_pair(x)
215
+ else:
216
+ lam = self._mix_batch(x)
217
+ target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
218
+ return x, target
219
+
220
+
221
+ class FastCollateMixup(Mixup):
222
+ """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
223
+
224
+ A Mixup impl that's performed while collating the batches.
225
+ """
226
+
227
+ def _mix_elem_collate(self, output, batch, half=False):
228
+ batch_size = len(batch)
229
+ num_elem = batch_size // 2 if half else batch_size
230
+ assert len(output) == num_elem
231
+ lam_batch, use_cutmix = self._params_per_elem(num_elem)
232
+ for i in range(num_elem):
233
+ j = batch_size - i - 1
234
+ lam = lam_batch[i]
235
+ mixed = batch[i][0]
236
+ if lam != 1.:
237
+ if use_cutmix[i]:
238
+ if not half:
239
+ mixed = mixed.copy()
240
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
241
+ output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
242
+ mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
243
+ lam_batch[i] = lam
244
+ else:
245
+ mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
246
+ np.rint(mixed, out=mixed)
247
+ output[i] += torch.from_numpy(mixed.astype(np.uint8))
248
+ if half:
249
+ lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
250
+ return torch.tensor(lam_batch).unsqueeze(1)
251
+
252
+ def _mix_pair_collate(self, output, batch):
253
+ batch_size = len(batch)
254
+ lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
255
+ for i in range(batch_size // 2):
256
+ j = batch_size - i - 1
257
+ lam = lam_batch[i]
258
+ mixed_i = batch[i][0]
259
+ mixed_j = batch[j][0]
260
+ assert 0 <= lam <= 1.0
261
+ if lam < 1.:
262
+ if use_cutmix[i]:
263
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
264
+ output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
265
+ patch_i = mixed_i[:, yl:yh, xl:xh].copy()
266
+ mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
267
+ mixed_j[:, yl:yh, xl:xh] = patch_i
268
+ lam_batch[i] = lam
269
+ else:
270
+ mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
271
+ mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
272
+ mixed_i = mixed_temp
273
+ np.rint(mixed_j, out=mixed_j)
274
+ np.rint(mixed_i, out=mixed_i)
275
+ output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
276
+ output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
277
+ lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
278
+ return torch.tensor(lam_batch).unsqueeze(1)
279
+
280
+ def _mix_batch_collate(self, output, batch):
281
+ batch_size = len(batch)
282
+ lam, use_cutmix = self._params_per_batch()
283
+ if use_cutmix:
284
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
285
+ output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
286
+ for i in range(batch_size):
287
+ j = batch_size - i - 1
288
+ mixed = batch[i][0]
289
+ if lam != 1.:
290
+ if use_cutmix:
291
+ mixed = mixed.copy() # don't want to modify the original while iterating
292
+ mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
293
+ else:
294
+ mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
295
+ np.rint(mixed, out=mixed)
296
+ output[i] += torch.from_numpy(mixed.astype(np.uint8))
297
+ return lam
298
+
299
+ def __call__(self, batch, _=None):
300
+ batch_size = len(batch)
301
+ assert batch_size % 2 == 0, 'Batch size should be even when using this'
302
+ half = 'half' in self.mode
303
+ if half:
304
+ batch_size //= 2
305
+ output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
306
+ if self.mode == 'elem' or self.mode == 'half':
307
+ lam = self._mix_elem_collate(output, batch, half=half)
308
+ elif self.mode == 'pair':
309
+ lam = self._mix_pair_collate(output, batch)
310
+ else:
311
+ lam = self._mix_batch_collate(output, batch)
312
+ target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
313
+ target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
314
+ target = target[:batch_size]
315
+ return output, target
316
+
timm/data/parsers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .parser_factory import create_parser
timm/data/parsers/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (200 Bytes). View file