Browse Source

第一次提交

master
majiahui@haimaqingfan.com 2 years ago
commit
c32bca42c9
  1. 144
      .gitignore
  2. 2
      README.md
  3. 674
      config_json/label_threshold.json
  4. 105
      demo06_class_roformer_pred.py

144
.gitignore

@ -0,0 +1,144 @@
# ---> Python
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
/chinese_roberta_wwm_ext_L-12_H-768_A-12/
/data/
/output_models/

2
README.md

@ -0,0 +1,2 @@
# drop_classify

674
config_json/label_threshold.json

@ -0,0 +1,674 @@
{
"矿业工程": [
0,
-0.5
],
"汽车工业": [
1,
-1.3
],
"哲学": [
2,
-0.6
],
"新能源": [
3,
-3.8
],
"机械工业": [
4,
-2.5
],
"贸易经济": [
5,
-1.7
],
"刑法": [
6,
0
],
"中国古代史": [
7,
-2.9
],
"投资": [
8,
-1.5
],
"肿瘤学": [
9,
0
],
"预防医学与卫生学": [
10,
-2.4
],
"水利水电工程": [
11,
-1.5
],
"中国近现代史": [
12,
-2.5
],
"中药学": [
13,
-0.3
],
"管理学": [
14,
-5.5
],
"公安": [
15,
-2.0
],
"国际法": [
16,
-0.3
],
"医药卫生方针政策与法律法规研究": [
17,
-2.1
],
"社会科学理论与方法": [
18,
-11.4
],
"高等教育": [
19,
-0.4
],
"经济统计": [
20,
-7.4
],
"天文学": [
21,
-1.3
],
"中国文学": [
22,
0
],
"史学理论": [
23,
-3.7
],
"燃料化工": [
24,
-2.2
],
"农作物": [
25,
-0.0
],
"军事医学与卫生": [
26,
-2.8
],
"行政法及地方法制": [
27,
-0.9
],
"无机化工": [
28,
-1.9
],
"社会学及统计学": [
29,
-3.3
],
"保险": [
30,
0
],
"金属学及金属工艺": [
31,
-0.2
],
"旅游": [
32,
-1.6
],
"仪器仪表工业": [
33,
-3.0
],
"中医学": [
34,
0
],
"领导学与决策学": [
35,
-5.4
],
"企业经济": [
36,
-1.0
],
"急救医学": [
37,
-1.0
],
"美术书法雕塑与摄影": [
38,
0
],
"自然地理学和测绘学": [
39,
-2.9
],
"园艺": [
40,
-0.9
],
"出版": [
41,
-2.4
],
"经济体制改革": [
42,
-1.0
],
"自动化技术": [
43,
-1.9
],
"神经病学": [
44,
-0.1
],
"海洋学": [
45,
-2.3
],
"人口学与计划生育": [
46,
-0.4
],
"轻工业手工业": [
47,
-1.2
],
"会计": [
48,
-1.0
],
"化学": [
49,
-1.7
],
"农业基础科学": [
50,
-1.8
],
"学前教育": [
51,
0
],
"中等教育": [
52,
0
],
"世界文学": [
53,
0
],
"中国语言文字": [
54,
0
],
"中国民族与地方史志": [
55,
-3.1
],
"新闻与传媒": [
56,
-0.2
],
"工业通用技术及设备": [
57,
-3.4
],
"文艺理论": [
58,
-0.9
],
"市场研究与信息": [
59,
-2.9
],
"呼吸系统疾病": [
60,
-1.3
],
"心血管系统疾病": [
61,
-0.8
],
"考古": [
62,
-1.4
],
"戏剧电影与电视艺术": [
63,
0
],
"畜牧与动物医学": [
64,
0
],
"体育": [
65,
0
],
"伦理学": [
66,
-2.3
],
"材料科学": [
67,
-1.3
],
"外科学": [
68,
-0.6
],
"民族学": [
69,
-2.6
],
"交通运输经济": [
70,
-2.4
],
"世界历史": [
71,
-1.5
],
"音乐舞蹈": [
72,
0
],
"铁路运输": [
73,
-0.4
],
"心理学": [
74,
-0.5
],
"诉讼法与司法制度": [
75,
0
],
"物理学": [
76,
-3.3
],
"初等教育": [
77,
0
],
"一般化学工业": [
78,
-2.2
],
"政党及群众组织": [
79,
-1.9
],
"自然科学理论与方法": [
80,
-7.9
],
"计算机软件及计算机应用": [
81,
-1.8
],
"蚕蜂与野生动物保护": [
82,
-2.2
],
"水产和渔业": [
83,
-0.9
],
"航空航天科学与工程": [
84,
-1.1
],
"内分泌腺及全身性疾病": [
85,
-0.8
],
"武器工业与军事技术": [
86,
-2.7
],
"无线电电子学": [
87,
-2.8
],
"临床医学": [
88,
-2.1
],
"资源科学": [
89,
-1.3
],
"经济理论及经济思想史": [
90,
-2.9
],
"民商法": [
91,
0
],
"服务业经济": [
92,
-3.8
],
"皮肤病与性病": [
93,
-1.7
],
"特种医学": [
94,
-3.2
],
"逻辑学": [
95,
-3.3
],
"工业经济": [
96,
-0.9
],
"数学": [
97,
-3.9
],
"宪法": [
98,
-1.7
],
"电力工业": [
99,
-1.0
],
"农艺学": [
100,
-2.3
],
"美学": [
101,
-3.2
],
"消化系统疾病": [
102,
-0.2
],
"军事": [
103,
-2.6
],
"感染性疾病及传染病": [
104,
-1.1
],
"公路与水路运输": [
105,
-1.7
],
"一般服务业": [
106,
-3.7
],
"动力工程": [
107,
-3.2
],
"计算机硬件技术": [
108,
-3.0
],
"核科学技术": [
109,
-1.8
],
"中国共产党": [
110,
-0.9
],
"成人教育与特殊教育": [
111,
-2.1
],
"船舶工业": [
112,
-0.9
],
"财政与税收": [
113,
0
],
"政治学": [
114,
-3.5
],
"农业经济": [
115,
-0.4
],
"审计": [
116,
0
],
"建筑科学与工程": [
117,
-1.2
],
"气象学": [
118,
-1.4
],
"教育理论与教育管理": [
119,
-1.3
],
"档案及博物馆": [
120,
-0.9
],
"环境科学与资源利用": [
121,
-0.9
],
"人才学与劳动科学": [
122,
-1.9
],
"植物保护": [
123,
-0.8
],
"中西医结合": [
124,
-3.1
],
"互联网技术": [
125,
-2.0
],
"药学": [
126,
-2.2
],
"思想政治教育": [
127,
-3.2
],
"儿科学": [
128,
0
],
"生物医学工程": [
129,
-3.1
],
"经济法": [
130,
-0.5
],
"法理、法史": [
131,
-2.4
],
"口腔科学": [
132,
0
],
"行政学及国家行政管理": [
133,
-2.8
],
"地质学": [
134,
-0.3
],
"非线性科学与系统科学": [
135,
-5.0
],
"宏观经济管理与可持续发展": [
136,
-2.2
],
"宗教": [
137,
-1.8
],
"精神病学": [
138,
-1.5
],
"眼科与耳鼻咽喉科": [
139,
0
],
"生物学": [
140,
-2.5
],
"外国语言文字": [
141,
0
],
"农业工程": [
142,
-1.8
],
"安全科学与灾害防治": [
143,
-2.7
],
"图书情报与数字图书馆": [
144,
-2.6
],
"泌尿科学": [
145,
-0.6
],
"力学": [
146,
-4.7
],
"文化": [
147,
-2.0
],
"地理": [
148,
-5.9
],
"中国通史": [
149,
-5.7
],
"妇产科学": [
150,
0
],
"信息经济与邮政经济": [
151,
-2.1
],
"金融": [
152,
-0.3
],
"医学教育与医学边缘学科": [
153,
-2.1
],
"文化经济": [
154,
-2.5
],
"基础医学": [
155,
-3.7
],
"职业教育": [
156,
0
],
"地球物理学": [
157,
-2.2
],
"林业": [
158,
-0.7
],
"石油天然气工业": [
159,
-0.4
],
"马克思主义": [
160,
-1.6
],
"中国政治与国际政治": [
161,
-2.7
],
"电信技术": [
162,
-1.0
],
"冶金工业": [
163,
-2.0
],
"有机化工": [
164,
-2.0
],
"科学研究管理": [
165,
-2.9
],
"证券": [
166,
-0.9
],
"人物传记": [
167,
-5.9
]
}

