giorgio-caparvi commited on
Commit
cc3c167
·
1 Parent(s): a3d4937

sending image and json with captions

Browse files
api/app.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  import io
6
  from model.src.utils.arg_parser import eval_parse_args # Nuovo import corretto
7
  import sys
 
8
 
9
  from model.src import eval
10
 
@@ -15,17 +16,34 @@ CORS(app)
15
  def index():
16
  return render_template('index.html')
17
 
18
- @app.route('/generate-design', methods=['POST'])
19
  def generate_design():
20
  try:
 
 
 
 
 
 
 
21
 
22
  # Getting Image
23
- image_file = request.files['image']
24
- image = Image.open(image_file)
25
- save_path = os.path.join('/api/model/assets/data/vitonhd/test/im_sketch', '03191_00.jpeg')
26
- image.save(save_path, 'JPEG')
27
 
 
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  # Creiamo una lista di argomenti come quelli che passeresti via CLI
@@ -42,7 +60,7 @@ def generate_design():
42
  ]
43
 
44
  # Esegui la funzione `main()` di eval.py passando gli argomenti
45
- final_image = eval.main()
46
 
47
  # Save the image to a BytesIO buffer to return via HTTP
48
  img_io = io.BytesIO()
 
5
  import io
6
  from model.src.utils.arg_parser import eval_parse_args # Nuovo import corretto
7
  import sys
8
+ from PIL import Image
9
 
10
  from model.src import eval
11
 
 
16
  def index():
17
  return render_template('index.html')
18
 
19
+ @app.route('/generate-design', methods=['GET','POST'])
20
  def generate_design():
21
  try:
22
+
23
+ # Getting json
24
+ json_data_from_req = request.get_json()
25
+ if not json_data_from_req:
26
+ return "Invalid or missing JSON data", 400
27
+ print(json_data_from_req)
28
+
29
 
30
  # Getting Image
 
 
 
 
31
 
32
+ if 'image' not in request.files:
33
+ return "No image file in request", 400
34
 
35
+ image_file = request.files['image']
36
+ try:
37
+ image = Image.open(image_file)
38
+ except Exception as e:
39
+ return f"Failed to open the image: {str(e)}", 400
40
+
41
+ # Create an in-memory buffer to store the image (instead of saving to disk)
42
+ img_sketch_buffer = io.BytesIO()
43
+ # Save the image to the buffer in JPEG format
44
+ image.save(img_sketch_buffer, format='JPEG')
45
+ # Rewind the buffer's position to the beginning
46
+ img_sketch_buffer.seek(0)
47
 
48
 
49
  # Creiamo una lista di argomenti come quelli che passeresti via CLI
 
60
  ]
61
 
62
  # Esegui la funzione `main()` di eval.py passando gli argomenti
63
+ final_image = eval.main(img_sketch_buffer, json_data_from_req)
64
 
65
  # Save the image to a BytesIO buffer to return via HTTP
66
  img_io = io.BytesIO()
api/model/src/datasets/vitonhd.py CHANGED
@@ -26,6 +26,8 @@ class VitonHDDataset(data.Dataset):
26
  self,
27
  dataroot_path: str,
28
  phase: str,
 
 
29
  tokenizer,
30
  radius=5,
31
  caption_folder='captions.json', #######################################################3
@@ -48,6 +50,8 @@ class VitonHDDataset(data.Dataset):
48
  self.width = size[1]
49
  self.radius = radius
50
  self.tokenizer = tokenizer
 
 
