Not-Grim-Refer commited on
Commit
fbd4b06
1 Parent(s): 19890e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -1
app.py CHANGED
@@ -1,3 +1,144 @@
 
1
  import gradio as gr
 
2
 
3
- gr.Interface.load("models/TheBloke/WizardLM-Uncensored-Falcon-7B-GPTQ").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+ import torch
4
 
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
6
+
7
+ model = AutoModelForCausalLM.from_pretrained("mrm8488/falcoder-7b", trust_remote_code=True)
8
+
9
+ description = """# <h1 style="text-align: center; color: white;"><span style='color: #F26207;'> Code Completion with falcoder-7b </h1>
10
+ <span style="color: white; text-align: center;"> falcoder-7b You can click the button to generate your code.</span>"""
11
+
12
+
13
+ token = os.environ["HUB_TOKEN"]
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ PAD_TOKEN = "<|pad|>"
17
+ EOS_TOKEN = "<|endoftext|>"
18
+ UNK_TOKEN = "<|unk|>"
19
+ MAX_INPUT_TOKENS = 1024 # max tokens from context
20
+
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True)
23
+ tokenizer.truncation_side = "left" # ensures if truncate, then keep the last N tokens of the prompt going L -> R
24
+
25
+ if device == "cuda":
26
+ model = AutoModelForCausalLM.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True, low_cpu_mem_usage=True).to(device, dtype=torch.bfloat16)
27
+ else:
28
+ model = AutoModelForCausalLM.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True, low_cpu_mem_usage=True)
29
+
30
+ model.eval()
31
+
32
+
33
+ custom_css = """
34
+ .gradio-container {
35
+ background-color: #0D1525;
36
+ color:white
37
+ }
38
+ #orange-button {
39
+ background: #F26207 !important;
40
+ color: white;
41
+ }
42
+ .cm-gutters{
43
+ border: none !important;
44
+ }
45
+ """
46
+
47
+ def post_processing(prompt, completion):
48
+ return prompt + completion
49
+ # completion = "<span style='color: #499cd5;'>" + completion + "</span>"
50
+ # prompt = "<span style='color: black;'>" + prompt + "</span>"
51
+ # code_html = f"<hr><br><pre style='font-size: 14px'><code>{prompt}{completion}</code></pre><br><hr>"
52
+ # return code_html
53
+
54
+
55
+ def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42, top_p=0.9, top_k=None, use_cache=True, repetition_penalty=1.0):
56
+
57
+ # truncates the prompt to MAX_INPUT_TOKENS if its too long
58
+ x = tokenizer.encode(prompt, return_tensors="pt", max_length=MAX_INPUT_TOKENS, truncation=True).to(device)
59
+ print("Prompt shape: ", x.shape) # just adding to see in the space logs in prod
60
+ set_seed(seed)
61
+ y = model.generate(x,
62
+ max_new_tokens=max_new_tokens,
63
+ temperature=temperature,
64
+ pad_token_id=tokenizer.pad_token_id,
65
+ eos_token_id=tokenizer.eos_token_id,
66
+ top_p=top_p,
67
+ top_k=top_k,
68
+ use_cache=use_cache,
69
+ repetition_penalty=repetition_penalty
70
+ )
71
+ completion = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
72
+ completion = completion[len(prompt):]
73
+ return post_processing(prompt, completion)
74
+
75
+
76
+ demo = gr.Blocks(
77
+ css=custom_css
78
+ )
79
+
80
+ with demo:
81
+ gr.Markdown(value=description)
82
+ with gr.Row():
83
+ input_col , settings_col = gr.Column(scale=6), gr.Column(scale=6),
84
+ with input_col:
85
+ code = gr.Code(lines=28,label='Input', value="def sieve_eratosthenes(n):")
86
+ with settings_col:
87
+ with gr.Accordion("Generation Settings", open=True):
88
+ max_new_tokens= gr.Slider(
89
+ minimum=8,
90
+ maximum=128,
91
+ step=1,
92
+ value=48,
93
+ label="Max Tokens",
94
+ )
95
+ temperature = gr.Slider(
96
+ minimum=0.1,
97
+ maximum=2.5,
98
+ step=0.1,
99
+ value=0.2,
100
+ label="Temperature",
101
+ )
102
+ repetition_penalty = gr.Slider(
103
+ minimum=1.0,
104
+ maximum=1.9,
105
+ step=0.1,
106
+ value=1.0,
107
+ label="Repetition Penalty. 1.0 means no penalty.",
108
+ )
109
+ seed = gr.Slider(
110
+ minimum=0,
111
+ maximum=1000,
112
+ step=1,
113
+ label="Random Seed"
114
+ )
115
+ top_p = gr.Slider(
116
+ minimum=0.1,
117
+ maximum=1.0,
118
+ step=0.1,
119
+ value=0.9,
120
+ label="Top P",
121
+ )
122
+ top_k = gr.Slider(
123
+ minimum=1,
124
+ maximum=64,
125
+ step=1,
126
+ value=4,
127
+ label="Top K",
128
+ )
129
+ use_cache = gr.Checkbox(
130
+ label="Use Cache",
131
+ value=True
132
+ )
133
+
134
+ with gr.Row():
135
+ run = gr.Button(elem_id="orange-button", value="Generate")
136
+
137
+ # with gr.Row():
138
+ # # _, middle_col_row_2, _ = gr.Column(scale=1), gr.Column(scale=6), gr.Column(scale=1)
139
+ # # with middle_col_row_2:
140
+ # output = gr.HTML(label="Generated Code")
141
+
142
+ event = run.click(code_generation, [code, max_new_tokens, temperature, seed, top_p, top_k, use_cache, repetition_penalty], code, api_name="predict")
143
+
144
+ demo.queue(max_size=40).launch()