pr4nav101 commited on
Commit
ecc74a5
·
verified ·
1 Parent(s): 6518b25

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +416 -0
app.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import base64
3
+ import cv2
4
+ import io
5
+ import numpy as np
6
+ from ultralytics.utils.plotting import Annotator
7
+ import streamlit as st
8
+ from streamlit_image_coordinates import streamlit_image_coordinates
9
+ import pandas as pd
10
+ import ollama
11
+ import bs4
12
+ import tempfile
13
+
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from langchain_community.document_loaders import WebBaseLoader
16
+ from langchain_community.document_loaders import CSVLoader
17
+ from langchain_community.vectorstores import Chroma
18
+ from langchain_community.embeddings import OllamaEmbeddings
19
+ from langchain_core.output_parsers import StrOutputParser
20
+ from langchain_core.runnables import RunnablePassthrough
21
+
22
+ def set_background(image_file1,image_file2):
23
+
24
+ with open(image_file1, "rb") as f:
25
+ img_data1 = f.read()
26
+ b64_encoded1 = base64.b64encode(img_data1).decode()
27
+ with open(image_file2, "rb") as f:
28
+ img_data2 = f.read()
29
+ b64_encoded2 = base64.b64encode(img_data2).decode()
30
+ style = f"""
31
+ <style>
32
+ .stApp{{
33
+ background-image: url(data:image/png;base64,{b64_encoded1});
34
+ background-size: cover;
35
+
36
+ }}
37
+ .st-emotion-cache-6qob1r{{
38
+ background-image: url(data:image/png;base64,{b64_encoded2});
39
+ background-size: cover;
40
+ border: 5px solid rgb(14, 17, 23);
41
+
42
+ }}
43
+ </style>
44
+ """
45
+ st.markdown(style, unsafe_allow_html=True)
46
+
47
+ set_background('pngtree-city-map-navigation-interface-picture-image_1833642.png','2024-05-18_14-57-09_5235.png')
48
+
49
+ st.title("Traffic Flow and Optimization Toolkit")
50
+
51
+ sb = st.sidebar # defining the sidebar
52
+
53
+ sb.markdown("🛰️ **Navigation**")
54
+ page_names = ["PS1", "PS2", "PS3","Chat with Results"]
55
+ page = sb.radio("", page_names, index=0)
56
+ st.session_state['n'] = sb.slider("Number of ROIs",1,5)
57
+
58
+ if page == 'PS1':
59
+ uploaded_file = st.file_uploader("Choose a video...", type=["mp4", "mpeg"])
60
+ model = YOLO('yolov8n.pt')
61
+ if uploaded_file is not None:
62
+ with tempfile.NamedTemporaryFile() as temp:
63
+ temp.write(uploaded_file.getbuffer())
64
+ if 'roi_list1' not in st.session_state:
65
+ st.session_state['roi_list1'] = []
66
+ if "all_rois1" not in st.session_state:
67
+ st.session_state['all_rois1'] = []
68
+ classes = model.names
69
+
70
+ done_1 = st.button('Selection Done')
71
+
72
+ while len(st.session_state["all_rois1"]) < st.session_state['n']:
73
+ cap = cv2.VideoCapture(temp.name)
74
+ while not done_1:
75
+ ret,frame=cap.read()
76
+ cv2.putText(frame,'SELECT ROI',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
77
+ if not ret:
78
+ st.write('ROI selection has concluded')
79
+ break
80
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
81
+ value = streamlit_image_coordinates(frame,key='numpy',width=750)
82
+ st.session_state["roi_list1"].append([int(value['x']*2.55),int(value['y']*2.55)])
83
+ st.write(st.session_state["roi_list1"])
84
+ if cv2.waitKey(0)&0xFF==27:
85
+ break
86
+ cap.release()
87
+ st.session_state["all_rois1"].append(st.session_state["roi_list1"])
88
+ st.session_state["roi_list1"] = []
89
+ done_1 = False
90
+
91
+ st.write('ROI indices: ',st.session_state["all_rois1"][0])
92
+
93
+
94
+
95
+ cap = cv2.VideoCapture(temp.name)
96
+ st.write("Detection started")
97
+ st.session_state['fps'] = cap.get(cv2.CAP_PROP_FPS)
98
+ st.write(f"FPS OF VIDEO: {st.session_state['fps']}")
99
+ avg_list = []
100
+ count = 0
101
+ frame_placeholder = st.empty()
102
+ st.session_state["data1"] = {}
103
+ for i in range(len(st.session_state["all_rois1"])):
104
+ st.session_state["data1"][f"ROI{i}"] = []
105
+ while cap.isOpened():
106
+ ret,frame=cap.read()
107
+ if not ret:
108
+ break
109
+ count += 1
110
+ if count % 3 != 0:
111
+ continue
112
+ k = 0
113
+ for roi_list_here1 in st.session_state["all_rois1"]:
114
+ max = [0,0]
115
+ min = [10000,10000]
116
+ roi_list_here = roi_list_here1[1:]
117
+ for i in range(len(roi_list_here)):
118
+ if roi_list_here[i][0] > max[0]:
119
+ max[0] = roi_list_here[i][0]
120
+ if roi_list_here[i][1] > max[1]:
121
+ max[1] = roi_list_here[i][1]
122
+ if roi_list_here[i][0] < min[0]:
123
+ min[0] = roi_list_here[i][0]
124
+ if roi_list_here[i][1] < min[1]:
125
+ min[1] = roi_list_here[i][1]
126
+ frame_cropped = frame[min[1]:max[1],min[0]:max[0]]
127
+ roi_corners = np.array([roi_list_here],dtype=np.int32)
128
+ mask = np.zeros(frame.shape,dtype=np.uint8)
129
+ mask.fill(255)
130
+ channel_count = frame.shape[2]
131
+ ignore_mask_color = (255,)*channel_count
132
+ cv2.fillPoly(mask,roi_corners,0)
133
+ mask_cropped = mask[min[1]:max[1],min[0]:max[0]]
134
+ roi = cv2.bitwise_or(frame_cropped,mask_cropped)
135
+
136
+ #roi = frame[roi_list_here[0][1]:roi_list_here[1][1],roi_list_here[0][0]:roi_list_here[1][0]]
137
+ number = []
138
+ results = model.predict(roi)
139
+ for r in results:
140
+ boxes = r.boxes
141
+ counter = 0
142
+ for box in boxes:
143
+ counter += 1
144
+ name = classes[box.cls.numpy()[0]]
145
+ conf = str(round(box.conf.numpy()[0],2))
146
+ text = name+""+conf
147
+ bbox = box.xyxy[0].numpy()
148
+ cv2.rectangle(frame,(int(bbox[0])+min[0],int(bbox[1])+min[1]),(int(bbox[2])+min[0],int(bbox[3])+min[1]),(0,255,0),2)
149
+ cv2.putText(frame,text,(int(bbox[0])+min[0],int(bbox[1])+min[1]-5),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),2)
150
+ number.append(counter)
151
+ avg = sum(number)/len(number)
152
+ stats = str(round(avg,2))
153
+ if count%10 == 0:
154
+ st.session_state["data1"][f"ROI{k}"].append(avg)
155
+ k+=1
156
+ cv2.putText(frame,stats,(min[0],min[1]),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,0),4)
157
+ cv2.polylines(frame,roi_corners,True,(255,0,0),2)
158
+ cv2.putText(frame,'The average number of vehicles in the Regions of Interest',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
159
+ frame_placeholder.image(frame,channels='BGR')
160
+ st.write(st.session_state.data1)
161
+ cap.release()
162
+ else:
163
+ st.error('PLEASE UPLOAD AN IMAGE OF THE FORMAT JPG,JPEG OR PNG', icon="🚨")
164
+
165
+ elif page == "PS3":
166
+ uploaded_file1 = st.file_uploader("Choose a video...", type=["mp4", "mpeg"])
167
+ if uploaded_file1 is not None:
168
+ g = io.BytesIO(uploaded_file1.read())
169
+ temporary_location = "temp_PS2.mp4"
170
+
171
+ with open(temporary_location, 'wb') as out:
172
+ out.write(g.read())
173
+ out.close()
174
+ model1 = YOLO("yolov8n.pt")
175
+ model2 = YOLO("best.pt")
176
+ if 'roi_list2' not in st.session_state:
177
+ st.session_state['roi_list2'] = []
178
+ if "all_rois2" not in st.session_state:
179
+ st.session_state['all_rois2'] = []
180
+ classes = model1.names
181
+
182
+ done_2 = st.button('Selection Done')
183
+
184
+ while len(st.session_state["all_rois2"]) < st.session_state['n']:
185
+ cap = cv2.VideoCapture('temp_PS2.mp4')
186
+ while not done_2:
187
+ ret,frame=cap.read()
188
+ cv2.putText(frame,'SELECT ROI',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
189
+ if not ret:
190
+ st.write('ROI selection has concluded')
191
+ break
192
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
193
+ value = streamlit_image_coordinates(frame,key='numpy',width=750)
194
+ st.session_state["roi_list2"].append([int(value['x']*2.5),int(value['y']*2.5)])
195
+ st.write(st.session_state["roi_list2"])
196
+ if cv2.waitKey(0)&0xFF==27:
197
+ break
198
+ cap.release()
199
+ st.session_state["all_rois2"].append(st.session_state["roi_list2"])
200
+ st.session_state["roi_list2"] = []
201
+ done_2 = False
202
+
203
+ st.write('ROI indices: ',st.session_state["all_rois2"][0])
204
+
205
+
206
+
207
+ cap = cv2.VideoCapture('temp_PS2.MP4')
208
+ st.write("Detection started")
209
+ avg_list = []
210
+ count = 0
211
+ frame_placeholder = st.empty()
212
+ st.session_state.data = {}
213
+ for i in range(len(st.session_state["all_rois2"])):
214
+ st.session_state["data"][f"ROI{i}"] = []
215
+ for i in range(len(st.session_state['all_rois2'])):
216
+ st.session_state.data[f"ROI{i}"] = []
217
+ while cap.isOpened():
218
+ ret,frame=cap.read()
219
+ if not ret:
220
+ break
221
+ count += 1
222
+ if count % 3 != 0:
223
+ continue
224
+ # rois = []
225
+ k = 0
226
+ for roi_list_here1 in st.session_state["all_rois2"]:
227
+ max = [0,0]
228
+ min = [10000,10000]
229
+ roi_list_here = roi_list_here1[1:]
230
+ for i in range(len(roi_list_here)-1):
231
+ if roi_list_here[i][0] > max[0]:
232
+ max[0] = roi_list_here[i][0]
233
+ if roi_list_here[i][1] > max[1]:
234
+ max[1] = roi_list_here[i][1]
235
+ if roi_list_here[i][0] < min[0]:
236
+ min[0] = roi_list_here[i][0]
237
+ if roi_list_here[i][1] < min[1]:
238
+ min[1] = roi_list_here[i][1]
239
+ frame_cropped = frame[min[1]:max[1],min[0]:max[0]]
240
+ roi_corners = np.array([roi_list_here],dtype=np.int32)
241
+ mask = np.zeros(frame.shape,dtype=np.uint8)
242
+ mask.fill(255)
243
+ channel_count = frame.shape[2]
244
+ ignore_mask_color = (255,)*channel_count
245
+ cv2.fillPoly(mask,roi_corners,0)
246
+ mask_cropped = mask[min[1]:max[1],min[0]:max[0]]
247
+ roi = cv2.bitwise_or(frame_cropped,mask_cropped)
248
+
249
+ #roi = frame[roi_list_here[0][1]:roi_list_here[1][1],roi_list_here[0][0]:roi_list_here[1][0]]
250
+ number = []
251
+ results = model1.predict(roi)
252
+ results_pothole = model2.predict(source=frame)
253
+ for r in results:
254
+ boxes = r.boxes
255
+ counter = 0
256
+ for box in boxes:
257
+ counter += 1
258
+ name = classes[box.cls.numpy()[0]]
259
+ conf = str(round(box.conf.numpy()[0],2))
260
+ text = name+conf
261
+ bbox = box.xyxy[0].numpy()
262
+ cv2.rectangle(frame,(int(bbox[0])+min[0],int(bbox[1])+min[1]),(int(bbox[2])+min[0],int(bbox[3])+min[1]),(0,255,0),2)
263
+ cv2.putText(frame,text,(int(bbox[0])+min[0],int(bbox[1])+min[1]-5),cv2.FONT_HERSHEY_SIMPLEX, 0.4,(0,0,255),2)
264
+ number.append(counter)
265
+ for r in results_pothole:
266
+ masks = r.masks
267
+ boxes = r.boxes.cpu().numpy()
268
+ xyxys = boxes.xyxy
269
+ confs = boxes.conf
270
+ if masks is not None:
271
+ shapes = np.ones_like(frame)
272
+ for mask,conf,xyxy in zip(masks,confs,xyxys):
273
+ polygon = mask.xy[0]
274
+ if conf >= 0.49 and len(polygon)>=3:
275
+ cv2.fillPoly(shapes,pts=np.int32([polygon]),color=(0,0,255,0.5))
276
+ frame = cv2.addWeighted(frame,0.7,shapes,0.3,gamma=0)
277
+ cv2.rectangle(frame,(int(xyxy[0]),int(xyxy[1])),(int(xyxy[2]),int(xyxy[3])),(0,0,255),2)
278
+ cv2.putText(frame,'Pothole '+str(conf),(int(xyxy[0]),int(xyxy[1])-5),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),2)
279
+
280
+ avg = sum(number)/len(number)
281
+ stats = str(round(avg,2))
282
+ if count % 10 == 0:
283
+ st.session_state.data[f"ROI{k}"].append(avg)
284
+ k+=1
285
+ cv2.putText(frame,stats,(min[0],min[1]),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,0),4)
286
+ cv2.polylines(frame,roi_corners,True,(255,0,0),2)
287
+ if counter >= 5:
288
+ cv2.putText(frame,'!!CONGESTION MORE THAN '+str(counter)+' Objects',(min[0]+20,min[1]+20),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,0),4)
289
+ cv2.polylines(frame,roi_corners,True,(255,0,0),2)
290
+ cv2.putText(frame,'Objects in the Regions of Interest',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
291
+ frame_placeholder.image(frame,channels='BGR')
292
+ cap.release()
293
+
294
+ df = pd.DataFrame(st.session_state.data)
295
+ df.to_csv('PS2.csv', sep='\t', encoding='utf-8')
296
+
297
+ else:
298
+ st.error('PLEASE UPLOAD AN IMAGE OF THE FORMAT JPG,JPEG OR PNG', icon="🚨")
299
+
300
+ elif page == "PS2":
301
+ st.header("CLICK ON RUN SCRIPT TO START A TRAFFIC SIMULATION")
302
+ script = st.button("RUN SCRIPT")
303
+ st.session_state.con = -1
304
+ if script:
305
+ st.session_state.con += 1
306
+ import gymnasium as gym
307
+ import sumo_rl
308
+ import os
309
+ from stable_baselines3 import DQN
310
+ from stable_baselines3.common.vec_env import DummyVecEnv
311
+ from stable_baselines3.common.evaluation import evaluate_policy
312
+ from sumo_rl import SumoEnvironment
313
+ env = gym.make('sumo-rl-v0',
314
+ net_file='single-intersection.net.xml',
315
+ route_file='single-intersection-gen.rou.xml',
316
+ out_csv_name='output',
317
+ use_gui=True,
318
+ single_agent=True,
319
+ num_seconds=10000)
320
+ model1 = DQN.load('DQN_MODEL3.zip',env=env)
321
+ one,two = evaluate_policy(model1,env = env,n_eval_episodes=5,render=True)
322
+ st.write("Evaluation Results: ",one,two)
323
+ import matplotlib.pyplot as plt
324
+ def eval_plot(path,metric,path_compare = None):
325
+ data = pd.read_csv(path)
326
+ if path_compare is not None:
327
+ data1 = pd.read_csv(path_compare)
328
+ x = []
329
+ for i in range(0,len(data)):
330
+ x.append(i)
331
+
332
+ y = data[metric]
333
+ y_1 = pd.to_numeric(y)
334
+ y_arr = np.array(y_1)
335
+ if path_compare is not None:
336
+ y2 = data1[metric]
337
+ y_2 = pd.to_numeric(y2)
338
+ y_arr2 = np.array(y_2)
339
+
340
+ x_arr = np.array(x)
341
+
342
+ fig = plt.figure()
343
+ ax1 = fig.add_subplot(2, 1, 1)
344
+ ax1.set_title(metric)
345
+ if path_compare is not None:
346
+ ax2 = fig.add_subplot(2, 1, 2,sharey=ax1)
347
+ ax2.set_title('compare '+metric)
348
+
349
+ ax1.plot(x_arr,y_arr)
350
+
351
+ if path_compare is not None:
352
+ ax2.plot(x_arr,y_arr2)
353
+
354
+ return fig
355
+ for i in range(1,2):
356
+ st.pyplot(eval_plot(f'output_conn{st.session_state.con}_ep{i}.csv','system_mean_waiting_time'))
357
+ st.pyplot(eval_plot(f'output_conn{st.session_state.con}_ep{i}.csv','agents_total_accumulated_waiting_time'))
358
+
359
+ elif page == "Chat with Results":
360
+ st.title('Chat with the Results')
361
+ st.write("Please upload the relevant CSV data to get started")
362
+ reload = st.button('Reload')
363
+ if 'isran' not in st.session_state or reload == True:
364
+ st.session_state['isran'] = False
365
+
366
+
367
+ uploaded_file = st.file_uploader('Choose your .csv file', type=["csv"])
368
+ if uploaded_file is not None and st.session_state['isran'] == False:
369
+ with open("temp.csv", "wb") as f:
370
+ f.write(uploaded_file.getvalue())
371
+ loader = CSVLoader('temp.csv')
372
+ docs = loader.load()
373
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 200)
374
+ splits = text_splitter.split_documents(docs)
375
+
376
+ embeddings = OllamaEmbeddings(model='mistral')
377
+ st.session_state.vectorstore = Chroma.from_documents(documents=splits,embedding=embeddings)
378
+ st.session_state['isran'] = True
379
+
380
+ if st.session_state['isran'] == True:
381
+ st.write("Embedding created")
382
+
383
+ def fdocs(docs):
384
+ return "\n\n".join(doc.page_content for doc in docs)
385
+
386
+ def llm(question,context):
387
+ formatted_prompt = f"Question: {question}\n\nContext:{context}"
388
+ response = ollama.chat(model='mistral', messages=[
389
+ {
390
+ 'role': 'user',
391
+ 'content': formatted_prompt
392
+ },
393
+ ])
394
+ return response['message']['content']
395
+
396
+
397
+
398
+ def rag_chain(question):
399
+ retriever = st.session_state.vectorstore.as_retriever()
400
+ retrieved_docs = retriever.invoke(question)
401
+ formatted_context = fdocs(retrieved_docs)
402
+ return llm(question,formatted_context)
403
+
404
+ if 'messages' not in st.session_state:
405
+ st.session_state.messages = []
406
+
407
+ for message in st.session_state.messages:
408
+ st.chat_message(message['role']).markdown(message['content'])
409
+
410
+ prompt = st.chat_input("Say something")
411
+ response = rag_chain(prompt)
412
+ if prompt:
413
+ st.chat_message('user').markdown(prompt)
414
+ st.session_state.messages.append({'role':'user','content':prompt})
415
+ st.session_state.messages.append({'role':'AI','content':response})
416
+ st.chat_message('AI').markdown(response)