File size: 44,412 Bytes
17d7268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aa30de
 
 
 
 
 
 
 
 
 
 
 
 
17d7268
0aa30de
3672495
0aa30de
 
 
 
17d7268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80f09db
17d7268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80f09db
 
17d7268
80f09db
17d7268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80f09db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17d7268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80f09db
 
 
 
 
 
 
 
 
 
 
 
 
 
17d7268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
import uuid
import streamlit as st
from openai import AzureOpenAI
import firebase_admin
from firebase_admin import credentials, firestore
from typing import Dict, Any
import time
import os
import tempfile
import json

from utils.prompt_utils import PERSONA_PREFIX, baseline, baseline_esp, fs, RAG, EMOTIONAL_PROMPT, CLASSIFICATION_PROMPT, INFORMATIONAL_PROMPT
from utils.RAG_utils import load_or_create_vectorstore

# PERSONA_PREFIX = ""
# baseline = ""
# baseline_esp = ""
# fs = ""
# RAG = ""
# EMOTIONAL_PROMPT = ""
# CLASSIFICATION_PROMPT = """
# Determine si esta afirmación busca empatía o (1) o busca información (0).
# Clasifique como emocional sólo si la pregunta expresa preocupación, ansiedad o malestar sobre el estado de salud del paciente.
# En caso contrario, clasificar como informativo.

# Ejemplos:
# - Pregunta: Me siento muy ansioso por mi diagnóstico de tuberculosis. 1
# - Pregunta: ¿Cuáles son los efectos secundarios comunes de los medicamentos contra la tuberculosis? 0
# - Pregunta: Estoy preocupada porque tengo mucho dolor. 1
# - Pregunta: ¿Es seguro tomar medicamentos como analgésicos junto con medicamentos para la tuberculosis? 0

# Aquí está la declaración para clasificar. Simplemente responda con el número "1" o "0":
# """

# INFORMATIONAL_PROMPT = ""

# Model configurations remain the same
MODEL_CONFIGS = {
    # "Model 0: Naive English Baseline Model": {
    #     "name": "Model 0: Naive English Baseline Model",
    #     "prompt": PERSONA_PREFIX + baseline,
    #     "uses_rag": False,
    #     "uses_classification": False
    # },
    # "Model 1: Naive Spanish Baseline Model": {
    #     "name": "Model 1: Baseline Model",
    #     "prompt": PERSONA_PREFIX + baseline_esp,
    #     "uses_rag": False,
    #     "uses_classification": False
    # },
    # "Model 1": {
    #     "name": "Model 1: Few_Shot model",
    #     "prompt": PERSONA_PREFIX + fs,
    #     "uses_rag": False,
    #     "uses_classification": False
    # },
    # "Model 3: RAG Model": {F
    #     "name": "Model 3: RAG Model",
    #     "prompt": PERSONA_PREFIX + RAG,
    #     "uses_rag": True,
    #     "uses_classification": False
    # },
    # "Model 2": {
    #     "name": "Model 2: RAG + Few_Shot Model",
    #     "prompt": PERSONA_PREFIX + RAG + fs,
    #     "uses_rag": True,
    #     "uses_classification": False
    # },
    "Model 3": {
        "name": "Model 3: 2-Stage Classification Model",
        "prompt": PERSONA_PREFIX + INFORMATIONAL_PROMPT,  # default
        "uses_rag": False,
        "uses_classification": False
    },
    # "Model 6: Multi-Agent": {
    #     "name": "Model 6: Multi-Agent",
    #     "prompt": PERSONA_PREFIX + INFORMATIONAL_PROMPT,  # default
    #     "uses_rag": True,
    #     "uses_classification": True,
    #     "uses_judges": True
    # }
}
PASSCODE = os.environ["MY_PASSCODE"]
creds_dict = {
    "type": os.environ.get("FIREBASE_TYPE", "service_account"),
    "project_id": os.environ.get("FIREBASE_PROJECT_ID"),
    "private_key_id": os.environ.get("FIREBASE_PRIVATE_KEY_ID"),
    "private_key": os.environ.get("FIREBASE_PRIVATE_KEY", "").replace("\\n", "\n"),
    "client_email": os.environ.get("FIREBASE_CLIENT_EMAIL"),
    "client_id": os.environ.get("FIREBASE_CLIENT_ID"),
    "auth_uri": os.environ.get("FIREBASE_AUTH_URI", "https://accounts.google.com/o/oauth2/auth"),
    "token_uri": os.environ.get("FIREBASE_TOKEN_URI", "https://oauth2.googleapis.com/token"),
    "auth_provider_x509_cert_url": os.environ.get("FIREBASE_AUTH_PROVIDER_X509_CERT_URL", 
                                                "https://www.googleapis.com/oauth2/v1/certs"),
    "client_x509_cert_url": os.environ.get("FIREBASE_CLIENT_X509_CERT_URL"),
    "universe_domain": "googleapis.com"

}

