AlexanderKazakov
commited on
Commit
·
a37b98a
1
Parent(s):
0ae385b
fix layout
Browse files- gradio_app/app.py +14 -157
- gradio_app/backend/query_llm.py +2 -2
gradio_app/app.py
CHANGED
@@ -1,48 +1,13 @@
|
|
1 |
-
"""
|
2 |
-
Credit to Derek Thomas, [email protected]
|
3 |
-
"""
|
4 |
-
|
5 |
-
# import subprocess
|
6 |
-
# subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])
|
7 |
-
|
8 |
-
import logging
|
9 |
from time import perf_counter
|
10 |
|
11 |
import gradio as gr
|
12 |
-
import markdown
|
13 |
-
# import lancedb
|
14 |
-
from jinja2 import Environment, FileSystemLoader
|
15 |
|
16 |
-
from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
|
17 |
-
from gradio_app.backend.cross_encoder import rerank_with_cross_encoder
|
18 |
from gradio_app.backend.query_llm import *
|
19 |
-
from gradio_app.backend.embedders import EmbedderFactory
|
20 |
|
21 |
-
from settings import *
|
22 |
|
23 |
-
# Setting up the logging
|
24 |
logging.basicConfig(level=logging.INFO)
|
25 |
logger = logging.getLogger(__name__)
|
26 |
|
27 |
-
# Set up the template environment with the templates directory
|
28 |
-
env = Environment(loader=FileSystemLoader('gradio_app/templates'))
|
29 |
-
|
30 |
-
# Load the templates directly from the environment
|
31 |
-
context_template = env.get_template('context_template.j2')
|
32 |
-
context_html_template = env.get_template('context_html_template.j2')
|
33 |
-
|
34 |
-
# db = lancedb.connect(LANCEDB_DIRECTORY)
|
35 |
-
db = None
|
36 |
-
|
37 |
-
# Examples
|
38 |
-
examples = [
|
39 |
-
'What is BERT?',
|
40 |
-
'Tell me about GPT',
|
41 |
-
'How to use accelerate in google colab?',
|
42 |
-
'What is the capital of China?',
|
43 |
-
'Why is the sky blue?',
|
44 |
-
]
|
45 |
-
|
46 |
|
47 |
def add_text(history, text):
|
48 |
history = [] if history is None else history
|
@@ -50,65 +15,14 @@ def add_text(history, text):
|
|
50 |
return history, gr.Textbox(value="", interactive=False)
|
51 |
|
52 |
|
53 |
-
def
|
54 |
-
logger.info('Retrieving documents...')
|
55 |
-
gr.Info('Start documents retrieval ...')
|
56 |
-
t = perf_counter()
|
57 |
-
|
58 |
-
table_name = f'{LANCEDB_TABLE_NAME}_{chunk}_{embed}'
|
59 |
-
table = db.open_table(table_name)
|
60 |
-
|
61 |
-
embedder = EmbedderFactory.get_embedder(embed)
|
62 |
-
|
63 |
-
query_vec = embedder.embed([query])[0]
|
64 |
-
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME)
|
65 |
-
top_k_rank = TOP_K_RANK if cross_enc is not None else TOP_K_RERANK
|
66 |
-
documents = documents.limit(top_k_rank).to_list()
|
67 |
-
thresh_dist = thresh_distances[embed]
|
68 |
-
thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
|
69 |
-
documents = [d for d in documents if d['_distance'] <= thresh_dist]
|
70 |
-
documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
|
71 |
-
|
72 |
-
t = perf_counter() - t
|
73 |
-
logger.info(f'Finished Retrieving documents in {round(t, 2)} seconds...')
|
74 |
-
|
75 |
-
logger.info('Reranking documents...')
|
76 |
-
gr.Info('Start documents reranking ...')
|
77 |
-
t = perf_counter()
|
78 |
-
|
79 |
-
documents = rerank_with_cross_encoder(cross_enc, documents, query)
|
80 |
-
|
81 |
-
t = perf_counter() - t
|
82 |
-
logger.info(f'Finished Reranking documents in {round(t, 2)} seconds...')
|
83 |
-
return documents
|
84 |
-
|
85 |
-
|
86 |
-
def construct_messages(llm, documents, history):
|
87 |
-
msg_constructor = get_message_constructor(llm)
|
88 |
-
while len(documents) != 0:
|
89 |
-
context = context_template.render(documents=documents)
|
90 |
-
documents_html = [markdown.markdown(d) for d in documents]
|
91 |
-
context_html = context_html_template.render(documents=documents_html)
|
92 |
-
messages = msg_constructor(context, history)
|
93 |
-
num_tokens = num_tokens_from_messages(messages, 'gpt-3.5-turbo') # todo for HF, it is approximation
|
94 |
-
if num_tokens + 512 < context_lengths[llm]:
|
95 |
-
break
|
96 |
-
documents.pop()
|
97 |
-
else:
|
98 |
-
raise gr.Error('Model context length exceeded, reload the page')
|
99 |
-
return documents, context_html, messages
|
100 |
-
|
101 |
-
|
102 |
-
def bot(history, llm, cross_enc, chunk, embed):
|
103 |
history[-1][1] = ""
|
104 |
query = history[-1][0]
|
105 |
|
106 |
if not query:
|
107 |
raise gr.Error("Empty string was submitted")
|
108 |
|
109 |
-
|
110 |
-
# documents, context_html, messages = construct_messages(llm, documents, history)
|
111 |
-
context_html = ''
|
112 |
messages = get_message_constructor(llm)('', history)
|
113 |
|
114 |
llm_gen = get_llm_generator(llm)
|
@@ -116,7 +30,7 @@ def bot(history, llm, cross_enc, chunk, embed):
|
|
116 |
t = perf_counter()
|
117 |
for part in llm_gen(messages):
|
118 |
history[-1][1] += part
|
119 |
-
yield history
|
120 |
else:
|
121 |
t = perf_counter() - t
|
122 |
logger.info(f'Finished Generating answer in {round(t, 2)} seconds...')
|
@@ -133,79 +47,22 @@ with gr.Blocks() as demo:
|
|
133 |
bubble_full_width=False,
|
134 |
show_copy_button=True,
|
135 |
show_share_button=True,
|
136 |
-
height=
|
137 |
)
|
138 |
-
|
139 |
-
with gr.Row():
|
140 |
-
input_textbox = gr.Textbox(
|
141 |
-
scale=3,
|
142 |
-
show_label=False,
|
143 |
-
placeholder="Enter text and press enter",
|
144 |
-
container=False,
|
145 |
-
)
|
146 |
-
txt_btn = gr.Button(value="Submit text", scale=1)
|
147 |
-
|
148 |
-
chunk_name = gr.Radio(
|
149 |
-
choices=[
|
150 |
-
"md",
|
151 |
-
"txt",
|
152 |
-
],
|
153 |
-
value="md",
|
154 |
-
label='Chunking policy'
|
155 |
-
)
|
156 |
-
|
157 |
-
embed_name = gr.Radio(
|
158 |
-
choices=[
|
159 |
-
"text-embedding-ada-002",
|
160 |
-
"sentence-transformers/all-MiniLM-L6-v2",
|
161 |
-
],
|
162 |
-
value="text-embedding-ada-002",
|
163 |
-
label='Embedder'
|
164 |
-
)
|
165 |
-
|
166 |
-
cross_enc_name = gr.Radio(
|
167 |
-
choices=[
|
168 |
-
None,
|
169 |
-
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
170 |
-
"cross-encoder/ms-marco-MiniLM-L-12-v2",
|
171 |
-
],
|
172 |
-
value=None,
|
173 |
-
label='Cross-Encoder'
|
174 |
-
)
|
175 |
-
|
176 |
-
llm_name = gr.Radio(
|
177 |
-
choices=[
|
178 |
-
"gpt-4-1106-preview",
|
179 |
-
"gpt-4",
|
180 |
-
"gpt-3.5-turbo-1106",
|
181 |
-
"gpt-3.5-turbo",
|
182 |
-
"mistralai/Mistral-7B-Instruct-v0.1",
|
183 |
-
"tiiuae/falcon-180B-chat",
|
184 |
-
# "GeneZC/MiniChat-3B",
|
185 |
-
],
|
186 |
-
value="gpt-4-1106-preview",
|
187 |
-
label='LLM'
|
188 |
-
)
|
189 |
-
|
190 |
-
# Examples
|
191 |
-
gr.Examples(examples, input_textbox)
|
192 |
-
|
193 |
with gr.Column():
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
# Turn it back on
|
204 |
-
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
|
205 |
|
206 |
# Turn off interactivity while generating if you hit enter
|
207 |
txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
|
208 |
-
bot, [chatbot
|
209 |
|
210 |
# Turn it back on
|
211 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from time import perf_counter
|
2 |
|
3 |
import gradio as gr
|
|
|
|
|
|
|
4 |
|
|
|
|
|
5 |
from gradio_app.backend.query_llm import *
|
|
|
6 |
|
|
|
7 |
|
|
|
8 |
logging.basicConfig(level=logging.INFO)
|
9 |
logger = logging.getLogger(__name__)
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def add_text(history, text):
|
13 |
history = [] if history is None else history
|
|
|
15 |
return history, gr.Textbox(value="", interactive=False)
|
16 |
|
17 |
|
18 |
+
def bot(history):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
history[-1][1] = ""
|
20 |
query = history[-1][0]
|
21 |
|
22 |
if not query:
|
23 |
raise gr.Error("Empty string was submitted")
|
24 |
|
25 |
+
llm = 'gpt-4-turbo-preview'
|
|
|
|
|
26 |
messages = get_message_constructor(llm)('', history)
|
27 |
|
28 |
llm_gen = get_llm_generator(llm)
|
|
|
30 |
t = perf_counter()
|
31 |
for part in llm_gen(messages):
|
32 |
history[-1][1] += part
|
33 |
+
yield history
|
34 |
else:
|
35 |
t = perf_counter() - t
|
36 |
logger.info(f'Finished Generating answer in {round(t, 2)} seconds...')
|
|
|
47 |
bubble_full_width=False,
|
48 |
show_copy_button=True,
|
49 |
show_share_button=True,
|
50 |
+
height=800
|
51 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
with gr.Column():
|
53 |
+
input_textbox = gr.Textbox(
|
54 |
+
interactive=True,
|
55 |
+
show_label=False,
|
56 |
+
placeholder="Enter text and press enter",
|
57 |
+
container=False,
|
58 |
+
autofocus=True,
|
59 |
+
lines=40,
|
60 |
+
max_lines=100,
|
61 |
+
)
|
|
|
|
|
62 |
|
63 |
# Turn off interactivity while generating if you hit enter
|
64 |
txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
|
65 |
+
bot, [chatbot], [chatbot])
|
66 |
|
67 |
# Turn it back on
|
68 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
|
gradio_app/backend/query_llm.py
CHANGED
@@ -111,7 +111,7 @@ def construct_openai_messages(context, history):
|
|
111 |
|
112 |
|
113 |
def get_message_constructor(llm_name):
|
114 |
-
if
|
115 |
return construct_openai_messages
|
116 |
if llm_name in ['mistralai/Mistral-7B-Instruct-v0.1', "tiiuae/falcon-180B-chat", "GeneZC/MiniChat-3B"]:
|
117 |
return construct_mistral_messages
|
@@ -119,7 +119,7 @@ def get_message_constructor(llm_name):
|
|
119 |
|
120 |
|
121 |
def get_llm_generator(llm_name):
|
122 |
-
if
|
123 |
cgi = ChatGptInteractor(
|
124 |
model_name=llm_name, stream=True,
|
125 |
# max_tokens=None, temperature=0,
|
|
|
111 |
|
112 |
|
113 |
def get_message_constructor(llm_name):
|
114 |
+
if 'gpt' in llm_name:
|
115 |
return construct_openai_messages
|
116 |
if llm_name in ['mistralai/Mistral-7B-Instruct-v0.1', "tiiuae/falcon-180B-chat", "GeneZC/MiniChat-3B"]:
|
117 |
return construct_mistral_messages
|
|
|
119 |
|
120 |
|
121 |
def get_llm_generator(llm_name):
|
122 |
+
if 'gpt' in llm_name:
|
123 |
cgi = ChatGptInteractor(
|
124 |
model_name=llm_name, stream=True,
|
125 |
# max_tokens=None, temperature=0,
|