File size: 9,863 Bytes
523a4a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import traceback
from datetime import datetime
from pathlib import Path
import os
import random
import string
import tempfile
import re
import io
import PyPDF2
import docx
from reportlab.pdfgen import canvas
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.enums import TA_JUSTIFY
from ai_config import n_of_questions, load_model, openai_api_key, convert_text_to_speech
from knowledge_retrieval import setup_knowledge_retrieval, generate_report
from ai_config import n_of_questions, openai_api_key, load_model

# Initialize settings
n_of_questions = n_of_questions()
current_datetime = datetime.now()
human_readable_datetime = current_datetime.strftime("%B %d, %Y at %H:%M")
current_date = current_datetime.strftime("%Y-%m-%d")

# Initialize the model and retrieval chain
try:
    llm = load_model(openai_api_key)
    interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval(llm)
    knowledge_base_connected = True
    print("Successfully connected to the knowledge base.")
except Exception as e:
    print(f"Error initializing the model or retrieval chain: {str(e)}")
    knowledge_base_connected = False
    print("Falling back to basic mode without knowledge base.")

question_count = 0
interview_history = []
last_audio_path = None  # Variable to store the path of the last audio file
initial_audio_path = None  # Variable to store the path of the initial audio file
language = None

def generate_random_string(length=5):
    return ''.join(random.choices(string.ascii_letters + string.digits, k=length))

def respond(message, history):
    global question_count, interview_history, combined_retriever, last_audio_path, initial_audio_path, language, interview_retrieval_chain, report_retrieval_chain

    if not isinstance(history, list):
        history = []
    if not history or not history[-1]:
        history.append(["", ""])

    # Extract the actual message text
    if isinstance(message, list):
        message = message[-1][0] if message and isinstance(message[-1], list) else message[-1]

    question_count += 1
    interview_history.append(f"Q{question_count}: {message}")
    history_str = "\n".join(interview_history)

    try:
        if knowledge_base_connected:
            if question_count == 1:
                # Capture the language from the first response
                language = message.strip().lower()
                # Reinitialize the interview chain with the new language
                interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval(
                    llm, language)

            if question_count < n_of_questions:
                result = interview_retrieval_chain.invoke({
                    "input": f"Based on the patient's statement: '{message}', what should be the next question?",
                    "history": history_str,
                    "question_number": question_count + 1,
                    "language": language
                })
                question = result.get("answer", f"Can you tell me more about that? (in {language})")
            else:
                result = generate_report(report_retrieval_chain, interview_history, language)
                question = result
                speech_file_path = None  # Skip audio generation for the report

            if question:
                random_suffix = generate_random_string()
                speech_file_path = Path(__file__).parent / f"question_{question_count}_{random_suffix}.mp3"
                convert_text_to_speech(question, speech_file_path)
                print(f"Question {question_count} saved as audio at {speech_file_path}")

                # Remove the last audio file if it exists
                if last_audio_path and os.path.exists(last_audio_path):
                    os.remove(last_audio_path)
                last_audio_path = speech_file_path
            else:
                speech_file_path = None  # Skip audio generation for the report

        else:
            # Fallback mode without knowledge base
            question = f"Can you elaborate on that? (in {language})"
            if question_count < n_of_questions:
                speech_file_path = Path(__file__).parent / f"question_{question_count}.mp3"
                convert_text_to_speech(question, speech_file_path)
                print(f"Question {question_count} saved as audio at {speech_file_path}")

                if last_audio_path and os.path.exists(last_audio_path):
                    os.remove(last_audio_path)
                last_audio_path = speech_file_path
            else:
                speech_file_path = None

        history[-1][1] = f"{question}"

        # Remove the initial question audio file after the first user response
        if initial_audio_path and os.path.exists(initial_audio_path):
            os.remove(initial_audio_path)
        initial_audio_path = None

        return history, str(speech_file_path) if speech_file_path else None

    except Exception as e:
        print(f"Error in retrieval chain: {str(e)}")
        print(traceback.format_exc())
        return history, None


def reset_interview():
    """Reset the interview state."""
    global question_count, interview_history, last_audio_path, initial_audio_path
    question_count = 0
    interview_history = []
    if last_audio_path and os.path.exists(last_audio_path):
        os.remove(last_audio_path)
    last_audio_path = None
    initial_audio_path = None


def read_file(file):
    if file is None:
        return "No file uploaded"

    if isinstance(file, str):
        with open(file, 'r', encoding='utf-8') as f:
            return f.read()

    if hasattr(file, 'name'):  # Check if it's a file-like object
        if file.name.endswith('.txt'):
            return file.content
        elif file.name.endswith('.pdf'):
            pdf_reader = PyPDF2.PdfReader(io.BytesIO(file.content))
            return "\n".join(page.extract_text() for page in pdf_reader.pages)
        elif file.name.endswith('.docx'):
            doc = docx.Document(io.BytesIO(file.content))
            return "\n".join(paragraph.text for paragraph in doc.paragraphs)
        else:
            return "Unsupported file format"

    return "Unable to read file"

def generate_report_from_file(file, language):
    try:
        file_content = read_file(file)
        if file_content == "No file uploaded" or file_content == "Unsupported file format" or file_content == "Unable to read file":
            return file_content

        report_language = language.strip().lower() if language else "english"
        print('preferred language:', report_language)
        print(f"Generating report in language: {report_language}")  # For debugging

        # Reinitialize the report chain with the new language
        _, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language)

        result = report_retrieval_chain.invoke({
            "input": "Please provide a clinical report based on the following content:",
            "history": file_content,
            "language": report_language
        })
        report_content = result.get("answer", "Unable to generate report due to insufficient information.")
        pdf_path = create_pdf(report_content)
        return report_content, pdf_path
    except Exception as e:
        return f"An error occurred while processing the file: {str(e)}", None


def generate_interview_report(interview_history, language):
    try:
        report_language = language.strip().lower() if language else "english"
        print('preferred report_language language:', report_language)
        _, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language)

        result = report_retrieval_chain.invoke({
            "input": "Please provide a clinical report based on the following interview:",
            "history": "\n".join(interview_history),
            "language": report_language
        })
        report_content = result.get("answer", "Unable to generate report due to insufficient information.")
        pdf_path = create_pdf(report_content)
        return report_content, pdf_path
    except Exception as e:
        return f"An error occurred while generating the report: {str(e)}", None

def create_pdf(content):
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.pdf')
    doc = SimpleDocTemplate(temp_file.name, pagesize=letter)
    styles = getSampleStyleSheet()

    # Create a custom style for bold text
    bold_style = ParagraphStyle('Bold', parent=styles['Normal'], fontName='Helvetica-Bold', fontSize=10)

    # Create a custom style for normal text with justification
    normal_style = ParagraphStyle('Normal', parent=styles['Normal'], alignment=TA_JUSTIFY)

    flowables = []

    for line in content.split('\n'):
        # Use regex to find words surrounded by **
        parts = re.split(r'(\*\*.*?\*\*)', line)
        paragraph_parts = []

        for part in parts:
            if part.startswith('**') and part.endswith('**'):
                # Bold text
                bold_text = part.strip('**')
                paragraph_parts.append(Paragraph(bold_text, bold_style))
            else:
                # Normal text
                paragraph_parts.append(Paragraph(part, normal_style))

        flowables.extend(paragraph_parts)
        flowables.append(Spacer(1, 12))  # Add space between paragraphs

    doc.build(flowables)
    return temp_file.name