File size: 81,501 Bytes
3943768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
from __future__ import annotations

import atexit
import concurrent
import copy
import difflib
import re
import threading
import traceback
import os
import time
import urllib.parse
import uuid
import warnings
from concurrent.futures import Future
from datetime import timedelta
from enum import Enum
from functools import lru_cache
from pathlib import Path
from typing import Callable, Generator, Any, Union, List, Dict, Literal, Tuple
import ast
import inspect
import numpy as np

try:
    from gradio_utils.yield_utils import ReturnType
except (ImportError, ModuleNotFoundError):
    try:
        from yield_utils import ReturnType
    except (ImportError, ModuleNotFoundError):
        try:
            from src.yield_utils import ReturnType
        except (ImportError, ModuleNotFoundError):
            from .src.yield_utils import ReturnType

os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"

from huggingface_hub import SpaceStage
from huggingface_hub.utils import (
    build_hf_headers,
)

from gradio_client import utils

from importlib.metadata import distribution, PackageNotFoundError

lock = threading.Lock()

try:
    assert distribution("gradio_client") is not None
    have_gradio_client = True
    from packaging import version

    client_version = distribution("gradio_client").version
    is_gradio_client_version7plus = version.parse(client_version) >= version.parse(
        "0.7.0"
    )
except (PackageNotFoundError, AssertionError):
    have_gradio_client = False
    is_gradio_client_version7plus = False

from gradio_client.client import Job, DEFAULT_TEMP_DIR, Endpoint
from gradio_client import Client


def check_job(job, timeout=0.0, raise_exception=True, verbose=False):
    try:
        e = job.exception(timeout=timeout)
    except concurrent.futures.TimeoutError:
        # not enough time to determine
        if verbose:
            print("not enough time to determine job status: %s" % timeout)
        e = None
    if e:
        # raise before complain about empty response if some error hit
        if raise_exception:
            raise RuntimeError(traceback.format_exception(e))
        else:
            return e


# Local copy of minimal version from h2oGPT server
class LangChainAction(Enum):
    """LangChain action"""

    QUERY = "Query"
    SUMMARIZE_MAP = "Summarize"
    EXTRACT = "Extract"


pre_prompt_query0 = "Pay attention and remember the information below, which will help to answer the question or imperative after the context ends."
prompt_query0 = "According to only the information in the document sources provided within the context above: "

pre_prompt_summary0 = """"""
prompt_summary0 = "Using only the information in the document sources above, write a condensed and concise well-structured Markdown summary of key results."

pre_prompt_extraction0 = (
    """In order to extract information, pay attention to the following text."""
)
prompt_extraction0 = (
    "Using only the information in the document sources above, extract "
)

hyde_llm_prompt0 = "Answer this question with vibrant details in order for some NLP embedding model to use that answer as better query than original question: "

client_version = distribution("gradio_client").version
old_gradio = version.parse(client_version) <= version.parse("0.6.1")