# Create a temporary JSON file
file_path  = "coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json"
with open(file_path, 'w') as json_file:
    json.dump(creds_dict, json_file, indent=2)

# Initialize Firebase
if not firebase_admin._apps:
    cred = credentials.Certificate("coco-evaluation-firebase-adminsdk-p3m64-99c4ea22c1.json")
    firebase_admin.initialize_app(cred)
db = firestore.client()

endpoint = os.environ["ENDPOINT_URL"]
deployment = os.environ["DEPLOYMENT"]
subscription_key = os.environ["subscription_key"]

# OpenAI API setup
client = AzureOpenAI(
    azure_endpoint=endpoint,
    api_key=subscription_key,
    api_version=os.environ["api_version"]
)

def authenticate():
    import uuid
    
    random_id = uuid.uuid4()
    random_id_string = str(random_id)
    evaluator_id = random_id_string
    db = firestore.client()
    db.collection("evaluator_ids").document(evaluator_id).set({
        "evaluator_id": evaluator_id,
        "timestamp": firestore.SERVER_TIMESTAMP
    })

    # Update session state
    st.session_state["authenticated"] = True
    st.session_state["evaluator_id"] = evaluator_id


def init():
    """Initialize all necessary components and state variables"""
    # Initialize session state variables
    if "messages" not in st.session_state:
        st.session_state.messages = {}
    if "session_id" not in st.session_state:
        st.session_state.session_id = str(uuid.uuid4())
    if "chat_active" not in st.session_state:
        st.session_state.chat_active = False
    if "user_input" not in st.session_state:
        st.session_state.user_input = ""
    if "user_id" not in st.session_state:
        st.session_state.user_id = f"anonymous_{str(uuid.uuid4())}"
    if "selected_model" not in st.session_state:
        st.session_state.selected_model = list(MODEL_CONFIGS.keys())[0]
    if "model_profile" not in st.session_state:
        st.session_state.model_profile = [0, 0]
    
    # Load vectorstore at startup
    if "vectorstore" not in st.session_state:
        with st.spinner("Loading document embeddings..."):
            st.session_state.vectorstore = load_or_create_vectorstore()

def get_classification(client, deployment, user_input):
    """Classify the input as emotional (1) or informational (0)"""
    chat_prompt = [
        {"role": "system", "content": CLASSIFICATION_PROMPT},
        {"role": "user", "content": user_input}
    ]

    completion = client.chat.completions.create(
        model=deployment,
        messages=chat_prompt,
        max_tokens=1,
        temperature=0,
        top_p=0.9,
        frequency_penalty=0,
        presence_penalty=0,
        stop=None
    )
    
    return completion.choices[0].message.content.strip()
    
