Browse Source

support RM metrics, add generating Args

main
hiyouga 2 years ago
parent
commit
cec6524d6b
  1. BIN
      assets/wechat.jpg
  2. 100846
      data/comparison_gpt4_data_en.json
  3. 73084
      data/comparison_gpt4_data_zh.json
  4. 13
      data/dataset_info.json
  5. 8
      requirements.txt
  6. 17
      src/api_demo.py
  7. 17
      src/cli_demo.py
  8. 11
      src/train_ppo.py
  9. 9
      src/train_pt.py
  10. 12
      src/train_rm.py
  11. 10
      src/train_sft.py
  12. 2
      src/utils/__init__.py
  13. 98
      src/utils/common.py
  14. 111
      src/utils/config.py
  15. 3
      src/utils/data_collator.py
  16. 12
      src/utils/pairwise.py
  17. 16
      src/utils/template.py
  18. 10
      src/web_demo.py

BIN
assets/wechat.jpg

Binary file not shown.

Before

Width:  |  Height:  |  Size: 141 KiB

After

Width:  |  Height:  |  Size: 146 KiB

100846
data/comparison_gpt4_data_en.json

File diff suppressed because it is too large

73084
data/comparison_gpt4_data_zh.json

File diff suppressed because it is too large

13
data/dataset_info.json

@ -79,11 +79,11 @@
}, },
"comparison_gpt4_en": { "comparison_gpt4_en": {
"file_name": "comparison_gpt4_data_en.json", "file_name": "comparison_gpt4_data_en.json",
"file_sha1": "eeb295ce0ab011c37af52596460c8a57d07ad19f" "file_sha1": "96fa18313544e22444fe20eead7754b17da452ae"
}, },
"comparison_gpt4_zh": { "comparison_gpt4_zh": {
"file_name": "comparison_gpt4_data_zh.json", "file_name": "comparison_gpt4_data_zh.json",
"file_sha1": "b99a41c1c864019d9b0c07dbcd5df0560cf33ce0" "file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd"
}, },
"hh_rlhf_en": { "hh_rlhf_en": {
"script_url": "hh_rlhf_en", "script_url": "hh_rlhf_en",
@ -103,14 +103,5 @@
"response": "", "response": "",
"history": "" "history": ""
} }
},
"pretrain_data": {
"file_name": "pretrain_data",
"columns": {
"prompt": "content",
"query": "",
"response": "",
"history": ""
}
} }
} }

8
requirements.txt

@ -2,11 +2,11 @@ torch>=1.13.1
protobuf protobuf
cpm_kernels cpm_kernels
sentencepiece sentencepiece
transformers>=4.27.4 transformers>=4.29.1
datasets>=2.10.0 datasets>=2.12.0
accelerate>=0.18.0 accelerate>=0.19.0
peft>=0.3.0 peft>=0.3.0
trl>=0.4.1 trl>=0.4.4
jieba jieba
rouge_chinese rouge_chinese
nltk nltk

17
src/api_demo.py

@ -42,7 +42,7 @@ app = FastAPI()
@app.post("/") @app.post("/")
async def create_item(request: Request): async def create_item(request: Request):
global model, tokenizer, prompt_template global model, tokenizer, prompt_template, generating_args
# Parse the request JSON # Parse the request JSON
json_post_raw = await request.json() json_post_raw = await request.json()
@ -56,16 +56,9 @@ async def create_item(request: Request):
input_ids = input_ids.to(model.device) input_ids = input_ids.to(model.device)
# Generation arguments # Generation arguments
gen_kwargs = { gen_kwargs = generating_args.to_dict()
"input_ids": input_ids, gen_kwargs["input_ids"] = input_ids
"do_sample": True, gen_kwargs["logits_processor"] = get_logits_processor()
"top_p": 0.7,
"temperature": 0.95,
"num_beams": 1,
"max_new_tokens": 512,
"repetition_penalty": 1.0,
"logits_processor": get_logits_processor()
}
# Generate response # Generate response
with torch.no_grad(): with torch.no_grad():
@ -95,7 +88,7 @@ async def create_item(request: Request):
if __name__ == "__main__": if __name__ == "__main__":
model_args, data_args, finetuning_args = prepare_infer_args() model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args) model, tokenizer = load_pretrained(model_args, finetuning_args)
prompt_template = Template(data_args.prompt_template) prompt_template = Template(data_args.prompt_template)

