File size: 5,533 Bytes
c12c1d4
 
19890e4
fbd4b06
c12c1d4
19890e4
c12c1d4
 
 
fbd4b06
c12c1d4
fbd4b06
 
c12c1d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbd4b06
c12c1d4
fbd4b06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c12c1d4
 
 
 
 
 
 
 
 
 
fbd4b06
 
 
c12c1d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbd4b06
c12c1d4
 
 
fbd4b06
 
 
 
 
 
 
 
 
 
 
 
 
c12c1d4
fbd4b06
 
c12c1d4
 
fbd4b06
c12c1d4
fbd4b06
c12c1d4
 
 
 
 
fbd4b06
c12c1d4
 
fbd4b06
c12c1d4
 
 
 
 
 
 
 
 
 
fbd4b06
c12c1d4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Import necessary libraries
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import torch
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set device to GPU if available, otherwise CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("mrm8488/falcoder-7b")
model = AutoModelForCausalLM.from_pretrained("mrm8488/falcoder-7b")

def generate_text(prompt, max_length, do_sample, temperature, top_k, top_p):
    """
    Generates text completion given a prompt and specified parameters.

    :param prompt: Input prompt for text generation.
    :type prompt: str
    :param max_length: Maximum length of generated text.
    :type max_length: int
    :param do_sample: Whether to use sampling for text generation.
    :type do_sample: bool
    :param temperature: Sampling temperature for text generation.
    :type temperature: float
    :param top_k: Value for top-k sampling.
    :type top_k: int
    :param top_p: Value for top-p sampling.
    :type top_p: float
    :return: Generated text completion.
    :rtype: str
    """
    
    # Format prompt
    formatted_prompt = "\n" + prompt
    if not ',' in prompt:
        formatted_prompt += ','
    
    # Tokenize prompt and move to device
    prompt = tokenizer(formatted_prompt, return_tensors='pt')
    prompt = {key: value.to(device) for key, value in prompt.items()}
    
    # Generate text completion using model and specified parameters
    out = model.generate(**prompt, max_length=max_length, do_sample=do_sample, temperature=temperature,
                          no_repeat_ngram_size=3, top_k=top_k, top_p=top_p)
    output = tokenizer.decode(out[0])
    clean_output = output.replace('\n', '\n')
    
    # Log generated text completion
    logger.info("Text generated: %s", clean_output)
    
    return clean_output

# Define Gradio interface
custom_css = """
.gradio-container {
    background-color: #0D1525; 
    color:white
}
#orange-button {
    background: #F26207 !important;
    color: white;
}
.cm-gutters{
    border: none !important;
}
"""

def post_processing(prompt, completion):
    """
    Formats generated text completion for display.

    :param prompt: Input prompt for text generation.
    :type prompt: str
    :param completion: Generated text completion.
    :type completion: str
    :return: Formatted text completion.
    :rtype: str
    """
    return prompt + completion

def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42, top_p=0.9, top_k=None, use_cache=True, repetition_penalty=1.0):
    """
    Generates code completion given a prompt and specified parameters.

    :param prompt: Input prompt for code generation.
    :type prompt: str
    :param max_new_tokens: Maximum number of tokens to generate.
    :type max_new_tokens: int
    :param temperature: Sampling temperature for code generation.
    :type temperature: float
    :param seed: Random seed for code generation.
    :type seed: int
    :param top_p: Value for top-p sampling.
    :type top_p: float
    :param top_k: Value for top-k sampling.
    :type top_k: int
    :param use_cache: Whether to use cache for code generation.
    :type use_cache: bool
    :param repetition_penalty: Value for repetition penalty.
    :type repetition_penalty: float
    :return: Generated code completion.
    :rtype: str
    """
    
    # Truncate prompt if too long
    MAX_INPUT_TOKENS = 2048
    if len(prompt) > MAX_INPUT_TOKENS:
        prompt = prompt[-MAX_INPUT_TOKENS:]
    
    # Tokenize prompt and move to device
    x = tokenizer.encode(prompt, return_tensors="pt", max_length=MAX_INPUT_TOKENS, truncation=True).to(device)
    logger.info("Prompt shape: %s", x.shape)
    
    # Generate code completion using model and specified parameters
    set_seed(seed)
    y = model.generate(x, 
                       max_new_tokens=max_new_tokens, 
                       temperature=temperature, 
                       pad_token_id=tokenizer.pad_token_id, 
                       eos_token_id=tokenizer.eos_token_id, 
                       top_p=top_p,
                       top_k=top_k,
                       use_cache=use_cache,
                       repetition_penalty=repetition_penalty
                    )
    completion = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
    completion = completion[len(prompt):]
    
    return post_processing(prompt, completion)

description = """
### Falcoder

Falcoder is a GPT-2 model fine-tuned on Python code. It can be used for generating code completions given a prompt.

### Text Generation

Use the text generation section to generate text completions given a prompt. You can adjust the maximum length of the generated text, whether to use sampling, the sampling temperature, and the top-k and top-p values for sampling.

### Code Generation

Use the code generation section to generate code completions given a prompt. You can adjust the maximum number of tokens to generate, the sampling temperature, the random seed, the top-p and top-k values for sampling, whether to use cache, and the repetition penalty.
"""

demo = gr.Interface(
    [generate_text, code_generation], 
    ["textbox", "textbox"], 
    ["textbox", "textbox"],
    title="Falcoder",
    description=description,
    theme="compact",
    layout="vertical",
    css=custom_css
)

# Launch Gradio interface
demo.launch()