bgaspra commited on
Commit
3ffa6ad
·
verified ·
1 Parent(s): 73e27ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -32
app.py CHANGED
@@ -9,8 +9,10 @@ from datasets import load_dataset
9
  from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
11
 
12
- # Load dataset
13
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
 
 
14
 
15
  # Preprocess text data
16
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
@@ -77,47 +79,31 @@ class CombinedModel(nn.Module):
77
  # Instantiate model
78
  model = CombinedModel()
79
 
80
- def get_recommendations(image):
 
81
  model.eval()
82
  with torch.no_grad():
83
- # Process image
84
- transform = transforms.Compose([
85
- transforms.Resize((224, 224)),
86
- transforms.ToTensor()
87
- ])
88
- image_tensor = transform(image).unsqueeze(0)
89
-
90
- # Process text
91
  text_input = tokenizer(
92
  "Sample prompt",
93
  return_tensors='pt',
94
  padding=True,
95
  truncation=True
96
  )
97
-
98
- # Get predictions
99
- output = model(image_tensor, text_input)
100
- scores, indices = torch.topk(output, 5)
101
-
102
- # Prepare gallery output
103
- recommendations = []
104
- for idx, score in zip(indices[0], scores[0]):
105
- sample_data = dataset[int(idx)]
106
- recommendations.append({
107
- 'image': sample_data['image'],
108
- 'label': f"Model: {sample_data['Model']}\nScore: {score:.2f}"
109
- })
110
-
111
- return recommendations
112
 
113
- # Gradio interface
114
  interface = gr.Interface(
115
- fn=get_recommendations,
116
  inputs=gr.Image(type="pil"),
117
- outputs=gr.Gallery(label="Recommended Images"),
118
- title="Image Recommendation System",
119
- description="Upload an image and get similar images with their model names and distances."
120
  )
121
 
122
- if __name__ == "__main__":
123
- interface.launch()
 
9
  from torch.utils.data import DataLoader, Dataset
10
  from sklearn.preprocessing import LabelEncoder
11
 
12
+ # Load dataset and filter out null/none values
13
  dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
14
+ # Filter out entries where Model is None or empty
15
+ dataset = dataset.filter(lambda example: example['Model'] is not None and example['Model'].strip() != '')
16
 
17
  # Preprocess text data
18
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 
79
  # Instantiate model
80
  model = CombinedModel()
81
 
82
+ # Define predict function
83
+ def predict(image):
84
  model.eval()
85
  with torch.no_grad():
86
+ image = transforms.ToTensor()(image).unsqueeze(0)
87
+ image = transforms.Resize((224, 224))(image)
 
 
 
 
 
 
88
  text_input = tokenizer(
89
  "Sample prompt",
90
  return_tensors='pt',
91
  padding=True,
92
  truncation=True
93
  )
94
+ output = model(image, text_input)
95
+ _, indices = torch.topk(output, 5)
96
+ recommended_models = [dataset['Model'][i] for i in indices[0]]
97
+ return recommended_models
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ # Set up Gradio interface
100
  interface = gr.Interface(
101
+ fn=predict,
102
  inputs=gr.Image(type="pil"),
103
+ outputs=gr.Textbox(label="Recommended Models"),
104
+ title="AI Image Model Recommender",
105
+ description="Upload an AI-generated image to receive model recommendations."
106
  )
107
 
108
+ # Launch the app
109
+ interface.launch()