Update app.py
Browse files
app.py
CHANGED
@@ -8,61 +8,88 @@ import pandas as pd
|
|
8 |
from datasets import load_dataset
|
9 |
from torch.utils.data import DataLoader, Dataset
|
10 |
from sklearn.preprocessing import LabelEncoder
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Load dataset
|
13 |
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# Preprocess text data
|
16 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
17 |
|
18 |
class CustomDataset(Dataset):
|
19 |
-
def __init__(self, dataset):
|
20 |
self.dataset = dataset
|
|
|
21 |
self.transform = transforms.Compose([
|
22 |
transforms.Resize((224, 224)),
|
23 |
transforms.ToTensor(),
|
24 |
])
|
25 |
self.label_encoder = LabelEncoder()
|
26 |
self.labels = self.label_encoder.fit_transform(dataset['Model'])
|
27 |
-
|
28 |
def __len__(self):
|
29 |
return len(self.dataset)
|
30 |
-
|
31 |
def __getitem__(self, idx):
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
34 |
label = self.labels[idx]
|
35 |
return image, text, label
|
36 |
|
37 |
-
#
|
38 |
class ImageModel(nn.Module):
|
39 |
def __init__(self):
|
40 |
super(ImageModel, self).__init__()
|
41 |
self.model = models.resnet18(pretrained=True)
|
42 |
self.model.fc = nn.Linear(self.model.fc.in_features, 512)
|
43 |
-
|
44 |
def forward(self, x):
|
45 |
return self.model(x)
|
46 |
|
47 |
-
# Define MLP for text processing
|
48 |
class TextModel(nn.Module):
|
49 |
def __init__(self):
|
50 |
super(TextModel, self).__init__()
|
51 |
self.bert = BertModel.from_pretrained('bert-base-uncased')
|
52 |
self.fc = nn.Linear(768, 512)
|
53 |
-
|
54 |
def forward(self, x):
|
55 |
output = self.bert(**x)
|
56 |
return self.fc(output.pooler_output)
|
57 |
|
58 |
-
# Combined model
|
59 |
class CombinedModel(nn.Module):
|
60 |
def __init__(self):
|
61 |
super(CombinedModel, self).__init__()
|
62 |
self.image_model = ImageModel()
|
63 |
self.text_model = TextModel()
|
64 |
self.fc = nn.Linear(1024, len(dataset['Model']))
|
65 |
-
|
66 |
def forward(self, image, text):
|
67 |
image_features = self.image_model(image)
|
68 |
text_features = self.text_model(text)
|
@@ -72,24 +99,45 @@ class CombinedModel(nn.Module):
|
|
72 |
# Instantiate model
|
73 |
model = CombinedModel()
|
74 |
|
75 |
-
#
|
76 |
-
def
|
77 |
model.eval()
|
78 |
with torch.no_grad():
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
# Set up Gradio interface
|
88 |
-
interface = gr.Interface(
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
93 |
|
94 |
# Launch the app
|
95 |
interface.launch()
|
|
|
8 |
from datasets import load_dataset
|
9 |
from torch.utils.data import DataLoader, Dataset
|
10 |
from sklearn.preprocessing import LabelEncoder
|
11 |
+
import requests
|
12 |
+
from PIL import Image
|
13 |
+
from io import BytesIO
|
14 |
+
import numpy as np
|
15 |
|
16 |
# Load dataset
|
17 |
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
|
18 |
|
19 |
+
# Download and cache images
|
20 |
+
def download_image(url):
|
21 |
+
try:
|
22 |
+
response = requests.get(url)
|
23 |
+
img = Image.open(BytesIO(response.content))
|
24 |
+
return img
|
25 |
+
except:
|
26 |
+
return None
|
27 |
+
|
28 |
+
# Create image cache
|
29 |
+
image_cache = {}
|
30 |
+
for idx, item in enumerate(dataset):
|
31 |
+
if idx % 100 == 0: # Status update
|
32 |
+
print(f"Downloaded {idx} images")
|
33 |
+
url = item['url']
|
34 |
+
if url not in image_cache:
|
35 |
+
img = download_image(url)
|
36 |
+
if img is not None:
|
37 |
+
image_cache[url] = img
|
38 |
+
|
39 |
# Preprocess text data
|
40 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
41 |
|
42 |
class CustomDataset(Dataset):
|
43 |
+
def __init__(self, dataset, image_cache):
|
44 |
self.dataset = dataset
|
45 |
+
self.image_cache = image_cache
|
46 |
self.transform = transforms.Compose([
|
47 |
transforms.Resize((224, 224)),
|
48 |
transforms.ToTensor(),
|
49 |
])
|
50 |
self.label_encoder = LabelEncoder()
|
51 |
self.labels = self.label_encoder.fit_transform(dataset['Model'])
|
52 |
+
|
53 |
def __len__(self):
|
54 |
return len(self.dataset)
|
55 |
+
|
56 |
def __getitem__(self, idx):
|
57 |
+
url = self.dataset[idx]['url']
|
58 |
+
image = self.transform(self.image_cache[url])
|
59 |
+
text = tokenizer(self.dataset[idx]['prompt'],
|
60 |
+
padding='max_length',
|
61 |
+
truncation=True,
|
62 |
+
return_tensors='pt')
|
63 |
label = self.labels[idx]
|
64 |
return image, text, label
|
65 |
|
66 |
+
# Model definitions remain the same
|
67 |
class ImageModel(nn.Module):
|
68 |
def __init__(self):
|
69 |
super(ImageModel, self).__init__()
|
70 |
self.model = models.resnet18(pretrained=True)
|
71 |
self.model.fc = nn.Linear(self.model.fc.in_features, 512)
|
72 |
+
|
73 |
def forward(self, x):
|
74 |
return self.model(x)
|
75 |
|
|
|
76 |
class TextModel(nn.Module):
|
77 |
def __init__(self):
|
78 |
super(TextModel, self).__init__()
|
79 |
self.bert = BertModel.from_pretrained('bert-base-uncased')
|
80 |
self.fc = nn.Linear(768, 512)
|
81 |
+
|
82 |
def forward(self, x):
|
83 |
output = self.bert(**x)
|
84 |
return self.fc(output.pooler_output)
|
85 |
|
|
|
86 |
class CombinedModel(nn.Module):
|
87 |
def __init__(self):
|
88 |
super(CombinedModel, self).__init__()
|
89 |
self.image_model = ImageModel()
|
90 |
self.text_model = TextModel()
|
91 |
self.fc = nn.Linear(1024, len(dataset['Model']))
|
92 |
+
|
93 |
def forward(self, image, text):
|
94 |
image_features = self.image_model(image)
|
95 |
text_features = self.text_model(text)
|
|
|
99 |
# Instantiate model
|
100 |
model = CombinedModel()
|
101 |
|
102 |
+
# Modified prediction function
|
103 |
+
def get_recommendations(input_image):
|
104 |
model.eval()
|
105 |
with torch.no_grad():
|
106 |
+
# Process input image
|
107 |
+
transform = transforms.Compose([
|
108 |
+
transforms.Resize((224, 224)),
|
109 |
+
transforms.ToTensor()
|
110 |
+
])
|
111 |
+
input_tensor = transform(input_image).unsqueeze(0)
|
112 |
+
|
113 |
+
# Get dummy text input (since we're focusing on image similarity)
|
114 |
+
text_input = tokenizer("", return_tensors='pt', padding=True, truncation=True)
|
115 |
+
|
116 |
+
# Get model output
|
117 |
+
output = model(input_tensor, text_input)
|
118 |
+
scores, indices = torch.topk(output, 5)
|
119 |
+
|
120 |
+
# Prepare gallery output
|
121 |
+
gallery_images = []
|
122 |
+
for idx in indices[0]:
|
123 |
+
url = dataset[idx]['url']
|
124 |
+
model_name = dataset[idx]['Model']
|
125 |
+
score = scores[0][idx].item()
|
126 |
+
|
127 |
+
# Get image from cache
|
128 |
+
if url in image_cache:
|
129 |
+
gallery_images.append((image_cache[url], f"{model_name}\nScore: {score:.2f}"))
|
130 |
+
|
131 |
+
return gallery_images
|
132 |
|
133 |
# Set up Gradio interface
|
134 |
+
interface = gr.Interface(
|
135 |
+
fn=get_recommendations,
|
136 |
+
inputs=gr.Image(type="pil"),
|
137 |
+
outputs=gr.Gallery(label="Recommended Images"),
|
138 |
+
title="Image Recommendation System",
|
139 |
+
description="Upload an image and get similar images with their model names and distances."
|
140 |
+
)
|
141 |
|
142 |
# Launch the app
|
143 |
interface.launch()
|