Stevenqaq commited on
Commit
5045e42
·
verified ·
1 Parent(s): a0f4f74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -143
app.py CHANGED
@@ -1,146 +1,124 @@
1
- from fastai.vision.all import *
2
- from io import BytesIO
3
- import requests
4
- import streamlit as st
 
5
 
6
- import numpy as np
7
  import torch
8
- import time
9
- import cv2
10
- from numpy import random
11
- from models.experimental import attempt_load
12
- from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
13
- scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path
14
- from utils.plots import plot_one_box
15
-
16
- def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
17
- # Resize and pad image while meeting stride-multiple constraints
18
- shape = img.shape[:2] # current shape [height, width]
19
- if isinstance(new_shape, int):
20
- new_shape = (new_shape, new_shape)
21
-
22
- # Scale ratio (new / old)
23
- r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
24
- if not scaleup: # only scale down, do not scale up (for better test mAP)
25
- r = min(r, 1.0)
26
-
27
- # Compute padding
28
- ratio = r, r # width, height ratios
29
- new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
30
- dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
31
- if auto: # minimum rectangle
32
- dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
33
- elif scaleFill: # stretch
34
- dw, dh = 0.0, 0.0
35
- new_unpad = (new_shape[1], new_shape[0])
36
- ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
37
-
38
- dw /= 2 # divide padding into 2 sides
39
- dh /= 2
40
-
41
- if shape[::-1] != new_unpad: # resize
42
- img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
43
- top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
44
- left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
45
- img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
46
- return img, ratio, (dw, dh)
47
-
48
- def detect_modify(img0, model, conf=0.4, imgsz=640, conf_thres = 0.25, iou_thres=0.45):
49
- st.image(img0, caption="Your image", use_column_width=True)
50
-
51
- stride = int(model.stride.max()) # model stride
52
- imgsz = check_img_size(imgsz, s=stride) # check img_size
53
-
54
- # Padded resize
55
- img0 = cv2.cvtColor(np.asarray(img0), cv2.COLOR_RGB2BGR)
56
- img = letterbox(img0, imgsz, stride=stride)[0]
57
- # Convert
58
- img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
59
- img = np.ascontiguousarray(img)
60
-
61
-
62
- # Get names and colors
63
- names = model.module.names if hasattr(model, 'module') else model.names
64
- colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
65
-
66
- # Run inference
67
- old_img_w = old_img_h = imgsz
68
- old_img_b = 1
69
-
70
- t0 = time.time()
71
- img = torch.from_numpy(img).to(device)
72
- # img /= 255.0 # 0 - 255 to 0.0 - 1.0
73
- img = img/255.0
74
- if img.ndimension() == 3:
75
- img = img.unsqueeze(0)
76
-
77
- # Inference
78
- # t1 = time_synchronized()
79
- with torch.no_grad(): # Calculating gradients would cause a GPU memory leak
80
- pred = model(img)[0]
81
- # t2 = time_synchronized()
82
-
83
- # Apply NMS
84
- pred = non_max_suppression(pred, conf_thres, iou_thres)
85
- # t3 = time_synchronized()
86
-
87
- # Process detections
88
- # for i, det in enumerate(pred): # detections per image
89
-
90
- gn = torch.tensor(img0.shape)[[1, 0, 1, 0]] # normalization gain whwh
91
-
92
- det = pred[0]
93
- if len(det):
94
- # Rescale boxes from img_size to im0 size
95
- det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img0.shape).round()
96
-
97
- # Print results
98
- s = ''
99
- for c in det[:, -1].unique():
100
- n = (det[:, -1] == c).sum() # detections per class
101
- s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
102
-
103
- # Write results
104
- for *xyxy, conf, cls in reversed(det):
105
- label = f'{names[int(cls)]} {conf:.2f}'
106
- plot_one_box(xyxy, img0, label=label, color=colors[int(cls)], line_thickness=1)
107
-
108
- f"""
109
- ### Prediction result:
110
- """
111
- img0 = cv2.cvtColor(np.asarray(img0), cv2.COLOR_BGR2RGB)
112
- st.image(img0, caption="Prediction Result", use_column_width=True)
113
-
114
- #set paramters
115
- weight_path = './best,pt'
116
- imgsz = 640
117
- conf = 0.4
118
- conf_thres = 0.25
119
- iou_thres=0.45
120
- device = torch.device("cpu")
121
- path = "./"
122
-
123
- # Load model
124
- model = attempt_load(weight_path, map_location=torch.device('cpu')) # load FP32 model
125
-
126
- """
127
- # YOLOv7
128
- This is a object detection model for [Objects].
129
- """
130
- option = st.radio("", ["Upload Image", "Image URL"])
131
-
132
- if option == "Upload Image":
133
- uploaded_file = st.file_uploader("Please upload an image.")
134
-
135
- if uploaded_file is not None:
136
- img = PILImage.create(uploaded_file)
137
- detect_modify(img, model, conf=conf, imgsz=imgsz, conf_thres=conf_thres, iou_thres=iou_thres)
138
- else:
139
- url = st.text_input("Please input a url.")
140
- if url != "":
141
  try:
142
- response = requests.get(url)
143
- pil_img = PILImage.create(BytesIO(response.content))
144
- detect_modify(pil_img, model, conf=conf, imgsz=imgsz, conf_thres=conf_thres, iou_thres=iou_thres)
145
- except:
146
- st.text("Problem reading image from", url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import platform
3
+ import subprocess
4
+ import time
5
+ from pathlib import Path
6
 
7
+ import requests
8
  import torch
9
+
10
+
11
+ def gsutil_getsize(url=''):
12
+ # gs://bucket/file size https://cloud.google.com/storage/docs/gsutil/commands/du
13
+ s = subprocess.check_output(f'gsutil du {url}', shell=True).decode('utf-8')
14
+ return eval(s.split(' ')[0]) if len(s) else 0 # bytes
15
+
16
+
17
+ def attempt_download(file, repo='WongKinYiu/yolov7'):
18
+ # Attempt file download if does not exist
19
+ file = Path(str(file).strip().replace("'", '').lower())
20
+
21
+ if not file.exists():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  try:
23
+ response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api
24
+ assets = [x['name'] for x in response['assets']] # release assets
25
+ tag = response['tag_name'] # i.e. 'v1.0'
26
+ except: # fallback plan
27
+ assets = ['yolov7.pt', 'yolov7-tiny.pt', 'yolov7x.pt', 'yolov7-d6.pt', 'yolov7-e6.pt',
28
+ 'yolov7-e6e.pt', 'yolov7-w6.pt']
29
+ try:
30
+ tag = subprocess.check_output('git tag', shell=True).decode().split()[-1]
31
+ except IndexError:
32
+ tag = 'default'
33
+
34
+ name = file.name
35
+ if name in assets:
36
+ msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/'
37
+ redundant = False # second download option
38
+ try: # GitHub
39
+ url = f'https://github.com/{repo}/releases/download/{tag}/{name}'
40
+ print(f'Downloading {url} to {file}...')
41
+ torch.hub.download_url_to_file(url, file)
42
+ assert file.exists() and file.stat().st_size > 1E6 # check
43
+ except Exception as e: # GCP
44
+ print(f'Download error: {e}')
45
+ assert redundant, 'No secondary mirror'
46
+ url = f'https://storage.googleapis.com/{repo}/ckpt/{name}'
47
+ print(f'Downloading {url} to {file}...')
48
+ os.system(f'curl -L {url} -o {file}') # torch.hub.download_url_to_file(url, weights)
49
+ finally:
50
+ if not file.exists() or file.stat().st_size < 1E6: # check
51
+ file.unlink(missing_ok=True) # remove partial downloads
52
+ print(f'ERROR: Download failure: {msg}')
53
+ print('')
54
+ return
55
+
56
+
57
+ def gdrive_download(id='', file='tmp.zip'):
58
+ # Downloads a file from Google Drive. from yolov7.utils.google_utils import *; gdrive_download()
59
+ t = time.time()
60
+ file = Path(file)
61
+ cookie = Path('cookie') # gdrive cookie
62
+ print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='')
63
+ file.unlink(missing_ok=True) # remove existing file
64
+ cookie.unlink(missing_ok=True) # remove existing cookie
65
+
66
+ # Attempt file download
67
+ out = "NUL" if platform.system() == "Windows" else "/dev/null"
68
+ os.system(f'curl -c ./cookie -s -L "drive.google.com/uc?export=download&id={id}" > {out}')
69
+ if os.path.exists('cookie'): # large file
70
+ s = f'curl -Lb ./cookie "drive.google.com/uc?export=download&confirm={get_token()}&id={id}" -o {file}'
71
+ else: # small file
72
+ s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"'
73
+ r = os.system(s) # execute, capture return
74
+ cookie.unlink(missing_ok=True) # remove existing cookie
75
+
76
+ # Error check
77
+ if r != 0:
78
+ file.unlink(missing_ok=True) # remove partial
79
+ print('Download error ') # raise Exception('Download error')
80
+ return r
81
+
82
+ # Unzip if archive
83
+ if file.suffix == '.zip':
84
+ print('unzipping... ', end='')
85
+ os.system(f'unzip -q {file}') # unzip
86
+ file.unlink() # remove zip to free space
87
+
88
+ print(f'Done ({time.time() - t:.1f}s)')
89
+ return r
90
+
91
+
92
+ def get_token(cookie="./cookie"):
93
+ with open(cookie) as f:
94
+ for line in f:
95
+ if "download" in line:
96
+ return line.split()[-1]
97
+ return ""
98
+
99
+ # def upload_blob(bucket_name, source_file_name, destination_blob_name):
100
+ # # Uploads a file to a bucket
101
+ # # https://cloud.google.com/storage/docs/uploading-objects#storage-upload-object-python
102
+ #
103
+ # storage_client = storage.Client()
104
+ # bucket = storage_client.get_bucket(bucket_name)
105
+ # blob = bucket.blob(destination_blob_name)
106
+ #
107
+ # blob.upload_from_filename(source_file_name)
108
+ #
109
+ # print('File {} uploaded to {}.'.format(
110
+ # source_file_name,
111
+ # destination_blob_name))
112
+ #
113
+ #
114
+ # def download_blob(bucket_name, source_blob_name, destination_file_name):
115
+ # # Uploads a blob from a bucket
116
+ # storage_client = storage.Client()
117
+ # bucket = storage_client.get_bucket(bucket_name)
118
+ # blob = bucket.blob(source_blob_name)
119
+ #
120
+ # blob.download_to_filename(destination_file_name)
121
+ #
122
+ # print('Blob {} downloaded to {}.'.format(
123
+ # source_blob_name,
124
+ # destination_file_name))