reab5555 commited on
Commit
07ab0e3
·
verified ·
1 Parent(s): f67326b

Update settings.py

Browse files
Files changed (1) hide show
  1. settings.py +231 -232
settings.py CHANGED
@@ -1,233 +1,232 @@
1
- import traceback
2
- from datetime import datetime
3
- from pathlib import Path
4
- import os
5
- import random
6
- import string
7
- import tempfile
8
- import re
9
- import io
10
- import PyPDF2
11
- import docx
12
- from reportlab.pdfgen import canvas
13
- from reportlab.lib.pagesizes import letter
14
- from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
15
- from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
16
- from reportlab.lib.enums import TA_JUSTIFY
17
- from ai_config import n_of_questions, load_model, openai_api_key, convert_text_to_speech
18
- from knowledge_retrieval import setup_knowledge_retrieval, generate_report
19
- from ai_config import n_of_questions, openai_api_key, load_model
20
-
21
- # Initialize settings
22
- n_of_questions = n_of_questions()
23
- current_datetime = datetime.now()
24
- human_readable_datetime = current_datetime.strftime("%B %d, %Y at %H:%M")
25
- current_date = current_datetime.strftime("%Y-%m-%d")
26
-
27
- # Initialize the model and retrieval chain
28
- try:
29
- llm = load_model(openai_api_key)
30
- interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval(llm)
31
- knowledge_base_connected = True
32
- print("Successfully connected to the knowledge base.")
33
- except Exception as e:
34
- print(f"Error initializing the model or retrieval chain: {str(e)}")
35
- knowledge_base_connected = False
36
- print("Falling back to basic mode without knowledge base.")
37
-
38
- question_count = 0
39
- interview_history = []
40
- last_audio_path = None # Variable to store the path of the last audio file
41
- initial_audio_path = None # Variable to store the path of the initial audio file
42
- language = None
43
-
44
- def generate_random_string(length=5):
45
- return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
46
-
47
- def respond(message, history):
48
- global question_count, interview_history, combined_retriever, last_audio_path, initial_audio_path, language, interview_retrieval_chain, report_retrieval_chain
49
-
50
- if not isinstance(history, list):
51
- history = []
52
- if not history or not history[-1]:
53
- history.append(["", ""])
54
-
55
- # Extract the actual message text
56
- if isinstance(message, list):
57
- message = message[-1][0] if message and isinstance(message[-1], list) else message[-1]
58
-
59
- question_count += 1
60
- interview_history.append(f"Q{question_count}: {message}")
61
- history_str = "\n".join(interview_history)
62
-
63
- try:
64
- if knowledge_base_connected:
65
- if question_count == 1:
66
- # Capture the language from the first response
67
- language = message.strip().lower()
68
- # Reinitialize the interview chain with the new language
69
- interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval(
70
- llm, language)
71
-
72
- if question_count < n_of_questions:
73
- result = interview_retrieval_chain.invoke({
74
- "input": f"Based on the patient's statement: '{message}', what should be the next question?",
75
- "history": history_str,
76
- "question_number": question_count + 1,
77
- "language": language
78
- })
79
- question = result.get("answer", f"Can you tell me more about that? (in {language})")
80
- else:
81
- result = generate_report(report_retrieval_chain, interview_history, language)
82
- question = result
83
- speech_file_path = None # Skip audio generation for the report
84
-
85
- if question:
86
- random_suffix = generate_random_string()
87
- speech_file_path = Path(__file__).parent / f"question_{question_count}_{random_suffix}.mp3"
88
- convert_text_to_speech(question, speech_file_path)
89
- print(f"Question {question_count} saved as audio at {speech_file_path}")
90
-
91
- # Remove the last audio file if it exists
92
- if last_audio_path and os.path.exists(last_audio_path):
93
- os.remove(last_audio_path)
94
- last_audio_path = speech_file_path
95
- else:
96
- speech_file_path = None # Skip audio generation for the report
97
-
98
- else:
99
- # Fallback mode without knowledge base
100
- question = f"Can you elaborate on that? (in {language})"
101
- if question_count < n_of_questions:
102
- speech_file_path = Path(__file__).parent / f"question_{question_count}.mp3"
103
- convert_text_to_speech(question, speech_file_path)
104
- print(f"Question {question_count} saved as audio at {speech_file_path}")
105
-
106
- if last_audio_path and os.path.exists(last_audio_path):
107
- os.remove(last_audio_path)
108
- last_audio_path = speech_file_path
109
- else:
110
- speech_file_path = None
111
-
112
- history[-1][1] = f"{question}"
113
-
114
- # Remove the initial question audio file after the first user response
115
- if initial_audio_path and os.path.exists(initial_audio_path):
116
- os.remove(initial_audio_path)
117
- initial_audio_path = None
118
-
119
- return history, str(speech_file_path) if speech_file_path else None
120
-
121
- except Exception as e:
122
- print(f"Error in retrieval chain: {str(e)}")
123
- print(traceback.format_exc())
124
- return history, None
125
-
126
-
127
- def reset_interview():
128
- """Reset the interview state."""
129
- global question_count, interview_history, last_audio_path, initial_audio_path
130
- question_count = 0
131
- interview_history = []
132
- if last_audio_path and os.path.exists(last_audio_path):
133
- os.remove(last_audio_path)
134
- last_audio_path = None
135
- initial_audio_path = None
136
-
137
-
138
- def read_file(file):
139
- if file is None:
140
- return "No file uploaded"
141
-
142
- if isinstance(file, str):
143
- with open(file, 'r', encoding='utf-8') as f:
144
- return f.read()
145
-
146
- if hasattr(file, 'name'): # Check if it's a file-like object
147
- if file.name.endswith('.txt'):
148
- return file.content
149
- elif file.name.endswith('.pdf'):
150
- pdf_reader = PyPDF2.PdfReader(io.BytesIO(file.content))
151
- return "\n".join(page.extract_text() for page in pdf_reader.pages)
152
- elif file.name.endswith('.docx'):
153
- doc = docx.Document(io.BytesIO(file.content))
154
- return "\n".join(paragraph.text for paragraph in doc.paragraphs)
155
- else:
156
- return "Unsupported file format"
157
-
158
- return "Unable to read file"
159
-
160
- def generate_report_from_file(file, language):
161
- try:
162
- file_content = read_file(file)
163
- if file_content == "No file uploaded" or file_content == "Unsupported file format" or file_content == "Unable to read file":
164
- return file_content
165
-
166
- report_language = language.strip().lower() if language else "english"
167
- print('preferred language:', report_language)
168
- print(f"Generating report in language: {report_language}") # For debugging
169
-
170
- # Reinitialize the report chain with the new language
171
- _, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language)
172
-
173
- result = report_retrieval_chain.invoke({
174
- "input": "Please provide a clinical report based on the following content:",
175
- "history": file_content,
176
- "language": report_language
177
- })
178
- report_content = result.get("answer", "Unable to generate report due to insufficient information.")
179
- pdf_path = create_pdf(report_content)
180
- return report_content, pdf_path
181
- except Exception as e:
182
- return f"An error occurred while processing the file: {str(e)}", None
183
-
184
-
185
- def generate_interview_report(interview_history, language):
186
- try:
187
- report_language = language.strip().lower() if language else "english"
188
- print('preferred report_language language:', report_language)
189
- _, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language)
190
-
191
- result = report_retrieval_chain.invoke({
192
- "input": "Please provide a clinical report based on the following interview:",
193
- "history": "\n".join(interview_history),
194
- "language": report_language
195
- })
196
- report_content = result.get("answer", "Unable to generate report due to insufficient information.")
197
- pdf_path = create_pdf(report_content)
198
- return report_content, pdf_path
199
- except Exception as e:
200
- return f"An error occurred while generating the report: {str(e)}", None
201
-
202
- def create_pdf(content):
203
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.pdf')
204
- doc = SimpleDocTemplate(temp_file.name, pagesize=letter)
205
- styles = getSampleStyleSheet()
206
-
207
- # Create a custom style for bold text
208
- bold_style = ParagraphStyle('Bold', parent=styles['Normal'], fontName='Helvetica-Bold', fontSize=10)
209
-
210
- # Create a custom style for normal text with justification
211
- normal_style = ParagraphStyle('Normal', parent=styles['Normal'], alignment=TA_JUSTIFY)
212
-
213
- flowables = []
214
-
215
- for line in content.split('\n'):
216
- # Use regex to find words surrounded by **
217
- parts = re.split(r'(\*\*.*?\*\*)', line)
218
- paragraph_parts = []
219
-
220
- for part in parts:
221
- if part.startswith('**') and part.endswith('**'):
222
- # Bold text
223
- bold_text = part.strip('**')
224
- paragraph_parts.append(Paragraph(bold_text, bold_style))
225
- else:
226
- # Normal text
227
- paragraph_parts.append(Paragraph(part, normal_style))
228
-
229
- flowables.extend(paragraph_parts)
230
- flowables.append(Spacer(1, 12)) # Add space between paragraphs
231
-
232
- doc.build(flowables)
233
  return temp_file.name
 
