File size: 3,588 Bytes
21ac435
 
d6e285e
21ac435
 
1faeb5b
37db1ce
21ac435
 
43b8814
 
587ee18
 
 
 
 
 
 
 
 
 
 
 
 
 
d6e285e
21ac435
8d85358
21ac435
d6e285e
21ac435
8d85358
 
37db1ce
8d85358
 
 
21ac435
8d85358
 
37db1ce
8d85358
21ac435
d6e285e
21ac435
8d85358
21ac435
 
 
 
 
587ee18
21ac435
37db1ce
587ee18
 
8d85358
1faeb5b
587ee18
 
 
21ac435
587ee18
37db1ce
587ee18
21ac435
587ee18
 
 
 
21ac435
587ee18
21ac435
 
 
 
37db1ce
 
 
 
 
 
21ac435
 
 
 
 
 
 
 
 
587ee18
21ac435
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from transformers import Blip2ForConditionalGeneration
from transformers import Blip2Processor
from peft import PeftModel
import streamlit as st
from PIL import Image
#import torch
import os

preprocess_ckp = "Salesforce/blip2-opt-2.7b" #Checkpoint path used for perprocess image
base_model_ckp = "./model/blip2-opt-2.7b-fp16-sharded" #Base model checkpoint path
peft_model_ckp = "./model/blip2_peft" #PEFT model checkpoint path
sample_img_path = "./sample_images"

map_sampleid_name = {
                    'dress' : '00fe223d-9d1f-4bd3-a556-7ece9d28e6fb.jpeg',
                    'earrings': '0b3862ae-f89e-419c-bc1e-57418abd4180.jpeg',
                    'sweater': '0c21ba7b-ceb6-4136-94a4-1d4394499986.jpeg',
                    'sunglasses': '0e44ec10-e53b-473a-a77f-ac8828bb5e01.jpeg',
                    'shoe': '4cd37d6d-e7ea-4c6e-aab2-af700e480bc1.jpeg',
                    'hat': '69aeb517-c66c-47b8-af7d-bdf1fde57ed0.jpeg',
                    'heels':'447abc42-6ac7-4458-a514-bdcd570b1cd1.jpeg',
                    'socks': 'd188836c-b734-4031-98e5-423d5ff1239d.jpeg',
                    'tee': 'e2d8637a-5478-429d-a2a8-3d5859dbc64d.jpeg',
                    'bracelet': 'e78518ac-0f54-4483-a233-fad6511f0b86.jpeg'
                    }
#init_model_required = True

def init_model():

    #if init_model_required:

    #Preprocess input 
    processor = Blip2Processor.from_pretrained(preprocess_ckp)

    #Model   
    #Inferance on GPU device. Will give error in CPU system, as "load_in_8bit" is an setting of bitsandbytes library and only works for GPU
    #model = Blip2ForConditionalGeneration.from_pretrained(base_model_ckp, load_in_8bit = True, device_map = "auto") 

    #Inferance on CPU device
    model = Blip2ForConditionalGeneration.from_pretrained(base_model_ckp) 

    model = PeftModel.from_pretrained(model, peft_model_ckp)

        #init_model_required = False

    return processor, model

def main():

    st.title("Fashion Image Caption using BLIP2")

    processor, model = init_model()

    #Select few sample images for the catagory of cloths
    st.text("Select image:")
    option = st.selectbox('From sample', ('None', 'dress', 'earrings', 'sweater', 'sunglasses', 'shoe', 'hat', 'heels', 'socks', 'tee', 'bracelet'), index = 0)
    st.text("OR")
    file_name = st.file_uploader("Upload an image")
  
    image = None
    if file_name is not None:     

        image = Image.open(file_name)

    elif option is not 'None': 

        file_name = os.path.join(sample_img_path, map_sampleid_name[option])
        image = Image.open(file_name)

    if image is not None:

        image_col, caption_text = st.columns(2)
        image_col.header("Image")
        image_col.image(image, use_column_width = True)

        #Preprocess the image
        #Inferance on GPU. When used this on GPU will get errors like: "slow_conv2d_cpu" not implemented for 'Half'" , " Input type (float) and bias type (struct c10::Half)"
        #inputs = processor(images = image, return_tensors = "pt").to('cuda', torch.float16)

        #Inferance on CPU 
        inputs = processor(images = image, return_tensors = "pt")

        pixel_values = inputs.pixel_values

        #Predict the caption for the imahe
        generated_ids = model.generate(pixel_values = pixel_values, max_length = 25)
        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]  

        #Output the predict text
        caption_text.header("Generated Caption")
        caption_text.text(generated_caption)
    

if __name__ == "__main__":
    main()