class CommonClient:
    def question(self, instruction, *args, **kwargs) -> str:
        """
        Prompt LLM (direct to LLM with instruct prompting required for instruct models) and get response
        """
        kwargs["instruction"] = kwargs.get("instruction", instruction)
        kwargs["langchain_action"] = LangChainAction.QUERY.value
        kwargs["langchain_mode"] = "LLM"
        ret = ""
        for ret1 in self.query_or_summarize_or_extract(*args, **kwargs):
            ret = ret1.reply
        return ret

    def question_stream(
            self, instruction, *args, **kwargs
    ) -> Generator[ReturnType, None, None]:
        """
        Prompt LLM (direct to LLM with instruct prompting required for instruct models) and get response
        """
        kwargs["instruction"] = kwargs.get("instruction", instruction)
        kwargs["langchain_action"] = LangChainAction.QUERY.value
        kwargs["langchain_mode"] = "LLM"
        ret = yield from self.query_or_summarize_or_extract(*args, **kwargs)
        return ret

    def query(self, query, *args, **kwargs) -> str:
        """
        Search for documents matching a query, then ask that query to LLM with those documents
        """
        kwargs["instruction"] = kwargs.get("instruction", query)
        kwargs["langchain_action"] = LangChainAction.QUERY.value
        ret = ""
        for ret1 in self.query_or_summarize_or_extract(*args, **kwargs):
            ret = ret1.reply
        return ret

    def query_stream(self, query, *args, **kwargs) -> Generator[ReturnType, None, None]:
        """
        Search for documents matching a query, then ask that query to LLM with those documents
        """
        kwargs["instruction"] = kwargs.get("instruction", query)
        kwargs["langchain_action"] = LangChainAction.QUERY.value
        ret = yield from self.query_or_summarize_or_extract(*args, **kwargs)
        return ret

    def summarize(self, *args, query=None, focus=None, **kwargs) -> str:
        """
        Search for documents matching a focus, then ask a query to LLM with those documents
        If focus "" or None, no similarity search is done and all documents (up to top_k_docs) are used
        """
        kwargs["prompt_summary"] = kwargs.get(
            "prompt_summary", query or prompt_summary0
        )
        kwargs["instruction"] = kwargs.get("instruction", focus)
        kwargs["langchain_action"] = LangChainAction.SUMMARIZE_MAP.value
        ret = ""
        for ret1 in self.query_or_summarize_or_extract(*args, **kwargs):
            ret = ret1.reply
        return ret

    def summarize_stream(self, *args, query=None, focus=None, **kwargs) -> str:
        """
        Search for documents matching a focus, then ask a query to LLM with those documents
        If focus "" or None, no similarity search is done and all documents (up to top_k_docs) are used
        """
        kwargs["prompt_summary"] = kwargs.get(
            "prompt_summary", query or prompt_summary0
        )
        kwargs["instruction"] = kwargs.get("instruction", focus)
        kwargs["langchain_action"] = LangChainAction.SUMMARIZE_MAP.value
        ret = yield from self.query_or_summarize_or_extract(*args, **kwargs)
        return ret

    def extract(self, *args, query=None, focus=None, **kwargs) -> list[str]:
        """
        Search for documents matching a focus, then ask a query to LLM with those documents
        If focus "" or None, no similarity search is done and all documents (up to top_k_docs) are used
        """
        kwargs["prompt_extraction"] = kwargs.get(
            "prompt_extraction", query or prompt_extraction0
        )
        kwargs["instruction"] = kwargs.get("instruction", focus)
        kwargs["langchain_action"] = LangChainAction.EXTRACT.value
        ret = ""
        for ret1 in self.query_or_summarize_or_extract(*args, **kwargs):
            ret = ret1.reply
        return ret

    def extract_stream(self, *args, query=None, focus=None, **kwargs) -> list[str]:
        """
        Search for documents matching a focus, then ask a query to LLM with those documents
        If focus "" or None, no similarity search is done and all documents (up to top_k_docs) are used
        """
        kwargs["prompt_extraction"] = kwargs.get(
            "prompt_extraction", query or prompt_extraction0
        )
        kwargs["instruction"] = kwargs.get("instruction", focus)
        kwargs["langchain_action"] = LangChainAction.EXTRACT.value
        ret = yield from self.query_or_summarize_or_extract(*args, **kwargs)
        return ret

    def get_client_kwargs(self, **kwargs):
        client_kwargs = {}
        try:
            from src.evaluate_params import eval_func_param_names
        except (ImportError, ModuleNotFoundError):
            try:
                from evaluate_params import eval_func_param_names
            except (ImportError, ModuleNotFoundError):
                from .src.evaluate_params import eval_func_param_names

        for k in eval_func_param_names:
            if k in kwargs:
                client_kwargs[k] = kwargs[k]

        if os.getenv("HARD_ASSERTS"):
            fun_kwargs = {
                k: v.default
                for k, v in dict(
                    inspect.signature(self.query_or_summarize_or_extract).parameters
                ).items()
            }
            diff = set(eval_func_param_names).difference(fun_kwargs)
            assert len(diff) == 0, (
                    "Add query_or_summarize_or_extract entries: %s" % diff
            )

            extra_query_params = [
                "file",
                "bad_error_string",
                "print_info",
                "asserts",
                "url",
                "prompt_extraction",
                "model",
                "text",
                "print_error",
                "pre_prompt_extraction",
                "embed",
                "print_warning",
                "sanitize_llm",
            ]
            diff = set(fun_kwargs).difference(
                eval_func_param_names + extra_query_params
            )
            assert len(diff) == 0, "Add eval_func_params entries: %s" % diff

        return client_kwargs

    def get_query_kwargs(self, **kwargs):
        fun_dict = dict(
            inspect.signature(self.query_or_summarize_or_extract).parameters
        ).items()
        fun_kwargs = {k: kwargs.get(k, v.default) for k, v in fun_dict}

        return fun_kwargs

    @staticmethod
    def check_error(res_dict):
        actual_llm = ""
        try:
            actual_llm = res_dict["save_dict"]["display_name"]
        except:
            pass
        if "error" in res_dict and res_dict["error"]:
            raise RuntimeError(f"Error from LLM {actual_llm}: {res_dict['error']}")
        if "error_ex" in res_dict and res_dict["error_ex"]:
            raise RuntimeError(
                f"Error Traceback from LLM {actual_llm}: {res_dict['error_ex']}"
            )
        if "response" not in res_dict:
            raise ValueError(f"No response from LLM {actual_llm}")

    def query_or_summarize_or_extract(
            self,
            print_error=print,
            print_info=print,
            print_warning=print,
            bad_error_string=None,
            sanitize_llm=None,
            h2ogpt_key: str = None,
            instruction: str = "",
            text: list[str] | str | None = None,
            file: list[str] | str | None = None,
            url: list[str] | str | None = None,
            embed: bool = True,
            chunk: bool = True,
            chunk_size: int = 512,
            langchain_mode: str = None,
            langchain_action: str | None = None,
            langchain_agents: List[str] = [],
            top_k_docs: int = 10,
            document_choice: Union[str, List[str]] = "All",
            document_subset: str = "Relevant",
            document_source_substrings: Union[str, List[str]] = [],
            document_source_substrings_op: str = "and",
            document_content_substrings: Union[str, List[str]] = [],
            document_content_substrings_op: str = "and",
            system_prompt: str | None = "",
            pre_prompt_query: str | None = pre_prompt_query0,
            prompt_query: str | None = prompt_query0,
            pre_prompt_summary: str | None = pre_prompt_summary0,
            prompt_summary: str | None = prompt_summary0,
            pre_prompt_extraction: str | None = pre_prompt_extraction0,
            prompt_extraction: str | None = prompt_extraction0,
            hyde_llm_prompt: str | None = hyde_llm_prompt0,
            all_docs_start_prompt: str | None = None,
            all_docs_finish_prompt: str | None = None,
            user_prompt_for_fake_system_prompt: str = None,
            json_object_prompt: str = None,
            json_object_prompt_simpler: str = None,
            json_code_prompt: str = None,
            json_code_prompt_if_no_schema: str = None,
            json_schema_instruction: str = None,
            json_preserve_system_prompt: bool = False,
            json_object_post_prompt_reminder: str = None,
            json_code_post_prompt_reminder: str = None,
            json_code2_post_prompt_reminder: str = None,
            model: str | int | None = None,
            model_lock: dict | None = None,
            stream_output: bool = False,
            enable_caching: bool = False,
            do_sample: bool = False,
            seed: int | None = 0,
            temperature: float = 0.0,
            top_p: float = 1.0,
            top_k: int = 40,
            # 1.07 causes issues still with more repetition
            repetition_penalty: float = 1.0,
            penalty_alpha: float = 0.0,
            max_time: int = 360,
            max_new_tokens: int = 1024,
            add_search_to_context: bool = False,
            chat_conversation: list[tuple[str, str]] | None = None,
            text_context_list: list[str] | None = None,
            docs_ordering_type: str | None = None,
            min_max_new_tokens: int = 512,
            max_input_tokens: int = -1,
            max_total_input_tokens: int = -1,
            docs_token_handling: str = "split_or_merge",
            docs_joiner: str = "\n\n",
            hyde_level: int = 0,
            hyde_template: str = None,
            hyde_show_only_final: bool = True,
            doc_json_mode: bool = False,
            metadata_in_context: list = [],
            image_file: Union[str, list] = None,
            image_control: str = None,
            images_num_max: int = None,
            image_resolution: tuple = None,
            image_format: str = None,
            rotate_align_resize_image: bool = None,
            video_frame_period: int = None,
            image_batch_image_prompt: str = None,
            image_batch_final_prompt: str = None,
            image_batch_stream: bool = None,
            visible_vision_models: Union[str, int, list] = None,
            video_file: Union[str, list] = None,
            response_format: str = "text",
            guided_json: Union[str, dict] = "",
            guided_regex: str = "",
            guided_choice: List[str] | None = None,
            guided_grammar: str = "",
            guided_whitespace_pattern: str = None,
            prompt_type: Union[int, str] = None,
            prompt_dict: Dict = None,
            chat_template: str = None,
            jq_schema=".[]",
            llava_prompt: str = "auto",
            image_audio_loaders: list = None,
            url_loaders: list = None,
            pdf_loaders: list = None,
            extract_frames: int = 10,
            add_chat_history_to_context: bool = True,
            chatbot_role: str = "None",  # "Female AI Assistant",
            speaker: str = "None",  # "SLT (female)",
            tts_language: str = "autodetect",
            tts_speed: float = 1.0,
            visible_image_models: List[str] = [],
            image_size: str = "1024x1024",
            image_quality: str = 'standard',
            image_guidance_scale: float = 3.0,
            image_num_inference_steps: int = 30,
            visible_models: Union[str, int, list] = None,
            client_metadata: str = '',
            # don't use the below (no doc string stuff) block
            num_return_sequences: int = None,
            chat: bool = True,
            min_new_tokens: int = None,
            early_stopping: Union[bool, str] = None,
            iinput: str = "",
            iinput_nochat: str = "",
            instruction_nochat: str = "",
            context: str = "",
            num_beams: int = 1,
            asserts: bool = False,
            do_lock: bool = False,
    ) -> Generator[ReturnType, None, None]:
        """
        Query or Summarize or Extract using h2oGPT
        Args:
            instruction: Query for LLM chat.  Used for similarity search

            For query, prompt template is:
              "{pre_prompt_query}
                \"\"\"
                {content}
                \"\"\"
                {prompt_query}{instruction}"
             If added to summarization, prompt template is
              "{pre_prompt_summary}
                \"\"\"
                {content}
                \"\"\"
                Focusing on {instruction}, {prompt_summary}"
            text: textual content or list of such contents
            file: a local file to upload or files to upload
            url: a url to give or urls to use
            embed: whether to embed content uploaded

            :param langchain_mode: "LLM" to talk to LLM with no docs, "MyData" for personal docs, "UserData" for shared docs, etc.
            :param langchain_action: Action to take, "Query" or "Summarize" or "Extract"
            :param langchain_agents: Which agents to use, if any
            :param top_k_docs: number of document parts.
                        When doing query, number of chunks
                        When doing summarization, not related to vectorDB chunks that are not used
                        E.g. if PDF, then number of pages
            :param chunk: whether to chunk sources for document Q/A
            :param chunk_size: Size in characters of chunks
            :param document_choice: Which documents ("All" means all) -- need to use upload_api API call to get server's name if want to select
            :param document_subset: Type of query, see src/gen.py
            :param document_source_substrings: See gen.py
            :param document_source_substrings_op: See gen.py
            :param document_content_substrings: See gen.py
            :param document_content_substrings_op: See gen.py

            :param system_prompt: pass system prompt to models that support it.
              If 'auto' or None, then use automatic version
              If '', then use no system prompt (default)
            :param pre_prompt_query: Prompt that comes before document part
            :param prompt_query: Prompt that comes after document part
            :param pre_prompt_summary: Prompt that comes before document part
               None makes h2oGPT internally use its defaults
               E.g. "In order to write a concise single-paragraph or bulleted list summary, pay attention to the following text"
            :param prompt_summary: Prompt that comes after document part
              None makes h2oGPT internally use its defaults
              E.g. "Using only the text above, write a condensed and concise summary of key results (preferably as bullet points):\n"
            i.e. for some internal document part fstring, the template looks like:
                template = "%s
                \"\"\"
                %s
                \"\"\"
                %s" % (pre_prompt_summary, fstring, prompt_summary)
            :param hyde_llm_prompt: hyde prompt for first step when using LLM
            :param all_docs_start_prompt: start of document block
            :param all_docs_finish_prompt: finish of document block

            :param user_prompt_for_fake_system_prompt: user part of pre-conversation if LLM doesn't handle system prompt
            :param json_object_prompt: prompt for getting LLM to do JSON object
            :param json_object_prompt_simpler: simpler of "" for MistralAI
            :param json_code_prompt: prompt for getting LLm to do JSON in code block
            :param json_code_prompt_if_no_schema: prompt for getting LLM to do JSON in code block if no schema
            :param json_schema_instruction: prompt for LLM to use schema
            :param json_preserve_system_prompt: Whether to preserve system prompt for json mode
            :param json_object_post_prompt_reminder: json object reminder about JSON
            :param json_code_post_prompt_reminder: json code w/ schema reminder about JSON
            :param json_code2_post_prompt_reminder: json code wo/ schema reminder about JSON

            :param h2ogpt_key: Access Key to h2oGPT server (if not already set in client at init time)
            :param model: base_model name or integer index of model_lock on h2oGPT server
                            None results in use of first (0th index) model in server
                   to get list of models do client.list_models()
            :param model_lock: dict of states or single state, with dict of things like inference server, to use when using dynamic LLM (not from existing model lock on h2oGPT)
            :param pre_prompt_extraction: Same as pre_prompt_summary but for when doing extraction
            :param prompt_extraction: Same as prompt_summary but for when doing extraction
            :param do_sample: see src/gen.py
            :param seed: see src/gen.py
            :param temperature: see src/gen.py
            :param top_p: see src/gen.py
            :param top_k: see src/gen.py
            :param repetition_penalty: see src/gen.py
            :param penalty_alpha: see src/gen.py
            :param max_new_tokens: see src/gen.py
            :param min_max_new_tokens: see src/gen.py
            :param max_input_tokens: see src/gen.py
            :param max_total_input_tokens: see src/gen.py
            :param stream_output: Whether to stream output
            :param enable_caching: Whether to enable caching
            :param max_time: how long to take

            :param add_search_to_context: Whether to do web search and add results to context
            :param chat_conversation: List of tuples for (human, bot) conversation that will be pre-appended to an (instruction, None) case for a query
            :param text_context_list: List of strings to add to context for non-database version of document Q/A for faster handling via API etc.
               Forces LangChain code path and uses as many entries in list as possible given max_seq_len, with first assumed to be most relevant and to go near prompt.
            :param docs_ordering_type: By default uses 'reverse_ucurve_sort' for optimal retrieval
            :param max_input_tokens: Max input tokens to place into model context for each LLM call
                                     -1 means auto, fully fill context for query, and fill by original document chunk for summarization
                                     >=0 means use that to limit context filling to that many tokens
            :param max_total_input_tokens: like max_input_tokens but instead of per LLM call, applies across all LLM calls for single summarization/extraction action
            :param max_new_tokens: Maximum new tokens
            :param min_max_new_tokens: minimum value for max_new_tokens when auto-adjusting for content of prompt, docs, etc.

            :param docs_token_handling: 'chunk' means fill context with top_k_docs (limited by max_input_tokens or model_max_len) chunks for query
                                                                             or top_k_docs original document chunks summarization
                                        None or 'split_or_merge' means same as 'chunk' for query, while for summarization merges documents to fill up to max_input_tokens or model_max_len tokens
            :param docs_joiner: string to join lists of text when doing split_or_merge.  None means '\n\n'
            :param hyde_level: 0-3 for HYDE.
                        0 uses just query to find similarity with docs
                        1 uses query + pure LLM response to find similarity with docs
                        2: uses query + LLM response using docs to find similarity with docs
                        3+: etc.
            :param hyde_template: see src/gen.py
            :param hyde_show_only_final: see src/gen.py
            :param doc_json_mode: see src/gen.py
            :param metadata_in_context: see src/gen.py

            :param image_file: Initial image for UI (or actual image for CLI) Vision Q/A.  Or list of images for some models
            :param image_control: Initial image for UI Image Control
            :param images_num_max: Max. number of images per LLM call
            :param image_resolution: Resolution of any images
            :param image_format: Image format
            :param rotate_align_resize_image: Whether to apply rotation, alignment, resize before giving to LLM
            :param video_frame_period: Period of frames to use from video
            :param image_batch_image_prompt: Prompt used to query image only if doing batching of images
            :param image_batch_final_prompt: Prompt used to query result of batching of images
            :param image_batch_stream: Whether to stream batching of images.
            :param visible_vision_models: Model to use for vision, e.g. if base LLM has no vision
                   If 'auto', then use CLI value, else use model display name given here
            :param video_file: DO NOT USE FOR API, put images, videos, urls, and youtube urls in image_file as list

            :param response_format: text or json_object or json_code
            # https://github.com/vllm-project/vllm/blob/a3c226e7eb19b976a937e745f3867eb05f809278/vllm/entrypoints/openai/protocol.py#L117-L135
            :param guided_json: str or dict of JSON schema
            :param guided_regex:
            :param guided_choice: list of strings to have LLM choose from
            :param guided_grammar:
            :param guided_whitespace_pattern:

            :param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
            :param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
            :param chat_template: jinja HF transformers chat_template to use.  '' or None means no change to template

            :param jq_schema: control json loader
                   By default '.[]' ingests everything in brute-force way, but better to match your schema
                   See: https://python.langchain.com/docs/modules/data_connection/document_loaders/json#using-jsonloader

            :param extract_frames: How many unique frames to extract from video (if 0, then just do audio if audio type file as well)

            :param llava_prompt: Prompt passed to LLaVa for querying the image

            :param image_audio_loaders: which loaders to use for image and audio parsing (None means default)
            :param url_loaders: which loaders to use for url parsing (None means default)
            :param pdf_loaders: which loaders to use for pdf parsing (None means default)

            :param add_chat_history_to_context: Include chat context when performing action
                   Not supported when using CLI mode

            :param chatbot_role: Default role for coqui models.  If 'None', then don't by default speak when launching h2oGPT for coqui model choice.
            :param speaker: Default speaker for microsoft models  If 'None', then don't by default speak when launching h2oGPT for microsoft model choice.
            :param tts_language: Default language for coqui models
            :param tts_speed: Default speed of TTS, < 1.0 (needs rubberband) for slower than normal, > 1.0 for faster.  Tries to keep fixed pitch.

            :param visible_image_models: Which image gen models to include
            :param image_size
            :param image_quality
            :param image_guidance_scale
            :param image_num_inference_steps
            :param visible_models: Which models in model_lock list to show by default
                   Takes integers of position in model_lock (model_states) list or strings of base_model names
                   Ignored if model_lock not used
                   For nochat API, this is single item within a list for model by name or by index in model_lock
                                        If None, then just use first model in model_lock list
                                        If model_lock not set, use model selected by CLI --base_model etc.
                   Note that unlike h2ogpt_key, this visible_models only applies to this running h2oGPT server,
                      and the value is not used to access the inference server.
                      If need a visible_models for an inference server, then use --model_lock and group together.
            :param client_metadata:
            :param asserts: whether to do asserts to ensure handling is correct

        Returns: summary/answer: str or extraction List[str]

        """
        if self.config is None:
            self.setup()
        if self.persist:
            client = self
        else:
            client = self.clone()
        try:
            h2ogpt_key = h2ogpt_key or self.h2ogpt_key
            client.h2ogpt_key = h2ogpt_key

            if model is not None and visible_models is None:
                visible_models = model
            client.check_model(model)

            # chunking not used here
            # MyData specifies scratch space, only persisted for this individual client call
            langchain_mode = langchain_mode or "MyData"
            loaders = tuple([None, None, None, None, None, None])
            doc_options = tuple([langchain_mode, chunk, chunk_size, embed])
            asserts |= bool(os.getenv("HARD_ASSERTS", False))
            if (
                    text
                    and isinstance(text, list)
                    and not file
                    and not url
                    and not text_context_list
            ):
                # then can do optimized text-only path
                text_context_list = text
                text = None

            res = []
            if text:
                t0 = time.time()
                res = client.predict(
                    text, *doc_options, *loaders, h2ogpt_key, api_name="/add_text"
                )
                t1 = time.time()
                print_info("upload text: %s" % str(timedelta(seconds=t1 - t0)))
                if asserts:
                    assert res[0] is None
                    assert res[1] == langchain_mode
                    assert "user_paste" in res[2]
                    assert res[3] == ""
            if file:
                # upload file(s).  Can be list or single file
                # after below call, "file" replaced with remote location of file
                _, file = client.predict(file, api_name="/upload_api")

                res = client.predict(
                    file, *doc_options, *loaders, h2ogpt_key, api_name="/add_file_api"
                )
                if asserts:
                    assert res[0] is None
                    assert res[1] == langchain_mode
                    assert os.path.basename(file) in res[2]
                    assert res[3] == ""
            if url:
                res = client.predict(
                    url, *doc_options, *loaders, h2ogpt_key, api_name="/add_url"
                )
                if asserts:
                    assert res[0] is None
                    assert res[1] == langchain_mode
                    assert url in res[2]
                    assert res[3] == ""
                    assert res[4]  # should have file name or something similar
            if res and not res[4] and "Exception" in res[2]:
                print_error("Exception: %s" % res[2])

            # ask for summary, need to use same client if using MyData
            api_name = "/submit_nochat_api"  # NOTE: like submit_nochat but stable API for string dict passing

            pre_prompt_summary = (
                pre_prompt_summary
                if langchain_action == LangChainAction.SUMMARIZE_MAP.value
                else pre_prompt_extraction
            )
            prompt_summary = (
                prompt_summary
                if langchain_action == LangChainAction.SUMMARIZE_MAP.value
                else prompt_extraction
            )

            chat_conversation = (
                chat_conversation
                if chat_conversation or not self.persist
                else self.chat_conversation.copy()
            )

            locals_for_client = locals().copy()
            locals_for_client.pop("self", None)
            client_kwargs = self.get_client_kwargs(**locals_for_client)

            # in case server changed, update in case clone()
            if do_lock:
                with lock:
                    self.server_hash = client.server_hash
            else:
                self.server_hash = client.server_hash

            # ensure can fill conversation
            if self.persist:
                self.chat_conversation.append((instruction, None))

            # get result
            actual_llm = visible_models
            response = ""
            texts_out = []
            trials = 3
            # average generation failure for gpt-35-turbo-1106 is 2, but up to 4 in 100 trials, so why chose 10
            # very quick to do since basically instant failure at start of generation
            trials_generation = 10
            trial = 0
            trial_generation = 0
            t0 = time.time()
            input_tokens = 0
            output_tokens = 0
            tokens_per_second = 0
            vision_visible_model = None
            vision_batch_input_tokens = 0
            vision_batch_output_tokens = 0
            vision_batch_tokens_per_second = 0
            t_taken_s = None
            while True:
                time_to_first_token = None
                t0 = time.time()
                try:
                    if not stream_output:
                        res = client.predict(
                            str(dict(client_kwargs)),
                            api_name=api_name,
                        )
                        if time_to_first_token is None:
                            time_to_first_token = time.time() - t0
                        t_taken_s = time.time() - t0
                        # in case server changed, update in case clone()
                        if do_lock:
                            with lock:
                                self.server_hash = client.server_hash
                        else:
                            self.server_hash = client.server_hash
                        res_dict = ast.literal_eval(res)
                        self.check_error(res_dict)
                        response = res_dict["response"]
                        if langchain_action != LangChainAction.EXTRACT.value:
                            response = response.strip()
                        else:
                            response = [r.strip() for r in ast.literal_eval(response)]
                        sources = res_dict["sources"]
                        scores_out = [x["score"] for x in sources]
                        texts_out = [x["content"] for x in sources]
                        prompt_raw = res_dict.get("prompt_raw", "")
                        try:
                            actual_llm = res_dict["save_dict"][
                                "display_name"
                            ]  # fast path
                        except Exception as e:
                            print_warning(
                                f"Unable to access save_dict to get actual_llm: {str(e)}"
                            )
                        try:
                            extra_dict = res_dict["save_dict"]["extra_dict"]
                            input_tokens = extra_dict["num_prompt_tokens"]
                            output_tokens = extra_dict["ntokens"]
                            tokens_per_second = np.round(
                                extra_dict["tokens_persecond"], decimals=3
                            )
                            vision_visible_model = extra_dict.get(
                                "batch_vision_visible_model"
                            )
                            vision_batch_input_tokens = extra_dict.get(
                                "vision_batch_input_tokens", 0
                            )
                        except:
                            if os.getenv("HARD_ASSERTS"):
                                raise
                        if asserts:
                            if text and not file and not url:
                                assert any(
                                    text[:cutoff] == texts_out
                                    for cutoff in range(len(text))
                                )
                            assert len(texts_out) == len(scores_out)

                        yield ReturnType(
                            reply=response,
                            text_context_list=texts_out,
                            prompt_raw=prompt_raw,
                            actual_llm=actual_llm,
                            input_tokens=input_tokens,
                            output_tokens=output_tokens,
                            tokens_per_second=tokens_per_second,
                            time_to_first_token=time_to_first_token or (time.time() - t0),
                            vision_visible_model=vision_visible_model,
                            vision_batch_input_tokens=vision_batch_input_tokens,
                            vision_batch_output_tokens=vision_batch_output_tokens,
                            vision_batch_tokens_per_second=vision_batch_tokens_per_second,
                        )
                        if self.persist:
                            self.chat_conversation[-1] = (instruction, response)
                    else:
                        job = client.submit(str(dict(client_kwargs)), api_name=api_name)
                        text0 = ""
                        while not job.done():
                            e = check_job(job, timeout=0, raise_exception=False)
                            if e is not None:
                                break
                            outputs_list = job.outputs().copy()
                            if outputs_list:
                                res = outputs_list[-1]
                                res_dict = ast.literal_eval(res)
                                self.check_error(res_dict)
                                response = res_dict["response"]  # keeps growing
                                prompt_raw = res_dict.get(
                                    "prompt_raw", ""
                                )  # only filled at end
                                text_chunk = response[
                                             len(text0):
                                             ]  # only keep new stuff
                                if not text_chunk:
                                    time.sleep(0.001)
                                    continue
                                text0 = response
                                assert text_chunk, "must yield non-empty string"
                                if time_to_first_token is None:
                                    time_to_first_token = time.time() - t0
                                yield ReturnType(
                                    reply=text_chunk,
                                    actual_llm=actual_llm,
                                )  # streaming part
                            time.sleep(0.005)

                        # Get final response (if anything left), but also get the actual references (texts_out), above is empty.
                        res_all = job.outputs().copy()
                        success = job.communicator.job.latest_status.success
                        timeout = 0.1 if success else 10
                        if len(res_all) > 0:
                            try:
                                check_job(job, timeout=timeout, raise_exception=True)
                            except (
                                    Exception
                            ) as e:  # FIXME - except TimeoutError once h2ogpt raises that.
                                if "Abrupt termination of communication" in str(e):
                                    t_taken = "%.4f" % (time.time() - t0)
                                    raise TimeoutError(
                                        f"LLM {actual_llm} timed out after {t_taken} seconds."
                                    )
                                else:
                                    raise

                            res = res_all[-1]
                            res_dict = ast.literal_eval(res)
                            self.check_error(res_dict)
                            response = res_dict["response"]
                            sources = res_dict["sources"]
                            prompt_raw = res_dict["prompt_raw"]
                            save_dict = res_dict.get("save_dict", dict(extra_dict={}))
                            extra_dict = save_dict.get("extra_dict", {})
                            texts_out = [x["content"] for x in sources]
                            t_taken_s = time.time() - t0
                            t_taken = "%.4f" % t_taken_s

                            if langchain_action != LangChainAction.EXTRACT.value:
                                text_chunk = response.strip()
                            else:
                                text_chunk = [
                                    r.strip() for r in ast.literal_eval(response)
                                ]

                            if not text_chunk:
                                raise TimeoutError(
                                    f"No output from LLM {actual_llm} after {t_taken} seconds."
                                )
                            if "error" in save_dict and not prompt_raw:
                                raise RuntimeError(
                                    f"Error from LLM {actual_llm}: {save_dict['error']}"
                                )
                            assert (
                                    prompt_raw or extra_dict
                            ), "LLM response failed to return final metadata."

                            try:
                                extra_dict = res_dict["save_dict"]["extra_dict"]
                                input_tokens = extra_dict["num_prompt_tokens"]
                                output_tokens = extra_dict["ntokens"]
                                vision_visible_model = extra_dict.get(
                                    "batch_vision_visible_model"
                                )
                                vision_batch_input_tokens = extra_dict.get(
                                    "batch_num_prompt_tokens", 0
                                )
                                vision_batch_output_tokens = extra_dict.get(
                                    "batch_ntokens", 0
                                )
                                tokens_per_second = np.round(
                                    extra_dict["tokens_persecond"], decimals=3
                                )
                                vision_batch_tokens_per_second = extra_dict.get(
                                    "batch_tokens_persecond", 0
                                )
                                if vision_batch_tokens_per_second:
                                    vision_batch_tokens_per_second = np.round(
                                        vision_batch_tokens_per_second, decimals=3
                                    )
                            except:
                                if os.getenv("HARD_ASSERTS"):
                                    raise
                            try:
                                actual_llm = res_dict["save_dict"][
                                    "display_name"
                                ]  # fast path
                            except Exception as e:
                                print_warning(
                                    f"Unable to access save_dict to get actual_llm: {str(e)}"
                                )

                            if text_context_list:
                                assert texts_out, "No texts_out 1"

                            if time_to_first_token is None:
                                time_to_first_token = time.time() - t0
                            yield ReturnType(
                                reply=text_chunk,
                                text_context_list=texts_out,
                                prompt_raw=prompt_raw,
                                actual_llm=actual_llm,
                                input_tokens=input_tokens,
                                output_tokens=output_tokens,
                                tokens_per_second=tokens_per_second,
                                time_to_first_token=time_to_first_token,
                                trial=trial,
                                vision_visible_model=vision_visible_model,
                                vision_batch_input_tokens=vision_batch_input_tokens,
                                vision_batch_output_tokens=vision_batch_output_tokens,
                                vision_batch_tokens_per_second=vision_batch_tokens_per_second,
                            )
                            if self.persist:
                                self.chat_conversation[-1] = (
                                    instruction,
                                    text_chunk,
                                )
                        else:
                            assert not success
                            check_job(job, timeout=2.0 * timeout, raise_exception=True)
                    if trial > 0 or trial_generation > 0:
                        print("trial recovered: %s %s" % (trial, trial_generation))
                    break
                except Exception as e:
                    if "No generations" in str(
                            e
                    ) or """'NoneType' object has no attribute 'generations'""" in str(
                        e
                    ):
                        trial_generation += 1
                    else:
                        trial += 1
                    print_error(
                        "h2oGPT predict failed: %s %s"
                        % (str(e), "".join(traceback.format_tb(e.__traceback__))),
                    )
                    if "invalid model" in str(e).lower():
                        raise
                    if bad_error_string and bad_error_string in str(e):
                        # no need to do 3 trials if have disallowed stuff, unlikely that LLM will change its mind
                        raise
                    if trial == trials or trial_generation == trials_generation:
                        print_error(
                            "trying again failed: %s %s" % (trial, trial_generation)
                        )
                        raise
                    else:
                        # both Anthopic and openai gives this kind of error, but h2oGPT only has retries for OpenAI
                        if "Overloaded" in str(traceback.format_tb(e.__traceback__)):
                            sleep_time = 30 + 2 ** (trial + 1)
                        else:
                            sleep_time = 1 * trial
                        print_warning(
                            "trying again: %s in %s seconds" % (trial, sleep_time)
                        )
                        time.sleep(sleep_time)
                finally:
                    # in case server changed, update in case clone()
                    if do_lock:
                        with lock:
                            self.server_hash = client.server_hash
                    else:
                        self.server_hash = client.server_hash

            t1 = time.time()
            print_info(
                dict(
                    api="submit_nochat_api",
                    streaming=stream_output,
                    texts_in=len(text or []) + len(text_context_list or []),
                    texts_out=len(texts_out),
                    images=len(image_file)
                    if isinstance(image_file, list)
                    else 1
                    if image_file
                    else 0,
                    response_time=str(timedelta(seconds=t1 - t0)),
                    response_len=len(response),
                    llm=visible_models,
                    actual_llm=actual_llm,
                )
            )
        finally:
            # in case server changed, update in case clone()
            if do_lock:
                with lock:
                    self.server_hash = client.server_hash
            else:
                self.server_hash = client.server_hash

    def check_model(self, model):
        if model != 0 and self.check_model_name:
            valid_llms = self.list_models()
            if (
                    isinstance(model, int)
                    and model >= len(valid_llms)
                    or isinstance(model, str)
                    and model not in valid_llms
            ):
                did_you_mean = ""
                if isinstance(model, str):
                    alt = difflib.get_close_matches(model, valid_llms, 1)
                    if alt:
                        did_you_mean = f"\nDid you mean {repr(alt[0])}?"
                raise RuntimeError(
                    f"Invalid llm: {repr(model)}, must be either an integer between "
                    f"0 and {len(valid_llms) - 1} or one of the following values: {valid_llms}.{did_you_mean}"
                )

    @staticmethod
    def _get_ttl_hash(seconds=60):
        """Return the same value within `seconds` time period"""
        return round(time.time() / seconds)

    @lru_cache()
    def _get_models_full(self, ttl_hash=None, do_lock=False) -> List[Dict[str, Any]]:
        """
        Full model info in list if dict (cached)
        """
        del ttl_hash  # to emphasize we don't use it and to shut pylint up
        if self.config is None:
            self.setup()
        client = self.clone()
        try:
            return ast.literal_eval(client.predict(api_name="/model_names"))
        finally:
            if do_lock:
                with lock:
                    self.server_hash = client.server_hash
            else:
                self.server_hash = client.server_hash

    def get_models_full(self, do_lock=False) -> List[Dict[str, Any]]:
        """
        Full model info in list if dict
        """
        return self._get_models_full(ttl_hash=self._get_ttl_hash(), do_lock=do_lock)

    def list_models(self) -> List[str]:
        """
        Model names available from endpoint
        """
        return [x["display_name"] for x in self.get_models_full()]

    def simple_stream(
            self,
            client_kwargs={},
            api_name="/submit_nochat_api",
            prompt="",
            prompter=None,
            sanitize_bot_response=False,
            max_time=300,
            is_public=False,
            raise_exception=True,
            verbose=False,
    ):
        job = self.submit(str(dict(client_kwargs)), api_name=api_name)
        sources = []
        res_dict = dict(
            response="",
            sources=sources,
            save_dict={},
            llm_answers={},
            response_no_refs="",
            sources_str="",
            prompt_raw="",
        )
        yield res_dict
        text = ""
        text0 = ""
        strex = ""
        tgen0 = time.time()
        while not job.done():
            e = check_job(job, timeout=0, raise_exception=False)
            if e is not None:
                break
            outputs_list = job.outputs().copy()
            if outputs_list:
                res = outputs_list[-1]
                res_dict = ast.literal_eval(res)
                text = res_dict["response"] if "response" in res_dict else ""
                prompt_and_text = prompt + text
                if prompter:
                    response = prompter.get_response(
                        prompt_and_text,
                        prompt=prompt,
                        sanitize_bot_response=sanitize_bot_response,
                    )
                else:
                    response = text
                text_chunk = response[len(text0):]
                if not text_chunk:
                    # just need some sleep for threads to switch
                    time.sleep(0.001)
                    continue
                # save old
                text0 = response
                res_dict.update(
                    dict(
                        response=response,
                        sources=sources,
                        error=strex,
                        response_no_refs=response,
                    )
                )
                yield res_dict
                if time.time() - tgen0 > max_time:
                    if verbose:
                        print(
                            "Took too long for Gradio: %s" % (time.time() - tgen0),
                            flush=True,
                        )
                    break
            time.sleep(0.005)
        # ensure get last output to avoid race
        res_all = job.outputs().copy()
        success = job.communicator.job.latest_status.success
        timeout = 0.1 if success else 10
        if len(res_all) > 0:
            # don't raise unless nochat API for now
            e = check_job(job, timeout=timeout, raise_exception=True)
            if e is not None:
                strex = "".join(traceback.format_tb(e.__traceback__))

            res = res_all[-1]
            res_dict = ast.literal_eval(res)
            text = res_dict["response"]
            sources = res_dict.get("sources")
            if sources is None:
                # then communication terminated, keep what have, but send error
                if is_public:
                    raise ValueError("Abrupt termination of communication")
                else:
                    raise ValueError("Abrupt termination of communication: %s" % strex)
        else:
            # if got no answer at all, probably something bad, always raise exception
            # UI will still put exception in Chat History under chat exceptions
            e = check_job(job, timeout=2.0 * timeout, raise_exception=True)
            # go with old text if last call didn't work
            if e is not None:
                stre = str(e)
                strex = "".join(traceback.format_tb(e.__traceback__))
            else:
                stre = ""
                strex = ""

            print(
                "Bad final response:%s %s %s: %s %s"
                % (res_all, prompt, text, stre, strex),
                flush=True,
            )
        prompt_and_text = prompt + text
        if prompter:
            response = prompter.get_response(
                prompt_and_text,
                prompt=prompt,
                sanitize_bot_response=sanitize_bot_response,
            )
        else:
            response = text
        res_dict.update(
            dict(
                response=response,
                sources=sources,
                error=strex,
                response_no_refs=response,
            )
        )
        yield res_dict
        return res_dict

    def stream(
            self,
            client_kwargs={},
            api_name="/submit_nochat_api",
            prompt="",
            prompter=None,
            sanitize_bot_response=False,
            max_time=None,
            is_public=False,
            raise_exception=True,
            verbose=False,
    ):
        strex = ""
        e = None
        res_dict = {}
        try:
            res_dict = yield from self._stream(
                client_kwargs,
                api_name=api_name,
                prompt=prompt,
                prompter=prompter,
                sanitize_bot_response=sanitize_bot_response,
                max_time=max_time,
                verbose=verbose,
            )
        except Exception as e:
            strex = "".join(traceback.format_tb(e.__traceback__))
            # check validity of final results and check for timeout
            # NOTE: server may have more before its timeout, and res_all will have more if waited a bit
            if raise_exception:
                raise

        if "timeout" in res_dict["save_dict"]["extra_dict"]:
            timeout_time = res_dict["save_dict"]["extra_dict"]["timeout"]
            raise TimeoutError(
                "Timeout from local after %s %s"
                % (timeout_time, ": " + strex if e else "")
            )

        # won't have sources if timed out
        if res_dict.get("sources") is None:
            # then communication terminated, keep what have, but send error
            if is_public:
                raise ValueError("Abrupt termination of communication")
            else:
                raise ValueError("Abrupt termination of communication: %s" % strex)
        return res_dict

    def _stream(
            self,
            client_kwargs,
            api_name="/submit_nochat_api",
            prompt="",
            prompter=None,
            sanitize_bot_response=False,
            max_time=None,
            verbose=False,
    ):
        job = self.submit(str(dict(client_kwargs)), api_name=api_name)

        text = ""
        sources = []
        save_dict = {}
        save_dict["extra_dict"] = {}
        res_dict = dict(
            response=text,
            sources=sources,
            save_dict=save_dict,
            llm_answers={},
            response_no_refs=text,
            sources_str="",
            prompt_raw="",
        )
        yield res_dict

        text0 = ""
        tgen0 = time.time()
        n = 0
        for res in job:
            res_dict, text0 = yield from self.yield_res(
                res,
                res_dict,
                prompt,
                prompter,
                sanitize_bot_response,
                max_time,
                text0,
                tgen0,
                verbose,
            )
            n += 1
            if "timeout" in res_dict["save_dict"]["extra_dict"]:
                break
        # final res
        outputs = job.outputs().copy()
        all_n = len(outputs)
        for nn in range(n, all_n):
            res = outputs[nn]
            res_dict, text0 = yield from self.yield_res(
                res,
                res_dict,
                prompt,
                prompter,
                sanitize_bot_response,
                max_time,
                text0,
                tgen0,
                verbose,
            )
        return res_dict

    @staticmethod
    def yield_res(
            res,
            res_dict,
            prompt,
            prompter,
            sanitize_bot_response,
            max_time,
            text0,
            tgen0,
            verbose,
    ):
        do_yield = True
        res_dict_server = ast.literal_eval(res)
        # yield what have
        text = res_dict_server["response"]
        if text is None:
            print("text None", flush=True)
            text = ""
        if prompter:
            response = prompter.get_response(
                prompt + text,
                prompt=prompt,
                sanitize_bot_response=sanitize_bot_response,
            )
        else:
            response = text
        text_chunk = response[len(text0):]
        if not text_chunk:
            # just need some sleep for threads to switch
            time.sleep(0.001)
            do_yield = False
        # save old
        text0 = response
        res_dict.update(res_dict_server)
        res_dict.update(dict(response=response, response_no_refs=response))

        timeout_time_other = (
            res_dict.get("save_dict", {}).get("extra_dict", {}).get("timeout")
        )
        if timeout_time_other:
            if verbose:
                print(
                    "Took too long for other Gradio: %s" % (time.time() - tgen0),
                    flush=True,
                )
            return res_dict, text0

        timeout_time = time.time() - tgen0
        if max_time is not None and timeout_time > max_time:
            if "save_dict" not in res_dict:
                res_dict["save_dict"] = {}
            if "extra_dict" not in res_dict["save_dict"]:
                res_dict["save_dict"]["extra_dict"] = {}
            res_dict["save_dict"]["extra_dict"]["timeout"] = timeout_time
            yield res_dict
            if verbose:
                print(
                    "Took too long for Gradio: %s" % (time.time() - tgen0), flush=True
                )
            return res_dict, text0
        if do_yield:
            yield res_dict
            time.sleep(0.005)
        return res_dict, text0


