Researcher / app.py
mgbam's picture
Update app.py
834ac1a verified
raw
history blame
9.45 kB
import streamlit as st
from transformers import pipeline
import networkx as nx
from pyvis.network import Network
import tempfile
import openai
import requests
import xml.etree.ElementTree as ET
import pandas as pd
from io import StringIO
import asyncio
# ---------------------------
# Model Loading & Caching
# ---------------------------
@st.cache_resource(show_spinner=False)
def load_summarizer():
# Load a summarization pipeline from Hugging Face (e.g., facebook/bart-large-cnn)
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
return summarizer
@st.cache_resource(show_spinner=False)
def load_text_generator():
# For demonstration, we load a text-generation model such as GPT-2.
generator = pipeline("text-generation", model="gpt2")
return generator
summarizer = load_summarizer()
generator = load_text_generator()
# ---------------------------
# Idea Generation Functions
# ---------------------------
def generate_ideas_with_hf(prompt):
# Use Hugging Face's text-generation pipeline.
# We use max_new_tokens so that new tokens are generated beyond the prompt.
results = generator(prompt, max_new_tokens=50, num_return_sequences=1)
idea_text = results[0]['generated_text']
return idea_text
def generate_ideas_with_openai(prompt, api_key):
"""
Generates research ideas using OpenAI's GPT-3.5 model with streaming.
This function uses the latest OpenAI SDK v1.0 and asynchronous API calls.
"""
openai.api_key = api_key
output_text = ""
async def stream_chat():
nonlocal output_text
# Asynchronously call the chat completion endpoint with streaming enabled.
response = await openai.ChatCompletion.acreate(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are an expert AI research assistant who generates innovative research ideas."},
{"role": "user", "content": prompt},
],
stream=True,
)
st_text = st.empty() # Placeholder for streaming output
async for chunk in response:
delta = chunk["choices"][0].get("delta", {})
text_piece = delta.get("content", "")
output_text += text_piece
st_text.text(output_text)
asyncio.run(stream_chat())
return output_text
# ---------------------------
# arXiv API Integration using xml.etree.ElementTree
# ---------------------------
def fetch_arxiv_results(query, max_results=5):
"""
Queries arXiv's free API to fetch relevant papers using XML parsing.
"""
base_url = "http://export.arxiv.org/api/query?"
search_query = "search_query=all:" + query
start = "0"
max_results_str = str(max_results)
query_url = f"{base_url}{search_query}&start={start}&max_results={max_results_str}"
response = requests.get(query_url)
results = []
if response.status_code == 200:
root = ET.fromstring(response.content)
ns = {"atom": "http://www.w3.org/2005/Atom"}
for entry in root.findall("atom:entry", ns):
title_elem = entry.find("atom:title", ns)
title = title_elem.text.strip() if title_elem is not None else ""
summary_elem = entry.find("atom:summary", ns)
summary = summary_elem.text.strip() if summary_elem is not None else ""
published_elem = entry.find("atom:published", ns)
published = published_elem.text.strip() if published_elem is not None else ""
link_elem = entry.find("atom:id", ns)
link = link_elem.text.strip() if link_elem is not None else ""
authors = []
for author in entry.findall("atom:author", ns):
name_elem = author.find("atom:name", ns)
if name_elem is not None:
authors.append(name_elem.text.strip())
authors_str = ", ".join(authors)
results.append({
"title": title,
"summary": summary,
"published": published,
"link": link,
"authors": authors_str
})
return results
else:
return []
# ---------------------------
# Streamlit Application Layout
# ---------------------------
st.title("Graph of AI Ideas Application with arXiv Integration and OpenAI SDK v1.0")
# Sidebar Configuration
st.sidebar.header("Configuration")
generation_mode = st.sidebar.selectbox("Select Idea Generation Mode",
["Hugging Face Open Source", "OpenAI GPT-3.5 (Streaming)"])
openai_api_key = st.sidebar.text_input("OpenAI API Key (for GPT-3.5 Streaming)", type="password")
# --- Section 1: arXiv Paper Search ---
st.header("arXiv Paper Search")
arxiv_query = st.text_input("Enter a search query for arXiv papers:")
if st.button("Search arXiv"):
if arxiv_query.strip():
with st.spinner("Searching arXiv..."):
results = fetch_arxiv_results(arxiv_query, max_results=5)
if results:
st.subheader("arXiv Search Results:")
for idx, paper in enumerate(results):
st.markdown(f"**{idx+1}. {paper['title']}**")
st.markdown(f"*Authors:* {paper['authors']}")
st.markdown(f"*Published:* {paper['published']}")
st.markdown(f"*Summary:* {paper['summary']}")
st.markdown(f"[Read more]({paper['link']})")
st.markdown("---")
else:
st.error("No results found or an error occurred with the arXiv API.")
else:
st.error("Please enter a valid query for the arXiv search.")
# --- Section 2: Research Paper Input and Idea Generation ---
st.header("Research Paper Input")
paper_abstract = st.text_area("Enter the research paper abstract:", height=200)
if st.button("Generate Ideas"):
if paper_abstract.strip():
st.subheader("Summarized Abstract")
# Summarize the abstract to capture its key points.
summary = summarizer(paper_abstract, max_length=100, min_length=30, do_sample=False)
summary_text = summary[0]['summary_text']
st.write(summary_text)
st.subheader("Generated Research Ideas")
# Build a combined prompt with the abstract and its summary.
prompt = (
f"Based on the following research paper abstract, generate innovative and promising research ideas for future work.\n\n"
f"Paper Abstract:\n{paper_abstract}\n\n"
f"Summary:\n{summary_text}\n\n"
f"Research Ideas:"
)
if generation_mode == "OpenAI GPT-3.5 (Streaming)":
if not openai_api_key.strip():
st.error("Please provide your OpenAI API Key in the sidebar.")
else:
with st.spinner("Generating ideas using OpenAI GPT-3.5 with SDK v1.0..."):
ideas = generate_ideas_with_openai(prompt, openai_api_key)
st.write(ideas)
else:
with st.spinner("Generating ideas using Hugging Face open source model..."):
ideas = generate_ideas_with_hf(prompt)
st.write(ideas)
else:
st.error("Please enter a research paper abstract.")
# --- Section 3: Knowledge Graph Visualization ---
st.header("Knowledge Graph Visualization")
st.markdown(
"Simulate a knowledge graph by entering paper details and their citation relationships in CSV format:\n\n"
"**PaperID,Title,CitedPaperIDs** (CitedPaperIDs separated by ';').\n\n"
"Example:\n\n```\n1,Paper A,2;3\n2,Paper B,\n3,Paper C,2\n```"
)
papers_csv = st.text_area("Enter paper details in CSV format:", height=150)
if st.button("Generate Knowledge Graph"):
if papers_csv.strip():
data = []
for line in papers_csv.splitlines():
parts = line.split(',')
if len(parts) >= 3:
paper_id = parts[0].strip()
title = parts[1].strip()
cited = parts[2].strip()
cited_list = [c.strip() for c in cited.split(';') if c.strip()]
data.append({"paper_id": paper_id, "title": title, "cited": cited_list})
if data:
# Build a directed graph using NetworkX.
G = nx.DiGraph()
for paper in data:
G.add_node(paper["paper_id"], title=paper.get("title", str(paper["paper_id"])))
for cited in paper["cited"]:
G.add_edge(paper["paper_id"], cited)
st.subheader("Knowledge Graph")
# Create an interactive visualization using Pyvis.
net = Network(height="500px", width="100%", directed=True)
for node, node_data in G.nodes(data=True):
net.add_node(node, label=node_data.get("title", str(node)))
for source, target in G.edges():
net.add_edge(source, target)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
net.write_html(temp_file.name)
with open(temp_file.name, 'r', encoding='utf-8') as f:
html_content = f.read()
st.components.v1.html(html_content, height=500)
else:
st.error("Please enter paper details for the knowledge graph.")