demo_1 / app.py
Royrotem100's picture
Set pad_token to eos_token and exclude user query from response
384005b
import gradio as gr
import requests
from typing import List, Dict, Tuple
from flask import Flask, request, jsonify, send_from_directory
from transformers import AutoTokenizer, AutoModelForCausalLM
import threading
import torch
import os
import re
# Define the API URL to use the internal server
API_URL = "http://localhost:5000/chat"
History = List[Tuple[str, str]]
Messages = List[Dict[str, str]]
app = Flask(__name__)
# Load the model and tokenizer
model_name = "dicta-il/dictalm2.0-instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set the pad_token to eos_token if not already set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Route to serve static files (e.g., images)
@app.route('/static/<path:path>')
def send_static(path):
return send_from_directory('static', path)
@app.route('/chat', methods=['POST'])
def chat():
data = request.json
messages = data.get('messages', [])
if not messages:
return jsonify({"response": "No messages provided"}), 400
# Concatenate all user inputs into a single string
user_input = " ".join([msg['content'] for msg in messages if msg['role'] == 'user'])
inputs = tokenizer(user_input, return_tensors='pt', padding=True, truncation=True)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=1000, pad_token_id=tokenizer.eos_token_id)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(user_input, '').strip()
return jsonify({"response": response_text})
# Function to run the Flask app
def run_flask():
app.run(host='0.0.0.0', port=5000)
# Start the Flask app in a separate thread
threading.Thread(target=run_flask).start()
# Gradio interface functions
def clear_session() -> History:
return []
def history_to_messages(history: History) -> Messages:
messages = []
for h in history:
messages.append({'role': 'user', 'content': h[0].strip()})
messages.append({'role': 'assistant', 'content': h[1].strip()})
return messages
def messages_to_history(messages: Messages) -> History:
history = []
for q, r in zip(messages[0::2], messages[1::2]):
history.append((q['content'], r['content']))
return history
def is_hebrew(text: str) -> bool:
return bool(re.search(r'[\u0590-\u05FF]', text))
def model_chat(query: str, history: History) -> Tuple[str, History]:
if not query.strip():
return '', history
messages = history_to_messages(history)
messages.append({'role': 'user', 'content': query.strip()})
try:
response = requests.post(API_URL, json={"messages": messages})
response.raise_for_status() # This will raise an HTTPError if the HTTP request returned an unsuccessful status code
response_json = response.json()
response_text = response_json.get("response", "Error: Response format is incorrect")
except requests.exceptions.HTTPError as e:
response_text = f"HTTPError: {str(e)}"
print(f"HTTPError: {e.response.text}") # Detailed error message
except requests.exceptions.RequestException as e:
response_text = f"RequestException: {str(e)}"
print(f"RequestException: {e}") # Debug print statement
except ValueError as e:
response_text = "ValueError: Invalid JSON response"
print(f"ValueError: {e}") # Debug print statement
except Exception as e:
response_text = f"Exception: {str(e)}"
print(f"General Exception: {e}") # Debug print statement
history.append((query.strip(), response_text.strip()))
return response_text.strip(), history
# Gradio interface setup
with gr.Blocks(css='''
.gr-group {direction: rtl;}
.chatbot{text-align:right;}
.dicta-header {
background-color: var(--input-background-fill);
border-radius: 10px;
padding: 20px;
text-align: center;
display: flex;
flex-direction: row;
align-items: center;
box-shadow: var(--block-shadow);
border-color: var(--block-border-color);
border-width: 1px;
}
@media (max-width: 768px) {
.dicta-header {
flex-direction: column;
}
}
.chatbot.prose {
font-size: 1.2em;
}
.dicta-logo {
width: 150px;
height: auto;
margin-bottom: 20px;
}
.dicta-intro-text {
margin-bottom: 20px;
text-align: center;
display: flex;
flex-direction: column;
align-items: center;
width: 100%;
font-size: 1.1em;
}
textarea {
font-size: 1.2em;
}
''', js=None) as demo:
gr.Markdown("""
<div class="dicta-header">
<a href="/static/logo111.png">
<img src="/static/logo111.png" alt="Logo" class="dicta-logo">
</a>
<div class="dicta-intro-text">
<h1>爪'讗讟 诪注专讻讬 - 讛讚讙诪讛 专讗砖讜谞讬转</h1>
<span dir='rtl'>讘专讜讻讬诐 讛讘讗讬诐 诇讚诪讜 讛讗讬谞讟专讗拽讟讬讘讬 讛专讗砖讜谉. 讞拽专讜 讗转 讬讻讜诇讜转 讛诪讜讚诇 讜专讗讜 讻讬爪讚 讛讜讗 讬讻讜诇 诇住讬讬注 诇讻诐 讘诪砖讬诪讜转讬讻诐</span><br/>
<span dir='rtl'>讛讚诪讜 谞讻转讘 注诇 讬讚讬 住专谉 专讜注讬 专转诐 转讜讱 砖讬诪讜砖 讘诪讜讚诇 砖驻讛 讚讬拽讟讛 砖驻讜转讞 注诇 讬讚讬 诪驻讗"转</span><br/>
</div>
</div>
""")
chatbot = gr.Chatbot(height=600)
query = gr.Textbox(placeholder="讛讻谞住 砖讗诇讛 讘注讘专讬转 (讗讜 讘讗谞讙诇讬转!)", rtl=True)
clear_btn = gr.Button("谞拽讛 砖讬讞讛")
def respond(query, history):
print(f"Query: {query}") # Debug print statement
response, history = model_chat(query, history)
print(f"Response: {response}") # Debug print statement
if is_hebrew(response):
return history, gr.update(value="", interactive=True, lines=2, rtl=True), history
else:
return history, gr.update(value="", interactive=True, lines=2, rtl=False), history
demo_state = gr.State([])
query.submit(respond, [query, demo_state], [chatbot, query, demo_state])
clear_btn.click(clear_session, [], [chatbot, demo_state])
demo.launch(share=True)