aixsatoshi commited on
Commit
5c13c72
·
verified ·
1 Parent(s): f714f44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -121
app.py CHANGED
@@ -11,142 +11,198 @@ from huggingface_hub import hf_hub_download
11
 
12
  # モデルのダウンロード
13
  hf_hub_download(
14
- repo_id="bartowski/gemma-2-9b-it-GGUF",
15
- filename="gemma-2-9b-it-Q5_K_M.gguf",
16
- local_dir="./models"
17
  )
18
 
19
  hf_hub_download(
20
- repo_id="bartowski/Gemma-2-9B-It-SPPO-Iter3-GGUF",
21
- filename="Gemma-2-9B-It-SPPO-Iter3-Q5_K_M.gguf",
22
- local_dir="./models"
 
 
 
 
 
 
 
 
 
 
 
 
23
  )
24
 
25
  # 推論関数
26
  @spaces.GPU(duration=120)
27
  def respond(
28
- message,
29
- history: list[tuple[str, str]],
30
- model,
31
- system_message,
32
- max_tokens,
33
- temperature,
34
- top_p,
35
- top_k,
36
- repeat_penalty,
37
  ):
38
- chat_template = MessagesFormatterType.GEMMA_2
39
-
40
- llm = Llama(
41
- model_path=f"models/{model}",
42
- flash_attn=True,
43
- n_gpu_layers=81,
44
- n_batch=1024,
45
- n_ctx=8192,
46
- )
47
- provider = LlamaCppPythonProvider(llm)
48
-
49
- agent = LlamaCppAgent(
50
- provider,
51
- system_prompt=f"{system_message}",
52
- predefined_messages_formatter_type=chat_template,
53
- debug_output=True
54
- )
55
-
56
- settings = provider.get_provider_default_settings()
57
- settings.temperature = temperature
58
- settings.top_k = top_k
59
- settings.top_p = top_p
60
- settings.max_tokens = max_tokens
61
- settings.repeat_penalty = repeat_penalty
62
- settings.stream = True
63
-
64
- messages = BasicChatHistory()
65
-
66
- for msn in history:
67
- user = {
68
- 'role': Roles.user,
69
- 'content': msn[0]
70
- }
71
- assistant = {
72
- 'role': Roles.assistant,
73
- 'content': msn[1]
74
- }
75
- messages.add_message(user)
76
- messages.add_message(assistant)
77
-
78
- stream = agent.get_chat_response(
79
- message,
80
- llm_sampling_settings=settings,
81
- chat_history=messages,
82
- returns_streaming_generator=True,
83
- print_output=False
84
- )
85
-
86
- outputs = ""
87
- for output in stream:
88
- outputs += output
89
- yield outputs
90
 
91
  # Gradioのインターフェースを作成
92
  def create_interface(model_name, description):
93
- return gr.ChatInterface(
94
- respond,
95
- additional_inputs=[
96
- gr.Textbox(value=model_name, label="Model", interactive=False),
97
- gr.Textbox(value="You are a helpful assistant.", label="System message"),
98
- gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens"),
99
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
100
- gr.Slider(
101
- minimum=0.1,
102
- maximum=1.0,
103
- value=0.95,
104
- step=0.05,
105
- label="Top-p",
106
- ),
107
- gr.Slider(
108
- minimum=0,
109
- maximum=100,
110
- value=40,
111
- step=1,
112
- label="Top-k",
113
- ),
114
- gr.Slider(
115
- minimum=0.0,
116
- maximum=2.0,
117
- value=1.1,
118
- step=0.1,
119
- label="Repetition penalty",
120
- ),
121
- ],
122
- retry_btn="Retry",
123
- undo_btn="Undo",
124
- clear_btn="Clear",
125
- submit_btn="Send",
126
- title=f"Chat with Gemma 2 using llama.cpp - {model_name}",
127
- description=description,
128
- chatbot=gr.Chatbot(
129
- scale=1,
130
- likeable=False,
131
- show_copy_button=True
132
- )
133
- )
134
 
