Marcus Posey commited on
Commit
e69ee5c
·
1 Parent(s): c5da24c

Add asynchronous responses

Browse files
Files changed (1) hide show
  1. app.py +28 -27
app.py CHANGED
@@ -8,6 +8,8 @@ import gspread
8
  from oauth2client.service_account import ServiceAccountCredentials
9
  import gradio as gr
10
  from huggingface_hub import login
 
 
11
 
12
 
13
  BOOK_MAPPING = {
@@ -39,26 +41,25 @@ class ModelManager:
39
  self.model_B = Client("mep296/llama-3-8b-rephrase-summarize-quality")
40
  self.model_C = Client("mep296/llama-3-8b-entigraph-quality")
41
  self.template_text = self._load_template()
42
-
43
  def _load_template(self):
44
  with open("prompt_template.txt", "r", encoding="utf-8") as file:
45
  return file.read()
46
-
47
- def get_model_response(self, model_name, prompt):
48
  try:
49
  formatted_prompt = self.template_text.format(prompt)
50
 
51
- model_clients = {
52
- "Model_A": self.model_A,
53
- "Model_B": self.model_B,
54
- "Model_C": self.model_C
55
- }
 
 
 
56
 
57
- client = model_clients[model_name]
58
- response = client.predict(
59
- prompt=formatted_prompt,
60
- api_name="/predict"
61
- )
62
  return response
63
  except (httpx.ReadTimeout, httpx.ConnectTimeout) as e:
64
  print(f"Timeout while getting response from {model_name}: {str(e)}")
@@ -66,19 +67,15 @@ class ModelManager:
66
  except Exception as e:
67
  print(f"Error getting response from {model_name}: {str(e)}")
68
  return f"Error: Could not get response from {model_name}. Please try again."
69
-
70
- def get_all_model_responses(self, prompt):
71
- responses = []
72
- model_responses.clear()
73
 
74
- self.model_A = Client("mep296/llama-3-8b-rephrase-quality")
75
- self.model_B = Client("mep296/llama-3-8b-rephrase-summarize-quality")
76
- self.model_C = Client("mep296/llama-3-8b-entigraph-quality")
77
-
78
- for model in ['A', 'B', 'C']:
79
- response = self.get_model_response(f"Model_{model}", prompt)
80
- model_responses[f"Model_{model}"] = response
81
- responses.append(response)
82
  return responses
83
 
84
  def get_book_model_mapping(self, book):
@@ -128,6 +125,7 @@ class ModelComparisonApp:
128
  self.chat_history_B = []
129
  self.chat_history_C = []
130
  self.state = gr.State(value="")
 
131
 
132
  def create_interface(self):
133
  text_size = gr.themes.sizes.text_lg
@@ -191,7 +189,10 @@ class ModelComparisonApp:
191
  mapping_dict = self.model_manager.book_model_assignments[self.selected_book]
192
  model_order = ["rephrase", "rephrase_summarize", "entigraph"]
193
  model_to_index = {model: i for i, model in enumerate(model_order)}
194
- responses = self.model_manager.get_all_model_responses(prompt)
 
 
 
195
 
196
  chats = []
197
  for response in responses:
@@ -202,7 +203,7 @@ class ModelComparisonApp:
202
 
203
  reordered_chats = [chats[model_to_index[mapping_dict[model]]] for model in ["Model A", "Model B", "Model C"]]
204
  return reordered_chats
205
-
206
  def get_votes(self, book, category, question, rating_A, rating_B, rating_C):
207
  model_mapping = self.model_manager.get_book_model_mapping(book)
208
  rating_data = {
 
8
  from oauth2client.service_account import ServiceAccountCredentials
9
  import gradio as gr
10
  from huggingface_hub import login
11
+ import asyncio
12
+ from concurrent.futures import ThreadPoolExecutor
13
 
14
 
15
  BOOK_MAPPING = {
 
41
  self.model_B = Client("mep296/llama-3-8b-rephrase-summarize-quality")
42
  self.model_C = Client("mep296/llama-3-8b-entigraph-quality")
43
  self.template_text = self._load_template()
44
+
45
  def _load_template(self):
46
  with open("prompt_template.txt", "r", encoding="utf-8") as file:
47
  return file.read()
48
+
49
+ async def get_model_response_async(self, model_name, client, prompt):
50
  try:
51
  formatted_prompt = self.template_text.format(prompt)
52
 
53
+ loop = asyncio.get_running_loop()
54
+ with ThreadPoolExecutor() as executor:
55
+ response = await loop.run_in_executor(
56
+ executor,
57
+ client.predict,
58
+ formatted_prompt,
59
+ "/predict"
60
+ )
61
 
62
+ model_responses[model_name] = response
 
 
 
 
63
  return response
64
  except (httpx.ReadTimeout, httpx.ConnectTimeout) as e:
65
  print(f"Timeout while getting response from {model_name}: {str(e)}")
 
67
  except Exception as e:
68
  print(f"Error getting response from {model_name}: {str(e)}")
69
  return f"Error: Could not get response from {model_name}. Please try again."
 
 
 
 
70
 
71
+ async def get_all_model_responses_async(self, prompt):
72
+ tasks = [
73
+ self.get_model_response_async("Model_A", self.model_A, prompt),
74
+ self.get_model_response_async("Model_B", self.model_B, prompt),
75
+ self.get_model_response_async("Model_C", self.model_C, prompt)
76
+ ]
77
+
78
+ responses = await asyncio.gather(*tasks)
79
  return responses
80
 
81
  def get_book_model_mapping(self, book):
 
125
  self.chat_history_B = []
126
  self.chat_history_C = []
127
  self.state = gr.State(value="")
128
+ self.loop = asyncio.get_event_loop()
129
 
130
  def create_interface(self):
131
  text_size = gr.themes.sizes.text_lg
 
189
  mapping_dict = self.model_manager.book_model_assignments[self.selected_book]
190
  model_order = ["rephrase", "rephrase_summarize", "entigraph"]
191
  model_to_index = {model: i for i, model in enumerate(model_order)}
192
+
193
+ responses = self.loop.run_until_complete(
194
+ self.model_manager.get_all_model_responses_async(prompt)
195
+ )
196
 
197
  chats = []
198
  for response in responses:
 
203
 
204
  reordered_chats = [chats[model_to_index[mapping_dict[model]]] for model in ["Model A", "Model B", "Model C"]]
205
  return reordered_chats
206
+
207
  def get_votes(self, book, category, question, rating_A, rating_B, rating_C):
208
  model_mapping = self.model_manager.get_book_model_mapping(book)
209
  rating_data = {