File size: 4,645 Bytes
d831107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba69c5d
d831107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr

description = """# <p style="text-align: center; color: white;"> 🎅 <span style='color: #ff75b3;'>SantaCoder:</span> Code Generation </p>
<span style='color: white;'>This is a demo to generate code with <a href="https://huggingface.co/bigcode/santacoder" style="color: #ff75b3;">SantaCoder</a>,
a 1.1B parameter model for code generation in Python, Java & JavaScript. The model can also do infilling, just specify where you would like the model to complete code
with the <span style='color: #ff75b3;'>&lt;FILL-HERE&gt;</span> token.</span>"""

def post_processing(prompt, completion):
    completion = "<span style='color: #ff75b3;'>" + completion + "</span>"
    prompt = "<span style='color: #727cd6;'>" + prompt + "</span>"
    code_html = f"<br><hr><br><pre style='font-size: 12px'><code>{prompt}{completion}</code></pre><br><hr>"
    return GENERATION_TITLE + code_html

def post_processing_fim(prefix, middle, suffix):
    prefix = "<span style='color: #727cd6;'>" + prefix + "</span>"
    middle = "<span style='color: #ff75b3;'>" + middle + "</span>"
    suffix = "<span style='color: #727cd6;'>" + suffix + "</span>"
    code_html = f"<br><hr><br><pre style='font-size: 12px'><code>{prefix}{middle}{suffix}</code></pre><br><hr>"
    return GENERATION_TITLE + code_html

def fim_generation(prompt, max_new_tokens, temperature):
    prefix = prompt.split("<FILL-HERE>")[0]
    suffix = prompt.split("<FILL-HERE>")[1]
    [middle] = infill((prefix, suffix), max_new_tokens, temperature)
    return post_processing_fim(prefix, middle, suffix)

def extract_fim_part(s: str):
    # Find the index of 
    start = s.find(FIM_MIDDLE) + len(FIM_MIDDLE)
    stop = s.find(EOD, start) or len(s)
    return s[start:stop]

def infill(prefix_suffix_tuples, max_new_tokens, temperature):
    if type(prefix_suffix_tuples) == tuple:
        prefix_suffix_tuples = [prefix_suffix_tuples]
        
    prompts = [f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" for prefix, suffix in prefix_suffix_tuples]
    # `return_token_type_ids=False` is essential, or we get nonsense output.
    inputs = tokenizer_fim(prompts, return_tensors="pt", padding=True, return_token_type_ids=False).to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            do_sample=True,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.pad_token_id
        )
    # WARNING: cannot use skip_special_tokens, because it blows away the FIM special tokens.
    return [        
        extract_fim_part(tokenizer_fim.decode(tensor, skip_special_tokens=False)) for tensor in outputs
    ]


def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42):
    #set_seed(seed)
    
    if "<FILL-HERE>" in prompt:
        return fim_generation(prompt, max_new_tokens, temperature=0.2)
    else:
        completion = pipe(prompt, do_sample=True, top_p=0.95, temperature=temperature, max_new_tokens=max_new_tokens)[0]['generated_text']
        completion = completion[len(prompt):]
        return post_processing(prompt, completion)


with gr.Blocks():
    with gr.Row():
        _, colum_2, _ = gr.Column(scale=1), gr.Column(scale=6), gr.Column(scale=1)
        with colum_2:
            gr.Markdown(value=description)
            code = gr.Textbox(lines=5, label="Input code", value="def all_odd_elements(sequence):\n    \"\"\"Returns every odd element of the sequence.\"\"\"")
            
            with gr.Accordion("Advanced settings", open=False):
                max_new_tokens= gr.Slider(
                    minimum=8,
                    maximum=1024,
                    step=1,
                    value=48,
                    label="Number of tokens to generate",
                )
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=2.5,
                    step=0.1,
                    value=0.2,
                    label="Temperature",
                )
                seed = gr.Slider(
                    minimum=0,
                    maximum=1000,
                    step=1,
                    label="Random seed to use for the generation"
                )
            run = gr.Button()
            output = gr.HTML(label="Generated code")

    event = run.click(code_generation, [code, max_new_tokens, temperature, seed], output, api_name="predict")
    gr.HTML(label="Contact", value="<img src='https://huggingface.co/datasets/bigcode/admin/resolve/main/bigcode_contact.png' alt='contact' style='display: block; margin: auto; max-width: 800px;'>")

demo.launch()