mgbam commited on
Commit
3b19854
·
verified ·
1 Parent(s): 179c437

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -33
app.py CHANGED
@@ -9,6 +9,7 @@ import xml.etree.ElementTree as ET
9
  import pandas as pd
10
  from io import StringIO
11
  import asyncio
 
12
 
13
  # ---------------------------
14
  # Model Loading & Caching
@@ -32,23 +33,19 @@ generator = load_text_generator()
32
  # Idea Generation Functions
33
  # ---------------------------
34
  def generate_ideas_with_hf(prompt):
35
- # Use Hugging Face's text-generation pipeline.
36
- # We use max_new_tokens so that new tokens are generated beyond the prompt.
37
  results = generator(prompt, max_new_tokens=50, num_return_sequences=1)
38
  idea_text = results[0]['generated_text']
39
  return idea_text
40
 
41
  def generate_ideas_with_openai(prompt, api_key):
42
  """
43
- Generates research ideas using OpenAI's GPT-3.5 model with streaming.
44
- This function uses the latest OpenAI SDK v1.0 and asynchronous API calls.
45
  """
46
  openai.api_key = api_key
47
  output_text = ""
48
-
49
  async def stream_chat():
50
  nonlocal output_text
51
- # Asynchronously call the chat completion endpoint with streaming enabled.
52
  response = await openai.ChatCompletion.acreate(
53
  model="gpt-3.5-turbo",
54
  messages=[
@@ -63,7 +60,6 @@ def generate_ideas_with_openai(prompt, api_key):
63
  text_piece = delta.get("content", "")
64
  output_text += text_piece
65
  st_text.text(output_text)
66
-
67
  asyncio.run(stream_chat())
68
  return output_text
69
 
@@ -72,14 +68,13 @@ def generate_ideas_with_openai(prompt, api_key):
72
  # ---------------------------
73
  def fetch_arxiv_results(query, max_results=5):
74
  """
75
- Queries arXiv's free API to fetch relevant papers using XML parsing.
76
  """
77
  base_url = "http://export.arxiv.org/api/query?"
78
  search_query = "search_query=all:" + query
79
  start = "0"
80
  max_results_str = str(max_results)
81
  query_url = f"{base_url}{search_query}&start={start}&max_results={max_results_str}"
82
-
83
  response = requests.get(query_url)
84
  results = []
85
  if response.status_code == 200:
@@ -88,44 +83,48 @@ def fetch_arxiv_results(query, max_results=5):
88
  for entry in root.findall("atom:entry", ns):
89
  title_elem = entry.find("atom:title", ns)
90
  title = title_elem.text.strip() if title_elem is not None else ""
91
-
92
  summary_elem = entry.find("atom:summary", ns)
93
  summary = summary_elem.text.strip() if summary_elem is not None else ""
94
-
95
  published_elem = entry.find("atom:published", ns)
96
  published = published_elem.text.strip() if published_elem is not None else ""
97
-
98
  link_elem = entry.find("atom:id", ns)
99
  link = link_elem.text.strip() if link_elem is not None else ""
100
-
101
- authors = []
102
- for author in entry.findall("atom:author", ns):
103
- name_elem = author.find("atom:name", ns)
104
- if name_elem is not None:
105
- authors.append(name_elem.text.strip())
106
- authors_str = ", ".join(authors)
107
-
108
  results.append({
109
  "title": title,
110
  "summary": summary,
111
  "published": published,
112
  "link": link,
113
- "authors": authors_str
114
  })
115
  return results
116
  else:
117
  return []
118
 
 
 
 
 
 
 
 
 
 
 
 
119
  # ---------------------------
120
  # Streamlit Application Layout
121
  # ---------------------------
122
  st.title("Graph of AI Ideas Application with arXiv Integration and OpenAI SDK v1.0")
123
 
124
- # Sidebar Configuration
125
  st.sidebar.header("Configuration")
126
  generation_mode = st.sidebar.selectbox("Select Idea Generation Mode",
127
  ["Hugging Face Open Source", "OpenAI GPT-3.5 (Streaming)"])
128
  openai_api_key = st.sidebar.text_input("OpenAI API Key (for GPT-3.5 Streaming)", type="password")
 
129
 
130
  # --- Section 1: arXiv Paper Search ---
131
  st.header("arXiv Paper Search")
@@ -156,13 +155,10 @@ paper_abstract = st.text_area("Enter the research paper abstract:", height=200)
156
  if st.button("Generate Ideas"):
157
  if paper_abstract.strip():
158
  st.subheader("Summarized Abstract")
159
- # Summarize the abstract to capture its key points.
160
  summary = summarizer(paper_abstract, max_length=100, min_length=30, do_sample=False)
161
  summary_text = summary[0]['summary_text']
162
  st.write(summary_text)
163
-
164
  st.subheader("Generated Research Ideas")
165
- # Build a combined prompt with the abstract and its summary.
166
  prompt = (
167
  f"Based on the following research paper abstract, generate innovative and promising research ideas for future work.\n\n"
168
  f"Paper Abstract:\n{paper_abstract}\n\n"
@@ -183,13 +179,16 @@ if st.button("Generate Ideas"):
183
  else:
184
  st.error("Please enter a research paper abstract.")
185
 
186
- # --- Section 3: Knowledge Graph Visualization ---
187
  st.header("Knowledge Graph Visualization")
188
  st.markdown(
189
- "Simulate a knowledge graph by entering paper details and their citation relationships in CSV format:\n\n"
190
  "**PaperID,Title,CitedPaperIDs** (CitedPaperIDs separated by ';').\n\n"
191
- "Example:\n\n```\n1,Paper A,2;3\n2,Paper B,\n3,Paper C,2\n```"
192
  )
 
 
 
193
  papers_csv = st.text_area("Enter paper details in CSV format:", height=150)
194
 
195
  if st.button("Generate Knowledge Graph"):
@@ -204,24 +203,47 @@ if st.button("Generate Knowledge Graph"):
204
  cited_list = [c.strip() for c in cited.split(';') if c.strip()]
205
  data.append({"paper_id": paper_id, "title": title, "cited": cited_list})
206
  if data:
207
- # Build a directed graph using NetworkX.
208
  G = nx.DiGraph()
209
  for paper in data:
210
  G.add_node(paper["paper_id"], title=paper.get("title", str(paper["paper_id"])))
211
  for cited in paper["cited"]:
212
  G.add_edge(paper["paper_id"], cited)
213
 
 
 
 
 
 
 
 
 
 
 
214
  st.subheader("Knowledge Graph")
215
- # Create an interactive visualization using Pyvis.
216
  net = Network(height="500px", width="100%", directed=True)
217
- for node, node_data in G.nodes(data=True):
218
- net.add_node(node, label=node_data.get("title", str(node)))
219
- for source, target in G.edges():
 
 
220
  net.add_edge(source, target)
 
 
 
 
 
 
221
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
222
  net.write_html(temp_file.name)
 
 
223
  with open(temp_file.name, 'r', encoding='utf-8') as f:
224
  html_content = f.read()
225
  st.components.v1.html(html_content, height=500)
 
 
 
226
  else:
227
  st.error("Please enter paper details for the knowledge graph.")
 
9
  import pandas as pd
10
  from io import StringIO
11
  import asyncio
12
+ import base64
13
 
14
  # ---------------------------
15
  # Model Loading & Caching
 
33
  # Idea Generation Functions
34
  # ---------------------------
35
  def generate_ideas_with_hf(prompt):
36
+ # Generate ideas using a Hugging Face model; new tokens beyond the prompt.
 
37
  results = generator(prompt, max_new_tokens=50, num_return_sequences=1)
38
  idea_text = results[0]['generated_text']
39
  return idea_text
40
 
41
  def generate_ideas_with_openai(prompt, api_key):
42
  """
43
+ Generates research ideas using OpenAI's GPT-3.5 (Streaming).
 
44
  """
45
  openai.api_key = api_key
46
  output_text = ""
 
47
  async def stream_chat():
48
  nonlocal output_text
 
49
  response = await openai.ChatCompletion.acreate(
50
  model="gpt-3.5-turbo",
51
  messages=[
 
60
  text_piece = delta.get("content", "")
61
  output_text += text_piece
62
  st_text.text(output_text)
 
63
  asyncio.run(stream_chat())
64
  return output_text
65
 
 
68
  # ---------------------------
69
  def fetch_arxiv_results(query, max_results=5):
70
  """
71
+ Queries arXiv's free API and parses the result using ElementTree.
72
  """
73
  base_url = "http://export.arxiv.org/api/query?"
74
  search_query = "search_query=all:" + query
75
  start = "0"
76
  max_results_str = str(max_results)
77
  query_url = f"{base_url}{search_query}&start={start}&max_results={max_results_str}"
 
78
  response = requests.get(query_url)
79
  results = []
80
  if response.status_code == 200:
 
83
  for entry in root.findall("atom:entry", ns):
84
  title_elem = entry.find("atom:title", ns)
85
  title = title_elem.text.strip() if title_elem is not None else ""
 
86
  summary_elem = entry.find("atom:summary", ns)
87
  summary = summary_elem.text.strip() if summary_elem is not None else ""
 
88
  published_elem = entry.find("atom:published", ns)
89
  published = published_elem.text.strip() if published_elem is not None else ""
 
90
  link_elem = entry.find("atom:id", ns)
91
  link = link_elem.text.strip() if link_elem is not None else ""
92
+ authors = [author.find("atom:name", ns).text.strip()
93
+ for author in entry.findall("atom:author", ns)
94
+ if author.find("atom:name", ns) is not None]
 
 
 
 
 
95
  results.append({
96
  "title": title,
97
  "summary": summary,
98
  "published": published,
99
  "link": link,
100
+ "authors": ", ".join(authors)
101
  })
102
  return results
103
  else:
104
  return []
105
 
106
+ # ---------------------------
107
+ # Utility Function: Graph Download Link
108
+ # ---------------------------
109
+ def get_download_link(file_path, filename="graph.html"):
110
+ """Converts the HTML file to a downloadable link."""
111
+ with open(file_path, "r", encoding="utf-8") as f:
112
+ html_data = f.read()
113
+ b64 = base64.b64encode(html_data.encode()).decode()
114
+ href = f'<a href="data:text/html;base64,{b64}" download="{filename}">Download Graph as HTML</a>'
115
+ return href
116
+
117
  # ---------------------------
118
  # Streamlit Application Layout
119
  # ---------------------------
120
  st.title("Graph of AI Ideas Application with arXiv Integration and OpenAI SDK v1.0")
121
 
122
+ # Sidebar: Configuration and Layout Options
123
  st.sidebar.header("Configuration")
124
  generation_mode = st.sidebar.selectbox("Select Idea Generation Mode",
125
  ["Hugging Face Open Source", "OpenAI GPT-3.5 (Streaming)"])
126
  openai_api_key = st.sidebar.text_input("OpenAI API Key (for GPT-3.5 Streaming)", type="password")
127
+ layout_option = st.sidebar.selectbox("Select Graph Layout", ["Default", "Force Atlas 2"])
128
 
129
  # --- Section 1: arXiv Paper Search ---
130
  st.header("arXiv Paper Search")
 
155
  if st.button("Generate Ideas"):
156
  if paper_abstract.strip():
157
  st.subheader("Summarized Abstract")
 
158
  summary = summarizer(paper_abstract, max_length=100, min_length=30, do_sample=False)
159
  summary_text = summary[0]['summary_text']
160
  st.write(summary_text)
 
161
  st.subheader("Generated Research Ideas")
 
162
  prompt = (
163
  f"Based on the following research paper abstract, generate innovative and promising research ideas for future work.\n\n"
164
  f"Paper Abstract:\n{paper_abstract}\n\n"
 
179
  else:
180
  st.error("Please enter a research paper abstract.")
181
 
182
+ # --- Section 3: Knowledge Graph Visualization with Additional Features ---
183
  st.header("Knowledge Graph Visualization")
184
  st.markdown(
185
+ "Enter paper details and citation relationships in CSV format:\n\n"
186
  "**PaperID,Title,CitedPaperIDs** (CitedPaperIDs separated by ';').\n\n"
187
+ "Example:\n\n```\n1,Graph of AI Ideas: Leveraging Knowledge Graphs and LLMs for AI Research Idea Generation,2;3\n2,Fundamental Approaches in AI Literature,\n3,Applications of LLMs in Research Idea Generation,2\n```"
188
  )
189
+ # Optional filter input for node titles.
190
+ filter_text = st.text_input("Optional: Enter keyword to filter nodes in the graph:")
191
+
192
  papers_csv = st.text_area("Enter paper details in CSV format:", height=150)
193
 
194
  if st.button("Generate Knowledge Graph"):
 
203
  cited_list = [c.strip() for c in cited.split(';') if c.strip()]
204
  data.append({"paper_id": paper_id, "title": title, "cited": cited_list})
205
  if data:
206
+ # Build the full graph.
207
  G = nx.DiGraph()
208
  for paper in data:
209
  G.add_node(paper["paper_id"], title=paper.get("title", str(paper["paper_id"])))
210
  for cited in paper["cited"]:
211
  G.add_edge(paper["paper_id"], cited)
212
 
213
+ # Filter nodes if a keyword is provided.
214
+ if filter_text.strip():
215
+ filtered_nodes = [n for n, d in G.nodes(data=True) if filter_text.lower() in d.get("title", "").lower()]
216
+ if filtered_nodes:
217
+ H = G.subgraph(filtered_nodes).copy()
218
+ else:
219
+ H = nx.DiGraph()
220
+ else:
221
+ H = G
222
+
223
  st.subheader("Knowledge Graph")
224
+ # Create the Pyvis network.
225
  net = Network(height="500px", width="100%", directed=True)
226
+
227
+ # Add nodes with tooltips (show title on hover).
228
+ for node, node_data in H.nodes(data=True):
229
+ net.add_node(node, label=node_data.get("title", str(node)), title=node_data.get("title", "No Title"))
230
+ for source, target in H.edges():
231
  net.add_edge(source, target)
232
+
233
+ # Apply layout based on the user's selection.
234
+ if layout_option == "Force Atlas 2":
235
+ net.force_atlas_2based()
236
+
237
+ # Write graph to temporary HTML file.
238
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
239
  net.write_html(temp_file.name)
240
+
241
+ # Show the graph.
242
  with open(temp_file.name, 'r', encoding='utf-8') as f:
243
  html_content = f.read()
244
  st.components.v1.html(html_content, height=500)
245
+
246
+ # Provide a download link for the graph.
247
+ st.markdown(get_download_link(temp_file.name), unsafe_allow_html=True)
248
  else:
249
  st.error("Please enter paper details for the knowledge graph.")