Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
from huggingface_hub import login | |
import os | |
import logging | |
from datetime import datetime | |
import json | |
from typing import List, Dict | |
import warnings | |
# Filter out CUDA/NVML warnings | |
warnings.filterwarnings('ignore', category=UserWarning) | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Environment variables | |
HF_TOKEN = os.getenv("HUGGING_FACE_TOKEN") | |
MODEL_NAME = os.getenv("MODEL_NAME", "google/gemma-2b-it") | |
# Cache directory for model | |
CACHE_DIR = "/home/user/.cache/huggingface" | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
# Set environment variables for GPU | |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" | |
class Review: | |
def __init__(self, code: str, language: str, suggestions: str): | |
self.code = code | |
self.language = language | |
self.suggestions = suggestions | |
self.timestamp = datetime.now().isoformat() | |
self.response_time = 0.0 | |
class CodeReviewer: | |
def __init__(self): | |
self.model = None | |
self.tokenizer = None | |
self.device = None | |
self.review_history: List[Review] = [] | |
self.metrics = { | |
'total_reviews': 0, | |
'avg_response_time': 0.0, | |
'reviews_today': 0 | |
} | |
self.initialize_model() | |
def initialize_model(self): | |
"""Initialize the model and tokenizer.""" | |
try: | |
if HF_TOKEN: | |
login(token=HF_TOKEN, add_to_git_credential=False) | |
logger.info("Loading tokenizer...") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_NAME, | |
token=HF_TOKEN, | |
trust_remote_code=True, | |
cache_dir=CACHE_DIR | |
) | |
logger.info("Loading model...") | |
# Initialize model with specific configuration | |
model_kwargs = { | |
"torch_dtype": torch.float16, | |
"trust_remote_code": True, | |
"low_cpu_mem_usage": True, | |
"cache_dir": CACHE_DIR, | |
"token": HF_TOKEN | |
} | |
# Try loading with different configurations | |
try: | |
# First try with device_map="auto" | |
self.model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
device_map="auto", | |
**model_kwargs | |
) | |
self.device = next(self.model.parameters()).device | |
except Exception as e1: | |
logger.warning(f"Failed to load with device_map='auto': {e1}") | |
try: | |
# Try with specific device | |
if torch.cuda.is_available(): | |
self.device = torch.device("cuda:0") | |
else: | |
self.device = torch.device("cpu") | |
model_kwargs["device_map"] = None | |
self.model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
**model_kwargs | |
).to(self.device) | |
except Exception as e2: | |
logger.error(f"Failed to load model on specific device: {e2}") | |
raise | |
logger.info(f"Model loaded successfully on {self.device}") | |
except Exception as e: | |
logger.error(f"Error initializing model: {e}") | |
raise | |
def create_review_prompt(self, code: str, language: str) -> str: | |
"""Create a structured prompt for code review.""" | |
return f"""Review this {language} code. List specific points in these sections: | |
Issues: | |
Improvements: | |
Best Practices: | |
Security: | |
Code: | |
```{language} | |
{code} | |
```""" | |
def review_code(self, code: str, language: str) -> str: | |
"""Perform code review using the model.""" | |
try: | |
start_time = datetime.now() | |
prompt = self.create_review_prompt(code, language) | |
# Tokenize with error handling | |
try: | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=512, | |
padding=True | |
).to(self.device) | |
except Exception as token_error: | |
logger.error(f"Tokenization error: {token_error}") | |
return "Error: Failed to process input code. Please try again." | |
# Generate with error handling | |
try: | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.95, | |
num_beams=1, | |
early_stopping=True | |
) | |
except Exception as gen_error: | |
logger.error(f"Generation error: {gen_error}") | |
return "Error: Failed to generate review. Please try again." | |
# Decode with error handling | |
try: | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
suggestions = response[len(prompt):].strip() | |
except Exception as decode_error: | |
logger.error(f"Decoding error: {decode_error}") | |
return "Error: Failed to decode model output. Please try again." | |
# Create review and update metrics | |
end_time = datetime.now() | |
review = Review(code, language, suggestions) | |
review.response_time = (end_time - start_time).total_seconds() | |
self.review_history.append(review) | |
# Update metrics | |
self.update_metrics(review) | |
# Clear GPU memory | |
if torch.cuda.is_available(): | |
del inputs, outputs | |
torch.cuda.empty_cache() | |
return suggestions | |
except Exception as e: | |
logger.error(f"Error during code review: {e}") | |
return f"Error performing code review: {str(e)}" | |
def update_metrics(self, review: Review): | |
"""Update metrics with new review.""" | |
self.metrics['total_reviews'] += 1 | |
# Update average response time | |
total_time = self.metrics['avg_response_time'] * (self.metrics['total_reviews'] - 1) | |
total_time += review.response_time | |
self.metrics['avg_response_time'] = total_time / self.metrics['total_reviews'] | |
# Update reviews today | |
today = datetime.now().date() | |
self.metrics['reviews_today'] = sum( | |
1 for r in self.review_history | |
if datetime.fromisoformat(r.timestamp).date() == today | |
) | |
def get_history(self) -> List[Dict]: | |
"""Get formatted review history.""" | |
return [ | |
{ | |
'timestamp': r.timestamp, | |
'language': r.language, | |
'code': r.code, | |
'suggestions': r.suggestions, | |
'response_time': f"{r.response_time:.2f}s" | |
} | |
for r in reversed(self.review_history[-10:]) # Last 10 reviews | |
] | |
def get_metrics(self) -> Dict: | |
"""Get current metrics.""" | |
return { | |
'Total Reviews': self.metrics['total_reviews'], | |
'Average Response Time': f"{self.metrics['avg_response_time']:.2f}s", | |
'Reviews Today': self.metrics['reviews_today'], | |
'Device': str(self.device) | |
} | |
# Initialize reviewer | |
reviewer = CodeReviewer() | |
# Create Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft()) as iface: | |
gr.Markdown("# Code Review Assistant") | |
gr.Markdown("An automated code review system powered by Gemma-2b") | |
with gr.Tabs(): | |
with gr.Tab("Review Code"): | |
with gr.Row(): | |
with gr.Column(): | |
code_input = gr.Textbox( | |
lines=10, | |
placeholder="Enter your code here...", | |
label="Code" | |
) | |
language_input = gr.Dropdown( | |
choices=["python", "javascript", "java", "cpp", "typescript", "go", "rust"], | |
value="python", | |
label="Language" | |
) | |
submit_btn = gr.Button("Submit for Review") | |
with gr.Column(): | |
output = gr.Textbox( | |
label="Review Results", | |
lines=10 | |
) | |
with gr.Tab("History"): | |
refresh_history = gr.Button("Refresh History") | |
history_output = gr.Textbox( | |
label="Review History", | |
lines=20 | |
) | |
with gr.Tab("Metrics"): | |
refresh_metrics = gr.Button("Refresh Metrics") | |
metrics_output = gr.JSON( | |
label="Performance Metrics" | |
) | |
# Set up event handlers | |
def review_code_interface(code: str, language: str) -> str: | |
if not code.strip(): | |
return "Please enter some code to review." | |
try: | |
return reviewer.review_code(code, language) | |
except Exception as e: | |
logger.error(f"Interface error: {e}") | |
return f"Error: {str(e)}" | |
def get_history_interface() -> str: | |
try: | |
history = reviewer.get_history() | |
if not history: | |
return "No reviews yet." | |
result = "" | |
for review in history: | |
result += f"Time: {review['timestamp']}\n" | |
result += f"Language: {review['language']}\n" | |
result += f"Response Time: {review['response_time']}\n" | |
result += "Code:\n```\n" + review['code'] + "\n```\n" | |
result += "Suggestions:\n" + review['suggestions'] + "\n" | |
result += "-" * 80 + "\n\n" | |
return result | |
except Exception as e: | |
logger.error(f"History error: {e}") | |
return "Error retrieving history" | |
def get_metrics_interface() -> Dict: | |
try: | |
return reviewer.get_metrics() | |
except Exception as e: | |
logger.error(f"Metrics error: {e}") | |
return {"error": str(e)} | |
submit_btn.click( | |
review_code_interface, | |
inputs=[code_input, language_input], | |
outputs=output | |
) | |
refresh_history.click( | |
get_history_interface, | |
outputs=history_output | |
) | |
refresh_metrics.click( | |
get_metrics_interface, | |
outputs=metrics_output | |
) | |
# Add example inputs | |
gr.Examples( | |
examples=[ | |
["""def add_numbers(a, b): | |
return a + b""", "python"], | |
["""function calculateSum(numbers) { | |
let sum = 0; | |
for(let i = 0; i < numbers.length; i++) { | |
sum += numbers[i]; | |
} | |
return sum; | |
}""", "javascript"] | |
], | |
inputs=[code_input, language_input] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
iface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
quiet=False | |
) | |