import streamlit as st from transformers import pipeline import networkx as nx from pyvis.network import Network import tempfile import openai import requests import xml.etree.ElementTree as ET import pandas as pd from io import StringIO import asyncio import base64 # --------------------------- # Model Loading & Caching # --------------------------- @st.cache_resource(show_spinner=False) def load_summarizer(): # Load a summarization pipeline from Hugging Face (e.g., facebook/bart-large-cnn) summarizer = pipeline("summarization", model="facebook/bart-large-cnn") return summarizer @st.cache_resource(show_spinner=False) def load_text_generator(): # For demonstration, we load a text-generation model such as GPT-2. generator = pipeline("text-generation", model="gpt2") return generator summarizer = load_summarizer() generator = load_text_generator() # --------------------------- # Idea Generation Functions # --------------------------- def generate_ideas_with_hf(prompt): # Generate ideas using a Hugging Face model; new tokens beyond the prompt. results = generator(prompt, max_new_tokens=50, num_return_sequences=1) idea_text = results[0]['generated_text'] return idea_text def generate_ideas_with_openai(prompt, api_key): """ Generates research ideas using OpenAI's GPT-3.5 (Streaming). """ openai.api_key = api_key output_text = "" async def stream_chat(): nonlocal output_text response = await openai.ChatCompletion.acreate( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": "You are an expert AI research assistant who generates innovative research ideas."}, {"role": "user", "content": prompt}, ], stream=True, ) st_text = st.empty() # Placeholder for streaming output async for chunk in response: delta = chunk["choices"][0].get("delta", {}) text_piece = delta.get("content", "") output_text += text_piece st_text.text(output_text) asyncio.run(stream_chat()) return output_text # --------------------------- # arXiv API Integration using xml.etree.ElementTree # --------------------------- def fetch_arxiv_results(query, max_results=5): """ Queries arXiv's free API and parses the result using ElementTree. """ base_url = "http://export.arxiv.org/api/query?" search_query = "search_query=all:" + query start = "0" max_results_str = str(max_results) query_url = f"{base_url}{search_query}&start={start}&max_results={max_results_str}" response = requests.get(query_url) results = [] if response.status_code == 200: root = ET.fromstring(response.content) ns = {"atom": "http://www.w3.org/2005/Atom"} for entry in root.findall("atom:entry", ns): title_elem = entry.find("atom:title", ns) title = title_elem.text.strip() if title_elem is not None else "" summary_elem = entry.find("atom:summary", ns) summary = summary_elem.text.strip() if summary_elem is not None else "" published_elem = entry.find("atom:published", ns) published = published_elem.text.strip() if published_elem is not None else "" link_elem = entry.find("atom:id", ns) link = link_elem.text.strip() if link_elem is not None else "" authors = [author.find("atom:name", ns).text.strip() for author in entry.findall("atom:author", ns) if author.find("atom:name", ns) is not None] results.append({ "title": title, "summary": summary, "published": published, "link": link, "authors": ", ".join(authors) }) return results else: return [] # --------------------------- # Utility Function: Graph Download Link # --------------------------- def get_download_link(file_path, filename="graph.html"): """Converts the HTML file to a downloadable link.""" with open(file_path, "r", encoding="utf-8") as f: html_data = f.read() b64 = base64.b64encode(html_data.encode()).decode() href = f'Download Graph as HTML' return href # --------------------------- # Streamlit Application Layout # --------------------------- st.title("Graph of AI Ideas Application with arXiv Integration and OpenAI SDK v1.0") # Sidebar: Configuration and Layout Options st.sidebar.header("Configuration") generation_mode = st.sidebar.selectbox("Select Idea Generation Mode", ["Hugging Face Open Source", "OpenAI GPT-3.5 (Streaming)"]) openai_api_key = st.sidebar.text_input("OpenAI API Key (for GPT-3.5 Streaming)", type="password") layout_option = st.sidebar.selectbox("Select Graph Layout", ["Default", "Force Atlas 2"]) # --- Section 1: arXiv Paper Search --- st.header("arXiv Paper Search") arxiv_query = st.text_input("Enter a search query for arXiv papers:") if st.button("Search arXiv"): if arxiv_query.strip(): with st.spinner("Searching arXiv..."): results = fetch_arxiv_results(arxiv_query, max_results=5) if results: st.subheader("arXiv Search Results:") for idx, paper in enumerate(results): st.markdown(f"**{idx+1}. {paper['title']}**") st.markdown(f"*Authors:* {paper['authors']}") st.markdown(f"*Published:* {paper['published']}") st.markdown(f"*Summary:* {paper['summary']}") st.markdown(f"[Read more]({paper['link']})") st.markdown("---") else: st.error("No results found or an error occurred with the arXiv API.") else: st.error("Please enter a valid query for the arXiv search.") # --- Section 2: Research Paper Input and Idea Generation --- st.header("Research Paper Input") paper_abstract = st.text_area("Enter the research paper abstract:", height=200) if st.button("Generate Ideas"): if paper_abstract.strip(): st.subheader("Summarized Abstract") summary = summarizer(paper_abstract, max_length=100, min_length=30, do_sample=False) summary_text = summary[0]['summary_text'] st.write(summary_text) st.subheader("Generated Research Ideas") prompt = ( f"Based on the following research paper abstract, generate innovative and promising research ideas for future work.\n\n" f"Paper Abstract:\n{paper_abstract}\n\n" f"Summary:\n{summary_text}\n\n" f"Research Ideas:" ) if generation_mode == "OpenAI GPT-3.5 (Streaming)": if not openai_api_key.strip(): st.error("Please provide your OpenAI API Key in the sidebar.") else: with st.spinner("Generating ideas using OpenAI GPT-3.5 with SDK v1.0..."): ideas = generate_ideas_with_openai(prompt, openai_api_key) st.write(ideas) else: with st.spinner("Generating ideas using Hugging Face open source model..."): ideas = generate_ideas_with_hf(prompt) st.write(ideas) else: st.error("Please enter a research paper abstract.") # --- Section 3: Knowledge Graph Visualization with Additional Features --- st.header("Knowledge Graph Visualization") st.markdown( "Enter paper details and citation relationships in CSV format:\n\n" "**PaperID,Title,CitedPaperIDs** (CitedPaperIDs separated by ';').\n\n" "Example:\n\n```\n1,Graph of AI Ideas: Leveraging Knowledge Graphs and LLMs for AI Research Idea Generation,2;3\n2,Fundamental Approaches in AI Literature,\n3,Applications of LLMs in Research Idea Generation,2\n```" ) # Optional filter input for node titles. filter_text = st.text_input("Optional: Enter keyword to filter nodes in the graph:") papers_csv = st.text_area("Enter paper details in CSV format:", height=150) if st.button("Generate Knowledge Graph"): if papers_csv.strip(): data = [] for line in papers_csv.splitlines(): parts = line.split(',') if len(parts) >= 3: paper_id = parts[0].strip() title = parts[1].strip() cited = parts[2].strip() cited_list = [c.strip() for c in cited.split(';') if c.strip()] data.append({"paper_id": paper_id, "title": title, "cited": cited_list}) if data: # Build the full graph. G = nx.DiGraph() for paper in data: G.add_node(paper["paper_id"], title=paper.get("title", str(paper["paper_id"]))) for cited in paper["cited"]: G.add_edge(paper["paper_id"], cited) # Filter nodes if a keyword is provided. if filter_text.strip(): filtered_nodes = [n for n, d in G.nodes(data=True) if filter_text.lower() in d.get("title", "").lower()] if filtered_nodes: H = G.subgraph(filtered_nodes).copy() else: H = nx.DiGraph() else: H = G st.subheader("Knowledge Graph") # Create the Pyvis network. net = Network(height="500px", width="100%", directed=True) # Add nodes with tooltips (show title on hover). for node, node_data in H.nodes(data=True): net.add_node(node, label=node_data.get("title", str(node)), title=node_data.get("title", "No Title")) for source, target in H.edges(): net.add_edge(source, target) # Apply layout based on the user's selection. if layout_option == "Force Atlas 2": net.force_atlas_2based() # Write graph to temporary HTML file. temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html") net.write_html(temp_file.name) # Show the graph. with open(temp_file.name, 'r', encoding='utf-8') as f: html_content = f.read() st.components.v1.html(html_content, height=500) # Provide a download link for the graph. st.markdown(get_download_link(temp_file.name), unsafe_allow_html=True) else: st.error("Please enter paper details for the knowledge graph.")