temp-9384289 commited on
Commit
8009a95
·
1 Parent(s): ffe27dd
Files changed (2) hide show
  1. app.py +106 -0
  2. requirements.txt +6 -2
app.py CHANGED
@@ -5,13 +5,19 @@ from diffusers import DiffusionPipeline
5
  import spaces
6
  # import torch
7
  import PIL.Image
 
 
8
  import gradio as gr
9
  import gradio.components as grc
10
  import numpy as np
11
  from huggingface_hub import from_pretrained_keras
 
12
  import keras
13
  import time
 
 
14
  import os
 
15
 
16
  # os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
17
 
@@ -68,6 +74,7 @@ def getModel(model):
68
  train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
69
 
70
  print(model_id)
 
71
  if 'diffusion' in model_id:
72
  pipe = DiffusionPipeline.from_pretrained(model_id)
73
  pipe = pipe.to("cpu")
@@ -78,6 +85,105 @@ def getModel(model):
78
  test = from_pretrained_keras('nathanReitinger/MNIST-GAN')
79
  image = pipe(generator= torch.manual_seed(42), num_inference_steps=40).images[0]
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  return image
82
 
83
 
 
5
  import spaces
6
  # import torch
7
  import PIL.Image
8
+ from PIL import Image
9
+ from torch.autograd import Variable
10
  import gradio as gr
11
  import gradio.components as grc
12
  import numpy as np
13
  from huggingface_hub import from_pretrained_keras
14
+ from image_similarity_measures.evaluate import evaluation
15
  import keras
16
  import time
17
+ import requests
18
+ import matplotlib.pyplot as plt
19
  import os
20
+ from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
21
 
22
  # os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
23
 
 
74
  train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
75
 
76
  print(model_id)
77
+ image = None
78
  if 'diffusion' in model_id:
79
  pipe = DiffusionPipeline.from_pretrained(model_id)
80
  pipe = pipe.to("cpu")
 
85
  test = from_pretrained_keras('nathanReitinger/MNIST-GAN')
86
  image = pipe(generator= torch.manual_seed(42), num_inference_steps=40).images[0]
87
 
88
+ ########################################### let's save this image for comparison to others
89
+ fig = plt.figure(figsize=(1, 1))
90
+ plt.subplot(1, 1, 0+1)
91
+ plt.imshow(image, cmap='gray')
92
+ plt.axis('off')
93
+ plt.savefig(file_path + 'generated_image.png')
94
+ plt.close()
95
+
96
+ API_URL = "https://api-inference.huggingface.co/models/farleyknight/mnist-digit-classification-2022-09-04"
97
+
98
+ # get a prediction on what number this is
99
+ def query(filename):
100
+ with open(filename, "rb") as f:
101
+ data = f.read()
102
+ response = requests.post(API_URL, data=data)
103
+ return response.json()
104
+
105
+ # use latest model to generate a new image, return path
106
+ ret = False
107
+ output = None
108
+ while ret == False:
109
+ output = query(file_path + 'generated_image.png')
110
+ if 'error' in output:
111
+ time.sleep(10)
112
+ ret = False
113
+ else:
114
+ ret = True
115
+ print(output)
116
+
117
+ low_score_log = ''
118
+ this_label_for_this_image = int(output[0]['label'])
119
+ low_score_log += "this image has been identified as a:" + str(this_label_for_this_image) + "\n" + str(output) + "\n"
120
+ print("===================")
121
+
122
+ lowest_score = 10000
123
+
124
+ for i in range(len(train_labels)):
125
+ # print(i)
126
+ if train_labels[i] == this_label_for_this_image:
127
+
128
+ ###
129
+ # get a real image (of correct number)
130
+ ###
131
+
132
+ # print(i)
133
+ to_check = train_images[i]
134
+ fig = plt.figure(figsize=(1, 1))
135
+ plt.subplot(1, 1, 0+1)
136
+ plt.imshow(to_check, cmap='gray')
137
+ plt.axis('off')
138
+ plt.savefig(file_path + 'real_deal.png')
139
+ plt.close()
140
+
141
+ # baseline = evaluation(org_img_path='results/real_deal.png', pred_img_path='results/real_deal.png', metrics=["rmse", "psnr"])
142
+ # print("---")
143
+
144
+ ###
145
+ # check how close that real training data is to generated number
146
+ ###
147
+ results = evaluation(org_img_path=file_path + 'real_deal.png', pred_img_path=file_path+'generated_image.png', metrics=["rmse", "psnr"])
148
+ if results['rmse'] < lowest_score:
149
+
150
+ lowest_score = results['rmse']
151
+
152
+ image1 = np.array(Image.open(file_path + 'real_deal.png'))
153
+ image2 = np.array(Image.open(file_path + 'generated_image.png'))
154
+ img1 = torch.from_numpy(image1).float().unsqueeze(0).unsqueeze(0)/255.0
155
+ img2 = torch.from_numpy(image2).float().unsqueeze(0).unsqueeze(0)/255.0
156
+ img1 = Variable( img1, requires_grad=False)
157
+ img2 = Variable( img2, requires_grad=True)
158
+ ssim_score = ssim(img1, img2).item()
159
+
160
+ # sys.exit()
161
+ # l2 = distance.euclidean(image1, image2)
162
+
163
+ low_score_log += 'rmse score:' + str(lowest_score) + "\n"
164
+ low_score_log += 'ssim score:' + str(ssim_score) + "\n"
165
+ low_score_log += 'found when:' + str(round( ((i/len(train_labels)) * 100),2 )) + '%' + "\n"
166
+
167
+ low_score_log += "---------\n"
168
+
169
+ print(lowest_score, ssim_score, str(round( ((i/len(train_labels)) * 100),2 )) + '%')
170
+
171
+ fig = plt.figure(figsize=(1, 1))
172
+ plt.subplot(1, 1, 0+1)
173
+ plt.imshow(to_check, cmap='gray')
174
+ plt.axis('off')
175
+ plt.savefig(file_path+str(i) + "--" + str(lowest_score) + '---most_close.png')
176
+ plt.close()
177
+
178
+
179
+ f = open(file_path + "score_log.txt", "w+")
180
+ f.write(low_score_log)
181
+ f.close()
182
+
183
+ print("Done!")
184
+
185
+
186
+ ############################################ return image that you just generated
187
  return image
188
 
189
 
requirements.txt CHANGED
@@ -3,9 +3,13 @@
3
  diffusers==0.27.2
4
  gradio==4.28.3
5
  huggingface-hub==0.22.2
 
6
  keras==2.11.0
7
- tensorflow==2.11.0
8
- numpy==1.23.4
9
  pillow==10.3.0
 
 
10
  spaces==0.26.2
 
11
  torch==2.2.2
 
3
  diffusers==0.27.2
4
  gradio==4.28.3
5
  huggingface-hub==0.22.2
6
+ image-similarity-measures==0.3.6
7
  keras==2.11.0
8
+ matplotlib==3.8.4
9
+ numpy==1.25.2
10
  pillow==10.3.0
11
+ pytorch-msssim==1.0.0
12
+ requests==2.31.0
13
  spaces==0.26.2
14
+ tensorflow==2.11.0
15
  torch==2.2.2