Spaces:
Running
Running
# src/utils/conversation_summarizer.py | |
from typing import List, Dict | |
from transformers import pipeline | |
import numpy as np | |
from datetime import datetime | |
from config.config import settings | |
class ConversationSummarizer: | |
def __init__( | |
self, | |
model_name: str = None, | |
max_length: int = None, | |
min_length: int = None | |
): | |
""" | |
Initialize the summarizer | |
Args: | |
model_name (str, optional): Override default model from config | |
max_length (int, optional): Override default max_length from config | |
min_length (int, optional): Override default min_length from config | |
""" | |
# Use provided values or fall back to config values | |
self.model_name = model_name or settings.SUMMARIZER_CONFIG['model_name'] | |
self.max_length = max_length or settings.SUMMARIZER_CONFIG['max_length'] | |
self.min_length = min_length or settings.SUMMARIZER_CONFIG['min_length'] | |
# Initialize the summarizer with config settings | |
self.summarizer = pipeline( | |
"summarization", | |
model=self.model_name, | |
device=settings.SUMMARIZER_CONFIG['device'], | |
model_kwargs=settings.SUMMARIZER_CONFIG['model_kwargs'] | |
) | |
async def summarize_conversation( | |
self, | |
messages: List[Dict], | |
include_metadata: bool = True | |
) -> Dict: | |
""" | |
Summarize a conversation and provide key insights | |
""" | |
# Format conversation for summarization | |
formatted_convo = self._format_conversation(messages) | |
# Generate summary | |
summary = self.summarizer( | |
formatted_convo, | |
max_length=self.max_length, | |
min_length=self.min_length, | |
do_sample=False | |
)[0]['summary_text'] | |
# Extract key insights | |
insights = self._extract_insights(messages) | |
# Generate metadata if requested | |
metadata = self._generate_metadata( | |
messages) if include_metadata else {} | |
return { | |
'summary': summary, | |
'key_insights': insights, | |
'metadata': metadata | |
} | |
def _format_conversation(self, messages: List[Dict]) -> str: | |
"""Format conversation for summarization""" | |
formatted = [] | |
for msg in messages: | |
role = msg.get('role', 'unknown') | |
content = msg.get('content', '') | |
formatted.append(f"{role}: {content}") | |
return "\n".join(formatted) | |
def _extract_insights(self, messages: List[Dict]) -> Dict: | |
"""Extract key insights from conversation""" | |
# Count message types | |
message_counts = { | |
'user': len([m for m in messages if m.get('role') == 'user']), | |
'assistant': len([m for m in messages if m.get('role') == 'assistant']) | |
} | |
# Calculate average message length | |
avg_length = np.mean([len(m.get('content', '')) for m in messages]) | |
# Extract main topics (simplified) | |
topics = self._extract_topics(messages) | |
return { | |
'message_distribution': message_counts, | |
'average_message_length': int(avg_length), | |
'main_topics': topics, | |
'total_messages': len(messages) | |
} | |
def _extract_topics(self, messages: List[Dict]) -> List[str]: | |
"""Extract main topics from conversation""" | |
# Combine all messages | |
full_text = " ".join([m.get('content', '') for m in messages]) | |
# Use the summarizer to extract main points | |
topics = self.summarizer( | |
full_text, | |
max_length=50, | |
min_length=10, | |
do_sample=False | |
)[0]['summary_text'].split('. ') | |
return topics | |
def _generate_metadata(self, messages: List[Dict]) -> Dict: | |
"""Generate conversation metadata""" | |
if not messages: | |
return {} | |
return { | |
'start_time': messages[0].get('timestamp', None), | |
'end_time': messages[-1].get('timestamp', None), | |
'duration_minutes': self._calculate_duration(messages), | |
'sources_used': self._extract_sources(messages) | |
} | |
def _calculate_duration(self, messages: List[Dict]) -> float: | |
"""Calculate conversation duration in minutes""" | |
try: | |
start_time = datetime.fromisoformat( | |
messages[0].get('timestamp', '')) | |
end_time = datetime.fromisoformat( | |
messages[-1].get('timestamp', '')) | |
return (end_time - start_time).total_seconds() / 60 | |
except: | |
return 0 | |
def _extract_sources(self, messages: List[Dict]) -> List[str]: | |
"""Extract unique sources used in conversation""" | |
sources = set() | |
for message in messages: | |
if message.get('sources'): | |
for source in message['sources']: | |
sources.add(source.get('filename', '')) | |
return list(sources) | |