ThisHumanA commited on
Commit
a236b47
·
verified ·
1 Parent(s): d12acdf

Upload app2_blip.py

Browse files
Files changed (1) hide show
  1. app2_blip.py +97 -0
app2_blip.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import gc
3
+ from PIL import Image
4
+ import torch
5
+ from transformers import (
6
+ BlipProcessor,
7
+ BlipForConditionalGeneration,
8
+ DetrImageProcessor,
9
+ DetrForObjectDetection,
10
+ )
11
+ import google.generativeai as genai
12
+
13
+ # Configure Generative AI
14
+ genai.configure(api_key='AIzaSyB1wxTDQcB2YT_6l2nm4MrhAmCVPzfkHNU')
15
+ gemini_model = genai.GenerativeModel("gemini-1.5-flash")
16
+
17
+ # Load BLIP model and processor
18
+ @st.cache_resource
19
+ def load_blip():
20
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
21
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
22
+ return processor, model
23
+
24
+ # Load DETR model and processor
25
+ @st.cache_resource
26
+ def load_detr():
27
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
28
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
29
+ return processor, model
30
+
31
+ blip_processor, blip_model = load_blip()
32
+ detr_processor, detr_model = load_detr()
33
+
34
+ # Streamlit app
35
+ st.title("Art Of Visual Storytelling")
36
+
37
+ # Dropdown menu for genre selection
38
+ genre = st.selectbox(
39
+ "Select the genre of the story:",
40
+ ["Fantasy", "Adventure", "Sci-Fi", "Romance", "Mystery", "Horror", "Comedy", "Drama"]
41
+ )
42
+
43
+ # Slider for specifying story length
44
+ story_length = st.slider(
45
+ "Select the desired story length (number of words):",
46
+ min_value=50, max_value=1000, value=200, step=50
47
+ )
48
+
49
+ # Image upload and processing
50
+ uploaded_image = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg'])
51
+
52
+ if uploaded_image:
53
+ # Display the uploaded image
54
+ raw_image = Image.open(uploaded_image).convert('RGB')
55
+ st.image(raw_image, caption='Uploaded Image', use_column_width=True)
56
+
57
+ # Detect objects using DETR model
58
+ detr_inputs = detr_processor(images=raw_image, return_tensors="pt")
59
+ detr_model.eval()
60
+ with torch.no_grad():
61
+ detr_outputs = detr_model(**detr_inputs)
62
+
63
+ target_sizes = torch.tensor([raw_image.size[::-1]]) # (height, width)
64
+ results = detr_processor.post_process_object_detection(detr_outputs, target_sizes=target_sizes, threshold=0.9)[0]
65
+
66
+ # Extract object names
67
+ detected_objects = []
68
+ for score, label in zip(results["scores"], results["labels"]):
69
+ label_name = detr_model.config.id2label[label.item()]
70
+ detected_objects.append(label_name)
71
+
72
+ # Display the detected objects before generating the caption
73
+ st.subheader("Detected Objects:")
74
+ st.text(", ".join(set(detected_objects))) # Show unique objects
75
+
76
+ # Generate caption using BLIP model
77
+ blip_inputs = blip_processor(raw_image, return_tensors="pt")
78
+ blip_outputs = blip_model.generate(**blip_inputs)
79
+ caption = blip_processor.decode(blip_outputs[0], skip_special_tokens=True)
80
+
81
+ # Display the generated caption after detected objects
82
+ st.subheader("Generated Caption:")
83
+ st.text(caption)
84
+
85
+ # Submit button to generate the story
86
+ if st.button("Generate Story"):
87
+ prompt = (
88
+ f"I have an image. The caption is '{caption}', and the detected objects are {', '.join(set(detected_objects))}. "
89
+ f"Write a {genre.lower()} story with approximately {story_length} words using elements of this caption and the detected objects."
90
+ )
91
+ response = gemini_model.generate_content(prompt)
92
+ st.subheader("Generated Story:")
93
+ st.text_area("Story Output", value=response.text, height=300)
94
+
95
+ # Cleanup memory
96
+ del raw_image, detr_inputs, detr_outputs, blip_inputs, blip_outputs, results
97
+ gc.collect()