|
import streamlit as st |
|
import google.generativeai as genai |
|
import requests |
|
import subprocess |
|
import os |
|
import pandas as pd |
|
import numpy as np |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.ensemble import RandomForestClassifier |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from transformers import AutoTokenizer, AutoModel |
|
import ast |
|
import networkx as nx |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
genai.configure(api_key=st.secrets["GOOGLE_API_KEY"]) |
|
|
|
|
|
generation_config = { |
|
"temperature": 0.6, |
|
"top_p": 0.8, |
|
"top_k": 30, |
|
"max_output_tokens": 16384, |
|
} |
|
|
|
model = genai.GenerativeModel( |
|
model_name="gemini-1.5-pro", |
|
generation_config=generation_config, |
|
system_instruction=""" |
|
You are Ath, a highly advanced code assistant with deep knowledge in AI, machine learning, and software engineering. You provide cutting-edge, optimized, and secure code solutions. Speak casually and use tech jargon when appropriate. |
|
""" |
|
) |
|
chat_session = model.start_chat(history=[]) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") |
|
codebert_model = AutoModel.from_pretrained("microsoft/codebert-base") |
|
|
|
class CodeImprovement(nn.Module): |
|
def __init__(self, input_dim): |
|
super(CodeImprovement, self).__init__() |
|
self.fc1 = nn.Linear(input_dim, 512) |
|
self.fc2 = nn.Linear(512, 256) |
|
self.fc3 = nn.Linear(256, 128) |
|
self.fc4 = nn.Linear(128, 2) |
|
|
|
def forward(self, x): |
|
x = torch.relu(self.fc1(x)) |
|
x = torch.relu(self.fc2(x)) |
|
x = torch.relu(self.fc3(x)) |
|
return torch.sigmoid(self.fc4(x)) |
|
|
|
code_improvement_model = CodeImprovement(768) |
|
optimizer = optim.Adam(code_improvement_model.parameters()) |
|
criterion = nn.BCELoss() |
|
|
|
def generate_response(user_input): |
|
try: |
|
response = chat_session.send_message(user_input) |
|
return response.text |
|
except Exception as e: |
|
return f"Error: {e}" |
|
|
|
def optimize_code(code): |
|
|
|
try: |
|
tree = ast.parse(code) |
|
analyzer = CodeAnalyzer() |
|
analyzer.visit(tree) |
|
|
|
|
|
transformer = CodeTransformer(analyzer.get_optimizations()) |
|
optimized_tree = transformer.visit(tree) |
|
|
|
optimized_code = ast.unparse(optimized_tree) |
|
except SyntaxError as e: |
|
return code, f"SyntaxError: {str(e)}" |
|
|
|
|
|
with open("temp_code.py", "w") as file: |
|
file.write(optimized_code) |
|
result = subprocess.run(["pylint", "temp_code.py"], capture_output=True, text=True) |
|
os.remove("temp_code.py") |
|
|
|
return optimized_code, result.stdout |
|
|
|
def fetch_from_github(query): |
|
headers = {"Authorization": f"token {st.secrets['GITHUB_TOKEN']}"} |
|
response = requests.get(f"https://api.github.com/search/code?q={query}", headers=headers) |
|
if response.status_code == 200: |
|
return response.json()['items'][:5] |
|
return [] |
|
|
|
def interact_with_api(api_url): |
|
response = requests.get(api_url) |
|
return response.json() |
|
|
|
def train_ml_model(code_data): |
|
df = pd.DataFrame(code_data) |
|
X = df.drop('target', axis=1) |
|
y = df['target'] |
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) |
|
model = RandomForestClassifier(n_estimators=100, max_depth=10) |
|
model.fit(X_train, y_train) |
|
return model |
|
|
|
def analyze_code_quality(code): |
|
|
|
inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512, padding="max_length") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = codebert_model(**inputs) |
|
|
|
|
|
cls_embedding = outputs.last_hidden_state[:, 0, :] |
|
|
|
|
|
prediction = code_improvement_model(cls_embedding) |
|
|
|
return prediction.item() |
|
|
|
def visualize_code_structure(code): |
|
try: |
|
tree = ast.parse(code) |
|
graph = nx.DiGraph() |
|
|
|
def add_nodes_edges(node, parent=None): |
|
node_id = id(node) |
|
graph.add_node(node_id, label=type(node).__name__) |
|
if parent: |
|
graph.add_edge(id(parent), node_id) |
|
for child in ast.iter_child_nodes(node): |
|
add_nodes_edges(child, node) |
|
|
|
add_nodes_edges(tree) |
|
|
|
plt.figure(figsize=(12, 8)) |
|
pos = nx.spring_layout(graph) |
|
nx.draw(graph, pos, with_labels=True, node_color='lightblue', node_size=1000, font_size=8, font_weight='bold') |
|
labels = nx.get_node_attributes(graph, 'label') |
|
nx.draw_networkx_labels(graph, pos, labels, font_size=6) |
|
|
|
return plt |
|
except SyntaxError: |
|
return None |
|
|
|
|
|
st.set_page_config(page_title="Advanced AI Code Assistant", page_icon="π", layout="wide") |
|
|
|
st.markdown(""" |
|
<style> |
|
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600;700&display=swap'); |
|
|
|
body { |
|
font-family: 'Inter', sans-serif; |
|
background-color: #f0f4f8; |
|
color: #1a202c; |
|
} |
|
.stApp { |
|
max-width: 1200px; |
|
margin: 0 auto; |
|
padding: 2rem; |
|
} |
|
.main-container { |
|
background: #ffffff; |
|
border-radius: 16px; |
|
padding: 2rem; |
|
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05); |
|
} |
|
h1 { |
|
font-size: 2.5rem; |
|
font-weight: 700; |
|
color: #2d3748; |
|
text-align: center; |
|
margin-bottom: 1rem; |
|
} |
|
.subtitle { |
|
font-size: 1.1rem; |
|
text-align: center; |
|
color: #4a5568; |
|
margin-bottom: 2rem; |
|
} |
|
.stTextArea textarea { |
|
border: 2px solid #e2e8f0; |
|
border-radius: 8px; |
|
font-size: 1rem; |
|
padding: 0.75rem; |
|
transition: all 0.3s ease; |
|
} |
|
.stTextArea textarea:focus { |
|
border-color: #4299e1; |
|
box-shadow: 0 0 0 3px rgba(66, 153, 225, 0.5); |
|
} |
|
.stButton button { |
|
background-color: #4299e1; |
|
color: white; |
|
border: none; |
|
border-radius: 8px; |
|
font-size: 1.1rem; |
|
font-weight: 600; |
|
padding: 0.75rem 2rem; |
|
transition: all 0.3s ease; |
|
width: 100%; |
|
} |
|
.stButton button:hover { |
|
background-color: #3182ce; |
|
} |
|
.output-container { |
|
background: #f7fafc; |
|
border-radius: 8px; |
|
padding: 1rem; |
|
margin-top: 2rem; |
|
} |
|
.code-block { |
|
background-color: #2d3748; |
|
color: #e2e8f0; |
|
font-family: 'Fira Code', monospace; |
|
font-size: 0.9rem; |
|
border-radius: 8px; |
|
padding: 1rem; |
|
margin-top: 1rem; |
|
overflow-x: auto; |
|
} |
|
.stAlert { |
|
background-color: #ebf8ff; |
|
color: #2b6cb0; |
|
border-radius: 8px; |
|
border: none; |
|
padding: 0.75rem 1rem; |
|
} |
|
.stSpinner { |
|
color: #4299e1; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
st.markdown('<div class="main-container">', unsafe_allow_html=True) |
|
st.title("π Advanced AI Code Assistant") |
|
st.markdown('<p class="subtitle">Powered by Google Gemini & Deep Learning</p>', unsafe_allow_html=True) |
|
|
|
prompt = st.text_area("What advanced code task can I help you with today?", height=120) |
|
|
|
if st.button("Generate Advanced Code"): |
|
if prompt.strip() == "": |
|
st.error("Please enter a valid prompt.") |
|
else: |
|
with st.spinner("Generating and analyzing code..."): |
|
completed_text = generate_response(prompt) |
|
if "Error" in completed_text: |
|
st.error(completed_text) |
|
else: |
|
optimized_code, lint_results = optimize_code(completed_text) |
|
|
|
if "SyntaxError" in lint_results: |
|
st.warning(f"Syntax error detected: {lint_results}") |
|
st.code(completed_text) |
|
else: |
|
quality_score = analyze_code_quality(optimized_code) |
|
st.success(f"Code generated and optimized successfully! Quality Score: {quality_score:.2f}") |
|
|
|
st.markdown('<div class="output-container">', unsafe_allow_html=True) |
|
st.markdown('<div class="code-block">', unsafe_allow_html=True) |
|
st.code(optimized_code) |
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
visualization = visualize_code_structure(optimized_code) |
|
if visualization: |
|
with st.expander("View Code Structure Visualization"): |
|
st.pyplot(visualization) |
|
else: |
|
st.warning("Unable to generate code structure visualization due to syntax errors.") |
|
|
|
with st.expander("View Lint Results"): |
|
st.text(lint_results) |
|
|
|
with st.expander("Fetch Similar Code from GitHub"): |
|
github_results = fetch_from_github(prompt) |
|
for item in github_results: |
|
st.markdown(f"[{item['name']}]({item['html_url']})") |
|
|
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
st.markdown(""" |
|
<div style='text-align: center; margin-top: 2rem; color: #4a5568;'> |
|
Crafted with π by Your Advanced AI Code Assistant |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
st.markdown('</div>', unsafe_allow_html=True) |