import streamlit as st import google.generativeai as genai import requests import subprocess import os import pandas as pd import numpy as np from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestClassifier import torch import torch.nn as nn import torch.optim as optim from transformers import AutoTokenizer, AutoModel import ast import networkx as nx import matplotlib.pyplot as plt # Configure the Gemini API genai.configure(api_key=st.secrets["GOOGLE_API_KEY"]) # Create the model with optimized parameters and enhanced system instructions generation_config = { "temperature": 0.6, "top_p": 0.8, "top_k": 30, "max_output_tokens": 16384, } model = genai.GenerativeModel( model_name="gemini-1.5-pro", generation_config=generation_config, system_instruction=""" You are Ath, a highly advanced code assistant with deep knowledge in AI, machine learning, and software engineering. You provide cutting-edge, optimized, and secure code solutions. Speak casually and use tech jargon when appropriate. """ ) chat_session = model.start_chat(history=[]) # Load pre-trained BERT model for code understanding tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") codebert_model = AutoModel.from_pretrained("microsoft/codebert-base") class CodeImprovement(nn.Module): def __init__(self, input_dim): super(CodeImprovement, self).__init__() self.fc1 = nn.Linear(input_dim, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, 128) self.fc4 = nn.Linear(128, 2) # Binary classification: needs improvement or not def forward(self, x): x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = torch.relu(self.fc3(x)) return torch.sigmoid(self.fc4(x)) code_improvement_model = CodeImprovement(768) # 768 is BERT's output dimension optimizer = optim.Adam(code_improvement_model.parameters()) criterion = nn.BCELoss() def generate_response(user_input): try: response = chat_session.send_message(user_input) return response.text except Exception as e: return f"Error: {e}" def optimize_code(code): # Use abstract syntax tree for advanced code analysis try: tree = ast.parse(code) analyzer = CodeAnalyzer() analyzer.visit(tree) # Apply code transformations based on analysis transformer = CodeTransformer(analyzer.get_optimizations()) optimized_tree = transformer.visit(tree) optimized_code = ast.unparse(optimized_tree) except SyntaxError as e: return code, f"SyntaxError: {str(e)}" # Run pylint for additional suggestions with open("temp_code.py", "w") as file: file.write(optimized_code) result = subprocess.run(["pylint", "temp_code.py"], capture_output=True, text=True) os.remove("temp_code.py") return optimized_code, result.stdout def fetch_from_github(query): headers = {"Authorization": f"token {st.secrets['GITHUB_TOKEN']}"} response = requests.get(f"https://api.github.com/search/code?q={query}", headers=headers) if response.status_code == 200: return response.json()['items'][:5] # Return top 5 results return [] def interact_with_api(api_url): response = requests.get(api_url) return response.json() def train_ml_model(code_data): df = pd.DataFrame(code_data) X = df.drop('target', axis=1) y = df['target'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) model = RandomForestClassifier(n_estimators=100, max_depth=10) model.fit(X_train, y_train) return model def analyze_code_quality(code): # Tokenize and encode the code inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512, padding="max_length") # Get BERT embeddings with torch.no_grad(): outputs = codebert_model(**inputs) # Use the [CLS] token embedding for classification cls_embedding = outputs.last_hidden_state[:, 0, :] # Pass through our code improvement model prediction = code_improvement_model(cls_embedding) return prediction.item() # Return the probability of needing improvement def visualize_code_structure(code): try: tree = ast.parse(code) graph = nx.DiGraph() def add_nodes_edges(node, parent=None): node_id = id(node) graph.add_node(node_id, label=type(node).__name__) if parent: graph.add_edge(id(parent), node_id) for child in ast.iter_child_nodes(node): add_nodes_edges(child, node) add_nodes_edges(tree) plt.figure(figsize=(12, 8)) pos = nx.spring_layout(graph) nx.draw(graph, pos, with_labels=True, node_color='lightblue', node_size=1000, font_size=8, font_weight='bold') labels = nx.get_node_attributes(graph, 'label') nx.draw_networkx_labels(graph, pos, labels, font_size=6) return plt except SyntaxError: return None # Streamlit UI setup st.set_page_config(page_title="Advanced AI Code Assistant", page_icon="🚀", layout="wide") st.markdown(""" """, unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) st.title("🚀 Advanced AI Code Assistant") st.markdown('

Powered by Google Gemini & Deep Learning

', unsafe_allow_html=True) prompt = st.text_area("What advanced code task can I help you with today?", height=120) if st.button("Generate Advanced Code"): if prompt.strip() == "": st.error("Please enter a valid prompt.") else: with st.spinner("Generating and analyzing code..."): completed_text = generate_response(prompt) if "Error" in completed_text: st.error(completed_text) else: optimized_code, lint_results = optimize_code(completed_text) if "SyntaxError" in lint_results: st.warning(f"Syntax error detected: {lint_results}") st.code(completed_text) else: quality_score = analyze_code_quality(optimized_code) st.success(f"Code generated and optimized successfully! Quality Score: {quality_score:.2f}") st.markdown('
', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) st.code(optimized_code) st.markdown('
', unsafe_allow_html=True) visualization = visualize_code_structure(optimized_code) if visualization: with st.expander("View Code Structure Visualization"): st.pyplot(visualization) else: st.warning("Unable to generate code structure visualization due to syntax errors.") with st.expander("View Lint Results"): st.text(lint_results) with st.expander("Fetch Similar Code from GitHub"): github_results = fetch_from_github(prompt) for item in github_results: st.markdown(f"[{item['name']}]({item['html_url']})") st.markdown('
', unsafe_allow_html=True) st.markdown("""
Crafted with 🚀 by Your Advanced AI Code Assistant
""", unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True)