chatbot-backend / src /utils /conversation_summarizer.py
TalatMasood's picture
changed the bart model from large to base
82b8aa2
raw
history blame
4.99 kB
# src/utils/conversation_summarizer.py
from typing import List, Dict
from transformers import pipeline
import numpy as np
from datetime import datetime
from config.config import settings
class ConversationSummarizer:
def __init__(
self,
model_name: str = None,
max_length: int = None,
min_length: int = None
):
"""
Initialize the summarizer
Args:
model_name (str, optional): Override default model from config
max_length (int, optional): Override default max_length from config
min_length (int, optional): Override default min_length from config
"""
# Use provided values or fall back to config values
self.model_name = model_name or settings.SUMMARIZER_CONFIG['model_name']
self.max_length = max_length or settings.SUMMARIZER_CONFIG['max_length']
self.min_length = min_length or settings.SUMMARIZER_CONFIG['min_length']
# Initialize the summarizer with config settings
self.summarizer = pipeline(
"summarization",
model=self.model_name,
device=settings.SUMMARIZER_CONFIG['device'],
model_kwargs=settings.SUMMARIZER_CONFIG['model_kwargs']
)
async def summarize_conversation(
self,
messages: List[Dict],
include_metadata: bool = True
) -> Dict:
"""
Summarize a conversation and provide key insights
"""
# Format conversation for summarization
formatted_convo = self._format_conversation(messages)
# Generate summary
summary = self.summarizer(
formatted_convo,
max_length=self.max_length,
min_length=self.min_length,
do_sample=False
)[0]['summary_text']
# Extract key insights
insights = self._extract_insights(messages)
# Generate metadata if requested
metadata = self._generate_metadata(
messages) if include_metadata else {}
return {
'summary': summary,
'key_insights': insights,
'metadata': metadata
}
def _format_conversation(self, messages: List[Dict]) -> str:
"""Format conversation for summarization"""
formatted = []
for msg in messages:
role = msg.get('role', 'unknown')
content = msg.get('content', '')
formatted.append(f"{role}: {content}")
return "\n".join(formatted)
def _extract_insights(self, messages: List[Dict]) -> Dict:
"""Extract key insights from conversation"""
# Count message types
message_counts = {
'user': len([m for m in messages if m.get('role') == 'user']),
'assistant': len([m for m in messages if m.get('role') == 'assistant'])
}
# Calculate average message length
avg_length = np.mean([len(m.get('content', '')) for m in messages])
# Extract main topics (simplified)
topics = self._extract_topics(messages)
return {
'message_distribution': message_counts,
'average_message_length': int(avg_length),
'main_topics': topics,
'total_messages': len(messages)
}
def _extract_topics(self, messages: List[Dict]) -> List[str]:
"""Extract main topics from conversation"""
# Combine all messages
full_text = " ".join([m.get('content', '') for m in messages])
# Use the summarizer to extract main points
topics = self.summarizer(
full_text,
max_length=50,
min_length=10,
do_sample=False
)[0]['summary_text'].split('. ')
return topics
def _generate_metadata(self, messages: List[Dict]) -> Dict:
"""Generate conversation metadata"""
if not messages:
return {}
return {
'start_time': messages[0].get('timestamp', None),
'end_time': messages[-1].get('timestamp', None),
'duration_minutes': self._calculate_duration(messages),
'sources_used': self._extract_sources(messages)
}
def _calculate_duration(self, messages: List[Dict]) -> float:
"""Calculate conversation duration in minutes"""
try:
start_time = datetime.fromisoformat(
messages[0].get('timestamp', ''))
end_time = datetime.fromisoformat(
messages[-1].get('timestamp', ''))
return (end_time - start_time).total_seconds() / 60
except:
return 0
def _extract_sources(self, messages: List[Dict]) -> List[str]:
"""Extract unique sources used in conversation"""
sources = set()
for message in messages:
if message.get('sources'):
for source in message['sources']:
sources.add(source.get('filename', ''))
return list(sources)