acecalisto3 commited on
Commit
7a3c1ad
·
verified ·
1 Parent(s): a531914

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -28,13 +28,13 @@ def generate_code(task_description, max_length, temperature, num_return_sequence
28
  model, tokenizer = load_model_and_tokenizer(model_name)
29
  if model is None or tokenizer is None:
30
  return ["Error: Failed to load model and tokenizer."]
31
-
32
  try:
33
  logging.info(f"Generating code with input: {task_description}")
34
  prompt = f"Develop code for the following task: {task_description}"
35
 
36
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
37
- max_new_tokens = max_length - inputs.input_ids.shape[1]
38
 
39
  with torch.no_grad():
40
  output = model.generate(
@@ -73,7 +73,7 @@ def main():
73
  st.markdown("This application generates code based on the given task description using a text-generation model.")
74
 
75
  # Model Selection
76
- model_name = st.selectbox("Select Model", ["gpt2", "gpt2-medium", "gpt2-large"], help="Choose the model for code generation.")
77
 
78
  # Input Section
79
  st.header("Task Description")
 
28
  model, tokenizer = load_model_and_tokenizer(model_name)
29
  if model is None or tokenizer is None:
30
  return ["Error: Failed to load model and tokenizer."]
31
+
32
  try:
33
  logging.info(f"Generating code with input: {task_description}")
34
  prompt = f"Develop code for the following task: {task_description}"
35
 
36
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
37
+ max_new_tokens = max(max_length - inputs.input_ids.shape[1], 1) # Ensure max_new_tokens is at least 1
38
 
39
  with torch.no_grad():
40
  output = model.generate(
 
73
  st.markdown("This application generates code based on the given task description using a text-generation model.")
74
 
75
  # Model Selection
76
+ model_name = st.selectbox("Select Model", ["EleutherAI/gpt-neo-2.7B", "EleutherAI/gpt-j-6B"], help="Choose the model for code generation.")
77
 
78
  # Input Section
79
  st.header("Task Description")