135
  # 各モデルのインターフェース
136
- description_9b = """<p align="center">Gemma-2 9B it Model</p>"""
137
- description_27b = """<p align="center">Gemma-2 9B SPPO it Model</p>"""
 
 
 
 
 
 
 
138
 
139
- interface_9b = create_interface('gemma-2-9b-it-Q5_K_M.gguf', description_9b)
140
- interface_27b = create_interface('Gemma-2-9B-It-SPPO-Iter3-Q5_K_M.gguf', description_27b)
141
 
142
- # Gradio Blocksで2つのインターフェースを並べて表示
143
- with gr.Blocks() as demo:
144
- #gr.Markdown("# Compare Gemma-2 9B and 27B Models")
145
- with gr.Row():
146
- with gr.Column():
147
- interface_9b.render()
148
- with gr.Column():
149
- interface_27b.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  if __name__ == "__main__":
152
- demo.launch()
 
11
 
12
  # モデルのダウンロード
13
  hf_hub_download(
14
+ repo_id="bartowski/gemma-2-9b-it-GGUF",
15
+ filename="gemma-2-9b-it-Q5_K_M.gguf",
16
+ local_dir="./models"
17
  )
18
 
19
  hf_hub_download(
20
+ repo_id="bartowski/Gemma-2-9B-It-SPPO-Iter3-GGUF",
21
+ filename="Gemma-2-9B-It-SPPO-Iter3-Q5_K_M.gguf",
22
+ local_dir="./models"
23
+ )
24
+
25
+ hf_hub_download(
26
+ repo_id="mradermacher/EZO-Common-9B-gemma-2-it-GGUF",
27
+ filename="EZO-Common-9B-gemma-2-it.Q5_K_M.gguf",
28
+ local_dir="./models"
29
+ )
30
+
31
+ hf_hub_download(
32
+ repo_id="mradermacher/EZO-Humanities-9B-gemma-2-it-GGUF",
33
+ filename="EZO-Humanities-9B-gemma-2-it.Q5_K_M.gguf",
34
+ local_dir="./models"
35
  )
36
 
37
  # 推論関数
38
  @spaces.GPU(duration=120)
39
  def respond(
40
+ message,
41
+ history: list[tuple[str, str]],
42
+ model,
43
+ system_message,
44
+ max_tokens,
45
+ temperature,
46
+ top_p,
47
+ top_k,
48
+ repeat_penalty,
49
  ):
50
+ chat_template = MessagesFormatterType.GEMMA_2
51
+
52
+ llm = Llama(
53
+ model_path=f"models/{model}",
54
+ flash_attn=True,
55
+ n_gpu_layers=81,
56
+ n_batch=1024,
57
+ n_ctx=8192,
58
+ )
59
+ provider = LlamaCppPythonProvider(llm)
60
+
61
+ agent = LlamaCppAgent(
62
+ provider,
63
+ system_prompt=f"{system_message}",
64
+ predefined_messages_formatter_type=chat_template,
65
+ debug_output=True
66
+ )
67
+
68
+ settings = provider.get_provider_default_settings()
69
+ settings.temperature = temperature
70
+ settings.top_k = top_k
71
+ settings.top_p = top_p
72
+ settings.max_tokens = max_tokens
73
+ settings.repeat_penalty = repeat_penalty
74
+ settings.stream = True
75
+
76
+ messages = BasicChatHistory()
77
+
78
+ for msn in history:
79
+ user = {
80
+ 'role': Roles.user,
81
+ 'content': msn[0]
82
+ }
83
+ assistant = {
84
+ 'role': Roles.assistant,
85
+ 'content': msn[1]
86
+ }
87
+ messages.add_message(user)
88
+ messages.add_message(assistant)
89
+
90
+ stream = agent.get_chat_response(
91
+ message,
92
+ llm_sampling_settings=settings,
93
+ chat_history=messages,
94
+ returns_streaming_generator=True,
95
+ print_output=False
96
+ )
97
+
98
+ outputs = ""
99
+ for output in stream:
100
+ outputs += output
101
+ yield outputs
102
 
