Baweja commited on
Commit
74a5bdf
·
verified ·
1 Parent(s): 25ed140

Upload app_new.py

Browse files
Files changed (1) hide show
  1. app_new.py +140 -0
app_new.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer, AutoModelForCausalLM
4
+ import gradio as gr
5
+
6
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
+
8
+
9
+ dataset_path = "./5k_index_data/my_knowledge_dataset"
10
+ index_path = "./5k_index_data/my_knowledge_dataset_hnsw_index.faiss"
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
13
+ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom",
14
+ passages_path = dataset_path,
15
+ index_path = index_path,
16
+ n_docs = 5)
17
+ rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
18
+ rag_model.retriever.init_retrieval()
19
+ rag_model.to(device)
20
+ model = AutoModelForCausalLM.from_pretrained('google/gemma-2-9b-it',
21
+ device_map = 'auto',
22
+ torch_dtype = torch.bfloat16,
23
+ )
24
+
25
+
26
+
27
+ def strip_title(title):
28
+ if title.startswith('"'):
29
+ title = title[1:]
30
+ if title.endswith('"'):
31
+ title = title[:-1]
32
+
33
+ return title
34
+
35
+ # getting the correct format to input in gemma model
36
+ def input_format(query, context):
37
+ sys_instruction = f'Context:\n {context} \n Given the following information, generate answer to the question. Provide links in the answer from the information to increase credebility.'
38
+ message = f'Question: {query}'
39
+
40
+ return f'<bos><start_of_turn>\n{sys_instruction}' + f' {message}<end_of_turn>\n'
41
+
42
+ # retrieving and generating answer in one call
43
+ def retrieved_info(query, rag_model = rag_model, generating_model = model):
44
+ # Tokenize Query
45
+ retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
46
+ [query],
47
+ return_tensors = 'pt',
48
+ padding = True,
49
+ truncation = True,
50
+ )['input_ids'].to(device)
51
+
52
+ # Retrieve Documents
53
+ question_encoder_output = rag_model.rag.question_encoder(retriever_input_ids)
54
+ question_encoder_pool_output = question_encoder_output[0]
55
+
56
+ result = rag_model.retriever(
57
+ retriever_input_ids,
58
+ question_encoder_pool_output.cpu().detach().to(torch.float32).numpy(),
59
+ prefix = rag_model.rag.generator.config.prefix,
60
+ n_docs = rag_model.config.n_docs,
61
+ return_tensors = 'pt',
62
+ )
63
+
64
+ # Preparing query and retrieved docs for model
65
+ all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids)
66
+ retrieved_context = []
67
+ for docs in all_docs:
68
+ titles = [strip_title(title) for title in docs['title']]
69
+ texts = docs['text']
70
+ for title, text in zip(titles, texts):
71
+ retrieved_context.append(f'{title}: {text}')
72
+
73
+ generation_model_input = input_format(query, retrieved_context)
74
+
75
+ # Generating answer using gemma model
76
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
77
+ input_ids = tokenizer(generation_model_input, return_tensors='pt').to(device)
78
+ output = generating_model.generate(input_ids, max_new_tokens = 512)
79
+
80
+ return tokenizer.decode(output[0])
81
+
82
+
83
+
84
+
85
+
86
+
87
+ def respond(
88
+ message,
89
+ history: list[tuple[str, str]],
90
+ system_message,
91
+ max_tokens ,
92
+ temperature,
93
+ top_p,
94
+ ):
95
+ if message: # If there's a user query
96
+ response = retrieved_info(message) # Get the answer from your local FAISS and Q&A model
97
+ return response
98
+
99
+ # In case no message, return an empty string
100
+ return ""
101
+
102
+
103
+
104
+ """
105
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
106
+ """
107
+ # Custom title and description
108
+ title = "🧠 Welcome to Your AI Knowledge Assistant"
109
+ description = """
110
+ HI!!, I am your loyal assistant, y functionality is based on RAG model, I retrieves relevant information and provide answers based on that. Ask me any question, and let me assist you.
111
+ My capabilities are limited because I am still in development phase. I will do my best to assist you. SOOO LET'S BEGGINNNN......
112
+ """
113
+
114
+ demo = gr.ChatInterface(
115
+ respond,
116
+ type = 'messages',
117
+ additional_inputs=[
118
+ gr.Textbox(value="You are a helpful and friendly assistant.", label="System message"),
119
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
120
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
121
+ gr.Slider(
122
+ minimum=0.1,
123
+ maximum=1.0,
124
+ value=0.95,
125
+ step=0.05,
126
+ label="Top-p (nucleus sampling)",
127
+ ),
128
+ ],
129
+ title=title,
130
+ description=description,
131
+ textbox=gr.Textbox(placeholder=["'What is the future of AI?' or 'App Development'"]),
132
+ examples=[["✨Future of AI"], ["📱App Development"]],
133
+ example_icons=["🤖", "📱"],
134
+ theme="compact",
135
+ )
136
+
137
+
138
+ if __name__ == "__main__":
139
+ demo.launch(share = True )
140
+