RamAnanth1 commited on
Commit
fc5172a
·
1 Parent(s): 67e29ff

Add pose detection, controlnet using pose

Browse files
Files changed (1) hide show
  1. app.py +64 -6
app.py CHANGED
@@ -25,6 +25,7 @@ import cv2
25
  import einops
26
  from pytorch_lightning import seed_everything
27
  import random
 
28
 
29
  VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
30
 
@@ -272,6 +273,61 @@ class canny2image:
272
  real_image.save(updated_image_path)
273
  return updated_image_path
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  class BLIPVQA:
276
  def __init__(self, device):
277
  print("Initializing BLIP VQA to %s" % device)
@@ -296,6 +352,8 @@ class ConversationBot:
296
  self.t2i = T2I(device="cuda:0")
297
  self.image2canny = image2canny()
298
  self.canny2image = canny2image(device="cuda:0")
 
 
299
  self.BLIPVQA = BLIPVQA(device="cuda:0")
300
  #self.pix2pix = Pix2Pix(device="cuda:0")
301
  self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
@@ -361,12 +419,12 @@ class ConversationBot:
361
  #Tool(name="Generate Image Condition On Sketch Image", func=self.scribble2image.inference,
362
  #description="useful when you want to generate a new real image from both the user desciption and a scribble image or a sketch image. "
363
  #"The input to this tool should be a comma seperated string of two, representing the image_path and the user description"),
364
- #Tool(name="Pose Detection On Image", func=self.image2pose.inference,
365
- #description="useful when you want to detect the human pose of the image. like: generate human poses of this image, or generate a pose image from this image. "
366
- #"The input to this tool should be a string, representing the image_path"),
367
- #Tool(name="Generate Image Condition On Pose Image", func=self.pose2image.inference,
368
- #description="useful when you want to generate a new real image from both the user desciption and a human pose image. like: generate a real image of a human from this human pose image, or generate a new real image of a human from this pose. "
369
- #"The input to this tool should be a comma seperated string of two, representing the image_path and the user description")]
370
  ]
371
 
372
  def init_langchain(self,api_key):
 
25
  import einops
26
  from pytorch_lightning import seed_everything
27
  import random
28
+ from controlnet_aux import OpenposeDetector
29
 
30
  VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
31
 
 
273
  real_image.save(updated_image_path)
274
  return updated_image_path
275
 
276
+ class image2pose:
277
+ def __init__(self):
278
+ print("Direct detect pose.")
279
+
280
+ def inference(self, inputs):
281
+ print("===>Starting image2pose Inference")
282
+ image = Image.open(inputs)
283
+ image = np.array(image)
284
+ pose_image = pose_model(image)
285
+
286
+ updated_image_path = get_new_image_name(inputs, func_name="pose")
287
+ pose_image.save(updated_image_path)
288
+ return updated_image_path
289
+
290
+
291
+ class pose2image:
292
+ def __init__(self, device):
293
+ print("Initialize the pose2image model.")
294
+
295
+
296
+ # Models
297
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16)
298
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
299
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
300
+ )
301
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
302
+
303
+ # This command loads the individual model components on GPU on-demand. So, we don't
304
+ # need to explicitly call pipe.to("cuda").
305
+ self.pipe.enable_model_cpu_offload()
306
+
307
+ self.pipe.enable_xformers_memory_efficient_attention()
308
+
309
+ # Generator seed,
310
+ self.generator = torch.manual_seed(0)
311
+
312
+
313
+ def get_pose(self,image):
314
+ return pose_model(image)
315
+
316
+ def inference(self, inputs):
317
+ print("===>Starting pose2image Inference")
318
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
319
+ image = Image.open(image_path)
320
+ image = np.array(image)
321
+ prompt = instruct_text
322
+ pose_image = self.get_pose(image)
323
+ output = self.pipe(prompt,pose_image,generator=self.generator,num_images_per_prompt=1,num_inference_steps=20)
324
+
325
+ updated_image_path = get_new_image_name(image_path, func_name="pose2image")
326
+ real_image = output.images[0] # get default the index0 image
327
+ real_image.save(updated_image_path)
328
+ return updated_image_path
329
+
330
+
331
  class BLIPVQA:
332
  def __init__(self, device):
333
  print("Initializing BLIP VQA to %s" % device)
 
352
  self.t2i = T2I(device="cuda:0")
353
  self.image2canny = image2canny()
354
  self.canny2image = canny2image(device="cuda:0")
355
+ self.image2pose = image2pose()
356
+ self.pose2image = pose2image(device="cuda:0")
357
  self.BLIPVQA = BLIPVQA(device="cuda:0")
358
  #self.pix2pix = Pix2Pix(device="cuda:0")
359
  self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
 
419
  #Tool(name="Generate Image Condition On Sketch Image", func=self.scribble2image.inference,
420
  #description="useful when you want to generate a new real image from both the user desciption and a scribble image or a sketch image. "
421
  #"The input to this tool should be a comma seperated string of two, representing the image_path and the user description"),
422
+ Tool(name="Pose Detection On Image", func=self.image2pose.inference,
423
+ description="useful when you want to detect the human pose of the image. like: generate human poses of this image, or generate a pose image from this image. "
424
+ "The input to this tool should be a string, representing the image_path"),
425
+ Tool(name="Generate Image Condition On Pose Image", func=self.pose2image.inference,
426
+ description="useful when you want to generate a new real image from both the user desciption and a human pose image. like: generate a real image of a human from this human pose image, or generate a new real image of a human from this pose. "
427
+ "The input to this tool should be a comma seperated string of two, representing the image_path and the user description")]
428
  ]
429
 
430
  def init_langchain(self,api_key):