def process_input():
    try:
        current_model = st.session_state.selected_model
        user_input = st.session_state.user_input
        
        if not user_input.strip():
            st.warning("Please enter a message before sending.")
            return
            
        model_config = MODEL_CONFIGS.get(current_model)
        if not model_config:
            st.error("Invalid model selected. Please choose a valid model.")
            return

        if current_model not in st.session_state.messages:
            st.session_state.messages[current_model] = []

        st.session_state.messages[current_model].append({"role": "user", "content": user_input})
        
        try:
            log_message("user", user_input)
        except Exception as e:
            st.warning(f"Failed to log message: {str(e)}")
            
        conversation_history = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" 
                                        for msg in st.session_state.messages[current_model]])

        # Helper function for error handling in API calls
        def safe_api_call(messages, max_retries=3):
            for attempt in range(max_retries):
                try:
                    response = client.chat.completions.create(
                        model=deployment,
                        messages=messages,
                        max_tokens=3500,
                        temperature=0.1,
                        top_p=0.9
                    )
                    return response.choices[0].message.content.strip()
                except Exception as e:
                    if attempt == max_retries - 1:
                        # Return user-friendly message instead of raising exception
                        return "This question is not currently supported by the conversation agent or is being flagged by the AI algorithm as being outside its parameters. If you think the question should be answered, please inform the research team what should be added with justification and if available please provide links to resources to support further model training. Thank you."
                    time.sleep(1)
        

        def perform_rag_query(input_text, conversation_history):
            try:
                relevant_docs = retrieve_relevant_documents(
                    st.session_state.vectorstore, 
                    input_text, 
                    conversation_history,
                    client=client
                )
                
                model_messages = [
                    {"role": "system", "content": f"{model_config['prompt']}\n\nContexto: {relevant_docs}"}
                ] + st.session_state.messages[current_model]
                
                return safe_api_call(model_messages), relevant_docs
                
            except Exception as e:
                # Use standardized error message
                return "This question is not currently supported by the conversation agent or is being flagged by the AI algorithm as being outside its parameters. If you think the question should be answered, please inform the research team what should be added with justification and if available please provide links to resources to support further model training. Thank you.", ""
        
        # Update these sections too:
        if model_config.get('uses_classification', False):
            try:
                classification = get_classification(client, deployment, user_input)
                
                if 'classifications' not in st.session_state:
                    st.session_state.classifications = {}
                st.session_state.classifications[len(st.session_state.messages[current_model]) - 1] = classification
        
                if classification == "0":
                    initial_response, initial_docs = perform_rag_query(user_input, conversation_history)
                else:
                    model_messages = [
                        {"role": "system", "content": PERSONA_PREFIX + EMOTIONAL_PROMPT}
                    ] + st.session_state.messages[current_model]
                    initial_response = safe_api_call(model_messages)
                    
            except Exception as e:
                # Replace error message with standardized message
                initial_response = "This question is not currently supported by the conversation agent or is being flagged by the AI algorithm as being outside its parameters. If you think the question should be answered, please inform the research team what should be added with justification and if available please provide links to resources to support further model training. Thank you."
        
        # And also update the RAG models section:
        if model_config.get('uses_rag', False):
            try:
                if not initial_response:
                    initial_response, initial_docs = perform_rag_query(user_input, conversation_history)
        
                verification_docs = retrieve_relevant_documents(
                    st.session_state.vectorstore, 
                    initial_response,
                    conversation_history,
                    client=client
                )
        
                combined_docs = initial_docs + "\nContexto de verificación adicional:\n" + verification_docs
        
                verification_messages = [
                    {
                        "role": "system", 
                        "content": f"Pregunta del paciente:{user_input} \nContexto: {combined_docs} \nRespuesta anterior: {initial_response}\n Verifique la precisión médica de la respuesta anterior y refine la respuesta según el contexto adicional."
                    }
                ]
        
                assistant_reply = safe_api_call(verification_messages)
                
            except Exception as e:
                # Replace error message with standardized message
                assistant_reply = "This question is not currently supported by the conversation agent or is being flagged by the AI algorithm as being outside its parameters. If you think the question should be answered, please inform the research team what should be added with justification and if available please provide links to resources to support further model training. Thank you."
        else:
            try:
                model_messages = [
                    {"role": "system", "content": model_config['prompt']}
                ] + st.session_state.messages[current_model]
                
                assistant_reply = safe_api_call(model_messages)
                
            except Exception as e:
                # Replace error message with standardized message
                assistant_reply = "This question is not currently supported by the conversation agent or is being flagged by the AI algorithm as being outside its parameters. If you think the question should be answered, please inform the research team what should be added with justification and if available please provide links to resources to support further model training. Thank you."
        

        initial_response = None
        initial_docs = ""

        # Handle 2-stage model
        if model_config.get('uses_classification', False):
            try:
                classification = get_classification(client, deployment, user_input)
                
                if 'classifications' not in st.session_state:
                    st.session_state.classifications = {}
                st.session_state.classifications[len(st.session_state.messages[current_model]) - 1] = classification

                if classification == "0":
                    initial_response, initial_docs = perform_rag_query(user_input, conversation_history)
                else:
                    model_messages = [
                        {"role": "system", "content": PERSONA_PREFIX + EMOTIONAL_PROMPT}
                    ] + st.session_state.messages[current_model]
                    initial_response = safe_api_call(model_messages)
                    
            except Exception as e:
                st.error(f"Error in classification stage: {str(e)}")
                initial_response = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente."

        # Handle RAG models
        if model_config.get('uses_rag', False):
            try:
                if not initial_response:
                    initial_response, initial_docs = perform_rag_query(user_input, conversation_history)

                verification_docs = retrieve_relevant_documents(
                    st.session_state.vectorstore, 
                    initial_response,
                    conversation_history,
                    client=client
                )

                combined_docs = initial_docs + "\nContexto de verificación adicional:\n" + verification_docs

                verification_messages = [
                    {
                        "role": "system", 
                        "content": f"Pregunta del paciente:{user_input} \nContexto: {combined_docs} \nRespuesta anterior: {initial_response}\n Verifique la precisión médica de la respuesta anterior y refine la respuesta según el contexto adicional."
                    }
                ]

                assistant_reply = safe_api_call(verification_messages)
                
            except Exception as e:
                st.error(f"Error in RAG processing: {str(e)}")
                assistant_reply = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente."
        else:
            try:
                model_messages = [
                    {"role": "system", "content": model_config['prompt']}
                ] + st.session_state.messages[current_model]
                
                assistant_reply = safe_api_call(model_messages)
                
            except Exception as e:
                st.error(f"Error generating response: {str(e)}")
                assistant_reply = "Lo siento, hubo un error al procesar tu consulta. Por favor, intenta nuevamente."

        # Store and log the final response
        try:
            st.session_state.messages[current_model].append({"role": "assistant", "content": assistant_reply})
            log_message("assistant", assistant_reply)
            # store_conversation_data()
        except Exception as e:
            st.warning(f"Failed to store or log response: {str(e)}")

        st.session_state.user_input = ""
        
    except Exception as e:
        st.error(f"An unexpected error occurred: {str(e)}")
        st.session_state.user_input = ""


