bgaspra commited on
Commit
4caf7ea
·
verified ·
1 Parent(s): bebfeb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -30
app.py CHANGED
@@ -8,12 +8,9 @@ import pandas as pd
8
  from datasets import load_dataset
9
  from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
11
- from PIL import Image
12
- import requests
13
- from io import BytesIO
14
 
15
  # Load dataset
16
- dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:1000]')
17
 
18
  # Preprocess text data
19
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
@@ -79,36 +76,20 @@ model = CombinedModel()
79
  def predict(image):
80
  model.eval()
81
  with torch.no_grad():
82
- image_tensor = transforms.ToTensor()(image).unsqueeze(0)
83
- image_tensor = transforms.Resize((224, 224))(image_tensor)
84
  text_input = tokenizer("Sample prompt", return_tensors='pt', padding=True, truncation=True)
85
- output = model(image_tensor, text_input)
86
  _, indices = torch.topk(output, 5)
87
-
88
- recommended_models = []
89
- for i in indices[0]:
90
- model_name = dataset['Model'][i]
91
- image_url = dataset['url'][i]
92
- response = requests.get(image_url)
93
- model_image = Image.open(BytesIO(response.content))
94
- recommended_models.append((model_name, model_image))
95
-
96
  return recommended_models
97
 
98
  # Set up Gradio interface
99
- def display_predictions(image):
100
- recommendations = predict(image)
101
- model_names = [rec[0] for rec in recommendations]
102
- model_images = [rec[1] for rec in recommendations]
103
- return model_names, model_images
104
-
105
- interface = gr.Interface(
106
- fn=display_predictions,
107
- inputs=gr.Image(type="pil"),
108
- outputs=[gr.Textbox(label="Recommended Models"), gr.Image(type="pil", label="Model Images")],
109
- title="AI Image Model Recommender",
110
- description="Upload an AI-generated image to receive model recommendations."
111
- )
112
 
113
  # Launch the app
114
- interface.launch()
 
8
  from datasets import load_dataset
9
  from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
 
 
 
11
 
12
  # Load dataset
13
+ dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
14
 
15
  # Preprocess text data
16
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 
76
  def predict(image):
77
  model.eval()
78
  with torch.no_grad():
79
+ image = transforms.ToTensor()(image).unsqueeze(0)
80
+ image = transforms.Resize((224, 224))(image)
81
  text_input = tokenizer("Sample prompt", return_tensors='pt', padding=True, truncation=True)
82
+ output = model(image, text_input)
83
  _, indices = torch.topk(output, 5)
84
+ recommended_models = [dataset['Model'][i] for i in indices[0]]
 
 
 
 
 
 
 
 
85
  return recommended_models
86
 
87
  # Set up Gradio interface
88
+ interface = gr.Interface(fn=predict,
89
+ inputs=gr.inputs.Image(type="pil"),
90
+ outputs=gr.outputs.Textbox(label="Recommended Models"),
91
+ title="AI Image Model Recommender",
92
+ description="Upload an AI-generated image to receive model recommendations.")
 
 
 
 
 
 
 
 
93
 
94
  # Launch the app
95
+ interface.launch()