ash-98 commited on
Commit
1ec55c6
·
verified ·
1 Parent(s): c8d9e89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -17
app.py CHANGED
@@ -1,28 +1,69 @@
1
  import os
2
-
3
-
4
- from fastapi import FastAPI
5
- from pydantic import BaseModel
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
7
 
8
  # 1. Load model & tokenizer once at startup
9
  MODEL_ID = "EQuIP-Queries/EQuIP_3B"
10
- # Specify cache_dir just in case
11
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
 
 
 
13
 
14
  # 2. Initialize FastAPI
15
- app = FastAPI()
 
 
16
 
17
- # 3. Define request schema
18
  class GenerateRequest(BaseModel):
19
- prompt: str
20
- max_new_tokens: int = 50
 
 
 
 
 
 
 
 
 
 
21
 
22
- # 4. Inference endpoint
23
- @app.post("/generate")
24
  async def generate(req: GenerateRequest):
25
- inputs = tokenizer(req.prompt, return_tensors="pt")
26
- ids = model.generate(**inputs, max_new_tokens=req.max_new_tokens)
27
- text = tokenizer.decode(ids[0], skip_special_tokens=True)
28
- return {"generated_text": text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import logging
3
+ from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel, Field
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from typing import Optional
7
+
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
 
12
  # 1. Load model & tokenizer once at startup
13
  MODEL_ID = "EQuIP-Queries/EQuIP_3B"
14
+ try:
15
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
17
+ except Exception as e:
18
+ logger.error(f"Failed to load model: {e}")
19
+ raise
20
 
21
  # 2. Initialize FastAPI
22
+ app = FastAPI(title="EQuIP Query Generator",
23
+ description="Generate Elasticsearch queries using EQuIP model",
24
+ version="1.0.0")
25
 
26
+ # 3. Define request/response schemas
27
  class GenerateRequest(BaseModel):
28
+ prompt: str = Field(..., description="Input prompt for query generation")
29
+ max_new_tokens: int = Field(default=50, ge=1, le=512, description="Maximum number of tokens to generate")
30
+
31
+ class GenerateResponse(BaseModel):
32
+ generated_text: str
33
+ input_prompt: str
34
+ token_count: Optional[int]
35
+
36
+ # 4. Health check endpoint
37
+ @app.get("/health")
38
+ async def health_check():
39
+ return {"status": "healthy", "model": MODEL_ID}
40
 
41
+ # 5. Inference endpoint
42
+ @app.post("/generate", response_model=GenerateResponse)
43
  async def generate(req: GenerateRequest):
44
+ try:
45
+ logger.info(f"Processing request with prompt: {req.prompt[:50]}...")
46
+ inputs = tokenizer(req.prompt, return_tensors="pt")
47
+
48
+ ids = model.generate(
49
+ **inputs,
50
+ max_new_tokens=req.max_new_tokens,
51
+ pad_token_id=tokenizer.eos_token_id,
52
+ num_return_sequences=1
53
+ )
54
+
55
+ generated_text = tokenizer.decode(ids[0], skip_special_tokens=True)
56
+ token_count = len(ids[0])
57
+
58
+ return GenerateResponse(
59
+ generated_text=generated_text,
60
+ input_prompt=req.prompt,
61
+ token_count=token_count
62
+ )
63
+
64
+ except Exception as e:
65
+ logger.error(f"Generation failed: {str(e)}")
66
+ raise HTTPException(
67
+ status_code=500,
68
+ detail=f"Generation failed: {str(e)}"
69
+ )