Spaces:
Running
Running
import os | |
import time | |
from fastapi import FastAPI, Request | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, Settings | |
from llama_index.llms.huggingface import HuggingFaceInferenceAPI | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from pydantic import BaseModel | |
from fastapi.responses import JSONResponse | |
import uuid # for generating unique IDs | |
import datetime | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.templating import Jinja2Templates | |
from huggingface_hub import InferenceClient | |
import json | |
import re | |
from gradio_client import Client | |
# Define Pydantic model for incoming request body | |
class MessageRequest(BaseModel): | |
message: str | |
repo_id = "meta-llama/Meta-Llama-3-8B-Instruct" | |
llm_client = InferenceClient( | |
model=repo_id, | |
token=os.getenv("HF_TOKEN"), | |
) | |
os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN") | |
app = FastAPI() | |
async def add_security_headers(request: Request, call_next): | |
response = await call_next(request) | |
response.headers["Content-Security-Policy"] = "frame-ancestors *; frame-src *; object-src *;" | |
response.headers["X-Frame-Options"] = "ALLOWALL" | |
return response | |
# Allow CORS requests from any domain | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def favicon(): | |
return HTMLResponse("") # or serve a real favicon if you have one | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
templates = Jinja2Templates(directory="static") | |
# Configure Llama index settings | |
Settings.llm = HuggingFaceInferenceAPI( | |
model_name="meta-llama/Meta-Llama-3-8B-Instruct", | |
tokenizer_name="meta-llama/Meta-Llama-3-8B-Instruct", | |
context_window=3000, | |
token=os.getenv("HF_TOKEN"), | |
max_new_tokens=512, | |
generate_kwargs={"temperature": 0.1}, | |
) | |
Settings.embed_model = HuggingFaceEmbedding( | |
model_name="BAAI/bge-small-en-v1.5" | |
) | |
PERSIST_DIR = "db" | |
PDF_DIRECTORY = 'data' | |
# Ensure directories exist | |
os.makedirs(PDF_DIRECTORY, exist_ok=True) | |
os.makedirs(PERSIST_DIR, exist_ok=True) | |
chat_history = [] | |
current_chat_history = [] | |
def data_ingestion_from_directory(): | |
documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data() | |
storage_context = StorageContext.from_defaults() | |
index = VectorStoreIndex.from_documents(documents) | |
index.storage_context.persist(persist_dir=PERSIST_DIR) | |
def initialize(): | |
start_time = time.time() | |
data_ingestion_from_directory() # Process PDF ingestion at startup | |
print(f"Data ingestion time: {time.time() - start_time} seconds") | |
def split_name(full_name): | |
# Split the name by spaces | |
words = full_name.strip().split() | |
# Logic for determining first name and last name | |
if len(words) == 1: | |
first_name = '' | |
last_name = words[0] | |
elif len(words) == 2: | |
first_name = words[0] | |
last_name = words[1] | |
else: | |
first_name = words[0] | |
last_name = ' '.join(words[1:]) | |
return first_name, last_name | |
initialize() # Run initialization tasks | |
#Chat bot lo personal health assistant vundali fitness , dietary and mental Health related guidance ivvali | |
def handle_query(query): | |
chat_text_qa_msgs = [ | |
( | |
"user", | |
""" | |
You are a personal health assistant. Provide concise, accurate, and professional health-related answers. Keep responses clear, limited to 10-15 words. Maintain a helpful tone and focus solely on health topics. | |
{context_str} | |
Question: | |
{query_str} | |
""" | |
) | |
] | |
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs) | |
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR) | |
index = load_index_from_storage(storage_context) | |
context_str = "" | |
for past_query, response in reversed(current_chat_history): | |
if past_query.strip(): | |
context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n" | |
query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str) | |
answer = query_engine.query(query) | |
if hasattr(answer, 'response'): | |
response = answer.response | |
elif isinstance(answer, dict) and 'response' in answer: | |
response = answer['response'] | |
else: | |
response = "Sorry, I couldn't find an answer." | |
current_chat_history.append((query, response)) | |
return response | |
async def load_chat(request: Request, id: str): | |
return templates.TemplateResponse("index.html", {"request": request, "user_id": id}) | |
async def load_chat(request: Request, id: str): | |
return templates.TemplateResponse("voice.html", {"request": request, "user_id": id}) | |
# Route to save chat history | |
async def save_chat_history(history: dict): | |
# Check if 'userId' is present in the incoming dictionary | |
user_id = history.get('userId') | |
print(user_id) | |
# Ensure user_id is defined before proceeding | |
if user_id is None: | |
return {"error": "userId is required"}, 400 | |
# Construct the chat history string | |
hist = ''.join([f"'{entry['sender']}: {entry['message']}'\n" for entry in history['history']]) | |
hist = "You are a Redfernstech summarize model. Your aim is to use this conversation to identify user interests solely based on that conversation: " + hist | |
print(hist) | |
# Get the summarized result from the client model | |
result = hist | |
return {"summary": result, "message": "Chat history saved"} | |
async def receive_form_data(request: Request): | |
form_data = await request.json() | |
# Generate a unique ID (for tracking user) | |
unique_id = str(uuid.uuid4()) | |
# Here you can do something with form_data like saving it to a database | |
print("Received form data:", form_data) | |
# Send back the unique id to the frontend | |
return JSONResponse({"id": unique_id}) | |
async def chat(request: MessageRequest): | |
message = request.message # Access the message from the request body | |
response = handle_query(message) # Process the message | |
message_data = { | |
"sender": "User", | |
"message": message, | |
"response": response, | |
"timestamp": datetime.datetime.now().isoformat() | |
} | |
chat_history.append(message_data) | |
return {"response": response} | |
def read_root(request: Request): | |
return templates.TemplateResponse("home.html", {"request": request}) | |