Rec_Sys_Flo2 / app.py
bgaspra's picture
Update app.py
26d55ba verified
raw
history blame
6.5 kB
import torch
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import pandas as pd
from datasets import load_dataset
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import warnings
warnings.filterwarnings('ignore')
# Load Florence-2 model and processor
model_name = "microsoft/Florence-2-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True
).to(device)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
# Create a dummy image for text-only processing
DUMMY_IMAGE = Image.new('RGB', (224, 224), color='white')
# Load CivitAI dataset
print("Loading dataset...")
dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train[:1000]")
df = pd.DataFrame(dataset)
print("Dataset loaded successfully!")
text_embedding_cache = {}
def get_image_embedding(image):
try:
inputs = processor(
images=image,
text="Generate image description",
return_tensors="pt",
padding=True
).to(device, torch_dtype)
# Generate decoder_input_ids
decoder_input_ids = model.generate(
**inputs,
max_length=1,
min_length=1,
num_beams=1,
pad_token_id=processor.tokenizer.pad_token_id,
return_dict_in_generate=True,
).sequences
inputs['decoder_input_ids'] = decoder_input_ids
with torch.no_grad():
outputs = model(**inputs)
# Use the mean of the last hidden state as the embedding
image_embeddings = outputs.last_hidden_state.mean(dim=1)
return image_embeddings.cpu().numpy()
except Exception as e:
print(f"Error in get_image_embedding: {str(e)}")
return None
def get_text_embedding(text):
try:
if text in text_embedding_cache:
return text_embedding_cache[text]
# Process text with dummy image
inputs = processor(
images=DUMMY_IMAGE,
text=text,
return_tensors="pt",
padding=True
).to(device, torch_dtype)
# Generate decoder_input_ids
decoder_input_ids = model.generate(
**inputs,
max_length=1,
min_length=1,
num_beams=1,
pad_token_id=processor.tokenizer.pad_token_id,
return_dict_in_generate=True,
).sequences
inputs['decoder_input_ids'] = decoder_input_ids
with torch.no_grad():
outputs = model(**inputs)
text_embeddings = outputs.last_hidden_state.mean(dim=1)
embedding = text_embeddings.cpu().numpy()
text_embedding_cache[text] = embedding
return embedding
except Exception as e:
print(f"Error in get_text_embedding: {str(e)}")
return None
def precompute_embeddings():
print("Pre-computing text embeddings...")
for idx, row in df.iterrows():
if row['prompt'] not in text_embedding_cache:
_ = get_text_embedding(row['prompt'])
if idx % 100 == 0:
print(f"Processed {idx}/1000 embeddings")
print("Finished pre-computing embeddings")
def find_similar_images(uploaded_image, top_k=5):
query_embedding = get_image_embedding(uploaded_image)
if query_embedding is None:
return [], []
similarities = []
for idx, row in df.iterrows():
prompt_embedding = get_text_embedding(row['prompt'])
if prompt_embedding is not None:
similarity = cosine_similarity(query_embedding, prompt_embedding)[0][0]
similarities.append({
'similarity': similarity,
'model': row['Model'],
'prompt': row['prompt']
})
sorted_results = sorted(similarities, key=lambda x: x['similarity'], reverse=True)
top_models = []
top_prompts = []
seen_models = set()
seen_prompts = set()
for result in sorted_results:
if len(top_models) < top_k and result['model'] not in seen_models:
top_models.append(result['model'])
seen_models.add(result['model'])
if len(top_prompts) < top_k and result['prompt'] not in seen_prompts:
top_prompts.append(result['prompt'])
seen_prompts.add(result['prompt'])
if len(top_models) == top_k and len(top_prompts) == top_k:
break
return top_models, top_prompts
def process_image(input_image):
if input_image is None:
return "Please upload an image.", "Please upload an image."
try:
if not isinstance(input_image, Image.Image):
input_image = Image.fromarray(input_image)
# Resize image to expected size
input_image = input_image.resize((224, 224))
recommended_models, recommended_prompts = find_similar_images(input_image)
if not recommended_models or not recommended_prompts:
return "Error processing image.", "Error processing image."
models_text = "Recommended Models:\n" + "\n".join([f"{i+1}. {model}" for i, model in enumerate(recommended_models)])
prompts_text = "Recommended Prompts:\n" + "\n".join([f"{i+1}. {prompt}" for i, prompt in enumerate(recommended_prompts)])
return models_text, prompts_text
except Exception as e:
print(f"Error in process_image: {str(e)}")
return "Error processing image.", "Error processing image."
# Pre-compute embeddings when starting the application
try:
precompute_embeddings()
except Exception as e:
print(f"Error in precompute_embeddings: {str(e)}")
# Create Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Upload AI-generated image"),
outputs=[
gr.Textbox(label="Recommended Models", lines=6),
gr.Textbox(label="Recommended Prompts", lines=6)
],
title="AI Image Model & Prompt Recommender",
description="Upload an AI-generated image to get recommendations for Stable Diffusion models and prompts.",
examples=[],
cache_examples=False
)
# Launch the interface
iface.launch()