Browse Source

support RM metrics, add generating Args

main
hiyouga 2 years ago
parent
commit
cec6524d6b
  1. BIN
      assets/wechat.jpg
  2. 100846
      data/comparison_gpt4_data_en.json
  3. 73084
      data/comparison_gpt4_data_zh.json
  4. 13
      data/dataset_info.json
  5. 8
      requirements.txt
  6. 17
      src/api_demo.py
  7. 17
      src/cli_demo.py
  8. 11
      src/train_ppo.py
  9. 9
      src/train_pt.py
  10. 12
      src/train_rm.py
  11. 10
      src/train_sft.py
  12. 2
      src/utils/__init__.py
  13. 56
      src/utils/common.py
  14. 91
      src/utils/config.py
  15. 3
      src/utils/data_collator.py
  16. 12
      src/utils/pairwise.py
  17. 16
      src/utils/template.py
  18. 10
      src/web_demo.py

BIN
assets/wechat.jpg

Binary file not shown.

Before

Width:  |  Height:  |  Size: 141 KiB

After

Width:  |  Height:  |  Size: 146 KiB

100846
data/comparison_gpt4_data_en.json

File diff suppressed because it is too large

73084
data/comparison_gpt4_data_zh.json

File diff suppressed because it is too large

13
data/dataset_info.json

@ -79,11 +79,11 @@
},
"comparison_gpt4_en": {
"file_name": "comparison_gpt4_data_en.json",
"file_sha1": "eeb295ce0ab011c37af52596460c8a57d07ad19f"
"file_sha1": "96fa18313544e22444fe20eead7754b17da452ae"
},
"comparison_gpt4_zh": {
"file_name": "comparison_gpt4_data_zh.json",
"file_sha1": "b99a41c1c864019d9b0c07dbcd5df0560cf33ce0"
"file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd"
},
"hh_rlhf_en": {
"script_url": "hh_rlhf_en",
@ -103,14 +103,5 @@
"response": "",
"history": ""
}
},
"pretrain_data": {
"file_name": "pretrain_data",
"columns": {
"prompt": "content",
"query": "",
"response": "",
"history": ""
}
}
}

8
requirements.txt

@ -2,11 +2,11 @@ torch>=1.13.1
protobuf
cpm_kernels
sentencepiece
transformers>=4.27.4
datasets>=2.10.0
accelerate>=0.18.0
transformers>=4.29.1
datasets>=2.12.0
accelerate>=0.19.0
peft>=0.3.0
trl>=0.4.1
trl>=0.4.4
jieba
rouge_chinese
nltk

17
src/api_demo.py

@ -42,7 +42,7 @@ app = FastAPI()
@app.post("/")
async def create_item(request: Request):
global model, tokenizer, prompt_template
global model, tokenizer, prompt_template, generating_args
# Parse the request JSON
json_post_raw = await request.json()
@ -56,16 +56,9 @@ async def create_item(request: Request):
input_ids = input_ids.to(model.device)
# Generation arguments
gen_kwargs = {
"input_ids": input_ids,
"do_sample": True,
"top_p": 0.7,
"temperature": 0.95,
"num_beams": 1,
"max_new_tokens": 512,
"repetition_penalty": 1.0,
"logits_processor": get_logits_processor()
}
gen_kwargs = generating_args.to_dict()
gen_kwargs["input_ids"] = input_ids
gen_kwargs["logits_processor"] = get_logits_processor()
# Generate response
with torch.no_grad():
@ -95,7 +88,7 @@ async def create_item(request: Request):
if __name__ == "__main__":
model_args, data_args, finetuning_args = prepare_infer_args()
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args)
prompt_template = Template(data_args.prompt_template)

17
src/cli_demo.py

@ -15,7 +15,7 @@ from transformers import TextIteratorStreamer
def main():
model_args, data_args, finetuning_args = prepare_infer_args()
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
model, tokenizer = load_pretrained(model_args, finetuning_args)
@ -25,17 +25,10 @@ def main():
def predict_and_print(query, history: list):
input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
gen_kwargs = {
"input_ids": input_ids,
"do_sample": True,
"top_p": 0.7,
"temperature": 0.95,
"num_beams": 1,
"max_new_tokens": 512,
"repetition_penalty": 1.0,
"logits_processor": get_logits_processor(),
"streamer": streamer
}
gen_kwargs = generating_args.to_dict()
gen_kwargs["input_ids"] = input_ids
gen_kwargs["logits_processor"] = get_logits_processor()
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
response = ""

