Spaces:
Running
Running
Upload 5 files
Browse files- IA3.ipynb +0 -0
- LoRA.ipynb +713 -0
- P_Tuning.ipynb +685 -0
- Prompt_Tuning.ipynb +692 -0
- prefix_tuning.ipynb +710 -0
IA3.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LoRA.ipynb
ADDED
|
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "a9935ae2",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"\n",
|
| 14 |
+
"===================================BUG REPORT===================================\n",
|
| 15 |
+
"Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
|
| 16 |
+
"For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
|
| 17 |
+
"================================================================================\n",
|
| 18 |
+
"CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
|
| 19 |
+
"CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
|
| 20 |
+
"CUDA SETUP: Detected CUDA version 117\n",
|
| 21 |
+
"CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
|
| 22 |
+
]
|
| 23 |
+
}
|
| 24 |
+
],
|
| 25 |
+
"source": [
|
| 26 |
+
"import argparse\n",
|
| 27 |
+
"import os\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"import torch\n",
|
| 30 |
+
"from torch.optim import AdamW\n",
|
| 31 |
+
"from torch.utils.data import DataLoader\n",
|
| 32 |
+
"from peft import (\n",
|
| 33 |
+
" get_peft_config,\n",
|
| 34 |
+
" get_peft_model,\n",
|
| 35 |
+
" get_peft_model_state_dict,\n",
|
| 36 |
+
" set_peft_model_state_dict,\n",
|
| 37 |
+
" LoraConfig,\n",
|
| 38 |
+
" PeftType,\n",
|
| 39 |
+
" PrefixTuningConfig,\n",
|
| 40 |
+
" PromptEncoderConfig,\n",
|
| 41 |
+
")\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"import evaluate\n",
|
| 44 |
+
"from datasets import load_dataset\n",
|
| 45 |
+
"from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n",
|
| 46 |
+
"from tqdm import tqdm"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "code",
|
| 51 |
+
"execution_count": 2,
|
| 52 |
+
"id": "e3b13308",
|
| 53 |
+
"metadata": {},
|
| 54 |
+
"outputs": [],
|
| 55 |
+
"source": [
|
| 56 |
+
"batch_size = 32\n",
|
| 57 |
+
"model_name_or_path = \"roberta-large\"\n",
|
| 58 |
+
"task = \"mrpc\"\n",
|
| 59 |
+
"peft_type = PeftType.LORA\n",
|
| 60 |
+
"device = \"cuda\"\n",
|
| 61 |
+
"num_epochs = 20"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "code",
|
| 66 |
+
"execution_count": 3,
|
| 67 |
+
"id": "0526f571",
|
| 68 |
+
"metadata": {},
|
| 69 |
+
"outputs": [],
|
| 70 |
+
"source": [
|
| 71 |
+
"peft_config = LoraConfig(task_type=\"SEQ_CLS\", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)\n",
|
| 72 |
+
"lr = 3e-4"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "code",
|
| 77 |
+
"execution_count": 4,
|
| 78 |
+
"id": "c2697d07",
|
| 79 |
+
"metadata": {},
|
| 80 |
+
"outputs": [
|
| 81 |
+
{
|
| 82 |
+
"data": {
|
| 83 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 84 |
+
"model_id": "0f74797387a941cbb0709487b8808eba",
|
| 85 |
+
"version_major": 2,
|
| 86 |
+
"version_minor": 0
|
| 87 |
+
},
|
| 88 |
+
"text/plain": [
|
| 89 |
+
"Downloading readme: 0%| | 0.00/27.9k [00:00<?, ?B/s]"
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
"metadata": {},
|
| 93 |
+
"output_type": "display_data"
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"name": "stderr",
|
| 97 |
+
"output_type": "stream",
|
| 98 |
+
"text": [
|
| 99 |
+
"Found cached dataset glue (/home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"data": {
|
| 104 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 105 |
+
"model_id": "1a9ecc2f624343c3af8d1824afb66ac5",
|
| 106 |
+
"version_major": 2,
|
| 107 |
+
"version_minor": 0
|
| 108 |
+
},
|
| 109 |
+
"text/plain": [
|
| 110 |
+
" 0%| | 0/3 [00:00<?, ?it/s]"
|
| 111 |
+
]
|
| 112 |
+
},
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"output_type": "display_data"
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"data": {
|
| 118 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 119 |
+
"model_id": "33b071c0e5794cb48b38bbf68f22b49b",
|
| 120 |
+
"version_major": 2,
|
| 121 |
+
"version_minor": 0
|
| 122 |
+
},
|
| 123 |
+
"text/plain": [
|
| 124 |
+
" 0%| | 0/4 [00:00<?, ?ba/s]"
|
| 125 |
+
]
|
| 126 |
+
},
|
| 127 |
+
"metadata": {},
|
| 128 |
+
"output_type": "display_data"
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"data": {
|
| 132 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 133 |
+
"model_id": "a977694036394d5c99adfb13c023e258",
|
| 134 |
+
"version_major": 2,
|
| 135 |
+
"version_minor": 0
|
| 136 |
+
},
|
| 137 |
+
"text/plain": [
|
| 138 |
+
" 0%| | 0/1 [00:00<?, ?ba/s]"
|
| 139 |
+
]
|
| 140 |
+
},
|
| 141 |
+
"metadata": {},
|
| 142 |
+
"output_type": "display_data"
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"data": {
|
| 146 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 147 |
+
"model_id": "facc8d9092dc4abe9e553fc8e5b795b8",
|
| 148 |
+
"version_major": 2,
|
| 149 |
+
"version_minor": 0
|
| 150 |
+
},
|
| 151 |
+
"text/plain": [
|
| 152 |
+
" 0%| | 0/2 [00:00<?, ?ba/s]"
|
| 153 |
+
]
|
| 154 |
+
},
|
| 155 |
+
"metadata": {},
|
| 156 |
+
"output_type": "display_data"
|
| 157 |
+
}
|
| 158 |
+
],
|
| 159 |
+
"source": [
|
| 160 |
+
"if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n",
|
| 161 |
+
" padding_side = \"left\"\n",
|
| 162 |
+
"else:\n",
|
| 163 |
+
" padding_side = \"right\"\n",
|
| 164 |
+
"\n",
|
| 165 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n",
|
| 166 |
+
"if getattr(tokenizer, \"pad_token_id\") is None:\n",
|
| 167 |
+
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
|
| 168 |
+
"\n",
|
| 169 |
+
"datasets = load_dataset(\"glue\", task)\n",
|
| 170 |
+
"metric = evaluate.load(\"glue\", task)\n",
|
| 171 |
+
"\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"def tokenize_function(examples):\n",
|
| 174 |
+
" # max_length=None => use the model max length (it's actually the default)\n",
|
| 175 |
+
" outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n",
|
| 176 |
+
" return outputs\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"tokenized_datasets = datasets.map(\n",
|
| 180 |
+
" tokenize_function,\n",
|
| 181 |
+
" batched=True,\n",
|
| 182 |
+
" remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
|
| 183 |
+
")\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n",
|
| 186 |
+
"# transformers library\n",
|
| 187 |
+
"tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"\n",
|
| 190 |
+
"def collate_fn(examples):\n",
|
| 191 |
+
" return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
|
| 192 |
+
"\n",
|
| 193 |
+
"\n",
|
| 194 |
+
"# Instantiate dataloaders.\n",
|
| 195 |
+
"train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n",
|
| 196 |
+
"eval_dataloader = DataLoader(\n",
|
| 197 |
+
" tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n",
|
| 198 |
+
")"
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"cell_type": "code",
|
| 203 |
+
"execution_count": null,
|
| 204 |
+
"id": "2ed5ac74",
|
| 205 |
+
"metadata": {},
|
| 206 |
+
"outputs": [],
|
| 207 |
+
"source": [
|
| 208 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n",
|
| 209 |
+
"model = get_peft_model(model, peft_config)\n",
|
| 210 |
+
"model.print_trainable_parameters()\n",
|
| 211 |
+
"model"
|
| 212 |
+
]
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"cell_type": "code",
|
| 216 |
+
"execution_count": 6,
|
| 217 |
+
"id": "0d2d0381",
|
| 218 |
+
"metadata": {},
|
| 219 |
+
"outputs": [],
|
| 220 |
+
"source": [
|
| 221 |
+
"optimizer = AdamW(params=model.parameters(), lr=lr)\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"# Instantiate scheduler\n",
|
| 224 |
+
"lr_scheduler = get_linear_schedule_with_warmup(\n",
|
| 225 |
+
" optimizer=optimizer,\n",
|
| 226 |
+
" num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),\n",
|
| 227 |
+
" num_training_steps=(len(train_dataloader) * num_epochs),\n",
|
| 228 |
+
")"
|
| 229 |
+
]
|
| 230 |
+
},
|
| 231 |
+
{
|
| 232 |
+
"cell_type": "code",
|
| 233 |
+
"execution_count": 7,
|
| 234 |
+
"id": "fa0e73be",
|
| 235 |
+
"metadata": {},
|
| 236 |
+
"outputs": [
|
| 237 |
+
{
|
| 238 |
+
"name": "stderr",
|
| 239 |
+
"output_type": "stream",
|
| 240 |
+
"text": [
|
| 241 |
+
" 0%| | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
| 242 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:28<00:00, 4.08it/s]\n",
|
| 243 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.68it/s]\n"
|
| 244 |
+
]
|
| 245 |
+
},
|
| 246 |
+
{
|
| 247 |
+
"name": "stdout",
|
| 248 |
+
"output_type": "stream",
|
| 249 |
+
"text": [
|
| 250 |
+
"epoch 0: {'accuracy': 0.7009803921568627, 'f1': 0.8189910979228486}\n"
|
| 251 |
+
]
|
| 252 |
+
},
|
| 253 |
+
{
|
| 254 |
+
"name": "stderr",
|
| 255 |
+
"output_type": "stream",
|
| 256 |
+
"text": [
|
| 257 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.18it/s]\n",
|
| 258 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.64it/s]\n"
|
| 259 |
+
]
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"name": "stdout",
|
| 263 |
+
"output_type": "stream",
|
| 264 |
+
"text": [
|
| 265 |
+
"epoch 1: {'accuracy': 0.7622549019607843, 'f1': 0.8482003129890453}\n"
|
| 266 |
+
]
|
| 267 |
+
},
|
| 268 |
+
{
|
| 269 |
+
"name": "stderr",
|
| 270 |
+
"output_type": "stream",
|
| 271 |
+
"text": [
|
| 272 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.20it/s]\n",
|
| 273 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.63it/s]\n"
|
| 274 |
+
]
|
| 275 |
+
},
|
| 276 |
+
{
|
| 277 |
+
"name": "stdout",
|
| 278 |
+
"output_type": "stream",
|
| 279 |
+
"text": [
|
| 280 |
+
"epoch 2: {'accuracy': 0.8651960784313726, 'f1': 0.9005424954792043}\n"
|
| 281 |
+
]
|
| 282 |
+
},
|
| 283 |
+
{
|
| 284 |
+
"name": "stderr",
|
| 285 |
+
"output_type": "stream",
|
| 286 |
+
"text": [
|
| 287 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.21it/s]\n",
|
| 288 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.62it/s]\n"
|
| 289 |
+
]
|
| 290 |
+
},
|
| 291 |
+
{
|
| 292 |
+
"name": "stdout",
|
| 293 |
+
"output_type": "stream",
|
| 294 |
+
"text": [
|
| 295 |
+
"epoch 3: {'accuracy': 0.8921568627450981, 'f1': 0.9228070175438596}\n"
|
| 296 |
+
]
|
| 297 |
+
},
|
| 298 |
+
{
|
| 299 |
+
"name": "stderr",
|
| 300 |
+
"output_type": "stream",
|
| 301 |
+
"text": [
|
| 302 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.20it/s]\n",
|
| 303 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.62it/s]\n"
|
| 304 |
+
]
|
| 305 |
+
},
|
| 306 |
+
{
|
| 307 |
+
"name": "stdout",
|
| 308 |
+
"output_type": "stream",
|
| 309 |
+
"text": [
|
| 310 |
+
"epoch 4: {'accuracy': 0.8970588235294118, 'f1': 0.9257950530035336}\n"
|
| 311 |
+
]
|
| 312 |
+
},
|
| 313 |
+
{
|
| 314 |
+
"name": "stderr",
|
| 315 |
+
"output_type": "stream",
|
| 316 |
+
"text": [
|
| 317 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.16it/s]\n",
|
| 318 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.01it/s]\n"
|
| 319 |
+
]
|
| 320 |
+
},
|
| 321 |
+
{
|
| 322 |
+
"name": "stdout",
|
| 323 |
+
"output_type": "stream",
|
| 324 |
+
"text": [
|
| 325 |
+
"epoch 5: {'accuracy': 0.8823529411764706, 'f1': 0.9169550173010381}\n"
|
| 326 |
+
]
|
| 327 |
+
},
|
| 328 |
+
{
|
| 329 |
+
"name": "stderr",
|
| 330 |
+
"output_type": "stream",
|
| 331 |
+
"text": [
|
| 332 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:30<00:00, 3.81it/s]\n",
|
| 333 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.62it/s]\n"
|
| 334 |
+
]
|
| 335 |
+
},
|
| 336 |
+
{
|
| 337 |
+
"name": "stdout",
|
| 338 |
+
"output_type": "stream",
|
| 339 |
+
"text": [
|
| 340 |
+
"epoch 6: {'accuracy': 0.8799019607843137, 'f1': 0.9170896785109983}\n"
|
| 341 |
+
]
|
| 342 |
+
},
|
| 343 |
+
{
|
| 344 |
+
"name": "stderr",
|
| 345 |
+
"output_type": "stream",
|
| 346 |
+
"text": [
|
| 347 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.16it/s]\n",
|
| 348 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.61it/s]\n"
|
| 349 |
+
]
|
| 350 |
+
},
|
| 351 |
+
{
|
| 352 |
+
"name": "stdout",
|
| 353 |
+
"output_type": "stream",
|
| 354 |
+
"text": [
|
| 355 |
+
"epoch 7: {'accuracy': 0.8799019607843137, 'f1': 0.9150779896013865}\n"
|
| 356 |
+
]
|
| 357 |
+
},
|
| 358 |
+
{
|
| 359 |
+
"name": "stderr",
|
| 360 |
+
"output_type": "stream",
|
| 361 |
+
"text": [
|
| 362 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.18it/s]\n",
|
| 363 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.61it/s]\n"
|
| 364 |
+
]
|
| 365 |
+
},
|
| 366 |
+
{
|
| 367 |
+
"name": "stdout",
|
| 368 |
+
"output_type": "stream",
|
| 369 |
+
"text": [
|
| 370 |
+
"epoch 8: {'accuracy': 0.8921568627450981, 'f1': 0.9233449477351917}\n"
|
| 371 |
+
]
|
| 372 |
+
},
|
| 373 |
+
{
|
| 374 |
+
"name": "stderr",
|
| 375 |
+
"output_type": "stream",
|
| 376 |
+
"text": [
|
| 377 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.18it/s]\n",
|
| 378 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.59it/s]\n"
|
| 379 |
+
]
|
| 380 |
+
},
|
| 381 |
+
{
|
| 382 |
+
"name": "stdout",
|
| 383 |
+
"output_type": "stream",
|
| 384 |
+
"text": [
|
| 385 |
+
"epoch 9: {'accuracy': 0.8872549019607843, 'f1': 0.9217687074829931}\n"
|
| 386 |
+
]
|
| 387 |
+
},
|
| 388 |
+
{
|
| 389 |
+
"name": "stderr",
|
| 390 |
+
"output_type": "stream",
|
| 391 |
+
"text": [
|
| 392 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.16it/s]\n",
|
| 393 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.61it/s]\n"
|
| 394 |
+
]
|
| 395 |
+
},
|
| 396 |
+
{
|
| 397 |
+
"name": "stdout",
|
| 398 |
+
"output_type": "stream",
|
| 399 |
+
"text": [
|
| 400 |
+
"epoch 10: {'accuracy': 0.8774509803921569, 'f1': 0.9137931034482758}\n"
|
| 401 |
+
]
|
| 402 |
+
},
|
| 403 |
+
{
|
| 404 |
+
"name": "stderr",
|
| 405 |
+
"output_type": "stream",
|
| 406 |
+
"text": [
|
| 407 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:29<00:00, 3.90it/s]\n",
|
| 408 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.81it/s]\n"
|
| 409 |
+
]
|
| 410 |
+
},
|
| 411 |
+
{
|
| 412 |
+
"name": "stdout",
|
| 413 |
+
"output_type": "stream",
|
| 414 |
+
"text": [
|
| 415 |
+
"epoch 11: {'accuracy': 0.9068627450980392, 'f1': 0.9321428571428573}\n"
|
| 416 |
+
]
|
| 417 |
+
},
|
| 418 |
+
{
|
| 419 |
+
"name": "stderr",
|
| 420 |
+
"output_type": "stream",
|
| 421 |
+
"text": [
|
| 422 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:28<00:00, 4.05it/s]\n",
|
| 423 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.59it/s]\n"
|
| 424 |
+
]
|
| 425 |
+
},
|
| 426 |
+
{
|
| 427 |
+
"name": "stdout",
|
| 428 |
+
"output_type": "stream",
|
| 429 |
+
"text": [
|
| 430 |
+
"epoch 12: {'accuracy': 0.8946078431372549, 'f1': 0.925476603119584}\n"
|
| 431 |
+
]
|
| 432 |
+
},
|
| 433 |
+
{
|
| 434 |
+
"name": "stderr",
|
| 435 |
+
"output_type": "stream",
|
| 436 |
+
"text": [
|
| 437 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.17it/s]\n",
|
| 438 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.58it/s]\n"
|
| 439 |
+
]
|
| 440 |
+
},
|
| 441 |
+
{
|
| 442 |
+
"name": "stdout",
|
| 443 |
+
"output_type": "stream",
|
| 444 |
+
"text": [
|
| 445 |
+
"epoch 13: {'accuracy': 0.8897058823529411, 'f1': 0.922279792746114}\n"
|
| 446 |
+
]
|
| 447 |
+
},
|
| 448 |
+
{
|
| 449 |
+
"name": "stderr",
|
| 450 |
+
"output_type": "stream",
|
| 451 |
+
"text": [
|
| 452 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.18it/s]\n",
|
| 453 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.61it/s]\n"
|
| 454 |
+
]
|
| 455 |
+
},
|
| 456 |
+
{
|
| 457 |
+
"name": "stdout",
|
| 458 |
+
"output_type": "stream",
|
| 459 |
+
"text": [
|
| 460 |
+
"epoch 14: {'accuracy': 0.8970588235294118, 'f1': 0.9265734265734265}\n"
|
| 461 |
+
]
|
| 462 |
+
},
|
| 463 |
+
{
|
| 464 |
+
"name": "stderr",
|
| 465 |
+
"output_type": "stream",
|
| 466 |
+
"text": [
|
| 467 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.18it/s]\n",
|
| 468 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.60it/s]\n"
|
| 469 |
+
]
|
| 470 |
+
},
|
| 471 |
+
{
|
| 472 |
+
"name": "stdout",
|
| 473 |
+
"output_type": "stream",
|
| 474 |
+
"text": [
|
| 475 |
+
"epoch 15: {'accuracy': 0.8970588235294118, 'f1': 0.9263157894736843}\n"
|
| 476 |
+
]
|
| 477 |
+
},
|
| 478 |
+
{
|
| 479 |
+
"name": "stderr",
|
| 480 |
+
"output_type": "stream",
|
| 481 |
+
"text": [
|
| 482 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.17it/s]\n",
|
| 483 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.59it/s]\n"
|
| 484 |
+
]
|
| 485 |
+
},
|
| 486 |
+
{
|
| 487 |
+
"name": "stdout",
|
| 488 |
+
"output_type": "stream",
|
| 489 |
+
"text": [
|
| 490 |
+
"epoch 16: {'accuracy': 0.8921568627450981, 'f1': 0.9233449477351917}\n"
|
| 491 |
+
]
|
| 492 |
+
},
|
| 493 |
+
{
|
| 494 |
+
"name": "stderr",
|
| 495 |
+
"output_type": "stream",
|
| 496 |
+
"text": [
|
| 497 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.18it/s]\n",
|
| 498 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.58it/s]\n"
|
| 499 |
+
]
|
| 500 |
+
},
|
| 501 |
+
{
|
| 502 |
+
"name": "stdout",
|
| 503 |
+
"output_type": "stream",
|
| 504 |
+
"text": [
|
| 505 |
+
"epoch 17: {'accuracy': 0.8897058823529411, 'f1': 0.9220103986135182}\n"
|
| 506 |
+
]
|
| 507 |
+
},
|
| 508 |
+
{
|
| 509 |
+
"name": "stderr",
|
| 510 |
+
"output_type": "stream",
|
| 511 |
+
"text": [
|
| 512 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:30<00:00, 3.78it/s]\n",
|
| 513 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.58it/s]\n"
|
| 514 |
+
]
|
| 515 |
+
},
|
| 516 |
+
{
|
| 517 |
+
"name": "stdout",
|
| 518 |
+
"output_type": "stream",
|
| 519 |
+
"text": [
|
| 520 |
+
"epoch 18: {'accuracy': 0.8921568627450981, 'f1': 0.9233449477351917}\n"
|
| 521 |
+
]
|
| 522 |
+
},
|
| 523 |
+
{
|
| 524 |
+
"name": "stderr",
|
| 525 |
+
"output_type": "stream",
|
| 526 |
+
"text": [
|
| 527 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:27<00:00, 4.16it/s]\n",
|
| 528 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββ| 13/13 [00:01<00:00, 8.60it/s]"
|
| 529 |
+
]
|
| 530 |
+
},
|
| 531 |
+
{
|
| 532 |
+
"name": "stdout",
|
| 533 |
+
"output_type": "stream",
|
| 534 |
+
"text": [
|
| 535 |
+
"epoch 19: {'accuracy': 0.8946078431372549, 'f1': 0.924693520140105}\n"
|
| 536 |
+
]
|
| 537 |
+
},
|
| 538 |
+
{
|
| 539 |
+
"name": "stderr",
|
| 540 |
+
"output_type": "stream",
|
| 541 |
+
"text": [
|
| 542 |
+
"\n"
|
| 543 |
+
]
|
| 544 |
+
}
|
| 545 |
+
],
|
| 546 |
+
"source": [
|
| 547 |
+
"model.to(device)\n",
|
| 548 |
+
"for epoch in range(num_epochs):\n",
|
| 549 |
+
" model.train()\n",
|
| 550 |
+
" for step, batch in enumerate(tqdm(train_dataloader)):\n",
|
| 551 |
+
" batch.to(device)\n",
|
| 552 |
+
" outputs = model(**batch)\n",
|
| 553 |
+
" loss = outputs.loss\n",
|
| 554 |
+
" loss.backward()\n",
|
| 555 |
+
" optimizer.step()\n",
|
| 556 |
+
" lr_scheduler.step()\n",
|
| 557 |
+
" optimizer.zero_grad()\n",
|
| 558 |
+
"\n",
|
| 559 |
+
" model.eval()\n",
|
| 560 |
+
" for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
| 561 |
+
" batch.to(device)\n",
|
| 562 |
+
" with torch.no_grad():\n",
|
| 563 |
+
" outputs = model(**batch)\n",
|
| 564 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
| 565 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
| 566 |
+
" metric.add_batch(\n",
|
| 567 |
+
" predictions=predictions,\n",
|
| 568 |
+
" references=references,\n",
|
| 569 |
+
" )\n",
|
| 570 |
+
"\n",
|
| 571 |
+
" eval_metric = metric.compute()\n",
|
| 572 |
+
" print(f\"epoch {epoch}:\", eval_metric)"
|
| 573 |
+
]
|
| 574 |
+
},
|
| 575 |
+
{
|
| 576 |
+
"cell_type": "markdown",
|
| 577 |
+
"id": "f2b2caca",
|
| 578 |
+
"metadata": {},
|
| 579 |
+
"source": [
|
| 580 |
+
"## Share adapters on the π€ Hub"
|
| 581 |
+
]
|
| 582 |
+
},
|
| 583 |
+
{
|
| 584 |
+
"cell_type": "code",
|
| 585 |
+
"execution_count": 8,
|
| 586 |
+
"id": "990b3c93",
|
| 587 |
+
"metadata": {},
|
| 588 |
+
"outputs": [
|
| 589 |
+
{
|
| 590 |
+
"data": {
|
| 591 |
+
"text/plain": [
|
| 592 |
+
"CommitInfo(commit_url='https://huggingface.co/smangrul/roberta-large-peft-lora/commit/c2c661898b8b6a0c68ecd068931e598d0a79686b', commit_message='Upload model', commit_description='', oid='c2c661898b8b6a0c68ecd068931e598d0a79686b', pr_url=None, pr_revision=None, pr_num=None)"
|
| 593 |
+
]
|
| 594 |
+
},
|
| 595 |
+
"execution_count": 8,
|
| 596 |
+
"metadata": {},
|
| 597 |
+
"output_type": "execute_result"
|
| 598 |
+
}
|
| 599 |
+
],
|
| 600 |
+
"source": [
|
| 601 |
+
"model.push_to_hub(\"smangrul/roberta-large-peft-lora\", use_auth_token=True)"
|
| 602 |
+
]
|
| 603 |
+
},
|
| 604 |
+
{
|
| 605 |
+
"cell_type": "markdown",
|
| 606 |
+
"id": "9d140b26",
|
| 607 |
+
"metadata": {},
|
| 608 |
+
"source": [
|
| 609 |
+
"## Load adapters from the Hub\n",
|
| 610 |
+
"\n",
|
| 611 |
+
"You can also directly load adapters from the Hub using the commands below:"
|
| 612 |
+
]
|
| 613 |
+
},
|
| 614 |
+
{
|
| 615 |
+
"cell_type": "code",
|
| 616 |
+
"execution_count": 11,
|
| 617 |
+
"id": "4d55c87d",
|
| 618 |
+
"metadata": {},
|
| 619 |
+
"outputs": [
|
| 620 |
+
{
|
| 621 |
+
"name": "stderr",
|
| 622 |
+
"output_type": "stream",
|
| 623 |
+
"text": [
|
| 624 |
+
"Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias']\n",
|
| 625 |
+
"- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
| 626 |
+
"- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
| 627 |
+
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.out_proj.bias', 'classifier.dense.weight', 'classifier.out_proj.weight']\n",
|
| 628 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
| 629 |
+
" 0%| | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
| 630 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.45it/s]"
|
| 631 |
+
]
|
| 632 |
+
},
|
| 633 |
+
{
|
| 634 |
+
"name": "stdout",
|
| 635 |
+
"output_type": "stream",
|
| 636 |
+
"text": [
|
| 637 |
+
"{'accuracy': 0.8946078431372549, 'f1': 0.924693520140105}\n"
|
| 638 |
+
]
|
| 639 |
+
},
|
| 640 |
+
{
|
| 641 |
+
"name": "stderr",
|
| 642 |
+
"output_type": "stream",
|
| 643 |
+
"text": [
|
| 644 |
+
"\n"
|
| 645 |
+
]
|
| 646 |
+
}
|
| 647 |
+
],
|
| 648 |
+
"source": [
|
| 649 |
+
"import torch\n",
|
| 650 |
+
"from peft import PeftModel, PeftConfig\n",
|
| 651 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
| 652 |
+
"\n",
|
| 653 |
+
"peft_model_id = \"smangrul/roberta-large-peft-lora\"\n",
|
| 654 |
+
"config = PeftConfig.from_pretrained(peft_model_id)\n",
|
| 655 |
+
"inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)\n",
|
| 656 |
+
"tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
|
| 657 |
+
"\n",
|
| 658 |
+
"# Load the Lora model\n",
|
| 659 |
+
"inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)\n",
|
| 660 |
+
"\n",
|
| 661 |
+
"inference_model.to(device)\n",
|
| 662 |
+
"inference_model.eval()\n",
|
| 663 |
+
"for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
| 664 |
+
" batch.to(device)\n",
|
| 665 |
+
" with torch.no_grad():\n",
|
| 666 |
+
" outputs = inference_model(**batch)\n",
|
| 667 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
| 668 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
| 669 |
+
" metric.add_batch(\n",
|
| 670 |
+
" predictions=predictions,\n",
|
| 671 |
+
" references=references,\n",
|
| 672 |
+
" )\n",
|
| 673 |
+
"\n",
|
| 674 |
+
"eval_metric = metric.compute()\n",
|
| 675 |
+
"print(eval_metric)"
|
| 676 |
+
]
|
| 677 |
+
},
|
| 678 |
+
{
|
| 679 |
+
"cell_type": "code",
|
| 680 |
+
"execution_count": null,
|
| 681 |
+
"id": "27c43da1",
|
| 682 |
+
"metadata": {},
|
| 683 |
+
"outputs": [],
|
| 684 |
+
"source": []
|
| 685 |
+
}
|
| 686 |
+
],
|
| 687 |
+
"metadata": {
|
| 688 |
+
"kernelspec": {
|
| 689 |
+
"display_name": "Python 3 (ipykernel)",
|
| 690 |
+
"language": "python",
|
| 691 |
+
"name": "python3"
|
| 692 |
+
},
|
| 693 |
+
"language_info": {
|
| 694 |
+
"codemirror_mode": {
|
| 695 |
+
"name": "ipython",
|
| 696 |
+
"version": 3
|
| 697 |
+
},
|
| 698 |
+
"file_extension": ".py",
|
| 699 |
+
"mimetype": "text/x-python",
|
| 700 |
+
"name": "python",
|
| 701 |
+
"nbconvert_exporter": "python",
|
| 702 |
+
"pygments_lexer": "ipython3",
|
| 703 |
+
"version": "3.10.5 (v3.10.5:f377153967, Jun 6 2022, 12:36:10) [Clang 13.0.0 (clang-1300.0.29.30)]"
|
| 704 |
+
},
|
| 705 |
+
"vscode": {
|
| 706 |
+
"interpreter": {
|
| 707 |
+
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
|
| 708 |
+
}
|
| 709 |
+
}
|
| 710 |
+
},
|
| 711 |
+
"nbformat": 4,
|
| 712 |
+
"nbformat_minor": 5
|
| 713 |
+
}
|
P_Tuning.ipynb
ADDED
|
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "a825ba6b",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"\n",
|
| 14 |
+
"===================================BUG REPORT===================================\n",
|
| 15 |
+
"Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
|
| 16 |
+
"For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
|
| 17 |
+
"================================================================================\n",
|
| 18 |
+
"CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
|
| 19 |
+
"CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
|
| 20 |
+
"CUDA SETUP: Detected CUDA version 117\n",
|
| 21 |
+
"CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
|
| 22 |
+
]
|
| 23 |
+
}
|
| 24 |
+
],
|
| 25 |
+
"source": [
|
| 26 |
+
"import argparse\n",
|
| 27 |
+
"import os\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"import torch\n",
|
| 30 |
+
"from torch.optim import AdamW\n",
|
| 31 |
+
"from torch.utils.data import DataLoader\n",
|
| 32 |
+
"from peft import (\n",
|
| 33 |
+
" get_peft_config,\n",
|
| 34 |
+
" get_peft_model,\n",
|
| 35 |
+
" get_peft_model_state_dict,\n",
|
| 36 |
+
" set_peft_model_state_dict,\n",
|
| 37 |
+
" PeftType,\n",
|
| 38 |
+
" PrefixTuningConfig,\n",
|
| 39 |
+
" PromptEncoderConfig,\n",
|
| 40 |
+
")\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"import evaluate\n",
|
| 43 |
+
"from datasets import load_dataset\n",
|
| 44 |
+
"from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n",
|
| 45 |
+
"from tqdm import tqdm"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": 2,
|
| 51 |
+
"id": "2bd7cbb2",
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"outputs": [],
|
| 54 |
+
"source": [
|
| 55 |
+
"batch_size = 32\n",
|
| 56 |
+
"model_name_or_path = \"roberta-large\"\n",
|
| 57 |
+
"task = \"mrpc\"\n",
|
| 58 |
+
"peft_type = PeftType.P_TUNING\n",
|
| 59 |
+
"device = \"cuda\"\n",
|
| 60 |
+
"num_epochs = 20"
|
| 61 |
+
]
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"cell_type": "code",
|
| 65 |
+
"execution_count": 3,
|
| 66 |
+
"id": "33d9b62e",
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"outputs": [],
|
| 69 |
+
"source": [
|
| 70 |
+
"peft_config = PromptEncoderConfig(task_type=\"SEQ_CLS\", num_virtual_tokens=20, encoder_hidden_size=128)\n",
|
| 71 |
+
"lr = 1e-3"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": 4,
|
| 77 |
+
"id": "152b6177",
|
| 78 |
+
"metadata": {},
|
| 79 |
+
"outputs": [
|
| 80 |
+
{
|
| 81 |
+
"name": "stderr",
|
| 82 |
+
"output_type": "stream",
|
| 83 |
+
"text": [
|
| 84 |
+
"Found cached dataset glue (/home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
|
| 85 |
+
]
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"data": {
|
| 89 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 90 |
+
"model_id": "a451b90675e0451489cc6426465afa32",
|
| 91 |
+
"version_major": 2,
|
| 92 |
+
"version_minor": 0
|
| 93 |
+
},
|
| 94 |
+
"text/plain": [
|
| 95 |
+
" 0%| | 0/3 [00:00<?, ?it/s]"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
"metadata": {},
|
| 99 |
+
"output_type": "display_data"
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"name": "stderr",
|
| 103 |
+
"output_type": "stream",
|
| 104 |
+
"text": [
|
| 105 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9fa7887f9eaa03ae.arrow\n",
|
| 106 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-dc593149bbeafe80.arrow\n",
|
| 107 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-140ebe5b70e09817.arrow\n"
|
| 108 |
+
]
|
| 109 |
+
}
|
| 110 |
+
],
|
| 111 |
+
"source": [
|
| 112 |
+
"if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n",
|
| 113 |
+
" padding_side = \"left\"\n",
|
| 114 |
+
"else:\n",
|
| 115 |
+
" padding_side = \"right\"\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n",
|
| 118 |
+
"if getattr(tokenizer, \"pad_token_id\") is None:\n",
|
| 119 |
+
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"datasets = load_dataset(\"glue\", task)\n",
|
| 122 |
+
"metric = evaluate.load(\"glue\", task)\n",
|
| 123 |
+
"\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"def tokenize_function(examples):\n",
|
| 126 |
+
" # max_length=None => use the model max length (it's actually the default)\n",
|
| 127 |
+
" outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n",
|
| 128 |
+
" return outputs\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"\n",
|
| 131 |
+
"tokenized_datasets = datasets.map(\n",
|
| 132 |
+
" tokenize_function,\n",
|
| 133 |
+
" batched=True,\n",
|
| 134 |
+
" remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
|
| 135 |
+
")\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n",
|
| 138 |
+
"# transformers library\n",
|
| 139 |
+
"tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"def collate_fn(examples):\n",
|
| 143 |
+
" return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"# Instantiate dataloaders.\n",
|
| 147 |
+
"train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n",
|
| 148 |
+
"eval_dataloader = DataLoader(\n",
|
| 149 |
+
" tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n",
|
| 150 |
+
")"
|
| 151 |
+
]
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"cell_type": "code",
|
| 155 |
+
"execution_count": null,
|
| 156 |
+
"id": "f6bc8144",
|
| 157 |
+
"metadata": {},
|
| 158 |
+
"outputs": [],
|
| 159 |
+
"source": [
|
| 160 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n",
|
| 161 |
+
"model = get_peft_model(model, peft_config)\n",
|
| 162 |
+
"model.print_trainable_parameters()\n",
|
| 163 |
+
"model"
|
| 164 |
+
]
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"cell_type": "code",
|
| 168 |
+
"execution_count": 6,
|
| 169 |
+
"id": "af41c571",
|
| 170 |
+
"metadata": {},
|
| 171 |
+
"outputs": [],
|
| 172 |
+
"source": [
|
| 173 |
+
"optimizer = AdamW(params=model.parameters(), lr=lr)\n",
|
| 174 |
+
"\n",
|
| 175 |
+
"# Instantiate scheduler\n",
|
| 176 |
+
"lr_scheduler = get_linear_schedule_with_warmup(\n",
|
| 177 |
+
" optimizer=optimizer,\n",
|
| 178 |
+
" num_warmup_steps=0, # 0.06*(len(train_dataloader) * num_epochs),\n",
|
| 179 |
+
" num_training_steps=(len(train_dataloader) * num_epochs),\n",
|
| 180 |
+
")"
|
| 181 |
+
]
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"cell_type": "code",
|
| 185 |
+
"execution_count": 7,
|
| 186 |
+
"id": "90993c93",
|
| 187 |
+
"metadata": {},
|
| 188 |
+
"outputs": [
|
| 189 |
+
{
|
| 190 |
+
"name": "stderr",
|
| 191 |
+
"output_type": "stream",
|
| 192 |
+
"text": [
|
| 193 |
+
" 0%| | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
| 194 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.54it/s]\n",
|
| 195 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.91it/s]\n"
|
| 196 |
+
]
|
| 197 |
+
},
|
| 198 |
+
{
|
| 199 |
+
"name": "stdout",
|
| 200 |
+
"output_type": "stream",
|
| 201 |
+
"text": [
|
| 202 |
+
"epoch 0: {'accuracy': 0.6985294117647058, 'f1': 0.8172362555720655}\n"
|
| 203 |
+
]
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"name": "stderr",
|
| 207 |
+
"output_type": "stream",
|
| 208 |
+
"text": [
|
| 209 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.61it/s]\n",
|
| 210 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.87it/s]\n"
|
| 211 |
+
]
|
| 212 |
+
},
|
| 213 |
+
{
|
| 214 |
+
"name": "stdout",
|
| 215 |
+
"output_type": "stream",
|
| 216 |
+
"text": [
|
| 217 |
+
"epoch 1: {'accuracy': 0.6936274509803921, 'f1': 0.806201550387597}\n"
|
| 218 |
+
]
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"name": "stderr",
|
| 222 |
+
"output_type": "stream",
|
| 223 |
+
"text": [
|
| 224 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.61it/s]\n",
|
| 225 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.88it/s]\n"
|
| 226 |
+
]
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"name": "stdout",
|
| 230 |
+
"output_type": "stream",
|
| 231 |
+
"text": [
|
| 232 |
+
"epoch 2: {'accuracy': 0.7132352941176471, 'f1': 0.8224582701062216}\n"
|
| 233 |
+
]
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"name": "stderr",
|
| 237 |
+
"output_type": "stream",
|
| 238 |
+
"text": [
|
| 239 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.61it/s]\n",
|
| 240 |
+
"100%|ββββββββββββββββββββββββββββοΏ½οΏ½βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.87it/s]\n"
|
| 241 |
+
]
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"name": "stdout",
|
| 245 |
+
"output_type": "stream",
|
| 246 |
+
"text": [
|
| 247 |
+
"epoch 3: {'accuracy': 0.7083333333333334, 'f1': 0.8199697428139183}\n"
|
| 248 |
+
]
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"name": "stderr",
|
| 252 |
+
"output_type": "stream",
|
| 253 |
+
"text": [
|
| 254 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.61it/s]\n",
|
| 255 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.90it/s]\n"
|
| 256 |
+
]
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"name": "stdout",
|
| 260 |
+
"output_type": "stream",
|
| 261 |
+
"text": [
|
| 262 |
+
"epoch 4: {'accuracy': 0.7205882352941176, 'f1': 0.8246153846153846}\n"
|
| 263 |
+
]
|
| 264 |
+
},
|
| 265 |
+
{
|
| 266 |
+
"name": "stderr",
|
| 267 |
+
"output_type": "stream",
|
| 268 |
+
"text": [
|
| 269 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.62it/s]\n",
|
| 270 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.90it/s]\n"
|
| 271 |
+
]
|
| 272 |
+
},
|
| 273 |
+
{
|
| 274 |
+
"name": "stdout",
|
| 275 |
+
"output_type": "stream",
|
| 276 |
+
"text": [
|
| 277 |
+
"epoch 5: {'accuracy': 0.7009803921568627, 'f1': 0.8200589970501474}\n"
|
| 278 |
+
]
|
| 279 |
+
},
|
| 280 |
+
{
|
| 281 |
+
"name": "stderr",
|
| 282 |
+
"output_type": "stream",
|
| 283 |
+
"text": [
|
| 284 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.59it/s]\n",
|
| 285 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.89it/s]\n"
|
| 286 |
+
]
|
| 287 |
+
},
|
| 288 |
+
{
|
| 289 |
+
"name": "stdout",
|
| 290 |
+
"output_type": "stream",
|
| 291 |
+
"text": [
|
| 292 |
+
"epoch 6: {'accuracy': 0.7254901960784313, 'f1': 0.8292682926829268}\n"
|
| 293 |
+
]
|
| 294 |
+
},
|
| 295 |
+
{
|
| 296 |
+
"name": "stderr",
|
| 297 |
+
"output_type": "stream",
|
| 298 |
+
"text": [
|
| 299 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.60it/s]\n",
|
| 300 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.86it/s]\n"
|
| 301 |
+
]
|
| 302 |
+
},
|
| 303 |
+
{
|
| 304 |
+
"name": "stdout",
|
| 305 |
+
"output_type": "stream",
|
| 306 |
+
"text": [
|
| 307 |
+
"epoch 7: {'accuracy': 0.7230392156862745, 'f1': 0.8269525267993874}\n"
|
| 308 |
+
]
|
| 309 |
+
},
|
| 310 |
+
{
|
| 311 |
+
"name": "stderr",
|
| 312 |
+
"output_type": "stream",
|
| 313 |
+
"text": [
|
| 314 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:34<00:00, 3.34it/s]\n",
|
| 315 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.88it/s]\n"
|
| 316 |
+
]
|
| 317 |
+
},
|
| 318 |
+
{
|
| 319 |
+
"name": "stdout",
|
| 320 |
+
"output_type": "stream",
|
| 321 |
+
"text": [
|
| 322 |
+
"epoch 8: {'accuracy': 0.7254901960784313, 'f1': 0.8297872340425533}\n"
|
| 323 |
+
]
|
| 324 |
+
},
|
| 325 |
+
{
|
| 326 |
+
"name": "stderr",
|
| 327 |
+
"output_type": "stream",
|
| 328 |
+
"text": [
|
| 329 |
+
"100%|ββββββββββββββββββββββββββββββββββββοΏ½οΏ½βββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.60it/s]\n",
|
| 330 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.77it/s]\n"
|
| 331 |
+
]
|
| 332 |
+
},
|
| 333 |
+
{
|
| 334 |
+
"name": "stdout",
|
| 335 |
+
"output_type": "stream",
|
| 336 |
+
"text": [
|
| 337 |
+
"epoch 9: {'accuracy': 0.7230392156862745, 'f1': 0.828006088280061}\n"
|
| 338 |
+
]
|
| 339 |
+
},
|
| 340 |
+
{
|
| 341 |
+
"name": "stderr",
|
| 342 |
+
"output_type": "stream",
|
| 343 |
+
"text": [
|
| 344 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.58it/s]\n",
|
| 345 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.88it/s]\n"
|
| 346 |
+
]
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
"name": "stdout",
|
| 350 |
+
"output_type": "stream",
|
| 351 |
+
"text": [
|
| 352 |
+
"epoch 10: {'accuracy': 0.7181372549019608, 'f1': 0.8183254344391785}\n"
|
| 353 |
+
]
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"name": "stderr",
|
| 357 |
+
"output_type": "stream",
|
| 358 |
+
"text": [
|
| 359 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.60it/s]\n",
|
| 360 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.87it/s]\n"
|
| 361 |
+
]
|
| 362 |
+
},
|
| 363 |
+
{
|
| 364 |
+
"name": "stdout",
|
| 365 |
+
"output_type": "stream",
|
| 366 |
+
"text": [
|
| 367 |
+
"epoch 11: {'accuracy': 0.7132352941176471, 'f1': 0.803361344537815}\n"
|
| 368 |
+
]
|
| 369 |
+
},
|
| 370 |
+
{
|
| 371 |
+
"name": "stderr",
|
| 372 |
+
"output_type": "stream",
|
| 373 |
+
"text": [
|
| 374 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.59it/s]\n",
|
| 375 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.85it/s]\n"
|
| 376 |
+
]
|
| 377 |
+
},
|
| 378 |
+
{
|
| 379 |
+
"name": "stdout",
|
| 380 |
+
"output_type": "stream",
|
| 381 |
+
"text": [
|
| 382 |
+
"epoch 12: {'accuracy': 0.7107843137254902, 'f1': 0.8206686930091186}\n"
|
| 383 |
+
]
|
| 384 |
+
},
|
| 385 |
+
{
|
| 386 |
+
"name": "stderr",
|
| 387 |
+
"output_type": "stream",
|
| 388 |
+
"text": [
|
| 389 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.59it/s]\n",
|
| 390 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.85it/s]\n"
|
| 391 |
+
]
|
| 392 |
+
},
|
| 393 |
+
{
|
| 394 |
+
"name": "stdout",
|
| 395 |
+
"output_type": "stream",
|
| 396 |
+
"text": [
|
| 397 |
+
"epoch 13: {'accuracy': 0.7181372549019608, 'f1': 0.8254931714719272}\n"
|
| 398 |
+
]
|
| 399 |
+
},
|
| 400 |
+
{
|
| 401 |
+
"name": "stderr",
|
| 402 |
+
"output_type": "stream",
|
| 403 |
+
"text": [
|
| 404 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.59it/s]\n",
|
| 405 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.87it/s]\n"
|
| 406 |
+
]
|
| 407 |
+
},
|
| 408 |
+
{
|
| 409 |
+
"name": "stdout",
|
| 410 |
+
"output_type": "stream",
|
| 411 |
+
"text": [
|
| 412 |
+
"epoch 14: {'accuracy': 0.7156862745098039, 'f1': 0.8253012048192772}\n"
|
| 413 |
+
]
|
| 414 |
+
},
|
| 415 |
+
{
|
| 416 |
+
"name": "stderr",
|
| 417 |
+
"output_type": "stream",
|
| 418 |
+
"text": [
|
| 419 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.59it/s]\n",
|
| 420 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.84it/s]\n"
|
| 421 |
+
]
|
| 422 |
+
},
|
| 423 |
+
{
|
| 424 |
+
"name": "stdout",
|
| 425 |
+
"output_type": "stream",
|
| 426 |
+
"text": [
|
| 427 |
+
"epoch 15: {'accuracy': 0.7230392156862745, 'f1': 0.8242612752721618}\n"
|
| 428 |
+
]
|
| 429 |
+
},
|
| 430 |
+
{
|
| 431 |
+
"name": "stderr",
|
| 432 |
+
"output_type": "stream",
|
| 433 |
+
"text": [
|
| 434 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.49it/s]\n",
|
| 435 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:02<00:00, 5.84it/s]\n"
|
| 436 |
+
]
|
| 437 |
+
},
|
| 438 |
+
{
|
| 439 |
+
"name": "stdout",
|
| 440 |
+
"output_type": "stream",
|
| 441 |
+
"text": [
|
| 442 |
+
"epoch 16: {'accuracy': 0.7181372549019608, 'f1': 0.8200312989045383}\n"
|
| 443 |
+
]
|
| 444 |
+
},
|
| 445 |
+
{
|
| 446 |
+
"name": "stderr",
|
| 447 |
+
"output_type": "stream",
|
| 448 |
+
"text": [
|
| 449 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:32<00:00, 3.49it/s]\n",
|
| 450 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.84it/s]\n"
|
| 451 |
+
]
|
| 452 |
+
},
|
| 453 |
+
{
|
| 454 |
+
"name": "stdout",
|
| 455 |
+
"output_type": "stream",
|
| 456 |
+
"text": [
|
| 457 |
+
"epoch 17: {'accuracy': 0.7107843137254902, 'f1': 0.8217522658610272}\n"
|
| 458 |
+
]
|
| 459 |
+
},
|
| 460 |
+
{
|
| 461 |
+
"name": "stderr",
|
| 462 |
+
"output_type": "stream",
|
| 463 |
+
"text": [
|
| 464 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.60it/s]\n",
|
| 465 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.88it/s]\n"
|
| 466 |
+
]
|
| 467 |
+
},
|
| 468 |
+
{
|
| 469 |
+
"name": "stdout",
|
| 470 |
+
"output_type": "stream",
|
| 471 |
+
"text": [
|
| 472 |
+
"epoch 18: {'accuracy': 0.7254901960784313, 'f1': 0.8292682926829268}\n"
|
| 473 |
+
]
|
| 474 |
+
},
|
| 475 |
+
{
|
| 476 |
+
"name": "stderr",
|
| 477 |
+
"output_type": "stream",
|
| 478 |
+
"text": [
|
| 479 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:31<00:00, 3.61it/s]\n",
|
| 480 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 6.89it/s]"
|
| 481 |
+
]
|
| 482 |
+
},
|
| 483 |
+
{
|
| 484 |
+
"name": "stdout",
|
| 485 |
+
"output_type": "stream",
|
| 486 |
+
"text": [
|
| 487 |
+
"epoch 19: {'accuracy': 0.7107843137254902, 'f1': 0.8206686930091186}\n"
|
| 488 |
+
]
|
| 489 |
+
},
|
| 490 |
+
{
|
| 491 |
+
"name": "stderr",
|
| 492 |
+
"output_type": "stream",
|
| 493 |
+
"text": [
|
| 494 |
+
"\n"
|
| 495 |
+
]
|
| 496 |
+
}
|
| 497 |
+
],
|
| 498 |
+
"source": [
|
| 499 |
+
"model.to(device)\n",
|
| 500 |
+
"for epoch in range(num_epochs):\n",
|
| 501 |
+
" model.train()\n",
|
| 502 |
+
" for step, batch in enumerate(tqdm(train_dataloader)):\n",
|
| 503 |
+
" batch.to(device)\n",
|
| 504 |
+
" outputs = model(**batch)\n",
|
| 505 |
+
" loss = outputs.loss\n",
|
| 506 |
+
" loss.backward()\n",
|
| 507 |
+
" optimizer.step()\n",
|
| 508 |
+
" lr_scheduler.step()\n",
|
| 509 |
+
" optimizer.zero_grad()\n",
|
| 510 |
+
"\n",
|
| 511 |
+
" model.eval()\n",
|
| 512 |
+
" for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
| 513 |
+
" batch.to(device)\n",
|
| 514 |
+
" with torch.no_grad():\n",
|
| 515 |
+
" outputs = model(**batch)\n",
|
| 516 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
| 517 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
| 518 |
+
" metric.add_batch(\n",
|
| 519 |
+
" predictions=predictions,\n",
|
| 520 |
+
" references=references,\n",
|
| 521 |
+
" )\n",
|
| 522 |
+
"\n",
|
| 523 |
+
" eval_metric = metric.compute()\n",
|
| 524 |
+
" print(f\"epoch {epoch}:\", eval_metric)"
|
| 525 |
+
]
|
| 526 |
+
},
|
| 527 |
+
{
|
| 528 |
+
"cell_type": "markdown",
|
| 529 |
+
"id": "a43bd9fb",
|
| 530 |
+
"metadata": {},
|
| 531 |
+
"source": [
|
| 532 |
+
"## Share adapters on the π€ Hub"
|
| 533 |
+
]
|
| 534 |
+
},
|
| 535 |
+
{
|
| 536 |
+
"cell_type": "code",
|
| 537 |
+
"execution_count": 8,
|
| 538 |
+
"id": "871b75aa",
|
| 539 |
+
"metadata": {},
|
| 540 |
+
"outputs": [
|
| 541 |
+
{
|
| 542 |
+
"data": {
|
| 543 |
+
"text/plain": [
|
| 544 |
+
"CommitInfo(commit_url='https://huggingface.co/smangrul/roberta-large-peft-p-tuning/commit/fa7abe613f498c76df5e16c85d9c19c3019587a7', commit_message='Upload model', commit_description='', oid='fa7abe613f498c76df5e16c85d9c19c3019587a7', pr_url=None, pr_revision=None, pr_num=None)"
|
| 545 |
+
]
|
| 546 |
+
},
|
| 547 |
+
"execution_count": 8,
|
| 548 |
+
"metadata": {},
|
| 549 |
+
"output_type": "execute_result"
|
| 550 |
+
}
|
| 551 |
+
],
|
| 552 |
+
"source": [
|
| 553 |
+
"model.push_to_hub(\"smangrul/roberta-large-peft-p-tuning\", use_auth_token=True)"
|
| 554 |
+
]
|
| 555 |
+
},
|
| 556 |
+
{
|
| 557 |
+
"cell_type": "markdown",
|
| 558 |
+
"id": "1c6a9036",
|
| 559 |
+
"metadata": {},
|
| 560 |
+
"source": [
|
| 561 |
+
"## Load adapters from the Hub\n",
|
| 562 |
+
"\n",
|
| 563 |
+
"You can also directly load adapters from the Hub using the commands below:"
|
| 564 |
+
]
|
| 565 |
+
},
|
| 566 |
+
{
|
| 567 |
+
"cell_type": "code",
|
| 568 |
+
"execution_count": 9,
|
| 569 |
+
"id": "91b0b8f5",
|
| 570 |
+
"metadata": {},
|
| 571 |
+
"outputs": [
|
| 572 |
+
{
|
| 573 |
+
"name": "stderr",
|
| 574 |
+
"output_type": "stream",
|
| 575 |
+
"text": [
|
| 576 |
+
"Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'roberta.pooler.dense.weight', 'lm_head.layer_norm.weight', 'roberta.pooler.dense.bias', 'lm_head.dense.weight', 'lm_head.bias']\n",
|
| 577 |
+
"- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
| 578 |
+
"- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
| 579 |
+
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']\n",
|
| 580 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
| 581 |
+
]
|
| 582 |
+
},
|
| 583 |
+
{
|
| 584 |
+
"data": {
|
| 585 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 586 |
+
"model_id": "e650799d58ec4bd1b21b6bc28ddf2069",
|
| 587 |
+
"version_major": 2,
|
| 588 |
+
"version_minor": 0
|
| 589 |
+
},
|
| 590 |
+
"text/plain": [
|
| 591 |
+
"Downloading: 0%| | 0.00/4.29M [00:00<?, ?B/s]"
|
| 592 |
+
]
|
| 593 |
+
},
|
| 594 |
+
"metadata": {},
|
| 595 |
+
"output_type": "display_data"
|
| 596 |
+
},
|
| 597 |
+
{
|
| 598 |
+
"name": "stderr",
|
| 599 |
+
"output_type": "stream",
|
| 600 |
+
"text": [
|
| 601 |
+
" 0%| | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
| 602 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 7.18it/s]"
|
| 603 |
+
]
|
| 604 |
+
},
|
| 605 |
+
{
|
| 606 |
+
"name": "stdout",
|
| 607 |
+
"output_type": "stream",
|
| 608 |
+
"text": [
|
| 609 |
+
"{'accuracy': 0.7107843137254902, 'f1': 0.8206686930091186}\n"
|
| 610 |
+
]
|
| 611 |
+
},
|
| 612 |
+
{
|
| 613 |
+
"name": "stderr",
|
| 614 |
+
"output_type": "stream",
|
| 615 |
+
"text": [
|
| 616 |
+
"\n"
|
| 617 |
+
]
|
| 618 |
+
}
|
| 619 |
+
],
|
| 620 |
+
"source": [
|
| 621 |
+
"import torch\n",
|
| 622 |
+
"from peft import PeftModel, PeftConfig\n",
|
| 623 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
| 624 |
+
"\n",
|
| 625 |
+
"peft_model_id = \"smangrul/roberta-large-peft-p-tuning\"\n",
|
| 626 |
+
"config = PeftConfig.from_pretrained(peft_model_id)\n",
|
| 627 |
+
"inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)\n",
|
| 628 |
+
"tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
|
| 629 |
+
"\n",
|
| 630 |
+
"# Load the Lora model\n",
|
| 631 |
+
"inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)\n",
|
| 632 |
+
"\n",
|
| 633 |
+
"inference_model.to(device)\n",
|
| 634 |
+
"inference_model.eval()\n",
|
| 635 |
+
"for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
| 636 |
+
" batch.to(device)\n",
|
| 637 |
+
" with torch.no_grad():\n",
|
| 638 |
+
" outputs = inference_model(**batch)\n",
|
| 639 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
| 640 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
| 641 |
+
" metric.add_batch(\n",
|
| 642 |
+
" predictions=predictions,\n",
|
| 643 |
+
" references=references,\n",
|
| 644 |
+
" )\n",
|
| 645 |
+
"\n",
|
| 646 |
+
"eval_metric = metric.compute()\n",
|
| 647 |
+
"print(eval_metric)"
|
| 648 |
+
]
|
| 649 |
+
},
|
| 650 |
+
{
|
| 651 |
+
"cell_type": "code",
|
| 652 |
+
"execution_count": null,
|
| 653 |
+
"id": "1a8d69d1",
|
| 654 |
+
"metadata": {},
|
| 655 |
+
"outputs": [],
|
| 656 |
+
"source": []
|
| 657 |
+
}
|
| 658 |
+
],
|
| 659 |
+
"metadata": {
|
| 660 |
+
"kernelspec": {
|
| 661 |
+
"display_name": "Python 3 (ipykernel)",
|
| 662 |
+
"language": "python",
|
| 663 |
+
"name": "python3"
|
| 664 |
+
},
|
| 665 |
+
"language_info": {
|
| 666 |
+
"codemirror_mode": {
|
| 667 |
+
"name": "ipython",
|
| 668 |
+
"version": 3
|
| 669 |
+
},
|
| 670 |
+
"file_extension": ".py",
|
| 671 |
+
"mimetype": "text/x-python",
|
| 672 |
+
"name": "python",
|
| 673 |
+
"nbconvert_exporter": "python",
|
| 674 |
+
"pygments_lexer": "ipython3",
|
| 675 |
+
"version": "3.10.5 (v3.10.5:f377153967, Jun 6 2022, 12:36:10) [Clang 13.0.0 (clang-1300.0.29.30)]"
|
| 676 |
+
},
|
| 677 |
+
"vscode": {
|
| 678 |
+
"interpreter": {
|
| 679 |
+
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
|
| 680 |
+
}
|
| 681 |
+
}
|
| 682 |
+
},
|
| 683 |
+
"nbformat": 4,
|
| 684 |
+
"nbformat_minor": 5
|
| 685 |
+
}
|
Prompt_Tuning.ipynb
ADDED
|
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "9ff5004e",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"\n",
|
| 14 |
+
"===================================BUG REPORT===================================\n",
|
| 15 |
+
"Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
|
| 16 |
+
"For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
|
| 17 |
+
"================================================================================\n",
|
| 18 |
+
"CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
|
| 19 |
+
"CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
|
| 20 |
+
"CUDA SETUP: Detected CUDA version 117\n",
|
| 21 |
+
"CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
|
| 22 |
+
]
|
| 23 |
+
}
|
| 24 |
+
],
|
| 25 |
+
"source": [
|
| 26 |
+
"import argparse\n",
|
| 27 |
+
"import os\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"import torch\n",
|
| 30 |
+
"from torch.optim import AdamW\n",
|
| 31 |
+
"from torch.utils.data import DataLoader\n",
|
| 32 |
+
"from peft import (\n",
|
| 33 |
+
" get_peft_config,\n",
|
| 34 |
+
" get_peft_model,\n",
|
| 35 |
+
" get_peft_model_state_dict,\n",
|
| 36 |
+
" set_peft_model_state_dict,\n",
|
| 37 |
+
" PeftType,\n",
|
| 38 |
+
" PrefixTuningConfig,\n",
|
| 39 |
+
" PromptEncoderConfig,\n",
|
| 40 |
+
" PromptTuningConfig,\n",
|
| 41 |
+
")\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"import evaluate\n",
|
| 44 |
+
"from datasets import load_dataset\n",
|
| 45 |
+
"from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n",
|
| 46 |
+
"from tqdm import tqdm"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "code",
|
| 51 |
+
"execution_count": 2,
|
| 52 |
+
"id": "e32c4a9e",
|
| 53 |
+
"metadata": {},
|
| 54 |
+
"outputs": [],
|
| 55 |
+
"source": [
|
| 56 |
+
"batch_size = 32\n",
|
| 57 |
+
"model_name_or_path = \"roberta-large\"\n",
|
| 58 |
+
"task = \"mrpc\"\n",
|
| 59 |
+
"peft_type = PeftType.PROMPT_TUNING\n",
|
| 60 |
+
"device = \"cuda\"\n",
|
| 61 |
+
"num_epochs = 20"
|
| 62 |
+
]
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"cell_type": "code",
|
| 66 |
+
"execution_count": 3,
|
| 67 |
+
"id": "622fe9c8",
|
| 68 |
+
"metadata": {},
|
| 69 |
+
"outputs": [],
|
| 70 |
+
"source": [
|
| 71 |
+
"peft_config = PromptTuningConfig(task_type=\"SEQ_CLS\", num_virtual_tokens=10)\n",
|
| 72 |
+
"lr = 1e-3"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "code",
|
| 77 |
+
"execution_count": 4,
|
| 78 |
+
"id": "74e9efe0",
|
| 79 |
+
"metadata": {},
|
| 80 |
+
"outputs": [
|
| 81 |
+
{
|
| 82 |
+
"name": "stderr",
|
| 83 |
+
"output_type": "stream",
|
| 84 |
+
"text": [
|
| 85 |
+
"Found cached dataset glue (/home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
|
| 86 |
+
]
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"data": {
|
| 90 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 91 |
+
"model_id": "76198cec552441818ff107910275e5be",
|
| 92 |
+
"version_major": 2,
|
| 93 |
+
"version_minor": 0
|
| 94 |
+
},
|
| 95 |
+
"text/plain": [
|
| 96 |
+
" 0%| | 0/3 [00:00<?, ?it/s]"
|
| 97 |
+
]
|
| 98 |
+
},
|
| 99 |
+
"metadata": {},
|
| 100 |
+
"output_type": "display_data"
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"name": "stderr",
|
| 104 |
+
"output_type": "stream",
|
| 105 |
+
"text": [
|
| 106 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9fa7887f9eaa03ae.arrow\n",
|
| 107 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-dc593149bbeafe80.arrow\n",
|
| 108 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-140ebe5b70e09817.arrow\n"
|
| 109 |
+
]
|
| 110 |
+
}
|
| 111 |
+
],
|
| 112 |
+
"source": [
|
| 113 |
+
"if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n",
|
| 114 |
+
" padding_side = \"left\"\n",
|
| 115 |
+
"else:\n",
|
| 116 |
+
" padding_side = \"right\"\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n",
|
| 119 |
+
"if getattr(tokenizer, \"pad_token_id\") is None:\n",
|
| 120 |
+
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
|
| 121 |
+
"\n",
|
| 122 |
+
"datasets = load_dataset(\"glue\", task)\n",
|
| 123 |
+
"metric = evaluate.load(\"glue\", task)\n",
|
| 124 |
+
"\n",
|
| 125 |
+
"\n",
|
| 126 |
+
"def tokenize_function(examples):\n",
|
| 127 |
+
" # max_length=None => use the model max length (it's actually the default)\n",
|
| 128 |
+
" outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n",
|
| 129 |
+
" return outputs\n",
|
| 130 |
+
"\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"tokenized_datasets = datasets.map(\n",
|
| 133 |
+
" tokenize_function,\n",
|
| 134 |
+
" batched=True,\n",
|
| 135 |
+
" remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
|
| 136 |
+
")\n",
|
| 137 |
+
"\n",
|
| 138 |
+
"# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n",
|
| 139 |
+
"# transformers library\n",
|
| 140 |
+
"tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"def collate_fn(examples):\n",
|
| 144 |
+
" return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"\n",
|
| 147 |
+
"# Instantiate dataloaders.\n",
|
| 148 |
+
"train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n",
|
| 149 |
+
"eval_dataloader = DataLoader(\n",
|
| 150 |
+
" tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n",
|
| 151 |
+
")"
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"cell_type": "code",
|
| 156 |
+
"execution_count": null,
|
| 157 |
+
"id": "a3c15af0",
|
| 158 |
+
"metadata": {},
|
| 159 |
+
"outputs": [],
|
| 160 |
+
"source": [
|
| 161 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n",
|
| 162 |
+
"model = get_peft_model(model, peft_config)\n",
|
| 163 |
+
"model.print_trainable_parameters()\n",
|
| 164 |
+
"model"
|
| 165 |
+
]
|
| 166 |
+
},
|
| 167 |
+
{
|
| 168 |
+
"cell_type": "code",
|
| 169 |
+
"execution_count": 6,
|
| 170 |
+
"id": "6d3c5edb",
|
| 171 |
+
"metadata": {},
|
| 172 |
+
"outputs": [],
|
| 173 |
+
"source": [
|
| 174 |
+
"optimizer = AdamW(params=model.parameters(), lr=lr)\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"# Instantiate scheduler\n",
|
| 177 |
+
"lr_scheduler = get_linear_schedule_with_warmup(\n",
|
| 178 |
+
" optimizer=optimizer,\n",
|
| 179 |
+
" num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),\n",
|
| 180 |
+
" num_training_steps=(len(train_dataloader) * num_epochs),\n",
|
| 181 |
+
")"
|
| 182 |
+
]
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
"cell_type": "code",
|
| 186 |
+
"execution_count": 7,
|
| 187 |
+
"id": "4d279225",
|
| 188 |
+
"metadata": {},
|
| 189 |
+
"outputs": [
|
| 190 |
+
{
|
| 191 |
+
"name": "stderr",
|
| 192 |
+
"output_type": "stream",
|
| 193 |
+
"text": [
|
| 194 |
+
" 0%| | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
| 195 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [02:09<00:00, 1.13s/it]\n",
|
| 196 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:08<00:00, 1.62it/s]\n"
|
| 197 |
+
]
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"name": "stdout",
|
| 201 |
+
"output_type": "stream",
|
| 202 |
+
"text": [
|
| 203 |
+
"epoch 0: {'accuracy': 0.678921568627451, 'f1': 0.7956318252730109}\n"
|
| 204 |
+
]
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"name": "stderr",
|
| 208 |
+
"output_type": "stream",
|
| 209 |
+
"text": [
|
| 210 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:50<00:00, 1.04it/s]\n",
|
| 211 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.22it/s]\n"
|
| 212 |
+
]
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"name": "stdout",
|
| 216 |
+
"output_type": "stream",
|
| 217 |
+
"text": [
|
| 218 |
+
"epoch 1: {'accuracy': 0.696078431372549, 'f1': 0.8171091445427728}\n"
|
| 219 |
+
]
|
| 220 |
+
},
|
| 221 |
+
{
|
| 222 |
+
"name": "stderr",
|
| 223 |
+
"output_type": "stream",
|
| 224 |
+
"text": [
|
| 225 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:36<00:00, 1.19it/s]\n",
|
| 226 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:06<00:00, 2.00it/s]\n"
|
| 227 |
+
]
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"name": "stdout",
|
| 231 |
+
"output_type": "stream",
|
| 232 |
+
"text": [
|
| 233 |
+
"epoch 2: {'accuracy': 0.6985294117647058, 'f1': 0.8161434977578476}\n"
|
| 234 |
+
]
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"name": "stderr",
|
| 238 |
+
"output_type": "stream",
|
| 239 |
+
"text": [
|
| 240 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:37<00:00, 1.18it/s]\n",
|
| 241 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:06<00:00, 2.09it/s]\n"
|
| 242 |
+
]
|
| 243 |
+
},
|
| 244 |
+
{
|
| 245 |
+
"name": "stdout",
|
| 246 |
+
"output_type": "stream",
|
| 247 |
+
"text": [
|
| 248 |
+
"epoch 3: {'accuracy': 0.7058823529411765, 'f1': 0.7979797979797979}\n"
|
| 249 |
+
]
|
| 250 |
+
},
|
| 251 |
+
{
|
| 252 |
+
"name": "stderr",
|
| 253 |
+
"output_type": "stream",
|
| 254 |
+
"text": [
|
| 255 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [02:03<00:00, 1.07s/it]\n",
|
| 256 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:07<00:00, 1.71it/s]\n"
|
| 257 |
+
]
|
| 258 |
+
},
|
| 259 |
+
{
|
| 260 |
+
"name": "stdout",
|
| 261 |
+
"output_type": "stream",
|
| 262 |
+
"text": [
|
| 263 |
+
"epoch 4: {'accuracy': 0.696078431372549, 'f1': 0.8132530120481929}\n"
|
| 264 |
+
]
|
| 265 |
+
},
|
| 266 |
+
{
|
| 267 |
+
"name": "stderr",
|
| 268 |
+
"output_type": "stream",
|
| 269 |
+
"text": [
|
| 270 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:53<00:00, 1.01it/s]\n",
|
| 271 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.19it/s]\n"
|
| 272 |
+
]
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"name": "stdout",
|
| 276 |
+
"output_type": "stream",
|
| 277 |
+
"text": [
|
| 278 |
+
"epoch 5: {'accuracy': 0.7107843137254902, 'f1': 0.8121019108280254}\n"
|
| 279 |
+
]
|
| 280 |
+
},
|
| 281 |
+
{
|
| 282 |
+
"name": "stderr",
|
| 283 |
+
"output_type": "stream",
|
| 284 |
+
"text": [
|
| 285 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:35<00:00, 1.20it/s]\n",
|
| 286 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.20it/s]\n"
|
| 287 |
+
]
|
| 288 |
+
},
|
| 289 |
+
{
|
| 290 |
+
"name": "stdout",
|
| 291 |
+
"output_type": "stream",
|
| 292 |
+
"text": [
|
| 293 |
+
"epoch 6: {'accuracy': 0.6911764705882353, 'f1': 0.7692307692307693}\n"
|
| 294 |
+
]
|
| 295 |
+
},
|
| 296 |
+
{
|
| 297 |
+
"name": "stderr",
|
| 298 |
+
"output_type": "stream",
|
| 299 |
+
"text": [
|
| 300 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:36<00:00, 1.20it/s]\n",
|
| 301 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.18it/s]\n"
|
| 302 |
+
]
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
"name": "stdout",
|
| 306 |
+
"output_type": "stream",
|
| 307 |
+
"text": [
|
| 308 |
+
"epoch 7: {'accuracy': 0.7156862745098039, 'f1': 0.8209876543209876}\n"
|
| 309 |
+
]
|
| 310 |
+
},
|
| 311 |
+
{
|
| 312 |
+
"name": "stderr",
|
| 313 |
+
"output_type": "stream",
|
| 314 |
+
"text": [
|
| 315 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:35<00:00, 1.20it/s]\n",
|
| 316 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.22it/s]\n"
|
| 317 |
+
]
|
| 318 |
+
},
|
| 319 |
+
{
|
| 320 |
+
"name": "stdout",
|
| 321 |
+
"output_type": "stream",
|
| 322 |
+
"text": [
|
| 323 |
+
"epoch 8: {'accuracy': 0.7205882352941176, 'f1': 0.8240740740740742}\n"
|
| 324 |
+
]
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"name": "stderr",
|
| 328 |
+
"output_type": "stream",
|
| 329 |
+
"text": [
|
| 330 |
+
"100%|ββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:36<00:00, 1.19it/s]\n",
|
| 331 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.21it/s]\n"
|
| 332 |
+
]
|
| 333 |
+
},
|
| 334 |
+
{
|
| 335 |
+
"name": "stdout",
|
| 336 |
+
"output_type": "stream",
|
| 337 |
+
"text": [
|
| 338 |
+
"epoch 9: {'accuracy': 0.7205882352941176, 'f1': 0.8229813664596273}\n"
|
| 339 |
+
]
|
| 340 |
+
},
|
| 341 |
+
{
|
| 342 |
+
"name": "stderr",
|
| 343 |
+
"output_type": "stream",
|
| 344 |
+
"text": [
|
| 345 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:36<00:00, 1.20it/s]\n",
|
| 346 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.35it/s]\n"
|
| 347 |
+
]
|
| 348 |
+
},
|
| 349 |
+
{
|
| 350 |
+
"name": "stdout",
|
| 351 |
+
"output_type": "stream",
|
| 352 |
+
"text": [
|
| 353 |
+
"epoch 10: {'accuracy': 0.7156862745098039, 'f1': 0.8164556962025317}\n"
|
| 354 |
+
]
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
"name": "stderr",
|
| 358 |
+
"output_type": "stream",
|
| 359 |
+
"text": [
|
| 360 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:35<00:00, 1.20it/s]\n",
|
| 361 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.22it/s]\n"
|
| 362 |
+
]
|
| 363 |
+
},
|
| 364 |
+
{
|
| 365 |
+
"name": "stdout",
|
| 366 |
+
"output_type": "stream",
|
| 367 |
+
"text": [
|
| 368 |
+
"epoch 11: {'accuracy': 0.7058823529411765, 'f1': 0.8113207547169811}\n"
|
| 369 |
+
]
|
| 370 |
+
},
|
| 371 |
+
{
|
| 372 |
+
"name": "stderr",
|
| 373 |
+
"output_type": "stream",
|
| 374 |
+
"text": [
|
| 375 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:32<00:00, 1.24it/s]\n",
|
| 376 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.48it/s]\n"
|
| 377 |
+
]
|
| 378 |
+
},
|
| 379 |
+
{
|
| 380 |
+
"name": "stdout",
|
| 381 |
+
"output_type": "stream",
|
| 382 |
+
"text": [
|
| 383 |
+
"epoch 12: {'accuracy': 0.7009803921568627, 'f1': 0.7946127946127945}\n"
|
| 384 |
+
]
|
| 385 |
+
},
|
| 386 |
+
{
|
| 387 |
+
"name": "stderr",
|
| 388 |
+
"output_type": "stream",
|
| 389 |
+
"text": [
|
| 390 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:32<00:00, 1.24it/s]\n",
|
| 391 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.38it/s]\n"
|
| 392 |
+
]
|
| 393 |
+
},
|
| 394 |
+
{
|
| 395 |
+
"name": "stdout",
|
| 396 |
+
"output_type": "stream",
|
| 397 |
+
"text": [
|
| 398 |
+
"epoch 13: {'accuracy': 0.7230392156862745, 'f1': 0.8186195826645265}\n"
|
| 399 |
+
]
|
| 400 |
+
},
|
| 401 |
+
{
|
| 402 |
+
"name": "stderr",
|
| 403 |
+
"output_type": "stream",
|
| 404 |
+
"text": [
|
| 405 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:29<00:00, 1.29it/s]\n",
|
| 406 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.31it/s]\n"
|
| 407 |
+
]
|
| 408 |
+
},
|
| 409 |
+
{
|
| 410 |
+
"name": "stdout",
|
| 411 |
+
"output_type": "stream",
|
| 412 |
+
"text": [
|
| 413 |
+
"epoch 14: {'accuracy': 0.7058823529411765, 'f1': 0.8130841121495327}\n"
|
| 414 |
+
]
|
| 415 |
+
},
|
| 416 |
+
{
|
| 417 |
+
"name": "stderr",
|
| 418 |
+
"output_type": "stream",
|
| 419 |
+
"text": [
|
| 420 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:30<00:00, 1.27it/s]\n",
|
| 421 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.39it/s]\n"
|
| 422 |
+
]
|
| 423 |
+
},
|
| 424 |
+
{
|
| 425 |
+
"name": "stdout",
|
| 426 |
+
"output_type": "stream",
|
| 427 |
+
"text": [
|
| 428 |
+
"epoch 15: {'accuracy': 0.7181372549019608, 'f1': 0.8194662480376768}\n"
|
| 429 |
+
]
|
| 430 |
+
},
|
| 431 |
+
{
|
| 432 |
+
"name": "stderr",
|
| 433 |
+
"output_type": "stream",
|
| 434 |
+
"text": [
|
| 435 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:28<00:00, 1.29it/s]\n",
|
| 436 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.35it/s]\n"
|
| 437 |
+
]
|
| 438 |
+
},
|
| 439 |
+
{
|
| 440 |
+
"name": "stdout",
|
| 441 |
+
"output_type": "stream",
|
| 442 |
+
"text": [
|
| 443 |
+
"epoch 16: {'accuracy': 0.7254901960784313, 'f1': 0.8181818181818181}\n"
|
| 444 |
+
]
|
| 445 |
+
},
|
| 446 |
+
{
|
| 447 |
+
"name": "stderr",
|
| 448 |
+
"output_type": "stream",
|
| 449 |
+
"text": [
|
| 450 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:30<00:00, 1.27it/s]\n",
|
| 451 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.30it/s]\n"
|
| 452 |
+
]
|
| 453 |
+
},
|
| 454 |
+
{
|
| 455 |
+
"name": "stdout",
|
| 456 |
+
"output_type": "stream",
|
| 457 |
+
"text": [
|
| 458 |
+
"epoch 17: {'accuracy': 0.7205882352941176, 'f1': 0.820754716981132}\n"
|
| 459 |
+
]
|
| 460 |
+
},
|
| 461 |
+
{
|
| 462 |
+
"name": "stderr",
|
| 463 |
+
"output_type": "stream",
|
| 464 |
+
"text": [
|
| 465 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:30<00:00, 1.27it/s]\n",
|
| 466 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.36it/s]\n"
|
| 467 |
+
]
|
| 468 |
+
},
|
| 469 |
+
{
|
| 470 |
+
"name": "stdout",
|
| 471 |
+
"output_type": "stream",
|
| 472 |
+
"text": [
|
| 473 |
+
"epoch 18: {'accuracy': 0.7254901960784313, 'f1': 0.821656050955414}\n"
|
| 474 |
+
]
|
| 475 |
+
},
|
| 476 |
+
{
|
| 477 |
+
"name": "stderr",
|
| 478 |
+
"output_type": "stream",
|
| 479 |
+
"text": [
|
| 480 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:28<00:00, 1.29it/s]\n",
|
| 481 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.43it/s]"
|
| 482 |
+
]
|
| 483 |
+
},
|
| 484 |
+
{
|
| 485 |
+
"name": "stdout",
|
| 486 |
+
"output_type": "stream",
|
| 487 |
+
"text": [
|
| 488 |
+
"epoch 19: {'accuracy': 0.7303921568627451, 'f1': 0.8242811501597445}\n"
|
| 489 |
+
]
|
| 490 |
+
},
|
| 491 |
+
{
|
| 492 |
+
"name": "stderr",
|
| 493 |
+
"output_type": "stream",
|
| 494 |
+
"text": [
|
| 495 |
+
"\n"
|
| 496 |
+
]
|
| 497 |
+
}
|
| 498 |
+
],
|
| 499 |
+
"source": [
|
| 500 |
+
"model.to(device)\n",
|
| 501 |
+
"for epoch in range(num_epochs):\n",
|
| 502 |
+
" model.train()\n",
|
| 503 |
+
" for step, batch in enumerate(tqdm(train_dataloader)):\n",
|
| 504 |
+
" batch.to(device)\n",
|
| 505 |
+
" outputs = model(**batch)\n",
|
| 506 |
+
" loss = outputs.loss\n",
|
| 507 |
+
" loss.backward()\n",
|
| 508 |
+
" optimizer.step()\n",
|
| 509 |
+
" lr_scheduler.step()\n",
|
| 510 |
+
" optimizer.zero_grad()\n",
|
| 511 |
+
"\n",
|
| 512 |
+
" model.eval()\n",
|
| 513 |
+
" for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
| 514 |
+
" batch.to(device)\n",
|
| 515 |
+
" with torch.no_grad():\n",
|
| 516 |
+
" outputs = model(**batch)\n",
|
| 517 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
| 518 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
| 519 |
+
" metric.add_batch(\n",
|
| 520 |
+
" predictions=predictions,\n",
|
| 521 |
+
" references=references,\n",
|
| 522 |
+
" )\n",
|
| 523 |
+
"\n",
|
| 524 |
+
" eval_metric = metric.compute()\n",
|
| 525 |
+
" print(f\"epoch {epoch}:\", eval_metric)"
|
| 526 |
+
]
|
| 527 |
+
},
|
| 528 |
+
{
|
| 529 |
+
"cell_type": "markdown",
|
| 530 |
+
"id": "e1ff3f44",
|
| 531 |
+
"metadata": {},
|
| 532 |
+
"source": [
|
| 533 |
+
"## Share adapters on the π€ Hub"
|
| 534 |
+
]
|
| 535 |
+
},
|
| 536 |
+
{
|
| 537 |
+
"cell_type": "code",
|
| 538 |
+
"execution_count": 8,
|
| 539 |
+
"id": "0bf79cb5",
|
| 540 |
+
"metadata": {},
|
| 541 |
+
"outputs": [
|
| 542 |
+
{
|
| 543 |
+
"data": {
|
| 544 |
+
"text/plain": [
|
| 545 |
+
"CommitInfo(commit_url='https://huggingface.co/smangrul/roberta-large-peft-prompt-tuning/commit/893a909d8499aa8778d58c781d43c3a8d9360de8', commit_message='Upload model', commit_description='', oid='893a909d8499aa8778d58c781d43c3a8d9360de8', pr_url=None, pr_revision=None, pr_num=None)"
|
| 546 |
+
]
|
| 547 |
+
},
|
| 548 |
+
"execution_count": 8,
|
| 549 |
+
"metadata": {},
|
| 550 |
+
"output_type": "execute_result"
|
| 551 |
+
}
|
| 552 |
+
],
|
| 553 |
+
"source": [
|
| 554 |
+
"model.push_to_hub(\"smangrul/roberta-large-peft-prompt-tuning\", use_auth_token=True)"
|
| 555 |
+
]
|
| 556 |
+
},
|
| 557 |
+
{
|
| 558 |
+
"cell_type": "markdown",
|
| 559 |
+
"id": "73870ad7",
|
| 560 |
+
"metadata": {},
|
| 561 |
+
"source": [
|
| 562 |
+
"## Load adapters from the Hub\n",
|
| 563 |
+
"\n",
|
| 564 |
+
"You can also directly load adapters from the Hub using the commands below:"
|
| 565 |
+
]
|
| 566 |
+
},
|
| 567 |
+
{
|
| 568 |
+
"cell_type": "code",
|
| 569 |
+
"execution_count": 9,
|
| 570 |
+
"id": "0654a552",
|
| 571 |
+
"metadata": {},
|
| 572 |
+
"outputs": [
|
| 573 |
+
{
|
| 574 |
+
"data": {
|
| 575 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 576 |
+
"model_id": "24581bb98582444ca6114b9fa267847f",
|
| 577 |
+
"version_major": 2,
|
| 578 |
+
"version_minor": 0
|
| 579 |
+
},
|
| 580 |
+
"text/plain": [
|
| 581 |
+
"Downloading: 0%| | 0.00/368 [00:00<?, ?B/s]"
|
| 582 |
+
]
|
| 583 |
+
},
|
| 584 |
+
"metadata": {},
|
| 585 |
+
"output_type": "display_data"
|
| 586 |
+
},
|
| 587 |
+
{
|
| 588 |
+
"name": "stderr",
|
| 589 |
+
"output_type": "stream",
|
| 590 |
+
"text": [
|
| 591 |
+
"Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias']\n",
|
| 592 |
+
"- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
| 593 |
+
"- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
| 594 |
+
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight']\n",
|
| 595 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
| 596 |
+
]
|
| 597 |
+
},
|
| 598 |
+
{
|
| 599 |
+
"data": {
|
| 600 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 601 |
+
"model_id": "f1584da4d1c54cc3873a515182674980",
|
| 602 |
+
"version_major": 2,
|
| 603 |
+
"version_minor": 0
|
| 604 |
+
},
|
| 605 |
+
"text/plain": [
|
| 606 |
+
"Downloading: 0%| | 0.00/4.25M [00:00<?, ?B/s]"
|
| 607 |
+
]
|
| 608 |
+
},
|
| 609 |
+
"metadata": {},
|
| 610 |
+
"output_type": "display_data"
|
| 611 |
+
},
|
| 612 |
+
{
|
| 613 |
+
"name": "stderr",
|
| 614 |
+
"output_type": "stream",
|
| 615 |
+
"text": [
|
| 616 |
+
" 0%| | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
| 617 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.58it/s]"
|
| 618 |
+
]
|
| 619 |
+
},
|
| 620 |
+
{
|
| 621 |
+
"name": "stdout",
|
| 622 |
+
"output_type": "stream",
|
| 623 |
+
"text": [
|
| 624 |
+
"{'accuracy': 0.7303921568627451, 'f1': 0.8242811501597445}\n"
|
| 625 |
+
]
|
| 626 |
+
},
|
| 627 |
+
{
|
| 628 |
+
"name": "stderr",
|
| 629 |
+
"output_type": "stream",
|
| 630 |
+
"text": [
|
| 631 |
+
"\n"
|
| 632 |
+
]
|
| 633 |
+
}
|
| 634 |
+
],
|
| 635 |
+
"source": [
|
| 636 |
+
"import torch\n",
|
| 637 |
+
"from peft import PeftModel, PeftConfig\n",
|
| 638 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
| 639 |
+
"\n",
|
| 640 |
+
"peft_model_id = \"smangrul/roberta-large-peft-prompt-tuning\"\n",
|
| 641 |
+
"config = PeftConfig.from_pretrained(peft_model_id)\n",
|
| 642 |
+
"inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)\n",
|
| 643 |
+
"tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
|
| 644 |
+
"\n",
|
| 645 |
+
"# Load the Lora model\n",
|
| 646 |
+
"inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)\n",
|
| 647 |
+
"\n",
|
| 648 |
+
"inference_model.to(device)\n",
|
| 649 |
+
"inference_model.eval()\n",
|
| 650 |
+
"for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
| 651 |
+
" batch.to(device)\n",
|
| 652 |
+
" with torch.no_grad():\n",
|
| 653 |
+
" outputs = inference_model(**batch)\n",
|
| 654 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
| 655 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
| 656 |
+
" metric.add_batch(\n",
|
| 657 |
+
" predictions=predictions,\n",
|
| 658 |
+
" references=references,\n",
|
| 659 |
+
" )\n",
|
| 660 |
+
"\n",
|
| 661 |
+
"eval_metric = metric.compute()\n",
|
| 662 |
+
"print(eval_metric)"
|
| 663 |
+
]
|
| 664 |
+
}
|
| 665 |
+
],
|
| 666 |
+
"metadata": {
|
| 667 |
+
"kernelspec": {
|
| 668 |
+
"display_name": "Python 3 (ipykernel)",
|
| 669 |
+
"language": "python",
|
| 670 |
+
"name": "python3"
|
| 671 |
+
},
|
| 672 |
+
"language_info": {
|
| 673 |
+
"codemirror_mode": {
|
| 674 |
+
"name": "ipython",
|
| 675 |
+
"version": 3
|
| 676 |
+
},
|
| 677 |
+
"file_extension": ".py",
|
| 678 |
+
"mimetype": "text/x-python",
|
| 679 |
+
"name": "python",
|
| 680 |
+
"nbconvert_exporter": "python",
|
| 681 |
+
"pygments_lexer": "ipython3",
|
| 682 |
+
"version": "3.10.4"
|
| 683 |
+
},
|
| 684 |
+
"vscode": {
|
| 685 |
+
"interpreter": {
|
| 686 |
+
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
|
| 687 |
+
}
|
| 688 |
+
}
|
| 689 |
+
},
|
| 690 |
+
"nbformat": 4,
|
| 691 |
+
"nbformat_minor": 5
|
| 692 |
+
}
|
prefix_tuning.ipynb
ADDED
|
@@ -0,0 +1,710 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "a825ba6b",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"\n",
|
| 14 |
+
"===================================BUG REPORT===================================\n",
|
| 15 |
+
"Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
|
| 16 |
+
"For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
|
| 17 |
+
"================================================================================\n",
|
| 18 |
+
"CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
|
| 19 |
+
"CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
|
| 20 |
+
"CUDA SETUP: Detected CUDA version 117\n",
|
| 21 |
+
"CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
|
| 22 |
+
]
|
| 23 |
+
}
|
| 24 |
+
],
|
| 25 |
+
"source": [
|
| 26 |
+
"import argparse\n",
|
| 27 |
+
"import os\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"import torch\n",
|
| 30 |
+
"from torch.optim import AdamW\n",
|
| 31 |
+
"from torch.utils.data import DataLoader\n",
|
| 32 |
+
"from peft import (\n",
|
| 33 |
+
" get_peft_config,\n",
|
| 34 |
+
" get_peft_model,\n",
|
| 35 |
+
" get_peft_model_state_dict,\n",
|
| 36 |
+
" set_peft_model_state_dict,\n",
|
| 37 |
+
" PeftType,\n",
|
| 38 |
+
" PrefixTuningConfig,\n",
|
| 39 |
+
" PromptEncoderConfig,\n",
|
| 40 |
+
")\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"import evaluate\n",
|
| 43 |
+
"from datasets import load_dataset\n",
|
| 44 |
+
"from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n",
|
| 45 |
+
"from tqdm import tqdm"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": 2,
|
| 51 |
+
"id": "2bd7cbb2",
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"outputs": [],
|
| 54 |
+
"source": [
|
| 55 |
+
"batch_size = 32\n",
|
| 56 |
+
"model_name_or_path = \"roberta-large\"\n",
|
| 57 |
+
"task = \"mrpc\"\n",
|
| 58 |
+
"peft_type = PeftType.PREFIX_TUNING\n",
|
| 59 |
+
"device = \"cuda\"\n",
|
| 60 |
+
"num_epochs = 20"
|
| 61 |
+
]
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"cell_type": "code",
|
| 65 |
+
"execution_count": 3,
|
| 66 |
+
"id": "33d9b62e",
|
| 67 |
+
"metadata": {},
|
| 68 |
+
"outputs": [],
|
| 69 |
+
"source": [
|
| 70 |
+
"peft_config = PrefixTuningConfig(task_type=\"SEQ_CLS\", num_virtual_tokens=20)\n",
|
| 71 |
+
"lr = 1e-2"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": 4,
|
| 77 |
+
"id": "152b6177",
|
| 78 |
+
"metadata": {},
|
| 79 |
+
"outputs": [
|
| 80 |
+
{
|
| 81 |
+
"name": "stderr",
|
| 82 |
+
"output_type": "stream",
|
| 83 |
+
"text": [
|
| 84 |
+
"Found cached dataset glue (/home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
|
| 85 |
+
]
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"data": {
|
| 89 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 90 |
+
"model_id": "be1eddbb9a7d4e6dae32fd026e167f96",
|
| 91 |
+
"version_major": 2,
|
| 92 |
+
"version_minor": 0
|
| 93 |
+
},
|
| 94 |
+
"text/plain": [
|
| 95 |
+
" 0%| | 0/3 [00:00<?, ?it/s]"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
"metadata": {},
|
| 99 |
+
"output_type": "display_data"
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"name": "stderr",
|
| 103 |
+
"output_type": "stream",
|
| 104 |
+
"text": [
|
| 105 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9fa7887f9eaa03ae.arrow\n"
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"data": {
|
| 110 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 111 |
+
"model_id": "b61574844b6c499b8960fd4d78c5e549",
|
| 112 |
+
"version_major": 2,
|
| 113 |
+
"version_minor": 0
|
| 114 |
+
},
|
| 115 |
+
"text/plain": [
|
| 116 |
+
" 0%| | 0/1 [00:00<?, ?ba/s]"
|
| 117 |
+
]
|
| 118 |
+
},
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"output_type": "display_data"
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"name": "stderr",
|
| 124 |
+
"output_type": "stream",
|
| 125 |
+
"text": [
|
| 126 |
+
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-7e7eacaa5160936d.arrow\n"
|
| 127 |
+
]
|
| 128 |
+
}
|
| 129 |
+
],
|
| 130 |
+
"source": [
|
| 131 |
+
"if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n",
|
| 132 |
+
" padding_side = \"left\"\n",
|
| 133 |
+
"else:\n",
|
| 134 |
+
" padding_side = \"right\"\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n",
|
| 137 |
+
"if getattr(tokenizer, \"pad_token_id\") is None:\n",
|
| 138 |
+
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"datasets = load_dataset(\"glue\", task)\n",
|
| 141 |
+
"metric = evaluate.load(\"glue\", task)\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"def tokenize_function(examples):\n",
|
| 145 |
+
" # max_length=None => use the model max length (it's actually the default)\n",
|
| 146 |
+
" outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n",
|
| 147 |
+
" return outputs\n",
|
| 148 |
+
"\n",
|
| 149 |
+
"\n",
|
| 150 |
+
"tokenized_datasets = datasets.map(\n",
|
| 151 |
+
" tokenize_function,\n",
|
| 152 |
+
" batched=True,\n",
|
| 153 |
+
" remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
|
| 154 |
+
")\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n",
|
| 157 |
+
"# transformers library\n",
|
| 158 |
+
"tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
|
| 159 |
+
"\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"def collate_fn(examples):\n",
|
| 162 |
+
" return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"\n",
|
| 165 |
+
"# Instantiate dataloaders.\n",
|
| 166 |
+
"train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n",
|
| 167 |
+
"eval_dataloader = DataLoader(\n",
|
| 168 |
+
" tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n",
|
| 169 |
+
")"
|
| 170 |
+
]
|
| 171 |
+
},
|
| 172 |
+
{
|
| 173 |
+
"cell_type": "code",
|
| 174 |
+
"execution_count": null,
|
| 175 |
+
"id": "f6bc8144",
|
| 176 |
+
"metadata": {},
|
| 177 |
+
"outputs": [],
|
| 178 |
+
"source": [
|
| 179 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n",
|
| 180 |
+
"model = get_peft_model(model, peft_config)\n",
|
| 181 |
+
"model.print_trainable_parameters()\n",
|
| 182 |
+
"model"
|
| 183 |
+
]
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"cell_type": "code",
|
| 187 |
+
"execution_count": 6,
|
| 188 |
+
"id": "af41c571",
|
| 189 |
+
"metadata": {},
|
| 190 |
+
"outputs": [],
|
| 191 |
+
"source": [
|
| 192 |
+
"optimizer = AdamW(params=model.parameters(), lr=lr)\n",
|
| 193 |
+
"\n",
|
| 194 |
+
"# Instantiate scheduler\n",
|
| 195 |
+
"lr_scheduler = get_linear_schedule_with_warmup(\n",
|
| 196 |
+
" optimizer=optimizer,\n",
|
| 197 |
+
" num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),\n",
|
| 198 |
+
" num_training_steps=(len(train_dataloader) * num_epochs),\n",
|
| 199 |
+
")"
|
| 200 |
+
]
|
| 201 |
+
},
|
| 202 |
+
{
|
| 203 |
+
"cell_type": "code",
|
| 204 |
+
"execution_count": 7,
|
| 205 |
+
"id": "90993c93",
|
| 206 |
+
"metadata": {},
|
| 207 |
+
"outputs": [
|
| 208 |
+
{
|
| 209 |
+
"name": "stderr",
|
| 210 |
+
"output_type": "stream",
|
| 211 |
+
"text": [
|
| 212 |
+
" 0%| | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
| 213 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:29<00:00, 3.87it/s]\n",
|
| 214 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.32it/s]\n"
|
| 215 |
+
]
|
| 216 |
+
},
|
| 217 |
+
{
|
| 218 |
+
"name": "stdout",
|
| 219 |
+
"output_type": "stream",
|
| 220 |
+
"text": [
|
| 221 |
+
"epoch 0: {'accuracy': 0.7132352941176471, 'f1': 0.7876588021778584}\n"
|
| 222 |
+
]
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
"name": "stderr",
|
| 226 |
+
"output_type": "stream",
|
| 227 |
+
"text": [
|
| 228 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:26<00:00, 4.42it/s]\n",
|
| 229 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.36it/s]\n"
|
| 230 |
+
]
|
| 231 |
+
},
|
| 232 |
+
{
|
| 233 |
+
"name": "stdout",
|
| 234 |
+
"output_type": "stream",
|
| 235 |
+
"text": [
|
| 236 |
+
"epoch 1: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}\n"
|
| 237 |
+
]
|
| 238 |
+
},
|
| 239 |
+
{
|
| 240 |
+
"name": "stderr",
|
| 241 |
+
"output_type": "stream",
|
| 242 |
+
"text": [
|
| 243 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:26<00:00, 4.41it/s]\n",
|
| 244 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.35it/s]\n"
|
| 245 |
+
]
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"name": "stdout",
|
| 249 |
+
"output_type": "stream",
|
| 250 |
+
"text": [
|
| 251 |
+
"epoch 2: {'accuracy': 0.8088235294117647, 'f1': 0.8717105263157895}\n"
|
| 252 |
+
]
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
"name": "stderr",
|
| 256 |
+
"output_type": "stream",
|
| 257 |
+
"text": [
|
| 258 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:26<00:00, 4.39it/s]\n",
|
| 259 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.34it/s]\n"
|
| 260 |
+
]
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
"name": "stdout",
|
| 264 |
+
"output_type": "stream",
|
| 265 |
+
"text": [
|
| 266 |
+
"epoch 3: {'accuracy': 0.7549019607843137, 'f1': 0.8475609756097561}\n"
|
| 267 |
+
]
|
| 268 |
+
},
|
| 269 |
+
{
|
| 270 |
+
"name": "stderr",
|
| 271 |
+
"output_type": "stream",
|
| 272 |
+
"text": [
|
| 273 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:26<00:00, 4.37it/s]\n",
|
| 274 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:01<00:00, 8.34it/s]\n"
|
| 275 |
+
]
|
| 276 |
+
},
|
| 277 |
+
{
|
| 278 |
+
"name": "stdout",
|
| 279 |
+
"output_type": "stream",
|
| 280 |
+
"text": [
|
| 281 |
+
"epoch 4: {'accuracy': 0.8480392156862745, 'f1': 0.8938356164383561}\n"
|
| 282 |
+
]
|
| 283 |
+
},
|
| 284 |
+
{
|
| 285 |
+
"name": "stderr",
|
| 286 |
+
"output_type": "stream",
|
| 287 |
+
"text": [
|
| 288 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [00:40<00:00, 2.87it/s]\n",
|
| 289 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:06<00:00, 1.93it/s]\n"
|
| 290 |
+
]
|
| 291 |
+
},
|
| 292 |
+
{
|
| 293 |
+
"name": "stdout",
|
| 294 |
+
"output_type": "stream",
|
| 295 |
+
"text": [
|
| 296 |
+
"epoch 5: {'accuracy': 0.8651960784313726, 'f1': 0.9053356282271946}\n"
|
| 297 |
+
]
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"name": "stderr",
|
| 301 |
+
"output_type": "stream",
|
| 302 |
+
"text": [
|
| 303 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:53<00:00, 1.01it/s]\n",
|
| 304 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:07<00:00, 1.79it/s]\n"
|
| 305 |
+
]
|
| 306 |
+
},
|
| 307 |
+
{
|
| 308 |
+
"name": "stdout",
|
| 309 |
+
"output_type": "stream",
|
| 310 |
+
"text": [
|
| 311 |
+
"epoch 6: {'accuracy': 0.8700980392156863, 'f1': 0.9065255731922399}\n"
|
| 312 |
+
]
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"name": "stderr",
|
| 316 |
+
"output_type": "stream",
|
| 317 |
+
"text": [
|
| 318 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:42<00:00, 1.12it/s]\n",
|
| 319 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.43it/s]\n"
|
| 320 |
+
]
|
| 321 |
+
},
|
| 322 |
+
{
|
| 323 |
+
"name": "stdout",
|
| 324 |
+
"output_type": "stream",
|
| 325 |
+
"text": [
|
| 326 |
+
"epoch 7: {'accuracy': 0.8676470588235294, 'f1': 0.9042553191489361}\n"
|
| 327 |
+
]
|
| 328 |
+
},
|
| 329 |
+
{
|
| 330 |
+
"name": "stderr",
|
| 331 |
+
"output_type": "stream",
|
| 332 |
+
"text": [
|
| 333 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:27<00:00, 1.31it/s]\n",
|
| 334 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.45it/s]\n"
|
| 335 |
+
]
|
| 336 |
+
},
|
| 337 |
+
{
|
| 338 |
+
"name": "stdout",
|
| 339 |
+
"output_type": "stream",
|
| 340 |
+
"text": [
|
| 341 |
+
"epoch 8: {'accuracy': 0.875, 'f1': 0.9103690685413005}\n"
|
| 342 |
+
]
|
| 343 |
+
},
|
| 344 |
+
{
|
| 345 |
+
"name": "stderr",
|
| 346 |
+
"output_type": "stream",
|
| 347 |
+
"text": [
|
| 348 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:29<00:00, 1.29it/s]\n",
|
| 349 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.48it/s]\n"
|
| 350 |
+
]
|
| 351 |
+
},
|
| 352 |
+
{
|
| 353 |
+
"name": "stdout",
|
| 354 |
+
"output_type": "stream",
|
| 355 |
+
"text": [
|
| 356 |
+
"epoch 9: {'accuracy': 0.8799019607843137, 'f1': 0.913884007029877}\n"
|
| 357 |
+
]
|
| 358 |
+
},
|
| 359 |
+
{
|
| 360 |
+
"name": "stderr",
|
| 361 |
+
"output_type": "stream",
|
| 362 |
+
"text": [
|
| 363 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:43<00:00, 1.11it/s]\n",
|
| 364 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:06<00:00, 1.88it/s]\n"
|
| 365 |
+
]
|
| 366 |
+
},
|
| 367 |
+
{
|
| 368 |
+
"name": "stdout",
|
| 369 |
+
"output_type": "stream",
|
| 370 |
+
"text": [
|
| 371 |
+
"epoch 10: {'accuracy': 0.8725490196078431, 'f1': 0.902621722846442}\n"
|
| 372 |
+
]
|
| 373 |
+
},
|
| 374 |
+
{
|
| 375 |
+
"name": "stderr",
|
| 376 |
+
"output_type": "stream",
|
| 377 |
+
"text": [
|
| 378 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:53<00:00, 1.02it/s]\n",
|
| 379 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:06<00:00, 2.02it/s]\n"
|
| 380 |
+
]
|
| 381 |
+
},
|
| 382 |
+
{
|
| 383 |
+
"name": "stdout",
|
| 384 |
+
"output_type": "stream",
|
| 385 |
+
"text": [
|
| 386 |
+
"epoch 11: {'accuracy': 0.875, 'f1': 0.9090909090909091}\n"
|
| 387 |
+
]
|
| 388 |
+
},
|
| 389 |
+
{
|
| 390 |
+
"name": "stderr",
|
| 391 |
+
"output_type": "stream",
|
| 392 |
+
"text": [
|
| 393 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:29<00:00, 1.28it/s]\n",
|
| 394 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:04<00:00, 2.65it/s]\n"
|
| 395 |
+
]
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
"name": "stdout",
|
| 399 |
+
"output_type": "stream",
|
| 400 |
+
"text": [
|
| 401 |
+
"epoch 12: {'accuracy': 0.8823529411764706, 'f1': 0.9139784946236559}\n"
|
| 402 |
+
]
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"name": "stderr",
|
| 406 |
+
"output_type": "stream",
|
| 407 |
+
"text": [
|
| 408 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:27<00:00, 1.31it/s]\n",
|
| 409 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.46it/s]\n"
|
| 410 |
+
]
|
| 411 |
+
},
|
| 412 |
+
{
|
| 413 |
+
"name": "stdout",
|
| 414 |
+
"output_type": "stream",
|
| 415 |
+
"text": [
|
| 416 |
+
"epoch 13: {'accuracy': 0.8602941176470589, 'f1': 0.9018932874354562}\n"
|
| 417 |
+
]
|
| 418 |
+
},
|
| 419 |
+
{
|
| 420 |
+
"name": "stderr",
|
| 421 |
+
"output_type": "stream",
|
| 422 |
+
"text": [
|
| 423 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:27<00:00, 1.31it/s]\n",
|
| 424 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½βββββββββββββββββββββ| 13/13 [00:05<00:00, 2.47it/s]\n"
|
| 425 |
+
]
|
| 426 |
+
},
|
| 427 |
+
{
|
| 428 |
+
"name": "stdout",
|
| 429 |
+
"output_type": "stream",
|
| 430 |
+
"text": [
|
| 431 |
+
"epoch 14: {'accuracy': 0.8700980392156863, 'f1': 0.9075043630017452}\n"
|
| 432 |
+
]
|
| 433 |
+
},
|
| 434 |
+
{
|
| 435 |
+
"name": "stderr",
|
| 436 |
+
"output_type": "stream",
|
| 437 |
+
"text": [
|
| 438 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:27<00:00, 1.31it/s]\n",
|
| 439 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.49it/s]\n"
|
| 440 |
+
]
|
| 441 |
+
},
|
| 442 |
+
{
|
| 443 |
+
"name": "stdout",
|
| 444 |
+
"output_type": "stream",
|
| 445 |
+
"text": [
|
| 446 |
+
"epoch 15: {'accuracy': 0.875, 'f1': 0.9087656529516995}\n"
|
| 447 |
+
]
|
| 448 |
+
},
|
| 449 |
+
{
|
| 450 |
+
"name": "stderr",
|
| 451 |
+
"output_type": "stream",
|
| 452 |
+
"text": [
|
| 453 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:27<00:00, 1.32it/s]\n",
|
| 454 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.49it/s]\n"
|
| 455 |
+
]
|
| 456 |
+
},
|
| 457 |
+
{
|
| 458 |
+
"name": "stdout",
|
| 459 |
+
"output_type": "stream",
|
| 460 |
+
"text": [
|
| 461 |
+
"epoch 16: {'accuracy': 0.8578431372549019, 'f1': 0.9003436426116839}\n"
|
| 462 |
+
]
|
| 463 |
+
},
|
| 464 |
+
{
|
| 465 |
+
"name": "stderr",
|
| 466 |
+
"output_type": "stream",
|
| 467 |
+
"text": [
|
| 468 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:27<00:00, 1.31it/s]\n",
|
| 469 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.22it/s]\n"
|
| 470 |
+
]
|
| 471 |
+
},
|
| 472 |
+
{
|
| 473 |
+
"name": "stdout",
|
| 474 |
+
"output_type": "stream",
|
| 475 |
+
"text": [
|
| 476 |
+
"epoch 17: {'accuracy': 0.8627450980392157, 'f1': 0.903448275862069}\n"
|
| 477 |
+
]
|
| 478 |
+
},
|
| 479 |
+
{
|
| 480 |
+
"name": "stderr",
|
| 481 |
+
"output_type": "stream",
|
| 482 |
+
"text": [
|
| 483 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:28<00:00, 1.31it/s]\n",
|
| 484 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:04<00:00, 2.65it/s]\n"
|
| 485 |
+
]
|
| 486 |
+
},
|
| 487 |
+
{
|
| 488 |
+
"name": "stdout",
|
| 489 |
+
"output_type": "stream",
|
| 490 |
+
"text": [
|
| 491 |
+
"epoch 18: {'accuracy': 0.8700980392156863, 'f1': 0.9078260869565218}\n"
|
| 492 |
+
]
|
| 493 |
+
},
|
| 494 |
+
{
|
| 495 |
+
"name": "stderr",
|
| 496 |
+
"output_type": "stream",
|
| 497 |
+
"text": [
|
| 498 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 115/115 [01:27<00:00, 1.32it/s]\n",
|
| 499 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:05<00:00, 2.45it/s]"
|
| 500 |
+
]
|
| 501 |
+
},
|
| 502 |
+
{
|
| 503 |
+
"name": "stdout",
|
| 504 |
+
"output_type": "stream",
|
| 505 |
+
"text": [
|
| 506 |
+
"epoch 19: {'accuracy': 0.8774509803921569, 'f1': 0.9125874125874125}\n"
|
| 507 |
+
]
|
| 508 |
+
},
|
| 509 |
+
{
|
| 510 |
+
"name": "stderr",
|
| 511 |
+
"output_type": "stream",
|
| 512 |
+
"text": [
|
| 513 |
+
"\n"
|
| 514 |
+
]
|
| 515 |
+
}
|
| 516 |
+
],
|
| 517 |
+
"source": [
|
| 518 |
+
"model.to(device)\n",
|
| 519 |
+
"for epoch in range(num_epochs):\n",
|
| 520 |
+
" model.train()\n",
|
| 521 |
+
" for step, batch in enumerate(tqdm(train_dataloader)):\n",
|
| 522 |
+
" batch.to(device)\n",
|
| 523 |
+
" outputs = model(**batch)\n",
|
| 524 |
+
" loss = outputs.loss\n",
|
| 525 |
+
" loss.backward()\n",
|
| 526 |
+
" optimizer.step()\n",
|
| 527 |
+
" lr_scheduler.step()\n",
|
| 528 |
+
" optimizer.zero_grad()\n",
|
| 529 |
+
"\n",
|
| 530 |
+
" model.eval()\n",
|
| 531 |
+
" for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
| 532 |
+
" batch.to(device)\n",
|
| 533 |
+
" with torch.no_grad():\n",
|
| 534 |
+
" outputs = model(**batch)\n",
|
| 535 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
| 536 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
| 537 |
+
" metric.add_batch(\n",
|
| 538 |
+
" predictions=predictions,\n",
|
| 539 |
+
" references=references,\n",
|
| 540 |
+
" )\n",
|
| 541 |
+
"\n",
|
| 542 |
+
" eval_metric = metric.compute()\n",
|
| 543 |
+
" print(f\"epoch {epoch}:\", eval_metric)"
|
| 544 |
+
]
|
| 545 |
+
},
|
| 546 |
+
{
|
| 547 |
+
"cell_type": "markdown",
|
| 548 |
+
"id": "7734299c",
|
| 549 |
+
"metadata": {},
|
| 550 |
+
"source": [
|
| 551 |
+
"## Share adapters on the π€ Hub"
|
| 552 |
+
]
|
| 553 |
+
},
|
| 554 |
+
{
|
| 555 |
+
"cell_type": "code",
|
| 556 |
+
"execution_count": 8,
|
| 557 |
+
"id": "afaf42dd",
|
| 558 |
+
"metadata": {},
|
| 559 |
+
"outputs": [
|
| 560 |
+
{
|
| 561 |
+
"data": {
|
| 562 |
+
"text/plain": [
|
| 563 |
+
"CommitInfo(commit_url='https://huggingface.co/smangrul/roberta-large-peft-prefix-tuning/commit/a00e05a4c9a68e700221784f8e073c2e194637c3', commit_message='Upload model', commit_description='', oid='a00e05a4c9a68e700221784f8e073c2e194637c3', pr_url=None, pr_revision=None, pr_num=None)"
|
| 564 |
+
]
|
| 565 |
+
},
|
| 566 |
+
"execution_count": 8,
|
| 567 |
+
"metadata": {},
|
| 568 |
+
"output_type": "execute_result"
|
| 569 |
+
}
|
| 570 |
+
],
|
| 571 |
+
"source": [
|
| 572 |
+
"model.push_to_hub(\"smangrul/roberta-large-peft-prefix-tuning\", use_auth_token=True)"
|
| 573 |
+
]
|
| 574 |
+
},
|
| 575 |
+
{
|
| 576 |
+
"cell_type": "markdown",
|
| 577 |
+
"id": "42b20e77",
|
| 578 |
+
"metadata": {},
|
| 579 |
+
"source": [
|
| 580 |
+
"## Load adapters from the Hub\n",
|
| 581 |
+
"\n",
|
| 582 |
+
"You can also directly load adapters from the Hub using the commands below:"
|
| 583 |
+
]
|
| 584 |
+
},
|
| 585 |
+
{
|
| 586 |
+
"cell_type": "code",
|
| 587 |
+
"execution_count": 9,
|
| 588 |
+
"id": "868e7580",
|
| 589 |
+
"metadata": {},
|
| 590 |
+
"outputs": [
|
| 591 |
+
{
|
| 592 |
+
"data": {
|
| 593 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 594 |
+
"model_id": "2ce57b4de8ae4f868115733abc2fb883",
|
| 595 |
+
"version_major": 2,
|
| 596 |
+
"version_minor": 0
|
| 597 |
+
},
|
| 598 |
+
"text/plain": [
|
| 599 |
+
"Downloading: 0%| | 0.00/373 [00:00<?, ?B/s]"
|
| 600 |
+
]
|
| 601 |
+
},
|
| 602 |
+
"metadata": {},
|
| 603 |
+
"output_type": "display_data"
|
| 604 |
+
},
|
| 605 |
+
{
|
| 606 |
+
"name": "stderr",
|
| 607 |
+
"output_type": "stream",
|
| 608 |
+
"text": [
|
| 609 |
+
"Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'roberta.pooler.dense.weight', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias']\n",
|
| 610 |
+
"- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
| 611 |
+
"- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
| 612 |
+
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight']\n",
|
| 613 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
| 614 |
+
]
|
| 615 |
+
},
|
| 616 |
+
{
|
| 617 |
+
"data": {
|
| 618 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 619 |
+
"model_id": "ace158c926a44b31a9b0ea80411bd7a9",
|
| 620 |
+
"version_major": 2,
|
| 621 |
+
"version_minor": 0
|
| 622 |
+
},
|
| 623 |
+
"text/plain": [
|
| 624 |
+
"Downloading: 0%| | 0.00/8.14M [00:00<?, ?B/s]"
|
| 625 |
+
]
|
| 626 |
+
},
|
| 627 |
+
"metadata": {},
|
| 628 |
+
"output_type": "display_data"
|
| 629 |
+
},
|
| 630 |
+
{
|
| 631 |
+
"name": "stderr",
|
| 632 |
+
"output_type": "stream",
|
| 633 |
+
"text": [
|
| 634 |
+
" 0%| | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
| 635 |
+
"100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 13/13 [00:06<00:00, 2.04it/s]"
|
| 636 |
+
]
|
| 637 |
+
},
|
| 638 |
+
{
|
| 639 |
+
"name": "stdout",
|
| 640 |
+
"output_type": "stream",
|
| 641 |
+
"text": [
|
| 642 |
+
"{'accuracy': 0.8774509803921569, 'f1': 0.9125874125874125}\n"
|
| 643 |
+
]
|
| 644 |
+
},
|
| 645 |
+
{
|
| 646 |
+
"name": "stderr",
|
| 647 |
+
"output_type": "stream",
|
| 648 |
+
"text": [
|
| 649 |
+
"\n"
|
| 650 |
+
]
|
| 651 |
+
}
|
| 652 |
+
],
|
| 653 |
+
"source": [
|
| 654 |
+
"import torch\n",
|
| 655 |
+
"from peft import PeftModel, PeftConfig\n",
|
| 656 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
| 657 |
+
"\n",
|
| 658 |
+
"peft_model_id = \"smangrul/roberta-large-peft-prefix-tuning\"\n",
|
| 659 |
+
"config = PeftConfig.from_pretrained(peft_model_id)\n",
|
| 660 |
+
"inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)\n",
|
| 661 |
+
"tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
|
| 662 |
+
"\n",
|
| 663 |
+
"# Load the Lora model\n",
|
| 664 |
+
"inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)\n",
|
| 665 |
+
"\n",
|
| 666 |
+
"inference_model.to(device)\n",
|
| 667 |
+
"inference_model.eval()\n",
|
| 668 |
+
"for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
| 669 |
+
" batch.to(device)\n",
|
| 670 |
+
" with torch.no_grad():\n",
|
| 671 |
+
" outputs = inference_model(**batch)\n",
|
| 672 |
+
" predictions = outputs.logits.argmax(dim=-1)\n",
|
| 673 |
+
" predictions, references = predictions, batch[\"labels\"]\n",
|
| 674 |
+
" metric.add_batch(\n",
|
| 675 |
+
" predictions=predictions,\n",
|
| 676 |
+
" references=references,\n",
|
| 677 |
+
" )\n",
|
| 678 |
+
"\n",
|
| 679 |
+
"eval_metric = metric.compute()\n",
|
| 680 |
+
"print(eval_metric)"
|
| 681 |
+
]
|
| 682 |
+
}
|
| 683 |
+
],
|
| 684 |
+
"metadata": {
|
| 685 |
+
"kernelspec": {
|
| 686 |
+
"display_name": "Python 3 (ipykernel)",
|
| 687 |
+
"language": "python",
|
| 688 |
+
"name": "python3"
|
| 689 |
+
},
|
| 690 |
+
"language_info": {
|
| 691 |
+
"codemirror_mode": {
|
| 692 |
+
"name": "ipython",
|
| 693 |
+
"version": 3
|
| 694 |
+
},
|
| 695 |
+
"file_extension": ".py",
|
| 696 |
+
"mimetype": "text/x-python",
|
| 697 |
+
"name": "python",
|
| 698 |
+
"nbconvert_exporter": "python",
|
| 699 |
+
"pygments_lexer": "ipython3",
|
| 700 |
+
"version": "3.10.5 (v3.10.5:f377153967, Jun 6 2022, 12:36:10) [Clang 13.0.0 (clang-1300.0.29.30)]"
|
| 701 |
+
},
|
| 702 |
+
"vscode": {
|
| 703 |
+
"interpreter": {
|
| 704 |
+
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
|
| 705 |
+
}
|
| 706 |
+
}
|
| 707 |
+
},
|
| 708 |
+
"nbformat": 4,
|
| 709 |
+
"nbformat_minor": 5
|
| 710 |
+
}
|