Spaces:
Running
Running
from fastapi import FastAPI | |
import torch | |
from transformers import CLIPProcessor, CLIPModel | |
from dotenv import load_dotenv | |
import logging | |
import os | |
load_dotenv() | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="Text Embedding API", | |
description="Returns CLIP text embeddings via GET") | |
HF_TOKEN = os.getenv('hf_token') | |
logger.info("Loading CLIP processor and model...") | |
try: | |
processor = CLIPProcessor.from_pretrained( | |
"openai/clip-vit-large-patch14", use_auth_token=HF_TOKEN) | |
clip_model = CLIPModel.from_pretrained( | |
"openai/clip-vit-large-patch14", use_auth_token=HF_TOKEN) | |
clip_model.eval() | |
logger.info("CLIP model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load CLIP model: {e}") | |
raise | |
def get_text_embedding(text: str): | |
logger.info(f"Processing text: {text}") | |
try: | |
inputs = processor(text=[text], return_tensors="pt", | |
padding=True, truncation=True) | |
with torch.no_grad(): | |
text_embedding = clip_model.get_text_features(**inputs) | |
logger.info("Text embedding generated") | |
return text_embedding.squeeze(0).tolist() | |
except Exception as e: | |
logger.error(f"Error generating embedding: {e}") | |
raise | |
async def root(): | |
logger.info("Root endpoint accessed") | |
return {"message": "Welcome to the Text Embedding API. Use GET https://ashish-001-text-embedding-api.hf.space/embedding?text=your_text to get embeddings."} | |
async def get_embedding(text: str): | |
logger.info(f"Embedding endpoint called with text") | |
embedding = get_text_embedding(text) | |
return {"embedding": embedding, "dimension": len(embedding)} | |