File size: 5,110 Bytes
5b79ee5 6c06d27 5b79ee5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import streamlit as st
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load the Starcoder2 model and tokenizer
model_name = "Bigcode/starcoder2"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def code_complete(prompt, max_length=256):
"""
Generate code completion suggestions for the given prompt.
Args:
prompt (str): The incomplete code snippet.
max_length (int, optional): The maximum length of the generated code. Defaults to 256.
Returns:
list: A list of code completion suggestions.
"""
# Tokenize the input prompt
inputs = tokenizer.encode_plus(prompt,
add_special_tokens=True,
max_length=max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors="pt")
# Generate code completion suggestions
outputs = model.generate(inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length)
# Decode the generated code
suggestions = []
for output in outputs:
decoded_code = tokenizer.decode(output, skip_special_tokens=True)
suggestions.append(decoded_code)
return suggestions
def code_fix(code):
"""
Fix errors in the given code snippet.
Args:
code (str): The code snippet with errors.
Returns:
str: The corrected code snippet.
"""
# Tokenize the input code
inputs = tokenizer.encode_plus(code,
add_special_tokens=True,
max_length=512,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors="pt")
# Generate corrected code
outputs = model.generate(inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=512)
# Decode the generated code
corrected_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
return corrected_code
def text_to_code(text, max_length=256):
"""
Generate code from a natural language description.
Args:
text (str): The natural language description of the code.
max_length (int, optional): The maximum length of the generated code. Defaults to 256.
Returns:
str: The generated code.
"""
# Tokenize the input text
inputs = tokenizer.encode_plus(text,
add_special_tokens=True,
max_length=max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors="pt")
# Generate code from the input text
outputs = model.generate(inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length)
# Decode the generated code
generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_code
# Create a Streamlit app
st.title("Codebot")
st.write("Welcome to the Codebot! You can use this app to generate code completions, fix errors in your code, or generate code from a natural language description.")
# Create a tab for code completion
code_completion_tab = st.tab("Code Completion")
with code_completion_tab:
st.write("Enter an incomplete code snippet:")
prompt_input = st.text_input("Prompt:", value="")
generate_button = st.button("Generate Completions")
if generate_button:
completions = code_complete(prompt_input)
st.write("Code completions:")
for i, completion in enumerate(completions):
st.write(f"{i+1}. {completion}")
# Create a tab for code fixing
code_fixing_tab = st.tab("Code Fixing")
with code_fixing_tab:
st.write("Enter a code snippet with errors:")
code_input = st.text_area("Code:", height=300)
fix_button = st.button("Fix Errors")
if fix_button:
corrected_code = code_fix(code_input)
st.write("Corrected code:")
st.code(corrected_code)
# Create a tab for text-to-code
text_to_code_tab = st.tab("Text-to-Code")
with text_to_code_tab:
st.write("Enter a natural language description of the code:")
text_input = st.text_input("Description:", value="")
generate_button = st.button("Generate Code")
if generate_button:
generated_code = text_to_code(text_input)
st.write("Generated code:")
st.code(generated_code)
# Run the Streamlit app
if __name__ == "__main__":
st.run() |