File size: 4,754 Bytes
5cff97b
 
fada25c
5cff97b
d4aca7a
 
5cff97b
 
 
 
 
 
 
 
d4aca7a
5cff97b
 
 
 
97743c9
d4aca7a
 
5cff97b
 
 
 
 
 
 
 
 
 
 
 
 
d4aca7a
ca16a7c
5cff97b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d5fd5d
5cff97b
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from dotenv import load_dotenv
from fastapi import FastAPI, Request
import gradio as gr
import os
import firebase_admin
from firebase_admin import db, credentials
import datetime
import uuid
import random
from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, Settings
from llama_index.llms.huggingface import HuggingFaceInferenceAPI
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import threading
import uvicorn

# Load environment variables
load_dotenv()

# Authenticate to Firebase
cred = credentials.Certificate("redfernstech-fd8fe-firebase-adminsdk-g9vcn-0537b4efd6.json")
firebase_admin.initialize_app(cred, {"databaseURL": "https://redfernstech-fd8fe-default-rtdb.firebaseio.com/"})

app = FastAPI()

# Configure the Llama index settings
Settings.llm = HuggingFaceInferenceAPI(
    model_name="meta-llama/Meta-Llama-3-8B-Instruct",
    tokenizer_name="meta-llama/Meta-Llama-3-8B-Instruct",
    context_window=3000,
    token=os.getenv("HF_TOKEN"),
    max_new_tokens=512,
    generate_kwargs={"temperature": 0.1},
)
Settings.embed_model = HuggingFaceEmbedding(
    model_name="BAAI/bge-small-en-v1.5"
)

# Define directories
PERSIST_DIR = "db"
PDF_DIRECTORY = 'data'
os.makedirs(PDF_DIRECTORY, exist_ok=True)
os.makedirs(PERSIST_DIR, exist_ok=True)

# Variable to store chat history
current_chat_history = []

def data_ingestion_from_directory():
    documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data()
    storage_context = StorageContext.from_defaults()
    index = VectorStoreIndex.from_documents(documents)
    index.storage_context.persist(persist_dir=PERSIST_DIR)

def handle_query(query):
    chat_text_qa_msgs = [
        (
            "user",
            """
            You are Clara, the Redfernstech chatbot. Provide accurate, professional, and concise answers based on the data. Respond within 10-15 words only.
            {context_str}
            Question:
            {query_str}
            """
        )
    ]
    text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)

    storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
    index = load_index_from_storage(storage_context)

    context_str = ""
    for past_query, response in reversed(current_chat_history):
        if past_query.strip():
            context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"

    query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str)
    answer = query_engine.query(query)

    response = answer.response if hasattr(answer, 'response') else "Sorry, I couldn't find an answer."

    current_chat_history.append((query, response))

    return response

def save_chat_message(username, email, session_id, message_data):
    ref = db.reference(f'/chat_history/{username}/{email}/{session_id}')
    ref.push().set(message_data)

def chat_interface(message, history, request: gr.Request):
    try:
        username = request.query_params.get('username')
        email = request.query_params.get('email')
        session_id = str(uuid.uuid4())
        response = handle_query(message)

        message_data = {
            "sender": request.client.host,
            "message": message,
            "response": response,
            "timestamp": datetime.datetime.now().isoformat()
        }

        save_chat_message(username, email, session_id, message_data)

        return response
    except Exception as e:
        return str(e)

css = '''
  .circle-logo {
    display: inline-block;
    width: 40px;
    height: 40px;
    border-radius: 50%;
    overflow: hidden;
    margin-right: 10px;
    vertical-align: middle;
  }
  .circle-logo img {
    width: 100%;
    height: 100%;
    object-fit: cover;
  }
  .response-with-logo {
    display: flex;
    align-items: center;
    margin-bottom: 10px;
  }
  footer {
    display: none !important;
    background-color: #F8D7DA;
  }
  .svelte-1ed2p3z p {
    font-size: 24px;
    font-weight: bold;
    line-height: 1.2;
    color: #111;
    margin: 20px 0;
  }
  label.svelte-1b6s6s {display: none}
  div.svelte-rk35yg {display: none;}
  div.progress-text.svelte-z7cif2.meta-text {display: none;}
'''

@app.get("/chat")
async def chat_ui(username: str, email: str):
    gr.ChatInterface(
            fn=chat_interface,
            css=css,
            description="Clara",
            clear_btn=None,
            undo_btn=None,
            retry_btn=None
        ).launch()
    return {"message": "Chat interface launched."}

if __name__ == "__main__":
    data_ingestion_from_directory()
    threading.Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=8000), daemon=True).start()