105
demo06_class_roformer_pred.py

@ -0,0 +1,105 @@
# -*- coding:utf-8 -*-
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import json
import random
import keras
import numpy as np
import pandas as pd
from bert4keras.backend import multilabel_categorical_crossentropy
from bert4keras.models import build_transformer_model
from bert4keras.optimizers import Adam
from bert4keras.snippets import DataGenerator, sequence_padding
from keras.layers import Lambda, Dense
from keras.models import Model
from bert4keras.tokenizers import Tokenizer
from tqdm import tqdm
import tensorflow as tf
from keras.backend import set_session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config)) # 此处不同
config_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/bert_model.ckpt'
dict_path = 'chinese_roformer-v2-char_L-12_H-768_A-12/vocab.txt'
class_nums = 168
batch_size = 16
max_len = 512
config_lable = './config_json/label_threshold.json'
weight_path = './output_models/best_model.weights'
tokenizer = Tokenizer(token_dict=dict_path)
roformer = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
model='roformer_v2',
return_keras_model=False
)
output = Lambda(lambda x: x[:, 0])(roformer.model.output)
output = Dense(
units=class_nums,
kernel_initializer=roformer.initializer
)(output)
model = Model(roformer.model.input, output)
model.load_weights(weight_path)
model.summary()
def load_label1():
with open(config_lable, 'r',
encoding='utf-8') as f:
labels_dict = json.load(f)
id2label1 = {j[0]: i for i, j in labels_dict.items()}
label2id1 = {i: j[0] for i, j in labels_dict.items()}
label_threshold1 = np.array([j[1] for i, j in labels_dict.items()])
return id2label1, label2id1, label_threshold1
id2label, label2id, label_threshold = load_label1()
def predict(text):
text = text[0]
sent_token_id, sent_segment_id = [], []
token_ids, segment_ids = tokenizer.encode(text, maxlen=max_len)
y_pred = model.predict([[token_ids], [segment_ids]])
idx = np.where(y_pred[0] > label_threshold, 1, 0)
label_pre = []
for i in range(len(idx)):
if idx[i] == 1:
label_pre.append(id2label[i])
return label_pre
if __name__ == '__main__':
# text_list = ["你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你你"]
# y_pred = predict(text_list)
# idx = np.where(y_pred[0] > label_threshold, 1, 0)
# label_pre = []
# for i in range(len(idx)):
# if idx[i] == 1:
# label_pre.append(id2label[i])
# print(label_pre)
#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
data = pd.read_csv("data/yy改写相似度.csv").values.tolist()
data_new = []
for data_dan in tqdm(data):
label_pre = predict([data_dan[0]])
label_pre = "".join(label_pre)
data_new.append(data_dan + [label_pre])
df = pd.DataFrame(data_new)
print(df)
df.to_csv("./data/yy改写相似度含文章类别.csv", index=None)
Loading…
Cancel
Save