Hammad712 commited on
Commit
a347f56
·
verified ·
1 Parent(s): ff08c00

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +33 -30
main.py CHANGED
@@ -25,36 +25,41 @@ class QueryRequest(BaseModel):
25
  question: str
26
 
27
 
28
- def _unpack_faiss(src: str, dest_dir: str) -> str:
29
  """
30
- If src is a .zip, unzip it into dest_dir and return
31
- the path to the extracted FAISS folder. Otherwise
32
- assume src is already a folder and return it.
33
  """
34
- if zipfile.is_zipfile(src):
35
- with zipfile.ZipFile(src, "r") as zf:
36
- zf.extractall(dest_dir)
37
- # if there’s exactly one subfolder, use it
38
- items = os.listdir(dest_dir)
39
- if len(items) == 1 and os.path.isdir(os.path.join(dest_dir, items[0])):
40
- return os.path.join(dest_dir, items[0])
41
- return dest_dir
42
- else:
43
- # src is already a directory
44
- return src
 
 
 
 
 
 
 
45
 
46
 
47
  @app.on_event("startup")
48
  def load_components():
49
  global llm, embeddings, vectorstore, retriever, chain
50
 
51
- # --- 1) Initialize LLM & Embeddings ---
52
- api_key = os.getenv("api_key")
53
  llm = ChatGroq(
54
  model="meta-llama/llama-4-scout-17b-16e-instruct",
55
  temperature=0,
56
  max_tokens=1024,
57
- api_key=api_key,
58
  )
59
 
60
  embeddings = HuggingFaceEmbeddings(
@@ -64,25 +69,23 @@ def load_components():
64
  )
65
 
66
  # --- 2) Load & merge two FAISS indexes ---
67
- # Paths to your two vectorstores (could be .zip or folders)
68
  src1 = "faiss_index.zip"
69
  src2 = "faiss_index_extra.zip"
70
 
71
- # Temporary dirs for extraction
72
- tmp1 = tempfile.mkdtemp()
73
- tmp2 = tempfile.mkdtemp()
74
 
75
- # Unpack and load each
76
- path1 = _unpack_faiss(src1, tmp1)
77
- vs1 = FAISS.load_local(path1, embeddings, allow_dangerous_deserialization=True)
78
 
79
- path2 = _unpack_faiss(src2, tmp2)
80
- vs2 = FAISS.load_local(path2, embeddings, allow_dangerous_deserialization=True)
 
81
 
82
  # Merge vs2 into vs1
83
  vs1.merge_from(vs2)
84
-
85
- # Assign the merged store to our global
86
  vectorstore = vs1
87
 
88
  # --- 3) Build retriever & QA chain ---
@@ -114,7 +117,7 @@ Your response:
114
  chain_type_kwargs={"prompt": prompt},
115
  )
116
 
117
- print("✅ Loaded and merged both FAISS indexes, QA chain is ready.")
118
 
119
 
120
  @app.get("/")
 
25
  question: str
26
 
27
 
28
+ def _unpack_faiss(src_path: str, extract_to: str) -> str:
29
  """
30
+ If src_path is a .zip, unzip to extract_to and return the directory
31
+ containing the .faiss file. If it's already a folder, just return it.
 
32
  """
33
+ # 1) ZIP case
34
+ if src_path.lower().endswith(".zip"):
35
+ if not os.path.isfile(src_path):
36
+ raise FileNotFoundError(f"Could not find zip file: {src_path}")
37
+ with zipfile.ZipFile(src_path, "r") as zf:
38
+ zf.extractall(extract_to)
39
+
40
+ # walk until we find any .faiss file
41
+ for root, _, files in os.walk(extract_to):
42
+ if any(fn.endswith(".faiss") for fn in files):
43
+ return root
44
+ raise RuntimeError(f"No .faiss index found inside {src_path}")
45
+
46
+ # 2) directory case
47
+ if os.path.isdir(src_path):
48
+ return src_path
49
+
50
+ raise RuntimeError(f"Path is neither a .zip nor a directory: {src_path}")
51
 
52
 
53
  @app.on_event("startup")
54
  def load_components():
55
  global llm, embeddings, vectorstore, retriever, chain
56
 
57
+ # --- 1) Init LLM & Embeddings ---
 
58
  llm = ChatGroq(
59
  model="meta-llama/llama-4-scout-17b-16e-instruct",
60
  temperature=0,
61
  max_tokens=1024,
62
+ api_key=os.getenv("api_key"),
63
  )
64
 
65
  embeddings = HuggingFaceEmbeddings(
 
69
  )
70
 
71
  # --- 2) Load & merge two FAISS indexes ---
 
72
  src1 = "faiss_index.zip"
73
  src2 = "faiss_index_extra.zip"
74
 
75
+ # Use TemporaryDirectory objects so they stick around until program exit
76
+ tmp1 = tempfile.TemporaryDirectory()
77
+ tmp2 = tempfile.TemporaryDirectory()
78
 
79
+ # Unpack & locate
80
+ dir1 = _unpack_faiss(src1, tmp1.name)
81
+ dir2 = _unpack_faiss(src2, tmp2.name)
82
 
83
+ # Load them
84
+ vs1 = FAISS.load_local(dir1, embeddings, allow_dangerous_deserialization=True)
85
+ vs2 = FAISS.load_local(dir2, embeddings, allow_dangerous_deserialization=True)
86
 
87
  # Merge vs2 into vs1
88
  vs1.merge_from(vs2)
 
 
89
  vectorstore = vs1
90
 
91
  # --- 3) Build retriever & QA chain ---
 
117
  chain_type_kwargs={"prompt": prompt},
118
  )
119
 
120
+ print("✅ Loaded & merged both FAISS indexes, QA chain ready.")
121
 
122
 
123
  @app.get("/")