Spaces:
Configuration error
Configuration error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| # Partly adopted from https://github.com/GT-Vision-Lab/VQA | |
| # Copyright (c) 2014, Aishwarya Agrawal | |
| import re | |
| from vlmeval.smp import * | |
| from typing import Optional | |
| from functools import partial | |
| def _process_digit_article(inText): | |
| outText = [] | |
| tempText = inText.lower().split() | |
| articles = ['a', 'an', 'the'] | |
| manualMap = { | |
| 'none': '0', | |
| 'zero': '0', | |
| 'one': '1', | |
| 'two': '2', | |
| 'three': '3', | |
| 'four': '4', | |
| 'five': '5', | |
| 'six': '6', | |
| 'seven': '7', | |
| 'eight': '8', | |
| 'nine': '9', | |
| 'ten': '10', | |
| } | |
| contractions = { | |
| 'aint': "ain't", | |
| 'arent': "aren't", | |
| 'cant': "can't", | |
| 'couldve': "could've", | |
| 'couldnt': "couldn't", | |
| "couldn'tve": "couldn't've", | |
| "couldnt've": "couldn't've", | |
| 'didnt': "didn't", | |
| 'doesnt': "doesn't", | |
| 'dont': "don't", | |
| 'hadnt': "hadn't", | |
| "hadnt've": "hadn't've", | |
| "hadn'tve": "hadn't've", | |
| 'hasnt': "hasn't", | |
| 'havent': "haven't", | |
| 'hed': "he'd", | |
| "hed've": "he'd've", | |
| "he'dve": "he'd've", | |
| 'hes': "he's", | |
| 'howd': "how'd", | |
| 'howll': "how'll", | |
| 'hows': "how's", | |
| "Id've": "I'd've", | |
| "I'dve": "I'd've", | |
| 'Im': "I'm", | |
| 'Ive': "I've", | |
| 'isnt': "isn't", | |
| 'itd': "it'd", | |
| "itd've": "it'd've", | |
| "it'dve": "it'd've", | |
| 'itll': "it'll", | |
| "let's": "let's", | |
| 'maam': "ma'am", | |
| 'mightnt': "mightn't", | |
| "mightnt've": "mightn't've", | |
| "mightn'tve": "mightn't've", | |
| 'mightve': "might've", | |
| 'mustnt': "mustn't", | |
| 'mustve': "must've", | |
| 'neednt': "needn't", | |
| 'notve': "not've", | |
| 'oclock': "o'clock", | |
| 'oughtnt': "oughtn't", | |
| "ow's'at": "'ow's'at", | |
| "'ows'at": "'ow's'at", | |
| "'ow'sat": "'ow's'at", | |
| 'shant': "shan't", | |
| "shed've": "she'd've", | |
| "she'dve": "she'd've", | |
| "she's": "she's", | |
| 'shouldve': "should've", | |
| 'shouldnt': "shouldn't", | |
| "shouldnt've": "shouldn't've", | |
| "shouldn'tve": "shouldn't've", | |
| "somebody'd": 'somebodyd', | |
| "somebodyd've": "somebody'd've", | |
| "somebody'dve": "somebody'd've", | |
| 'somebodyll': "somebody'll", | |
| 'somebodys': "somebody's", | |
| 'someoned': "someone'd", | |
| "someoned've": "someone'd've", | |
| "someone'dve": "someone'd've", | |
| 'someonell': "someone'll", | |
| 'someones': "someone's", | |
| 'somethingd': "something'd", | |
| "somethingd've": "something'd've", | |
| "something'dve": "something'd've", | |
| 'somethingll': "something'll", | |
| 'thats': "that's", | |
| 'thered': "there'd", | |
| "thered've": "there'd've", | |
| "there'dve": "there'd've", | |
| 'therere': "there're", | |
| 'theres': "there's", | |
| 'theyd': "they'd", | |
| "theyd've": "they'd've", | |
| "they'dve": "they'd've", | |
| 'theyll': "they'll", | |
| 'theyre': "they're", | |
| 'theyve': "they've", | |
| 'twas': "'twas", | |
| 'wasnt': "wasn't", | |
| "wed've": "we'd've", | |
| "we'dve": "we'd've", | |
| 'weve': "we've", | |
| 'werent': "weren't", | |
| 'whatll': "what'll", | |
| 'whatre': "what're", | |
| 'whats': "what's", | |
| 'whatve': "what've", | |
| 'whens': "when's", | |
| 'whered': "where'd", | |
| 'wheres': "where's", | |
| 'whereve': "where've", | |
| 'whod': "who'd", | |
| "whod've": "who'd've", | |
| "who'dve": "who'd've", | |
| 'wholl': "who'll", | |
| 'whos': "who's", | |
| 'whove': "who've", | |
| 'whyll': "why'll", | |
| 'whyre': "why're", | |
| 'whys': "why's", | |
| 'wont': "won't", | |
| 'wouldve': "would've", | |
| 'wouldnt': "wouldn't", | |
| "wouldnt've": "wouldn't've", | |
| "wouldn'tve": "wouldn't've", | |
| 'yall': "y'all", | |
| "yall'll": "y'all'll", | |
| "y'allll": "y'all'll", | |
| "yall'd've": "y'all'd've", | |
| "y'alld've": "y'all'd've", | |
| "y'all'dve": "y'all'd've", | |
| 'youd': "you'd", | |
| "youd've": "you'd've", | |
| "you'dve": "you'd've", | |
| 'youll': "you'll", | |
| 'youre': "you're", | |
| 'youve': "you've", | |
| } | |
| for word in tempText: | |
| word = manualMap.setdefault(word, word) | |
| if word not in articles: | |
| outText.append(word) | |
| for wordId, word in enumerate(outText): | |
| if word in contractions: | |
| outText[wordId] = contractions[word] | |
| outText = ' '.join(outText) | |
| return outText | |
| def hit_calculate(result, dataset_name, anls_threshold=0.5): | |
| if listinstr(['TextVQA'], dataset_name): | |
| return [np.mean(x['match']) for x in result] | |
| elif listinstr(['DocVQA', 'InfoVQA'], dataset_name): | |
| # return [1 - np.min(x['match']) >= anls_threshold for x in result] | |
| return [0.0 if 1 - np.min(x['match']) < anls_threshold else 1 - np.min(x['match']) for x in result] | |
| elif listinstr(['ChartQA', 'OCRVQA'], dataset_name): | |
| return [np.max(x['match']) for x in result] | |
| else: # default using vqa_score to calculate score | |
| return [np.mean(x['match']) for x in result] | |
| # https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81 | |
| def relaxed_correctness(target: str, | |
| prediction: str, | |
| max_relative_change: float = 0.05) -> bool: | |
| """Calculates relaxed correctness. | |
| The correctness tolerates certain error ratio defined by max_relative_change. | |
| See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1: | |
| “Following Methani et al. (2020), we use a relaxed accuracy measure for the | |
| numeric answers to allow a minor inaccuracy that may result from the automatic | |
| data extraction process. We consider an answer to be correct if it is within | |
| 5% of the gold answer. For non-numeric answers, we still need an exact match | |
| to consider an answer to be correct.” | |
| Args: | |
| target: Target string. | |
| prediction: Predicted string. | |
| max_relative_change: Maximum relative change. | |
| Returns: | |
| Whether the prediction was correct given the specified tolerance. | |
| """ | |
| def _to_float(text: str) -> Optional[float]: | |
| try: | |
| if text.endswith('%'): | |
| # Convert percentages to floats. | |
| return float(text.rstrip('%')) / 100.0 | |
| else: | |
| return float(text) | |
| except ValueError: | |
| return None | |
| prediction = str(prediction) | |
| target = str(target) | |
| prediction_float = _to_float(prediction) | |
| target_float = _to_float(target) | |
| if prediction_float is not None and target_float: | |
| relative_change = abs(prediction_float - target_float) / abs(target_float) | |
| return relative_change <= max_relative_change | |
| else: | |
| return prediction.lower() == target.lower() | |
| def levenshtein_distance(s1, s2): | |
| if len(s1) > len(s2): | |
| s1, s2 = s2, s1 | |
| distances = range(len(s1) + 1) | |
| for i2, c2 in enumerate(s2): | |
| distances_ = [i2 + 1] | |
| for i1, c1 in enumerate(s1): | |
| if c1 == c2: | |
| distances_.append(distances[i1]) | |
| else: | |
| distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) | |
| distances = distances_ | |
| return distances[-1] | |
| def anls_compute(groundtruth, prediction): | |
| gt_answer = ' '.join(groundtruth.strip().lower().split()) | |
| det_answer = ' '.join(prediction.strip().lower().split()) | |
| dist = levenshtein_distance(gt_answer, det_answer) | |
| length = max(len(groundtruth.upper()), len(prediction.upper())) | |
| values = 0.0 if length == 0 else float(dist) / float(length) | |
| return values | |
| def process_answer(answer): | |
| answer = answer.replace('\n', ' ') | |
| answer = answer.replace('\t', ' ') | |
| answer = answer.strip() | |
| answer = process_punctuation(answer) | |
| answer = _process_digit_article(answer) | |
| return answer | |
| def process_line(line, method='vqa_score'): | |
| ret = {} | |
| if istype(line['answer'], list): | |
| answers = eval(line['answer']) | |
| else: | |
| answers = [line['answer']] | |
| if method == 'vqa_score': | |
| ret['gt'] = [process_answer(x) for x in answers] | |
| ret['pred'] = process_answer(line['prediction']) | |
| ret['match'] = [] | |
| for current_idx, gtAnsDatum in enumerate(ret['gt']): | |
| otherGTAns = [ | |
| item for ret_gt_idx, item in enumerate(ret['gt']) | |
| if ret_gt_idx != current_idx | |
| ] | |
| matchingAns = [ | |
| item for item in otherGTAns if item == ret['pred'] | |
| ] | |
| acc = min(1, float(len(matchingAns)) / 3) | |
| ret['match'].append(acc) | |
| elif method == 'anls': | |
| ret['gt'] = answers | |
| ret['pred'] = line['prediction'] | |
| ret['match'] = [anls_compute(x, ret['pred']) for x in ret['gt']] | |
| elif method == 'relaxed_accuracy': | |
| ret['gt'] = answers | |
| ret['pred'] = line['prediction'].strip() | |
| ret['match'] = [relaxed_correctness(ret['pred'], x) for x in ret['gt']] | |
| elif method == 'accuracy': | |
| ret['gt'] = answers | |
| ret['pred'] = line['prediction'].strip() | |
| ret['match'] = [(1.0 if (x.strip().lower() == ret['pred'].strip().lower()) else 0.0) for x in ret['gt']] | |
| else: # default using vqa_score to calculate score | |
| ret['gt'] = [process_answer(x) for x in answers] | |
| ret['pred'] = process_answer(line['prediction']) | |
| ret['match'] = [x == ret['pred'] for x in ret['gt']] | |
| return ret | |
| def VQAEval(eval_file, dataset_name, **kwargs): | |
| logger = get_logger('Evaluation') | |
| data = load(eval_file) | |
| assert 'answer' in data and 'prediction' in data | |
| data['prediction'] = [str(x) for x in data['prediction']] | |
| data['answer'] = [str(x) for x in data['answer']] | |
| lt = len(data) | |
| pool = mp.Pool(16) | |
| lines = [data.iloc[i] for i in range(lt)] | |
| if listinstr(['TextVQA'], dataset_name): | |
| res = pool.map(partial(process_line, method='vqa_score'), lines) | |
| elif listinstr(['ChartQA'], dataset_name): | |
| res = pool.map(partial(process_line, method='relaxed_accuracy'), lines) | |
| elif listinstr(['OCRVQA'], dataset_name): | |
| res = pool.map(partial(process_line, method='accuracy'), lines) | |
| elif listinstr(['DocVQA', 'InfoVQA'], dataset_name): | |
| res = pool.map(partial(process_line, method='anls'), lines) | |
| else: # default using vqa_score to calculate score | |
| res = pool.map(process_line, lines) | |
| # [np.mean(x['match']) >= full_score_weight for x in res] | |
| hit = hit_calculate(res, dataset_name) | |
| ret = dict() | |
| if 'split' in data: | |
| splits = set(data['split']) | |
| for sp in splits: | |
| sub = [r for l, r in zip(lines, res) if l['split'] == sp] | |
| # [np.mean(x['match']) >= full_score_weight for x in sub] | |
| hit = hit_calculate(sub, dataset_name) | |
| ret[sp] = np.mean(hit) * 100 | |
| sub = [r for l, r in zip(lines, res)] | |
| hit = hit_calculate(sub, dataset_name) | |
| ret['Overall'] = np.mean(hit) * 100 | |
| else: | |
| ret['Overall'] = np.mean(hit) * 100 | |
| if 'category' in data: | |
| cates = list(set(data['category'])) | |
| cates.sort() | |
| for c in cates: | |
| sub = [r for l, r in zip(lines, res) if l['category'] == c] | |
| # [np.mean(x['match']) >= full_score_weight for x in sub] | |
| hit = hit_calculate(sub, dataset_name) | |
| ret[c] = np.mean(hit) * 100 | |
| ret = d2df(ret) | |
| ret.round(2) | |
| suffix = eval_file.split('.')[-1] | |
| result_file = eval_file.replace(f'.{suffix}', '_acc.csv') | |
| logger.info(f'VQA Eval Finished. Saved to {result_file}. ') | |
| logger.info(ret) | |
| dump(ret, result_file) | |