AlexanderKazakov commited on
Commit
a37b98a
·
1 Parent(s): 0ae385b

fix layout

Browse files
Files changed (2) hide show
  1. gradio_app/app.py +14 -157
  2. gradio_app/backend/query_llm.py +2 -2
gradio_app/app.py CHANGED
@@ -1,48 +1,13 @@
1
- """
2
- Credit to Derek Thomas, [email protected]
3
- """
4
-
5
- # import subprocess
6
- # subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])
7
-
8
- import logging
9
  from time import perf_counter
10
 
11
  import gradio as gr
12
- import markdown
13
- # import lancedb
14
- from jinja2 import Environment, FileSystemLoader
15
 
16
- from gradio_app.backend.ChatGptInteractor import num_tokens_from_messages
17
- from gradio_app.backend.cross_encoder import rerank_with_cross_encoder
18
  from gradio_app.backend.query_llm import *
19
- from gradio_app.backend.embedders import EmbedderFactory
20
 
21
- from settings import *
22
 
23
- # Setting up the logging
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
26
 
27
- # Set up the template environment with the templates directory
28
- env = Environment(loader=FileSystemLoader('gradio_app/templates'))
29
-
30
- # Load the templates directly from the environment
31
- context_template = env.get_template('context_template.j2')
32
- context_html_template = env.get_template('context_html_template.j2')
33
-
34
- # db = lancedb.connect(LANCEDB_DIRECTORY)
35
- db = None
36
-
37
- # Examples
38
- examples = [
39
- 'What is BERT?',
40
- 'Tell me about GPT',
41
- 'How to use accelerate in google colab?',
42
- 'What is the capital of China?',
43
- 'Why is the sky blue?',
44
- ]
45
-
46
 
47
  def add_text(history, text):
48
  history = [] if history is None else history
@@ -50,65 +15,14 @@ def add_text(history, text):
50
  return history, gr.Textbox(value="", interactive=False)
51
 
52
 
53
- def find_context(query, cross_enc, chunk, embed):
54
- logger.info('Retrieving documents...')
55
- gr.Info('Start documents retrieval ...')
56
- t = perf_counter()
57
-
58
- table_name = f'{LANCEDB_TABLE_NAME}_{chunk}_{embed}'
59
- table = db.open_table(table_name)
60
-
61
- embedder = EmbedderFactory.get_embedder(embed)
62
-
63
- query_vec = embedder.embed([query])[0]
64
- documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME)
65
- top_k_rank = TOP_K_RANK if cross_enc is not None else TOP_K_RERANK
66
- documents = documents.limit(top_k_rank).to_list()
67
- thresh_dist = thresh_distances[embed]
68
- thresh_dist = max(thresh_dist, min(d['_distance'] for d in documents))
69
- documents = [d for d in documents if d['_distance'] <= thresh_dist]
70
- documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
71
-
72
- t = perf_counter() - t
73
- logger.info(f'Finished Retrieving documents in {round(t, 2)} seconds...')
74
-
75
- logger.info('Reranking documents...')
76
- gr.Info('Start documents reranking ...')
77
- t = perf_counter()
78
-
79
- documents = rerank_with_cross_encoder(cross_enc, documents, query)
80
-
81
- t = perf_counter() - t
82
- logger.info(f'Finished Reranking documents in {round(t, 2)} seconds...')
83
- return documents
84
-
85
-
86
- def construct_messages(llm, documents, history):
87
- msg_constructor = get_message_constructor(llm)
88
- while len(documents) != 0:
89
- context = context_template.render(documents=documents)
90
- documents_html = [markdown.markdown(d) for d in documents]
91
- context_html = context_html_template.render(documents=documents_html)
92
- messages = msg_constructor(context, history)
93
- num_tokens = num_tokens_from_messages(messages, 'gpt-3.5-turbo') # todo for HF, it is approximation
94
- if num_tokens + 512 < context_lengths[llm]:
95
- break
96
- documents.pop()
97
- else:
98
- raise gr.Error('Model context length exceeded, reload the page')
99
- return documents, context_html, messages
100
-
101
-
102
- def bot(history, llm, cross_enc, chunk, embed):
103
  history[-1][1] = ""
104
  query = history[-1][0]
105
 
106
  if not query:
107
  raise gr.Error("Empty string was submitted")
108
 
109
- # documents = find_context(query, cross_enc, chunk, embed)
110
- # documents, context_html, messages = construct_messages(llm, documents, history)
111
- context_html = ''
112
  messages = get_message_constructor(llm)('', history)
113
 
114
  llm_gen = get_llm_generator(llm)
@@ -116,7 +30,7 @@ def bot(history, llm, cross_enc, chunk, embed):
116
  t = perf_counter()
117
  for part in llm_gen(messages):
118
  history[-1][1] += part
119
- yield history, context_html
120
  else:
121
  t = perf_counter() - t
122
  logger.info(f'Finished Generating answer in {round(t, 2)} seconds...')
@@ -133,79 +47,22 @@ with gr.Blocks() as demo:
133
  bubble_full_width=False,
134
  show_copy_button=True,
135
  show_share_button=True,
136
- height=500,
137
  )
