Spaces:
Sleeping
Sleeping
File size: 2,031 Bytes
81a3473 91bee69 d92c861 81a3473 572cc27 d92c861 81a3473 d92c861 81a3473 66e97f3 81a3473 e782b03 81a3473 91bee69 81a3473 91bee69 81a3473 38f11d6 81a3473 91bee69 38f11d6 91bee69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("Arafath10/reference_page_finder")
# Load the model
model = AutoModelForSequenceClassification.from_pretrained("Arafath10/reference_page_finder")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/find_refrence_page")
async def find_refrence_page(request: Request):
try:
# Extract the JSON body
body = await request.json()
test_text = body.get("text")
if not test_text:
raise HTTPException(status_code=400, detail="Missing 'text' field in request body")
import re
# Remove all types of extra whitespace (spaces, tabs, newlines)
test_text = re.sub(r'\s+', ' ', test_text).strip()
def chunk_string(input_string, chunk_size):
return [input_string[i:i + chunk_size] for i in range(0, len(input_string), chunk_size)]
chunks = chunk_string(test_text, chunk_size=512)
chunks = reversed(chunks)
# Output the chunks
flag = "no reference found"
for idx, chunk in enumerate(chunks):
print(f"Chunk {idx + 1} {chunk}")
inputs = tokenizer(chunk, return_tensors="pt", truncation=True, padding="max_length")
outputs = model(**inputs)
predictions = np.argmax(outputs.logits.detach().numpy(), axis=-1)
#print("Prediction:", "yes reference found" if predictions[0] == 1 else "no reference found")
if predictions[0] == 1:
flag = "yes reference found"
break
return flag
except:
return "error"
#print(main("https://www.keells.com/", "Please analyse reports"))
|