gen-ai-project / app.py
Moha782's picture
Update app.py
6a0b151 verified
raw
history blame
3.52 kB
import os
import json
import gradio as gr
import faiss
import fitz # PyMuPDF
import numpy as np
from huggingface_hub import InferenceClient
from sentence_transformers import SentenceTransformer
# Initialize the SentenceTransformer model
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
# Extract text from PDF
def extract_text_from_pdf(pdf_path):
doc = fitz.open(pdf_path)
text = ""
for page_num in range(doc.page_count):
page = doc.load_page(page_num)
text += page.get_text()
return text.split("\n\n")
# Build FAISS index
def build_faiss_index(documents):
document_embeddings = model.encode(documents)
index = faiss.IndexFlatL2(document_embeddings.shape[1])
index.add(document_embeddings)
faiss.write_index(index, "apexcustoms_index.faiss")
model.save("sentence_transformer_model")
return index
# Ensure that text extraction and FAISS index building is done
if not os.path.exists("apexcustoms_index.faiss") or not os.path.exists("sentence_transformer_model"):
documents = extract_text_from_pdf("apexcustoms.pdf")
with open("apexcustoms.json", "w") as f:
json.dump(documents, f)
index = build_faiss_index(documents)
else:
index = faiss.read_index("apexcustoms_index.faiss")
# Hugging Face client
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
def retrieve_documents(query, k=5):
query_embedding = model.encode([query])
distances, indices = index.search(query_embedding, k)
return [documents[i] for i in indices[0]]
async def respond(message, history, system_message, max_tokens, temperature, top_p):
# Retrieve relevant documents
relevant_docs = retrieve_documents(message)
context = "\n\n".join(relevant_docs[:3]) # Limit context to top 3 documents
# Limit history to the last 5 exchanges to reduce payload size
history = history[-5:]
messages = [{"role": "system", "content": system_message},
{"role": "user", "content": f"Context: {context}\n\n{message}"}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
async for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
if message.choices and message.choices[0].delta and message.choices[0].delta.content:
token = message.choices[0].delta.content
yield token
demo = gr.ChatInterface(
fn=respond,
inputs=[
gr.Textbox(
value="You are a helpful car configuration assistant, specifically you are the assistant for Apex Customs (https://www.apexcustoms.com/). Given the user's input, provide suggestions for car models, colors, and customization options. Be creative and conversational in your responses. You should remember the user car model and tailor your answers accordingly. \n\nUser: ",
label="System message"
),
gr.Slider(minimum=1, maximum=2048, step=1, value=512, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, step=0.1, value=0.7, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.95, label="Top-p (nucleus sampling)"),
],
outputs=gr.Textbox(label="Assistant's Response"),
)
if __name__ == "__main__":
demo.launch()