File size: 7,096 Bytes
0abb626 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import os
from uuid import uuid4
import torch
from PIL import Image
from controlnet_aux import HEDdetector
from diffusers import StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
from flask import Flask, request, send_file
from transformers import BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
from transformers import pipeline
app = Flask('chatgpt-plugin-extras')
class VitGPT2:
def __init__(self, device):
print(f"Initializing VitGPT2 ImageCaptioning to {device}")
self.pipeline = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
def inference(self, image_path):
captions = self.pipeline(image_path)[0]['generated_text']
print(f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text: {captions}")
return captions
class ImageCaptioning:
def __init__(self, device):
print(f"Initializing ImageCaptioning to {device}")
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
self.model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large", torch_dtype=self.torch_dtype).to(self.device)
def inference(self, image_path):
inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device, self.torch_dtype)
out = self.model.generate(**inputs)
captions = self.processor.decode(out[0], skip_special_tokens=True)
print(f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text: {captions}")
return captions
class VQA:
def __init__(self, device):
print(f"Initializing Visual QA to {device}")
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
self.model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base",
torch_dtype=self.torch_dtype).to(self.device)
def inference(self, image_path, question):
inputs = self.processor(Image.open(image_path), question, return_tensors="pt").to(self.device, self.torch_dtype)
out = self.model.generate(**inputs)
answers = self.processor.decode(out[0], skip_special_tokens=True)
print(f"\nProcessed Visual QA, Input Image: {image_path}, Output Text: {answers}")
return answers
class Image2Hed:
def __init__(self, device):
print("Initializing Image2Hed")
self.detector = HEDdetector.from_pretrained('lllyasviel/ControlNet')
def inference(self, inputs, output_filename):
output_path = os.path.join('data', output_filename)
image = Image.open(inputs)
hed = self.detector(image)
hed.save(output_path)
print(f"\nProcessed Image2Hed, Input Image: {inputs}, Output Hed: {output_path}")
return '/result/' + output_filename
class Image2Scribble:
def __init__(self, device):
print("Initializing Image2Scribble")
self.detector = HEDdetector.from_pretrained('lllyasviel/ControlNet')
def inference(self, inputs, output_filename):
output_path = os.path.join('data', output_filename)
image = Image.open(inputs)
hed = self.detector(image, scribble=True)
hed.save(output_path)
print(f"\nProcessed Image2Hed, Input Image: {inputs}, Output Hed: {output_path}")
return '/result/' + output_filename
class InstructPix2Pix:
def __init__(self, device):
print(f"Initializing InstructPix2Pix to {device}")
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix",
safety_checker=None,
torch_dtype=self.torch_dtype).to(device)
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
def inference(self, image_path, text, output_filename):
"""Change style of image."""
print("===>Starting InstructPix2Pix Inference")
original_image = Image.open(image_path)
image = self.pipe(text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2).images[0]
output_path = os.path.join('data', output_filename)
image.save(output_path)
print(f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text: {text}, "
f"Output Image: {output_path}")
return '/result/' + output_path
@app.route('/result/<filename>')
def get_result(filename):
file_path = os.path.join('data', filename)
return send_file(file_path, mimetype='image/png')
ic = ImageCaptioning("cpu")
vqa = VQA("cpu")
i2h = Image2Hed("cpu")
i2s = Image2Scribble("cpu")
# vgic = VitGPT2("cpu")
# ip2p = InstructPix2Pix("cpu")
@app.route('/image2hed', methods=['POST'])
def imag2hed():
file = request.files['file'] # 获取上传的文件
filename = str(uuid4()) + '.png'
filepath = os.path.join('data', 'upload', filename)
file.save(filepath)
output_filename = str(uuid4()) + '.png'
result = i2h.inference(filepath, output_filename)
return result
@app.route('/image2Scribble', methods=['POST'])
def image2Scribble():
file = request.files['file'] # 获取上传的文件
filename = str(uuid4()) + '.png'
filepath = os.path.join('data', 'upload', filename)
file.save(filepath)
output_filename = str(uuid4()) + '.png'
result = i2s.inference(filepath, output_filename)
return result
@app.route('/image-captioning', methods=['POST'])
def image_caption():
file = request.files['file'] # 获取上传的文件
filename = str(uuid4()) + '.png'
filepath = os.path.join('data', 'upload', filename)
file.save(filepath)
# result1 = vgic.inference(filepath)
result2 = ic.inference(filepath)
return result2
@app.route('/visual-qa', methods=['POST'])
def visual_qa():
file = request.files['file'] # 获取上传的文件
filename = str(uuid4()) + '.png'
filepath = os.path.join('data', 'upload', filename)
file.save(filepath)
question = request.args.get('q')
result = vqa.inference(filepath, question=question)
return result
@app.route('/instruct-pix2pix', methods=['POST'])
def InstructPix2Pix():
file = request.files['file'] # 获取上传的文件
filename = str(uuid4()) + '.png'
filepath = os.path.join('data', 'upload', filename)
file.save(filepath)
output_filename = str(uuid4()) + '.png'
question = request.args.get('t')
result = ip2p.inference(filepath, question, output_filename)
return result
if __name__ == '__main__':
app.run(host='0.0.0.0')
|