Niha14 commited on
Commit
2cc0fa9
·
verified ·
1 Parent(s): b708b02

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ from byaldi import RAGMultiModalModel
5
+ import tempfile
6
+
7
+ # Function to upload image, run inference, and display output
8
+ def upload_image_and_infer():
9
+ # Step 1: Allow user to upload an image file
10
+ uploaded_file = st.file_uploader("Upload an image file", type=["jpg", "png", "jpeg"])
11
+
12
+ if uploaded_file is not None:
13
+ # Step 2: Save uploaded image to temporary file
14
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
15
+ temp_file.write(uploaded_file.read())
16
+ temp_path = temp_file.name
17
+
18
+ # Step 3: Display the uploaded image
19
+ image = Image.open(temp_path)
20
+ st.image(image, caption="Uploaded Image", use_column_width=True)
21
+
22
+ # Step 4: Load the RAGMultiModalModel and processor
23
+ RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
24
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
25
+ "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8", torch_dtype="auto", device_map="auto"
26
+ )
27
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8")
28
+
29
+ # Assuming `results` contains the page number information
30
+ text_query = "extract the details?"
31
+ RAG.index(
32
+ input_path=temp_path, # Using the uploaded image's temporary path
33
+ index_name="image_index",
34
+ store_collection_with_index=False,
35
+ overwrite=True
36
+ )
37
+ results = RAG.search(text_query, k=1)
38
+
39
+ # Step 5: Prepare messages for inference
40
+ image_index = results[0]["page_num"] - 1 # Get page number from the search result
41
+ messages = [
42
+ {
43
+ "role": "user",
44
+ "content": [
45
+ {
46
+ "type": "image",
47
+ "image": image, # Use the uploaded image
48
+ },
49
+ {"type": "text", "text": text_query},
50
+ ],
51
+ }
52
+ ]
53
+
54
+ # Step 6: Prepare input for the model
55
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
56
+ image_inputs, video_inputs = process_vision_info(messages) # Assuming process_vision_info is defined
57
+
58
+ # Tokenizing and preparing inputs
59
+ inputs = processor(
60
+ text=[text],
61
+ images=image_inputs,
62
+ videos=video_inputs,
63
+ padding=True,
64
+ return_tensors="pt",
65
+ )
66
+ inputs = inputs.to("cuda")
67
+
68
+ # Step 7: Inference and generate output
69
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
70
+ generated_ids_trimmed = [
71
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
72
+ ]
73
+
74
+ # Decode the generated output
75
+ output_text = processor.batch_decode(
76
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
77
+ )
78
+
79
+ # Step 8: Display the output in Streamlit
80
+ st.write("Generated Output:", output_text)
81
+ else:
82
+ st.write("Please upload an image.")
83
+
84
+ # Helper function to process images (replace with actual implementation if needed)
85
+ def process_vision_info(messages):
86
+ image_inputs = [msg['content'][0]['image'] for msg in messages if 'image' in msg['content'][0]]
87
+ video_inputs = [] # Assuming no video inputs for now
88
+ return image_inputs, video_inputs
89
+
90
+ # Run the function inside the Streamlit app
91
+ upload_image_and_infer()