class H2OGradioClient(CommonClient, Client):
    """
    Parent class of gradio client
    To handle automatically refreshing client if detect gradio server changed
    """

    def reset_session(self) -> None:
        self.session_hash = str(uuid.uuid4())
        if hasattr(self, "include_heartbeat") and self.include_heartbeat:
            self._refresh_heartbeat.set()

    def __init__(
            self,
            src: str,
            hf_token: str | None = None,
            max_workers: int = 40,
            serialize: bool | None = None,  # TODO: remove in 1.0
            output_dir: str
                        | Path = DEFAULT_TEMP_DIR,  # Maybe this can be combined with `download_files` in 1.0
            verbose: bool = False,
            auth: tuple[str, str] | None = None,
            *,
            headers: dict[str, str] | None = None,
            upload_files: bool = True,  # TODO: remove and hardcode to False in 1.0
            download_files: bool = True,  # TODO: consider setting to False in 1.0
            _skip_components: bool = True,
            # internal parameter to skip values certain components (e.g. State) that do not need to be displayed to users.
            ssl_verify: bool = True,
            h2ogpt_key: str = None,
            persist: bool = False,
            check_hash: bool = True,
            check_model_name: bool = False,
            include_heartbeat: bool = False,
    ):
        """
        Parameters:
            Base Class parameters
            +
            h2ogpt_key: h2oGPT key to gain access to the server
            persist: whether to persist the state, so repeated calls are aware of the prior user session
                     This allows the scratch MyData to be reused, etc.
                     This also maintains the chat_conversation history
            check_hash: whether to check git hash for consistency between server and client to ensure API always up to date
            check_model_name: whether to check the model name here (adds delays), or just let server fail (faster)
        """
        if serialize is None:
            # else converts inputs arbitrarily and outputs mutate
            # False keeps as-is and is normal for h2oGPT
            serialize = False
        self.args = tuple([src])
        self.kwargs = dict(
            hf_token=hf_token,
            max_workers=max_workers,
            serialize=serialize,
            output_dir=output_dir,
            verbose=verbose,
            h2ogpt_key=h2ogpt_key,
            persist=persist,
            check_hash=check_hash,
            check_model_name=check_model_name,
            include_heartbeat=include_heartbeat,
        )
        if is_gradio_client_version7plus:
            # 4.18.0:
            # self.kwargs.update(dict(auth=auth, upload_files=upload_files, download_files=download_files))
            # 4.17.0:
            # self.kwargs.update(dict(auth=auth))
            # 4.24.0:
            self._skip_components = _skip_components
            self.ssl_verify = ssl_verify
            self.kwargs.update(
                dict(
                    auth=auth,
                    upload_files=upload_files,
                    download_files=download_files,
                    ssl_verify=ssl_verify,
                )
            )

        self.verbose = verbose
        self.hf_token = hf_token
        if serialize is not None:
            warnings.warn(
                "The `serialize` parameter is deprecated and will be removed. Please use the equivalent `upload_files` parameter instead."
            )
            upload_files = serialize
        self.serialize = serialize
        self.upload_files = upload_files
        self.download_files = download_files
        self.space_id = None
        self.cookies: dict[str, str] = {}
        if is_gradio_client_version7plus:
            self.output_dir = (
                str(output_dir) if isinstance(output_dir, Path) else output_dir
            )
        else:
            self.output_dir = output_dir
        self.max_workers = max_workers
        self.src = src
        self.auth = auth
        self.headers = headers

        self.config = None
        self.h2ogpt_key = h2ogpt_key
        self.persist = persist
        self.check_hash = check_hash
        self.check_model_name = check_model_name
        self.include_heartbeat = include_heartbeat

        self.chat_conversation = []  # internal for persist=True
        self.server_hash = None  # internal

    def __repr__(self):
        if self.config and False:
            # too slow for guardrails exceptional path
            return self.view_api(print_info=False, return_format="str")
        return "Not setup for %s" % self.src

    def __str__(self):
        if self.config and False:
            # too slow for guardrails exceptional path
            return self.view_api(print_info=False, return_format="str")
        return "Not setup for %s" % self.src

    def setup(self):
        src = self.src

        headers0 = self.headers
        self.headers = build_hf_headers(
            token=self.hf_token,
            library_name="gradio_client",
            library_version=utils.__version__,
        )
        if headers0:
            self.headers.update(headers0)
        if (
                "authorization" in self.headers
                and self.headers["authorization"] == "Bearer "
        ):
            self.headers["authorization"] = "Bearer hf_xx"
        if src.startswith("http://") or src.startswith("https://"):
            _src = src if src.endswith("/") else src + "/"
        else:
            _src = self._space_name_to_src(src)
            if _src is None:
                raise ValueError(
                    f"Could not find Space: {src}. If it is a private Space, please provide an hf_token."
                )
            self.space_id = src
        self.src = _src
        state = self._get_space_state()
        if state == SpaceStage.BUILDING:
            if self.verbose:
                print("Space is still building. Please wait...")
            while self._get_space_state() == SpaceStage.BUILDING:
                time.sleep(2)  # so we don't get rate limited by the API
                pass
        if state in utils.INVALID_RUNTIME:
            raise ValueError(
                f"The current space is in the invalid state: {state}. "
                "Please contact the owner to fix this."
            )
        if self.verbose:
            print(f"Loaded as API: {self.src} ✔")

        if is_gradio_client_version7plus:
            if self.auth is not None:
                self._login(self.auth)

        self.config = self._get_config()
        self.api_url = urllib.parse.urljoin(self.src, utils.API_URL)
        if is_gradio_client_version7plus:
            self.protocol: Literal[
                "ws", "sse", "sse_v1", "sse_v2", "sse_v2.1"
            ] = self.config.get("protocol", "ws")
            self.sse_url = urllib.parse.urljoin(
                self.src, utils.SSE_URL_V0 if self.protocol == "sse" else utils.SSE_URL
            )
            if hasattr(utils, "HEARTBEAT_URL") and self.include_heartbeat:
                self.heartbeat_url = urllib.parse.urljoin(self.src, utils.HEARTBEAT_URL)
            else:
                self.heartbeat_url = None
            self.sse_data_url = urllib.parse.urljoin(
                self.src,
                utils.SSE_DATA_URL_V0 if self.protocol == "sse" else utils.SSE_DATA_URL,
            )
        self.ws_url = urllib.parse.urljoin(
            self.src.replace("http", "ws", 1), utils.WS_URL
        )
        self.upload_url = urllib.parse.urljoin(self.src, utils.UPLOAD_URL)
        self.reset_url = urllib.parse.urljoin(self.src, utils.RESET_URL)
        if is_gradio_client_version7plus:
            self.app_version = version.parse(self.config.get("version", "2.0"))
            self._info = self._get_api_info()
        self.session_hash = str(uuid.uuid4())

        self.get_endpoints(self)

        # Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1
        # threading.Thread(target=self._telemetry_thread, daemon=True).start()
        if (
                is_gradio_client_version7plus
                and hasattr(utils, "HEARTBEAT_URL")
                and self.include_heartbeat
        ):
            self._refresh_heartbeat = threading.Event()
            self._kill_heartbeat = threading.Event()

            self.heartbeat = threading.Thread(
                target=self._stream_heartbeat, daemon=True
            )
            self.heartbeat.start()

        self.server_hash = self.get_server_hash()

        return self

    @staticmethod
    def get_endpoints(client, verbose=False):
        t0 = time.time()
        # Create a pool of threads to handle the requests
        client.executor = concurrent.futures.ThreadPoolExecutor(
            max_workers=client.max_workers
        )
        if is_gradio_client_version7plus:
            from gradio_client.client import EndpointV3Compatibility

            endpoint_class = (
                Endpoint
                if client.protocol.startswith("sse")
                else EndpointV3Compatibility
            )
        else:
            endpoint_class = Endpoint

        if is_gradio_client_version7plus:
            client.endpoints = [
                endpoint_class(client, fn_index, dependency, client.protocol)
                for fn_index, dependency in enumerate(client.config["dependencies"])
            ]
        else:
            client.endpoints = [
                endpoint_class(client, fn_index, dependency)
                for fn_index, dependency in enumerate(client.config["dependencies"])
            ]
        if is_gradio_client_version7plus:
            client.stream_open = False
            client.streaming_future = None
            from gradio_client.utils import Message

            client.pending_messages_per_event = {}
            client.pending_event_ids = set()
        if verbose:
            print("duration endpoints: %s" % (time.time() - t0), flush=True)

    @staticmethod
    def is_full_git_hash(s):
        # This regex checks for exactly 40 hexadecimal characters.
        return bool(re.fullmatch(r"[0-9a-f]{40}", s))

    def get_server_hash(self) -> str:
        return self._get_server_hash(ttl_hash=self._get_ttl_hash())

    def _get_server_hash(self, ttl_hash=None) -> str:
        """
        Get server hash using super without any refresh action triggered
        Returns: git hash of gradio server
        """
        del ttl_hash  # to emphasize we don't use it and to shut pylint up
        t0 = time.time()
        if self.config is None:
            self.setup()
        t1 = time.time()
        ret = "GET_GITHASH_UNSET"
        try:
            if self.check_hash:
                ret = super().submit(api_name="/system_hash").result()
                assert self.is_full_git_hash(ret), f"ret is not a full git hash: {ret}"
            return ret
        finally:
            if self.verbose:
                print(
                    "duration server_hash: %s full time: %s system_hash time: %s"
                    % (ret, time.time() - t0, time.time() - t1),
                    flush=True,
                )

    def refresh_client_if_should(self):
        if self.config is None:
            self.setup()
        # get current hash in order to update api_name -> fn_index map in case gradio server changed
        # FIXME: Could add cli api as hash
        server_hash = self.get_server_hash()
        if self.server_hash != server_hash:
            if self.verbose:
                print(
                    "server hash changed: %s %s" % (self.server_hash, server_hash),
                    flush=True,
                )
            if self.server_hash is not None and self.persist:
                if self.verbose:
                    print(
                        "Failed to persist due to server hash change, only kept chat_conversation not user session hash",
                        flush=True,
                    )
            # risky to persist if hash changed
            self.refresh_client()
            self.server_hash = server_hash

    def refresh_client(self):
        """
        Ensure every client call is independent
        Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code)
        Returns:
        """
        if self.config is None:
            self.setup()

        kwargs = self.kwargs.copy()
        kwargs.pop("h2ogpt_key", None)
        kwargs.pop("persist", None)
        kwargs.pop("check_hash", None)
        kwargs.pop("check_model_name", None)
        kwargs.pop("include_heartbeat", None)
        ntrials = 3
        client = None
        for trial in range(0, ntrials):
            try:
                client = Client(*self.args, **kwargs)
                break
            except ValueError as e:
                if trial >= ntrials:
                    raise
                else:
                    if self.verbose:
                        print("Trying refresh %d/%d %s" % (trial, ntrials - 1, str(e)))
                    trial += 1
                    time.sleep(10)
        if client is None:
            raise RuntimeError("Failed to get new client")
        session_hash0 = self.session_hash if self.persist else None
        for k, v in client.__dict__.items():
            setattr(self, k, v)
        if session_hash0:
            # keep same system hash in case server API only changed and not restarted
            self.session_hash = session_hash0
        if self.verbose:
            print("Hit refresh_client(): %s %s" % (self.session_hash, session_hash0))
        # ensure server hash also updated
        self.server_hash = self.get_server_hash()

    def clone(self, do_lock=False):
        if do_lock:
            with lock:
                return self._clone()
        else:
            return self._clone()

    def _clone(self):
        if self.config is None:
            self.setup()
        client = self.__class__("")
        for k, v in self.__dict__.items():
            setattr(client, k, v)
        client.reset_session()

        self.get_endpoints(client)

        # transfer internals in case used
        client.server_hash = self.server_hash
        client.chat_conversation = self.chat_conversation
        return client

    def submit(
            self,
            *args,
            api_name: str | None = None,
            fn_index: int | None = None,
            result_callbacks: Callable | list[Callable] | None = None,
            exception_handling=True,  # new_stream = True, can make False, doesn't matter.
    ) -> Job:
        if self.config is None:
            self.setup()
        # Note predict calls submit
        try:
            self.refresh_client_if_should()
            job = super().submit(*args, api_name=api_name, fn_index=fn_index)
        except Exception as e:
            ex = traceback.format_exc()
            print(
                "Hit e=%s\n\n%s\n\n%s"
                % (str(ex), traceback.format_exc(), self.__dict__),
                flush=True,
            )
            # force reconfig in case only that
            self.refresh_client()
            job = super().submit(*args, api_name=api_name, fn_index=fn_index)

        if exception_handling:  # for debugging if causes issues
            # see if immediately failed
            e = check_job(job, timeout=0.01, raise_exception=False)
            if e is not None:
                print(
                    "GR job failed: %s %s"
                    % (str(e), "".join(traceback.format_tb(e.__traceback__))),
                    flush=True,
                )
                # force reconfig in case only that
                self.refresh_client()
                job = super().submit(*args, api_name=api_name, fn_index=fn_index)
                e2 = check_job(job, timeout=0.1, raise_exception=False)
                if e2 is not None:
                    print(
                        "GR job failed again: %s\n%s"
                        % (str(e2), "".join(traceback.format_tb(e2.__traceback__))),
                        flush=True,
                    )

        return job


