Artificial-superintelligence commited on
Commit
d4a5735
Β·
verified Β·
1 Parent(s): 7d9ae2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +557 -209
app.py CHANGED
@@ -3,256 +3,604 @@ import google.generativeai as genai
3
  import requests
4
  import subprocess
5
  import os
6
- import pylint
7
  import pandas as pd
 
8
  from sklearn.model_selection import train_test_split
9
- from sklearn.ensemble import RandomForestClassifier
10
- import git
11
- import spacy
12
- from spacy.lang.en import English
13
- import boto3
14
- import unittest
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Configure the Gemini API
17
  genai.configure(api_key=st.secrets["GOOGLE_API_KEY"])
18
 
19
  # Create the model with optimized parameters and enhanced system instructions
20
  generation_config = {
21
- "temperature": 0.6, # Lower temperature for more deterministic responses
22
- "top_p": 0.8, # Adjusted for better diversity
23
- "top_k": 30, # Increased for more diverse tokens
24
- "max_output_tokens": 16384, # Increased for longer responses
25
  }
26
 
27
  model = genai.GenerativeModel(
28
  model_name="gemini-1.5-pro",
29
  generation_config=generation_config,
30
  system_instruction="""
31
- You are Ath, a highly knowledgeable code assistant. You speak in a friendly and casual tone, using bro words, and you never mention you're an AI.
32
- Your responses should contain optimized, secure, and high-quality code only, without explanations. You are designed to provide accurate, efficient, and cutting-edge code solutions.
33
  """
34
  )
35
  chat_session = model.start_chat(history=[])
36
 
37
- def generate_response(user_input):
38
- """Generate a response from the AI model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  try:
40
  response = chat_session.send_message(user_input)
41
  return response.text
42
  except Exception as e:
43
- return f"Error: {e}"
44
-
45
- def optimize_code(code):
46
- """Optimize the generated code using static analysis tools."""
47
- with open("temp_code.py", "w") as file:
48
- file.write(code)
49
- result = subprocess.run(["pylint", "temp_code.py"], capture_output=True, text=True)
50
- os.remove("temp_code.py")
51
- return code
52
-
53
- def fetch_from_github(query):
54
- """Fetch code snippets from GitHub."""
55
- # Placeholder for fetching code snippets from GitHub
56
- return ""
57
-
58
- def interact_with_api(api_url):
59
- """Interact with external APIs."""
60
- response = requests.get(api_url)
61
- return response.json()
62
-
63
- def train_ml_model(code_data):
64
- """Train a machine learning model to predict code improvements."""
65
- df = pd.DataFrame(code_data)
66
- X = df.drop('target', axis=1)
67
- y = df['target']
68
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
69
- model = RandomForestClassifier()
70
- model.fit(X_train, y_train)
71
- return model
72
-
73
- def handle_error(error):
74
- """Handle errors and log them."""
75
- st.error(f"An error occurred: {error}")
76
-
77
- def initialize_git_repo(repo_path):
78
- """Initialize or check the existence of a Git repository."""
79
- if not os.path.exists(repo_path):
80
- os.makedirs(repo_path)
81
- if not os.path.exists(os.path.join(repo_path, '.git')):
82
- repo = git.Repo.init(repo_path)
83
  else:
84
- repo = git.Repo(repo_path)
85
- return repo
86
-
87
- def integrate_with_git(repo_path, code):
88
- """Integrate the generated code with a Git repository."""
89
- repo = initialize_git_repo(repo_path)
90
- with open(os.path.join(repo_path, "generated_code.py"), "w") as file:
91
- file.write(code)
92
- repo.index.add(["generated_code.py"])
93
- repo.index.commit("Added generated code")
94
-
95
- def process_user_input(user_input):
96
- """Process user input using advanced natural language processing."""
97
- nlp = English()
98
- doc = nlp(user_input)
99
- return doc
100
-
101
- def interact_with_cloud_services(service_name, action, params):
102
- """Interact with cloud services using boto3."""
103
- client = boto3.client(service_name)
104
- response = getattr(client, action)(**params)
105
- return response
106
-
107
- def run_tests():
108
- """Run automated tests using unittest."""
109
- # Ensure the tests directory is importable
110
- tests_dir = os.path.join(os.getcwd(), 'tests')
111
- if not os.path.exists(tests_dir):
112
- os.makedirs(tests_dir)
113
- init_file = os.path.join(tests_dir, '__init__.py')
114
- if not os.path.exists(init_file):
115
- with open(init_file, 'w') as f:
116
- f.write('')
117
 
118
- test_suite = unittest.TestLoader().discover(tests_dir)
119
- test_runner = unittest.TextTestRunner()
120
- test_result = test_runner.run(test_suite)
121
- return test_result
122
 
123
- # Streamlit UI setup
124
- st.set_page_config(page_title="Sleek AI Code Assistant", page_icon="πŸ’»", layout="wide")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- st.markdown("""
127
- <style>
128
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600;700&display=swap');
129
 
130
- body {
131
- font-family: 'Inter', sans-serif;
132
- background-color: #f0f4f8;
133
- color: #1a202c;
134
- }
135
- .stApp {
136
- max-width: 1000px;
137
- margin: 0 auto;
138
- padding: 2rem;
139
- }
140
- .main-container {
141
- background: #ffffff;
142
- border-radius: 16px;
143
- padding: 2rem;
144
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
145
- }
146
- h1 {
147
- font-size: 2.5rem;
148
- font-weight: 700;
149
- color: #2d3748;
150
- text-align: center;
151
- margin-bottom: 1rem;
152
- }
153
- .subtitle {
154
- font-size: 1.1rem;
155
- text-align: center;
156
- color: #4a5568;
157
- margin-bottom: 2rem;
158
- }
159
- .stTextArea textarea {
160
- border: 2px solid #e2e8f0;
161
- border-radius: 8px;
162
- font-size: 1rem;
163
- padding: 0.75rem;
164
- transition: all 0.3s ease;
165
- }
166
- .stTextArea textarea:focus {
167
- border-color: #4299e1;
168
- box-shadow: 0 0 0 3px rgba(66, 153, 225, 0.5);
169
- }
170
- .stButton button {
171
- background-color: #4299e1;
172
- color: white;
173
- border: none;
174
- border-radius: 8px;
175
- font-size: 1.1rem;
176
- font-weight: 600;
177
- padding: 0.75rem 2rem;
178
- transition: all 0.3s ease;
179
- width: 100%;
180
- }
181
- .stButton button:hover {
182
- background-color: #3182ce;
183
- }
184
- .output-container {
185
- background: #f7fafc;
186
- border-radius: 8px;
187
- padding: 1rem;
188
- margin-top: 2rem;
189
- }
190
- .code-block {
191
- background-color: #2d3748;
192
- color: #e2e8f0;
193
- font-family: 'Fira Code', monospace;
194
- font-size: 0.9rem;
195
- border-radius: 8px;
196
- padding: 1rem;
197
- margin-top: 1rem;
198
- overflow-x: auto;
199
- }
200
- .stAlert {
201
- background-color: #ebf8ff;
202
- color: #2b6cb0;
203
- border-radius: 8px;
204
- border: none;
205
- padding: 0.75rem 1rem;
206
  }
