Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
import networkx as nx | |
from pyvis.network import Network | |
import tempfile | |
import openai | |
import requests | |
import xml.etree.ElementTree as ET | |
import pandas as pd | |
from io import StringIO | |
import asyncio | |
# --------------------------- | |
# Model Loading & Caching | |
# --------------------------- | |
def load_summarizer(): | |
# Load a summarization pipeline from Hugging Face (e.g., facebook/bart-large-cnn) | |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
return summarizer | |
def load_text_generator(): | |
# For demonstration, we load a text-generation model such as GPT-2. | |
generator = pipeline("text-generation", model="gpt2") | |
return generator | |
summarizer = load_summarizer() | |
generator = load_text_generator() | |
# --------------------------- | |
# Idea Generation Functions | |
# --------------------------- | |
def generate_ideas_with_hf(prompt): | |
# Use Hugging Face's text-generation pipeline. | |
# We use max_new_tokens so that new tokens are generated beyond the prompt. | |
results = generator(prompt, max_new_tokens=50, num_return_sequences=1) | |
idea_text = results[0]['generated_text'] | |
return idea_text | |
def generate_ideas_with_openai(prompt, api_key): | |
""" | |
Generates research ideas using OpenAI's GPT-3.5 model with streaming. | |
This function uses the latest OpenAI SDK v1.0 and asynchronous API calls. | |
""" | |
openai.api_key = api_key | |
output_text = "" | |
async def stream_chat(): | |
nonlocal output_text | |
# Asynchronously call the chat completion endpoint with streaming enabled. | |
response = await openai.ChatCompletion.acreate( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are an expert AI research assistant who generates innovative research ideas."}, | |
{"role": "user", "content": prompt}, | |
], | |
stream=True, | |
) | |
st_text = st.empty() # Placeholder for streaming output | |
async for chunk in response: | |
delta = chunk["choices"][0].get("delta", {}) | |
text_piece = delta.get("content", "") | |
output_text += text_piece | |
st_text.text(output_text) | |
asyncio.run(stream_chat()) | |
return output_text | |
# --------------------------- | |
# arXiv API Integration using xml.etree.ElementTree | |
# --------------------------- | |
def fetch_arxiv_results(query, max_results=5): | |
""" | |
Queries arXiv's free API to fetch relevant papers using XML parsing. | |
""" | |
base_url = "http://export.arxiv.org/api/query?" | |
search_query = "search_query=all:" + query | |
start = "0" | |
max_results_str = str(max_results) | |
query_url = f"{base_url}{search_query}&start={start}&max_results={max_results_str}" | |
response = requests.get(query_url) | |
results = [] | |
if response.status_code == 200: | |
root = ET.fromstring(response.content) | |
ns = {"atom": "http://www.w3.org/2005/Atom"} | |
for entry in root.findall("atom:entry", ns): | |
title_elem = entry.find("atom:title", ns) | |
title = title_elem.text.strip() if title_elem is not None else "" | |
summary_elem = entry.find("atom:summary", ns) | |
summary = summary_elem.text.strip() if summary_elem is not None else "" | |
published_elem = entry.find("atom:published", ns) | |
published = published_elem.text.strip() if published_elem is not None else "" | |
link_elem = entry.find("atom:id", ns) | |
link = link_elem.text.strip() if link_elem is not None else "" | |
authors = [] | |
for author in entry.findall("atom:author", ns): | |
name_elem = author.find("atom:name", ns) | |
if name_elem is not None: | |
authors.append(name_elem.text.strip()) | |
authors_str = ", ".join(authors) | |
results.append({ | |
"title": title, | |
"summary": summary, | |
"published": published, | |
"link": link, | |
"authors": authors_str | |
}) | |
return results | |
else: | |
return [] | |
# --------------------------- | |
# Streamlit Application Layout | |
# --------------------------- | |
st.title("Graph of AI Ideas Application with arXiv Integration and OpenAI SDK v1.0") | |
# Sidebar Configuration | |
st.sidebar.header("Configuration") | |
generation_mode = st.sidebar.selectbox("Select Idea Generation Mode", | |
["Hugging Face Open Source", "OpenAI GPT-3.5 (Streaming)"]) | |
openai_api_key = st.sidebar.text_input("OpenAI API Key (for GPT-3.5 Streaming)", type="password") | |
# --- Section 1: arXiv Paper Search --- | |
st.header("arXiv Paper Search") | |
arxiv_query = st.text_input("Enter a search query for arXiv papers:") | |
if st.button("Search arXiv"): | |
if arxiv_query.strip(): | |
with st.spinner("Searching arXiv..."): | |
results = fetch_arxiv_results(arxiv_query, max_results=5) | |
if results: | |
st.subheader("arXiv Search Results:") | |
for idx, paper in enumerate(results): | |
st.markdown(f"**{idx+1}. {paper['title']}**") | |
st.markdown(f"*Authors:* {paper['authors']}") | |
st.markdown(f"*Published:* {paper['published']}") | |
st.markdown(f"*Summary:* {paper['summary']}") | |
st.markdown(f"[Read more]({paper['link']})") | |
st.markdown("---") | |
else: | |
st.error("No results found or an error occurred with the arXiv API.") | |
else: | |
st.error("Please enter a valid query for the arXiv search.") | |
# --- Section 2: Research Paper Input and Idea Generation --- | |
st.header("Research Paper Input") | |
paper_abstract = st.text_area("Enter the research paper abstract:", height=200) | |
if st.button("Generate Ideas"): | |
if paper_abstract.strip(): | |
st.subheader("Summarized Abstract") | |
# Summarize the abstract to capture its key points. | |
summary = summarizer(paper_abstract, max_length=100, min_length=30, do_sample=False) | |
summary_text = summary[0]['summary_text'] | |
st.write(summary_text) | |
st.subheader("Generated Research Ideas") | |
# Build a combined prompt with the abstract and its summary. | |
prompt = ( | |
f"Based on the following research paper abstract, generate innovative and promising research ideas for future work.\n\n" | |
f"Paper Abstract:\n{paper_abstract}\n\n" | |
f"Summary:\n{summary_text}\n\n" | |
f"Research Ideas:" | |
) | |
if generation_mode == "OpenAI GPT-3.5 (Streaming)": | |
if not openai_api_key.strip(): | |
st.error("Please provide your OpenAI API Key in the sidebar.") | |
else: | |
with st.spinner("Generating ideas using OpenAI GPT-3.5 with SDK v1.0..."): | |
ideas = generate_ideas_with_openai(prompt, openai_api_key) | |
st.write(ideas) | |
else: | |
with st.spinner("Generating ideas using Hugging Face open source model..."): | |
ideas = generate_ideas_with_hf(prompt) | |
st.write(ideas) | |
else: | |
st.error("Please enter a research paper abstract.") | |
# --- Section 3: Knowledge Graph Visualization --- | |
st.header("Knowledge Graph Visualization") | |
st.markdown( | |
"Simulate a knowledge graph by entering paper details and their citation relationships in CSV format:\n\n" | |
"**PaperID,Title,CitedPaperIDs** (CitedPaperIDs separated by ';').\n\n" | |
"Example:\n\n```\n1,Paper A,2;3\n2,Paper B,\n3,Paper C,2\n```" | |
) | |
papers_csv = st.text_area("Enter paper details in CSV format:", height=150) | |
if st.button("Generate Knowledge Graph"): | |
if papers_csv.strip(): | |
data = [] | |
for line in papers_csv.splitlines(): | |
parts = line.split(',') | |
if len(parts) >= 3: | |
paper_id = parts[0].strip() | |
title = parts[1].strip() | |
cited = parts[2].strip() | |
cited_list = [c.strip() for c in cited.split(';') if c.strip()] | |
data.append({"paper_id": paper_id, "title": title, "cited": cited_list}) | |
if data: | |
# Build a directed graph using NetworkX. | |
G = nx.DiGraph() | |
for paper in data: | |
G.add_node(paper["paper_id"], title=paper.get("title", str(paper["paper_id"]))) | |
for cited in paper["cited"]: | |
G.add_edge(paper["paper_id"], cited) | |
st.subheader("Knowledge Graph") | |
# Create an interactive visualization using Pyvis. | |
net = Network(height="500px", width="100%", directed=True) | |
for node, node_data in G.nodes(data=True): | |
net.add_node(node, label=node_data.get("title", str(node))) | |
for source, target in G.edges(): | |
net.add_edge(source, target) | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html") | |
net.write_html(temp_file.name) | |
with open(temp_file.name, 'r', encoding='utf-8') as f: | |
html_content = f.read() | |
st.components.v1.html(html_content, height=500) | |
else: | |
st.error("Please enter paper details for the knowledge graph.") | |