acecalisto3 commited on
Commit
0717322
·
verified ·
1 Parent(s): 2a4e65c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -110
app.py CHANGED
@@ -19,6 +19,7 @@ from sentence_transformers import SentenceTransformer
19
  import faiss
20
  import numpy as np
21
  from PIL import Image
 
22
 
23
  # Configure logging
24
  logging.basicConfig(
@@ -36,26 +37,12 @@ DEFAULT_PORT = 7860
36
  MODEL_CACHE_DIR = Path("model_cache")
37
  TEMPLATE_DIR = Path("templates")
38
  TEMP_DIR = Path("temp")
39
- DATABASE_PATH = Path("code_database.json") #Path for our simple database
40
-
41
 
42
  # Ensure directories exist
43
  for directory in [MODEL_CACHE_DIR, TEMPLATE_DIR, TEMP_DIR]:
44
  directory.mkdir(exist_ok=True, parents=True)
45
 
46
-
47
- @dataclass
48
- class Template:
49
- code: str
50
- description: str
51
- components: List[str]
52
- metadata: Dict[str, Any] = field(default_factory=dict)
53
- version: str = "1.0"
54
-
55
- class TemplateManager:
56
- # ... (TemplateManager remains the same) ...
57
-
58
-
59
  class RAGSystem:
60
  def __init__(self, model_name: str = "gpt2", device: str = "cuda" if torch.cuda.is_available() else "cpu", embedding_model="all-mpnet-base-v2"):
61
  try:
@@ -65,89 +52,91 @@ class RAGSystem:
65
  self.pipe = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer, device=self.device)
66
  self.embedding_model = SentenceTransformer(embedding_model)
67
  self.load_database()
 
68
  except Exception as e:
69
- logger.error(f"Error loading language model or embedding model: {e}. Falling back to placeholder generation.")
70
  self.pipe = None
71
  self.embedding_model = None
72
  self.code_embeddings = None
73
 
74
-
75
  def load_database(self):
76
- """Loads or creates the code database"""
77
- if DATABASE_PATH.exists():
78
- try:
79
- with open(DATABASE_PATH, 'r', encoding='utf-8') as f:
80
- self.database = json.load(f)
81
- self.code_embeddings = np.array(self.database['embeddings'])
82
- logger.info("Loaded code database from file")
83
- except (json.JSONDecodeError, KeyError) as e:
84
- logger.error(f"Error loading code database: {e}. Creating new database.")
85
- self.database = {'codes': [], 'embeddings': []}
86
- self.code_embeddings = np.array([])
87
 
88
- else:
89
- logger.info("Code database does not exist. Creating new database.")
90
- self.database = {'codes': [], 'embeddings': []}
91
- self.code_embeddings = np.array([])
92
 
93
- if self.embedding_model and len(self.database['codes']) != len(self.database['embeddings']):
94
- logger.warning("Mismatch between number of codes and embeddings, rebuilding embeddings")
95
- self.rebuild_embeddings()
96
- elif self.embedding_model is None:
97
- logger.warning("Embeddings are not supported in this context. ")
98
- #Index the embeddings for efficient searching
99
- if len(self.code_embeddings) > 0 and self.embedding_model:
100
- self.index = faiss.IndexFlatL2(self.code_embeddings.shape[1]) #L2 distance
101
- self.index.add(self.code_embeddings)
 
102
 
103
  def add_to_database(self, code: str):
104
  """Adds a code snippet to the database"""
105
  try:
106
- embedding = self.embedding_model.encode(code)
107
- self.database['codes'].append(code)
108
- self.database['embeddings'].append(embedding.tolist())
109
- self.code_embeddings = np.vstack((self.code_embeddings, embedding))
110
- self.index.add(np.array([embedding])) # update FAISS index
111
- self.save_database()
112
- logger.info(f"Added code snippet to database. Total size:{len(self.database['codes'])}")
113
  except Exception as e:
114
- logger.error(f"Error adding to database: {e}")
115
-
116
 
 
117
  def save_database(self):
