|
|
@ -112,8 +112,8 @@ lable_2_id_content = { |
|
|
"中文摘要": 2, |
|
|
"中文摘要": 2, |
|
|
"中文关键词": 3, |
|
|
"中文关键词": 3, |
|
|
"英文关键词": 4, |
|
|
"英文关键词": 4, |
|
|
"图": 5, |
|
|
"图题": 5, |
|
|
"表": 6, |
|
|
"表题": 6, |
|
|
"参考文献": 7 |
|
|
"参考文献": 7 |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@ -135,13 +135,13 @@ for i in lable_2_id_title_no_title: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
"data_zong_shout_3", |
|
|
"data_zong_shout_4", |
|
|
use_fast=True, |
|
|
use_fast=True, |
|
|
revision="main", |
|
|
revision="main", |
|
|
trust_remote_code=False, |
|
|
trust_remote_code=False, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
model_name = "data_zong_shout_3" |
|
|
model_name = "data_zong_shout_4" |
|
|
config = AutoConfig.from_pretrained( |
|
|
config = AutoConfig.from_pretrained( |
|
|
model_name, |
|
|
model_name, |
|
|
num_labels=len(lable_2_id_fenji), |
|
|
num_labels=len(lable_2_id_fenji), |
|
|
@ -156,7 +156,7 @@ model_roberta_zong = AutoModelForSequenceClassification.from_pretrained( |
|
|
ignore_mismatched_sizes=False, |
|
|
ignore_mismatched_sizes=False, |
|
|
).to(device) |
|
|
).to(device) |
|
|
|
|
|
|
|
|
model_name = "data_zong_no_start_shout_3" |
|
|
model_name = "data_zong_no_start_shout_4" |
|
|
config = AutoConfig.from_pretrained( |
|
|
config = AutoConfig.from_pretrained( |
|
|
model_name, |
|
|
model_name, |
|
|
num_labels=len(lable_2_id_fenji), |
|
|
num_labels=len(lable_2_id_fenji), |
|
|
@ -171,7 +171,7 @@ model_roberta_zong_no_start = AutoModelForSequenceClassification.from_pretrained |
|
|
ignore_mismatched_sizes=False, |
|
|
ignore_mismatched_sizes=False, |
|
|
).to(device) |
|
|
).to(device) |
|
|
|
|
|
|
|
|
model_name = "data_zong_no_end_shout_3" |
|
|
model_name = "data_zong_no_end_shout_4" |
|
|
config = AutoConfig.from_pretrained( |
|
|
config = AutoConfig.from_pretrained( |
|
|
model_name, |
|
|
model_name, |
|
|
num_labels=len(lable_2_id_fenji), |
|
|
num_labels=len(lable_2_id_fenji), |
|
|
@ -201,14 +201,14 @@ model_roberta_zong_no_end = AutoModelForSequenceClassification.from_pretrained( |
|
|
# ignore_mismatched_sizes=False, |
|
|
# ignore_mismatched_sizes=False, |
|
|
# ).to(device) |
|
|
# ).to(device) |
|
|
|
|
|
|
|
|
model_name = "data_content_roberta" |
|
|
model_name = "data_title_roberta_4" |
|
|
config = AutoConfig.from_pretrained( |
|
|
config = AutoConfig.from_pretrained( |
|
|
model_name, |
|
|
model_name, |
|
|
num_labels=len(lable_2_id_content), |
|
|
num_labels=len(lable_2_id_title), |
|
|
revision="main", |
|
|
revision="main", |
|
|
trust_remote_code=False |
|
|
trust_remote_code=False |
|
|
) |
|
|
) |
|
|
model_content_roberta = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_title_roberta_cls = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_name, |
|
|
model_name, |
|
|
config=config, |
|
|
config=config, |
|
|
revision="main", |
|
|
revision="main", |
|
|
@ -216,14 +216,14 @@ model_content_roberta = AutoModelForSequenceClassification.from_pretrained( |
|
|
ignore_mismatched_sizes=False, |
|
|
ignore_mismatched_sizes=False, |
|
|
).to(device) |
|
|
).to(device) |
|
|
|
|
|
|
|
|
model_name = "data_content_roberta_no_end" |
|
|
model_name = "data_title_no_title_roberta_4" |
|
|
config = AutoConfig.from_pretrained( |
|
|
config = AutoConfig.from_pretrained( |
|
|
model_name, |
|
|
model_name, |
|
|
num_labels=len(lable_2_id_content), |
|
|
num_labels=len(lable_2_id_title_no_title), |
|
|
revision="main", |
|
|
revision="main", |
|
|
trust_remote_code=False |
|
|
trust_remote_code=False |
|
|
) |
|
|
) |
|
|
model_content_roberta_no_end = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_title_roberta_cls_2 = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_name, |
|
|
model_name, |
|
|
config=config, |
|
|
config=config, |
|
|
revision="main", |
|
|
revision="main", |
|
|
@ -231,14 +231,15 @@ model_content_roberta_no_end = AutoModelForSequenceClassification.from_pretraine |
|
|
ignore_mismatched_sizes=False, |
|
|
ignore_mismatched_sizes=False, |
|
|
).to(device) |
|
|
).to(device) |
|
|
|
|
|
|
|
|
model_name = "data_content_roberta_no_start" |
|
|
|
|
|
|
|
|
model_name = "data_content_roberta_4" |
|
|
config = AutoConfig.from_pretrained( |
|
|
config = AutoConfig.from_pretrained( |
|
|
model_name, |
|
|
model_name, |
|
|
num_labels=len(lable_2_id_content), |
|
|
num_labels=len(lable_2_id_content), |
|
|
revision="main", |
|
|
revision="main", |
|
|
trust_remote_code=False |
|
|
trust_remote_code=False |
|
|
) |
|
|
) |
|
|
model_content_roberta_no_start = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_content_roberta = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_name, |
|
|
model_name, |
|
|
config=config, |
|
|
config=config, |
|
|
revision="main", |
|
|
revision="main", |
|
|
@ -246,14 +247,14 @@ model_content_roberta_no_start = AutoModelForSequenceClassification.from_pretrai |
|
|
ignore_mismatched_sizes=False, |
|
|
ignore_mismatched_sizes=False, |
|
|
).to(device) |
|
|
).to(device) |
|
|
|
|
|
|
|
|
model_name = "data_title_roberta_2" |
|
|
model_name = "data_content_roberta_no_end_4" |
|
|
config = AutoConfig.from_pretrained( |
|
|
config = AutoConfig.from_pretrained( |
|
|
model_name, |
|
|
model_name, |
|
|
num_labels=len(lable_2_id_title), |
|
|
num_labels=len(lable_2_id_content), |
|
|
revision="main", |
|
|
revision="main", |
|
|
trust_remote_code=False |
|
|
trust_remote_code=False |
|
|
) |
|
|
) |
|
|
model_title_roberta = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_content_roberta_no_end = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_name, |
|
|
model_name, |
|
|
config=config, |
|
|
config=config, |
|
|
revision="main", |
|
|
revision="main", |
|
|
@ -261,14 +262,14 @@ model_title_roberta = AutoModelForSequenceClassification.from_pretrained( |
|
|
ignore_mismatched_sizes=False, |
|
|
ignore_mismatched_sizes=False, |
|
|
).to(device) |
|
|
).to(device) |
|
|
|
|
|
|
|
|
model_name = "data_title_no_title_roberta_2" |
|
|
model_name = "data_content_roberta_no_start_4" |
|
|
config = AutoConfig.from_pretrained( |
|
|
config = AutoConfig.from_pretrained( |
|
|
model_name, |
|
|
model_name, |
|
|
num_labels=len(lable_2_id_title_no_title), |
|
|
num_labels=len(lable_2_id_content), |
|
|
revision="main", |
|
|
revision="main", |
|
|
trust_remote_code=False |
|
|
trust_remote_code=False |
|
|
) |
|
|
) |
|
|
model_data_title_no_title_roberta = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_content_roberta_no_start = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_name, |
|
|
model_name, |
|
|
config=config, |
|
|
config=config, |
|
|
revision="main", |
|
|
revision="main", |
|
|
@ -282,7 +283,6 @@ model_data_title_roberta_ner = AutoModelForTokenClassification.from_pretrained(m |
|
|
model_data_title_roberta_ner.eval().to(device) |
|
|
model_data_title_roberta_ner.eval().to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gen_zong_cls(content_list): |
|
|
def gen_zong_cls(content_list): |
|
|
|
|
|
|
|
|
paper_quanwen_lable_list = [] |
|
|
paper_quanwen_lable_list = [] |
|
|
@ -322,7 +322,7 @@ def gen_zong_cls(content_list): |
|
|
new_sen_list = new_sen_list + [content_list[right][1][:30]] |
|
|
new_sen_list = new_sen_list + [content_list[right][1][:30]] |
|
|
|
|
|
|
|
|
new_sen = "\n".join(new_sen_list) |
|
|
new_sen = "\n".join(new_sen_list) |
|
|
if len(new_sen) > 510 or left_end_bool == False or right_end_bool == False: |
|
|
if len(new_sen) > 510 or (left_end_bool == False and right_end_bool == False): |
|
|
break |
|
|
break |
|
|
else: |
|
|
else: |
|
|
old_sen = new_sen |
|
|
old_sen = new_sen |
|
|
@ -399,7 +399,9 @@ def gen_zong_cls(content_list): |
|
|
|
|
|
|
|
|
res_score = {} |
|
|
res_score = {} |
|
|
# 目标句子在中间预测结果 |
|
|
# 目标句子在中间预测结果 |
|
|
sentence_list = [sentence_zong_zhong[0]] |
|
|
|
|
|
|
|
|
sentence_zong_zhong_str = sentence_zong_zhong[0].replace("\n", "[SEP]") |
|
|
|
|
|
sentence_list = [sentence_zong_zhong_str] |
|
|
# sentence_list = [data[1][0]] |
|
|
# sentence_list = [data[1][0]] |
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
|
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
|
|
@ -410,27 +412,29 @@ def gen_zong_cls(content_list): |
|
|
else: |
|
|
else: |
|
|
res_score[predicted_class_idx_zhong] += sentence_zong_zhong[1] |
|
|
res_score[predicted_class_idx_zhong] += sentence_zong_zhong[1] |
|
|
|
|
|
|
|
|
sentence_list = [sentence_zong_no_end[0]] |
|
|
sentence_zong_no_end_str = sentence_zong_no_end[0].replace("\n", "[SEP]") |
|
|
|
|
|
sentence_list = [sentence_zong_no_end_str] |
|
|
# sentence_list = [data[1][0]] |
|
|
# sentence_list = [data[1][0]] |
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
|
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
|
|
logits = model_roberta_zong_no_end(**result_on_device) |
|
|
logits = model_roberta_zong_no_end(**result_on_device) |
|
|
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item() |
|
|
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item() |
|
|
if predicted_class_idx_zhong not in res_score: |
|
|
if predicted_class_idx_qian not in res_score: |
|
|
res_score[predicted_class_idx_zhong] = sentence_zong_no_end[1] |
|
|
res_score[predicted_class_idx_qian] = sentence_zong_no_end[1] |
|
|
else: |
|
|
else: |
|
|
res_score[predicted_class_idx_zhong] += sentence_zong_no_end[1] |
|
|
res_score[predicted_class_idx_qian] += sentence_zong_no_end[1] |
|
|
|
|
|
|
|
|
sentence_list = [sentence_zong_no_start[0]] |
|
|
sentence_zong_no_start_str = sentence_zong_no_start[0].replace("\n", "[SEP]") |
|
|
|
|
|
sentence_list = [sentence_zong_no_start_str] |
|
|
# sentence_list = [data[1][0]] |
|
|
# sentence_list = [data[1][0]] |
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
|
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
|
|
logits = model_roberta_zong_no_start(**result_on_device) |
|
|
logits = model_roberta_zong_no_start(**result_on_device) |
|
|
predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item() |
|
|
predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item() |
|
|
if predicted_class_idx_zhong not in res_score: |
|
|
if predicted_class_idx_hou not in res_score: |
|
|
res_score[predicted_class_idx_zhong] = sentence_zong_no_end[1] |
|
|
res_score[predicted_class_idx_hou] = sentence_zong_no_start[1] |
|
|
else: |
|
|
else: |
|
|
res_score[predicted_class_idx_zhong] += sentence_zong_no_end[1] |
|
|
res_score[predicted_class_idx_hou] += sentence_zong_no_start[1] |
|
|
|
|
|
|
|
|
res_score_list = sorted(res_score.items(), key=lambda item: item[1], reverse=True) |
|
|
res_score_list = sorted(res_score.items(), key=lambda item: item[1], reverse=True) |
|
|
predicted_class_idx = res_score_list[0][0] |
|
|
predicted_class_idx = res_score_list[0][0] |
|
|
@ -445,7 +449,7 @@ def gen_zong_cls(content_list): |
|
|
|
|
|
|
|
|
def gen_title_cls(content_list): |
|
|
def gen_title_cls(content_list): |
|
|
paper_quanwen_lable_list = [] |
|
|
paper_quanwen_lable_list = [] |
|
|
for index, paper_sen in content_list: |
|
|
for index, paper_sen in enumerate(content_list): |
|
|
|
|
|
|
|
|
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[0:index]] |
|
|
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[0:index]] |
|
|
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:len(content_list)]] |
|
|
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:len(content_list)]] |
|
|
@ -453,7 +457,7 @@ def gen_title_cls(content_list): |
|
|
# print(len(paper_end_list)) |
|
|
# print(len(paper_end_list)) |
|
|
paper_new_start = "\n".join(paper_start_list) |
|
|
paper_new_start = "\n".join(paper_start_list) |
|
|
paper_new_end = "\n".join(paper_end_list) |
|
|
paper_new_end = "\n".join(paper_end_list) |
|
|
paper_object_dangqian = "<Start>" + paper_sen + "<End>" |
|
|
paper_object_dangqian = "<Start>" + paper_sen[1] + "<End>" |
|
|
paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end]) |
|
|
paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end]) |
|
|
paper_zhong = paper_zhong.strip("\n") |
|
|
paper_zhong = paper_zhong.strip("\n") |
|
|
|
|
|
|
|
|
@ -495,18 +499,84 @@ def gen_title_cls(content_list): |
|
|
paper_zhong = old_sen |
|
|
paper_zhong = old_sen |
|
|
|
|
|
|
|
|
# 目标句子在中间预测结果 |
|
|
# 目标句子在中间预测结果 |
|
|
|
|
|
paper_zhong = paper_zhong.replace("\n", "[SEP]") |
|
|
sentence_list = [paper_zhong] |
|
|
sentence_list = [paper_zhong] |
|
|
# sentence_list = [data[1][0]] |
|
|
# sentence_list = [data[1][0]] |
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
|
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
|
|
logits = model_title_roberta(**result_on_device) |
|
|
logits = model_title_roberta_cls(**result_on_device) |
|
|
predicted_class_idx = torch.argmax(logits[0], dim=1).item() |
|
|
predicted_class_idx = torch.argmax(logits[0], dim=1).item() |
|
|
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_title[predicted_class_idx]]) |
|
|
paper_quanwen_lable_list.append([paper_sen[0], paper_sen[1], id_2_lable_title[predicted_class_idx]]) |
|
|
|
|
|
|
|
|
|
|
|
return paper_quanwen_lable_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gen_title_cls_2(content_list): |
|
|
|
|
|
paper_quanwen_lable_list = [] |
|
|
|
|
|
for index, paper_sen in enumerate(content_list): |
|
|
|
|
|
|
|
|
|
|
|
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[0:index]] |
|
|
|
|
|
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:len(content_list)]] |
|
|
|
|
|
# print(len(paper_start_list)) |
|
|
|
|
|
# print(len(paper_end_list)) |
|
|
|
|
|
paper_new_start = "\n".join(paper_start_list) |
|
|
|
|
|
paper_new_end = "\n".join(paper_end_list) |
|
|
|
|
|
paper_object_dangqian = "<Start>" + paper_sen[1] + "<End>" |
|
|
|
|
|
paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end]) |
|
|
|
|
|
paper_zhong = paper_zhong.strip("\n") |
|
|
|
|
|
|
|
|
|
|
|
if len(paper_zhong) > 510: |
|
|
|
|
|
data_paper_list = str(paper_zhong).split("\n") |
|
|
|
|
|
start_index = 0 |
|
|
|
|
|
for i in range(len(data_paper_list)): |
|
|
|
|
|
if "<Start>" in data_paper_list[i]: |
|
|
|
|
|
start_index = i |
|
|
|
|
|
break |
|
|
|
|
|
left_end = 0 |
|
|
|
|
|
right_end = len(data_paper_list) - 1 |
|
|
|
|
|
left = start_index |
|
|
|
|
|
right = start_index |
|
|
|
|
|
left_end_bool = True |
|
|
|
|
|
right_end_bool = True |
|
|
|
|
|
old_sen = data_paper_list[start_index] |
|
|
|
|
|
while True: |
|
|
|
|
|
if left - 1 >= left_end: |
|
|
|
|
|
left = left - 1 |
|
|
|
|
|
else: |
|
|
|
|
|
left_end_bool = False |
|
|
|
|
|
if right + 1 <= right_end: |
|
|
|
|
|
right = right + 1 |
|
|
|
|
|
else: |
|
|
|
|
|
right_end_bool = False |
|
|
|
|
|
|
|
|
|
|
|
new_sen_list = [old_sen] |
|
|
|
|
|
if left_end_bool == True: |
|
|
|
|
|
new_sen_list = [data_paper_list[left]] + new_sen_list |
|
|
|
|
|
if right_end_bool == True: |
|
|
|
|
|
new_sen_list = new_sen_list + [data_paper_list[right]] |
|
|
|
|
|
|
|
|
|
|
|
new_sen = "\n".join(new_sen_list) |
|
|
|
|
|
if len(new_sen) > 510: |
|
|
|
|
|
break |
|
|
|
|
|
else: |
|
|
|
|
|
old_sen = new_sen |
|
|
|
|
|
paper_zhong = old_sen |
|
|
|
|
|
|
|
|
|
|
|
# 目标句子在中间预测结果 |
|
|
|
|
|
paper_zhong = paper_zhong.replace("\n", "[SEP]") |
|
|
|
|
|
sentence_list = [paper_zhong] |
|
|
|
|
|
# sentence_list = [data[1][0]] |
|
|
|
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|
|
|
|
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
|
|
|
|
|
logits = model_title_roberta_cls_2(**result_on_device) |
|
|
|
|
|
predicted_class_idx = torch.argmax(logits[0], dim=1).item() |
|
|
|
|
|
paper_quanwen_lable_list.append([paper_sen[0], paper_sen[1], id_2_lable_title_no_title[predicted_class_idx]]) |
|
|
|
|
|
|
|
|
return paper_quanwen_lable_list |
|
|
return paper_quanwen_lable_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gen_content_cls(content_list): |
|
|
def gen_content_cls(content_list): |
|
|
|
|
|
content_list = sorted(content_list, key=lambda item: item[0]) |
|
|
paper_quanwen_lable_list = [] |
|
|
paper_quanwen_lable_list = [] |
|
|
for index, paper_sen in content_list: |
|
|
for index, paper_sen in content_list: |
|
|
# 视野前后7句 |
|
|
# 视野前后7句 |
|
|
@ -534,6 +604,7 @@ def gen_content_cls(content_list): |
|
|
paper_hou = "\n".join([paper_object_dangqian, paper_new_end]) |
|
|
paper_hou = "\n".join([paper_object_dangqian, paper_new_end]) |
|
|
|
|
|
|
|
|
# 目标句子在中间预测结果 |
|
|
# 目标句子在中间预测结果 |
|
|
|
|
|
paper_zhong = paper_zhong.replace("\n", "[SEP]") |
|
|
sentence_list = [paper_zhong] |
|
|
sentence_list = [paper_zhong] |
|
|
# sentence_list = [data[1][0]] |
|
|
# sentence_list = [data[1][0]] |
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|
|
@ -541,6 +612,7 @@ def gen_content_cls(content_list): |
|
|
logits = model_content_roberta(**result_on_device) |
|
|
logits = model_content_roberta(**result_on_device) |
|
|
predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item() |
|
|
predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item() |
|
|
|
|
|
|
|
|
|
|
|
paper_qian = paper_qian.replace("\n", "[SEP]") |
|
|
sentence_list = [paper_qian] |
|
|
sentence_list = [paper_qian] |
|
|
# sentence_list = [data[1][0]] |
|
|
# sentence_list = [data[1][0]] |
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|
|
@ -548,6 +620,7 @@ def gen_content_cls(content_list): |
|
|
logits = model_content_roberta_no_end(**result_on_device) |
|
|
logits = model_content_roberta_no_end(**result_on_device) |
|
|
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item() |
|
|
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item() |
|
|
|
|
|
|
|
|
|
|
|
paper_hou = paper_hou.replace("\n", "[SEP]") |
|
|
sentence_list = [paper_hou] |
|
|
sentence_list = [paper_hou] |
|
|
# sentence_list = [data[1][0]] |
|
|
# sentence_list = [data[1][0]] |
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|
|
@ -575,6 +648,7 @@ def gen_content_cls(content_list): |
|
|
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_content[predicted_class_idx]]) |
|
|
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_content[predicted_class_idx]]) |
|
|
return paper_quanwen_lable_list |
|
|
return paper_quanwen_lable_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_lists_recursive(a, b, a_soc, b_soc, target_size=510, result_a=None, result_b=None): |
|
|
def split_lists_recursive(a, b, a_soc, b_soc, target_size=510, result_a=None, result_b=None): |
|
|
""" |
|
|
""" |
|
|
递归地同时分割两个列表,保持一一对应关系 |
|
|
递归地同时分割两个列表,保持一一对应关系 |
|
|
@ -631,11 +705,10 @@ def split_lists_recursive(a, b, a_soc, b_soc, target_size=510, result_a=None, re |
|
|
# 剩余部分 |
|
|
# 剩余部分 |
|
|
# target_size = current_chunk_size |
|
|
# target_size = current_chunk_size |
|
|
while True: |
|
|
while True: |
|
|
if b[target_size_new][0] == "B": |
|
|
if b[target_size_new] == "-100": |
|
|
break |
|
|
break |
|
|
if target_size_new == len(a): |
|
|
if target_size_new == len(a): |
|
|
break |
|
|
break |
|
|
target_size_new += 1 |
|
|
|
|
|
a_remaining = a[target_size_new:] |
|
|
a_remaining = a[target_size_new:] |
|
|
b_remaining = b[target_size_new:] |
|
|
b_remaining = b[target_size_new:] |
|
|
|
|
|
|
|
|
@ -673,28 +746,8 @@ def ner_predict(tokens): |
|
|
|
|
|
|
|
|
return results |
|
|
return results |
|
|
|
|
|
|
|
|
def main(content: str): |
|
|
|
|
|
|
|
|
|
|
|
# 先整理句子,把句子整理成模型需要的格式 [id, sen, lable] |
|
|
|
|
|
paper_content_list = [[i,j] for i,j in enumerate(content.split("\n"))] |
|
|
|
|
|
|
|
|
|
|
|
# 先逐句把每句话是否是标题,是否是正文,是否是无用类别识别出来, |
|
|
|
|
|
print("先逐句把每句话是否是标题,是否是正文,是否是无用类别识别出来") |
|
|
|
|
|
zong_list = gen_zong_cls(paper_content_list) |
|
|
|
|
|
|
|
|
|
|
|
# 把标题数据和正文数据,无用类别数据做区分 |
|
|
|
|
|
title_data = [] |
|
|
|
|
|
content_data = [] |
|
|
|
|
|
|
|
|
|
|
|
for data_dan in zong_list: |
|
|
|
|
|
if data_dan[2] == "标题": |
|
|
|
|
|
title_data.append([data_dan[0], data_dan[1]]) |
|
|
|
|
|
if data_dan[2] == "正文": |
|
|
|
|
|
content_data.append([data_dan[0], data_dan[1]]) |
|
|
|
|
|
|
|
|
|
|
|
# 把所有的标题类型提取出来,对每个标题区分标题级别 |
|
|
|
|
|
print("把所有的标题类型提取出来,对每个标题区分标题级别") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gen_title_ner(title_data): |
|
|
data_dan_sen = [i[1] for i in title_data] |
|
|
data_dan_sen = [i[1] for i in title_data] |
|
|
data_dan_sen_index = [i[0] for i in title_data] |
|
|
data_dan_sen_index = [i[0] for i in title_data] |
|
|
data_dan_sen_index_new = [] |
|
|
data_dan_sen_index_new = [] |
|
|
@ -712,7 +765,8 @@ def main(content: str): |
|
|
data_dan_sen_index_new = data_dan_sen_index_new[:-1] |
|
|
data_dan_sen_index_new = data_dan_sen_index_new[:-1] |
|
|
data_dan_sen_new = data_dan_sen_new[:-1] |
|
|
data_dan_sen_new = data_dan_sen_new[:-1] |
|
|
data_dan_sen_new = ["[SEP]" if item == "\n" else item for item in data_dan_sen_new] |
|
|
data_dan_sen_new = ["[SEP]" if item == "\n" else item for item in data_dan_sen_new] |
|
|
a_return1, b_return1 = split_lists_recursive(data_dan_sen_new, data_dan_sen_index_new, data_dan_sen_new, data_dan_sen_index_new, |
|
|
a_return1, b_return1 = split_lists_recursive(data_dan_sen_new, data_dan_sen_index_new, data_dan_sen_new, |
|
|
|
|
|
data_dan_sen_index_new, |
|
|
target_size=510) |
|
|
target_size=510) |
|
|
|
|
|
|
|
|
data_zong_train_list = [] |
|
|
data_zong_train_list = [] |
|
|
@ -769,20 +823,69 @@ def main(content: str): |
|
|
title_list.append([ii, sen, label]) |
|
|
title_list.append([ii, sen, label]) |
|
|
|
|
|
|
|
|
title_data_dict = {} |
|
|
title_data_dict = {} |
|
|
|
|
|
# TODO 此处需要确定多个标签值的情况怎么办,暂时先以首次出现的标签值为准 |
|
|
for i in title_list: |
|
|
for i in title_list: |
|
|
if i[0] not in title_data_dict: |
|
|
if i[0] not in title_data_dict: |
|
|
title_data_dict[i[0]] = [[i[1], i[2]]] |
|
|
title_data_dict[i[0]] = [i[1], i[2]] |
|
|
else: |
|
|
# else: |
|
|
title_data_dict[i[0]] += [[i[1], i[2]]] |
|
|
# title_data_dict[i[0]] += [[i[1], i[2]]] |
|
|
|
|
|
|
|
|
|
|
|
title_list_new = [] |
|
|
|
|
|
for i, j in title_data_dict.items(): |
|
|
|
|
|
title_list_new.append([i, j[0], j[1]]) |
|
|
|
|
|
|
|
|
print(title_data_dict) |
|
|
print(title_data_dict) |
|
|
# 把所有的标题类型提取出来,对每个标题区分标题级别 |
|
|
title_list = title_list_new |
|
|
print("把所有的标题类型提取出来,对每个标题区分标题级别") |
|
|
|
|
|
|
|
|
return title_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(content: str): |
|
|
|
|
|
|
|
|
|
|
|
# 先整理句子,把句子整理成模型需要的格式 [id, sen, lable] |
|
|
|
|
|
paper_content_list = [[i,j] for i,j in enumerate(content.split("\n"))] |
|
|
|
|
|
|
|
|
|
|
|
# 先逐句把每句话是否是标题,是否是正文,是否是无用类别识别出来, |
|
|
|
|
|
print("先逐句把每句话是否是标题,是否是正文,是否是无用类别识别出来") |
|
|
|
|
|
zong_list = gen_zong_cls(paper_content_list) |
|
|
|
|
|
|
|
|
|
|
|
# 把标题数据和正文数据,无用类别数据做区分 |
|
|
|
|
|
title_data = [] |
|
|
|
|
|
content_data = [] |
|
|
|
|
|
|
|
|
|
|
|
for data_dan in zong_list: |
|
|
|
|
|
if data_dan[2] == "标题": |
|
|
|
|
|
title_data.append([data_dan[0], data_dan[1]]) |
|
|
|
|
|
if data_dan[2] == "正文": |
|
|
|
|
|
content_data.append([data_dan[0], data_dan[1]]) |
|
|
|
|
|
|
|
|
# 把所有的标题类型提取出来,对每个标题区分标题级别 |
|
|
# 把所有的标题类型提取出来,对每个标题区分标题级别 |
|
|
print("把所有的标题类型提取出来,对每个标题区分标题级别") |
|
|
print("把所有的标题类型提取出来,对每个标题区分标题级别") |
|
|
# title_list = gen_title_cls(title_data) |
|
|
|
|
|
|
|
|
# 把所有的标题类型提取出来,对每个标题区分标题级别 |
|
|
|
|
|
# print("把所有的标题类型提取出来,对每个标题区分标题级别") |
|
|
|
|
|
title_list_ner = gen_title_ner(title_data) |
|
|
|
|
|
title_list_cls = gen_title_cls(title_data) |
|
|
|
|
|
title_list_cls_2 = gen_title_cls_2(title_data) |
|
|
|
|
|
title_list_ner = sorted(title_list_ner, key=lambda item: item[0]) |
|
|
|
|
|
title_list_cls = sorted(title_list_cls, key=lambda item: item[0]) |
|
|
|
|
|
title_list_cls_2 = sorted(title_list_cls_2, key=lambda item: item[0]) |
|
|
|
|
|
print(title_list_ner) |
|
|
|
|
|
print(title_list_cls) |
|
|
|
|
|
print(title_list_cls_2) |
|
|
|
|
|
title_list = [] |
|
|
|
|
|
for i,j in zip(title_list_ner, title_list_cls): |
|
|
|
|
|
if i[2] == '非标题类型' or j[2] == '非标题类型': |
|
|
|
|
|
title_list.append([i[0], i[1], '非标题类型']) |
|
|
|
|
|
else: |
|
|
|
|
|
title_list.append(i) |
|
|
|
|
|
title_list_new = [] |
|
|
|
|
|
for i in title_list: |
|
|
|
|
|
if i[2] == '非标题类型': |
|
|
|
|
|
content_data.append([i[0], i[1]]) |
|
|
|
|
|
else: |
|
|
|
|
|
title_list_new.append(i) |
|
|
|
|
|
title_list = title_list_new |
|
|
|
|
|
|
|
|
# 把所有的正文类别提取出来,逐个进行打标 |
|
|
# 把所有的正文类别提取出来,逐个进行打标 |
|
|
print("把所有的正文类别提取出来,逐个进行打标") |
|
|
print("把所有的正文类别提取出来,逐个进行打标") |
|
|
|