TalatMasood commited on
Commit
82b8aa2
·
1 Parent(s): 32318b8

changed the bart model from large to base

Browse files
config/config.py CHANGED
@@ -34,6 +34,18 @@ class Settings:
34
  # Better for development purposes.
35
  return os.getenv('EMBEDDING_MODEL', 'all-MiniLM-L6-v2')
36
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Vector Store Configuration
38
  CHROMA_PATH = os.getenv('CHROMA_PATH', './chroma_db')
39
 
 
34
  # Better for development purposes.
35
  return os.getenv('EMBEDDING_MODEL', 'all-MiniLM-L6-v2')
36
 
37
+ # New Conversation Summarizer Settings
38
+ SUMMARIZER_CONFIG = {
39
+ # 'facebook/bart-large-cnn', for bigger and better model
40
+ 'model_name': os.getenv('SUMMARIZER_MODEL', 'facebook/bart-base'),
41
+ 'max_length': int(os.getenv('SUMMARIZER_MAX_LENGTH', '130')),
42
+ 'min_length': int(os.getenv('SUMMARIZER_MIN_LENGTH', '30')),
43
+ 'device': -1, # CPU
44
+ 'model_kwargs': {
45
+ 'low_cpu_mem_usage': True
46
+ }
47
+ }
48
+
49
  # Vector Store Configuration
50
  CHROMA_PATH = os.getenv('CHROMA_PATH', './chroma_db')
51
 
src/utils/conversation_summarizer.py CHANGED
@@ -3,22 +3,36 @@ from typing import List, Dict
3
  from transformers import pipeline
4
  import numpy as np
5
  from datetime import datetime
 
 
6
 
7
  class ConversationSummarizer:
8
  def __init__(
9
  self,
10
- model_name: str = "facebook/bart-large-cnn",
11
- max_length: int = 130,
12
- min_length: int = 30
13
  ):
14
- """Initialize the summarizer"""
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  self.summarizer = pipeline(
16
  "summarization",
17
- model=model_name,
18
- device=-1 # CPU
 
19
  )
20
- self.max_length = max_length
21
- self.min_length = min_length
22
 
23
  async def summarize_conversation(
24
  self,
@@ -30,7 +44,7 @@ class ConversationSummarizer:
30
  """
31
  # Format conversation for summarization
32
  formatted_convo = self._format_conversation(messages)
33
-
34
  # Generate summary
35
  summary = self.summarizer(
36
  formatted_convo,
@@ -38,13 +52,14 @@ class ConversationSummarizer:
38
  min_length=self.min_length,
39
  do_sample=False
40
  )[0]['summary_text']
41
-
42
  # Extract key insights
43
  insights = self._extract_insights(messages)
44
-
45
  # Generate metadata if requested
46
- metadata = self._generate_metadata(messages) if include_metadata else {}
47
-
 
48
  return {
49
  'summary': summary,
50
  'key_insights': insights,
@@ -58,7 +73,7 @@ class ConversationSummarizer:
58
  role = msg.get('role', 'unknown')
59
  content = msg.get('content', '')
60
  formatted.append(f"{role}: {content}")
61
-
62
  return "\n".join(formatted)
63
 
64
  def _extract_insights(self, messages: List[Dict]) -> Dict:
@@ -68,13 +83,13 @@ class ConversationSummarizer:
68
  'user': len([m for m in messages if m.get('role') == 'user']),
69
  'assistant': len([m for m in messages if m.get('role') == 'assistant'])
70
  }
71
-
72
  # Calculate average message length
73
  avg_length = np.mean([len(m.get('content', '')) for m in messages])
74
-
75
  # Extract main topics (simplified)
76
  topics = self._extract_topics(messages)
77
-
78
  return {
79
  'message_distribution': message_counts,
80
  'average_message_length': int(avg_length),
@@ -86,7 +101,7 @@ class ConversationSummarizer:
86
  """Extract main topics from conversation"""
87
  # Combine all messages
88
  full_text = " ".join([m.get('content', '') for m in messages])
89
-
90
  # Use the summarizer to extract main points
91
  topics = self.summarizer(
92
  full_text,
@@ -94,14 +109,14 @@ class ConversationSummarizer:
94
  min_length=10,
95
  do_sample=False
96
  )[0]['summary_text'].split('. ')
97
-
98
  return topics
99
 
100
  def _generate_metadata(self, messages: List[Dict]) -> Dict:
101
  """Generate conversation metadata"""
102
  if not messages:
103
  return {}
104
-
105
  return {
106
  'start_time': messages[0].get('timestamp', None),
107
  'end_time': messages[-1].get('timestamp', None),
@@ -112,8 +127,10 @@ class ConversationSummarizer:
112
  def _calculate_duration(self, messages: List[Dict]) -> float:
113
  """Calculate conversation duration in minutes"""
114
  try:
115
- start_time = datetime.fromisoformat(messages[0].get('timestamp', ''))
116
- end_time = datetime.fromisoformat(messages[-1].get('timestamp', ''))
 
 
117
  return (end_time - start_time).total_seconds() / 60
118
  except:
119
  return 0
@@ -125,4 +142,4 @@ class ConversationSummarizer:
125
  if message.get('sources'):
126
  for source in message['sources']:
127
  sources.add(source.get('filename', ''))
128
- return list(sources)
 
3
  from transformers import pipeline
4
  import numpy as np
5
  from datetime import datetime
6
+ from config.config import settings
7
+
8
 
9
  class ConversationSummarizer:
10
  def __init__(
11
  self,
12
+ model_name: str = None,
13
+ max_length: int = None,
14
+ min_length: int = None
15
  ):
16
+ """
17
+ Initialize the summarizer
18
+
19
+ Args:
20
+ model_name (str, optional): Override default model from config
21
+ max_length (int, optional): Override default max_length from config
22
+ min_length (int, optional): Override default min_length from config
23
+ """
24
+ # Use provided values or fall back to config values
25
+ self.model_name = model_name or settings.SUMMARIZER_CONFIG['model_name']
26
+ self.max_length = max_length or settings.SUMMARIZER_CONFIG['max_length']
27
+ self.min_length = min_length or settings.SUMMARIZER_CONFIG['min_length']
28
+
29
+ # Initialize the summarizer with config settings
30
  self.summarizer = pipeline(
31
  "summarization",
32
+ model=self.model_name,
33
+ device=settings.SUMMARIZER_CONFIG['device'],
34
+ model_kwargs=settings.SUMMARIZER_CONFIG['model_kwargs']
35
  )
 
 
36
 
37
  async def summarize_conversation(
38
  self,
 
44
  """
45
  # Format conversation for summarization
46
  formatted_convo = self._format_conversation(messages)
47
+
48
  # Generate summary
49
  summary = self.summarizer(
50
  formatted_convo,
 
52
  min_length=self.min_length,
53
  do_sample=False
54
  )[0]['summary_text']
55
+
56
  # Extract key insights
57
  insights = self._extract_insights(messages)
58
+
59
  # Generate metadata if requested
60
+ metadata = self._generate_metadata(
61
+ messages) if include_metadata else {}
62
+
63
  return {
64
  'summary': summary,
65
  'key_insights': insights,
 
73
  role = msg.get('role', 'unknown')
74
  content = msg.get('content', '')
75
  formatted.append(f"{role}: {content}")
76
+
77
  return "\n".join(formatted)
78
 
79
  def _extract_insights(self, messages: List[Dict]) -> Dict:
 
83
  'user': len([m for m in messages if m.get('role') == 'user']),
84
  'assistant': len([m for m in messages if m.get('role') == 'assistant'])
85
  }
86
+
87
  # Calculate average message length
88
  avg_length = np.mean([len(m.get('content', '')) for m in messages])
89
+
90
  # Extract main topics (simplified)
91
  topics = self._extract_topics(messages)
92
+
93
  return {
94
  'message_distribution': message_counts,
95
  'average_message_length': int(avg_length),
 
101
  """Extract main topics from conversation"""
102
  # Combine all messages
103
  full_text = " ".join([m.get('content', '') for m in messages])
104
+
105
  # Use the summarizer to extract main points
106
  topics = self.summarizer(
107
  full_text,
 
109
  min_length=10,
110
  do_sample=False
111
  )[0]['summary_text'].split('. ')
112
+
113
  return topics
114
 
115
  def _generate_metadata(self, messages: List[Dict]) -> Dict:
116
  """Generate conversation metadata"""
117
  if not messages:
118
  return {}
119
+
120
  return {
121
  'start_time': messages[0].get('timestamp', None),
122
  'end_time': messages[-1].get('timestamp', None),
 
127
  def _calculate_duration(self, messages: List[Dict]) -> float:
128
  """Calculate conversation duration in minutes"""
129
  try:
130
+ start_time = datetime.fromisoformat(
131
+ messages[0].get('timestamp', ''))
132
+ end_time = datetime.fromisoformat(
133
+ messages[-1].get('timestamp', ''))
134
  return (end_time - start_time).total_seconds() / 60
135
  except:
136
  return 0
 
142
  if message.get('sources'):
143
  for source in message['sources']:
144
  sources.add(source.get('filename', ''))
145
+ return list(sources)