51
  self.transform = transforms.Compose([
52
  transforms.ToTensor(),
53
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
@@ -70,12 +74,19 @@ class VitonHDDataset(data.Dataset):
70
  assert all(x in possible_outputs for x in outputlist)
71
 
72
  # Load Captions
 
 
 
 
73
  with open(os.path.join(self.dataroot, self.caption_folder)) as f:
74
  # self.captions_dict = json.load(f)['items']
75
  self.captions_dict = json.load(f)
76
  self.captions_dict = {k: v for k, v in self.captions_dict.items() if len(v) >= 3}
 
77
 
78
  dataroot = self.dataroot
 
 
79
  if phase == 'train':
80
  filename = os.path.join(dataroot, f"{phase}_pairs.txt")
81
  else:
@@ -99,10 +110,10 @@ class VitonHDDataset(data.Dataset):
99
  im_names.append(im_name)
100
  c_names.append(c_name)
101
  dataroot_names.append(dataroot)
102
-
103
- self.im_names = im_names
104
- self.c_names = c_names
105
- self.dataroot_names = dataroot_names
106
 
107
  def __getitem__(self, index):
108
  """
@@ -112,9 +123,10 @@ class VitonHDDataset(data.Dataset):
112
  :return: dict containing dataset samples
113
  :rtype: dict
114
  """
115
- c_name = self.c_names[index]
116
- im_name = self.im_names[index]
117
- dataroot = self.dataroot_names[index]
 
118
 
119
  sketch_threshold = random.randint(self.sketch_threshold_range[0], self.sketch_threshold_range[1])
120
 
@@ -146,7 +158,7 @@ class VitonHDDataset(data.Dataset):
146
  image = self.transform(image) # [-1,1]
147
 
148
  if "im_sketch" in self.outputlist:
149
- # Person image
150
  # im_sketch = Image.open(os.path.join(dataroot, 'im_sketch', c_name.replace(".jpg", ".png")))
151
  if self.order == 'unpaired':
152
  im_sketch = Image.open(
@@ -161,8 +173,12 @@ class VitonHDDataset(data.Dataset):
161
  else:
162
  raise ValueError(
163
  f"Order should be either paired or unpaired"
164
- )
165
-
 
 
 
 
166
  im_sketch = im_sketch.resize((self.width, self.height))
167
  im_sketch = ImageOps.invert(im_sketch)
168
  # threshold grayscale pil image
 
26
  self,
27
  dataroot_path: str,
28
  phase: str,
29
+ im_sketch_buffer_from_request,
30
+ json_from_req,
31
  tokenizer,
32
  radius=5,
33
  caption_folder='captions.json', #######################################################3
 
50
  self.width = size[1]
51
  self.radius = radius
52
  self.tokenizer = tokenizer
53
+ self.im_sketch_buffer_from_request = im_sketch_buffer_from_request
54
+ self.json_from_req = json_from_req
55
  self.transform = transforms.Compose([
56
  transforms.ToTensor(),
57
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
 
74
  assert all(x in possible_outputs for x in outputlist)
75
 
76
  # Load Captions
77
+ model_data = self.json_from_req.get('MODEL', {}) # Safely get the 'MODEL' key, default to an empty dictionary if it doesn't exist
78
+ # Filter captions based on the length requirement (3 or more items)
79
+ self.captions_dict = {k: v for k, v in model_data.items() if len(v) >= 3}
80
+ '''
81
  with open(os.path.join(self.dataroot, self.caption_folder)) as f:
82
  # self.captions_dict = json.load(f)['items']
83
  self.captions_dict = json.load(f)
84
  self.captions_dict = {k: v for k, v in self.captions_dict.items() if len(v) >= 3}
85
+
86
 
87
  dataroot = self.dataroot
88
+
89
+
90
  if phase == 'train':
91
  filename = os.path.join(dataroot, f"{phase}_pairs.txt")
92
  else:
 
110
  im_names.append(im_name)
111
  c_names.append(c_name)
112
  dataroot_names.append(dataroot)
113
+ '''
114
+ self.im_names = []
115
+ self.c_names = []
116
+ self.dataroot_names = []
117
 
118
  def __getitem__(self, index):
119
  """
 
123
  :return: dict containing dataset samples
124
  :rtype: dict
125
  """
126
+ c_name = list(self.captions_dict.keys())[0] + "_00.jpg" # self.c_names[index]
127
+ im_name = list(self.captions_dict.keys())[0] + "_00.jpg" #self.im_names[index]
128
+ #dataroot = self.dataroot_names[index]
129
+ dataroot = "./assets/data/vitonhd"
130
 
131
  sketch_threshold = random.randint(self.sketch_threshold_range[0], self.sketch_threshold_range[1])
132
 
 
158
  image = self.transform(image) # [-1,1]
159
 
160
  if "im_sketch" in self.outputlist:
161
+ '''# Person image
162
  # im_sketch = Image.open(os.path.join(dataroot, 'im_sketch', c_name.replace(".jpg", ".png")))
163
  if self.order == 'unpaired':
164
  im_sketch = Image.open(
 
173
  else:
174
  raise ValueError(
175
  f"Order should be either paired or unpaired"
176
+ )'''
177
+ im_sketch = Image.open(self.im_sketch_buffer_from_request)
178
+ # define a transform to convert the image to grayscale
179
+ transform = transforms.Grayscale()
180
+ # apply the above transform on the image
181
+ im_sketch = transform(im_sketch)
182
  im_sketch = im_sketch.resize((self.width, self.height))
183
  im_sketch = ImageOps.invert(im_sketch)
184
  # threshold grayscale pil image
api/model/src/eval.py CHANGED
@@ -29,7 +29,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
29
  os.environ["WANDB_START_METHOD"] = "thread"
30
 
31
 
32
- def main() -> None:
33
  args = eval_parse_args()
34
  accelerator = Accelerator(
35
  mixed_precision=args.mixed_precision,
@@ -91,6 +91,8 @@ def main() -> None:
91
  radius=5,
92
  tokenizer=tokenizer,
93
  size=(512, 384),
 
 
94
  )
95
  else:
96
  raise NotImplementedError
 
29
  os.environ["WANDB_START_METHOD"] = "thread"
30
 
31
 
32
+ def main(im_sketch: io.BytesIO, json_data_from_req: Dict) -> None:
33
  args = eval_parse_args()
34
  accelerator = Accelerator(
35
  mixed_precision=args.mixed_precision,
 
91
  radius=5,
92
  tokenizer=tokenizer,
93
  size=(512, 384),
94
+ im_sketch=im_sketch,
95
+ json_data_from_req=json_data_from_req
96
  )
97
  else:
98
  raise NotImplementedError