pr4nav101 commited on
Commit
6518b25
·
verified ·
1 Parent(s): 2c95c7f

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -418
app.py DELETED
@@ -1,418 +0,0 @@
1
- from ultralytics import YOLO
2
- import base64
3
- import cv2
4
- import io
5
- import numpy as np
6
- from ultralytics.utils.plotting import Annotator
7
- import streamlit as st
8
- from streamlit_image_coordinates import streamlit_image_coordinates
9
- import pandas as pd
10
- import ollama
11
- import bs4
12
- import tempfile
13
-
14
- from langchain.text_splitter import RecursiveCharacterTextSplitter
15
- from langchain_community.document_loaders import WebBaseLoader
16
- from langchain_community.document_loaders import CSVLoader
17
- from langchain_community.vectorstores import Chroma
18
- from langchain_community.embeddings import OllamaEmbeddings
19
- from langchain_core.output_parsers import StrOutputParser
20
- from langchain_core.runnables import RunnablePassthrough
21
-
22
- def set_background(image_file1,image_file2):
23
-
24
- with open(image_file1, "rb") as f:
25
- img_data1 = f.read()
26
- b64_encoded1 = base64.b64encode(img_data1).decode()
27
- with open(image_file2, "rb") as f:
28
- img_data2 = f.read()
29
- b64_encoded2 = base64.b64encode(img_data2).decode()
30
- style = f"""
31
- <style>
32
- .stApp{{
33
- background-image: url(data:image/png;base64,{b64_encoded1});
34
- background-size: cover;
35
-
36
- }}
37
- .st-emotion-cache-6qob1r{{
38
- background-image: url(data:image/png;base64,{b64_encoded2});
39
- background-size: cover;
40
- border: 5px solid rgb(14, 17, 23);
41
-
42
- }}
43
- </style>
44
- """
45
- st.markdown(style, unsafe_allow_html=True)
46
-
47
- set_background('pngtree-city-map-navigation-interface-picture-image_1833642.png','2024-05-18_14-57-09_5235.png')
48
-
49
- st.title("Traffic Flow and Optimization Toolkit")
50
-
51
- sb = st.sidebar # defining the sidebar
52
-
53
- sb.markdown("🛰️ **Navigation**")
54
- page_names = ["PS1", "PS2", "PS3","Chat with Results"]
55
- page = sb.radio("", page_names, index=0)
56
- st.session_state['n'] = sb.slider("Number of ROIs",1,5)
57
-
58
- if page == 'PS1':
59
- uploaded_file = st.file_uploader("Choose a video...", type=["mp4", "mpeg"])
60
- model = YOLO('yolov8n.pt')
61
- if uploaded_file is not None:
62
- g = io.BytesIO(uploaded_file.read())
63
- with tempfile.NamedTemporaryFile() as temp:
64
- temp.write(g)
65
- if 'roi_list1' not in st.session_state:
66
- st.session_state['roi_list1'] = []
67
- if "all_rois1" not in st.session_state:
68
- st.session_state['all_rois1'] = []
69
- classes = model.names
70
-
71
- done_1 = st.button('Selection Done')
72
-
73
- while len(st.session_state["all_rois1"]) < st.session_state['n']:
74
- cap = cv2.VideoCapture(temp.name)
75
- while not done_1:
76
- ret,frame=cap.read()
77
- cv2.putText(frame,'SELECT ROI',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
78
- if not ret:
79
- st.write('ROI selection has concluded')
80
- break
81
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
82
- value = streamlit_image_coordinates(frame,key='numpy',width=750)
83
- st.session_state["roi_list1"].append([int(value['x']*2.55),int(value['y']*2.55)])
84
- st.write(st.session_state["roi_list1"])
85
- if cv2.waitKey(0)&0xFF==27:
86
- break
87
- cap.release()
88
- st.session_state["all_rois1"].append(st.session_state["roi_list1"])
89
- st.session_state["roi_list1"] = []
90
- done_1 = False
91
-
92
- st.write('ROI indices: ',st.session_state["all_rois1"][0])
93
-
94
-
95
-
96
- cap = cv2.VideoCapture(temp.name)
97
- st.write("Detection started")
98
- st.session_state['fps'] = cap.get(cv2.CAP_PROP_FPS)
99
- st.write(f"FPS OF VIDEO: {st.session_state['fps']}")
100
- avg_list = []
101
- count = 0
102
- frame_placeholder = st.empty()
103
- st.session_state["data1"] = {}
104
- for i in range(len(st.session_state["all_rois1"])):
105
- st.session_state["data1"][f"ROI{i}"] = []
106
- while cap.isOpened():
107
- ret,frame=cap.read()
108
- if not ret:
109
- break
110
- count += 1
111
- if count % 3 != 0:
112
- continue
113
- k = 0
114
- for roi_list_here1 in st.session_state["all_rois1"]:
115
- max = [0,0]
116
- min = [10000,10000]
117
- roi_list_here = roi_list_here1[1:]
118
- for i in range(len(roi_list_here)):
119
- if roi_list_here[i][0] > max[0]:
120
- max[0] = roi_list_here[i][0]
121
- if roi_list_here[i][1] > max[1]:
122
- max[1] = roi_list_here[i][1]
123
- if roi_list_here[i][0] < min[0]:
124
- min[0] = roi_list_here[i][0]
125
- if roi_list_here[i][1] < min[1]:
126
- min[1] = roi_list_here[i][1]
127
- frame_cropped = frame[min[1]:max[1],min[0]:max[0]]
128
- roi_corners = np.array([roi_list_here],dtype=np.int32)
129
- mask = np.zeros(frame.shape,dtype=np.uint8)
130
- mask.fill(255)
131
- channel_count = frame.shape[2]
132
- ignore_mask_color = (255,)*channel_count
133
- cv2.fillPoly(mask,roi_corners,0)
134
- mask_cropped = mask[min[1]:max[1],min[0]:max[0]]
135
- roi = cv2.bitwise_or(frame_cropped,mask_cropped)
136
-
137
- #roi = frame[roi_list_here[0][1]:roi_list_here[1][1],roi_list_here[0][0]:roi_list_here[1][0]]
138
- number = []
139
- results = model.predict(roi)
140
- for r in results:
141
- boxes = r.boxes
142
- counter = 0
143
- for box in boxes:
144
- counter += 1
145
- name = classes[box.cls.numpy()[0]]
146
- conf = str(round(box.conf.numpy()[0],2))
147
- text = name+""+conf
148
- bbox = box.xyxy[0].numpy()
149
- cv2.rectangle(frame,(int(bbox[0])+min[0],int(bbox[1])+min[1]),(int(bbox[2])+min[0],int(bbox[3])+min[1]),(0,255,0),2)
150
- cv2.putText(frame,text,(int(bbox[0])+min[0],int(bbox[1])+min[1]-5),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),2)
151
- number.append(counter)
152
- avg = sum(number)/len(number)
153
- stats = str(round(avg,2))
154
- if count%10 == 0:
155
- st.session_state["data1"][f"ROI{k}"].append(avg)
156
- k+=1
157
- cv2.putText(frame,stats,(min[0],min[1]),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,0),4)
158
- cv2.polylines(frame,roi_corners,True,(255,0,0),2)
159
- cv2.putText(frame,'The average number of vehicles in the Regions of Interest',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
160
- frame_placeholder.image(frame,channels='BGR')
161
- cap.release()
162
- df = pd.DataFrame(st.session_state["data1"])
163
- df.to_csv('PS1.csv', sep='\t', encoding='utf-8')
164
- else:
165
- st.error('PLEASE UPLOAD AN IMAGE OF THE FORMAT JPG,JPEG OR PNG', icon="🚨")
166
-
167
- elif page == "PS3":
168
- uploaded_file1 = st.file_uploader("Choose a video...", type=["mp4", "mpeg"])
169
- if uploaded_file1 is not None:
170
- g = io.BytesIO(uploaded_file1.read())
171
- temporary_location = "temp_PS2.mp4"
172
-
173
- with open(temporary_location, 'wb') as out:
174
- out.write(g.read())
175
- out.close()
176
- model1 = YOLO("yolov8n.pt")
177
- model2 = YOLO("best.pt")
178
- if 'roi_list2' not in st.session_state:
179
- st.session_state['roi_list2'] = []
180
- if "all_rois2" not in st.session_state:
181
- st.session_state['all_rois2'] = []
182
- classes = model1.names
183
-
184
- done_2 = st.button('Selection Done')
185
-
186
- while len(st.session_state["all_rois2"]) < st.session_state['n']:
187
- cap = cv2.VideoCapture('temp_PS2.mp4')
188
- while not done_2:
189
- ret,frame=cap.read()
190
- cv2.putText(frame,'SELECT ROI',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
191
- if not ret:
192
- st.write('ROI selection has concluded')
193
- break
194
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
195
- value = streamlit_image_coordinates(frame,key='numpy',width=750)
196
- st.session_state["roi_list2"].append([int(value['x']*2.5),int(value['y']*2.5)])
197
- st.write(st.session_state["roi_list2"])
198
- if cv2.waitKey(0)&0xFF==27:
199
- break
200
- cap.release()
201
- st.session_state["all_rois2"].append(st.session_state["roi_list2"])
202
- st.session_state["roi_list2"] = []
203
- done_2 = False
204
-
205
- st.write('ROI indices: ',st.session_state["all_rois2"][0])
206
-
207
-
208
-
209
- cap = cv2.VideoCapture('temp_PS2.MP4')
210
- st.write("Detection started")
211
- avg_list = []
212
- count = 0
213
- frame_placeholder = st.empty()
214
- st.session_state.data = {}
215
- for i in range(len(st.session_state["all_rois2"])):
216
- st.session_state["data"][f"ROI{i}"] = []
217
- for i in range(len(st.session_state['all_rois2'])):
218
- st.session_state.data[f"ROI{i}"] = []
219
- while cap.isOpened():
220
- ret,frame=cap.read()
221
- if not ret:
222
- break
223
- count += 1
224
- if count % 3 != 0:
225
- continue
226
- # rois = []
227
- k = 0
228
- for roi_list_here1 in st.session_state["all_rois2"]:
229
- max = [0,0]
230
- min = [10000,10000]
231
- roi_list_here = roi_list_here1[1:]
232
- for i in range(len(roi_list_here)-1):
233
- if roi_list_here[i][0] > max[0]:
234
- max[0] = roi_list_here[i][0]
235
- if roi_list_here[i][1] > max[1]:
236
- max[1] = roi_list_here[i][1]
237
- if roi_list_here[i][0] < min[0]:
238
- min[0] = roi_list_here[i][0]
239
- if roi_list_here[i][1] < min[1]:
240
- min[1] = roi_list_here[i][1]
241
- frame_cropped = frame[min[1]:max[1],min[0]:max[0]]
242
- roi_corners = np.array([roi_list_here],dtype=np.int32)
243
- mask = np.zeros(frame.shape,dtype=np.uint8)
244
- mask.fill(255)
245
- channel_count = frame.shape[2]
246
- ignore_mask_color = (255,)*channel_count
247
- cv2.fillPoly(mask,roi_corners,0)
248
- mask_cropped = mask[min[1]:max[1],min[0]:max[0]]
249
- roi = cv2.bitwise_or(frame_cropped,mask_cropped)
250
-
251
- #roi = frame[roi_list_here[0][1]:roi_list_here[1][1],roi_list_here[0][0]:roi_list_here[1][0]]
252
- number = []
253
- results = model1.predict(roi)
254
- results_pothole = model2.predict(source=frame)
255
- for r in results:
256
- boxes = r.boxes
257
- counter = 0
258
- for box in boxes:
259
- counter += 1
260
- name = classes[box.cls.numpy()[0]]
261
- conf = str(round(box.conf.numpy()[0],2))
262
- text = name+conf
263
- bbox = box.xyxy[0].numpy()
264
- cv2.rectangle(frame,(int(bbox[0])+min[0],int(bbox[1])+min[1]),(int(bbox[2])+min[0],int(bbox[3])+min[1]),(0,255,0),2)
265
- cv2.putText(frame,text,(int(bbox[0])+min[0],int(bbox[1])+min[1]-5),cv2.FONT_HERSHEY_SIMPLEX, 0.4,(0,0,255),2)
266
- number.append(counter)
267
- for r in results_pothole:
268
- masks = r.masks
269
- boxes = r.boxes.cpu().numpy()
270
- xyxys = boxes.xyxy
271
- confs = boxes.conf
272
- if masks is not None:
273
- shapes = np.ones_like(frame)
274
- for mask,conf,xyxy in zip(masks,confs,xyxys):
275
- polygon = mask.xy[0]
276
- if conf >= 0.49 and len(polygon)>=3:
277
- cv2.fillPoly(shapes,pts=np.int32([polygon]),color=(0,0,255,0.5))
278
- frame = cv2.addWeighted(frame,0.7,shapes,0.3,gamma=0)
279
- cv2.rectangle(frame,(int(xyxy[0]),int(xyxy[1])),(int(xyxy[2]),int(xyxy[3])),(0,0,255),2)
280
- cv2.putText(frame,'Pothole '+str(conf),(int(xyxy[0]),int(xyxy[1])-5),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),2)
281
-
282
- avg = sum(number)/len(number)
283
- stats = str(round(avg,2))
284
- if count % 10 == 0:
285
- st.session_state.data[f"ROI{k}"].append(avg)
286
- k+=1
287
- cv2.putText(frame,stats,(min[0],min[1]),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,0),4)
288
- cv2.polylines(frame,roi_corners,True,(255,0,0),2)
289
- if counter >= 5:
290
- cv2.putText(frame,'!!CONGESTION MORE THAN '+str(counter)+' Objects',(min[0]+20,min[1]+20),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,0),4)
291
- cv2.polylines(frame,roi_corners,True,(255,0,0),2)
292
- cv2.putText(frame,'Objects in the Regions of Interest',(100,100),cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),4)
293
- frame_placeholder.image(frame,channels='BGR')
294
- cap.release()
295
-
296
- df = pd.DataFrame(st.session_state.data)
297
- df.to_csv('PS2.csv', sep='\t', encoding='utf-8')
298
-
299
- else:
300
- st.error('PLEASE UPLOAD AN IMAGE OF THE FORMAT JPG,JPEG OR PNG', icon="🚨")
301
-
302
- elif page == "PS2":
303
- st.header("CLICK ON RUN SCRIPT TO START A TRAFFIC SIMULATION")
304
- script = st.button("RUN SCRIPT")
305
- st.session_state.con = -1
306
- if script:
307
- st.session_state.con += 1
308
- import gymnasium as gym
309
- import sumo_rl
310
- import os
311
- from stable_baselines3 import DQN
312
- from stable_baselines3.common.vec_env import DummyVecEnv
313
- from stable_baselines3.common.evaluation import evaluate_policy
314
- from sumo_rl import SumoEnvironment
315
- env = gym.make('sumo-rl-v0',
316
- net_file='single-intersection.net.xml',
317
- route_file='single-intersection-gen.rou.xml',
318
- out_csv_name='output',
319
- use_gui=True,
320
- single_agent=True,
321
- num_seconds=10000)
322
- model1 = DQN.load('DQN_MODEL3.zip',env=env)
323
- one,two = evaluate_policy(model1,env = env,n_eval_episodes=5,render=True)
324
- st.write("Evaluation Results: ",one,two)
325
- import matplotlib.pyplot as plt
326
- def eval_plot(path,metric,path_compare = None):
327
- data = pd.read_csv(path)
328
- if path_compare is not None:
329
- data1 = pd.read_csv(path_compare)
330
- x = []
331
- for i in range(0,len(data)):
332
- x.append(i)
333
-
334
- y = data[metric]
335
- y_1 = pd.to_numeric(y)
336
- y_arr = np.array(y_1)
337
- if path_compare is not None:
338
- y2 = data1[metric]
339
- y_2 = pd.to_numeric(y2)
340
- y_arr2 = np.array(y_2)
341
-
342
- x_arr = np.array(x)
343
-
344
- fig = plt.figure()
345
- ax1 = fig.add_subplot(2, 1, 1)
346
- ax1.set_title(metric)
347
- if path_compare is not None:
348
- ax2 = fig.add_subplot(2, 1, 2,sharey=ax1)
349
- ax2.set_title('compare '+metric)
350
-
351
- ax1.plot(x_arr,y_arr)
352
-
353
- if path_compare is not None:
354
- ax2.plot(x_arr,y_arr2)
355
-
356
- return fig
357
- for i in range(1,2):
358
- st.pyplot(eval_plot(f'output_conn{st.session_state.con}_ep{i}.csv','system_mean_waiting_time'))
359
- st.pyplot(eval_plot(f'output_conn{st.session_state.con}_ep{i}.csv','agents_total_accumulated_waiting_time'))
360
-
361
- elif page == "Chat with Results":
362
- st.title('Chat with the Results')
363
- st.write("Please upload the relevant CSV data to get started")
364
- reload = st.button('Reload')
365
- if 'isran' not in st.session_state or reload == True:
366
- st.session_state['isran'] = False
367
-
368
-
369
- uploaded_file = st.file_uploader('Choose your .csv file', type=["csv"])
370
- if uploaded_file is not None and st.session_state['isran'] == False:
371
- with open("temp.csv", "wb") as f:
372
- f.write(uploaded_file.getvalue())
373
- loader = CSVLoader('temp.csv')
374
- docs = loader.load()
375
- text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 200)
376
- splits = text_splitter.split_documents(docs)
377
-
378
- embeddings = OllamaEmbeddings(model='mistral')
379
- st.session_state.vectorstore = Chroma.from_documents(documents=splits,embedding=embeddings)
380
- st.session_state['isran'] = True
381
-
382
- if st.session_state['isran'] == True:
383
- st.write("Embedding created")
384
-
385
- def fdocs(docs):
386
- return "\n\n".join(doc.page_content for doc in docs)
387
-
388
- def llm(question,context):
389
- formatted_prompt = f"Question: {question}\n\nContext:{context}"
390
- response = ollama.chat(model='mistral', messages=[
391
- {
392
- 'role': 'user',
393
- 'content': formatted_prompt
394
- },
395
- ])
396
- return response['message']['content']
397
-
398
-
399
-
400
- def rag_chain(question):
401
- retriever = st.session_state.vectorstore.as_retriever()
402
- retrieved_docs = retriever.invoke(question)
403
- formatted_context = fdocs(retrieved_docs)
404
- return llm(question,formatted_context)
405
-
406
- if 'messages' not in st.session_state:
407
- st.session_state.messages = []
408
-
409
- for message in st.session_state.messages:
410
- st.chat_message(message['role']).markdown(message['content'])
411
-
412
- prompt = st.chat_input("Say something")
413
- response = rag_chain(prompt)
414
- if prompt:
415
- st.chat_message('user').markdown(prompt)
416
- st.session_state.messages.append({'role':'user','content':prompt})
417
- st.session_state.messages.append({'role':'AI','content':response})
418
- st.chat_message('AI').markdown(response)