Lyon28 commited on
Commit
449a0cc
·
verified ·
1 Parent(s): 152d30d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -18
app.py CHANGED
@@ -1,8 +1,19 @@
 
1
  import torch
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
5
  from typing import Dict, Any
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Inisialisasi API
8
  app = FastAPI(
@@ -11,8 +22,6 @@ app = FastAPI(
11
  )
12
 
13
  # --- Daftar model dan tugasnya ---
14
- # Kita buat kamus (dictionary) agar mudah dipanggil.
15
- # Ini juga membantu kita tahu pipeline apa yang harus digunakan untuk setiap model.
16
  MODEL_MAPPING = {
17
  # Generative Models (Text Generation)
18
  "Tinny-Llama": {"id": "Lyon28/Tinny-Llama", "task": "text-generation"},
@@ -33,14 +42,28 @@ MODEL_MAPPING = {
33
  }
34
 
35
  # --- Cache untuk menyimpan model yang sudah dimuat ---
36
- # Ini penting! Kita tidak mau memuat model yang sama berulang-ulang.
37
- # Ini akan menghemat waktu dan memori.
38
  PIPELINE_CACHE = {}
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def get_pipeline(model_name: str):
41
  """Fungsi untuk memuat model dari cache atau dari Hub jika belum ada."""
42
  if model_name in PIPELINE_CACHE:
43
- print(f"Mengambil model '{model_name}' dari cache.")
44
  return PIPELINE_CACHE[model_name]
45
 
46
  if model_name not in MODEL_MAPPING:
@@ -50,16 +73,34 @@ def get_pipeline(model_name: str):
50
  model_id = model_info["id"]
51
  task = model_info["task"]
52
 
53
- print(f"Memuat model '{model_name}' ({model_id}) untuk tugas '{task}'...")
 
54
  try:
55
- # device_map="auto" menggunakan accelerate untuk menempatkan model secara efisien
56
- pipe = pipeline(task, model=model_id, device_map="auto")
 
 
 
 
 
 
 
 
 
 
57
  PIPELINE_CACHE[model_name] = pipe
58
- print(f"Model '{model_name}' berhasil dimuat dan disimpan di cache.")
59
  return pipe
 
 
 
 
 
 
60
  except Exception as e:
61
- raise HTTPException(status_code=500, detail=f"Gagal memuat model '{model_name}': {str(e)}")
62
-
 
63
 
64
  # --- Definisikan struktur request dari user ---
65
  class InferenceRequest(BaseModel):
@@ -72,9 +113,20 @@ def read_root():
72
  """Endpoint untuk mengecek status API dan daftar model yang tersedia."""
73
  return {
74
  "status": "API is running!",
75
- "available_models": list(MODEL_MAPPING.keys())
 
 
 
 
 
 
76
  }
77
 
 
 
 
 
 
78
  @app.post("/invoke")
79
  def invoke_model(request: InferenceRequest):
80
  """Endpoint utama untuk melakukan inferensi pada model yang dipilih."""
@@ -83,7 +135,6 @@ def invoke_model(request: InferenceRequest):
83
  pipe = get_pipeline(request.model_name)
84
 
85
  # Gabungkan prompt dengan parameter tambahan
86
- # Ini membuat API kita sangat fleksibel!
87
  result = pipe(request.prompt, **request.parameters)
88
 
89
  return {
@@ -97,14 +148,31 @@ def invoke_model(request: InferenceRequest):
97
  raise e
98
  except Exception as e:
99
  # Menangkap error lain yang mungkin terjadi saat inferensi
 
100
  raise HTTPException(status_code=500, detail=f"Terjadi error saat inferensi: {str(e)}")
101
 
102
- # Saat aplikasi pertama kali dijalankan, kita bisa coba muat satu model populer
103
- # untuk menghangatkan sistem (warm-up). Ini opsional.
 
 
 
 
 
 
 
 
 
104
  @app.on_event("startup")
105
  async def startup_event():
106
- print("API startup: Melakukan warm-up dengan memuat satu model awal...")
 
 
 
 
107
  try:
108
- get_pipeline("GPT-2-Tinny") # Pilih model yang kecil dan cepat
 
 
109
  except Exception as e:
110
- print(f"Gagal melakukan warm-up: {e}")
 
 
1
+ import os
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
6
  from typing import Dict, Any
7
+ import logging
8
+
9
+ # Setup logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Set cache directories
14
+ os.environ['HF_HOME'] = '/app/.cache'
15
+ os.environ['TRANSFORMERS_CACHE'] = '/app/.cache/transformers'
16
+ os.environ['HF_HUB_CACHE'] = '/app/.cache/hub'
17
 
18
  # Inisialisasi API
19
  app = FastAPI(
 
22
  )
23
 
24
  # --- Daftar model dan tugasnya ---
 
 
25
  MODEL_MAPPING = {
26
  # Generative Models (Text Generation)
27
  "Tinny-Llama": {"id": "Lyon28/Tinny-Llama", "task": "text-generation"},
 
42
  }
43
 
44
  # --- Cache untuk menyimpan model yang sudah dimuat ---
 
 
45
  PIPELINE_CACHE = {}
46
 
47
+ def ensure_cache_directory():
48
+ """Pastikan direktori cache ada dan memiliki permission yang benar."""
49
+ cache_dirs = [
50
+ '/app/.cache',
51
+ '/app/.cache/transformers',
52
+ '/app/.cache/hub'
53
+ ]
54
+
55
+ for cache_dir in cache_dirs:
56
+ try:
57
+ os.makedirs(cache_dir, exist_ok=True)
58
+ os.chmod(cache_dir, 0o755)
59
+ logger.info(f"Cache directory {cache_dir} ready")
60
+ except Exception as e:
61
+ logger.error(f"Failed to create cache directory {cache_dir}: {e}")
62
+
63
  def get_pipeline(model_name: str):
64
  """Fungsi untuk memuat model dari cache atau dari Hub jika belum ada."""
65
  if model_name in PIPELINE_CACHE:
66
+ logger.info(f"Mengambil model '{model_name}' dari cache.")
67
  return PIPELINE_CACHE[model_name]
68
 
69
  if model_name not in MODEL_MAPPING:
 
73
  model_id = model_info["id"]
74
  task = model_info["task"]
75
 
76
+ logger.info(f"Memuat model '{model_name}' ({model_id}) untuk tugas '{task}'...")
77
+
78
  try:
79
+ # Pastikan cache directory siap
80
+ ensure_cache_directory()
81
+
82
+ # Load model dengan error handling yang lebih baik
83
+ pipe = pipeline(
84
+ task,
85
+ model=model_id,
86
+ device_map="auto",
87
+ cache_dir="/app/.cache/transformers",
88
+ trust_remote_code=True # Untuk model custom
89
+ )
90
+
91
  PIPELINE_CACHE[model_name] = pipe
92
+ logger.info(f"Model '{model_name}' berhasil dimuat dan disimpan di cache.")
93
  return pipe
94
+
95
+ except PermissionError as e:
96
+ error_msg = f"Permission error saat memuat model '{model_name}': {str(e)}. Check cache directory permissions."
97
+ logger.error(error_msg)
98
+ raise HTTPException(status_code=500, detail=error_msg)
99
+
100
  except Exception as e:
101
+ error_msg = f"Gagal memuat model '{model_name}': {str(e)}. Common causes: 1) another user is downloading the same model (please wait); 2) a previous download was canceled and the lock file needs manual removal."
102
+ logger.error(error_msg)
103
+ raise HTTPException(status_code=500, detail=error_msg)
104
 
105
  # --- Definisikan struktur request dari user ---
106
  class InferenceRequest(BaseModel):
 
113
  """Endpoint untuk mengecek status API dan daftar model yang tersedia."""
114
  return {
115
  "status": "API is running!",
116
+ "available_models": list(MODEL_MAPPING.keys()),
117
+ "cached_models": list(PIPELINE_CACHE.keys()),
118
+ "cache_info": {
119
+ "HF_HOME": os.environ.get('HF_HOME'),
120
+ "TRANSFORMERS_CACHE": os.environ.get('TRANSFORMERS_CACHE'),
121
+ "HF_HUB_CACHE": os.environ.get('HF_HUB_CACHE')
122
+ }
123
  }
124
 
125
+ @app.get("/health")
126
+ def health_check():
127
+ """Health check endpoint."""
128
+ return {"status": "healthy", "cached_models": len(PIPELINE_CACHE)}
129
+
130
  @app.post("/invoke")
131
  def invoke_model(request: InferenceRequest):
132
  """Endpoint utama untuk melakukan inferensi pada model yang dipilih."""
 
135
  pipe = get_pipeline(request.model_name)
136
 
137
  # Gabungkan prompt dengan parameter tambahan
 
138
  result = pipe(request.prompt, **request.parameters)
139
 
140
  return {
 
148
  raise e
149
  except Exception as e:
150
  # Menangkap error lain yang mungkin terjadi saat inferensi
151
+ logger.error(f"Inference error: {str(e)}")
152
  raise HTTPException(status_code=500, detail=f"Terjadi error saat inferensi: {str(e)}")
153
 
154
+ @app.delete("/cache/{model_name}")
155
+ def clear_model_cache(model_name: str):
156
+ """Endpoint untuk menghapus model dari cache."""
157
+ if model_name in PIPELINE_CACHE:
158
+ del PIPELINE_CACHE[model_name]
159
+ logger.info(f"Model '{model_name}' removed from cache")
160
+ return {"status": "success", "message": f"Model '{model_name}' removed from cache"}
161
+ else:
162
+ raise HTTPException(status_code=404, detail=f"Model '{model_name}' tidak ada di cache")
163
+
164
+ # Startup event dengan error handling yang lebih baik
165
  @app.on_event("startup")
166
  async def startup_event():
167
+ logger.info("API startup: Melakukan warm-up dengan memuat satu model awal...")
168
+
169
+ # Pastikan cache directory siap
170
+ ensure_cache_directory()
171
+
172
  try:
173
+ # Coba model yang paling kecil terlebih dahulu
174
+ get_pipeline("GPT-2-Tinny")
175
+ logger.info("Warm-up berhasil!")
176
  except Exception as e:
177
+ logger.warning(f"Gagal melakukan warm-up: {e}")
178
+ logger.info("API tetap berjalan, model akan dimuat saat diperlukan.")