TherapyNote / main.py
abagherp's picture
Upload folder using huggingface_hub
6830eb0 verified
from __future__ import annotations
import os
from pathlib import Path
import yaml
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableSequence
from langgraph.prebuilt import ValidationNode
from config.settings import settings
from forms.schemas import ExtractedNotes, SOAPNote, DAPNote, BIRPNote, PIRPNote, GIRPNote, SIRPNote, FAIRFDARPNote, DARENote, PIENote, SOAPIERNote, SOAPIENote, POMRNote, NarrativeNote, CBENote, SBARNote
from utils.youtube import download_transcript
from utils.text_processing import chunk_text
from models.llm_provider import get_llm
from langchain.globals import set_llm_cache
from langchain.cache import SQLiteCache
set_llm_cache(SQLiteCache(database_path=".langchain.db"))
from dotenv import load_dotenv
load_dotenv()
# Set environment for LangSmith tracing/logging
os.environ["LANGCHAIN_TRACING_V2"] = "true"
if settings.LANGCHAIN_API_KEY:
os.environ["LANGCHAIN_API_KEY"] = settings.LANGCHAIN_API_KEY
def load_prompt(note_type: str) -> tuple[str, str]:
"""Load the prompt template from YAML for the specified note type."""
prompt_path = Path("langhub/prompts/therapy_extraction_prompt.yaml")
with open(prompt_path, "r") as f:
data = yaml.safe_load(f)
note_prompts = data.get("prompts", {}).get(note_type.lower())
if not note_prompts:
raise ValueError(f"No prompt template found for note type: {note_type}")
return note_prompts["system"], note_prompts["human"]
def create_extraction_chain(note_type: str = "soap") -> RunnableSequence:
"""Create a chain for extracting structured notes."""
print(f"Creating extraction chain for {note_type.upper()} notes...")
print("Initializing LLM...")
llm = get_llm()
print("Setting up schema mapping...")
# Select the appropriate schema based on note type
schema_map = {
"soap": SOAPNote,
"dap": DAPNote,
"birp": BIRPNote,
"birp_raw": BIRPNote,
"pirp": PIRPNote,
"girp": GIRPNote,
"sirp": SIRPNote,
"fair_fdarp": FAIRFDARPNote,
"dare": DARENote,
"pie": PIENote,
"soapier": SOAPIERNote,
"soapie": SOAPIENote,
"pomr": POMRNote,
"narrative": NarrativeNote,
"cbe": CBENote,
"sbar": SBARNote
}
schema = schema_map.get(note_type.lower())
if not schema:
raise ValueError(f"Unsupported note type: {note_type}")
print("Creating structured LLM output...")
# Create structured LLM
structured_llm = llm.with_structured_output(schema=schema, include_raw=True)
print("Loading system prompt...")
# Load system prompt and human prompt for the specific note type
system_prompt, human_prompt = load_prompt(note_type)
print("Creating prompt template...")
# Create prompt template
prompt_template = ChatPromptTemplate.from_messages([
("system", system_prompt),
("human", human_prompt)
])
print("Building extraction chain...")
# Create the chain
chain = prompt_template | structured_llm
print("Extraction chain created successfully")
return chain
def process_session(url: str, note_type: str = "soap") -> dict:
"""Process a single therapy session."""
try:
# Download transcript
print(f"Downloading transcript from {url}...")
transcript = download_transcript(url)
# Create extraction chain
chain = create_extraction_chain(note_type)
# Process transcript
print("Extracting structured notes...")
result = chain.invoke({
"note_type": note_type.upper(),
"text": transcript
})
return result.model_dump()
except Exception as e:
print(f"Error processing session: {str(e)}")
return {}
def main():
# Example YouTube sessions
sessions = [
{
"title": "CBT Role-Play – Complete Session – Part 6",
"url": "https://www.youtube.com/watch?v=KuHLL2AE-SE"
},
{
"title": "CBT Role-Play – Complete Session – Part 7",
"url": "https://www.youtube.com/watch?v=jS1KE3_Pqlc"
}
]
for session in sessions:
print(f"\nProcessing session: {session['title']}")
# Extract notes in different formats
note_types = ["soap", "dap", "birp"]
results = {}
for note_type in note_types:
print(f"\nExtracting {note_type.upper()} notes...")
result = process_session(session["url"], note_type)
results[note_type] = result
# Print results
print(f"\nResults for '{session['title']}':")
for note_type, notes in results.items():
print(f"\n{note_type.upper()} Notes:")
print(yaml.dump(notes, default_flow_style=False))
if __name__ == "__main__":
main()