acecalisto3 commited on
Commit
98ae0b4
·
verified ·
1 Parent(s): 5b177d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -38
app.py CHANGED
@@ -2,89 +2,122 @@ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import logging
4
  import torch
 
 
5
 
6
  # Logging Setup
7
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
8
 
9
- # Model and tokenizer initialization
10
- model_name = "gpt2"
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Set padding token
15
- tokenizer.pad_token = tokenizer.eos_token
 
 
 
16
 
17
- def generate_code(task_description, max_length, temperature, num_return_sequences):
18
  try:
19
  logging.info(f"Generating code with input: {task_description}")
20
  prompt = f"Develop code for the following task: {task_description}"
21
 
22
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
23
-
24
- # Calculate the maximum new tokens
25
  max_new_tokens = max_length - inputs.input_ids.shape[1]
26
 
27
- # Generate the output
28
- output = model.generate(
29
- inputs.input_ids,
30
- max_new_tokens=max_new_tokens, # Use max_new_tokens instead of max_length
31
- num_return_sequences=num_return_sequences,
32
- temperature=temperature,
33
- do_sample=True,
34
- no_repeat_ngram_size=2,
35
- top_k=50,
36
- top_p=0.95,
37
- pad_token_id=tokenizer.eos_token_id, # Explicitly set pad_token_id
38
- attention_mask=inputs.attention_mask, # Pass the attention mask
39
- )
40
 
41
- # Decode the output
42
  codes = [tokenizer.decode(seq, skip_special_tokens=True) for seq in output]
43
-
44
  logging.info("Code generation completed successfully.")
45
  return codes
46
  except Exception as e:
47
  logging.error(f"Error generating code: {e}")
48
  return [f"Error generating code: {e}"]
49
 
 
 
 
 
 
 
 
 
 
50
  def main():
51
  st.set_page_config(page_title="Advanced Code Generator", layout="wide")
52
 
53
  st.title("Advanced Code Generator")
54
  st.markdown("This application generates code based on the given task description using a text-generation model.")
55
 
 
 
 
56
  # Input Section
57
  st.header("Task Description")
58
  task_description = st.text_area("Describe the task for which you need code:", height=150)
59
 
60
  # Options Section
61
  st.header("Options")
62
- max_length = st.slider("Max Length", min_value=50, max_value=1000, value=250, step=50, help="Maximum length of the generated code.")
63
- temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1, help="Controls the creativity of the generated code.")
64
- num_return_sequences = st.slider("Number of Sequences", min_value=1, max_value=5, value=1, step=1, help="Number of code snippets to generate.")
 
 
 
 
65
 
66
  # Generate Code Button
67
  if st.button("Generate Code"):
68
  if task_description.strip():
69
  with st.spinner("Generating code..."):
70
- generated_codes = generate_code(task_description, max_length, temperature, num_return_sequences)
71
  st.header("Generated Code")
72
- for idx, code in enumerate(generated_codes):
73
- st.subheader(f"Generated Code {idx + 1}")
74
- st.code(code, language='python')
75
  else:
76
  st.error("Please enter a task description.")
77
 
78
  # Save Code Section
79
- if 'generated_codes' in locals() and generated_codes:
80
  st.header("Save Code")
81
- selected_code_idx = st.selectbox("Select which code to save:", range(1, len(generated_codes) + 1)) - 1
82
- file_name = st.text_input("Enter file name to save:", value="generated_code.py")
83
- if st.button("Save", key="save_code"):
 
 
 
 
 
84
  if file_name:
85
- with open(file_name, "w") as file:
86
- file.write(generated_codes[selected_code_idx])
87
- st.success(f"Code saved to {file_name}")
 
88
  else:
89
  st.error("Please enter a valid file name.")
90
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import logging
4
  import torch
5
+ import os
6
+ from functools import lru_cache
7
 
8
  # Logging Setup
9
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10
 
11
+ # Cache the model and tokenizer loading
12
+ @lru_cache(maxsize=None)
13
+ def load_model_and_tokenizer(model_name):
14
+ try:
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ model = AutoModelForCausalLM.from_pretrained(model_name)
17
+ tokenizer.pad_token = tokenizer.eos_token
18
+ return model, tokenizer
19
+ except Exception as e:
20
+ logging.error(f"Error loading model and tokenizer: {e}")
21
+ return None, None
22
+
23
+ # Initialize session state
24
+ if 'generated_codes' not in st.session_state:
25
+ st.session_state.generated_codes = []
26
 
