Royrotem100 commited on
Commit
4c9d398
1 Parent(s): 64f89ac

Add DictaLM 2.0 instruct model 9

Browse files
Files changed (1) hide show
  1. app.py +36 -88
app.py CHANGED
@@ -1,30 +1,20 @@
1
  import os
2
  import gradio as gr
3
  from http import HTTPStatus
 
4
  from typing import Generator, List, Optional, Tuple, Dict
5
- import re
6
  from urllib.error import HTTPError
7
- from flask import Flask, request, jsonify
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
9
- import threading
10
- import requests
11
- import torch
12
 
13
-
14
- # Load the model and tokenizer
15
- #tokenizer = AutoTokenizer.from_pretrained("./dictalm2.0-instruct-roys-chat")
16
- #model = AutoModelForCausalLM.from_pretrained("./dictalm2.0-instruct-roys-chat")
17
-
18
- # Load the model and tokenizer
19
- model_name = "dicta-il/dictalm2.0-instruct"
20
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
21
- tokenizer = AutoTokenizer.from_pretrained(model_name)
22
 
23
  History = List[Tuple[str, str]]
24
  Messages = List[Dict[str, str]]
25
 
26
  def clear_session() -> History:
27
- return []
28
 
29
  def history_to_messages(history: History) -> Messages:
30
  messages = []
@@ -39,69 +29,29 @@ def messages_to_history(messages: Messages) -> Tuple[str, History]:
39
  history.append([q['content'], r['content']])
40
  return history
41
 
42
-
43
- # Flask app setup
44
- app = Flask(__name__)
45
-
46
- @app.route('/predict', methods=['POST'])
47
- def predict():
48
- data = request.json
49
- input_text = data.get('text', '')
50
-
51
- # Format the input text with instruction tokens
52
- formatted_text = f"<s>[INST] {input_text} [/INST]"
53
-
54
- # Tokenize the input
55
- inputs = tokenizer(formatted_text, return_tensors='pt', padding=True, truncation=True, max_length=1024)
56
-
57
- # Generate the output
58
- outputs = model.generate(
59
- inputs['input_ids'],
60
- attention_mask=inputs['attention_mask'],
61
- max_length=1024,
62
- temperature=0.7,
63
- top_p=0.9,
64
- do_sample=True,
65
- pad_token_id=tokenizer.eos_token_id
66
- )
67
-
68
- # Decode the output
69
- prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(formatted_text, '').strip()
70
-
71
- # Remove the user input part from the response
72
- if "[/INST]" in prediction:
73
- prediction = prediction.split("[/INST]", 1)[-1].strip()
74
-
75
- return jsonify({"prediction": prediction})
76
-
77
- def run_flask():
78
- app.run(host='0.0.0.0', port=5000)
79
-
80
- # Run Flask in a separate thread
81
- threading.Thread(target=run_flask).start()
82
-
83
-
84
- def model_chat(query: Optional[str], history: Optional[History]) -> Tuple[History, str]:
85
  if query is None:
86
  query = ''
87
  if history is None:
88
  history = []
89
  if not query.strip():
