Spaces:
Running
Running
import gradio as gr | |
from PyPDF2 import PdfReader | |
from bs4 import BeautifulSoup | |
import openai | |
import traceback | |
import requests | |
from io import BytesIO | |
from transformers import AutoTokenizer | |
import json | |
import os | |
from openai import OpenAI | |
# Cache for tokenizers to avoid reloading | |
tokenizer_cache = {} | |
# Global variables for providers | |
PROVIDERS = { | |
"SambaNova": { | |
"name": "SambaNova", | |
"logo": "https://venturebeat.com/wp-content/uploads/2020/02/SambaNovaLogo_H_F.jpg", | |
"endpoint": "https://api.sambanova.ai/v1/", | |
"api_key_env_var": "SAMBANOVA_API_KEY", | |
"models": [ | |
"Meta-Llama-3.1-70B-Instruct", | |
# Add more models if needed | |
], | |
"type": "tuples", | |
"max_total_tokens": "50000", | |
}, | |
"Hyperbolic": { | |
"name": "hyperbolic", | |
"logo": "https://www.nftgators.com/wp-content/uploads/2024/07/Hyperbolic.jpg", | |
"endpoint": "https://api.hyperbolic.xyz/v1", | |
"api_key_env_var": "HYPERBOLIC_API_KEY", | |
"models": [ | |
"meta-llama/Meta-Llama-3.1-405B-Instruct", | |
], | |
"type": "tuples", | |
"max_total_tokens": "50000", | |
}, | |
} | |
# Function to fetch paper information from OpenReview | |
def fetch_paper_info_neurips(paper_id): | |
url = f"https://openreview.net/forum?id={paper_id}" | |
response = requests.get(url) | |
if response.status_code != 200: | |
return None | |
html_content = response.content | |
soup = BeautifulSoup(html_content, 'html.parser') | |
# Extract title | |
title_tag = soup.find('h2', class_='citation_title') | |
title = title_tag.get_text(strip=True) if title_tag else 'Title not found' | |
# Extract authors | |
authors = [] | |
author_div = soup.find('div', class_='forum-authors') | |
if author_div: | |
author_tags = author_div.find_all('a') | |
authors = [tag.get_text(strip=True) for tag in author_tags] | |
author_list = ', '.join(authors) if authors else 'Authors not found' | |
# Extract abstract | |
abstract_div = soup.find('strong', text='Abstract:') | |
if abstract_div: | |
abstract_paragraph = abstract_div.find_next_sibling('div') | |
abstract = abstract_paragraph.get_text(strip=True) if abstract_paragraph else 'Abstract not found' | |
else: | |
abstract = 'Abstract not found' | |
# Construct preamble in Markdown | |
preamble = f"**[{title}](https://openreview.net/forum?id={paper_id})**\n\n{author_list}\n\n" | |
return preamble | |
def fetch_paper_content_arxiv(paper_id): | |
try: | |
# Construct the URL for the arXiv PDF | |
url = f"https://arxiv.org/pdf/{paper_id}.pdf" | |
# Fetch the PDF | |
response = requests.get(url) | |
response.raise_for_status() # Raise an exception for HTTP errors | |
# Read the PDF content | |
pdf_content = BytesIO(response.content) | |
reader = PdfReader(pdf_content) | |
# Extract text from the PDF | |
text = "" | |
for page in reader.pages: | |
text += page.extract_text() | |
return text # Return full text; truncation will be handled later | |
except Exception as e: | |
print(f"Error fetching paper content: {e}") | |
return None | |
def fetch_paper_content(paper_id): | |
try: | |
# Construct the URL | |
url = f"https://openreview.net/pdf?id={paper_id}" | |
# Fetch the PDF | |
response = requests.get(url) | |
response.raise_for_status() # Raise an exception for HTTP errors | |
# Read the PDF content | |
pdf_content = BytesIO(response.content) | |
reader = PdfReader(pdf_content) | |
# Extract text from the PDF | |
text = "" | |
for page in reader.pages: | |
text += page.extract_text() | |
return text # Return full text; truncation will be handled later | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
return None | |
def create_chat_interface(provider_dropdown, model_dropdown, paper_content, hf_token_input, default_type, | |
provider_max_total_tokens): | |
# Define the function to handle the chat | |
print("the type is", default_type.value) | |
def get_fn(message, history, paper_content_value, hf_token_value, provider_name_value, model_name_value, | |
max_total_tokens): | |
provider_info = PROVIDERS[provider_name_value] | |
endpoint = provider_info['endpoint'] | |
api_key_env_var = provider_info['api_key_env_var'] | |
models = provider_info['models'] | |
max_total_tokens = int(max_total_tokens) | |
# Load tokenizer and cache it | |
tokenizer_key = f"{provider_name_value}_{model_name_value}" | |
if tokenizer_key not in tokenizer_cache: | |
# Load the tokenizer; adjust the model path based on the provider and model | |
# This is a placeholder; you need to provide the correct tokenizer path | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", | |
token=os.environ.get("HF_TOKEN")) | |
tokenizer_cache[tokenizer_key] = tokenizer | |
else: | |
tokenizer = tokenizer_cache[tokenizer_key] | |
# Include the paper content as context | |
if paper_content_value: | |
context = f"The discussion is about the following paper:\n{paper_content_value}\n\n" | |
else: | |
context = "" | |
# Tokenize the context | |
context_tokens = tokenizer.encode(context) | |
context_token_length = len(context_tokens) | |
# Prepare the messages without context | |
messages = [] | |
message_tokens_list = [] | |
total_tokens = context_token_length # Start with context tokens | |
for user_msg, assistant_msg in history: | |
# Tokenize user message | |
user_tokens = tokenizer.encode(user_msg) | |
messages.append({"role": "user", "content": user_msg}) | |
message_tokens_list.append(len(user_tokens)) | |
total_tokens += len(user_tokens) | |
# Tokenize assistant message | |
if assistant_msg: | |
assistant_tokens = tokenizer.encode(assistant_msg) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
message_tokens_list.append(len(assistant_tokens)) | |
total_tokens += len(assistant_tokens) | |
# Tokenize the new user message | |
message_tokens = tokenizer.encode(message) | |
messages.append({"role": "user", "content": message}) | |
message_tokens_list.append(len(message_tokens)) | |
total_tokens += len(message_tokens) | |
# Check if total tokens exceed the maximum allowed tokens | |
if total_tokens > max_total_tokens: | |
# Attempt to truncate the context first | |
available_tokens = max_total_tokens - (total_tokens - context_token_length) | |
if available_tokens > 0: | |
# Truncate the context to fit the available tokens | |
truncated_context_tokens = context_tokens[:available_tokens] | |
context = tokenizer.decode(truncated_context_tokens) | |
context_token_length = available_tokens | |
total_tokens = total_tokens - len(context_tokens) + context_token_length | |
else: | |
# Not enough space for context; remove it | |
context = "" | |
total_tokens -= context_token_length | |
context_token_length = 0 | |
# If total tokens still exceed the limit, truncate the message history | |
while total_tokens > max_total_tokens and len(messages) > 1: | |
# Remove the oldest message | |
removed_message = messages.pop(0) | |
removed_tokens = message_tokens_list.pop(0) | |
total_tokens -= removed_tokens | |
# Rebuild the final messages list including the (possibly truncated) context | |
final_messages = [] | |
if context: | |
final_messages.append( | |
{"role": "system", "content": f"{context}"}) | |
final_messages.extend(messages) | |
# Use the provider's API key | |
api_key = hf_token_value or os.environ.get(api_key_env_var) | |
if not api_key: | |
raise ValueError("API token is not provided.") | |
# Initialize the OpenAI client with the provider's endpoint | |
client = OpenAI( | |
base_url=endpoint, | |
api_key=api_key, | |
) | |
try: | |
# Create the chat completion | |
completion = client.chat.completions.create( | |
model=model_name_value, | |
messages=final_messages, | |
stream=True, | |
) | |
response_text = "" | |
for chunk in completion: | |
delta = chunk.choices[0].delta.content or "" | |
response_text += delta | |
yield response_text | |
except json.JSONDecodeError as e: | |
print("Failed to decode JSON during the completion creation process.") | |
print(f"Error Message: {e.msg}") | |
print(f"Error Position: Line {e.lineno}, Column {e.colno} (Character {e.pos})") | |
print(f"Problematic JSON Data: {e.doc}") | |
yield f"{e.doc}" | |
except openai.OpenAIError as openai_err: | |
# Handle other OpenAI-related errors | |
print(f"An OpenAI error occurred: {openai_err}") | |
yield f"{openai_err}" | |
except Exception as ex: | |
# Handle any other exceptions | |
print(f"An unexpected error occurred: {ex}") | |
yield f"{ex}" | |
# Create the Chatbot separately to access it later | |
chatbot = gr.Chatbot( | |
label="Chatbot", | |
scale=1, | |
height=400, | |
autoscroll=True, | |
) | |
# Create the ChatInterface | |
chat_interface = gr.ChatInterface( | |
fn=get_fn, | |
chatbot=chatbot, | |
additional_inputs=[paper_content, hf_token_input, provider_dropdown, model_dropdown, provider_max_total_tokens], | |
type="tuples", | |
) | |
return chat_interface, chatbot | |
def paper_chat_tab(paper_id, paper_from): | |
with gr.Column(): | |
# Preamble message to hint the user | |
gr.Markdown("**Note:** Providing your own API token can help you avoid rate limits.") | |
# Input for API token | |
provider_names = list(PROVIDERS.keys()) | |
default_provider = provider_names[0] | |
default_type = gr.State(value=PROVIDERS[default_provider]["type"]) | |
default_max_total_tokens = gr.State(value=PROVIDERS[default_provider]["max_total_tokens"]) | |
provider_dropdown = gr.Dropdown( | |
label="Select Provider", | |
choices=provider_names, | |
value=default_provider | |
) | |
hf_token_input = gr.Textbox( | |
label=f"Enter your {default_provider} API token (optional)", | |
type="password", | |
placeholder=f"Enter your {default_provider} API token to avoid rate limits" | |
) | |
# Dropdown for selecting the model | |
model_dropdown = gr.Dropdown( | |
label="Select Model", | |
choices=PROVIDERS[default_provider]['models'], | |
value=PROVIDERS[default_provider]['models'][0] | |
) | |
# Placeholder for the provider logo | |
logo_html = gr.HTML( | |
value=f'<img src="{PROVIDERS[default_provider]["logo"]}" width="100px" />' | |
) | |
# Note about the provider | |
note_markdown = gr.Markdown(f"**Note:** This model is supported by {default_provider}.") | |
# State to store the paper content | |
paper_content = gr.State() | |
# Textbox to display the paper title and authors | |
content = gr.Markdown(value="") | |
# Create the chat interface and get the chatbot component | |
chat_interface, chatbot = create_chat_interface(provider_dropdown, model_dropdown, paper_content, | |
hf_token_input, | |
default_type, default_max_total_tokens) | |
# Function to update models and logo when provider changes | |
def update_provider(selected_provider): | |
provider_info = PROVIDERS[selected_provider] | |
models = provider_info['models'] | |
logo_url = provider_info['logo'] | |
chatbot_message_type = provider_info['type'] | |
max_total_tokens = provider_info['max_total_tokens'] | |
# Update the models dropdown | |
model_dropdown_choices = gr.update(choices=models, value=models[0]) | |
# Update the logo image | |
logo_html_content = f'<img src="{logo_url}" width="100px" />' | |
logo_html_update = gr.update(value=logo_html_content) | |
# Update the note markdown | |
note_markdown_update = gr.update(value=f"**Note:** This model is supported by {selected_provider}.") | |
# Update the hf_token_input label and placeholder | |
hf_token_input_update = gr.update( | |
label=f"Enter your {selected_provider} API token (optional)", | |
placeholder=f"Enter your {selected_provider} API token to avoid rate limits" | |
) | |
# Reset the chatbot history | |
chatbot_reset = [] # This resets the chatbot conversation | |
return model_dropdown_choices, logo_html_update, note_markdown_update, hf_token_input_update, chatbot_message_type, max_total_tokens, chatbot_reset | |
provider_dropdown.change( | |
fn=update_provider, | |
inputs=provider_dropdown, | |
outputs=[model_dropdown, logo_html, note_markdown, hf_token_input, default_type, default_max_total_tokens, | |
chatbot], | |
queue=False | |
) | |
# Function to update the paper info | |
def update_paper_info(paper_id_value, paper_from_value, selected_model): | |
if paper_from_value == "neurips": | |
preamble = fetch_paper_info_neurips(paper_id_value) | |
text = fetch_paper_content(paper_id_value) | |
if preamble is None: | |
preamble = "Paper not found or could not retrieve paper information." | |
if text is None: | |
return preamble, None, [] | |
return preamble, text, [] | |
elif paper_from_value == "paper_page": | |
# Fetch the paper information from Hugging Face API | |
url = f"https://huggingface.co/api/papers/{paper_id_value}?field=comments" | |
response = requests.get(url) | |
if response.status_code != 200: | |
return "Paper not found or could not retrieve paper information.", None, [] | |
paper_info = response.json() | |
# Extract required information | |
title = paper_info.get('title', 'No Title') | |
link = f"https://huggingface.co/papers/{paper_id_value}" | |
authors_list = [author.get('name', 'Unknown') for author in paper_info.get('authors', [])] | |
authors = ', '.join(authors_list) | |
summary = paper_info.get('summary', 'No Summary') | |
num_comments = len(paper_info.get('comments', [])) | |
num_upvotes = paper_info.get('upvotes', 0) | |
# Format the preamble | |
preamble = f"🤗 [paper-page]({link})<br/>" | |
preamble += f"**Title:** {title}<br/>" | |
preamble += f"**Authors:** {authors}<br/>" | |
preamble += f"**Summary:**<br/>>\n{summary}<br/>" | |
preamble += f"👍{num_comments} 💬{num_upvotes} <br/>" | |
# Fetch the paper content | |
text = fetch_paper_content_arxiv(paper_id_value) | |
if text is None: | |
text = "Paper content could not be retrieved." | |
return preamble, text, [] | |
else: | |
return "", "", [] | |
# Update paper content when paper ID changes | |
paper_id.change( | |
fn=update_paper_info, | |
inputs=[paper_id, paper_from, model_dropdown], | |
outputs=[content, paper_content, chatbot] | |
) | |
def main(): | |
""" | |
Launches the Gradio app. | |
""" | |
with gr.Blocks(css_paths="style.css") as demo: | |
# Create an input for paper_id | |
paper_id = gr.Textbox(label="Paper ID", value="") | |
# Create an input for paper_from (e.g., 'neurips' or 'paper_page') | |
paper_from = gr.Radio( | |
label="Paper Source", | |
choices=["neurips", "paper_page"], | |
value="neurips" | |
) | |
# Build the paper chat tab | |
paper_chat_tab(paper_id, paper_from) | |
demo.launch(ssr_mode=False) | |
# Run the main function when the script is executed | |
if __name__ == "__main__": | |
main() | |