CosmickVisions commited on
Commit
645418b
·
verified ·
1 Parent(s): 8b2cff8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -8
app.py CHANGED
@@ -5,7 +5,7 @@ import tempfile
5
  import uuid
6
  from dotenv import load_dotenv
7
  from langchain_community.vectorstores import FAISS
8
- from langchain_community.embeddings import HuggingFaceEmbeddings
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  import fitz # PyMuPDF
11
  import base64
@@ -16,11 +16,26 @@ import json
16
  import re
17
  from datetime import datetime, timedelta
18
  from pathlib import Path
 
19
 
20
  # Load environment variables
21
  load_dotenv()
22
  client = groq.Client(api_key=os.getenv("GROQ_TECH_API_KEY"))
23
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Directory to store FAISS indexes
26
  FAISS_INDEX_DIR = "faiss_indexes_tech"
@@ -116,11 +131,14 @@ def generate_response(message, session_id, model_name, history):
116
  return history
117
  try:
118
  context = ""
119
- if session_id and session_id in user_vectorstores:
120
- vectorstore = user_vectorstores[session_id]
121
- docs = vectorstore.similarity_search(message, k=3)
122
- if docs:
123
- context = "\n\nRelevant information from uploaded PDF:\n" + "\n".join(f"- {doc.page_content}" for doc in docs)
 
 
 
124
 
125
  # Check if it's a GitHub repo search
126
  if re.match(r'^/github\s+.+', message, re.IGNORECASE):
@@ -447,7 +465,23 @@ def process_code_file(file_obj):
447
  # Calculate metrics
448
  metrics = calculate_complexity_metrics(content, language)
449
 
450
- return str(uuid.uuid4()), f"✅ Successfully analyzed {file_obj.name}", metrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  except Exception as e:
452
  return None, f"Error processing file: {str(e)}", {}
453
 
 
5
  import uuid
6
  from dotenv import load_dotenv
7
  from langchain_community.vectorstores import FAISS
8
+ from langchain_community.embeddings import HuggingFaceInstructEmbeddings
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
10
  import fitz # PyMuPDF
11
  import base64
 
16
  import re
17
  from datetime import datetime, timedelta
18
  from pathlib import Path
19
+ import torch
20
 
21
  # Load environment variables
22
  load_dotenv()
23
  client = groq.Client(api_key=os.getenv("GROQ_TECH_API_KEY"))
24
+
25
+ # Replace the embeddings initialization
26
+ try:
27
+ # Initialize embeddings with a simpler, more reliable model
28
+ embeddings = HuggingFaceInstructEmbeddings(
29
+ model_name="hkunlp/instructor-base",
30
+ model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}
31
+ )
32
+ except Exception as e:
33
+ print(f"Warning: Failed to load primary embeddings model: {e}")
34
+ # Fallback to a basic model
35
+ embeddings = HuggingFaceInstructEmbeddings(
36
+ model_name="all-MiniLM-L6-v2",
37
+ model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}
38
+ )
39
 
40
  # Directory to store FAISS indexes
41
  FAISS_INDEX_DIR = "faiss_indexes_tech"
 
131
  return history
132
  try:
133
  context = ""
134
+ if embeddings and session_id and session_id in user_vectorstores: # Check if embeddings exist
135
+ try:
136
+ vectorstore = user_vectorstores[session_id]
137
+ docs = vectorstore.similarity_search(message, k=3)
138
+ if docs:
139
+ context = "\n\nRelevant information from uploaded code:\n" + "\n".join(f"- {doc.page_content}" for doc in docs)
140
+ except Exception as e:
141
+ print(f"Warning: Failed to perform similarity search: {e}")
142
 
143
  # Check if it's a GitHub repo search
144
  if re.match(r'^/github\s+.+', message, re.IGNORECASE):
 
465
  # Calculate metrics
466
  metrics = calculate_complexity_metrics(content, language)
467
 
468
+ # Only create vectorstore if embeddings are available
469
+ if embeddings:
470
+ try:
471
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
472
+ chunks = text_splitter.create_documents([content])
473
+ vectorstore = FAISS.from_documents(chunks, embeddings)
474
+ session_id = str(uuid.uuid4())
475
+ index_path = os.path.join(FAISS_INDEX_DIR, session_id)
476
+ vectorstore.save_local(index_path)
477
+ user_vectorstores[session_id] = vectorstore
478
+ except Exception as e:
479
+ print(f"Warning: Failed to create vectorstore: {e}")
480
+ session_id = None
481
+ else:
482
+ session_id = None
483
+
484
+ return session_id, f"✅ Successfully analyzed {file_obj.name}", metrics
485
  except Exception as e:
486
  return None, f"Error processing file: {str(e)}", {}
487