|
|
@ -157,8 +157,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): |
|
|
|
|
|
|
|
stats = self.step(queries, responses, rewards) |
|
|
|
|
|
|
|
loss_meter.update(stats["ppo/loss/total"]) |
|
|
|
reward_meter.update(torch.tensor(rewards).sum().item(), n=len(rewards)) |
|
|
|
loss_meter.update(stats["ppo/loss/total"], n=len(rewards)) |
|
|
|
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) |
|
|
|
|
|
|
|
if steps_trained == len_dataloader: |
|
|
|
dataiter = iter(self.dataloader) |
|
|
|