mgbam commited on
Commit
36ca259
·
verified ·
1 Parent(s): 7c0a761

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -44
app.py CHANGED
@@ -4,19 +4,24 @@ import networkx as nx
4
  from pyvis.network import Network
5
  import tempfile
6
  import openai
 
 
 
 
 
7
 
8
  # ---------------------------
9
  # Model Loading & Caching
10
  # ---------------------------
11
  @st.cache_resource(show_spinner=False)
12
  def load_summarizer():
13
- # Load a summarization pipeline from Hugging Face (using facebook/bart-large-cnn)
14
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
15
  return summarizer
16
 
17
  @st.cache_resource(show_spinner=False)
18
  def load_text_generator():
19
- # For a quick demo, we use a smaller text generation model (e.g., GPT-2)
20
  generator = pipeline("text-generation", model="gpt2")
21
  return generator
22
 
@@ -24,61 +29,123 @@ summarizer = load_summarizer()
24
  generator = load_text_generator()
25
 
26
  # ---------------------------
27
- # OpenAI Based Idea Generation (Streaming)
28
  # ---------------------------
 
 
 
 
 
 
29
  def generate_ideas_with_openai(prompt, api_key):
 
 
 
 
30
  openai.api_key = api_key
31
  output_text = ""
32
- # Create a chat completion request for streaming output
33
- response = openai.ChatCompletion.create(
34
- model="gpt-3.5-turbo",
35
- messages=[
36
- {"role": "system", "content": "You are an expert AI research assistant who generates innovative research ideas."},
37
- {"role": "user", "content": prompt}
38
- ],
39
- stream=True,
40
- )
41
- st_text = st.empty() # Placeholder for streaming output
42
- for chunk in response:
43
- if 'choices' in chunk and len(chunk['choices']) > 0:
44
- delta = chunk['choices'][0]['delta']
45
- if 'content' in delta:
46
- text_piece = delta['content']
47
- output_text += text_piece
48
- st_text.text(output_text)
 
 
 
49
  return output_text
50
 
51
- def generate_ideas_with_hf(prompt):
52
- # Use a Hugging Face text-generation pipeline for demo purposes.
53
- # (This may be less creative compared to GPT-3.5)
54
- results = generator(prompt, max_length=150, num_return_sequences=1)
55
- idea_text = results[0]['generated_text']
56
- return idea_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # ---------------------------
59
- # Streamlit App Layout
60
  # ---------------------------
61
- st.title("Graph of AI Ideas Application")
62
 
 
63
  st.sidebar.header("Configuration")
64
- generation_mode = st.sidebar.selectbox("Select Idea Generation Mode",
65
- ["Hugging Face Open Source", "OpenAI GPT-3.5 (Streaming)"])
66
  openai_api_key = st.sidebar.text_input("OpenAI API Key (for GPT-3.5 Streaming)", type="password")
67
 
68
- # --- Section 1: Research Paper Input and Idea Generation ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  st.header("Research Paper Input")
70
  paper_abstract = st.text_area("Enter the research paper abstract:", height=200)
71
 
72
  if st.button("Generate Ideas"):
73
  if paper_abstract.strip():
74
  st.subheader("Summarized Abstract")
75
- # Summarize the paper abstract to capture essential points
76
  summary = summarizer(paper_abstract, max_length=100, min_length=30, do_sample=False)
77
  summary_text = summary[0]['summary_text']
78
  st.write(summary_text)
79
 
80
  st.subheader("Generated Research Ideas")
