Canstralian's picture
Update app.py
688a2bc verified
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load the Starcoder2 model and tokenizer
model_name = "bigcode/starcoder2-3b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Load a sample dataset (you can change this to any dataset you prefer)
dataset = load_dataset("code_search_net", "python", split="train[:100]")
def generate_code(prompt, max_length=100):
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=max_length)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def get_random_sample():
random_sample = dataset[int(len(dataset) * gr.Random().random())]
return random_sample['func_code_string']
with gr.Blocks() as demo:
gr.Markdown("# Starcoder2 Code Generation Demo")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input Prompt", lines=5)
max_length = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Max Output Length")
submit_btn = gr.Button("Generate Code")
random_btn = gr.Button("Get Random Sample")
with gr.Column():
output_text = gr.Textbox(label="Generated Code", lines=10)
submit_btn.click(generate_code, inputs=[input_text, max_length], outputs=output_text)
random_btn.click(get_random_sample, outputs=input_text)
gr.Markdown("""
## How to use:
1. Enter a prompt in the input box or click 'Get Random Sample' to load a random code snippet.
2. Adjust the max output length if needed.
3. Click 'Generate Code' to see the model's output.
""")
demo.launch()