import streamlit as st
import google.generativeai as genai
import requests
import subprocess
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModel
import ast
import networkx as nx
import matplotlib.pyplot as plt
# Configure the Gemini API
genai.configure(api_key=st.secrets["GOOGLE_API_KEY"])
# Create the model with optimized parameters and enhanced system instructions
generation_config = {
"temperature": 0.6,
"top_p": 0.8,
"top_k": 30,
"max_output_tokens": 16384,
}
model = genai.GenerativeModel(
model_name="gemini-1.5-pro",
generation_config=generation_config,
system_instruction="""
You are Ath, a highly advanced code assistant with deep knowledge in AI, machine learning, and software engineering. You provide cutting-edge, optimized, and secure code solutions. Speak casually and use tech jargon when appropriate.
"""
)
chat_session = model.start_chat(history=[])
# Load pre-trained BERT model for code understanding
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
codebert_model = AutoModel.from_pretrained("microsoft/codebert-base")
class CodeImprovement(nn.Module):
def __init__(self, input_dim):
super(CodeImprovement, self).__init__()
self.fc1 = nn.Linear(input_dim, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 128)
self.fc4 = nn.Linear(128, 2) # Binary classification: needs improvement or not
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.relu(self.fc3(x))
return torch.sigmoid(self.fc4(x))
code_improvement_model = CodeImprovement(768) # 768 is BERT's output dimension
optimizer = optim.Adam(code_improvement_model.parameters())
criterion = nn.BCELoss()
def generate_response(user_input):
try:
response = chat_session.send_message(user_input)
return response.text
except Exception as e:
return f"Error in generating response: {str(e)}"
def validate_and_fix_code(code):
lines = code.split('\n')
fixed_lines = []
for line in lines:
# Check for unterminated string literals
if line.count('"') % 2 != 0 and line.count("'") % 2 != 0:
line += '"' # Add a closing quote if needed
fixed_lines.append(line)
return '\n'.join(fixed_lines)
def optimize_code(code):
# Validate and fix the code first
fixed_code = validate_and_fix_code(code)
try:
tree = ast.parse(fixed_code)
# Placeholder for actual optimization logic
optimized_code = fixed_code
except SyntaxError as e:
return fixed_code, f"SyntaxError: {str(e)}"
# Run pylint for additional suggestions
with open("temp_code.py", "w") as file:
file.write(optimized_code)
result = subprocess.run(["pylint", "temp_code.py"], capture_output=True, text=True)
os.remove("temp_code.py")
return optimized_code, result.stdout
def fetch_from_github(query):
headers = {"Authorization": f"token {st.secrets['GITHUB_TOKEN']}"}
response = requests.get(f"https://api.github.com/search/code?q={query}", headers=headers)
if response.status_code == 200:
return response.json()['items'][:5] # Return top 5 results
return []
def analyze_code_quality(code):
# Tokenize and encode the code
inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
# Get BERT embeddings
with torch.no_grad():
outputs = codebert_model(**inputs)
# Use the [CLS] token embedding for classification
cls_embedding = outputs.last_hidden_state[:, 0, :]
# Pass through our code improvement model
prediction = code_improvement_model(cls_embedding)
return prediction.item() # Return the probability of needing improvement
def visualize_code_structure(code):
try:
tree = ast.parse(code)
graph = nx.DiGraph()
def add_nodes_edges(node, parent=None):
node_id = id(node)
graph.add_node(node_id, label=type(node).__name__)
if parent:
graph.add_edge(id(parent), node_id)
for child in ast.iter_child_nodes(node):
add_nodes_edges(child, node)
add_nodes_edges(tree)
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(graph)
nx.draw(graph, pos, with_labels=True, node_color='lightblue', node_size=1000, font_size=8, font_weight='bold')
labels = nx.get_node_attributes(graph, 'label')
nx.draw_networkx_labels(graph, pos, labels, font_size=6)
return plt
except SyntaxError:
return None
# Streamlit UI setup
st.set_page_config(page_title="Advanced AI Code Assistant", page_icon="🚀", layout="wide")
st.markdown("""
""", unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)
st.title("🚀 Advanced AI Code Assistant")
st.markdown('
Powered by Google Gemini & Deep Learning
', unsafe_allow_html=True)
prompt = st.text_area("What advanced code task can I help you with today?", height=120)
if st.button("Generate Advanced Code"):
if prompt.strip() == "":
st.error("Please enter a valid prompt.")
else:
with st.spinner("Generating and analyzing code..."):
completed_text = generate_response(prompt)
if "Error in generating response" in completed_text:
st.error(completed_text)
else:
optimized_code, lint_results = optimize_code(completed_text)
if "SyntaxError" in lint_results:
st.warning(f"Syntax error detected in the generated code. Attempting to fix...")
st.code(optimized_code)
st.info("Please review the code above. It may contain errors or be incomplete.")
else:
quality_score = analyze_code_quality(optimized_code)
st.success(f"Code generated and optimized successfully! Quality Score: {quality_score:.2f}")
st.markdown('
', unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)
st.code(optimized_code)
st.markdown('
', unsafe_allow_html=True)
visualization = visualize_code_structure(optimized_code)
if visualization:
with st.expander("View Code Structure Visualization"):
st.pyplot(visualization)
else:
st.warning("Unable to generate code structure visualization due to syntax errors.")
with st.expander("View Lint Results"):
st.text(lint_results)
with st.expander("Fetch Similar Code from GitHub"):
github_results = fetch_from_github(prompt)
for item in github_results:
st.markdown(f"[{item['name']}]({item['html_url']})")
st.markdown('
', unsafe_allow_html=True)
st.markdown("""
Crafted with 🚀 by Your Advanced AI Code Assistant
""", unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)