import streamlit as st from transformers import pipeline import networkx as nx from pyvis.network import Network import tempfile import openai import requests import feedparser import pandas as pd from io import StringIO import asyncio # --------------------------- # Model Loading & Caching # --------------------------- @st.cache_resource(show_spinner=False) def load_summarizer(): # Load a summarization pipeline from Hugging Face (e.g., facebook/bart-large-cnn) summarizer = pipeline("summarization", model="facebook/bart-large-cnn") return summarizer @st.cache_resource(show_spinner=False) def load_text_generator(): # For demonstration, we load a text-generation model such as GPT-2. generator = pipeline("text-generation", model="gpt2") return generator summarizer = load_summarizer() generator = load_text_generator() # --------------------------- # Idea Generation Functions # --------------------------- def generate_ideas_with_hf(prompt): # Use Hugging Face's text-generation pipeline. # Instead of using max_length, we use max_new_tokens so that new tokens are generated. results = generator(prompt, max_new_tokens=50, num_return_sequences=1) idea_text = results[0]['generated_text'] return idea_text def generate_ideas_with_openai(prompt, api_key): """ Generates research ideas using OpenAI's GPT-3.5 model with streaming. This function uses the latest OpenAI SDK v1.0 and asynchronous API calls. """ openai.api_key = api_key output_text = "" async def stream_chat(): nonlocal output_text # Asynchronously call the chat completion endpoint with streaming enabled. response = await openai.ChatCompletion.acreate( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": "You are an expert AI research assistant who generates innovative research ideas."}, {"role": "user", "content": prompt}, ], stream=True, ) st_text = st.empty() # Placeholder for streaming output async for chunk in response: delta = chunk["choices"][0].get("delta", {}) text_piece = delta.get("content", "") output_text += text_piece st_text.text(output_text) asyncio.run(stream_chat()) return output_text # --------------------------- # arXiv API Integration # --------------------------- def fetch_arxiv_results(query, max_results=5): """ Queries arXiv's free API to fetch relevant papers. """ base_url = "http://export.arxiv.org/api/query?" search_query = "search_query=all:" + query start = "0" max_results = str(max_results) query_url = f"{base_url}{search_query}&start={start}&max_results={max_results}" response = requests.get(query_url) if response.status_code == 200: feed = feedparser.parse(response.content) results = [] for entry in feed.entries: title = entry.title summary = entry.summary published = entry.published link = entry.link authors = ", ".join(author.name for author in entry.authors) results.append({ "title": title, "authors": authors, "published": published, "summary": summary, "link": link }) return results else: return [] # --------------------------- # Streamlit Application Layout # --------------------------- st.title("Graph of AI Ideas Application with arXiv Integration and OpenAI SDK v1.0") # Sidebar Configuration st.sidebar.header("Configuration") generation_mode = st.sidebar.selectbox("Select Idea Generation Mode", ["Hugging Face Open Source", "OpenAI GPT-3.5 (Streaming)"]) openai_api_key = st.sidebar.text_input("OpenAI API Key (for GPT-3.5 Streaming)", type="password") # --- Section 1: arXiv Paper Search --- st.header("arXiv Paper Search") arxiv_query = st.text_input("Enter a search query for arXiv papers:") if st.button("Search arXiv"): if arxiv_query.strip(): with st.spinner("Searching arXiv..."): results = fetch_arxiv_results(arxiv_query, max_results=5) if results: st.subheader("arXiv Search Results:") for idx, paper in enumerate(results): st.markdown(f"**{idx+1}. {paper['title']}**") st.markdown(f"*Authors:* {paper['authors']}") st.markdown(f"*Published:* {paper['published']}") st.markdown(f"*Summary:* {paper['summary']}") st.markdown(f"[Read more]({paper['link']})") st.markdown("---") else: st.error("No results found or an error occurred with the arXiv API.") else: st.error("Please enter a valid query for the arXiv search.") # --- Section 2: Research Paper Input and Idea Generation --- st.header("Research Paper Input") paper_abstract = st.text_area("Enter the research paper abstract:", height=200) if st.button("Generate Ideas"): if paper_abstract.strip(): st.subheader("Summarized Abstract") # Summarize the abstract to capture its key points summary = summarizer(paper_abstract, max_length=100, min_length=30, do_sample=False) summary_text = summary[0]['summary_text'] st.write(summary_text) st.subheader("Generated Research Ideas") # Build a combined prompt with the abstract and its summary prompt = ( f"Based on the following research paper abstract, generate innovative and promising research ideas for future work.\n\n" f"Paper Abstract:\n{paper_abstract}\n\n" f"Summary:\n{summary_text}\n\n" f"Research Ideas:" ) if generation_mode == "OpenAI GPT-3.5 (Streaming)": if not openai_api_key.strip(): st.error("Please provide your OpenAI API Key in the sidebar.") else: with st.spinner("Generating ideas using OpenAI GPT-3.5 with SDK v1.0..."): ideas = generate_ideas_with_openai(prompt, openai_api_key) st.write(ideas) else: with st.spinner("Generating ideas using Hugging Face open source model..."): ideas = generate_ideas_with_hf(prompt) st.write(ideas) else: st.error("Please enter a research paper abstract.") # --- Section 3: Knowledge Graph Visualization --- st.header("Knowledge Graph Visualization") st.markdown( "Simulate a knowledge graph by entering paper details and their citation relationships in CSV format:\n\n" "**PaperID,Title,CitedPaperIDs** (CitedPaperIDs separated by ';').\n\n" "Example:\n\n```\n1,Paper A,2;3\n2,Paper B,\n3,Paper C,2\n```" ) papers_csv = st.text_area("Enter paper details in CSV format:", height=150) if st.button("Generate Knowledge Graph"): if papers_csv.strip(): data = [] for line in papers_csv.splitlines(): parts = line.split(',') if len(parts) >= 3: paper_id = parts[0].strip() title = parts[1].strip() cited = parts[2].strip() cited_list = [c.strip() for c in cited.split(';') if c.strip()] data.append({"paper_id": paper_id, "title": title, "cited": cited_list}) if data: # Build a directed graph using NetworkX G = nx.DiGraph() for paper in data: # Ensure each node has a 'title' key, using the node id as fallback. G.add_node(paper["paper_id"], title=paper.get("title", str(paper["paper_id"]))) for cited in paper["cited"]: G.add_edge(paper["paper_id"], cited) st.subheader("Knowledge Graph") # Create an interactive visualization using Pyvis net = Network(height="500px", width="100%", directed=True) for node, node_data in G.nodes(data=True): net.add_node(node, label=node_data.get("title", str(node))) for source, target in G.edges(): net.add_edge(source, target) # Save the interactive visualization to an HTML file and embed it in Streamlit temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html") net.write_html(temp_file.name) with open(temp_file.name, 'r', encoding='utf-8') as f: html_content = f.read() st.components.v1.html(html_content, height=500) else: st.error("Please enter paper details for the knowledge graph.")