You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
31 lines
1.3 KiB
31 lines
1.3 KiB
![]()
2 years ago
|
import os
|
||
|
import sys
|
||
|
import torch
|
||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
exportPath = sys.argv[1] if (sys.argv[1] is not None) else "baichuan-fp32.flm"
|
||
|
|
||
|
model_path = '/home/majiahui/project/models-llm/openbuddy-llama-7b-v1.4-fp16'
|
||
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
model_path,
|
||
|
device_map="auto",
|
||
|
trust_remote_code=True,
|
||
|
torch_dtype=torch.float16
|
||
|
# torch_dtype=torch.bfloat16
|
||
|
)
|
||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||
|
# special_tokens_dict = {"pad_token": "<unk>"}
|
||
|
# tokenizer.add_special_tokens(special_tokens_dict)
|
||
|
|
||
|
WEIGHTS_NAME = "adapter_model.bin"
|
||
|
checkpoint_dir = "/home/majiahui/project/LLaMA-Efficient-Tuning/path_to_sft_openbuddy_llama_paper_checkpoint_prompt_freeze_checkpoint/checkpoint-168000"
|
||
|
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
||
|
assert os.path.exists(weights_file), f"Provided path ({checkpoint_dir}) does not contain the pretrained weights."
|
||
|
model_state_dict = torch.load(weights_file, map_location="cuda")
|
||
|
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
||
|
device = torch.device("cpu")
|
||
|
model = model.to(device)
|
||
|
|
||
|
model.save_pretrained(exportPath)
|