You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
1.8 KiB
61 lines
1.8 KiB
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
@Time : 2023/3/10 18:53
|
|
@Author :
|
|
@FileName:
|
|
@Software:
|
|
@Describe:
|
|
"""
|
|
import os
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
# import pickle
|
|
# import redis
|
|
# from redis import ConnectionPool
|
|
# app = Flask(__name__)
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
import json
|
|
from keras.layers import *
|
|
from tqdm import tqdm
|
|
import time
|
|
from src.basemodel import ClassifyModel
|
|
|
|
|
|
if __name__ == '__main__':
|
|
maxlen = 256
|
|
batch_size = 32
|
|
# bert配置
|
|
config_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_config.json'
|
|
checkpoint_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt'
|
|
dict_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt'
|
|
|
|
texts = ["我们有个好朋友"] * 34
|
|
print(texts)
|
|
classifymodel = ClassifyModel(config_path, checkpoint_path, dict_path, is_train=False, load_weights_path=None)
|
|
# data = classifymodel.data_generator(texts, batch_size)
|
|
# for token, segment in zip(data[0],data[1]):
|
|
# print(classifymodel.predict(token, segment).shape)
|
|
|
|
df_train_nuoche = pd.read_csv("data/10235513_大型商业建筑人员疏散设计研究_沈福禹/查重.csv",encoding="utf-8")
|
|
Data = []
|
|
for data_dan in df_train_nuoche.values.tolist():
|
|
Data.append(data_dan[0])
|
|
print(Data[0])
|
|
print(len(Data))
|
|
|
|
data = classifymodel.data_generator(Data, batch_size)
|
|
|
|
print(len(data[0][-1]))
|
|
# print(type(train_generator))
|
|
# d = next(train_generator)
|
|
# print(d)
|
|
a1 = np.empty((0, 768), dtype=int)
|
|
for token, segment in zip(data[0],data[1]):
|
|
a2 = classifymodel.predict(token, segment)
|
|
a1 = np.concatenate([a1, a2])
|
|
|
|
print(a1.shape)
|
|
np.save('data/10235513_大型商业建筑人员疏散设计研究_沈福禹/save_x', a1)
|