File size: 5,102 Bytes
74a5bdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import torch
import transformers
from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer, AutoModelForCausalLM
import gradio as gr
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset_path = "./5k_index_data/my_knowledge_dataset"
index_path = "./5k_index_data/my_knowledge_dataset_hnsw_index.faiss"
tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom",
passages_path = dataset_path,
index_path = index_path,
n_docs = 5)
rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
rag_model.retriever.init_retrieval()
rag_model.to(device)
model = AutoModelForCausalLM.from_pretrained('google/gemma-2-9b-it',
device_map = 'auto',
torch_dtype = torch.bfloat16,
)
def strip_title(title):
if title.startswith('"'):
title = title[1:]
if title.endswith('"'):
title = title[:-1]
return title
# getting the correct format to input in gemma model
def input_format(query, context):
sys_instruction = f'Context:\n {context} \n Given the following information, generate answer to the question. Provide links in the answer from the information to increase credebility.'
message = f'Question: {query}'
return f'<bos><start_of_turn>\n{sys_instruction}' + f' {message}<end_of_turn>\n'
# retrieving and generating answer in one call
def retrieved_info(query, rag_model = rag_model, generating_model = model):
# Tokenize Query
retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
[query],
return_tensors = 'pt',
padding = True,
truncation = True,
)['input_ids'].to(device)
# Retrieve Documents
question_encoder_output = rag_model.rag.question_encoder(retriever_input_ids)
question_encoder_pool_output = question_encoder_output[0]
result = rag_model.retriever(
retriever_input_ids,
question_encoder_pool_output.cpu().detach().to(torch.float32).numpy(),
prefix = rag_model.rag.generator.config.prefix,
n_docs = rag_model.config.n_docs,
return_tensors = 'pt',
)
# Preparing query and retrieved docs for model
all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids)
retrieved_context = []
for docs in all_docs:
titles = [strip_title(title) for title in docs['title']]
texts = docs['text']
for title, text in zip(titles, texts):
retrieved_context.append(f'{title}: {text}')
generation_model_input = input_format(query, retrieved_context)
# Generating answer using gemma model
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
input_ids = tokenizer(generation_model_input, return_tensors='pt').to(device)
output = generating_model.generate(input_ids, max_new_tokens = 512)
return tokenizer.decode(output[0])
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens ,
temperature,
top_p,
):
if message: # If there's a user query
response = retrieved_info(message) # Get the answer from your local FAISS and Q&A model
return response
# In case no message, return an empty string
return ""
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
# Custom title and description
title = "🧠 Welcome to Your AI Knowledge Assistant"
description = """
HI!!, I am your loyal assistant, y functionality is based on RAG model, I retrieves relevant information and provide answers based on that. Ask me any question, and let me assist you.
My capabilities are limited because I am still in development phase. I will do my best to assist you. SOOO LET'S BEGGINNNN......
"""
demo = gr.ChatInterface(
respond,
type = 'messages',
additional_inputs=[
gr.Textbox(value="You are a helpful and friendly assistant.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
title=title,
description=description,
textbox=gr.Textbox(placeholder=["'What is the future of AI?' or 'App Development'"]),
examples=[["✨Future of AI"], ["📱App Development"]],
example_icons=["🤖", "📱"],
theme="compact",
)
if __name__ == "__main__":
demo.launch(share = True )
|