Update app.py
Browse files
app.py
CHANGED
@@ -1,28 +1,69 @@
|
|
1 |
import os
|
2 |
-
|
3 |
-
|
4 |
-
from
|
5 |
-
from pydantic import BaseModel
|
6 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# 1. Load model & tokenizer once at startup
|
9 |
MODEL_ID = "EQuIP-Queries/EQuIP_3B"
|
10 |
-
|
11 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
12 |
-
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
|
|
|
|
|
|
|
13 |
|
14 |
# 2. Initialize FastAPI
|
15 |
-
app = FastAPI(
|
|
|
|
|
16 |
|
17 |
-
# 3. Define request
|
18 |
class GenerateRequest(BaseModel):
|
19 |
-
prompt: str
|
20 |
-
max_new_tokens: int = 50
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
#
|
23 |
-
@app.post("/generate")
|
24 |
async def generate(req: GenerateRequest):
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import logging
|
3 |
+
from fastapi import FastAPI, HTTPException
|
4 |
+
from pydantic import BaseModel, Field
|
|
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
# Configure logging
|
9 |
+
logging.basicConfig(level=logging.INFO)
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
|
12 |
# 1. Load model & tokenizer once at startup
|
13 |
MODEL_ID = "EQuIP-Queries/EQuIP_3B"
|
14 |
+
try:
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
16 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
|
17 |
+
except Exception as e:
|
18 |
+
logger.error(f"Failed to load model: {e}")
|
19 |
+
raise
|
20 |
|
21 |
# 2. Initialize FastAPI
|
22 |
+
app = FastAPI(title="EQuIP Query Generator",
|
23 |
+
description="Generate Elasticsearch queries using EQuIP model",
|
24 |
+
version="1.0.0")
|
25 |
|
26 |
+
# 3. Define request/response schemas
|
27 |
class GenerateRequest(BaseModel):
|
28 |
+
prompt: str = Field(..., description="Input prompt for query generation")
|
29 |
+
max_new_tokens: int = Field(default=50, ge=1, le=512, description="Maximum number of tokens to generate")
|
30 |
+
|
31 |
+
class GenerateResponse(BaseModel):
|
32 |
+
generated_text: str
|
33 |
+
input_prompt: str
|
34 |
+
token_count: Optional[int]
|
35 |
+
|
36 |
+
# 4. Health check endpoint
|
37 |
+
@app.get("/health")
|
38 |
+
async def health_check():
|
39 |
+
return {"status": "healthy", "model": MODEL_ID}
|
40 |
|
41 |
+
# 5. Inference endpoint
|
42 |
+
@app.post("/generate", response_model=GenerateResponse)
|
43 |
async def generate(req: GenerateRequest):
|
44 |
+
try:
|
45 |
+
logger.info(f"Processing request with prompt: {req.prompt[:50]}...")
|
46 |
+
inputs = tokenizer(req.prompt, return_tensors="pt")
|
47 |
+
|
48 |
+
ids = model.generate(
|
49 |
+
**inputs,
|
50 |
+
max_new_tokens=req.max_new_tokens,
|
51 |
+
pad_token_id=tokenizer.eos_token_id,
|
52 |
+
num_return_sequences=1
|
53 |
+
)
|
54 |
+
|
55 |
+
generated_text = tokenizer.decode(ids[0], skip_special_tokens=True)
|
56 |
+
token_count = len(ids[0])
|
57 |
+
|
58 |
+
return GenerateResponse(
|
59 |
+
generated_text=generated_text,
|
60 |
+
input_prompt=req.prompt,
|
61 |
+
token_count=token_count
|
62 |
+
)
|
63 |
+
|
64 |
+
except Exception as e:
|
65 |
+
logger.error(f"Generation failed: {str(e)}")
|
66 |
+
raise HTTPException(
|
67 |
+
status_code=500,
|
68 |
+
detail=f"Generation failed: {str(e)}"
|
69 |
+
)
|