download submodule
Browse files
app.py
CHANGED
@@ -7,9 +7,11 @@ import gradio as gr
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import tempfile
|
9 |
import os
|
|
|
10 |
|
11 |
class MedicalRAG:
|
12 |
def __init__(self, embed_path, pmids_path, content_path):
|
|
|
13 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
# Load data
|
15 |
self.embeddings = np.load(embed_path)
|
@@ -18,6 +20,19 @@ class MedicalRAG:
|
|
18 |
# Setup models
|
19 |
self.encoder, self.tokenizer = self._setup_encoder()
|
20 |
self.generator = self._setup_generator()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def _create_faiss_index(self, embeddings):
|
23 |
index = faiss.IndexFlatIP(768) # 768 is embedding dimension
|
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import tempfile
|
9 |
import os
|
10 |
+
import subprocess
|
11 |
|
12 |
class MedicalRAG:
|
13 |
def __init__(self, embed_path, pmids_path, content_path):
|
14 |
+
self.download_files()
|
15 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
# Load data
|
17 |
self.embeddings = np.load(embed_path)
|
|
|
20 |
# Setup models
|
21 |
self.encoder, self.tokenizer = self._setup_encoder()
|
22 |
self.generator = self._setup_generator()
|
23 |
+
def download_files(self):
|
24 |
+
urls = [
|
25 |
+
"https://ftp.ncbi.nlm.nih.gov/pub/lu/MedCPT/pubmed_embeddings/embeds_chunk_36.npy",
|
26 |
+
"https://ftp.ncbi.nlm.nih.gov/pub/lu/MedCPT/pubmed_embeddings/pmids_chunk_36.json",
|
27 |
+
"https://ftp.ncbi.nlm.nih.gov/pub/lu/MedCPT/pubmed_embeddings/pubmed_chunk_36.json"
|
28 |
+
]
|
29 |
+
for url in urls:
|
30 |
+
file_name = url.split('/')[-1]
|
31 |
+
if not os.path.exists(file_name):
|
32 |
+
print(f"Downloading {file_name}...")
|
33 |
+
subprocess.run(["wget", url], check=True)
|
34 |
+
else:
|
35 |
+
print(f"{file_name} already exists. Skipping download.")
|
36 |
|
37 |
def _create_faiss_index(self, embeddings):
|
38 |
index = faiss.IndexFlatIP(768) # 768 is embedding dimension
|