You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
249 lines
10 KiB
249 lines
10 KiB
![]()
2 years ago
|
# This file can be downloaded from: https://www.docvqa.org/datasets/infographicvqa and https://rrc.cvc.uab.es/?ch=17&com=introduction
|
||
|
|
||
|
import os, json
|
||
|
import argparse
|
||
|
|
||
|
question_ids_to_exclude = []
|
||
|
|
||
|
# answer_types = {'image span': 'Image-Span', 'question span': 'Question-Span', 'multiple spans': 'Multi-Span', 'non span': 'None span', 'list': 'List'}
|
||
|
answer_types = {'image span': 'Image-Span', 'question span': 'Question-Span', 'multiple spans': 'Multi-Span', 'non span': 'None span'}
|
||
|
evidence_types = {'table/list': 'Table/list', 'textual': 'Text', 'photo/pciture/visual_objects': 'Visual/Layout', 'figure': 'Figure', 'map': 'Map'}
|
||
|
reasoning_requirements = {'comparison': 'Sorting', 'arithmetic': 'Arithmetic', 'counting':'Counting'}
|
||
|
|
||
|
|
||
|
def save_json(file_path, data):
|
||
|
with open(file_path, 'w+') as json_file:
|
||
|
json.dump(data, json_file)
|
||
|
|
||
|
|
||
|
|
||
|
def levenshtein_distance(s1, s2):
|
||
|
if len(s1) > len(s2):
|
||
|
s1, s2 = s2, s1
|
||
|
|
||
|
distances = range(len(s1) + 1)
|
||
|
for i2, c2 in enumerate(s2):
|
||
|
distances_ = [i2+1]
|
||
|
for i1, c1 in enumerate(s1):
|
||
|
if c1 == c2:
|
||
|
distances_.append(distances[i1])
|
||
|
else:
|
||
|
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
|
||
|
distances = distances_
|
||
|
return distances[-1]
|
||
|
|
||
|
|
||
|
def validate_data(gtFilePath, submFilePath):
|
||
|
"""
|
||
|
Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
|
||
|
Validates also that there are no missing files in the folder.
|
||
|
If some error detected, the method raises the error
|
||
|
"""
|
||
|
|
||
|
gtJson = json.load(open(gtFilePath,'rb'));
|
||
|
submJson = json.load(open(submFilePath,'rb'));
|
||
|
|
||
|
if not 'data' in gtJson:
|
||
|
raise Exception("The GT file is not valid (no data key)")
|
||
|
|
||
|
if not 'dataset_name' in gtJson:
|
||
|
raise Exception("The GT file is not valid (no dataset_name key)")
|
||
|
|
||
|
if isinstance(submJson, list) == False :
|
||
|
raise Exception("The Det file is not valid (root item must be an array)")
|
||
|
|
||
|
if len(submJson) != len(gtJson['data']) :
|
||
|
raise Exception("The Det file is not valid (invalid number of answers. Expected:" + str(len(gtJson['data'])) + " Found:" + str(len(submJson)) + ")")
|
||
|
|
||
|
gtQuestions = sorted([r['questionId'] for r in gtJson['data']])
|
||
|
res_id_to_index = {int(r['questionId']): ix for ix, r in enumerate(submJson)}
|
||
|
detQuestions = sorted([r['questionId'] for r in submJson])
|
||
|
|
||
|
if( (gtQuestions == detQuestions) == False ):
|
||
|
raise Exception("The Det file is not valid. Question IDs must much GT")
|
||
|
|
||
|
for gtObject in gtJson['data']:
|
||
|
|
||
|
try:
|
||
|
q_id = int(gtObject['questionId']);
|
||
|
res_ix = res_id_to_index[q_id];
|
||
|
|
||
|
except:
|
||
|
raise Exception("The Det file is not valid. Question " + str(gtObject['questionId']) + " not present")
|
||
|
|
||
|
else:
|
||
|
detObject = submJson[res_ix];
|
||
|
|
||
|
# if detObject['questionId'] != gtObject['questionId'] :
|
||
|
# raise Exception("Answer #" + str(i) + " not valid (invalid question ID. Expected:" + str(gtObject['questionId']) + "Found:" + detObject['questionId'] + ")")
|
||
|
|
||
|
if not 'answer' in detObject:
|
||
|
raise Exception("Question " + str(gtObject['questionId']) + " not valid (no answer key)")
|
||
|
|
||
|
if isinstance(detObject['answer'], list) == True :
|
||
|
raise Exception("Question " + str(gtObject['questionId']) + " not valid (answer key has to be a single string)")
|
||
|
|
||
|
|
||
|
def evaluate_method(gtFilePath, submFilePath, evaluationParams):
|
||
|
"""
|
||
|
Method evaluate_method: evaluate method and returns the results
|
||
|
Results. Dictionary with the following values:
|
||
|
- method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
|
||
|
- samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
|
||
|
"""
|
||
|
|
||
|
show_scores_per_answer_type = evaluationParams.answer_types
|
||
|
|
||
|
gtJson = json.load(open(gtFilePath,'rb'));
|
||
|
submJson = json.load(open(submFilePath,'rb'));
|
||
|
|
||
|
res_id_to_index = {int(r['questionId']): ix for ix, r in enumerate(submJson)}
|
||
|
|
||
|
|
||
|
perSampleMetrics = {}
|
||
|
|
||
|
totalScore = 0
|
||
|
row = 0
|
||
|
|
||
|
if show_scores_per_answer_type:
|
||
|
answerTypeTotalScore = {x:0 for x in answer_types.keys()}
|
||
|
answerTypeNumQuestions = {x:0 for x in answer_types.keys()}
|
||
|
|
||
|
evidenceTypeTotalScore = {x:0 for x in evidence_types.keys()}
|
||
|
evidenceTypeNumQuestions = {x:0 for x in evidence_types.keys()}
|
||
|
|
||
|
reasoningTypeTotalScore = {x:0 for x in reasoning_requirements.keys()}
|
||
|
reasoningTypeNumQuestions = {x:0 for x in reasoning_requirements.keys()}
|
||
|
|
||
|
for gtObject in gtJson['data']:
|
||
|
|
||
|
q_id = int(gtObject['questionId']);
|
||
|
res_ix = res_id_to_index[q_id];
|
||
|
detObject = submJson[res_ix];
|
||
|
|
||
|
if q_id in question_ids_to_exclude:
|
||
|
question_result = 0
|
||
|
info = 'Question EXCLUDED from the result'
|
||
|
|
||
|
else:
|
||
|
info = ''
|
||
|
values = []
|
||
|
for answer in gtObject['answers']:
|
||
|
# preprocess both the answers - gt and prediction
|
||
|
gt_answer = ' '.join(answer.strip().lower().split())
|
||
|
det_answer = ' '.join(detObject['answer'].strip().lower().split())
|
||
|
|
||
|
#dist = levenshtein_distance(answer.lower(), detObject['answer'].lower())
|
||
|
dist = levenshtein_distance(gt_answer,det_answer)
|
||
|
length = max( len(answer.upper()), len(detObject['answer'].upper()) )
|
||
|
values.append( 0.0 if length == 0 else float(dist) / float(length) )
|
||
|
|
||
|
question_result = 1 - min(values)
|
||
|
|
||
|
if (question_result < evaluationParams.anls_threshold) :
|
||
|
question_result = 0
|
||
|
|
||
|
totalScore += question_result
|
||
|
|
||
|
if show_scores_per_answer_type:
|
||
|
for q_type in gtObject["answer_type"]:
|
||
|
answerTypeTotalScore[q_type] += question_result
|
||
|
answerTypeNumQuestions[q_type] += 1
|
||
|
|
||
|
for q_type in gtObject["evidence"]:
|
||
|
evidenceTypeTotalScore[q_type] += question_result
|
||
|
evidenceTypeNumQuestions[q_type] += 1
|
||
|
|
||
|
for q_type in gtObject["operation/reasoning"]:
|
||
|
reasoningTypeTotalScore[q_type] += question_result
|
||
|
reasoningTypeNumQuestions[q_type] += 1
|
||
|
|
||
|
|
||
|
perSampleMetrics[str(gtObject['questionId'])] = {
|
||
|
'score':question_result,
|
||
|
'question':gtObject['question'],
|
||
|
'gt':gtObject['answers'],
|
||
|
'det':detObject['answer'],
|
||
|
'info': info
|
||
|
}
|
||
|
row = row + 1
|
||
|
|
||
|
|
||
|
methodMetrics = {
|
||
|
'score': 0 if len(gtJson['data']) == 0 else totalScore/ (len(gtJson['data']) - len(question_ids_to_exclude) )
|
||
|
}
|
||
|
|
||
|
answer_types_scores = {}
|
||
|
evidence_types_scores = {}
|
||
|
operation_types_scores = {}
|
||
|
|
||
|
if show_scores_per_answer_type:
|
||
|
for a_type, ref in answer_types.items():
|
||
|
answer_types_scores[ref] = 0 if len(gtJson['data']) == 0 else answerTypeTotalScore[a_type] / (answerTypeNumQuestions[a_type] )
|
||
|
|
||
|
for e_type, ref in evidence_types.items():
|
||
|
evidence_types_scores[ref] = 0 if len(gtJson['data']) == 0 else evidenceTypeTotalScore[e_type] / (evidenceTypeNumQuestions[e_type] )
|
||
|
|
||
|
for r_type, ref in reasoning_requirements.items():
|
||
|
operation_types_scores[ref] = 0 if len(gtJson['data']) == 0 else reasoningTypeTotalScore[r_type] / (reasoningTypeNumQuestions[r_type] )
|
||
|
|
||
|
|
||
|
resDict = {
|
||
|
'result': methodMetrics,
|
||
|
'scores_by_types': {'answer_types': answer_types_scores, 'evidence_types': evidence_types_scores, 'operation_types': operation_types_scores},
|
||
|
'per_sample_result':perSampleMetrics
|
||
|
}
|
||
|
|
||
|
return resDict;
|
||
|
|
||
|
|
||
|
def display_results(results, show_answer_types):
|
||
|
print("\nOverall ANLS: {:2.4f}".format(results['result']['score']))
|
||
|
|
||
|
if show_answer_types:
|
||
|
print("\nAnswer types:")
|
||
|
for a_type in answer_types.values():
|
||
|
print("\t{:12s} {:2.4f}".format(a_type, results['scores_by_types']['answer_types'][a_type]))
|
||
|
|
||
|
print("\nEvidence types:")
|
||
|
for e_type in evidence_types.values():
|
||
|
print("\t{:12s} {:2.4f}".format(e_type, results['scores_by_types']['evidence_types'][e_type]))
|
||
|
|
||
|
print("\nOperation required:")
|
||
|
for r_type in reasoning_requirements.values():
|
||
|
print("\t{:12s} {:2.4f}".format(r_type, results['scores_by_types']['operation_types'][r_type]))
|
||
|
|
||
|
|
||
|
|
||
|
if __name__=='__main__':
|
||
|
parser = argparse.ArgumentParser(description="InfographVQA evaluation script.")
|
||
|
|
||
|
parser.add_argument('-g', '--ground_truth', type=str, help="Path of the Ground Truth file.", required=True)
|
||
|
parser.add_argument('-s', '--submission_file', type=str, help="Path of your method's results file.", required=True)
|
||
|
|
||
|
parser.add_argument('-t', '--anls_threshold', type=float, default=0.5, help="ANLS threshold to use (See Scene-Text VQA paper for more info.).", required=False)
|
||
|
parser.add_argument('-a', '--answer_types', type=bool, default=False, help="Score break down by answer types (special gt file required).", required=False)
|
||
|
parser.add_argument('-o', '--output', type=str, help="Path to a directory where to copy the file 'results.json' that contains per-sample results.", required=False)
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
# Validate the format of ground truth and submission files.
|
||
|
validate_data(args.ground_truth, args.submission_file)
|
||
|
|
||
|
# Evaluate method
|
||
|
results = evaluate_method(args.ground_truth, args.submission_file, args)
|
||
|
|
||
|
display_results(results, args.answer_types)
|
||
|
|
||
|
if args.output:
|
||
|
output_dir = args.output
|
||
|
|
||
|
if not os.path.exists(output_dir):
|
||
|
os.makedirs(output_dir)
|
||
|
|
||
|
resultsOutputname = os.path.join(output_dir, 'results.json')
|
||
|
save_json(resultsOutputname, results)
|
||
|
|
||
|
print("All results including per-sample result has been correctly saved!")
|
||
|
|