acecalisto3 commited on
Commit
3ff0de1
·
verified ·
1 Parent(s): ea88422

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -11
app.py CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
2
  from transformers import pipeline
3
  import logging
4
  import torch
 
5
 
6
  # Logging Setup
7
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -20,9 +21,38 @@ def get_model_pipeline(model_name):
20
  logging.error(f"Error loading model pipeline: {e}")
21
  return None
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Function to generate code
24
  @st.cache_data
25
- def generate_code(task_description, max_length, temperature, num_return_sequences, model_name):
26
  code_pipeline = get_model_pipeline(model_name)
27
  if code_pipeline is None:
28
  return ["Error: Failed to load model pipeline."]
@@ -30,14 +60,17 @@ def generate_code(task_description, max_length, temperature, num_return_sequence
30
  try:
31
  logging.info(f"Generating code with input: {task_description}")
32
  prompt = f"Develop code for the following task: {task_description}"
33
-
34
- outputs = code_pipeline(
35
- prompt,
36
- max_length=max_length,
37
- num_return_sequences=num_return_sequences,
38
- temperature=temperature,
39
- truncation=True # Added truncation
40
- )
 
 
 
41
  codes = [output['generated_text'] for output in outputs]
42
 
43
  logging.info("Code generation completed successfully.")
@@ -78,13 +111,15 @@ def main():
78
 
79
  # Options Section
80
  st.header("Options")
81
- col1, col2, col3 = st.columns(3)
82
  with col1:
83
  max_length = st.slider("Max Length", min_value=50, max_value=2048, value=250, step=50, help="Maximum length of the generated code.")
84
  with col2:
85
  temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1, help="Controls the creativity of the generated code.")
86
  with col3:
87
  num_return_sequences = st.slider("Number of Sequences", min_value=1, max_value=5, value=1, step=1, help="Number of code snippets to generate.")
 
 
88
 
89
  # Generate Code Button
90
  if st.button("Generate Code"):
@@ -92,7 +127,7 @@ def main():
92
  # Clear previous generated codes
93
  st.session_state.generated_codes = []
94
  with st.spinner("Generating code..."):
95
- st.session_state.generated_codes = generate_code(task_description, max_length, temperature, num_return_sequences, model_name)
96
  st.header("Generated Code")
97
  for idx, code in enumerate(st.session_state.generated_codes):
98
  with st.expander(f"Generated Code {idx + 1}", expanded=True):
 
2
  from transformers import pipeline
3
  import logging
4
  import torch
5
+ import numpy as np
6
 
7
  # Logging Setup
8
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
21
  logging.error(f"Error loading model pipeline: {e}")
22
  return None
23
 
24
+ # Beam search implementation
25
+ def beam_search(model, prompt, beam_width=3, max_length=20):
26
+ sequences = [[list(prompt), 0.0]]
27
+
28
+ for _ in range(max_length):
29
+ all_candidates = list()
30
+
31
+ for seq, score in sequences:
32
+ if len(seq) > 0 and seq[-1] == model.tokenizer.eos_token_id:
33
+ all_candidates.append((seq, score))
34
+ continue
35
+
36
+ inputs = model.tokenizer(seq, return_tensors='pt')
37
+ outputs = model.model(**inputs)
38
+ logits = outputs.logits[0, -1, :]
39
+ probabilities = torch.nn.functional.softmax(logits, dim=-1).detach().cpu().numpy()
40
+
41
+ candidates = np.argsort(probabilities)[-beam_width:]
42
+
43
+ for candidate in candidates:
44
+ new_seq = seq + [candidate]
45
+ new_score = score + np.log(probabilities[candidate])
46
+ all_candidates.append((new_seq, new_score))
47
+
48
+ ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)
49
+ sequences = ordered[:beam_width]
50
+
51
+ return sequences[0][0]
52
+
53
  # Function to generate code
54
  @st.cache_data
55
+ def generate_code(task_description, max_length, temperature, num_return_sequences, model_name, beam_width=3):
56
  code_pipeline = get_model_pipeline(model_name)
57
  if code_pipeline is None:
58
  return ["Error: Failed to load model pipeline."]
 
60
  try:
61
  logging.info(f"Generating code with input: {task_description}")
62
  prompt = f"Develop code for the following task: {task_description}"
63
+
64
+ # Tokenize prompt for beam search
65
+ inputs = code_pipeline.tokenizer(prompt, return_tensors='pt')
66
+ input_ids = inputs['input_ids'][0].tolist()
67
+
68
+ outputs = []
69
+ for _ in range(num_return_sequences):
70
+ output_tokens = beam_search(code_pipeline, input_ids, beam_width=beam_width, max_length=max_length)
71
+ output_text = code_pipeline.tokenizer.decode(output_tokens, skip_special_tokens=True)
72
+ outputs.append({'generated_text': output_text})
73
+
74
  codes = [output['generated_text'] for output in outputs]
75
 
76
  logging.info("Code generation completed successfully.")
 
111
 
112
  # Options Section
113
  st.header("Options")
114
+ col1, col2, col3, col4 = st.columns(4)
115
  with col1:
116
  max_length = st.slider("Max Length", min_value=50, max_value=2048, value=250, step=50, help="Maximum length of the generated code.")
117
  with col2:
118
  temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1, help="Controls the creativity of the generated code.")
119
  with col3:
120
  num_return_sequences = st.slider("Number of Sequences", min_value=1, max_value=5, value=1, step=1, help="Number of code snippets to generate.")
121
+ with col4:
122
+ beam_width = st.slider("Beam Width", min_value=1, max_value=10, value=3, step=1, help="Beam width for beam search.")
123
 
124
  # Generate Code Button
125
  if st.button("Generate Code"):
 
127
  # Clear previous generated codes
128
  st.session_state.generated_codes = []
129
  with st.spinner("Generating code..."):
130
+ st.session_state.generated_codes = generate_code(task_description, max_length, temperature, num_return_sequences, model_name, beam_width)
131
  st.header("Generated Code")
132
  for idx, code in enumerate(st.session_state.generated_codes):
133
  with st.expander(f"Generated Code {idx + 1}", expanded=True):