Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,19 @@
|
|
|
|
1 |
import torch
|
2 |
from fastapi import FastAPI, HTTPException
|
3 |
from pydantic import BaseModel
|
4 |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
5 |
from typing import Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Inisialisasi API
|
8 |
app = FastAPI(
|
@@ -11,8 +22,6 @@ app = FastAPI(
|
|
11 |
)
|
12 |
|
13 |
# --- Daftar model dan tugasnya ---
|
14 |
-
# Kita buat kamus (dictionary) agar mudah dipanggil.
|
15 |
-
# Ini juga membantu kita tahu pipeline apa yang harus digunakan untuk setiap model.
|
16 |
MODEL_MAPPING = {
|
17 |
# Generative Models (Text Generation)
|
18 |
"Tinny-Llama": {"id": "Lyon28/Tinny-Llama", "task": "text-generation"},
|
@@ -33,14 +42,28 @@ MODEL_MAPPING = {
|
|
33 |
}
|
34 |
|
35 |
# --- Cache untuk menyimpan model yang sudah dimuat ---
|
36 |
-
# Ini penting! Kita tidak mau memuat model yang sama berulang-ulang.
|
37 |
-
# Ini akan menghemat waktu dan memori.
|
38 |
PIPELINE_CACHE = {}
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
def get_pipeline(model_name: str):
|
41 |
"""Fungsi untuk memuat model dari cache atau dari Hub jika belum ada."""
|
42 |
if model_name in PIPELINE_CACHE:
|
43 |
-
|
44 |
return PIPELINE_CACHE[model_name]
|
45 |
|
46 |
if model_name not in MODEL_MAPPING:
|
@@ -50,16 +73,34 @@ def get_pipeline(model_name: str):
|
|
50 |
model_id = model_info["id"]
|
51 |
task = model_info["task"]
|
52 |
|
53 |
-
|
|
|
54 |
try:
|
55 |
-
#
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
PIPELINE_CACHE[model_name] = pipe
|
58 |
-
|
59 |
return pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
except Exception as e:
|
61 |
-
|
62 |
-
|
|
|
63 |
|
64 |
# --- Definisikan struktur request dari user ---
|
65 |
class InferenceRequest(BaseModel):
|
@@ -72,9 +113,20 @@ def read_root():
|
|
72 |
"""Endpoint untuk mengecek status API dan daftar model yang tersedia."""
|
73 |
return {
|
74 |
"status": "API is running!",
|
75 |
-
"available_models": list(MODEL_MAPPING.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
}
|
77 |
|
|
|
|
|
|
|
|
|
|
|
78 |
@app.post("/invoke")
|
79 |
def invoke_model(request: InferenceRequest):
|
80 |
"""Endpoint utama untuk melakukan inferensi pada model yang dipilih."""
|
@@ -83,7 +135,6 @@ def invoke_model(request: InferenceRequest):
|
|
83 |
pipe = get_pipeline(request.model_name)
|
84 |
|
85 |
# Gabungkan prompt dengan parameter tambahan
|
86 |
-
# Ini membuat API kita sangat fleksibel!
|
87 |
result = pipe(request.prompt, **request.parameters)
|
88 |
|
89 |
return {
|
@@ -97,14 +148,31 @@ def invoke_model(request: InferenceRequest):
|
|
97 |
raise e
|
98 |
except Exception as e:
|
99 |
# Menangkap error lain yang mungkin terjadi saat inferensi
|
|
|
100 |
raise HTTPException(status_code=500, detail=f"Terjadi error saat inferensi: {str(e)}")
|
101 |
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
@app.on_event("startup")
|
105 |
async def startup_event():
|
106 |
-
|
|
|
|
|
|
|
|
|
107 |
try:
|
108 |
-
|
|
|
|
|
109 |
except Exception as e:
|
110 |
-
|
|
|
|
1 |
+
import os
|
2 |
import torch
|
3 |
from fastapi import FastAPI, HTTPException
|
4 |
from pydantic import BaseModel
|
5 |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
6 |
from typing import Dict, Any
|
7 |
+
import logging
|
8 |
+
|
9 |
+
# Setup logging
|
10 |
+
logging.basicConfig(level=logging.INFO)
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
# Set cache directories
|
14 |
+
os.environ['HF_HOME'] = '/app/.cache'
|
15 |
+
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache/transformers'
|
16 |
+
os.environ['HF_HUB_CACHE'] = '/app/.cache/hub'
|
17 |
|
18 |
# Inisialisasi API
|
19 |
app = FastAPI(
|
|
|
22 |
)
|
23 |
|
24 |
# --- Daftar model dan tugasnya ---
|
|
|
|
|
25 |
MODEL_MAPPING = {
|
26 |
# Generative Models (Text Generation)
|
27 |
"Tinny-Llama": {"id": "Lyon28/Tinny-Llama", "task": "text-generation"},
|
|
|
42 |
}
|
43 |
|
44 |
# --- Cache untuk menyimpan model yang sudah dimuat ---
|
|
|
|
|
45 |
PIPELINE_CACHE = {}
|
46 |
|
47 |
+
def ensure_cache_directory():
|
48 |
+
"""Pastikan direktori cache ada dan memiliki permission yang benar."""
|
49 |
+
cache_dirs = [
|
50 |
+
'/app/.cache',
|
51 |
+
'/app/.cache/transformers',
|
52 |
+
'/app/.cache/hub'
|
53 |
+
]
|
54 |
+
|
55 |
+
for cache_dir in cache_dirs:
|
56 |
+
try:
|
57 |
+
os.makedirs(cache_dir, exist_ok=True)
|
58 |
+
os.chmod(cache_dir, 0o755)
|
59 |
+
logger.info(f"Cache directory {cache_dir} ready")
|
60 |
+
except Exception as e:
|
61 |
+
logger.error(f"Failed to create cache directory {cache_dir}: {e}")
|
62 |
+
|
63 |
def get_pipeline(model_name: str):
|
64 |
"""Fungsi untuk memuat model dari cache atau dari Hub jika belum ada."""
|
65 |
if model_name in PIPELINE_CACHE:
|
66 |
+
logger.info(f"Mengambil model '{model_name}' dari cache.")
|
67 |
return PIPELINE_CACHE[model_name]
|
68 |
|
69 |
if model_name not in MODEL_MAPPING:
|
|
|
73 |
model_id = model_info["id"]
|
74 |
task = model_info["task"]
|
75 |
|
76 |
+
logger.info(f"Memuat model '{model_name}' ({model_id}) untuk tugas '{task}'...")
|
77 |
+
|
78 |
try:
|
79 |
+
# Pastikan cache directory siap
|
80 |
+
ensure_cache_directory()
|
81 |
+
|
82 |
+
# Load model dengan error handling yang lebih baik
|
83 |
+
pipe = pipeline(
|
84 |
+
task,
|
85 |
+
model=model_id,
|
86 |
+
device_map="auto",
|
87 |
+
cache_dir="/app/.cache/transformers",
|
88 |
+
trust_remote_code=True # Untuk model custom
|
89 |
+
)
|
90 |
+
|
91 |
PIPELINE_CACHE[model_name] = pipe
|
92 |
+
logger.info(f"Model '{model_name}' berhasil dimuat dan disimpan di cache.")
|
93 |
return pipe
|
94 |
+
|
95 |
+
except PermissionError as e:
|
96 |
+
error_msg = f"Permission error saat memuat model '{model_name}': {str(e)}. Check cache directory permissions."
|
97 |
+
logger.error(error_msg)
|
98 |
+
raise HTTPException(status_code=500, detail=error_msg)
|
99 |
+
|
100 |
except Exception as e:
|
101 |
+
error_msg = f"Gagal memuat model '{model_name}': {str(e)}. Common causes: 1) another user is downloading the same model (please wait); 2) a previous download was canceled and the lock file needs manual removal."
|
102 |
+
logger.error(error_msg)
|
103 |
+
raise HTTPException(status_code=500, detail=error_msg)
|
104 |
|
105 |
# --- Definisikan struktur request dari user ---
|
106 |
class InferenceRequest(BaseModel):
|
|
|
113 |
"""Endpoint untuk mengecek status API dan daftar model yang tersedia."""
|
114 |
return {
|
115 |
"status": "API is running!",
|
116 |
+
"available_models": list(MODEL_MAPPING.keys()),
|
117 |
+
"cached_models": list(PIPELINE_CACHE.keys()),
|
118 |
+
"cache_info": {
|
119 |
+
"HF_HOME": os.environ.get('HF_HOME'),
|
120 |
+
"TRANSFORMERS_CACHE": os.environ.get('TRANSFORMERS_CACHE'),
|
121 |
+
"HF_HUB_CACHE": os.environ.get('HF_HUB_CACHE')
|
122 |
+
}
|
123 |
}
|
124 |
|
125 |
+
@app.get("/health")
|
126 |
+
def health_check():
|
127 |
+
"""Health check endpoint."""
|
128 |
+
return {"status": "healthy", "cached_models": len(PIPELINE_CACHE)}
|
129 |
+
|
130 |
@app.post("/invoke")
|
131 |
def invoke_model(request: InferenceRequest):
|
132 |
"""Endpoint utama untuk melakukan inferensi pada model yang dipilih."""
|
|
|
135 |
pipe = get_pipeline(request.model_name)
|
136 |
|
137 |
# Gabungkan prompt dengan parameter tambahan
|
|
|
138 |
result = pipe(request.prompt, **request.parameters)
|
139 |
|
140 |
return {
|
|
|
148 |
raise e
|
149 |
except Exception as e:
|
150 |
# Menangkap error lain yang mungkin terjadi saat inferensi
|
151 |
+
logger.error(f"Inference error: {str(e)}")
|
152 |
raise HTTPException(status_code=500, detail=f"Terjadi error saat inferensi: {str(e)}")
|
153 |
|
154 |
+
@app.delete("/cache/{model_name}")
|
155 |
+
def clear_model_cache(model_name: str):
|
156 |
+
"""Endpoint untuk menghapus model dari cache."""
|
157 |
+
if model_name in PIPELINE_CACHE:
|
158 |
+
del PIPELINE_CACHE[model_name]
|
159 |
+
logger.info(f"Model '{model_name}' removed from cache")
|
160 |
+
return {"status": "success", "message": f"Model '{model_name}' removed from cache"}
|
161 |
+
else:
|
162 |
+
raise HTTPException(status_code=404, detail=f"Model '{model_name}' tidak ada di cache")
|
163 |
+
|
164 |
+
# Startup event dengan error handling yang lebih baik
|
165 |
@app.on_event("startup")
|
166 |
async def startup_event():
|
167 |
+
logger.info("API startup: Melakukan warm-up dengan memuat satu model awal...")
|
168 |
+
|
169 |
+
# Pastikan cache directory siap
|
170 |
+
ensure_cache_directory()
|
171 |
+
|
172 |
try:
|
173 |
+
# Coba model yang paling kecil terlebih dahulu
|
174 |
+
get_pipeline("GPT-2-Tinny")
|
175 |
+
logger.info("Warm-up berhasil!")
|
176 |
except Exception as e:
|
177 |
+
logger.warning(f"Gagal melakukan warm-up: {e}")
|
178 |
+
logger.info("API tetap berjalan, model akan dimuat saat diperlukan.")
|