Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from transformers import RagTokenizer, RagTokenForGeneration | |
from typing import List, Dict, Tuple | |
import re | |
import os | |
import torch | |
from math import ceil | |
# Load the RAG model and tokenizer | |
rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") | |
rag_model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq") | |
# Load your PDF document | |
pdf_path = "apexcustoms.pdf" | |
with open(pdf_path, 'rb') as f: | |
pdf_text = f.read().decode('utf-8', errors='ignore') | |
# Split the PDF text into chunks | |
split_pattern = r'\n\n' | |
doc_chunks = re.split(split_pattern, pdf_text) | |
# Preprocess the corpus | |
corpus = rag_tokenizer(doc_chunks, return_tensors="pt", padding=True, truncation=True).input_ids | |
# Pad the corpus to be a multiple of `n_docs` | |
n_docs = rag_model.config.n_docs | |
corpus_length = corpus.size(-1) | |
pad_length = ceil(corpus_length / n_docs) * n_docs - corpus_length | |
corpus = torch.nn.functional.pad(corpus, (0, pad_length), mode='constant', value=rag_model.config.pad_token_id) | |
""" | |
For more information on huggingface_hub Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference | |
""" | |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
def respond( | |
message, | |
history: List[Tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
messages = [{"role": "system", "content": system_message}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
# Tokenize the input | |
inputs = rag_tokenizer(message, return_tensors="pt") | |
input_ids = inputs.pop("input_ids") | |
# Generate with the RAG model | |
output_ids = rag_model.generate( | |
input_ids=input_ids, | |
context_input_ids=corpus, | |
max_length=max_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=0, | |
num_beams=2, | |
) | |
retrieved_context = rag_tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] | |
response = "" | |
for message in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
context=retrieved_context, # Include the retrieved context | |
): | |
token = message.choices[0].delta.content | |
response += token | |
yield response | |
""" | |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
""" | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are a helpful car configuration assistant, specifically you are the assistant for Apex Customs (https://www.apexcustoms.com/). Given the user's input, provide suggestions for car models, colors, and customization options. Be creative and conversational in your responses. You should remember the user car model and tailor your answers accordingly. \n\nUser: ", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch() |