paper-central / paper_chat_tab.py
IAMJB's picture
push chat
a0359a1
raw
history blame
10.4 kB
import gradio as gr
from PyPDF2 import PdfReader
from bs4 import BeautifulSoup
import requests
from io import BytesIO
from transformers import AutoTokenizer
import os
from openai import OpenAI
# Cache for tokenizers to avoid reloading
tokenizer_cache = {}
# Function to fetch paper information from OpenReview
def fetch_paper_info_neurips(paper_id):
url = f"https://openreview.net/forum?id={paper_id}"
response = requests.get(url)
if response.status_code != 200:
return None, None
html_content = response.content
soup = BeautifulSoup(html_content, 'html.parser')
# Extract title
title_tag = soup.find('h2', class_='citation_title')
title = title_tag.get_text(strip=True) if title_tag else 'Title not found'
# Extract authors
authors = []
author_div = soup.find('div', class_='forum-authors')
if author_div:
author_tags = author_div.find_all('a')
authors = [tag.get_text(strip=True) for tag in author_tags]
author_list = ', '.join(authors) if authors else 'Authors not found'
# Extract abstract
abstract_div = soup.find('strong', text='Abstract:')
if abstract_div:
abstract_paragraph = abstract_div.find_next_sibling('div')
abstract = abstract_paragraph.get_text(strip=True) if abstract_paragraph else 'Abstract not found'
else:
abstract = 'Abstract not found'
# Construct preamble in Markdown
# preamble = f"**[{title}](https://openreview.net/forum?id={paper_id})**\n\n{author_list}\n\n**Abstract:**\n{abstract}"
preamble = f"**[{title}](https://openreview.net/forum?id={paper_id})**\n\n{author_list}\n\n"
return preamble
def fetch_paper_content(paper_id):
try:
# Construct the URL
url = f"https://openreview.net/pdf?id={paper_id}"
# Fetch the PDF
response = requests.get(url)
response.raise_for_status() # Raise an exception for HTTP errors
# Read the PDF content
pdf_content = BytesIO(response.content)
reader = PdfReader(pdf_content)
# Extract text from the PDF
text = ""
for page in reader.pages:
text += page.extract_text()
return text # Return full text; truncation will be handled later
except Exception as e:
print(f"An error occurred: {e}")
return None
def paper_chat_tab(paper_id):
with gr.Blocks() as demo:
with gr.Column():
# Textbox to display the paper title and authors
content = gr.Markdown(value="")
# Preamble message to hint the user
gr.Markdown("**Note:** Providing your own sambanova token can help you avoid rate limits.")
# Input for Hugging Face token
hf_token_input = gr.Textbox(
label="Enter your sambanova token (optional)",
type="password",
placeholder="Enter your sambanova token to avoid rate limits"
)
models = [
"Meta-Llama-3.1-8B-Instruct",
"Meta-Llama-3.1-70B-Instruct",
"Meta-Llama-3.1-405B-Instruct",
]
default_model = models[-1]
# Dropdown for selecting the model
model_dropdown = gr.Dropdown(
label="Select Model",
choices=models,
value=default_model
)
# State to store the paper content
paper_content = gr.State()
# Create a column for each model, only visible if it's the default model
columns = []
for model_name in models:
column = gr.Column(visible=(model_name == default_model))
with column:
chatbot = create_chat_interface(model_name, paper_content, hf_token_input)
columns.append(column)
gr.HTML(
'<img src="https://venturebeat.com/wp-content/uploads/2020/02/SambaNovaLogo_H_F.jpg" width="100px" />')
gr.Markdown("**Note:** This model is supported by SambaNova.")
# Update visibility of columns based on the selected model
def update_columns(selected_model):
visibility = []
for model_name in models:
is_visible = model_name == selected_model
visibility.append(gr.update(visible=is_visible))
return visibility
model_dropdown.change(
fn=update_columns,
inputs=model_dropdown,
outputs=columns,
api_name=False,
queue=False,
)
# Function to update the content Markdown and paper_content when paper ID or model changes
def update_paper_info(paper_id, selected_model):
preamble = fetch_paper_info_neurips(paper_id)
text = fetch_paper_content(paper_id)
if text is None:
return preamble, None
return preamble, text
# Update paper content when paper ID or model changes
paper_id.change(
fn=update_paper_info,
inputs=[paper_id, model_dropdown],
outputs=[content, paper_content]
)
model_dropdown.change(
fn=update_paper_info,
inputs=[paper_id, model_dropdown],
outputs=[content, paper_content],
queue=False,
)
return demo
def create_chat_interface(model_name, paper_content, hf_token_input):
# Load tokenizer and cache it
if model_name not in tokenizer_cache:
# Load the tokenizer from Hugging Face
# tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
tokenizer_cache[model_name] = tokenizer
else:
tokenizer = tokenizer_cache[model_name]
max_total_tokens = 50000 # Maximum tokens allowed
# Define the function to handle the chat
def get_fn(message, history, paper_content_value, hf_token_value):
# Include the paper content as context
if paper_content_value:
context = f"The following is the content of the paper:\n{paper_content_value}\n\n"
else:
context = ""
# Tokenize the context
context_tokens = tokenizer.encode(context)
context_token_length = len(context_tokens)
# Prepare the messages without context
messages = []
message_tokens_list = []
total_tokens = context_token_length # Start with context tokens
for user_msg, assistant_msg in history:
# Tokenize user message
user_tokens = tokenizer.encode(user_msg)
messages.append({"role": "user", "content": user_msg})
message_tokens_list.append(len(user_tokens))
total_tokens += len(user_tokens)
# Tokenize assistant message
if assistant_msg:
assistant_tokens = tokenizer.encode(assistant_msg)
messages.append({"role": "assistant", "content": assistant_msg})
message_tokens_list.append(len(assistant_tokens))
total_tokens += len(assistant_tokens)
# Tokenize the new user message
message_tokens = tokenizer.encode(message)
messages.append({"role": "user", "content": message})
message_tokens_list.append(len(message_tokens))
total_tokens += len(message_tokens)
# Check if total tokens exceed the maximum allowed tokens
if total_tokens > max_total_tokens:
# Attempt to truncate the context first
available_tokens = max_total_tokens - (total_tokens - context_token_length)
if available_tokens > 0:
# Truncate the context to fit the available tokens
truncated_context_tokens = context_tokens[:available_tokens]
context = tokenizer.decode(truncated_context_tokens)
context_token_length = available_tokens
total_tokens = total_tokens - len(context_tokens) + context_token_length
else:
# Not enough space for context; remove it
context = ""
total_tokens -= context_token_length
context_token_length = 0
# If total tokens still exceed the limit, truncate the message history
while total_tokens > max_total_tokens and len(messages) > 1:
# Remove the oldest message
removed_message = messages.pop(0)
removed_tokens = message_tokens_list.pop(0)
total_tokens -= removed_tokens
# Rebuild the final messages list including the (possibly truncated) context
final_messages = []
if context:
final_messages.append({"role": "system", "content": context})
final_messages.extend(messages)
# Use the Hugging Face token if provided
api_key = hf_token_value or os.environ.get("SAMBANOVA_API_KEY")
if not api_key:
raise ValueError("API token is not provided.")
# Initialize the OpenAI client
client = OpenAI(
base_url="https://api.sambanova.ai/v1/",
api_key=api_key,
)
try:
# Create the chat completion
completion = client.chat.completions.create(
model=model_name,
messages=final_messages,
stream=True,
)
response_text = ""
for chunk in completion:
delta = chunk.choices[0].delta.content or ""
response_text += delta
yield response_text
except Exception as e:
error_message = f"Error: {str(e)}"
yield error_message
# Create the ChatInterface
chat_interface = gr.ChatInterface(
fn=get_fn,
chatbot=gr.Chatbot(
label="Chatbot",
scale=1,
height=400,
autoscroll=True
),
additional_inputs=[paper_content, hf_token_input],
# examples=["What are the main findings of this paper?", "Explain the methodology used in this research."]
)
return chat_interface