diff --git a/chatgpt_detector_model_predict.py b/chatgpt_detector_model_predict.py index 75f01dc..dbe9789 100644 --- a/chatgpt_detector_model_predict.py +++ b/chatgpt_detector_model_predict.py @@ -30,9 +30,13 @@ db_key_querying = 'querying' db_key_queryset = 'queryset' batch_size = 32 # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -tokenizer = AutoTokenizer.from_pretrained("chatgpt-detector-roberta-chinese") +# tokenizer = AutoTokenizer.from_pretrained("chatgpt-detector-roberta-chinese") # model = AutoModelForSequenceClassification.from_pretrained("chatgpt-detector-roberta-chinese").cuda() -model = AutoModelForSequenceClassification.from_pretrained("chatgpt-detector-roberta-chinese").cpu() +# model = AutoModelForSequenceClassification.from_pretrained("chatgpt-detector-roberta-chinese").cpu() +model_name = "AIGC_detector_zhv2" + +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForSequenceClassification.from_pretrained(model_name).cpu() def model_preidct(text): tokenized_text = tokenizer.encode_plus(text, max_length=512, add_special_tokens=True, @@ -80,28 +84,42 @@ def main(content_list: list): sim_word = 0 sim_word_5_9 = 0 total_words = 0 + print(content_list) total_paragraph = len(content_list) - for i in range(len(content_list)): - total_words += len(content_list[i]) - res = model_preidct(content_list[i]) + + for i in range(0, len(content_list), 3): + if i + 2 <= len(content_list)-1: + sen_nums = 3 + content_str = "。".join([content_list[i], content_list[i+1], content_list[i+2]]) + elif i + 1 <= len(content_list)-1: + sen_nums = 2 + content_str = "。".join([content_list[i], content_list[i + 1]]) + else: + sen_nums = 1 + content_str = content_list[i] + total_words += len(content_str) + res = model_preidct(content_str) # return_list = { # "humen": output[0][0], # "robot": output[0][1] # } if res["robot"] > 0.9: - gpt_score_list.append(res["robot"]) - sim_word += len(content_list[i]) - gpt_content.append( - "".format(str(i)) + content_list[i] + "。\n" + "") + for ci in range(sen_nums): + gpt_score_list.append(res["robot"]) + sim_word += len(content_list[i + ci]) + gpt_content.append( + "".format(str(i + ci)) + content_list[i + ci] + "。\n" + "") elif 0.9 > res["robot"] > 0.5: - gpt_score_list.append(res["robot"]) - sim_word_5_9 += len(content_list[i]) - gpt_content.append( - "".format(str(i)) + content_list[i] + "。\n" + "") + for ci in range(sen_nums): + gpt_score_list.append(res["robot"]) + sim_word_5_9 += len(content_list[i + ci]) + gpt_content.append( + "".format(str(i + ci)) + content_list[i + ci] + "。\n" + "") else: - gpt_score_list.append(0) - gpt_content.append(content_list[i] + "。\n") + for ci in range(sen_nums): + gpt_score_list.append(0) + gpt_content.append(content_list[i + ci] + "。\n") return_list["gpt_content"] = "".join(gpt_content) return_list["gpt_score_list"] = str(gpt_score_list) @@ -114,7 +132,6 @@ def main(content_list: list): def classify(): # 调用模型,设置最大batch_size while True: - try: if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 time.sleep(3) continue @@ -172,8 +189,7 @@ def classify(): # 调用模型,设置最大batch_size json.dump(return_text, f2, ensure_ascii=False, indent=4) redis_.set(queue_uuid, load_result_path, 86400) redis_.srem(db_key_querying, queue_uuid) - except: - continue + if __name__ == '__main__':