You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
30 lines
1.3 KiB
30 lines
1.3 KiB
import os
|
|
import sys
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
if __name__ == "__main__":
|
|
exportPath = sys.argv[1] if (sys.argv[1] is not None) else "baichuan-fp32.flm"
|
|
|
|
model_path = '/home/majiahui/project/models-llm/openbuddy-llama-7b-v1.4-fp16'
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_path,
|
|
device_map="auto",
|
|
trust_remote_code=True,
|
|
torch_dtype=torch.float16
|
|
# torch_dtype=torch.bfloat16
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
|
# special_tokens_dict = {"pad_token": "<unk>"}
|
|
# tokenizer.add_special_tokens(special_tokens_dict)
|
|
|
|
WEIGHTS_NAME = "adapter_model.bin"
|
|
checkpoint_dir = "/home/majiahui/project/LLaMA-Efficient-Tuning/path_to_sft_openbuddy_llama_paper_checkpoint_prompt_freeze_checkpoint/checkpoint-168000"
|
|
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
|
assert os.path.exists(weights_file), f"Provided path ({checkpoint_dir}) does not contain the pretrained weights."
|
|
model_state_dict = torch.load(weights_file, map_location="cuda")
|
|
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
|
device = torch.device("cpu")
|
|
model = model.to(device)
|
|
|
|
model.save_pretrained(exportPath)
|
|
|