cnmoro commited on
Commit
22f2310
·
verified ·
1 Parent(s): d387356

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -48
app.py CHANGED
@@ -8,8 +8,11 @@ import gradio as gr
8
 
9
  torch.set_num_threads(2)
10
 
11
- openrouter_key = os.environ.get("OPENROUTER_KEY")
12
  model = EmbeddingModel(use_quantized_onnx_model=True)
 
 
 
 
13
 
14
  def fetch_links(query, max_results=5):
15
  with DDGS() as ddgs:
@@ -53,7 +56,34 @@ def retrieval_pipeline(query):
53
 
54
  return context, websearch_time, webcrawl_time, embedding_time, retrieval_time, links
55
 
56
- async def predict(message, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  context, websearch_time, webcrawl_time, embedding_time, retrieval_time, links = retrieval_pipeline(message)
58
 
59
  if detect_language(message) == Language.ptbr:
@@ -61,56 +91,14 @@ async def predict(message, history):
61
  else:
62
  prompt = f"Context:\n\n{context}\n\nBased on the context, answer: {message}"
63
 
64
- print(prompt)
65
-
66
- url = "https://openrouter.ai/api/v1/chat/completions"
67
- headers = { "Content-Type": "application/json",
68
- "Authorization": f"Bearer {openrouter_key}" }
69
- body = { "stream": True,
70
- "models": [
71
- "huggingfaceh4/zephyr-7b-beta:free",
72
- "mistralai/mistral-7b-instruct:free",
73
- "nousresearch/nous-capybara-7b:free",
74
- "openchat/openchat-7b:free"
75
- ],
76
- "route": "fallback",
77
- "max_tokens": 512,
78
- "messages": [
79
- {"role": "user", "content": prompt}
80
- ] }
81
 
82
  full_response = ""
83
- async with aiohttp.ClientSession() as session:
84
- async with session.post(url, headers=headers, json=body) as response:
85
-
86
- buffer = "" # A buffer to hold incomplete lines of data
87
- async for chunk in response.content.iter_any():
88
- buffer += chunk.decode()
89
- while "\n" in buffer: # Process as long as there are complete lines in the buffer
90
- line, buffer = buffer.split("\n", 1)
91
-
92
- print(line)
93
-
94
- if line.startswith("data: "):
95
- event_data = line[len("data: "):]
96
- if event_data != '[DONE]':
97
- try:
98
- current_text = json.loads(event_data)['choices'][0]['delta']['content']
99
- full_response += current_text
100
- yield full_response
101
- await asyncio.sleep(0.01)
102
- except Exception as e:
103
- print("Error event 1", e)
104
- try:
105
- current_text = json.loads(event_data)['choices'][0]['text']
106
- full_response += current_text
107
- yield full_response
108
- await asyncio.sleep(0.01)
109
- except Exception as e:
110
- print("Error event 2", e)
111
 
112
  final_metadata_block = ""
113
-
114
  final_metadata_block += f"Links visited:\n"
115
  for link in links:
116
  final_metadata_block += f"{link}\n"
 
8
 
9
  torch.set_num_threads(2)
10
 
 
11
  model = EmbeddingModel(use_quantized_onnx_model=True)
12
+ tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
13
+ llm = AutoModelForCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5")
14
+
15
+ prompt_template = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: $PROMPT ASSISTANT: "
16
 
17
  def fetch_links(query, max_results=5):
18
  with DDGS() as ddgs:
 
56
 
57
  return context, websearch_time, webcrawl_time, embedding_time, retrieval_time, links
58
 
59
+ @spaces.GPU(enable_queue=True)
60
+ def ask_open_llm(prompt):
61
+ device = torch.device('cuda')
62
+
63
+ llm.to(device)
64
+ model_inputs = tokenizer([
65
+ prompt
66
+ ], return_tensors="pt").to(device)
67
+
68
+ streamer = TextIteratorStreamer(tokenizer, timeout=120., skip_prompt=True, skip_special_tokens=True)
69
+
70
+ generate_kwargs = dict(
71
+ model_inputs,
72
+ streamer=streamer,
73
+ max_new_tokens=512,
74
+ top_p=0.2,
75
+ top_k=20,
76
+ temperature=0.4,
77
+ repetition_penalty=1.1
78
+ )
79
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
80
+ t.start() # Starting the generation in a separate thread.
81
+ partial_message = ""
82
+ for new_token in streamer:
83
+ partial_message += new_token
84
+ yield partial_message
85
+
86
+ def predict(message, history):
87
  context, websearch_time, webcrawl_time, embedding_time, retrieval_time, links = retrieval_pipeline(message)
88
 
89
  if detect_language(message) == Language.ptbr:
 
91
  else:
92
  prompt = f"Context:\n\n{context}\n\nBased on the context, answer: {message}"
93
 
94
+ prompt = prompt_template.replace("$PROMPT", prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  full_response = ""
97
+ for partial_message in ask_open_llm(prompt):
98
+ full_response += partial_message
99
+ yield full_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  final_metadata_block = ""
 
102
  final_metadata_block += f"Links visited:\n"
103
  for link in links:
104
  final_metadata_block += f"{link}\n"