90
- return history, ""
91
-
92
- response = requests.post("http://127.0.0.1:5000/predict", json={"text": query.strip()})
93
- if response.status_code == 200:
94
- prediction = response.json().get("prediction", "")
95
- history.append((query, prediction))
96
- return history, prediction
97
- else:
98
- return history, "Error: Unable to get a response from the model."
99
-
100
- def respond(query: str, history: History) -> Tuple[History, str]:
101
- history, response = model_chat(query, history)
102
- return history, response # Return history and response to show the model's response
103
-
104
-
 
 
105
  with gr.Blocks(css='''
106
  .gr-group {direction: rtl;}
107
  .chatbot{text-align:right;}
@@ -147,27 +97,25 @@ with gr.Blocks(css='''
147
  textarea {
148
  font-size: 1.2em;
149
  }
150
- ''', js=None) as demo:
151
  gr.Markdown("""
152
  <div class="dicta-header">
153
  <a href="">
154
- <img src="file/logo_am.png" alt="Dicta Logo" class="dicta-logo">
155
  </a>
156
  <div class="dicta-intro-text">
157
- <h1>讛讚讙诪讛 专讗砖讜谞讬转</h1>
158
  <span dir='rtl'>讘专讜讻讬诐 讛讘讗讬诐 诇讚诪讜 讛讗讬谞讟专讗拽讟讬讘讬 讛专讗砖讜谉. 讞拽专讜 讗转 讬讻讜诇讜转 讛诪讜讚诇 讜专讗讜 讻讬爪讚 讛讜讗 讬讻讜诇 诇住讬讬注 诇讻诐 讘诪砖讬诪讜转讬讻诐</span><br/>
159
- <span dir='rtl'>讛讚诪讜 谞讻转讘 注诇 讬讚讬 专讜注讬 专转诐 转讜讱 砖讬诪讜砖 讘诪讜讚诇 砖驻讛 讚讬拽讟讛 砖驻讜转讞 注诇 讬讚讬 诪驻讗"转</span><br/>
160
  </div>
161
  </div>
162
  """)
163
 
164
- chatbot = gr.Chatbot()
165
- query = gr.Textbox(placeholder="讛讻谞住 砖讗诇讛 讘注讘专讬转 (讗讜 讘讗谞讙诇讬转!)", rtl=True)
166
- clear_btn = gr.Button("谞拽讛 砖讬讞讛")
167
-
168
- demo_state = gr.State([])
169
-
170
- query.submit(respond, [query, demo_state], [chatbot, query, demo_state])
171
- clear_btn.click(clear_session, [], demo_state, chatbot)
172
-
173
- demo.queue(api_open=False).launch(max_threads=20, share=False, allowed_paths=['logo_am.png'])
 
1
  import os
2
  import gradio as gr
3
  from http import HTTPStatus
4
+ import openai
5
  from typing import Generator, List, Optional, Tuple, Dict
 
6
  from urllib.error import HTTPError
 
 
 
 
 
7
 
8
+ API_URL = os.getenv('API_URL')
9
+ API_KEY = os.getenv('API_KEY')
10
+ CUSTOM_JS = os.getenv('CUSTOM_JS', None)
11
+ oai_client = openai.OpenAI(api_key=API_KEY, base_url=API_URL)
 
 
 
 
 
12
 
13
  History = List[Tuple[str, str]]
14
  Messages = List[Dict[str, str]]
15
 
16
  def clear_session() -> History:
17
+ return '', []
18
 
19
  def history_to_messages(history: History) -> Messages:
20
  messages = []
 
29
  history.append([q['content'], r['content']])
30
  return history
31
 
32
+ def model_chat(query: Optional[str], history: Optional[History]) -> Generator[Tuple[str, History], None, None]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  if query is None:
34
  query = ''
35
  if history is None:
36
  history = []
37
  if not query.strip():
38
+ return
39
+ messages = history_to_messages(history)
40
+ messages.append({'role': 'user', 'content': query.strip()})
41
+ gen = oai_client.chat.completions.create(
42
+ model='dicta-il/dictalm2.0-instruct',
43
+ messages=messages,
44
+ temperature=0.7,
45
+ max_tokens=1024,
46
+ top_p=0.9,
47
+ stream=True
48
+ )
49
+ full_response = ''
50
+ for completion in gen:
51
+ text = completion.choices[0].delta.content
52
+ full_response += text or ''
53
+ yield full_response
54
+
55
  with gr.Blocks(css='''
56
  .gr-group {direction: rtl;}
57
  .chatbot{text-align:right;}
 
97
  textarea {
98
  font-size: 1.2em;
99
  }
100
+ ''', js=CUSTOM_JS) as demo:
101
  gr.Markdown("""
102
  <div class="dicta-header">
103
  <a href="">
104
+ <img src="file/logo111.png" alt="Dicta Logo" class="dicta-logo">
105
  </a>
106
  <div class="dicta-intro-text">
107
+ <h1>爪'讗讟 诪注专讻讬 - 讛讚讙诪讛 专讗砖讜谞讬转</h1>
108
  <span dir='rtl'>讘专讜讻讬诐 讛讘讗讬诐 诇讚诪讜 讛讗讬谞讟专讗拽讟讬讘讬 讛专讗砖讜谉. 讞拽专讜 讗转 讬讻讜诇讜转 讛诪讜讚诇 讜专讗讜 讻讬爪讚 讛讜讗 讬讻讜诇 诇住讬讬注 诇讻诐 讘诪砖讬诪讜转讬讻诐</span><br/>
109
+ <span dir='rtl'>讛讚诪讜 谞讻转讘 注诇 讬讚讬 住专谉 专讜注讬 专转诐 转讜讱 砖讬诪讜砖 讘诪讜讚诇 砖驻讛 讚讬拽讟讛 砖驻讜转讞 注诇 讬讚讬 诪驻讗"转</span><br/>
110
  </div>
111
  </div>
112
  """)
113
 
114
+ interface = gr.ChatInterface(model_chat, fill_height=False)
115
+ interface.chatbot.rtl = True
116
+ interface.textbox.placeholder = "讛讻谞住 砖讗诇讛 讘注讘专讬转 (讗讜 讘讗谞讙诇讬转!)"
117
+ interface.textbox.rtl = True
118
+ interface.textbox.text_align = 'right'
119
+ interface.theme_css += '.gr-group {direction: rtl !important;}'
120
+
121
+ demo.queue(api_open=False).launch(max_threads=20, share=False, allowed_paths=['dicta-logo.jpg'])