import os import gradio as gr from http import HTTPStatus from typing import Generator, List, Optional, Tuple, Dict from urllib.error import HTTPError from flask import Flask, request, jsonify from transformers import AutoTokenizer, AutoModelForCausalLM import threading import requests import torch # Load the model and tokenizer #tokenizer = AutoTokenizer.from_pretrained("./dictalm2.0-instruct-roys-chat") #model = AutoModelForCausalLM.from_pretrained("./dictalm2.0-instruct-roys-chat") # Load the model and tokenizer model_name = "dicta-il/dictalm2.0-instruct" model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(model_name) History = List[Tuple[str, str]] Messages = List[Dict[str, str]] def clear_session() -> History: return '', [] def history_to_messages(history: History) -> Messages: messages = [] for h in history: messages.append({'role': 'user', 'content': h[0].strip()}) messages.append({'role': 'assistant', 'content': h[1].strip()}) return messages def messages_to_history(messages: Messages) -> Tuple[str, History]: history = [] for q, r in zip(messages[0::2], messages[1::2]): history.append([q['content'], r['content']]) return history # Flask app setup app = Flask(__name__) @app.route('/predict', methods=['POST']) def predict(): data = request.json input_text = data.get('text', '') # Format the input text with instruction tokens formatted_text = f"[INST] {input_text} [/INST]" # Tokenize the input inputs = tokenizer(formatted_text, return_tensors='pt') # Generate the output outputs = model.generate(inputs['input_ids'], max_length=1024, temperature=0.7, top_p=0.9) # Decode the output prediction = tokenizer.decode(outputs[0], skip_special_tokens=True) return jsonify({"prediction": prediction}) def run_flask(): app.run(host='0.0.0.0', port=5000) # Run Flask in a separate thread threading.Thread(target=run_flask).start() def model_chat(query: Optional[str], history: Optional[History]) -> Generator[Tuple[str, History], None, None]: if query is None: query = '' if history is None: history = [] if not query.strip(): return response = requests.post("http://127.0.0.1:5000/predict", json={"text": query.strip()}) if response.status_code == 200: prediction = response.json().get("prediction", "") history.append((query, prediction)) yield prediction, history else: yield "Error: Unable to get a response from the model.", history with gr.Blocks(css=''' .gr-group {direction: rtl;} .chatbot{text-align:right;} .dicta-header { background-color: var(--input-background-fill); /* Replace with desired background color */ border-radius: 10px; padding: 20px; text-align: center; display: flex; flex-direction: row; align-items: center; box-shadow: var(--block-shadow); border-color: var(--block-border-color); border-width: 1px; } @media (max-width: 768px) { .dicta-header { flex-direction: column; /* Change to vertical for mobile devices */ } } .chatbot.prose { font-size: 1.2em; } .dicta-logo { width: 150px; /* Replace with actual logo width as desired */ height: auto; margin-bottom: 20px; } .dicta-intro-text { margin-bottom: 20px; text-align: center; display: flex; flex-direction: column; align-items: center; width: 100%; font-size: 1.1em; } textarea { font-size: 1.2em; } ''', js=None) as demo: gr.Markdown("""

צ'אט מערכי - הדגמה ראשונית

ברוכים הבאים לדמו האינטראקטיבי הראשון. חקרו את יכולות המודל וראו כיצד הוא יכול לסייע לכם במשימותיכם
הדמו נכתב על ידי סרן רועי רתם תוך שימוש במודל שפה דיקטה שפותח על ידי מפא"ת
""") interface = gr.ChatInterface(model_chat, fill_height=False) interface.chatbot.rtl = True interface.textbox.placeholder = "הכנס שאלה בעברית (או באנגלית!)" interface.textbox.rtl = True interface.textbox.text_align = 'right' interface.theme_css += '.gr-group {direction: rtl !important;}' demo.queue(api_open=False).launch(max_threads=20, share=False, allowed_paths=['logo_am.png'])