File size: 10,687 Bytes
3dbb4eb
d01c5cc
3dbb4eb
 
 
d01c5cc
36ca259
834ac1a
36ca259
 
 
3b19854
d01c5cc
3dbb4eb
 
 
 
 
36ca259
3dbb4eb
 
d01c5cc
3dbb4eb
 
bbccbee
3dbb4eb
 
d01c5cc
3dbb4eb
 
d01c5cc
3dbb4eb
36ca259
3dbb4eb
36ca259
3b19854
c6562a0
36ca259
 
 
3dbb4eb
36ca259
3b19854
36ca259
3dbb4eb
 
36ca259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dbb4eb
 
36ca259
834ac1a
36ca259
 
 
3b19854
36ca259
 
 
 
834ac1a
 
36ca259
834ac1a
36ca259
834ac1a
 
 
 
 
 
 
 
 
 
 
3b19854
 
 
36ca259
 
 
834ac1a
 
3b19854
36ca259
 
 
 
3dbb4eb
3b19854
 
 
 
 
 
 
 
 
 
 
3dbb4eb
36ca259
3dbb4eb
36ca259
d01c5cc
3b19854
3dbb4eb
36ca259
 
3dbb4eb
3b19854
d01c5cc
36ca259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dbb4eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36ca259
3dbb4eb
 
 
 
 
 
 
 
 
3b19854
3dbb4eb
 
3b19854
bbccbee
3b19854
d01c5cc
3b19854
 
 
3dbb4eb
 
 
 
 
 
 
 
 
 
 
 
 
 
3b19854
3dbb4eb
 
c6562a0
3dbb4eb
 
 
3b19854
 
 
 
 
 
 
 
 
 
3dbb4eb
3b19854
3dbb4eb
3b19854
 
 
 
 
3dbb4eb
3b19854
 
 
 
 
 
3dbb4eb
 
3b19854
 
3dbb4eb
 
 
3b19854
 
 
3dbb4eb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import streamlit as st
from transformers import pipeline
import networkx as nx
from pyvis.network import Network
import tempfile
import openai
import requests
import xml.etree.ElementTree as ET
import pandas as pd
from io import StringIO
import asyncio
import base64

# ---------------------------
# Model Loading & Caching
# ---------------------------
@st.cache_resource(show_spinner=False)
def load_summarizer():
    # Load a summarization pipeline from Hugging Face (e.g., facebook/bart-large-cnn)
    summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
    return summarizer

@st.cache_resource(show_spinner=False)
def load_text_generator():
    # For demonstration, we load a text-generation model such as GPT-2.
    generator = pipeline("text-generation", model="gpt2")
    return generator

summarizer = load_summarizer()
generator = load_text_generator()

# ---------------------------
# Idea Generation Functions
# ---------------------------
def generate_ideas_with_hf(prompt):
    # Generate ideas using a Hugging Face model; new tokens beyond the prompt.
    results = generator(prompt, max_new_tokens=50, num_return_sequences=1)
    idea_text = results[0]['generated_text']
    return idea_text

def generate_ideas_with_openai(prompt, api_key):
    """
    Generates research ideas using OpenAI's GPT-3.5 (Streaming).
    """
    openai.api_key = api_key
    output_text = ""
    async def stream_chat():
        nonlocal output_text
        response = await openai.ChatCompletion.acreate(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "You are an expert AI research assistant who generates innovative research ideas."},
                {"role": "user", "content": prompt},
            ],
            stream=True,
        )
        st_text = st.empty()  # Placeholder for streaming output
        async for chunk in response:
            delta = chunk["choices"][0].get("delta", {})
            text_piece = delta.get("content", "")
            output_text += text_piece
            st_text.text(output_text)
    asyncio.run(stream_chat())
    return output_text

# ---------------------------
# arXiv API Integration using xml.etree.ElementTree
# ---------------------------
def fetch_arxiv_results(query, max_results=5):
    """
    Queries arXiv's free API and parses the result using ElementTree.
    """
    base_url = "http://export.arxiv.org/api/query?"
    search_query = "search_query=all:" + query
    start = "0"
    max_results_str = str(max_results)
    query_url = f"{base_url}{search_query}&start={start}&max_results={max_results_str}"
    response = requests.get(query_url)
    results = []
    if response.status_code == 200:
        root = ET.fromstring(response.content)
        ns = {"atom": "http://www.w3.org/2005/Atom"}
        for entry in root.findall("atom:entry", ns):
            title_elem = entry.find("atom:title", ns)
            title = title_elem.text.strip() if title_elem is not None else ""
            summary_elem = entry.find("atom:summary", ns)
            summary = summary_elem.text.strip() if summary_elem is not None else ""
            published_elem = entry.find("atom:published", ns)
            published = published_elem.text.strip() if published_elem is not None else ""
            link_elem = entry.find("atom:id", ns)
            link = link_elem.text.strip() if link_elem is not None else ""
            authors = [author.find("atom:name", ns).text.strip()
                       for author in entry.findall("atom:author", ns)
                       if author.find("atom:name", ns) is not None]
            results.append({
                "title": title,
                "summary": summary,
                "published": published,
                "link": link,
                "authors": ", ".join(authors)
            })
        return results
    else:
        return []

# ---------------------------
# Utility Function: Graph Download Link
# ---------------------------
def get_download_link(file_path, filename="graph.html"):
    """Converts the HTML file to a downloadable link."""
    with open(file_path, "r", encoding="utf-8") as f:
        html_data = f.read()
    b64 = base64.b64encode(html_data.encode()).decode()
    href = f'<a href="data:text/html;base64,{b64}" download="{filename}">Download Graph as HTML</a>'
    return href

