Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
import networkx as nx | |
from pyvis.network import Network | |
import tempfile | |
import openai | |
import requests | |
import xml.etree.ElementTree as ET | |
import pandas as pd | |
from io import StringIO | |
import asyncio | |
import base64 | |
# --------------------------- | |
# Model Loading & Caching | |
# --------------------------- | |
def load_summarizer(): | |
# Load a summarization pipeline from Hugging Face (e.g., facebook/bart-large-cnn) | |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
return summarizer | |
def load_text_generator(): | |
# For demonstration, we load a text-generation model such as GPT-2. | |
generator = pipeline("text-generation", model="gpt2") | |
return generator | |
summarizer = load_summarizer() | |
generator = load_text_generator() | |
# --------------------------- | |
# Idea Generation Functions | |
# --------------------------- | |
def generate_ideas_with_hf(prompt): | |
# Generate ideas using a Hugging Face model; new tokens beyond the prompt. | |
results = generator(prompt, max_new_tokens=50, num_return_sequences=1) | |
idea_text = results[0]['generated_text'] | |
return idea_text | |
def generate_ideas_with_openai(prompt, api_key): | |
""" | |
Generates research ideas using OpenAI's GPT-3.5 (Streaming). | |
""" | |
openai.api_key = api_key | |
output_text = "" | |
async def stream_chat(): | |
nonlocal output_text | |
response = await openai.ChatCompletion.acreate( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are an expert AI research assistant who generates innovative research ideas."}, | |
{"role": "user", "content": prompt}, | |
], | |
stream=True, | |
) | |
st_text = st.empty() # Placeholder for streaming output | |
async for chunk in response: | |
delta = chunk["choices"][0].get("delta", {}) | |
text_piece = delta.get("content", "") | |
output_text += text_piece | |
st_text.text(output_text) | |
asyncio.run(stream_chat()) | |
return output_text | |
# --------------------------- | |
# arXiv API Integration using xml.etree.ElementTree | |
# --------------------------- | |
def fetch_arxiv_results(query, max_results=5): | |
""" | |
Queries arXiv's free API and parses the result using ElementTree. | |
""" | |
base_url = "http://export.arxiv.org/api/query?" | |
search_query = "search_query=all:" + query | |
start = "0" | |
max_results_str = str(max_results) | |
query_url = f"{base_url}{search_query}&start={start}&max_results={max_results_str}" | |
response = requests.get(query_url) | |
results = [] | |
if response.status_code == 200: | |
root = ET.fromstring(response.content) | |
ns = {"atom": "http://www.w3.org/2005/Atom"} | |
for entry in root.findall("atom:entry", ns): | |
title_elem = entry.find("atom:title", ns) | |
title = title_elem.text.strip() if title_elem is not None else "" | |
summary_elem = entry.find("atom:summary", ns) | |
summary = summary_elem.text.strip() if summary_elem is not None else "" | |
published_elem = entry.find("atom:published", ns) | |
published = published_elem.text.strip() if published_elem is not None else "" | |
link_elem = entry.find("atom:id", ns) | |
link = link_elem.text.strip() if link_elem is not None else "" | |
authors = [author.find("atom:name", ns).text.strip() | |
for author in entry.findall("atom:author", ns) | |
if author.find("atom:name", ns) is not None] | |
results.append({ | |
"title": title, | |
"summary": summary, | |
"published": published, | |
"link": link, | |
"authors": ", ".join(authors) | |
}) | |
return results | |
else: | |
return [] | |
# --------------------------- | |
# Utility Function: Graph Download Link | |
# --------------------------- | |
def get_download_link(file_path, filename="graph.html"): | |
"""Converts the HTML file to a downloadable link.""" | |
with open(file_path, "r", encoding="utf-8") as f: | |
html_data = f.read() | |
b64 = base64.b64encode(html_data.encode()).decode() | |
href = f'<a href="data:text/html;base64,{b64}" download="{filename}">Download Graph as HTML</a>' | |
return href | |
# --------------------------- | |
# Streamlit Application Layout | |
# --------------------------- | |
st.title("Graph of AI Ideas Application with arXiv Integration and OpenAI SDK v1.0") | |
# Sidebar: Configuration and Layout Options | |
st.sidebar.header("Configuration") | |
generation_mode = st.sidebar.selectbox("Select Idea Generation Mode", | |
["Hugging Face Open Source", "OpenAI GPT-3.5 (Streaming)"]) | |
openai_api_key = st.sidebar.text_input("OpenAI API Key (for GPT-3.5 Streaming)", type="password") | |
layout_option = st.sidebar.selectbox("Select Graph Layout", ["Default", "Force Atlas 2"]) | |
# --- Section 1: arXiv Paper Search --- | |
st.header("arXiv Paper Search") | |
arxiv_query = st.text_input("Enter a search query for arXiv papers:") | |
if st.button("Search arXiv"): | |
if arxiv_query.strip(): | |
with st.spinner("Searching arXiv..."): | |
results = fetch_arxiv_results(arxiv_query, max_results=5) | |
if results: | |
st.subheader("arXiv Search Results:") | |
for idx, paper in enumerate(results): | |
st.markdown(f"**{idx+1}. {paper['title']}**") | |
st.markdown(f"*Authors:* {paper['authors']}") | |
st.markdown(f"*Published:* {paper['published']}") | |
st.markdown(f"*Summary:* {paper['summary']}") | |
st.markdown(f"[Read more]({paper['link']})") | |
st.markdown("---") | |
else: | |
st.error("No results found or an error occurred with the arXiv API.") | |
else: | |
st.error("Please enter a valid query for the arXiv search.") | |
# --- Section 2: Research Paper Input and Idea Generation --- | |
st.header("Research Paper Input") | |
paper_abstract = st.text_area("Enter the research paper abstract:", height=200) | |
if st.button("Generate Ideas"): | |
if paper_abstract.strip(): | |
st.subheader("Summarized Abstract") | |
summary = summarizer(paper_abstract, max_length=100, min_length=30, do_sample=False) | |
summary_text = summary[0]['summary_text'] | |
st.write(summary_text) | |
st.subheader("Generated Research Ideas") | |
prompt = ( | |
f"Based on the following research paper abstract, generate innovative and promising research ideas for future work.\n\n" | |
f"Paper Abstract:\n{paper_abstract}\n\n" | |
f"Summary:\n{summary_text}\n\n" | |
f"Research Ideas:" | |
) | |
if generation_mode == "OpenAI GPT-3.5 (Streaming)": | |
if not openai_api_key.strip(): | |
st.error("Please provide your OpenAI API Key in the sidebar.") | |
else: | |
with st.spinner("Generating ideas using OpenAI GPT-3.5 with SDK v1.0..."): | |
ideas = generate_ideas_with_openai(prompt, openai_api_key) | |
st.write(ideas) | |
else: | |
with st.spinner("Generating ideas using Hugging Face open source model..."): | |
ideas = generate_ideas_with_hf(prompt) | |
st.write(ideas) | |
else: | |
st.error("Please enter a research paper abstract.") | |
# --- Section 3: Knowledge Graph Visualization with Additional Features --- | |
st.header("Knowledge Graph Visualization") | |
st.markdown( | |
"Enter paper details and citation relationships in CSV format:\n\n" | |
"**PaperID,Title,CitedPaperIDs** (CitedPaperIDs separated by ';').\n\n" | |
"Example:\n\n```\n1,Graph of AI Ideas: Leveraging Knowledge Graphs and LLMs for AI Research Idea Generation,2;3\n2,Fundamental Approaches in AI Literature,\n3,Applications of LLMs in Research Idea Generation,2\n```" | |
) | |
# Optional filter input for node titles. | |
filter_text = st.text_input("Optional: Enter keyword to filter nodes in the graph:") | |
papers_csv = st.text_area("Enter paper details in CSV format:", height=150) | |
if st.button("Generate Knowledge Graph"): | |
if papers_csv.strip(): | |
data = [] | |
for line in papers_csv.splitlines(): | |
parts = line.split(',') | |
if len(parts) >= 3: | |
paper_id = parts[0].strip() | |
title = parts[1].strip() | |
cited = parts[2].strip() | |
cited_list = [c.strip() for c in cited.split(';') if c.strip()] | |
data.append({"paper_id": paper_id, "title": title, "cited": cited_list}) | |
if data: | |
# Build the full graph. | |
G = nx.DiGraph() | |
for paper in data: | |
G.add_node(paper["paper_id"], title=paper.get("title", str(paper["paper_id"]))) | |
for cited in paper["cited"]: | |
G.add_edge(paper["paper_id"], cited) | |
# Filter nodes if a keyword is provided. | |
if filter_text.strip(): | |
filtered_nodes = [n for n, d in G.nodes(data=True) if filter_text.lower() in d.get("title", "").lower()] | |
if filtered_nodes: | |
H = G.subgraph(filtered_nodes).copy() | |
else: | |
H = nx.DiGraph() | |
else: | |
H = G | |
st.subheader("Knowledge Graph") | |
# Create the Pyvis network. | |
net = Network(height="500px", width="100%", directed=True) | |
# Add nodes with tooltips (show title on hover). | |
for node, node_data in H.nodes(data=True): | |
net.add_node(node, label=node_data.get("title", str(node)), title=node_data.get("title", "No Title")) | |
for source, target in H.edges(): | |
net.add_edge(source, target) | |
# Apply layout based on the user's selection. | |
if layout_option == "Force Atlas 2": | |
net.force_atlas_2based() | |
# Write graph to temporary HTML file. | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html") | |
net.write_html(temp_file.name) | |
# Show the graph. | |
with open(temp_file.name, 'r', encoding='utf-8') as f: | |
html_content = f.read() | |
st.components.v1.html(html_content, height=500) | |
# Provide a download link for the graph. | |
st.markdown(get_download_link(temp_file.name), unsafe_allow_html=True) | |
else: | |
st.error("Please enter paper details for the knowledge graph.") | |