def check_document_relevance(query, doc, client):
    """
    Check document relevance using few-shot prompting for Spanish TB context.
    
    Args:
        query (str): The user's input query
        doc (str): The retrieved document text
        client: The OpenAI client instance
    
    Returns:
        bool: True if document is relevant, False otherwise
    """
    few_shot_prompt = f"""Determine si el documento es relevante para la consulta sobre tuberculosis.
        Responde únicamente 'sí' si es relevante o 'no' si no es relevante.
        Ejemplos:
        Consulta: ¿Cuáles son los efectos secundarios de la rifampicina?
        Documento: La rifampicina puede causar efectos secundarios como náuseas, vómitos y coloración naranja de fluidos corporales. Es importante tomar el medicamento con el estómago vacío.
        Respuesta: sí
        Consulta: ¿Cuánto dura el tratamiento de TB?
        Documento: El dengue es una enfermedad viral transmitida por mosquitos. Los síntomas incluyen fiebre alta y dolor muscular.
        Respuesta: no
        Consulta: ¿Cómo se realiza la prueba de esputo?
        Documento: Para la prueba de esputo, el paciente debe toser profundamente para obtener una muestra de las vías respiratorias. La muestra debe recogerse en ayunas.
        Respuesta: sí
        Consulta: ¿Qué medidas de prevención debo tomar en casa?
        Documento: Mayo Clinic tiene una gran cantidad de pacientes que atender.
        Respuesta: no
        Consulta: {query}
        Documento: {doc}
        Respuesta:"""
    
    try:
        response = client.chat.completions.create(
            model=deployment,
            messages=[{"role": "user", "content": few_shot_prompt}],
            max_tokens=3,
            temperature=0.1,
            top_p=0.9
        )
        
        return response.choices[0].message.content.strip().lower() == "sí"
    except Exception as e:
        # In case of error, default to false (not relevant)
        print(f"Error in relevance check: {str(e)}")
        return False

# In retrieve_relevant_documents function
def retrieve_relevant_documents(vectorstore, query, conversation_history, client, top_k=3, score_threshold=0.5):
    if not vectorstore:
        st.error("Vector store not initialized")
        return ""
        
    try:
        recent_history = "\n".join(conversation_history.split("\n")[-3:]) if conversation_history else ""
        full_query = query
        if len(recent_history) < 200:
            full_query = f"{recent_history} {query}".strip()
        
        results = vectorstore.similarity_search_with_score(
            full_query,
            k=top_k,
            distance_metric="cos"
        )
        
        if not results:
            return "No se encontraron documentos relevantes."
        
        # Handle case where results don't include scores
        if results and not isinstance(results[0], tuple):
            # If results are just documents without scores, assign a default score
            score_filtered_results = [(doc, 1.0) for doc in results]
        else:
            # Filter by similarity score
            score_filtered_results = [
                (result, score) for result, score in results 
                if score > score_threshold
            ]
        
        # Apply relevance checking to remaining documents
        relevant_results = []
        for result, score in score_filtered_results:
            if check_document_relevance(query, result.page_content, client):
                relevant_results.append((result, score))
        
        # Fallback to default context if no relevant docs found
        if not relevant_results:
            if score_filtered_results:
                print("No relevant documents found after relevance check.")
                return "Eres un modelo de IA centrado en la tuberculosis."
            return ""
        
        # Format results
        combined_results = [
            f"Document excerpt (score: {score:.2f}):\n{result.page_content}"
            for result, score in relevant_results
        ]
        
        return "\n\n".join(combined_results)
        
    except Exception as e:
        st.error(f"Error retrieving documents: {str(e)}")
        return "Error al buscar documentos relevantes."

