kritsg commited on
Commit
aa3b50d
·
1 Parent(s): c33de2b

upload image functionality testing

Browse files
Files changed (2) hide show
  1. app.py +5 -4
  2. bayes/data_routines.py +39 -34
app.py CHANGED
@@ -21,18 +21,19 @@ from bayes.models import *
21
  from image_posterior import create_gif
22
 
23
 
24
- def get_image_data(image_name):
25
  """Gets the image data and model."""
26
- image = get_dataset_by_name(image_name, get_label=False)
 
27
  model_and_data = process_imagenet_get_model(image)
28
-
29
  return image, model_and_data
30
 
31
 
32
  def segmentation_generation(image_name, c_width, n_top, n_gif_imgs):
33
  print("Inputs Received:", image_name, c_width, n_top, n_gif_imgs)
34
 
35
- print("imagename", image_name.filename)
36
 
37
  return "yeehaw"
38
 
 
21
  from image_posterior import create_gif
22
 
23
 
24
+ def get_image_data(inp_image):
25
  """Gets the image data and model."""
26
+ image = get_dataset_by_name(inp_image, get_label=False)
27
+ print("image returned\n", image)
28
  model_and_data = process_imagenet_get_model(image)
29
+ print("model returned\n", model_and_data)
30
  return image, model_and_data
31
 
32
 
33
  def segmentation_generation(image_name, c_width, n_top, n_gif_imgs):
34
  print("Inputs Received:", image_name, c_width, n_top, n_gif_imgs)
35
 
36
+ get_image_data(image_name)
37
 
38
  return "yeehaw"
39
 
bayes/data_routines.py CHANGED
@@ -113,45 +113,49 @@ def get_PIL_transf():
113
  ])
114
  return transf
115
 
116
- def load_image(path):
 
 
 
 
 
 
117
  """Loads an image by path."""
118
- with open(os.path.abspath(path), 'rb') as f:
119
- with Image.open(f) as img:
120
- return img.convert('RGB')
121
 
122
- def get_imagenet(name, get_label=True):
123
  """Gets the imagenet data.
124
 
125
  Arguments:
126
  name: The name of the imagenet dataset
127
  """
128
- images_paths = []
129
 
130
  # Store all the paths of the images
131
- data_dir = os.path.join("./data", name)
132
- for (dirpath, dirnames, filenames) in os.walk(data_dir):
133
- for fn in filenames:
134
- if fn != ".DS_Store":
135
- images_paths.append(os.path.join(dirpath, fn))
136
 
137
  # Load & do transforms for the images
138
  pill_transf = get_PIL_transf()
139
  images, segs = [], []
140
- for img_path in images_paths:
141
- img = load_image(img_path)
142
- PIL_transformed_image = np.array(pill_transf(img))
143
- segments = slic(PIL_transformed_image, n_segments=NSEGMENTS, compactness=100, sigma=1)
144
 
145
- images.append(PIL_transformed_image)
146
- segs.append(segments)
147
 
148
  images = np.array(images)
149
 
150
- if get_label:
151
- assert name in IMAGENET_LABELS, "Get label set to True but name not in known imagenet labels"
152
- y = np.ones(images.shape[0]) * IMAGENET_LABELS[name]
153
- else:
154
- y = np.ones(images.shape[0]) * -1
155
 
156
  segs = np.array(segs)
157
 
@@ -203,16 +207,17 @@ def get_mnist(num):
203
 
204
  return output
205
 
206
- def get_dataset_by_name(name, get_label=True):
207
- if name == "compas":
208
- d = get_and_preprocess_compas_data()
209
- elif name == "german":
210
- d = get_and_preprocess_german()
211
- elif "mnist" in name:
212
- d = get_mnist(int(name[-1]))
213
- elif "imagenet" in name:
214
- d = get_imagenet(name[9:], get_label=get_label)
215
- else:
216
- raise NameError("Unkown dataset %s", name)
217
- d['name'] = name
 
218
  return d
 
113
  ])
114
  return transf
115
 
116
+ # def load_image(path):
117
+ # """Loads an image by path."""
118
+ # with open(os.path.abspath(path), 'rb') as f:
119
+ # with Image.open(f) as img:
120
+ # return img.convert('RGB')
121
+
122
+ def load_image(pil_image):
123
  """Loads an image by path."""
124
+ with Image.open(pil_image) as img:
125
+ return img.convert('RGB')
 
126
 
127
+ def get_imagenet(pil_image, get_label=True):
128
  """Gets the imagenet data.
129
 
130
  Arguments:
131
  name: The name of the imagenet dataset
132
  """
133
+ # images_paths = []
134
 
135
  # Store all the paths of the images
136
+ # data_dir = os.path.join("./data", name)
137
+ # for (dirpath, dirnames, filenames) in os.walk(data_dir):
138
+ # for fn in filenames:
139
+ # if fn != ".DS_Store":
140
+ # images_paths.append(os.path.join(dirpath, fn))
141
 
142
  # Load & do transforms for the images
143
  pill_transf = get_PIL_transf()
144
  images, segs = [], []
145
+ img = load_image(pil_image)
146
+ PIL_transformed_image = np.array(pill_transf(img))
147
+ segments = slic(PIL_transformed_image, n_segments=NSEGMENTS, compactness=100, sigma=1)
 
148
 
149
+ images.append(PIL_transformed_image)
150
+ segs.append(segments)
151
 
152
  images = np.array(images)
153
 
154
+ # if get_label:
155
+ # assert name in IMAGENET_LABELS, "Get label set to True but name not in known imagenet labels"
156
+ # y = np.ones(images.shape[0]) * IMAGENET_LABELS[name]
157
+ # else:
158
+ y = np.ones(images.shape[0]) * -1
159
 
160
  segs = np.array(segs)
161
 
 
207
 
208
  return output
209
 
210
+ def get_dataset_by_name(inp_image, get_label=True):
211
+ d = get_imagenet(inp_image, get_label=get_lable)
212
+ # if name == "compas":
213
+ # d = get_and_preprocess_compas_data()
214
+ # elif name == "german":
215
+ # d = get_and_preprocess_german()
216
+ # elif "mnist" in name:
217
+ # d = get_mnist(int(name[-1]))
218
+ # elif "imagenet" in name:
219
+ # d = get_imagenet(name[9:], get_label=get_label)
220
+ # else:
221
+ # raise NameError("Unkown dataset %s", name)
222
+ # d['name'] = name
223
  return d