|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
import gradio as gr |
|
import requests |
|
import re |
|
|
|
|
|
model_name = "Salesforce/codet5-base" |
|
tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
model = T5ForConditionalGeneration.from_pretrained(model_name) |
|
|
|
|
|
def fetch_code_from_github(keyword, language="python"): |
|
try: |
|
|
|
url = f"https://api.github.com/search/code?q={keyword}+language:{language}" |
|
headers = {"Accept": "application/vnd.github.v3+json"} |
|
response = requests.get(url, headers=headers) |
|
|
|
if response.status_code == 200: |
|
results = response.json() |
|
|
|
items = results.get("items", []) |
|
if items: |
|
first_result = items[0] |
|
file_url = first_result["html_url"] |
|
return f"Code fetched from: {file_url}" |
|
else: |
|
return "No matching code found on GitHub." |
|
else: |
|
return f"Error fetching code: {response.status_code}" |
|
except Exception as e: |
|
return str(e) |
|
|
|
|
|
def modify_or_generate_code(input_text, task="generate_code"): |
|
try: |
|
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) |
|
outputs = model.generate(**inputs, max_length=512, num_beams=4, early_stopping=True) |
|
generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return generated_code |
|
except Exception as e: |
|
return str(e) |
|
|
|
|
|
def main_interface(task, input_text, fetch_keyword=None, language="python"): |
|
if task == "fetch_code": |
|
return fetch_code_from_github(fetch_keyword, language) |
|
elif task == "generate_code": |
|
return modify_or_generate_code(input_text, task) |
|
else: |
|
return "Invalid task selected." |
|
|
|
|
|
interface = gr.Interface( |
|
fn=main_interface, |
|
inputs=[ |
|
gr.Radio(["fetch_code", "generate_code"], label="Task"), |
|
gr.Textbox(lines=5, placeholder="Enter your input text or description here...", label="Input Text"), |
|
gr.Textbox(placeholder="Enter keyword for fetching code...", label="Fetch Keyword (Optional)"), |
|
gr.Textbox(value="python", placeholder="Language (default: python)", label="Programming Language"), |
|
], |
|
outputs="text", |
|
title="CodeT5 Code Assistant", |
|
description="Generate or fetch code snippets using CodeT5 with extended features.", |
|
examples=[ |
|
["generate_code", "Translate Python to Java: def add(a, b): return a + b", None, "python"], |
|
["fetch_code", None, "factorial", "python"] |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|