Browse Source

fix checkpoint loading

main
hiyouga 2 years ago
parent
commit
c0e5df92d6
  1. 8
      src/cli_demo.py
  2. 57
      src/utils/common.py
  3. 3
      src/utils/config.py
  4. 11
      src/web_demo.py

8
src/cli_demo.py

@ -21,8 +21,14 @@ def main():
model = model.cuda() model = model.cuda()
model.eval() model.eval()
def format_example(query):
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
return prompt
def predict(query, history: list): def predict(query, history: list):
inputs = tokenizer([query], return_tensors="pt") inputs = tokenizer([format_example(query)], return_tensors="pt")
inputs = inputs.to(model.device) inputs = inputs.to(model.device)
gen_kwargs = { gen_kwargs = {
"do_sample": True, "do_sample": True,

57
src/utils/common.py

@ -2,6 +2,7 @@ import os
import sys import sys
import torch import torch
import hashlib import hashlib
from itertools import chain
from typing import List, Literal, Optional, Tuple from typing import List, Literal, Optional, Tuple
import transformers import transformers
@ -84,6 +85,8 @@ def init_adapter(
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None: if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
if len(model_args.checkpoint_dir) > 1:
logger.warning("Only LoRA tuning accepts multiple checkpoints.")
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
@ -117,6 +120,9 @@ def init_adapter(
) )
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)
if model_args.checkpoint_dir is not None:
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
return model return model
@ -131,19 +137,14 @@ def load_pretrained(
Support both training and inference. Support both training and inference.
""" """
if finetuning_args is None: # load the fine-tuning arguments
if (not is_trainable) and (model_args.checkpoint_dir is None): if model_args.checkpoint_dir is None:
logger.warning("Checkpoint is not found at evaluation, load the original model.") logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none") finetuning_args = FinetuningArguments(finetuning_type="none")
elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)):
if model_args.checkpoint_dir is not None: # load fine-tuned model from checkpoint finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME))
for checkpoint_dir in model_args.checkpoint_dir: else:
if not os.path.isfile(os.path.join(checkpoint_dir, FINETUNING_ARGS_NAME)): raise ValueError("Missing fine-tuning arguments in the provided dictionary.")
raise ValueError("The fine-tuning arguments are not found in the provided dictionary.")
logger.info("Load fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME))
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) > 1:
logger.warning("Only LoRA tuning accepts multiple checkpoints.")
assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method." assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method."
@ -350,7 +351,7 @@ def preprocess_data(
if examples["prompt"][i] and examples["response"][i]: if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i] query, answer = examples["prompt"][i], examples["response"][i]
if examples["query"][i]: if examples["query"][i]:
query += examples["query"][i] query += "\n" + examples["query"][i]
prompt = "Below is an instruction that describes a task. " prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n" prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\n" + prefix prompt += "Instruction:\n" + prefix
@ -361,6 +362,20 @@ def preprocess_data(
prompt += "Human: {}\nAssistant: ".format(query) prompt += "Human: {}\nAssistant: ".format(query)
yield prompt, answer yield prompt, answer
def preprocess_pretrain_dataset(examples):
# build grouped texts with format `<s>??`
text_ids = tokenizer(examples["prompt"])["input_ids"]
concatenated_ids = list(chain(*text_ids))
total_length = len(concatenated_ids)
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // data_args.max_source_length) * data_args.max_source_length
# split by chunks of max_source_length
result = [concatenated_ids[i: i+data_args.max_source_length] for i in range(0, total_length, data_args.max_source_length)]
return {
"input_ids": result,
"labels": result.copy()
}
def preprocess_supervised_dataset(examples): def preprocess_supervised_dataset(examples):
# build inputs with format `X <s> Y </s>` and labels with format `<ignore> ... <ignore> <s> Y </s>` # build inputs with format `X <s> Y </s>` and labels with format `<ignore> ... <ignore> <s> Y </s>`
model_inputs = {"input_ids": [], "labels": []} model_inputs = {"input_ids": [], "labels": []}
@ -425,7 +440,9 @@ def preprocess_data(
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
print("label_ids:\n{}".format(example["labels"])) print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]]))) print("labels:\n{}".format(
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]]))
)
def print_pairwise_dataset_example(example): def print_pairwise_dataset_example(example):
print("accept_ids:\n{}".format(example["accept_ids"])) print("accept_ids:\n{}".format(example["accept_ids"]))
@ -437,11 +454,11 @@ def preprocess_data(
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
if stage == "sft": if stage == "pt":
if (not training_args.do_train) and training_args.predict_with_generate: # with generation preprocess_function = preprocess_pretrain_dataset
preprocess_function = preprocess_evaluation_dataset elif stage == "sft":
else: # without generation preprocess_function = preprocess_evaluation_dataset \
preprocess_function = preprocess_supervised_dataset if training_args.predict_with_generate else preprocess_supervised_dataset
elif stage == "rm": elif stage == "rm":
preprocess_function = preprocess_pairwise_dataset preprocess_function = preprocess_pairwise_dataset
elif stage == "ppo": elif stage == "ppo":

3
src/utils/config.py

@ -194,7 +194,8 @@ class FinetuningArguments:
if self.name_module_trainable == "mlp": if self.name_module_trainable == "mlp":
self.trainable_layers = ["layers.{:d}.mlp".format(idx) for idx in trainable_layer_ids] self.trainable_layers = ["layers.{:d}.mlp".format(idx) for idx in trainable_layer_ids]
elif self.name_module_trainable == "qkv": elif self.name_module_trainable == "qkv":
self.trainable_layers = ["layers.{:d}.attention.query_key_value".format(idx) for idx in trainable_layer_ids] self.trainable_layers = ["layers.{:d}.self_attn.{}".format(idx, proj) \
for proj in ["k_proj", "q_proj", "v_proj", "o_proj"] for idx in trainable_layer_ids]
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."

11
src/web_demo.py

@ -9,8 +9,10 @@ import gradio as gr
from utils import ModelArguments, auto_configure_device_map, load_pretrained from utils import ModelArguments, auto_configure_device_map, load_pretrained
from transformers import HfArgumentParser from transformers import HfArgumentParser
from transformers.utils.versions import require_version
require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems
parser = HfArgumentParser(ModelArguments) parser = HfArgumentParser(ModelArguments)
model_args, = parser.parse_args_into_dataclasses() model_args, = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args) model, tokenizer = load_pretrained(model_args)
@ -71,10 +73,17 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
return text return text
def format_example(query):
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query)
return prompt
def predict(input, chatbot, max_length, top_p, temperature, history): def predict(input, chatbot, max_length, top_p, temperature, history):
chatbot.append((parse_text(input), "")) chatbot.append((parse_text(input), ""))
inputs = tokenizer([input], return_tensors="pt") inputs = tokenizer([format_example(input)], return_tensors="pt")
inputs = inputs.to(model.device) inputs = inputs.to(model.device)
gen_kwargs = { gen_kwargs = {
"do_sample": True, "do_sample": True,

Loading…
Cancel
Save