You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					249 lines
				
				10 KiB
			
		
		
			
		
	
	
					249 lines
				
				10 KiB
			|   
											2 years ago
										 | # This file can be downloaded from: https://www.docvqa.org/datasets/infographicvqa and https://rrc.cvc.uab.es/?ch=17&com=introduction | ||
|  | 
 | ||
|  | import os, json | ||
|  | import argparse | ||
|  | 
 | ||
|  | question_ids_to_exclude = [] | ||
|  | 
 | ||
|  | # answer_types = {'image span': 'Image-Span', 'question span': 'Question-Span', 'multiple spans': 'Multi-Span', 'non span': 'None span', 'list': 'List'} | ||
|  | answer_types = {'image span': 'Image-Span', 'question span': 'Question-Span', 'multiple spans': 'Multi-Span', 'non span': 'None span'} | ||
|  | evidence_types = {'table/list': 'Table/list', 'textual': 'Text', 'photo/pciture/visual_objects': 'Visual/Layout', 'figure': 'Figure', 'map': 'Map'} | ||
|  | reasoning_requirements = {'comparison': 'Sorting', 'arithmetic': 'Arithmetic', 'counting':'Counting'} | ||
|  | 
 | ||
|  | 
 | ||
|  | def save_json(file_path, data): | ||
|  |     with open(file_path, 'w+') as json_file: | ||
|  |         json.dump(data, json_file) | ||
|  | 
 | ||
|  | 
 | ||
|  | 
 | ||
|  | def levenshtein_distance(s1, s2): | ||
|  |     if len(s1) > len(s2): | ||
|  |         s1, s2 = s2, s1 | ||
|  | 
 | ||
|  |     distances = range(len(s1) + 1) | ||
|  |     for i2, c2 in enumerate(s2): | ||
|  |         distances_ = [i2+1] | ||
|  |         for i1, c1 in enumerate(s1): | ||
|  |             if c1 == c2: | ||
|  |                 distances_.append(distances[i1]) | ||
|  |             else: | ||
|  |                 distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) | ||
|  |         distances = distances_ | ||
|  |     return distances[-1] | ||
|  | 
 | ||
|  | 
 | ||
|  | def validate_data(gtFilePath, submFilePath): | ||
|  |     """
 | ||
|  |     Method validate_data: validates that all files in the results folder are correct (have the correct name contents). | ||
|  |                             Validates also that there are no missing files in the folder. | ||
|  |                             If some error detected, the method raises the error | ||
|  |     """
 | ||
|  |      | ||
|  |     gtJson = json.load(open(gtFilePath,'rb')); | ||
|  |     submJson = json.load(open(submFilePath,'rb')); | ||
|  |      | ||
|  |     if not 'data' in gtJson: | ||
|  |         raise Exception("The GT file is not valid (no data key)") | ||
|  |      | ||
|  |     if not 'dataset_name' in gtJson: | ||
|  |         raise Exception("The GT file is not valid (no dataset_name key)")       | ||
|  |      | ||
|  |     if isinstance(submJson, list) == False : | ||
|  |         raise Exception("The Det file is not valid (root item must be an array)") | ||
|  |          | ||
|  |     if len(submJson) != len(gtJson['data']) : | ||
|  |         raise Exception("The Det file is not valid (invalid number of answers. Expected:" + str(len(gtJson['data'])) + " Found:" + str(len(submJson)) + ")")     | ||
|  |      | ||
|  |     gtQuestions = sorted([r['questionId'] for r in gtJson['data']]) | ||
|  |     res_id_to_index = {int(r['questionId']): ix for ix, r in enumerate(submJson)} | ||
|  |     detQuestions = sorted([r['questionId'] for r in submJson]) | ||
|  |      | ||
|  |     if( (gtQuestions == detQuestions) == False ): | ||
|  |         raise Exception("The Det file is not valid. Question IDs must much GT")     | ||
|  |      | ||
|  |     for gtObject in gtJson['data']: | ||
|  |          | ||
|  |         try: | ||
|  |             q_id = int(gtObject['questionId']); | ||
|  |             res_ix = res_id_to_index[q_id]; | ||
|  |              | ||
|  |         except: | ||
|  |             raise Exception("The Det file is not valid. Question " + str(gtObject['questionId']) + " not present") | ||
|  |          | ||
|  |         else: | ||
|  |             detObject = submJson[res_ix]; | ||
|  |              | ||
|  | #            if detObject['questionId'] != gtObject['questionId'] : | ||
|  | #                raise Exception("Answer #" + str(i) + " not valid (invalid question ID. Expected:" + str(gtObject['questionId']) + "Found:" + detObject['questionId'] + ")") | ||
|  | 
 | ||
