File size: 9,452 Bytes
3dbb4eb
d01c5cc
3dbb4eb
 
 
d01c5cc
36ca259
834ac1a
36ca259
 
 
d01c5cc
3dbb4eb
 
 
 
 
36ca259
3dbb4eb
 
d01c5cc
3dbb4eb
 
bbccbee
3dbb4eb
 
d01c5cc
3dbb4eb
 
d01c5cc
3dbb4eb
36ca259
3dbb4eb
36ca259
c6562a0
834ac1a
c6562a0
36ca259
 
 
3dbb4eb
36ca259
c6562a0
bbccbee
36ca259
3dbb4eb
 
36ca259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dbb4eb
 
36ca259
834ac1a
36ca259
 
 
834ac1a
36ca259
 
 
 
834ac1a
 
 
36ca259
834ac1a
36ca259
834ac1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36ca259
 
 
834ac1a
 
 
36ca259
 
 
 
3dbb4eb
 
36ca259
3dbb4eb
36ca259
d01c5cc
36ca259
3dbb4eb
36ca259
 
3dbb4eb
d01c5cc
36ca259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dbb4eb
 
 
 
 
 
834ac1a
3dbb4eb
 
 
 
 
834ac1a
3dbb4eb
 
 
 
 
 
 
 
 
 
36ca259
3dbb4eb
 
 
 
 
 
 
 
 
36ca259
3dbb4eb
 
bbccbee
 
 
d01c5cc
3dbb4eb
 
 
 
 
 
 
 
 
 
 
 
 
 
834ac1a
3dbb4eb
 
c6562a0
3dbb4eb
 
 
 
834ac1a
3dbb4eb
 
bbccbee
3dbb4eb
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import streamlit as st
from transformers import pipeline
import networkx as nx
from pyvis.network import Network
import tempfile
import openai
import requests
import xml.etree.ElementTree as ET
import pandas as pd
from io import StringIO
import asyncio

# ---------------------------
# Model Loading & Caching
# ---------------------------
@st.cache_resource(show_spinner=False)
def load_summarizer():
    # Load a summarization pipeline from Hugging Face (e.g., facebook/bart-large-cnn)
    summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
    return summarizer

@st.cache_resource(show_spinner=False)
def load_text_generator():
    # For demonstration, we load a text-generation model such as GPT-2.
    generator = pipeline("text-generation", model="gpt2")
    return generator

summarizer = load_summarizer()
generator = load_text_generator()

# ---------------------------
# Idea Generation Functions
# ---------------------------
def generate_ideas_with_hf(prompt):
    # Use Hugging Face's text-generation pipeline.
    # We use max_new_tokens so that new tokens are generated beyond the prompt.
    results = generator(prompt, max_new_tokens=50, num_return_sequences=1)
    idea_text = results[0]['generated_text']
    return idea_text

def generate_ideas_with_openai(prompt, api_key):
    """
    Generates research ideas using OpenAI's GPT-3.5 model with streaming.
    This function uses the latest OpenAI SDK v1.0 and asynchronous API calls.
    """
    openai.api_key = api_key
    output_text = ""
    
    async def stream_chat():
        nonlocal output_text
        # Asynchronously call the chat completion endpoint with streaming enabled.
        response = await openai.ChatCompletion.acreate(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "You are an expert AI research assistant who generates innovative research ideas."},
                {"role": "user", "content": prompt},
            ],
            stream=True,
        )
        st_text = st.empty()  # Placeholder for streaming output
        async for chunk in response:
            delta = chunk["choices"][0].get("delta", {})
            text_piece = delta.get("content", "")
            output_text += text_piece
            st_text.text(output_text)
    
    asyncio.run(stream_chat())
    return output_text

# ---------------------------
# arXiv API Integration using xml.etree.ElementTree
# ---------------------------
def fetch_arxiv_results(query, max_results=5):
    """
    Queries arXiv's free API to fetch relevant papers using XML parsing.
    """
    base_url = "http://export.arxiv.org/api/query?"
    search_query = "search_query=all:" + query
    start = "0"
    max_results_str = str(max_results)
    query_url = f"{base_url}{search_query}&start={start}&max_results={max_results_str}"
    
    response = requests.get(query_url)
    results = []
    if response.status_code == 200:
        root = ET.fromstring(response.content)
        ns = {"atom": "http://www.w3.org/2005/Atom"}
        for entry in root.findall("atom:entry", ns):
            title_elem = entry.find("atom:title", ns)
            title = title_elem.text.strip() if title_elem is not None else ""
            
            summary_elem = entry.find("atom:summary", ns)
            summary = summary_elem.text.strip() if summary_elem is not None else ""
            
            published_elem = entry.find("atom:published", ns)
            published = published_elem.text.strip() if published_elem is not None else ""
            
            link_elem = entry.find("atom:id", ns)
            link = link_elem.text.strip() if link_elem is not None else ""
            
            authors = []
            for author in entry.findall("atom:author", ns):
                name_elem = author.find("atom:name", ns)
                if name_elem is not None:
                    authors.append(name_elem.text.strip())
            authors_str = ", ".join(authors)
            
            results.append({
                "title": title,
                "summary": summary,
                "published": published,
                "link": link,
                "authors": authors_str
            })
        return results
    else:
        return []

