bgaspra commited on
Commit
7265ee9
·
verified ·
1 Parent(s): 4caf7ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -9
app.py CHANGED
@@ -8,9 +8,12 @@ import pandas as pd
8
  from datasets import load_dataset
9
  from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
 
 
 
11
 
12
  # Load dataset
13
- dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
14
 
15
  # Preprocess text data
16
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
@@ -72,7 +75,19 @@ class CombinedModel(nn.Module):
72
  # Instantiate model
73
  model = CombinedModel()
74
 
75
- # Define predict function
 
 
 
 
 
 
 
 
 
 
 
 
76
  def predict(image):
77
  model.eval()
78
  with torch.no_grad():
@@ -80,16 +95,15 @@ def predict(image):
80
  image = transforms.Resize((224, 224))(image)
81
  text_input = tokenizer("Sample prompt", return_tensors='pt', padding=True, truncation=True)
82
  output = model(image, text_input)
83
- _, indices = torch.topk(output, 5)
84
- recommended_models = [dataset['Model'][i] for i in indices[0]]
85
- return recommended_models
 
86
 
87
- # Set up Gradio interface
88
  interface = gr.Interface(fn=predict,
89
- inputs=gr.inputs.Image(type="pil"),
90
- outputs=gr.outputs.Textbox(label="Recommended Models"),
91
  title="AI Image Model Recommender",
92
  description="Upload an AI-generated image to receive model recommendations.")
93
 
94
- # Launch the app
95
  interface.launch()
 
8
  from datasets import load_dataset
9
  from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
11
+ import matplotlib.pyplot as plt
12
+ from PIL import Image
13
+ import io
14
 
15
  # Load dataset
16
+ dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:1000]')
17
 
18
  # Preprocess text data
19
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 
75
  # Instantiate model
76
  model = CombinedModel()
77
 
78
+ def display_recommendations(image, recommended_models, distances):
79
+ fig, axes = plt.subplots(1, len(recommended_models), figsize=(16, 4))
80
+ fig.suptitle("Recommended Models")
81
+
82
+ for i, (model, distance) in enumerate(zip(recommended_models, distances)):
83
+ # Load and display the recommended model image
84
+ model_image = Image.open(io.BytesIO(dataset.get_example(model)['image']))
85
+ axes[i].imshow(model_image)
86
+ axes[i].axis('off')
87
+ axes[i].set_title(f"{model}\nDistance: {distance:.2f}")
88
+
89
+ return fig
90
+
91
  def predict(image):
92
  model.eval()
93
  with torch.no_grad():
 
95
  image = transforms.Resize((224, 224))(image)
96
  text_input = tokenizer("Sample prompt", return_tensors='pt', padding=True, truncation=True)
97
  output = model(image, text_input)
98
+ distances, indices = torch.topk(-output.squeeze(), 5)
99
+ recommended_models = [dataset['Model'][i] for i in indices]
100
+ distances = (-distances).tolist()
101
+ return display_recommendations(image, recommended_models, distances)
102
 
 
103
  interface = gr.Interface(fn=predict,
104
+ inputs=gr.Image(type="pil"),
105
+ outputs=gr.Plot(),
106
  title="AI Image Model Recommender",
107
  description="Upload an AI-generated image to receive model recommendations.")
108
 
 
109
  interface.launch()