Spaces:
Runtime error
Runtime error
Marcus Posey
commited on
Commit
·
e69ee5c
1
Parent(s):
c5da24c
Add asynchronous responses
Browse files
app.py
CHANGED
@@ -8,6 +8,8 @@ import gspread
|
|
8 |
from oauth2client.service_account import ServiceAccountCredentials
|
9 |
import gradio as gr
|
10 |
from huggingface_hub import login
|
|
|
|
|
11 |
|
12 |
|
13 |
BOOK_MAPPING = {
|
@@ -39,26 +41,25 @@ class ModelManager:
|
|
39 |
self.model_B = Client("mep296/llama-3-8b-rephrase-summarize-quality")
|
40 |
self.model_C = Client("mep296/llama-3-8b-entigraph-quality")
|
41 |
self.template_text = self._load_template()
|
42 |
-
|
43 |
def _load_template(self):
|
44 |
with open("prompt_template.txt", "r", encoding="utf-8") as file:
|
45 |
return file.read()
|
46 |
-
|
47 |
-
def
|
48 |
try:
|
49 |
formatted_prompt = self.template_text.format(prompt)
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
56 |
|
57 |
-
|
58 |
-
response = client.predict(
|
59 |
-
prompt=formatted_prompt,
|
60 |
-
api_name="/predict"
|
61 |
-
)
|
62 |
return response
|
63 |
except (httpx.ReadTimeout, httpx.ConnectTimeout) as e:
|
64 |
print(f"Timeout while getting response from {model_name}: {str(e)}")
|
@@ -66,19 +67,15 @@ class ModelManager:
|
|
66 |
except Exception as e:
|
67 |
print(f"Error getting response from {model_name}: {str(e)}")
|
68 |
return f"Error: Could not get response from {model_name}. Please try again."
|
69 |
-
|
70 |
-
def get_all_model_responses(self, prompt):
|
71 |
-
responses = []
|
72 |
-
model_responses.clear()
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
return responses
|
83 |
|
84 |
def get_book_model_mapping(self, book):
|
@@ -128,6 +125,7 @@ class ModelComparisonApp:
|
|
128 |
self.chat_history_B = []
|
129 |
self.chat_history_C = []
|
130 |
self.state = gr.State(value="")
|
|
|
131 |
|
132 |
def create_interface(self):
|
133 |
text_size = gr.themes.sizes.text_lg
|
@@ -191,7 +189,10 @@ class ModelComparisonApp:
|
|
191 |
mapping_dict = self.model_manager.book_model_assignments[self.selected_book]
|
192 |
model_order = ["rephrase", "rephrase_summarize", "entigraph"]
|
193 |
model_to_index = {model: i for i, model in enumerate(model_order)}
|
194 |
-
|
|
|
|
|
|
|
195 |
|
196 |
chats = []
|
197 |
for response in responses:
|
@@ -202,7 +203,7 @@ class ModelComparisonApp:
|
|
202 |
|
203 |
reordered_chats = [chats[model_to_index[mapping_dict[model]]] for model in ["Model A", "Model B", "Model C"]]
|
204 |
return reordered_chats
|
205 |
-
|
206 |
def get_votes(self, book, category, question, rating_A, rating_B, rating_C):
|
207 |
model_mapping = self.model_manager.get_book_model_mapping(book)
|
208 |
rating_data = {
|
|
|
8 |
from oauth2client.service_account import ServiceAccountCredentials
|
9 |
import gradio as gr
|
10 |
from huggingface_hub import login
|
11 |
+
import asyncio
|
12 |
+
from concurrent.futures import ThreadPoolExecutor
|
13 |
|
14 |
|
15 |
BOOK_MAPPING = {
|
|
|
41 |
self.model_B = Client("mep296/llama-3-8b-rephrase-summarize-quality")
|
42 |
self.model_C = Client("mep296/llama-3-8b-entigraph-quality")
|
43 |
self.template_text = self._load_template()
|
44 |
+
|
45 |
def _load_template(self):
|
46 |
with open("prompt_template.txt", "r", encoding="utf-8") as file:
|
47 |
return file.read()
|
48 |
+
|
49 |
+
async def get_model_response_async(self, model_name, client, prompt):
|
50 |
try:
|
51 |
formatted_prompt = self.template_text.format(prompt)
|
52 |
|
53 |
+
loop = asyncio.get_running_loop()
|
54 |
+
with ThreadPoolExecutor() as executor:
|
55 |
+
response = await loop.run_in_executor(
|
56 |
+
executor,
|
57 |
+
client.predict,
|
58 |
+
formatted_prompt,
|
59 |
+
"/predict"
|
60 |
+
)
|
61 |
|
62 |
+
model_responses[model_name] = response
|
|
|
|
|
|
|
|
|
63 |
return response
|
64 |
except (httpx.ReadTimeout, httpx.ConnectTimeout) as e:
|
65 |
print(f"Timeout while getting response from {model_name}: {str(e)}")
|
|
|
67 |
except Exception as e:
|
68 |
print(f"Error getting response from {model_name}: {str(e)}")
|
69 |
return f"Error: Could not get response from {model_name}. Please try again."
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
async def get_all_model_responses_async(self, prompt):
|
72 |
+
tasks = [
|
73 |
+
self.get_model_response_async("Model_A", self.model_A, prompt),
|
74 |
+
self.get_model_response_async("Model_B", self.model_B, prompt),
|
75 |
+
self.get_model_response_async("Model_C", self.model_C, prompt)
|
76 |
+
]
|
77 |
+
|
78 |
+
responses = await asyncio.gather(*tasks)
|
79 |
return responses
|
80 |
|
81 |
def get_book_model_mapping(self, book):
|
|
|
125 |
self.chat_history_B = []
|
126 |
self.chat_history_C = []
|
127 |
self.state = gr.State(value="")
|
128 |
+
self.loop = asyncio.get_event_loop()
|
129 |
|
130 |
def create_interface(self):
|
131 |
text_size = gr.themes.sizes.text_lg
|
|
|
189 |
mapping_dict = self.model_manager.book_model_assignments[self.selected_book]
|
190 |
model_order = ["rephrase", "rephrase_summarize", "entigraph"]
|
191 |
model_to_index = {model: i for i, model in enumerate(model_order)}
|
192 |
+
|
193 |
+
responses = self.loop.run_until_complete(
|
194 |
+
self.model_manager.get_all_model_responses_async(prompt)
|
195 |
+
)
|
196 |
|
197 |
chats = []
|
198 |
for response in responses:
|
|
|
203 |
|
204 |
reordered_chats = [chats[model_to_index[mapping_dict[model]]] for model in ["Model A", "Model B", "Model C"]]
|
205 |
return reordered_chats
|
206 |
+
|
207 |
def get_votes(self, book, category, question, rating_A, rating_B, rating_C):
|
208 |
model_mapping = self.model_manager.get_book_model_mapping(book)
|
209 |
rating_data = {
|