|
import streamlit as st |
|
import google.generativeai as genai |
|
import requests |
|
import subprocess |
|
import os |
|
import pandas as pd |
|
import numpy as np |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.ensemble import RandomForestClassifier |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from transformers import AutoTokenizer, AutoModel, pipeline |
|
import ast |
|
import networkx as nx |
|
import matplotlib.pyplot as plt |
|
import re |
|
import javalang |
|
import clang.cindex |
|
import radon.metrics as radon_metrics |
|
import radon.complexity as radon_complexity |
|
import black |
|
import isort |
|
import autopep8 |
|
|
|
|
|
genai.configure(api_key=st.secrets["GOOGLE_API_KEY"]) |
|
|
|
|
|
generation_config = { |
|
"temperature": 0.7, |
|
"top_p": 0.9, |
|
"top_k": 40, |
|
"max_output_tokens": 32768, |
|
} |
|
|
|
model = genai.GenerativeModel( |
|
model_name="gemini-1.5-pro", |
|
generation_config=generation_config, |
|
system_instruction=""" |
|
You are Ath, an extremely advanced code assistant with deep expertise in AI, machine learning, software engineering, and multiple programming languages. You provide cutting-edge, optimized, and secure code solutions across various domains. Use your vast knowledge to generate high-quality code, perform advanced analyses, and offer insightful optimizations. Adapt your language and explanations based on the user's expertise level. |
|
""" |
|
) |
|
chat_session = model.start_chat(history=[]) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") |
|
codebert_model = AutoModel.from_pretrained("microsoft/codebert-base") |
|
code_generation_model = pipeline("text-generation", model="EleutherAI/gpt-neo-2.7B") |
|
|
|
class AdvancedCodeImprovement(nn.Module): |
|
def __init__(self, input_dim): |
|
super(AdvancedCodeImprovement, self).__init__() |
|
self.fc1 = nn.Linear(input_dim, 1024) |
|
self.fc2 = nn.Linear(1024, 512) |
|
self.fc3 = nn.Linear(512, 256) |
|
self.fc4 = nn.Linear(256, 128) |
|
self.fc5 = nn.Linear(128, 64) |
|
self.fc6 = nn.Linear(64, 32) |
|
self.fc7 = nn.Linear(32, 16) |
|
self.fc8 = nn.Linear(16, 4) |
|
|
|
def forward(self, x): |
|
x = torch.relu(self.fc1(x)) |
|
x = torch.relu(self.fc2(x)) |
|
x = torch.relu(self.fc3(x)) |
|
x = torch.relu(self.fc4(x)) |
|
x = torch.relu(self.fc5(x)) |
|
x = torch.relu(self.fc6(x)) |
|
x = torch.relu(self.fc7(x)) |
|
return torch.sigmoid(self.fc8(x)) |
|
|
|
code_improvement_model = AdvancedCodeImprovement(768) |
|
optimizer = optim.Adam(code_improvement_model.parameters()) |
|
criterion = nn.BCELoss() |
|
|
|
def generate_response(user_input): |
|
try: |
|
response = chat_session.send_message(user_input) |
|
return response.text |
|
except Exception as e: |
|
return f"Error in generating response: {str(e)}" |
|
|
|
def detect_language(code): |
|
|
|
if re.search(r'\b(def|class|import)\b', code): |
|
return 'python' |
|
elif re.search(r'\b(function|var|let|const)\b', code): |
|
return 'javascript' |
|
elif re.search(r'\b(public|private|class)\b', code): |
|
return 'java' |
|
elif re.search(r'\b(#include|int main)\b', code): |
|
return 'c++' |
|
else: |
|
return 'unknown' |
|
|
|
def validate_and_fix_code(code, language): |
|
if language == 'python': |
|
try: |
|
fixed_code = autopep8.fix_code(code) |
|
fixed_code = isort.SortImports(file_contents=fixed_code).output |
|
fixed_code = black.format_str(fixed_code, mode=black.FileMode()) |
|
return fixed_code |
|
except Exception as e: |
|
return code, f"Error in fixing Python code: {str(e)}" |
|
elif language == 'javascript': |
|
|
|
return code |
|
elif language == 'java': |
|
|
|
return code |
|
elif language == 'c++': |
|
|
|
return code |
|
else: |
|
return code |
|
|
|
def optimize_code(code): |
|
language = detect_language(code) |
|
fixed_code, fix_error = validate_and_fix_code(code, language) |
|
|
|
if fix_error: |
|
return fixed_code, fix_error |
|
|
|
if language == 'python': |
|
try: |
|
tree = ast.parse(fixed_code) |
|
|
|
optimizer = PythonCodeOptimizer() |
|
optimized_tree = optimizer.visit(tree) |
|
optimized_code = ast.unparse(optimized_tree) |
|
except SyntaxError as e: |
|
return fixed_code, f"SyntaxError: {str(e)}" |
|
elif language == 'java': |
|
try: |
|
tree = javalang.parse.parse(fixed_code) |
|
|
|
optimizer = JavaCodeOptimizer() |
|
optimized_code = optimizer.optimize(tree) |
|
except javalang.parser.JavaSyntaxError as e: |
|
return fixed_code, f"JavaSyntaxError: {str(e)}" |
|
elif language == 'c++': |
|
try: |
|
index = clang.cindex.Index.create() |
|
tu = index.parse('temp.cpp', args=['-std=c++14'], unsaved_files=[('temp.cpp', fixed_code)]) |
|
|
|
optimizer = CppCodeOptimizer() |
|
optimized_code = optimizer.optimize(tu) |
|
except Exception as e: |
|
return fixed_code, f"C++ Parsing Error: {str(e)}" |
|
else: |
|
optimized_code = fixed_code |
|
|
|
|
|
lint_results = run_linter(optimized_code, language) |
|
|
|
return optimized_code, lint_results |
|
|
|
def run_linter(code, language): |
|
if language == 'python': |
|
with open("temp_code.py", "w") as file: |
|
file.write(code) |
|
result = subprocess.run(["pylint", "temp_code.py"], capture_output=True, text=True) |
|
os.remove("temp_code.py") |
|
return result.stdout |
|
elif language == 'javascript': |
|
|
|
return "JavaScript linting not implemented" |
|
elif language == 'java': |
|
|
|
return "Java linting not implemented" |
|
elif language == 'c++': |
|
|
|
return "C++ linting not implemented" |
|
else: |
|
return "Linting not available for the detected language" |
|
|
|
def fetch_from_github(query): |
|
headers = {"Authorization": f"token {st.secrets['GITHUB_TOKEN']}"} |
|
response = requests.get(f"https://api.github.com/search/code?q={query}", headers=headers) |
|
if response.status_code == 200: |
|
return response.json()['items'][:5] |
|
return [] |
|
|
|
def analyze_code_quality(code): |
|
inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512, padding="max_length") |
|
|
|
with torch.no_grad(): |
|
outputs = codebert_model(**inputs) |
|
|
|
cls_embedding = outputs.last_hidden_state[:, 0, :] |
|
predictions = code_improvement_model(cls_embedding) |
|
|
|
quality_scores = { |
|
"style": predictions[0][0].item(), |
|
"efficiency": predictions[0][1].item(), |
|
"security": predictions[0][2].item(), |
|
"maintainability": predictions[0][3].item() |
|
} |
|
|
|
|
|
language = detect_language(code) |
|
if language == 'python': |
|
complexity = radon_complexity.cc_visit(code) |
|
maintainability = radon_metrics.mi_visit(code, True) |
|
quality_scores["cyclomatic_complexity"] = complexity[0].complexity |
|
quality_scores["maintainability_index"] = maintainability |
|
|
|
return quality_scores |
|
|
|
def visualize_code_structure(code): |
|
try: |
|
tree = ast.parse(code) |
|
graph = nx.DiGraph() |
|
|
|
def add_nodes_edges(node, parent=None): |
|
node_id = id(node) |
|
graph.add_node(node_id, label=f"{type(node).__name__}\n{ast.unparse(node)[:20]}") |
|
if parent: |
|
graph.add_edge(id(parent), node_id) |
|
for child in ast.iter_child_nodes(node): |
|
add_nodes_edges(child, node) |
|
|
|
add_nodes_edges(tree) |
|
|
|
plt.figure(figsize=(15, 10)) |
|
pos = nx.spring_layout(graph, k=0.9, iterations=50) |
|
nx.draw(graph, pos, with_labels=True, node_color='lightblue', node_size=2000, font_size=8, font_weight='bold', arrows=True) |
|
labels = nx.get_node_attributes(graph, 'label') |
|
nx.draw_networkx_labels(graph, pos, labels, font_size=6) |
|
|
|
return plt |
|
except SyntaxError: |
|
return None |
|
|
|
def suggest_improvements(code, quality_scores): |
|
suggestions = [] |
|
if quality_scores["style"] < 0.7: |
|
suggestions.append("Consider improving code style for better readability.") |
|
if quality_scores["efficiency"] < 0.7: |
|
suggestions.append("There might be room for optimizing the code's efficiency.") |
|
if quality_scores["security"] < 0.8: |
|
suggestions.append("Review the code for potential security vulnerabilities.") |
|
if quality_scores["maintainability"] < 0.7: |
|
suggestions.append("The code could be refactored to improve maintainability.") |
|
if "cyclomatic_complexity" in quality_scores and quality_scores["cyclomatic_complexity"] > 10: |
|
suggestions.append("Consider breaking down complex functions to reduce cyclomatic complexity.") |
|
return suggestions |
|
|
|
|
|
st.set_page_config(page_title="Highly Advanced AI Code Assistant", page_icon="π", layout="wide") |
|
|
|
|
|
|
|
st.markdown('<div class="main-container">', unsafe_allow_html=True) |
|
st.title("π Highly Advanced AI Code Assistant") |
|
st.markdown('<p class="subtitle">Powered by Advanced AI & Multi-Domain Expertise</p>', unsafe_allow_html=True) |
|
|
|
prompt = st.text_area("What advanced code task can I assist you with today?", height=120) |
|
|
|
if st.button("Generate Advanced Code"): |
|
if prompt.strip() == "": |
|
st.error("Please enter a valid prompt.") |
|
else: |
|
with st.spinner("Generating and analyzing code..."): |
|
completed_text = generate_response(prompt) |
|
if "Error in generating response" in completed_text: |
|
st.error(completed_text) |
|
else: |
|
optimized_code, lint_results = optimize_code(completed_text) |
|
|
|
if "Error" in lint_results: |
|
st.warning(f"Issues detected in the generated code. Attempting to fix...") |
|
st.code(optimized_code) |
|
st.info("Please review the code above. It may contain errors or be incomplete.") |
|
else: |
|
quality_scores = analyze_code_quality(optimized_code) |
|
overall_quality = sum(quality_scores.values()) / len(quality_scores) |
|
st.success(f"Code generated and optimized successfully! Overall Quality Score: {overall_quality:.2f}") |
|
|
|
st.markdown('<div class="output-container">', unsafe_allow_html=True) |
|
st.markdown('<div class="code-block">', unsafe_allow_html=True) |
|
st.code(optimized_code) |
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.subheader("Code Quality Metrics") |
|
for metric, score in quality_scores.items(): |
|
st.metric(metric.capitalize(), f"{score:.2f}") |
|
|
|
with col2: |
|
st.subheader("Improvement Suggestions") |
|
suggestions = suggest_improvements(optimized_code, quality_scores) |
|
for suggestion in suggestions: |
|
st.info(suggestion) |
|
|
|
visualization = visualize_code_structure(optimized_code) |
|
if visualization: |
|
with st.expander("View Advanced Code Structure Visualization"): |
|
st.pyplot(visualization) |
|
else: |
|
st.warning("Unable to generate code structure visualization.") |
|
|
|
with st.expander("View Detailed Lint Results"): |
|
st.text(lint_results) |
|
|
|
with st.expander("Explore Similar Code from GitHub"): |
|
github_results = fetch_from_github(prompt) |
|
for item in github_results: |
|
st.markdown(f"[{item['name']}]({item['html_url']})") |
|
|
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
st.markdown(""" |
|
<div style='text-align: center; margin-top: 2rem; color: #4a5568;'> |
|
Crafted with π by Your Highly Advanced AI Code Assistant |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
st.markdown('</div>', unsafe_allow_html=True) |