From f3de68efb8bf3741c6b036a96ad63801f4618902 Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Mon, 27 Feb 2023 12:01:25 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=94=B9=E4=B8=BA=E6=8C=89=E7=85=A7?= =?UTF-8?q?=E6=A0=87=E7=AD=BE=E7=BD=AE=E4=BF=A1=E5=BA=A6=E6=8E=92=E5=BA=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- demo06_class_roformer_pred.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/demo06_class_roformer_pred.py b/demo06_class_roformer_pred.py index 9fdc57b..1bbf282 100644 --- a/demo06_class_roformer_pred.py +++ b/demo06_class_roformer_pred.py @@ -1,7 +1,7 @@ # -*- coding:utf-8 -*- import os -os.environ["CUDA_VISIBLE_DEVICES"] = "1" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" import json import random import keras @@ -74,12 +74,25 @@ def predict(text): sent_token_id, sent_segment_id = [], [] token_ids, segment_ids = tokenizer.encode(text, maxlen=max_len) y_pred = model.predict([[token_ids], [segment_ids]]) - idx = np.where(y_pred[0] > label_threshold, 1, 0) + # print(len(y_pred[0])) + label_probability = {} + for index, threshold in zip(range(len(y_pred[0])), label_threshold): + label_probability[index] = y_pred[0][index] - threshold + label_probability = sorted(label_probability.items(), + key=lambda x: x[1], reverse=True) label_pre = [] - for i in range(len(idx)): - if idx[i] == 1: - label_pre.append(id2label[i]) + for i in label_probability: + if i[1] > 0: + label_pre.append(id2label[i[0]]) + if label_pre == []: + label_pre.append(id2label[label_probability[0][0]]) return label_pre + # idx = np.where(y_pred[0] > label_threshold, 1, 0) + # label_pre = [] + # for i in range(len(idx)): + # if idx[i] == 1: + # label_pre.append(id2label[i]) + # return label_pre if __name__ == '__main__': @@ -101,5 +114,4 @@ if __name__ == '__main__': label_pre = ",".join(label_pre) data_new.append(data_dan + [label_pre]) df = pd.DataFrame(data_new) - print(df) df.to_csv("./data/yy改写相似度含文章类别.csv", index=None) \ No newline at end of file