|
import os |
|
import pickle |
|
from typing import Optional, Tuple |
|
import gradio as gr |
|
from threading import Lock |
|
|
|
from langchain.llms import OpenAI |
|
from langchain.chains import ChatVectorDBChain |
|
from template import QA_PROMPT, CONDENSE_QUESTION_PROMPT |
|
from pdf2vectorstore import convert_to_vectorstore |
|
|
|
def get_chain(api_key, vectorstore, model_name): |
|
llm = OpenAI(model_name = model_name, temperature=0, openai_api_key=api_key) |
|
qa_chain = ChatVectorDBChain.from_llm( |
|
llm, |
|
vectorstore, |
|
qa_prompt=QA_PROMPT, |
|
condense_question_prompt=CONDENSE_QUESTION_PROMPT, |
|
) |
|
return qa_chain |
|
|
|
def set_openai_api_key(api_key: str, vectorstore, model_name: str): |
|
if api_key: |
|
chain = get_chain(api_key, vectorstore, model_name) |
|
return chain |
|
|
|
class ChatWrapper: |
|
|
|
def __init__(self): |
|
self.lock = Lock() |
|
self.previous_url = "" |
|
self.vectorstore_state = None |
|
self.chain = None |
|
|
|
def __call__( |
|
self, |
|
api_key: str, |
|
arxiv_url: str, |
|
inp: str, |
|
history: Optional[Tuple[str, str]], |
|
model_name: str, |
|
): |
|
if not arxiv_url or not api_key: |
|
history = history or [] |
|
history.append((inp, "Please provide both arXiv URL and API key to begin")) |
|
return history, history |
|
|
|
if arxiv_url != self.previous_url: |
|
history = [] |
|
vectorstore = convert_to_vectorstore(arxiv_url, api_key) |
|
self.previous_url = arxiv_url |
|
self.chain = set_openai_api_key(api_key, vectorstore, model_name) |
|
self.vectorstore_state = vectorstore |
|
|
|
if self.chain is None: |
|
self.chain = set_openai_api_key(api_key, self.vectorstore_state, model_name) |
|
|
|
self.lock.acquire() |
|
try: |
|
history = history or [] |
|
if self.chain is None: |
|
history.append((inp, "Please paste your OpenAI key to use")) |
|
return history, history |
|
import openai |
|
openai.api_key = api_key |
|
output = self.chain ({"question": inp, "chat_history": history})["answer"] |
|
history.append((inp, output)) |
|
except Exception as e: |
|
raise e |
|
finally: |
|
api_key = "" |
|
self.lock.release() |
|
return history, history |
|
|
|
chat = ChatWrapper() |
|
|
|
block = gr.Blocks(css=".gradio-container {background-color: #f8f8f8; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif}") |
|
|
|
with block: |
|
gr.HTML("<h1 style='text-align: center;'>ArxivGPT</h1>") |
|
gr.HTML("<h3 style='text-align: center;'>Ask questions about research papers</h3>") |
|
|
|
with gr.Row(): |
|
with gr.Column(width="auto"): |
|
openai_api_key_textbox = gr.Textbox( |
|
label="OpenAI API Key", |
|
placeholder="Paste your OpenAI API key (sk-...)", |
|
show_label=True, |
|
lines=1, |
|
type="password", |
|
) |
|
with gr.Column(width="auto"): |
|
arxiv_url_textbox = gr.Textbox( |
|
label="Arxiv URL", |
|
placeholder="Enter the arXiv URL", |
|
show_label=True, |
|
lines=1, |
|
) |
|
with gr.Column(width="auto"): |
|
model_dropdown = gr.Dropdown( |
|
label="Choose a model (GPT-4 coming soon!)", |
|
choices=["gpt-3.5-turbo"], |
|
) |
|
|
|
chatbot = gr.Chatbot() |
|
|
|
with gr.Row(): |
|
message = gr.Textbox( |
|
label="What's your question?", |
|
placeholder="Ask questions about the paper you just linked", |
|
lines=1, |
|
) |
|
submit = gr.Button(value="Send", variant="secondary").style(full_width=False) |
|
|
|
gr.Examples( |
|
examples=[ |
|
"Please give me a brief summary about this paper", |
|
"Are there any interesting correlations in the given paper?", |
|
"How can this paper be applied in the real world?", |
|
"What are the limitations of this paper?", |
|
], |
|
inputs=message, |
|
) |
|
|
|
gr.HTML( |
|
"<center style='margin-top: 20px;'>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain π¦οΈπ</a></center>" |
|
) |
|
|
|
state = gr.State() |
|
|
|
submit.click(chat, |
|
inputs=[openai_api_key_textbox, arxiv_url_textbox, message, state, model_dropdown], |
|
outputs=[chatbot, state]) |
|
message.submit(chat, |
|
inputs=[openai_api_key_textbox, arxiv_url_textbox, message, state, model_dropdown], |
|
outputs=[chatbot, state]) |
|
|
|
block.launch(width=800) |