You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
46 lines
1.6 KiB
46 lines
1.6 KiB
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
@Time : 2023/3/27 10:23
|
|
@Author :
|
|
@FileName:
|
|
@Software:
|
|
@Describe:
|
|
"""
|
|
import sys
|
|
import os
|
|
|
|
pre_model_path = {
|
|
"t5": {
|
|
# "linux": "/home/zc-nlp-zyp/work_file/ssd_data/模型库/预训练模型集合/keras/mt5/mt5_base",
|
|
"linux": "mt5/mt5_base",
|
|
"win32": r"E:\pycharm_workspace\premodel\keras\mt5\mt5_base"
|
|
},
|
|
|
|
}
|
|
|
|
|
|
class DropT5Config:
|
|
def __init__(self):
|
|
self.sys_platform = sys.platform
|
|
self.premodel_path = pre_model_path["t5"][self.sys_platform]
|
|
self.config_path = os.path.join(self.premodel_path, 'mt5_base_config.json')
|
|
self.checkpoint_path = os.path.join(self.premodel_path, 'model.ckpt-1000000')
|
|
self.spm_path = os.path.join(self.premodel_path, 'sentencepiece_cn.model')
|
|
self.keep_tokens_path = os.path.join(self.premodel_path, 'sentencepiece_cn_keep_tokens.json')
|
|
self.savemodel_path = "./output_t5/best_model_t5_0724.weights"
|
|
self.maxlen = 256
|
|
self.cuda_id = "0"
|
|
|
|
|
|
class MultipleResultsDropT5Config:
|
|
def __init__(self):
|
|
self.sys_platform = sys.platform
|
|
self.premodel_path = pre_model_path["t5"][self.sys_platform]
|
|
self.config_path = os.path.join(self.premodel_path, 'mt5_base_config.json')
|
|
self.checkpoint_path = os.path.join(self.premodel_path, 'model.ckpt-1000000')
|
|
self.spm_path = os.path.join(self.premodel_path, 'sentencepiece_cn.model')
|
|
self.keep_tokens_path = os.path.join(self.premodel_path, 'sentencepiece_cn_keep_tokens.json')
|
|
self.savemodel_path = "./output_t5/best_model_t5_0724.weights"
|
|
self.maxlen = 256
|
|
self.cuda_id = "0"
|