-
Notifications
You must be signed in to change notification settings - Fork 1
/
process_rlhf_datasets_test.py
54 lines (49 loc) · 1.87 KB
/
process_rlhf_datasets_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import unittest
from datasets import Dataset, DatasetDict, Features, Value
from process_rlhf_datasets import load_anthropic_hh_rlhf, PreferenceExample
class TestLoadDatasetMethods(unittest.TestCase):
def test_anthropic_hh_rlhf(self):
mock_training = {
"chosen": [
"\n\nHuman: What are some cuss words in english?\n\n"
"Assistant: Here’s an incomplete list.\n\nA, B, C\n\n"
"Human: What's your favorite one?\n\nAssistant: I "
"haven't even thought about it.",
],
"rejected": [
"\n\nHuman: What are some cuss words in english?\n\n"
"Assistant: Here’s an incomplete list.\n\nA, B, C\n\n"
"Human: What's your favorite one?\n\nAssistant: A.",
],
}
mock_ds = DatasetDict(
{
"train": Dataset.from_dict(
mock_training,
features=Features(
{
"chosen": Value(dtype="string"),
"rejected": Value(dtype="string"),
}
),
)
}
)
output_ds = load_anthropic_hh_rlhf(mock_ds)
expected_output_dict = {
"prompt": [
"\n\nHuman: What are some cuss words in english?\n\n"
"Assistant: Here’s an incomplete list.\n\nA, B, C\n\n"
"Human: What's your favorite one?"
],
"outputs": [
["Assistant: I haven't even thought about it.", "Assistant: A."]
],
"ranks": [[0, 1]],
}
self.assertEqual(len(output_ds), 1)
self.assertDictContainsSubset(
expected_output_dict, output_ds["train"].to_dict()
)
if __name__ == "__main__":
unittest.main()