Spaces:
Running
Running
Commit
·
82b8aa2
1
Parent(s):
32318b8
changed the bart model from large to base
Browse files- config/config.py +12 -0
- src/utils/conversation_summarizer.py +40 -23
config/config.py
CHANGED
@@ -34,6 +34,18 @@ class Settings:
|
|
34 |
# Better for development purposes.
|
35 |
return os.getenv('EMBEDDING_MODEL', 'all-MiniLM-L6-v2')
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
# Vector Store Configuration
|
38 |
CHROMA_PATH = os.getenv('CHROMA_PATH', './chroma_db')
|
39 |
|
|
|
34 |
# Better for development purposes.
|
35 |
return os.getenv('EMBEDDING_MODEL', 'all-MiniLM-L6-v2')
|
36 |
|
37 |
+
# New Conversation Summarizer Settings
|
38 |
+
SUMMARIZER_CONFIG = {
|
39 |
+
# 'facebook/bart-large-cnn', for bigger and better model
|
40 |
+
'model_name': os.getenv('SUMMARIZER_MODEL', 'facebook/bart-base'),
|
41 |
+
'max_length': int(os.getenv('SUMMARIZER_MAX_LENGTH', '130')),
|
42 |
+
'min_length': int(os.getenv('SUMMARIZER_MIN_LENGTH', '30')),
|
43 |
+
'device': -1, # CPU
|
44 |
+
'model_kwargs': {
|
45 |
+
'low_cpu_mem_usage': True
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
# Vector Store Configuration
|
50 |
CHROMA_PATH = os.getenv('CHROMA_PATH', './chroma_db')
|
51 |
|
src/utils/conversation_summarizer.py
CHANGED
@@ -3,22 +3,36 @@ from typing import List, Dict
|
|
3 |
from transformers import pipeline
|
4 |
import numpy as np
|
5 |
from datetime import datetime
|
|
|
|
|
6 |
|
7 |
class ConversationSummarizer:
|
8 |
def __init__(
|
9 |
self,
|
10 |
-
model_name: str =
|
11 |
-
max_length: int =
|
12 |
-
min_length: int =
|
13 |
):
|
14 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
self.summarizer = pipeline(
|
16 |
"summarization",
|
17 |
-
model=model_name,
|
18 |
-
device
|
|
|
19 |
)
|
20 |
-
self.max_length = max_length
|
21 |
-
self.min_length = min_length
|
22 |
|
23 |
async def summarize_conversation(
|
24 |
self,
|
@@ -30,7 +44,7 @@ class ConversationSummarizer:
|
|
30 |
"""
|
31 |
# Format conversation for summarization
|
32 |
formatted_convo = self._format_conversation(messages)
|
33 |
-
|
34 |
# Generate summary
|
35 |
summary = self.summarizer(
|
36 |
formatted_convo,
|
@@ -38,13 +52,14 @@ class ConversationSummarizer:
|
|
38 |
min_length=self.min_length,
|
39 |
do_sample=False
|
40 |
)[0]['summary_text']
|
41 |
-
|
42 |
# Extract key insights
|
43 |
insights = self._extract_insights(messages)
|
44 |
-
|
45 |
# Generate metadata if requested
|
46 |
-
metadata = self._generate_metadata(
|
47 |
-
|
|
|
48 |
return {
|
49 |
'summary': summary,
|
50 |
'key_insights': insights,
|
@@ -58,7 +73,7 @@ class ConversationSummarizer:
|
|
58 |
role = msg.get('role', 'unknown')
|
59 |
content = msg.get('content', '')
|
60 |
formatted.append(f"{role}: {content}")
|
61 |
-
|
62 |
return "\n".join(formatted)
|
63 |
|
64 |
def _extract_insights(self, messages: List[Dict]) -> Dict:
|
@@ -68,13 +83,13 @@ class ConversationSummarizer:
|
|
68 |
'user': len([m for m in messages if m.get('role') == 'user']),
|
69 |
'assistant': len([m for m in messages if m.get('role') == 'assistant'])
|
70 |
}
|
71 |
-
|
72 |
# Calculate average message length
|
73 |
avg_length = np.mean([len(m.get('content', '')) for m in messages])
|
74 |
-
|
75 |
# Extract main topics (simplified)
|
76 |
topics = self._extract_topics(messages)
|
77 |
-
|
78 |
return {
|
79 |
'message_distribution': message_counts,
|
80 |
'average_message_length': int(avg_length),
|
@@ -86,7 +101,7 @@ class ConversationSummarizer:
|
|
86 |
"""Extract main topics from conversation"""
|
87 |
# Combine all messages
|
88 |
full_text = " ".join([m.get('content', '') for m in messages])
|
89 |
-
|
90 |
# Use the summarizer to extract main points
|
91 |
topics = self.summarizer(
|
92 |
full_text,
|
@@ -94,14 +109,14 @@ class ConversationSummarizer:
|
|
94 |
min_length=10,
|
95 |
do_sample=False
|
96 |
)[0]['summary_text'].split('. ')
|
97 |
-
|
98 |
return topics
|
99 |
|
100 |
def _generate_metadata(self, messages: List[Dict]) -> Dict:
|
101 |
"""Generate conversation metadata"""
|
102 |
if not messages:
|
103 |
return {}
|
104 |
-
|
105 |
return {
|
106 |
'start_time': messages[0].get('timestamp', None),
|
107 |
'end_time': messages[-1].get('timestamp', None),
|
@@ -112,8 +127,10 @@ class ConversationSummarizer:
|
|
112 |
def _calculate_duration(self, messages: List[Dict]) -> float:
|
113 |
"""Calculate conversation duration in minutes"""
|
114 |
try:
|
115 |
-
start_time = datetime.fromisoformat(
|
116 |
-
|
|
|
|
|
117 |
return (end_time - start_time).total_seconds() / 60
|
118 |
except:
|
119 |
return 0
|
@@ -125,4 +142,4 @@ class ConversationSummarizer:
|
|
125 |
if message.get('sources'):
|
126 |
for source in message['sources']:
|
127 |
sources.add(source.get('filename', ''))
|
128 |
-
return list(sources)
|
|
|
3 |
from transformers import pipeline
|
4 |
import numpy as np
|
5 |
from datetime import datetime
|
6 |
+
from config.config import settings
|
7 |
+
|
8 |
|
9 |
class ConversationSummarizer:
|
10 |
def __init__(
|
11 |
self,
|
12 |
+
model_name: str = None,
|
13 |
+
max_length: int = None,
|
14 |
+
min_length: int = None
|
15 |
):
|
16 |
+
"""
|
17 |
+
Initialize the summarizer
|
18 |
+
|
19 |
+
Args:
|
20 |
+
model_name (str, optional): Override default model from config
|
21 |
+
max_length (int, optional): Override default max_length from config
|
22 |
+
min_length (int, optional): Override default min_length from config
|
23 |
+
"""
|
24 |
+
# Use provided values or fall back to config values
|
25 |
+
self.model_name = model_name or settings.SUMMARIZER_CONFIG['model_name']
|
26 |
+
self.max_length = max_length or settings.SUMMARIZER_CONFIG['max_length']
|
27 |
+
self.min_length = min_length or settings.SUMMARIZER_CONFIG['min_length']
|
28 |
+
|
29 |
+
# Initialize the summarizer with config settings
|
30 |
self.summarizer = pipeline(
|
31 |
"summarization",
|
32 |
+
model=self.model_name,
|
33 |
+
device=settings.SUMMARIZER_CONFIG['device'],
|
34 |
+
model_kwargs=settings.SUMMARIZER_CONFIG['model_kwargs']
|
35 |
)
|
|
|
|
|
36 |
|
37 |
async def summarize_conversation(
|
38 |
self,
|
|
|
44 |
"""
|
45 |
# Format conversation for summarization
|
46 |
formatted_convo = self._format_conversation(messages)
|
47 |
+
|
48 |
# Generate summary
|
49 |
summary = self.summarizer(
|
50 |
formatted_convo,
|
|
|
52 |
min_length=self.min_length,
|
53 |
do_sample=False
|
54 |
)[0]['summary_text']
|
55 |
+
|
56 |
# Extract key insights
|
57 |
insights = self._extract_insights(messages)
|
58 |
+
|
59 |
# Generate metadata if requested
|
60 |
+
metadata = self._generate_metadata(
|
61 |
+
messages) if include_metadata else {}
|
62 |
+
|
63 |
return {
|
64 |
'summary': summary,
|
65 |
'key_insights': insights,
|
|
|
73 |
role = msg.get('role', 'unknown')
|
74 |
content = msg.get('content', '')
|
75 |
formatted.append(f"{role}: {content}")
|
76 |
+
|
77 |
return "\n".join(formatted)
|
78 |
|
79 |
def _extract_insights(self, messages: List[Dict]) -> Dict:
|
|
|
83 |
'user': len([m for m in messages if m.get('role') == 'user']),
|
84 |
'assistant': len([m for m in messages if m.get('role') == 'assistant'])
|
85 |
}
|
86 |
+
|
87 |
# Calculate average message length
|
88 |
avg_length = np.mean([len(m.get('content', '')) for m in messages])
|
89 |
+
|
90 |
# Extract main topics (simplified)
|
91 |
topics = self._extract_topics(messages)
|
92 |
+
|
93 |
return {
|
94 |
'message_distribution': message_counts,
|
95 |
'average_message_length': int(avg_length),
|
|
|
101 |
"""Extract main topics from conversation"""
|
102 |
# Combine all messages
|
103 |
full_text = " ".join([m.get('content', '') for m in messages])
|
104 |
+
|
105 |
# Use the summarizer to extract main points
|
106 |
topics = self.summarizer(
|
107 |
full_text,
|
|
|
109 |
min_length=10,
|
110 |
do_sample=False
|
111 |
)[0]['summary_text'].split('. ')
|
112 |
+
|
113 |
return topics
|
114 |
|
115 |
def _generate_metadata(self, messages: List[Dict]) -> Dict:
|
116 |
"""Generate conversation metadata"""
|
117 |
if not messages:
|
118 |
return {}
|
119 |
+
|
120 |
return {
|
121 |
'start_time': messages[0].get('timestamp', None),
|
122 |
'end_time': messages[-1].get('timestamp', None),
|
|
|
127 |
def _calculate_duration(self, messages: List[Dict]) -> float:
|
128 |
"""Calculate conversation duration in minutes"""
|
129 |
try:
|
130 |
+
start_time = datetime.fromisoformat(
|
131 |
+
messages[0].get('timestamp', ''))
|
132 |
+
end_time = datetime.fromisoformat(
|
133 |
+
messages[-1].get('timestamp', ''))
|
134 |
return (end_time - start_time).total_seconds() / 60
|
135 |
except:
|
136 |
return 0
|
|
|
142 |
if message.get('sources'):
|
143 |
for source in message['sources']:
|
144 |
sources.add(source.get('filename', ''))
|
145 |
+
return list(sources)
|