|
|
@ -4,15 +4,15 @@ |
|
|
|
|
|
|
|
|
|
|
|
from transformers import HfArgumentParser, TrainingArguments |
|
|
|
from utils import ModelArguments, load_pretrained |
|
|
|
from utils import ModelArguments, FinetuningArguments, load_pretrained |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, TrainingArguments)) |
|
|
|
model_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
parser = HfArgumentParser((ModelArguments, TrainingArguments, FinetuningArguments)) |
|
|
|
model_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() |
|
|
|
|
|
|
|
model, tokenizer = load_pretrained(model_args) |
|
|
|
model, tokenizer = load_pretrained(model_args, finetuning_args) |
|
|
|
model.save_pretrained(training_args.output_dir, max_shard_size="1GB") |
|
|
|
tokenizer.save_pretrained(training_args.output_dir) |
|
|
|
|
|
|
|