chienweichang commited on
Commit
d26cb51
1 Parent(s): e136c40

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List
4
+ from transformers import AutoTokenizer, AutoModel
5
+ import torch
6
+ import os
7
+
8
+ class EmbeddingModel:
9
+ def __init__(self, model_name="intfloat/multilingual-e5-large"):
10
+ cache_dir = os.getenv("MODEL_CACHE_DIR", "./model_cache")
11
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
12
+ self.model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
13
+
14
+ def get_embedding(self, text):
15
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
16
+ with torch.no_grad():
17
+ outputs = self.model(**inputs)
18
+ return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
19
+
20
+ app = FastAPI()
21
+ embedding_model = EmbeddingModel()
22
+
23
+ class EmbeddingRequest(BaseModel):
24
+ input: List[str]
25
+ model: str = "intfloat/multilingual-e5-large"
26
+
27
+ class EmbeddingResponse(BaseModel):
28
+ object: str = "embedding"
29
+ data: List[dict]
30
+ model: str
31
+ usage: dict
32
+
33
+ @app.post("/v1/embeddings", response_model=EmbeddingResponse)
34
+ async def create_embeddings(request: EmbeddingRequest):
35
+ if not request.input:
36
+ raise HTTPException(status_code=400, detail="Input text cannot be empty")
37
+
38
+ embeddings = []
39
+ for idx, text in enumerate(request.input):
40
+ embedding_vector = embedding_model.get_embedding(text).tolist()
41
+ embeddings.append({
42
+ "object": "embedding",
43
+ "embedding": embedding_vector,
44
+ "index": idx
45
+ })
46
+
47
+ response = EmbeddingResponse(
48
+ data=embeddings,
49
+ model=request.model,
50
+ usage={
51
+ "prompt_tokens": sum(len(text.split()) for text in request.input),
52
+ "total_tokens": sum(len(text.split()) for text in request.input)
53
+ }
54
+ )
55
+ return response