import os import torch from accelerate import Accelerator from PIL import Image import random import requests import streamlit as st from transformers import BlipProcessor, BlipForConditionalGeneration from langchain_huggingface import HuggingFaceEndpoint from langchain_core.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser # Define the model IDs llm_model_id = "mistralai/Mistral-7B-Instruct-v0.3" blip_model_id = "Salesforce/blip-image-captioning-large" # Initialize BLIP processor and model processor = BlipProcessor.from_pretrained(blip_model_id) model = BlipForConditionalGeneration.from_pretrained(blip_model_id) # Initialize the accelerator accelerator = Accelerator() def get_llm_hf_inference(model_id=llm_model_id, max_new_tokens=128, temperature=0.1): try: llm = HuggingFaceEndpoint( repo_id=model_id, max_new_tokens=max_new_tokens, temperature=temperature, token=os.getenv("HF_TOKEN") ) except Exception as e: st.error(f"Error loading model: {e}") llm = None return llm def generate_caption(image, min_len=30, max_len=100): try: inputs = processor(image, return_tensors="pt") out = model.generate(**inputs, min_length=min_len, max_length=max_len) caption = processor.decode(out[0], skip_special_tokens=True) return caption except Exception as e: st.error(f"Error generating caption: {e}") return 'Unable to generate caption.' # Configure the Streamlit app st.set_page_config(page_title="HuggingFace ChatBot", page_icon="🤗") st.title("Personal HuggingFace ChatBot") st.markdown(f"*This is a simple chatbot using the HuggingFace transformers library with {llm_model_id}.*") # Initialize session state if "avatars" not in st.session_state: st.session_state.avatars = {'user': None, 'assistant': None} if 'user_text' not in st.session_state: st.session_state.user_text = None if "max_response_length" not in st.session_state: st.session_state.max_response_length = 256 if "system_message" not in st.session_state: st.session_state.system_message = "friendly AI conversing with a human user" if "starter_message" not in st.session_state: st.session_state.starter_message = "Hello, there! How can I help you today?" if "uploaded_image_path" not in st.session_state: st.session_state.uploaded_image_path = None # Sidebar for settings with st.sidebar: st.header("System Settings") st.session_state.system_message = st.text_area( "System Message", value="You are a friendly AI conversing with a human user." ) st.session_state.starter_message = st.text_area( 'First AI Message', value="Hello, there! How can I help you today?" ) st.session_state.max_response_length = st.number_input( "Max Response Length", value=128 ) st.markdown("*Select Avatars:*") col1, col2 = st.columns(2) with col1: st.session_state.avatars['assistant'] = st.selectbox( "AI Avatar", options=["🤗", "💬", "🤖"], index=0 ) with col2: st.session_state.avatars['user'] = st.selectbox( "User Avatar", options=["👤", "👱‍♂️", "👨🏾", "👩", "👧🏾"], index=0 ) reset_history = st.button("Reset Chat History") # Initialize or reset chat history if "chat_history" not in st.session_state or reset_history: st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message}] def get_response(system_message, chat_history, user_text, max_new_tokens=256): hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1) if hf is None: return "Error with model inference.", chat_history prompt = PromptTemplate.from_template( "[INST] {system_message}\nCurrent Conversation:\n{chat_history}\n\nUser: {user_text}.\n [/INST]\nAI:" ) chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content') response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history)) response = response.split("AI:")[-1] chat_history.append({'role': 'user', 'content': user_text}) chat_history.append({'role': 'assistant', 'content': response}) return response, chat_history # Chat interface chat_interface = st.container() with chat_interface: output_container = st.container() # Image upload and captioning uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if uploaded_image and st.session_state.uploaded_image_path is None: # Save the uploaded image to a session-local directory with st.spinner("Processing image... 0%"): image = Image.open(uploaded_image).convert("RGB") # Create a directory for session images if not exists if not os.path.exists("session_images"): os.makedirs("session_images") # Save image to local session directory image_path = os.path.join("session_images", uploaded_image.name) image.save(image_path) # Generate and save caption caption = generate_caption(image) st.session_state.chat_history.append({'role': 'user', 'content': f'![uploaded image]({image_path})'}) st.session_state.chat_history.append({'role': 'assistant', 'content': caption}) st.spinner("Processing image... 100%") st.session_state.user_text = st.chat_input(placeholder="Enter your text here.") if st.session_state.user_text: with st.chat_message("user", avatar=st.session_state.avatars['user']): st.markdown(st.session_state.user_text) with st.chat_message("assistant", avatar=st.session_state.avatars['assistant']): response, st.session_state.chat_history = get_response( system_message=st.session_state.system_message, chat_history=st.session_state.chat_history, user_text=st.session_state.user_text, max_new_tokens=st.session_state.max_response_length ) st.markdown(response) st.spinner("Thinking... 100%")