Sasiraj01 commited on
Commit
b85f85d
·
verified ·
1 Parent(s): 2e86476

Upload app-2.py

Browse files
Files changed (1) hide show
  1. app-2.py +52 -0
app-2.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
4
+ from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext, set_global_service_context
5
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
6
+ from llama_index.vector_stores.faiss import FaissVectorStore
7
+ from llama_index.storage.storage_context import StorageContext
8
+ import torch
9
+ from PIL import Image
10
+ import os
11
+
12
+ # Load LLaVA model and processor
13
+ model_id = "llava-hf/llava-1.5-7b-hf"
14
+ processor = AutoProcessor.from_pretrained(model_id)
15
+ model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True)
16
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # Load documents and build FAISS index
19
+ documents = SimpleDirectoryReader("docs").load_data()
20
+ embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en")
21
+ service_context = ServiceContext.from_defaults(embed_model=embed_model)
22
+ set_global_service_context(service_context)
23
+
24
+ index = VectorStoreIndex.from_documents(documents, service_context=service_context)
25
+ query_engine = index.as_query_engine()
26
+
27
+ def multimodal_rag(image, question):
28
+ # Step 1: RAG to retrieve context
29
+ context = query_engine.query(question)
30
+
31
+ # Step 2: Process with LLaVA
32
+ prompt = f"Context: {context}
33
+
34
+ Question: {question}"
35
+ inputs = processor(prompt, image, return_tensors="pt").to(model.device)
36
+ output = model.generate(**inputs, max_new_tokens=100)
37
+ answer = processor.decode(output[0], skip_special_tokens=True)
38
+ return answer
39
+
40
+ demo = gr.Interface(
41
+ fn=multimodal_rag,
42
+ inputs=[
43
+ gr.Image(type="pil", label="Upload Image"),
44
+ gr.Textbox(label="Enter your question")
45
+ ],
46
+ outputs="text",
47
+ title="Multimodal RAG with LLaVA and FAISS",
48
+ description="Upload an image and ask a question. The system retrieves relevant text using FAISS and answers using LLaVA."
49
+ )
50
+
51
+ if __name__ == "__main__":
52
+ demo.launch()