|
import os |
|
import pickle |
|
from typing import Optional, Tuple |
|
import gradio as gr |
|
from threading import Lock |
|
|
|
from langchain.llms import OpenAI |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.chains import ChatVectorDBChain, ConversationalRetrievalChain |
|
from template import QA_PROMPT, CONDENSE_QUESTION_PROMPT |
|
from pdf2vectorstore import convert_to_vectorstore |
|
|
|
def get_chain(api_key, vectorstore, model_name): |
|
if model_name == "gpt-4": |
|
llm = ChatOpenAI(model_name = model_name, temperature=0, openai_api_key=api_key) |
|
retriever = vectorstore.as_retriever() |
|
retriever.search_kwargs['distance_metric'] = 'cos' |
|
retriever.search_kwargs['fetch_k'] = 100 |
|
retriever.search_kwargs['maximal_marginal_relevance'] = True |
|
retriever.search_kwargs['k'] = 10 |
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm, |
|
retriever, |
|
qa_prompt=QA_PROMPT, |
|
condense_question_prompt=CONDENSE_QUESTION_PROMPT, |
|
) |
|
return qa_chain |
|
else: |
|
llm = OpenAI(model_name = model_name, temperature=0, openai_api_key=api_key) |
|
qa_chain = ChatVectorDBChain.from_llm( |
|
llm, |
|
vectorstore, |
|
qa_prompt=QA_PROMPT, |
|
condense_question_prompt=CONDENSE_QUESTION_PROMPT, |
|
) |
|
return qa_chain |
|
|
|
def set_openai_api_key(api_key: str, vectorstore, model_name: str): |
|
if api_key: |
|
chain = get_chain(api_key, vectorstore, model_name) |
|
return chain |
|
|
|
class ChatWrapper: |
|
|
|
def __init__(self): |
|
self.lock = Lock() |
|
self.previous_url = "" |
|
self.vectorstore_state = None |
|
self.chain = None |
|
|
|
def __call__( |
|
self, |
|
api_key: str, |
|
arxiv_url: str, |
|
inp: str, |
|
history: Optional[Tuple[str, str]], |
|
model_name: str, |
|
): |
|
if not arxiv_url or not api_key: |
|
history = history or [] |
|
history.append((inp, "Please provide both arXiv URL and API key to begin")) |
|
return history, history |
|
|
|
if arxiv_url != self.previous_url: |
|
history = [] |
|
vectorstore = convert_to_vectorstore(arxiv_url, api_key) |
|
self.previous_url = arxiv_url |
|
self.chain = set_openai_api_key(api_key, vectorstore, model_name) |
|
self.vectorstore_state = vectorstore |
|
|
|
if self.chain is None: |
|
self.chain = set_openai_api_key(api_key, self.vectorstore_state, model_name) |
|
|
|
self.lock.acquire() |
|
try: |
|
history = history or [] |
|
if self.chain is None: |
|
history.append((inp, "Please paste your OpenAI key to use")) |
|
return history, history |
|
import openai |
|
openai.api_key = api_key |
|
output = self.chain ({"question": inp, "chat_history": history})["answer"] |
|
history.append((inp, output)) |
|
except Exception as e: |
|
raise e |
|
finally: |
|
api_key = "" |
|
self.lock.release() |
|
return history, history |
|
|
|
chat = ChatWrapper() |
|
|
|
block = gr.Blocks(css=".gradio-container {background-color: #f8f8f8; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif}") |
|
|
|
with block: |
|
gr.HTML(""" |
|
<style> |
|
body { |
|
background-color: #f5f5f5; |
|
font-family: 'Roboto', sans-serif; |
|
padding: 30px; |
|
} |
|
</style> |
|
""") |
|
|
|
gr.HTML("<h1 style='text-align: center;'>ArxivGPT</h1>") |
|
gr.HTML("<h3 style='text-align: center;'>Ask questions about research papers</h3>") |
|
|
|
with gr.Row(): |
|
with gr.Column(width="auto"): |
|
openai_api_key_textbox = gr.Textbox( |
|
label="OpenAI API Key", |
|
placeholder="Paste your OpenAI API key (sk-...)", |
|
show_label=True, |
|
lines=1, |
|
type="password", |
|
) |
|
with gr.Column(width="auto"): |
|
arxiv_url_textbox = gr.Textbox( |
|
label="Arxiv URL", |
|
placeholder="Enter the arXiv URL", |
|
show_label=True, |
|
lines=1, |
|
) |
|
with gr.Column(width="auto"): |
|
model_dropdown = gr.Dropdown( |
|
label="Choose a model", |
|
choices=["gpt-3.5-turbo", "gpt-4"], |
|
) |
|
|
|
chatbot = gr.Chatbot() |
|
|
|
with gr.Row(): |
|
message = gr.Textbox( |
|
label="What's your question?", |
|
placeholder="Ask questions about the paper you just linked", |
|
lines=1, |
|
) |
|
submit = gr.Button(value="Send", variant="secondary").style(full_width=False) |
|
|
|
gr.Examples( |
|
examples=[ |
|
"What's this paper about?", |
|
"Please give me a brief summary about this paper", |
|
"Are there any interesting correlations in the given paper?", |
|
"How can this paper be applied in the real world?", |
|
"What are the limitations of this paper?", |
|
], |
|
inputs=message, |
|
) |
|
gr.HTML(""" |
|
<div style="text-align:center"> |
|
<p>Developed by <a href='https://www.linkedin.com/in/dekay/'>Github and Huggingface: Volkopat</a></p> |
|
<p>Powered by <a href='https://openai.com/'>OpenAI</a>, <a href='https://arxiv.org/'>arXiv</a> and <a href='https://github.com/hwchase17/langchain'>LangChain π¦οΈπ</a></p> |
|
<p>ArxivGPT is a chatbot that answers questions about research papers. It uses a pretrained GPT-3.5 model to generate answers.</p> |
|
<p>Currently, it can answer questions about the paper you just linked.</p> |
|
<p>It's still in development, so please report any bugs you find. </p> |
|
<p>It can take up to a minute to start a conversation for every new paper as this is just a demo hosted on a lightweight service.</p> |
|
<p>For best results, test it on better hardware. Took 20 seconds to start on M1 Chip</p> |
|
<p>The answers can be quite limited as there is a 4096 token limit for GPT-3.5, hence wait for GPT-4 access for better quality.</p> |
|
<p>If you don't get a response for GPT-4, it is likely that you don't have API access, try 3.5</p> |
|
<p>Possible upgrades coming up: faster parsing, status messages, other research paper hubs.</p> |
|
</div> |
|
<style> |
|
p { |
|
margin-bottom: 10px; |
|
font-size: 16px; |
|
} |
|
a { |
|
color: #3867d6; |
|
text-decoration: none; |
|
} |
|
a:hover { |
|
text-decoration: underline; |
|
} |
|
</style> |
|
""") |
|
|
|
state = gr.State() |
|
|
|
submit.click(chat, |
|
inputs=[openai_api_key_textbox, arxiv_url_textbox, message, state, model_dropdown], |
|
outputs=[chatbot, state]) |
|
message.submit(chat, |
|
inputs=[openai_api_key_textbox, arxiv_url_textbox, message, state, model_dropdown], |
|
outputs=[chatbot, state]) |
|
|
|
block.launch(width=800) |