Royrotem100 commited on
Commit
384005b
1 Parent(s): b4fc999

Set pad_token to eos_token and exclude user query from response

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -6,6 +6,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import threading
7
  import torch
8
  import os
 
9
 
10
  # Define the API URL to use the internal server
11
  API_URL = "http://localhost:5000/chat"
@@ -43,7 +44,7 @@ def chat():
43
  inputs = tokenizer(user_input, return_tensors='pt', padding=True, truncation=True)
44
  input_ids = inputs['input_ids']
45
  attention_mask = inputs['attention_mask']
46
- outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=100,pad_token_id=tokenizer.eos_token_id)
47
  response_text = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(user_input, '').strip()
48
 
49
  return jsonify({"response": response_text})
@@ -72,6 +73,9 @@ def messages_to_history(messages: Messages) -> History:
72
  history.append((q['content'], r['content']))
73
  return history
74
 
 
 
 
75
  def model_chat(query: str, history: History) -> Tuple[str, History]:
76
  if not query.strip():
77
  return '', history
@@ -163,7 +167,10 @@ with gr.Blocks(css='''
163
  print(f"Query: {query}") # Debug print statement
164
  response, history = model_chat(query, history)
165
  print(f"Response: {response}") # Debug print statement
166
- return history, gr.update(value="", interactive=True), history # Ensure correct return format
 
 
 
167
 
168
  demo_state = gr.State([])
169
 
 
6
  import threading
7
  import torch
8
  import os
9
+ import re
10
 
11
  # Define the API URL to use the internal server
12
  API_URL = "http://localhost:5000/chat"
 
44
  inputs = tokenizer(user_input, return_tensors='pt', padding=True, truncation=True)
45
  input_ids = inputs['input_ids']
46
  attention_mask = inputs['attention_mask']
47
+ outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=1000, pad_token_id=tokenizer.eos_token_id)
48
  response_text = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(user_input, '').strip()
49
 
50
  return jsonify({"response": response_text})
 
73
  history.append((q['content'], r['content']))
74
  return history
75
 
76
+ def is_hebrew(text: str) -> bool:
77
+ return bool(re.search(r'[\u0590-\u05FF]', text))
78
+
79
  def model_chat(query: str, history: History) -> Tuple[str, History]:
80
  if not query.strip():
81
  return '', history
 
167
  print(f"Query: {query}") # Debug print statement
168
  response, history = model_chat(query, history)
169
  print(f"Response: {response}") # Debug print statement
170
+ if is_hebrew(response):
171
+ return history, gr.update(value="", interactive=True, lines=2, rtl=True), history
172
+ else:
173
+ return history, gr.update(value="", interactive=True, lines=2, rtl=False), history
174
 
175
  demo_state = gr.State([])
176