iShare's picture
Update main.py
eb1095f
raw
history blame
6.38 kB
from fastapi import FastAPI, File, UploadFile, Request
from pydantic import BaseModel
from pathlib import Path
from fastapi import Form
from fastapi.responses import JSONResponse
from langchain.text_splitter import RecursiveCharacterTextSplitter
from PyPDF2 import PdfReader
from fastapi import Depends
#在FastAPI中,Depends()函数用于声明依赖项
from langchain.chains.question_answering import load_qa_chain
from langchain import PromptTemplate, LLMChain
from langchain import HuggingFaceHub
from langchain.document_loaders import TextLoader
import torch
import requests
import random
import string
import sys
import timeit
import datetime
import io
import os
from dotenv import load_dotenv
load_dotenv()
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
model_id = os.getenv('model_id')
hf_token = os.getenv('hf_token')
repo_id = os.getenv('repo_id')
def get_embeddings(input_str_texts):
response = requests.post(api_url, headers=headers, json={"inputs": input_str_texts, "options":{"wait_for_model":True}})
return response.json()
def generate_random_string(length):
letters = string.ascii_lowercase
return ''.join(random.choice(letters) for i in range(length))
def remove_context(text):
if 'Context:' in text:
end_of_context = text.find('\n\n')
return text[end_of_context + 2:]
else:
return text
api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_id}"
headers = {"Authorization": f"Bearer {hf_token}"}
llm = HuggingFaceHub(repo_id=repo_id,
model_kwargs={"min_length":100,
"max_new_tokens":1024, "do_sample":True,
"temperature":0.1,
"top_k":50,
"top_p":0.95, "eos_token_id":49155})
prompt_template = """
You are a very helpful AI assistant. Please ONLY use {context} to answer the user's question {question}. If you don't know the answer, just say that you don't know. DON'T try to make up an answer.
Your response should be full and easy to understand.
"""
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=PROMPT)
app = FastAPI()
class FileToProcess(BaseModel):
uploaded_file: UploadFile = File(...)
@app.get("/")
async def home():
return "API Working!"
@app.post("/fastapi_file_upload_process")
#async def upload_file(username: str, file_to_process: FileToProcess = Depends()):
async def pdf_file_qa_process(username: str, request: Request, file_to_process: FileToProcess = Depends()):
uploaded_file = file_to_process.uploaded_file
print("File received:"+uploaded_file.filename)
username = request.query_params.get("username")
filename = request.query_params.get("filename")
print(username)
print(filename)
random_string = generate_random_string(20)
file_path = Path.cwd() / random_string
file_path.mkdir(parents=True, exist_ok=True)
file_saved_in_api = file_path / uploaded_file.filename
print(file_saved_in_api)
print(file_saved_in_api.resolve())
with open(file_saved_in_api, "wb+") as file_object:
file_object.write(uploaded_file.file.read())
text_splitter = RecursiveCharacterTextSplitter(
#separator = "\n",
chunk_size = 500,
chunk_overlap = 100, #striding over the text
length_function = len,
)
doc_reader = PdfReader(file_saved_in_api)
raw_text = ''
for i, page in enumerate(doc_reader.pages):
text = page.extract_text()
if text:
raw_text += text
temp_texts = text_splitter.split_text(raw_text)
texts=temp_texts
initial_embeddings=get_embeddings(temp_texts)
db_embeddings = torch.FloatTensor(initial_embeddings)
print("db_embeddings created...")
#question = var_query.query
question = username
print("API Call Query Received: "+question)
q_embedding=get_embeddings(question)
final_q_embedding = torch.FloatTensor(q_embedding)
from sentence_transformers.util import semantic_search
hits = semantic_search(final_q_embedding, torch.FloatTensor(db_embeddings), top_k=5)
page_contents = []
for i in range(len(hits[0])):
page_content = texts[hits[0][i]['corpus_id']]
page_contents.append(page_content)
temp_page_contents=str(page_contents)
final_page_contents = temp_page_contents.replace('\\n', '')
random_string_2=generate_random_string(20)
file_path = random_string_2 + ".txt"
with open(file_path, "w", encoding="utf-8") as file:
file.write(final_page_contents)
loader = TextLoader(file_path, encoding="utf-8")
loaded_documents = loader.load()
temp_ai_response = chain({"input_documents": loaded_documents, "question": question}, return_only_outputs=False)
initial_ai_response=temp_ai_response['output_text']
cleaned_initial_ai_response = remove_context(initial_ai_response)
#final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '')
final_ai_response = cleaned_initial_ai_response.partition('¿Cuál es')[0].strip()
final_ai_response = final_ai_response.partition('¿Cuáles')[0].strip()
final_ai_response = final_ai_response.partition('<|end|>')[0].strip().replace('\n\n', '\n').replace('<|end|>', '').replace('<|user|>', '').replace('<|system|>', '').replace('<|assistant|>', '')
new_final_ai_response = final_ai_response.split('Unhelpful Answer:')[0].strip()
new_final_ai_response = new_final_ai_response.split('Note:')[0].strip()
new_final_ai_response = new_final_ai_response.split('Please provide feedback on how to improve the chatbot.')[0].strip()
api_call_msg={"INFO": f"File '{file_saved_in_api}' saved to your profile."}
print(api_call_msg)
print(api_call_msg["INFO"])
print()
print(api_call_msg["INFO"].replace("uploaded_file", uploaded_file.filename))
print("API call finished...")
#return {"INFO": f"File '{uploaded_file.filename}' saved to your profile."}
return api_call_msg