1
+ import traceback
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+ import os
5
+ import random
6
+ import string
7
+ import tempfile
8
+ import re
9
+ import io
10
+ import PyPDF2
11
+ import docx
12
+ from reportlab.pdfgen import canvas
13
+ from reportlab.lib.pagesizes import letter
14
+ from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
15
+ from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
16
+ from reportlab.lib.enums import TA_JUSTIFY
17
+ from ai_config import n_of_questions, load_model, openai_api_key, convert_text_to_speech
18
+ from knowledge_retrieval import setup_knowledge_retrieval, generate_report
19
+
20
+ # Initialize settings
21
+ n_of_questions = n_of_questions()
22
+ current_datetime = datetime.now()
23
+ human_readable_datetime = current_datetime.strftime("%B %d, %Y at %H:%M")
24
+ current_date = current_datetime.strftime("%Y-%m-%d")
25
+
26
+ # Initialize the model and retrieval chain
27
+ try:
28
+ llm = load_model(openai_api_key)
29
+ interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval(llm)
30
+ knowledge_base_connected = True
31
+ print("Successfully connected to the knowledge base.")
32
+ except Exception as e:
33
+ print(f"Error initializing the model or retrieval chain: {str(e)}")
34
+ knowledge_base_connected = False
35
+ print("Falling back to basic mode without knowledge base.")
36
+
37
+ question_count = 0
38
+ interview_history = []
39
+ last_audio_path = None # Variable to store the path of the last audio file
40
+ initial_audio_path = None # Variable to store the path of the initial audio file
41
+ language = None
42
+
43
+ def generate_random_string(length=5):
44
+ return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
45
+
46
+ def respond(message, history):
47
+ global question_count, interview_history, combined_retriever, last_audio_path, initial_audio_path, language, interview_retrieval_chain, report_retrieval_chain
48
+
49
+ if not isinstance(history, list):
50
+ history = []
51
+ if not history or not history[-1]:
52
+ history.append(["", ""])
53
+
54
+ # Extract the actual message text
55
+ if isinstance(message, list):
56
+ message = message[-1][0] if message and isinstance(message[-1], list) else message[-1]
57
+
58
+ question_count += 1
59
+ interview_history.append(f"Q{question_count}: {message}")
60
+ history_str = "\n".join(interview_history)
61
+
62
+ try:
63
+ if knowledge_base_connected:
64
+ if question_count == 1:
65
+ # Capture the language from the first response
66
+ language = message.strip().lower()
67
+ # Reinitialize the interview chain with the new language
68
+ interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval(
69
+ llm, language)
70
+
71
+ if question_count < n_of_questions:
72
+ result = interview_retrieval_chain.invoke({
73
+ "input": f"Based on the patient's statement: '{message}', what should be the next question?",
74
+ "history": history_str,
75
+ "question_number": question_count + 1,
76
+ "language": language
77
+ })
78
+ question = result.get("answer", f"Can you tell me more about that? (in {language})")
79
+ else:
80
+ result = generate_report(report_retrieval_chain, interview_history, language)
81
+ question = result
82
+ speech_file_path = None # Skip audio generation for the report
83
+
84
+ if question:
85
+ random_suffix = generate_random_string()
86
+ speech_file_path = Path(__file__).parent / f"question_{question_count}_{random_suffix}.mp3"
87
+ convert_text_to_speech(question, speech_file_path)
88
+ print(f"Question {question_count} saved as audio at {speech_file_path}")
89
+
90
+ # Remove the last audio file if it exists
91
+ if last_audio_path and os.path.exists(last_audio_path):
92
+ os.remove(last_audio_path)
93
+ last_audio_path = speech_file_path
94
+ else:
95
+ speech_file_path = None # Skip audio generation for the report
96
+
97
+ else:
98
+ # Fallback mode without knowledge base
99
+ question = f"Can you elaborate on that? (in {language})"
100
+ if question_count < n_of_questions:
101
+ speech_file_path = Path(__file__).parent / f"question_{question_count}.mp3"
102
+ convert_text_to_speech(question, speech_file_path)
103
+ print(f"Question {question_count} saved as audio at {speech_file_path}")
104
+
105
+ if last_audio_path and os.path.exists(last_audio_path):
106
+ os.remove(last_audio_path)
107
+ last_audio_path = speech_file_path
108
+ else:
109
+ speech_file_path = None
110
+
111
+ history[-1][1] = f"{question}"
112
+
113
+ # Remove the initial question audio file after the first user response
114
+ if initial_audio_path and os.path.exists(initial_audio_path):
115
+ os.remove(initial_audio_path)
116
+ initial_audio_path = None
117
+
118
+ return history, str(speech_file_path) if speech_file_path else None
119
+
120
+ except Exception as e:
121
+ print(f"Error in retrieval chain: {str(e)}")
122
+ print(traceback.format_exc())
123
+ return history, None
124
+
125
+
126
+ def reset_interview():
127
+ """Reset the interview state."""
128
+ global question_count, interview_history, last_audio_path, initial_audio_path
129
+ question_count = 0
130
+ interview_history = []
131
+ if last_audio_path and os.path.exists(last_audio_path):
132
+ os.remove(last_audio_path)
133
+ last_audio_path = None
134
+ initial_audio_path = None
135
+
136
+
137
+ def read_file(file):
138
+ if file is None:
139
+ return "No file uploaded"
140
+
141
+ if isinstance(file, str):
142
+ with open(file, 'r', encoding='utf-8') as f:
143
+ return f.read()
144
+
145
+ if hasattr(file, 'name'): # Check if it's a file-like object
146
+ if file.name.endswith('.txt'):
147
+ return file.content
148
+ elif file.name.endswith('.pdf'):
149
+ pdf_reader = PyPDF2.PdfReader(io.BytesIO(file.content))
150
+ return "\n".join(page.extract_text() for page in pdf_reader.pages)
151
+ elif file.name.endswith('.docx'):
152
+ doc = docx.Document(io.BytesIO(file.content))
153
+ return "\n".join(paragraph.text for paragraph in doc.paragraphs)
154
+ else:
155
+ return "Unsupported file format"
156
+
157
+ return "Unable to read file"
158
+
159
+ def generate_report_from_file(file, language):
160
+ try:
161
+ file_content = read_file(file)
162
+ if file_content == "No file uploaded" or file_content == "Unsupported file format" or file_content == "Unable to read file":
163
+ return file_content
164
+
165
+ report_language = language.strip().lower() if language else "english"
166
+ print('preferred language:', report_language)
167
+ print(f"Generating report in language: {report_language}") # For debugging
168
+
169
+ # Reinitialize the report chain with the new language
170
+ _, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language)
171
+
172
+ result = report_retrieval_chain.invoke({
173
+ "input": "Please provide a clinical report based on the following content:",
174
+ "history": file_content,
175
+ "language": report_language
176
+ })
177
+ report_content = result.get("answer", "Unable to generate report due to insufficient information.")
178
+ pdf_path = create_pdf(report_content)
179
+ return report_content, pdf_path
180
+ except Exception as e:
181
+ return f"An error occurred while processing the file: {str(e)}", None
182
+
183
+
184
+ def generate_interview_report(interview_history, language):
185
+ try:
186
+ report_language = language.strip().lower() if language else "english"
187
+ print('preferred report_language language:', report_language)
188
+ _, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language)
189
+
190
+ result = report_retrieval_chain.invoke({
191
+ "input": "Please provide a clinical report based on the following interview:",
192
+ "history": "\n".join(interview_history),
193
+ "language": report_language
194
+ })
195
+ report_content = result.get("answer", "Unable to generate report due to insufficient information.")
196
+ pdf_path = create_pdf(report_content)
197
+ return report_content, pdf_path
198
+ except Exception as e:
199
+ return f"An error occurred while generating the report: {str(e)}", None
200
+
201
+ def create_pdf(content):
202
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.pdf')
203
+ doc = SimpleDocTemplate(temp_file.name, pagesize=letter)
204
+ styles = getSampleStyleSheet()
205
+
206
+ # Create a custom style for bold text
207
+ bold_style = ParagraphStyle('Bold', parent=styles['Normal'], fontName='Helvetica-Bold', fontSize=10)
208
+
209
+ # Create a custom style for normal text with justification
210
+ normal_style = ParagraphStyle('Normal', parent=styles['Normal'], alignment=TA_JUSTIFY)
211
+
212
+ flowables = []
213
+
214
+ for line in content.split('\n'):
215
+ # Use regex to find words surrounded by **
216
+ parts = re.split(r'(\*\*.*?\*\*)', line)
217
+ paragraph_parts = []
218
+
219
+ for part in parts:
220
+ if part.startswith('**') and part.endswith('**'):
221
+ # Bold text
222
+ bold_text = part.strip('**')
223
+ paragraph_parts.append(Paragraph(bold_text, bold_style))
224
+ else:
225
+ # Normal text
226
+ paragraph_parts.append(Paragraph(part, normal_style))
227
+
228
+ flowables.extend(paragraph_parts)
229
+ flowables.append(Spacer(1, 12)) # Add space between paragraphs
230
+
231
+ doc.build(flowables)
 
232
  return temp_file.name