File size: 5,043 Bytes
6830eb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from __future__ import annotations
import os
from pathlib import Path
import yaml

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableSequence
from langgraph.prebuilt import ValidationNode

from config.settings import settings
from forms.schemas import ExtractedNotes, SOAPNote, DAPNote, BIRPNote, PIRPNote, GIRPNote, SIRPNote, FAIRFDARPNote, DARENote, PIENote, SOAPIERNote, SOAPIENote, POMRNote, NarrativeNote, CBENote, SBARNote
from utils.youtube import download_transcript
from utils.text_processing import chunk_text
from models.llm_provider import get_llm
from langchain.globals import set_llm_cache
from langchain.cache import SQLiteCache

set_llm_cache(SQLiteCache(database_path=".langchain.db"))

from dotenv import load_dotenv

load_dotenv()

# Set environment for LangSmith tracing/logging
os.environ["LANGCHAIN_TRACING_V2"] = "true"
if settings.LANGCHAIN_API_KEY:
    os.environ["LANGCHAIN_API_KEY"] = settings.LANGCHAIN_API_KEY

def load_prompt(note_type: str) -> tuple[str, str]:
    """Load the prompt template from YAML for the specified note type."""
    prompt_path = Path("langhub/prompts/therapy_extraction_prompt.yaml")
    with open(prompt_path, "r") as f:
        data = yaml.safe_load(f)
    
    note_prompts = data.get("prompts", {}).get(note_type.lower())
    if not note_prompts:
        raise ValueError(f"No prompt template found for note type: {note_type}")
    
    return note_prompts["system"], note_prompts["human"]

def create_extraction_chain(note_type: str = "soap") -> RunnableSequence:
    """Create a chain for extracting structured notes."""
    print(f"Creating extraction chain for {note_type.upper()} notes...")
    
    print("Initializing LLM...")
    llm = get_llm()
    
    print("Setting up schema mapping...")
    # Select the appropriate schema based on note type
    schema_map = {
        "soap": SOAPNote,
        "dap": DAPNote,
        "birp": BIRPNote,
        "birp_raw": BIRPNote,
        "pirp": PIRPNote,
        "girp": GIRPNote,
        "sirp": SIRPNote,
        "fair_fdarp": FAIRFDARPNote,
        "dare": DARENote,
        "pie": PIENote,
        "soapier": SOAPIERNote,
        "soapie": SOAPIENote,
        "pomr": POMRNote,
        "narrative": NarrativeNote,
        "cbe": CBENote,
        "sbar": SBARNote
    }
    schema = schema_map.get(note_type.lower())

    if not schema:
        raise ValueError(f"Unsupported note type: {note_type}")
    
    print("Creating structured LLM output...")
    # Create structured LLM
    structured_llm = llm.with_structured_output(schema=schema, include_raw=True)
    
    print("Loading system prompt...")
    # Load system prompt and human prompt for the specific note type
    system_prompt, human_prompt = load_prompt(note_type)
    
    print("Creating prompt template...")
    # Create prompt template
    prompt_template = ChatPromptTemplate.from_messages([
        ("system", system_prompt),
        ("human", human_prompt)
    ])
    
    print("Building extraction chain...")
    # Create the chain
    chain = prompt_template | structured_llm
    
    print("Extraction chain created successfully")
    return chain

def process_session(url: str, note_type: str = "soap") -> dict:
    """Process a single therapy session."""
    try:
        # Download transcript
        print(f"Downloading transcript from {url}...")
        transcript = download_transcript(url)
        
        # Create extraction chain
        chain = create_extraction_chain(note_type)
        
        # Process transcript
        print("Extracting structured notes...")
        result = chain.invoke({
            "note_type": note_type.upper(),
            "text": transcript
        })
        
        return result.model_dump()
        
    except Exception as e:
        print(f"Error processing session: {str(e)}")
        return {}

def main():
    # Example YouTube sessions
    sessions = [
        {
            "title": "CBT Role-Play – Complete Session – Part 6",
            "url": "https://www.youtube.com/watch?v=KuHLL2AE-SE"
        },
        {
            "title": "CBT Role-Play – Complete Session – Part 7",
            "url": "https://www.youtube.com/watch?v=jS1KE3_Pqlc"
        }
    ]
    
    for session in sessions:
        print(f"\nProcessing session: {session['title']}")
        
        # Extract notes in different formats
        note_types = ["soap", "dap", "birp"]
        results = {}
        
        for note_type in note_types:
            print(f"\nExtracting {note_type.upper()} notes...")
            result = process_session(session["url"], note_type)
            results[note_type] = result
        
        # Print results
        print(f"\nResults for '{session['title']}':")
        for note_type, notes in results.items():
            print(f"\n{note_type.upper()} Notes:")
            print(yaml.dump(notes, default_flow_style=False))

if __name__ == "__main__":
    main()