|
import json |
|
import numpy as np |
|
from PIL import Image |
|
import torch.nn.functional as F |
|
import torch |
|
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer |
|
|
|
|
|
|
|
class StoppingCriteriaSub(StoppingCriteria): |
|
|
|
def __init__(self, stops=[], encounters=1): |
|
super().__init__() |
|
self.stops = stops |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
|
for stop in self.stops: |
|
if torch.all(input_ids[:, -len(stop):] == stop).item(): |
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
class Chat: |
|
|
|
def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda:0', stopping_criteria=None): |
|
self.device = device |
|
self.model = model |
|
self.transform = transform |
|
self.df = dataframe |
|
self.tar_img_feats = tar_img_feats |
|
self.img_feats = None |
|
self.target_recipe = None |
|
self.messages = [] |
|
|
|
if stopping_criteria is not None: |
|
self.stopping_criteria = stopping_criteria |
|
else: |
|
stop_words_ids = [torch.tensor([2]).to(self.device)] |
|
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) |
|
|
|
def encode_image(self, image_path): |
|
img = Image.fromarray(image_path).convert("RGB") |
|
img = self.transform(img).unsqueeze(0) |
|
img = img.to(self.device) |
|
img_embs = self.model.visual_encoder(img) |
|
img_feats = F.normalize(self.model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu() |
|
|
|
self.img_feats = img_feats |
|
|
|
self.get_target(self.img_feats, self.tar_img_feats) |
|
|
|
def get_target(self, img_feats, tar_img_feats) : |
|
score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy() |
|
index = np.argsort(score)[::-1][0] + 1 |
|
print(score) |
|
self.target_recipe = self.df.iloc[index] |
|
|
|
def ask(self, msg): |
|
if "nutrition" in msg or "nutrients" in msg : |
|
return json.dumps(self.target_recipe["recipe_nutrients"], indent=4) |
|
elif "instruction" in msg : |
|
return json.dumps(self.target_recipe["recipe_instructions"], indent=4) |
|
elif "ingredients" in msg : |
|
return json.dumps(self.target_recipe["recipe_ingredients"], indent=4) |
|
elif "tag" in msg or "class" in msg : |
|
return json.dumps(self.target_recipe["tags"], indent=4) |
|
else: |
|
return "Conversational capabilities will be included later." |
|
|