Baweja commited on
Commit
56b53e7
·
verified ·
1 Parent(s): 5cb77b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -140
app.py CHANGED
@@ -1,140 +1,140 @@
1
- import torch
2
- import transformers
3
- from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer, AutoModelForCausalLM
4
- import gradio as gr
5
-
6
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
-
8
-
9
- dataset_path = "./5k_index_data/my_knowledge_dataset"
10
- index_path = "./5k_index_data/my_knowledge_dataset_hnsw_index.faiss"
11
-
12
- tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
13
- retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom",
14
- passages_path = dataset_path,
15
- index_path = index_path,
16
- n_docs = 5)
17
- rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
18
- rag_model.retriever.init_retrieval()
19
- rag_model.to(device)
20
- model = AutoModelForCausalLM.from_pretrained('google/gemma-2-9b-it',
21
- device_map = 'auto',
22
- torch_dtype = torch.bfloat16,
23
- )
24
-
25
-
26
-
27
- def strip_title(title):
28
- if title.startswith('"'):
29
- title = title[1:]
30
- if title.endswith('"'):
31
- title = title[:-1]
32
-
33
- return title
34
-
35
- # getting the correct format to input in gemma model
36
- def input_format(query, context):
37
- sys_instruction = f'Context:\n {context} \n Given the following information, generate answer to the question. Provide links in the answer from the information to increase credebility.'
38
- message = f'Question: {query}'
39
-
40
- return f'<bos><start_of_turn>\n{sys_instruction}' + f' {message}<end_of_turn>\n'
41
-
42
- # retrieving and generating answer in one call
43
- def retrieved_info(query, rag_model = rag_model, generating_model = model):
44
- # Tokenize Query
45
- retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
46
- [query],
47
- return_tensors = 'pt',
48
- padding = True,
49
- truncation = True,
50
- )['input_ids'].to(device)
51
-
52
- # Retrieve Documents
53
- question_encoder_output = rag_model.rag.question_encoder(retriever_input_ids)
54
- question_encoder_pool_output = question_encoder_output[0]
55
-
56
- result = rag_model.retriever(
57
- retriever_input_ids,
58
- question_encoder_pool_output.cpu().detach().to(torch.float32).numpy(),
59
- prefix = rag_model.rag.generator.config.prefix,
60
- n_docs = rag_model.config.n_docs,
61
- return_tensors = 'pt',
62
- )
63
-
64
- # Preparing query and retrieved docs for model
65
- all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids)
66
- retrieved_context = []
67
- for docs in all_docs:
68
- titles = [strip_title(title) for title in docs['title']]
69
- texts = docs['text']
70
- for title, text in zip(titles, texts):
71
- retrieved_context.append(f'{title}: {text}')
72
-
73
- generation_model_input = input_format(query, retrieved_context)
74
-
75
- # Generating answer using gemma model
76
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
77
- input_ids = tokenizer(generation_model_input, return_tensors='pt').to(device)
78
- output = generating_model.generate(input_ids, max_new_tokens = 512)
79
-
80
- return tokenizer.decode(output[0])
81
-
82
-
83
-
84
-
85
-
86
-
87
- def respond(
88
- message,
89
- history: list[tuple[str, str]],
90
- system_message,
91
- max_tokens ,
92
- temperature,
93
- top_p,
94
- ):
95
- if message: # If there's a user query
96
- response = retrieved_info(message) # Get the answer from your local FAISS and Q&A model
97
- return response
98
-
99
- # In case no message, return an empty string
100
- return ""
101
-
102
-
103
-
104
- """
105
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
106
- """
107
- # Custom title and description
108
- title = "🧠 Welcome to Your AI Knowledge Assistant"
109
- description = """
110
- HI!!, I am your loyal assistant, y functionality is based on RAG model, I retrieves relevant information and provide answers based on that. Ask me any question, and let me assist you.
111
- My capabilities are limited because I am still in development phase. I will do my best to assist you. SOOO LET'S BEGGINNNN......
112
- """
113
-
114
- demo = gr.ChatInterface(
115
- respond,
116
- type = 'messages',
117
- additional_inputs=[
118
- gr.Textbox(value="You are a helpful and friendly assistant.", label="System message"),
119
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
120
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
121
- gr.Slider(
122
- minimum=0.1,
123
- maximum=1.0,
124
- value=0.95,
125
- step=0.05,
126
- label="Top-p (nucleus sampling)",
127
- ),
128
- ],
129
- title=title,
130
- description=description,
131
- textbox=gr.Textbox(placeholder=["'What is the future of AI?' or 'App Development'"]),
132
- examples=[["✨Future of AI"], ["📱App Development"]],
133
- example_icons=["🤖", "📱"],
134
- theme="compact",
135
- )
136
-
137
-
138
- if __name__ == "__main__":
139
- demo.launch(share = True )
140
-
 
