RAG-Chatbot / app.py
not-lain's picture
🌘wπŸŒ’
1b7e4b0
raw
history blame
4.27 kB
import gradio as gr
from datasets import load_dataset, Dataset
# import faiss
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
from ragatouille import RAGPretrainedModel
from datasets import load_dataset
token = os.environ["HF_TOKEN"]
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-7b-it",
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
torch_dtype=torch.float16,
token=token,
)
tok = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token)
device = torch.device("cuda")
model = model.to(device)
RAG = RAGPretrainedModel.from_pretrained("mixedbread-ai/mxbai-colbert-v1")
# prepare data
# since data is too big we will only select the first 3K lines
dataset = load_dataset(
"wikimedia/wikipedia", "20231101.en", split="train", streaming=True
)
# init data
data = Dataset.from_dict({})
i = 0
for i, entry in enumerate(dataset):
# each entry has the following columns
# ['id', 'url', 'title', 'text']
data.add_item(entry)
if i == 3000:
break
# free memory
del dataset # we keep data
# index data
documents = data["text"]
RAG.index(documents, index_name="wikipedia", use_faiss=True)
# free memory
del documents
def search(query, k: int = 5):
results = RAG.search(query, k=k)
# results are ordered according to their score
# results has the following keys
#
# {'content' : 'retrieved content'
# 'score' : score[float]
# 'rank' : "results are sorted using score and each is given a rank, also can be called place, 1 2 3 4 ..."
# 'document_id' : "no clue man i just got here"
# 'passage_id' : "or original row number"
# }
#
return [result["passage_id"] for result in results]
def prepare_prompt(query, indexes,data = data):
prompt = (
f"Query: {query}\nContinue to answer the query by using the Search Results:\n"
)
titles = []
urls = []
for i in indexes:
title = entry["title"][i]
text = entry["text"][i]
url = entry["url"][i]
titles.append(title)
urls.append(url)
prompt += f"Title: {title}, Text: {text}\n"
return prompt, (titles,urls)
@spaces.GPU
def talk(message, history):
indexes = search(message)
message,metadata = prepare_prompt(message, indexes)
resources = "\nRESOURCES:\n"
for title,url in metadata:
resources += f"[{title}]({url}), "
chat = []
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
cleaned_past = item[1].split("\nRESOURCES:\n")[0]
chat.append({"role": "assistant", "content": cleaned_past})
chat.append({"role": "user", "content": message})
messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
# Tokenize the messages string
model_inputs = tok([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=0.75,
num_beams=1,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Initialize an empty string to store the generated text
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
partial_text += resources
yield partial_text
TITLE = "RAG"
DESCRIPTION = """
## Resources used to build this project
* https://huggingface.co/mixedbread-ai/mxbai-colbert-large-v1
* me 😎
## Models
the models used in this space are :
* google/gemma-7b-it
* mixedbread-ai/mxbai-colbert-v1
"""
demo = gr.ChatInterface(
fn=talk,
chatbot=gr.Chatbot(
show_label=True,
show_share_button=True,
show_copy_button=True,
likeable=True,
layout="bubble",
bubble_full_width=False,
),
theme="Soft",
examples=[["what is machine learning"]],
title=TITLE,
description=DESCRIPTION,
)
demo.launch()