File size: 4,395 Bytes
d49d09a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec8c2f5
d49d09a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# utils.py

import streamlit as st
import os
import re
import pandas as pd
from langchain_pinecone import PineconeVectorStore
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.docstore.document import Document
from dotenv import load_dotenv
from pinecone import Pinecone
from openai import OpenAI

# Load environment variables
load_dotenv()

# Initialize OpenAI client
def get_openai_client():
    return OpenAI(
        organization=os.getenv('OPENAI_ORG_ID'),
        project=os.getenv('OPENAI_PROJECT_ID')
    )

# Initialize embeddings
@st.cache_resource
def initialize_embeddings(model_name: str = "all-mpnet-base-v2"):
    embeddings = HuggingFaceEmbeddings(model_name=model_name)
    return embeddings

# Initialize vector store
@st.cache_resource
def initialize_vector_store(pinecone_api_key: str, index_name: str):
    pc = Pinecone(api_key=pinecone_api_key)
    index = pc.Index(index_name)
    embeddings = initialize_embeddings()
    vector_store = PineconeVectorStore(index=index, embedding=embeddings, text_key='content')
    return vector_store, embeddings

# Fetch documents based on query and filters
def get_docs(vector_store, embeddings, query, country=[], vulnerability_cat=[]):
    if not country:
        country = "All Countries"
    if not vulnerability_cat:
        filters = None if country == "All Countries" else {'country': {'$in': country}}
    else:
        if country == "All Countries":
            filters = {'vulnerability_cat': {'$in': vulnerability_cat}}
        else:
            filters = {
                'country': {'$in': country},
                'vulnerability_cat': {'$in': vulnerability_cat}
            }

    docs = vector_store.similarity_search_by_vector_with_score(
        embeddings.embed_query(query),
        k=20,
        filter=filters,
    )

    docs_dict = [{**x[0].metadata, "score": x[1], "content": x[0].page_content} for x in docs]
    df_docs = pd.DataFrame(docs_dict).reset_index()
    df_docs['ref_id'] = df_docs.index + 1

    ls_dict = [
        Document(
            page_content=row['content'],
            metadata={
                'country': row['country'],
                'document': row['document'],
                'page': row['page'],
                'file_name': row['file_name'],
                'ref_id': row['ref_id'],
                'vulnerability_cat': row['vulnerability_cat'],
                'score': row['score']
            }
        )
        for _, row in df_docs.iterrows()
    ]

    return ls_dict

# Extract references from the response
def get_refs(docs, res):
    res = res.lower()
    pattern = r'ref\. (\d+)'
    ref_ids = [int(match) for match in re.findall(pattern, res)]
    result_str = ""
    for doc in docs:
        ref_id = doc.metadata['ref_id']
        if ref_id in ref_ids:
            metadata = doc.metadata
            if metadata['document'] == "Supplementary":
                result_str += (
                    f"**Ref. {ref_id} [{metadata['country']} {metadata['document']}: {metadata['file_name']} p{metadata['page']}; "
                    f"vulnerabilities: {metadata['vulnerability_cat']}]:** *'{doc.page_content}'*<br><br>"
                )
            else:
                result_str += (
                    f"**Ref. {ref_id} [{metadata['country']} {metadata['document']} p{metadata['page']}; "
                    f"vulnerabilities: {metadata['vulnerability_cat']}]:** *'{doc.page_content}'*<br><br>"
                )
    return result_str

# Construct the prompt for the model
def get_prompt(prompt_template, docs, input_query):
    context = ' - '.join([
        f"&&& [ref. {d.metadata['ref_id']}] {d.metadata['document']} &&&: {d.page_content}"
        for d in docs
    ])
    prompt = f"{prompt_template}; Context: {context}; Question: {input_query}; Answer:"
    return prompt

# Execute the query and generate the response
def run_query(client, prompt, docs, res_box):
    stream = client.chat.completions.create(
        model="gpt-4o-mini-2024-07-18",
        messages=[{"role": "user", "content": prompt}],
        stream=True,
    )
    report = []
    for chunk in stream:
        if chunk.choices[0].delta.content is not None:
            report.append(chunk.choices[0].delta.content)
            result = "".join(report).strip()
            res_box.success(result)

    references = get_refs(docs, result)
    return references