103
  # Gradioのインターフェースを作成
104
  def create_interface(model_name, description):
105
+ return gr.ChatInterface(
106
+ respond,
107
+ additional_inputs=[
108
+ gr.Textbox(value=model_name, label="Model", interactive=False),
109
+ gr.Textbox(value="You are a helpful assistant.", label="System message"),
110
+ gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens"),
111
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
112
+ gr.Slider(
113
+ minimum=0.1,
114
+ maximum=1.0,
115
+ value=0.95,
116
+ step=0.05,
117
+ label="Top-p",
118
+ ),
119
+ gr.Slider(
120
+ minimum=0,
121
+ maximum=100,
122
+ value=40,
123
+ step=1,
124
+ label="Top-k",
125
+ ),
126
+ gr.Slider(
127
+ minimum=0.0,
128
+ maximum=2.0,
129
+ value=1.1,
130
+ step=0.1,
131
+ label="Repetition penalty",
132
+ ),
133
+ ],
134
+ retry_btn="Retry",
135
+ undo_btn="Undo",
136
+ clear_btn="Clear",
137
+ submit_btn="Send",
138
+ title=f"{model_name}",
139
+ description=description,
140
+ chatbot=gr.Chatbot(
141
+ scale=1,
142
+ likeable=False,
143
+ show_copy_button=True
144
+ )
145
+ )
146
 
147
  # 各モデルのインターフェース
148
+ description_a = """<p align="center">Gemma-2 9B it Q5</p>"""
149
+ description_b = """<p align="center">Gemma-2 9B SPPO it Q5</p>"""
150
+ description_c = """<p align="center">EZO Common 9B it Q5</p>"""
151
+ description_d = """<p align="center">EZO Humanities 9B it Q5</p>"""
152
+
153
+ interface_a = create_interface('gemma-2-9b-it-Q5_K_M.gguf', description_a)
154
+ interface_b = create_interface('Gemma-2-9B-It-SPPO-Iter3-Q5_K_M.gguf', description_b)
155
+ interface_c = create_interface('EZO-Common-9B-gemma-2-it.Q5_K_M.gguf', description_c)
156
+ interface_d = create_interface('EZO-Humanities-9B-gemma-2-it.Q5_K_M.gguf', description_d)
157
 
158
+ # Gradio Blocksで4つのインターフェースを並べて表示
159
+ demo = gr.Blocks()
160
 
161
+ with demo:
162
+ gr.HTML("""
163
+ <style>
164
+ body, html {
165
+ height: 100%;
166
+ margin: 0;
167
+ overflow: auto;
168
+ }
169
+ .gradio-container {
170
+ max-height: 100vh;
171
+ overflow-y: auto;
172
+ }
173
+ .gr-row {
174
+ display: flex;
175
+ flex-wrap: wrap;
176
+ height: 50vh;
177
+ margin-bottom: 20px;
178
+ }
179
+ .gr-column {
180
+ flex: 1 1 calc(50% - 20px);
181
+ margin: 10px;
182
+ max-height: 100%;
183
+ overflow-y: auto;
184
+ }
185
+ .gr-column > div {
186
+ height: 100%;
187
+ display: flex;
188
+ flex-direction: column;
189
+ }
190
+ .gr-column > div > div:last-child {
191
+ flex-grow: 1;
192
+ overflow-y: auto;
193
+ }
194
+ </style>
195
+ """)
196
+ with gr.Row():
197
+ with gr.Column():
198
+ interface_c.render()
199
+ with gr.Column():
200
+ interface_d.render()
201
+ with gr.Row():
202
+ with gr.Column():
203
+ interface_a.render()
204
+ with gr.Column():
205
+ interface_b.render()
206
 
207
  if __name__ == "__main__":
208
+ demo.launch()