diff --git a/demo06_class_roformer_pred.py b/demo06_class_roformer_pred.py index 9fdc57b..1bbf282 100644 --- a/demo06_class_roformer_pred.py +++ b/demo06_class_roformer_pred.py @@ -1,7 +1,7 @@ # -*- coding:utf-8 -*- import os -os.environ["CUDA_VISIBLE_DEVICES"] = "1" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" import json import random import keras @@ -74,12 +74,25 @@ def predict(text): sent_token_id, sent_segment_id = [], [] token_ids, segment_ids = tokenizer.encode(text, maxlen=max_len) y_pred = model.predict([[token_ids], [segment_ids]]) - idx = np.where(y_pred[0] > label_threshold, 1, 0) + # print(len(y_pred[0])) + label_probability = {} + for index, threshold in zip(range(len(y_pred[0])), label_threshold): + label_probability[index] = y_pred[0][index] - threshold + label_probability = sorted(label_probability.items(), + key=lambda x: x[1], reverse=True) label_pre = [] - for i in range(len(idx)): - if idx[i] == 1: - label_pre.append(id2label[i]) + for i in label_probability: + if i[1] > 0: + label_pre.append(id2label[i[0]]) + if label_pre == []: + label_pre.append(id2label[label_probability[0][0]]) return label_pre + # idx = np.where(y_pred[0] > label_threshold, 1, 0) + # label_pre = [] + # for i in range(len(idx)): + # if idx[i] == 1: + # label_pre.append(id2label[i]) + # return label_pre if __name__ == '__main__': @@ -101,5 +114,4 @@ if __name__ == '__main__': label_pre = ",".join(label_pre) data_new.append(data_dan + [label_pre]) df = pd.DataFrame(data_new) - print(df) df.to_csv("./data/yy改写相似度含文章类别.csv", index=None) \ No newline at end of file