Mark-Arcee commited on
Commit
dfa37de
·
verified ·
1 Parent(s): 9dc82d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -123
app.py CHANGED
@@ -1,128 +1,98 @@
1
- #!/usr/bin/env python
2
-
3
  import os
4
- from threading import Thread
5
- from typing import Iterator
6
 
7
  import gradio as gr
8
- import spaces
9
- import torch
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
-
12
- DESCRIPTION = "# Arcee Model Testing - current test: arcee-ai/SEC-Calme-7B-Instruct"
13
-
14
- if not torch.cuda.is_available():
15
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
16
-
17
- MAX_MAX_NEW_TOKENS = 4096
18
- DEFAULT_MAX_NEW_TOKENS = 4096
19
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
20
-
21
- if torch.cuda.is_available():
22
- model_id = "arcee-ai/SEC-Calme-7B-Instruct"
23
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
24
- tokenizer = AutoTokenizer.from_pretrained(model_id)
25
-
26
-
27
- @spaces.GPU
28
- def generate(
29
- message: str,
30
- chat_history: list[tuple[str, str]],
31
- max_new_tokens: int = 4096,
32
- temperature: float = 0.6,
33
- top_p: float = 0.9,
34
- top_k: int = 50,
35
- repetition_penalty: float = 1.2,
36
- ) -> Iterator[str]:
37
- conversation = []
38
- for user, assistant in chat_history:
39
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
40
- conversation.append({"role": "user", "content": message})
41
-
42
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
43
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
44
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
45
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
46
- input_ids = input_ids.to(model.device)
47
-
48
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
49
- generate_kwargs = dict(
50
- {"input_ids": input_ids},
51
- streamer=streamer,
52
- max_new_tokens=max_new_tokens,
53
- do_sample=True,
54
- top_p=top_p,
55
- top_k=top_k,
56
- temperature=temperature,
57
- num_beams=1,
58
- repetition_penalty=repetition_penalty,
59
- )
60
- t = Thread(target=model.generate, kwargs=generate_kwargs)
61
- t.start()
62
-
63
- outputs = []
64
- for text in streamer:
65
- outputs.append(text)
66
- yield "".join(outputs)
67
-
68
-
69
- chat_interface = gr.ChatInterface(
70
- fn=generate,
71
- additional_inputs=[
72
- gr.Slider(
73
- label="Max new tokens",
74
- minimum=1,
75
- maximum=MAX_MAX_NEW_TOKENS,
76
- step=1,
77
- value=DEFAULT_MAX_NEW_TOKENS,
78
- ),
79
- gr.Slider(
80
- label="Temperature",
81
- minimum=0.1,
82
- maximum=4.0,
83
- step=0.1,
84
- value=0.6,
85
- ),
86
- gr.Slider(
87
- label="Top-p (nucleus sampling)",
88
- minimum=0.05,
89
- maximum=1.0,
90
- step=0.05,
91
- value=0.9,
92
- ),
93
- gr.Slider(
94
- label="Top-k",
95
- minimum=1,
96
- maximum=1000,
97
- step=1,
98
- value=50,
99
- ),
100
- gr.Slider(
101
- label="Repetition penalty",
102
- minimum=1.0,
103
- maximum=2.0,
104
- step=0.05,
105
- value=1.2,
106
- ),
107
- ],
108
- stop_btn=None,
109
- examples=[
110
- ["Hello there! How are you doing?"],
111
- ["Can you explain briefly to me what is the Python programming language?"],
112
- ["Explain the plot of Cinderella in a sentence."],
113
- ["How many hours does it take a man to eat a Helicopter?"],
114
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
115
- ],
116
- )
117
-
118
- with gr.Blocks(css="style.css") as demo:
119
- gr.Markdown(DESCRIPTION)
120
- gr.DuplicateButton(
121
- value="Duplicate Space for private use",
122
- elem_id="duplicate-button",
123
- visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
124
  )
125
- chat_interface.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- if __name__ == "__main__":
128
- demo.queue(max_size=20).launch()
 
 
 
1
  import os
2
+ os.system('pip install dashscope')
 
3
 
4
  import gradio as gr