17
src/cli_demo.py

@ -15,7 +15,7 @@ from transformers import TextIteratorStreamer
def main(): def main():
model_args, data_args, finetuning_args = prepare_infer_args() model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA" model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
model, tokenizer = load_pretrained(model_args, finetuning_args) model, tokenizer = load_pretrained(model_args, finetuning_args)
@ -25,17 +25,10 @@ def main():
def predict_and_print(query, history: list): def predict_and_print(query, history: list):
input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"] input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device) input_ids = input_ids.to(model.device)
gen_kwargs = { gen_kwargs = generating_args.to_dict()
"input_ids": input_ids, gen_kwargs["input_ids"] = input_ids
"do_sample": True, gen_kwargs["logits_processor"] = get_logits_processor()
"top_p": 0.7, gen_kwargs["streamer"] = streamer
"temperature": 0.95,
"num_beams": 1,
"max_new_tokens": 512,
"repetition_penalty": 1.0,
"logits_processor": get_logits_processor(),
"streamer": streamer
}
thread = Thread(target=model.generate, kwargs=gen_kwargs) thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start() thread.start()
response = "" response = ""

11
src/train_ppo.py

@ -6,18 +6,17 @@
import math import math
from torch.optim import AdamW from torch.optim import AdamW
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from trl import PPOConfig from trl import PPOConfig
from utils import ( from utils import (
prepare_args,
prepare_data,
load_pretrained,
preprocess_data,
DynamicDataCollatorWithPadding, DynamicDataCollatorWithPadding,
PPOPeftTrainer, PPOPeftTrainer,
LogCallback, LogCallback,
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
plot_loss plot_loss
) )
@ -29,7 +28,7 @@ def main():
dataset = prepare_data(model_args, data_args) dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo") model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo") dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
data_collator = DynamicDataCollatorWithPadding(tokenizer, model.pretrained_model) data_collator = DynamicDataCollatorWithPadding(tokenizer)
ppo_config = PPOConfig( ppo_config = PPOConfig(
model_name=model_args.model_name_or_path, model_name=model_args.model_name_or_path,

9
src/train_pt.py

@ -5,14 +5,15 @@
import math import math
from utils import ( from utils import (
DynamicDataCollatorWithPadding,
PeftTrainer,
LogCallback,
load_pretrained, load_pretrained,
prepare_args, prepare_args,
prepare_data, prepare_data,
preprocess_data, preprocess_data,
DynamicDataCollatorWithPadding,
PeftTrainer,
LogCallback,
plot_loss plot_loss
) )
@ -24,7 +25,7 @@ def main():
dataset = prepare_data(model_args, data_args) dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt") model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt") dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
data_collator = DynamicDataCollatorWithPadding(tokenizer, model, data_args.ignore_pad_token_for_loss) data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)
# Split the dataset # Split the dataset
if training_args.do_train: if training_args.do_train:

12
src/train_rm.py

@ -6,13 +6,14 @@
from utils import ( from utils import (
prepare_args,
prepare_data,
load_pretrained,
preprocess_data,
PairwiseDataCollatorWithPadding, PairwiseDataCollatorWithPadding,
PairwisePeftTrainer, PairwisePeftTrainer,
LogCallback, LogCallback,
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
compute_accuracy,
plot_loss plot_loss
) )
@ -23,7 +24,7 @@ def main():
dataset = prepare_data(model_args, data_args) dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm") model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm") dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer, model.pretrained_model) data_collator = PairwiseDataCollatorWithPadding(tokenizer)
training_args.remove_unused_columns = False # important for pairwise dataset training_args.remove_unused_columns = False # important for pairwise dataset
@ -45,6 +46,7 @@ def main():
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=[LogCallback()], callbacks=[LogCallback()],
compute_metrics=compute_accuracy,
**trainer_kwargs **trainer_kwargs
) )

10
src/train_sft.py

