|
import traceback
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
import os
|
|
import random
|
|
import string
|
|
import tempfile
|
|
import re
|
|
import io
|
|
import PyPDF2
|
|
import docx
|
|
from reportlab.pdfgen import canvas
|
|
from reportlab.lib.pagesizes import letter
|
|
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
|
|
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
|
from reportlab.lib.enums import TA_JUSTIFY
|
|
from ai_config import n_of_questions, load_model, openai_api_key, convert_text_to_speech
|
|
from knowledge_retrieval import setup_knowledge_retrieval, generate_report
|
|
from ai_config import n_of_questions, openai_api_key, load_model
|
|
|
|
|
|
n_of_questions = n_of_questions()
|
|
current_datetime = datetime.now()
|
|
human_readable_datetime = current_datetime.strftime("%B %d, %Y at %H:%M")
|
|
current_date = current_datetime.strftime("%Y-%m-%d")
|
|
|
|
|
|
try:
|
|
llm = load_model(openai_api_key)
|
|
interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval(llm)
|
|
knowledge_base_connected = True
|
|
print("Successfully connected to the knowledge base.")
|
|
except Exception as e:
|
|
print(f"Error initializing the model or retrieval chain: {str(e)}")
|
|
knowledge_base_connected = False
|
|
print("Falling back to basic mode without knowledge base.")
|
|
|
|
question_count = 0
|
|
interview_history = []
|
|
last_audio_path = None
|
|
initial_audio_path = None
|
|
language = None
|
|
|
|
def generate_random_string(length=5):
|
|
return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
|
|
|
|
def respond(message, history):
|
|
global question_count, interview_history, combined_retriever, last_audio_path, initial_audio_path, language, interview_retrieval_chain, report_retrieval_chain
|
|
|
|
if not isinstance(history, list):
|
|
history = []
|
|
if not history or not history[-1]:
|
|
history.append(["", ""])
|
|
|
|
|
|
if isinstance(message, list):
|
|
message = message[-1][0] if message and isinstance(message[-1], list) else message[-1]
|
|
|
|
question_count += 1
|
|
interview_history.append(f"Q{question_count}: {message}")
|
|
history_str = "\n".join(interview_history)
|
|
|
|
try:
|
|
if knowledge_base_connected:
|
|
if question_count == 1:
|
|
|
|
language = message.strip().lower()
|
|
|
|
interview_retrieval_chain, report_retrieval_chain, combined_retriever = setup_knowledge_retrieval(
|
|
llm, language)
|
|
|
|
if question_count < n_of_questions:
|
|
result = interview_retrieval_chain.invoke({
|
|
"input": f"Based on the patient's statement: '{message}', what should be the next question?",
|
|
"history": history_str,
|
|
"question_number": question_count + 1,
|
|
"language": language
|
|
})
|
|
question = result.get("answer", f"Can you tell me more about that? (in {language})")
|
|
else:
|
|
result = generate_report(report_retrieval_chain, interview_history, language)
|
|
question = result
|
|
speech_file_path = None
|
|
|
|
if question:
|
|
random_suffix = generate_random_string()
|
|
speech_file_path = Path(__file__).parent / f"question_{question_count}_{random_suffix}.mp3"
|
|
convert_text_to_speech(question, speech_file_path)
|
|
print(f"Question {question_count} saved as audio at {speech_file_path}")
|
|
|
|
|
|
if last_audio_path and os.path.exists(last_audio_path):
|
|
os.remove(last_audio_path)
|
|
last_audio_path = speech_file_path
|
|
else:
|
|
speech_file_path = None
|
|
|
|
else:
|
|
|
|
question = f"Can you elaborate on that? (in {language})"
|
|
if question_count < n_of_questions:
|
|
speech_file_path = Path(__file__).parent / f"question_{question_count}.mp3"
|
|
convert_text_to_speech(question, speech_file_path)
|
|
print(f"Question {question_count} saved as audio at {speech_file_path}")
|
|
|
|
if last_audio_path and os.path.exists(last_audio_path):
|
|
os.remove(last_audio_path)
|
|
last_audio_path = speech_file_path
|
|
else:
|
|
speech_file_path = None
|
|
|
|
history[-1][1] = f"{question}"
|
|
|
|
|
|
if initial_audio_path and os.path.exists(initial_audio_path):
|
|
os.remove(initial_audio_path)
|
|
initial_audio_path = None
|
|
|
|
return history, str(speech_file_path) if speech_file_path else None
|
|
|
|
except Exception as e:
|
|
print(f"Error in retrieval chain: {str(e)}")
|
|
print(traceback.format_exc())
|
|
return history, None
|
|
|
|
|
|
def reset_interview():
|
|
"""Reset the interview state."""
|
|
global question_count, interview_history, last_audio_path, initial_audio_path
|
|
question_count = 0
|
|
interview_history = []
|
|
if last_audio_path and os.path.exists(last_audio_path):
|
|
os.remove(last_audio_path)
|
|
last_audio_path = None
|
|
initial_audio_path = None
|
|
|
|
|
|
def read_file(file):
|
|
if file is None:
|
|
return "No file uploaded"
|
|
|
|
if isinstance(file, str):
|
|
with open(file, 'r', encoding='utf-8') as f:
|
|
return f.read()
|
|
|
|
if hasattr(file, 'name'):
|
|
if file.name.endswith('.txt'):
|
|
return file.content
|
|
elif file.name.endswith('.pdf'):
|
|
pdf_reader = PyPDF2.PdfReader(io.BytesIO(file.content))
|
|
return "\n".join(page.extract_text() for page in pdf_reader.pages)
|
|
elif file.name.endswith('.docx'):
|
|
doc = docx.Document(io.BytesIO(file.content))
|
|
return "\n".join(paragraph.text for paragraph in doc.paragraphs)
|
|
else:
|
|
return "Unsupported file format"
|
|
|
|
return "Unable to read file"
|
|
|
|
def generate_report_from_file(file, language):
|
|
try:
|
|
file_content = read_file(file)
|
|
if file_content == "No file uploaded" or file_content == "Unsupported file format" or file_content == "Unable to read file":
|
|
return file_content
|
|
|
|
report_language = language.strip().lower() if language else "english"
|
|
print('preferred language:', report_language)
|
|
print(f"Generating report in language: {report_language}")
|
|
|
|
|
|
_, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language)
|
|
|
|
result = report_retrieval_chain.invoke({
|
|
"input": "Please provide a clinical report based on the following content:",
|
|
"history": file_content,
|
|
"language": report_language
|
|
})
|
|
report_content = result.get("answer", "Unable to generate report due to insufficient information.")
|
|
pdf_path = create_pdf(report_content)
|
|
return report_content, pdf_path
|
|
except Exception as e:
|
|
return f"An error occurred while processing the file: {str(e)}", None
|
|
|
|
|
|
def generate_interview_report(interview_history, language):
|
|
try:
|
|
report_language = language.strip().lower() if language else "english"
|
|
print('preferred report_language language:', report_language)
|
|
_, report_retrieval_chain, _ = setup_knowledge_retrieval(llm, report_language)
|
|
|
|
result = report_retrieval_chain.invoke({
|
|
"input": "Please provide a clinical report based on the following interview:",
|
|
"history": "\n".join(interview_history),
|
|
"language": report_language
|
|
})
|
|
report_content = result.get("answer", "Unable to generate report due to insufficient information.")
|
|
pdf_path = create_pdf(report_content)
|
|
return report_content, pdf_path
|
|
except Exception as e:
|
|
return f"An error occurred while generating the report: {str(e)}", None
|
|
|
|
def create_pdf(content):
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.pdf')
|
|
doc = SimpleDocTemplate(temp_file.name, pagesize=letter)
|
|
styles = getSampleStyleSheet()
|
|
|
|
|
|
bold_style = ParagraphStyle('Bold', parent=styles['Normal'], fontName='Helvetica-Bold', fontSize=10)
|
|
|
|
|
|
normal_style = ParagraphStyle('Normal', parent=styles['Normal'], alignment=TA_JUSTIFY)
|
|
|
|
flowables = []
|
|
|
|
for line in content.split('\n'):
|
|
|
|
parts = re.split(r'(\*\*.*?\*\*)', line)
|
|
paragraph_parts = []
|
|
|
|
for part in parts:
|
|
if part.startswith('**') and part.endswith('**'):
|
|
|
|
bold_text = part.strip('**')
|
|
paragraph_parts.append(Paragraph(bold_text, bold_style))
|
|
else:
|
|
|
|
paragraph_parts.append(Paragraph(part, normal_style))
|
|
|
|
flowables.extend(paragraph_parts)
|
|
flowables.append(Spacer(1, 12))
|
|
|
|
doc.build(flowables)
|
|
return temp_file.name |