Baweja commited on
Commit
a397aee
·
verified ·
1 Parent(s): 24c77a7

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +190 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import gradio as gr
2
+ # from huggingface_hub import InferenceClient
3
+
4
+ # """
5
+ # For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
+ # """
7
+ # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
+
9
+
10
+ # def respond(
11
+ # message,
12
+ # history: list[tuple[str, str]],
13
+ # system_message,
14
+ # max_tokens,
15
+ # temperature,
16
+ # top_p,
17
+ # ):
18
+ # messages = [{"role": "system", "content": system_message}]
19
+
20
+ # for val in history:
21
+ # if val[0]:
22
+ # messages.append({"role": "user", "content": val[0]})
23
+ # if val[1]:
24
+ # messages.append({"role": "assistant", "content": val[1]})
25
+
26
+ # messages.append({"role": "user", "content": message})
27
+
28
+ # response = ""
29
+
30
+ # for message in client.chat_completion(
31
+ # messages,
32
+ # max_tokens=max_tokens,
33
+ # stream=True,
34
+ # temperature=temperature,
35
+ # top_p=top_p,
36
+ # ):
37
+ # token = message.choices[0].delta.content
38
+
39
+ # response += token
40
+ # yield response
41
+
42
+ # """
43
+ # For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
+ # """
45
+ # demo = gr.ChatInterface(
46
+ # respond,
47
+ # additional_inputs=[
48
+ # gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
+ # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
+ # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
+ # gr.Slider(
52
+ # minimum=0.1,
53
+ # maximum=1.0,
54
+ # value=0.95,
55
+ # step=0.05,
56
+ # label="Top-p (nucleus sampling)",
57
+ # ),
58
+ # ],
59
+ # )
60
+
61
+
62
+ # if __name__ == "__main__":
63
+ # demo.launch()
64
+
65
+
66
+
67
+
68
+ import gradio as gr
69
+ import torch
70
+ from transformers import RagRetriever, RagSequenceForGeneration, AutoTokenizer
71
+
72
+ """
73
+ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
74
+ """
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+
77
+ def strip_title(title):
78
+ if title.startswith('"'):
79
+ title = title[1:]
80
+ if title.endswith('"'):
81
+ title = title[:-1]
82
+ return title
83
+
84
+ def retrieved_info(rag_model, query):
85
+ # Tokenize query
86
+ retriever_input_ids = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
87
+ [query],
88
+ return_tensors="pt",
89
+ padding=True,
90
+ truncation=True,
91
+ )["input_ids"].to(device)
92
+
93
+ # Retrieve documents
94
+ question_enc_outputs = rag_model.rag.question_encoder(retriever_input_ids)
95
+ question_enc_pool_output = question_enc_outputs[0]
96
+
97
+ result = rag_model.retriever(
98
+ retriever_input_ids,
99
+ question_enc_pool_output.cpu().detach().to(torch.float32).numpy(),
100
+ prefix=rag_model.rag.generator.config.prefix,
101
+ n_docs=rag_model.config.n_docs,
102
+ return_tensors="pt",
103
+ )
104
+
105
+ # Display retrieved documents including URLs
106
+ all_docs = rag_model.retriever.index.get_doc_dicts(result.doc_ids)
107
+ retrieved_context = []
108
+ for docs in all_docs:
109
+ titles = [strip_title(title) for title in docs["title"]]
110
+ texts = docs["text"]
111
+ for title, text in zip(titles, texts):
112
+ #print(f"Title: {title}")
113
+ #print(f"Context: {text}")
114
+ retrieved_context.append(f"{title}: {text}")
115
+
116
+ answer = retrieved_context
117
+ return answer
118
+
119
+
120
+
121
+
122
+ def respond(
123
+ message,
124
+ history: list[tuple[str, str]],
125
+ system_message,
126
+ max_tokens ,
127
+ temperature,
128
+ top_p,
129
+ ):
130
+ # Load model
131
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
132
+
133
+ dataset_path = "./sample/my_knowledge_dataset"
134
+ index_path = "./sample/my_knowledge_dataset_hnsw_index.faiss"
135
+
136
+ tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
137
+ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom",
138
+ passages_path = dataset_path,
139
+ index_path = index_path,
140
+ n_docs = 5)
141
+ rag_model = RagSequenceForGeneration.from_pretrained('facebook/rag-sequence-nq', retriever=retriever)
142
+ rag_model.retriever.init_retrieval()
143
+ rag_model.to(device)
144
+
145
+ if message: # If there's a user query
146
+ response = retrieved_info(rag_model, message) # Get the answer from your local FAISS and Q&A model
147
+ return response[0]
148
+
149
+ # In case no message, return an empty string
150
+ return ""
151
+
152
+
153
+
154
+ """
155
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
156
+ """
157
+ # Custom title and description
158
+ title = "🧠 Welcome to Your AI Knowledge Assistant"
159
+ description = """
160
+ HI!!, I am a chatbot, I retrieves relevant information from a custom dataset using RAG. Ask any question, and let me assist you.
161
+ My capabilities and knowledge is limited right now because of computational resources. Originally I can acess more than a million files
162
+ from my knowledge-base but, right now, I am limited to less than 1000 files. LET'S BEGGINNNN......
163
+ """
164
+
165
+ demo = gr.ChatInterface(
166
+ respond,
167
+ type = 'messages',
168
+ additional_inputs=[
169
+ gr.Textbox(value="You are a helpful and friendly assistant.", label="System message"),
170
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
171
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
172
+ gr.Slider(
173
+ minimum=0.1,
174
+ maximum=1.0,
175
+ value=0.95,
176
+ step=0.05,
177
+ label="Top-p (nucleus sampling)",
178
+ ),
179
+ ],
180
+ title=title,
181
+ description=description,
182
+ textbox=gr.Textbox(placeholder=["'What is the future of AI?' or 'App Development'"]),
183
+ examples=[["✨Future of AI"], ["📱App Development"]],
184
+ example_icons=["🤖", "📱"],
185
+ theme="compact",
186
+ )
187
+
188
+
189
+ if __name__ == "__main__":
190
+ demo.launch(share = True )
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch == 2.4.0
2
+ git+https://github.com/huggingface/transformers.git
3
+ faiss-cpu
4
+ datasets