# ---------------------------
# Streamlit Application Layout
# ---------------------------
st.title("Graph of AI Ideas Application with arXiv Integration and OpenAI SDK v1.0")

# Sidebar Configuration
st.sidebar.header("Configuration")
generation_mode = st.sidebar.selectbox("Select Idea Generation Mode",
                                       ["Hugging Face Open Source", "OpenAI GPT-3.5 (Streaming)"])
openai_api_key = st.sidebar.text_input("OpenAI API Key (for GPT-3.5 Streaming)", type="password")

# --- Section 1: arXiv Paper Search ---
st.header("arXiv Paper Search")
arxiv_query = st.text_input("Enter a search query for arXiv papers:")

if st.button("Search arXiv"):
    if arxiv_query.strip():
        with st.spinner("Searching arXiv..."):
            results = fetch_arxiv_results(arxiv_query, max_results=5)
            if results:
                st.subheader("arXiv Search Results:")
                for idx, paper in enumerate(results):
                    st.markdown(f"**{idx+1}. {paper['title']}**")
                    st.markdown(f"*Authors:* {paper['authors']}")
                    st.markdown(f"*Published:* {paper['published']}")
                    st.markdown(f"*Summary:* {paper['summary']}")
                    st.markdown(f"[Read more]({paper['link']})")
                    st.markdown("---")
            else:
                st.error("No results found or an error occurred with the arXiv API.")
    else:
        st.error("Please enter a valid query for the arXiv search.")

# --- Section 2: Research Paper Input and Idea Generation ---
st.header("Research Paper Input")
paper_abstract = st.text_area("Enter the research paper abstract:", height=200)

if st.button("Generate Ideas"):
    if paper_abstract.strip():
        st.subheader("Summarized Abstract")
        # Summarize the abstract to capture its key points.
        summary = summarizer(paper_abstract, max_length=100, min_length=30, do_sample=False)
        summary_text = summary[0]['summary_text']
        st.write(summary_text)
        
        st.subheader("Generated Research Ideas")
        # Build a combined prompt with the abstract and its summary.
        prompt = (
            f"Based on the following research paper abstract, generate innovative and promising research ideas for future work.\n\n"
            f"Paper Abstract:\n{paper_abstract}\n\n"
            f"Summary:\n{summary_text}\n\n"
            f"Research Ideas:"
        )
        if generation_mode == "OpenAI GPT-3.5 (Streaming)":
            if not openai_api_key.strip():
                st.error("Please provide your OpenAI API Key in the sidebar.")
            else:
                with st.spinner("Generating ideas using OpenAI GPT-3.5 with SDK v1.0..."):
                    ideas = generate_ideas_with_openai(prompt, openai_api_key)
                    st.write(ideas)
        else:
            with st.spinner("Generating ideas using Hugging Face open source model..."):
                ideas = generate_ideas_with_hf(prompt)
                st.write(ideas)
    else:
        st.error("Please enter a research paper abstract.")

# --- Section 3: Knowledge Graph Visualization ---
st.header("Knowledge Graph Visualization")
st.markdown(
    "Simulate a knowledge graph by entering paper details and their citation relationships in CSV format:\n\n"
    "**PaperID,Title,CitedPaperIDs** (CitedPaperIDs separated by ';').\n\n"
    "Example:\n\n```\n1,Paper A,2;3\n2,Paper B,\n3,Paper C,2\n```"
)
papers_csv = st.text_area("Enter paper details in CSV format:", height=150)

if st.button("Generate Knowledge Graph"):
    if papers_csv.strip():
        data = []
        for line in papers_csv.splitlines():
            parts = line.split(',')
            if len(parts) >= 3:
                paper_id = parts[0].strip()
                title = parts[1].strip()
                cited = parts[2].strip()
                cited_list = [c.strip() for c in cited.split(';') if c.strip()]
                data.append({"paper_id": paper_id, "title": title, "cited": cited_list})
        if data:
            # Build a directed graph using NetworkX.
            G = nx.DiGraph()
            for paper in data:
                G.add_node(paper["paper_id"], title=paper.get("title", str(paper["paper_id"])))
                for cited in paper["cited"]:
                    G.add_edge(paper["paper_id"], cited)
            
            st.subheader("Knowledge Graph")
            # Create an interactive visualization using Pyvis.
            net = Network(height="500px", width="100%", directed=True)
            for node, node_data in G.nodes(data=True):
                net.add_node(node, label=node_data.get("title", str(node)))
            for source, target in G.edges():
                net.add_edge(source, target)
            temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
            net.write_html(temp_file.name)
            with open(temp_file.name, 'r', encoding='utf-8') as f:
                html_content = f.read()
            st.components.v1.html(html_content, height=500)
    else:
        st.error("Please enter paper details for the knowledge graph.")