Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import os | |
from pathlib import Path | |
import yaml | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.messages import HumanMessage, SystemMessage | |
from langchain_core.runnables import RunnableSequence | |
from langgraph.prebuilt import ValidationNode | |
from config.settings import settings | |
from forms.schemas import ExtractedNotes, SOAPNote, DAPNote, BIRPNote, PIRPNote, GIRPNote, SIRPNote, FAIRFDARPNote, DARENote, PIENote, SOAPIERNote, SOAPIENote, POMRNote, NarrativeNote, CBENote, SBARNote | |
from utils.youtube import download_transcript | |
from utils.text_processing import chunk_text | |
from models.llm_provider import get_llm | |
from langchain.globals import set_llm_cache | |
from langchain.cache import SQLiteCache | |
set_llm_cache(SQLiteCache(database_path=".langchain.db")) | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Set environment for LangSmith tracing/logging | |
os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
if settings.LANGCHAIN_API_KEY: | |
os.environ["LANGCHAIN_API_KEY"] = settings.LANGCHAIN_API_KEY | |
def load_prompt(note_type: str) -> tuple[str, str]: | |
"""Load the prompt template from YAML for the specified note type.""" | |
prompt_path = Path("langhub/prompts/therapy_extraction_prompt.yaml") | |
with open(prompt_path, "r") as f: | |
data = yaml.safe_load(f) | |
note_prompts = data.get("prompts", {}).get(note_type.lower()) | |
if not note_prompts: | |
raise ValueError(f"No prompt template found for note type: {note_type}") | |
return note_prompts["system"], note_prompts["human"] | |
def create_extraction_chain(note_type: str = "soap") -> RunnableSequence: | |
"""Create a chain for extracting structured notes.""" | |
print(f"Creating extraction chain for {note_type.upper()} notes...") | |
print("Initializing LLM...") | |
llm = get_llm() | |
print("Setting up schema mapping...") | |
# Select the appropriate schema based on note type | |
schema_map = { | |
"soap": SOAPNote, | |
"dap": DAPNote, | |
"birp": BIRPNote, | |
"birp_raw": BIRPNote, | |
"pirp": PIRPNote, | |
"girp": GIRPNote, | |
"sirp": SIRPNote, | |
"fair_fdarp": FAIRFDARPNote, | |
"dare": DARENote, | |
"pie": PIENote, | |
"soapier": SOAPIERNote, | |
"soapie": SOAPIENote, | |
"pomr": POMRNote, | |
"narrative": NarrativeNote, | |
"cbe": CBENote, | |
"sbar": SBARNote | |
} | |
schema = schema_map.get(note_type.lower()) | |
if not schema: | |
raise ValueError(f"Unsupported note type: {note_type}") | |
print("Creating structured LLM output...") | |
# Create structured LLM | |
structured_llm = llm.with_structured_output(schema=schema, include_raw=True) | |
print("Loading system prompt...") | |
# Load system prompt and human prompt for the specific note type | |
system_prompt, human_prompt = load_prompt(note_type) | |
print("Creating prompt template...") | |
# Create prompt template | |
prompt_template = ChatPromptTemplate.from_messages([ | |
("system", system_prompt), | |
("human", human_prompt) | |
]) | |
print("Building extraction chain...") | |
# Create the chain | |
chain = prompt_template | structured_llm | |
print("Extraction chain created successfully") | |
return chain | |
def process_session(url: str, note_type: str = "soap") -> dict: | |
"""Process a single therapy session.""" | |
try: | |
# Download transcript | |
print(f"Downloading transcript from {url}...") | |
transcript = download_transcript(url) | |
# Create extraction chain | |
chain = create_extraction_chain(note_type) | |
# Process transcript | |
print("Extracting structured notes...") | |
result = chain.invoke({ | |
"note_type": note_type.upper(), | |
"text": transcript | |
}) | |
return result.model_dump() | |
except Exception as e: | |
print(f"Error processing session: {str(e)}") | |
return {} | |
def main(): | |
# Example YouTube sessions | |
sessions = [ | |
{ | |
"title": "CBT Role-Play β Complete Session β Part 6", | |
"url": "https://www.youtube.com/watch?v=KuHLL2AE-SE" | |
}, | |
{ | |
"title": "CBT Role-Play β Complete Session β Part 7", | |
"url": "https://www.youtube.com/watch?v=jS1KE3_Pqlc" | |
} | |
] | |
for session in sessions: | |
print(f"\nProcessing session: {session['title']}") | |
# Extract notes in different formats | |
note_types = ["soap", "dap", "birp"] | |
results = {} | |
for note_type in note_types: | |
print(f"\nExtracting {note_type.upper()} notes...") | |
result = process_session(session["url"], note_type) | |
results[note_type] = result | |
# Print results | |
print(f"\nResults for '{session['title']}':") | |
for note_type, notes in results.items(): | |
print(f"\n{note_type.upper()} Notes:") | |
print(yaml.dump(notes, default_flow_style=False)) | |
if __name__ == "__main__": | |
main() |