207
- .stSpinner {
208
- color: #4299e1;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  </style>
211
  """, unsafe_allow_html=True)
212
 
213
  st.markdown('<div class="main-container">', unsafe_allow_html=True)
214
- st.title("πŸ’» Sleek AI Code Assistant")
215
- st.markdown('<p class="subtitle">Powered by Google Gemini</p>', unsafe_allow_html=True)
216
 
217
- prompt = st.text_area("What code can I help you with today?", height=120)
 
 
 
 
218
 
219
- if st.button("Generate Code"):
 
 
 
 
 
 
220
  if prompt.strip() == "":
221
- st.error("Please enter a valid prompt.")
222
  else:
223
- with st.spinner("Generating code..."):
224
- try:
225
- processed_input = process_user_input(prompt)
226
- completed_text = generate_response(processed_input.text)
227
- if "Error" in completed_text:
228
- handle_error(completed_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  else:
230
- optimized_code = optimize_code(completed_text)
231
- st.success("Code generated and optimized successfully!")
232
-
233
- st.markdown('<div class="output-container">', unsafe_allow_html=True)
234
- st.markdown('<div class="code-block">', unsafe_allow_html=True)
235
- st.code(optimized_code)
236
- st.markdown('</div>', unsafe_allow_html=True)
237
- st.markdown('</div>', unsafe_allow_html=True)
238
-
239
- # Integrate with Git
240
- repo_path = "./repo" # Replace with your repository path
241
- integrate_with_git(repo_path, optimized_code)
242
-
243
- # Run automated tests
244
- test_result = run_tests()
245
- if test_result.wasSuccessful():
246
- st.success("All tests passed successfully!")
247
- else:
248
- st.error("Some tests failed. Please check the code.")
249
- except Exception as e:
250
- handle_error(e)
 
 
 
251
 
252
  st.markdown("""
253
  <div style='text-align: center; margin-top: 2rem; color: #4a5568;'>
254
- Created with ❀️ by Your Sleek AI Code Assistant
255
  </div>
256
  """, unsafe_allow_html=True)
257
 
258
- st.markdown('</div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import requests
4
  import subprocess
5
  import os
 
6
  import pandas as pd
7
+ import numpy as np
8
  from sklearn.model_selection import train_test_split
9
+ from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
10
+ from sklearn.svm import SVC
11
+ from sklearn.neural_network import MLPClassifier
12
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.optim as optim
16
+ from transformers import AutoTokenizer, AutoModel, pipeline, GPT2LMHeadModel, GPT2Tokenizer
17
+ import ast
18
+ import networkx as nx
19
+ import matplotlib.pyplot as plt
20
+ import re
21
+ import javalang
22
+ import clang.cindex
23
+ import radon.metrics as radon_metrics
24
+ import radon.complexity as radon_complexity
25
+ import black
26
+ import isort
27
+ import autopep8
28
+ from typing import List, Dict, Any
29
+ import joblib
30
+ from fastapi import FastAPI
31
+ from pydantic import BaseModel
32
+ import uvicorn
33
 
34
  # Configure the Gemini API
35
  genai.configure(api_key=st.secrets["GOOGLE_API_KEY"])
36
 
37
  # Create the model with optimized parameters and enhanced system instructions
38
  generation_config = {
39
+ "temperature": 0.7,
40
+ "top_p": 0.9,
41
+ "top_k": 40,
42
+ "max_output_tokens": 32768,
43
  }
44
 
45
  model = genai.GenerativeModel(
46
  model_name="gemini-1.5-pro",
47
  generation_config=generation_config,
48
  system_instruction="""
49
+ You are Ath, an extremely advanced code assistant with deep expertise in AI, machine learning, software engineering, and multiple programming languages. You provide cutting-edge, optimized, and secure code solutions across various domains. Use your vast knowledge to generate high-quality code, perform advanced analyses, and offer insightful optimizations. Adapt your language and explanations based on the user's expertise level. Incorporate the latest advancements in AI and software development to provide state-of-the-art solutions.
 
50
  """
51
  )
52
  chat_session = model.start_chat(history=[])
53
 
54
+ # Load pre-trained models for code understanding and generation
55
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
56
+ codebert_model = AutoModel.from_pretrained("microsoft/codebert-base")
57
+ code_generation_model = pipeline("text-generation", model="EleutherAI/gpt-neo-2.7B")
58
+
59
+ # Load GPT-2 for more advanced text generation
60
+ gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2-large")
61
+ gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large")
62
+
63
+ class AdvancedCodeImprovement(nn.Module):
64
+ def __init__(self, input_dim):
65
+ super(AdvancedCodeImprovement, self).__init__()
66
+ self.lstm = nn.LSTM(input_dim, 512, num_layers=2, batch_first=True, bidirectional=True)
67
+ self.attention = nn.MultiheadAttention(1024, 8)
68
+ self.fc1 = nn.Linear(1024, 512)
69
+ self.fc2 = nn.Linear(512, 256)
70
+ self.fc3 = nn.Linear(256, 128)
71
+ self.fc4 = nn.Linear(128, 64)
72
+ self.fc5 = nn.Linear(64, 32)
73
+ self.fc6 = nn.Linear(32, 8) # Extended classification: style, efficiency, security, maintainability, scalability, readability, testability, modularity
74
+
75
+ def forward(self, x):
76
+ x, _ = self.lstm(x)
77
+ x, _ = self.attention(x, x, x)
78
+ x = x.mean(dim=1) # Global average pooling
79
+ x = torch.relu(self.fc1(x))
80
+ x = torch.relu(self.fc2(x))
81
+ x = torch.relu(self.fc3(x))
82
+ x = torch.relu(self.fc4(x))
83
+ x = torch.relu(self.fc5(x))
84
+ return torch.sigmoid(self.fc6(x))
85
+
86
+ code_improvement_model = AdvancedCodeImprovement(768) # 768 is BERT's output dimension
87
+ optimizer = optim.Adam(code_improvement_model.parameters())
88
+ criterion = nn.BCELoss()
89
+
90
+ # Load pre-trained code improvement model
91
+ if os.path.exists("code_improvement_model.pth"):
92
+ code_improvement_model.load_state_dict(torch.load("code_improvement_model.pth"))
93
+ code_improvement_model.eval()
94
+
95
+ def generate_response(user_input: str) -> str:
96
  try:
97
  response = chat_session.send_message(user_input)
98
  return response.text
99
  except Exception as e:
100
+ return f"Error in generating response: {str(e)}"
101
+
102
+ def detect_language(code: str) -> str:
103
+ # Enhanced language detection with more specific patterns
104
+ patterns = {
105
+ 'python': r'\b(def|class|import|from|if\s+__name__\s*==\s*[\'"]__main__[\'"])\b',
106
+ 'javascript': r'\b(function|var|let|const|=>|document\.getElementById)\b',
107
+ 'java': r'\b(public\s+class|private|protected|package|import\s+java)\b',
108
+ 'c++': r'\b(#include\s*<|using\s+namespace|template\s*<|std::)',
109
+ 'ruby': r'\b(def|class|module|require|attr_accessor)\b',
110
+ 'go': r'\b(func|package\s+main|import\s*\(|fmt\.Println)\b',
111
+ 'rust': r'\b(fn|let\s+mut|impl|pub\s+struct|use\s+std)\b',
112
+ 'typescript': r'\b(interface|type|namespace|readonly|abstract\s+class)\b',
113
+ }
114
+
115
+ for lang, pattern in patterns.items():
116
+ if re.search(pattern, code):
117
+ return lang
118
+ return 'unknown'
119
+
120
+ def validate_and_fix_code(code: str, language: str) -> tuple[str, str]:
121
+ if language == 'python':
122
+ try:
123
+ fixed_code = autopep8.fix_code(code)
124
+ fixed_code = isort.SortImports(file_contents=fixed_code).output
125
+ fixed_code = black.format_str(fixed_code, mode=black.FileMode())
126
+ return fixed_code, ""
127
+ except Exception as e:
128
+ return code, f"Error in fixing Python code: {str(e)}"
129
+ elif language == 'javascript':
130
+ # Use a JS beautifier (placeholder)
131
+ return code, ""
132
+ elif language == 'java':
133
+ # Use a Java formatter (placeholder)
134
+ return code, ""
135
+ elif language == 'c++':
136
+ # Use a C++ formatter (placeholder)
137
+ return code, ""
 
 
138
  else:
139
+ return code, ""
140
+
141
+ def optimize_code(code: str) -> tuple[str, str]:
142
+ language = detect_language(code)
143
+ fixed_code, fix_error = validate_and_fix_code(code, language)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ if fix_error:
146
+ return fixed_code, fix_error
 
 
147
 
148
+ if language == 'python':
149
+ try:
150
+ tree = ast.parse(fixed_code)
151
+ # Perform advanced Python-specific optimizations
152
+ optimizer = PythonCodeOptimizer()
153
+ optimized_tree = optimizer.visit(tree)
154
+ optimized_code = ast.unparse(optimized_tree)
155
+ except SyntaxError as e:
156
+ return fixed_code, f"SyntaxError: {str(e)}"
157
+ elif language == 'java':
158
+ try:
159
+ tree = javalang.parse.parse(fixed_code)
160
+ # Perform Java-specific optimizations
161
+ optimizer = JavaCodeOptimizer()
162
+ optimized_code = optimizer.optimize(tree)
163
+ except javalang.parser.JavaSyntaxError as e:
164
+ return fixed_code, f"JavaSyntaxError: {str(e)}"
165
+ elif language == 'c++':
166
+ try:
167
+ index = clang.cindex.Index.create()
168
+ tu = index.parse('temp.cpp', args=['-std=c++14'], unsaved_files=[('temp.cpp', fixed_code)])
169
+ # Perform C++-specific optimizations
170
+ optimizer = CppCodeOptimizer()
171
+ optimized_code = optimizer.optimize(tu)
172
+ except Exception as e:
173
+ return fixed_code, f"C++ Parsing Error: {str(e)}"
174
+ else:
175
+ optimized_code = fixed_code # For unsupported languages, return the fixed code
176
 
177
+ # Run language-specific linter
178
+ lint_results = run_linter(optimized_code, language)
 
179
 
180
+ return optimized_code, lint_results
181
+
182
+ def run_linter(code: str, language: str) -> str:
183
+ if language == 'python':
184
+ with open("temp_code.py", "w") as file:
185
+ file.write(code)
186
+ result = subprocess.run(["pylint", "temp_code.py"], capture_output=True, text=True)
187
+ os.remove("temp_code.py")
188
+ return result.stdout
189
+ elif language == 'javascript':
190
+ # Run ESLint (placeholder)
191
+ return "JavaScript linting not implemented"
192
+ elif language == 'java':
193
+ # Run CheckStyle (placeholder)
194
+ return "Java linting not implemented"
195
+ elif language == 'c++':
196
+ # Run cppcheck (placeholder)
197
+ return "C++ linting not implemented"
198
+ else:
199
+ return "Linting not available for the detected language"
200
+
201
+ def fetch_from_github(query: str) -> List[Dict[str, Any]]:
202
+ headers = {"Authorization": f"token {st.secrets['GITHUB_TOKEN']}"}
203
+ response = requests.get(f"https://api.github.com/search/code?q={query}", headers=headers)
204
+ if response.status_code == 200:
205
+ return response.json()['items'][:5] # Return top 5 results
206
+ return []
207
+
208
+ def analyze_code_quality(code: str) -> Dict[str, float]:
209
+ inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
210
+
211
+ with torch.no_grad():
212
+ outputs = codebert_model(**inputs)
213
+
214
+ cls_embedding = outputs.last_hidden_state[:, 0, :]
215
+ predictions = code_improvement_model(cls_embedding)
216
+
217
+ quality_scores = {
218
+ "style": predictions[0][0].item(),
219
+ "efficiency": predictions[0][1].item(),
220
+ "security": predictions[0][2].item(),
221
+ "maintainability": predictions[0][3].item(),
222
+ "scalability": predictions[0][4].item(),
223
+ "readability": predictions[0][5].item(),
224
+ "testability": predictions[0][6].item(),
225
+ "modularity": predictions[0][7].item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  }
227
+
228
+ # Calculate additional metrics
229
+ language = detect_language(code)
230
+ if language == 'python':
231
+ complexity = radon_complexity.cc_visit(code)
232
+ maintainability = radon_metrics.mi_visit(code, True)
233
+ quality_scores["cyclomatic_complexity"] = complexity[0].complexity if complexity else 0
234
+ quality_scores["maintainability_index"] = maintainability
235
+
236
+ return quality_scores
237
+
238
+ def visualize_code_structure(code: str) -> plt.Figure:
239
+ try:
240
+ tree = ast.parse(code)
241
+ graph = nx.DiGraph()
242
+
243
+ def add_nodes_edges(node, parent=None):
244
+ node_id = id(node)
245
+ graph.add_node(node_id, label=f"{type(node).__name__}\n{ast.unparse(node)[:20]}")
246
+ if parent:
247
+ graph.add_edge(id(parent), node_id)
248
+ for child in ast.iter_child_nodes(node):
249
+ add_nodes_edges(child, node)
250
+
251
+ add_nodes_edges(tree)
252
+
253
+ plt.figure(figsize=(15, 10))
254
+ pos = nx.spring_layout(graph, k=0.9, iterations=50)
255
+ nx.draw(graph, pos, with_labels=True, node_color='lightblue', node_size=2000, font_size=8, font_weight='bold', arrows=True)
256
+ labels = nx.get_node_attributes(graph, 'label')
257
+ nx.draw_networkx_labels(graph, pos, labels, font_size=6)
258
+
259
+ return plt
260
+ except SyntaxError:
261
+ return None
262
+
263
+ def suggest_improvements(code: str, quality_scores: Dict[str, float]) -> List[str]:
264
+ suggestions = []
265
+ thresholds = {
266
+ "style": 0.7,
267
+ "efficiency": 0.7,
268
+ "security": 0.8,
269
+ "maintainability": 0.7,
270
+ "scalability": 0.7,
271
+ "readability": 0.7,
272
+ "testability": 0.7,
273
+ "modularity": 0.7
274
  }
275
+
276
+ for metric, threshold in thresholds.items():
277
+ if quality_scores[metric] < threshold:
278
+ suggestions.append(f"Consider improving code {metric} (current score: {quality_scores[metric]:.2f}).")
279
+
280
+ if "cyclomatic_complexity" in quality_scores and quality_scores["cyclomatic_complexity"] > 10:
281
+ suggestions.append(f"Consider breaking down complex functions to reduce cyclomatic complexity (current: {quality_scores['cyclomatic_complexity']}).")
282
+
283
+ return suggestions
284
+
285
+ # New function for advanced code generation using GPT-2
286
+ def generate_advanced_code(prompt: str, language: str) -> str:
287
+ input_text = f"Generate {language} code for: {prompt}\n\n"
288
+ input_ids = gpt2_tokenizer.encode(input_text, return_tensors="pt")
289
+
290
+ output = gpt2_model.generate(
291
+ input_ids,
292
+ max_length=1000,
293
+ num_return_sequences=1,
294
+ no_repeat_ngram_size=2,
295
+ top_k=50,
296
+ top_p=0.95,
297
+ temperature=0.7
298
+ )
299
+
300
+ generated_code = gpt2_tokenizer.decode(output[0], skip_special_tokens=True)
301
+ return generated_code.split("\n\n", 1)[1] # Remove the input prompt from the generated text
302
+
303
+ # New function for code similarity analysis
304
+ def analyze_code_similarity(code1: str, code2: str) -> float:
305
+ tokens1 = tokenizer.tokenize(code1)
306
+ tokens2 = tokenizer.tokenize(code2)
307
+
308
+ # Use Jaccard similarity for token-based comparison
309
+ set1 = set(tokens1)
310
+ set2 = set(tokens2)
311
+ similarity = len(set1.intersection(set2)) / len(set1.union(set2))
312
+
313
+ return similarity
314
+
315
+ # New function for code performance estimation
316
+ def estimate_code_performance(code: str) -> Dict[str, Any]:
317
+ language = detect_language(code)
318
+ if language == 'python':
319
+ # Use abstract syntax tree to estimate time complexity
320
+ tree = ast.parse(code)
321
+ analyzer = ComplexityAnalyzer()
322
+ analyzer.visit(tree)
323
+ return {
324
+ "time_complexity": analyzer.time_complexity,
325
+ "space_complexity": analyzer.space_complexity
326
+ }
327
+ else:
328
+ return {"error": "Performance estimation not supported for this language"}
329
+
330
+ class ComplexityAnalyzer(ast.NodeVisitor):
331
+ def __init__(self):
332
+ self.time_complexity = "O(1)"
333
+ self.space_complexity = "O(1)"
334
+ self.loop_depth = 0
335
+
336
+ def visit_For(self, node):
337
+ self.loop_depth += 1
338
+ self.generic_visit(node)
339
+ self.loop_depth -= 1
340
+ self.update_complexity()
341
+
342
+ def visit_While(self, node):
343
+ self.loop_depth += 1
344
+ self.generic_visit(node)
345
+ self.loop_depth -= 1
346
+ self.update_complexity()
347
+
348
+ def update_complexity(self):
349
+ if self.loop_depth > 0:
350
+ self.time_complexity = f"O(n^{self.loop_depth})"
351
+ self.space_complexity = "O(n)"
352
+
353
+ # New function for code translation between programming languages
354
+ def translate_code(code: str, source_lang: str, target_lang: str) -> str:
355
+ prompt = f"Translate the following {source_lang} code to {target_lang}:\n\n{code}\n\nTranslated {target_lang} code:"
356
+ translated_code = generate_advanced_code(prompt, target_lang)
357
+ return translated_code
358
+
359
+ # New function for generating unit tests
360
+ def generate_unit_tests(code: str, language: str) -> str:
361
+ prompt = f"Generate unit tests for the following {language} code:\n\n{code}\n\nUnit tests:"
362
+ unit_tests = generate_advanced_code(prompt, language)
363
+ return unit_tests
364
+
365
+ # New function for code documentation generation
366
+ def generate_documentation(code: str, language: str) -> str:
367
+ prompt = f"Generate comprehensive documentation for the following {language} code:\n\n{code}\n\nDocumentation:"
368
+ documentation = generate_advanced_code(prompt, language)
369
+ return documentation
370
+
371
+ # New function for advanced code refactoring suggestions
372
+ def suggest_refactoring(code: str, language: str) -> List[str]:
373
+ quality_scores = analyze_code_quality(code)
374
+ suggestions = suggest_improvements(code, quality_scores)
375
+
376
+ # Add more specific refactoring suggestions based on code analysis
377
+ tree = ast.parse(code)
378
+ analyzer = RefactoringAnalyzer()
379
+ analyzer.visit(tree)
380
+
381
+ suggestions.extend(analyzer.suggestions)
382
+ return suggestions
383
+
384
+ class RefactoringAnalyzer(ast.NodeVisitor):
385
+ def __init__(self):
386
+ self.suggestions = []
387
+ self.function_lengths = {}
388
+
389
+ def visit_FunctionDef(self, node):
390
+ function_length = len(node.body)
391
+ self.function_lengths[node.name] = function_length
392
+ if function_length > 20:
393
+ self.suggestions.append(f"Consider breaking down the function '{node.name}' into smaller, more manageable functions.")
394
+ self.generic_visit(node)
395
+
396
+ def visit_If(self, node):
397
+ if isinstance(node.test, ast.Compare) and len(node.test.ops) > 2:
398
+ self.suggestions.append("Consider simplifying complex conditional statements.")
399
+ self.generic_visit(node)
400
+
401
+ # New function for code security analysis
402
+ def analyze_code_security(code: str, language: str) -> List[str]:
403
+ vulnerabilities = []
404
+
405
+ if language == 'python':
406
+ tree = ast.parse(code)
407
+ analyzer = SecurityAnalyzer()
408
+ analyzer.visit(tree)
409
+ vulnerabilities.extend(analyzer.vulnerabilities)
410
+
411
+ # Add more language-specific security checks here
412
+
413
+ return vulnerabilities
414
+
415
+ class SecurityAnalyzer(ast.NodeVisitor):
416
+ def __init__(self):
417
+ self.vulnerabilities = []
418
+
419
+ def visit_Call(self, node):
420
+ if isinstance(node.func, ast.Name):
421
+ if node.func.id == 'eval':
422
+ self.vulnerabilities.append("Potential security risk: Use of 'eval' function detected.")
423
+ elif node.func.id == 'exec':
424
+ self.vulnerabilities.append("Potential security risk: Use of 'exec' function detected.")
425
+ self.generic_visit(node)
426
+
427
+ # New function for code optimization suggestions
428
+ def suggest_optimizations(code: str, language: str) -> List[str]:
429
+ suggestions = []
430
+
431
+ if language == 'python':
432
+ tree = ast.parse(code)
433
+ analyzer = OptimizationAnalyzer()
434
+ analyzer.visit(tree)
435
+ suggestions.extend(analyzer.suggestions)
436
+
437
+ # Add more language-specific optimization suggestions here
438
+
439
+ return suggestions
440
+
441
+ class OptimizationAnalyzer(ast.NodeVisitor):
442
+ def __init__(self):
443
+ self.suggestions = []
444
+ self.loop_variables = set()
445
+
446
+ def visit_For(self, node):
447
+ if isinstance(node.iter, ast.Call) and isinstance(node.iter.func, ast.Name) and node.iter.func.id == 'range':
448
+ self.suggestions.append("Consider using 'enumerate()' instead of 'range()' for index-based iteration.")
449
+ self.generic_visit(node)
450
+
451
+ def visit_ListComp(self, node):
452
+ if isinstance(node.elt, ast.Call) and isinstance(node.elt.func, ast.Name) and node.elt.func.id == 'append':
453
+ self.suggestions.append("Consider using a list comprehension instead of appending in a loop for better performance.")
454
+ self.generic_visit(node)
455
+
456
+ # Streamlit UI setup
457
+ st.set_page_config(page_title="Advanced AI Code Assistant", page_icon="πŸš€", layout="wide")
458
+
459
+ st.markdown("""
460
+ <style>
461
+ .main-container {
462
+ padding: 2rem;
463
+ border-radius: 10px;
464
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
465
+ background-color: #f8f9fa;
466
+ }
467
+ .title {
468
+ color: #2c3e50;
469
+ font-size: 2.5rem;
470
+ margin-bottom: 1rem;
471
+ }
472
+ .subtitle {
473
+ color: #34495e;
474
+ font-size: 1.2rem;
475
+ margin-bottom: 2rem;
476
+ }
477
+ .output-container {
478
+ margin-top: 2rem;
479
+ padding: 1rem;
480
+ border-radius: 5px;
481
+ background-color: #ffffff;
482
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
483
+ }
484
+ .code-block {
485
+ margin-bottom: 1rem;
486
+ }
487
+ .metric-container {
488
+ display: flex;
489
+ justify-content: space-between;
490
+ flex-wrap: wrap;
491
+ }
492
+ .metric-item {
493
+ flex-basis: 48%;
494
+ margin-bottom: 1rem;
495
+ }
496
  </style>
497
  """, unsafe_allow_html=True)
498
 
499
  st.markdown('<div class="main-container">', unsafe_allow_html=True)
500
+ st.markdown('<h1 class="title">πŸš€ Advanced AI Code Assistant</h1>', unsafe_allow_html=True)
501
+ st.markdown('<p class="subtitle">Powered by Cutting-Edge AI & Multi-Domain Expertise</p>', unsafe_allow_html=True)
502
 
503
+ task = st.selectbox("Select a task", [
504
+ "Generate Code", "Optimize Code", "Analyze Code Quality",
505
+ "Translate Code", "Generate Unit Tests", "Generate Documentation",
506
+ "Suggest Refactoring", "Analyze Code Security", "Suggest Optimizations"
507
+ ])
508
 
509
+ language = st.selectbox("Select programming language", [
510
+ "Python", "JavaScript", "Java", "C++", "Ruby", "Go", "Rust", "TypeScript"
511
+ ])
512
+
513
+ prompt = st.text_area("Enter your code or prompt", height=200)
514
+
515
+ if st.button("Execute Task"):
516
  if prompt.strip() == "":
517
+ st.error("Please enter a valid prompt or code snippet.")
518
  else:
519
+ with st.spinner("Processing your request..."):
520
+ if task == "Generate Code":
521
+ result = generate_advanced_code(prompt, language.lower())
522
+ st.code(result, language=language.lower())
523
+ elif task == "Optimize Code":
524
+ optimized_code, lint_results = optimize_code(prompt)
525
+ st.code(optimized_code, language=language.lower())
526
+ st.text(lint_results)
527
+ elif task == "Analyze Code Quality":
528
+ quality_scores = analyze_code_quality(prompt)
529
+ st.json(quality_scores)
530
+ elif task == "Translate Code":
531
+ target_lang = st.selectbox("Select target language", [
532
+ lang for lang in ["Python", "JavaScript", "Java", "C++", "Ruby", "Go", "Rust", "TypeScript"] if lang != language
533
+ ])
534
+ translated_code = translate_code(prompt, language.lower(), target_lang.lower())
535
+ st.code(translated_code, language=target_lang.lower())
536
+ elif task == "Generate Unit Tests":
537
+ unit_tests = generate_unit_tests(prompt, language.lower())
538
+ st.code(unit_tests, language=language.lower())
539
+ elif task == "Generate Documentation":
540
+ documentation = generate_documentation(prompt, language.lower())
541
+ st.markdown(documentation)
542
+ elif task == "Suggest Refactoring":
543
+ refactoring_suggestions = suggest_refactoring(prompt, language.lower())
544
+ for suggestion in refactoring_suggestions:
545
+ st.info(suggestion)
546
+ elif task == "Analyze Code Security":
547
+ vulnerabilities = analyze_code_security(prompt, language.lower())
548
+ if vulnerabilities:
549
+ for vuln in vulnerabilities:
550
+ st.warning(vuln)
551
  else:
552
+ st.success("No obvious security vulnerabilities detected.")
553
+ elif task == "Suggest Optimizations":
554
+ optimization_suggestions = suggest_optimizations(prompt, language.lower())
555
+ for suggestion in optimization_suggestions:
556
+ st.info(suggestion)
557
+
558
+ # Additional analysis for all tasks
559
+ quality_scores = analyze_code_quality(prompt)
560
+ performance_estimate = estimate_code_performance(prompt)
561
+
562
+ col1, col2 = st.columns(2)
563
+ with col1:
564
+ st.subheader("Code Quality Metrics")
565
+ for metric, score in quality_scores.items():
566
+ st.metric(metric.capitalize(), f"{score:.2f}")
567
+
568
+ with col2:
569
+ st.subheader("Performance Estimation")
570
+ st.json(performance_estimate)
571
+
572
+ visualization = visualize_code_structure(prompt)
573
+ if visualization:
574
+ st.subheader("Code Structure Visualization")
575
+ st.pyplot(visualization)
576
 
577
  st.markdown("""
578
  <div style='text-align: center; margin-top: 2rem; color: #4a5568;'>
579
+ Powered by Advanced AI & Multi-Domain Expertise
580
  </div>
581
  """, unsafe_allow_html=True)
582
 
583
+ st.markdown('</div>', unsafe_allow_html=True)
584
+
585
+ # FastAPI setup for potential API endpoints
586
+ app = FastAPI()
587
+
588
+ class CodeRequest(BaseModel):
589
+ code: str
590
+ language: str
591
+ task: str
592
+
593
+ @app.post("/analyze")
594
+ async def analyze_code(request: CodeRequest):
595
+ if request.task == "quality":
596
+ return analyze_code_quality(request.code)
597
+ elif request.task == "security":
598
+ return analyze_code_security(request.code, request.language)
599
+ elif request.task == "optimize":
600
+ optimized_code, _ = optimize_code(request.code)
601
+ return {"optimized_code": optimized_code}
602
+ else:
603
+ return {"error": "Invalid task"}
604
+
605
+ if __name__ == "__main__":
606
+ uvicorn.run(app, host="0.0.0.0", port=8000)