Not-Grim-Refer commited on
Commit
c12c1d4
1 Parent(s): 465c5b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -99
app.py CHANGED
@@ -1,36 +1,61 @@
1
- import os
 
2
  import gradio as gr
3
  import torch
 
4
 
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
 
 
6
 
7
- model = AutoModelForCausalLM.from_pretrained("mrm8488/falcoder-7b", trust_remote_code=True)
8
-
9
- description = """# <h1 style="text-align: center; color: white;"><span style='color: #F26207;'> Code Completion with falcoder-7b </h1>
10
- <span style="color: white; text-align: center;"> falcoder-7b You can click the button to generate your code.</span>"""
11
-
12
-
13
- token = os.environ["HUB_TOKEN"]
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
- PAD_TOKEN = "<|pad|>"
17
- EOS_TOKEN = "<|endoftext|>"
18
- UNK_TOKEN = "<|unk|>"
19
- MAX_INPUT_TOKENS = 1024 # max tokens from context
20
-
21
- REPO = "mrm8488/falcoder-7b"
22
-
23
- tokenizer = AutoTokenizer.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True)
24
- tokenizer.truncation_side = "left" # ensures if truncate, then keep the last N tokens of the prompt going L -> R
25
-
26
- if device == "cuda":
27
- model = AutoModelForCausalLM.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True, low_cpu_mem_usage=True).to(device, dtype=torch.bfloat16)
28
- else:
29
- model = AutoModelForCausalLM.from_pretrained(REPO, use_auth_token=token, trust_remote_code=True, low_cpu_mem_usage=True)
30
-
31
- model.eval()
32
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
34
  custom_css = """
35
  .gradio-container {
36
  background-color: #0D1525;
@@ -46,18 +71,52 @@ custom_css = """
46
  """
47
 
48
  def post_processing(prompt, completion):
 
 
 
 
 
 
 
 
 
 
49
  return prompt + completion
50
- # completion = "<span style='color: #499cd5;'>" + completion + "</span>"
51
- # prompt = "<span style='color: black;'>" + prompt + "</span>"
52
- # code_html = f"<hr><br><pre style='font-size: 14px'><code>{prompt}{completion}</code></pre><br><hr>"
53
- # return code_html
54
 
55
-
56
  def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42, top_p=0.9, top_k=None, use_cache=True, repetition_penalty=1.0):
57
-
58
- # truncates the prompt to MAX_INPUT_TOKENS if its too long
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  x = tokenizer.encode(prompt, return_tensors="pt", max_length=MAX_INPUT_TOKENS, truncation=True).to(device)
60
- print("Prompt shape: ", x.shape) # just adding to see in the space logs in prod
 
 
61
  set_seed(seed)
62
  y = model.generate(x,
63
  max_new_tokens=max_new_tokens,
@@ -71,75 +130,33 @@ def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42, top_p=0.9,
71
  )
72
  completion = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
73
  completion = completion[len(prompt):]
 
74
  return post_processing(prompt, completion)
75
 
 
 
76
 
77
- demo = gr.Blocks(
78
- css=custom_css
79
- )
80
 
81
- with demo:
82
- gr.Markdown(value=description)
83
- with gr.Row():
84
- input_col , settings_col = gr.Column(scale=6), gr.Column(scale=6),
85
- with input_col:
86
- code = gr.Code(lines=28,label='Input', value="def sieve_eratosthenes(n):")
87
- with settings_col:
88
- with gr.Accordion("Generation Settings", open=True):
89
- max_new_tokens= gr.Slider(
90
- minimum=8,
91
- maximum=128,
92
- step=1,
93
- value=48,
94
- label="Max Tokens",
95
- )
96
- temperature = gr.Slider(
97
- minimum=0.1,
98
- maximum=2.5,
99
- step=0.1,
100
- value=0.2,
101
- label="Temperature",
102
- )
103
- repetition_penalty = gr.Slider(
104
- minimum=1.0,
105
- maximum=1.9,
106
- step=0.1,
107
- value=1.0,
108
- label="Repetition Penalty. 1.0 means no penalty.",
109
- )
110
- seed = gr.Slider(
111
- minimum=0,
112
- maximum=1000,
113
- step=1,
114
- label="Random Seed"
115
- )
116
- top_p = gr.Slider(
117
- minimum=0.1,
118
- maximum=1.0,
119
- step=0.1,
120
- value=0.9,
121
- label="Top P",
122
- )
123
- top_k = gr.Slider(
124
- minimum=1,
125
- maximum=64,
126
- step=1,
127
- value=4,
128
- label="Top K",
129
- )
130
- use_cache = gr.Checkbox(
131
- label="Use Cache",
132
- value=True
133
- )
134
-
135
- with gr.Row():
136
- run = gr.Button(elem_id="orange-button", value="Generate")
137
 
138
- # with gr.Row():
139
- # # _, middle_col_row_2, _ = gr.Column(scale=1), gr.Column(scale=6), gr.Column(scale=1)
140
- # # with middle_col_row_2:
141
- # output = gr.HTML(label="Generated Code")
142
 
143
- event = run.click(code_generation, [code, max_new_tokens, temperature, seed, top_p, top_k, use_cache, repetition_penalty], code, api_name="predict")
 
 
 
 
 
 
 
 
 
144
 
145
- demo.queue(max_size=40).launch()
 
 
1
+ # Import necessary libraries
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import gradio as gr
4
  import torch
5
+ import logging
6
 
7
+ # Set up logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
 
11
+ # Set device to GPU if available, otherwise CPU
 
 
 
 
 
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
+ # Load tokenizer and model
15
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/falcoder-7b")
16
+ model = AutoModelForCausalLM.from_pretrained("mrm8488/falcoder-7b")
17
+
18
+ def generate_text(prompt, max_length, do_sample, temperature, top_k, top_p):
19
+ """
20
+ Generates text completion given a prompt and specified parameters.
21
+
22
+ :param prompt: Input prompt for text generation.
23
+ :type prompt: str
24
+ :param max_length: Maximum length of generated text.
25
+ :type max_length: int
26
+ :param do_sample: Whether to use sampling for text generation.
27
+ :type do_sample: bool
28
+ :param temperature: Sampling temperature for text generation.
29
+ :type temperature: float
30
+ :param top_k: Value for top-k sampling.
31
+ :type top_k: int
32
+ :param top_p: Value for top-p sampling.
33
+ :type top_p: float
34
+ :return: Generated text completion.
35
+ :rtype: str
36
+ """
37
+
38
+ # Format prompt
39
+ formatted_prompt = "\n" + prompt
40
+ if not ',' in prompt:
41
+ formatted_prompt += ','
42
+
43
+ # Tokenize prompt and move to device
44
+ prompt = tokenizer(formatted_prompt, return_tensors='pt')
45
+ prompt = {key: value.to(device) for key, value in prompt.items()}
46
+
47
+ # Generate text completion using model and specified parameters
48
+ out = model.generate(**prompt, max_length=max_length, do_sample=do_sample, temperature=temperature,
49
+ no_repeat_ngram_size=3, top_k=top_k, top_p=top_p)
50
+ output = tokenizer.decode(out[0])
51
+ clean_output = output.replace('\n', '\n')
52
+
53
+ # Log generated text completion
54
+ logger.info("Text generated: %s", clean_output)
55
+
56
+ return clean_output
57
 
58
+ # Define Gradio interface
59
  custom_css = """
