pr4nav101 commited on
Commit
2c95c7f
·
verified ·
1 Parent(s): 132947d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +418 -0
app.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import base64
3
+ import cv2
4
+ import io
5
+ import numpy as np
6
+ from ultralytics.utils.plotting import Annotator
7
+ import streamlit as st
8
+ from streamlit_image_coordinates import streamlit_image_coordinates
9
+ import pandas as pd
10
+ import ollama
11
+ import bs4
12
+ import tempfile
13
+
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from langchain_community.document_loaders import WebBaseLoader
16
+ from langchain_community.document_loaders import CSVLoader
17
+ from langchain_community.vectorstores import Chroma
18
+ from langchain_community.embeddings import OllamaEmbeddings
19
+ from langchain_core.output_parsers import StrOutputParser
20
+ from langchain_core.runnables import RunnablePassthrough
21
+
22
+ def set_background(image_file1,image_file2):
23
+
24
+ with open(image_file1, "rb") as f:
25
+ img_data1 = f.read()
26
+ b64_encoded1 = base64.b64encode(img_data1).decode()
27
+ with open(image_file2, "rb") as f:
28
+ img_data2 = f.read()
29
+ b64_encoded2 = base64.b64encode(img_data2).decode()
30
+ style = f"""
31
+ <style>
32
+ .stApp{{
33
+ background-image: url(data:image/png;base64,{b64_encoded1});
34
+ background-size: cover;
35
+
36
+ }}
37
+ .st-emotion-cache-6qob1r{{
38
+ background-image: url(data:image/png;base64,{b64_encoded2});
39
+ background-size: cover;
40
+ border: 5px solid rgb(14, 17, 23);
41
+
42
+ }}
43
+ </style>
44
+ """
45
+ st.markdown(style, unsafe_allow_html=True)
46
+
47
+ set_background('pngtree-city-map-navigation-interface-picture-image_1833642.png','2024-05-18_14-57-09_5235.png')
48
+
49
+ st.title("Traffic Flow and Optimization Toolkit")
50
+
51
+ sb = st.sidebar # defining the sidebar
52
+
53
+ sb.markdown("🛰️ **Navigation**")
54
+ page_names = ["PS1", "PS2", "PS3","Chat with Results"]
55
+ page = sb.radio("", page_names, index=0)
56
+ st.session_state['n'] = sb.slider("Number of ROIs",1,5)
57
+
58
+ if page == 'PS1':
59
+ uploaded_file = st.file_uploader("Choose a video...", type=["mp4", "mpeg"])
60
+ model = YOLO('yolov8n.pt')
61
+ if uploaded_file is not None:
62
+ g = io.BytesIO(uploaded_file.read())
63
+ with tempfile.NamedTemporaryFile() as temp:
64
+ temp.write(g)
65
+ if 'roi_list1' not in st.session_state:
66
+ st.session_state['roi_list1'] = []
67
+ if "all_rois1" not in st.session_state:
68
+ st.session_state['all_rois1'] = []
69
+ classes = model.names
70
+
71
+ done_1 = st.button('Selection Done')
72
+
73
+ while len(st.session_state["all_rois1"]) < st.session_state['n']:
74
+ cap = cv2.VideoCapture(temp.name)
75
+ while not done_1:
76
+ ret,frame=cap.read()
77
+ cv2.putText(frame,'SELECT ROI',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
78
+ if not ret:
79
+ st.write('ROI selection has concluded')
80
+ break
81
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
82
+ value = streamlit_image_coordinates(frame,key='numpy',width=750)
83
+ st.session_state["roi_list1"].append([int(value['x']*2.55),int(value['y']*2.55)])
84
+ st.write(st.session_state["roi_list1"])
85
+ if cv2.waitKey(0)&0xFF==27:
86
+ break
87
+ cap.release()
88
+ st.session_state["all_rois1"].append(st.session_state["roi_list1"])
89
+ st.session_state["roi_list1"] = []
90
+ done_1 = False
91
+
92
+ st.write('ROI indices: ',st.session_state["all_rois1"][0])
93
+
94
+
95
+
96
+ cap = cv2.VideoCapture(temp.name)
97
+ st.write("Detection started")
98
+ st.session_state['fps'] = cap.get(cv2.CAP_PROP_FPS)
99
+ st.write(f"FPS OF VIDEO: {st.session_state['fps']}")
100
+ avg_list = []
101
+ count = 0
102
+ frame_placeholder = st.empty()
103
+ st.session_state["data1"] = {}
104
+ for i in range(len(st.session_state["all_rois1"])):
105
+ st.session_state["data1"][f"ROI{i}"] = []
106
+ while cap.isOpened():
107
+ ret,frame=cap.read()
108
+ if not ret:
109
+ break
110
+ count += 1
111
+ if count % 3 != 0:
112
+ continue
113
+ k = 0
114
+ for roi_list_here1 in st.session_state["all_rois1"]:
115
+ max = [0,0]
116
+ min = [10000,10000]
117
+ roi_list_here = roi_list_here1[1:]
118
+ for i in range(len(roi_list_here)):
119
+ if roi_list_here[i][0] > max[0]:
120
+ max[0] = roi_list_here[i][0]
121
+ if roi_list_here[i][1] > max[1]:
122
+ max[1] = roi_list_here[i][1]
123
+ if roi_list_here[i][0] < min[0]:
124
+ min[0] = roi_list_here[i][0]
125
+ if roi_list_here[i][1] < min[1]:
126
+ min[1] = roi_list_here[i][1]
127
+ frame_cropped = frame[min[1]:max[1],min[0]:max[0]]
128
+ roi_corners = np.array([roi_list_here],dtype=np.int32)
129
+ mask = np.zeros(frame.shape,dtype=np.uint8)
130
+ mask.fill(255)
131
+ channel_count = frame.shape[2]
132
+ ignore_mask_color = (255,)*channel_count
133
+ cv2.fillPoly(mask,roi_corners,0)
134
+ mask_cropped = mask[min[1]:max[1],min[0]:max[0]]
135
+ roi = cv2.bitwise_or(frame_cropped,mask_cropped)
136
+
137
+ #roi = frame[roi_list_here[0][1]:roi_list_here[1][1],roi_list_here[0][0]:roi_list_here[1][0]]
138
+ number = []
139
+ results = model.predict(roi)
140
+ for r in results:
141
+ boxes = r.boxes
142
+ counter = 0
143
+ for box in boxes:
144
+ counter += 1
145
+ name = classes[box.cls.numpy()[0]]
146
+ conf = str(round(box.conf.numpy()[0],2))
147
+ text = name+""+conf
148
+ bbox = box.xyxy[0].numpy()
149
+ cv2.rectangle(frame,(int(bbox[0])+min[0],int(bbox[1])+min[1]),(int(bbox[2])+min[0],int(bbox[3])+min[1]),(0,255,0),2)
150
+ cv2.putText(frame,text,(int(bbox[0])+min[0],int(bbox[1])+min[1]-5),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),2)
151
+ number.append(counter)
152
+ avg = sum(number)/len(number)
153
+ stats = str(round(avg,2))
154
+ if count%10 == 0:
155
+ st.session_state["data1"][f"ROI{k}"].append(avg)
156
+ k+=1
157
+ cv2.putText(frame,stats,(min[0],min[1]),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,0),4)
158
+ cv2.polylines(frame,roi_corners,True,(255,0,0),2)
159
+ cv2.putText(frame,'The average number of vehicles in the Regions of Interest',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
160
+ frame_placeholder.image(frame,channels='BGR')
161
+ cap.release()
162
+ df = pd.DataFrame(st.session_state["data1"])
163
+ df.to_csv('PS1.csv', sep='\t', encoding='utf-8')
164
+ else:
165
+ st.error('PLEASE UPLOAD AN IMAGE OF THE FORMAT JPG,JPEG OR PNG', icon="🚨")
166
+
167
+ elif page == "PS3":
168
+ uploaded_file1 = st.file_uploader("Choose a video...", type=["mp4", "mpeg"])
169
+ if uploaded_file1 is not None:
170
+ g = io.BytesIO(uploaded_file1.read())
171
+ temporary_location = "temp_PS2.mp4"
172
+
173
+ with open(temporary_location, 'wb') as out:
174
+ out.write(g.read())
175
+ out.close()
176
+ model1 = YOLO("yolov8n.pt")
177
+ model2 = YOLO("best.pt")
178
+ if 'roi_list2' not in st.session_state:
179
+ st.session_state['roi_list2'] = []
180
+ if "all_rois2" not in st.session_state:
181
+ st.session_state['all_rois2'] = []
182
+ classes = model1.names
183
+
184
+ done_2 = st.button('Selection Done')
185
+
186
+ while len(st.session_state["all_rois2"]) < st.session_state['n']:
187
+ cap = cv2.VideoCapture('temp_PS2.mp4')
188
+ while not done_2:
189
+ ret,frame=cap.read()
190
+ cv2.putText(frame,'SELECT ROI',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
191
+ if not ret:
192
+ st.write('ROI selection has concluded')
193
+ break
194
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
195
+ value = streamlit_image_coordinates(frame,key='numpy',width=750)
196
+ st.session_state["roi_list2"].append([int(value['x']*2.5),int(value['y']*2.5)])
197
+ st.write(st.session_state["roi_list2"])
198
+ if cv2.waitKey(0)&0xFF==27:
199
+ break
200
+ cap.release()
201
+ st.session_state["all_rois2"].append(st.session_state["roi_list2"])
202
+ st.session_state["roi_list2"] = []
203
+ done_2 = False
204
+
205
+ st.write('ROI indices: ',st.session_state["all_rois2"][0])
206
+
207
+
208
+
209
+ cap = cv2.VideoCapture('temp_PS2.MP4')
210
+ st.write("Detection started")
211
+ avg_list = []
212
+ count = 0
213
+ frame_placeholder = st.empty()
214
+ st.session_state.data = {}
215
+ for i in range(len(st.session_state["all_rois2"])):
216
+ st.session_state["data"][f"ROI{i}"] = []
217
+ for i in range(len(st.session_state['all_rois2'])):
218
+ st.session_state.data[f"ROI{i}"] = []
219
+ while cap.isOpened():
220
+ ret,frame=cap.read()
221
+ if not ret:
222
+ break
223
+ count += 1
224
+ if count % 3 != 0:
225
+ continue
226
+ # rois = []
227
+ k = 0
228
+ for roi_list_here1 in st.session_state["all_rois2"]:
229
+ max = [0,0]
230
+ min = [10000,10000]
231
+ roi_list_here = roi_list_here1[1:]
232
+ for i in range(len(roi_list_here)-1):
233
+ if roi_list_here[i][0] > max[0]:
234
+ max[0] = roi_list_here[i][0]
235
+ if roi_list_here[i][1] > max[1]:
236
+ max[1] = roi_list_here[i][1]
237
+ if roi_list_here[i][0] < min[0]:
238
+ min[0] = roi_list_here[i][0]
239
+ if roi_list_here[i][1] < min[1]:
240
+ min[1] = roi_list_here[i][1]
241
+ frame_cropped = frame[min[1]:max[1],min[0]:max[0]]
242
+ roi_corners = np.array([roi_list_here],dtype=np.int32)
243
+ mask = np.zeros(frame.shape,dtype=np.uint8)
244
+ mask.fill(255)
245
+ channel_count = frame.shape[2]
246
+ ignore_mask_color = (255,)*channel_count
247
+ cv2.fillPoly(mask,roi_corners,0)
248
+ mask_cropped = mask[min[1]:max[1],min[0]:max[0]]
249
+ roi = cv2.bitwise_or(frame_cropped,mask_cropped)
250
+
251
+ #roi = frame[roi_list_here[0][1]:roi_list_here[1][1],roi_list_here[0][0]:roi_list_here[1][0]]
252
+ number = []
253
+ results = model1.predict(roi)
254
+ results_pothole = model2.predict(source=frame)
255
+ for r in results:
256
+ boxes = r.boxes
257
+ counter = 0
258
+ for box in boxes:
259
+ counter += 1
260
+ name = classes[box.cls.numpy()[0]]
261
+ conf = str(round(box.conf.numpy()[0],2))
262
+ text = name+conf
263
+ bbox = box.xyxy[0].numpy()
264
+ cv2.rectangle(frame,(int(bbox[0])+min[0],int(bbox[1])+min[1]),(int(bbox[2])+min[0],int(bbox[3])+min[1]),(0,255,0),2)
265
+ cv2.putText(frame,text,(int(bbox[0])+min[0],int(bbox[1])+min[1]-5),cv2.FONT_HERSHEY_SIMPLEX, 0.4,(0,0,255),2)
266
+ number.append(counter)
267
+ for r in results_pothole:
268
+ masks = r.masks
269
+ boxes = r.boxes.cpu().numpy()
270
+ xyxys = boxes.xyxy
271
+ confs = boxes.conf
272
+ if masks is not None:
273
+ shapes = np.ones_like(frame)
274
+ for mask,conf,xyxy in zip(masks,confs,xyxys):
275
+ polygon = mask.xy[0]
276
+ if conf >= 0.49 and len(polygon)>=3:
277
+ cv2.fillPoly(shapes,pts=np.int32([polygon]),color=(0,0,255,0.5))
278
+ frame = cv2.addWeighted(frame,0.7,shapes,0.3,gamma=0)
279
+ cv2.rectangle(frame,(int(xyxy[0]),int(xyxy[1])),(int(xyxy[2]),int(xyxy[3])),(0,0,255),2)
280
+ cv2.putText(frame,'Pothole '+str(conf),(int(xyxy[0]),int(xyxy[1])-5),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),2)
281
+
282
+ avg = sum(number)/len(number)
283
+ stats = str(round(avg,2))
284
+ if count % 10 == 0:
285
+ st.session_state.data[f"ROI{k}"].append(avg)
286
+ k+=1
287
+ cv2.putText(frame,stats,(min[0],min[1]),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,0),4)
288
+ cv2.polylines(frame,roi_corners,True,(255,0,0),2)
289
+ if counter >= 5:
290
+ cv2.putText(frame,'!!CONGESTION MORE THAN '+str(counter)+' Objects',(min[0]+20,min[1]+20),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,0),4)
291
+ cv2.polylines(frame,roi_corners,True,(255,0,0),2)
292
+ cv2.putText(frame,'Objects in the Regions of Interest',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
293
+ frame_placeholder.image(frame,channels='BGR')
294
+ cap.release()
295
+
296
+ df = pd.DataFrame(st.session_state.data)
297
+ df.to_csv('PS2.csv', sep='\t', encoding='utf-8')
298
+
299
+ else:
300
+ st.error('PLEASE UPLOAD AN IMAGE OF THE FORMAT JPG,JPEG OR PNG', icon="🚨")
301
+
302
+ elif page == "PS2":
303
+ st.header("CLICK ON RUN SCRIPT TO START A TRAFFIC SIMULATION")
304
+ script = st.button("RUN SCRIPT")
305
+ st.session_state.con = -1
306
+ if script:
307
+ st.session_state.con += 1
308
+ import gymnasium as gym
309
+ import sumo_rl
310
+ import os
311
+ from stable_baselines3 import DQN
312
+ from stable_baselines3.common.vec_env import DummyVecEnv
313
+ from stable_baselines3.common.evaluation import evaluate_policy
314
+ from sumo_rl import SumoEnvironment
315
+ env = gym.make('sumo-rl-v0',
316
+ net_file='single-intersection.net.xml',
317
+ route_file='single-intersection-gen.rou.xml',
318
+ out_csv_name='output',
319
+ use_gui=True,
320
+ single_agent=True,
321
+ num_seconds=10000)
322
+ model1 = DQN.load('DQN_MODEL3.zip',env=env)
323
+ one,two = evaluate_policy(model1,env = env,n_eval_episodes=5,render=True)
324
+ st.write("Evaluation Results: ",one,two)
325
+ import matplotlib.pyplot as plt
326
+ def eval_plot(path,metric,path_compare = None):
327
+ data = pd.read_csv(path)
328
+ if path_compare is not None:
329
+ data1 = pd.read_csv(path_compare)
330
+ x = []
331
+ for i in range(0,len(data)):
332
+ x.append(i)
333
+
334
+ y = data[metric]
335
+ y_1 = pd.to_numeric(y)
336
+ y_arr = np.array(y_1)
337
+ if path_compare is not None:
338
+ y2 = data1[metric]
339
+ y_2 = pd.to_numeric(y2)
340
+ y_arr2 = np.array(y_2)
341
+
342
+ x_arr = np.array(x)
343
+
344
+ fig = plt.figure()
345
+ ax1 = fig.add_subplot(2, 1, 1)
346
+ ax1.set_title(metric)
347
+ if path_compare is not None:
348
+ ax2 = fig.add_subplot(2, 1, 2,sharey=ax1)
349
+ ax2.set_title('compare '+metric)
350
+
351
+ ax1.plot(x_arr,y_arr)
352
+
353
+ if path_compare is not None:
354
+ ax2.plot(x_arr,y_arr2)
355
+
356
+ return fig
357
+ for i in range(1,2):
358
+ st.pyplot(eval_plot(f'output_conn{st.session_state.con}_ep{i}.csv','system_mean_waiting_time'))
359
+ st.pyplot(eval_plot(f'output_conn{st.session_state.con}_ep{i}.csv','agents_total_accumulated_waiting_time'))
360
+
361
+ elif page == "Chat with Results":
362
+ st.title('Chat with the Results')
363
+ st.write("Please upload the relevant CSV data to get started")
364
+ reload = st.button('Reload')
365
+ if 'isran' not in st.session_state or reload == True:
366
+ st.session_state['isran'] = False
367
+
368
+
369
+ uploaded_file = st.file_uploader('Choose your .csv file', type=["csv"])
370
+ if uploaded_file is not None and st.session_state['isran'] == False:
371
+ with open("temp.csv", "wb") as f:
372
+ f.write(uploaded_file.getvalue())
373
+ loader = CSVLoader('temp.csv')
374
+ docs = loader.load()
375
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 200)
376
+ splits = text_splitter.split_documents(docs)
377
+
378
+ embeddings = OllamaEmbeddings(model='mistral')
379
+ st.session_state.vectorstore = Chroma.from_documents(documents=splits,embedding=embeddings)
380
+ st.session_state['isran'] = True
381
+
382
+ if st.session_state['isran'] == True:
383
+ st.write("Embedding created")
384
+
385
+ def fdocs(docs):
386
+ return "\n\n".join(doc.page_content for doc in docs)
387
+
388
+ def llm(question,context):
389
+ formatted_prompt = f"Question: {question}\n\nContext:{context}"
390
+ response = ollama.chat(model='mistral', messages=[
391
+ {
392
+ 'role': 'user',
393
+ 'content': formatted_prompt
394
+ },
395
+ ])
396
+ return response['message']['content']
397
+
398
+
399
+
400
+ def rag_chain(question):
401
+ retriever = st.session_state.vectorstore.as_retriever()
402
+ retrieved_docs = retriever.invoke(question)
403
+ formatted_context = fdocs(retrieved_docs)
404
+ return llm(question,formatted_context)
405
+
406
+ if 'messages' not in st.session_state:
407
+ st.session_state.messages = []
408
+
409
+ for message in st.session_state.messages:
410
+ st.chat_message(message['role']).markdown(message['content'])
411
+
412
+ prompt = st.chat_input("Say something")
413
+ response = rag_chain(prompt)
414
+ if prompt:
415
+ st.chat_message('user').markdown(prompt)
416
+ st.session_state.messages.append({'role':'user','content':prompt})
417
+ st.session_state.messages.append({'role':'AI','content':response})
418
+ st.chat_message('AI').markdown(response)