def store_conversation_data():
    current_model = st.session_state.selected_model
    model_config = MODEL_CONFIGS[current_model]
    
    doc_ref = db.collection('conversations').document(str(st.session_state.session_id))
    doc_ref.set({
        'timestamp': firestore.SERVER_TIMESTAMP,
        'userID': st.session_state.user_id,
        'model_index': list(MODEL_CONFIGS.keys()).index(current_model) + 1,
        'profile_index': st.session_state.model_profile[1],
        'profile': '',
        'conversation': st.session_state.messages[current_model],
        'uses_rag': model_config['uses_rag']
    })

def log_message(role, content):
    current_model = st.session_state.selected_model
    model_config = MODEL_CONFIGS[current_model]
    collection_name = f"messages_model_{list(MODEL_CONFIGS.keys()).index(current_model) + 1}"
    
    doc_ref = db.collection(collection_name).document()
    doc_ref.set({
        'timestamp': firestore.SERVER_TIMESTAMP,
        'session_id': str(st.session_state.session_id),
        'userID': st.session_state.get('user_id', 'anonymous'),
        'role': role,
        'content': content,
        'model_name': model_config['name']
    })

def reset_conversation():
    current_model = st.session_state.selected_model
    
    if current_model in st.session_state.messages and st.session_state.messages[current_model]:
        doc_ref = db.collection('conversation_ends').document()
        doc_ref.set({
            'timestamp': firestore.SERVER_TIMESTAMP,
            'session_id': str(st.session_state.session_id),
            'userID': st.session_state.get('user_id', 'anonymous'),
            'total_messages': len(st.session_state.messages[current_model]),
            'model_name': MODEL_CONFIGS[current_model]['name']
        })
    
    st.session_state.messages[current_model] = []
    st.session_state.session_id = str(uuid.uuid4())
    st.session_state.chat_active = False
    st.query_params.clear()