11
src/train_ppo.py

@ -6,18 +6,17 @@
import math
from torch.optim import AdamW
from transformers.optimization import get_scheduler
from trl import PPOConfig
from utils import (
prepare_args,
prepare_data,
load_pretrained,
preprocess_data,
DynamicDataCollatorWithPadding,
PPOPeftTrainer,
LogCallback,
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
plot_loss
)
@ -29,7 +28,7 @@ def main():
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
data_collator = DynamicDataCollatorWithPadding(tokenizer, model.pretrained_model)
data_collator = DynamicDataCollatorWithPadding(tokenizer)
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,

9
src/train_pt.py

@ -5,14 +5,15 @@
import math
from utils import (
DynamicDataCollatorWithPadding,
PeftTrainer,
LogCallback,
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
DynamicDataCollatorWithPadding,
PeftTrainer,
LogCallback,
plot_loss
)
@ -24,7 +25,7 @@ def main():
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
data_collator = DynamicDataCollatorWithPadding(tokenizer, model, data_args.ignore_pad_token_for_loss)
data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)
# Split the dataset
if training_args.do_train:

12
src/train_rm.py

@ -6,13 +6,14 @@
from utils import (
prepare_args,
prepare_data,
load_pretrained,
preprocess_data,
PairwiseDataCollatorWithPadding,
PairwisePeftTrainer,
LogCallback,
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
compute_accuracy,
plot_loss
)
@ -23,7 +24,7 @@ def main():
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer, model.pretrained_model)
data_collator = PairwiseDataCollatorWithPadding(tokenizer)
training_args.remove_unused_columns = False # important for pairwise dataset
@ -45,6 +46,7 @@ def main():
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[LogCallback()],
compute_metrics=compute_accuracy,
**trainer_kwargs
)

10
src/train_sft.py

@ -5,14 +5,14 @@
from utils import (
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
DynamicDataCollatorWithPadding,
Seq2SeqPeftTrainer,
ComputeMetrics,
LogCallback,
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
get_logits_processor,
plot_loss
)
@ -25,7 +25,7 @@ def main():
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
data_collator = DynamicDataCollatorWithPadding(tokenizer, model, data_args.ignore_pad_token_for_loss)
data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length if \

2
src/utils/__init__.py

@ -11,7 +11,7 @@ from .data_collator import DynamicDataCollatorWithPadding
from .peft_trainer import PeftTrainer, LogCallback
from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer, compute_accuracy
from .ppo import PPOPeftTrainer
from .template import Template

56
src/utils/common.py

