tecuts commited on
Commit
15fbf8b
·
verified ·
1 Parent(s): 4f0e710

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -12
app.py CHANGED
@@ -7,6 +7,7 @@ import logging
7
  from typing import Optional
8
  from fastapi.staticfiles import StaticFiles
9
  from dotenv import load_dotenv
 
10
 
11
  # Set up logging
12
  logging.basicConfig(level=logging.INFO)
@@ -30,10 +31,18 @@ BASE_URL = "https://tecuts-depot.hf.space"
30
  ARL_TOKEN = os.getenv('ARL')
31
  dl = DeeLogin(arl=ARL_TOKEN)
32
 
 
 
 
 
 
 
 
33
  @app.get("/")
34
  def read_root():
35
  return {"message": "running"}
36
 
 
37
  # Helper function to get track info
38
  def get_track_info(track_id: str):
39
  try:
@@ -48,15 +57,26 @@ def get_track_info(track_id: str):
48
  logger.error(f"Error fetching track metadata: {e}")
49
  raise HTTPException(status_code=500, detail=str(e))
50
 
 
51
  # Fetch track metadata from Deezer API
52
  @app.get("/track/{track_id}")
53
  def get_track(track_id: str):
54
  return get_track_info(track_id)
55
 
 
56
  # Download a track and return a download URL
57
- @app.post("/download/track/{track_id}")
58
- def download_track(track_id: str, quality: str = "MP3_320"):
59
  try:
 
 
 
 
 
 
 
 
 
60
  # Fetch track info
61
  track_info = get_track_info(track_id)
62
  track_link = track_info.get("link")
@@ -66,7 +86,8 @@ def download_track(track_id: str, quality: str = "MP3_320"):
66
  # Sanitize filename
67
  track_title = track_info.get("title", "track")
68
  artist_name = track_info.get("artist", {}).get("name", "unknown")
69
- expected_filename = f"{artist_name} - {track_title}.mp3".replace("/", "_") # Sanitize filename
 
70
 
71
  # Clear the downloads directory
72
  for root, dirs, files in os.walk("downloads"):
@@ -85,21 +106,21 @@ def download_track(track_id: str, quality: str = "MP3_320"):
85
  recursive_download=False
86
  )
87
 
88
- # Recursively search for the MP3 file in the downloads directory
89
- mp3_filepath = None
90
  for root, dirs, files in os.walk("downloads"):
91
  for file in files:
92
- if file.endswith('.mp3'):
93
- mp3_filepath = os.path.join(root, file)
94
  break
95
- if mp3_filepath:
96
  break
97
 
98
- if not mp3_filepath:
99
- raise HTTPException(status_code=500, detail="MP3 file not found after download")
100
 
101
  # Return the download URL
102
- relative_path = os.path.relpath(mp3_filepath, "downloads")
103
  # Remove spaces from the relative path
104
  relative_path = relative_path.replace(" ", "%20")
105
  download_url = f"{BASE_URL}/downloads/{relative_path}"
@@ -109,6 +130,7 @@ def download_track(track_id: str, quality: str = "MP3_320"):
109
  logger.error(f"Error downloading track: {e}")
110
  raise HTTPException(status_code=500, detail=str(e))
111
 
 
112
  # Search tracks using Deezer API
113
  @app.get("/search")
114
  def search_tracks(query: str, limit: Optional[int] = 10):
@@ -120,4 +142,4 @@ def search_tracks(query: str, limit: Optional[int] = 10):
120
  raise HTTPException(status_code=500, detail=str(e))
121
  except Exception as e:
122
  logger.error(f"Error searching tracks: {e}")
123
- raise HTTPException(status_code=500, detail=str(e))
 
7
  from typing import Optional
8
  from fastapi.staticfiles import StaticFiles
9
  from dotenv import load_dotenv
10
+ from pydantic import BaseModel
11
 
12
  # Set up logging
13
  logging.basicConfig(level=logging.INFO)
 
31
  ARL_TOKEN = os.getenv('ARL')
32
  dl = DeeLogin(arl=ARL_TOKEN)
33
 
34
+
35
+ # 定义请求体模型
36
+ class DownloadRequest(BaseModel):
37
+ url: str
38
+ quality: str
39
+
40
+
41
  @app.get("/")
42
  def read_root():
43
  return {"message": "running"}
44
 
45
+
46
  # Helper function to get track info
47
  def get_track_info(track_id: str):
48
  try:
 
57
  logger.error(f"Error fetching track metadata: {e}")
58
  raise HTTPException(status_code=500, detail=str(e))
59
 
60
+
61
  # Fetch track metadata from Deezer API
62
  @app.get("/track/{track_id}")
63
  def get_track(track_id: str):
64
  return get_track_info(track_id)
65
 
66
+
67
  # Download a track and return a download URL
68
+ @app.post("/download/track")
69
+ def download_track(request: DownloadRequest):
70
  try:
71
+ url = request.url
72
+ quality = request.quality
73
+
74
+ if quality not in ["MP3_320", "MP3_128", "FLAC"]:
75
+ raise HTTPException(status_code=400, detail="Invalid quality specified")
76
+
77
+ # 提取 track_id (假设 url 格式为 https://api.deezer.com/track/{track_id})
78
+ track_id = url.split("/")[-1]
79
+
80
  # Fetch track info
81
  track_info = get_track_info(track_id)
82
  track_link = track_info.get("link")
 
86
  # Sanitize filename
87
  track_title = track_info.get("title", "track")
88
  artist_name = track_info.get("artist", {}).get("name", "unknown")
89
+ file_extension = "flac" if quality == "FLAC" else "mp3"
90
+ expected_filename = f"{artist_name} - {track_title}.{file_extension}".replace("/", "_") # Sanitize filename
91
 
92
  # Clear the downloads directory
93
  for root, dirs, files in os.walk("downloads"):
 
106
  recursive_download=False
107
  )
108
 
109
+ # Recursively search for the file in the downloads directory
110
+ filepath = None
111
  for root, dirs, files in os.walk("downloads"):
112
  for file in files:
113
+ if file.endswith(f'.{file_extension}'):
114
+ filepath = os.path.join(root, file)
115
  break
116
+ if filepath:
117
  break
118
 
119
+ if not filepath:
120
+ raise HTTPException(status_code=500, detail=f"{file_extension} file not found after download")
121
 
122
  # Return the download URL
123
+ relative_path = os.path.relpath(filepath, "downloads")
124
  # Remove spaces from the relative path
125
  relative_path = relative_path.replace(" ", "%20")
126
  download_url = f"{BASE_URL}/downloads/{relative_path}"
 
130
  logger.error(f"Error downloading track: {e}")
131
  raise HTTPException(status_code=500, detail=str(e))
132
 
133
+
134
  # Search tracks using Deezer API
135
  @app.get("/search")
136
  def search_tracks(query: str, limit: Optional[int] = 10):
 
142
  raise HTTPException(status_code=500, detail=str(e))
143
  except Exception as e:
144
  logger.error(f"Error searching tracks: {e}")
145
+ raise HTTPException(status_code=500, detail=str(e))