118
- """Saves the database to a file"""
119
- try:
120
- with open(DATABASE_PATH, 'w', encoding='utf-8') as f:
121
- json.dump(self.database, f, indent=2)
122
- logger.info(f"Saved database to {DATABASE_PATH}")
123
- except Exception as e:
124
- logger.error(f"Error saving database: {e}")
125
 
126
  def rebuild_embeddings(self):
127
- """rebuilds embeddings from the codes"""
128
- try:
129
- embeddings = self.embedding_model.encode(self.database['codes'])
130
- self.code_embeddings = embeddings
131
- self.database['embeddings'] = embeddings.tolist()
132
- self.index = faiss.IndexFlatL2(embeddings.shape[1]) #L2 distance
133
- self.index.add(embeddings)
134
- self.save_database()
135
- logger.info("Rebuilt and saved embeddings to the database")
136
- except Exception as e:
137
- logger.error(f"Error rebuilding embeddings: {e}")
138
-
139
 
140
  def retrieve_similar_code(self, description: str, top_k: int = 3) -> List[str]:
141
- """Retrieves similar code snippets from the database"""
142
- if self.embedding_model is None:
143
- return []
144
- try:
145
- embedding = self.embedding_model.encode(description)
146
- D, I = self.index.search(np.array([embedding]), top_k)
147
- return [self.database['codes'][i] for i in I[0]]
148
- except Exception as e:
149
- logger.error(f"Error retrieving similar code: {e}")
150
- return []
 
 
151
 
152
  def generate_code(self, description: str, template_code: str) -> str:
153
  retrieved_codes = self.retrieve_similar_code(description)
@@ -156,11 +145,13 @@ class RAGSystem:
156
  try:
157
  generated_text = self.pipe(prompt, max_length=500, num_return_sequences=1)[0]['generated_text']
158
  generated_code = generated_text.split("Generated Code:")[1].strip().split('```')[0]
 
159
  return generated_code
160
  except Exception as e:
161
  logger.error(f"Error generating code with language model: {e}. Returning template code.")
162
  return template_code
163
  else:
 
164
  return f"# Placeholder code generation. Description: {description}\n{template_code}"
165
 
166
  def generate_interface(self, screenshot: Optional[Image.Image], description: str) -> str:
@@ -173,16 +164,23 @@ class RAGSystem:
173
  try:
174
  generated_text = self.pipe(prompt, max_length=500, num_return_sequences=1)[0]['generated_text']
175
  generated_code = generated_text.split("```")[1].strip()
 
176
  return generated_code
177
  except Exception as e:
178
  logger.error(f"Error generating interface with language model: {e}. Returning placeholder.")
179
  return "import gradio as gr\n\ndemo = gr.Interface(fn=lambda x:x, inputs='text', outputs='text')\ndemo.launch()"
180
  else:
 
181
  return "import gradio as gr\n\ndemo = gr.Interface(fn=lambda x:x, inputs='text', outputs='text')\ndemo.launch()"
182
 
183
  class PreviewManager:
184
- # ... (PreviewManager remains largely the same) ...
 
185
 
 
 
 
 
186
 
187
  class GradioInterface:
188
  def __init__(self):
@@ -194,23 +192,18 @@ class GradioInterface:
194
 
195
  def _extract_components(self, code: str) -> List[str]:
196
  """Extract components from the code."""
197
- # This logic should analyze the code and extract components.
198
- # For example, you might look for function definitions, classes, etc.
199
  components = []
200
- # Simple regex to find function definitions
201
- function_matches = re.findall(r'def (\w+)', code)
202
- components.extend(function_matches)
203
-
204
- # Simple regex to find class definitions
205
  class_matches = re.findall(r'class (\w+)', code)
206
  components.extend(class_matches)
207
-
208
- # You can add more sophisticated logic here as needed
209
  return components
210
 
211
  def _get_template_choices(self) -> List[str]:
212
  """Get available template choices."""
213
- return list(self.template_manager.templates.keys())
 
 
214
 
215
  def launch(self, **kwargs):
216
  with gr.Blocks() as interface:
@@ -221,37 +214,35 @@ class GradioInterface:
221
  template_choice = gr.Dropdown(label="Select Template", choices=self._get_template_choices(), value=None)