138
-
139
- with gr.Row():
140
- input_textbox = gr.Textbox(
141
- scale=3,
142
- show_label=False,
143
- placeholder="Enter text and press enter",
144
- container=False,
145
- )
146
- txt_btn = gr.Button(value="Submit text", scale=1)
147
-
148
- chunk_name = gr.Radio(
149
- choices=[
150
- "md",
151
- "txt",
152
- ],
153
- value="md",
154
- label='Chunking policy'
155
- )
156
-
157
- embed_name = gr.Radio(
158
- choices=[
159
- "text-embedding-ada-002",
160
- "sentence-transformers/all-MiniLM-L6-v2",
161
- ],
162
- value="text-embedding-ada-002",
163
- label='Embedder'
164
- )
165
-
166
- cross_enc_name = gr.Radio(
167
- choices=[
168
- None,
169
- "cross-encoder/ms-marco-TinyBERT-L-2-v2",
170
- "cross-encoder/ms-marco-MiniLM-L-12-v2",
171
- ],
172
- value=None,
173
- label='Cross-Encoder'
174
- )
175
-
176
- llm_name = gr.Radio(
177
- choices=[
178
- "gpt-4-1106-preview",
179
- "gpt-4",
180
- "gpt-3.5-turbo-1106",
181
- "gpt-3.5-turbo",
182
- "mistralai/Mistral-7B-Instruct-v0.1",
183
- "tiiuae/falcon-180B-chat",
184
- # "GeneZC/MiniChat-3B",
185
- ],
186
- value="gpt-4-1106-preview",
187
- label='LLM'
188
- )
189
-
190
- # Examples
191
- gr.Examples(examples, input_textbox)
192
-
193
  with gr.Column():
194
- context_html = gr.HTML()
195
-
196
- # Turn off interactivity while generating if you click
197
- txt_msg = txt_btn.click(
198
- add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False
199
- ).then(
200
- bot, [chatbot, llm_name, cross_enc_name, chunk_name, embed_name], [chatbot, context_html]
201
- )
202
-
203
- # Turn it back on
204
- txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
205
 
206
  # Turn off interactivity while generating if you hit enter
207
  txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
208
- bot, [chatbot, llm_name, cross_enc_name, chunk_name, embed_name], [chatbot, context_html])
209
 
210
  # Turn it back on
211
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
 
 
 
 
 
 
 
 
 
1
  from time import perf_counter
2
 
3
  import gradio as gr
 
 
 
4
 
 
 
5
  from gradio_app.backend.query_llm import *
 
6
 
 
7
 
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def add_text(history, text):
13
  history = [] if history is None else history
 
15
  return history, gr.Textbox(value="", interactive=False)
16
 
17
 
18
+ def bot(history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  history[-1][1] = ""
20
  query = history[-1][0]
21
 
22
  if not query:
23
  raise gr.Error("Empty string was submitted")
24
 
25
+ llm = 'gpt-4-turbo-preview'
 
 
26
  messages = get_message_constructor(llm)('', history)
27
 
28
  llm_gen = get_llm_generator(llm)
 
30
  t = perf_counter()
31
  for part in llm_gen(messages):
32
  history[-1][1] += part
33
+ yield history
34
  else:
35
  t = perf_counter() - t
36
  logger.info(f'Finished Generating answer in {round(t, 2)} seconds...')
 
47
  bubble_full_width=False,
48
  show_copy_button=True,
49
  show_share_button=True,
50
+ height=800
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  with gr.Column():
53
+ input_textbox = gr.Textbox(
54
+ interactive=True,
55
+ show_label=False,
56
+ placeholder="Enter text and press enter",
57
+ container=False,
58
+ autofocus=True,
59
+ lines=40,
60
+ max_lines=100,
61
+ )
 
 
62
 
63
  # Turn off interactivity while generating if you hit enter
64
  txt_msg = input_textbox.submit(add_text, [chatbot, input_textbox], [chatbot, input_textbox], queue=False).then(
65
+ bot, [chatbot], [chatbot])
66
 
67
  # Turn it back on
68
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [input_textbox], queue=False)
gradio_app/backend/query_llm.py CHANGED
@@ -111,7 +111,7 @@ def construct_openai_messages(context, history):
111
 
112
 
113
  def get_message_constructor(llm_name):
114
- if llm_name in ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-1106"]:
115
  return construct_openai_messages
116
  if llm_name in ['mistralai/Mistral-7B-Instruct-v0.1', "tiiuae/falcon-180B-chat", "GeneZC/MiniChat-3B"]:
117
  return construct_mistral_messages
@@ -119,7 +119,7 @@ def get_message_constructor(llm_name):
119
 
120
 
121
  def get_llm_generator(llm_name):
122
- if llm_name in ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-1106"]:
123
  cgi = ChatGptInteractor(
124
  model_name=llm_name, stream=True,
125
  # max_tokens=None, temperature=0,
 
111
 
112
 
113
  def get_message_constructor(llm_name):
114
+ if 'gpt' in llm_name:
115
  return construct_openai_messages
116
  if llm_name in ['mistralai/Mistral-7B-Instruct-v0.1', "tiiuae/falcon-180B-chat", "GeneZC/MiniChat-3B"]:
117
  return construct_mistral_messages
 
119
 
120
 
121
  def get_llm_generator(llm_name):
122
+ if 'gpt' in llm_name:
123
  cgi = ChatGptInteractor(
124
  model_name=llm_name, stream=True,
125
  # max_tokens=None, temperature=0,