Spaces:
Paused
Paused
import numpy as np | |
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import os | |
# التحقق من توفر GPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def load_model(): | |
""" | |
تحميل النموذج والمُرمِّز مع التخزين المؤقت | |
""" | |
model_name = "mistralai/Mistral-7B-Instruct-v0.2" | |
# تهيئة الـtokenizer أولاً | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# تهيئة النموذج مع إعدادات مناسبة | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
device_map="auto" | |
) | |
return model, tokenizer | |
def reset_conversation(): | |
''' | |
إعادة تعيين المحادثة | |
''' | |
st.session_state.conversation = [] | |
st.session_state.messages = [] | |
return None | |
def generate_response(model, tokenizer, prompt, temperature=0.7, max_length=500): | |
""" | |
توليد استجابة من النموذج | |
""" | |
try: | |
# تحضير المدخلات | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
# توليد النص | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_length=max_length, | |
temperature=temperature, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# فك ترميز النص المولد | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
except Exception as e: | |
return f"حدث خطأ أثناء توليد الاستجابة: {str(e)}" | |
# تهيئة Streamlit | |
st.title("Mistral Chat 🤖") | |
# إضافة أزرار التحكم في الشريط الجانبي | |
with st.sidebar: | |
st.header("إعدادات") | |
temperature = st.slider("درجة الحرارة", min_value=0.1, max_value=1.0, value=0.7, step=0.1) | |
max_tokens = st.slider("الحد الأقصى للكلمات", min_value=50, max_value=1000, value=500, step=50) | |
if st.button("مسح المحادثة"): | |
reset_conversation() | |
# تحميل النموذج | |
try: | |
with st.spinner("جاري تحميل النموذج... قد يستغرق هذا بضع دقائق..."): | |
model, tokenizer = load_model() | |
st.sidebar.success("تم تحميل النموذج بنجاح! 🎉") | |
except Exception as e: | |
st.error(f"حدث خطأ أثناء تحميل النموذج: {str(e)}") | |
st.stop() | |
# تهيئة سجل المحادثة | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# عرض المحادثة السابقة | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
# معالجة إدخال المستخدم | |
if prompt := st.chat_input(): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.write(prompt) | |
with st.chat_message("assistant"): | |
with st.spinner("جاري التفكير..."): | |
response = generate_response( | |
model=model, | |
tokenizer=tokenizer, | |
prompt=prompt, | |
temperature=temperature, | |
max_length=max_tokens | |
) | |
st.write(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) |