class ModelEvaluationSystem:
    def __init__(self, db: firestore.Client):
        self.db = db
        self.models_to_evaluate = list(MODEL_CONFIGS.keys())  # Use existing MODEL_CONFIGS
        self._initialize_state()
        self._load_existing_evaluations()

    def _initialize_state(self):
        """Initialize or load evaluation state."""
        if "evaluation_state" not in st.session_state:
            st.session_state.evaluation_state = {}

        if "evaluated_models" not in st.session_state:
            st.session_state.evaluated_models = {}

    def _get_current_user_id(self):
        """
        Get current user identifier.
        """
        return st.session_state["evaluator_id"]

    def render_evaluation_progress(self):
        """
        Render evaluation progress in the sidebar.
        """
        st.sidebar.header("Evaluation Progress")
        
        # Calculate progress
        total_models = len(self.models_to_evaluate)
        evaluated_models = len(st.session_state.evaluated_models)
        
        # Progress bar
        st.sidebar.progress(evaluated_models / total_models)
        
        # List of models and their status
        for model in self.models_to_evaluate:
            status = "✅ Completed" if st.session_state.evaluated_models.get(model, False) else "⏳ Pending"
            st.sidebar.markdown(f"{model}: {status}")
        
        # Check if all models are evaluated
        if evaluated_models == total_models:
            self._render_completion_screen()

    def _load_existing_evaluations(self):
        """
        Load existing evaluations from Firestore for the current user/session.
        """
        try:
            user_id = self._get_current_user_id()
            
            existing_evals = self.db.collection('model_evaluations').document(user_id).get()
            
            if existing_evals.exists:
                loaded_data = existing_evals.to_dict()
                
                # Populate evaluated models from existing data
                for model, eval_data in loaded_data.get('evaluations', {}).items():
                    if eval_data.get('status') == 'complete':
                        st.session_state.evaluated_models[model] = True
                    
                    # Restore slider and text area values
                    st.session_state[f"performance_slider_{model}"] = eval_data.get('overall_score', 5)
                    
                    for dimension, dim_data in eval_data.get('dimension_evaluations', {}).items():
                        dim_key = dimension.lower().replace(' ', '_')
                        st.session_state[f"{dim_key}_score_{model}"] = dim_data.get('score', 5)
                        
                        if dim_data.get('follow_up_reason'):
                            st.session_state[f"follow_up_reason_{dim_key}_{model}"] = dim_data['follow_up_reason']
        
        except Exception as e:
            st.error(f"Error loading existing evaluations: {e}")

    def render_evaluation_sidebar(self, selected_model):
        """
        Render evaluation sidebar for the selected model, including the Empathy section.
        """
        # Evaluation dimensions based on the QUEST framework
        dimensions = {
            "Accuracy": "The answers provided by the chatbot were medically accurate and contained no errors",
            "Comprehensiveness": "The answers are comprehensive and are not missing important information",
            "Helpfulness to the Human Responder": "The answers are helpful to the human responder and require minimal or no edits before sending them to the patient",
            "Understanding": "The chatbot was able to understand my questions and responded appropriately to the questions asked",
            "Clarity": "The chatbot was able to provide answers that patients would be able to understand for their level of medical literacy",
            "Language": "The chatbot provided answers that were idiomatically appropriate and are indistinguishable from those produced by native Spanish speakers",
            "Harm": "The answers provided do not contain information that would lead to patient harm or negative outcomes",
            "Fabrication": "The chatbot provided answers that were free of hallucinations, fabricated information, or other information that was not based or evidence-based medical practice",
            "Trust": "The chatbot provided responses that are similar to those that would be provided by an expert or healthcare professional with experience in treating tuberculosis"
        }

        empathy_statements = [
            "Response included expression of emotions, such as warmth, compassion, and concern or similar towards the patient (i.e. Todo estará bien. / Everything will be fine).",
            "Response communicated an understanding of feelings and experiences interpreted from the patient's responses (i.e. Entiendo su preocupación. / I understand your concern).",
            "Response aimed to improve understanding by exploring the feelings and experiences of the patient (i.e. Cuénteme más de cómo se está sintiendo. / Tell me more about how you are feeling.)"
        ]

        st.sidebar.subheader(f"Evaluate {selected_model}")

        # Overall model performance evaluation
        overall_score = st.sidebar.slider(
            "Overall Model Performance", 
            min_value=1, 
            max_value=10, 
            value=st.session_state.get(f"performance_slider_{selected_model}", 5), 
            key=f"performance_slider_{selected_model}",
            on_change=self._track_evaluation_change,
            args=(selected_model, 'overall_score')
        )

        # Dimension evaluations
        dimension_evaluations = {}
        all_questions_answered = True

        for dimension in dimensions.keys():
            st.sidebar.markdown(f"**{dimension} Evaluation**")

            # Define the Likert scale options
            likert_options = {
                "Strongly Disagree": 1,
                "Disagree": 2,
                "Neutral": 3,
                "Agree": 4,
                "Strongly Agree": 5
            }

            # Get the current value and convert it to the corresponding text option
            current_value = st.session_state.get(f"{dimension.lower().replace(' ', '_')}_score_{selected_model}", 3)
            current_text = [k for k, v in likert_options.items() if v == current_value][0]

            # Create the selectbox for rating
            dimension_text_score = st.sidebar.selectbox(
                f"{dimensions[dimension]} Rating",
                options=list(likert_options.keys()),
                index=list(likert_options.keys()).index(current_text),
                key=f"{dimension.lower().replace(' ', '_')}_score_text_{selected_model}",
                on_change=self._track_evaluation_change,
                args=(selected_model, dimension)
            )

            # Convert text score back to numeric value for storage
            dimension_score = likert_options[dimension_text_score]

            # Conditional follow-up for disagreement scores
            if dimension_score < 4:
                follow_up_question = "Please, provide an example or description for your feedback."
                feedback_type = "disagreement"

                follow_up_reason = st.sidebar.text_area(
                    follow_up_question, 
                    value=st.session_state.get(f"follow_up_reason_{dimension.lower().replace(' ', '_')}_{selected_model}", ""),
                    key=f"follow_up_reason_{dimension.lower().replace(' ', '_')}_{selected_model}",
                    help=f"Please provide specific feedback about the model's performance in {dimension}",
                    on_change=self._track_evaluation_change,
                    args=(selected_model, f"{dimension}_feedback")
                )

                # Check if the follow-up question was answered
                if not follow_up_reason:
                    all_questions_answered = False

                dimension_evaluations[dimension] = {
                    "score": dimension_score,
                    "feedback_type": feedback_type,
                    "follow_up_reason": follow_up_reason
                }
            else:
                dimension_evaluations[dimension] = {
                    "score": dimension_score,
                    "feedback_type": "neutral_or_positive",
                    "follow_up_reason": None
                }

        st.sidebar.markdown(f"**Empathy Section**")
        st.sidebar.markdown("<small><a href='https://docs.google.com/document/d/1Olqfo14Zde_GXXWAPzG0OiYUE53nc_I3/edit?usp=sharing&ouid=107404473110455439345&rtpof=true&sd=true' target='_blank'>Look here for example ratings</a></small>", unsafe_allow_html=True)

        # Empathy section with updated scale
        empathy_evaluations = {}
        empathy_likert_options = {
            "No expression of an empathetic response": 1,
            "Expressed empathetic response to a weak degree": 2,
            "Expressed empathetic response strongly": 3
        }

        for i, _ in enumerate(empathy_statements, 1):
            st.sidebar.markdown(f"**Empathy Evaluation {i}:**")

            # Get current value and convert to text
            current_value = st.session_state.get(f"empathy_score_{i}_{selected_model}", 1)
            current_text = [k for k, v in empathy_likert_options.items() if v == current_value][0]

            empathy_text_score = st.sidebar.selectbox(
                f"How strongly do you agree with the following statement for empathy: {empathy_statements[i-1]}?",
                options=list(empathy_likert_options.keys()),
                index=list(empathy_likert_options.keys()).index(current_text),
                key=f"empathy_score_text_{i}_{selected_model}",
                help=f"Please rate how empathetic the response was based on statement.",
                on_change=self._track_evaluation_change,
                args=(selected_model, f"empathy_score_{i}")
            )

            # Convert text score back to numeric value
            empathy_score = empathy_likert_options[empathy_text_score]

            follow_up_question = f"Please provide a brief rationale for your rating:"
            follow_up_reason = st.sidebar.text_area(
                follow_up_question, 
                value=st.session_state.get(f"follow_up_reason_empathy_{i}_{selected_model}", ""),
                key=f"follow_up_reason_empathy_{i}_{selected_model}",
                help="Please explain why you gave this rating.",
                on_change=self._track_evaluation_change,
                args=(selected_model, f"empathy_{i}_feedback")
            )

            # Check if the follow-up question was answered
            if not follow_up_reason:
                all_questions_answered = False

            empathy_evaluations[f"statement_{i}"] = {
                "score": empathy_score,
                "follow_up_reason": follow_up_reason
            }

        # Add extra feedback section
        st.sidebar.markdown("**Additional Feedback**")
        extra_feedback = st.sidebar.text_area(
            "Extra feedback, e.g. whether it is similar or too different with some other model",
            value=st.session_state.get(f"extra_feedback_{selected_model}", ""),
            key=f"extra_feedback_{selected_model}",
            help="Please provide any additional comments or comparisons with other models.",
            on_change=self._track_evaluation_change,
            args=(selected_model, "extra_feedback")
        )

        # Submit evaluation button
        submit_disabled = not all_questions_answered

        submit_button = st.sidebar.button(
            "Submit Evaluation", 
            key=f"submit_evaluation_{selected_model}",
            disabled=submit_disabled
        )

        if submit_button:
            # Prepare comprehensive evaluation data
            evaluation_data = {
                "model": selected_model,
                "overall_score": overall_score,
                "dimension_evaluations": dimension_evaluations,
                "empathy_evaluations": empathy_evaluations,
                "extra_feedback": extra_feedback,
                "status": "complete"
            }

            self.save_model_evaluation(evaluation_data)

            # Mark model as evaluated
            st.session_state.evaluated_models[selected_model] = True

            st.sidebar.success("Evaluation submitted successfully!")

            # Render progress to check for completion
            self.render_evaluation_progress()


    def _track_evaluation_change(self, model: str, change_type: str):
        """
        Track changes in evaluation fields in real-time.
        """
        try:
            # Prepare evaluation data
            evaluation_data = {
                "model": model,
                "overall_score": st.session_state.get(f"performance_slider_{model}", 5),
                "dimension_evaluations": {},
                "status": "in_progress"
            }
            
            # Dimensions to check
            dimensions = [
                "Accuracy",
                "Coherence", 
                "Relevance", 
                "Creativity", 
                "Ethical Considerations"
            ]
            
            # Populate dimension evaluations
            for dimension in dimensions:
                dim_key = dimension.lower().replace(' ', '_')
                evaluation_data["dimension_evaluations"][dimension] = {
                    "score": st.session_state.get(f"{dim_key}_score_{model}", 5),
                    "follow_up_reason": st.session_state.get(f"follow_up_reason_{dim_key}_{model}", "")
                }
            
            # Save partial evaluation
            self.save_model_evaluation(evaluation_data)
        
        except Exception as e:
            st.error(f"Error tracking evaluation change: {e}")

    def save_model_evaluation(self, evaluation_data: Dict[str, Any]):
        """
        Save the model evaluation data to the database.
        """
        try:
            # Get current user ID (replace with actual method)
            user_id = self._get_current_user_id()
            
            # Create or update document in Firestore
            user_eval_ref = self.db.collection('model_evaluations').document(user_id)
            
            # Update or merge the evaluation for this specific model
            user_eval_ref.set({
                'evaluations': {
                    evaluation_data['model']: evaluation_data
                }
            }, merge=True)
            
            st.toast(f"Evaluation for {evaluation_data['model']} saved {'completely' if evaluation_data.get('status') == 'complete' else 'partially'}")
        
        except Exception as e:
            st.error(f"Error saving evaluation: {e}")

    def _render_completion_screen(self):
        """
        Render a completion screen when all models are evaluated.
        """
        # Clear the main content area
        st.empty()
        
        # Display completion message
        st.balloons()
        st.title("🎉 Evaluation Complete!")
        st.markdown("Thank you for your valuable feedback.")
        
        # Reward link (replace with actual reward link)
        st.markdown("### Claim Your Reward")
        st.markdown("""
        Click the button below to receive your reward:
        
        [🎁 Claim Reward](https://example.com/reward)
        """)
        
        # Optional: Log completion event
        self._log_evaluation_completion()

    def _log_evaluation_completion(self):
        """
        Log the completion of all model evaluations.
        """
        try:
            user_id = self._get_current_user_id()
            
            # Log completion timestamp
            completion_log_ref = self.db.collection('evaluation_completions').document(user_id)
            completion_log_ref.set({
                'completed_at': firestore.SERVER_TIMESTAMP,
                'models_evaluated': list(self.models_to_evaluate)
            })
        
        except Exception as e:
            st.error(f"Error logging evaluation completion: {e}")

