|
@ -477,13 +477,13 @@ def preprocess_data( |
|
|
desc="Running tokenizer on dataset" |
|
|
desc="Running tokenizer on dataset" |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if stage == "pt": |
|
|
if stage == "pt": |
|
|
print_unsupervised_dataset_example(dataset[0]) |
|
|
print_unsupervised_dataset_example(dataset[0]) |
|
|
elif stage == "sft": |
|
|
elif stage == "sft": |
|
|
print_supervised_dataset_example(dataset[0]) |
|
|
print_supervised_dataset_example(dataset[0]) |
|
|
elif stage == "rm": |
|
|
elif stage == "rm": |
|
|
print_pairwise_dataset_example(dataset[0]) |
|
|
print_pairwise_dataset_example(dataset[0]) |
|
|
elif stage == "ppo": |
|
|
elif stage == "ppo": |
|
|
print_unsupervised_dataset_example(dataset[0]) |
|
|
print_unsupervised_dataset_example(dataset[0]) |
|
|
|
|
|
|
|
|
return dataset |
|
|
return dataset |
|
|