@ -5,14 +5,14 @@
from utils import ( from utils import (
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
DynamicDataCollatorWithPadding, DynamicDataCollatorWithPadding,
Seq2SeqPeftTrainer, Seq2SeqPeftTrainer,
ComputeMetrics, ComputeMetrics,
LogCallback, LogCallback,
load_pretrained,
prepare_args,
prepare_data,
preprocess_data,
get_logits_processor, get_logits_processor,
plot_loss plot_loss
) )
@ -25,7 +25,7 @@ def main():
dataset = prepare_data(model_args, data_args) dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft") model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft") dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
data_collator = DynamicDataCollatorWithPadding(tokenizer, model, data_args.ignore_pad_token_for_loss) data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)
# Override the decoding parameters of Seq2SeqTrainer # Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length if \ training_args.generation_max_length = training_args.generation_max_length if \

2
src/utils/__init__.py

@ -11,7 +11,7 @@ from .data_collator import DynamicDataCollatorWithPadding
from .peft_trainer import PeftTrainer, LogCallback from .peft_trainer import PeftTrainer, LogCallback
from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer, compute_accuracy
from .ppo import PPOPeftTrainer from .ppo import PPOPeftTrainer
from .template import Template from .template import Template

98
src/utils/common.py

@ -36,7 +36,8 @@ from trl import AutoModelForCausalLMWithValueHead
from .config import ( from .config import (
ModelArguments, ModelArguments,
DataTrainingArguments, DataTrainingArguments,
FinetuningArguments FinetuningArguments,
GeneratingArguments
) )
from .template import Template from .template import Template
@ -54,7 +55,8 @@ check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0") require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
logger = get_logger(__name__) logger = get_logger(__name__)
@ -91,12 +93,10 @@ def _init_adapter(
if model_args.checkpoint_dir is not None: if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
assert is_mergeable and len( assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
else: else:
assert is_mergeable or len( assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA") logger.info("Fine-tuning method: LoRA")
@ -106,8 +106,7 @@ def _init_adapter(
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." "The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
if (is_trainable and model_args.resume_lora_training) or ( if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
not is_mergeable): # continually train on the lora weights
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else: else:
checkpoints_to_merge = model_args.checkpoint_dir checkpoints_to_merge = model_args.checkpoint_dir
@ -119,10 +118,10 @@ def _init_adapter(
if len(checkpoints_to_merge) > 0: if len(checkpoints_to_merge) > 0:
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
if lastest_checkpoint is not None: # resume lora training or quantized inference if lastest_checkpoint is not None: # resume lora training or quantized inference
model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable) model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable)
if is_trainable and lastest_checkpoint is None: # create new lora weights while training if is_trainable and lastest_checkpoint is None: # create new lora weights while training
lora_config = LoraConfig( lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, task_type=TaskType.CAUSAL_LM,
inference_mode=False, inference_mode=False,
@ -170,7 +169,7 @@ def load_pretrained(
padding_side="left", padding_side="left",
**config_kwargs **config_kwargs
) )
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
is_mergeable = True is_mergeable = True
@ -186,11 +185,9 @@ def load_pretrained(
) )
elif model_args.quantization_bit == 4: elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
require_version("transformers>=4.30.0.dev0", require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
"To fix: pip install git+https://github.com/huggingface/transformers.git") require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
require_version("accelerate>=0.20.0.dev0",
"To fix: pip install git+https://github.com/huggingface/accelerate.git")
config_kwargs["load_in_4bit"] = True config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig( config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
@ -201,10 +198,10 @@ def load_pretrained(
else: else:
raise NotImplementedError raise NotImplementedError
is_mergeable = False is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK") or 0)} config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if not is_trainable: if not is_trainable: # `device_map=auto` should be used for inference only
config_kwargs["device_map"] = "auto" config_kwargs["device_map"] = "auto"
# Load and prepare pretrained models (without valuehead). # Load and prepare pretrained models (without valuehead).
@ -218,24 +215,26 @@ def load_pretrained(
model = prepare_model_for_training(model) if is_trainable else model model = prepare_model_for_training(model) if is_trainable else model
model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
if stage == "rm" or stage == "ppo": # add value head if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
if stage == "ppo": # load reward model if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
load_valuehead_params(model, model_args.checkpoint_dir[0])
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
if stage == "ppo": # load reward model
assert is_trainable, "PPO stage cannot be performed at evaluation." assert is_trainable, "PPO stage cannot be performed at evaluation."
assert model_args.reward_model is not None, "Reward model is necessary for PPO training." assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
logger.info("Load reward model from {}".format(model_args.reward_model)) logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
load_valuehead_params(model, model_args.reward_model) load_valuehead_params(model, model_args.reward_model)
# Set the parameter _is_int8_training_enabled for the AutoModelForCausalLMWithValueHead model
# To meet the compliance requirements of the transformers library
if model_args.quantization_bit is not None:
model._is_int8_training_enabled = True
if not is_trainable: if not is_trainable:
model.requires_grad_(False) # fix all model params model.requires_grad_(False) # fix all model params
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
print_trainable_params(model) print_trainable_params(model)
@ -245,11 +244,11 @@ def load_pretrained(
def prepare_args( def prepare_args(
stage: Literal["pt", "sft", "rm", "ppo"] stage: Literal["pt", "sft", "rm", "ppo"]
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]: ) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments)) parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
model_args, data_args, training_args, finetuning_args = parser.parse_json_file( model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
json_file=os.path.abspath(sys.argv[1]))
else: else:
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
@ -290,7 +289,7 @@ def prepare_args(
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.") logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
training_args.ddp_find_unused_parameters = False training_args.ddp_find_unused_parameters = False
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
if training_args.fp16: if training_args.fp16:
@ -313,13 +312,14 @@ def prepare_args(
return model_args, data_args, training_args, finetuning_args return model_args, data_args, training_args, finetuning_args
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]: def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments))
model_args, data_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else: else:
model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses() model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.") raise ValueError("Quantization is only compatible with the LoRA method.")
@ -327,13 +327,14 @@ def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, Finetun
if data_args.prompt_template == "alpaca": if data_args.prompt_template == "alpaca":
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.") logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
return model_args, data_args, finetuning_args return model_args, data_args, finetuning_args, generating_args
def prepare_data( def prepare_data(
model_args: ModelArguments, model_args: ModelArguments,
data_args: DataTrainingArguments data_args: DataTrainingArguments
) -> Dataset: ) -> Dataset:
def checksum(file_path, hash): def checksum(file_path, hash):
with open(file_path, "rb") as datafile: with open(file_path, "rb") as datafile:
binary_data = datafile.read() binary_data = datafile.read()
@ -342,7 +343,7 @@ def prepare_data(
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
max_samples = data_args.max_samples max_samples = data_args.max_samples
all_datasets: List[Dataset] = [] # support multiple datasets all_datasets: List[Dataset] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list: for dataset_attr in data_args.dataset_list:
@ -358,10 +359,12 @@ def prepare_data(
elif dataset_attr.load_from == "file": elif dataset_attr.load_from == "file":
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
extension = dataset_attr.file_name.split(".")[-1] extension = dataset_attr.file_name.split(".")[-1]
if dataset_attr.file_sha1 is not None: if dataset_attr.file_sha1 is not None:
checksum(data_file, dataset_attr.file_sha1) checksum(data_file, dataset_attr.file_sha1)
else: else:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
raw_datasets = load_dataset( raw_datasets = load_dataset(
extension if extension in ["csv", "json"] else "text", extension if extension in ["csv", "json"] else "text",
data_files=data_file, data_files=data_file,
@ -383,11 +386,11 @@ def prepare_data(
("query_column", "query"), ("query_column", "query"),
("response_column", "response"), ("response_column", "response"),
("history_column", "history") ("history_column", "history")
]: # every dataset will have 4 columns same as each other ]: # every dataset will have 4 columns same as each other
if getattr(dataset_attr, column_name) != target_name: if getattr(dataset_attr, column_name) != target_name:
if getattr(dataset_attr, column_name): if getattr(dataset_attr, column_name):
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name) dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
else: # None or empty string else: # None or empty string
dataset = dataset.add_column(target_name, dummy_data) dataset = dataset.add_column(target_name, dummy_data)
all_datasets.append(dataset) all_datasets.append(dataset)
@ -406,6 +409,7 @@ def preprocess_data(
training_args: Seq2SeqTrainingArguments, training_args: Seq2SeqTrainingArguments,
stage: Literal["pt", "sft", "rm", "ppo"] stage: Literal["pt", "sft", "rm", "ppo"]
) -> Dataset: ) -> Dataset:
column_names = list(dataset.column_names) column_names = list(dataset.column_names)
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
prompt_template = Template(data_args.prompt_template) prompt_template = Template(data_args.prompt_template)
@ -442,9 +446,9 @@ def preprocess_data(
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
target_ids = tokenizer.encode(text=answer, add_special_tokens=False) target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
if len(source_ids) > data_args.max_source_length - 1: # bos token if len(source_ids) > data_args.max_source_length - 1: # bos token
source_ids = source_ids[:data_args.max_source_length - 1] source_ids = source_ids[:data_args.max_source_length - 1]
if len(target_ids) > data_args.max_target_length - 1: # eos token if len(target_ids) > data_args.max_target_length - 1: # eos token
target_ids = target_ids[:data_args.max_target_length - 1] target_ids = target_ids[:data_args.max_target_length - 1]
input_ids = source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id] input_ids = source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
@ -461,9 +465,9 @@ def preprocess_data(
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
target_ids = tokenizer.encode(text=answer, add_special_tokens=False) target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
if len(source_ids) > data_args.max_source_length - 1: # bos token if len(source_ids) > data_args.max_source_length - 1: # bos token
source_ids = source_ids[:data_args.max_source_length - 1] source_ids = source_ids[:data_args.max_source_length - 1]
if len(target_ids) > data_args.max_target_length - 1: # bos token if len(target_ids) > data_args.max_target_length - 1: # bos token
target_ids = target_ids[:data_args.max_target_length - 1] target_ids = target_ids[:data_args.max_target_length - 1]
input_ids = source_ids + [tokenizer.bos_token_id] input_ids = source_ids + [tokenizer.bos_token_id]
@ -481,11 +485,11 @@ def preprocess_data(
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
if len(source_ids) > data_args.max_source_length - 1: # bos token if len(source_ids) > data_args.max_source_length - 1: # bos token
source_ids = source_ids[:data_args.max_source_length - 1] source_ids = source_ids[:data_args.max_source_length - 1]
if len(accept_ids) > data_args.max_target_length - 1: # eos token if len(accept_ids) > data_args.max_target_length - 1: # eos token
accept_ids = accept_ids[:data_args.max_target_length - 1] accept_ids = accept_ids[:data_args.max_target_length - 1]
if len(reject_ids) > data_args.max_target_length - 1: # eos token if len(reject_ids) > data_args.max_target_length - 1: # eos token
reject_ids = reject_ids[:data_args.max_target_length - 1] reject_ids = reject_ids[:data_args.max_target_length - 1]
accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id] accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id]

111
src/utils/config.py

@ -1,12 +1,13 @@
import os import os
import json import json
import torch import torch
from typing import List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
@dataclass @dataclass
class DatasetAttr: class DatasetAttr:
load_from: str load_from: str
dataset_name: Optional[str] = None dataset_name: Optional[str] = None
file_name: Optional[str] = None file_name: Optional[str] = None
@ -55,11 +56,11 @@ class ModelArguments:
) )
quantization_type: Optional[Literal["fp4", "nf4"]] = field( quantization_type: Optional[Literal["fp4", "nf4"]] = field(
default="nf4", default="nf4",
metadata={"help": "Quantization data type to use."} metadata={"help": "Quantization data type to use in int4 training."}
) )
double_quantization: Optional[bool] = field( double_quantization: Optional[bool] = field(
default=True, default=True,
metadata={"help": "Compress the quantization statistics through double quantization."} metadata={"help": "Whether to use double quantization in int4 training or not."}
) )
compute_dtype: Optional[torch.dtype] = field( compute_dtype: Optional[torch.dtype] = field(
default=None, default=None,
@ -67,8 +68,7 @@ class ModelArguments:
) )
checkpoint_dir: Optional[str] = field( checkpoint_dir: Optional[str] = field(
default=None, default=None,
metadata={ metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
) )
reward_model: Optional[str] = field( reward_model: Optional[str] = field(
default=None, default=None,
@ -76,8 +76,7 @@ class ModelArguments:
) )
resume_lora_training: Optional[bool] = field( resume_lora_training: Optional[bool] = field(
default=True, default=True,
metadata={ metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
) )
plot_loss: Optional[bool] = field( plot_loss: Optional[bool] = field(
default=False, default=False,
@ -85,7 +84,7 @@ class ModelArguments:
) )
def __post_init__(self): def __post_init__(self):
if self.checkpoint_dir is not None: # support merging multiple lora weights if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
@ -147,7 +146,7 @@ class DataTrainingArguments:
metadata={"help": "Which template to use for constructing prompts in training and inference."} metadata={"help": "Which template to use for constructing prompts in training and inference."}
) )
def __post_init__(self): # support mixing multiple datasets def __post_init__(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset.split(",")] dataset_names = [ds.strip() for ds in self.dataset.split(",")]
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
@ -156,42 +155,25 @@ class DataTrainingArguments:
for name in dataset_names: for name in dataset_names:
if name not in dataset_info: if name not in dataset_info:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
dataset_attrs = []
dataset_attr = None
if "hf_hub_url" in dataset_info[name]: if "hf_hub_url" in dataset_info[name]:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]: elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
elif os.path.isfile(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])): else:
dataset_attr = DatasetAttr( dataset_attr = DatasetAttr(
"file", "file",
file_name=dataset_info[name]["file_name"], file_name=dataset_info[name]["file_name"],
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
) )
else:
# Support Directory if "columns" in dataset_info[name]:
for file_name in os.listdir(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])): dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
path = os.path.join(dataset_info[name]["file_name"], file_name) dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attrs.append(DatasetAttr( dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
"file", dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
file_name=path,
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None self.dataset_list.append(dataset_attr)
))
if dataset_attr is not None:
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)
else:
for i, dataset_attr in enumerate(dataset_attrs):
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)
@dataclass @dataclass
@ -228,22 +210,20 @@ class FinetuningArguments:
lora_target: Optional[str] = field( lora_target: Optional[str] = field(
default="q_proj,v_proj", default="q_proj,v_proj",
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \ metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \ LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"gate_proj\", \"down_proj\"], \
BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"} BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"]"}
) )
def __post_init__(self): def __post_init__(self):
if isinstance(self.lora_target, str): if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
self.lora_target = [target.strip() for target in self.lora_target = [target.strip() for target in self.lora_target.split(",")]
self.lora_target.split(",")] # support custom target modules of LoRA
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)] trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0 else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)] trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
trainable_layer_ids]
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
@ -259,3 +239,44 @@ class FinetuningArguments:
with open(json_path, "r", encoding="utf-8") as f: with open(json_path, "r", encoding="utf-8") as f:
text = f.read() text = f.read()
return cls(**json.loads(text)) return cls(**json.loads(text))
@dataclass
class GeneratingArguments:
"""
Arguments pertaining to specify the decoding parameters.
"""
do_sample: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
)
temperature: Optional[float] = field(
default=0.95,
metadata={"help": "The value used to modulate the next token probabilities."}
)
top_p: Optional[float] = field(
default=0.7,
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
)
top_k: Optional[int] = field(
default=50,
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
)
infer_num_beams: Optional[int] = field(
default=1,
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
)
max_new_tokens: Optional[int] = field(
default=512,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
)
repetition_penalty: Optional[float] = field(
default=1.0,
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
)
def to_dict(self) -> Dict[str, Any]:
data_dict = asdict(self)
num_beams = data_dict.pop("infer_num_beams")
data_dict["num_beams"] = num_beams
return data_dict

