bgaspra commited on
Commit
0365b37
·
verified ·
1 Parent(s): e897bc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -25
app.py CHANGED
@@ -8,61 +8,88 @@ import pandas as pd
8
  from datasets import load_dataset
9
  from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
 
 
 
 
11
 
12
  # Load dataset
13
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Preprocess text data
16
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
17
 
18
  class CustomDataset(Dataset):
19
- def __init__(self, dataset):
20
  self.dataset = dataset
 
21
  self.transform = transforms.Compose([
22
  transforms.Resize((224, 224)),
23
  transforms.ToTensor(),
24
  ])
25
  self.label_encoder = LabelEncoder()
26
  self.labels = self.label_encoder.fit_transform(dataset['Model'])
27
-
28
  def __len__(self):
29
  return len(self.dataset)
30
-
31
  def __getitem__(self, idx):
32
- image = self.transform(self.dataset[idx]['image'])
33
- text = tokenizer(self.dataset[idx]['prompt'], padding='max_length', truncation=True, return_tensors='pt')
 
 
 
 
34
  label = self.labels[idx]
35
  return image, text, label
36
 
37
- # Define CNN for image processing
38
  class ImageModel(nn.Module):
39
  def __init__(self):
40
  super(ImageModel, self).__init__()
41
  self.model = models.resnet18(pretrained=True)
42
  self.model.fc = nn.Linear(self.model.fc.in_features, 512)
43
-
44
  def forward(self, x):
45
  return self.model(x)
46
 
47
- # Define MLP for text processing
48
  class TextModel(nn.Module):
49
  def __init__(self):
50
  super(TextModel, self).__init__()
51
  self.bert = BertModel.from_pretrained('bert-base-uncased')
52
  self.fc = nn.Linear(768, 512)
53
-
54
  def forward(self, x):
55
  output = self.bert(**x)
56
  return self.fc(output.pooler_output)
57
 
58
- # Combined model
59
  class CombinedModel(nn.Module):
60
  def __init__(self):
61
  super(CombinedModel, self).__init__()
62
  self.image_model = ImageModel()
63
  self.text_model = TextModel()
64
  self.fc = nn.Linear(1024, len(dataset['Model']))
65
-
66
  def forward(self, image, text):
67
  image_features = self.image_model(image)
68
  text_features = self.text_model(text)
@@ -72,24 +99,45 @@ class CombinedModel(nn.Module):
72
  # Instantiate model
73
  model = CombinedModel()
74
 
75
- # Define predict function
76
- def predict(image):
77
  model.eval()
78
  with torch.no_grad():
79
- image = transforms.ToTensor()(image).unsqueeze(0)
80
- image = transforms.Resize((224, 224))(image)
81
- text_input = tokenizer("Sample prompt", return_tensors='pt', padding=True, truncation=True)
82
- output = model(image, text_input)
83
- _, indices = torch.topk(output, 5)
84
- recommended_models = [dataset['Model'][i] for i in indices[0]]
85
- return recommended_models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  # Set up Gradio interface
88
- interface = gr.Interface(fn=predict,
89
- inputs=gr.Image(type="pil"),
90
- outputs=gr.Textbox(label="Recommended Models"),
91
- title="AI Image Model Recommender",
92
- description="Upload an AI-generated image to receive model recommendations.")
 
 
93
 
94
  # Launch the app
95
  interface.launch()
 
8
  from datasets import load_dataset
9
  from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
11
+ import requests
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ import numpy as np
15
 
16
  # Load dataset
17
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
18
 
19
+ # Download and cache images
20
+ def download_image(url):
21
+ try:
22
+ response = requests.get(url)
23
+ img = Image.open(BytesIO(response.content))
24
+ return img
25
+ except:
26
+ return None
27
+
28
+ # Create image cache
29
+ image_cache = {}
30
+ for idx, item in enumerate(dataset):
31
+ if idx % 100 == 0: # Status update
32
+ print(f"Downloaded {idx} images")
33
+ url = item['url']
34
+ if url not in image_cache:
35
+ img = download_image(url)
36
+ if img is not None:
37
+ image_cache[url] = img
38
+
39
  # Preprocess text data
40
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
41
 
42
  class CustomDataset(Dataset):
43
+ def __init__(self, dataset, image_cache):
44
  self.dataset = dataset
45
+ self.image_cache = image_cache
46
  self.transform = transforms.Compose([
47
  transforms.Resize((224, 224)),
48
  transforms.ToTensor(),
49
  ])
50
  self.label_encoder = LabelEncoder()
51
  self.labels = self.label_encoder.fit_transform(dataset['Model'])
52
+
53
  def __len__(self):
54
  return len(self.dataset)
55
+
56
  def __getitem__(self, idx):
57
+ url = self.dataset[idx]['url']
58
+ image = self.transform(self.image_cache[url])
59
+ text = tokenizer(self.dataset[idx]['prompt'],
60
+ padding='max_length',
61
+ truncation=True,
62
+ return_tensors='pt')
63
  label = self.labels[idx]
64
  return image, text, label
65
 
66
+ # Model definitions remain the same
67
  class ImageModel(nn.Module):
68
  def __init__(self):
69
  super(ImageModel, self).__init__()
70
  self.model = models.resnet18(pretrained=True)
71
  self.model.fc = nn.Linear(self.model.fc.in_features, 512)
72
+
73
  def forward(self, x):
74
  return self.model(x)
75
 
 
76
  class TextModel(nn.Module):
77
  def __init__(self):
78
  super(TextModel, self).__init__()
79
  self.bert = BertModel.from_pretrained('bert-base-uncased')
80
  self.fc = nn.Linear(768, 512)
81
+
82
  def forward(self, x):
83
  output = self.bert(**x)
84
  return self.fc(output.pooler_output)
85
 
 
86
  class CombinedModel(nn.Module):
87
  def __init__(self):
88
  super(CombinedModel, self).__init__()
89
  self.image_model = ImageModel()
90
  self.text_model = TextModel()
91
  self.fc = nn.Linear(1024, len(dataset['Model']))
92
+
93
  def forward(self, image, text):
94
  image_features = self.image_model(image)
95
  text_features = self.text_model(text)
 
99
  # Instantiate model
100
  model = CombinedModel()
101
 
102
+ # Modified prediction function
103
+ def get_recommendations(input_image):
104
  model.eval()
105
  with torch.no_grad():
106
+ # Process input image
107
+ transform = transforms.Compose([
108
+ transforms.Resize((224, 224)),
109
+ transforms.ToTensor()
110
+ ])
111
+ input_tensor = transform(input_image).unsqueeze(0)
112
+
113
+ # Get dummy text input (since we're focusing on image similarity)
114
+ text_input = tokenizer("", return_tensors='pt', padding=True, truncation=True)
115
+
116
+ # Get model output
117
+ output = model(input_tensor, text_input)
118
+ scores, indices = torch.topk(output, 5)
119
+
120
+ # Prepare gallery output
121
+ gallery_images = []
122
+ for idx in indices[0]:
123
+ url = dataset[idx]['url']
124
+ model_name = dataset[idx]['Model']
125
+ score = scores[0][idx].item()
126
+
127
+ # Get image from cache
128
+ if url in image_cache:
129
+ gallery_images.append((image_cache[url], f"{model_name}\nScore: {score:.2f}"))
130
+
131
+ return gallery_images
132
 
133
  # Set up Gradio interface
134
+ interface = gr.Interface(
135
+ fn=get_recommendations,
136
+ inputs=gr.Image(type="pil"),
137
+ outputs=gr.Gallery(label="Recommended Images"),
138
+ title="Image Recommendation System",
139
+ description="Upload an image and get similar images with their model names and distances."
140
+ )
141
 
142
  # Launch the app
143
  interface.launch()