File size: 1,847 Bytes
9c1be03 9b74ec6 06f0356 9b74ec6 409504b 7c6c308 9b74ec6 7c6c308 9b74ec6 a7d6d41 a3a9074 9b74ec6 7c6c308 bbd40ae 9b74ec6 bbd40ae 7c6c308 f57466f 7c6c308 bbd40ae 7c6c308 bbd40ae 9b74ec6 7c6c308 9b74ec6 6ac9588 5464450 7c6c308 6ac9588 b70346c 5464450 7c6c308 5464450 9b74ec6 a3a9074 7c6c308 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, RedirectResponse
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
from typing import List
import numpy as np
app = FastAPI()
# Load models
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
# API endpoints
@app.post("/modify_query")
async def modify_query(request: Request):
try:
raw_data = await request.json()
binary_embeddings = model.encode([raw_data['query_string']], precision="binary")
return JSONResponse(content={'embeddings':binary_embeddings[0].tolist()})
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/modify_query_v3")
async def modify_query_v3(request: Request):
try:
# Generate embeddings for a list of query strings
raw_data = await request.json()
embeddings = model.encode(raw_data['query_string_list'])
return JSONResponse(content={'embeddings':[emb.tolist() for emb in embeddings]})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error in modifying query v3: {str(e)}")
@app.post("/makeanswer")
async def makeAnswer(request: Request):
try:
# Summarize the context
raw_data = await request.json()
response = summarizer(raw_data['context'], max_length=130, min_length=30, do_sample=False)
return JSONResponse(content={'answer':response[0]["summary_text"]})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error in T5 summarization: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|