Spaces:
Sleeping
Sleeping
chienweichang
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, Query, File, UploadFile
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
from sklearn.neighbors import KDTree
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
from geopy.distance import geodesic
|
7 |
+
import logging
|
8 |
+
|
9 |
+
app = FastAPI()
|
10 |
+
|
11 |
+
# 允許所有來源的跨域請求(可以根據需要進行限制)
|
12 |
+
app.add_middleware(
|
13 |
+
CORSMiddleware,
|
14 |
+
allow_origins=["*"], # 可以根據需要限制來源
|
15 |
+
allow_credentials=True,
|
16 |
+
allow_methods=["*"],
|
17 |
+
allow_headers=["*"],
|
18 |
+
)
|
19 |
+
|
20 |
+
# 設置日誌
|
21 |
+
logging.basicConfig(level=logging.INFO)
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
poi_data = None
|
25 |
+
trees = {}
|
26 |
+
|
27 |
+
# 構建KD-tree(不保存到磁碟)
|
28 |
+
def build_kdtrees(poi_data):
|
29 |
+
trees = {}
|
30 |
+
for poi_type, group in poi_data.groupby('poi_type'):
|
31 |
+
coords = np.array(list(group['coordinates']))
|
32 |
+
tree = KDTree(coords, leaf_size=2)
|
33 |
+
trees[poi_type] = tree
|
34 |
+
return trees
|
35 |
+
|
36 |
+
@app.post("/upload-poi")
|
37 |
+
async def upload_poi(file: UploadFile = File(...)):
|
38 |
+
global poi_data, trees
|
39 |
+
try:
|
40 |
+
poi_data = pd.read_csv(file.file)
|
41 |
+
poi_data['coordinates'] = list(zip(poi_data.lat, poi_data.lng))
|
42 |
+
trees = build_kdtrees(poi_data)
|
43 |
+
return {"message": "POI data uploaded and KD-trees built successfully"}
|
44 |
+
except Exception as e:
|
45 |
+
logger.error(f"An error occurred while processing the uploaded POI data: {e}")
|
46 |
+
raise HTTPException(status_code=500, detail="An error occurred while processing the uploaded POI data")
|
47 |
+
|
48 |
+
@app.get("/poi/nearest")
|
49 |
+
def get_nearest_poi(lat: float, lng: float, poi_type: str = Query(...)):
|
50 |
+
global poi_data, trees
|
51 |
+
if poi_data is None or trees == {}:
|
52 |
+
raise HTTPException(status_code=400, detail="POI data not uploaded")
|
53 |
+
|
54 |
+
coords = np.array([[lat, lng]])
|
55 |
+
|
56 |
+
if poi_type == "all":
|
57 |
+
all_pois = []
|
58 |
+
|
59 |
+
# 遍歷所有KD-tree並找出最近的POI
|
60 |
+
for tree_poi_type, tree in trees.items():
|
61 |
+
dist, inds = tree.query(coords, k=10)
|
62 |
+
for i, distance in enumerate(dist[0]):
|
63 |
+
poi_candidate = poi_data[poi_data['poi_type'] == tree_poi_type].iloc[inds[0][i]]
|
64 |
+
candidate_distance = geodesic((lat, lng), (poi_candidate["lat"], poi_candidate["lng"])).meters
|
65 |
+
all_pois.append({
|
66 |
+
"name": poi_candidate["name"],
|
67 |
+
"poi_type": tree_poi_type,
|
68 |
+
"distance": round(candidate_distance, 2),
|
69 |
+
"latitude": poi_candidate["lat"],
|
70 |
+
"longitude": poi_candidate["lng"]
|
71 |
+
})
|
72 |
+
|
73 |
+
# 排序所有POI並取前10個
|
74 |
+
all_pois = sorted(all_pois, key=lambda x: x["distance"])[:10]
|
75 |
+
|
76 |
+
if not all_pois:
|
77 |
+
raise HTTPException(status_code=404, detail="No POI found")
|
78 |
+
|
79 |
+
return all_pois
|
80 |
+
else:
|
81 |
+
if poi_type not in trees:
|
82 |
+
raise HTTPException(status_code=404, detail="Model type not found")
|
83 |
+
|
84 |
+
tree = trees[poi_type]
|
85 |
+
dist, inds = tree.query(coords, k=10)
|
86 |
+
nearest_pois = []
|
87 |
+
|
88 |
+
for i, distance in enumerate(dist[0]):
|
89 |
+
nearest_poi = poi_data[poi_data['poi_type'] == poi_type].iloc[inds[0][i]]
|
90 |
+
distance_m = geodesic((lat, lng), (nearest_poi["lat"], nearest_poi["lng"])).meters
|
91 |
+
nearest_pois.append({
|
92 |
+
"name": nearest_poi["name"],
|
93 |
+
"poi_type": poi_type,
|
94 |
+
"distance": round(distance_m, 2),
|
95 |
+
"latitude": nearest_poi["lat"],
|
96 |
+
"longitude": nearest_poi["lng"]
|
97 |
+
})
|
98 |
+
|
99 |
+
return nearest_pois
|
100 |
+
|
101 |
+
@app.post("/clear-kdtrees")
|
102 |
+
def clear_kdtrees():
|
103 |
+
global poi_data, trees
|
104 |
+
poi_data = None
|
105 |
+
trees = {}
|
106 |
+
return {"message": "KD-trees and POI data cleared successfully"}
|
107 |
+
|
108 |
+
# 運行應用
|
109 |
+
if __name__ == "__main__":
|
110 |
+
import uvicorn
|
111 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|