ecyht2 commited on
Commit
fefc78b
·
verified ·
1 Parent(s): 36e2af1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import gradio as gr
6
+
7
+ log_level = os.environ.get("LOG_LEVEL", "WARNING")
8
+ logging.basicConfig(encoding='utf-8', level=log_level)
9
+
10
+ logging.info("Loading Model")
11
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
12
+ model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
13
+
14
+ def format_prompt(message, history):
15
+ """Formats the prompt for the AI"""
16
+ logging.info("Formatting Prompt")
17
+ logging.debug("Input Message: %s", message)
18
+ logging.debug("Input History: %s", history)
19
+
20
+ prompt = f"Instruct: {message}\n"
21
+ prompt += "Output: "
22
+ return prompt
23
+
24
+
25
+ def generate(
26
+ prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
27
+ ):
28
+ logging.info("Generating Response")
29
+ logging.debug("Input Prompt: %s", prompt)
30
+ logging.debug("Input History: %s", history)
31
+ logging.debug("Input System Prompt: %s", system_prompt)
32
+ logging.debug("Input Temperature: %s", temperature)
33
+ logging.debug("Input Max New Tokens: %s", max_new_tokens)
34
+ logging.debug("Input Top P: %s", top_p)
35
+ logging.debug("Input Repetition Penalty: %s", repetition_penalty)
36
+
37
+ logging.info("Converting Parameters to Correct Type")
38
+ temperature = float(temperature)
39
+ if temperature < 1e-2:
40
+ temperature = 1e-2
41
+ top_p = float(top_p)
42
+ logging.debug("Temperature: %s", temperature)
43
+ logging.debug("Top P: %s", top_p)
44
+
45
+ logging.info("Creating Generate kwargs")
46
+ generate_kwargs = dict(
47
+ temperature=temperature,
48
+ max_new_tokens=max_new_tokens,
49
+ top_p=top_p,
50
+ repetition_penalty=repetition_penalty,
51
+ do_sample=True,
52
+ seed=42,
53
+ )
54
+ logging.debug("Generate Args: %s", generate_kwargs)
55
+
56
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
57
+ logging.debug("Prompt: %s", formatted_prompt)
58
+
59
+ logging.info("Generating Text")
60
+ stream = model.generate(tokenizer(prompt, return_tensors="pt"), **generate_kwargs)
61
+
62
+ logging.info("Creating Output")
63
+ output = ""
64
+ for response in stream:
65
+ output += response.token.text
66
+ yield output
67
+
68
+ logging.debug("Output: %s", output)
69
+ return output
70
+
71
+
72
+ additional_inputs = [
73
+ gr.Textbox(
74
+ label="System Prompt",
75
+ max_lines=1,
76
+ interactive=True,
77
+ ),
78
+ gr.Slider(
79
+ label="Temperature",
80
+ value=0.9,
81
+ minimum=0.0,
82
+ maximum=1.0,
83
+ step=0.05,
84
+ interactive=True,
85
+ info="Higher values produce more diverse outputs",
86
+ ),
87
+ gr.Slider(
88
+ label="Max new tokens",
89
+ value=256,
90
+ minimum=0,
91
+ maximum=1048,
92
+ step=64,
93
+ interactive=True,
94
+ info="The maximum numbers of new tokens",
95
+ ),
96
+ gr.Slider(
97
+ label="Top-p (nucleus sampling)",
98
+ value=0.90,
99
+ minimum=0.0,
100
+ maximum=1,
101
+ step=0.05,
102
+ interactive=True,
103
+ info="Higher values sample more low-probability tokens",
104
+ ),
105
+ gr.Slider(
106
+ label="Repetition penalty",
107
+ value=1.2,
108
+ minimum=1.0,
109
+ maximum=2.0,
110
+ step=0.05,
111
+ interactive=True,
112
+ info="Penalize repeated tokens",
113
+ )
114
+ ]
115
+
116
+ examples = []
117
+
118
+ logging.info("Creating Chat Interface")
119
+ gr.ChatInterface(
120
+ fn=generate,
121
+ chatbot=gr.Chatbot(show_label=False, show_share_button=False,
122
+ show_copy_button=True, likeable=True, layout="panel"),
123
+ additional_inputs=additional_inputs,
124
+ title="Mixtral Instruct",
125
+ examples=examples,
126
+ concurrency_limit=20,
127
+ ).launch(show_api=False)