Spaces:
Paused
Paused
#Script added by SPDraptor | |
import spaces | |
import copy | |
import numpy as np | |
import torch | |
from PIL import Image, ImageDraw | |
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from typing import Any | |
import supervision as sv | |
from sam2.build_sam import build_sam2, build_sam2_video_predictor | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
import time | |
device = torch.device('cuda') | |
model_id = 'microsoft/Florence-2-large' | |
models_dict = { | |
'Florence_model':AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval(), | |
'Florence_processor':AutoProcessor.from_pretrained(model_id, trust_remote_code=True), | |
} | |
SAM_CHECKPOINT = "/home/user/app/sam2_hiera_large.pt" | |
SAM_CONFIG = "sam2_hiera_l.yaml" | |
def load_sam_image_model( | |
device: torch.device, | |
config: str = SAM_CONFIG, | |
checkpoint: str = SAM_CHECKPOINT | |
) -> SAM2ImagePredictor: | |
model = build_sam2(config, checkpoint) | |
return SAM2ImagePredictor(sam_model=model) | |
def run_sam_inference( | |
model: Any, | |
image: Image, | |
detections: sv.Detections | |
) -> sv.Detections: | |
image = np.array(image.convert("RGB")) | |
model.set_image(image) | |
print(type(detections.xyxy),detections.xyxy) | |
if detections.xyxy.size == 0: | |
return { | |
'code': 400, | |
'data':'null', | |
'message':'The AI couldn’t detect the object you want to mask.' | |
} | |
mask, score, _ = model.predict(box=detections.xyxy, multimask_output=False) | |
# dirty fix; remove this later | |
if len(mask.shape) == 4: | |
mask = np.squeeze(mask) | |
detections.mask = mask.astype(bool) | |
return { | |
'code': 200, | |
'data':detections, | |
'message':'The AI couldn’t detect the object you want to mask.' | |
} | |
def florence2(image,task_prompt, text_input=None): | |
""" | |
Calling the Microsoft Florence2 model | |
""" | |
model = models_dict['Florence_model'] | |
processor = models_dict['Florence_processor'] | |
# print(image) | |
if text_input is None: | |
prompt = task_prompt | |
else: | |
prompt = task_prompt + text_input | |
input_florence = processor(text=prompt, images=image, return_tensors="pt").to("cuda") | |
print(input_florence) | |
generated_ids = model.generate( | |
input_ids=input_florence["input_ids"], | |
pixel_values=input_florence["pixel_values"], | |
max_new_tokens=1024, | |
early_stopping=False, | |
do_sample=False, | |
num_beams=3, | |
) | |
generated_text = processor.batch_decode(generated_ids, | |
skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation( | |
generated_text, | |
task=task_prompt, | |
image_size=(image.width, image.height)) | |
return parsed_answer | |
def draw_MASK(image, prediction, fill_mask=False): | |
""" | |
Draws segmentation masks with polygons on an image. | |
Parameters: | |
- image_path: Path to the image file. | |
- prediction: Dictionary containing 'polygons' and 'labels' keys. | |
'polygons' is a list of lists, each containing vertices of a polygon. | |
'labels' is a list of labels corresponding to each polygon. | |
- fill_mask: Boolean indicating whether to fill the polygons with color. | |
""" | |
width=image.width | |
height=image.height | |
new_image = Image.new("RGB", (width, height), color="black") | |
draw = ImageDraw.Draw(new_image) | |
scale = 1 | |
for polygons, label in zip(prediction['polygons'], prediction['labels']): | |
color = "white" | |
fill_color = "white" if fill_mask else None | |
for _polygon in polygons: | |
_polygon = np.array(_polygon).reshape(-1, 2) | |
if len(_polygon) < 3: | |
print('Invalid polygon:', _polygon) | |
continue | |
_polygon = (_polygon * scale).reshape(-1).tolist() | |
if fill_mask: | |
draw.polygon(_polygon, outline=color, fill=fill_color) | |
else: | |
draw.polygon(_polygon, outline=color) | |
draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color) | |
return new_image | |
def masking_process(image,obj): | |
# task_prompt = '<REGION_TO_SEGMENTATION>' | |
# # task_prompt = '<OPEN_VOCABULARY_DETECTION>' | |
# print(type(task_prompt),type(obj)) | |
# print('1') | |
start_time = time.time() | |
image = Image.fromarray(image).convert("RGB") | |
# results = florence2(image,task_prompt, text_input=obj) | |
# output_image = copy.deepcopy(image) | |
# img=draw_MASK(output_image, | |
# results['<REGION_TO_SEGMENTATION>'], | |
# fill_mask=True) | |
# mask=img.convert('1') | |
task_prompt = '<OPEN_VOCABULARY_DETECTION>' | |
# image = Image.open("/content/tiger.jpeg").convert("RGB") | |
# obj = "Tiger" | |
Florence_results = florence2(image,task_prompt, text_input=obj) | |
# print('2') | |
SAM_IMAGE_MODEL = load_sam_image_model(device=device) | |
# print('3') | |
detections = sv.Detections.from_lmm( | |
lmm=sv.LMM.FLORENCE_2, | |
result=Florence_results, | |
resolution_wh=image.size | |
) | |
# print('4') | |
response = run_sam_inference(SAM_IMAGE_MODEL, image, detections) | |
print(f'Time taken by masking model: {time.time() - start_time}') | |
# print('5') | |
if response['code'] == 400: | |
print("no object found") | |
return "no object found" | |
else: | |
detections2=response['data'] | |
mask = Image.fromarray(detections2.mask[0]) | |
# response['data']=mask | |
torch.cuda.empty_cache() | |
return mask |