embeding_api / main.py
Arafath10's picture
Update main.py
07fb065 verified
raw
history blame
1.42 kB
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModel, AutoTokenizer
import torch
device = torch.device("cpu")
# Load the model and tokenizer
model = AutoModel.from_pretrained(
"nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True
)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def chunk_text(text, chunk_size=512):
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
@app.post("/get_embeding")
async def get_embeding(text):
chunks = chunk_text(text)
for chunk in chunks:
# Tokenize the input text
inputs = tokenizer(chunk, return_tensors="pt")
# Generate embeddings
with torch.no_grad():
outputs = model(**inputs)
# The embeddings can be found in the 'last_hidden_state'
embeddings = outputs.last_hidden_state
# Optionally, you can average the token embeddings to get a single vector for the sentence
sentence_embedding = torch.mean(embeddings, dim=1)
#print(sentence_embedding)
return sentence_embedding