Advait3009 commited on
Commit
2f62b14
·
verified ·
1 Parent(s): 087290f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import base64
4
+ import torch
5
+ from PIL import Image
6
+ from utils.retriever import FAISSRetriever
7
+ from utils.embedder import MultiModalEmbedder
8
+ from utils.memory import ChatMemory
9
+ from utils.model_loader import load_llava_model
10
+ from transformers import TextStreamer
11
+
12
+ # Initialize components with caching
13
+ @st.cache_resource
14
+ def load_components():
15
+ embedder = MultiModalEmbedder()
16
+ retriever = FAISSRetriever()
17
+ llava_pipe = load_llava_model()
18
+ return embedder, retriever, llava_pipe
19
+
20
+ def main():
21
+ st.title("MultiModal RAG Chatbot 🤖🖼️")
22
+
23
+ # Initialize session state
24
+ if "messages" not in st.session_state:
25
+ st.session_state.messages = []
26
+ if "memory" not in st.session_state:
27
+ st.session_state.memory = ChatMemory()
28
+
29
+ # Sidebar for document upload
30
+ with st.sidebar:
31
+ st.header("Knowledge Base")
32
+ uploaded_files = st.file_uploader(
33
+ "Upload documents/images",
34
+ type=["pdf", "jpg", "png", "jpeg"],
35
+ accept_multiple_files=True
36
+ )
37
+
38
+ # Chat input
39
+ user_input = st.chat_input("Ask something or upload an image...")
40
+ uploaded_image = st.file_uploader("Upload image", type=["jpg", "png", "jpeg"], key="img_upload")
41
+
42
+ # Display chat history
43
+ for msg in st.session_state.messages:
44
+ with st.chat_message(msg["role"]):
45
+ if msg["type"] == "text":
46
+ st.markdown(msg["content"])
47
+ elif msg["type"] == "image":
48
+ st.image(msg["content"])
49
+
50
+ # Process inputs
51
+ if user_input or uploaded_image:
52
+ embedder, retriever, llava_pipe = load_components()
53
+
54
+ # Handle image upload
55
+ image = None
56
+ if uploaded_image:
57
+ image = Image.open(uploaded_image).convert("RGB")
58
+ with st.chat_message("user"):
59
+ st.image(image, caption="Uploaded Image", use_column_width=True)
60
+ st.session_state.messages.append({
61
+ "role": "user",
62
+ "type": "image",
63
+ "content": image
64
+ })
65
+
66
+ # Generate response
67
+ with st.spinner("Thinking..."):
68
+ # Retrieve context
69
+ if image:
70
+ image_emb = embedder.embed_image(image)
71
+ text_emb = embedder.embed_text(user_input) if user_input else None
72
+ context = retriever.search(image_emb, text_emb)
73
+ else:
74
+ context = retriever.search(text_emb=embedder.embed_text(user_input))
75
+
76
+ # Generate LLM response
77
+ prompt = f"CONTEXT: {context}\n\nQUERY: {user_input or 'Explain this image'}"
78
+ response = llava_pipe(
79
+ prompt,
80
+ image=image,
81
+ max_new_tokens=512,
82
+ streamer=TextStreamer(),
83
+ return_full_text=False
84
+ )[0]['generated_text']
85
+
86
+ # Update memory and display
87
+ st.session_state.memory.update(user_input, response)
88
+ with st.chat_message("assistant"):
89
+ st.markdown(response)
90
+ st.session_state.messages.append({
91
+ "role": "assistant",
92
+ "type": "text",
93
+ "content": response
94
+ })
95
+
96
+ if __name__ == "__main__":
97
+ main()