|
@ -5,7 +5,6 @@ import torch |
|
|
import logging |
|
|
import logging |
|
|
from typing import Dict, List, Optional |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
|
|
|
from transformers import Seq2SeqTrainingArguments |
|
|
|
|
|
from transformers.trainer import TRAINER_STATE_NAME |
|
|
from transformers.trainer import TRAINER_STATE_NAME |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.generation.utils import LogitsProcessorList |
|
|
from transformers.generation.utils import LogitsProcessorList |
|
@ -143,7 +142,7 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) - |
|
|
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) |
|
|
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]: |
|
|
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]: |
|
|
""" |
|
|
""" |
|
|
EMA implementation according to TensorBoard. |
|
|
EMA implementation according to TensorBoard. |
|
|
""" |
|
|
""" |
|
@ -156,9 +155,10 @@ def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]: |
|
|
return smoothed |
|
|
return smoothed |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]] = ["loss"]) -> None: |
|
|
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.pyplot as plt |
|
|
data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r")) |
|
|
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: |
|
|
|
|
|
data = json.load(f) |
|
|
|
|
|
|
|
|
for key in keys: |
|
|
for key in keys: |
|
|
steps, metrics = [], [] |
|
|
steps, metrics = [], [] |
|
@ -174,9 +174,9 @@ def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]] |
|
|
plt.figure() |
|
|
plt.figure() |
|
|
plt.plot(steps, metrics, alpha=0.4, label="original") |
|
|
plt.plot(steps, metrics, alpha=0.4, label="original") |
|
|
plt.plot(steps, smooth(metrics), label="smoothed") |
|
|
plt.plot(steps, smooth(metrics), label="smoothed") |
|
|
plt.title("training {} of {}".format(key, training_args.output_dir)) |
|
|
plt.title("training {} of {}".format(key, save_dictionary)) |
|
|
plt.xlabel("step") |
|
|
plt.xlabel("step") |
|
|
plt.ylabel(key) |
|
|
plt.ylabel(key) |
|
|
plt.legend() |
|
|
plt.legend() |
|
|
plt.savefig(os.path.join(training_args.output_dir, "training_{}.png".format(key)), format="png", dpi=100) |
|
|
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100) |
|
|
print("Figure saved:", os.path.join(training_args.output_dir, "training_{}.png".format(key))) |
|
|
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key))) |
|
|