bgaspra commited on
Commit
bebfeb8
·
verified ·
1 Parent(s): 7c292b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -10
app.py CHANGED
@@ -8,6 +8,9 @@ import pandas as pd
8
  from datasets import load_dataset
9
  from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
 
 
 
11
 
12
  # Load dataset
13
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:1000]')
@@ -76,20 +79,36 @@ model = CombinedModel()
76
  def predict(image):
77
  model.eval()
78
  with torch.no_grad():
79
- image = transforms.ToTensor()(image).unsqueeze(0)
80
- image = transforms.Resize((224, 224))(image)
81
  text_input = tokenizer("Sample prompt", return_tensors='pt', padding=True, truncation=True)
82
- output = model(image, text_input)
83
  _, indices = torch.topk(output, 5)
84
- recommended_models = [dataset['Model'][i] for i in indices[0]]
 
 
 
 
 
 
 
 
85
  return recommended_models
86
 
87
  # Set up Gradio interface
88
- interface = gr.Interface(fn=predict,
89
- inputs=gr.Image(type="pil"),
90
- outputs=gr.Textbox(label="Recommended Models"),
91
- title="AI Image Model Recommender",
92
- description="Upload an AI-generated image to receive model recommendations.")
 
 
 
 
 
 
 
 
93
 
94
  # Launch the app
95
- interface.launch()
 
8
  from datasets import load_dataset
9
  from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
11
+ from PIL import Image
12
+ import requests
13
+ from io import BytesIO
14
 
15
  # Load dataset
16
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:1000]')
 
79
  def predict(image):
80
  model.eval()
81
  with torch.no_grad():
82
+ image_tensor = transforms.ToTensor()(image).unsqueeze(0)
83
+ image_tensor = transforms.Resize((224, 224))(image_tensor)
84
  text_input = tokenizer("Sample prompt", return_tensors='pt', padding=True, truncation=True)
85
+ output = model(image_tensor, text_input)
86
  _, indices = torch.topk(output, 5)
87
+
88
+ recommended_models = []
89
+ for i in indices[0]:
90
+ model_name = dataset['Model'][i]
91
+ image_url = dataset['url'][i]
92
+ response = requests.get(image_url)
93
+ model_image = Image.open(BytesIO(response.content))
94
+ recommended_models.append((model_name, model_image))
95
+
96
  return recommended_models
97
 
98
  # Set up Gradio interface
99
+ def display_predictions(image):
100
+ recommendations = predict(image)
101
+ model_names = [rec[0] for rec in recommendations]
102
+ model_images = [rec[1] for rec in recommendations]
103
+ return model_names, model_images
104
+
105
+ interface = gr.Interface(
106
+ fn=display_predictions,
107
+ inputs=gr.Image(type="pil"),
108
+ outputs=[gr.Textbox(label="Recommended Models"), gr.Image(type="pil", label="Model Images")],
109
+ title="AI Image Model Recommender",
110
+ description="Upload an AI-generated image to receive model recommendations."
111
+ )
112
 
113
  # Launch the app
114
+ interface.launch()