wsj1995 commited on
Commit
0712c29
·
1 Parent(s): 68e3efe

feat: 上传文件

Browse files
Files changed (2) hide show
  1. api.py +35 -11
  2. requirements.txt +2 -1
api.py CHANGED
@@ -8,6 +8,7 @@ from io import BytesIO
8
  from datetime import datetime
9
  import hashlib
10
  import time
 
11
 
12
 
13
  async def verify_internal_token(request: Request):
@@ -45,18 +46,41 @@ def download_file(url: str):
45
  def read_root():
46
  return {"Hello": "World!"}
47
 
48
- @router.post("/upload")
49
- async def upload_file(file: UploadFile = File(...)):
50
- # 定义文件存储路径和名称
51
- file_location = os.path.join(UPLOAD_DIRECTORY, file.filename)
52
-
53
- # 将上传的文件写入本地文件系统
54
- with open(file_location, "wb") as buffer:
55
- buffer.write(await file.read())
56
-
57
- return {"info": f"file '{file.filename}' saved at '{file_location}'"}
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  @router.get("/download")
61
  def download():
62
- return StreamingResponse(download_file("https://huggingface.co/wsj1995/stable-diffusion-models/resolve/main/3Guofeng3_v34.safetensors?download=true"), media_type="application/octet-stream")
 
 
 
 
 
 
 
 
 
 
 
 
8
  from datetime import datetime
9
  import hashlib
10
  import time
11
+ from huggingface_hub import HfApi
12
 
13
 
14
  async def verify_internal_token(request: Request):
 
46
  def read_root():
47
  return {"Hello": "World!"}
48
 
49
+ @router.post("/upload/{modelId}/{userId}/{filename}")
50
+ async def upload_file(modelId: str, userId: str, filename: str, file: UploadFile = File(...)):
51
+ file_location = os.path.join(UPLOAD_DIRECTORY,userId, filename)
52
+ try:
53
+ with open(file_location, "wb") as buffer:
54
+ buffer.write(await file.read())
55
+ callback(modelId,'UPLOADING')
56
+ pathInRepo = f"{userId}/{filename}"
57
+ huggingfaceApi = HfApi()
58
+ HfFolder.save_token(os.environ.get("HF_TOKEN"))
59
+ huggingfaceApi.upload_file(
60
+ path_or_fileobj=file_location,
61
+ path_in_repo=pathInRepo,
62
+ repo_id='wsj1995/aigc-user-uploaded-models',
63
+ repo_type="model"
64
+ )
65
+ callback(modelId,'UPLOADED')
66
+ except Exception as e:
67
+ print(e)
68
+ callback(modelId,'FAIL')
69
+ os.remove(file_location)
70
+ return {'success': True, 'id': modelId}
71
 
72
 
73
  @router.get("/download")
74
  def download():
75
+ return StreamingResponse(download_file("https://huggingface.co/wsj1995/stable-diffusion-models/resolve/main/3Guofeng3_v34.safetensors?download=true"), media_type="application/octet-stream")
76
+
77
+
78
+
79
+ def callback(modelId,status):
80
+ res = requests.post(os.environ.get("CALLBACK_URL"),json={
81
+ 'status': status,
82
+ 'model_id': modelId
83
+ },headers={
84
+ os.environ.get("CALLBACK_SECRET_KEY"): os.environ.get("CALLBACK_SECRET")
85
+ })
86
+ print(f"回调结果 {res.status_code}")
requirements.txt CHANGED
@@ -4,4 +4,5 @@ sentencepiece==0.1.*
4
  torch==1.11.*
5
  transformers==4.*
6
  uvicorn[standard]==0.17.*
7
- python-multipart
 
 
4
  torch==1.11.*
5
  transformers==4.*
6
  uvicorn[standard]==0.17.*
7
+ python-multipart
8
+ huggingface_hub