|
@ -1,7 +1,7 @@ |
|
|
# -*- coding:utf-8 -*- |
|
|
# -*- coding:utf-8 -*- |
|
|
|
|
|
|
|
|
import os |
|
|
import os |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1" |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
import json |
|
|
import json |
|
|
import random |
|
|
import random |
|
|
import keras |
|
|
import keras |
|
@ -74,12 +74,25 @@ def predict(text): |
|
|
sent_token_id, sent_segment_id = [], [] |
|
|
sent_token_id, sent_segment_id = [], [] |
|
|
token_ids, segment_ids = tokenizer.encode(text, maxlen=max_len) |
|
|
token_ids, segment_ids = tokenizer.encode(text, maxlen=max_len) |
|
|
y_pred = model.predict([[token_ids], [segment_ids]]) |
|
|
y_pred = model.predict([[token_ids], [segment_ids]]) |
|
|
idx = np.where(y_pred[0] > label_threshold, 1, 0) |
|
|
# print(len(y_pred[0])) |
|
|
|
|
|
label_probability = {} |
|
|
|
|
|
for index, threshold in zip(range(len(y_pred[0])), label_threshold): |
|
|
|
|
|
label_probability[index] = y_pred[0][index] - threshold |
|
|
|
|
|
label_probability = sorted(label_probability.items(), |
|
|
|
|
|
key=lambda x: x[1], reverse=True) |
|
|
label_pre = [] |
|
|
label_pre = [] |
|
|
for i in range(len(idx)): |
|
|
for i in label_probability: |
|
|
if idx[i] == 1: |
|
|
if i[1] > 0: |
|
|
label_pre.append(id2label[i]) |
|
|
label_pre.append(id2label[i[0]]) |
|
|
|
|
|
if label_pre == []: |
|
|
|
|
|
label_pre.append(id2label[label_probability[0][0]]) |
|
|
return label_pre |
|
|
return label_pre |
|
|
|
|
|
# idx = np.where(y_pred[0] > label_threshold, 1, 0) |
|
|
|
|
|
# label_pre = [] |
|
|
|
|
|
# for i in range(len(idx)): |
|
|
|
|
|
# if idx[i] == 1: |
|
|
|
|
|
# label_pre.append(id2label[i]) |
|
|
|
|
|
# return label_pre |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
if __name__ == '__main__': |
|
@ -101,5 +114,4 @@ if __name__ == '__main__': |
|
|
label_pre = ",".join(label_pre) |
|
|
label_pre = ",".join(label_pre) |
|
|
data_new.append(data_dan + [label_pre]) |
|
|
data_new.append(data_dan + [label_pre]) |
|
|
df = pd.DataFrame(data_new) |
|
|
df = pd.DataFrame(data_new) |
|
|
print(df) |
|
|
|
|
|
df.to_csv("./data/yy改写相似度含文章类别.csv", index=None) |
|
|
df.to_csv("./data/yy改写相似度含文章类别.csv", index=None) |