Spaces:
Running
Running
Update views/rag_med.py
Browse files- views/rag_med.py +42 -59
views/rag_med.py
CHANGED
@@ -9,24 +9,6 @@ from sentence_transformers import SentenceTransformer
|
|
9 |
import pickle
|
10 |
from prompt_template import prompt
|
11 |
|
12 |
-
# Função para encontrar o arquivo
|
13 |
-
def find_file(filename):
|
14 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
15 |
-
root_dir = os.path.dirname(current_dir)
|
16 |
-
|
17 |
-
possible_paths = [
|
18 |
-
os.path.join(current_dir, filename),
|
19 |
-
os.path.join(root_dir, filename),
|
20 |
-
os.path.join("/content", filename), # Para ambientes como Colab
|
21 |
-
os.path.join("/home/user/app/", filename) # Caminho comum no Hugging Face
|
22 |
-
]
|
23 |
-
|
24 |
-
for path in possible_paths:
|
25 |
-
if os.path.exists(path):
|
26 |
-
return path
|
27 |
-
|
28 |
-
raise FileNotFoundError(f"Não foi possível encontrar o arquivo: {filename}")
|
29 |
-
|
30 |
# Initialize the messages list in the session state
|
31 |
if "messages" not in st.session_state:
|
32 |
st.session_state.messages = []
|
@@ -54,39 +36,41 @@ with st.sidebar:
|
|
54 |
st.session_state.messages = []
|
55 |
st.rerun()
|
56 |
|
57 |
-
# Carregando o índice FAISS e os dados armazenados
|
58 |
-
@st.cache_resource
|
59 |
-
def load_faiss_index():
|
60 |
-
index_file = find_file("faiss_index.bin")
|
61 |
-
data_file = find_file("stored_data.pkl")
|
62 |
-
|
63 |
-
index = faiss.read_index(index_file)
|
64 |
-
with open(data_file, 'rb') as f:
|
65 |
-
stored_data = pickle.load(f)
|
66 |
-
return index, stored_data
|
67 |
-
|
68 |
-
try:
|
69 |
-
index, stored_data = load_faiss_index()
|
70 |
-
except Exception as e:
|
71 |
-
st.error(f"Erro ao carregar o índice FAISS: {e}")
|
72 |
-
st.stop()
|
73 |
-
|
74 |
-
# Inicialize o modelo de embedding
|
75 |
model = SentenceTransformer("all-MiniLM-L6-v2")
|
76 |
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
query_embedding = model.encode([query])
|
79 |
-
query_embedding = query_embedding.astype(np.float32).reshape(1, -1)
|
80 |
distances, indices = index.search(query_embedding, k)
|
81 |
-
|
82 |
-
|
83 |
-
for i, idx in enumerate(indices[0]):
|
84 |
-
if idx != -1 and idx < len(stored_data):
|
85 |
-
result = stored_data[idx]
|
86 |
-
result['score'] = float(1 - distances[0][i]/2) # Convertendo distância para similaridade
|
87 |
-
results.append(result)
|
88 |
-
|
89 |
-
return results
|
90 |
|
91 |
def generate_response(query, context):
|
92 |
"""
|
@@ -105,8 +89,10 @@ def generate_response(query, context):
|
|
105 |
model="gpt-4o-mini",
|
106 |
messages=[
|
107 |
{
|
108 |
-
"role":
|
109 |
-
"
|
|
|
|
|
110 |
},
|
111 |
{
|
112 |
"role": "user",
|
@@ -116,19 +102,19 @@ def generate_response(query, context):
|
|
116 |
)
|
117 |
return response.choices[0].message.content
|
118 |
|
|
|
119 |
co1, co2, co3 = st.columns([1.5, 0.4, 3])
|
120 |
with co1:
|
121 |
st.write('')
|
122 |
with co2:
|
123 |
-
st.
|
124 |
-
#st.image('icons/icon.svg', width=80)
|
125 |
|
126 |
with co3:
|
127 |
st.title("Encaminhamento Médico")
|
128 |
|
129 |
col1, col2, col3, col4, col5 = st.columns([1, 3, 0.5, 3, 1])
|
130 |
|
131 |
-
|
132 |
with col1:
|
133 |
st.write('')
|
134 |
|
@@ -139,8 +125,7 @@ with col2:
|
|
139 |
cl1, cl2, cl3 = st.columns([2, 1, 2])
|
140 |
with cl1:
|
141 |
idade = st.text_input("###### IDADE", value="")
|
142 |
-
sexo_opcao = st.radio("Sexo", ("Masculino", "Feminino"),
|
143 |
-
index=0)
|
144 |
sexo = (sexo_opcao == "Masculino")
|
145 |
with cl2:
|
146 |
st.write("")
|
@@ -253,7 +238,6 @@ with col4:
|
|
253 |
mas adaptando o conteúdo para ser mais fluido e coeso.
|
254 |
"""
|
255 |
st.text_area("Prompt gerado:", user_prompt, height=600)
|
256 |
-
|
257 |
if not openai_api_key:
|
258 |
st.info("Insira sua chave API OpenAI para continuar.")
|
259 |
st.stop()
|
@@ -291,9 +275,8 @@ if prompt := st.chat_input("Digite sua pergunta"):
|
|
291 |
st.markdown(prompt)
|
292 |
|
293 |
# Processa a query e gera a resposta
|
294 |
-
|
295 |
-
|
296 |
-
response = generate_response(prompt, context)
|
297 |
|
298 |
# Adiciona a resposta do bot
|
299 |
st.session_state.messages.append({
|
@@ -303,4 +286,4 @@ if prompt := st.chat_input("Digite sua pergunta"):
|
|
303 |
with st.chat_message("assistant"):
|
304 |
st.markdown(response)
|
305 |
with col5:
|
306 |
-
st.write('')
|
|
|
9 |
import pickle
|
10 |
from prompt_template import prompt
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
# Initialize the messages list in the session state
|
13 |
if "messages" not in st.session_state:
|
14 |
st.session_state.messages = []
|
|
|
36 |
st.session_state.messages = []
|
37 |
st.rerun()
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
model = SentenceTransformer("all-MiniLM-L6-v2")
|
40 |
|
41 |
+
# índice FAISS
|
42 |
+
json_file = "scr/faiss/chunks .json"
|
43 |
+
embeddings_file = "scr/faiss/embeddings.npy"
|
44 |
+
index_file = "scr/faiss/faiss_index.index"
|
45 |
+
text_chunks_file = "scr/faiss/text_chunks.npy"
|
46 |
+
|
47 |
+
with open(json_file, "r") as file:
|
48 |
+
chunks = json.load(file)
|
49 |
+
|
50 |
+
embeddings = np.load(embeddings_file)
|
51 |
+
index = faiss.read_index(index_file)
|
52 |
+
text_chunks = np.load(text_chunks_file)
|
53 |
+
|
54 |
+
with st.sidebar:
|
55 |
+
openai_api_key = st.text_input("OpenAI API Key",
|
56 |
+
key="chatbot_api_key",
|
57 |
+
type="password")
|
58 |
+
st.markdown(
|
59 |
+
"[Pegue aqui sua chave OpenAI API](https://platform.openai.com/account/api-keys)"
|
60 |
+
)
|
61 |
+
if st.sidebar.button("Limpar Conversa"):
|
62 |
+
st.session_state.messages = []
|
63 |
+
st.rerun()
|
64 |
+
|
65 |
+
|
66 |
+
def retrieve(query, k=3):
|
67 |
+
"""
|
68 |
+
Retrieves the most similar document chunks to a given query.
|
69 |
+
"""
|
70 |
query_embedding = model.encode([query])
|
|
|
71 |
distances, indices = index.search(query_embedding, k)
|
72 |
+
return [chunks[idx] for idx in indices[0]]
|
73 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
def generate_response(query, context):
|
76 |
"""
|
|
|
89 |
model="gpt-4o-mini",
|
90 |
messages=[
|
91 |
{
|
92 |
+
"role":
|
93 |
+
"system",
|
94 |
+
"content":
|
95 |
+
"Você é um assistente especializado em análise de encaminhamentos médicos.",
|
96 |
},
|
97 |
{
|
98 |
"role": "user",
|
|
|
102 |
)
|
103 |
return response.choices[0].message.content
|
104 |
|
105 |
+
|
106 |
co1, co2, co3 = st.columns([1.5, 0.4, 3])
|
107 |
with co1:
|
108 |
st.write('')
|
109 |
with co2:
|
110 |
+
st.image('icons/icon.svg', width=80)
|
|
|
111 |
|
112 |
with co3:
|
113 |
st.title("Encaminhamento Médico")
|
114 |
|
115 |
col1, col2, col3, col4, col5 = st.columns([1, 3, 0.5, 3, 1])
|
116 |
|
117 |
+
fem = 'icons/iconF.svg'
|
118 |
with col1:
|
119 |
st.write('')
|
120 |
|
|
|
125 |
cl1, cl2, cl3 = st.columns([2, 1, 2])
|
126 |
with cl1:
|
127 |
idade = st.text_input("###### IDADE", value="")
|
128 |
+
sexo_opcao = st.radio("Sexo", ("Masculino", "Feminino"), index=0)
|
|
|
129 |
sexo = (sexo_opcao == "Masculino")
|
130 |
with cl2:
|
131 |
st.write("")
|
|
|
238 |
mas adaptando o conteúdo para ser mais fluido e coeso.
|
239 |
"""
|
240 |
st.text_area("Prompt gerado:", user_prompt, height=600)
|
|
|
241 |
if not openai_api_key:
|
242 |
st.info("Insira sua chave API OpenAI para continuar.")
|
243 |
st.stop()
|
|
|
275 |
st.markdown(prompt)
|
276 |
|
277 |
# Processa a query e gera a resposta
|
278 |
+
retrieved_context = retrieve(prompt)
|
279 |
+
response = generate_response(prompt, retrieved_context)
|
|
|
280 |
|
281 |
# Adiciona a resposta do bot
|
282 |
st.session_state.messages.append({
|
|
|
286 |
with st.chat_message("assistant"):
|
287 |
st.markdown(response)
|
288 |
with col5:
|
289 |
+
st.write('')
|