60
  .gradio-container {
61
  background-color: #0D1525;
 
71
  """
72
 
73
  def post_processing(prompt, completion):
74
+ """
75
+ Formats generated text completion for display.
76
+
77
+ :param prompt: Input prompt for text generation.
78
+ :type prompt: str
79
+ :param completion: Generated text completion.
80
+ :type completion: str
81
+ :return: Formatted text completion.
82
+ :rtype: str
83
+ """
84
  return prompt + completion
 
 
 
 
85
 
 
86
  def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42, top_p=0.9, top_k=None, use_cache=True, repetition_penalty=1.0):
87
+ """
88
+ Generates code completion given a prompt and specified parameters.
89
+
90
+ :param prompt: Input prompt for code generation.
91
+ :type prompt: str
92
+ :param max_new_tokens: Maximum number of tokens to generate.
93
+ :type max_new_tokens: int
94
+ :param temperature: Sampling temperature for code generation.
95
+ :type temperature: float
96
+ :param seed: Random seed for code generation.
97
+ :type seed: int
98
+ :param top_p: Value for top-p sampling.
99
+ :type top_p: float
100
+ :param top_k: Value for top-k sampling.
101
+ :type top_k: int
102
+ :param use_cache: Whether to use cache for code generation.
103
+ :type use_cache: bool
104
+ :param repetition_penalty: Value for repetition penalty.
105
+ :type repetition_penalty: float
106
+ :return: Generated code completion.
107
+ :rtype: str
108
+ """
109
+
110
+ # Truncate prompt if too long
111
+ MAX_INPUT_TOKENS = 2048
112
+ if len(prompt) > MAX_INPUT_TOKENS:
113
+ prompt = prompt[-MAX_INPUT_TOKENS:]
114
+
115
+ # Tokenize prompt and move to device
116
  x = tokenizer.encode(prompt, return_tensors="pt", max_length=MAX_INPUT_TOKENS, truncation=True).to(device)
117
+ logger.info("Prompt shape: %s", x.shape)
118
+
119
+ # Generate code completion using model and specified parameters
120
  set_seed(seed)
121
  y = model.generate(x,
122
  max_new_tokens=max_new_tokens,
 
130
  )
131
  completion = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
132
  completion = completion[len(prompt):]
133
+
134
  return post_processing(prompt, completion)
135
 
136
+ description = """
137
+ ### Falcoder
138
 
139
+ Falcoder is a GPT-2 model fine-tuned on Python code. It can be used for generating code completions given a prompt.
 
 
140
 
141
+ ### Text Generation
142
+
143
+ Use the text generation section to generate text completions given a prompt. You can adjust the maximum length of the generated text, whether to use sampling, the sampling temperature, and the top-k and top-p values for sampling.
144
+
145
+ ### Code Generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ Use the code generation section to generate code completions given a prompt. You can adjust the maximum number of tokens to generate, the sampling temperature, the random seed, the top-p and top-k values for sampling, whether to use cache, and the repetition penalty.
148
+ """
 
 
149
 
150
+ demo = gr.Interface(
151
+ [generate_text, code_generation],
152
+ ["textbox", "textbox"],
153
+ ["textbox", "textbox"],
154
+ title="Falcoder",
155
+ description=description,
156
+ theme="compact",
157
+ layout="vertical",
158
+ css=custom_css
159
+ )
160
 
161
+ # Launch Gradio interface
162
+ demo.launch()