222
  save_button = gr.Button("Save as Template")
223
 
224
- # Generate code button action
225
  generate_button.click(
226
  fn=self.generate_code,
227
- inputs=description_input,
228
  outputs=code_output
229
  )
230
 
231
- # Save template button action
232
  save_button.click(
233
  fn=self.save_template,
234
  inputs=[code_output, template_choice, description_input],
235
  outputs=code_output
236
  )
237
 
238
- # Additional UI elements can be added here
239
  gr.Markdown("### Preview")
240
  preview_output = gr.Textbox(label="Preview", interactive=False)
241
- self.preview_manager.update_preview(code_output) # Update preview with generated code
242
 
243
- # Update preview when code is generated
244
  generate_button.click(
245
  fn=lambda code: self.preview_manager.update_preview(code),
246
  inputs=code_output,
247
  outputs=preview_output
248
  )
249
 
 
250
  interface.launch(**kwargs)
251
 
252
- def generate_code(self, description: str) -> str:
253
- """Generate code based on the description."""
254
- template_code = "" # Placeholder for template code
 
255
  return self.rag_system.generate_code(description, template_code)
256
 
257
  def save_template(self, code: str, name: str, description: str) -> str:
@@ -261,25 +252,16 @@ class GradioInterface:
261
  template = Template(code=code, description=description, components=components)
262
  if self.template_manager.save_template(name, template):
263
  self.rag_system.add_to_database(code) # Add code to the database
 
264
  return f"✅ Template '{name}' saved successfully."
265
  else:
 
266
  return "❌ Failed to save template."
267
  except Exception as e:
268
  logger.error(f"Error saving template: {e}")
269
  return f"❌ Error saving template: {str(e)}"
270
 
271
-
272
  def main():
273
- # Configure logging
274
- logging.basicConfig(
275
- level=logging.INFO,
276
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
277
- handlers=[
278
- logging.StreamHandler(),
279
- logging.FileHandler('gradio_builder.log')
280
- ]
281
- )
282
- logger = logging.getLogger(__name__)
283
  logger.info("=== Application Startup ===")
284
 
285
  try:
 
19
  import faiss
20
  import numpy as np
21
  from PIL import Image
22
+ from templates import TemplateManager, Template # Import TemplateManager and Template
23
 
24
  # Configure logging
25
  logging.basicConfig(
 
37
  MODEL_CACHE_DIR = Path("model_cache")
38
  TEMPLATE_DIR = Path("templates")
39
  TEMP_DIR = Path("temp")
40
+ DATABASE_PATH = Path("code_database.json") # Path for our simple database
 
41
 
42
  # Ensure directories exist
43
  for directory in [MODEL_CACHE_DIR, TEMPLATE_DIR, TEMP_DIR]:
44
  directory.mkdir(exist_ok=True, parents=True)
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  class RAGSystem:
47
  def __init__(self, model_name: str = "gpt2", device: str = "cuda" if torch.cuda.is_available() else "cpu", embedding_model="all-mpnet-base-v2"):
48
  try:
 
52
  self.pipe = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer, device=self.device)
53
  self.embedding_model = SentenceTransformer(embedding_model)
54
  self.load_database()
55
+ logger.info("RAG system initialized successfully.")
56
  except Exception as e:
57
+ logger.error(f"Error loading language model or embedding model: {e}. Falling back to placeholder generation.")
58
  self.pipe = None
59
  self.embedding_model = None
60
  self.code_embeddings = None
61
 
 
62
  def load_database(self):
63
+ """Loads or creates the code database"""
64
+ if DATABASE_PATH.exists():
65
+ try:
66
+ with open(DATABASE_PATH, 'r', encoding='utf-8') as f:
67
+ self.database = json.load(f)
68
+ self.code_embeddings = np.array(self.database['embeddings'])
69
+ logger.info("Loaded code database from file.")
70
+ except (json.JSONDecodeError, KeyError) as e:
71
+ logger.error(f"Error loading code database: {e}. Creating new database.")
72
+ self.database = {'codes': [], 'embeddings': []}
73
+ self.code_embeddings = np.array([])
74
 
