diff --git a/chatgpt_detector_model_predict.py b/chatgpt_detector_model_predict.py
index 75f01dc..dbe9789 100644
--- a/chatgpt_detector_model_predict.py
+++ b/chatgpt_detector_model_predict.py
@@ -30,9 +30,13 @@ db_key_querying = 'querying'
db_key_queryset = 'queryset'
batch_size = 32
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-tokenizer = AutoTokenizer.from_pretrained("chatgpt-detector-roberta-chinese")
+# tokenizer = AutoTokenizer.from_pretrained("chatgpt-detector-roberta-chinese")
# model = AutoModelForSequenceClassification.from_pretrained("chatgpt-detector-roberta-chinese").cuda()
-model = AutoModelForSequenceClassification.from_pretrained("chatgpt-detector-roberta-chinese").cpu()
+# model = AutoModelForSequenceClassification.from_pretrained("chatgpt-detector-roberta-chinese").cpu()
+model_name = "AIGC_detector_zhv2"
+
+tokenizer = AutoTokenizer.from_pretrained(model_name)
+model = AutoModelForSequenceClassification.from_pretrained(model_name).cpu()
def model_preidct(text):
tokenized_text = tokenizer.encode_plus(text, max_length=512, add_special_tokens=True,
@@ -80,28 +84,42 @@ def main(content_list: list):
sim_word = 0
sim_word_5_9 = 0
total_words = 0
+ print(content_list)
total_paragraph = len(content_list)
- for i in range(len(content_list)):
- total_words += len(content_list[i])
- res = model_preidct(content_list[i])
+
+ for i in range(0, len(content_list), 3):
+ if i + 2 <= len(content_list)-1:
+ sen_nums = 3
+ content_str = "。".join([content_list[i], content_list[i+1], content_list[i+2]])
+ elif i + 1 <= len(content_list)-1:
+ sen_nums = 2
+ content_str = "。".join([content_list[i], content_list[i + 1]])
+ else:
+ sen_nums = 1
+ content_str = content_list[i]
+ total_words += len(content_str)
+ res = model_preidct(content_str)
# return_list = {
# "humen": output[0][0],
# "robot": output[0][1]
# }
if res["robot"] > 0.9:
- gpt_score_list.append(res["robot"])
- sim_word += len(content_list[i])
- gpt_content.append(
- "".format(str(i)) + content_list[i] + "。\n" + "")
+ for ci in range(sen_nums):
+ gpt_score_list.append(res["robot"])
+ sim_word += len(content_list[i + ci])
+ gpt_content.append(
+ "".format(str(i + ci)) + content_list[i + ci] + "。\n" + "")
elif 0.9 > res["robot"] > 0.5:
- gpt_score_list.append(res["robot"])
- sim_word_5_9 += len(content_list[i])
- gpt_content.append(
- "".format(str(i)) + content_list[i] + "。\n" + "")
+ for ci in range(sen_nums):
+ gpt_score_list.append(res["robot"])
+ sim_word_5_9 += len(content_list[i + ci])
+ gpt_content.append(
+ "".format(str(i + ci)) + content_list[i + ci] + "。\n" + "")
else:
- gpt_score_list.append(0)
- gpt_content.append(content_list[i] + "。\n")
+ for ci in range(sen_nums):
+ gpt_score_list.append(0)
+ gpt_content.append(content_list[i + ci] + "。\n")
return_list["gpt_content"] = "".join(gpt_content)
return_list["gpt_score_list"] = str(gpt_score_list)
@@ -114,7 +132,6 @@ def main(content_list: list):
def classify(): # 调用模型,设置最大batch_size
while True:
- try:
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取
time.sleep(3)
continue
@@ -172,8 +189,7 @@ def classify(): # 调用模型,设置最大batch_size
json.dump(return_text, f2, ensure_ascii=False, indent=4)
redis_.set(queue_uuid, load_result_path, 86400)
redis_.srem(db_key_querying, queue_uuid)
- except:
- continue
+
if __name__ == '__main__':