From c32bca42c9f34d47752ac6a2e51c5893c8af35f0 Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Mon, 27 Feb 2023 11:20:54 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AC=AC=E4=B8=80=E6=AC=A1=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 144 +++++++++ README.md | 2 + config_json/label_threshold.json | 674 +++++++++++++++++++++++++++++++++++++++ demo06_class_roformer_pred.py | 105 ++++++ 4 files changed, 925 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 config_json/label_threshold.json create mode 100644 demo06_class_roformer_pred.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0f84355 --- /dev/null +++ b/.gitignore @@ -0,0 +1,144 @@ +# ---> Python +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +/chinese_roberta_wwm_ext_L-12_H-768_A-12/ +/data/ +/output_models/ + diff --git a/README.md b/README.md new file mode 100644 index 0000000..48b4401 --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# drop_classify + diff --git a/config_json/label_threshold.json b/config_json/label_threshold.json new file mode 100644 index 0000000..32007f1 --- /dev/null +++ b/config_json/label_threshold.json @@ -0,0 +1,674 @@ +{ + "矿业工程": [ + 0, + -0.5 + ], + "汽车工业": [ + 1, + -1.3 + ], + "哲学": [ + 2, + -0.6 + ], + "新能源": [ + 3, + -3.8 + ], + "机械工业": [ + 4, + -2.5 + ], + "贸易经济": [ + 5, + -1.7 + ], + "刑法": [ + 6, + 0 + ], + "中国古代史": [ + 7, + -2.9 + ], + "投资": [ + 8, + -1.5 + ], + "肿瘤学": [ + 9, + 0 + ], + "预防医学与卫生学": [ + 10, + -2.4 + ], + "水利水电工程": [ + 11, + -1.5 + ], + "中国近现代史": [ + 12, + -2.5 + ], + "中药学": [ + 13, + -0.3 + ], + "管理学": [ + 14, + -5.5 + ], + "公安": [ + 15, + -2.0 + ], + "国际法": [ + 16, + -0.3 + ], + "医药卫生方针政策与法律法规研究": [ + 17, + -2.1 + ], + "社会科学理论与方法": [ + 18, + -11.4 + ], + "高等教育": [ + 19, + -0.4 + ], + "经济统计": [ + 20, + -7.4 + ], + "天文学": [ + 21, + -1.3 + ], + "中国文学": [ + 22, + 0 + ], + "史学理论": [ + 23, + -3.7 + ], + "燃料化工": [ + 24, + -2.2 + ], + "农作物": [ + 25, + -0.0 + ], + "军事医学与卫生": [ + 26, + -2.8 + ], + "行政法及地方法制": [ + 27, + -0.9 + ], + "无机化工": [ + 28, + -1.9 + ], + "社会学及统计学": [ + 29, + -3.3 + ], + "保险": [ + 30, + 0 + ], + "金属学及金属工艺": [ + 31, + -0.2 + ], + "旅游": [ + 32, + -1.6 + ], + "仪器仪表工业": [ + 33, + -3.0 + ], + "中医学": [ + 34, + 0 + ], + "领导学与决策学": [ + 35, + -5.4 + ], + "企业经济": [ + 36, + -1.0 + ], + "急救医学": [ + 37, + -1.0 + ], + "美术书法雕塑与摄影": [ + 38, + 0 + ], + "自然地理学和测绘学": [ + 39, + -2.9 + ], + "园艺": [ + 40, + -0.9 + ], + "出版": [ + 41, + -2.4 + ], + "经济体制改革": [ + 42, + -1.0 + ], + "自动化技术": [ + 43, + -1.9 + ], + "神经病学": [ + 44, + -0.1 + ], + "海洋学": [ + 45, + -2.3 + ], + "人口学与计划生育": [ + 46, + -0.4 + ], + "轻工业手工业": [ + 47, + -1.2 + ], + "会计": [ + 48, + -1.0 + ], + "化学": [ + 49, + -1.7 + ], + "农业基础科学": [ + 50, + -1.8 + ], + "学前教育": [ + 51, + 0 + ], + "中等教育": [ + 52, + 0 + ], + "世界文学": [ + 53, + 0 + ], + "中国语言文字": [ + 54, + 0 + ], + "中国民族与地方史志": [ + 55, + -3.1 + ], + "新闻与传媒": [ + 56, + -0.2 + ], + "工业通用技术及设备": [ + 57, + -3.4 + ], + "文艺理论": [ + 58, + -0.9 + ], + "市场研究与信息": [ + 59, + -2.9 + ], + "呼吸系统疾病": [ + 60, + -1.3 + ], + "心血管系统疾病": [ + 61, + -0.8 + ], + "考古": [ + 62, + -1.4 + ], + "戏剧电影与电视艺术": [ + 63, + 0 + ], + "畜牧与动物医学": [ + 64, + 0 + ], + "体育": [ + 65, + 0 + ], + "伦理学": [ + 66, + -2.3 + ], + "材料科学": [ + 67, + -1.3 + ], + "外科学": [ + 68, + -0.6 + ], + "民族学": [ + 69, + -2.6 + ], + "交通运输经济": [ + 70, + -2.4 + ], + "世界历史": [ + 71, + -1.5 + ], + "音乐舞蹈": [ + 72, + 0 + ], + "铁路运输": [ + 73, + -0.4 + ], + "心理学": [ + 74, + -0.5 + ], + "诉讼法与司法制度": [ + 75, + 0 + ], + "物理学": [ + 76, + -3.3 + ], + "初等教育": [ + 77, + 0 + ], + "一般化学工业": [ + 78, + -2.2 + ], + "政党及群众组织": [ + 79, + -1.9 + ], + "自然科学理论与方法": [ + 80, + -7.9 + ], + "计算机软件及计算机应用": [ + 81, + -1.8 + ], + "蚕蜂与野生动物保护": [ + 82, + -2.2 + ], + "水产和渔业": [ + 83, + -0.9 + ], + "航空航天科学与工程": [ + 84, + -1.1 + ], + "内分泌腺及全身性疾病": [ + 85, + -0.8 + ], + "武器工业与军事技术": [ + 86, + -2.7 + ], + "无线电电子学": [ + 87, + -2.8 + ], + "临床医学": [ + 88, + -2.1 + ], + "资源科学": [ + 89, + -1.3 + ], + "经济理论及经济思想史": [ + 90, + -2.9 + ], + "民商法": [ + 91, + 0 + ], + "服务业经济": [ + 92, + -3.8 + ], + "皮肤病与性病": [ + 93, + -1.7 + ], + "特种医学": [ + 94, + -3.2 + ], + "逻辑学": [ + 95, + -3.3 + ], + "工业经济": [ + 96, + -0.9 + ], + "数学": [ + 97, + -3.9 + ], + "宪法": [ + 98, + -1.7 + ], + "电力工业": [ + 99, + -1.0 + ], + "农艺学": [ + 100, + -2.3 + ], + "美学": [ + 101, + -3.2 + ], + "消化系统疾病": [ + 102, + -0.2 + ], + "军事": [ + 103, + -2.6 + ], + "感染性疾病及传染病": [ + 104, + -1.1 + ], + "公路与水路运输": [ + 105, + -1.7 + ], + "一般服务业": [ + 106, + -3.7 + ], + "动力工程": [ + 107, + -3.2 + ], + "计算机硬件技术": [ + 108, + -3.0 + ], + "核科学技术": [ + 109, + -1.8 + ], + "中国共产党": [ + 110, + -0.9 + ], + "成人教育与特殊教育": [ + 111, + -2.1 + ], + "船舶工业": [ + 112, + -0.9 + ], + "财政与税收": [ + 113, + 0 + ], + "政治学": [ + 114, + -3.5 + ], + "农业经济": [ + 115, + -0.4 + ], + "审计": [ + 116, + 0 + ], + "建筑科学与工程": [ + 117, + -1.2 + ], + "气象学": [ + 118, + -1.4 + ], + "教育理论与教育管理": [ + 119, + -1.3 + ], + "档案及博物馆": [ + 120, + -0.9 + ], + "环境科学与资源利用": [ + 121, + -0.9 + ], + "人才学与劳动科学": [ + 122, + -1.9 + ], + "植物保护": [ + 123, + -0.8 + ], + "中西医结合": [ + 124, + -3.1 + ], + "互联网技术": [ + 125, + -2.0 + ], + "药学": [ + 126, + -2.2 + ], + "思想政治教育": [ + 127, + -3.2 + ], + "儿科学": [ + 128, + 0 + ], + "生物医学工程": [ + 129, + -3.1 + ], + "经济法": [ + 130, + -0.5 + ], + "法理、法史": [ + 131, + -2.4 + ], + "口腔科学": [ + 132, + 0 + ], + "行政学及国家行政管理": [ + 133, + -2.8 + ], + "地质学": [ + 134, + -0.3 + ], + "非线性科学与系统科学": [ + 135, + -5.0 + ], + "宏观经济管理与可持续发展": [ + 136, + -2.2 + ], + "宗教": [ + 137, + -1.8 + ], + "精神病学": [ + 138, + -1.5 + ], + "眼科与耳鼻咽喉科": [ + 139, + 0 + ], + "生物学": [ + 140, + -2.5 + ], + "外国语言文字": [ + 141, + 0 + ], + "农业工程": [ + 142, + -1.8 + ], + "安全科学与灾害防治": [ + 143, + -2.7 + ], + "图书情报与数字图书馆": [ + 144, + -2.6 + ], + "泌尿科学": [ + 145, + -0.6 + ], + "力学": [ + 146, + -4.7 + ], + "文化": [ + 147, + -2.0 + ], + "地理": [ + 148, + -5.9 + ], + "中国通史": [ + 149, + -5.7 + ], + "妇产科学": [ + 150, + 0 + ], + "信息经济与邮政经济": [ + 151, + -2.1 + ], + "金融": [ + 152, + -0.3 + ], + "医学教育与医学边缘学科": [ + 153, + -2.1 + ], + "文化经济": [ + 154, + -2.5 + ], + "基础医学": [ + 155, + -3.7 + ], + "职业教育": [ + 156, + 0 + ], + "地球物理学": [ + 157, + -2.2 + ], + "林业": [ + 158, + -0.7 + ], + "石油天然气工业": [ + 159, + -0.4 + ], + "马克思主义": [ + 160, + -1.6 + ], + "中国政治与国际政治": [ + 161, + -2.7 + ], + "电信技术": [ + 162, + -1.0 + ], + "冶金工业": [ + 163, + -2.0 + ], + "有机化工": [ + 164, + -2.0 + ], + "科学研究管理": [ + 165, + -2.9 + ], + "证券": [ + 166, + -0.9 + ], + "人物传记": [ + 167, + -5.9 + ] +} \ No newline at end of file diff --git a/demo06_class_roformer_pred.py b/demo06_class_roformer_pred.py new file mode 100644 index 0000000..9fdc57b --- /dev/null +++ b/demo06_class_roformer_pred.py @@ -0,0 +1,105 @@ +# -*- coding:utf-8 -*- + +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "1" +import json +import random +import keras +import numpy as np +import pandas as pd +from bert4keras.backend import multilabel_categorical_crossentropy +from bert4keras.models import build_transformer_model +from bert4keras.optimizers import Adam +from bert4keras.snippets import DataGenerator, sequence_padding +from keras.layers import Lambda, Dense +from keras.models import Model +from bert4keras.tokenizers import Tokenizer +from tqdm import tqdm + +import tensorflow as tf +from keras.backend import set_session +config = tf.ConfigProto() +config.gpu_options.allow_growth = True +set_session(tf.Session(config=config)) # 此处不同 + +config_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/bert_config.json' +checkpoint_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/bert_model.ckpt' +dict_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/vocab.txt' + + +class_nums = 168 +batch_size = 16 +max_len = 512 + +config_lable = './config_json/label_threshold.json' +weight_path = './output_models/best_model.weights' + + +tokenizer = Tokenizer(token_dict=dict_path) + +roformer = build_transformer_model( + config_path=config_path, + checkpoint_path=checkpoint_path, + model='roformer_v2', + return_keras_model=False +) + +output = Lambda(lambda x: x[:, 0])(roformer.model.output) + +output = Dense( + units=class_nums, + kernel_initializer=roformer.initializer +)(output) + +model = Model(roformer.model.input, output) +model.load_weights(weight_path) +model.summary() + + +def load_label1(): + with open(config_lable, 'r', + encoding='utf-8') as f: + labels_dict = json.load(f) + + id2label1 = {j[0]: i for i, j in labels_dict.items()} + label2id1 = {i: j[0] for i, j in labels_dict.items()} + label_threshold1 = np.array([j[1] for i, j in labels_dict.items()]) + + return id2label1, label2id1, label_threshold1 + +id2label, label2id, label_threshold = load_label1() + +def predict(text): + text = text[0] + sent_token_id, sent_segment_id = [], [] + token_ids, segment_ids = tokenizer.encode(text, maxlen=max_len) + y_pred = model.predict([[token_ids], [segment_ids]]) + idx = np.where(y_pred[0] > label_threshold, 1, 0) + label_pre = [] + for i in range(len(idx)): + if idx[i] == 1: + label_pre.append(id2label[i]) + return label_pre + + +if __name__ == '__main__': + # text_list = ["你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你"] + # y_pred = predict(text_list) + # idx = np.where(y_pred[0] > label_threshold, 1, 0) + # label_pre = [] + # for i in range(len(idx)): + # if idx[i] == 1: + # label_pre.append(id2label[i]) + # print(label_pre) + + #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + data = pd.read_csv("data/yy改写相似度.csv").values.tolist() + + data_new = [] + for data_dan in tqdm(data): + label_pre = predict([data_dan[0]]) + label_pre = ",".join(label_pre) + data_new.append(data_dan + [label_pre]) + df = pd.DataFrame(data_new) + print(df) + df.to_csv("./data/yy改写相似度含文章类别.csv", index=None) \ No newline at end of file