|  |             if not 'answer' in detObject: | ||
|  |                 raise Exception("Question " + str(gtObject['questionId']) + " not valid (no answer key)") | ||
|  | 
 | ||
|  |             if isinstance(detObject['answer'], list) == True : | ||
|  |                 raise Exception("Question " + str(gtObject['questionId']) + " not valid (answer key has to be a single string)") | ||
|  | 
 | ||
|  | 
 | ||
|  | def evaluate_method(gtFilePath, submFilePath, evaluationParams): | ||
|  |     """
 | ||
|  |     Method evaluate_method: evaluate method and returns the results | ||
|  |         Results. Dictionary with the following values: | ||
|  |         - method (required)  Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } | ||
|  |         - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } | ||
|  |     """  
 | ||
|  |      | ||
|  |     show_scores_per_answer_type = evaluationParams.answer_types | ||
|  | 
 | ||
|  |     gtJson = json.load(open(gtFilePath,'rb')); | ||
|  |     submJson = json.load(open(submFilePath,'rb')); | ||
|  |      | ||
|  |     res_id_to_index = {int(r['questionId']): ix for ix, r in enumerate(submJson)} | ||
|  |      | ||
|  |      | ||
|  |     perSampleMetrics = {} | ||
|  |      | ||
|  |     totalScore = 0 | ||
|  |     row = 0 | ||
|  |      | ||
|  |     if show_scores_per_answer_type: | ||
|  | 	    answerTypeTotalScore = {x:0 for x in answer_types.keys()} | ||
|  | 	    answerTypeNumQuestions = {x:0 for x in answer_types.keys()} | ||
|  | 
 | ||
|  | 	    evidenceTypeTotalScore = {x:0 for x in evidence_types.keys()} | ||
|  | 	    evidenceTypeNumQuestions = {x:0 for x in evidence_types.keys()} | ||
|  | 
 | ||
|  | 	    reasoningTypeTotalScore = {x:0 for x in reasoning_requirements.keys()} | ||
|  | 	    reasoningTypeNumQuestions = {x:0 for x in reasoning_requirements.keys()} | ||
|  |      | ||
|  |     for gtObject in gtJson['data']: | ||
|  | 
 | ||
|  |         q_id = int(gtObject['questionId']); | ||
|  |         res_ix = res_id_to_index[q_id]; | ||
|  |         detObject = submJson[res_ix]; | ||
|  | 
 | ||
|  |         if q_id in question_ids_to_exclude: | ||
|  |             question_result = 0 | ||
|  |             info = 'Question EXCLUDED from the result' | ||
|  |              | ||
|  |         else: | ||
|  |             info = '' | ||
|  |             values = [] | ||
|  |             for answer in gtObject['answers']: | ||
|  |                 # preprocess both the answers - gt and prediction | ||
|  |                 gt_answer = ' '.join(answer.strip().lower().split()) | ||
|  |                 det_answer = ' '.join(detObject['answer'].strip().lower().split()) | ||
|  | 
 | ||
|  |                 #dist = levenshtein_distance(answer.lower(), detObject['answer'].lower()) | ||
|  |                 dist = levenshtein_distance(gt_answer,det_answer) | ||
|  |                 length = max( len(answer.upper()), len(detObject['answer'].upper()) ) | ||
|  |                 values.append( 0.0 if length == 0 else float(dist) / float(length) ) | ||
|  | 
 | ||
|  |             question_result = 1 - min(values) | ||
|  |          | ||
|  |             if (question_result < evaluationParams.anls_threshold) : | ||
|  |                 question_result = 0 | ||
|  | 
 | ||
|  |             totalScore += question_result | ||
|  |              | ||
|  |             if show_scores_per_answer_type: | ||
|  |                 for q_type in gtObject["answer_type"]: | ||
|  |                     answerTypeTotalScore[q_type] += question_result | ||
|  |                     answerTypeNumQuestions[q_type] += 1 | ||
|  | 
 | ||
|  |                 for q_type in gtObject["evidence"]: | ||
|  |                     evidenceTypeTotalScore[q_type] += question_result | ||
|  |                     evidenceTypeNumQuestions[q_type] += 1 | ||
|  | 
 | ||
|  |                 for q_type in gtObject["operation/reasoning"]: | ||
|  |                     reasoningTypeTotalScore[q_type] += question_result | ||
|  |                     reasoningTypeNumQuestions[q_type] += 1 | ||
|  |                  | ||
|  |          | ||
|  |         perSampleMetrics[str(gtObject['questionId'])] = { | ||
|  |                                 'score':question_result, | ||
|  |                                 'question':gtObject['question'], | ||
|  |                                 'gt':gtObject['answers'], | ||
|  |                                 'det':detObject['answer'], | ||
|  |                                 'info': info | ||
|  |                                 } | ||
|  |         row = row + 1 | ||
|  | 
 | ||
