import streamlit as st
import google.generativeai as genai
import requests
import subprocess
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModel
import ast
import networkx as nx
import matplotlib.pyplot as plt
# Configure the Gemini API
genai.configure(api_key=st.secrets["GOOGLE_API_KEY"])
# Create the model with optimized parameters and enhanced system instructions
generation_config = {
"temperature": 0.6,
"top_p": 0.8,
"top_k": 30,
"max_output_tokens": 16384,
}
model = genai.GenerativeModel(
model_name="gemini-1.5-pro",
generation_config=generation_config,
system_instruction="""
You are Ath, a highly advanced code assistant with deep knowledge in AI, machine learning, and software engineering. You provide cutting-edge, optimized, and secure code solutions. Speak casually and use tech jargon when appropriate.
"""
)
chat_session = model.start_chat(history=[])
# Load pre-trained BERT model for code understanding
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
codebert_model = AutoModel.from_pretrained("microsoft/codebert-base")
class CodeImprovement(nn.Module):
def __init__(self, input_dim):
super(CodeImprovement, self).__init__()
self.fc1 = nn.Linear(input_dim, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 128)
self.fc4 = nn.Linear(128, 2) # Binary classification: needs improvement or not
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.relu(self.fc3(x))
return torch.sigmoid(self.fc4(x))
code_improvement_model = CodeImprovement(768) # 768 is BERT's output dimension
optimizer = optim.Adam(code_improvement_model.parameters())
criterion = nn.BCELoss()
def generate_response(user_input):
try:
response = chat_session.send_message(user_input)
return response.text
except Exception as e:
return f"Error: {e}"
def optimize_code(code):
# Use abstract syntax tree for advanced code analysis
try:
tree = ast.parse(code)
analyzer = CodeAnalyzer()
analyzer.visit(tree)
# Apply code transformations based on analysis
transformer = CodeTransformer(analyzer.get_optimizations())
optimized_tree = transformer.visit(tree)
optimized_code = ast.unparse(optimized_tree)
except SyntaxError as e:
return code, f"SyntaxError: {str(e)}"
# Run pylint for additional suggestions
with open("temp_code.py", "w") as file:
file.write(optimized_code)
result = subprocess.run(["pylint", "temp_code.py"], capture_output=True, text=True)
os.remove("temp_code.py")
return optimized_code, result.stdout
def fetch_from_github(query):
headers = {"Authorization": f"token {st.secrets['GITHUB_TOKEN']}"}
response = requests.get(f"https://api.github.com/search/code?q={query}", headers=headers)
if response.status_code == 200:
return response.json()['items'][:5] # Return top 5 results
return []
def interact_with_api(api_url):
response = requests.get(api_url)
return response.json()
def train_ml_model(code_data):
df = pd.DataFrame(code_data)
X = df.drop('target', axis=1)
y = df['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
model = RandomForestClassifier(n_estimators=100, max_depth=10)
model.fit(X_train, y_train)
return model
def analyze_code_quality(code):
# Tokenize and encode the code
inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
# Get BERT embeddings
with torch.no_grad():
outputs = codebert_model(**inputs)
# Use the [CLS] token embedding for classification
cls_embedding = outputs.last_hidden_state[:, 0, :]
# Pass through our code improvement model
prediction = code_improvement_model(cls_embedding)
return prediction.item() # Return the probability of needing improvement
def visualize_code_structure(code):
try:
tree = ast.parse(code)
graph = nx.DiGraph()
def add_nodes_edges(node, parent=None):
node_id = id(node)
graph.add_node(node_id, label=type(node).__name__)
if parent:
graph.add_edge(id(parent), node_id)
for child in ast.iter_child_nodes(node):
add_nodes_edges(child, node)
add_nodes_edges(tree)
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(graph)
nx.draw(graph, pos, with_labels=True, node_color='lightblue', node_size=1000, font_size=8, font_weight='bold')
labels = nx.get_node_attributes(graph, 'label')
nx.draw_networkx_labels(graph, pos, labels, font_size=6)
return plt
except SyntaxError:
return None
# Streamlit UI setup
st.set_page_config(page_title="Advanced AI Code Assistant", page_icon="🚀", layout="wide")
st.markdown("""
""", unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)
st.title("🚀 Advanced AI Code Assistant")
st.markdown('
Powered by Google Gemini & Deep Learning
', unsafe_allow_html=True)
prompt = st.text_area("What advanced code task can I help you with today?", height=120)
if st.button("Generate Advanced Code"):
if prompt.strip() == "":
st.error("Please enter a valid prompt.")
else:
with st.spinner("Generating and analyzing code..."):
completed_text = generate_response(prompt)
if "Error" in completed_text:
st.error(completed_text)
else:
optimized_code, lint_results = optimize_code(completed_text)
if "SyntaxError" in lint_results:
st.warning(f"Syntax error detected: {lint_results}")
st.code(completed_text)
else:
quality_score = analyze_code_quality(optimized_code)
st.success(f"Code generated and optimized successfully! Quality Score: {quality_score:.2f}")
st.markdown('
', unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)
st.code(optimized_code)
st.markdown('
', unsafe_allow_html=True)
visualization = visualize_code_structure(optimized_code)
if visualization:
with st.expander("View Code Structure Visualization"):
st.pyplot(visualization)
else:
st.warning("Unable to generate code structure visualization due to syntax errors.")
with st.expander("View Lint Results"):
st.text(lint_results)
with st.expander("Fetch Similar Code from GitHub"):
github_results = fetch_from_github(prompt)
for item in github_results:
st.markdown(f"[{item['name']}]({item['html_url']})")
st.markdown('
', unsafe_allow_html=True)
st.markdown("""
Crafted with 🚀 by Your Advanced AI Code Assistant
""", unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)