Spaces:
Runtime error
Runtime error
import torch | |
from peft import PeftModel, PeftConfig | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
import spaces | |
# Load the model and tokenizer | |
peft_model_id = "rootxhacker/CodeAstra-7B" | |
config = PeftConfig.from_pretrained(peft_model_id) | |
# Function to move tensors to CPU | |
def to_cpu(obj): | |
if isinstance(obj, torch.Tensor): | |
return obj.cpu() | |
elif isinstance(obj, list): | |
return [to_cpu(item) for item in obj] | |
elif isinstance(obj, tuple): | |
return tuple(to_cpu(item) for item in obj) | |
elif isinstance(obj, dict): | |
return {key: to_cpu(value) for key, value in obj.items()} | |
return obj | |
# Load the model | |
model = AutoModelForCausalLM.from_pretrained( | |
config.base_model_name_or_path, | |
return_dict=True, | |
load_in_4bit=True, | |
device_map='auto' | |
) | |
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) | |
# Load the Lora model | |
model = PeftModel.from_pretrained(model, peft_model_id) | |
def get_completion(query, model, tokenizer): | |
try: | |
# Move model to CUDA | |
model = model.cuda() | |
# Ensure input is on CUDA | |
inputs = tokenizer(query, return_tensors="pt").to('cuda') | |
with torch.no_grad(): | |
outputs = model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7) | |
# Move outputs to CPU before decoding | |
outputs = to_cpu(outputs) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
finally: | |
# Move model back to CPU to free up GPU memory | |
model = model.cpu() | |
torch.cuda.empty_cache() | |
def code_review(code_to_analyze): | |
two_shot_prompt = f"""find all vulnerabilities which in the code | |
{code_to_analyze} """ | |
full_response = get_completion(two_shot_prompt, model, tokenizer) | |
# Return the full response without any processing | |
return full_response | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=code_review, | |
inputs=gr.Textbox(lines=10, label="Enter code to analyze"), | |
outputs=gr.Textbox(label="Code Review Result"), | |
title="Code Review Expert", | |
description="This tool analyzes code for potential security flaws, logic vulnerabilities, and provides guidance on secure coding practices." | |
) | |
# Launch the Gradio app | |
iface.launch() |