import spaces
from transformers import (
TextIteratorStreamer,
)
from transformers import (
AutoProcessor,
# BitsAndBytesConfig,
LlavaForConditionalGeneration,
)
from PIL import Image
import gradio as gr
from threading import Thread
from dotenv import load_dotenv
# Import Supabase functions
from db_client import get_user_history, update_user_history, delete_user_history
# Add these imports
from datetime import datetime
import pytz
from gradio.components import LoginButton
from typing import Optional
from transformers import AutoModelForCausalLM, CodeGenTokenizerFast as Tokenizer
import torch
from theme import Seafoam
load_dotenv()
# Add TESTING variable
TESTING = False
IS_LOGGED_IN = False
USER_ID = None
# Hugging Face model id
# model_id = "mistral-community/pixtral-12b"
model_id = "blanchon/PixDiet-pixtral-nutrition-v2"
# BitsAndBytesConfig int-4 config
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_use_double_quant=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.bfloat16,
# )
# Modify the model and processor initialization
if TESTING:
model_id = "vikhyatk/moondream1"
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
processor = Tokenizer.from_pretrained(model_id)
else:
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
# quantization_config=bnb_config,
)
processor = AutoProcessor.from_pretrained(model_id)
# Set the chat template for the tokenizer
processor.chat_template = """
{%- for message in messages %}
{%- if message.role == "user" %}
[INST]
{%- for item in message.content %}
{%- if item.type == "text" %}
{{ item.text }}
{%- elif item.type == "image" %}
\n[IMG]
{%- endif %}
{%- endfor %}
[/INST]
{%- elif message.role == "assistant" %}
{%- for item in message.content %}
{%- if item.type == "text" %}
{{ item.text }}
{%- endif %}
{%- endfor %}
{%- endif %}
{%- endfor %}
""".replace(" ", "")
processor.tokenizer.pad_token = processor.tokenizer.eos_token
@spaces.GPU
def bot_streaming(chatbot, image_input, max_new_tokens=250):
# Preprocess inputs
messages = get_user_history(USER_ID)
images = []
text_input = chatbot[-1][0]
# Get current time in Paris timezone
paris_tz = pytz.timezone("Europe/Paris")
current_time = datetime.now(paris_tz).strftime("%I:%M%p")
if text_input != "":
text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?"
else:
text_input = f"Current time: {current_time}. You are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?"
# Add current message
if image_input is not None:
# Check if image_input is already a PIL Image
if isinstance(image_input, Image.Image):
image = image_input.convert("RGB")
else:
image = Image.fromarray(image_input).convert("RGB")
images.append(image)
messages.append(
{
"role": "user",
"content": [{"type": "text", "text": text_input}, {"type": "image"}],
}
)
else:
messages.append(
{"role": "user", "content": [{"type": "text", "text": text_input}]}
)
# Apply chat template
texts = processor.apply_chat_template(messages)
# Process inputs
if not images:
inputs = processor(text=texts, return_tensors="pt").to("cuda")
else:
inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(
processor.tokenizer, skip_special_tokens=True, skip_prompt=True
)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
chatbot[-1][1] = response
yield chatbot
thread.join()
# Debug output
print("*" * 60)
print("*" * 60)
print("BOT_STREAMING_CONV_START")
for i, (request, answer) in enumerate(chatbot[:-1], 1):
print(f"Q{i}:\n {request}")
print(f"A{i}:\n {answer}")
print("New_Q:\n", text_input)
print("New_A:\n", response)
print("BOT_STREAMING_CONV_END")
if IS_LOGGED_IN:
new_history = messages + [
{"role": "assistant", "content": [{"type": "text", "text": response}]}
]
update_user_history(USER_ID, new_history)
seafoam = Seafoam()
# Define the HTML content for the header
html = """
🍽️ PixDiet