You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

249 lines
10 KiB

2 years ago
# This file can be downloaded from: https://www.docvqa.org/datasets/infographicvqa and https://rrc.cvc.uab.es/?ch=17&com=introduction
import os, json
import argparse
question_ids_to_exclude = []
# answer_types = {'image span': 'Image-Span', 'question span': 'Question-Span', 'multiple spans': 'Multi-Span', 'non span': 'None span', 'list': 'List'}
answer_types = {'image span': 'Image-Span', 'question span': 'Question-Span', 'multiple spans': 'Multi-Span', 'non span': 'None span'}
evidence_types = {'table/list': 'Table/list', 'textual': 'Text', 'photo/pciture/visual_objects': 'Visual/Layout', 'figure': 'Figure', 'map': 'Map'}
reasoning_requirements = {'comparison': 'Sorting', 'arithmetic': 'Arithmetic', 'counting':'Counting'}
def save_json(file_path, data):
with open(file_path, 'w+') as json_file:
json.dump(data, json_file)
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2+1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
def validate_data(gtFilePath, submFilePath):
"""
Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
Validates also that there are no missing files in the folder.
If some error detected, the method raises the error
"""
gtJson = json.load(open(gtFilePath,'rb'));
submJson = json.load(open(submFilePath,'rb'));
if not 'data' in gtJson:
raise Exception("The GT file is not valid (no data key)")
if not 'dataset_name' in gtJson:
raise Exception("The GT file is not valid (no dataset_name key)")
if isinstance(submJson, list) == False :
raise Exception("The Det file is not valid (root item must be an array)")
if len(submJson) != len(gtJson['data']) :
raise Exception("The Det file is not valid (invalid number of answers. Expected:" + str(len(gtJson['data'])) + " Found:" + str(len(submJson)) + ")")
gtQuestions = sorted([r['questionId'] for r in gtJson['data']])
res_id_to_index = {int(r['questionId']): ix for ix, r in enumerate(submJson)}
detQuestions = sorted([r['questionId'] for r in submJson])
if( (gtQuestions == detQuestions) == False ):
raise Exception("The Det file is not valid. Question IDs must much GT")
for gtObject in gtJson['data']:
try:
q_id = int(gtObject['questionId']);
res_ix = res_id_to_index[q_id];
except:
raise Exception("The Det file is not valid. Question " + str(gtObject['questionId']) + " not present")
else:
detObject = submJson[res_ix];
# if detObject['questionId'] != gtObject['questionId'] :
# raise Exception("Answer #" + str(i) + " not valid (invalid question ID. Expected:" + str(gtObject['questionId']) + "Found:" + detObject['questionId'] + ")")
if not 'answer' in detObject:
raise Exception("Question " + str(gtObject['questionId']) + " not valid (no answer key)")
if isinstance(detObject['answer'], list) == True :
raise Exception("Question " + str(gtObject['questionId']) + " not valid (answer key has to be a single string)")
def evaluate_method(gtFilePath, submFilePath, evaluationParams):
"""
Method evaluate_method: evaluate method and returns the results
Results. Dictionary with the following values:
- method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
- samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
"""
show_scores_per_answer_type = evaluationParams.answer_types
gtJson = json.load(open(gtFilePath,'rb'));
submJson = json.load(open(submFilePath,'rb'));
res_id_to_index = {int(r['questionId']): ix for ix, r in enumerate(submJson)}
perSampleMetrics = {}
totalScore = 0
row = 0
if show_scores_per_answer_type:
answerTypeTotalScore = {x:0 for x in answer_types.keys()}
answerTypeNumQuestions = {x:0 for x in answer_types.keys()}
evidenceTypeTotalScore = {x:0 for x in evidence_types.keys()}
evidenceTypeNumQuestions = {x:0 for x in evidence_types.keys()}
reasoningTypeTotalScore = {x:0 for x in reasoning_requirements.keys()}
reasoningTypeNumQuestions = {x:0 for x in reasoning_requirements.keys()}
for gtObject in gtJson['data']:
q_id = int(gtObject['questionId']);
res_ix = res_id_to_index[q_id];
detObject = submJson[res_ix];
if q_id in question_ids_to_exclude:
question_result = 0
info = 'Question EXCLUDED from the result'
else:
info = ''
values = []
for answer in gtObject['answers']:
# preprocess both the answers - gt and prediction
gt_answer = ' '.join(answer.strip().lower().split())
det_answer = ' '.join(detObject['answer'].strip().lower().split())
#dist = levenshtein_distance(answer.lower(), detObject['answer'].lower())
dist = levenshtein_distance(gt_answer,det_answer)
length = max( len(answer.upper()), len(detObject['answer'].upper()) )
values.append( 0.0 if length == 0 else float(dist) / float(length) )
question_result = 1 - min(values)
if (question_result < evaluationParams.anls_threshold) :
question_result = 0
totalScore += question_result
if show_scores_per_answer_type:
for q_type in gtObject["answer_type"]:
answerTypeTotalScore[q_type] += question_result
answerTypeNumQuestions[q_type] += 1
for q_type in gtObject["evidence"]:
evidenceTypeTotalScore[q_type] += question_result
evidenceTypeNumQuestions[q_type] += 1
for q_type in gtObject["operation/reasoning"]:
reasoningTypeTotalScore[q_type] += question_result
reasoningTypeNumQuestions[q_type] += 1
perSampleMetrics[str(gtObject['questionId'])] = {
'score':question_result,
'question':gtObject['question'],
'gt':gtObject['answers'],
'det':detObject['answer'],
'info': info
}
row = row + 1
methodMetrics = {
'score': 0 if len(gtJson['data']) == 0 else totalScore/ (len(gtJson['data']) - len(question_ids_to_exclude) )
}
answer_types_scores = {}
evidence_types_scores = {}
operation_types_scores = {}
if show_scores_per_answer_type:
for a_type, ref in answer_types.items():
answer_types_scores[ref] = 0 if len(gtJson['data']) == 0 else answerTypeTotalScore[a_type] / (answerTypeNumQuestions[a_type] )
for e_type, ref in evidence_types.items():
evidence_types_scores[ref] = 0 if len(gtJson['data']) == 0 else evidenceTypeTotalScore[e_type] / (evidenceTypeNumQuestions[e_type] )
for r_type, ref in reasoning_requirements.items():
operation_types_scores[ref] = 0 if len(gtJson['data']) == 0 else reasoningTypeTotalScore[r_type] / (reasoningTypeNumQuestions[r_type] )
resDict = {
'result': methodMetrics,
'scores_by_types': {'answer_types': answer_types_scores, 'evidence_types': evidence_types_scores, 'operation_types': operation_types_scores},
'per_sample_result':perSampleMetrics
}
return resDict;
def display_results(results, show_answer_types):
print("\nOverall ANLS: {:2.4f}".format(results['result']['score']))
if show_answer_types:
print("\nAnswer types:")
for a_type in answer_types.values():
print("\t{:12s} {:2.4f}".format(a_type, results['scores_by_types']['answer_types'][a_type]))
print("\nEvidence types:")
for e_type in evidence_types.values():
print("\t{:12s} {:2.4f}".format(e_type, results['scores_by_types']['evidence_types'][e_type]))
print("\nOperation required:")
for r_type in reasoning_requirements.values():
print("\t{:12s} {:2.4f}".format(r_type, results['scores_by_types']['operation_types'][r_type]))
if __name__=='__main__':
parser = argparse.ArgumentParser(description="InfographVQA evaluation script.")
parser.add_argument('-g', '--ground_truth', type=str, help="Path of the Ground Truth file.", required=True)
parser.add_argument('-s', '--submission_file', type=str, help="Path of your method's results file.", required=True)
parser.add_argument('-t', '--anls_threshold', type=float, default=0.5, help="ANLS threshold to use (See Scene-Text VQA paper for more info.).", required=False)
parser.add_argument('-a', '--answer_types', type=bool, default=False, help="Score break down by answer types (special gt file required).", required=False)
parser.add_argument('-o', '--output', type=str, help="Path to a directory where to copy the file 'results.json' that contains per-sample results.", required=False)
args = parser.parse_args()
# Validate the format of ground truth and submission files.
validate_data(args.ground_truth, args.submission_file)
# Evaluate method
results = evaluate_method(args.ground_truth, args.submission_file, args)
display_results(results, args.answer_types)
if args.output:
output_dir = args.output
if not os.path.exists(output_dir):
os.makedirs(output_dir)
resultsOutputname = os.path.join(output_dir, 'results.json')
save_json(resultsOutputname, results)
print("All results including per-sample result has been correctly saved!")