abdullahmubeen10 commited on
Commit
4dccf51
·
verified ·
1 Parent(s): ee37064

Update Demo.py

Browse files
Files changed (1) hide show
  1. Demo.py +126 -126
Demo.py CHANGED
@@ -1,127 +1,127 @@
1
- import streamlit as st
2
- import sparknlp
3
- import os
4
- import pandas as pd
5
-
6
- from sparknlp.base import *
7
- from sparknlp.annotator import *
8
- from pyspark.ml import Pipeline
9
- from sparknlp.pretrained import PretrainedPipeline
10
- from streamlit_tags import st_tags
11
-
12
- # Page configuration
13
- st.set_page_config(
14
- layout="wide",
15
- initial_sidebar_state="auto"
16
- )
17
-
18
- # CSS for styling
19
- st.markdown("""
20
- <style>
21
- .main-title {
22
- font-size: 36px;
23
- color: #4A90E2;
24
- font-weight: bold;
25
- text-align: center;
26
- }
27
- .section {
28
- background-color: #f9f9f9;
29
- padding: 10px;
30
- border-radius: 10px;
31
- margin-top: 10px;
32
- }
33
- .section p, .section ul {
34
- color: #666666;
35
- }
36
- </style>
37
- """, unsafe_allow_html=True)
38
-
39
- @st.cache_resource
40
- def init_spark():
41
- return sparknlp.start()
42
-
43
- @st.cache_resource
44
- def create_pipeline(model):
45
- image_assembler = ImageAssembler() \
46
- .setInputCol("image") \
47
- .setOutputCol("image_assembler")
48
-
49
- image_classifier = ViTForImageClassification \
50
- .pretrained(model) \
51
- .setInputCols("image_assembler") \
52
- .setOutputCol("class")
53
-
54
- pipeline = Pipeline(stages=[
55
- image_assembler,
56
- image_classifier,
57
- ])
58
- return pipeline
59
-
60
- def fit_data(pipeline, data):
61
- empty_df = spark.createDataFrame([['']]).toDF('text')
62
- model = pipeline.fit(empty_df)
63
- light_pipeline = LightPipeline(model)
64
- annotations_result = light_pipeline.fullAnnotateImage(data)
65
- return annotations_result[0]['class'][0].result
66
-
67
- def save_uploadedfile(uploadedfile):
68
- filepath = os.path.join(IMAGE_FILE_PATH, uploadedfile.name)
69
- with open(filepath, "wb") as f:
70
- if hasattr(uploadedfile, 'getbuffer'):
71
- f.write(uploadedfile.getbuffer())
72
- else:
73
- f.write(uploadedfile.read())
74
-
75
- # Sidebar content
76
- model_list = ['image_classifier_vit_base_cats_vs_dogs', 'image_classifier_vit_base_patch16_224', 'image_classifier_vit_CarViT', 'image_classifier_vit_base_beans_demo', 'image_classifier_vit_base_food101', 'image_classifier_vit_base_patch16_224_in21k_finetuned_cifar10']
77
- model = st.sidebar.selectbox(
78
- "Choose the pretrained model",
79
- model_list,
80
- help="For more info about the models visit: https://sparknlp.org/models"
81
- )
82
-
83
- # Set up the page layout
84
- st.markdown(f'<div class="main-title">ViT for Image Classification</div>', unsafe_allow_html=True)
85
- # st.markdown(f'<div class="section"><p>{sub_title}</p></div>', unsafe_allow_html=True)
86
-
87
- # Reference notebook link in sidebar
88
- link = """
89
- <a href="https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/image/ViTForImageClassification.ipynb">
90
- <img src="https://colab.research.google.com/assets/colab-badge.svg" style="zoom: 1.3" alt="Open In Colab"/>
91
- </a>
92
- """
93
- st.sidebar.markdown('Reference notebook:')
94
- st.sidebar.markdown(link, unsafe_allow_html=True)
95
-
96
- # Load examples
97
- IMAGE_FILE_PATH = f"/content/sparknlp VIT Image Classification/inputs/{model}"
98
- image_files = sorted([file for file in os.listdir(IMAGE_FILE_PATH) if file.split('.')[-1]=='png' or file.split('.')[-1]=='jpg' or file.split('.')[-1]=='JPEG' or file.split('.')[-1]=='jpeg'])
99
-
100
- st.subheader("This model identifies image classes using the vision transformer (ViT).")
101
-
102
- img_options = st.selectbox("Select an image", image_files)
103
- uploadedfile = st.file_uploader("Try it for yourself!")
104
-
105
- if uploadedfile:
106
- file_details = {"FileName":uploadedfile.name,"FileType":uploadedfile.type}
107
- save_uploadedfile(uploadedfile)
108
- selected_image = f"{IMAGE_FILE_PATH}/{uploadedfile.name}"
109
- elif img_options:
110
- selected_image = f"{IMAGE_FILE_PATH}/{img_options}"
111
-
112
- st.subheader('Classified Image')
113
-
114
- image_size = st.slider('Image Size', 400, 1000, value=400, step = 100)
115
-
116
- try:
117
- st.image(f"{IMAGE_FILE_PATH}/{selected_image}", width=image_size)
118
- except:
119
- st.image(selected_image, width=image_size)
120
-
121
- st.subheader('Classification')
122
-
123
- spark = init_spark()
124
- Pipeline = create_pipeline(model)
125
- output = fit_data(Pipeline, selected_image)
126
-
127
  st.markdown(f'This document has been classified as : **{output}**')
 