27
+ @st.cache_data
28
+ def generate_code(task_description, max_length, temperature, num_return_sequences, model_name):
29
+ model, tokenizer = load_model_and_tokenizer(model_name)
30
+ if model is None or tokenizer is None:
31
+ return ["Error: Failed to load model and tokenizer."]
32
 
 
33
  try:
34
  logging.info(f"Generating code with input: {task_description}")
35
  prompt = f"Develop code for the following task: {task_description}"
36
 
37
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
 
 
38
  max_new_tokens = max_length - inputs.input_ids.shape[1]
39
 
40
+ with torch.no_grad():
41
+ output = model.generate(
42
+ inputs.input_ids,
43
+ max_new_tokens=max_new_tokens,
44
+ num_return_sequences=num_return_sequences,
45
+ temperature=temperature,
46
+ do_sample=True,
47
+ no_repeat_ngram_size=2,
48
+ top_k=50,
49
+ top_p=0.95,
50
+ pad_token_id=tokenizer.eos_token_id,
51
+ attention_mask=inputs.attention_mask,
52
+ )
53
 
 
54
  codes = [tokenizer.decode(seq, skip_special_tokens=True) for seq in output]
 
55
  logging.info("Code generation completed successfully.")
56
  return codes
57
  except Exception as e:
58
  logging.error(f"Error generating code: {e}")
59
  return [f"Error generating code: {e}"]
60
 
61
+ def save_code(code, file_name):
62
+ try:
63
+ with open(file_name, "w") as file:
64
+ file.write(code)
65
+ return True
66
+ except Exception as e:
67
+ logging.error(f"Error saving code: {e}")
68
+ return False
69
+
70
  def main():
71
  st.set_page_config(page_title="Advanced Code Generator", layout="wide")
72
 
73
  st.title("Advanced Code Generator")
74
  st.markdown("This application generates code based on the given task description using a text-generation model.")
75
 
76
+ # Model Selection
77
+ model_name = st.selectbox("Select Model", ["gpt2", "gpt2-medium", "gpt2-large"], help="Choose the model for code generation.")
78
+
79
  # Input Section
80
  st.header("Task Description")
81
  task_description = st.text_area("Describe the task for which you need code:", height=150)
82
 
83
  # Options Section
84
  st.header("Options")
85
+ col1, col2, col3 = st.columns(3)
86
+ with col1:
87
+ max_length = st.slider("Max Length", min_value=50, max_value=1000, value=250, step=50, help="Maximum length of the generated code.")
88
+ with col2:
89
+ temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1, help="Controls the creativity of the generated code.")
90
+ with col3:
91
+ num_return_sequences = st.slider("Number of Sequences", min_value=1, max_value=5, value=1, step=1, help="Number of code snippets to generate.")
92
 
93
  # Generate Code Button
94
  if st.button("Generate Code"):
95
  if task_description.strip():
96
  with st.spinner("Generating code..."):
97
+ st.session_state.generated_codes = generate_code(task_description, max_length, temperature, num_return_sequences, model_name)
98
  st.header("Generated Code")
99
+ for idx, code in enumerate(st.session_state.generated_codes):
100
+ with st.expander(f"Generated Code {idx + 1}", expanded=True):
101
+ st.code(code, language='python')
102
  else:
103
  st.error("Please enter a task description.")
104
 
105
  # Save Code Section
106
+ if st.session_state.generated_codes:
107
  st.header("Save Code")
108
+ selected_code_idx = st.selectbox("Select which code to save:", range(1, len(st.session_state.generated_codes) + 1)) - 1
109
+ col1, col2 = st.columns(2)
110
+ with col1:
111
+ file_name = st.text_input("Enter file name to save:", value="generated_code.py")
112
+ with col2:
113
+ save_button = st.button("Save", key="save_code")
114
+
115
+ if save_button:
116
  if file_name:
117
+ if save_code(st.session_state.generated_codes[selected_code_idx], file_name):
118
+ st.success(f"Code saved to {file_name}")
119
+ else:
120
+ st.error("Failed to save the code. Please try again.")
121
  else:
122
  st.error("Please enter a valid file name.")
123