|  |                                  | ||
|  |     methodMetrics = { | ||
|  | 		'score': 0 if len(gtJson['data']) == 0 else totalScore/ (len(gtJson['data']) - len(question_ids_to_exclude) ) | ||
|  |     } | ||
|  | 
 | ||
|  |     answer_types_scores = {} | ||
|  |     evidence_types_scores = {} | ||
|  |     operation_types_scores = {} | ||
|  | 
 | ||
|  |     if show_scores_per_answer_type: | ||
|  |         for a_type, ref in answer_types.items(): | ||
|  |             answer_types_scores[ref] = 0 if len(gtJson['data']) == 0 else answerTypeTotalScore[a_type] / (answerTypeNumQuestions[a_type] ) | ||
|  | 
 | ||
|  |         for e_type, ref in evidence_types.items(): | ||
|  |             evidence_types_scores[ref] = 0 if len(gtJson['data']) == 0 else evidenceTypeTotalScore[e_type] / (evidenceTypeNumQuestions[e_type] ) | ||
|  | 
 | ||
|  |         for r_type, ref in reasoning_requirements.items(): | ||
|  |             operation_types_scores[ref] = 0 if len(gtJson['data']) == 0 else reasoningTypeTotalScore[r_type] / (reasoningTypeNumQuestions[r_type] ) | ||
|  | 
 | ||
|  | 
 | ||
|  |     resDict = { | ||
|  |             'result': methodMetrics,  | ||
|  |             'scores_by_types': {'answer_types': answer_types_scores, 'evidence_types': evidence_types_scores, 'operation_types': operation_types_scores}, | ||
|  |             'per_sample_result':perSampleMetrics | ||
|  |             } | ||
|  | 
 | ||
|  |     return resDict; | ||
|  | 
 | ||
|  | 
 | ||
|  | def display_results(results, show_answer_types): | ||
|  |     print("\nOverall ANLS: {:2.4f}".format(results['result']['score'])) | ||
|  | 
 | ||
|  |     if show_answer_types: | ||
|  |         print("\nAnswer types:") | ||
|  |         for a_type in answer_types.values(): | ||
|  |             print("\t{:12s} {:2.4f}".format(a_type, results['scores_by_types']['answer_types'][a_type])) | ||
|  | 
 | ||
|  |         print("\nEvidence types:") | ||
|  |         for e_type in evidence_types.values(): | ||
|  |             print("\t{:12s} {:2.4f}".format(e_type, results['scores_by_types']['evidence_types'][e_type])) | ||
|  | 
 | ||
|  |         print("\nOperation required:") | ||
|  |         for r_type in reasoning_requirements.values(): | ||
|  |             print("\t{:12s} {:2.4f}".format(r_type, results['scores_by_types']['operation_types'][r_type])) | ||
|  | 
 | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__=='__main__': | ||
|  |     parser = argparse.ArgumentParser(description="InfographVQA evaluation script.") | ||
|  | 
 | ||
|  |     parser.add_argument('-g', '--ground_truth', type=str, help="Path of the Ground Truth file.", required=True) | ||
|  |     parser.add_argument('-s', '--submission_file', type=str, help="Path of your method's results file.", required=True) | ||
|  | 
 | ||
|  |     parser.add_argument('-t', '--anls_threshold', type=float, default=0.5, help="ANLS threshold to use (See Scene-Text VQA paper for more info.).", required=False) | ||
|  |     parser.add_argument('-a', '--answer_types', type=bool, default=False, help="Score break down by answer types (special gt file required).", required=False) | ||
|  |     parser.add_argument('-o', '--output', type=str, help="Path to a directory where to copy the file 'results.json' that contains per-sample results.", required=False) | ||
|  | 
 | ||
|  |     args = parser.parse_args() | ||
|  | 
 | ||
|  |     # Validate the format of ground truth and submission files. | ||
|  |     validate_data(args.ground_truth, args.submission_file) | ||
|  | 
 | ||
|  |     # Evaluate method | ||
|  |     results = evaluate_method(args.ground_truth, args.submission_file, args) | ||
|  | 
 | ||
|  |     display_results(results, args.answer_types) | ||
|  | 
 | ||
|  |     if args.output: | ||
|  |         output_dir = args.output | ||
|  | 
 | ||
|  |         if not os.path.exists(output_dir): | ||
|  |             os.makedirs(output_dir) | ||
|  | 
 | ||
|  |         resultsOutputname = os.path.join(output_dir, 'results.json') | ||
|  |         save_json(resultsOutputname, results) | ||
|  | 
 | ||
|  |         print("All results including per-sample result has been correctly saved!") | ||
|  | 
 |