pr4nav101 commited on
Commit
6dd9f64
·
verified ·
1 Parent(s): ce52b47

Upload 2 files

Browse files
Files changed (1) hide show
  1. app.py +419 -0
app.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import base64
3
+ import cv2
4
+ import io
5
+ import numpy as np
6
+ from ultralytics.utils.plotting import Annotator
7
+ import streamlit as st
8
+ from streamlit_image_coordinates import streamlit_image_coordinates
9
+ import pandas as pd
10
+ import ollama
11
+ import bs4
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain_community.document_loaders import WebBaseLoader
14
+ from langchain_community.document_loaders import CSVLoader
15
+ from langchain_community.vectorstores import Chroma
16
+ from langchain_community.embeddings import OllamaEmbeddings
17
+ from langchain_core.output_parsers import StrOutputParser
18
+ from langchain_core.runnables import RunnablePassthrough
19
+
20
+ def set_background(image_file1,image_file2):
21
+
22
+ with open(image_file1, "rb") as f:
23
+ img_data1 = f.read()
24
+ b64_encoded1 = base64.b64encode(img_data1).decode()
25
+ with open(image_file2, "rb") as f:
26
+ img_data2 = f.read()
27
+ b64_encoded2 = base64.b64encode(img_data2).decode()
28
+ style = f"""
29
+ <style>
30
+ .stApp{{
31
+ background-image: url(data:image/png;base64,{b64_encoded1});
32
+ background-size: cover;
33
+
34
+ }}
35
+ .st-emotion-cache-6qob1r{{
36
+ background-image: url(data:image/png;base64,{b64_encoded2});
37
+ background-size: cover;
38
+ border: 5px solid rgb(14, 17, 23);
39
+
40
+ }}
41
+ </style>
42
+ """
43
+ st.markdown(style, unsafe_allow_html=True)
44
+
45
+ set_background('pngtree-city-map-navigation-interface-picture-image_1833642.png','2024-05-18_14-57-09_5235.png')
46
+
47
+ st.title("Traffic Flow and Optimization Toolkit")
48
+
49
+ sb = st.sidebar # defining the sidebar
50
+
51
+ sb.markdown("🛰️ **Navigation**")
52
+ page_names = ["PS1", "PS2", "PS3","Chat with Results"]
53
+ page = sb.radio("", page_names, index=0)
54
+ st.session_state['n'] = sb.slider("Number of ROIs",1,5)
55
+
56
+ if page == 'PS1':
57
+ uploaded_file = st.file_uploader("Choose a video...", type=["mp4", "mpeg"])
58
+ if uploaded_file is not None:
59
+ g = io.BytesIO(uploaded_file.read())
60
+ temporary_location = "temp_PS1.mp4"
61
+
62
+ with open(temporary_location, 'wb') as out:
63
+ out.write(g.read())
64
+ out.close()
65
+ model = YOLO('yolov8n.pt')
66
+ if 'roi_list1' not in st.session_state:
67
+ st.session_state['roi_list1'] = []
68
+ if "all_rois1" not in st.session_state:
69
+ st.session_state['all_rois1'] = []
70
+ classes = model.names
71
+
72
+ done_1 = st.button('Selection Done')
73
+
74
+ while len(st.session_state["all_rois1"]) < st.session_state['n']:
75
+ cap = cv2.VideoCapture('temp_PS1.mp4')
76
+ while not done_1:
77
+ ret,frame=cap.read()
78
+ cv2.putText(frame,'SELECT ROI',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
79
+ if not ret:
80
+ st.write('ROI selection has concluded')
81
+ break
82
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
83
+ value = streamlit_image_coordinates(frame,key='numpy',width=750)
84
+ st.session_state["roi_list1"].append([int(value['x']*2.55),int(value['y']*2.55)])
85
+ st.write(st.session_state["roi_list1"])
86
+ if cv2.waitKey(0)&0xFF==27:
87
+ break
88
+ cap.release()
89
+ st.session_state["all_rois1"].append(st.session_state["roi_list1"])
90
+ st.session_state["roi_list1"] = []
91
+ done_1 = False
92
+
93
+ st.write('ROI indices: ',st.session_state["all_rois1"][0])
94
+
95
+
96
+
97
+ cap = cv2.VideoCapture('temp_PS1.MP4')
98
+ st.write("Detection started")
99
+ st.session_state['fps'] = cap.get(cv2.CAP_PROP_FPS)
100
+ st.write(f"FPS OF VIDEO: {st.session_state['fps']}")
101
+ avg_list = []
102
+ count = 0
103
+ frame_placeholder = st.empty()
104
+ st.session_state["data1"] = {}
105
+ for i in range(len(st.session_state["all_rois1"])):
106
+ st.session_state["data1"][f"ROI{i}"] = []
107
+ while cap.isOpened():
108
+ ret,frame=cap.read()
109
+ if not ret:
110
+ break
111
+ count += 1
112
+ if count % 3 != 0:
113
+ continue
114
+ k = 0
115
+ for roi_list_here1 in st.session_state["all_rois1"]:
116
+ max = [0,0]
117
+ min = [10000,10000]
118
+ roi_list_here = roi_list_here1[1:]
119
+ for i in range(len(roi_list_here)):
120
+ if roi_list_here[i][0] > max[0]:
121
+ max[0] = roi_list_here[i][0]
122
+ if roi_list_here[i][1] > max[1]:
123
+ max[1] = roi_list_here[i][1]
124
+ if roi_list_here[i][0] < min[0]:
125
+ min[0] = roi_list_here[i][0]
126
+ if roi_list_here[i][1] < min[1]:
127
+ min[1] = roi_list_here[i][1]
128
+ frame_cropped = frame[min[1]:max[1],min[0]:max[0]]
129
+ roi_corners = np.array([roi_list_here],dtype=np.int32)
130
+ mask = np.zeros(frame.shape,dtype=np.uint8)
131
+ mask.fill(255)
132
+ channel_count = frame.shape[2]
133
+ ignore_mask_color = (255,)*channel_count
134
+ cv2.fillPoly(mask,roi_corners,0)
135
+ mask_cropped = mask[min[1]:max[1],min[0]:max[0]]
136
+ roi = cv2.bitwise_or(frame_cropped,mask_cropped)
137
+
138
+ #roi = frame[roi_list_here[0][1]:roi_list_here[1][1],roi_list_here[0][0]:roi_list_here[1][0]]
139
+ number = []
140
+ results = model.predict(roi)
141
+ for r in results:
142
+ boxes = r.boxes
143
+ counter = 0
144
+ for box in boxes:
145
+ counter += 1
146
+ name = classes[box.cls.numpy()[0]]
147
+ conf = str(round(box.conf.numpy()[0],2))
148
+ text = name+""+conf
149
+ bbox = box.xyxy[0].numpy()
150
+ cv2.rectangle(frame,(int(bbox[0])+min[0],int(bbox[1])+min[1]),(int(bbox[2])+min[0],int(bbox[3])+min[1]),(0,255,0),2)
151
+ cv2.putText(frame,text,(int(bbox[0])+min[0],int(bbox[1])+min[1]-5),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),2)
152
+ number.append(counter)
153
+ avg = sum(number)/len(number)
154
+ stats = str(round(avg,2))
155
+ if count%10 == 0:
156
+ st.session_state["data1"][f"ROI{k}"].append(avg)
157
+ k+=1
158
+ cv2.putText(frame,stats,(min[0],min[1]),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,0),4)
159
+ cv2.polylines(frame,roi_corners,True,(255,0,0),2)
160
+ cv2.putText(frame,'The average number of vehicles in the Regions of Interest',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
161
+ frame_placeholder.image(frame,channels='BGR')
162
+ cap.release()
163
+ df = pd.DataFrame(st.session_state["data1"])
164
+ df.to_csv('PS1.csv', sep='\t', encoding='utf-8')
165
+ else:
166
+ st.error('PLEASE UPLOAD AN IMAGE OF THE FORMAT JPG,JPEG OR PNG', icon="🚨")
167
+
168
+ elif page == "PS3":
169
+ uploaded_file1 = st.file_uploader("Choose a video...", type=["mp4", "mpeg"])
170
+ if uploaded_file1 is not None:
171
+ g = io.BytesIO(uploaded_file1.read())
172
+ temporary_location = "temp_PS2.mp4"
173
+
174
+ with open(temporary_location, 'wb') as out:
175
+ out.write(g.read())
176
+ out.close()
177
+ model1 = YOLO("yolov8n.pt")
178
+ model2 = YOLO("best.pt")
179
+ if 'roi_list2' not in st.session_state:
180
+ st.session_state['roi_list2'] = []
181
+ if "all_rois2" not in st.session_state:
182
+ st.session_state['all_rois2'] = []
183
+ classes = model1.names
184
+
185
+ done_2 = st.button('Selection Done')
186
+
187
+ while len(st.session_state["all_rois2"]) < st.session_state['n']:
188
+ cap = cv2.VideoCapture('temp_PS2.mp4')
189
+ while not done_2:
190
+ ret,frame=cap.read()
191
+ cv2.putText(frame,'SELECT ROI',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
192
+ if not ret:
193
+ st.write('ROI selection has concluded')
194
+ break
195
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
196
+ value = streamlit_image_coordinates(frame,key='numpy',width=750)
197
+ st.session_state["roi_list2"].append([int(value['x']*2.5),int(value['y']*2.5)])
198
+ st.write(st.session_state["roi_list2"])
199
+ if cv2.waitKey(0)&0xFF==27:
200
+ break
201
+ cap.release()
202
+ st.session_state["all_rois2"].append(st.session_state["roi_list2"])
203
+ st.session_state["roi_list2"] = []
204
+ done_2 = False
205
+
206
+ st.write('ROI indices: ',st.session_state["all_rois2"][0])
207
+
208
+
209
+
210
+ cap = cv2.VideoCapture('temp_PS2.MP4')
211
+ st.write("Detection started")
212
+ avg_list = []
213
+ count = 0
214
+ frame_placeholder = st.empty()
215
+ st.session_state.data = {}
216
+ for i in range(len(st.session_state["all_rois2"])):
217
+ st.session_state["data"][f"ROI{i}"] = []
218
+ for i in range(len(st.session_state['all_rois2'])):
219
+ st.session_state.data[f"ROI{i}"] = []
220
+ while cap.isOpened():
221
+ ret,frame=cap.read()
222
+ if not ret:
223
+ break
224
+ count += 1
225
+ if count % 3 != 0:
226
+ continue
227
+ # rois = []
228
+ k = 0
229
+ for roi_list_here1 in st.session_state["all_rois2"]:
230
+ max = [0,0]
231
+ min = [10000,10000]
232
+ roi_list_here = roi_list_here1[1:]
233
+ for i in range(len(roi_list_here)-1):
234
+ if roi_list_here[i][0] > max[0]:
235
+ max[0] = roi_list_here[i][0]
236
+ if roi_list_here[i][1] > max[1]:
237
+ max[1] = roi_list_here[i][1]
238
+ if roi_list_here[i][0] < min[0]:
239
+ min[0] = roi_list_here[i][0]
240
+ if roi_list_here[i][1] < min[1]:
241
+ min[1] = roi_list_here[i][1]
242
+ frame_cropped = frame[min[1]:max[1],min[0]:max[0]]
243
+ roi_corners = np.array([roi_list_here],dtype=np.int32)
244
+ mask = np.zeros(frame.shape,dtype=np.uint8)
245
+ mask.fill(255)
246
+ channel_count = frame.shape[2]
247
+ ignore_mask_color = (255,)*channel_count
248
+ cv2.fillPoly(mask,roi_corners,0)
249
+ mask_cropped = mask[min[1]:max[1],min[0]:max[0]]
250
+ roi = cv2.bitwise_or(frame_cropped,mask_cropped)
251
+
252
+ #roi = frame[roi_list_here[0][1]:roi_list_here[1][1],roi_list_here[0][0]:roi_list_here[1][0]]
253
+ number = []
254
+ results = model1.predict(roi)
255
+ results_pothole = model2.predict(source=frame)
256
+ for r in results:
257
+ boxes = r.boxes
258
+ counter = 0
259
+ for box in boxes:
260
+ counter += 1
261
+ name = classes[box.cls.numpy()[0]]
262
+ conf = str(round(box.conf.numpy()[0],2))
263
+ text = name+conf
264
+ bbox = box.xyxy[0].numpy()
265
+ cv2.rectangle(frame,(int(bbox[0])+min[0],int(bbox[1])+min[1]),(int(bbox[2])+min[0],int(bbox[3])+min[1]),(0,255,0),2)
266
+ cv2.putText(frame,text,(int(bbox[0])+min[0],int(bbox[1])+min[1]-5),cv2.FONT_HERSHEY_SIMPLEX, 0.4,(0,0,255),2)
267
+ number.append(counter)
268
+ for r in results_pothole:
269
+ masks = r.masks
270
+ boxes = r.boxes.cpu().numpy()
271
+ xyxys = boxes.xyxy
272
+ confs = boxes.conf
273
+ if masks is not None:
274
+ shapes = np.ones_like(frame)
275
+ for mask,conf,xyxy in zip(masks,confs,xyxys):
276
+ polygon = mask.xy[0]
277
+ if conf >= 0.49 and len(polygon)>=3:
278
+ cv2.fillPoly(shapes,pts=np.int32([polygon]),color=(0,0,255,0.5))
279
+ frame = cv2.addWeighted(frame,0.7,shapes,0.3,gamma=0)
280
+ cv2.rectangle(frame,(int(xyxy[0]),int(xyxy[1])),(int(xyxy[2]),int(xyxy[3])),(0,0,255),2)
281
+ cv2.putText(frame,'Pothole '+str(conf),(int(xyxy[0]),int(xyxy[1])-5),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),2)
282
+
283
+ avg = sum(number)/len(number)
284
+ stats = str(round(avg,2))
285
+ if count % 10 == 0:
286
+ st.session_state.data[f"ROI{k}"].append(avg)
287
+ k+=1
288
+ cv2.putText(frame,stats,(min[0],min[1]),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,0),4)
289
+ cv2.polylines(frame,roi_corners,True,(255,0,0),2)
290
+ if counter >= 5:
291
+ cv2.putText(frame,'!!CONGESTION MORE THAN '+str(counter)+' Objects',(min[0]+20,min[1]+20),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,0),4)
292
+ cv2.polylines(frame,roi_corners,True,(255,0,0),2)
293
+ cv2.putText(frame,'Objects in the Regions of Interest',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
294
+ frame_placeholder.image(frame,channels='BGR')
295
+ cap.release()
296
+
297
+ df = pd.DataFrame(st.session_state.data)
298
+ df.to_csv('PS2.csv', sep='\t', encoding='utf-8')
299
+
300
+ else:
301
+ st.error('PLEASE UPLOAD AN IMAGE OF THE FORMAT JPG,JPEG OR PNG', icon="🚨")
302
+
303
+ elif page == "PS2":
304
+ st.header("CLICK ON RUN SCRIPT TO START A TRAFFIC SIMULATION")
305
+ script = st.button("RUN SCRIPT")
306
+ st.session_state.con = -1
307
+ if script:
308
+ st.session_state.con += 1
309
+ import gymnasium as gym
310
+ import sumo_rl
311
+ import os
312
+ from stable_baselines3 import DQN
313
+ from stable_baselines3.common.vec_env import DummyVecEnv
314
+ from stable_baselines3.common.evaluation import evaluate_policy
315
+ from sumo_rl import SumoEnvironment
316
+ env = gym.make('sumo-rl-v0',
317
+ net_file='single-intersection.net.xml',
318
+ route_file='single-intersection-gen.rou.xml',
319
+ out_csv_name='output',
320
+ use_gui=True,
321
+ single_agent=True,
322
+ num_seconds=10000)
323
+ model1 = DQN.load('DQN_MODEL3.zip',env=env)
324
+ one,two = evaluate_policy(model1,env = env,n_eval_episodes=5,render=True)
325
+ st.write("Evaluation Results: ",one,two)
326
+ import matplotlib.pyplot as plt
327
+ def eval_plot(path,metric,path_compare = None):
328
+ data = pd.read_csv(path)
329
+ if path_compare is not None:
330
+ data1 = pd.read_csv(path_compare)
331
+ x = []
332
+ for i in range(0,len(data)):
333
+ x.append(i)
334
+
335
+ y = data[metric]
336
+ y_1 = pd.to_numeric(y)
337
+ y_arr = np.array(y_1)
338
+ if path_compare is not None:
339
+ y2 = data1[metric]
340
+ y_2 = pd.to_numeric(y2)
341
+ y_arr2 = np.array(y_2)
342
+
343
+ x_arr = np.array(x)
344
+
345
+ fig = plt.figure()
346
+ ax1 = fig.add_subplot(2, 1, 1)
347
+ ax1.set_title(metric)
348
+ if path_compare is not None:
349
+ ax2 = fig.add_subplot(2, 1, 2,sharey=ax1)
350
+ ax2.set_title('compare '+metric)
351
+
352
+ ax1.plot(x_arr,y_arr)
353
+
354
+ if path_compare is not None:
355
+ ax2.plot(x_arr,y_arr2)
356
+
357
+ return fig
358
+ for i in range(1,2):
359
+ st.pyplot(eval_plot(f'output_conn{st.session_state.con}_ep{i}.csv','system_mean_waiting_time'))
360
+ st.pyplot(eval_plot(f'output_conn{st.session_state.con}_ep{i}.csv','agents_total_accumulated_waiting_time'))
361
+
362
+ elif page == "Chat with Results":
363
+ st.title('Chat with the Results')
364
+ st.write("Please upload the relevant CSV data to get started")
365
+ reload = st.button('Reload')
366
+ if 'isran' not in st.session_state or reload == True:
367
+ st.session_state['isran'] = False
368
+
369
+
370
+ uploaded_file = st.file_uploader('Choose your .csv file', type=["csv"])
371
+ if uploaded_file is not None and st.session_state['isran'] == False:
372
+ with open("temp.csv", "wb") as f:
373
+ f.write(uploaded_file.getvalue())
374
+ loader = CSVLoader('temp.csv')
375
+ docs = loader.load()
376
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 200)
377
+ splits = text_splitter.split_documents(docs)
378
+
379
+ embeddings = OllamaEmbeddings(model='mistral')
380
+ st.session_state.vectorstore = Chroma.from_documents(documents=splits,embedding=embeddings)
381
+ st.session_state['isran'] = True
382
+
383
+ if st.session_state['isran'] == True:
384
+ st.write("Embedding created")
385
+
386
+ def fdocs(docs):
387
+ return "\n\n".join(doc.page_content for doc in docs)
388
+
389
+ def llm(question,context):
390
+ formatted_prompt = f"Question: {question}\n\nContext:{context}"
391
+ response = ollama.chat(model='mistral', messages=[
392
+ {
393
+ 'role': 'user',
394
+ 'content': formatted_prompt
395
+ },
396
+ ])
397
+ return response['message']['content']
398
+
399
+
400
+
401
+ def rag_chain(question):
402
+ retriever = st.session_state.vectorstore.as_retriever()
403
+ retrieved_docs = retriever.invoke(question)
404
+ formatted_context = fdocs(retrieved_docs)
405
+ return llm(question,formatted_context)
406
+
407
+ if 'messages' not in st.session_state:
408
+ st.session_state.messages = []
409
+
410
+ for message in st.session_state.messages:
411
+ st.chat_message(message['role']).markdown(message['content'])
412
+
413
+ prompt = st.chat_input("Say something")
414
+ response = rag_chain(prompt)
415
+ if prompt:
416
+ st.chat_message('user').markdown(prompt)
417
+ st.session_state.messages.append({'role':'user','content':prompt})
418
+ st.session_state.messages.append({'role':'AI','content':response})
419
+ st.chat_message('AI').markdown(response)