ecyht2 commited on
Commit
2f0b879
1 Parent(s): 02277c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -4
app.py CHANGED
@@ -1,8 +1,133 @@
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
2
 
3
- tokenizer = AutoTokenizer.from_pretrained(
 
4
  "cognitivecomputations/dolphin-2.6-mistral-7b"
5
  )
6
- model = AutoModelForCausalLM.from_pretrained(
7
- "cognitivecomputations/dolphin-2.6-mistral-7b"
8
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ from huggingface_hub import InferenceClient
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import gradio as gr
7
+
8
+ log_level = os.environ.get("LOG_LEVEL", "WARNING")
9
+ logging.basicConfig(encoding='utf-8', level=log_level)
10
 
11
+ logging.info("Creating Inference Client")
12
+ client = InferenceClient(
13
  "cognitivecomputations/dolphin-2.6-mistral-7b"
14
  )
15
+
16
+ def format_prompt(message, history):
17
+ """Formats the prompt for the AI"""
18
+ logging.info("Formatting Prompt")
19
+ logging.debug("Input Message: %s", message)
20
+ logging.debug("Input History: %s", history)
21
+
22
+ prompt = "<|im_start|>system\n" +\
23
+ "You are Dolphin, a helpful AI assistant.<|im_end|>"
24
+ prompt += "<|im_start|>user\n" + f"{message}<|im_end|>"
25
+ prompt += "<|im_start|>assistant"
26
+
27
+ return prompt
28
+
29
+
30
+ def generate(
31
+ prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
32
+ ):
33
+ logging.info("Generating Response")
34
+ logging.debug("Input Prompt: %s", prompt)
35
+ logging.debug("Input History: %s", history)
36
+ logging.debug("Input System Prompt: %s", system_prompt)
37
+ logging.debug("Input Temperature: %s", temperature)
38
+ logging.debug("Input Max New Tokens: %s", max_new_tokens)
39
+ logging.debug("Input Top P: %s", top_p)
40
+ logging.debug("Input Repetition Penalty: %s", repetition_penalty)
41
+
42
+ logging.info("Converting Parameters to Correct Type")
43
+ temperature = float(temperature)
44
+ if temperature < 1e-2:
45
+ temperature = 1e-2
46
+ top_p = float(top_p)
47
+ logging.debug("Temperature: %s", temperature)
48
+ logging.debug("Top P: %s", top_p)
49
+
50
+ logging.info("Creating Generate kwargs")
51
+ generate_kwargs = dict(
52
+ temperature=temperature,
53
+ max_new_tokens=max_new_tokens,
54
+ top_p=top_p,
55
+ repetition_penalty=repetition_penalty,
56
+ do_sample=True,
57
+ seed=42,
58
+ )
59
+ logging.debug("Generate Args: %s", generate_kwargs)
60
+
61
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
62
+ logging.debug("Prompt: %s", formatted_prompt)
63
+
64
+ logging.info("Generating Text")
65
+ stream = client.text_generation(
66
+ formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
67
+
68
+ logging.info("Creating Output")
69
+ output = ""
70
+ for response in stream:
71
+ output += response.token.text
72
+ yield output
73
+
74
+ logging.debug("Output: %s", output)
75
+ return output
76
+
77
+
78
+ additional_inputs = [
79
+ gr.Textbox(
80
+ label="System Prompt",
81
+ max_lines=1,
82
+ interactive=True,
83
+ ),
84
+ gr.Slider(
85
+ label="Temperature",
86
+ value=0.9,
87
+ minimum=0.0,
88
+ maximum=1.0,
89
+ step=0.05,
90
+ interactive=True,
91
+ info="Higher values produce more diverse outputs",
92
+ ),
93
+ gr.Slider(
94
+ label="Max new tokens",
95
+ value=256,
96
+ minimum=0,
97
+ maximum=1048,
98
+ step=64,
99
+ interactive=True,
100
+ info="The maximum numbers of new tokens",
101
+ ),
102
+ gr.Slider(
103
+ label="Top-p (nucleus sampling)",
104
+ value=0.90,
105
+ minimum=0.0,
106
+ maximum=1,
107
+ step=0.05,
108
+ interactive=True,
109
+ info="Higher values sample more low-probability tokens",
110
+ ),
111
+ gr.Slider(
112
+ label="Repetition penalty",
113
+ value=1.2,
114
+ minimum=1.0,
115
+ maximum=2.0,
116
+ step=0.05,
117
+ interactive=True,
118
+ info="Penalize repeated tokens",
119
+ )
120
+ ]
121
+
122
+ examples = []
123
+
124
+ logging.info("Creating Chat Interface")
125
+ gr.ChatInterface(
126
+ fn=generate,
127
+ chatbot=gr.Chatbot(show_label=False, show_share_button=False,
128
+ show_copy_button=True, likeable=True, layout="panel"),
129
+ additional_inputs=additional_inputs,
130
+ title="Dolphin Mistral",
131
+ examples=examples,
132
+ concurrency_limit=20,
133
+ ).launch(show_api=False)