Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -8,8 +8,11 @@ import gradio as gr
|
|
8 |
|
9 |
torch.set_num_threads(2)
|
10 |
|
11 |
-
openrouter_key = os.environ.get("OPENROUTER_KEY")
|
12 |
model = EmbeddingModel(use_quantized_onnx_model=True)
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def fetch_links(query, max_results=5):
|
15 |
with DDGS() as ddgs:
|
@@ -53,7 +56,34 @@ def retrieval_pipeline(query):
|
|
53 |
|
54 |
return context, websearch_time, webcrawl_time, embedding_time, retrieval_time, links
|
55 |
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
context, websearch_time, webcrawl_time, embedding_time, retrieval_time, links = retrieval_pipeline(message)
|
58 |
|
59 |
if detect_language(message) == Language.ptbr:
|
@@ -61,56 +91,14 @@ async def predict(message, history):
|
|
61 |
else:
|
62 |
prompt = f"Context:\n\n{context}\n\nBased on the context, answer: {message}"
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
url = "https://openrouter.ai/api/v1/chat/completions"
|
67 |
-
headers = { "Content-Type": "application/json",
|
68 |
-
"Authorization": f"Bearer {openrouter_key}" }
|
69 |
-
body = { "stream": True,
|
70 |
-
"models": [
|
71 |
-
"huggingfaceh4/zephyr-7b-beta:free",
|
72 |
-
"mistralai/mistral-7b-instruct:free",
|
73 |
-
"nousresearch/nous-capybara-7b:free",
|
74 |
-
"openchat/openchat-7b:free"
|
75 |
-
],
|
76 |
-
"route": "fallback",
|
77 |
-
"max_tokens": 512,
|
78 |
-
"messages": [
|
79 |
-
{"role": "user", "content": prompt}
|
80 |
-
] }
|
81 |
|
82 |
full_response = ""
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
buffer = "" # A buffer to hold incomplete lines of data
|
87 |
-
async for chunk in response.content.iter_any():
|
88 |
-
buffer += chunk.decode()
|
89 |
-
while "\n" in buffer: # Process as long as there are complete lines in the buffer
|
90 |
-
line, buffer = buffer.split("\n", 1)
|
91 |
-
|
92 |
-
print(line)
|
93 |
-
|
94 |
-
if line.startswith("data: "):
|
95 |
-
event_data = line[len("data: "):]
|
96 |
-
if event_data != '[DONE]':
|
97 |
-
try:
|
98 |
-
current_text = json.loads(event_data)['choices'][0]['delta']['content']
|
99 |
-
full_response += current_text
|
100 |
-
yield full_response
|
101 |
-
await asyncio.sleep(0.01)
|
102 |
-
except Exception as e:
|
103 |
-
print("Error event 1", e)
|
104 |
-
try:
|
105 |
-
current_text = json.loads(event_data)['choices'][0]['text']
|
106 |
-
full_response += current_text
|
107 |
-
yield full_response
|
108 |
-
await asyncio.sleep(0.01)
|
109 |
-
except Exception as e:
|
110 |
-
print("Error event 2", e)
|
111 |
|
112 |
final_metadata_block = ""
|
113 |
-
|
114 |
final_metadata_block += f"Links visited:\n"
|
115 |
for link in links:
|
116 |
final_metadata_block += f"{link}\n"
|
|
|
8 |
|
9 |
torch.set_num_threads(2)
|
10 |
|
|
|
11 |
model = EmbeddingModel(use_quantized_onnx_model=True)
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
|
13 |
+
llm = AutoModelForCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5")
|
14 |
+
|
15 |
+
prompt_template = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: $PROMPT ASSISTANT: "
|
16 |
|
17 |
def fetch_links(query, max_results=5):
|
18 |
with DDGS() as ddgs:
|
|
|
56 |
|
57 |
return context, websearch_time, webcrawl_time, embedding_time, retrieval_time, links
|
58 |
|
59 |
+
@spaces.GPU(enable_queue=True)
|
60 |
+
def ask_open_llm(prompt):
|
61 |
+
device = torch.device('cuda')
|
62 |
+
|
63 |
+
llm.to(device)
|
64 |
+
model_inputs = tokenizer([
|
65 |
+
prompt
|
66 |
+
], return_tensors="pt").to(device)
|
67 |
+
|
68 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=120., skip_prompt=True, skip_special_tokens=True)
|
69 |
+
|
70 |
+
generate_kwargs = dict(
|
71 |
+
model_inputs,
|
72 |
+
streamer=streamer,
|
73 |
+
max_new_tokens=512,
|
74 |
+
top_p=0.2,
|
75 |
+
top_k=20,
|
76 |
+
temperature=0.4,
|
77 |
+
repetition_penalty=1.1
|
78 |
+
)
|
79 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
80 |
+
t.start() # Starting the generation in a separate thread.
|
81 |
+
partial_message = ""
|
82 |
+
for new_token in streamer:
|
83 |
+
partial_message += new_token
|
84 |
+
yield partial_message
|
85 |
+
|
86 |
+
def predict(message, history):
|
87 |
context, websearch_time, webcrawl_time, embedding_time, retrieval_time, links = retrieval_pipeline(message)
|
88 |
|
89 |
if detect_language(message) == Language.ptbr:
|
|
|
91 |
else:
|
92 |
prompt = f"Context:\n\n{context}\n\nBased on the context, answer: {message}"
|
93 |
|
94 |
+
prompt = prompt_template.replace("$PROMPT", prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
full_response = ""
|
97 |
+
for partial_message in ask_open_llm(prompt):
|
98 |
+
full_response += partial_message
|
99 |
+
yield full_response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
final_metadata_block = ""
|
|
|
102 |
final_metadata_block += f"Links visited:\n"
|
103 |
for link in links:
|
104 |
final_metadata_block += f"{link}\n"
|