|
|
|
|
|
import streamlit as st |
|
import os |
|
import re |
|
import pandas as pd |
|
from langchain_pinecone import PineconeVectorStore |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain.docstore.document import Document |
|
from dotenv import load_dotenv |
|
from pinecone import Pinecone |
|
from openai import OpenAI |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
def get_openai_client(): |
|
return OpenAI( |
|
organization=os.getenv('OPENAI_ORG_ID'), |
|
project=os.getenv('OPENAI_PROJECT_ID') |
|
) |
|
|
|
|
|
@st.cache_resource |
|
def initialize_embeddings(model_name: str = "all-mpnet-base-v2"): |
|
embeddings = HuggingFaceEmbeddings(model_name=model_name) |
|
return embeddings |
|
|
|
|
|
@st.cache_resource |
|
def initialize_vector_store(pinecone_api_key: str, index_name: str): |
|
pc = Pinecone(api_key=pinecone_api_key) |
|
index = pc.Index(index_name) |
|
embeddings = initialize_embeddings() |
|
vector_store = PineconeVectorStore(index=index, embedding=embeddings, text_key='content') |
|
return vector_store, embeddings |
|
|
|
|
|
def get_docs(vector_store, embeddings, query, country=[], vulnerability_cat=[]): |
|
if not country: |
|
country = "All Countries" |
|
if not vulnerability_cat: |
|
filters = None if country == "All Countries" else {'country': {'$in': country}} |
|
else: |
|
if country == "All Countries": |
|
filters = {'vulnerability_cat': {'$in': vulnerability_cat}} |
|
else: |
|
filters = { |
|
'country': {'$in': country}, |
|
'vulnerability_cat': {'$in': vulnerability_cat} |
|
} |
|
|
|
docs = vector_store.similarity_search_by_vector_with_score( |
|
embeddings.embed_query(query), |
|
k=20, |
|
filter=filters, |
|
) |
|
|
|
docs_dict = [{**x[0].metadata, "score": x[1], "content": x[0].page_content} for x in docs] |
|
df_docs = pd.DataFrame(docs_dict).reset_index() |
|
df_docs['ref_id'] = df_docs.index + 1 |
|
|
|
ls_dict = [ |
|
Document( |
|
page_content=row['content'], |
|
metadata={ |
|
'country': row['country'], |
|
'document': row['document'], |
|
'page': row['page'], |
|
'file_name': row['file_name'], |
|
'ref_id': row['ref_id'], |
|
'vulnerability_cat': row['vulnerability_cat'], |
|
'score': row['score'] |
|
} |
|
) |
|
for _, row in df_docs.iterrows() |
|
] |
|
|
|
return ls_dict |
|
|
|
|
|
def get_refs(docs, res): |
|
res = res.lower() |
|
pattern = r'ref\. (\d+)' |
|
ref_ids = [int(match) for match in re.findall(pattern, res)] |
|
result_str = "" |
|
for doc in docs: |
|
ref_id = doc.metadata['ref_id'] |
|
if ref_id in ref_ids: |
|
metadata = doc.metadata |
|
if metadata['document'] == "Supplementary": |
|
result_str += ( |
|
f"**Ref. {ref_id} [{metadata['country']} {metadata['document']}: {metadata['file_name']} p{metadata['page']}; " |
|
f"vulnerabilities: {metadata['vulnerability_cat']}]:** *'{doc.page_content}'*<br><br>" |
|
) |
|
else: |
|
result_str += ( |
|
f"**Ref. {ref_id} [{metadata['country']} {metadata['document']} p{metadata['page']}; " |
|
f"vulnerabilities: {metadata['vulnerability_cat']}]:** *'{doc.page_content}'*<br><br>" |
|
) |
|
return result_str |
|
|
|
|
|
def get_prompt(prompt_template, docs, input_query): |
|
context = ' - '.join([ |
|
f"&&& [ref. {d.metadata['ref_id']}] {d.metadata['document']} &&&: {d.page_content}" |
|
for d in docs |
|
]) |
|
prompt = f"{prompt_template}; Context: {context}; Question: {input_query}; Answer:" |
|
return prompt |
|
|
|
|
|
def run_query(client, prompt, docs, res_box): |
|
stream = client.chat.completions.create( |
|
model="gpt-4o-mini-2024-07-18", |
|
messages=[{"role": "user", "content": prompt}], |
|
stream=True, |
|
) |
|
report = [] |
|
for chunk in stream: |
|
if chunk.choices[0].delta.content is not None: |
|
report.append(chunk.choices[0].delta.content) |
|
result = "".join(report).strip() |
|
res_box.success(result) |
|
|
|
references = get_refs(docs, result) |
|
return references |