import pandas as pd import json from PIL import Image import numpy as np import os from pathlib import Path import torch import torch.nn.functional as F # from src.data.embs import ImageDataset from src.model.blip_embs import blip_embs from src.data.transforms import transform_test from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer import gradio as gr import spaces from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_groq import ChatGroq from dotenv import load_dotenv import asyncio from flask import Flask, request, render_template from flask_cors import CORS from flask_socketio import SocketIO, emit, join_room, leave_room # GROQ_API_KEY = os.getenv("GROQ_API_KEY") GROQ_API_KEY = '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' load_dotenv(".env") USER_AGENT = os.getenv("USER_AGENT") GROQ_API_KEY = os.getenv("GROQ_API_KEY") SECRET_KEY = os.getenv("SECRET_KEY") # Set environment variables os.environ['USER_AGENT'] = USER_AGENT os.environ["GROQ_API_KEY"] = GROQ_API_KEY os.environ["TOKENIZERS_PARALLELISM"] = 'true' # Initialize Flask app and SocketIO with CORS app = Flask(__name__) CORS(app) socketio = SocketIO(app, cors_allowed_origins="*") app.config['SESSION_COOKIE_SECURE'] = True # Use HTTPS app.config['SESSION_COOKIE_HTTPONLY'] = True app.config['SESSION_COOKIE_SAMESITE'] = 'Lax' app.config['SECRET_KEY'] = SECRET_KEY # Initialize LLM llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0, max_tokens=1024, max_retries=2) # QA system prompt and chain qa_system_prompt = """ Prompt: You are a highly intelligent assistant. Use the following context to answer user questions. Analyze the data carefully and generate a clear, concise, and informative response to the user's question based on this data. Response Guidelines: - Use only the information provided in the data to answer the question. - Ensure the answer is accurate and directly related to the question. - If the data is insufficient to answer the question, politey apologise and tell the user that there is insufficient data available to answer their question. - Provide the response in a conversational yet professional tone. Context: {context} """ qa_prompt = ChatPromptTemplate.from_messages( [ ("system", qa_system_prompt), ("human", "{input}") ] ) question_answer_chain = qa_prompt | llm | StrOutputParser() class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops=[], encounters=1): super().__init__() self.stops = stops def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): for stop in self.stops: if torch.all(input_ids[:, -len(stop):] == stop).item(): return True return False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_blip_config(model="base"): config = dict() if model == "base": config[ "pretrained" ] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth " config["vit"] = "base" config["batch_size_train"] = 32 config["batch_size_test"] = 16 config["vit_grad_ckpt"] = True config["vit_ckpt_layer"] = 4 config["init_lr"] = 1e-5 elif model == "large": config[ "pretrained" ] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth" config["vit"] = "large" config["batch_size_train"] = 16 config["batch_size_test"] = 32 config["vit_grad_ckpt"] = True config["vit_ckpt_layer"] = 12 config["init_lr"] = 5e-6 config["image_size"] = 384 config["queue_size"] = 57600 config["alpha"] = 0.4 config["k_test"] = 256 config["negative_all_rank"] = True return config print("Creating model") config = get_blip_config("large") model = blip_embs( pretrained=config["pretrained"], image_size=config["image_size"], vit=config["vit"], vit_grad_ckpt=config["vit_grad_ckpt"], vit_ckpt_layer=config["vit_ckpt_layer"], queue_size=config["queue_size"], negative_all_rank=config["negative_all_rank"], ) model = model.to(device) model.eval() print("Model Loaded !") print("="*50) transform = transform_test(384) print("Loading Data") df = pd.read_json("datasets/sidechef/my_recipes.json") print("Loading Target Embedding") tar_img_feats = [] for _id in df["id_"].tolist(): tar_img_feats.append(torch.load("datasets/sidechef/blip-embs-large/{:07d}.pth".format(_id)).unsqueeze(0)) tar_img_feats = torch.cat(tar_img_feats, dim=0) class Chat: def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda:0', stopping_criteria=None): self.device = device self.model = model self.transform = transform self.df = dataframe self.tar_img_feats = tar_img_feats self.img_feats = None self.target_recipe = None self.messages = [] if stopping_criteria is not None: self.stopping_criteria = stopping_criteria else: stop_words_ids = [torch.tensor([2]).to(self.device)] self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) def encode_image(self, image_path): img = Image.fromarray(image_path).convert("RGB") img = self.transform(img).unsqueeze(0) img = img.to(self.device) img_embs = model.visual_encoder(img) img_feats = F.normalize(model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu() self.img_feats = img_feats self.get_target(self.img_feats, self.tar_img_feats) def get_target(self, img_feats, tar_img_feats) : score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy() index = np.argsort(score)[::-1][0] + 1 self.target_recipe = df.iloc[index] def ask(self): return json.dumps(self.target_recipe.to_json()) chat = Chat(model,transform,df,tar_img_feats, device) print("Chat Initialized !") custom_css = """ .primary{ background-color: #4CAF50; /* Green */ } """ @spaces.GPU def respond_to_user(image, message): # Process the image and message here device = torch.device("cuda" if torch.cuda.is_available() else "cpu") chat = Chat(model,transform,df,tar_img_feats, device) chat.encode_image(image) data = chat.ask() formated_input = { 'input': message, 'context': data } try: response = question_answer_chain.invoke(formated_input) except Exception as e: response = {'content':"An error occurred while processing your request."} return response iface = gr.Interface( fn=respond_to_user, inputs=[gr.Image(), gr.Textbox(label="Ask Query")], outputs=gr.Textbox(label="Nutrition-GPT"), title="Nutrition-GPT Demo", description="Upload an food image and ask queries!", css=".component-12 {background-color: red}", ) iface.launch()