ThisHumanA commited on
Commit
c106bd2
·
verified ·
1 Parent(s): d190243

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -97
app.py CHANGED
@@ -1,97 +1,105 @@
1
- import streamlit as st
2
- import gc
3
- from PIL import Image
4
- import torch
5
- from transformers import (
6
- BlipProcessor,
7
- BlipForConditionalGeneration,
8
- DetrImageProcessor,
9
- DetrForObjectDetection,
10
- )
11
- import google.generativeai as genai
12
-
13
- # Configure Generative AI
14
- genai.configure(api_key='AIzaSyB1wxTDQcB2YT_6l2nm4MrhAmCVPzfkHNU')
15
- gemini_model = genai.GenerativeModel("gemini-1.5-flash")
16
-
17
- # Load BLIP model and processor
18
- @st.cache_resource
19
- def load_blip():
20
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
21
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
22
- return processor, model
23
-
24
- # Load DETR model and processor
25
- @st.cache_resource
26
- def load_detr():
27
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
28
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
29
- return processor, model
30
-
31
- blip_processor, blip_model = load_blip()
32
- detr_processor, detr_model = load_detr()
33
-
34
- # Streamlit app
35
- st.title("Art Of Visual Storytelling")
36
-
37
- # Dropdown menu for genre selection
38
- genre = st.selectbox(
39
- "Select the genre of the story:",
40
- ["Fantasy", "Adventure", "Sci-Fi", "Romance", "Mystery", "Horror", "Comedy", "Drama"]
41
- )
42
-
43
- # Slider for specifying story length
44
- story_length = st.slider(
45
- "Select the desired story length (number of words):",
46
- min_value=50, max_value=1000, value=200, step=50
47
- )
48
-
49
- # Image upload and processing
50
- uploaded_image = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg'])
51
-
52
- if uploaded_image:
53
- # Display the uploaded image
54
- raw_image = Image.open(uploaded_image).convert('RGB')
55
- st.image(raw_image, caption='Uploaded Image', use_column_width=True)
56
-
57
- # Detect objects using DETR model
58
- detr_inputs = detr_processor(images=raw_image, return_tensors="pt")
59
- detr_model.eval()
60
- with torch.no_grad():
61
- detr_outputs = detr_model(**detr_inputs)
62
-
63
- target_sizes = torch.tensor([raw_image.size[::-1]]) # (height, width)
64
- results = detr_processor.post_process_object_detection(detr_outputs, target_sizes=target_sizes, threshold=0.9)[0]
65
-
66
- # Extract object names
67
- detected_objects = []
68
- for score, label in zip(results["scores"], results["labels"]):
69
- label_name = detr_model.config.id2label[label.item()]
70
- detected_objects.append(label_name)
71
-
72
- # Display the detected objects before generating the caption
73
- st.subheader("Detected Objects:")
74
- st.text(", ".join(set(detected_objects))) # Show unique objects
75
-
76
- # Generate caption using BLIP model
77
- blip_inputs = blip_processor(raw_image, return_tensors="pt")
78
- blip_outputs = blip_model.generate(**blip_inputs)
79
- caption = blip_processor.decode(blip_outputs[0], skip_special_tokens=True)
80
-
81
- # Display the generated caption after detected objects
82
- st.subheader("Generated Caption:")
83
- st.text(caption)
84
-
85
- # Submit button to generate the story
86
- if st.button("Generate Story"):
87
- prompt = (
88
- f"I have an image. The caption is '{caption}', and the detected objects are {', '.join(set(detected_objects))}. "
89
- f"Write a {genre.lower()} story with approximately {story_length} words using elements of this caption and the detected objects."
90
- )
91
- response = gemini_model.generate_content(prompt)
92
- st.subheader("Generated Story:")
93
- st.text_area("Story Output", value=response.text, height=300)
94
-
95
- # Cleanup memory
96
- del raw_image, detr_inputs, detr_outputs, blip_inputs, blip_outputs, results
97
- gc.collect()
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import gc
3
+ from PIL import Image
4
+ import torch
5
+ from transformers import (
6
+ BlipProcessor,
7
+ BlipForConditionalGeneration,
8
+ DetrImageProcessor,
9
+ DetrForObjectDetection,
10
+ )
11
+ import google.generativeai as genai
12
+
13
+ # Configure Generative AI
14
+ genai.configure(api_key='AIzaSyB1wxTDQcB2YT_6l2nm4MrhAmCVPzfkHNU')
15
+ gemini_model = genai.GenerativeModel("gemini-1.5-flash")
16
+
17
+ # Load BLIP model and processor
18
+ @st.cache_resource
19
+ def load_blip():
20
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
21
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
22
+ return processor, model
23
+
24
+ # Load DETR model and processor
25
+ @st.cache_resource
26
+ def load_detr():
27
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
28
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
29
+ return processor, model
30
+
31
+ blip_processor, blip_model = load_blip()
32
+ detr_processor, detr_model = load_detr()
33
+
34
+ # Streamlit app
35
+ st.title("Art Of Visual Storytelling")
36
+
37
+ # Dropdown menu for genre selection
38
+ genre = st.selectbox(
39
+ "Select the genre of the story:",
40
+ ["Fantasy", "Adventure", "Sci-Fi", "Romance", "Mystery", "Horror", "Comedy", "Drama"]
41
+ )
42
+
43
+ # Slider for specifying story length
44
+ story_length = st.slider(
45
+ "Select the desired story length (number of words):",
46
+ min_value=50, max_value=1000, value=200, step=50
47
+ )
48
+
49
+ # Image upload and processing
50
+ uploaded_image = st.file_uploader("Upload an Image", type=['png', 'jpg', 'jpeg'])
51
+
52
+ if uploaded_image:
53
+ # Display the uploaded image
54
+ raw_image = Image.open(uploaded_image).convert('RGB')
55
+ st.image(raw_image, caption='Uploaded Image', use_column_width=True)
56
+
57
+ # Detect objects using DETR model
58
+ detr_inputs = detr_processor(images=raw_image, return_tensors="pt")
59
+ detr_model.eval()
60
+ with torch.no_grad():
61
+ detr_outputs = detr_model(**detr_inputs)
62
+
63
+ target_sizes = torch.tensor([raw_image.size[::-1]]) # (height, width)
64
+ results = detr_processor.post_process_object_detection(detr_outputs, target_sizes=target_sizes, threshold=0.9)[0]
65
+
66
+ # Extract object names
67
+ detected_objects = []
68
+ for score, label in zip(results["scores"], results["labels"]):
69
+ label_name = detr_model.config.id2label[label.item()]
70
+ detected_objects.append(label_name)
71
+
72
+ # Display the detected objects before generating the caption
73
+ st.subheader("Detected Objects:")
74
+ st.text(", ".join(set(detected_objects))) # Show unique objects
75
+
76
+ # Generate caption using BLIP model
77
+ blip_inputs = blip_processor(raw_image, return_tensors="pt")
78
+ blip_outputs = blip_model.generate(**blip_inputs)
79
+ caption = blip_processor.decode(blip_outputs[0], skip_special_tokens=True)
80
+
81
+ # Display the generated caption after detected objects
82
+ st.subheader("Generated Caption:")
83
+ st.text(caption)
84
+
85
+ # Language selection box appears after the caption
86
+ st.subheader("Select Story Language:")
87
+ language = st.selectbox(
88
+ "Select the language of the story:",
89
+ ["English", "Hindi", "Bengali", "Tamil"]
90
+ )
91
+
92
+ # Submit button to generate the story
93
+ if st.button("Generate Story"):
94
+ prompt = (
95
+ f"I have an image. The caption is '{caption}', and the detected objects are {', '.join(set(detected_objects))}. "
96
+ f"Write a {genre.lower()} story in {language.lower()} with approximately {story_length} words using elements of this caption and the detected objects."
97
+ )
98
+ response = gemini_model.generate_content(prompt)
99
+ st.subheader("Generated Story:")
100
+ st.text_area("Story Output", value=response.text, height=300)
101
+
102
+ # Cleanup memory
103
+ del raw_image, detr_inputs, detr_outputs, blip_inputs, blip_outputs, results
104
+ gc.collect()
105
+