75
+ else:
76
+ logger.info("Code database does not exist. Creating new database.")
77
+ self.database = {'codes': [], 'embeddings': []}
78
+ self.code_embeddings = np.array([])
79
 
80
+ if self.embedding_model and len(self.database['codes']) != len(self.database['embeddings']):
81
+ logger.warning("Mismatch between number of codes and embeddings, rebuilding embeddings.")
82
+ self.rebuild_embeddings()
83
+ elif self.embedding_model is None:
84
+ logger.warning("Embeddings are not supported in this context.")
85
+
86
+ # Index the embeddings for efficient searching
87
+ if len(self.code_embeddings) > 0 and self.embedding_model:
88
+ self.index = faiss.IndexFlatL2(self.code_embeddings.shape[1]) # L2 distance
89
+ self.index.add(self.code_embeddings)
90
 
91
  def add_to_database(self, code: str):
92
  """Adds a code snippet to the database"""
93
  try:
94
+ embedding = self.embedding_model.encode(code)
95
+ self.database['codes'].append(code)
96
+ self.database['embeddings'].append(embedding.tolist())
97
+ self.code_embeddings = np.vstack((self.code_embeddings, embedding))
98
+ self.index.add(np.array([embedding])) # update FAISS index
99
+ self.save_database()
100
+ logger.info(f"Added code snippet to database. Total size: {len(self.database['codes'])}.")
101
  except Exception as e:
102
+ logger.error(f"Error adding to database: {e}")
 
103
 
104
+ ```python
105
  def save_database(self):
106
+ """Saves the database to a file"""
107
+ try:
108
+ with open(DATABASE_PATH, 'w', encoding='utf-8') as f:
109
+ json.dump(self.database, f, indent=2)
110
+ logger.info(f"Saved database to {DATABASE_PATH}.")
111
+ except Exception as e:
112
+ logger.error(f"Error saving database: {e}")
113
 
114
  def rebuild_embeddings(self):
115
+ """Rebuilds embeddings from the codes"""
116
+ try:
117
+ embeddings = self.embedding_model.encode(self.database['codes'])
118
+ self.code_embeddings = embeddings
119
+ self.database['embeddings'] = embeddings.tolist()
120
+ self.index = faiss.IndexFlatL2(embeddings.shape[1]) # L2 distance
121
+ self.index.add(embeddings)
122
+ self.save_database()
123
+ logger.info("Rebuilt and saved embeddings to the database.")
124
+ except Exception as e:
125
+ logger.error(f"Error rebuilding embeddings: {e}")
 
126
 
127
  def retrieve_similar_code(self, description: str, top_k: int = 3) -> List[str]:
128
+ """Retrieves similar code snippets from the database"""
129
+ if self.embedding_model is None:
130
+ logger.warning("Embedding model is not available. Cannot retrieve similar code.")
131
+ return []
132
+ try:
133
+ embedding = self.embedding_model.encode(description)
134
+ D, I = self.index.search(np.array([embedding]), top_k)
135
+ logger.info(f"Retrieved {top_k} similar code snippets for description: {description}.")
136
+ return [self.database['codes'][i] for i in I[0]]
137
+ except Exception as e:
138
+ logger.error(f"Error retrieving similar code: {e}")
139
+ return []
140
 
141
  def generate_code(self, description: str, template_code: str) -> str:
142
  retrieved_codes = self.retrieve_similar_code(description)
 
145
  try:
146
  generated_text = self.pipe(prompt, max_length=500, num_return_sequences=1)[0]['generated_text']
147
  generated_code = generated_text.split("Generated Code:")[1].strip().split('```')[0]
148
+ logger.info("Code generated successfully.")
149
  return generated_code
150
  except Exception as e:
151
  logger.error(f"Error generating code with language model: {e}. Returning template code.")
152
  return template_code
153
  else:
154
+ logger.warning("Text generation pipeline is not available. Returning placeholder code.")
155
  return f"# Placeholder code generation. Description: {description}\n{template_code}"
156
 
