fusion / app.py
whyumesh's picture
Update app.py
b49d6ac verified
raw
history blame
7.15 kB
import torch
from transformers import (
Qwen2VLForConditionalGeneration,
AutoProcessor,
AutoModelForCausalLM,
AutoTokenizer
)
from qwen_vl_utils import process_vision_info
from PIL import Image
import cv2
import numpy as np
import gradio as gr
import spaces
# Load both models and their processors/tokenizers
def load_models():
# Vision model
vision_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
torch_dtype=torch.float16,
device_map="auto"
)
vision_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
# Code model
code_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-Coder-1.5B-Instruct",
torch_dtype=torch.float16,
device_map="auto"
)
code_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-1.5B-Instruct")
return vision_model, vision_processor, code_model, code_tokenizer
vision_model, vision_processor, code_model, code_tokenizer = load_models()
VISION_SYSTEM_PROMPT = """You are an AI assistant specialized in analyzing images and videos of code editors. Your primary task is to:
1. FIRST AND MOST IMPORTANTLY: Check if the image contains any inappropriate content such as:
- Harassment or bullying
- Hate speech or discriminatory content
- Sexually explicit material
- Dangerous or harmful content
If any such content is detected, respond ONLY with: "I apologize, but I cannot process this content as it appears to contain [type of inappropriate content]. Please provide only appropriate code-related images."
2. If the content is appropriate, then:
- Extract and describe any code snippets visible in the image
- Identify any error messages, warnings, or highlighting that indicates bugs
- Describe the programming language and context if visible
Be thorough and accurate in your description of appropriate content, as this will be used to fix the code."""
CODE_SYSTEM_PROMPT = """You are an expert code debugging assistant. Your tasks in order are:
1. Check if the input description contains any flags for inappropriate content.
If it does, respond ONLY with: "I apologize, but I cannot process this request as the original content was flagged as inappropriate."
2. If the content is appropriate, then based on the description of code and errors provided:
- Identify the bugs and issues in the code
- Provide a corrected version of the code
- Explain the fixes made and why they resolve the issues
- Provide the output in a well-structured format removing all unnecessary information
Be thorough in your explanation and ensure the corrected code is complete and functional."""
def process_image_for_code(image):
# First, process with vision model
vision_messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": f"{VISION_SYSTEM_PROMPT}\n\nDescribe the code and any errors you see in this image."},
],
}
]
vision_text = vision_processor.apply_chat_template(
vision_messages,
tokenize=False,
add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(vision_messages)
vision_inputs = vision_processor(
text=[vision_text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(vision_model.device)
with torch.no_grad():
vision_output_ids = vision_model.generate(**vision_inputs, max_new_tokens=512)
vision_output_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(vision_inputs.input_ids, vision_output_ids)
]
vision_description = vision_processor.batch_decode(
vision_output_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
# Check if vision model flagged inappropriate content
if "I apologize, but I cannot process this content" in vision_description:
return vision_description, "No code analysis provided due to inappropriate content."
# Then, use code model to fix the code
code_messages = [
{"role": "system", "content": CODE_SYSTEM_PROMPT},
{"role": "user", "content": f"Here's a description of code with errors:\n\n{vision_description}\n\nPlease analyze and fix the code."}
]
code_text = code_tokenizer.apply_chat_template(
code_messages,
tokenize=False,
add_generation_prompt=True
)
code_inputs = code_tokenizer([code_text], return_tensors="pt").to(code_model.device)
with torch.no_grad():
code_output_ids = code_model.generate(
**code_inputs,
max_new_tokens=1024,
temperature=0.7,
top_p=0.95,
)
code_output_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(code_inputs.input_ids, code_output_ids)
]
fixed_code_response = code_tokenizer.batch_decode(
code_output_trimmed,
skip_special_tokens=True
)[0]
return vision_description, fixed_code_response
def process_video_for_code(video_path, max_frames=16, frame_interval=30):
cap = cv2.VideoCapture(video_path)
frames = []
frame_count = 0
while len(frames) < max_frames:
ret, frame = cap.read()
if not ret:
break
if frame_count % frame_interval == 0:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frames.append(frame)
frame_count += 1
cap.release()
if frames:
return process_image_for_code(frames[0])
else:
return "No frames could be extracted from the video.", "No code could be analyzed."
@spaces.GPU
def process_content(content):
if content is None:
return "Please upload an image or video file of code with errors.", ""
try:
if content.name.lower().endswith(('.png', '.jpg', '.jpeg')):
image = Image.open(content.name)
vision_output, code_output = process_image_for_code(image)
elif content.name.lower().endswith(('.mp4', '.avi', '.mov')):
vision_output, code_output = process_video_for_code(content.name)
else:
return "Unsupported file type. Please provide an image or video file.", ""
except Exception as e:
return f"An error occurred while processing the file: {str(e)}", ""
return vision_output, code_output
# Gradio interface
iface = gr.Interface(
fn=process_content,
inputs=gr.File(label="Upload Image or Video of Code with Errors"),
outputs=[
gr.Textbox(label="Vision Model Output (Code Description)"),
gr.Code(label="Fixed Code", language="python")
],
title="Vision Code Debugger",
description="Upload an image or video of code with errors for AI analysis and fixes. Note: Only appropriate code-related content will be processed."
)
if __name__ == "__main__":
iface.launch()