Sabareeshr commited on
Commit
bed46d8
·
verified ·
1 Parent(s): d93b643

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -0
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio import inputs, outputs, Interface
2
+ from huggingface_hub import InferenceClient
3
+
4
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1")
5
+
6
+
7
+ def format_prompt(message, history):
8
+ prompt = "<s>"
9
+ for user_prompt, bot_response in history:
10
+ prompt += f"[INST] {user_prompt} [/INST]"
11
+ prompt += f" {bot_response}</s> "
12
+ prompt += f"[INST] {message} [/INST]"
13
+ return prompt
14
+
15
+
16
+ def generate(
17
+ prompt,
18
+ history,
19
+ temperature=0.9,
20
+ max_new_tokens=256,
21
+ top_p=0.95,
22
+ repetition_penalty=1.0,
23
+ ):
24
+ temperature = float(temperature)
25
+ if temperature < 1e-2:
26
+ temperature = 1e-2
27
+ top_p = float(top_p)
28
+
29
+ generate_kwargs = dict(
30
+ temperature=temperature,
31
+ max_new_tokens=max_new_tokens,
32
+ top_p=top_p,
33
+ repetition_penalty=repetition_penalty,
34
+ do_sample=True,
35
+ seed=42,
36
+ )
37
+
38
+ formatted_prompt = format_prompt(prompt, history)
39
+
40
+ stream = client.text_generation(
41
+ formatted_prompt,
42
+ **generate_kwargs,
43
+ stream=True,
44
+ details=True,
45
+ return_full_text=False,
46
+ )
47
+ output = ""
48
+
49
+ for response in stream:
50
+ output += response.token.text
51
+ yield output
52
+ return output
53
+
54
+
55
+ additional_inputs = [
56
+ inputs.Slider(
57
+ label="Temperature",
58
+ default=0.9,
59
+ min=0.0,
60
+ max=1.0,
61
+ step=0.05,
62
+ description="Higher values produce more diverse outputs",
63
+ ),
64
+ inputs.Slider(
65
+ label="Max new tokens",
66
+ default=256,
67
+ min=0,
68
+ max=1048,
69
+ step=64,
70
+ description="The maximum numbers of new tokens",
71
+ ),
72
+ inputs.Slider(
73
+ label="Top-p (nucleus sampling)",
74
+ default=0.90,
75
+ min=0.0,
76
+ max=1,
77
+ step=0.05,
78
+ description="Higher values sample more low-probability tokens",
79
+ ),
80
+ inputs.Slider(
81
+ label="Repetition penalty",
82
+ default=1.2,
83
+ min=1.0,
84
+ max=2.0,
85
+ step=0.05,
86
+ description="Penalize repeated tokens",
87
+ ),
88
+ ]
89
+
90
+
91
+ interface = Interface(
92
+ fn=generate,
93
+ inputs=[
94
+ inputs.Textbox(
95
+ label="User Prompt",
96
+ lines=2,
97
+ placeholder="Type your message here...",
98
+ ),
99
+ inputs.Textbox(
100
+ label="Bot Response",
101
+ lines=2,
102
+ placeholder="Bot's response will appear here...",
103
+ ),
104
+ *additional_inputs,
105
+ ],
106
+ outputs=outputs.Textbox(label="Conversation", lines=10),
107
+ title="Mistral 7B",
108
+ layout="vertical",
109
+ theme="compact",
110
+ )
111
+
112
+ interface.launch(share=False)