5
+ from http import HTTPStatus
6
+ import dashscope
7
+ from dashscope import Generation
8
+ from dashscope.api_entities.dashscope_response import Role
9
+ from typing import List, Optional, Tuple, Dict
10
+ from urllib.error import HTTPError
11
+ default_system = 'You are a helpful assistant.'
12
+
13
+ YOUR_API_TOKEN = os.getenv('YOUR_API_TOKEN')
14
+ dashscope.api_key = YOUR_API_TOKEN
15
+
16
+ History = List[Tuple[str, str]]
17
+ Messages = List[Dict[str, str]]
18
+
19
+ def clear_session() -> History:
20
+ return '', []
21
+
22
+ def modify_system_session(system: str) -> str:
23
+ if system is None or len(system) == 0:
24
+ system = default_system
25
+ return system, system, []
26
+
27
+ def history_to_messages(history: History, system: str) -> Messages:
28
+ messages = [{'role': Role.SYSTEM, 'content': system}]
29
+ for h in history:
30
+ messages.append({'role': Role.USER, 'content': h[0]})
31
+ messages.append({'role': Role.ASSISTANT, 'content': h[1]})
32
+ return messages
33
+
34
+
35
+ def messages_to_history(messages: Messages) -> Tuple[str, History]:
36
+ assert messages[0]['role'] == Role.SYSTEM
37
+ system = messages[0]['content']
38
+ history = []
39
+ for q, r in zip(messages[1::2], messages[2::2]):
40
+ history.append([q['content'], r['content']])
41
+ return system, history
42
+
43
+
44
+ def model_chat(query: Optional[str], history: Optional[History], system: str
45
+ ) -> Tuple[str, str, History]:
46
+ if query is None:
47
+ query = ''
48
+ if history is None:
49
+ history = []
50
+ messages = history_to_messages(history, system)
51
+ messages.append({'role': Role.USER, 'content': query})
52
+ gen = Generation.call(
53
+ model = "arcee-ai/Saul-Base-Calme-7B-Instruct-slerp",
54
+ messages=messages,
55
+ result_format='message',
56
+ stream=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  )
58
+ for response in gen:
59
+ if response.status_code == HTTPStatus.OK:
60
+ role = response.output.choices[0].message.role
61
+ response = response.output.choices[0].message.content
62
+ system, history = messages_to_history(messages + [{'role': role, 'content': response}])
63
+ yield '', history, system
64
+ else:
65
+ raise HTTPError('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
66
+ response.request_id, response.status_code,
67
+ response.code, response.message
68
+ ))
69
+
70
+
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown("""<center><font size=8>Qarcee-ai/Saul-Base-Calme-7B-Instruct-slerp</center>""")
73
+ gr.Markdown("""<center><font size=4>arcee-ai/Saul-Base-Calme-7B-Instruct-slerp.</center>""")
74
+
75
+ with gr.Row():
76
+ with gr.Column(scale=3):
77
+ system_input = gr.Textbox(value=default_system, lines=1, label='System')
78
+ with gr.Column(scale=1):
79
+ modify_system = gr.Button("🛠️ Set system prompt and clear history.", scale=2)
80
+ system_state = gr.Textbox(value=default_system, visible=False)
81
+ chatbot = gr.Chatbot(label='arcee-ai/Saul-Base-Calme-7B-Instruct-slerp')
82
+ textbox = gr.Textbox(lines=2, label='Input')
83
+
84
+ with gr.Row():
85
+ clear_history = gr.Button("🧹 Clear history")
86
+ sumbit = gr.Button("🚀 Send")
87
+
88
+ sumbit.click(model_chat,
89
+ inputs=[textbox, chatbot, system_state],
90
+ outputs=[textbox, chatbot, system_input])
91
+ clear_history.click(fn=clear_session,
92
+ inputs=[],
93
+ outputs=[textbox, chatbot])
94
+ modify_system.click(fn=modify_system_session,
95
+ inputs=[system_input],
96
+ outputs=[system_state, system_input, chatbot])
97
 
98
+ demo.queue(api_open=False).launch(max_threads=10,height=800, share=False)