File size: 3,400 Bytes
4ca8caf
 
 
 
 
76532d7
4ca8caf
 
 
 
 
 
cb594b3
4ca8caf
cb594b3
 
 
 
76532d7
cb594b3
4ca8caf
 
 
 
 
 
 
 
 
 
 
 
cb594b3
 
 
 
 
 
 
 
 
 
 
 
 
 
4ca8caf
cb594b3
 
 
4ca8caf
 
cb594b3
 
 
4ca8caf
 
 
 
 
 
 
 
 
ba2ef59
 
 
4ca8caf
 
cb594b3
4ca8caf
 
cb594b3
4ca8caf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb594b3
4ca8caf
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import urllib.request 
import fitz
import re
import numpy as np
import tensorflow_hub as hub
from openai import OpenAI
import gradio as gr
import os
import shutil
from pathlib import Path
from tempfile import NamedTemporaryFile
from sklearn.neighbors import NearestNeighbors
import anthropic

# client = OpenAI(
#     base_url='https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1/v1/',
#     api_key=os.getenv('openai_key')
# )

client = anthropic.Anthropic()

from util import pdf_to_text, text_to_chunks, SemanticSearch

recommender = SemanticSearch()
def load_recommender(path, start_page=1):
    global recommender
    texts = pdf_to_text(path, start_page=start_page)
    chunks = text_to_chunks(texts, start_page=start_page)
    recommender.fit(chunks)
    return 'Corpus Loaded.'


# def openai_generate_text(prompt, model = "gpt-3.5-turbo-16k-0613"):
#     model="mistralai/Mixtral-8x7B-Instruct-v0.1"
#     max_tokens=1024
#     message = clinet.chat.completions.create(
#         model=model,
#         messages=[
#             {"role": "user", "content": prompt}
#         ],
#         max_tokens=max_tokens,
#     ).choices[0].message.content
#     return message

def claude_generate_text(prompt, model = "claude-3-haiku-20240307"):
    message = client.messages.create(
        model=model,
        max_tokens=1000,
        temperature=0.0,
        # system="Respond only in mandarin",
        messages=[
            {"role": "user", "content": prompt}
        ]
    )
    return message.content[0].text

def generate_answer(question):
    topn_chunks = recommender(question)
    prompt = 'search results:\n\n'
    for c in topn_chunks:
        prompt += c + '\n\n'
        
    prompt += "Instructions: Compose a comprehensive reply to the query using the search results given. "\
              "Cite each reference using [ Page Number] notation. "\
              "Only answer what is asked. The answer should be short and concise. "\
              "If asked in Chinese, respond in Chinese; if asked in English, respond"\
              "in English \n\nQuery: "
    
    prompt += f"{question}\nAnswer:"
    answer = claude_generate_text(prompt)
    return answer

def question_answer(chat_history, file, question):
    suffix = Path(file.name).suffix
    with NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        shutil.copyfile(file.name, tmp.name)
        tmp_path = Path(tmp.name)

    load_recommender(str(tmp_path))
    answer = generate_answer(question)
    chat_history.append([question, answer])
    return chat_history

title = 'PDF GPT '
description = """ PDF GPT """

with gr.Blocks(css="""#chatbot { font-size: 14px; min-height: 1200; }""") as demo:

    gr.Markdown(f'<center><h3>{title}</h3></center>')
    gr.Markdown(description)

    with gr.Row():
        
        with gr.Group():
            with gr.Accordion("URL or pdf file"):
                file = gr.File(label='Upload your PDF/ Research Paper / Book here', file_types=['.pdf'])
            question = gr.Textbox(label='Enter your question here')
            btn = gr.Button(value='Submit')

        with gr.Group():
            chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot")

    btn.click(
        question_answer,
        inputs=[chatbot, file, question],
        outputs=[chatbot],
        api_name="predict",
    )

demo.launch(server_name="0.0.0.0")