Browse Source

alter rewards data type

main
hiyouga 2 years ago
parent
commit
50d9a20f81
  1. 10
      src/cli_demo.py
  2. 2
      src/train_ppo.py
  3. 2
      src/train_pt.py
  4. 2
      src/train_rm.py
  5. 2
      src/train_sft.py
  6. 2
      src/utils/__init__.py
  7. 14
      src/utils/common.py
  8. 6
      src/utils/data_collator.py
  9. 14
      src/utils/other.py
  10. 3
      src/utils/peft_trainer.py
  11. 11
      src/utils/ppo.py
  12. 10
      src/web_demo.py

10
src/cli_demo.py

@ -4,22 +4,24 @@
import torch import torch
from utils import ModelArguments, load_pretrained from utils import ModelArguments, FinetuningArguments, load_pretrained
from transformers import HfArgumentParser from transformers import HfArgumentParser
def main(): def main():
parser = HfArgumentParser(ModelArguments) parser = HfArgumentParser((ModelArguments, FinetuningArguments))
model_args, = parser.parse_args_into_dataclasses() model_args, finetuning_args = parser.parse_args_into_dataclasses()
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA" model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
model, tokenizer = load_pretrained(model_args) model, tokenizer = load_pretrained(model_args, finetuning_args)
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map from accelerate import dispatch_model, infer_auto_device_map
device_map = infer_auto_device_map(model) device_map = infer_auto_device_map(model)
model = dispatch_model(model, device_map) model = dispatch_model(model, device_map)
else: else:
model = model.cuda() model = model.cuda()
model.eval() model.eval()
def format_example(query): def format_example(query):

2
src/train_ppo.py

@ -70,7 +70,7 @@ def main():
ppo_trainer.save_model() ppo_trainer.save_model()
ppo_trainer.save_state() # must be after save_model ppo_trainer.save_state() # must be after save_model
if ppo_trainer.is_world_process_zero() and model_args.plot_loss: if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "reward"]) plot_loss(training_args.output_dir, keys=["loss", "reward"])
def _mp_fn(index): def _mp_fn(index):

2
src/train_pt.py

@ -55,7 +55,7 @@ def main():
trainer.save_state() trainer.save_state()
trainer.save_model() trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss: if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:

2
src/train_rm.py

@ -56,7 +56,7 @@ def main():
trainer.save_state() trainer.save_state()
trainer.save_model() trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss: if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:

2
src/train_sft.py

@ -72,7 +72,7 @@ def main():
trainer.save_state() trainer.save_state()
trainer.save_model() trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss: if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args, keys=["loss", "eval_loss"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:

2
src/utils/__init__.py

@ -13,5 +13,5 @@ from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer
from .ppo import PPOPeftTrainer from .ppo import PPOPeftTrainer
from .config import ModelArguments from .config import ModelArguments, FinetuningArguments
from .other import get_logits_processor, plot_loss from .other import get_logits_processor, plot_loss

14
src/utils/common.py

@ -42,8 +42,7 @@ from .other import (
load_valuehead_params, load_valuehead_params,
print_trainable_params, print_trainable_params,
prepare_model_for_training, prepare_model_for_training,
IGNORE_INDEX, IGNORE_INDEX
FINETUNING_ARGS_NAME
) )
check_min_version("4.29.1") check_min_version("4.29.1")
@ -128,7 +127,7 @@ def init_adapter(
def load_pretrained( def load_pretrained(
model_args: ModelArguments, model_args: ModelArguments,
finetuning_args: Optional[FinetuningArguments] = None, finetuning_args: FinetuningArguments,
is_trainable: Optional[bool] = False, is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft" stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
@ -137,16 +136,9 @@ def load_pretrained(
Support both training and inference. Support both training and inference.
""" """
if finetuning_args is None: # load the fine-tuning arguments if (not is_trainable) and model_args.checkpoint_dir is None:
if model_args.checkpoint_dir is None:
logger.warning("Checkpoint is not found at evaluation, load the original model.") logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none") finetuning_args = FinetuningArguments(finetuning_type="none")
elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)):
finetuning_args = FinetuningArguments.load_from_json(
os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)
)
else:
raise ValueError("Missing fine-tuning arguments in the provided dictionary.")
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \ assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with LoRA method." "RM and PPO training can only be performed with LoRA method."

6
src/utils/data_collator.py

@ -2,7 +2,7 @@ import torch
from typing import Dict, Optional, Sequence, Union from typing import Dict, Optional, Sequence, Union
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding, BatchEncoding
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
@ -34,7 +34,7 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
attention_mask = attention_mask.bool() attention_mask = attention_mask.bool()
return attention_mask return attention_mask
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]: def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> BatchEncoding:
r""" r"""
Pads batched data to the longest sequence in the batch. Pads batched data to the longest sequence in the batch.
@ -64,4 +64,4 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
batch["input_ids"] = input_ids batch["input_ids"] = input_ids
batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device) batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device)
return batch return BatchEncoding(batch)

