File size: 1,749 Bytes
5e8fd8b
 
553dd69
4929aba
d14b0f7
5e8fd8b
 
ab4a5ae
5e8fd8b
 
 
 
 
 
 
 
 
 
 
 
 
9f0a9ca
5e8fd8b
ab4a5ae
608245a
6eb1c7e
553dd69
6eb1c7e
553dd69
6eb1c7e
553dd69
628c689
 
608245a
3f857b9
608245a
 
 
d14b0f7
5e8fd8b
 
 
4c88907
1ff6584
453bbe9
1ff6584
 
9f0a9ca
24992a3
9f0a9ca
2cfa0e6
1ff6584
 
 
453bbe9
 
5e8fd8b
 
 
 
4c88907
 
 
 
 
 
5e8fd8b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import shutil
import time
from typing import List
from fastapi import FastAPI, UploadFile
from fastapi.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from .rag import ChatPDF

middleware = [
    Middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_methods=['*'],
        allow_headers=['*']
    )
]

app = FastAPI(middleware=middleware)

files_dir = os.path.expanduser("~/wtp_be_files/")
session_assistant = ChatPDF()


def astreamer(generator):
    t0 = time.time()
    for i in generator:
        print(f"Chunk being yielded (time {int((time.time()-t0)*1000)}ms)", flush=True)
        yield i
    print(f"Over (time {int((time.time()-t0)*1000)}ms)", flush=True)


@app.get("/query")
async def process_input(text: str):
    if text and len(text.strip()) > 0:
        text = text.strip()
        streaming_response = session_assistant.ask(text)
        return StreamingResponse(astreamer(streaming_response.response_gen), media_type='text/event-stream')


@app.post("/upload")
def upload(files: list[UploadFile]):  
    try:
        os.makedirs(files_dir)
        for file in files:
            try:
                path = f"{files_dir}/{file.filename}"
                file.file.seek(0)
                with open(path, 'wb') as destination:
                    shutil.copyfileobj(file.file, destination)
            finally:
                file.file.close()
    finally:
        session_assistant.ingest(files_dir)
        shutil.rmtree(files_dir)

    return "Files inserted!"


@app.get("/clear")
def ping():
    session_assistant.clear()
    return "All files have been cleared."


@app.get("/")
def ping():
    return "Pong!"