157
  def generate_interface(self, screenshot: Optional[Image.Image], description: str) -> str:
 
164
  try:
165
  generated_text = self.pipe(prompt, max_length=500, num_return_sequences=1)[0]['generated_text']
166
  generated_code = generated_text.split("```")[1].strip()
167
+ logger.info("Interface code generated successfully.")
168
  return generated_code
169
  except Exception as e:
170
  logger.error(f"Error generating interface with language model: {e}. Returning placeholder.")
171
  return "import gradio as gr\n\ndemo = gr.Interface(fn=lambda x:x, inputs='text', outputs='text')\ndemo.launch()"
172
  else:
173
+ logger.warning("Text generation pipeline is not available. Returning placeholder interface code.")
174
  return "import gradio as gr\n\ndemo = gr.Interface(fn=lambda x:x, inputs='text', outputs='text')\ndemo.launch()"
175
 
176
  class PreviewManager:
177
+ def __init__(self):
178
+ self.preview_code = ""
179
 
180
+ def update_preview(self, code: str):
181
+ """Update the preview with the generated code."""
182
+ self.preview_code = code
183
+ logger.info("Preview updated with new code.")
184
 
185
  class GradioInterface:
186
  def __init__(self):
 
192
 
193
  def _extract_components(self, code: str) -> List[str]:
194
  """Extract components from the code."""
 
 
195
  components = []
196
+ function_matches = re.findall(r'def (\w+)', code components.extend(function_matches)
 
 
 
 
197
  class_matches = re.findall(r'class (\w+)', code)
198
  components.extend(class_matches)
199
+ logger.info(f"Extracted components: {components}")
 
200
  return components
201
 
202
  def _get_template_choices(self) -> List[str]:
203
  """Get available template choices."""
204
+ choices = list(self.template_manager.templates.keys())
205
+ logger.info(f"Available template choices: {choices}")
206
+ return choices
207
 
208
  def launch(self, **kwargs):
209
  with gr.Blocks() as interface:
 
214
  template_choice = gr.Dropdown(label="Select Template", choices=self._get_template_choices(), value=None)
215
  save_button = gr.Button("Save as Template")
216
 
 
217
  generate_button.click(
218
  fn=self.generate_code,
219
+ inputs=[description_input, template_choice],
220
  outputs=code_output
221
  )
222
 
 
223
  save_button.click(
224
  fn=self.save_template,
225
  inputs=[code_output, template_choice, description_input],
226
  outputs=code_output
227
  )
228
 
 
229
  gr.Markdown("### Preview")
230
  preview_output = gr.Textbox(label="Preview", interactive=False)
231
+ self.preview_manager.update_preview(code_output)
232
 
 
233
  generate_button.click(
234
  fn=lambda code: self.preview_manager.update_preview(code),
235
  inputs=code_output,
236
  outputs=preview_output
237
  )
238
 
239
+ logger.info("Launching Gradio interface.")
240
  interface.launch(**kwargs)
241
 
242
+ def generate_code(self, description: str, template_choice: Optional[str]) -> str:
243
+ """Generate code based on the description and selected template."""
244
+ template_code = self.template_manager.get_template(template_choice) if template_choice else "" # Get template code if selected
245
+ logger.info(f"Generating code for description: {description} with template: {template_choice}")
246
  return self.rag_system.generate_code(description, template_code)
247
 
248
  def save_template(self, code: str, name: str, description: str) -> str:
 
252
  template = Template(code=code, description=description, components=components)
253
  if self.template_manager.save_template(name, template):
254
  self.rag_system.add_to_database(code) # Add code to the database
255
+ logger.info(f"Template '{name}' saved successfully.")
256
  return f"✅ Template '{name}' saved successfully."
257
  else:
258
+ logger.error("Failed to save template.")
259
  return "❌ Failed to save template."
260
  except Exception as e:
261
  logger.error(f"Error saving template: {e}")
262
  return f"❌ Error saving template: {str(e)}"
263
 
 
264
  def main():
 
 
 
 
 
 
 
 
 
 
265
  logger.info("=== Application Startup ===")
266
 
267
  try: