bgaspra commited on
Commit
827eb57
·
verified ·
1 Parent(s): 9b04cd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -46
app.py CHANGED
@@ -8,18 +8,14 @@ import pandas as pd
8
  from datasets import load_dataset
9
  from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
11
- import matplotlib.pyplot as plt
12
- from PIL import Image
13
- import io
14
 
15
  # Load dataset
16
- dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:1000]')
17
 
18
  # Preprocess text data
19
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
20
-
21
  class CustomDataset(Dataset):
22
- def __init__(self, dataset):
23
  self.dataset = dataset
24
  self.transform = transforms.Compose([
25
  transforms.Resize((224, 224)),
@@ -27,83 +23,65 @@ class CustomDataset(Dataset):
27
  ])
28
  self.label_encoder = LabelEncoder()
29
  self.labels = self.label_encoder.fit_transform(dataset['Model'])
30
-
31
- def __len__(self):
32
  return len(self.dataset)
33
-
34
- def __getitem__(self, idx):
35
  image = self.transform(self.dataset[idx]['image'])
36
  text = tokenizer(self.dataset[idx]['prompt'], padding='max_length', truncation=True, return_tensors='pt')
37
  label = self.labels[idx]
38
  return image, text, label
39
-
40
  # Define CNN for image processing
41
  class ImageModel(nn.Module):
42
- def __init__(self):
43
- super(ImageModel, self).__init__()
44
  self.model = models.resnet18(pretrained=True)
45
  self.model.fc = nn.Linear(self.model.fc.in_features, 512)
46
-
47
  def forward(self, x):
48
  return self.model(x)
49
-
50
  # Define MLP for text processing
51
  class TextModel(nn.Module):
52
- def __init__(self):
53
- super(TextModel, self).__init__()
54
  self.bert = BertModel.from_pretrained('bert-base-uncased')
55
  self.fc = nn.Linear(768, 512)
56
-
57
  def forward(self, x):
58
- output = self.bert(**x)
59
  return self.fc(output.pooler_output)
60
-
61
  # Combined model
62
  class CombinedModel(nn.Module):
63
- def __init__(self):
64
- super(CombinedModel, self).__init__()
65
  self.image_model = ImageModel()
66
  self.text_model = TextModel()
67
  self.fc = nn.Linear(1024, len(dataset['Model']))
68
-
69
  def forward(self, image, text):
70
  image_features = self.image_model(image)
71
  text_features = self.text_model(text)
72
  combined = torch.cat((image_features, text_features), dim=1)
73
  return self.fc(combined)
74
-
75
  # Instantiate model
76
  model = CombinedModel()
77
-
78
- def display_recommendations(image, recommended_models, distances):
79
- fig, axes = plt.subplots(1, len(recommended_models), figsize=(16, 4))
80
- fig.suptitle("Recommended Models")
81
-
82
- for i, (model, distance) in enumerate(zip(recommended_models, distances)):
83
- # Load and display the recommended model image
84
- model_image = Image.open(io.BytesIO(dataset.get_example(model)['image']))
85
- axes[i].imshow(model_image)
86
- axes[i].axis('off')
87
- axes[i].set_title(f"{model}\nDistance: {distance:.2f}")
88
-
89
- return fig
90
-
91
  def predict(image):
92
  model.eval()
93
  with torch.no_grad():
94
  image = transforms.ToTensor()(image).unsqueeze(0)
95
  image = transforms.Resize((224, 224))(image)
96
  text_input = tokenizer("Sample prompt", return_tensors='pt', padding=True, truncation=True)
97
- output = model(image, text_input)
98
- distances, indices = torch.topk(-output.squeeze(), 5)
99
- recommended_models = [dataset['Model'][i] for i in indices]
100
- distances = (-distances).tolist()
101
- return display_recommendations(image, recommended_models, distances)
102
-
103
  interface = gr.Interface(fn=predict,
104
  inputs=gr.Image(type="pil"),
105
- outputs=gr.Plot(),
106
  title="AI Image Model Recommender",
107
  description="Upload an AI-generated image to receive model recommendations.")
108
-
109
  interface.launch()
 
8
  from datasets import load_dataset
9
  from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
 
 
 
11
 
12
  # Load dataset
13
+ dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
14
 
15
  # Preprocess text data
16
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 
17
  class CustomDataset(Dataset):
18
+ def init(self, dataset):
19
  self.dataset = dataset
20
  self.transform = transforms.Compose([
21
  transforms.Resize((224, 224)),
 
23
  ])
24
  self.label_encoder = LabelEncoder()
25
  self.labels = self.label_encoder.fit_transform(dataset['Model'])
26
+ def len(self):
 
27
  return len(self.dataset)
28
+ def getitem(self, idx):
 
29
  image = self.transform(self.dataset[idx]['image'])
30
  text = tokenizer(self.dataset[idx]['prompt'], padding='max_length', truncation=True, return_tensors='pt')
31
  label = self.labels[idx]
32
  return image, text, label
33
+
34
  # Define CNN for image processing
35
  class ImageModel(nn.Module):
36
+ def init(self):
37
+ super(ImageModel, self).init()
38
  self.model = models.resnet18(pretrained=True)
39
  self.model.fc = nn.Linear(self.model.fc.in_features, 512)
 
40
  def forward(self, x):
41
  return self.model(x)
42
+
43
  # Define MLP for text processing
44
  class TextModel(nn.Module):
45
+ def init(self):
46
+ super(TextModel, self).init()
47
  self.bert = BertModel.from_pretrained('bert-base-uncased')
48
  self.fc = nn.Linear(768, 512)
 
49
  def forward(self, x):
50
+ output = self.bert(x)
51
  return self.fc(output.pooler_output)
52
+
53
  # Combined model
54
  class CombinedModel(nn.Module):
55
+ def init**(self):
56
+ super(CombinedModel, self).init()
57
  self.image_model = ImageModel()
58
  self.text_model = TextModel()
59
  self.fc = nn.Linear(1024, len(dataset['Model']))
 
60
  def forward(self, image, text):
61
  image_features = self.image_model(image)
62
  text_features = self.text_model(text)
63
  combined = torch.cat((image_features, text_features), dim=1)
64
  return self.fc(combined)
65
+
66
  # Instantiate model
67
  model = CombinedModel()
68
+ # Define predict function
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def predict(image):
70
  model.eval()
71
  with torch.no_grad():
72
  image = transforms.ToTensor()(image).unsqueeze(0)
73
  image = transforms.Resize((224, 224))(image)
74
  text_input = tokenizer("Sample prompt", return_tensors='pt', padding=True, truncation=True)
75
+ output = model(image, textinput)
76
+ , indices = torch.topk(output, 5)
77
+ recommended_models = [dataset['Model'][i] for i in indices[0]]
78
+ return recommended_models
79
+
80
+ # Set up Gradio interface
81
  interface = gr.Interface(fn=predict,
82
  inputs=gr.Image(type="pil"),
83
+ outputs=gr.Textbox(label="Recommended Models"),
84
  title="AI Image Model Recommender",
85
  description="Upload an AI-generated image to receive model recommendations.")
86
+ # Launch the app
87
  interface.launch()