Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import pandas as pd | |
import json | |
from PIL import Image | |
import numpy as np | |
import os | |
from pathlib import Path | |
import torch | |
import torch.nn.functional as F | |
# from src.data.embs import ImageDataset | |
from src.model.blip_embs import blip_embs | |
from src.data.transforms import transform_test | |
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
import gradio as gr | |
import spaces | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_groq import ChatGroq | |
from dotenv import load_dotenv | |
import asyncio | |
from flask import Flask, request, render_template | |
from flask_cors import CORS | |
from flask_socketio import SocketIO, emit, join_room, leave_room | |
# GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
GROQ_API_KEY = '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' | |
load_dotenv(".env") | |
USER_AGENT = os.getenv("USER_AGENT") | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
SECRET_KEY = os.getenv("SECRET_KEY") | |
# Set environment variables | |
os.environ['USER_AGENT'] = USER_AGENT | |
os.environ["GROQ_API_KEY"] = GROQ_API_KEY | |
os.environ["TOKENIZERS_PARALLELISM"] = 'true' | |
# Initialize Flask app and SocketIO with CORS | |
app = Flask(__name__) | |
CORS(app) | |
socketio = SocketIO(app, cors_allowed_origins="*") | |
app.config['SESSION_COOKIE_SECURE'] = True # Use HTTPS | |
app.config['SESSION_COOKIE_HTTPONLY'] = True | |
app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' | |
app.config['SECRET_KEY'] = SECRET_KEY | |
# Initialize LLM | |
llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0, max_tokens=1024, max_retries=2) | |
# QA system prompt and chain | |
qa_system_prompt = """ | |
Prompt: | |
You are a highly intelligent assistant. Use the following context to answer user questions. Analyze the data carefully and generate a clear, concise, and informative response to the user's question based on this data. | |
Response Guidelines: | |
- Use only the information provided in the data to answer the question. | |
- Ensure the answer is accurate and directly related to the question. | |
- If the data is insufficient to answer the question, politey apologise and tell the user that there is insufficient data available to answer their question. | |
- Provide the response in a conversational yet professional tone. | |
Context: | |
{context} | |
""" | |
qa_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", qa_system_prompt), | |
("human", "{input}") | |
] | |
) | |
question_answer_chain = qa_prompt | llm | StrOutputParser() | |
class StoppingCriteriaSub(StoppingCriteria): | |
def __init__(self, stops=[], encounters=1): | |
super().__init__() | |
self.stops = stops | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): | |
for stop in self.stops: | |
if torch.all(input_ids[:, -len(stop):] == stop).item(): | |
return True | |
return False | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def get_blip_config(model="base"): | |
config = dict() | |
if model == "base": | |
config[ | |
"pretrained" | |
] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth " | |
config["vit"] = "base" | |
config["batch_size_train"] = 32 | |
config["batch_size_test"] = 16 | |
config["vit_grad_ckpt"] = True | |
config["vit_ckpt_layer"] = 4 | |
config["init_lr"] = 1e-5 | |
elif model == "large": | |
config[ | |
"pretrained" | |
] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth" | |
config["vit"] = "large" | |
config["batch_size_train"] = 16 | |
config["batch_size_test"] = 32 | |
config["vit_grad_ckpt"] = True | |
config["vit_ckpt_layer"] = 12 | |
config["init_lr"] = 5e-6 | |
config["image_size"] = 384 | |
config["queue_size"] = 57600 | |
config["alpha"] = 0.4 | |
config["k_test"] = 256 | |
config["negative_all_rank"] = True | |
return config | |
print("Creating model") | |
config = get_blip_config("large") | |
model = blip_embs( | |
pretrained=config["pretrained"], | |
image_size=config["image_size"], | |
vit=config["vit"], | |
vit_grad_ckpt=config["vit_grad_ckpt"], | |
vit_ckpt_layer=config["vit_ckpt_layer"], | |
queue_size=config["queue_size"], | |
negative_all_rank=config["negative_all_rank"], | |
) | |
model = model.to(device) | |
model.eval() | |
print("Model Loaded !") | |
print("="*50) | |
transform = transform_test(384) | |
print("Loading Data") | |
df = pd.read_json("datasets/sidechef/my_recipes.json") | |
print("Loading Target Embedding") | |
tar_img_feats = [] | |
for _id in df["id_"].tolist(): | |
tar_img_feats.append(torch.load("datasets/sidechef/blip-embs-large/{:07d}.pth".format(_id)).unsqueeze(0)) | |
tar_img_feats = torch.cat(tar_img_feats, dim=0) | |
class Chat: | |
def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda:0', stopping_criteria=None): | |
self.device = device | |
self.model = model | |
self.transform = transform | |
self.df = dataframe | |
self.tar_img_feats = tar_img_feats | |
self.img_feats = None | |
self.target_recipe = None | |
self.messages = [] | |
if stopping_criteria is not None: | |
self.stopping_criteria = stopping_criteria | |
else: | |
stop_words_ids = [torch.tensor([2]).to(self.device)] | |
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) | |
def encode_image(self, image_path): | |
img = Image.fromarray(image_path).convert("RGB") | |
img = self.transform(img).unsqueeze(0) | |
img = img.to(self.device) | |
img_embs = model.visual_encoder(img) | |
img_feats = F.normalize(model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu() | |
self.img_feats = img_feats | |
self.get_target(self.img_feats, self.tar_img_feats) | |
def get_target(self, img_feats, tar_img_feats) : | |
score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy() | |
index = np.argsort(score)[::-1][0] + 1 | |
self.target_recipe = df.iloc[index] | |
def ask(self): | |
return json.dumps(self.target_recipe.to_json()) | |
chat = Chat(model,transform,df,tar_img_feats, device) | |
print("Chat Initialized !") | |
custom_css = """ | |
.primary{ | |
background-color: #4CAF50; /* Green */ | |
} | |
""" | |
def respond_to_user(image, message): | |
# Process the image and message here | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
chat = Chat(model,transform,df,tar_img_feats, device) | |
chat.encode_image(image) | |
data = chat.ask() | |
formated_input = { | |
'input': message, | |
'context': data | |
} | |
try: | |
response = question_answer_chain.invoke(formated_input) | |
except Exception as e: | |
response = {'content':"An error occurred while processing your request."} | |
return response | |
iface = gr.Interface( | |
fn=respond_to_user, | |
inputs=[gr.Image(), gr.Textbox(label="Ask Query")], | |
outputs=gr.Textbox(label="Nutrition-GPT"), | |
title="Nutrition-GPT Demo", | |
description="Upload an food image and ask queries!", | |
css=".component-12 {background-color: red}", | |
) | |
iface.launch() |