1
+ import torch
2
+ import transformers
3
+ from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer, AutoModelForCausalLM
4
+ import gradio as gr
5
+
6
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
+
8
+
9
+ dataset_path = "./5k_index_data/my_knowledge_dataset"
10
+ index_path = "./5k_index_data/my_knowledge_dataset_hnsw_index.faiss"
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
13
+ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom",
14
+ passages_path = dataset_path,
15
+ index_path = index_path,
16
+ n_docs = 5)
17
+ rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
18
+ rag_model.retriever.init_retrieval()
19
+ rag_model.to(device)
20
+ model = AutoModelForCausalLM.from_pretrained('google/gemma-7b-it',
21
+ device_map = 'auto',
22
+ torch_dtype = torch.bfloat16,
23
+ )
24
+
25
+
26
+
27
+ def strip_title(title):
28
+ if title.startswith('"'):
29
+ title = title[1:]
30
+ if title.endswith('"'):
31
+ title = title[:-1]
32
+
33
+ return title
34
+
35
+ # getting the correct format to input in gemma model
36
+ def input_format(query, context):
37
+ sys_instruction = f'Context:\n {context} \n Given the following information, generate answer to the question. Provide links in the answer from the information to increase credebility.'
38
+ message = f'Question: {query}'
39
+
40
+ return f'<bos><start_of_turn>\n{sys_instruction}' + f' {message}<end_of_turn>\n'
41
+
42
+ # retrieving and generating answer in one call
43
+ def retrieved_info(query, rag_model = rag_model, generating_model = model):
44
+ # Tokenize Query
45
+ retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
46
+ [query],
47
+ return_tensors = 'pt',
48
+ padding = True,
49
+ truncation = True,
50
+ )['input_ids'].to(device)
51
+
52
+ # Retrieve Documents
53
+ question_encoder_output = rag_model.rag.question_encoder(retriever_input_ids)
54
+ question_encoder_pool_output = question_encoder_output[0]
55
+
56
+ result = rag_model.retriever(
57
+ retriever_input_ids,
58
+ question_encoder_pool_output.cpu().detach().to(torch.float32).numpy(),
59
+ prefix = rag_model.rag.generator.config.prefix,
60
+ n_docs = rag_model.config.n_docs,
61
+ return_tensors = 'pt',
62
+ )
63
+
64
+ # Preparing query and retrieved docs for model
65
+ all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids)
66
+ retrieved_context = []
67
+ for docs in all_docs:
68
+ titles = [strip_title(title) for title in docs['title']]
69
+ texts = docs['text']
70
+ for title, text in zip(titles, texts):
71
+ retrieved_context.append(f'{title}: {text}')
72
+
73
+ generation_model_input = input_format(query, retrieved_context)
74
+
75
+ # Generating answer using gemma model
76
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it")
77
+ input_ids = tokenizer(generation_model_input, return_tensors='pt').to(device)
78
+ output = generating_model.generate(input_ids, max_new_tokens = 512)
79
+
80
+ return tokenizer.decode(output[0])
81
+
82
+
83
+
84
+
85
+
86
+
87
+ def respond(
88
+ message,
89
+ history: list[tuple[str, str]],
90
+ system_message,
91
+ max_tokens ,
92
+ temperature,
93
+ top_p,
94
+ ):
95
+ if message: # If there's a user query
96
+ response = retrieved_info(message) # Get the answer from your local FAISS and Q&A model
97
+ return response
98
+
99
+ # In case no message, return an empty string
100
+ return ""
101
+
102
+
103
+
104
+ """
105
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
106
+ """
107
+ # Custom title and description
108
+ title = "🧠 Welcome to Your AI Knowledge Assistant"
109
+ description = """
110
+ Hi!! I am your loyal assistant. My functionality is based on the RAG model. I retrieve relevant information and provide answers based on that. Ask me any questions, and let me assist you.
111
+ My capabilities are limited because I am still in the development phase. I will do my best to assist you. SOOO LET'S BEGGINNNN......
112
+ """
113
+
114
+ demo = gr.ChatInterface(
115
+ respond,
116
+ type = 'messages',
117
+ additional_inputs=[
118
+ gr.Textbox(value="You are a helpful and friendly assistant.", label="System message"),
119
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
120
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
121
+ gr.Slider(
122
+ minimum=0.1,
123
+ maximum=1.0,
124
+ value=0.95,
125
+ step=0.05,
126
+ label="Top-p (nucleus sampling)",
127
+ ),
128
+ ],
129
+ title=title,
130
+ description=description,
131
+ textbox=gr.Textbox(placeholder=["'What is the future of AI?' or 'App Development'"]),
132
+ examples=[["✨Future of AI"], ["📱App Development"]],
133
+ example_icons=["🤖", "📱"],
134
+ theme="compact",
135
+ )
136
+
137
+
138
+ if __name__ == "__main__":
139
+ demo.launch(share = True )
140
+