test / app.py
joermd's picture
Update app.py
a85fc0b verified
raw
history blame
2.75 kB
import streamlit as st
import streamlit.components.v1 as components
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
from streamlit.runtime.scriptrunner import add_script_run_ctx
import threading
# تكوين النموذج والتوكنايزر
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("amd/AMD-OLMo-1B")
model = AutoModelForCausalLM.from_pretrained(
"amd/AMD-OLMo-1B",
torch_dtype=torch.float16,
device_map="auto"
)
return model, tokenizer
def generate_response(prompt, model, tokenizer):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=200,
num_return_sequences=1,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.replace(prompt, "").strip()
def main():
st.set_page_config(
page_title="سبيدي",
page_icon="💬",
layout="wide"
)
# إخفاء عناصر Streamlit الافتراضية
hide_streamlit_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
# تحميل النموذج والتوكنايزر
model, tokenizer = load_model()
# قراءة ملف HTML
def read_html():
with open('index.html', 'r', encoding='utf-8') as file:
return file.read()
# معالجة الرسائل الواردة من JavaScript
def handle_message(message_data):
try:
data = json.loads(message_data)
user_message = data.get('message', '')
if user_message:
response = generate_response(user_message, model, tokenizer)
return {"response": response}
return {"response": "عذراً، لم أفهم رسالتك"}
except Exception as e:
return {"response": f"عذراً، حدث خطأ: {str(e)}"}
# تكوين معالج الرسائل
def message_handler(message_data):
ctx = add_script_run_ctx()
response = handle_message(message_data)
ctx.enqueue(json.dumps(response))
# عرض الواجهة
components.html(
read_html(),
height=800,
on_message=message_handler
)
if __name__ == "__main__":
main()