File size: 7,212 Bytes
9889643 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
from fastapi import FastAPI
from pydantic import BaseModel
import pandas as pd
from sentence_transformers import SentenceTransformer
import chromadb
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import requests
# Define FastAPI app
app = FastAPI()
origins = [
"http://localhost:5173",
"localhost:5173"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
# Load the dataset and model at startup
df = pd.read_csv("hf://datasets/QuyenAnhDE/Diseases_Symptoms/Diseases_Symptoms.csv")
df['Symptoms'] = df['Symptoms'].str.split(',')
df['Symptoms'] = df['Symptoms'].apply(lambda x: [s.strip() for s in x])
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
client = chromadb.PersistentClient(path='./chromadb')
collection = client.get_or_create_collection(name="symptomsvector")
class SymptomQuery(BaseModel):
symptom: str
# Endpoint to handle symptom query and return matching symptoms
@app.post("/find_matching_symptoms")
def find_matching_symptoms(query: SymptomQuery):
# Generate embedding for the symptom query
symptoms = query.symptom.split(',')
all_results = []
for symptom in symptoms:
symptom = symptom.strip()
query_embedding = model.encode([symptom])
# Perform similarity search in ChromaDB
results = collection.query(
query_embeddings=query_embedding.tolist(),
n_results=3 # Return top 3 similar symptoms for each symptom
)
all_results.extend(results['documents'][0])
# Remove duplicates while preserving order
matching_symptoms = list(dict.fromkeys(all_results))
return {"matching_symptoms": matching_symptoms}
# Endpoint to handle symptom query and return matching diseases
@app.post("/find_matching_diseases")
def find_matching_diseases(query: SymptomQuery):
# Generate embedding for the symptom query
query_embedding = model.encode([query.symptom])
# Perform similarity search in ChromaDB
results = collection.query(
query_embeddings=query_embedding.tolist(),
n_results=5 # Return top 5 similar symptoms
)
# Extract matching symptoms
matching_symptoms = results['documents'][0]
# Filter diseases that match the symptoms
matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in matching_symptoms for s in x))]
return {"matching_diseases": matching_diseases['Name'].tolist()}
# Endpoint to handle symptom query and return detailed disease list
@app.post("/find_disease_list")
def find_disease_list(query: SymptomQuery):
# Generate embedding for the symptom query
query_embedding = model.encode([query.symptom])
# Perform similarity search in ChromaDB
results = collection.query(
query_embeddings=query_embedding.tolist(),
n_results=5 # Return top 5 similar symptoms
)
# Extract matching symptoms
matching_symptoms = results['documents'][0]
# Filter diseases that match the symptoms
matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in matching_symptoms for s in x))]
# Create a list of disease information
disease_list = []
symptoms_list = []
unique_symptoms_list = []
for _, row in matching_diseases.iterrows():
disease_info = {
'Disease': row['Name'],
'Symptoms': row['Symptoms'],
'Treatments': row['Treatments']
}
disease_list.append(disease_info)
symptoms_info = row['Symptoms']
symptoms_list.append(symptoms_info)
for i in range(len(symptoms_list)):
for j in range(len(symptoms_list[i])):
if symptoms_list[i][j] not in unique_symptoms_list:
unique_symptoms_list.append(symptoms_list[i][j])
return {"disease_list": disease_list, "unique_symptoms_list": unique_symptoms_list}
class SelectedSymptomsQuery(BaseModel):
selected_symptoms: list
@app.post("/find_disease")
def find_disease(query: SelectedSymptomsQuery):
selected_symptoms = query.selected_symptoms
# Filter diseases that match at least one of the selected symptoms
matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in x for s in selected_symptoms))]
# Sort diseases by the number of matching symptoms in descending order
matching_diseases['match_count'] = matching_diseases['Symptoms'].apply(lambda x: sum(s in selected_symptoms for s in x))
matching_diseases = matching_diseases.sort_values(by='match_count', ascending=False)
# Create a list of disease information
disease_list = []
max_match_count_disease = None
max_match_count = -1
for _, row in matching_diseases.iterrows():
disease_info = {
'Disease': row['Name'],
'Symptoms': row['Symptoms'],
'Treatments': row['Treatments'],
'MatchCount': row['match_count']
}
disease_list.append(disease_info)
# Check if this disease has the maximum match count
if row['match_count'] > max_match_count:
max_match_count = row['match_count']
max_match_count_disease = disease_info
return {"disease_list": disease_list, "max_match_count_disease": max_match_count_disease}
class DiseaseListQuery(BaseModel):
disease_list: list
class DiseaseDetail(BaseModel):
Disease: str
Symptoms: list
Treatments: str
MatchCount: int
@app.post("/pass2llm")
def pass2llm(query: DiseaseDetail):
# Prepare the data to be sent to the LLM API
disease_list_details = query
# Make the API request to the Ngrok endpoint to get the public URL
headers = {
"Authorization": "Bearer 2npJaJjnLBj1RGPcGf0QiyAAJHJ_5qqtw2divkpoAipqN9WLG",
"Ngrok-Version": "2"
}
response = requests.get("https://api.ngrok.com/endpoints", headers=headers)
# Check if the request was successful
if response.status_code == 200:
llm_api_response = response.json()
public_url = llm_api_response['endpoints'][0]['public_url']
# Prepare the prompt with the disease list details
prompt = f"Here is a list of diseases and their details: {disease_list_details}. Please generate a summary."
# Make the request to the LLM API
llm_headers = {
"Content-Type": "application/json"
}
llm_payload = {
"model": "llama3",
"prompt": prompt,
"stream": False
}
llm_response = requests.post(f"{public_url}/api/generate", headers=llm_headers, json=llm_payload)
# Check if the request to the LLM API was successful
if llm_response.status_code == 200:
llm_response_json = llm_response.json()
return {"message": "Successfully passed to LLM!", "llm_response": llm_response_json.get("response")}
else:
return {"message": "Failed to get response from LLM!", "error": llm_response.text}
else:
return {"message": "Failed to get public URL from Ngrok!", "error": response.text}
# To run the FastAPI app with Uvicorn
# if __name__ == "__main__":
# uvicorn.run(app, host="0.0.0.0", port=8000)
|