@ -36,7 +36,8 @@ from trl import AutoModelForCausalLMWithValueHead
from .config import (
ModelArguments,
DataTrainingArguments,
FinetuningArguments
FinetuningArguments,
GeneratingArguments
)
from .template import Template
@ -54,7 +55,8 @@ check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
logger = get_logger(__name__)
@ -91,12 +93,10 @@ def _init_adapter(
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert is_mergeable and len(
model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
else:
assert is_mergeable or len(
model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
@ -106,8 +106,7 @@ def _init_adapter(
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
if (is_trainable and model_args.resume_lora_training) or (
not is_mergeable): # continually train on the lora weights
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir
@ -186,11 +185,9 @@ def load_pretrained(
)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
require_version("transformers>=4.30.0.dev0",
"To fix: pip install git+https://github.com/huggingface/transformers.git")
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
require_version("accelerate>=0.20.0.dev0",
"To fix: pip install git+https://github.com/huggingface/accelerate.git")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
@ -201,10 +198,10 @@ def load_pretrained(
else:
raise NotImplementedError
is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK") or 0)}
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if not is_trainable:
if not is_trainable: # `device_map=auto` should be used for inference only
config_kwargs["device_map"] = "auto"
# Load and prepare pretrained models (without valuehead).
@ -221,6 +218,13 @@ def load_pretrained(
if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
load_valuehead_params(model, model_args.checkpoint_dir[0])
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
if stage == "ppo": # load reward model
assert is_trainable, "PPO stage cannot be performed at evaluation."
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
@ -228,11 +232,6 @@ def load_pretrained(
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
load_valuehead_params(model, model_args.reward_model)
# Set the parameter _is_int8_training_enabled for the AutoModelForCausalLMWithValueHead model
# To meet the compliance requirements of the transformers library
if model_args.quantization_bit is not None:
model._is_int8_training_enabled = True
if not is_trainable:
model.requires_grad_(False) # fix all model params
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
@ -245,11 +244,11 @@ def load_pretrained(
def prepare_args(
stage: Literal["pt", "sft", "rm", "ppo"]
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1]))
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
@ -313,13 +312,14 @@ def prepare_args(
return model_args, data_args, training_args, finetuning_args
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments))
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
model_args, data_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses()
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
@ -327,13 +327,14 @@ def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, Finetun
if data_args.prompt_template == "alpaca":
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
return model_args, data_args, finetuning_args
return model_args, data_args, finetuning_args, generating_args
def prepare_data(
model_args: ModelArguments,
data_args: DataTrainingArguments
) -> Dataset:
def checksum(file_path, hash):
with open(file_path, "rb") as datafile:
binary_data = datafile.read()
@ -358,10 +359,12 @@ def prepare_data(
elif dataset_attr.load_from == "file":
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
extension = dataset_attr.file_name.split(".")[-1]
if dataset_attr.file_sha1 is not None:
checksum(data_file, dataset_attr.file_sha1)
else:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
raw_datasets = load_dataset(
extension if extension in ["csv", "json"] else "text",
data_files=data_file,
@ -406,6 +409,7 @@ def preprocess_data(
training_args: Seq2SeqTrainingArguments,
stage: Literal["pt", "sft", "rm", "ppo"]
) -> Dataset:
column_names = list(dataset.column_names)
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
prompt_template = Template(data_args.prompt_template)

91
src/utils/config.py

@ -1,12 +1,13 @@
import os
import json
import torch
from typing import List, Literal, Optional
from typing import Any, Dict, List, Literal, Optional
from dataclasses import asdict, dataclass, field
@dataclass
class DatasetAttr:
load_from: str
dataset_name: Optional[str] = None
file_name: Optional[str] = None
@ -55,11 +56,11 @@ class ModelArguments:
)
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
default="nf4",
metadata={"help": "Quantization data type to use."}
metadata={"help": "Quantization data type to use in int4 training."}
)
double_quantization: Optional[bool] = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."}
metadata={"help": "Whether to use double quantization in int4 training or not."}
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
@ -67,8 +68,7 @@ class ModelArguments:
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={
"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
)
reward_model: Optional[str] = field(
default=None,
@ -76,8 +76,7 @@ class ModelArguments:
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={
"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
plot_loss: Optional[bool] = field(
default=False,
@ -156,41 +155,24 @@ class DataTrainingArguments:
for name in dataset_names:
if name not in dataset_info:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
dataset_attrs = []
dataset_attr = None
if "hf_hub_url" in dataset_info[name]:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
elif os.path.isfile(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])):
else:
dataset_attr = DatasetAttr(
"file",
file_name=dataset_info[name]["file_name"],
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
)
else:
# Support Directory
for file_name in os.listdir(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])):
path = os.path.join(dataset_info[name]["file_name"], file_name)
dataset_attrs.append(DatasetAttr(
"file",
file_name=path,
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
))
if dataset_attr is not None:
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)
else:
for i, dataset_attr in enumerate(dataset_attrs):
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)
@ -228,22 +210,20 @@ class FinetuningArguments:
lora_target: Optional[str] = field(
default="q_proj,v_proj",
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \
BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"}
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"gate_proj\", \"down_proj\"], \
BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"]"}
)
def __post_init__(self):
if isinstance(self.lora_target, str):
self.lora_target = [target.strip() for target in
self.lora_target.split(",")] # support custom target modules of LoRA
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in
trainable_layer_ids]
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
@ -259,3 +239,44 @@ class FinetuningArguments:
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))
@dataclass
class GeneratingArguments:
"""
Arguments pertaining to specify the decoding parameters.
"""
do_sample: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
)
temperature: Optional[float] = field(
default=0.95,
metadata={"help": "The value used to modulate the next token probabilities."}
)
top_p: Optional[float] = field(
default=0.7,
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
)
top_k: Optional[int] = field(
default=50,
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
)
infer_num_beams: Optional[int] = field(
default=1,
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
)
max_new_tokens: Optional[int] = field(
default=512,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
)
repetition_penalty: Optional[float] = field(
default=1.0,
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
)
def to_dict(self) -> Dict[str, Any]:
data_dict = asdict(self)
num_beams = data_dict.pop("infer_num_beams")
data_dict["num_beams"] = num_beams
return data_dict

