nelbarman053 commited on
Commit
e6856f6
β€’
1 Parent(s): 94b8dff

app and caption generator file running and working

Browse files
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def main():
4
+
5
+ st.markdown('<div style="display: flex; justify-content: center;"><p style="font-size: 30px; font-weight: bold;">Artwork Caption Generation</p></div>', unsafe_allow_html=True)
6
+ col_left, col_right = st.columns(2)
7
+
8
+ with col_left:
9
+ col11, col12 = st.columns(2)
10
+ with col11:
11
+ st.image(image="assets\drawing.jpeg")
12
+ with col12:
13
+ st.write("<div style='text-align:right;'>Witness the birth of creativity. Our website breathes life into your artistic visions.</div>", unsafe_allow_html=True)
14
+ col21, col22 = st.columns(2)
15
+ with col21:
16
+ st.write("<div style='text-align:left;'>Delve into the depths of inspiration. Explore our platform to give voice to your masterpiece.</div>", unsafe_allow_html=True)
17
+ with col22:
18
+ st.image(image="assets\\thinking.jpeg")
19
+ col31, col32 = st.columns(2)
20
+ with col31:
21
+ st.image(image="assets\generate.jpeg")
22
+ with col32:
23
+ st.write("<div style='text-align:right;'>Crafting the narrative of your creation. Unleash your imagination with our caption generator for artwork.</div>", unsafe_allow_html=True)
24
+
25
+ with col_right:
26
+ st.subheader('How to use')
27
+
28
+ if st.button(label="Generator", type="primary"):
29
+ st.switch_page(page='pages/Caption_Generator.py')
30
+
31
+ if __name__ == "__main__":
32
+ main()
assets/an_artist_would_call.jpeg ADDED
assets/drawing.jpeg ADDED
assets/generate.jpeg ADDED
assets/thinking.jpeg ADDED
assets/to be added text.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ Witness the birth of creativity. Our website breathes life into your artistic visions.
5
+
6
+
7
+ Delve into the depths of inspiration. Explore our platform to give voice to your masterpiece.
8
+
9
+
10
+ Crafting the narrative of your creation. Unleash your imagination with our caption generator for artwork.
pages/Caption_Generator.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import streamlit as st
5
+ from torchvision.transforms import v2
6
+ from transformers import GenerationConfig
7
+ from transformers import GPT2TokenizerFast
8
+ from transformers import ViTImageProcessor
9
+ from transformers import VisionEncoderDecoderModel
10
+
11
+ # Page configuration settings
12
+ st.set_page_config(
13
+ layout="centered",
14
+ page_title="Generate Caption",
15
+ initial_sidebar_state="auto",
16
+ )
17
+
18
+ # Initializing session state keys
19
+ if all(key not in st.session_state.keys() for key in ("generate", "image")):
20
+ st.session_state["generate"] = False
21
+ st.session_state["image"] = None
22
+
23
+ # Loading necessary resources and caching them
24
+ @st.cache_resource(show_spinner="Loading Resources...")
25
+ def loadResources():
26
+ encoder = 'microsoft/swin-base-patch4-window7-224-in22k'
27
+ decoder = 'gpt2'
28
+
29
+ model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
30
+ encoder, decoder
31
+ )
32
+
33
+ processor = ViTImageProcessor.from_pretrained(encoder)
34
+ tokenizer = GPT2TokenizerFast.from_pretrained(decoder)
35
+
36
+ if 'gpt2' in decoder:
37
+ tokenizer.pad_token = tokenizer.eos_token
38
+ model.config.eos_token_id = tokenizer.eos_token_id
39
+ model.config.pad_token_id = tokenizer.pad_token_id
40
+ model.config.decoder_start_token_id = tokenizer.bos_token_id
41
+ else:
42
+ model.config.decoder_start_token_id = tokenizer.cls_token_id
43
+ model.config.pad_token_id = tokenizer.pad_token_id
44
+
45
+ model = torch.load("generator_model.pkl", map_location=torch.device("cpu"))
46
+ model.eval()
47
+ return processor, tokenizer, model
48
+
49
+ # Pre-processing image and caching
50
+ @st.cache_data
51
+ def preprocess_image(_processor, image):
52
+ transforms = v2.Compose([
53
+ v2.Resize(size=(224,224)),
54
+ v2.ToDtype(torch.float32, scale = True),
55
+ ])
56
+ image = transforms(image)
57
+ img = _processor(image, return_tensors = 'pt')
58
+ return img
59
+
60
+ # Generating caption and caching
61
+ @st.cache_data
62
+ def get_caption(_processor, _tokenizer, _model, image):
63
+ image = preprocess_image(_processor, image)
64
+ output = _model.generate(
65
+ **image,
66
+ generation_config = GenerationConfig(
67
+ pad_token_id = _tokenizer.pad_token_id
68
+ )
69
+ )
70
+
71
+ caption = _tokenizer.batch_decode(
72
+ output,
73
+ skip_special_tokens = True
74
+ )
75
+
76
+ return caption[0]
77
+
78
+ # Displaying elements
79
+ def DisplayInteractionElements():
80
+ st.markdown('<div style="display: flex; justify-content: center;"><p style="font-size: 40px; font-weight: bold;">πŸ‘‰ Caption Generator πŸ‘ˆ</p></div>', unsafe_allow_html=True)
81
+ st.file_uploader(accept_multiple_files=False, label='Upload an Image', type=['jpg', 'jpeg', 'png'], key="image_uploader")
82
+
83
+ if st.session_state['image_uploader']:
84
+ image = st.session_state['image_uploader']
85
+ im_file = Image.open(image).convert("RGB")
86
+ im_file = np.array(im_file)
87
+
88
+ st.session_state['image'] = im_file
89
+
90
+ col1, col2, col3 = st.columns(3)
91
+
92
+ col2.image(image=image, caption='Uploaded Image')
93
+
94
+ st.button(label='Generate Caption', use_container_width=True, type='primary', on_click=generateCaption)
95
+
96
+ # Triggering generate state
97
+ def generateCaption():
98
+ st.session_state['generate'] = True
99
+
100
+ def main():
101
+
102
+ DisplayInteractionElements()
103
+
104
+ processor, tokenizer, model = loadResources()
105
+
106
+ if not st.session_state['image_uploader']:
107
+ st.session_state['generate'] = False
108
+
109
+ if st.session_state['generate'] and st.session_state['image_uploader']:
110
+ caption = get_caption(processor, tokenizer, model, st.session_state['image'])
111
+ st.markdown(f'<div style="display: flex; justify-content: center;"><p style="font-size: 35px; font-weight: bold; color: blue;">{caption}</p></div>', unsafe_allow_html = True)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
116
+
117
+
118
+
119
+
120
+
requirements.txt ADDED
Binary file (214 Bytes). View file