3
src/utils/data_collator.py

@ -3,7 +3,6 @@ import torch
from typing import Dict, Optional, Sequence, Union from typing import Dict, Optional, Sequence, Union
from transformers import DataCollatorWithPadding, BatchEncoding from transformers import DataCollatorWithPadding, BatchEncoding
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from .other import IGNORE_INDEX from .other import IGNORE_INDEX
@ -16,11 +15,9 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
def __init__( def __init__(
self, self,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
ignore_pad_token_for_loss: Optional[bool] = False ignore_pad_token_for_loss: Optional[bool] = False
): ):
super().__init__(tokenizer, padding=True) super().__init__(tokenizer, padding=True)
self.model = model
self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id
def get_attention_masks(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor: def get_attention_masks(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:

12
src/utils/pairwise.py

@ -1,5 +1,6 @@
import torch import torch
from typing import Dict, Sequence, Union import numpy as np
from typing import Dict, Sequence, Tuple, Union
from .data_collator import DynamicDataCollatorWithPadding from .data_collator import DynamicDataCollatorWithPadding
@ -10,6 +11,12 @@ from .other import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
preds, _ = eval_preds
preds = np.array(preds)
return {"accuracy": (preds[:, 0] > preds[:, 1]).sum() / len(preds)}
class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding): class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding):
r""" r"""
Data collator for pairwise data. Data collator for pairwise data.
@ -47,5 +54,4 @@ class PairwisePeftTrainer(PeftTrainer):
_, _, values = model(**inputs) _, _, values = model(**inputs)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0) r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
outputs = {"r_accept": r_accept, "r_reject": r_reject} return (loss, torch.stack((r_accept, r_reject), dim=-1)) if return_outputs else loss
return (loss, outputs) if return_outputs else loss