class CloneableGradioClient(CommonClient, Client):
    def __init__(self, *args, **kwargs):
        self._original_config = None
        self._original_info = None
        self._original_endpoints = None
        self._original_executor = None
        self._original_heartbeat = None
        self._quiet = kwargs.pop('quiet', False)
        super().__init__(*args, **kwargs)
        self._initialize_session_specific()
        self._initialize_shared_info()
        atexit.register(self.cleanup)
        self.auth = kwargs.get('auth')

    def _initialize_session_specific(self):
        """Initialize or reset session-specific attributes."""
        self.session_hash = str(uuid.uuid4())
        self._refresh_heartbeat = threading.Event()
        self._kill_heartbeat = threading.Event()
        self.stream_open = False
        self.streaming_future = None
        self.pending_messages_per_event = {}
        self.pending_event_ids = set()

    def _initialize_shared_info(self):
        """Initialize information that can be shared across clones."""
        if self._original_config is None:
            self._original_config = super().config
        if self._original_info is None:
            self._original_info = super()._info
        if self._original_endpoints is None:
            self._original_endpoints = super().endpoints
        if self._original_executor is None:
            self._original_executor = super().executor
        if self._original_heartbeat is None:
            self._original_heartbeat = super().heartbeat

    @property
    def config(self):
        return self._original_config

    @config.setter
    def config(self, value):
        self._original_config = value

    @property
    def _info(self):
        return self._original_info

    @_info.setter
    def _info(self, value):
        self._original_info = value

    @property
    def endpoints(self):
        return self._original_endpoints

    @endpoints.setter
    def endpoints(self, value):
        self._original_endpoints = value

    @property
    def executor(self):
        return self._original_executor

    @executor.setter
    def executor(self, value):
        self._original_executor = value

    @property
    def heartbeat(self):
        return self._original_heartbeat

    @heartbeat.setter
    def heartbeat(self, value):
        self._original_heartbeat = value

    def setup(self):
        # no-op
        pass

    @staticmethod
    def _get_ttl_hash(seconds=60):
        """Return the same value within `seconds` time period"""
        return round(time.time() / seconds)

    def get_server_hash(self) -> str:
        return self._get_server_hash(ttl_hash=self._get_ttl_hash())

    def _get_server_hash(self, ttl_hash=None):
        del ttl_hash  # to emphasize we don't use it and to shut pylint up
        return self.predict(api_name="/system_hash")

    def clone(self):
        """Create a new CloneableGradioClient instance with the same configuration but a new session."""
        new_client = copy.copy(self)
        new_client._initialize_session_specific()
        new_client._quiet = True  # Set the cloned client to quiet mode
        atexit.register(new_client.cleanup)
        return new_client

    def __repr__(self):
        if self._quiet:
            return f"<CloneableGradioClient (quiet) connected to {self.src}>"
        return super().__repr__()

    def __str__(self):
        if self._quiet:
            return f"CloneableGradioClient (quiet) connected to {self.src}"
        return super().__str__()

    def cleanup(self):
        """Clean up resources used by this client."""
        if self._original_executor:
            self._original_executor.shutdown(wait=False)
        if self._kill_heartbeat:
            self._kill_heartbeat.set()
        if self._original_heartbeat:
            self._original_heartbeat.join(timeout=1)
        atexit.unregister(self.cleanup)


if old_gradio:
    GradioClient = H2OGradioClient
else:
    GradioClient = CloneableGradioClient