cszhzleo commited on
Commit
25a9d18
·
verified ·
1 Parent(s): 5245454

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ #from streamlit_datalist import stDatalist
3
+ from utils import convert_to_base64, convert_to_html
4
+ import requests
5
+
6
+ IP = '127.0.0.1'
7
+ PORT= 8080
8
+ url = f'http://{IP}:{PORT}/predictions/model'
9
+ headers = {'Content-Type': 'application/json'}
10
+ st.set_page_config(page_title="AWS Inferentia2 Demo", layout="wide")
11
+ #st.set_page_config(layout="wide")
12
+
13
+ st.title("Multimodal Model on AWS Inf2")
14
+ st.subheader("LLaVA-1.6-Mistral-7B")
15
+
16
+
17
+ def upload_image():
18
+ image_list=["./images/view.jpg",
19
+ "./images/cat.jpg",
20
+ "./images/olympic.jpg",
21
+ "./images/usa.jpg",
22
+ "./images/box.jpg"]
23
+ name_list=["view(https://llava-vl.github.io/static/images/view.jpg)",
24
+ "cat",
25
+ "paris 2024",
26
+ "statue of liberty",
27
+ "box(from my camera)"]
28
+ images_all = dict(zip(name_list, image_list))
29
+
30
+ user_option = st.selectbox("Select a preset image", ["–Select–"] + name_list)
31
+ print(user_option)
32
+ if user_option!="–Select–":
33
+ image_names=[images_all[user_option]]
34
+ else:
35
+ image_names=[]
36
+
37
+ st.text("OR")
38
+
39
+ images = st.file_uploader("Upload an image to chat about", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
40
+ #print(images)
41
+ # assert max number of images, e.g. 1
42
+ assert len(images) <= 1, (st.error("Please upload at most 1 image"), st.stop())
43
+
44
+ if images or image_names:
45
+ if images:
46
+ image_names=[]
47
+ # convert images to base64
48
+ images_b64 = []
49
+ for image in images+image_names:
50
+ image_b64 = convert_to_base64(image)
51
+ images_b64.append(image_b64)
52
+
53
+ # display images in multiple columns
54
+ cols = st.columns(len(images_b64)) ##only process first image
55
+ for i, col in enumerate(cols):
56
+ col.markdown(f"**Image {i+1}**")
57
+ col.markdown(convert_to_html(images_b64[i]), unsafe_allow_html=True)
58
+ break #only process first image
59
+ st.markdown("---")
60
+ return images_b64[0] #only process first image
61
+ st.stop()
62
+
63
+
64
+ @st.cache_data(show_spinner=False)
65
+ def ask_llm(prompt, byte_image):
66
+ payload = {
67
+ "prompt":prompt,
68
+ "image": byte_image,
69
+ "parameters": {
70
+ "top_k": 100,
71
+ "top_p": 0.1,
72
+ "temperature": 0.2,
73
+ }
74
+ }
75
+ response = requests.post(url, json=payload, headers=headers)
76
+
77
+ return response.text
78
+
79
+ def app():
80
+ st.markdown("---")
81
+ c1, c2 = st.columns(2)
82
+ with c2:
83
+ image_b64 = upload_image()
84
+ with c1:
85
+ question = st.chat_input("Ask a question about this image")
86
+
87
+ if not question: st.stop()
88
+ with c1:
89
+ with st.chat_message("question"):
90
+ st.markdown(question, unsafe_allow_html=True)
91
+ with st.spinner("Thinking..."):
92
+ res = ask_llm(question, image_b64)
93
+ with st.chat_message("response"):
94
+ st.write(res)
95
+
96
+ if __name__ == "__main__":
97
+ app()