3
src/utils/data_collator.py

@ -3,7 +3,6 @@ import torch
from typing import Dict, Optional, Sequence, Union
from transformers import DataCollatorWithPadding, BatchEncoding
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from .other import IGNORE_INDEX
@ -16,11 +15,9 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
ignore_pad_token_for_loss: Optional[bool] = False
):
super().__init__(tokenizer, padding=True)
self.model = model
self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id
def get_attention_masks(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:

12
src/utils/pairwise.py

@ -1,5 +1,6 @@
import torch
from typing import Dict, Sequence, Union
import numpy as np
from typing import Dict, Sequence, Tuple, Union
from .data_collator import DynamicDataCollatorWithPadding
@ -10,6 +11,12 @@ from .other import get_logger
logger = get_logger(__name__)
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
preds, _ = eval_preds
preds = np.array(preds)
return {"accuracy": (preds[:, 0] > preds[:, 1]).sum() / len(preds)}
class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding):
r"""
Data collator for pairwise data.
@ -47,5 +54,4 @@ class PairwisePeftTrainer(PeftTrainer):
_, _, values = model(**inputs)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
outputs = {"r_accept": r_accept, "r_reject": r_reject}
return (loss, outputs) if return_outputs else loss
return (loss, torch.stack((r_accept, r_reject), dim=-1)) if return_outputs else loss

16
src/utils/template.py

@ -14,27 +14,25 @@ class Template:
return getattr(self, "_format_{}".format(self.name))(query, history, prefix)
def _format_vanilla(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str:
prompt = prefix
if history:
for old_query, response in history:
prompt += old_query + "\n" + response + "\n"
prompt += query
return prompt
r"""
Use for language model inference without histories.
"""
return query
def _format_alpaca(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str:
r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
https://github.com/ymcui/Chinese-LLaMA-Alpaca
"""
if prefix:
prompt = prefix
else:
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n\n"
prompt += "Instruction:\n"
if history:
for old_query, response in history:
prompt += "Human:\n{}\n\nAssistant:\n{}\n\n".format(old_query, response)
prompt += "Human:\n{}\n\nAssistant:".format(query)
prompt += "### Instruction:\n{}\n\n### Response:\n{}\n\n".format(old_query, response)
prompt += "### Instruction:\n{}\n\n### Response:\n".format(query)
return prompt
def _format_vicuna(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str:

10
src/web_demo.py

@ -21,7 +21,7 @@ from transformers.utils.versions import require_version
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
model_args, data_args, finetuning_args = prepare_infer_args()
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args)
prompt_template = Template(data_args.prompt_template)
@ -87,9 +87,9 @@ def predict(query, chatbot, max_length, top_p, temperature, history):
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"num_beams": 1,
"num_beams": generating_args.infer_num_beams,
"max_length": max_length,
"repetition_penalty": 1.0,
"repetition_penalty": generating_args.repetition_penalty,
"logits_processor": get_logits_processor(),
"streamer": streamer
}
@ -133,8 +133,8 @@ with gr.Blocks() as demo:
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1.5, value=0.95, step=0.01, label="Temperature", interactive=True)
top_p = gr.Slider(0, 1, value=generating_args.top_p, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1.5, value=generating_args.temperature, step=0.01, label="Temperature", interactive=True)
history = gr.State([])

Loading…
Cancel
Save