chandrujobs commited on
Commit
837c657
·
verified ·
1 Parent(s): 928e881

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -8,23 +8,27 @@ def load_model():
8
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
  return tokenizer, model
10
 
11
- # Load the model and tokenizer (cached)
12
- with st.spinner("Loading model..."):
13
- tokenizer, model = load_model()
14
 
15
- # Streamlit UI
16
- st.title("Code Generator with Hugging Face")
17
- st.write("Generate code snippets from natural language prompts!")
18
 
19
  prompt = st.text_area("Enter your coding task:", placeholder="Write a Python function to calculate factorial.")
20
- max_length = st.slider("Select maximum length of generated code:", min_value=20, max_value=200, value=50, step=10)
21
 
22
  if st.button("Generate Code"):
23
  if prompt.strip():
24
  with st.spinner("Generating code..."):
25
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
26
- outputs = model.generate(inputs.input_ids, max_length=max_length, num_beams=4, early_stopping=True)
 
 
 
 
27
  generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
- st.text_area("Generated Code:", generated_code, height=200)
 
 
29
  else:
30
  st.warning("Please enter a prompt!")
 
8
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
  return tokenizer, model
10
 
11
+ # Load model
12
+ tokenizer, model = load_model()
 
13
 
14
+ st.title("Code Generator")
15
+ st.write("Generate code snippets from natural language prompts using CodeT5!")
 
16
 
17
  prompt = st.text_area("Enter your coding task:", placeholder="Write a Python function to calculate factorial.")
18
+ max_length = st.slider("Maximum length of generated code:", 20, 300, 100)
19
 
20
  if st.button("Generate Code"):
21
  if prompt.strip():
22
  with st.spinner("Generating code..."):
23
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
24
+ outputs = model.generate(inputs.input_ids, max_length=max_length, num_beams=5, temperature=0.7, early_stopping=True)
25
+
26
+ st.write("### Debugging: Raw Model Output")
27
+ st.json(outputs.tolist()) # Debugging output
28
+
29
  generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
+
31
+ st.write("### Generated Code:")
32
+ st.code(generated_code, language="python")
33
  else:
34
  st.warning("Please enter a prompt!")