
13 changed files with 993 additions and 83 deletions
@ -1,47 +1,29 @@ |
|||||
# 小说改写项目 |
# 小说改写项目 |
||||
|
|
||||
基于simbert模型的生成式任务,使用keras框架,数据处理脚本在data_do文件夹下 |
基于unilm模型的生成式任务,使用keras框架,数据处理脚本在data_do文件夹下 |
||||
训练数据 train_yy_sim.txt |
训练数据 train_cat_data_4.txt |
||||
|
|
||||
## 训练 |
## 训练 |
||||
训练:bash train.sh |
加入了质量检测训练:bash train.sh |
||||
训练:bash train_dropout.sh |
加入了质量检测训练:bash train_sim.sh |
||||
|
|
||||
## 预测 |
## 预测 |
||||
|
|
||||
加入了质量检测 predict_sim.py |
加入了质量检测 python predict_tf_sim.py |
||||
|
未加入质量检测 python predict_tf.py |
||||
|
|
||||
## API serve |
## API serve |
||||
|
|
||||
目前的启动方式:bash run_app.sh |
目前的启动方式:bash run_app.sh |
||||
一键启动方式:bash run_app_gunicorn.sh |
一键启动方式:bash run_app_gunicorn.sh |
||||
命令行启动:python flask_predict_no_batch.py |
|
||||
|
|
||||
## 请求示例 |
## 请求示例 |
||||
requests.post( |
requests.post( |
||||
"http://192.168.1.17:14000", |
"http://192.168.1.17:14000", |
||||
json={"texts": ["李正旺你真是傻逼讪笑”。", |
json={"texts": ["张三要爬上高位的,才能够翻云覆雨。"]}, |
||||
"李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。你这是见义勇为,是勇斗歹徒、义救同学的英雄,我们清江一中决不让英雄流血又流泪!”。", |
|
||||
"李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。你这是见义勇为,是勇斗歹徒、义救同学的英雄,我们清江一中决不让英雄流血又流泪!”。", |
|
||||
"李正旺你真是傻逼讪笑”。", |
|
||||
"李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。你这是见义勇为,是勇斗歹徒、义救同学的英雄,我们清江一中决不让英雄流血又流泪!”。", |
|
||||
"李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。你这是见义勇为,是勇斗歹徒、义救同学的英雄,我们清江一中决不让英雄流血又流泪!”。"], |
|
||||
"text_type":"focus"}, |
|
||||
timeout=1000 |
timeout=1000 |
||||
) |
) |
||||
|
|
||||
|
|
||||
## 响应 |
## 响应 |
||||
{ |
{'probabilities': None, 'texts': ['张三要上了巅峰,他就可以为所欲为了。']} |
||||
"probabilities": null, |
|
||||
"status_code": true, |
|
||||
"texts": [ |
|
||||
"李正旺你真是傻逼地讪笑。", |
|
||||
"李正旺你真是傻逼地讪笑,并挥手说:不不不,你千万不要误会。对于这个事情,校长已经深刻交代过的,而且我也十分理解。你这是见义勇为,是勇斗歹人、义救同学的好人物,在我们清江一中决不能让他流血又流泪!。", |
|
||||
"李正旺你真是傻逼地讪笑,并挥手说:不不不,你千万不要误会。对于这个事情,校长已经深刻交代过的,而且我也十分理解。你这是见义勇为,是勇斗歹人、义救同学的好人物,在我们清江一中决不能让他流血又流泪!。", |
|
||||
"李正旺你真是傻逼地讪笑。", |
|
||||
"李正旺你真是傻逼地讪笑,并挥手说:不不不,你千万不要误会。对于这个事情,校长已经深刻交代过的,而且我也十分理解。你这是见义勇为,是勇斗歹人、义救同学的好人物,在我们清江一中决不能让他流血又流泪!。", |
|
||||
"李正旺你真是傻逼地讪笑,并挥手说:不不不,你千万不要误会。对于这个事情,校长已经深刻交代过的,而且我也十分理解。你这是见义勇为,是勇斗歹人、义救同学的好人物,在我们清江一中决不能让他流血又流泪!。" |
|
||||
] |
|
||||
} |
|
@ -0,0 +1,28 @@ |
|||||
|
# -*- coding: utf-8 -*- |
||||
|
|
||||
|
""" |
||||
|
@Time : 2023/2/27 18:24 |
||||
|
@Author : |
||||
|
@FileName: |
||||
|
@Software: |
||||
|
@Describe: |
||||
|
""" |
||||
|
import pandas as pd |
||||
|
import difflib |
||||
|
|
||||
|
path = "../data/11篇_yy.xlsx" |
||||
|
data = pd.read_excel( |
||||
|
path |
||||
|
).values.tolist() |
||||
|
|
||||
|
|
||||
|
data_new = [] |
||||
|
for i in data: |
||||
|
data_1 = i[0] |
||||
|
data_2 = i[1] |
||||
|
str_sim_value = difflib.SequenceMatcher(None, data_1, data_2).quick_ratio() |
||||
|
data_new.append(i + [str_sim_value]) |
||||
|
|
||||
|
data_new = sorted(data_new, key= lambda x:x[2], reverse=True) |
||||
|
df = pd.DataFrame(data_new) |
||||
|
df.to_excel("../data/11篇_yy_strsim.xlsx", index=None) |
@ -0,0 +1,33 @@ |
|||||
|
# -*- coding: utf-8 -*- |
||||
|
|
||||
|
""" |
||||
|
@Time : 2023/2/27 18:24 |
||||
|
@Author : |
||||
|
@FileName: |
||||
|
@Software: |
||||
|
@Describe: |
||||
|
""" |
||||
|
import pandas as pd |
||||
|
import difflib |
||||
|
|
||||
|
file = "../data/11篇汇总txt_new_predict_t5.txt" |
||||
|
try: |
||||
|
with open(file, 'r', encoding="utf-8") as f: |
||||
|
lines = [x.strip() for x in f if x.strip() != ''] |
||||
|
except: |
||||
|
with open(file, 'r', encoding="gbk") as f: |
||||
|
lines = [x.strip() for x in f if x.strip() != ''] |
||||
|
|
||||
|
data_new = [] |
||||
|
for i in lines: |
||||
|
data_dan = i.split("\t") |
||||
|
if len(data_dan) != 2: |
||||
|
continue |
||||
|
data_1 = data_dan[0] |
||||
|
data_2 = data_dan[1] |
||||
|
str_sim_value = difflib.SequenceMatcher(None, data_1, data_2).quick_ratio() |
||||
|
data_new.append(data_dan + [str_sim_value]) |
||||
|
print(data_new) |
||||
|
data_new = sorted(data_new, key= lambda x:x[2], reverse=True) |
||||
|
df = pd.DataFrame(data_new) |
||||
|
df.to_excel("../data/11篇_t5_strsim.xlsx", index=None) |
@ -0,0 +1,63 @@ |
|||||
|
# -*- coding: utf-8 -*- |
||||
|
|
||||
|
""" |
||||
|
@Time : 2022/12/20 10:35 |
||||
|
@Author : |
||||
|
@FileName: |
||||
|
@Software: |
||||
|
@Describe: |
||||
|
""" |
||||
|
import os |
||||
|
from bs4 import BeautifulSoup |
||||
|
import pandas as pd |
||||
|
import re |
||||
|
# 遍历文件夹 |
||||
|
|
||||
|
|
||||
|
|
||||
|
yuanshi = "../data/11篇yy/paperyyreduce20230221120936.html" |
||||
|
soup_source = BeautifulSoup(open(yuanshi, encoding='utf-8'), |
||||
|
"html.parser") |
||||
|
|
||||
|
yyshuju = "../data/11篇yy/paperyyreduce_result20230221120936" |
||||
|
soup_result = BeautifulSoup(open(yyshuju, encoding='utf-8'), |
||||
|
"html.parser") |
||||
|
|
||||
|
source_sentence_list = soup_source.select('p > em') |
||||
|
result_sentence_list = soup_result.select('p > em') |
||||
|
|
||||
|
|
||||
|
data = [] |
||||
|
for sentence_index in range(len(source_sentence_list)): |
||||
|
try: |
||||
|
print(source_sentence_list[sentence_index]["id"]) |
||||
|
print(result_sentence_list[sentence_index]["id"]) |
||||
|
print(result_sentence_list[sentence_index]["class"]) |
||||
|
if source_sentence_list[sentence_index]["id"] == result_sentence_list[sentence_index]["id"] \ |
||||
|
and (result_sentence_list[sentence_index]["class"] == ['similar','red'] |
||||
|
or result_sentence_list[sentence_index]["class"] == ['similar']): |
||||
|
# if source_sentence_list[sentence_index]["id"] == result_sentence_list[sentence_index]["id"]: |
||||
|
source_text = source_sentence_list[sentence_index].string |
||||
|
result_text = result_sentence_list[sentence_index].string |
||||
|
source_text = source_text.strip("\n") |
||||
|
result_text = result_text.strip("\n") |
||||
|
if source_text != None and result_text != None: |
||||
|
data.append([source_text,result_text]) |
||||
|
except: |
||||
|
print(sentence_index) |
||||
|
|
||||
|
# print(data) |
||||
|
|
||||
|
|
||||
|
def data_clean(text): |
||||
|
# 清洗excel中的非法字符,都是不常见的不可显示字符,例如退格,响铃等 |
||||
|
ILLEGAL_CHARACTERS_RE = re.compile(r'[\000-\010]|[\013-\014]|[\016-\037]') |
||||
|
text = ILLEGAL_CHARACTERS_RE.sub(r'', text) |
||||
|
return text |
||||
|
|
||||
|
print(data) |
||||
|
df = pd.DataFrame(data,columns=["原文","yy降重"]) |
||||
|
for col in df.columns: |
||||
|
df[col] = df[col].apply(lambda x: data_clean(x)) |
||||
|
|
||||
|
df.to_excel("../data/11篇_yy.xlsx",index=None) |
@ -0,0 +1,186 @@ |
|||||
|
# -*- coding: utf-8 -*- |
||||
|
|
||||
|
""" |
||||
|
@Time : 2023/1/31 19:02 |
||||
|
@Author : |
||||
|
@FileName: |
||||
|
@Software: |
||||
|
@Describe: |
||||
|
""" |
||||
|
import os |
||||
|
# os.environ["TF_KERAS"] = "1" |
||||
|
import pandas as pd |
||||
|
|
||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
||||
|
import json |
||||
|
import numpy as np |
||||
|
from bert4keras.backend import keras, set_gelu |
||||
|
from bert4keras.tokenizers import Tokenizer, load_vocab |
||||
|
from bert4keras.models import build_transformer_model |
||||
|
from bert4keras.optimizers import Adam, extend_with_piecewise_linear_lr |
||||
|
from bert4keras.snippets import sequence_padding, DataGenerator |
||||
|
from bert4keras.snippets import open |
||||
|
from keras.layers import Lambda, Dense |
||||
|
import tensorflow as tf |
||||
|
from keras.backend import set_session |
||||
|
from sklearn.metrics.pairwise import cosine_similarity |
||||
|
from rouge import Rouge # pip install rouge |
||||
|
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction |
||||
|
from tqdm import tqdm |
||||
|
import jieba |
||||
|
from gensim.models import KeyedVectors, word2vec, Word2Vec |
||||
|
import random |
||||
|
import difflib |
||||
|
|
||||
|
config = tf.ConfigProto() |
||||
|
config.gpu_options.allow_growth = True |
||||
|
set_session(tf.Session(config=config)) # 此处不同 |
||||
|
|
||||
|
class Word2vecModel: |
||||
|
def __init__(self): |
||||
|
self.path = "E:\pycharm_workspace\查重分析\word2vec_model\\word2vec_add_new_18.model" |
||||
|
self.model = Word2Vec.load(self.path) |
||||
|
|
||||
|
def word2vec_res(self,seg_0_list, seg_1_list): |
||||
|
sentence_0_list = [] |
||||
|
sentence_1_list = [] |
||||
|
for i in seg_0_list: |
||||
|
a = self.model.wv[i] |
||||
|
sentence_0_list.append(a) |
||||
|
|
||||
|
for i in seg_1_list: |
||||
|
a = self.model.wv[i] |
||||
|
sentence_1_list.append(a) |
||||
|
|
||||
|
return sentence_0_list, sentence_1_list |
||||
|
|
||||
|
class Evaluator(keras.callbacks.Callback): |
||||
|
"""评估与保存 |
||||
|
""" |
||||
|
|
||||
|
def __init__(self): |
||||
|
self.rouge = Rouge() |
||||
|
self.smooth = SmoothingFunction().method1 |
||||
|
self.best_bleu = 0. |
||||
|
|
||||
|
# def on_epoch_end(self, epoch, logs=None): |
||||
|
# metrics = self.evaluate(valid_data) # 评测模型 |
||||
|
# if metrics['bleu'] > self.best_bleu: |
||||
|
# self.best_bleu = metrics['bleu'] |
||||
|
# model.save_weights('./best_model.weights') # 保存模型 |
||||
|
# metrics['best_bleu'] = self.best_bleu |
||||
|
# print('valid_data:', metrics) |
||||
|
|
||||
|
|
||||
|
def evaluate_t(self, data_1, data_2, topk=1): |
||||
|
total = 0 |
||||
|
rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0 |
||||
|
|
||||
|
scores = self.rouge.get_scores(hyps=[data_1], refs=[data_2]) |
||||
|
rouge_1 += scores[0]['rouge-1']['f'] |
||||
|
rouge_2 += scores[0]['rouge-2']['f'] |
||||
|
rouge_l += scores[0]['rouge-l']['f'] |
||||
|
bleu += sentence_bleu( |
||||
|
references=[data_1.split(' ')], |
||||
|
hypothesis=data_2.split(' '), |
||||
|
smoothing_function=self.smooth |
||||
|
) |
||||
|
# rouge_1 /= total |
||||
|
# rouge_2 /= total |
||||
|
# rouge_l /= total |
||||
|
# bleu /= total |
||||
|
return [rouge_1, rouge_2, rouge_l, bleu] |
||||
|
|
||||
|
class bertModel: |
||||
|
def __init__(self): |
||||
|
self.config_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_config.json' |
||||
|
self.checkpoint_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt' |
||||
|
self.dict_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt' |
||||
|
self.token_dict, self.keep_tokens = load_vocab( |
||||
|
dict_path=self.dict_path, |
||||
|
simplified=True, |
||||
|
startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'], |
||||
|
) |
||||
|
self.tokenizer = Tokenizer(self.token_dict, do_lower_case=True) |
||||
|
self.buildmodel() |
||||
|
|
||||
|
|
||||
|
def buildmodel(self): |
||||
|
bert = build_transformer_model( |
||||
|
config_path=self.config_path, |
||||
|
checkpoint_path=self.checkpoint_path, |
||||
|
return_keras_model=False, |
||||
|
) |
||||
|
|
||||
|
output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output) |
||||
|
self.model = keras.models.Model(bert.model.input, output) |
||||
|
self.model.summary() |
||||
|
|
||||
|
def predict(self,text): |
||||
|
batch_token_ids, batch_segment_ids = [], [] |
||||
|
token_ids, segment_ids = self.tokenizer.encode(text, maxlen=256) |
||||
|
batch_token_ids.append(token_ids) |
||||
|
batch_segment_ids.append(segment_ids) |
||||
|
return self.model.predict([batch_token_ids, batch_segment_ids]) |
||||
|
|
||||
|
|
||||
|
def simbert(data_1, data_2): |
||||
|
pass |
||||
|
|
||||
|
def word2vec(): |
||||
|
pass |
||||
|
|
||||
|
def bleu(): |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
file = "../data/train_yy_zong.txt" |
||||
|
sim_value = [1, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0] |
||||
|
model = bertModel() |
||||
|
eval_class = Evaluator() |
||||
|
# word2vecmodel = Word2vecModel() |
||||
|
try: |
||||
|
with open(file, 'r', encoding="utf-8") as f: |
||||
|
lines = [x.strip() for x in f if x.strip() != ''] |
||||
|
except: |
||||
|
with open(file, 'r', encoding="gbk") as f: |
||||
|
lines = [x.strip() for x in f if x.strip() != ''] |
||||
|
|
||||
|
bertsim_list = [] |
||||
|
bleusim_list = [] |
||||
|
word2vecsim_list = [] |
||||
|
data_train_text = [] |
||||
|
|
||||
|
random.shuffle(lines) |
||||
|
print(len(lines)) |
||||
|
for txt in tqdm(lines): |
||||
|
text = txt.split('\t') |
||||
|
if len(text) == 3: |
||||
|
data_1 = text[0] |
||||
|
data_2 = text[2] |
||||
|
str_sim_value = difflib.SequenceMatcher(None, data_1, data_2).quick_ratio() |
||||
|
# if len(data_2) - len(data_1) < 0 and len(data_2) / len(data_1) > 0.8: |
||||
|
# num_yu = 1 - len(data_2) / len(data_1) |
||||
|
# str_sim_value = 1 - str_sim_value * num_yu |
||||
|
|
||||
|
if 1 >= str_sim_value > 0.95: |
||||
|
data_train_text.append([data_1, data_2, str(str_sim_value), "1-0.95"]) |
||||
|
elif 0.95 >= str_sim_value > 0.9: |
||||
|
data_train_text.append([data_1, data_2, str(str_sim_value), "0.95-0.9"]) |
||||
|
elif 0.9 >= str_sim_value > 0.85: |
||||
|
data_train_text.append([data_1, data_2, str(str_sim_value), "0.9-0.85"]) |
||||
|
elif 0.85 >= str_sim_value > 0.8: |
||||
|
data_train_text.append([data_1, data_2, str(str_sim_value), "0.85-0.8"]) |
||||
|
elif 0.8 >= str_sim_value > 0.75: |
||||
|
data_train_text.append([data_1, data_2, str(str_sim_value), "0.8-0.75"]) |
||||
|
elif 0.75 >= str_sim_value > 0.7: |
||||
|
data_train_text.append([data_1, data_2, str(str_sim_value), "0.75-0.7"]) |
||||
|
else: |
||||
|
data_train_text.append([data_1, data_2, str(str_sim_value), "0.7 - 0"]) |
||||
|
|
||||
|
data_train_text = sorted(data_train_text, key=lambda x:x[2], reverse=True) |
||||
|
df = pd.DataFrame(data_train_text) |
||||
|
print(df) |
||||
|
df.to_csv("../data/yy改写相似度.csv", index=None) |
||||
|
df.to_excel("../data/yy改写相似度.xlsx", index=None) |
@ -0,0 +1,277 @@ |
|||||
|
import os |
||||
|
# os.environ["TF_KERAS"] = "1" |
||||
|
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
||||
|
# os.environ["CUDA_VISIBLE_DEVICES"] = "1" |
||||
|
from flask import Flask, jsonify |
||||
|
from flask import request |
||||
|
# from linshi import autotitle |
||||
|
import requests |
||||
|
from flask import request |
||||
|
from predict_t5 import autotitle |
||||
|
|
||||
|
|
||||
|
import re |
||||
|
app = Flask(__name__) |
||||
|
app.config["JSON_AS_ASCII"] = False |
||||
|
|
||||
|
import logging |
||||
|
pattern = r"[。]" |
||||
|
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") |
||||
|
fuhao_end_sentence = ["。",",","?","!","…"] |
||||
|
|
||||
|
config = { |
||||
|
"batch_szie": 1000 |
||||
|
} |
||||
|
|
||||
|
|
||||
|
def get_dialogs_index(line: str): |
||||
|
""" |
||||
|
获取对话及其索引 |
||||
|
:param line 文本 |
||||
|
:return dialogs 对话内容 |
||||
|
dialogs_index: 对话位置索引 |
||||
|
other_index: 其他内容位置索引 |
||||
|
""" |
||||
|
dialogs = re.finditer(RE_DIALOG, line) |
||||
|
dialogs_text = re.findall(RE_DIALOG, line) |
||||
|
dialogs_index = [] |
||||
|
for dialog in dialogs: |
||||
|
all_ = [i for i in range(dialog.start(), dialog.end())] |
||||
|
dialogs_index.extend(all_) |
||||
|
other_index = [i for i in range(len(line)) if i not in dialogs_index] |
||||
|
|
||||
|
return dialogs_text, dialogs_index, other_index |
||||
|
|
||||
|
|
||||
|
def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): |
||||
|
fuhao = [",","?","!","…"] |
||||
|
text_1 = text[:120] |
||||
|
text_2 = text[120:] |
||||
|
text_1_new = "" |
||||
|
for i in range(len(text_1)-1, -1, -1): |
||||
|
if text_1[i] in fuhao: |
||||
|
text_1_new = text_1[:i] |
||||
|
text_1_new += text_1[i] |
||||
|
chulipangban_return_list.append([text_1_new, snetence_id, short_num]) |
||||
|
if text_2 != "": |
||||
|
if i+1 != 120: |
||||
|
text_2 = text_1[i+1:] + text_2 |
||||
|
break |
||||
|
# else: |
||||
|
# chulipangban_return_list.append(text_1) |
||||
|
if text_1_new == "": |
||||
|
chulipangban_return_list.append([text_1, snetence_id, short_num]) |
||||
|
if text_2 != "": |
||||
|
short_num += 1 |
||||
|
chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num) |
||||
|
return chulipangban_return_list |
||||
|
|
||||
|
|
||||
|
def chulipangban_test_1(text, snetence_id): |
||||
|
sentence_list = text.split("。") |
||||
|
# sentence_list_new = [] |
||||
|
# for i in sentence_list: |
||||
|
# if i != "": |
||||
|
# sentence_list_new.append(i) |
||||
|
# sentence_list = sentence_list_new |
||||
|
sentence_batch_list = [] |
||||
|
sentence_batch_one = [] |
||||
|
sentence_batch_length = 0 |
||||
|
return_list = [] |
||||
|
for sentence in sentence_list: |
||||
|
if len(sentence) < 120: |
||||
|
sentence_batch_length += len(sentence) |
||||
|
sentence_batch_list.append([sentence, snetence_id, 0]) |
||||
|
# sentence_pre = autotitle.gen_synonyms_short(sentence) |
||||
|
# return_list.append(sentence_pre) |
||||
|
else: |
||||
|
sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0) |
||||
|
for sentence_short in sentence_split_list: |
||||
|
sentence_batch_list.append(sentence_short) |
||||
|
return sentence_batch_list |
||||
|
|
||||
|
|
||||
|
def paragraph_test_(text:list, text_new:list): |
||||
|
|
||||
|
for i in range(len(text)): |
||||
|
text = chulipangban_test_1(text, i) |
||||
|
text = "。".join(text) |
||||
|
text_new.append(text) |
||||
|
|
||||
|
# text_new_str = "".join(text_new) |
||||
|
return text_new |
||||
|
|
||||
|
def paragraph_test(text:list): |
||||
|
|
||||
|
text_new = [] |
||||
|
for i in range(len(text)): |
||||
|
text_list = chulipangban_test_1(text[i], i) |
||||
|
text_new.extend(text_list) |
||||
|
|
||||
|
# text_new_str = "".join(text_new) |
||||
|
return text_new |
||||
|
|
||||
|
|
||||
|
def batch_data_process(text_list): |
||||
|
sentence_batch_length = 0 |
||||
|
sentence_batch_one = [] |
||||
|
sentence_batch_list = [] |
||||
|
|
||||
|
for sentence in text_list: |
||||
|
sentence_batch_length += len(sentence[0]) |
||||
|
sentence_batch_one.append(sentence) |
||||
|
if sentence_batch_length > 500: |
||||
|
sentence_batch_length = 0 |
||||
|
sentence_ = sentence_batch_one.pop(-1) |
||||
|
sentence_batch_list.append(sentence_batch_one) |
||||
|
sentence_batch_one = [] |
||||
|
sentence_batch_one.append(sentence_) |
||||
|
sentence_batch_list.append(sentence_batch_one) |
||||
|
return sentence_batch_list |
||||
|
|
||||
|
def batch_predict(batch_data_list): |
||||
|
''' |
||||
|
一个bacth数据预测 |
||||
|
@param data_text: |
||||
|
@return: |
||||
|
''' |
||||
|
batch_data_list_new = [] |
||||
|
batch_data_text_list = [] |
||||
|
batch_data_snetence_id_list = [] |
||||
|
for i in batch_data_list: |
||||
|
batch_data_text_list.append(i[0]) |
||||
|
batch_data_snetence_id_list.append(i[1:]) |
||||
|
# batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) |
||||
|
batch_pre_data_list = batch_data_text_list |
||||
|
for text,sentence_id in zip(batch_pre_data_list,batch_data_snetence_id_list): |
||||
|
batch_data_list_new.append([text] + sentence_id) |
||||
|
|
||||
|
return batch_data_list_new |
||||
|
|
||||
|
|
||||
|
def one_predict(data_text): |
||||
|
''' |
||||
|
一个条数据预测 |
||||
|
@param data_text: |
||||
|
@return: |
||||
|
''' |
||||
|
if data_text[0] != "": |
||||
|
pre_data = autotitle.generate(data_text[0]) |
||||
|
else: |
||||
|
pre_data = "" |
||||
|
data_new = [pre_data] + data_text[1:] |
||||
|
return data_new |
||||
|
|
||||
|
|
||||
|
def predict_data_post_processing(text_list): |
||||
|
text_list_sentence = [] |
||||
|
# text_list_sentence.append([text_list[0][0], text_list[0][1]]) |
||||
|
for i in range(len(text_list)): |
||||
|
if text_list[i][2] != 0: |
||||
|
text_list_sentence[-1][0] += text_list[i][0] |
||||
|
else: |
||||
|
text_list_sentence.append([text_list[i][0], text_list[i][1]]) |
||||
|
|
||||
|
return_list = [] |
||||
|
sentence_one = [] |
||||
|
sentence_id = 0 |
||||
|
for i in text_list_sentence: |
||||
|
if i[1] == sentence_id: |
||||
|
sentence_one.append(i[0]) |
||||
|
else: |
||||
|
sentence_id = i[1] |
||||
|
return_list.append("。".join(sentence_one)) |
||||
|
sentence_one = [] |
||||
|
sentence_one.append(i[0]) |
||||
|
if sentence_one != []: |
||||
|
return_list.append("。".join(sentence_one)) |
||||
|
return return_list |
||||
|
|
||||
|
|
||||
|
# def main(text:list): |
||||
|
# # text_list = paragraph_test(text) |
||||
|
# # batch_data = batch_data_process(text_list) |
||||
|
# # text_list = [] |
||||
|
# # for i in batch_data: |
||||
|
# # text_list.extend(i) |
||||
|
# # return_list = predict_data_post_processing(text_list) |
||||
|
# # return return_list |
||||
|
|
||||
|
def main(text: list): |
||||
|
text_list = paragraph_test(text) |
||||
|
text_list_new = [] |
||||
|
for i in text_list: |
||||
|
pre = one_predict(i) |
||||
|
text_list_new.append(pre) |
||||
|
return_list = predict_data_post_processing(text_list_new) |
||||
|
return return_list |
||||
|
|
||||
|
@app.route('/droprepeat/', methods=['POST']) |
||||
|
def sentence(): |
||||
|
print(request.remote_addr) |
||||
|
texts = request.json["texts"] |
||||
|
text_type = request.json["text_type"] |
||||
|
print("原始语句" + str(texts)) |
||||
|
# question = question.strip('。、!??') |
||||
|
|
||||
|
|
||||
|
if isinstance(texts, list): |
||||
|
texts_list = [] |
||||
|
y_pred_label_list = [] |
||||
|
position_list = [] |
||||
|
|
||||
|
# texts = texts.replace('\'', '\"') |
||||
|
if texts is None: |
||||
|
return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} |
||||
|
return jsonify(return_text) |
||||
|
else: |
||||
|
assert text_type in ['focus', 'chapter'] |
||||
|
if text_type == 'focus': |
||||
|
texts_list = main(texts) |
||||
|
if text_type == 'chapter': |
||||
|
texts_list = main(texts) |
||||
|
|
||||
|
return_text = {"texts": texts_list, "probabilities": None, "status_code": True} |
||||
|
else: |
||||
|
return_text = {"texts":"输入格式应该为list", "probabilities": None, "status_code":False} |
||||
|
return jsonify(return_text) |
||||
|
|
||||
|
|
||||
|
# @app.route('/chapter/', methods=['POST']) |
||||
|
# def chapter(): |
||||
|
# texts = request.json["texts"] |
||||
|
# |
||||
|
# print("原始语句" + str(texts)) |
||||
|
# # question = question.strip('。、!??') |
||||
|
# |
||||
|
# if isinstance(texts, str): |
||||
|
# texts_list = [] |
||||
|
# y_pred_label_list = [] |
||||
|
# position_list = [] |
||||
|
# |
||||
|
# # texts = texts.replace('\'', '\"') |
||||
|
# if texts is None: |
||||
|
# return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} |
||||
|
# return jsonify(return_text) |
||||
|
# else: |
||||
|
# texts = texts.split("\n") |
||||
|
# for text in texts: |
||||
|
# text = text.strip() |
||||
|
# return_str = autotitle.generate_random_shortest(text) |
||||
|
# texts_list.append(return_str) |
||||
|
# texts_str = "\n".join(texts_list) |
||||
|
# return_text = {"texts": texts_str, "probabilities": None, "status_code": True} |
||||
|
# else: |
||||
|
# return_text = {"texts": "输入格式应该为str", "probabilities": None, "status_code": False} |
||||
|
# return jsonify(return_text) |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
fh = logging.FileHandler(mode='a', encoding='utf-8', filename='chitchat.log') |
||||
|
logging.basicConfig( |
||||
|
handlers=[fh], |
||||
|
level=logging.DEBUG, |
||||
|
format='%(asctime)s - %(levelname)s - %(message)s', |
||||
|
datefmt='%a, %d %b %Y %H:%M:%S', |
||||
|
) |
||||
|
app.run(host="0.0.0.0", port=14000, threaded=True, debug=False) |
Loading…
Reference in new issue