Arafath10's picture
Update main.py
81a3473 verified
raw
history blame
2.03 kB
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("Arafath10/reference_page_finder")
# Load the model
model = AutoModelForSequenceClassification.from_pretrained("Arafath10/reference_page_finder")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/find_refrence_page")
async def find_refrence_page(request: Request):
try:
# Extract the JSON body
body = await request.json()
test_text = body.get("text")
if not test_text:
raise HTTPException(status_code=400, detail="Missing 'text' field in request body")
import re
# Remove all types of extra whitespace (spaces, tabs, newlines)
test_text = re.sub(r'\s+', ' ', test_text).strip()
def chunk_string(input_string, chunk_size):
return [input_string[i:i + chunk_size] for i in range(0, len(input_string), chunk_size)]
chunks = chunk_string(test_text, chunk_size=512)
chunks = reversed(chunks)
# Output the chunks
flag = "no reference found"
for idx, chunk in enumerate(chunks):
print(f"Chunk {idx + 1} {chunk}")
inputs = tokenizer(chunk, return_tensors="pt", truncation=True, padding="max_length")
outputs = model(**inputs)
predictions = np.argmax(outputs.logits.detach().numpy(), axis=-1)
#print("Prediction:", "yes reference found" if predictions[0] == 1 else "no reference found")
if predictions[0] == 1:
flag = "yes reference found"
break
return flag
except:
return "error"
#print(main("https://www.keells.com/", "Please analyse reports"))