81
- # Build a prompt that combines the abstract and its summary
82
  prompt = (
83
  f"Based on the following research paper abstract, generate innovative and promising research ideas for future work.\n\n"
84
  f"Paper Abstract:\n{paper_abstract}\n\n"
@@ -89,7 +156,7 @@ if st.button("Generate Ideas"):
89
  if not openai_api_key.strip():
90
  st.error("Please provide your OpenAI API Key in the sidebar.")
91
  else:
92
- with st.spinner("Generating ideas using OpenAI GPT-3.5..."):
93
  ideas = generate_ideas_with_openai(prompt, openai_api_key)
94
  st.write(ideas)
95
  else:
@@ -99,21 +166,16 @@ if st.button("Generate Ideas"):
99
  else:
100
  st.error("Please enter a research paper abstract.")
101
 
102
- # --- Section 2: Knowledge Graph Visualization ---
103
  st.header("Knowledge Graph Visualization")
104
  st.markdown(
105
- "Simulate a knowledge graph by entering paper details and their citation relationships. "
106
- "Enter details in CSV format: **PaperID,Title,CitedPaperIDs** (CitedPaperIDs separated by ';'). "
107
  "Example:\n\n`1,Paper A,2;3`\n`2,Paper B,`\n`3,Paper C,2`"
108
  )
109
  papers_csv = st.text_area("Enter paper details in CSV format:", height=150)
110
 
111
  if st.button("Generate Knowledge Graph"):
112
  if papers_csv.strip():
113
- import pandas as pd
114
- from io import StringIO
115
-
116
- # Process the CSV text input
117
  data = []
118
  for line in papers_csv.splitlines():
119
  parts = line.split(',')
@@ -123,9 +185,8 @@ if st.button("Generate Knowledge Graph"):
123
  cited = parts[2].strip()
124
  cited_list = [c.strip() for c in cited.split(';') if c.strip()]
125
  data.append({"paper_id": paper_id, "title": title, "cited": cited_list})
126
-
127
  if data:
128
- # Build a directed graph
129
  G = nx.DiGraph()
130
  for paper in data:
131
  G.add_node(paper["paper_id"], title=paper["title"])
@@ -139,7 +200,7 @@ if st.button("Generate Knowledge Graph"):
139
  net.add_node(node, label=node_data["title"])
140
  for source, target in G.edges():
141
  net.add_edge(source, target)
142
- # Write and display the network as HTML in Streamlit
143
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
144
  net.write_html(temp_file.name)
145
  with open(temp_file.name, 'r', encoding='utf-8') as f:
 
4
  from pyvis.network import Network
5
  import tempfile
6
  import openai
7
+ import requests
8
+ import feedparser
9
+ import pandas as pd
10
+ from io import StringIO
11
+ import asyncio
12
 
13
  # ---------------------------
14
  # Model Loading & Caching
15
  # ---------------------------
16
  @st.cache_resource(show_spinner=False)
17
  def load_summarizer():
18
+ # Load a summarization pipeline from Hugging Face (e.g., facebook/bart-large-cnn)
19
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
20
  return summarizer
21
 
22
  @st.cache_resource(show_spinner=False)
23
  def load_text_generator():
24
+ # For demonstration, we load a text-generation model such as GPT-2
25
  generator = pipeline("text-generation", model="gpt2")
26
  return generator
27
 
 
29
  generator = load_text_generator()
30
 
31
  # ---------------------------
32
+ # Idea Generation Functions
33
  # ---------------------------
34
+ def generate_ideas_with_hf(prompt):
35
+ # Use Hugging Face's text-generation pipeline (less creative than GPT‑3.5)
36
+ results = generator(prompt, max_length=150, num_return_sequences=1)
37
+ idea_text = results[0]['generated_text']
38
+ return idea_text
39
+
40
  def generate_ideas_with_openai(prompt, api_key):
41
+ """
42
+ Generates research ideas using OpenAI's GPT‑3.5 model with streaming.
43
+ This function uses the latest OpenAI SDK v1.0 which supports asynchronous API calls.
44
+ """
45
  openai.api_key = api_key
46
  output_text = ""
47
+
48
+ async def stream_chat():
49
+ nonlocal output_text
50
+ # Asynchronously call the chat completion endpoint with streaming enabled.
51
+ response = await openai.ChatCompletion.acreate(
52
+ model="gpt-3.5-turbo",
53
+ messages=[
54
+ {"role": "system", "content": "You are an expert AI research assistant who generates innovative research ideas."},
55
+ {"role": "user", "content": prompt},
56
+ ],
57
+ stream=True,
58
+ )
59
+ st_text = st.empty() # Placeholder for streaming output
60
+ async for chunk in response:
61
+ delta = chunk["choices"][0].get("delta", {})
62
+ text_piece = delta.get("content", "")
63
+ output_text += text_piece
64
+ st_text.text(output_text)
65
+
66
+ asyncio.run(stream_chat())
67
  return output_text
68
 
69
+ # ---------------------------
70
+ # arXiv API Integration
71
+ # ---------------------------
72
+ def fetch_arxiv_results(query, max_results=5):
73
+ """
74
+ Queries arXiv's free API to fetch relevant papers.
75
+ """
76
+ base_url = "http://export.arxiv.org/api/query?"
77
+ search_query = "search_query=all:" + query
78
+ start = "0"
79
+ max_results = str(max_results)
80
+ query_url = f"{base_url}{search_query}&start={start}&max_results={max_results}"
81
+ response = requests.get(query_url)
82
+ if response.status_code == 200:
83
+ feed = feedparser.parse(response.content)
84
+ results = []
85
+ for entry in feed.entries:
86
+ title = entry.title
87
+ summary = entry.summary
88
+ published = entry.published
89
+ link = entry.link
90
+ authors = ", ".join(author.name for author in entry.authors)
91
+ results.append({
92
+ "title": title,
93
+ "authors": authors,
94
+ "published": published,
95
+ "summary": summary,
96
+ "link": link
97
+ })
98
+ return results
99
+ else:
100
+ return []
101
 
102
  # ---------------------------
103
+ # Streamlit Application Layout
104
  # ---------------------------
105
+ st.title("Graph of AI Ideas Application with arXiv Integration and OpenAI SDK v1.0")
106
 
107
+ # Sidebar Configuration
108
  st.sidebar.header("Configuration")
109
+ generation_mode = st.sidebar.selectbox("Select Idea Generation Mode",
110
+ ["Hugging Face Open Source", "OpenAI GPT-3.5 (Streaming)"])
111
  openai_api_key = st.sidebar.text_input("OpenAI API Key (for GPT-3.5 Streaming)", type="password")
112
 
113
+ # --- Section 1: arXiv Paper Search ---
114
+ st.header("arXiv Paper Search")
115
+ arxiv_query = st.text_input("Enter a search query for arXiv papers:")
116
+
117
+ if st.button("Search arXiv"):
118
+ if arxiv_query.strip():
119
+ with st.spinner("Searching arXiv..."):
120
+ results = fetch_arxiv_results(arxiv_query, max_results=5)
121
+ if results:
122
+ st.subheader("arXiv Search Results:")
123
+ for idx, paper in enumerate(results):
124
+ st.markdown(f"**{idx+1}. {paper['title']}**")
125
+ st.markdown(f"*Authors:* {paper['authors']}")
126
+ st.markdown(f"*Published:* {paper['published']}")
127
+ st.markdown(f"*Summary:* {paper['summary']}")
128
+ st.markdown(f"[Read more]({paper['link']})")
129
+ st.markdown("---")
130
+ else:
131
+ st.error("No results found or an error occurred with the arXiv API.")
132
+ else:
133
+ st.error("Please enter a valid query for the arXiv search.")
134
+
135
+ # --- Section 2: Research Paper Input and Idea Generation ---
136
  st.header("Research Paper Input")
137
  paper_abstract = st.text_area("Enter the research paper abstract:", height=200)
138
 
139
  if st.button("Generate Ideas"):
140
  if paper_abstract.strip():
141
  st.subheader("Summarized Abstract")
142
+ # Use the Hugging Face summarizer to capture key points
143
  summary = summarizer(paper_abstract, max_length=100, min_length=30, do_sample=False)
144
  summary_text = summary[0]['summary_text']
145
  st.write(summary_text)
146
 
147
  st.subheader("Generated Research Ideas")
148
+ # Build a combined prompt based on the abstract and its summary
149
  prompt = (
150
  f"Based on the following research paper abstract, generate innovative and promising research ideas for future work.\n\n"
151
  f"Paper Abstract:\n{paper_abstract}\n\n"
 
156
  if not openai_api_key.strip():
157
  st.error("Please provide your OpenAI API Key in the sidebar.")
158
  else:
159
+ with st.spinner("Generating ideas using OpenAI GPT-3.5 with SDK v1.0..."):
160
  ideas = generate_ideas_with_openai(prompt, openai_api_key)
161
  st.write(ideas)
162
  else:
 
166
  else:
167
  st.error("Please enter a research paper abstract.")
168
 
169
+ # --- Section 3: Knowledge Graph Visualization ---
170
  st.header("Knowledge Graph Visualization")
171
  st.markdown(
172
+ "Simulate a knowledge graph by entering paper details and their citation relationships in CSV format: **PaperID,Title,CitedPaperIDs** (CitedPaperIDs separated by ';').\n\n"
 
173
  "Example:\n\n`1,Paper A,2;3`\n`2,Paper B,`\n`3,Paper C,2`"
174
  )
175
  papers_csv = st.text_area("Enter paper details in CSV format:", height=150)
176
 
177
  if st.button("Generate Knowledge Graph"):
178
  if papers_csv.strip():
 
 
 
 
179
  data = []
180
  for line in papers_csv.splitlines():
181
  parts = line.split(',')
 
185
  cited = parts[2].strip()
186
  cited_list = [c.strip() for c in cited.split(';') if c.strip()]
187
  data.append({"paper_id": paper_id, "title": title, "cited": cited_list})
 
188
  if data:
189
+ # Build a directed graph using NetworkX
190
  G = nx.DiGraph()
191
  for paper in data:
192
  G.add_node(paper["paper_id"], title=paper["title"])
 
200
  net.add_node(node, label=node_data["title"])
201
  for source, target in G.edges():
202
  net.add_edge(source, target)
203
+ # Save the interactive visualization to an HTML file and embed it in Streamlit
204
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".html")
205
  net.write_html(temp_file.name)
206
  with open(temp_file.name, 'r', encoding='utf-8') as f: