ganesh3 commited on
Commit
507c938
·
1 Parent(s): dbd33b2

second modification

Browse files
Dockerfile CHANGED
@@ -22,6 +22,7 @@ COPY app/ ./app/
22
  COPY config/ ./config/
23
  COPY data/ ./data/
24
  COPY grafana/ ./grafana/
 
25
 
26
  # Make port 8501 available to the world outside this container
27
  EXPOSE 8501
 
22
  COPY config/ ./config/
23
  COPY data/ ./data/
24
  COPY grafana/ ./grafana/
25
+ COPY .env ./
26
 
27
  # Make port 8501 available to the world outside this container
28
  EXPOSE 8501
app/data_processor.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from minsearch import Index
2
  from sentence_transformers import SentenceTransformer
3
  import numpy as np
@@ -5,124 +6,198 @@ from sklearn.metrics.pairwise import cosine_similarity
5
  import re
6
  from elasticsearch import Elasticsearch
7
  import os
 
 
 
 
 
8
 
9
  def clean_text(text):
10
- # Remove special characters and extra whitespace
11
- text = re.sub(r'[^\w\s]', '', text)
12
- text = re.sub(r'\s+', ' ', text).strip()
13
- return text
 
 
 
14
 
15
  class DataProcessor:
16
  def __init__(self, text_fields=["content", "title", "description"],
17
- keyword_fields=["video_id", "start_time", "author", "upload_date"],
18
  embedding_model="all-MiniLM-L6-v2"):
19
  self.text_index = Index(text_fields=text_fields, keyword_fields=keyword_fields)
20
  self.embedding_model = SentenceTransformer(embedding_model)
21
  self.documents = []
22
  self.embeddings = []
 
 
23
 
24
- # Use environment variables for Elasticsearch configuration
25
  elasticsearch_host = os.getenv('ELASTICSEARCH_HOST', 'localhost')
26
  elasticsearch_port = int(os.getenv('ELASTICSEARCH_PORT', 9200))
27
 
28
- # Initialize Elasticsearch client with explicit scheme
29
  self.es = Elasticsearch([f'http://{elasticsearch_host}:{elasticsearch_port}'])
 
30
 
31
  def process_transcript(self, video_id, transcript_data):
 
 
 
 
32
  metadata = transcript_data['metadata']
33
  transcript = transcript_data['transcript']
34
 
35
- for i, segment in enumerate(transcript):
36
- cleaned_text = clean_text(segment['text'])
37
- doc = {
38
- "video_id": video_id,
39
- "content": cleaned_text,
40
- "start_time": segment['start'],
41
- "duration": segment['duration'],
42
- "segment_id": f"{video_id}_{i}",
43
- "title": metadata['title'],
44
- "author": metadata['author'],
45
- "upload_date": metadata['upload_date'],
46
- "view_count": metadata['view_count'],
47
- "like_count": metadata['like_count'],
48
- "comment_count": metadata['comment_count'],
49
- "video_duration": metadata['duration']
50
- }
51
- self.documents.append(doc)
52
- self.embeddings.append(self.embedding_model.encode(cleaned_text + " " + metadata['title']))
 
 
 
 
 
 
 
 
 
53
 
54
  def build_index(self, index_name):
55
- self.text_index.fit(self.documents)
 
 
 
 
 
 
 
 
 
 
 
 
56
  self.embeddings = np.array(self.embeddings)
57
 
58
- # Create Elasticsearch index
59
- if not self.es.indices.exists(index=index_name):
60
- self.es.indices.create(index=index_name, body={
61
- "mappings": {
62
- "properties": {
63
- "embedding": {"type": "dense_vector", "dims": self.embeddings.shape[1]},
64
- "content": {"type": "text"},
65
- "video_id": {"type": "keyword"},
66
- "segment_id": {"type": "keyword"},
67
- "start_time": {"type": "float"},
68
- "duration": {"type": "float"},
69
- "title": {"type": "text"},
70
- "author": {"type": "keyword"},
71
- "upload_date": {"type": "date"},
72
- "view_count": {"type": "integer"},
73
- "like_count": {"type": "integer"},
74
- "comment_count": {"type": "integer"},
75
- "video_duration": {"type": "text"}
76
  }
77
- }
78
- })
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- # Index documents in Elasticsearch
81
- for doc, embedding in zip(self.documents, self.embeddings):
82
- doc['embedding'] = embedding.tolist()
83
- self.es.index(index=index_name, body=doc, id=doc['segment_id'])
 
 
 
 
 
 
 
 
84
 
85
  def search(self, query, filter_dict={}, boost_dict={}, num_results=10, method='hybrid', index_name=None):
 
 
 
 
 
 
 
 
 
 
86
  if method == 'text':
87
- return self.text_search(query, filter_dict, boost_dict, num_results)
88
  elif method == 'embedding':
89
  return self.embedding_search(query, num_results, index_name)
90
  else: # hybrid search
91
- text_results = self.text_search(query, filter_dict, boost_dict, num_results)
92
  embedding_results = self.embedding_search(query, num_results, index_name)
93
  return self.combine_results(text_results, embedding_results, num_results)
94
 
95
- def text_search(self, query, filter_dict={}, boost_dict={}, num_results=10):
96
- return self.text_index.search(query, filter_dict, boost_dict, num_results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def embedding_search(self, query, num_results=10, index_name=None):
99
- if index_name:
100
- # Use Elasticsearch for embedding search
101
- query_vector = self.embedding_model.encode(query).tolist()
102
- script_query = {
103
- "script_score": {
104
- "query": {"match_all": {}},
105
- "script": {
106
- "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
107
- "params": {"query_vector": query_vector}
108
- }
 
109
  }
110
  }
111
- response = self.es.search(
112
- index=index_name,
113
- body={
114
- "size": num_results,
115
- "query": script_query,
116
- "_source": {"excludes": ["embedding"]}
117
- }
118
- )
119
- return [hit['_source'] for hit in response['hits']['hits']]
120
- else:
121
- # Use in-memory embedding search
122
- query_embedding = self.embedding_model.encode(query)
123
- similarities = cosine_similarity([query_embedding], self.embeddings)[0]
124
- top_indices = np.argsort(similarities)[::-1][:num_results]
125
- return [self.documents[i] for i in top_indices]
126
 
127
  def combine_results(self, text_results, embedding_results, num_results):
128
  combined = []
@@ -142,4 +217,8 @@ class DataProcessor:
142
  return deduped[:num_results]
143
 
144
  def process_query(self, query):
145
- return clean_text(query)
 
 
 
 
 
1
+ import logging
2
  from minsearch import Index
3
  from sentence_transformers import SentenceTransformer
4
  import numpy as np
 
6
  import re
7
  from elasticsearch import Elasticsearch
8
  import os
9
+ import json
10
+ from transcript_extractor import get_transcript
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
 
15
  def clean_text(text):
16
+ if not isinstance(text, str):
17
+ logger.warning(f"Non-string input to clean_text: {type(text)}")
18
+ return ""
19
+ cleaned = re.sub(r'[^\w\s.,!?]', ' ', text)
20
+ cleaned = re.sub(r'\s+', ' ', cleaned).strip()
21
+ logger.info(f"Cleaned text: '{cleaned[:100]}...'")
22
+ return cleaned
23
 
24
  class DataProcessor:
25
  def __init__(self, text_fields=["content", "title", "description"],
26
+ keyword_fields=["video_id", "author", "upload_date"],
27
  embedding_model="all-MiniLM-L6-v2"):
28
  self.text_index = Index(text_fields=text_fields, keyword_fields=keyword_fields)
29
  self.embedding_model = SentenceTransformer(embedding_model)
30
  self.documents = []
31
  self.embeddings = []
32
+ self.index_built = False
33
+ self.current_index_name = None
34
 
 
35
  elasticsearch_host = os.getenv('ELASTICSEARCH_HOST', 'localhost')
36
  elasticsearch_port = int(os.getenv('ELASTICSEARCH_PORT', 9200))
37
 
 
38
  self.es = Elasticsearch([f'http://{elasticsearch_host}:{elasticsearch_port}'])
39
+ logger.info(f"DataProcessor initialized with Elasticsearch at {elasticsearch_host}:{elasticsearch_port}")
40
 
41
  def process_transcript(self, video_id, transcript_data):
42
+ if not transcript_data or 'metadata' not in transcript_data or 'transcript' not in transcript_data:
43
+ logger.error(f"Invalid transcript data for video {video_id}")
44
+ return None
45
+
46
  metadata = transcript_data['metadata']
47
  transcript = transcript_data['transcript']
48
 
49
+ logger.info(f"Processing transcript for video {video_id}")
50
+ logger.info(f"Number of transcript segments: {len(transcript)}")
51
+
52
+ full_transcript = " ".join([segment.get('text', '') for segment in transcript])
53
+ cleaned_transcript = clean_text(full_transcript)
54
+
55
+ if not cleaned_transcript:
56
+ logger.warning(f"Empty cleaned transcript for video {video_id}")
57
+ return None
58
+
59
+ doc = {
60
+ "video_id": video_id,
61
+ "content": cleaned_transcript,
62
+ "segment_id": f"{video_id}_full",
63
+ "title": clean_text(metadata.get('title', '')),
64
+ "author": metadata.get('author', ''),
65
+ "upload_date": metadata.get('upload_date', ''),
66
+ "view_count": metadata.get('view_count', 0),
67
+ "like_count": metadata.get('like_count', 0),
68
+ "comment_count": metadata.get('comment_count', 0),
69
+ "video_duration": metadata.get('duration', '')
70
+ }
71
+ self.documents.append(doc)
72
+ self.embeddings.append(self.embedding_model.encode(cleaned_transcript + " " + metadata.get('title', '')))
73
+
74
+ logger.info(f"Processed transcript for video {video_id}")
75
+ return f"video_{video_id}_{self.embedding_model.get_sentence_embedding_dimension()}"
76
 
77
  def build_index(self, index_name):
78
+ if not self.documents:
79
+ logger.error("No documents to index")
80
+ return None
81
+
82
+ logger.info(f"Building index with {len(self.documents)} documents")
83
+ try:
84
+ self.text_index.fit(self.documents)
85
+ self.index_built = True
86
+ logger.info("Text index built successfully")
87
+ except Exception as e:
88
+ logger.error(f"Error building text index: {str(e)}")
89
+ raise
90
+
91
  self.embeddings = np.array(self.embeddings)
92
 
93
+ try:
94
+ if not self.es.indices.exists(index=index_name):
95
+ self.es.indices.create(index=index_name, body={
96
+ "mappings": {
97
+ "properties": {
98
+ "embedding": {"type": "dense_vector", "dims": self.embeddings.shape[1]},
99
+ "content": {"type": "text"},
100
+ "video_id": {"type": "keyword"},
101
+ "segment_id": {"type": "keyword"},
102
+ "title": {"type": "text"},
103
+ "author": {"type": "keyword"},
104
+ "upload_date": {"type": "date"},
105
+ "view_count": {"type": "integer"},
106
+ "like_count": {"type": "integer"},
107
+ "comment_count": {"type": "integer"},
108
+ "video_duration": {"type": "text"}
109
+ }
 
110
  }
111
+ })
112
+ logger.info(f"Created Elasticsearch index: {index_name}")
113
+
114
+ for doc, embedding in zip(self.documents, self.embeddings):
115
+ doc_with_embedding = doc.copy()
116
+ doc_with_embedding['embedding'] = embedding.tolist()
117
+ self.es.index(index=index_name, body=doc_with_embedding, id=doc['segment_id'])
118
+
119
+ logger.info(f"Successfully indexed {len(self.documents)} documents in Elasticsearch")
120
+ self.current_index_name = index_name
121
+ return index_name
122
+ except Exception as e:
123
+ logger.error(f"Error building Elasticsearch index: {str(e)}")
124
+ raise
125
 
126
+ def ensure_index_built(self, video_id, embedding_model):
127
+ index_name = f"video_{video_id}_{embedding_model.replace('-', '_')}".lower()
128
+ if not self.es.indices.exists(index=index_name):
129
+ logger.info(f"Index {index_name} does not exist. Building now...")
130
+ transcript_data = get_transcript(video_id)
131
+ if transcript_data:
132
+ self.process_transcript(video_id, transcript_data)
133
+ return self.build_index(index_name)
134
+ else:
135
+ logger.error(f"Failed to retrieve transcript for video {video_id}")
136
+ return None
137
+ return index_name
138
 
139
  def search(self, query, filter_dict={}, boost_dict={}, num_results=10, method='hybrid', index_name=None):
140
+ if not index_name:
141
+ logger.error("No index name provided for search.")
142
+ raise ValueError("No index name provided for search.")
143
+
144
+ if not self.es.indices.exists(index=index_name):
145
+ logger.error(f"Index {index_name} does not exist.")
146
+ raise ValueError(f"Index {index_name} does not exist.")
147
+
148
+ logger.info(f"Performing {method} search for query: {query} in index: {index_name}")
149
+
150
  if method == 'text':
151
+ return self.text_search(query, filter_dict, boost_dict, num_results, index_name)
152
  elif method == 'embedding':
153
  return self.embedding_search(query, num_results, index_name)
154
  else: # hybrid search
155
+ text_results = self.text_search(query, filter_dict, boost_dict, num_results, index_name)
156
  embedding_results = self.embedding_search(query, num_results, index_name)
157
  return self.combine_results(text_results, embedding_results, num_results)
158
 
159
+ def text_search(self, query, filter_dict={}, boost_dict={}, num_results=10, index_name=None):
160
+ if not index_name:
161
+ logger.error("No index name provided for text search.")
162
+ raise ValueError("No index name provided for text search.")
163
+
164
+ # Perform text search using Elasticsearch
165
+ search_body = {
166
+ "query": {
167
+ "multi_match": {
168
+ "query": query,
169
+ "fields": ["content", "title"]
170
+ }
171
+ },
172
+ "size": num_results
173
+ }
174
+ response = self.es.search(index=index_name, body=search_body)
175
+ return [hit['_source'] for hit in response['hits']['hits']]
176
 
177
  def embedding_search(self, query, num_results=10, index_name=None):
178
+ if not index_name:
179
+ logger.error("No index name provided for embedding search.")
180
+ raise ValueError("No index name provided for embedding search.")
181
+
182
+ query_vector = self.embedding_model.encode(query).tolist()
183
+ script_query = {
184
+ "script_score": {
185
+ "query": {"match_all": {}},
186
+ "script": {
187
+ "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
188
+ "params": {"query_vector": query_vector}
189
  }
190
  }
191
+ }
192
+ response = self.es.search(
193
+ index=index_name,
194
+ body={
195
+ "size": num_results,
196
+ "query": script_query,
197
+ "_source": {"excludes": ["embedding"]}
198
+ }
199
+ )
200
+ return [hit['_source'] for hit in response['hits']['hits']]
 
 
 
 
 
201
 
202
  def combine_results(self, text_results, embedding_results, num_results):
203
  combined = []
 
217
  return deduped[:num_results]
218
 
219
  def process_query(self, query):
220
+ return clean_text(query)
221
+
222
+ def set_embedding_model(self, model_name):
223
+ self.embedding_model = SentenceTransformer(model_name)
224
+ logger.info(f"Embedding model set to: {model_name}")
app/database.py CHANGED
@@ -6,6 +6,7 @@ class DatabaseHandler:
6
  self.db_path = db_path
7
  self.conn = None
8
  self.create_tables()
 
9
 
10
  def create_tables(self):
11
  with sqlite3.connect(self.db_path) as conn:
@@ -48,13 +49,44 @@ class DatabaseHandler:
48
  ''')
49
  conn.commit()
50
 
51
- def add_video(self, youtube_id, title, channel_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  with sqlite3.connect(self.db_path) as conn:
53
  cursor = conn.cursor()
54
  cursor.execute('''
55
- INSERT OR IGNORE INTO videos (youtube_id, title, channel_name)
56
- VALUES (?, ?, ?)
57
- ''', (youtube_id, title, channel_name))
 
 
 
 
 
 
 
 
 
 
58
  conn.commit()
59
  return cursor.lastrowid
60
 
@@ -92,12 +124,75 @@ class DatabaseHandler:
92
  cursor.execute('SELECT * FROM videos WHERE youtube_id = ?', (youtube_id,))
93
  return cursor.fetchone()
94
 
95
- def get_elasticsearch_index(self, video_id, embedding_model_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  with sqlite3.connect(self.db_path) as conn:
97
  cursor = conn.cursor()
98
  cursor.execute('''
99
- SELECT index_name FROM elasticsearch_indices
100
- WHERE video_id = ? AND embedding_model_id = ?
101
- ''', (video_id, embedding_model_id))
 
 
 
102
  result = cursor.fetchone()
103
  return result[0] if result else None
 
6
  self.db_path = db_path
7
  self.conn = None
8
  self.create_tables()
9
+ self.update_schema()
10
 
11
  def create_tables(self):
12
  with sqlite3.connect(self.db_path) as conn:
 
49
  ''')
50
  conn.commit()
51
 
52
+ def update_schema(self):
53
+ with sqlite3.connect(self.db_path) as conn:
54
+ cursor = conn.cursor()
55
+ # Check if columns exist, if not, add them
56
+ cursor.execute("PRAGMA table_info(videos)")
57
+ columns = [column[1] for column in cursor.fetchall()]
58
+
59
+ new_columns = [
60
+ ("upload_date", "TEXT"),
61
+ ("view_count", "INTEGER"),
62
+ ("like_count", "INTEGER"),
63
+ ("comment_count", "INTEGER"),
64
+ ("video_duration", "TEXT")
65
+ ]
66
+
67
+ for col_name, col_type in new_columns:
68
+ if col_name not in columns:
69
+ cursor.execute(f"ALTER TABLE videos ADD COLUMN {col_name} {col_type}")
70
+
71
+ conn.commit()
72
+
73
+ def add_video(self, video_data):
74
  with sqlite3.connect(self.db_path) as conn:
75
  cursor = conn.cursor()
76
  cursor.execute('''
77
+ INSERT OR REPLACE INTO videos
78
+ (youtube_id, title, channel_name, upload_date, view_count, like_count, comment_count, video_duration)
79
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
80
+ ''', (
81
+ video_data['video_id'],
82
+ video_data['title'],
83
+ video_data['author'],
84
+ video_data['upload_date'],
85
+ video_data['view_count'],
86
+ video_data['like_count'],
87
+ video_data['comment_count'],
88
+ video_data['video_duration']
89
+ ))
90
  conn.commit()
91
  return cursor.lastrowid
92
 
 
124
  cursor.execute('SELECT * FROM videos WHERE youtube_id = ?', (youtube_id,))
125
  return cursor.fetchone()
126
 
127
+ def get_elasticsearch_index(self, video_id, embedding_model):
128
+ with sqlite3.connect(self.db_path) as conn:
129
+ cursor = conn.cursor()
130
+ cursor.execute('''
131
+ SELECT ei.index_name
132
+ FROM elasticsearch_indices ei
133
+ JOIN embedding_models em ON ei.embedding_model_id = em.id
134
+ JOIN videos v ON ei.video_id = v.id
135
+ WHERE v.youtube_id = ? AND em.model_name = ?
136
+ ''', (video_id, embedding_model))
137
+ result = cursor.fetchone()
138
+ return result[0] if result else None
139
+
140
+ def get_all_videos(self):
141
+ with sqlite3.connect(self.db_path) as conn:
142
+ cursor = conn.cursor()
143
+ cursor.execute('''
144
+ SELECT youtube_id, title, channel_name, upload_date
145
+ FROM videos
146
+ ORDER BY upload_date DESC
147
+ ''')
148
+ return cursor.fetchall()
149
+
150
+ def get_elasticsearch_index_by_youtube_id(self, youtube_id, embedding_model):
151
+ with sqlite3.connect(self.db_path) as conn:
152
+ cursor = conn.cursor()
153
+ cursor.execute('''
154
+ SELECT ei.index_name
155
+ FROM elasticsearch_indices ei
156
+ JOIN embedding_models em ON ei.embedding_model_id = em.id
157
+ JOIN videos v ON ei.video_id = v.id
158
+ WHERE v.youtube_id = ? AND em.model_name = ?
159
+ ''', (youtube_id, embedding_model))
160
+ result = cursor.fetchone()
161
+ return result[0] if result else None
162
+
163
+ def get_transcript_content(self, youtube_id):
164
+ # This method assumes you're storing the transcript content in the database
165
+ # If you're not, you'll need to modify this to retrieve the transcript from wherever it's stored
166
+ with sqlite3.connect(self.db_path) as conn:
167
+ cursor = conn.cursor()
168
+ cursor.execute('''
169
+ SELECT transcript_content
170
+ FROM videos
171
+ WHERE youtube_id = ?
172
+ ''', (youtube_id,))
173
+ result = cursor.fetchone()
174
+ return result[0] if result else None
175
+
176
+ # If you're not already storing the transcript content, you'll need to add a method to do so:
177
+ def add_transcript_content(self, youtube_id, transcript_content):
178
+ with sqlite3.connect(self.db_path) as conn:
179
+ cursor = conn.cursor()
180
+ cursor.execute('''
181
+ UPDATE videos
182
+ SET transcript_content = ?
183
+ WHERE youtube_id = ?
184
+ ''', (transcript_content, youtube_id))
185
+ conn.commit()
186
+
187
+ def get_elasticsearch_index_by_youtube_id(self, youtube_id, embedding_model):
188
  with sqlite3.connect(self.db_path) as conn:
189
  cursor = conn.cursor()
190
  cursor.execute('''
191
+ SELECT ei.index_name
192
+ FROM elasticsearch_indices ei
193
+ JOIN embedding_models em ON ei.embedding_model_id = em.id
194
+ JOIN videos v ON ei.video_id = v.id
195
+ WHERE v.youtube_id = ? AND em.model_name = ?
196
+ ''', (youtube_id, embedding_model))
197
  result = cursor.fetchone()
198
  return result[0] if result else None
app/main.py CHANGED
@@ -12,28 +12,43 @@ import json
12
  import requests
13
  from tqdm import tqdm
14
  import sqlite3
 
 
 
 
 
15
 
16
- # Initialize components
17
  @st.cache_resource
18
  def init_components():
19
- db_handler = DatabaseHandler()
20
- data_processor = DataProcessor()
21
- rag_system = RAGSystem(data_processor)
22
- query_rewriter = QueryRewriter()
23
- evaluation_system = EvaluationSystem(data_processor, db_handler)
24
- return db_handler, data_processor, rag_system, query_rewriter, evaluation_system
25
-
26
- db_handler, data_processor, rag_system, query_rewriter, evaluation_system = init_components()
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Ground Truth Generation
 
29
  def generate_questions(transcript):
30
- OLLAMA_HOST = os.getenv('OLLAMA_HOST', 'localhost')
31
- OLLAMA_PORT = os.getenv('OLLAMA_PORT', '11434')
32
  prompt_template = """
33
  You are an AI assistant tasked with generating questions based on a YouTube video transcript.
34
- Formulate 10 questions that a user might ask based on the provided transcript.
35
  Make the questions specific to the content of the transcript.
36
  The questions should be complete and not too short. Use as few words as possible from the transcript.
 
37
 
38
  The transcript:
39
 
@@ -47,34 +62,44 @@ def generate_questions(transcript):
47
  prompt = prompt_template.format(transcript=transcript)
48
 
49
  try:
50
- response = requests.post(f'http://{OLLAMA_HOST}:{OLLAMA_PORT}/api/generate', json={
51
- 'model': 'phi3.5',
52
- 'prompt': prompt
53
- })
54
- response.raise_for_status()
55
- return json.loads(response.json()['response'])
56
- except requests.RequestException as e:
57
- st.error(f"Error generating questions: {str(e)}")
 
 
 
 
 
58
  return None
59
 
60
- def generate_ground_truth(video_id):
61
- transcript_data = get_transcript(video_id)
 
 
 
 
 
 
 
 
 
 
62
 
63
- if transcript_data and 'transcript' in transcript_data:
64
- full_transcript = " ".join([entry['text'] for entry in transcript_data['transcript']])
65
- questions = generate_questions(full_transcript)
66
 
67
- if questions and 'questions' in questions:
68
- df = pd.DataFrame([(video_id, q) for q in questions['questions']], columns=['video_id', 'question'])
69
-
70
- os.makedirs('data', exist_ok=True)
71
- df.to_csv('data/ground-truth-retrieval.csv', index=False)
72
- st.success("Ground truth data generated and saved to data/ground-truth-retrieval.csv")
73
- return df
74
- else:
75
- st.error("Failed to generate questions.")
76
  else:
77
- st.error("Failed to generate ground truth data due to transcript retrieval error.")
 
78
  return None
79
 
80
  # RAG Evaluation
@@ -82,6 +107,7 @@ def evaluate_rag(sample_size=200):
82
  try:
83
  ground_truth = pd.read_csv('data/ground-truth-retrieval.csv')
84
  except FileNotFoundError:
 
85
  st.error("Ground truth file not found. Please generate ground truth data first.")
86
  return None
87
 
@@ -111,14 +137,38 @@ def evaluate_rag(sample_size=200):
111
  progress_bar = st.progress(0)
112
  for i, (_, row) in enumerate(sample.iterrows()):
113
  question = row['question']
114
- answer_llm = rag_system.query(question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  prompt = prompt_template.format(question=question, answer_llm=answer_llm)
116
- evaluation = rag_system.query(prompt) # Assuming rag_system can handle this type of query
117
  try:
118
- evaluation_json = json.loads(evaluation)
119
- evaluations.append((row['video_id'], question, answer_llm, evaluation_json['Relevance'], evaluation_json['Explanation']))
120
- except json.JSONDecodeError:
121
- st.warning(f"Failed to parse evaluation for question: {question}")
 
 
 
 
 
 
 
 
 
 
 
122
  progress_bar.progress((i + 1) / len(sample))
123
 
124
  # Store RAG evaluations in the database
@@ -140,39 +190,163 @@ def evaluate_rag(sample_size=200):
140
  conn.commit()
141
  conn.close()
142
 
 
143
  st.success("Evaluation complete. Results stored in the database.")
144
  return evaluations
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def main():
147
  st.title("YouTube Transcript RAG System")
148
 
 
 
 
 
 
 
 
149
  tab1, tab2, tab3 = st.tabs(["RAG System", "Ground Truth Generation", "Evaluation"])
150
 
151
  with tab1:
152
  st.header("RAG System")
153
- # Input section
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  input_type = st.radio("Select input type:", ["Video URL", "Channel URL", "YouTube ID"])
155
  input_value = st.text_input("Enter the URL or ID:")
156
- embedding_model = st.selectbox("Select embedding model:", ["all-MiniLM-L6-v2", "all-mpnet-base-v2"])
157
-
158
  if st.button("Process"):
159
  with st.spinner("Processing..."):
160
  data_processor.embedding_model = SentenceTransformer(embedding_model)
161
  if input_type == "Video URL":
162
  video_id = extract_video_id(input_value)
163
  if video_id:
164
- process_single_video(video_id, embedding_model)
 
 
 
 
165
  else:
166
  st.error("Failed to extract video ID from the URL")
167
  elif input_type == "Channel URL":
168
  channel_videos = get_channel_videos(input_value)
169
  if channel_videos:
170
- process_multiple_videos([video['video_id'] for video in channel_videos], embedding_model)
 
 
 
 
171
  else:
172
  st.error("Failed to retrieve videos from the channel")
173
  else:
174
- process_single_video(input_value, embedding_model)
175
-
 
 
 
 
176
  # Query section
177
  st.subheader("Query the RAG System")
178
  query = st.text_input("Enter your query:")
@@ -180,108 +354,147 @@ def main():
180
  search_method = st.radio("Search method:", ["Hybrid", "Text-only", "Embedding-only"])
181
 
182
  if st.button("Search"):
183
- with st.spinner("Searching..."):
184
- if rewrite_method == "Chain of Thought":
185
- query = query_rewriter.rewrite_cot(query)
186
- elif rewrite_method == "ReAct":
187
- query = query_rewriter.rewrite_react(query)
188
-
189
- search_method_map = {"Hybrid": "hybrid", "Text-only": "text", "Embedding-only": "embedding"}
190
- response = rag_system.query(query, search_method=search_method_map[search_method])
191
- st.write("Response:", response)
192
-
193
- # Feedback
194
- feedback = st.radio("Provide feedback:", ["+1", "-1"])
195
- if st.button("Submit Feedback"):
196
- db_handler.add_user_feedback("all_videos", query, 1 if feedback == "+1" else -1)
197
- st.success("Feedback submitted successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  with tab2:
200
  st.header("Ground Truth Generation")
201
- video_id = st.text_input("Enter YouTube Video ID for ground truth generation:")
202
- if st.button("Generate Ground Truth"):
203
- with st.spinner("Generating ground truth..."):
204
- ground_truth_df = generate_ground_truth(video_id)
205
- if ground_truth_df is not None:
206
- st.dataframe(ground_truth_df)
207
- csv = ground_truth_df.to_csv(index=False)
208
- st.download_button(
209
- label="Download Ground Truth CSV",
210
- data=csv,
211
- file_name="ground_truth.csv",
212
- mime="text/csv",
213
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  with tab3:
216
  st.header("RAG Evaluation")
217
- sample_size = st.number_input("Enter sample size for evaluation:", min_value=1, max_value=1000, value=200)
218
- if st.button("Run Evaluation"):
219
- with st.spinner("Running evaluation..."):
220
- evaluation_results = evaluate_rag(sample_size)
221
- if evaluation_results:
222
- st.write("Evaluation Results:")
223
- st.dataframe(pd.DataFrame(evaluation_results, columns=['Video ID', 'Question', 'Answer', 'Relevance', 'Explanation']))
224
-
225
- @st.cache_data
226
- def process_single_video(video_id, embedding_model):
227
- # Check if the video has already been processed with the current embedding model
228
- existing_index = db_handler.get_elasticsearch_index(video_id, embedding_model)
229
- if existing_index:
230
- st.info(f"Video {video_id} has already been processed with {embedding_model}. Using existing index: {existing_index}")
231
- return existing_index
232
 
233
- transcript_data = get_transcript(video_id)
234
- if transcript_data:
235
- # Store video metadata in the database
236
- video_data = {
237
- 'video_id': video_id,
238
- 'title': transcript_data['metadata'].get('title', 'Unknown Title'),
239
- 'author': transcript_data['metadata'].get('author', 'Unknown Author'),
240
- 'upload_date': transcript_data['metadata'].get('upload_date', 'Unknown Date'),
241
- 'view_count': int(transcript_data['metadata'].get('view_count', 0)),
242
- 'like_count': int(transcript_data['metadata'].get('like_count', 0)),
243
- 'comment_count': int(transcript_data['metadata'].get('comment_count', 0)),
244
- 'video_duration': transcript_data['metadata'].get('duration', 'Unknown Duration')
245
- }
246
- db_handler.add_video(video_data)
247
 
248
- # Store transcript segments in the database
249
- for i, segment in enumerate(transcript_data['transcript']):
250
- segment_data = {
251
- 'segment_id': f"{video_id}_{i}",
252
- 'video_id': video_id,
253
- 'content': segment.get('text', ''),
254
- 'start_time': segment.get('start', 0),
255
- 'duration': segment.get('duration', 0)
256
- }
257
- db_handler.add_transcript_segment(segment_data)
258
-
259
- # Process transcript for RAG system
260
- data_processor.process_transcript(video_id, transcript_data)
261
-
262
- # Create Elasticsearch index
263
- index_name = f"video_{video_id}_{embedding_model}"
264
- data_processor.build_index(index_name)
265
-
266
- # Store Elasticsearch index information
267
- db_handler.add_elasticsearch_index(video_id, index_name, embedding_model)
268
-
269
- st.success(f"Processed and indexed transcript for video {video_id}")
270
- st.write("Metadata:", transcript_data['metadata'])
271
- return index_name
272
- else:
273
- st.error(f"Failed to retrieve transcript for video {video_id}")
274
- return None
275
 
276
- @st.cache_data
277
- def process_multiple_videos(video_ids, embedding_model):
278
- indices = []
279
- for video_id in video_ids:
280
- index = process_single_video(video_id, embedding_model)
281
- if index:
282
- indices.append(index)
283
- st.success(f"Processed and indexed transcripts for {len(indices)} videos")
284
- return indices
 
 
 
 
 
 
 
 
 
 
285
 
286
  if __name__ == "__main__":
287
  main()
 
12
  import requests
13
  from tqdm import tqdm
14
  import sqlite3
15
+ import logging
16
+ import ollama
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
 
 
21
  @st.cache_resource
22
  def init_components():
23
+ try:
24
+ db_handler = DatabaseHandler()
25
+ data_processor = DataProcessor()
26
+ rag_system = RAGSystem(data_processor)
27
+ query_rewriter = QueryRewriter()
28
+ evaluation_system = EvaluationSystem(data_processor, db_handler)
29
+ logger.info("Components initialized successfully")
30
+ return db_handler, data_processor, rag_system, query_rewriter, evaluation_system
31
+ except Exception as e:
32
+ logger.error(f"Error initializing components: {str(e)}")
33
+ st.error(f"Error initializing components: {str(e)}")
34
+ st.error("Please check your configuration and ensure all services are running.")
35
+ return None, None, None, None, None
36
+
37
+ components = init_components()
38
+ if components:
39
+ db_handler, data_processor, rag_system, query_rewriter, evaluation_system = components
40
+ else:
41
+ st.stop()
42
 
43
  # Ground Truth Generation
44
+
45
  def generate_questions(transcript):
 
 
46
  prompt_template = """
47
  You are an AI assistant tasked with generating questions based on a YouTube video transcript.
48
+ Formulate atleast 10 questions that a user might ask based on the provided transcript.
49
  Make the questions specific to the content of the transcript.
50
  The questions should be complete and not too short. Use as few words as possible from the transcript.
51
+ It is important that the questions are relevant to the content of the transcript and are atleast 10 in number.
52
 
53
  The transcript:
54
 
 
62
  prompt = prompt_template.format(transcript=transcript)
63
 
64
  try:
65
+ response = ollama.chat(
66
+ model='phi3.5',
67
+ messages=[{"role": "user", "content": prompt}]
68
+ )
69
+ print("Printing the response from OLLAMA: " + response['message']['content'])
70
+ return json.loads(response['message']['content'])
71
+ except Exception as e:
72
+ logger.error(f"Error generating questions: {str(e)}")
73
+ return None
74
+
75
+ def generate_ground_truth(video_id=None, existing_transcript=None):
76
+ if video_id is None and existing_transcript is None:
77
+ st.error("Please provide either a video ID or an existing transcript.")
78
  return None
79
 
80
+ if video_id:
81
+ transcript_data = get_transcript(video_id)
82
+ if transcript_data and 'transcript' in transcript_data:
83
+ full_transcript = " ".join([entry['text'] for entry in transcript_data['transcript']])
84
+ else:
85
+ logger.error("Failed to retrieve transcript for the provided video ID.")
86
+ st.error("Failed to retrieve transcript for the provided video ID.")
87
+ return None
88
+ else:
89
+ full_transcript = existing_transcript
90
+
91
+ questions = generate_questions(full_transcript)
92
 
93
+ if questions and 'questions' in questions:
94
+ df = pd.DataFrame([(video_id if video_id else "custom", q) for q in questions['questions']], columns=['video_id', 'question'])
 
95
 
96
+ os.makedirs('data', exist_ok=True)
97
+ df.to_csv('data/ground-truth-retrieval.csv', index=False)
98
+ st.success("Ground truth data generated and saved to data/ground-truth-retrieval.csv")
99
+ return df
 
 
 
 
 
100
  else:
101
+ logger.error("Failed to generate questions.")
102
+ st.error("Failed to generate questions.")
103
  return None
104
 
105
  # RAG Evaluation
 
107
  try:
108
  ground_truth = pd.read_csv('data/ground-truth-retrieval.csv')
109
  except FileNotFoundError:
110
+ logger.error("Ground truth file not found. Please generate ground truth data first.")
111
  st.error("Ground truth file not found. Please generate ground truth data first.")
112
  return None
113
 
 
137
  progress_bar = st.progress(0)
138
  for i, (_, row) in enumerate(sample.iterrows()):
139
  question = row['question']
140
+ video_id = row['video_id']
141
+
142
+ # Get the index name for the video (you might need to adjust this based on your setup)
143
+ index_name = db_handler.get_elasticsearch_index_by_youtube_id(video_id, "all-MiniLM-L6-v2") # Assuming you're using this embedding model
144
+
145
+ if not index_name:
146
+ logger.warning(f"No index found for video {video_id}. Skipping this question.")
147
+ continue
148
+
149
+ try:
150
+ answer_llm, _ = rag_system.query(question, index_name=index_name)
151
+ except ValueError as e:
152
+ logger.error(f"Error querying RAG system: {str(e)}")
153
+ continue
154
+
155
  prompt = prompt_template.format(question=question, answer_llm=answer_llm)
 
156
  try:
157
+ response = ollama.chat(
158
+ model='phi3.5',
159
+ messages=[{"role": "user", "content": prompt}]
160
+ )
161
+ evaluation_json = json.loads(response['message']['content'])
162
+ evaluations.append((
163
+ str(video_id),
164
+ str(question),
165
+ str(answer_llm),
166
+ str(evaluation_json.get('Relevance', 'UNKNOWN')),
167
+ str(evaluation_json.get('Explanation', 'No explanation provided'))
168
+ ))
169
+ except Exception as e:
170
+ logger.warning(f"Failed to evaluate question: {question}. Error: {str(e)}")
171
+ st.warning(f"Failed to evaluate question: {question}")
172
  progress_bar.progress((i + 1) / len(sample))
173
 
174
  # Store RAG evaluations in the database
 
190
  conn.commit()
191
  conn.close()
192
 
193
+ logger.info("Evaluation complete. Results stored in the database.")
194
  st.success("Evaluation complete. Results stored in the database.")
195
  return evaluations
196
 
197
+ @st.cache_data
198
+ def process_single_video(video_id, embedding_model):
199
+ # Check if the video has already been processed with the current embedding model
200
+ existing_index = db_handler.get_elasticsearch_index_by_youtube_id(video_id, embedding_model)
201
+ if existing_index:
202
+ logger.info(f"Video {video_id} has already been processed with {embedding_model}. Using existing index: {existing_index}")
203
+ return existing_index
204
+
205
+ transcript_data = get_transcript(video_id)
206
+ if transcript_data is None:
207
+ logger.error(f"Failed to retrieve transcript for video {video_id}")
208
+ return None
209
+
210
+ # Store video metadata in the database
211
+ video_data = {
212
+ 'video_id': video_id,
213
+ 'title': transcript_data['metadata'].get('title', 'Unknown Title'),
214
+ 'author': transcript_data['metadata'].get('author', 'Unknown Author'),
215
+ 'upload_date': transcript_data['metadata'].get('upload_date', 'Unknown Date'),
216
+ 'view_count': int(transcript_data['metadata'].get('view_count', 0)),
217
+ 'like_count': int(transcript_data['metadata'].get('like_count', 0)),
218
+ 'comment_count': int(transcript_data['metadata'].get('comment_count', 0)),
219
+ 'video_duration': transcript_data['metadata'].get('duration', 'Unknown Duration')
220
+ }
221
+ try:
222
+ db_handler.add_video(video_data)
223
+ except Exception as e:
224
+ logger.error(f"Error adding video to database: {str(e)}")
225
+ return None
226
+
227
+ # Process transcript for RAG system
228
+ try:
229
+ data_processor.process_transcript(video_id, transcript_data)
230
+ except Exception as e:
231
+ logger.error(f"Error processing transcript: {str(e)}")
232
+ return None
233
+
234
+ # Create Elasticsearch index
235
+ index_name = f"video_{video_id}_{embedding_model}".lower()
236
+ try:
237
+ index_name = data_processor.build_index(index_name)
238
+ logger.info(f"Successfully built index: {index_name}")
239
+ except Exception as e:
240
+ logger.error(f"Error building index: {str(e)}")
241
+ return None
242
+
243
+ # Add embedding model to the database
244
+ embedding_model_id = db_handler.add_embedding_model(embedding_model, "Description of the model")
245
+
246
+ # Get the video ID from the database
247
+ video_db_record = db_handler.get_video_by_youtube_id(video_id)
248
+ if video_db_record is None:
249
+ logger.error(f"Failed to retrieve video record from database for video {video_id}")
250
+ return None
251
+ video_db_id = video_db_record[0] # Assuming the ID is the first column
252
+
253
+ # Store Elasticsearch index information
254
+ db_handler.add_elasticsearch_index(video_db_id, index_name, embedding_model_id)
255
+
256
+ logger.info(f"Processed and indexed transcript for video {video_id}")
257
+ return index_name
258
+
259
+ @st.cache_data
260
+ def process_multiple_videos(video_ids, embedding_model):
261
+ indices = []
262
+ for video_id in video_ids:
263
+ index = process_single_video(video_id, embedding_model)
264
+ if index:
265
+ indices.append(index)
266
+ logger.info(f"Processed and indexed transcripts for {len(indices)} videos")
267
+ st.success(f"Processed and indexed transcripts for {len(indices)} videos")
268
+ return indices
269
+
270
  def main():
271
  st.title("YouTube Transcript RAG System")
272
 
273
+ components = init_components()
274
+ if not all(components):
275
+ st.error("Failed to initialize one or more components. Please check the logs and your configuration.")
276
+ return
277
+
278
+ db_handler, data_processor, rag_system, query_rewriter, evaluation_system = components
279
+
280
  tab1, tab2, tab3 = st.tabs(["RAG System", "Ground Truth Generation", "Evaluation"])
281
 
282
  with tab1:
283
  st.header("RAG System")
284
+
285
+ # Video selection section
286
+ st.subheader("Select a Video")
287
+ videos = db_handler.get_all_videos()
288
+ if not videos:
289
+ st.warning("No videos available. Please process some videos first.")
290
+ else:
291
+ video_df = pd.DataFrame(videos, columns=['youtube_id', 'title', 'channel_name', 'upload_date'])
292
+
293
+ # Allow filtering by channel name
294
+ channels = sorted(video_df['channel_name'].unique())
295
+ selected_channel = st.selectbox("Filter by Channel", ["All"] + channels)
296
+
297
+ if selected_channel != "All":
298
+ video_df = video_df[video_df['channel_name'] == selected_channel]
299
+
300
+ # Display videos and allow selection
301
+ st.dataframe(video_df)
302
+ selected_video_id = st.selectbox("Select a Video", video_df['youtube_id'].tolist(), format_func=lambda x: video_df[video_df['youtube_id'] == x]['title'].iloc[0])
303
+
304
+ # Embedding model selection
305
+ embedding_model = st.selectbox("Select embedding model:", ["all-MiniLM-L6-v2", "all-mpnet-base-v2"])
306
+
307
+ # Get the index name for the selected video and embedding model
308
+ index_name = db_handler.get_elasticsearch_index_by_youtube_id(selected_video_id, embedding_model)
309
+
310
+ if index_name:
311
+ st.success(f"Using index: {index_name}")
312
+ else:
313
+ st.warning("No index found for the selected video and embedding model. The index will be built when you search.")
314
+
315
+ # Process new video section
316
+ st.subheader("Process New Video")
317
  input_type = st.radio("Select input type:", ["Video URL", "Channel URL", "YouTube ID"])
318
  input_value = st.text_input("Enter the URL or ID:")
319
+
 
320
  if st.button("Process"):
321
  with st.spinner("Processing..."):
322
  data_processor.embedding_model = SentenceTransformer(embedding_model)
323
  if input_type == "Video URL":
324
  video_id = extract_video_id(input_value)
325
  if video_id:
326
+ index_name = process_single_video(video_id, embedding_model)
327
+ if index_name is None:
328
+ st.error(f"Failed to process video {video_id}")
329
+ else:
330
+ st.success(f"Successfully processed video {video_id}")
331
  else:
332
  st.error("Failed to extract video ID from the URL")
333
  elif input_type == "Channel URL":
334
  channel_videos = get_channel_videos(input_value)
335
  if channel_videos:
336
+ index_names = process_multiple_videos([video['video_id'] for video in channel_videos], embedding_model)
337
+ if not index_names:
338
+ st.error("Failed to process any videos from the channel")
339
+ else:
340
+ st.success(f"Successfully processed {len(index_names)} videos from the channel")
341
  else:
342
  st.error("Failed to retrieve videos from the channel")
343
  else:
344
+ index_name = process_single_video(input_value, embedding_model)
345
+ if index_name is None:
346
+ st.error(f"Failed to process video {input_value}")
347
+ else:
348
+ st.success(f"Successfully processed video {input_value}")
349
+
350
  # Query section
351
  st.subheader("Query the RAG System")
352
  query = st.text_input("Enter your query:")
 
354
  search_method = st.radio("Search method:", ["Hybrid", "Text-only", "Embedding-only"])
355
 
356
  if st.button("Search"):
357
+ if not selected_video_id:
358
+ st.error("Please select a video before searching.")
359
+ else:
360
+ with st.spinner("Searching..."):
361
+ rewritten_query = query
362
+ rewrite_prompt = ""
363
+ if rewrite_method == "Chain of Thought":
364
+ rewritten_query, rewrite_prompt = query_rewriter.rewrite_cot(query)
365
+ elif rewrite_method == "ReAct":
366
+ rewritten_query, rewrite_prompt = query_rewriter.rewrite_react(query)
367
+
368
+ st.subheader("Query Processing")
369
+ st.write("Original query:", query)
370
+ if rewrite_method != "None":
371
+ st.write("Rewritten query:", rewritten_query)
372
+ st.text_area("Query rewriting prompt:", rewrite_prompt, height=100)
373
+ if rewritten_query == query:
374
+ st.warning("Query rewriting failed. Using original query.")
375
+
376
+ search_method_map = {"Hybrid": "hybrid", "Text-only": "text", "Embedding-only": "embedding"}
377
+ try:
378
+ # Ensure index is built before searching
379
+ if not index_name:
380
+ st.info("Building index for the selected video...")
381
+ index_name = process_single_video(selected_video_id, embedding_model)
382
+ if not index_name:
383
+ st.error("Failed to build index for the selected video.")
384
+ return
385
+
386
+ response, final_prompt = rag_system.query(rewritten_query, search_method=search_method_map[search_method], index_name=index_name)
387
+
388
+ st.subheader("RAG System Prompt")
389
+ if final_prompt:
390
+ st.text_area("Prompt sent to LLM:", final_prompt, height=300)
391
+ else:
392
+ st.warning("No prompt was generated. This might indicate an issue with the RAG system.")
393
+
394
+ st.subheader("Response")
395
+ if response:
396
+ st.write(response)
397
+ else:
398
+ st.error("No response generated. Please try again or check the system logs for errors.")
399
+ except ValueError as e:
400
+ logger.error(f"Error during search: {str(e)}")
401
+ st.error(f"Error during search: {str(e)}")
402
+ except Exception as e:
403
+ logger.error(f"An unexpected error occurred: {str(e)}")
404
+ st.error(f"An unexpected error occurred: {str(e)}")
405
 
406
  with tab2:
407
  st.header("Ground Truth Generation")
408
+ use_existing_transcript = st.checkbox("Use existing transcript")
409
+
410
+ if use_existing_transcript:
411
+ # Get all available videos (assuming all videos have transcripts)
412
+ videos = db_handler.get_all_videos()
413
+ if not videos:
414
+ st.warning("No videos available. Please process some videos first.")
415
+ else:
416
+ video_df = pd.DataFrame(videos, columns=['youtube_id', 'title', 'channel_name', 'upload_date'])
417
+
418
+ # Allow filtering by channel name
419
+ channels = sorted(video_df['channel_name'].unique())
420
+ selected_channel = st.selectbox("Filter by Channel", ["All"] + channels, key="gt_channel_filter")
421
+
422
+ if selected_channel != "All":
423
+ video_df = video_df[video_df['channel_name'] == selected_channel]
424
+
425
+ # Display videos and allow selection
426
+ st.dataframe(video_df)
427
+ selected_video_id = st.selectbox("Select a Video", video_df['youtube_id'].tolist(),
428
+ format_func=lambda x: video_df[video_df['youtube_id'] == x]['title'].iloc[0],
429
+ key="gt_video_select")
430
+
431
+ if st.button("Generate Ground Truth from Existing Transcript"):
432
+ with st.spinner("Generating ground truth..."):
433
+ # Retrieve the transcript content (you'll need to implement this method)
434
+ transcript_data = get_transcript(selected_video_id)
435
+ if transcript_data and 'transcript' in transcript_data:
436
+ full_transcript = " ".join([entry['text'] for entry in transcript_data['transcript']])
437
+ ground_truth_df = generate_ground_truth(video_id=selected_video_id, existing_transcript=full_transcript)
438
+ if ground_truth_df is not None:
439
+ st.dataframe(ground_truth_df)
440
+ csv = ground_truth_df.to_csv(index=False)
441
+ st.download_button(
442
+ label="Download Ground Truth CSV",
443
+ data=csv,
444
+ file_name=f"ground_truth_{selected_video_id}.csv",
445
+ mime="text/csv",
446
+ )
447
+ else:
448
+ st.error("Failed to retrieve transcript content.")
449
+ else:
450
+ video_id = st.text_input("Enter YouTube Video ID for ground truth generation:")
451
+ if st.button("Generate Ground Truth"):
452
+ with st.spinner("Generating ground truth..."):
453
+ ground_truth_df = generate_ground_truth(video_id=video_id)
454
+ if ground_truth_df is not None:
455
+ st.dataframe(ground_truth_df)
456
+ csv = ground_truth_df.to_csv(index=False)
457
+ st.download_button(
458
+ label="Download Ground Truth CSV",
459
+ data=csv,
460
+ file_name=f"ground_truth_{video_id}.csv",
461
+ mime="text/csv",
462
+ )
463
 
464
  with tab3:
465
  st.header("RAG Evaluation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
+ # Load ground truth data
468
+ try:
469
+ ground_truth_df = pd.read_csv('data/ground-truth-retrieval.csv')
470
+ ground_truth_available = True
471
+ except FileNotFoundError:
472
+ ground_truth_available = False
 
 
 
 
 
 
 
 
473
 
474
+ if ground_truth_available:
475
+ st.write("Evaluation will be run on the following ground truth data:")
476
+ st.dataframe(ground_truth_df)
477
+ st.info("The evaluation will use this ground truth data to assess the performance of the RAG system.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
 
479
+ sample_size = st.number_input("Enter sample size for evaluation:", min_value=1, max_value=len(ground_truth_df), value=min(200, len(ground_truth_df)))
480
+
481
+ if st.button("Run Evaluation"):
482
+ with st.spinner("Running evaluation..."):
483
+ evaluation_results = evaluate_rag(sample_size)
484
+ if evaluation_results:
485
+ st.write("Evaluation Results:")
486
+ st.dataframe(pd.DataFrame(evaluation_results, columns=['Video ID', 'Question', 'Answer', 'Relevance', 'Explanation']))
487
+ else:
488
+ st.warning("No ground truth data available. Please generate ground truth data first.")
489
+ st.button("Run Evaluation", disabled=True)
490
+
491
+ # Add a section to generate ground truth if it's not available
492
+ if not ground_truth_available:
493
+ st.subheader("Generate Ground Truth")
494
+ st.write("You need to generate ground truth data before running the evaluation.")
495
+ if st.button("Go to Ground Truth Generation"):
496
+ st.session_state.active_tab = "Ground Truth Generation"
497
+ st.experimental_rerun()
498
 
499
  if __name__ == "__main__":
500
  main()
app/query_rewriter.py CHANGED
@@ -1,8 +1,24 @@
 
1
  import ollama
 
 
 
2
 
3
  class QueryRewriter:
4
  def __init__(self):
5
- self.model = "phi" # Using Phi-3.5 model
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def rewrite_cot(self, query):
8
  prompt = f"""
@@ -11,8 +27,11 @@ class QueryRewriter:
11
 
12
  Rewritten query:
13
  """
14
- response = ollama.generate(model=self.model, prompt=prompt)
15
- return response['response'].strip()
 
 
 
16
 
17
  def rewrite_react(self, query):
18
  prompt = f"""
@@ -29,5 +48,8 @@ class QueryRewriter:
29
 
30
  Final rewritten query:
31
  """
32
- response = ollama.generate(model=self.model, prompt=prompt)
33
- return response['response'].strip()
 
 
 
 
1
+ import os
2
  import ollama
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
 
7
  class QueryRewriter:
8
  def __init__(self):
9
+ self.model = os.getenv('OLLAMA_MODEL', "phi3")
10
+ self.ollama_host = os.getenv('OLLAMA_HOST', 'http://ollama:11434')
11
+
12
+ def generate(self, prompt):
13
+ try:
14
+ response = ollama.chat(
15
+ model=self.model,
16
+ messages=[{"role": "user", "content": prompt}]
17
+ )
18
+ return response['message']['content']
19
+ except Exception as e:
20
+ logger.error(f"Error generating response: {e}")
21
+ return f"Error: {str(e)}"
22
 
23
  def rewrite_cot(self, query):
24
  prompt = f"""
 
27
 
28
  Rewritten query:
29
  """
30
+ rewritten_query = self.generate(prompt)
31
+ if rewritten_query.startswith("Error:"):
32
+ logger.error(f"Error in CoT rewriting: {rewritten_query}")
33
+ return query, prompt # Return original query if rewriting fails
34
+ return rewritten_query, prompt
35
 
36
  def rewrite_react(self, query):
37
  prompt = f"""
 
48
 
49
  Final rewritten query:
50
  """
51
+ rewritten_query = self.generate(prompt)
52
+ if rewritten_query.startswith("Error:"):
53
+ logger.error(f"Error in ReAct rewriting: {rewritten_query}")
54
+ return query, prompt # Return original query if rewriting fails
55
+ return rewritten_query, prompt
app/rag.py CHANGED
@@ -1,31 +1,108 @@
 
 
1
  import ollama
 
 
 
 
 
 
2
 
3
  class RAGSystem:
4
  def __init__(self, data_processor):
5
  self.data_processor = data_processor
6
- self.model = "phi3.5" # Using Phi-3.5 model
7
-
8
- def query(self, user_query, top_k=3, search_method='hybrid'):
9
- # Retrieve relevant documents using the specified search method
10
- relevant_docs = self.data_processor.search(user_query, num_results=top_k, method=search_method)
11
 
12
- # Construct the prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  context = "\n".join([doc['content'] for doc in relevant_docs])
14
- prompt = f"Context: {context}\n\nQuestion: {user_query}\n\nAnswer:"
15
-
16
- # Generate response using Ollama
17
- response = ollama.generate(model=self.model, prompt=prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- return response['response']
20
-
21
- def rerank_documents(self, documents, query):
22
- # Implement a simple re-ranking strategy
23
- # This could be improved with more sophisticated methods
24
- reranked = sorted(documents, key=lambda doc: self.calculate_relevance(doc['content'], query), reverse=True)
25
- return reranked
26
-
27
- def calculate_relevance(self, document, query):
28
- # Simple relevance calculation based on word overlap
29
- doc_words = set(document.lower().split())
30
- query_words = set(query.lower().split())
31
- return len(doc_words.intersection(query_words)) / len(query_words)
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
  import ollama
4
+ import logging
5
+ import time
6
+
7
+ load_dotenv()
8
+
9
+ logger = logging.getLogger(__name__)
10
 
11
  class RAGSystem:
12
  def __init__(self, data_processor):
13
  self.data_processor = data_processor
14
+ self.model = os.getenv('OLLAMA_MODEL', 'phi3')
15
+ self.ollama_host = os.getenv('OLLAMA_HOST', 'http://ollama:11434')
16
+ self.timeout = int(os.getenv('OLLAMA_TIMEOUT', 240))
17
+ self.max_retries = int(os.getenv('OLLAMA_MAX_RETRIES', 3))
 
18
 
19
+ self.check_ollama_service()
20
+
21
+ def check_ollama_service(self):
22
+ try:
23
+ ollama.list()
24
+ logger.info("Ollama service is accessible.")
25
+ self.pull_model()
26
+ except Exception as e:
27
+ logger.error(f"An error occurred while connecting to Ollama: {e}")
28
+ logger.error(f"Please ensure Ollama is running and accessible at {self.ollama_host}")
29
+
30
+ def pull_model(self):
31
+ try:
32
+ ollama.pull(self.model)
33
+ logger.info(f"Successfully pulled model {self.model}.")
34
+ except Exception as e:
35
+ logger.error(f"Error pulling model {self.model}: {e}")
36
+
37
+ def generate(self, prompt):
38
+ for attempt in range(self.max_retries):
39
+ try:
40
+ response = ollama.chat(
41
+ model=self.model,
42
+ messages=[{"role": "user", "content": prompt}]
43
+ )
44
+ print("Printing the response from OLLAMA: "+response['message']['content'])
45
+ return response['message']['content']
46
+ except Exception as e:
47
+ logger.error(f"Error generating response on attempt {attempt + 1}: {e}")
48
+ if attempt == self.max_retries - 1:
49
+ logger.error("All retries exhausted. Unable to generate response.")
50
+ return None
51
+ time.sleep(2 ** attempt) # Exponential backoff
52
+
53
+ def get_prompt(self, user_query, relevant_docs):
54
  context = "\n".join([doc['content'] for doc in relevant_docs])
55
+ prompt = f"""You are AI Youtube transcript assistant that analyses youtube transcripts and responds back to the user query based on the Context shared with you. Please ensure that the answers are correct, meaningful, and help in answering the query.
56
+
57
+ Context: {context}
58
+
59
+ Question: {user_query}
60
+
61
+ Answer:"""
62
+ return prompt
63
+
64
+ def query(self, user_query, search_method='hybrid', index_name=None):
65
+ try:
66
+ if not index_name:
67
+ raise ValueError("No index name provided. Please select a video and ensure it has been processed.")
68
+
69
+ relevant_docs = self.data_processor.search(user_query, num_results=3, method=search_method, index_name=index_name)
70
+
71
+ if not relevant_docs:
72
+ logger.warning("No relevant documents found for the query.")
73
+ return "I couldn't find any relevant information to answer your query.", ""
74
+
75
+ prompt = self.get_prompt(user_query, relevant_docs)
76
+
77
+ response = ollama.chat(
78
+ model=self.model,
79
+ messages=[{"role": "user", "content": prompt}]
80
+ )
81
+
82
+ answer = response['message']['content']
83
+ return answer, prompt
84
+ except Exception as e:
85
+ logger.error(f"An error occurred in the RAG system: {e}")
86
+ return f"An error occurred: {str(e)}", ""
87
 
88
+ def rewrite_cot(self, query):
89
+ prompt = f"""Rewrite the following query using chain-of-thought reasoning:
90
+
91
+ Query: {query}
92
+
93
+ Rewritten query:"""
94
+ response = self.generate(prompt)
95
+ if response:
96
+ return response, prompt
97
+ return query, prompt # Return original query if rewriting fails
98
+
99
+ def rewrite_react(self, query):
100
+ prompt = f"""Rewrite the following query using ReAct (Reasoning and Acting) approach:
101
+
102
+ Query: {query}
103
+
104
+ Rewritten query:"""
105
+ response = self.generate(prompt)
106
+ if response:
107
+ return response, prompt
108
+ return query, prompt # Return original query if rewriting fails
app/transcript_extractor.py CHANGED
@@ -1,15 +1,35 @@
 
 
1
  from youtube_transcript_api import YouTubeTranscriptApi
2
  from googleapiclient.discovery import build
3
  from googleapiclient.errors import HttpError
4
  import re
5
- import os
6
 
7
- # Replace with your actual API key
8
- API_KEY = os.environ.get('YOUTUBE_API_KEY', 'YOUR_API_KEY_HERE')
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- youtube = build('youtube', 'v3', developerKey=API_KEY)
 
 
 
 
 
 
11
 
12
  def extract_video_id(url):
 
 
13
  video_id_match = re.search(r"(?:v=|\/)([0-9A-Za-z_-]{11}).*", url)
14
  if video_id_match:
15
  return video_id_match.group(1)
@@ -30,21 +50,53 @@ def get_video_metadata(video_id):
30
  'title': snippet['title'],
31
  'author': snippet['channelTitle'],
32
  'upload_date': snippet['publishedAt'],
33
- 'view_count': video['statistics']['viewCount'],
34
- 'like_count': video['statistics'].get('likeCount', 'N/A'),
35
- 'comment_count': video['statistics'].get('commentCount', 'N/A'),
36
  'duration': video['contentDetails']['duration']
37
  }
38
  else:
 
39
  return None
40
  except HttpError as e:
41
  print(f"An HTTP error {e.resp.status} occurred: {e.content}")
42
  return None
 
 
 
43
 
44
  def get_transcript(video_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  try:
46
  transcript = YouTubeTranscriptApi.get_transcript(video_id)
47
  metadata = get_video_metadata(video_id)
 
 
 
 
48
  return {
49
  'transcript': transcript,
50
  'metadata': metadata
@@ -53,7 +105,11 @@ def get_transcript(video_id):
53
  print(f"Error extracting transcript for video {video_id}: {str(e)}")
54
  return None
55
 
56
- def get_channel_videos(channel_id):
 
 
 
 
57
  try:
58
  request = youtube.search().list(
59
  part="id,snippet",
@@ -75,6 +131,15 @@ def get_channel_videos(channel_id):
75
  except HttpError as e:
76
  print(f"An HTTP error {e.resp.status} occurred: {e.content}")
77
  return []
 
 
 
 
 
 
 
 
 
78
 
79
  def process_videos(video_ids):
80
  transcripts = {}
 
1
+ import os
2
+ from dotenv import load_dotenv
3
  from youtube_transcript_api import YouTubeTranscriptApi
4
  from googleapiclient.discovery import build
5
  from googleapiclient.errors import HttpError
6
  import re
 
7
 
8
+ # Get the directory of the current script
9
+ current_dir = os.path.dirname(os.path.abspath(__file__))
10
+ # Construct the path to the .env file (one directory up from the current script)
11
+ dotenv_path = os.path.join(os.path.dirname(current_dir), '.env')
12
+ print("the .env path is :" + dotenv_path)
13
+ # Load environment variables from .env file
14
+ load_dotenv(dotenv_path)
15
+
16
+ # Get API key from environment variable
17
+ API_KEY = os.getenv('YOUTUBE_API_KEY')
18
+ print("the api key is :" + API_KEY)
19
+ if not API_KEY:
20
+ raise ValueError("YouTube API key not found. Make sure it's set in your .env file in the parent directory of the 'app' folder.")
21
 
22
+ print(f"API_KEY: {API_KEY[:5]}...{API_KEY[-5:]}") # Print first and last 5 characters for verification
23
+
24
+ try:
25
+ youtube = build('youtube', 'v3', developerKey=API_KEY)
26
+ except Exception as e:
27
+ print(f"Error initializing YouTube API client: {str(e)}")
28
+ raise
29
 
30
  def extract_video_id(url):
31
+ if not url:
32
+ return None
33
  video_id_match = re.search(r"(?:v=|\/)([0-9A-Za-z_-]{11}).*", url)
34
  if video_id_match:
35
  return video_id_match.group(1)
 
50
  'title': snippet['title'],
51
  'author': snippet['channelTitle'],
52
  'upload_date': snippet['publishedAt'],
53
+ 'view_count': video['statistics'].get('viewCount', '0'),
54
+ 'like_count': video['statistics'].get('likeCount', '0'),
55
+ 'comment_count': video['statistics'].get('commentCount', '0'),
56
  'duration': video['contentDetails']['duration']
57
  }
58
  else:
59
+ print(f"No video found with ID: {video_id}")
60
  return None
61
  except HttpError as e:
62
  print(f"An HTTP error {e.resp.status} occurred: {e.content}")
63
  return None
64
+ except Exception as e:
65
+ print(f"An error occurred while fetching video metadata: {str(e)}")
66
+ return None
67
 
68
  def get_transcript(video_id):
69
+ # Get the directory of the current script
70
+ current_dir = os.path.dirname(os.path.abspath(__file__))
71
+ # Construct the path to the .env file (one directory up from the current script)
72
+ dotenv_path = os.path.join(os.path.dirname(current_dir), '.env')
73
+ print("the .env path is :" + dotenv_path)
74
+ # Load environment variables from .env file
75
+ load_dotenv(dotenv_path)
76
+
77
+ # Get API key from environment variable
78
+ API_KEY = os.getenv('YOUTUBE_API_KEY')
79
+ print("the api key is :" + API_KEY)
80
+ if not API_KEY:
81
+ raise ValueError("YouTube API key not found. Make sure it's set in your .env file in the parent directory of the 'app' folder.")
82
+
83
+ print(f"API_KEY: {API_KEY[:5]}...{API_KEY[-5:]}") # Print first and last 5 characters for verification
84
+
85
+ try:
86
+ youtube = build('youtube', 'v3', developerKey=API_KEY)
87
+ except Exception as e:
88
+ print(f"Error initializing YouTube API client: {str(e)}")
89
+ raise
90
+
91
+ if not video_id:
92
+ return None
93
  try:
94
  transcript = YouTubeTranscriptApi.get_transcript(video_id)
95
  metadata = get_video_metadata(video_id)
96
+ print(f"Metadata for video {video_id}: {metadata}")
97
+ print(f"Transcript length for video {video_id}: {len(transcript)}")
98
+ if not metadata:
99
+ return None
100
  return {
101
  'transcript': transcript,
102
  'metadata': metadata
 
105
  print(f"Error extracting transcript for video {video_id}: {str(e)}")
106
  return None
107
 
108
+ def get_channel_videos(channel_url):
109
+ channel_id = extract_channel_id(channel_url)
110
+ if not channel_id:
111
+ print(f"Invalid channel URL: {channel_url}")
112
+ return []
113
  try:
114
  request = youtube.search().list(
115
  part="id,snippet",
 
131
  except HttpError as e:
132
  print(f"An HTTP error {e.resp.status} occurred: {e.content}")
133
  return []
134
+ except Exception as e:
135
+ print(f"An error occurred while fetching channel videos: {str(e)}")
136
+ return []
137
+
138
+ def extract_channel_id(url):
139
+ channel_id_match = re.search(r"(?:channel\/|c\/|@)([a-zA-Z0-9-_]+)", url)
140
+ if channel_id_match:
141
+ return channel_id_match.group(1)
142
+ return None
143
 
144
  def process_videos(video_ids):
145
  transcripts = {}
data/ground-truth-retrieval.csv ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ video_id,question
2
+ zjkBMFhNj_g,What are Google Apps Script and its relation to user data security within a domain?
3
+ zjkBMFhNj_g,"How can prompt injection attacks manipulate language models' outputs using shared documents like those managed by Gmail users or Microsoft Office files (Word, Excel)?"
4
+ zjkBMFhNj_g,"In the context of AI-based systems such as large language models (LLMs), how might an attacker exploit these tools to exfiltrate sensitive user data from a Google Doc? Please provide details."
5
+ zjkBMFhNj_g,"Can you explain prompt injection attacks and their potential impact on LLM predictions, including any specific examples provided in the discussion like using 'James Bond' as a trigger phrase for threat detection tasks or title generation?"
6
+ zjkBMFhNj_g,Are there defenses against these types of language model (LLM) security threats similar to traditional cybersecurity measures such as prompt injection attacks and data poisoning? Please elaborate.
7
+ zjkBMFhNj_g,"What does the future hold for LLMs considering their benefits, potential risks including adversarial exploitation like those discussed here, regulatory oversight needs due to privacy concerns (GDPR), mitigation of harmful outputs by these models in various applications?"
data/sqlite.db CHANGED
Binary files a/data/sqlite.db and b/data/sqlite.db differ
 
docker-compose.yaml CHANGED
@@ -7,10 +7,14 @@ services:
7
  - "8501:8501"
8
  depends_on:
9
  - elasticsearch
 
10
  environment:
11
  - ELASTICSEARCH_HOST=elasticsearch
12
  - ELASTICSEARCH_PORT=9200
13
  - YOUTUBE_API_KEY=${YOUTUBE_API_KEY}
 
 
 
14
  env_file:
15
  - .env
16
  volumes:
@@ -37,6 +41,14 @@ services:
37
  depends_on:
38
  - elasticsearch
39
 
 
 
 
 
 
 
 
40
  volumes:
41
  esdata:
42
- grafana-storage:
 
 
7
  - "8501:8501"
8
  depends_on:
9
  - elasticsearch
10
+ - ollama
11
  environment:
12
  - ELASTICSEARCH_HOST=elasticsearch
13
  - ELASTICSEARCH_PORT=9200
14
  - YOUTUBE_API_KEY=${YOUTUBE_API_KEY}
15
+ - OLLAMA_HOST=http://ollama:11434
16
+ - OLLAMA_TIMEOUT=${OLLAMA_TIMEOUT:-120}
17
+ - OLLAMA_MAX_RETRIES=${OLLAMA_MAX_RETRIES:-3}
18
  env_file:
19
  - .env
20
  volumes:
 
41
  depends_on:
42
  - elasticsearch
43
 
44
+ ollama:
45
+ image: ollama/ollama:latest
46
+ ports:
47
+ - "11434:11434"
48
+ volumes:
49
+ - ollama_data:/root/.ollama
50
+
51
  volumes:
52
  esdata:
53
+ grafana-storage:
54
+ ollama_data:
requirements.txt CHANGED
@@ -11,4 +11,5 @@ elasticsearch
11
  ollama
12
  requests
13
  matplotlib
14
- tqdm
 
 
11
  ollama
12
  requests
13
  matplotlib
14
+ tqdm
15
+ python-dotenv