|
import streamlit as st |
|
import torch |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
|
|
model_name = "Bigcode/starcoder2" |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
def code_complete(prompt, max_length=256): |
|
""" |
|
Generate code completion suggestions for the given prompt. |
|
|
|
Args: |
|
prompt (str): The incomplete code snippet. |
|
max_length (int, optional): The maximum length of the generated code. Defaults to 256. |
|
|
|
Returns: |
|
list: A list of code completion suggestions. |
|
""" |
|
|
|
inputs = tokenizer.encode_plus(prompt, |
|
add_special_tokens=True, |
|
max_length=max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_attention_mask=True, |
|
return_tensors="pt") |
|
|
|
|
|
outputs = model.generate(inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
max_length=max_length) |
|
|
|
|
|
suggestions = [] |
|
for output in outputs: |
|
decoded_code = tokenizer.decode(output, skip_special_tokens=True) |
|
suggestions.append(decoded_code) |
|
|
|
return suggestions |
|
|
|
def code_fix(code): |
|
""" |
|
Fix errors in the given code snippet. |
|
|
|
Args: |
|
code (str): The code snippet with errors. |
|
|
|
Returns: |
|
str: The corrected code snippet. |
|
""" |
|
|
|
inputs = tokenizer.encode_plus(code, |
|
add_special_tokens=True, |
|
max_length=512, |
|
padding="max_length", |
|
truncation=True, |
|
return_attention_mask=True, |
|
return_tensors="pt") |
|
|
|
|
|
outputs = model.generate(inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
max_length=512) |
|
|
|
|
|
corrected_code = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
return corrected_code |
|
|
|
def text_to_code(text, max_length=256): |
|
""" |
|
Generate code from a natural language description. |
|
|
|
Args: |
|
text (str): The natural language description of the code. |
|
max_length (int, optional): The maximum length of the generated code. Defaults to 256. |
|
|
|
Returns: |
|
str: The generated code. |
|
""" |
|
|
|
inputs = tokenizer.encode_plus(text, |
|
add_special_tokens=True, |
|
max_length=max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_attention_mask=True, |
|
return_tensors="pt") |
|
|
|
|
|
outputs = model.generate(inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
max_length=max_length) |
|
|
|
|
|
generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
return generated_code |
|
|
|
|
|
st.title("Codebot") |
|
st.write("Welcome to the Codebot! You can use this app to generate code completions, fix errors in your code, or generate code from a natural language description.") |
|
|
|
|
|
code_completion_tab = st.tab("Code Completion") |
|
|
|
with code_completion_tab: |
|
st.write("Enter an incomplete code snippet:") |
|
prompt_input = st.text_input("Prompt:", value="") |
|
generate_button = st.button("Generate Completions") |
|
|
|
if generate_button: |
|
completions = code_complete(prompt_input) |
|
st.write("Code completions:") |
|
for i, completion in enumerate(completions): |
|
st.write(f"{i+1}. {completion}") |
|
|
|
|
|
code_fixing_tab = st.tab("Code Fixing") |
|
|
|
with code_fixing_tab: |
|
st.write("Enter a code snippet with errors:") |
|
code_input = st.text_area("Code:", height=300) |
|
fix_button = st.button("Fix Errors") |
|
|
|
if fix_button: |
|
corrected_code = code_fix(code_input) |
|
st.write("Corrected code:") |
|
st.code(corrected_code) |
|
|
|
|
|
text_to_code_tab = st.tab("Text-to-Code") |
|
|
|
with text_to_code_tab: |
|
st.write("Enter a natural language description of the code:") |
|
text_input = st.text_input("Description:", value="") |
|
generate_button = st.button("Generate Code") |
|
|
|
if generate_button: |
|
generated_code = text_to_code(text_input) |
|
st.write("Generated code:") |
|
st.code(generated_code) |
|
|
|
|
|
if __name__ == "__main__": |
|
st.run() |