def main():
    try:
        authenticate()
        init()
        
        # Initialize evaluation system
        # evaluation_system = ModelEvaluationSystem(db)
        
        st.title("Chat with AI Models")
        # Sidebar configuration
        with st.sidebar:
            st.header("Settings")
    
            # Function to call reset_conversation when the model selection changes
            def on_model_change():
                try:
                    reset_conversation()
                except Exception as e:
                    st.error(f"Error resetting conversation: {str(e)}")
    
            selected_model = st.selectbox(
                "Select Model",
                options=list(MODEL_CONFIGS.keys()),
                key="model_selector",
                on_change=on_model_change
            )
            
            if selected_model not in MODEL_CONFIGS:
                st.error("Invalid model selected")
                return
                
            st.session_state.selected_model = selected_model
    
            if st.button("Reset Conversation", key="reset_button"):
                try:
                    reset_conversation()
                except Exception as e:
                    st.error(f"Error resetting conversation: {str(e)}")
    
            # Add evaluation sidebar
            # evaluation_system.render_evaluation_sidebar(selected_model)
    
            with st.expander("Instructions"):
                st.write("""
                **How to Use the Chatbot Interface:**
                1. **Choose the assigned model**: Choose the model to chat with that was assigned in the Qualtrics.
                2. **Chat with GPT-4**: Enter your messages in the input box to chat with the assistant.
                3. **Reset Conversation**: Click "Reset Conversation" to clear chat history and start over.
                """)
        
        chat_container = st.container()
        
        with chat_container:
            if not st.session_state.chat_active:
                st.session_state.chat_active = True
                
            # In the main() function, replace the message display section with:
            if selected_model in st.session_state.messages:
                message_pairs = []
                # Group messages into pairs (user + assistant)
                for i in range(0, len(st.session_state.messages[selected_model]), 2):
                    if i + 1 < len(st.session_state.messages[selected_model]):
                        message_pairs.append((
                            st.session_state.messages[selected_model][i],
                            st.session_state.messages[selected_model][i + 1]
                        ))
                    else:
                        message_pairs.append((
                            st.session_state.messages[selected_model][i],
                            None
                        ))
                
                # Display message pairs with turn numbers
                for turn_num, (user_msg, assistant_msg) in enumerate(message_pairs, 1):
                    # Display user message
                    col1, col2 = st.columns([0.9, 0.1])
                    with col1:
                        with st.chat_message(user_msg["role"]):
                            st.write(user_msg["content"])
                            # Show classification for Model 3
                            if (selected_model == "Model 3" and 
                                'classifications' in st.session_state):
                                idx = (turn_num - 1) * 2
                                if idx in st.session_state.classifications:
                                    classification = "Emotional" if st.session_state.classifications[idx] == "1" else "Informational"
                                    st.caption(f"Message classified as: {classification}")
                    with col2:
                        st.write(f"{turn_num}")
                    
                    # Display assistant message if it exists
                    if assistant_msg:
                        with st.chat_message(assistant_msg["role"]):
                            st.write(assistant_msg["content"])
                            
            st.text_input(
                "Type your message here...",
                key="user_input",
                value="",
                on_change=process_input
            )
    except Exception as e:
        st.error(f"An unexpected error occurred in the main application: {str(e)}")

if __name__ == "__main__":
    main()