Michaelj1 commited on
Commit
55b3aef
·
1 Parent(s): eab4a95

download submodule

Browse files
Files changed (1) hide show
  1. app.py +15 -0
app.py CHANGED
@@ -7,9 +7,11 @@ import gradio as gr
7
  import matplotlib.pyplot as plt
8
  import tempfile
9
  import os
 
10
 
11
  class MedicalRAG:
12
  def __init__(self, embed_path, pmids_path, content_path):
 
13
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
  # Load data
15
  self.embeddings = np.load(embed_path)
@@ -18,6 +20,19 @@ class MedicalRAG:
18
  # Setup models
19
  self.encoder, self.tokenizer = self._setup_encoder()
20
  self.generator = self._setup_generator()
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def _create_faiss_index(self, embeddings):
23
  index = faiss.IndexFlatIP(768) # 768 is embedding dimension
 
7
  import matplotlib.pyplot as plt
8
  import tempfile
9
  import os
10
+ import subprocess
11
 
12
  class MedicalRAG:
13
  def __init__(self, embed_path, pmids_path, content_path):
14
+ self.download_files()
15
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
  # Load data
17
  self.embeddings = np.load(embed_path)
 
20
  # Setup models
21
  self.encoder, self.tokenizer = self._setup_encoder()
22
  self.generator = self._setup_generator()
23
+ def download_files(self):
24
+ urls = [
25
+ "https://ftp.ncbi.nlm.nih.gov/pub/lu/MedCPT/pubmed_embeddings/embeds_chunk_36.npy",
26
+ "https://ftp.ncbi.nlm.nih.gov/pub/lu/MedCPT/pubmed_embeddings/pmids_chunk_36.json",
27
+ "https://ftp.ncbi.nlm.nih.gov/pub/lu/MedCPT/pubmed_embeddings/pubmed_chunk_36.json"
28
+ ]
29
+ for url in urls:
30
+ file_name = url.split('/')[-1]
31
+ if not os.path.exists(file_name):
32
+ print(f"Downloading {file_name}...")
33
+ subprocess.run(["wget", url], check=True)
34
+ else:
35
+ print(f"{file_name} already exists. Skipping download.")
36
 
37
  def _create_faiss_index(self, embeddings):
38
  index = faiss.IndexFlatIP(768) # 768 is embedding dimension