ashish-001 commited on
Commit
65b0ce5
·
verified ·
1 Parent(s): 4ae867f

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +11 -0
  2. app.py +57 -0
  3. requirements.txt +6 -0
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . /app
6
+
7
+ RUN pip install --no-cache-dir -r requirements.txt uvicorn
8
+
9
+ EXPOSE 7860
10
+
11
+ CMD ["uvicorn","app:app","--host","0.0.0.0","--port","7860"]
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ import torch
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ from dotenv import load_dotenv
5
+ import logging
6
+ import os
7
+
8
+ load_dotenv()
9
+
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ app = FastAPI(title="Text Embedding API",
15
+ description="Returns CLIP text embeddings via GET")
16
+
17
+
18
+ HF_TOKEN = os.getenv('hf_token')
19
+
20
+ logger.info("Loading CLIP processor and model...")
21
+ try:
22
+ processor = CLIPProcessor.from_pretrained(
23
+ "openai/clip-vit-large-patch14", use_auth_token=HF_TOKEN)
24
+ clip_model = CLIPModel.from_pretrained(
25
+ "openai/clip-vit-large-patch14", use_auth_token=HF_TOKEN)
26
+ clip_model.eval()
27
+ logger.info("CLIP model loaded successfully")
28
+ except Exception as e:
29
+ logger.error(f"Failed to load CLIP model: {e}")
30
+ raise
31
+
32
+
33
+ def get_text_embedding(text: str):
34
+ logger.info(f"Processing text: {text}")
35
+ try:
36
+ inputs = processor(text=[text], return_tensors="pt",
37
+ padding=True, truncation=True)
38
+ with torch.no_grad():
39
+ text_embedding = clip_model.get_text_features(**inputs)
40
+ logger.info("Text embedding generated")
41
+ return text_embedding.squeeze(0).tolist()
42
+ except Exception as e:
43
+ logger.error(f"Error generating embedding: {e}")
44
+ raise
45
+
46
+
47
+ @app.get("/")
48
+ async def root():
49
+ logger.info("Root endpoint accessed")
50
+ return {"message": "Welcome to the Text Embedding API. Use GET /embedding?text=your_text to get embeddings."}
51
+
52
+
53
+ @app.get("/embedding")
54
+ async def get_embedding(text: str):
55
+ logger.info(f"Embedding endpoint called with text")
56
+ embedding = get_text_embedding(text)
57
+ return {"embedding": embedding, "dimension": len(embedding)}
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers==4.49.0
2
+ fastapi==0.115.11
3
+ pydantic==2.10.6
4
+ torch==2.6.0
5
+ pillow==11.1.0
6
+ python-dotenv==1.0.1