1
+ import streamlit as st
2
+ import sparknlp
3
+ import os
4
+ import pandas as pd
5
+
6
+ from sparknlp.base import *
7
+ from sparknlp.annotator import *
8
+ from pyspark.ml import Pipeline
9
+ from sparknlp.pretrained import PretrainedPipeline
10
+ from streamlit_tags import st_tags
11
+
12
+ # Page configuration
13
+ st.set_page_config(
14
+ layout="wide",
15
+ initial_sidebar_state="auto"
16
+ )
17
+
18
+ # CSS for styling
19
+ st.markdown("""
20
+ <style>
21
+ .main-title {
22
+ font-size: 36px;
23
+ color: #4A90E2;
24
+ font-weight: bold;
25
+ text-align: center;
26
+ }
27
+ .section {
28
+ background-color: #f9f9f9;
29
+ padding: 10px;
30
+ border-radius: 10px;
31
+ margin-top: 10px;
32
+ }
33
+ .section p, .section ul {
34
+ color: #666666;
35
+ }
36
+ </style>
37
+ """, unsafe_allow_html=True)
38
+
39
+ @st.cache_resource
40
+ def init_spark():
41
+ return sparknlp.start()
42
+
43
+ @st.cache_resource
44
+ def create_pipeline(model):
45
+ image_assembler = ImageAssembler() \
46
+ .setInputCol("image") \
47
+ .setOutputCol("image_assembler")
48
+
49
+ image_classifier = ViTForImageClassification \
50
+ .pretrained(model) \
51
+ .setInputCols("image_assembler") \
52
+ .setOutputCol("class")
53
+
54
+ pipeline = Pipeline(stages=[
55
+ image_assembler,
56
+ image_classifier,
57
+ ])
58
+ return pipeline
59
+
60
+ def fit_data(pipeline, data):
61
+ empty_df = spark.createDataFrame([['']]).toDF('text')
62
+ model = pipeline.fit(empty_df)
63
+ light_pipeline = LightPipeline(model)
64
+ annotations_result = light_pipeline.fullAnnotateImage(data)
65
+ return annotations_result[0]['class'][0].result
66
+
67
+ def save_uploadedfile(uploadedfile):
68
+ filepath = os.path.join(IMAGE_FILE_PATH, uploadedfile.name)
69
+ with open(filepath, "wb") as f:
70
+ if hasattr(uploadedfile, 'getbuffer'):
71
+ f.write(uploadedfile.getbuffer())
72
+ else:
73
+ f.write(uploadedfile.read())
74
+
75
+ # Sidebar content
76
+ model_list = ['image_classifier_vit_base_cats_vs_dogs', 'image_classifier_vit_base_patch16_224', 'image_classifier_vit_CarViT', 'image_classifier_vit_base_beans_demo', 'image_classifier_vit_base_food101', 'image_classifier_vit_base_patch16_224_in21k_finetuned_cifar10']
77
+ model = st.sidebar.selectbox(
78
+ "Choose the pretrained model",
79
+ model_list,
80
+ help="For more info about the models visit: https://sparknlp.org/models"
81
+ )
82
+
83
+ # Set up the page layout
84
+ st.markdown(f'<div class="main-title">ViT for Image Classification</div>', unsafe_allow_html=True)
85
+ # st.markdown(f'<div class="section"><p>{sub_title}</p></div>', unsafe_allow_html=True)
86
+
87
+ # Reference notebook link in sidebar
88
+ link = """
89
+ <a href="https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/image/ViTForImageClassification.ipynb">
90
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" style="zoom: 1.3" alt="Open In Colab"/>
91
+ </a>
92
+ """
93
+ st.sidebar.markdown('Reference notebook:')
94
+ st.sidebar.markdown(link, unsafe_allow_html=True)
95
+
96
+ # Load examples
97
+ IMAGE_FILE_PATH = f"inputs/{model}"
98
+ image_files = sorted([file for file in os.listdir(IMAGE_FILE_PATH) if file.split('.')[-1]=='png' or file.split('.')[-1]=='jpg' or file.split('.')[-1]=='JPEG' or file.split('.')[-1]=='jpeg'])
99
+
100
+ st.subheader("This model identifies image classes using the vision transformer (ViT).")
101
+
102
+ img_options = st.selectbox("Select an image", image_files)
103
+ uploadedfile = st.file_uploader("Try it for yourself!")
104
+
105
+ if uploadedfile:
106
+ file_details = {"FileName":uploadedfile.name,"FileType":uploadedfile.type}
107
+ save_uploadedfile(uploadedfile)
108
+ selected_image = f"{IMAGE_FILE_PATH}/{uploadedfile.name}"
109
+ elif img_options:
110
+ selected_image = f"{IMAGE_FILE_PATH}/{img_options}"
111
+
112
+ st.subheader('Classified Image')
113
+
114
+ image_size = st.slider('Image Size', 400, 1000, value=400, step = 100)
115
+
116
+ try:
117
+ st.image(f"{IMAGE_FILE_PATH}/{selected_image}", width=image_size)
118
+ except:
119
+ st.image(selected_image, width=image_size)
120
+
121
+ st.subheader('Classification')
122
+
123
+ spark = init_spark()
124
+ Pipeline = create_pipeline(model)
125
+ output = fit_data(Pipeline, selected_image)
126
+
127
  st.markdown(f'This document has been classified as : **{output}**')