Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from PIL import Image | |
import pandas as pd | |
from datasets import load_dataset | |
from sklearn.metrics.pairwise import cosine_similarity | |
import numpy as np | |
import warnings | |
warnings.filterwarnings('ignore') | |
# Load Florence-2 model and processor | |
model_name = "microsoft/Florence-2-base" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch_dtype, | |
trust_remote_code=True | |
).to(device) | |
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) | |
# Create a dummy image for text-only processing | |
DUMMY_IMAGE = Image.new('RGB', (224, 224), color='white') | |
# Load CivitAI dataset | |
print("Loading dataset...") | |
dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train[:1000]") | |
df = pd.DataFrame(dataset) | |
print("Dataset loaded successfully!") | |
text_embedding_cache = {} | |
def get_image_embedding(image): | |
try: | |
inputs = processor( | |
images=image, | |
text="Generate image description", | |
return_tensors="pt", | |
padding=True | |
).to(device, torch_dtype) | |
# Generate decoder_input_ids | |
decoder_input_ids = model.generate( | |
**inputs, | |
max_length=1, | |
min_length=1, | |
num_beams=1, | |
pad_token_id=processor.tokenizer.pad_token_id, | |
return_dict_in_generate=True, | |
).sequences | |
inputs['decoder_input_ids'] = decoder_input_ids | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Use the mean of the last hidden state as the embedding | |
image_embeddings = outputs.last_hidden_state.mean(dim=1) | |
return image_embeddings.cpu().numpy() | |
except Exception as e: | |
print(f"Error in get_image_embedding: {str(e)}") | |
return None | |
def get_text_embedding(text): | |
try: | |
if text in text_embedding_cache: | |
return text_embedding_cache[text] | |
# Process text with dummy image | |
inputs = processor( | |
images=DUMMY_IMAGE, | |
text=text, | |
return_tensors="pt", | |
padding=True | |
).to(device, torch_dtype) | |
# Generate decoder_input_ids | |
decoder_input_ids = model.generate( | |
**inputs, | |
max_length=1, | |
min_length=1, | |
num_beams=1, | |
pad_token_id=processor.tokenizer.pad_token_id, | |
return_dict_in_generate=True, | |
).sequences | |
inputs['decoder_input_ids'] = decoder_input_ids | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
text_embeddings = outputs.last_hidden_state.mean(dim=1) | |
embedding = text_embeddings.cpu().numpy() | |
text_embedding_cache[text] = embedding | |
return embedding | |
except Exception as e: | |
print(f"Error in get_text_embedding: {str(e)}") | |
return None | |
def precompute_embeddings(): | |
print("Pre-computing text embeddings...") | |
for idx, row in df.iterrows(): | |
if row['prompt'] not in text_embedding_cache: | |
_ = get_text_embedding(row['prompt']) | |
if idx % 100 == 0: | |
print(f"Processed {idx}/1000 embeddings") | |
print("Finished pre-computing embeddings") | |
def find_similar_images(uploaded_image, top_k=5): | |
query_embedding = get_image_embedding(uploaded_image) | |
if query_embedding is None: | |
return [], [] | |
similarities = [] | |
for idx, row in df.iterrows(): | |
prompt_embedding = get_text_embedding(row['prompt']) | |
if prompt_embedding is not None: | |
similarity = cosine_similarity(query_embedding, prompt_embedding)[0][0] | |
similarities.append({ | |
'similarity': similarity, | |
'model': row['Model'], | |
'prompt': row['prompt'] | |
}) | |
sorted_results = sorted(similarities, key=lambda x: x['similarity'], reverse=True) | |
top_models = [] | |
top_prompts = [] | |
seen_models = set() | |
seen_prompts = set() | |
for result in sorted_results: | |
if len(top_models) < top_k and result['model'] not in seen_models: | |
top_models.append(result['model']) | |
seen_models.add(result['model']) | |
if len(top_prompts) < top_k and result['prompt'] not in seen_prompts: | |
top_prompts.append(result['prompt']) | |
seen_prompts.add(result['prompt']) | |
if len(top_models) == top_k and len(top_prompts) == top_k: | |
break | |
return top_models, top_prompts | |
def process_image(input_image): | |
if input_image is None: | |
return "Please upload an image.", "Please upload an image." | |
try: | |
if not isinstance(input_image, Image.Image): | |
input_image = Image.fromarray(input_image) | |
# Resize image to expected size | |
input_image = input_image.resize((224, 224)) | |
recommended_models, recommended_prompts = find_similar_images(input_image) | |
if not recommended_models or not recommended_prompts: | |
return "Error processing image.", "Error processing image." | |
models_text = "Recommended Models:\n" + "\n".join([f"{i+1}. {model}" for i, model in enumerate(recommended_models)]) | |
prompts_text = "Recommended Prompts:\n" + "\n".join([f"{i+1}. {prompt}" for i, prompt in enumerate(recommended_prompts)]) | |
return models_text, prompts_text | |
except Exception as e: | |
print(f"Error in process_image: {str(e)}") | |
return "Error processing image.", "Error processing image." | |
# Pre-compute embeddings when starting the application | |
try: | |
precompute_embeddings() | |
except Exception as e: | |
print(f"Error in precompute_embeddings: {str(e)}") | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=gr.Image(type="pil", label="Upload AI-generated image"), | |
outputs=[ | |
gr.Textbox(label="Recommended Models", lines=6), | |
gr.Textbox(label="Recommended Prompts", lines=6) | |
], | |
title="AI Image Model & Prompt Recommender", | |
description="Upload an AI-generated image to get recommendations for Stable Diffusion models and prompts.", | |
examples=[], | |
cache_examples=False | |
) | |
# Launch the interface | |
iface.launch() |