Sakalti commited on
Commit
a6a997a
Β·
verified Β·
1 Parent(s): 0ac2fac

Create Test.py

Browse files
Files changed (1) hide show
  1. Test.py +62 -0
Test.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import spaces
5
+
6
+ model_name = "Sakalti/SakalFusion-7B-Alpha"
7
+
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ model_name,
10
+ torch_dtype=torch.bfloat16,
11
+ device_map="auto"
12
+ )
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+
15
+ @spaces.gpu(duration=100)
16
+ def generate(prompt, history, top_p, top_k, max_new_tokens, repetition_penalty, temperature):
17
+ messages = [
18
+ {"role": "system", "content": "あγͺγŸγ―γƒ•γƒ¬γƒ³γƒ‰γƒͺγƒΌγͺγƒγƒ£γƒƒγƒˆγƒœγƒƒγƒˆγ§γ™γ€‚"},
19
+ {"role": "user", "content": prompt}
20
+ ]
21
+ text = tokenizer.apply_chat_template(
22
+ messages,
23
+ tokenize=False,
24
+ add_generation_prompt=True
25
+ )
26
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
27
+
28
+ generated_ids = model.generate(
29
+ **model_inputs,
30
+ max_new_tokens=max_new_tokens,
31
+ top_p=top_p,
32
+ top_k=top_k,
33
+ repetition_penalty=repetition_penalty,
34
+ temperature=temperature
35
+ )
36
+ generated_ids = [
37
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
38
+ ]
39
+
40
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
41
+ return response, history + [[prompt, response]]
42
+
43
+ with gr.Blocks() as demo:
44
+ chatbot = gr.Chatbot()
45
+ msg = gr.Textbox()
46
+ clear = gr.Button("Clear")
47
+
48
+ with gr.Row():
49
+ top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top P")
50
+ top_k = gr.Slider(0, 100, value=50, label="Top K")
51
+ max_new_tokens = gr.Slider(1, 2048, value=864, label="Max New Tokens")
52
+ repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, label="Repetition Penalty")
53
+ temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
54
+
55
+ def respond(message, chat_history, top_p, top_k, max_new_tokens, repetition_penalty, temperature):
56
+ bot_message, chat_history = generate(message, chat_history, top_p, top_k, max_new_tokens, repetition_penalty, temperature)
57
+ return "", chat_history, chat_history
58
+
59
+ msg.submit(respond, [msg, chatbot, top_p, top_k, max_new_tokens, repetition_penalty, temperature], [msg, chatbot, chatbot])
60
+ clear.click(lambda: ([], []), None, [chatbot, msg])
61
+
62
+ demo.launch(share=True)