You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
331 lines
11 KiB
331 lines
11 KiB
![]()
2 years ago
|
"""Copyright (c) 2022, salesforce.com, inc.
|
||
|
|
||
|
All rights reserved.
|
||
|
SPDX-License-Identifier: BSD-3-Clause
|
||
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||
|
"""
|
||
|
|
||
|
# coding=utf-8
|
||
|
|
||
|
__author__ = 'aagrawal'
|
||
|
|
||
|
import re
|
||
|
# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
|
||
|
# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
|
||
|
import sys
|
||
|
|
||
|
|
||
|
class VQAEval:
|
||
|
|
||
|
def __init__(self, vqa=None, vqaRes=None, n=2):
|
||
|
self.n = n
|
||
|
self.accuracy = {}
|
||
|
self.evalQA = {}
|
||
|
self.evalQuesType = {}
|
||
|
self.evalAnsType = {}
|
||
|
self.vqa = vqa
|
||
|
self.vqaRes = vqaRes
|
||
|
if vqa is not None:
|
||
|
self.params = {'question_id': vqa.getQuesIds()}
|
||
|
self.contractions = {
|
||
|
'aint': "ain't",
|
||
|
'arent': "aren't",
|
||
|
'cant': "can't",
|
||
|
'couldve': "could've",
|
||
|
'couldnt': "couldn't",
|
||
|
"couldn'tve": "couldn't've",
|
||
|
"couldnt've": "couldn't've",
|
||
|
'didnt': "didn't",
|
||
|
'doesnt': "doesn't",
|
||
|
'dont': "don't",
|
||
|
'hadnt': "hadn't",
|
||
|
"hadnt've": "hadn't've",
|
||
|
"hadn'tve": "hadn't've",
|
||
|
'hasnt': "hasn't",
|
||
|
'havent': "haven't",
|
||
|
'hed': "he'd",
|
||
|
"hed've": "he'd've",
|
||
|
"he'dve": "he'd've",
|
||
|
'hes': "he's",
|
||
|
'howd': "how'd",
|
||
|
'howll': "how'll",
|
||
|
'hows': "how's",
|
||
|
"Id've": "I'd've",
|
||
|
"I'dve": "I'd've",
|
||
|
'Im': "I'm",
|
||
|
'Ive': "I've",
|
||
|
'isnt': "isn't",
|
||
|
'itd': "it'd",
|
||
|
"itd've": "it'd've",
|
||
|
"it'dve": "it'd've",
|
||
|
'itll': "it'll",
|
||
|
"let's": "let's",
|
||
|
'maam': "ma'am",
|
||
|
'mightnt': "mightn't",
|
||
|
"mightnt've": "mightn't've",
|
||
|
"mightn'tve": "mightn't've",
|
||
|
'mightve': "might've",
|
||
|
'mustnt': "mustn't",
|
||
|
'mustve': "must've",
|
||
|
'neednt': "needn't",
|
||
|
'notve': "not've",
|
||
|
'oclock': "o'clock",
|
||
|
'oughtnt': "oughtn't",
|
||
|
"ow's'at": "'ow's'at",
|
||
|
"'ows'at": "'ow's'at",
|
||
|
"'ow'sat": "'ow's'at",
|
||
|
'shant': "shan't",
|
||
|
"shed've": "she'd've",
|
||
|
"she'dve": "she'd've",
|
||
|
"she's": "she's",
|
||
|
'shouldve': "should've",
|
||
|
'shouldnt': "shouldn't",
|
||
|
"shouldnt've": "shouldn't've",
|
||
|
"shouldn'tve": "shouldn't've",
|
||
|
"somebody'd": 'somebodyd',
|
||
|
"somebodyd've": "somebody'd've",
|
||
|
"somebody'dve": "somebody'd've",
|
||
|
'somebodyll': "somebody'll",
|
||
|
'somebodys': "somebody's",
|
||
|
'someoned': "someone'd",
|
||
|
"someoned've": "someone'd've",
|
||
|
"someone'dve": "someone'd've",
|
||
|
'someonell': "someone'll",
|
||
|
'someones': "someone's",
|
||
|
'somethingd': "something'd",
|
||
|
"somethingd've": "something'd've",
|
||
|
"something'dve": "something'd've",
|
||
|
'somethingll': "something'll",
|
||
|
'thats': "that's",
|
||
|
'thered': "there'd",
|
||
|
"thered've": "there'd've",
|
||
|
"there'dve": "there'd've",
|
||
|
'therere': "there're",
|
||
|
'theres': "there's",
|
||
|
'theyd': "they'd",
|
||
|
"theyd've": "they'd've",
|
||
|
"they'dve": "they'd've",
|
||
|
'theyll': "they'll",
|
||
|
'theyre': "they're",
|
||
|
'theyve': "they've",
|
||
|
'twas': "'twas",
|
||
|
'wasnt': "wasn't",
|
||
|
"wed've": "we'd've",
|
||
|
"we'dve": "we'd've",
|
||
|
'weve': "we've",
|
||
|
'werent': "weren't",
|
||
|
'whatll': "what'll",
|
||
|
'whatre': "what're",
|
||
|
'whats': "what's",
|
||
|
'whatve': "what've",
|
||
|
'whens': "when's",
|
||
|
'whered': "where'd",
|
||
|
'wheres': "where's",
|
||
|
'whereve': "where've",
|
||
|
'whod': "who'd",
|
||
|
"whod've": "who'd've",
|
||
|
"who'dve": "who'd've",
|
||
|
'wholl': "who'll",
|
||
|
'whos': "who's",
|
||
|
'whove': "who've",
|
||
|
'whyll': "why'll",
|
||
|
'whyre': "why're",
|
||
|
'whys': "why's",
|
||
|
'wont': "won't",
|
||
|
'wouldve': "would've",
|
||
|
'wouldnt': "wouldn't",
|
||
|
"wouldnt've": "wouldn't've",
|
||
|
"wouldn'tve": "wouldn't've",
|
||
|
'yall': "y'all",
|
||
|
"yall'll": "y'all'll",
|
||
|
"y'allll": "y'all'll",
|
||
|
"yall'd've": "y'all'd've",
|
||
|
"y'alld've": "y'all'd've",
|
||
|
"y'all'dve": "y'all'd've",
|
||
|
'youd': "you'd",
|
||
|
"youd've": "you'd've",
|
||
|
"you'dve": "you'd've",
|
||
|
'youll': "you'll",
|
||
|
'youre': "you're",
|
||
|
'youve': "you've",
|
||
|
}
|
||
|
self.manualMap = {
|
||
|
'none': '0',
|
||
|
'zero': '0',
|
||
|
'one': '1',
|
||
|
'two': '2',
|
||
|
'three': '3',
|
||
|
'four': '4',
|
||
|
'five': '5',
|
||
|
'six': '6',
|
||
|
'seven': '7',
|
||
|
'eight': '8',
|
||
|
'nine': '9',
|
||
|
'ten': '10',
|
||
|
}
|
||
|
self.articles = ['a', 'an', 'the']
|
||
|
|
||
|
self.periodStrip = re.compile('(?!<=\d)(\.)(?!\d)')
|
||
|
self.commaStrip = re.compile('(\d)(,)(\d)')
|
||
|
self.punct = [
|
||
|
';',
|
||
|
r'/',
|
||
|
'[',
|
||
|
']',
|
||
|
'"',
|
||
|
'{',
|
||
|
'}',
|
||
|
'(',
|
||
|
')',
|
||
|
'=',
|
||
|
'+',
|
||
|
'\\',
|
||
|
'_',
|
||
|
'-',
|
||
|
'>',
|
||
|
'<',
|
||
|
'@',
|
||
|
'`',
|
||
|
',',
|
||
|
'?',
|
||
|
'!',
|
||
|
]
|
||
|
|
||
|
def evaluate(self, quesIds=None):
|
||
|
if quesIds == None:
|
||
|
quesIds = [quesId for quesId in self.params['question_id']]
|
||
|
gts = {}
|
||
|
res = {}
|
||
|
for quesId in quesIds:
|
||
|
gts[quesId] = self.vqa.qa[quesId]
|
||
|
res[quesId] = self.vqaRes.qa[quesId]
|
||
|
|
||
|
# =================================================
|
||
|
# Compute accuracy
|
||
|
# =================================================
|
||
|
accQA = []
|
||
|
accQuesType = {}
|
||
|
accAnsType = {}
|
||
|
print('computing accuracy')
|
||
|
step = 0
|
||
|
for quesId in quesIds:
|
||
|
resAns = res[quesId]['answer']
|
||
|
resAns = resAns.replace('\n', ' ')
|
||
|
resAns = resAns.replace('\t', ' ')
|
||
|
resAns = resAns.strip()
|
||
|
resAns = self.processPunctuation(resAns)
|
||
|
resAns = self.processDigitArticle(resAns)
|
||
|
gtAcc = []
|
||
|
gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
|
||
|
if len(set(gtAnswers)) > 1:
|
||
|
for ansDic in gts[quesId]['answers']:
|
||
|
ansDic['answer'] = self.processPunctuation(
|
||
|
ansDic['answer'])
|
||
|
for gtAnsDatum in gts[quesId]['answers']:
|
||
|
otherGTAns = [
|
||
|
item for item in gts[quesId]['answers']
|
||
|
if item != gtAnsDatum
|
||
|
]
|
||
|
matchingAns = [
|
||
|
item for item in otherGTAns if item['answer'] == resAns
|
||
|
]
|
||
|
acc = min(1, float(len(matchingAns)) / 3)
|
||
|
gtAcc.append(acc)
|
||
|
quesType = gts[quesId]['question_type']
|
||
|
ansType = gts[quesId]['answer_type']
|
||
|
avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
|
||
|
accQA.append(avgGTAcc)
|
||
|
if quesType not in accQuesType:
|
||
|
accQuesType[quesType] = []
|
||
|
accQuesType[quesType].append(avgGTAcc)
|
||
|
if ansType not in accAnsType:
|
||
|
accAnsType[ansType] = []
|
||
|
accAnsType[ansType].append(avgGTAcc)
|
||
|
self.setEvalQA(quesId, avgGTAcc)
|
||
|
self.setEvalQuesType(quesId, quesType, avgGTAcc)
|
||
|
self.setEvalAnsType(quesId, ansType, avgGTAcc)
|
||
|
if step % 100 == 0:
|
||
|
self.updateProgress(step / float(len(quesIds)))
|
||
|
step = step + 1
|
||
|
|
||
|
self.setAccuracy(accQA, accQuesType, accAnsType)
|
||
|
print('Done computing accuracy')
|
||
|
|
||
|
def processPunctuation(self, inText):
|
||
|
outText = inText
|
||
|
for p in self.punct:
|
||
|
if (p + ' ' in inText or ' ' + p
|
||
|
in inText) or (re.search(self.commaStrip, inText) != None):
|
||
|
outText = outText.replace(p, '')
|
||
|
else:
|
||
|
outText = outText.replace(p, ' ')
|
||
|
outText = self.periodStrip.sub('', outText, re.UNICODE)
|
||
|
return outText
|
||
|
|
||
|
def processDigitArticle(self, inText):
|
||
|
outText = []
|
||
|
tempText = inText.lower().split()
|
||
|
for word in tempText:
|
||
|
word = self.manualMap.setdefault(word, word)
|
||
|
if word not in self.articles:
|
||
|
outText.append(word)
|
||
|
else:
|
||
|
pass
|
||
|
for wordId, word in enumerate(outText):
|
||
|
if word in self.contractions:
|
||
|
outText[wordId] = self.contractions[word]
|
||
|
outText = ' '.join(outText)
|
||
|
return outText
|
||
|
|
||
|
def setAccuracy(self, accQA, accQuesType, accAnsType):
|
||
|
self.accuracy['overall'] = round(100 * float(sum(accQA)) / len(accQA),
|
||
|
self.n)
|
||
|
self.accuracy['perQuestionType'] = {
|
||
|
quesType: round(
|
||
|
100 * float(sum(accQuesType[quesType])) /
|
||
|
len(accQuesType[quesType]),
|
||
|
self.n,
|
||
|
)
|
||
|
for quesType in accQuesType
|
||
|
}
|
||
|
self.accuracy['perAnswerType'] = {
|
||
|
ansType: round(
|
||
|
100 * float(sum(accAnsType[ansType])) /
|
||
|
len(accAnsType[ansType]), self.n)
|
||
|
for ansType in accAnsType
|
||
|
}
|
||
|
|
||
|
def setEvalQA(self, quesId, acc):
|
||
|
self.evalQA[quesId] = round(100 * acc, self.n)
|
||
|
|
||
|
def setEvalQuesType(self, quesId, quesType, acc):
|
||
|
if quesType not in self.evalQuesType:
|
||
|
self.evalQuesType[quesType] = {}
|
||
|
self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
|
||
|
|
||
|
def setEvalAnsType(self, quesId, ansType, acc):
|
||
|
if ansType not in self.evalAnsType:
|
||
|
self.evalAnsType[ansType] = {}
|
||
|
self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
|
||
|
|
||
|
def updateProgress(self, progress):
|
||
|
barLength = 20
|
||
|
status = ''
|
||
|
if isinstance(progress, int):
|
||
|
progress = float(progress)
|
||
|
if not isinstance(progress, float):
|
||
|
progress = 0
|
||
|
status = 'error: progress var must be float\r\n'
|
||
|
if progress < 0:
|
||
|
progress = 0
|
||
|
status = 'Halt...\r\n'
|
||
|
if progress >= 1:
|
||
|
progress = 1
|
||
|
status = 'Done...\r\n'
|
||
|
block = int(round(barLength * progress))
|
||
|
text = '\rFinshed Percent: [{0}] {1}% {2}'.format(
|
||
|
'#' * block + '-' * (barLength - block), int(progress * 100),
|
||
|
status)
|
||
|
sys.stdout.write(text)
|
||
|
sys.stdout.flush()
|