14
src/utils/other.py

@ -5,7 +5,6 @@ import torch
import logging import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
from transformers import Seq2SeqTrainingArguments
from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer import TRAINER_STATE_NAME
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList from transformers.generation.utils import LogitsProcessorList
@ -143,7 +142,7 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]: def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
""" """
EMA implementation according to TensorBoard. EMA implementation according to TensorBoard.
""" """
@ -156,9 +155,10 @@ def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
return smoothed return smoothed
def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]] = ["loss"]) -> None: def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r")) with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f)
for key in keys: for key in keys:
steps, metrics = [], [] steps, metrics = [], []
@ -174,9 +174,9 @@ def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]]
plt.figure() plt.figure()
plt.plot(steps, metrics, alpha=0.4, label="original") plt.plot(steps, metrics, alpha=0.4, label="original")
plt.plot(steps, smooth(metrics), label="smoothed") plt.plot(steps, smooth(metrics), label="smoothed")
plt.title("training {} of {}".format(key, training_args.output_dir)) plt.title("training {} of {}".format(key, save_dictionary))
plt.xlabel("step") plt.xlabel("step")
plt.ylabel(key) plt.ylabel(key)
plt.legend() plt.legend()
plt.savefig(os.path.join(training_args.output_dir, "training_{}.png".format(key)), format="png", dpi=100) plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
print("Figure saved:", os.path.join(training_args.output_dir, "training_{}.png".format(key))) print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))

3
src/utils/peft_trainer.py

@ -109,7 +109,8 @@ class PeftTrainer(Seq2SeqTrainer):
if hasattr(model, "v_head"): # save valuehead weights if hasattr(model, "v_head"): # save valuehead weights
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME)) self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self): def _load_best_model(self):

11
src/utils/ppo.py

@ -75,7 +75,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.log_callback = callbacks[0] self.log_callback = callbacks[0]
self.state = TrainerState() self.state = TrainerState()
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
def ppo_train(self, max_target_length: int) -> None: def ppo_train(self, max_target_length: int) -> None:
r""" r"""
@ -148,7 +148,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Compute rewards # Compute rewards
replace_model(unwrapped_model, target="reward") replace_model(unwrapped_model, target="reward")
_, _, values = self.model(**self.prepare_model_inputs(queries, responses)) _, _, values = self.model(**self.prepare_model_inputs(queries, responses))
rewards = [reward for reward in values[:, -1]] rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default") # make sure the model is default at the end replace_model(unwrapped_model, target="default") # make sure the model is default at the end
# Run PPO step # Run PPO step
@ -214,13 +214,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
return response[:, inputs["input_ids"].size(1):] return response[:, inputs["input_ids"].size(1):]
return response return response
def prepare_model_inputs(self, queries: List[torch.Tensor], responses: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
input_data = self.data_collator([{"input_ids": ids} for ids in input_ids])
input_data = {k: v.to(self.current_device) for k, v in input_data.items() if v is not None}
input_data.pop("labels", None) # we don't want to compute LM losses
return input_data
@PPODecorators.empty_cuda_cache() @PPODecorators.empty_cuda_cache()
def batched_forward_pass( def batched_forward_pass(
self, self,

10
src/web_demo.py

@ -7,21 +7,23 @@ import torch
import mdtex2html import mdtex2html
import gradio as gr import gradio as gr
from utils import ModelArguments, load_pretrained from utils import ModelArguments, FinetuningArguments, load_pretrained
from transformers import HfArgumentParser from transformers import HfArgumentParser
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems
parser = HfArgumentParser(ModelArguments) parser = HfArgumentParser((ModelArguments, FinetuningArguments))
model_args, = parser.parse_args_into_dataclasses() model_args, finetuning_args = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args) model, tokenizer = load_pretrained(model_args, finetuning_args)
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map from accelerate import dispatch_model, infer_auto_device_map
device_map = infer_auto_device_map(model) device_map = infer_auto_device_map(model)
model = dispatch_model(model, device_map) model = dispatch_model(model, device_map)
else: else:
model = model.cuda() model = model.cuda()
model.eval() model.eval()

Loading…
Cancel
Save