import os import gradio as gr from http import HTTPStatus from typing import Generator, List, Optional, Tuple, Dict import re from urllib.error import HTTPError from flask import Flask, request, jsonify from transformers import AutoTokenizer, AutoModelForCausalLM import threading import requests import torch # Load the model and tokenizer model_name = "dicta-il/dictalm2.0-instruct" model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(model_name) # Set the pad token to eos_token if not already set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token History = List[Tuple[str, str]] Messages = List[Dict[str, str]] def clear_session() -> History: return [] def history_to_messages(history: History) -> Messages: messages = [] for h in history: messages.append({'role': 'user', 'content': h[0].strip()}) messages.append({'role': 'assistant', 'content': h[1].strip()}) return messages def messages_to_history(messages: Messages) -> History: history = [] for q, r in zip(messages[0::2], messages[1::2]): history.append((q['content'], r['content'])) return history # Flask app setup app = Flask(__name__) @app.route('/predict', methods=['POST']) def predict(): data = request.json input_text = data.get('text', '') # Format the input text with instruction tokens formatted_text = f"[INST] {input_text} [/INST]" # Tokenize the input inputs = tokenizer(formatted_text, return_tensors='pt', padding=True, truncation=True, max_length=1024) # Generate the output outputs = model.generate( inputs['input_ids'], attention_mask=inputs['attention_mask'], max_length=1024, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Decode the output prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(formatted_text, '').strip() return jsonify({"prediction": prediction}) def run_flask(): app.run(host='0.0.0.0', port=5000) def is_hebrew(text: str) -> bool: return bool(re.search(r'[\u0590-\u05FF]', text)) # Run Flask in a separate thread threading.Thread(target=run_flask).start() def model_chat(query: Optional[str], history: Optional[History]) -> Generator[Tuple[str, History], None, None]: if query is None: query = '' if history is None: history = [] if not query.strip(): return response = requests.post("http://127.0.0.1:5000/predict", json={"text": query.strip()}) if response.status_code == 200: prediction = response.json().get("prediction", "") history.append((query, prediction)) yield history else: yield history with gr.Blocks(css=''' .gr-group {direction: rtl;} .chatbot{text-align:right;} .dicta-header { background-color: var(--input-background-fill); /* Replace with desired background color */ border-radius: 10px; padding: 20px; text-align: center; display: flex; flex-direction: row; align-items: center; box-shadow: var(--block-shadow); border-color: var(--block-border-color); border-width: 1px; } @media (max-width: 768px) { .dicta-header { flex-direction: column; /* Change to vertical for mobile devices */ } } .chatbot.prose { font-size: 1.2em; } .dicta-logo { width: 150px; /* Replace with actual logo width as desired */ height: auto; margin-bottom: 20px; } .dicta-intro-text { margin-bottom: 20px; text-align: center; display: flex; flex-direction: column; align-items: center; width: 100%; font-size: 1.1em; } textarea { font-size: 1.2em; } ''', js=None) as demo: gr.Markdown("""

הדגמה ראשונית

ברוכים הבאים לדמו האינטראקטיבי הראשון. חקרו את יכולות המודל וראו כיצד הוא יכול לסייע לכם במשימותיכם
הדמו נכתב על ידי רועי רתם תוך שימוש במודל שפה דיקטה שפותח על ידי מפא"ת
""") interface = gr.ChatInterface(model_chat, fill_height=False) interface.chatbot.rtl = True interface.textbox.placeholder = "הכנס שאלה בעברית (או באנגלית!)" interface.textbox.rtl = True interface.textbox.text_align = 'right' interface.theme_css += '.gr-group {direction: rtl !important;}' demo.queue(api_open=False).launch(max_threads=20, share=False, allowed_paths=['logo_am.png'])