Spaces:
Runtime error
Runtime error
Royrotem100
commited on
Commit
•
384005b
1
Parent(s):
b4fc999
Set pad_token to eos_token and exclude user query from response
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
6 |
import threading
|
7 |
import torch
|
8 |
import os
|
|
|
9 |
|
10 |
# Define the API URL to use the internal server
|
11 |
API_URL = "http://localhost:5000/chat"
|
@@ -43,7 +44,7 @@ def chat():
|
|
43 |
inputs = tokenizer(user_input, return_tensors='pt', padding=True, truncation=True)
|
44 |
input_ids = inputs['input_ids']
|
45 |
attention_mask = inputs['attention_mask']
|
46 |
-
outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=
|
47 |
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(user_input, '').strip()
|
48 |
|
49 |
return jsonify({"response": response_text})
|
@@ -72,6 +73,9 @@ def messages_to_history(messages: Messages) -> History:
|
|
72 |
history.append((q['content'], r['content']))
|
73 |
return history
|
74 |
|
|
|
|
|
|
|
75 |
def model_chat(query: str, history: History) -> Tuple[str, History]:
|
76 |
if not query.strip():
|
77 |
return '', history
|
@@ -163,7 +167,10 @@ with gr.Blocks(css='''
|
|
163 |
print(f"Query: {query}") # Debug print statement
|
164 |
response, history = model_chat(query, history)
|
165 |
print(f"Response: {response}") # Debug print statement
|
166 |
-
|
|
|
|
|
|
|
167 |
|
168 |
demo_state = gr.State([])
|
169 |
|
|
|
6 |
import threading
|
7 |
import torch
|
8 |
import os
|
9 |
+
import re
|
10 |
|
11 |
# Define the API URL to use the internal server
|
12 |
API_URL = "http://localhost:5000/chat"
|
|
|
44 |
inputs = tokenizer(user_input, return_tensors='pt', padding=True, truncation=True)
|
45 |
input_ids = inputs['input_ids']
|
46 |
attention_mask = inputs['attention_mask']
|
47 |
+
outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=1000, pad_token_id=tokenizer.eos_token_id)
|
48 |
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(user_input, '').strip()
|
49 |
|
50 |
return jsonify({"response": response_text})
|
|
|
73 |
history.append((q['content'], r['content']))
|
74 |
return history
|
75 |
|
76 |
+
def is_hebrew(text: str) -> bool:
|
77 |
+
return bool(re.search(r'[\u0590-\u05FF]', text))
|
78 |
+
|
79 |
def model_chat(query: str, history: History) -> Tuple[str, History]:
|
80 |
if not query.strip():
|
81 |
return '', history
|
|
|
167 |
print(f"Query: {query}") # Debug print statement
|
168 |
response, history = model_chat(query, history)
|
169 |
print(f"Response: {response}") # Debug print statement
|
170 |
+
if is_hebrew(response):
|
171 |
+
return history, gr.update(value="", interactive=True, lines=2, rtl=True), history
|
172 |
+
else:
|
173 |
+
return history, gr.update(value="", interactive=True, lines=2, rtl=False), history
|
174 |
|
175 |
demo_state = gr.State([])
|
176 |
|