File size: 86,595 Bytes
05eec1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4YErqpfH9jVI"
   },
   "source": [
    "---\n",
    "title: RAG Evaluation\n",
    "---\n",
    "_Authored by: [Aymeric Roucher](https://huggingface.co/m-ric)_\n",
    "\n",
    "This notebook demonstrates how you can evaluate your RAG (Retrieval Augmented Generation), by building a synthetic evaluation dataset and using LLM-as-a-judge to compute the accuracy of your system.\n",
    "\n",
    "For an introduction to RAG, you can check [this other cookbook](rag_zephyr_langchain)!\n",
    "\n",
    "RAG systems are complex: here a RAG diagram, where we noted in blue all possibilities for system enhancement:\n",
    "\n",
    "<img src=\"https://huggingface.co/datasets/huggingface/cookbook-images/resolve/main/RAG_workflow.png\" height=\"700\">\n",
    "\n",
    "Implementing any of these improvements can bring a huge performance boost; but changing anything is useless if you cannot monitor the impact of your changes on the system's performance!\n",
    "So let's see how to evaluate our RAG system.\n",
    "\n",
    "### Evaluating RAG performance\n",
    "\n",
    "Since there are so many moving parts to tune with a big impact on performance, benchmarking the RAG system is crucial.\n",
    "\n",
    "For our evaluation pipeline, we will need:\n",
    "1. An evaluation dataset with question - answer couples (QA couples)\n",
    "2. An evaluator to compute the accuracy of our system on the above evaluation dataset.\n",
    "\n",
    "➑️ It turns out, we can use LLMs to help us all along the way!\n",
    "1. The evaluation dataset will be synthetically generated by an LLM πŸ€–, and questions will be filtered out by other LLMs πŸ€–\n",
    "2. An [LLM-as-a-judge](https://huggingface.co/papers/2306.05685) agent πŸ€– will then perform the evaluation on this synthetic dataset.\n",
    "\n",
    "__Let's dig into it and start building our evaluation pipeline!__ First, we install the required model dependancies."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bCKBvOcp9jVK"
   },
   "outputs": [],
   "source": [
    "!pip install -q torch transformers transformers langchain sentence-transformers faiss-gpu openpyxl openai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "k_lJFbYm9jVL"
   },
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%reload_ext dotenv\n",
    "%dotenv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oIlNZ1Mn9jVL"
   },
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "import pandas as pd\n",
    "from typing import Optional, List, Tuple\n",
    "from langchain_core.language_models import BaseChatModel\n",
    "import json\n",
    "import datasets\n",
    "\n",
    "pd.set_option(\"display.max_colwidth\", None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zeW8P62J9jVM"
   },
   "source": [
    "### Load your knowledge base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YRbm5tNF9jVM"
   },
   "outputs": [],
   "source": [
    "ds = datasets.load_dataset(\"m-ric/huggingface_doc\", split=\"train\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wy9CKj0M9jVM"
   },
   "source": [
    "# 1. Build a synthetic dataset for evaluation\n",
    "We first build a synthetic dataset of questions and associated contexts. The method is to get elements from our knowledge base, and ask an LLM to generate questions based on these documents.\n",
    "\n",
    "Then we setup other LLM agents to act as quality filters for the generated QA couples: each of them will act as the filter for a specific flaw."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QkoEgiDg9jVM"
   },
   "source": [
    "### 1.1. Prepare source documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3gTOlRKO9jVM"
   },
   "outputs": [],
   "source": [
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "from langchain.docstore.document import Document as LangchainDocument\n",
    "\n",
    "langchain_docs = [\n",
    "    LangchainDocument(page_content=doc[\"text\"], metadata={\"source\": doc[\"source\"]})\n",
    "    for doc in tqdm(ds)\n",
    "]\n",
    "\n",
    "\n",
    "text_splitter = RecursiveCharacterTextSplitter(\n",
    "    chunk_size=2000,\n",
    "    chunk_overlap=200,\n",
    "    add_start_index=True,\n",
    "    separators=[\"\\n\\n\", \"\\n\", \".\", \" \", \"\"],\n",
    ")\n",
    "\n",
    "docs_processed = []\n",
    "for doc in langchain_docs:\n",
    "    docs_processed += text_splitter.split_documents([doc])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WjrNhcCh9jVN"
   },
   "source": [
    "### 1.2. Setup agents for question generation\n",
    "\n",
    "We use [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) for QA couple generation because it it has excellent performance in leaderboards such as [Chatbot Arena](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GoRySj3Q9jVN"
   },
   "outputs": [],
   "source": [
    "from langchain_community.llms import HuggingFaceHub\n",
    "\n",
    "repo_id = \"mistralai/Mixtral-8x7B-Instruct-v0.1\"\n",
    "\n",
    "llm = HuggingFaceHub(\n",
    "    repo_id=repo_id,\n",
    "    task=\"text-generation\",\n",
    "    model_kwargs={\n",
    "        \"max_new_tokens\": 512,\n",
    "        \"top_k\": 30,\n",
    "        \"temperature\": 0.1,\n",
    "        \"repetition_penalty\": 1.03,\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wubTNTaV9jVN"
   },
   "outputs": [],
   "source": [
    "from langchain_community.chat_models import ChatHuggingFace\n",
    "\n",
    "chat_model = ChatHuggingFace(llm=llm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hIM_DJRo9jVN"
   },
   "outputs": [],
   "source": [
    "from langchain.prompts import ChatPromptTemplate\n",
    "\n",
    "QA_generation_prompt = \"\"\"\n",
    "Your task is to write a factoid question and an answer given a context.\n",
    "Your factoid question should be answerable with a specific, concise piece of factual information from the context.\n",
    "Your factoid question should be formulated in the same style as questions users could ask in a search engine.\n",
    "This means that your factoid question MUST NOT mention something like \"according to the passage\" or \"context\".\n",
    "\n",
    "Provide your answer as follows:\n",
    "\n",
    "Output:::\n",
    "Factoid question: (your factoid question)\n",
    "Answer: (your answer to the factoid question)\n",
    "\n",
    "Now here is the context.\n",
    "\n",
    "Context: {context}\\n\n",
    "Output:::\"\"\"\n",
    "\n",
    "QA_generation_prompt = ChatPromptTemplate.from_template(QA_generation_prompt)\n",
    "QA_generation_agent = QA_generation_prompt | chat_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lVFc-lVy9jVN"
   },
   "source": [
    "Now let's generate our QA couples.\n",
    "For this example, we generate only 10 QA couples and will load the rest from the Hub.\n",
    "\n",
    "But for your specific knowledge base, given that you want to get at least ~100 test samples, and accounting for the fact that we will filter out around half of these with our critique agents later on, you should generate much more, in the >200 samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8fteqDDD9jVN"
   },
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "N_GENERATIONS = (\n",
    "    10  # We intentionally generate only 10 QA couples here for cost and time considerations\n",
    ")\n",
    "\n",
    "print(f\"Generating {N_GENERATIONS} QA couples...\")\n",
    "outputs = []\n",
    "for context in tqdm(random.sample(langchain_docs, N_GENERATIONS)):\n",
    "    # Generate QA couple\n",
    "    output_QA_couple = QA_generation_agent.invoke({\"context\": context.page_content}).content\n",
    "    try:\n",
    "        question = output_QA_couple.split(\"Factoid question: \")[1].split(\"Answer: \")[0]\n",
    "        answer = output_QA_couple.split(\"Answer: \")[1]\n",
    "        outputs.append(\n",
    "            {\n",
    "                \"context\": context.page_content,\n",
    "                \"question\": question,\n",
    "                \"answer\": answer,\n",
    "                \"source_doc\": context.metadata[\"source\"],\n",
    "            }\n",
    "        )\n",
    "    except:\n",
    "        continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aUlOUDv59jVN",
    "outputId": "c9634fdb-2a7f-43a6-c4eb-e60b166b8238"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>context</th>\n",
       "      <th>question</th>\n",
       "      <th>answer</th>\n",
       "      <th>source_doc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>!--Copyright 2023 The HuggingFace Team. All rights reserved.\\n\\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\\nthe License. You may obtain a copy of the License at\\n\\nhttp://www.apache.org/licenses/LICENSE-2.0\\n\\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\\nspecific language governing permissions and limitations under the License.\\n--&gt;\\n\\n# Schedulers\\n\\nπŸ€— Diffusers provides many scheduler functions for the diffusion process. A scheduler takes a model's output (the sample which the diffusion process is iterating on) and a timestep to return a denoised sample. The timestep is important because it dictates where in the diffusion process the step is; data is generated by iterating forward *n* timesteps and inference occurs by propagating backward through the timesteps. Based on the timestep, a scheduler may be *discrete* in which case the timestep is an `int` or *continuous* in which case the timestep is a `float`.\\n\\nDepending on the context, a scheduler defines how to iteratively add noise to an image or how to update a sample based on a model's output:\\n\\n- during *training*, a scheduler adds noise (there are different algorithms for how to add noise) to a sample to train a diffusion model\\n- during *inference*, a scheduler defines how to update a sample based on a pretrained model's output\\n\\nMany schedulers are implemented from the [k-diffusion](https://github.com/crowsonkb/k-diffusion) library by [Katherine Crowson](https://github.com/crowsonkb/), and they're also widely used in A1111. To help you map the schedulers from k-diffusion and A1111 to the schedulers in πŸ€— Diffusers, take a look at the table below:\\n\\n| A1111/k-diffusion    | πŸ€— Diffusers                         | Usage                                                                                                         |\\n|---------------------|-------------------------------------|---------------------------------------------------------------------------------------------------------------|\\n| DPM++ 2M            | [`DPMSolverMultistepScheduler`]     |                                                                                                               |\\n| DPM++ 2M Karras     | [`DPMSolverMultistepScheduler`]     | init with `use_karras_sigmas=True`                                                                            |\\n| DPM++ 2M SDE        | [`DPMSolverMultistepScheduler`]     | init with `algorithm_type=\"sde-dpmsolver++\"`                                                                  |\\n| DPM++ 2M SDE Karras | [`DPMSolverMultistepScheduler`]     | init with `use_karras_sigmas=True` and `algorithm_type=\"sde-dpmsolver++\"`                                     |\\n| DPM++ 2S a          | N/A                                 | very similar to  `DPMSolverSinglestepScheduler`                         |\\n| DPM++ 2S a Karras   | N/A                                 | very similar to  `DPMSolverSinglestepScheduler(use_karras_sigmas=True, ...)` |\\n| DPM++ SDE           | [`DPMSolverSinglestepScheduler`]    |                                                                                                               |\\n| DPM++ SDE Karras    | [`DPMSolverSinglestepScheduler`]    | init with `use_karras_sigmas=True`                                                                            |\\n| DPM2                | [`KDPM2DiscreteScheduler`]          |                                                                                                               |\\n| DPM2 Karras         | [`KDPM2DiscreteScheduler`]          | init with `use_karras_sigmas=True`                                                                            |\\n| DPM2 a              | [`KDPM2AncestralDiscreteScheduler`] |                                                                                                               |\\n| DPM2 a Karras       | [`KDPM2AncestralDiscreteScheduler`] | init with `use_karras_sigmas=True`                                                                            |\\n| DPM adaptive        | N/A                                 |                                                                                                               |\\n| DPM fast            | N/A                                 |                                                                                                               |\\n| Euler               | [`EulerDiscreteScheduler`]          |                                                                                                               |\\n| Euler a             | [`EulerAncestralDiscreteScheduler`] |                                                                                                               |\\n| Heun                | [`HeunDiscreteScheduler`]           |                                                                                                               |\\n| LMS                 | [`LMSDiscreteScheduler`]            |                                                                                                               |\\n| LMS Karras          | [`LMSDiscreteScheduler`]            | init with `use_karras_sigmas=True`                                                                            |\\n| N/A                 | [`DEISMultistepScheduler`]          |                                                                                                               |\\n| N/A                 | [`UniPCMultistepScheduler`]         |                                                                                                               |\\n\\nAll schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers.\\n\\n## SchedulerMixin\\n[[autodoc]] SchedulerMixin\\n\\n## SchedulerOutput\\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\\n\\n## KarrasDiffusionSchedulers\\n\\n[`KarrasDiffusionSchedulers`] are a broad generalization of schedulers in πŸ€— Diffusers. The schedulers in this class are distinguished at a high level by their noise sampling strategy, the type of network and scaling, the training strategy, and how the loss is weighed.\\n\\nThe different schedulers in this class, depending on the ordinary differential equations (ODE) solver type, fall into the above taxonomy and provide a good abstraction for the design of the main schedulers implemented in πŸ€— Diffusers. The schedulers in this class are given [here](https://github.com/huggingface/diffusers/blob/a69754bb879ed55b9b6dc9dd0b3cf4fa4124c765/src/diffusers/schedulers/scheduling_utils.py#L32).\\n\\n## PushToHubMixin\\n\\n[[autodoc]] utils.PushToHubMixin\\n</td>\n",
       "      <td>What is the class of schedulers in πŸ€— Diffusers that are distinguished by their noise sampling strategy, type of network and scaling, training strategy, and loss weighing?\\n</td>\n",
       "      <td>[`KarrasDiffusionSchedulers`]</td>\n",
       "      <td>huggingface/diffusers/blob/main/docs/source/en/api/schedulers/overview.md</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          context  \\\n",
       "0  !--Copyright 2023 The HuggingFace Team. All rights reserved.\\n\\nLicensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with\\nthe License. You may obtain a copy of the License at\\n\\nhttp://www.apache.org/licenses/LICENSE-2.0\\n\\nUnless required by applicable law or agreed to in writing, software distributed under the License is distributed on\\nan \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the\\nspecific language governing permissions and limitations under the License.\\n-->\\n\\n# Schedulers\\n\\nπŸ€— Diffusers provides many scheduler functions for the diffusion process. A scheduler takes a model's output (the sample which the diffusion process is iterating on) and a timestep to return a denoised sample. The timestep is important because it dictates where in the diffusion process the step is; data is generated by iterating forward *n* timesteps and inference occurs by propagating backward through the timesteps. Based on the timestep, a scheduler may be *discrete* in which case the timestep is an `int` or *continuous* in which case the timestep is a `float`.\\n\\nDepending on the context, a scheduler defines how to iteratively add noise to an image or how to update a sample based on a model's output:\\n\\n- during *training*, a scheduler adds noise (there are different algorithms for how to add noise) to a sample to train a diffusion model\\n- during *inference*, a scheduler defines how to update a sample based on a pretrained model's output\\n\\nMany schedulers are implemented from the [k-diffusion](https://github.com/crowsonkb/k-diffusion) library by [Katherine Crowson](https://github.com/crowsonkb/), and they're also widely used in A1111. To help you map the schedulers from k-diffusion and A1111 to the schedulers in πŸ€— Diffusers, take a look at the table below:\\n\\n| A1111/k-diffusion    | πŸ€— Diffusers                         | Usage                                                                                                         |\\n|---------------------|-------------------------------------|---------------------------------------------------------------------------------------------------------------|\\n| DPM++ 2M            | [`DPMSolverMultistepScheduler`]     |                                                                                                               |\\n| DPM++ 2M Karras     | [`DPMSolverMultistepScheduler`]     | init with `use_karras_sigmas=True`                                                                            |\\n| DPM++ 2M SDE        | [`DPMSolverMultistepScheduler`]     | init with `algorithm_type=\"sde-dpmsolver++\"`                                                                  |\\n| DPM++ 2M SDE Karras | [`DPMSolverMultistepScheduler`]     | init with `use_karras_sigmas=True` and `algorithm_type=\"sde-dpmsolver++\"`                                     |\\n| DPM++ 2S a          | N/A                                 | very similar to  `DPMSolverSinglestepScheduler`                         |\\n| DPM++ 2S a Karras   | N/A                                 | very similar to  `DPMSolverSinglestepScheduler(use_karras_sigmas=True, ...)` |\\n| DPM++ SDE           | [`DPMSolverSinglestepScheduler`]    |                                                                                                               |\\n| DPM++ SDE Karras    | [`DPMSolverSinglestepScheduler`]    | init with `use_karras_sigmas=True`                                                                            |\\n| DPM2                | [`KDPM2DiscreteScheduler`]          |                                                                                                               |\\n| DPM2 Karras         | [`KDPM2DiscreteScheduler`]          | init with `use_karras_sigmas=True`                                                                            |\\n| DPM2 a              | [`KDPM2AncestralDiscreteScheduler`] |                                                                                                               |\\n| DPM2 a Karras       | [`KDPM2AncestralDiscreteScheduler`] | init with `use_karras_sigmas=True`                                                                            |\\n| DPM adaptive        | N/A                                 |                                                                                                               |\\n| DPM fast            | N/A                                 |                                                                                                               |\\n| Euler               | [`EulerDiscreteScheduler`]          |                                                                                                               |\\n| Euler a             | [`EulerAncestralDiscreteScheduler`] |                                                                                                               |\\n| Heun                | [`HeunDiscreteScheduler`]           |                                                                                                               |\\n| LMS                 | [`LMSDiscreteScheduler`]            |                                                                                                               |\\n| LMS Karras          | [`LMSDiscreteScheduler`]            | init with `use_karras_sigmas=True`                                                                            |\\n| N/A                 | [`DEISMultistepScheduler`]          |                                                                                                               |\\n| N/A                 | [`UniPCMultistepScheduler`]         |                                                                                                               |\\n\\nAll schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers.\\n\\n## SchedulerMixin\\n[[autodoc]] SchedulerMixin\\n\\n## SchedulerOutput\\n[[autodoc]] schedulers.scheduling_utils.SchedulerOutput\\n\\n## KarrasDiffusionSchedulers\\n\\n[`KarrasDiffusionSchedulers`] are a broad generalization of schedulers in πŸ€— Diffusers. The schedulers in this class are distinguished at a high level by their noise sampling strategy, the type of network and scaling, the training strategy, and how the loss is weighed.\\n\\nThe different schedulers in this class, depending on the ordinary differential equations (ODE) solver type, fall into the above taxonomy and provide a good abstraction for the design of the main schedulers implemented in πŸ€— Diffusers. The schedulers in this class are given [here](https://github.com/huggingface/diffusers/blob/a69754bb879ed55b9b6dc9dd0b3cf4fa4124c765/src/diffusers/schedulers/scheduling_utils.py#L32).\\n\\n## PushToHubMixin\\n\\n[[autodoc]] utils.PushToHubMixin\\n   \n",
       "\n",
       "                                                                                                                                                                       question  \\\n",
       "0  What is the class of schedulers in πŸ€— Diffusers that are distinguished by their noise sampling strategy, type of network and scaling, training strategy, and loss weighing?\\n   \n",
       "\n",
       "                          answer  \\\n",
       "0  [`KarrasDiffusionSchedulers`]   \n",
       "\n",
       "                                                                  source_doc  \n",
       "0  huggingface/diffusers/blob/main/docs/source/en/api/schedulers/overview.md  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display(pd.DataFrame(outputs).head(1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0KG4dNtg9jVN"
   },
   "source": [
    "### 1.3. Setup critique agents\n",
    "\n",
    "The questions generated by the previous agent can have many flaws: we should do a quality check before validating these questions.\n",
    "\n",
    "We thus build critique agents that will rate each question on several criteria, given in [this paper](https://huggingface.co/papers/2312.10003):\n",
    "- **Groundedness:** can the question be answered from the given context?\n",
    "- **Relevance:** is the question relevant to users? For instance, `\"What is the date when transformers 4.29.1 was released?\"` is not relevant for ML practicioners.\n",
    "\n",
    "One last failure case we've noticed is when a function is tailored for the particular setting where the question was generated, but undecipherable by itself, like `\"What is the name of the function used in this guide?\"`.\n",
    "We also build a critique agent for this criteria:\n",
    "- **Stand-alone**: is the question understandable free of any context, for someone with domain knowledge/Internet access? The opposite of this would be `What is the function used in this article?` for a question generated from a specific blog article.\n",
    "\n",
    "We systematically score functions with all these agents, and whenever the score is too low for any one of the agents, we eliminate the question from our eval dataset.\n",
    "\n",
    "πŸ’‘ ___When asking the agents to output a score, we first ask them to produce its rationale. This will help us verify scores, but most importantly, asking it to first output rationale gives the model more tokens to think and elaborate an answer before summarizing it into a single score token.___\n",
    "\n",
    "We now build and run these critique agents."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "05aSgTGs9jVO"
   },
   "outputs": [],
   "source": [
    "question_groundedness_critique_prompt = \"\"\"\n",
    "You will be given a context and a question.\n",
    "Your task is to provide a 'total rating' scoring how well one can answer the given question unambiguously with the given context.\n",
    "Give your answer on a scale of 1 to 5, where 1 means that the question is not answerable at all given the context, and 5 means that the question is clearly and unambiguously answerable with the context.\n",
    "\n",
    "Provide your answer as follows:\n",
    "\n",
    "Answer:::\n",
    "Evaluation: (your rationale for the rating)\n",
    "Total rating: (your rating)\n",
    "\n",
    "Now here are the question and context.\n",
    "\n",
    "Question: {question}\\n\n",
    "Context: {context}\\n\n",
    "Answer::: \"\"\"\n",
    "\n",
    "question_relevance_critique_prompt = \"\"\"\n",
    "You will be given a question.\n",
    "Your task is to provide a 'total rating' representing how useful this question can be to machine learning developers building NLP applications with the Hugging Face ecosystem.\n",
    "Give your answer on a scale of 1 to 5, where 1 means that the question is not useful at all, and 5 means that the question is extremely useful.\n",
    "\n",
    "Provide your answer as follows:\n",
    "\n",
    "Answer:::\n",
    "Evaluation: (your rationale for the rating)\n",
    "Total rating: (your rating)\n",
    "\n",
    "Now here is the question.\n",
    "\n",
    "Question: {question}\\n\n",
    "Answer::: \"\"\"\n",
    "\n",
    "question_standalone_critique_prompt = \"\"\"\n",
    "You will be given a question.\n",
    "Your task is to provide a 'total rating' representing how context-independant this question is.\n",
    "Give your answer on a scale of 1 to 5, where 1 means that the question only makes sense in a specific context, and 5 means that the question makes sense by itself.\n",
    "For instance, if the question refers to a particular setting, like 'in the context' or 'in the document', the rating must be 1.\n",
    "The questions can contain obscure technical nouns or acronyms like Gradio, Hub, Hugging Face or Space and still be a 5: it must simply be clear to an operator with access to documentation what the question is about.\n",
    "\n",
    "Provide your answer as follows:\n",
    "\n",
    "Answer:::\n",
    "Evaluation: (your rationale for the rating)\n",
    "Total rating: (your rating)\n",
    "\n",
    "Now here is the question.\n",
    "\n",
    "Question: {question}\\n\n",
    "Answer::: \"\"\"\n",
    "\n",
    "question_groundedness_critique_prompt = ChatPromptTemplate.from_template(\n",
    "    question_groundedness_critique_prompt\n",
    ")\n",
    "question_groundedness_critique_agent = question_groundedness_critique_prompt | chat_model\n",
    "\n",
    "question_relevance_critique_prompt = ChatPromptTemplate.from_template(\n",
    "    question_relevance_critique_prompt\n",
    ")\n",
    "question_relevance_critique_agent = question_relevance_critique_prompt | chat_model\n",
    "\n",
    "question_standalone_critique_prompt = ChatPromptTemplate.from_template(\n",
    "    question_standalone_critique_prompt\n",
    ")\n",
    "question_standalone_critique_agent = question_standalone_critique_prompt | chat_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b9tbk7ME9jVO"
   },
   "outputs": [],
   "source": [
    "print(\"Generating critique for each QA couple...\")\n",
    "for output in tqdm(outputs):\n",
    "    # Critique the generated QA couple\n",
    "    question_groundedness_evaluation = question_groundedness_critique_agent.invoke(\n",
    "        {\"context\": output[\"context\"], \"question\": output[\"question\"]}\n",
    "    ).content\n",
    "    question_relevance_evaluation = question_relevance_critique_agent.invoke(\n",
    "        {\"question\": output[\"question\"]}\n",
    "    ).content\n",
    "    question_standalone_evaluation = question_standalone_critique_agent.invoke(\n",
    "        {\"question\": output[\"question\"]}\n",
    "    ).content\n",
    "\n",
    "    try:\n",
    "        groundedness_score = int(question_groundedness_evaluation.split(\"Total rating: \")[1][0])\n",
    "        groundedness_eval = question_groundedness_evaluation.split(\"Total rating: \")[0].split(\n",
    "            \"Evaluation: \"\n",
    "        )[1]\n",
    "        relevance_score = int(question_relevance_evaluation.split(\"Total rating: \")[1][0])\n",
    "        relevance_eval = question_relevance_evaluation.split(\"Total rating: \")[0].split(\n",
    "            \"Evaluation: \"\n",
    "        )[1]\n",
    "        standalone_score = int(question_standalone_evaluation.split(\"Total rating: \")[1][0])\n",
    "        standalone_eval = question_standalone_evaluation.split(\"Total rating: \")[0].split(\n",
    "            \"Evaluation: \"\n",
    "        )[1]\n",
    "        output.update(\n",
    "            {\n",
    "                \"groundedness_score\": groundedness_score,\n",
    "                \"groundedness_eval\": groundedness_eval,\n",
    "                \"relevance_score\": relevance_score,\n",
    "                \"relevance_eval\": relevance_eval,\n",
    "                \"standalone_score\": standalone_score,\n",
    "                \"standalone_eval\": standalone_eval,\n",
    "            }\n",
    "        )\n",
    "    except:\n",
    "        continue"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IQv36Y_f9jVO"
   },
   "source": [
    "Now let us filter out bad questions based on our critique agent scores:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oBWuOu1b9jVO",
    "outputId": "b32bacea-52f8-486a-96fe-5c188605c5a2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation dataset before filtering:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>question</th>\n",
       "      <th>answer</th>\n",
       "      <th>groundedness_score</th>\n",
       "      <th>relevance_score</th>\n",
       "      <th>standalone_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>What is the class of schedulers in πŸ€— Diffusers that are distinguished by their noise sampling strategy, type of network and scaling, training strategy, and loss weighing?\\n</td>\n",
       "      <td>[`KarrasDiffusionSchedulers`]</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>What are some utility functions provided by the Hugging Face library for pipelines?\\n</td>\n",
       "      <td>The Hugging Face library provides several utility functions for pipelines, including `ArgumentHandler`, `ZeroShotClassificationArgumentHandler`, `QuestionAnsweringArgumentHandler` for argument handling, `PipelineDataFormat`, `CsvPipelineDataFormat`, `JsonPipelineDataFormat`, `PipedPipelineDataFormat` for data format, and `PipelineException` for exceptions.</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>What is the default name used in the Gradio demo if no name is provided?\\n</td>\n",
       "      <td>User\\n\\nExplanation: The factoid question asks for the default name used in the Gradio demo if no name is provided. The answer to this question can be found in the `argparse.ArgumentParser()` function, where a default value of \"User\" is set for the `--name` argument.</td>\n",
       "      <td>5.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>What is the function used to load a pre-trained Resnet-18 model in the provided context?\\n</td>\n",
       "      <td>The function used to load a pre-trained Resnet-18 model in the provided context is `torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()`.</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>What is the name of the component used for creating a button in the given code?\\n</td>\n",
       "      <td>The name of the component used for creating a button in the given code is `BaseButton`.</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>What is the command to get the example ONNX file for Bart model?\\n</td>\n",
       "      <td>The command is `python run_onnx_exporter.py --model_name_or_path facebook/bart-base`.</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>What will be covered in the next unit of the course?\\n</td>\n",
       "      <td>The next unit of the course will cover learning more about Unity MLAgents and training agents in Unity environments. It will also prepare students for AI vs AI challenges where they will train their agents to compete against other agents in a snowball fight and a soccer game.</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>What is the purpose of the `negative_original_size`, `negative_crops_coords_top_left`, and `negative_target_size` parameters in SDXL?\\n</td>\n",
       "      <td>These parameters allow SDXL to negatively condition the model on image resolution and cropping parameters.</td>\n",
       "      <td>2.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>How are transformers models tested in the Hugging Face repository?\\n</td>\n",
       "      <td>Transformers models are tested in the Hugging Face repository using two test suites: `tests` for the general API and `examples` for various applications that aren't part of the API. These tests are run on CircleCI and GitHub Actions, with different jobs and configurations for each. The tests can be run in various ways, including running all tests, getting the list of all tests, running a specific test module, and running specific tests by name or keyword expression. Additionally, there are options for running tests in parallel, repeating tests, and running tests on a specific GPU or CPU.</td>\n",
       "      <td>3.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>What command is used to create a virtual environment in the given context?\\n</td>\n",
       "      <td>The command used to create a virtual environment in the given context is `python -m venv &lt;env_name&gt;`.</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                                                                                                       question  \\\n",
       "0  What is the class of schedulers in πŸ€— Diffusers that are distinguished by their noise sampling strategy, type of network and scaling, training strategy, and loss weighing?\\n   \n",
       "1                                                                                         What are some utility functions provided by the Hugging Face library for pipelines?\\n   \n",
       "2                                                                                                    What is the default name used in the Gradio demo if no name is provided?\\n   \n",
       "3                                                                                    What is the function used to load a pre-trained Resnet-18 model in the provided context?\\n   \n",
       "4                                                                                             What is the name of the component used for creating a button in the given code?\\n   \n",
       "5                                                                                                            What is the command to get the example ONNX file for Bart model?\\n   \n",
       "6                                                                                                                        What will be covered in the next unit of the course?\\n   \n",
       "7                                       What is the purpose of the `negative_original_size`, `negative_crops_coords_top_left`, and `negative_target_size` parameters in SDXL?\\n   \n",
       "8                                                                                                          How are transformers models tested in the Hugging Face repository?\\n   \n",
       "9                                                                                                  What command is used to create a virtual environment in the given context?\\n   \n",
       "\n",
       "                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               answer  \\\n",
       "0                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       [`KarrasDiffusionSchedulers`]   \n",
       "1                                                                                                                                                                                                                                              The Hugging Face library provides several utility functions for pipelines, including `ArgumentHandler`, `ZeroShotClassificationArgumentHandler`, `QuestionAnsweringArgumentHandler` for argument handling, `PipelineDataFormat`, `CsvPipelineDataFormat`, `JsonPipelineDataFormat`, `PipedPipelineDataFormat` for data format, and `PipelineException` for exceptions.   \n",
       "2                                                                                                                                                                                                                                                                                                                                         User\\n\\nExplanation: The factoid question asks for the default name used in the Gradio demo if no name is provided. The answer to this question can be found in the `argparse.ArgumentParser()` function, where a default value of \"User\" is set for the `--name` argument.   \n",
       "3                                                                                                                                                                                                                                                                                                                                                                                                                                                   The function used to load a pre-trained Resnet-18 model in the provided context is `torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()`.   \n",
       "4                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             The name of the component used for creating a button in the given code is `BaseButton`.   \n",
       "5                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               The command is `python run_onnx_exporter.py --model_name_or_path facebook/bart-base`.   \n",
       "6                                                                                                                                                                                                                                                                                                                                The next unit of the course will cover learning more about Unity MLAgents and training agents in Unity environments. It will also prepare students for AI vs AI challenges where they will train their agents to compete against other agents in a snowball fight and a soccer game.   \n",
       "7                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          These parameters allow SDXL to negatively condition the model on image resolution and cropping parameters.   \n",
       "8  Transformers models are tested in the Hugging Face repository using two test suites: `tests` for the general API and `examples` for various applications that aren't part of the API. These tests are run on CircleCI and GitHub Actions, with different jobs and configurations for each. The tests can be run in various ways, including running all tests, getting the list of all tests, running a specific test module, and running specific tests by name or keyword expression. Additionally, there are options for running tests in parallel, repeating tests, and running tests on a specific GPU or CPU.   \n",
       "9                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               The command used to create a virtual environment in the given context is `python -m venv <env_name>`.   \n",
       "\n",
       "   groundedness_score  relevance_score  standalone_score  \n",
       "0                 3.0              1.0               4.0  \n",
       "1                 5.0              4.0               5.0  \n",
       "2                 5.0              3.0               5.0  \n",
       "3                 NaN              NaN               NaN  \n",
       "4                 5.0              1.0               5.0  \n",
       "5                 NaN              NaN               NaN  \n",
       "6                 5.0              1.0               5.0  \n",
       "7                 2.0              4.0               2.0  \n",
       "8                 3.0              4.0               4.0  \n",
       "9                 NaN              NaN               NaN  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================\n",
      "Final evaluation dataset:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>question</th>\n",
       "      <th>answer</th>\n",
       "      <th>groundedness_score</th>\n",
       "      <th>relevance_score</th>\n",
       "      <th>standalone_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>What are some utility functions provided by the Hugging Face library for pipelines?\\n</td>\n",
       "      <td>The Hugging Face library provides several utility functions for pipelines, including `ArgumentHandler`, `ZeroShotClassificationArgumentHandler`, `QuestionAnsweringArgumentHandler` for argument handling, `PipelineDataFormat`, `CsvPipelineDataFormat`, `JsonPipelineDataFormat`, `PipedPipelineDataFormat` for data format, and `PipelineException` for exceptions.</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                                                question  \\\n",
       "1  What are some utility functions provided by the Hugging Face library for pipelines?\\n   \n",
       "\n",
       "                                                                                                                                                                                                                                                                                                                                                                   answer  \\\n",
       "1  The Hugging Face library provides several utility functions for pipelines, including `ArgumentHandler`, `ZeroShotClassificationArgumentHandler`, `QuestionAnsweringArgumentHandler` for argument handling, `PipelineDataFormat`, `CsvPipelineDataFormat`, `JsonPipelineDataFormat`, `PipedPipelineDataFormat` for data format, and `PipelineException` for exceptions.   \n",
       "\n",
       "   groundedness_score  relevance_score  standalone_score  \n",
       "1                 5.0              4.0               5.0  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "pd.set_option(\"display.max_colwidth\", None)\n",
    "\n",
    "generated_questions = pd.DataFrame.from_dict(outputs)\n",
    "\n",
    "print(\"Evaluation dataset before filtering:\")\n",
    "display(\n",
    "    generated_questions[\n",
    "        [\"question\", \"answer\", \"groundedness_score\", \"relevance_score\", \"standalone_score\"]\n",
    "    ]\n",
    ")\n",
    "generated_questions = generated_questions.loc[\n",
    "    (generated_questions[\"groundedness_score\"] >= 4)\n",
    "    & (generated_questions[\"relevance_score\"] >= 4)\n",
    "    & (generated_questions[\"standalone_score\"] >= 4)\n",
    "]\n",
    "print(\"============================================\")\n",
    "print(\"Final evaluation dataset:\")\n",
    "display(\n",
    "    generated_questions[\n",
    "        [\"question\", \"answer\", \"groundedness_score\", \"relevance_score\", \"standalone_score\"]\n",
    "    ]\n",
    ")\n",
    "\n",
    "eval_dataset = datasets.Dataset.from_pandas(\n",
    "    generated_questions, split=\"train\", preserve_index=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HaOMZyu69jVO"
   },
   "source": [
    "Now our synthetic evaluation dataset is complete! We can evaluate different RAG systems on this evaluation dataset.\n",
    "\n",
    "We have generated only a few QA couples here to reduce time and cost. But let's kick start the next part by loading a pre-generated dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Q3RRz4W79jVO"
   },
   "outputs": [],
   "source": [
    "eval_dataset = datasets.load_dataset(\"m-ric/huggingface_doc_qa_eval\", split=\"train\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "K5s19uTd9jVO"
   },
   "source": [
    "# 2. Build our RAG System"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Z-mET8Dy9jVO"
   },
   "source": [
    "### 2.1. Preprocessing documents to build our vector database\n",
    "\n",
    "- In this part, __we split the documents from our knowledge base into smaller chunks__: these will be the snippets that are picked by the Retriever, to then be ingested by the Reader LLM as supporting elements for its answer.\n",
    "- The goal is to build semantically relevant snippets: not too small to be sufficient for supporting an answer, and not too large too avoid diluting individual ideas.\n",
    "\n",
    "Many options exist for text splitting:\n",
    "- split every `n` words / characters, but this has the risk of cutting in half paragraphs or even sentences\n",
    "- split after `n` words / character, but only on sentence boundaries\n",
    "- **recursive split** tries to preserve even more of the document structure, by processing it tree-like way, splitting first on the largest units (chapters) then recursively splitting on smaller units (paragraphs, sentences).\n",
    "\n",
    "To learn more about chunking, I recommend you read [this great notebook](https://github.com/FullStackRetrieval-com/RetrievalTutorials/blob/main/5_Levels_Of_Text_Splitting.ipynb) by Greg Kamradt.\n",
    "\n",
    "[This space](https://huggingface.co/spaces/m-ric/chunk_visualizer) lets you visualize how different splitting options affect the chunks you get.\n",
    "\n",
    "> In the following, we use Langchain's `RecursiveCharacterTextSplitter`.\n",
    "\n",
    "πŸ’‘ _To measure chunk length in our Text Splitter, our length function will not be the count of characters, but the count of tokens in the tokenized text: indeed, for subsequent embedder that processes token, measuring length in tokens is more relevant and empirically performs better._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "H4fhm55Q9jVO"
   },
   "outputs": [],
   "source": [
    "from langchain.docstore.document import Document as LangchainDocument\n",
    "\n",
    "RAW_KNOWLEDGE_BASE = [\n",
    "    LangchainDocument(page_content=doc[\"text\"], metadata={\"source\": doc[\"source\"]})\n",
    "    for doc in tqdm(ds)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "sz9Jw2_q9jVO"
   },
   "outputs": [],
   "source": [
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "\n",
    "def split_documents(\n",
    "    chunk_size: int,\n",
    "    knowledge_base: List[LangchainDocument],\n",
    "    tokenizer_name: str,\n",
    ") -> List[LangchainDocument]:\n",
    "    \"\"\"\n",
    "    Split documents into chunks of size `chunk_size` characters and return a list of documents.\n",
    "    \"\"\"\n",
    "    text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(\n",
    "        AutoTokenizer.from_pretrained(tokenizer_name),\n",
    "        chunk_size=chunk_size,\n",
    "        chunk_overlap=int(chunk_size / 10),\n",
    "        add_start_index=True,\n",
    "        strip_whitespace=True,\n",
    "        separators=[\"\\n\\n\", \"\\n\", \".\", \" \", \"\"],\n",
    "    )\n",
    "\n",
    "    docs_processed = []\n",
    "    for doc in knowledge_base:\n",
    "        docs_processed += text_splitter.split_documents([doc])\n",
    "\n",
    "    # Remove duplicates\n",
    "    unique_texts = {}\n",
    "    docs_processed_unique = []\n",
    "    for doc in docs_processed:\n",
    "        if doc.page_content not in unique_texts:\n",
    "            unique_texts[doc.page_content] = True\n",
    "            docs_processed_unique.append(doc)\n",
    "\n",
    "    return docs_processed_unique"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QzBYfNG79jVO"
   },
   "source": [
    "### 2.2. Retriever - embeddings πŸ—‚οΈ\n",
    "The __retriever acts like an internal search engine__: given the user query, it returns the most relevant documents from your knowledge base.\n",
    "\n",
    "> For the knowledge base, we use Langchain vector databases since __it offers a convenient [FAISS](https://github.com/facebookresearch/faiss) index and allows us to keep document metadata throughout the processing__.\n",
    "\n",
    "πŸ› οΈ __Options included:__\n",
    "\n",
    "- Tune the chunking method:\n",
    "    - Size of the chunks\n",
    "    - Method: split on different separators, use [semantic chunking](https://python.langchain.com/docs/modules/data_connection/document_transformers/semantic-chunker)...\n",
    "- Change the embedding model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LqJlIDZR9jVO"
   },
   "outputs": [],
   "source": [
    "from langchain.vectorstores import FAISS\n",
    "from langchain_community.embeddings import HuggingFaceEmbeddings\n",
    "from langchain_community.vectorstores.utils import DistanceStrategy\n",
    "import os\n",
    "\n",
    "\n",
    "def load_embeddings(\n",
    "    langchain_docs: List[LangchainDocument],\n",
    "    chunk_size: int,\n",
    "    embedding_model_name: Optional[str] = \"thenlper/gte-small\",\n",
    ") -> FAISS:\n",
    "    \"\"\"\n",
    "    Creates a FAISS index from the given embedding model and documents. Loads the index directly if it already exists.\n",
    "\n",
    "    Args:\n",
    "        langchain_docs: list of documents\n",
    "        chunk_size: size of the chunks to split the documents into\n",
    "        embedding_model_name: name of the embedding model to use\n",
    "\n",
    "    Returns:\n",
    "        FAISS index\n",
    "    \"\"\"\n",
    "    # load embedding_model\n",
    "    embedding_model = HuggingFaceEmbeddings(\n",
    "        model_name=embedding_model_name,\n",
    "        multi_process=True,\n",
    "        model_kwargs={\"device\": \"cuda\"},\n",
    "        encode_kwargs={\"normalize_embeddings\": True},  # set True to compute cosine similarity\n",
    "    )\n",
    "\n",
    "    # Check if embeddings already exist on disk\n",
    "    index_name = f\"index_chunk:{chunk_size}_embeddings:{embedding_model_name.replace('/', '~')}\"\n",
    "    index_folder_path = f\"./data/indexes/{index_name}/\"\n",
    "    if os.path.isdir(index_folder_path):\n",
    "        return FAISS.load_local(\n",
    "            index_folder_path,\n",
    "            embedding_model,\n",
    "            distance_strategy=DistanceStrategy.COSINE,\n",
    "        )\n",
    "\n",
    "    else:\n",
    "        print(\"Index not found, generating it...\")\n",
    "        docs_processed = split_documents(\n",
    "            chunk_size,\n",
    "            langchain_docs,\n",
    "            embedding_model_name,\n",
    "        )\n",
    "        knowledge_index = FAISS.from_documents(\n",
    "            docs_processed, embedding_model, distance_strategy=DistanceStrategy.COSINE\n",
    "        )\n",
    "        knowledge_index.save_local(index_folder_path)\n",
    "        return knowledge_index"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "b6y1mQJX9jVO"
   },
   "source": [
    "### 2.3. Reader - LLM πŸ’¬\n",
    "\n",
    "In this part, the __LLM Reader reads the retrieved documents to formulate its answer.__\n",
    "\n",
    "πŸ› οΈ Here we tried the following options to improve results:\n",
    "- Switch reranking on/off\n",
    "- Change the reader model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9PdpuWyP9jVP"
   },
   "outputs": [],
   "source": [
    "RAG_PROMPT_TEMPLATE = \"\"\"\n",
    "<|system|>\n",
    "Using the information contained in the context,\n",
    "give a comprehensive answer to the question.\n",
    "Respond only to the question asked, response should be concise and relevant to the question.\n",
    "Provide the number of the source document when relevant.\n",
    "If the answer cannot be deduced from the context, do not give an answer.</s>\n",
    "<|user|>\n",
    "Context:\n",
    "{context}\n",
    "---\n",
    "Now here is the question you need to answer.\n",
    "\n",
    "Question: {question}\n",
    "</s>\n",
    "<|assistant|>\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9SDqenld9jVP"
   },
   "outputs": [],
   "source": [
    "from langchain_community.llms import HuggingFaceHub\n",
    "\n",
    "repo_id = \"HuggingFaceH4/zephyr-7b-beta\"\n",
    "READER_MODEL_NAME = \"zephyr-7b-beta\"\n",
    "\n",
    "READER_LLM = HuggingFaceHub(\n",
    "    repo_id=repo_id,\n",
    "    task=\"text-generation\",\n",
    "    model_kwargs={\n",
    "        \"max_new_tokens\": 512,\n",
    "        \"top_k\": 30,\n",
    "        \"temperature\": 0.1,\n",
    "        \"repetition_penalty\": 1.03,\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "QZ62CbcZ9jVP"
   },
   "outputs": [],
   "source": [
    "from ragatouille import RAGPretrainedModel\n",
    "from langchain_core.vectorstores import VectorStore\n",
    "from langchain_core.language_models.llms import LLM\n",
    "\n",
    "\n",
    "def answer_with_rag(\n",
    "    question: str,\n",
    "    llm: LLM,\n",
    "    knowledge_index: VectorStore,\n",
    "    reranker: Optional[RAGPretrainedModel] = None,\n",
    "    num_retrieved_docs: int = 30,\n",
    "    num_docs_final: int = 7,\n",
    ") -> Tuple[str, List[LangchainDocument]]:\n",
    "    \"\"\"Answer a question using RAG with the given knowledge index.\"\"\"\n",
    "    # Gather documents with retriever\n",
    "    relevant_docs = knowledge_index.similarity_search(query=question, k=num_retrieved_docs)\n",
    "    relevant_docs = [doc.page_content for doc in relevant_docs]  # keep only the text\n",
    "\n",
    "    # Optionally rerank results\n",
    "    if reranker:\n",
    "        relevant_docs = reranker.rerank(question, relevant_docs, k=num_docs_final)\n",
    "        relevant_docs = [doc[\"content\"] for doc in relevant_docs]\n",
    "\n",
    "    relevant_docs = relevant_docs[:num_docs_final]\n",
    "\n",
    "    # Build the final prompt\n",
    "    context = \"\\nExtracted documents:\\n\"\n",
    "    context += \"\".join([f\"Document {str(i)}:::\\n\" + doc for i, doc in enumerate(relevant_docs)])\n",
    "\n",
    "    final_prompt = RAG_PROMPT_TEMPLATE.format(question=question, context=context)\n",
    "\n",
    "    # Redact an answer\n",
    "    answer = llm(final_prompt)\n",
    "\n",
    "    return answer, relevant_docs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hiygbqfT9jVP"
   },
   "source": [
    "# 3. Benchmarking the RAG system\n",
    "\n",
    "The RAG system and the evaluation datasets are now ready. The last step is to judge the RAG system's output on this evlauation dataset.\n",
    "\n",
    "To this end, __we setup a judge agent__. βš–οΈπŸ€–\n",
    "\n",
    "Out of [the different RAG evaluation metrics](https://docs.ragas.io/en/latest/concepts/metrics/index.html), we choose to focus only on faithfulness since it the best end-to-end metric of our system's performance.\n",
    "\n",
    "> We use GPT4 as a judge for its empirically good performance, but you could try with other models such as [kaist-ai/prometheus-13b-v1.0](https://huggingface.co/kaist-ai/prometheus-13b-v1.0) or [BAAI/JudgeLM-33B-v1.0](https://huggingface.co/BAAI/JudgeLM-33B-v1.0).\n",
    "\n",
    "πŸ’‘ _In the evaluation prompt, we give a detailed description each metric on the scale 1-5, as is done in [Prometheus's prompt template](https://huggingface.co/kaist-ai/prometheus-13b-v1.0): this helps the model ground its metric precisely. If instead you give the judge LLM a vague scale to work with, the outputs will not be consistent enough between different examples._\n",
    "\n",
    "πŸ’‘ _Again, prompting the LLM to output rationale before giving its final score gives it more tokens to help it formalize and elaborate a judgement._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VrlMh_ZI9jVP"
   },
   "outputs": [],
   "source": [
    "def run_rag_tests(\n",
    "    eval_dataset: datasets.Dataset,\n",
    "    llm: BaseChatModel,\n",
    "    knowledge_index: VectorStore,\n",
    "    output_file: str,\n",
    "    reranker: Optional[RAGPretrainedModel] = None,\n",
    "    verbose: Optional[bool] = True,\n",
    "    test_settings: Optional[str] = None,  # To document the test settings used\n",
    "):\n",
    "    \"\"\"Runs RAG tests on the given dataset and saves the results to the given output file.\"\"\"\n",
    "    try:  # load previous generations if they exist\n",
    "        with open(output_file, \"r\") as f:\n",
    "            outputs = json.load(f)\n",
    "    except:\n",
    "        outputs = []\n",
    "\n",
    "    for example in tqdm(eval_dataset):\n",
    "        question = example[\"question\"]\n",
    "        if question in [output[\"question\"] for output in outputs]:\n",
    "            continue\n",
    "\n",
    "        answer, relevant_docs = answer_with_rag(question, llm, knowledge_index, reranker=reranker)\n",
    "        if verbose:\n",
    "            print(\"=======================================================\")\n",
    "            print(f\"Question: {question}\")\n",
    "            print(f\"Answer: {answer}\")\n",
    "            print(f'True answer: {example[\"answer\"]}')\n",
    "        result = {\n",
    "            \"question\": question,\n",
    "            \"true_answer\": example[\"answer\"],\n",
    "            \"source_doc\": example[\"source_doc\"],\n",
    "            \"generated_answer\": answer,\n",
    "            \"retrieved_docs\": [doc for doc in relevant_docs],\n",
    "        }\n",
    "        if test_settings:\n",
    "            result[\"test_settings\"] = test_settings\n",
    "        outputs.append(result)\n",
    "\n",
    "        with open(output_file, \"w\") as f:\n",
    "            json.dump(outputs, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ae-3KWzK9jVP"
   },
   "outputs": [],
   "source": [
    "EVALUATION_PROMPT = \"\"\"###Task Description:\n",
    "An instruction (might include an Input inside it), a response to evaluate, a reference answer that gets a score of 5, and a score rubric representing a evaluation criteria are given.\n",
    "1. Write a detailed feedback that assess the quality of the response strictly based on the given score rubric, not evaluating in general.\n",
    "2. After writing a feedback, write a score that is an integer between 1 and 5. You should refer to the score rubric.\n",
    "3. The output format should look as follows: \\\"Feedback: {{write a feedback for criteria}} [RESULT] {{an integer number between 1 and 5}}\\\"\n",
    "4. Please do not generate any other opening, closing, and explanations. Be sure to include [RESULT] in your output.\n",
    "\n",
    "###The instruction to evaluate:\n",
    "{instruction}\n",
    "\n",
    "###Response to evaluate:\n",
    "{response}\n",
    "\n",
    "###Reference Answer (Score 5):\n",
    "{reference_answer}\n",
    "\n",
    "###Score Rubrics:\n",
    "[Is the response correct, accurate, and factual based on the reference answer?]\n",
    "Score 1: The response is completely incorrect, inaccurate, and/or not factual.\n",
    "Score 2: The response is mostly incorrect, inaccurate, and/or not factual.\n",
    "Score 3: The response is somewhat correct, accurate, and/or factual.\n",
    "Score 4: The response is mostly correct, accurate, and factual.\n",
    "Score 5: The response is completely correct, accurate, and factual.\n",
    "\n",
    "###Feedback:\"\"\"\n",
    "\n",
    "from langchain.prompts.chat import (\n",
    "    ChatPromptTemplate,\n",
    "    HumanMessagePromptTemplate,\n",
    ")\n",
    "from langchain.schema import SystemMessage\n",
    "\n",
    "\n",
    "evaluation_prompt_template = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        SystemMessage(content=\"You are a fair evaluator language model.\"),\n",
    "        HumanMessagePromptTemplate.from_template(EVALUATION_PROMPT),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ia9Mvn859jVP"
   },
   "outputs": [],
   "source": [
    "from langchain.chat_models import ChatOpenAI\n",
    "\n",
    "eval_chat_model = ChatOpenAI(model=\"gpt-4-1106-preview\", temperature=0)\n",
    "evaluator_name = \"GPT4\"\n",
    "\n",
    "\n",
    "def evaluate_answers(\n",
    "    answer_path: str,\n",
    "    eval_chat_model: BaseChatModel,\n",
    "    evaluator_name: str,\n",
    "    evaluation_prompt_template: ChatPromptTemplate,\n",
    ") -> None:\n",
    "    \"\"\"Evaluates generated answers. Modifies the given answer file in place for better checkpointing.\"\"\"\n",
    "    answers = []\n",
    "    if os.path.isfile(answer_path):  # load previous generations if they exist\n",
    "        answers = json.load(open(answer_path, \"r\"))\n",
    "\n",
    "    for experiment in tqdm(answers):\n",
    "        if f\"eval_score_{evaluator_name}\" in experiment:\n",
    "            continue\n",
    "\n",
    "        eval_prompt = evaluation_prompt_template.format_messages(\n",
    "            instruction=experiment[\"question\"],\n",
    "            response=experiment[\"generated_answer\"],\n",
    "            reference_answer=experiment[\"true_answer\"],\n",
    "        )\n",
    "        eval_result = eval_chat_model.invoke(eval_prompt)\n",
    "        feedback, score = [item.strip() for item in eval_result.content.split(\"[RESULT]\")]\n",
    "        experiment[f\"eval_score_{evaluator_name}\"] = score\n",
    "        experiment[f\"eval_feedback_{evaluator_name}\"] = feedback\n",
    "\n",
    "        with open(answer_path, \"w\") as f:\n",
    "            json.dump(answers, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EXH-szLe9jVP"
   },
   "source": [
    "πŸš€ Let's run the tests and evaluate answers!πŸ‘‡"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jW2nnvUT9jVQ"
   },
   "outputs": [],
   "source": [
    "if not os.path.exists(\"./output\"):\n",
    "    os.mkdir(\"./output\")\n",
    "\n",
    "for chunk_size in [200]:  # Add other chunk sizes (in tokens) as needed\n",
    "    for embeddings in [\"thenlper/gte-small\"]:  # Add other embeddings as needed\n",
    "        for rerank in [True, False]:\n",
    "            settings_name = f\"chunk:{chunk_size}_embeddings:{embeddings.replace('/', '~')}_rerank:{rerank}_reader-model:{READER_MODEL_NAME}\"\n",
    "            output_file_name = f\"./output/rag_{settings_name}.json\"\n",
    "\n",
    "            print(f\"Running evaluation for {settings_name}:\")\n",
    "\n",
    "            print(\"Loading knowledge base embeddings...\")\n",
    "            knowledge_index = load_embeddings(\n",
    "                RAW_KNOWLEDGE_BASE,\n",
    "                chunk_size=chunk_size,\n",
    "                embedding_model_name=embeddings,\n",
    "            )\n",
    "\n",
    "            print(\"Running RAG...\")\n",
    "            reranker = (\n",
    "                RAGPretrainedModel.from_pretrained(\"colbert-ir/colbertv2.0\") if rerank else None\n",
    "            )\n",
    "            run_rag_tests(\n",
    "                eval_dataset=eval_dataset,\n",
    "                llm=READER_LLM,\n",
    "                knowledge_index=knowledge_index,\n",
    "                output_file=output_file_name,\n",
    "                reranker=reranker,\n",
    "                verbose=False,\n",
    "                test_settings=settings_name,\n",
    "            )\n",
    "\n",
    "            print(\"Running evaluation...\")\n",
    "            evaluate_answers(\n",
    "                output_file_name,\n",
    "                eval_chat_model,\n",
    "                evaluator_name,\n",
    "                evaluation_prompt_template,\n",
    "            )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tytXV5-h9jVT"
   },
   "source": [
    "### Inspect results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "D4YDSfmr9jVT"
   },
   "outputs": [],
   "source": [
    "import glob\n",
    "\n",
    "outputs = []\n",
    "for file in glob.glob(\"./output/*.json\"):\n",
    "    output = pd.DataFrame(json.load(open(file, \"r\")))\n",
    "    output[\"settings\"] = file\n",
    "    outputs.append(output)\n",
    "result = pd.concat(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CdkXMNvS9jVT"
   },
   "outputs": [],
   "source": [
    "result[\"eval_score_GPT4\"] = result[\"eval_score_GPT4\"].apply(\n",
    "    lambda x: int(x) if isinstance(x, str) else 1\n",
    ")\n",
    "result[\"eval_score_GPT4\"] = (result[\"eval_score_GPT4\"] - 1) / 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lgxBpid29jVT",
    "outputId": "9a3bcf32-4b0c-4df1-c76c-3ebbca82929d"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "settings\n",
       "./output/rag_chunk:200_embeddings:thenlper~gte-small_rerank:False_reader-model:zephyr-7b-beta.json       0.884328\n",
       "./output/rag_chunk:200_embeddings:BAAI~bge-base-en-v1.5_rerank:False_reader-model:zephyr-7b-beta.json    0.906716\n",
       "./output/rag_chunk:200_embeddings:BAAI~bge-base-en-v1.5_rerank:True_reader-model:zephyr-7b-beta.json     0.906716\n",
       "./output/rag_chunk:200_embeddings:thenlper~gte-small_rerank:True_reader-model:mixtral.json               0.906716\n",
       "./output/rag_chunk:200_embeddings:thenlper~gte-small_rerank:True_reader-model:zephyr-7b-beta.json        0.921642\n",
       "./output/rag_chunk:200_embeddings:thenlper~gte-small_rerank:True_reader-model:mixtral0.json              0.947761\n",
       "Name: eval_score_GPT4, dtype: float64"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "average_scores = result.groupby(\"settings\")[\"eval_score_GPT4\"].mean()\n",
    "average_scores.sort_values()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pSPH9DYI9jVT"
   },
   "source": [
    "## Example results\n",
    "\n",
    "Let us load the results that I obtained by tweaking the different options available in this notebook.\n",
    "For more detail on why these options could work on not, see the notebook on [advanced_RAG](advanced_rag).\n",
    "\n",
    "As you can see in the graph below, some tweaks do not bring any improvement, some give huge performance boosts.\n",
    "\n",
    "➑️ ___There is no single good recipe: you should try several different directions when tuning your RAG systems.___\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RVOxatv99jVT"
   },
   "outputs": [],
   "source": [
    "import plotly.express as px\n",
    "\n",
    "scores = datasets.load_dataset(\"m-ric/rag_scores_cookbook\", split=\"train\")\n",
    "scores = pd.Series(scores[\"score\"], index=scores[\"settings\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vqK0Dg2Q9jVT"
   },
   "outputs": [],
   "source": [
    "fig = px.bar(\n",
    "    scores,\n",
    "    color=scores,\n",
    "    labels={\n",
    "        \"value\": \"Accuracy\",\n",
    "        \"settings\": \"Configuration\",\n",
    "    },\n",
    "    color_continuous_scale=\"bluered\",\n",
    ")\n",
    "fig.update_layout(w\n",
    "    width=1000,\n",
    "    height=600,\n",
    "    barmode=\"group\",\n",
    "    yaxis_range=[0, 100],\n",
    "    title=\"<b>Accuracy of different RAG configurations</b>\",\n",
    "    xaxis_title=\"RAG settings\",\n",
    "    font=dict(size=15),\n",
    ")\n",
    "fig.layout.yaxis.ticksuffix = \"%\"\n",
    "fig.update_coloraxes(showscale=False)\n",
    "fig.update_traces(texttemplate=\"%{y:.1f}\", textposition=\"outside\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dPUOMWGk9jVT"
   },
   "source": [
    "<img src=\"https://huggingface.co/datasets/huggingface/cookbook-images/resolve/main/RAG_settings_accuracy.png\" height=\"500\" width=\"800\">\n",
    "\n",
    "As you can see, these had varying impact on performance. In particular, tuning the chunk size is both easy and very impactful.\n",
    "\n",
    "But this is our case: your results could be very different: now that you have a robust evaluation pipeline, you can set on to explore other options! πŸ—ΊοΈ"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "ml2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}