Update README.md
Browse files
README.md
CHANGED
|
@@ -248,4 +248,78 @@ The following hyperparameters were used during training:
|
|
| 248 |
- Transformers 4.35.2
|
| 249 |
- Pytorch 2.1.0+cu118
|
| 250 |
- Datasets 2.15.0
|
| 251 |
-
- Tokenizers 0.15.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
- Transformers 4.35.2
|
| 249 |
- Pytorch 2.1.0+cu118
|
| 250 |
- Datasets 2.15.0
|
| 251 |
+
- Tokenizers 0.15.0
|
| 252 |
+
|
| 253 |
+
### Example of usage
|
| 254 |
+
|
| 255 |
+
```python
|
| 256 |
+
from datasets import load_dataset
|
| 257 |
+
from transformers import TrainingArguments
|
| 258 |
+
from transformers import CLIPProcessor, AutoModelForImageClassification
|
| 259 |
+
|
| 260 |
+
processor = CLIPProcessor.from_pretrained("Andron00e/CLIPForImageClassification-v1")
|
| 261 |
+
model = AutoModelForImageClassification.from_pretrained("Andron00e/CLIPForImageClassification-v1")
|
| 262 |
+
|
| 263 |
+
dataset = load_dataset("Andron00e/CIFAR100-custom")
|
| 264 |
+
dataset = dataset["train"].train_test_split(test_size=0.2)
|
| 265 |
+
from datasets import DatasetDict
|
| 266 |
+
|
| 267 |
+
val_test = dataset["test"].train_test_split(test_size=0.5)
|
| 268 |
+
dataset = DatasetDict({
|
| 269 |
+
"train": dataset["train"],
|
| 270 |
+
"validation": val_test["train"],
|
| 271 |
+
"test": val_test["test"],
|
| 272 |
+
})
|
| 273 |
+
|
| 274 |
+
def transform(example_batch):
|
| 275 |
+
inputs = processor(text=[classes[x] for x in example_batch['labels']], images=[x for x in example_batch['image']], padding=True, return_tensors='pt')
|
| 276 |
+
inputs['labels'] = example_batch['labels']
|
| 277 |
+
return inputs
|
| 278 |
+
|
| 279 |
+
def collate_fn(batch):
|
| 280 |
+
return {
|
| 281 |
+
'input_ids': torch.stack([x['input_ids'] for x in batch]),
|
| 282 |
+
'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
|
| 283 |
+
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
|
| 284 |
+
'labels': torch.tensor([x['labels'] for x in batch])
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
training_args = TrainingArguments(
|
| 288 |
+
output_dir="./outputs",
|
| 289 |
+
per_device_train_batch_size=16,
|
| 290 |
+
evaluation_strategy="steps",
|
| 291 |
+
num_train_epochs=4,
|
| 292 |
+
fp16=False,
|
| 293 |
+
save_steps=100,
|
| 294 |
+
eval_steps=100,
|
| 295 |
+
logging_steps=10,
|
| 296 |
+
learning_rate=2e-4,
|
| 297 |
+
save_total_limit=2,
|
| 298 |
+
remove_unused_columns=False,
|
| 299 |
+
push_to_hub=False,
|
| 300 |
+
report_to='tensorboard',
|
| 301 |
+
load_best_model_at_end=True,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
from transformers import Trainer
|
| 305 |
+
|
| 306 |
+
trainer = Trainer(
|
| 307 |
+
model=model,
|
| 308 |
+
args=training_args,
|
| 309 |
+
data_collator=collate_fn,
|
| 310 |
+
compute_metrics=compute_metrics,
|
| 311 |
+
train_dataset=dataset.with_transform(transform)["train"],
|
| 312 |
+
eval_dataset=dataset.with_transform(transform)["validation"],
|
| 313 |
+
tokenizer=model.processor,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
train_results = trainer.train()
|
| 317 |
+
trainer.save_model()
|
| 318 |
+
trainer.log_metrics("train", train_results.metrics)
|
| 319 |
+
trainer.save_metrics("train", train_results.metrics)
|
| 320 |
+
trainer.save_state()
|
| 321 |
+
|
| 322 |
+
metrics = trainer.evaluate(processed_dataset['test'])
|
| 323 |
+
trainer.log_metrics("eval", metrics)
|
| 324 |
+
trainer.save_metrics("eval", metrics)
|
| 325 |
+
```
|