import gradio as gr import requests from typing import List, Dict, Tuple from flask import Flask, request, jsonify, send_from_directory from transformers import AutoTokenizer, AutoModelForCausalLM import threading import torch import os import re # Define the API URL to use the internal server API_URL = "http://localhost:5000/chat" History = List[Tuple[str, str]] Messages = List[Dict[str, str]] app = Flask(__name__) # Load the model and tokenizer model_name = "dicta-il/dictalm2.0-instruct" model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(model_name) # Set the pad_token to eos_token if not already set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Route to serve static files (e.g., images) @app.route('/static/') def send_static(path): return send_from_directory('static', path) @app.route('/chat', methods=['POST']) def chat(): data = request.json messages = data.get('messages', []) if not messages: return jsonify({"response": "No messages provided"}), 400 # Concatenate all user inputs into a single string user_input = " ".join([msg['content'] for msg in messages if msg['role'] == 'user']) inputs = tokenizer(user_input, return_tensors='pt', padding=True, truncation=True) input_ids = inputs['input_ids'] attention_mask = inputs['attention_mask'] outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=1000, pad_token_id=tokenizer.eos_token_id) response_text = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(user_input, '').strip() return jsonify({"response": response_text}) # Function to run the Flask app def run_flask(): app.run(host='0.0.0.0', port=5000) # Start the Flask app in a separate thread threading.Thread(target=run_flask).start() # Gradio interface functions def clear_session() -> History: return [] def history_to_messages(history: History) -> Messages: messages = [] for h in history: messages.append({'role': 'user', 'content': h[0].strip()}) messages.append({'role': 'assistant', 'content': h[1].strip()}) return messages def messages_to_history(messages: Messages) -> History: history = [] for q, r in zip(messages[0::2], messages[1::2]): history.append((q['content'], r['content'])) return history def is_hebrew(text: str) -> bool: return bool(re.search(r'[\u0590-\u05FF]', text)) def model_chat(query: str, history: History) -> Tuple[str, History]: if not query.strip(): return '', history messages = history_to_messages(history) messages.append({'role': 'user', 'content': query.strip()}) try: response = requests.post(API_URL, json={"messages": messages}) response.raise_for_status() # This will raise an HTTPError if the HTTP request returned an unsuccessful status code response_json = response.json() response_text = response_json.get("response", "Error: Response format is incorrect") except requests.exceptions.HTTPError as e: response_text = f"HTTPError: {str(e)}" print(f"HTTPError: {e.response.text}") # Detailed error message except requests.exceptions.RequestException as e: response_text = f"RequestException: {str(e)}" print(f"RequestException: {e}") # Debug print statement except ValueError as e: response_text = "ValueError: Invalid JSON response" print(f"ValueError: {e}") # Debug print statement except Exception as e: response_text = f"Exception: {str(e)}" print(f"General Exception: {e}") # Debug print statement history.append((query.strip(), response_text.strip())) return response_text.strip(), history # Gradio interface setup with gr.Blocks(css=''' .gr-group {direction: rtl;} .chatbot{text-align:right;} .dicta-header { background-color: var(--input-background-fill); border-radius: 10px; padding: 20px; text-align: center; display: flex; flex-direction: row; align-items: center; box-shadow: var(--block-shadow); border-color: var(--block-border-color); border-width: 1px; } @media (max-width: 768px) { .dicta-header { flex-direction: column; } } .chatbot.prose { font-size: 1.2em; } .dicta-logo { width: 150px; height: auto; margin-bottom: 20px; } .dicta-intro-text { margin-bottom: 20px; text-align: center; display: flex; flex-direction: column; align-items: center; width: 100%; font-size: 1.1em; } textarea { font-size: 1.2em; } ''', js=None) as demo: gr.Markdown("""

צ'אט מערכי - הדגמה ראשונית

ברוכים הבאים לדמו האינטראקטיבי הראשון. חקרו את יכולות המודל וראו כיצד הוא יכול לסייע לכם במשימותיכם
הדמו נכתב על ידי סרן רועי רתם תוך שימוש במודל שפה דיקטה שפותח על ידי מפא"ת
""") chatbot = gr.Chatbot(height=600) query = gr.Textbox(placeholder="הכנס שאלה בעברית (או באנגלית!)", rtl=True) clear_btn = gr.Button("נקה שיחה") def respond(query, history): print(f"Query: {query}") # Debug print statement response, history = model_chat(query, history) print(f"Response: {response}") # Debug print statement if is_hebrew(response): return history, gr.update(value="", interactive=True, lines=2, rtl=True), history else: return history, gr.update(value="", interactive=True, lines=2, rtl=False), history demo_state = gr.State([]) query.submit(respond, [query, demo_state], [chatbot, query, demo_state]) clear_btn.click(clear_session, [], [chatbot, demo_state]) demo.launch(share=True)