from fastapi import FastAPI import torch from transformers import CLIPProcessor, CLIPModel from dotenv import load_dotenv import logging import os load_dotenv() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Text Embedding API", description="Returns CLIP text embeddings via GET") HF_TOKEN = os.getenv('hf_token') logger.info("Loading CLIP processor and model...") try: processor = CLIPProcessor.from_pretrained( "openai/clip-vit-large-patch14", use_auth_token=HF_TOKEN) clip_model = CLIPModel.from_pretrained( "openai/clip-vit-large-patch14", use_auth_token=HF_TOKEN) clip_model.eval() logger.info("CLIP model loaded successfully") except Exception as e: logger.error(f"Failed to load CLIP model: {e}") raise def get_text_embedding(text: str): logger.info(f"Processing text: {text}") try: inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): text_embedding = clip_model.get_text_features(**inputs) logger.info("Text embedding generated") return text_embedding.squeeze(0).tolist() except Exception as e: logger.error(f"Error generating embedding: {e}") raise @app.get("/") async def root(): logger.info("Root endpoint accessed") return {"message": "Welcome to the Text Embedding API. Use GET https://ashish-001-text-embedding-api.hf.space/embedding?text=your_text to get embeddings."} @app.get("/embedding") async def get_embedding(text: str): logger.info(f"Embedding endpoint called with text") embedding = get_text_embedding(text) return {"embedding": embedding, "dimension": len(embedding)}