joermd commited on
Commit
067af65
·
verified ·
1 Parent(s): 75bf6fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -65
app.py CHANGED
@@ -1,14 +1,14 @@
1
- import streamlit as st
2
- import streamlit.components.v1 as components
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
- import json
6
- from streamlit.runtime.scriptrunner import add_script_run_ctx
7
- import threading
8
 
9
- # تكوين النموذج والتوكنايزر
10
- @st.cache_resource
 
11
  def load_model():
 
12
  tokenizer = AutoTokenizer.from_pretrained("amd/AMD-OLMo-1B")
13
  model = AutoModelForCausalLM.from_pretrained(
14
  "amd/AMD-OLMo-1B",
@@ -18,8 +18,8 @@ def load_model():
18
  return model, tokenizer
19
 
20
  def generate_response(prompt, model, tokenizer):
 
21
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
22
-
23
  with torch.no_grad():
24
  outputs = model.generate(
25
  **inputs,
@@ -30,62 +30,29 @@ def generate_response(prompt, model, tokenizer):
30
  repetition_penalty=1.2,
31
  pad_token_id=tokenizer.eos_token_id
32
  )
33
-
34
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
- return response.replace(prompt, "").strip()
36
 
37
- def main():
38
- st.set_page_config(
39
- page_title="سبيدي",
40
- page_icon="💬",
41
- layout="wide"
42
- )
43
-
44
- # إخفاء عناصر Streamlit الافتراضية
45
- hide_streamlit_style = """
46
- <style>
47
- #MainMenu {visibility: hidden;}
48
- footer {visibility: hidden;}
49
- header {visibility: hidden;}
50
- </style>
51
- """
52
- st.markdown(hide_streamlit_style, unsafe_allow_html=True)
53
-
54
- # تحميل النموذج والتوكنايزر
55
- model, tokenizer = load_model()
56
-
57
- # قراءة ملف HTML
58
- def read_html():
59
- with open('index.html', 'r', encoding='utf-8') as file:
60
- return file.read()
61
-
62
- # معالجة الرسائل الواردة من JavaScript
63
- def handle_message(message_data):
64
- try:
65
- data = json.loads(message_data)
66
- user_message = data.get('message', '')
67
-
68
- if user_message:
69
- response = generate_response(user_message, model, tokenizer)
70
- return {"response": response}
71
-
72
- return {"response": "عذراً، لم أفهم رسالتك"}
73
-
74
- except Exception as e:
75
- return {"response": f"عذراً، حدث خطأ: {str(e)}"}
76
-
77
- # تكوين معالج الرسائل
78
- def message_handler(message_data):
79
- ctx = add_script_run_ctx()
80
- response = handle_message(message_data)
81
- ctx.enqueue(json.dumps(response))
82
-
83
- # عرض الواجهة
84
- components.html(
85
- read_html(),
86
- height=800,
87
- on_message=message_handler
88
- )
89
 
90
- if __name__ == "__main__":
91
- main()
 
1
+ # app.py
2
+ from flask import Flask, render_template, request, jsonify
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
+ from functools import lru_cache
 
 
6
 
7
+ app = Flask(__name__)
8
+
9
+ @lru_cache(maxsize=1)
10
  def load_model():
11
+ """Load model and tokenizer with caching"""
12
  tokenizer = AutoTokenizer.from_pretrained("amd/AMD-OLMo-1B")
13
  model = AutoModelForCausalLM.from_pretrained(
14
  "amd/AMD-OLMo-1B",
 
18
  return model, tokenizer
19
 
20
  def generate_response(prompt, model, tokenizer):
21
+ """Generate response from the model"""
22
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
23
  with torch.no_grad():
24
  outputs = model.generate(
25
  **inputs,
 
30
  repetition_penalty=1.2,
31
  pad_token_id=tokenizer.eos_token_id
32
  )
33
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
+ return response.replace(prompt, "").strip()
 
35
 
36
+ @app.route('/')
37
+ def home():
38
+ return render_template('index.html')
39
+
40
+ @app.route('/message', methods=['POST'])
41
+ def message():
42
+ try:
43
+ data = request.json
44
+ user_message = data.get('message', '')
45
+
46
+ if not user_message:
47
+ return jsonify({"response": "عذراً، لم أفهم رسالتك"})
48
+
49
+ model, tokenizer = load_model()
50
+ response = generate_response(user_message, model, tokenizer)
51
+
52
+ return jsonify({"response": response})
53
+
54
+ except Exception as e:
55
+ return jsonify({"response": f"عذراً، حدث خطأ: {str(e)}"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ if __name__ == '__main__':
58
+ app.run(debug=True)