Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,89 +2,122 @@ import streamlit as st
|
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
import logging
|
4 |
import torch
|
|
|
|
|
5 |
|
6 |
# Logging Setup
|
7 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
8 |
|
9 |
-
#
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
16 |
|
17 |
-
def generate_code(task_description, max_length, temperature, num_return_sequences):
|
18 |
try:
|
19 |
logging.info(f"Generating code with input: {task_description}")
|
20 |
prompt = f"Develop code for the following task: {task_description}"
|
21 |
|
22 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
23 |
-
|
24 |
-
# Calculate the maximum new tokens
|
25 |
max_new_tokens = max_length - inputs.input_ids.shape[1]
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
|
41 |
-
# Decode the output
|
42 |
codes = [tokenizer.decode(seq, skip_special_tokens=True) for seq in output]
|
43 |
-
|
44 |
logging.info("Code generation completed successfully.")
|
45 |
return codes
|
46 |
except Exception as e:
|
47 |
logging.error(f"Error generating code: {e}")
|
48 |
return [f"Error generating code: {e}"]
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
def main():
|
51 |
st.set_page_config(page_title="Advanced Code Generator", layout="wide")
|
52 |
|
53 |
st.title("Advanced Code Generator")
|
54 |
st.markdown("This application generates code based on the given task description using a text-generation model.")
|
55 |
|
|
|
|
|
|
|
56 |
# Input Section
|
57 |
st.header("Task Description")
|
58 |
task_description = st.text_area("Describe the task for which you need code:", height=150)
|
59 |
|
60 |
# Options Section
|
61 |
st.header("Options")
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
65 |
|
66 |
# Generate Code Button
|
67 |
if st.button("Generate Code"):
|
68 |
if task_description.strip():
|
69 |
with st.spinner("Generating code..."):
|
70 |
-
generated_codes = generate_code(task_description, max_length, temperature, num_return_sequences)
|
71 |
st.header("Generated Code")
|
72 |
-
for idx, code in enumerate(generated_codes):
|
73 |
-
st.
|
74 |
-
|
75 |
else:
|
76 |
st.error("Please enter a task description.")
|
77 |
|
78 |
# Save Code Section
|
79 |
-
if
|
80 |
st.header("Save Code")
|
81 |
-
selected_code_idx = st.selectbox("Select which code to save:", range(1, len(generated_codes) + 1)) - 1
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
84 |
if file_name:
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
88 |
else:
|
89 |
st.error("Please enter a valid file name.")
|
90 |
|
|
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
import logging
|
4 |
import torch
|
5 |
+
import os
|
6 |
+
from functools import lru_cache
|
7 |
|
8 |
# Logging Setup
|
9 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
10 |
|
11 |
+
# Cache the model and tokenizer loading
|
12 |
+
@lru_cache(maxsize=None)
|
13 |
+
def load_model_and_tokenizer(model_name):
|
14 |
+
try:
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
16 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
17 |
+
tokenizer.pad_token = tokenizer.eos_token
|
18 |
+
return model, tokenizer
|
19 |
+
except Exception as e:
|
20 |
+
logging.error(f"Error loading model and tokenizer: {e}")
|
21 |
+
return None, None
|
22 |
+
|
23 |
+
# Initialize session state
|
24 |
+
if 'generated_codes' not in st.session_state:
|
25 |
+
st.session_state.generated_codes = []
|
26 |
|
27 |
+
@st.cache_data
|
28 |
+
def generate_code(task_description, max_length, temperature, num_return_sequences, model_name):
|
29 |
+
model, tokenizer = load_model_and_tokenizer(model_name)
|
30 |
+
if model is None or tokenizer is None:
|
31 |
+
return ["Error: Failed to load model and tokenizer."]
|
32 |
|
|
|
33 |
try:
|
34 |
logging.info(f"Generating code with input: {task_description}")
|
35 |
prompt = f"Develop code for the following task: {task_description}"
|
36 |
|
37 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
|
|
|
|
38 |
max_new_tokens = max_length - inputs.input_ids.shape[1]
|
39 |
|
40 |
+
with torch.no_grad():
|
41 |
+
output = model.generate(
|
42 |
+
inputs.input_ids,
|
43 |
+
max_new_tokens=max_new_tokens,
|
44 |
+
num_return_sequences=num_return_sequences,
|
45 |
+
temperature=temperature,
|
46 |
+
do_sample=True,
|
47 |
+
no_repeat_ngram_size=2,
|
48 |
+
top_k=50,
|
49 |
+
top_p=0.95,
|
50 |
+
pad_token_id=tokenizer.eos_token_id,
|
51 |
+
attention_mask=inputs.attention_mask,
|
52 |
+
)
|
53 |
|
|
|
54 |
codes = [tokenizer.decode(seq, skip_special_tokens=True) for seq in output]
|
|
|
55 |
logging.info("Code generation completed successfully.")
|
56 |
return codes
|
57 |
except Exception as e:
|
58 |
logging.error(f"Error generating code: {e}")
|
59 |
return [f"Error generating code: {e}"]
|
60 |
|
61 |
+
def save_code(code, file_name):
|
62 |
+
try:
|
63 |
+
with open(file_name, "w") as file:
|
64 |
+
file.write(code)
|
65 |
+
return True
|
66 |
+
except Exception as e:
|
67 |
+
logging.error(f"Error saving code: {e}")
|
68 |
+
return False
|
69 |
+
|
70 |
def main():
|
71 |
st.set_page_config(page_title="Advanced Code Generator", layout="wide")
|
72 |
|
73 |
st.title("Advanced Code Generator")
|
74 |
st.markdown("This application generates code based on the given task description using a text-generation model.")
|
75 |
|
76 |
+
# Model Selection
|
77 |
+
model_name = st.selectbox("Select Model", ["gpt2", "gpt2-medium", "gpt2-large"], help="Choose the model for code generation.")
|
78 |
+
|
79 |
# Input Section
|
80 |
st.header("Task Description")
|
81 |
task_description = st.text_area("Describe the task for which you need code:", height=150)
|
82 |
|
83 |
# Options Section
|
84 |
st.header("Options")
|
85 |
+
col1, col2, col3 = st.columns(3)
|
86 |
+
with col1:
|
87 |
+
max_length = st.slider("Max Length", min_value=50, max_value=1000, value=250, step=50, help="Maximum length of the generated code.")
|
88 |
+
with col2:
|
89 |
+
temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1, help="Controls the creativity of the generated code.")
|
90 |
+
with col3:
|
91 |
+
num_return_sequences = st.slider("Number of Sequences", min_value=1, max_value=5, value=1, step=1, help="Number of code snippets to generate.")
|
92 |
|
93 |
# Generate Code Button
|
94 |
if st.button("Generate Code"):
|
95 |
if task_description.strip():
|
96 |
with st.spinner("Generating code..."):
|
97 |
+
st.session_state.generated_codes = generate_code(task_description, max_length, temperature, num_return_sequences, model_name)
|
98 |
st.header("Generated Code")
|
99 |
+
for idx, code in enumerate(st.session_state.generated_codes):
|
100 |
+
with st.expander(f"Generated Code {idx + 1}", expanded=True):
|
101 |
+
st.code(code, language='python')
|
102 |
else:
|
103 |
st.error("Please enter a task description.")
|
104 |
|
105 |
# Save Code Section
|
106 |
+
if st.session_state.generated_codes:
|
107 |
st.header("Save Code")
|
108 |
+
selected_code_idx = st.selectbox("Select which code to save:", range(1, len(st.session_state.generated_codes) + 1)) - 1
|
109 |
+
col1, col2 = st.columns(2)
|
110 |
+
with col1:
|
111 |
+
file_name = st.text_input("Enter file name to save:", value="generated_code.py")
|
112 |
+
with col2:
|
113 |
+
save_button = st.button("Save", key="save_code")
|
114 |
+
|
115 |
+
if save_button:
|
116 |
if file_name:
|
117 |
+
if save_code(st.session_state.generated_codes[selected_code_idx], file_name):
|
118 |
+
st.success(f"Code saved to {file_name}")
|
119 |
+
else:
|
120 |
+
st.error("Failed to save the code. Please try again.")
|
121 |
else:
|
122 |
st.error("Please enter a valid file name.")
|
123 |
|