# ---------------------------
# Streamlit Application Layout
# ---------------------------
st.title("Graph of AI Ideas Application with arXiv Integration and OpenAI SDK v1.0")

# Sidebar: Configuration and Layout Options
st.sidebar.header("Configuration")
generation_mode = st.sidebar.selectbox("Select Idea Generation Mode",
                                       ["Hugging Face Open Source", "OpenAI GPT-3.5 (Streaming)"])
openai_api_key = st.sidebar.text_input("OpenAI API Key (for GPT-3.5 Streaming)", type="password")
layout_option = st.sidebar.selectbox("Select Graph Layout", ["Default", "Force Atlas 2"])

# --- Section 1: arXiv Paper Search ---
st.header("arXiv Paper Search")
arxiv_query = st.text_input("Enter a search query for arXiv papers:")

if st.button("Search arXiv"):
    if arxiv_query.strip():
        with st.spinner("Searching arXiv..."):
            results = fetch_arxiv_results(arxiv_query, max_results=5)
            if results:
                st.subheader("arXiv Search Results:")
                for idx, paper in enumerate(results):
                    st.markdown(f"**{idx+1}. {paper['title']}**")
                    st.markdown(f"*Authors:* {paper['authors']}")
                    st.markdown(f"*Published:* {paper['published']}")
                    st.markdown(f"*Summary:* {paper['summary']}")
                    st.markdown(f"[Read more]({paper['link']})")
                    st.markdown("---")
            else:
                st.error("No results found or an error occurred with the arXiv API.")
    else:
        st.error("Please enter a valid query for the arXiv search.")

# --- Section 2: Research Paper Input and Idea Generation ---
st.header("Research Paper Input")
paper_abstract = st.text_area("Enter the research paper abstract:", height=200)

if st.button("Generate Ideas"):
    if paper_abstract.strip():
        st.subheader("Summarized Abstract")
        summary = summarizer(paper_abstract, max_length=100, min_length=30, do_sample=False)
        summary_text = summary[0]['summary_text']
        st.write(summary_text)
        st.subheader("Generated Research Ideas")
        prompt = (
            f"Based on the following research paper abstract, generate innovative and promising research ideas for future work.\n\n"
            f"Paper Abstract:\n{paper_abstract}\n\n"
            f"Summary:\n{summary_text}\n\n"
            f"Research Ideas:"
        )
        if generation_mode == "OpenAI GPT-3.5 (Streaming)":
            if not openai_api_key.strip():
                st.error("Please provide your OpenAI API Key in the sidebar.")
            else:
                with st.spinner("Generating ideas using OpenAI GPT-3.5 with SDK v1.0..."):
                    ideas = generate_ideas_with_openai(prompt, openai_api_key)
                    st.write(ideas)
        else:
            with st.spinner("Generating ideas using Hugging Face open source model..."):
                ideas = generate_ideas_with_hf(prompt)
                st.write(ideas)
    else:
        st.error("Please enter a research paper abstract.")

# --- Section 3: Knowledge Graph Visualization with Additional Features ---
st.header("Knowledge Graph Visualization")
st.markdown(
    "Enter paper details and citation relationships in CSV format:\n\n"
    "**PaperID,Title,CitedPaperIDs** (CitedPaperIDs separated by ';').\n\n"
    "Example:\n\n```\n1,Graph of AI Ideas: Leveraging Knowledge Graphs and LLMs for AI Research Idea Generation,2;3\n2,Fundamental Approaches in AI Literature,\n3,Applications of LLMs in Research Idea Generation,2\n```"
)
# Optional filter input for node titles.
filter_text = st.text_input("Optional: Enter keyword to filter nodes in the graph:")

papers_csv = st.text_area("Enter paper details in CSV format:", height=150)

if st.button("Generate Knowledge Graph"):
    if papers_csv.strip():
        data = []
        for line in papers_csv.splitlines():
            parts = line.split(',')
            if len(parts) >= 3:
                paper_id = parts[0].strip()
                title = parts[1].strip()
                cited = parts[2].strip()
                cited_list = [c.strip() for c in cited.split(';') if c.strip()]
                data.append({"paper_id": paper_id, "title": title, "cited": cited_list})
        if data:
            # Build the full graph.
            G = nx.DiGraph()
            for paper in data:
                G.add_node(paper["paper_id"], title=paper.get("title", str(paper["paper_id"])))
                for cited in paper["cited"]:
                    G.add_edge(paper["paper_id"], cited)
            
            # Filter nodes if a keyword is provided.
            if filter_text.strip():
                filtered_nodes = [n for n, d in G.nodes(data=True) if filter_text.lower() in d.get("title", "").lower()]
                if filtered_nodes:
                    H = G.subgraph(filtered_nodes).copy()
                else:
                    H = nx.DiGraph()
            else:
                H = G

            st.subheader("Knowledge Graph")
            # Create the Pyvis network.
            net = Network(height="500px", width="100%", directed=True)
            
            # Add nodes with tooltips (show title on hover).
            for node, node_data in H.nodes(data=True):
                net.add_node(node, label=node_data.get("title", str(node)), title=node_data.get("title", "No Title"))
            for source, target in H.edges():
                net.add_edge(source, target)
            
            # Apply layout based on the user's selection.
            if layout_option == "Force Atlas 2":
                net.force_atlas_2based()
            
            # Write graph to temporary HTML file.
            temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
            net.write_html(temp_file.name)
            
            # Show the graph.
            with open(temp_file.name, 'r', encoding='utf-8') as f:
                html_content = f.read()
            st.components.v1.html(html_content, height=500)
            
            # Provide a download link for the graph.
            st.markdown(get_download_link(temp_file.name), unsafe_allow_html=True)
    else:
        st.error("Please enter paper details for the knowledge graph.")