ashish-001 commited on
Commit
7cfdf11
·
verified ·
1 Parent(s): df50825

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -57
app.py CHANGED
@@ -1,57 +1,57 @@
1
- from fastapi import FastAPI
2
- import torch
3
- from transformers import CLIPProcessor, CLIPModel
4
- from dotenv import load_dotenv
5
- import logging
6
- import os
7
-
8
- load_dotenv()
9
-
10
-
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger(__name__)
13
-
14
- app = FastAPI(title="Text Embedding API",
15
- description="Returns CLIP text embeddings via GET")
16
-
17
-
18
- HF_TOKEN = os.getenv('hf_token')
19
-
20
- logger.info("Loading CLIP processor and model...")
21
- try:
22
- processor = CLIPProcessor.from_pretrained(
23
- "openai/clip-vit-large-patch14", use_auth_token=HF_TOKEN)
24
- clip_model = CLIPModel.from_pretrained(
25
- "openai/clip-vit-large-patch14", use_auth_token=HF_TOKEN)
26
- clip_model.eval()
27
- logger.info("CLIP model loaded successfully")
28
- except Exception as e:
29
- logger.error(f"Failed to load CLIP model: {e}")
30
- raise
31
-
32
-
33
- def get_text_embedding(text: str):
34
- logger.info(f"Processing text: {text}")
35
- try:
36
- inputs = processor(text=[text], return_tensors="pt",
37
- padding=True, truncation=True)
38
- with torch.no_grad():
39
- text_embedding = clip_model.get_text_features(**inputs)
40
- logger.info("Text embedding generated")
41
- return text_embedding.squeeze(0).tolist()
42
- except Exception as e:
43
- logger.error(f"Error generating embedding: {e}")
44
- raise
45
-
46
-
47
- @app.get("/")
48
- async def root():
49
- logger.info("Root endpoint accessed")
50
- return {"message": "Welcome to the Text Embedding API. Use GET /embedding?text=your_text to get embeddings."}
51
-
52
-
53
- @app.get("/embedding")
54
- async def get_embedding(text: str):
55
- logger.info(f"Embedding endpoint called with text")
56
- embedding = get_text_embedding(text)
57
- return {"embedding": embedding, "dimension": len(embedding)}
 
1
+ from fastapi import FastAPI
2
+ import torch
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ from dotenv import load_dotenv
5
+ import logging
6
+ import os
7
+
8
+ load_dotenv()
9
+
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ app = FastAPI(title="Text Embedding API",
15
+ description="Returns CLIP text embeddings via GET")
16
+
17
+
18
+ HF_TOKEN = os.getenv('hf_token')
19
+
20
+ logger.info("Loading CLIP processor and model...")
21
+ try:
22
+ processor = CLIPProcessor.from_pretrained(
23
+ "openai/clip-vit-large-patch14", use_auth_token=HF_TOKEN)
24
+ clip_model = CLIPModel.from_pretrained(
25
+ "openai/clip-vit-large-patch14", use_auth_token=HF_TOKEN)
26
+ clip_model.eval()
27
+ logger.info("CLIP model loaded successfully")
28
+ except Exception as e:
29
+ logger.error(f"Failed to load CLIP model: {e}")
30
+ raise
31
+
32
+
33
+ def get_text_embedding(text: str):
34
+ logger.info(f"Processing text: {text}")
35
+ try:
36
+ inputs = processor(text=[text], return_tensors="pt",
37
+ padding=True, truncation=True)
38
+ with torch.no_grad():
39
+ text_embedding = clip_model.get_text_features(**inputs)
40
+ logger.info("Text embedding generated")
41
+ return text_embedding.squeeze(0).tolist()
42
+ except Exception as e:
43
+ logger.error(f"Error generating embedding: {e}")
44
+ raise
45
+
46
+
47
+ @app.get("/")
48
+ async def root():
49
+ logger.info("Root endpoint accessed")
50
+ return {"message": "Welcome to the Text Embedding API. Use GET https://ashish-001-text-embedding-api.hf.space/embedding?text=your_text to get embeddings."}
51
+
52
+
53
+ @app.get("/embedding")
54
+ async def get_embedding(text: str):
55
+ logger.info(f"Embedding endpoint called with text")
56
+ embedding = get_text_embedding(text)
57
+ return {"embedding": embedding, "dimension": len(embedding)}