File size: 5,102 Bytes
74a5bdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import transformers
from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer, AutoModelForCausalLM
import gradio as gr

device = 'cuda' if torch.cuda.is_available() else 'cpu'


dataset_path = "./5k_index_data/my_knowledge_dataset"
index_path = "./5k_index_data/my_knowledge_dataset_hnsw_index.faiss"

tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom",
                                            passages_path = dataset_path,
                                            index_path = index_path,
                                            n_docs = 5)
rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
rag_model.retriever.init_retrieval()
rag_model.to(device)
model = AutoModelForCausalLM.from_pretrained('google/gemma-2-9b-it',
                              device_map = 'auto',
                              torch_dtype = torch.bfloat16,
                             )



def strip_title(title):
    if title.startswith('"'):
        title = title[1:]
    if title.endswith('"'):
        title = title[:-1]
    
    return title

# getting the correct format to input in gemma model
def input_format(query, context):
    sys_instruction = f'Context:\n {context} \n Given the following information, generate answer to the question. Provide links in the answer from the information to increase credebility.'
    message = f'Question: {query}'

    return f'<bos><start_of_turn>\n{sys_instruction}' + f' {message}<end_of_turn>\n'

# retrieving and generating  answer in one call
def retrieved_info(query, rag_model = rag_model, generating_model = model):
    # Tokenize Query
    retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
        [query],
        return_tensors = 'pt',
        padding = True,
        truncation = True,
    )['input_ids'].to(device)

    # Retrieve Documents
    question_encoder_output = rag_model.rag.question_encoder(retriever_input_ids)
    question_encoder_pool_output = question_encoder_output[0]

    result = rag_model.retriever(
        retriever_input_ids,
        question_encoder_pool_output.cpu().detach().to(torch.float32).numpy(),
        prefix = rag_model.rag.generator.config.prefix,
        n_docs = rag_model.config.n_docs,
        return_tensors = 'pt',
    )

    # Preparing query and retrieved docs for model
    all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids)
    retrieved_context = []
    for docs in all_docs:
        titles = [strip_title(title) for title in docs['title']]
        texts = docs['text']
        for title, text in zip(titles, texts):
            retrieved_context.append(f'{title}: {text}')

    generation_model_input = input_format(query, retrieved_context)

    # Generating answer using gemma model
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
    input_ids = tokenizer(generation_model_input, return_tensors='pt').to(device)
    output = generating_model.generate(input_ids, max_new_tokens = 512)
    
    return tokenizer.decode(output[0])






def respond(

    message,

    history: list[tuple[str, str]],

    system_message,

    max_tokens ,

    temperature,

    top_p,

):
    if message:  # If there's a user query
        response = retrieved_info(message)  # Get the answer from your local FAISS and Q&A model
        return response

    # In case no message, return an empty string
    return ""



"""

For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface

"""
# Custom title and description
title = "🧠 Welcome to Your AI Knowledge Assistant"
description = """

HI!!, I am your loyal assistant, y functionality is based on RAG model, I retrieves relevant information and provide answers based on that. Ask me any question, and let me assist you.

My capabilities are limited because I am still in development phase. I will do my best to assist you. SOOO LET'S BEGGINNNN......

"""

demo = gr.ChatInterface(
    respond,
    type = 'messages',
    additional_inputs=[
        gr.Textbox(value="You are a helpful and friendly assistant.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
    title=title,
    description=description,
    textbox=gr.Textbox(placeholder=["'What is the future of AI?' or 'App Development'"]),
    examples=[["✨Future of AI"], ["📱App Development"]],
    example_icons=["🤖", "📱"],
    theme="compact",
)


if __name__ == "__main__":
    demo.launch(share = True )