You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

214 lines
6.9 KiB

import argparse
import itertools
import json
import os
import re
from functools import partial
import torch
from torchvision.ops.boxes import box_area
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
ds_collections = {
'refcoco_val': 'data/refcoco/refcoco_val.jsonl',
'refcoco_testA': 'data/refcoco/refcoco_testA.jsonl',
'refcoco_testB': 'data/refcoco/refcoco_testB.jsonl',
'refcoco+_val': 'data/refcoco+/refcoco+_val.jsonl',
'refcoco+_testA': 'data/refcoco+/refcoco+_testA.jsonl',
'refcoco+_testB': 'data/refcoco+/refcoco+_testB.jsonl',
'refcocog_val': 'data/refcocog/refcocog_val.jsonl',
'refcocog_test': 'data/refcocog/refcocog_test.jsonl',
}
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union
def collate_fn(batches, tokenizer):
texts = [_['text'] for _ in batches]
bboxes = [_['bbox'] for _ in batches]
hws = [_['hw'] for _ in batches]
input_ids = tokenizer(texts, return_tensors='pt', padding='longest')
return input_ids.input_ids, input_ids.attention_mask, bboxes, hws
class RefCOCODataset(torch.utils.data.Dataset):
def __init__(self, test, tokenizer, prompt):
self.datas = open(test).readlines()
self.tokenizer = tokenizer
self.prompt = prompt
def __len__(self):
return len(self.datas)
def __getitem__(self, idx):
data = json.loads(self.datas[idx].strip())
image = data['image']
text = data['sent']
bbox = data['bbox']
w, h = data['width'], data['height']
return {
'text': self.prompt.format(image, text),
'bbox': bbox,
'hw': (h, w),
}
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size,
self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[:rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='')
parser.add_argument('--dataset', type=str, default='')
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=1)
args = parser.parse_args()
torch.distributed.init_process_group(
backend='nccl',
world_size=int(os.getenv('WORLD_SIZE', '1')),
rank=int(os.getenv('RANK', '0')),
)
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint, device_map='cuda', trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint,
trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
prompt = '<img>{}</img><ref>{}</ref><box>'
dataset = RefCOCODataset(test=ds_collections[args.dataset],
tokenizer=tokenizer,
prompt=prompt)
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len(dataset)),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True,
collate_fn=partial(collate_fn, tokenizer=tokenizer),
)
outputs = []
for _, (input_ids, attention_mask, bboxes,
hws) in tqdm(enumerate(dataloader)):
pred = model.generate(
input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(),
do_sample=False,
num_beams=1,
max_new_tokens=28,
min_new_tokens=10,
length_penalty=1,
num_return_sequences=1,
use_cache=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id,
)
answers = [
tokenizer.decode(_[input_ids.size(1):].cpu(),
skip_special_tokens=True) for _ in pred
]
for bbox, hw, answer in zip(bboxes, hws, answers):
outputs.append({
'answer': answer,
'gt_bbox': bbox,
'hw': hw,
})
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
merged_outputs = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_outputs, outputs)
merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
PATTERN = re.compile(r'\((.*?)\),\((.*?)\)')
if torch.distributed.get_rank() == 0:
correct = total_cnt = 0
for i, output in enumerate(merged_outputs):
predict_bbox = re.findall(PATTERN, output['answer'])
try:
if ',' not in predict_bbox[0][0] or ',' not in predict_bbox[0][
1]:
predict_bbox = (0., 0., 0., 0.)
else:
x1, y1 = [
float(tmp) for tmp in predict_bbox[0][0].split(',')
]
x2, y2 = [
float(tmp) for tmp in predict_bbox[0][1].split(',')
]
predict_bbox = (x1, y1, x2, y2)
except:
predict_bbox = (0., 0., 0., 0.)
target_bbox = torch.tensor(output['gt_bbox'],
dtype=torch.float32).view(-1, 4)
predict_bbox = torch.tensor(predict_bbox,
dtype=torch.float32).view(-1, 4) / 999
predict_bbox[:, 0::2] *= output['hw'][1]
predict_bbox[:, 1::2] *= output['hw'][0]
iou, _ = box_iou(predict_bbox, target_bbox)
iou = iou.item()
total_cnt += 1
if iou >= 0.5:
correct += 1
print(f"Evaluating {args.dataset} ...")
print(f'Precision @ 1: {correct / total_cnt} \n')
torch.distributed.barrier()