16
src/utils/template.py

@ -14,27 +14,25 @@ class Template:
return getattr(self, "_format_{}".format(self.name))(query, history, prefix) return getattr(self, "_format_{}".format(self.name))(query, history, prefix)
def _format_vanilla(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: def _format_vanilla(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str:
prompt = prefix r"""
if history: Use for language model inference without histories.
for old_query, response in history: """
prompt += old_query + "\n" + response + "\n" return query
prompt += query
return prompt
def _format_alpaca(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: def _format_alpaca(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str:
r""" r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
https://github.com/ymcui/Chinese-LLaMA-Alpaca
""" """
if prefix: if prefix:
prompt = prefix prompt = prefix
else: else:
prompt = "Below is an instruction that describes a task. " prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n\n" prompt += "Write a response that appropriately completes the request.\n\n"
prompt += "Instruction:\n"
if history: if history:
for old_query, response in history: for old_query, response in history:
prompt += "Human:\n{}\n\nAssistant:\n{}\n\n".format(old_query, response) prompt += "### Instruction:\n{}\n\n### Response:\n{}\n\n".format(old_query, response)
prompt += "Human:\n{}\n\nAssistant:".format(query) prompt += "### Instruction:\n{}\n\n### Response:\n".format(query)
return prompt return prompt
def _format_vicuna(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: def _format_vicuna(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str:

10
src/web_demo.py

@ -21,7 +21,7 @@ from transformers.utils.versions import require_version
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0") require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
model_args, data_args, finetuning_args = prepare_infer_args() model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args) model, tokenizer = load_pretrained(model_args, finetuning_args)
prompt_template = Template(data_args.prompt_template) prompt_template = Template(data_args.prompt_template)
@ -87,9 +87,9 @@ def predict(query, chatbot, max_length, top_p, temperature, history):
"do_sample": True, "do_sample": True,
"top_p": top_p, "top_p": top_p,
"temperature": temperature, "temperature": temperature,
"num_beams": 1, "num_beams": generating_args.infer_num_beams,
"max_length": max_length, "max_length": max_length,
"repetition_penalty": 1.0, "repetition_penalty": generating_args.repetition_penalty,
"logits_processor": get_logits_processor(), "logits_processor": get_logits_processor(),
"streamer": streamer "streamer": streamer
} }
@ -133,8 +133,8 @@ with gr.Blocks() as demo:
with gr.Column(scale=1): with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History") emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True) max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) top_p = gr.Slider(0, 1, value=generating_args.top_p, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1.5, value=0.95, step=0.01, label="Temperature", interactive=True) temperature = gr.Slider(0, 1.5, value=generating_args.temperature, step=0.01, label="Temperature", interactive=True)
history = gr.State([]) history = gr.State([])

Loading…
Cancel
Save