You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

61 lines
1.8 KiB

# -*- coding: utf-8 -*-
"""
@Time : 2023/3/10 18:53
@Author :
@FileName:
@Software:
@Describe:
"""
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# import pickle
# import redis
# from redis import ConnectionPool
# app = Flask(__name__)
import numpy as np
import pandas as pd
import json
from keras.layers import *
from tqdm import tqdm
import time
from src.basemodel import ClassifyModel
if __name__ == '__main__':
maxlen = 256
batch_size = 32
# bert配置
config_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt'
dict_path = 'chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt'
texts = ["我们有个好朋友"] * 34
print(texts)
classifymodel = ClassifyModel(config_path, checkpoint_path, dict_path, is_train=False, load_weights_path=None)
# data = classifymodel.data_generator(texts, batch_size)
# for token, segment in zip(data[0],data[1]):
# print(classifymodel.predict(token, segment).shape)
df_train_nuoche = pd.read_csv("data/10235513_大型商业建筑人员疏散设计研究_沈福禹/查重.csv",encoding="utf-8")
Data = []
for data_dan in df_train_nuoche.values.tolist():
Data.append(data_dan[0])
print(Data[0])
print(len(Data))
data = classifymodel.data_generator(Data, batch_size)
print(len(data[0][-1]))
# print(type(train_generator))
# d = next(train_generator)
# print(d)
a1 = np.empty((0, 768), dtype=int)
for token, segment in zip(data[0],data[1]):
a2 = classifymodel.predict(token, segment)
a1 = np.concatenate([a1, a2])
print(a1.shape)
np.save('data/10235513_大型商业建筑人员疏散设计研究_沈福禹/save_x', a1)