dwb2023 commited on
Commit
a1094f9
·
verified ·
1 Parent(s): b619603

alternate version of the app

Browse files
Files changed (1) hide show
  1. alternate.py +228 -0
alternate.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import Optional
4
+ import gradio as gr
5
+ from gradio import Interface, Blocks
6
+ import networkx as nx
7
+ import pandas as pd
8
+ import matplotlib.pyplot as plt
9
+ import community as community_louvain
10
+ import pyvis
11
+ from pyvis.network import Network
12
+ from smolagents import CodeAgent, HfApiModel, tool, GradioUI
13
+ from opentelemetry import trace
14
+ from opentelemetry.sdk.trace import TracerProvider
15
+ from opentelemetry.sdk.trace.export import BatchSpanProcessor
16
+ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
17
+ from openinference.instrumentation.smolagents import SmolagentsInstrumentor
18
+
19
+ # Set up telemetry
20
+ PHOENIX_API_KEY = os.getenv("PHOENIX_API_KEY")
21
+ api_key = f"api_key={PHOENIX_API_KEY}"
22
+ os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = api_key
23
+ os.environ["PHOENIX_CLIENT_HEADERS"] = api_key
24
+ os.environ["PHOENIX_COLLECTOR_ENDPOINT"] = "https://app.phoenix.arize.com"
25
+
26
+ # Updated endpoint from local to cloud
27
+ endpoint = "https://app.phoenix.arize.com/v1/traces"
28
+ trace_provider = TracerProvider()
29
+ trace_provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter(endpoint)))
30
+ SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
31
+
32
+ examples = [
33
+ ["Analyze the degree, betweenness, and closeness centrality metrics for all families in the network, and highlight families with the highest values for each metric."],
34
+ ["Identify families with significant betweenness centrality and discuss their potential influence in the network."],
35
+ ["Compare the top three families by degree centrality with their closeness centrality rankings, and explain the differences."],
36
+ ["Visualize the network structure, emphasizing families with high centrality values using color or size variations."],
37
+ ["Explore the roles of families with above-average centrality values across all metrics and discuss their positions in the network."]
38
+ ]
39
+
40
+ class GradioUIWithExamples(GradioUI):
41
+ def __init__(self, agent, examples=None, **kwargs):
42
+ super().__init__(agent, **kwargs)
43
+ self.examples = examples
44
+
45
+ def build_interface(self):
46
+ with gr.Blocks() as demo:
47
+ gr.Markdown("## Florentine Families Network Analysis")
48
+
49
+ # Main Input/Output
50
+ input_box = gr.Textbox(
51
+ label="Your Question",
52
+ placeholder="Type your question about the Florentine Families graph...",
53
+ )
54
+ output_box = gr.Textbox(
55
+ label="Agent's Response",
56
+ placeholder="Response will appear here...",
57
+ interactive=False,
58
+ )
59
+ submit_button = gr.Button("Submit")
60
+
61
+ # Link submit button to agent logic
62
+ submit_button.click(
63
+ self.agent.run,
64
+ inputs=input_box,
65
+ outputs=output_box,
66
+ )
67
+
68
+ # Add Examples
69
+ if self.examples:
70
+ gr.Markdown("### Examples")
71
+ for example in self.examples:
72
+ gr.Button(example[0]).click(
73
+ lambda x=example[0]: x, # Populate input box
74
+ inputs=[],
75
+ outputs=input_box,
76
+ )
77
+ return demo
78
+
79
+ def launch(self):
80
+ # Use the custom-built interface instead of the base class's logic
81
+ demo = self.build_interface()
82
+ demo.launch()
83
+
84
+ # Initialize graph
85
+ graph = nx.florentine_families_graph()
86
+ #graph = nx.les_miserables_graph()
87
+
88
+ @tool
89
+ def analyze_graph(graph: nx.Graph, metrics: Optional[str] = None, visualize: Optional[bool] = False) -> dict:
90
+ """
91
+ Performs an in-depth analysis of the Florentine families graph, a predefined social network representing relationships between Renaissance Florentine families. This graph has already been initialized and should be used for all analyses unless another graph is explicitly provided.
92
+
93
+ Args:
94
+ graph: A networkx graph object to analyze. This is a required argument.
95
+ metrics: A comma-separated string of centrality metrics to calculate.
96
+ Valid options include: 'degree', 'betweenness', 'closeness', 'eigenvector',
97
+ 'density', 'clustering_coefficient'. If None, all metrics will be calculated.
98
+ visualize: A boolean indicating whether to generate visualizations for the graph and its metrics.
99
+
100
+ Returns:
101
+ A dictionary containing:
102
+ - 'metrics': Numerical results for the requested centrality metrics.
103
+ - 'graph_summary': High-level statistics about the graph (number of nodes, edges, density, etc.).
104
+ - 'community_info': Detected communities, if applicable.
105
+ - 'visualizations': Paths to generated visualization files, if visualize is True.
106
+
107
+ Note:
108
+ - This tool defaults to analyzing the Florentine families graph. If a different graph is provided, it will override the default.
109
+ - Ensure that the 'metrics' argument contains valid options to avoid errors.
110
+ """
111
+
112
+ if metrics:
113
+ metrics = [metric.strip() for metric in metrics.split(',')]
114
+ else:
115
+ metrics = ['degree', 'betweenness', 'closeness', 'eigenvector', 'density', 'clustering_coefficient']
116
+
117
+ # Graph summary
118
+ graph_summary = {
119
+ "number_of_nodes": graph.number_of_nodes(),
120
+ "number_of_edges": graph.number_of_edges(),
121
+ "density": nx.density(graph),
122
+ "average_clustering": nx.average_clustering(graph),
123
+ "connected_components": len(list(nx.connected_components(graph))),
124
+ }
125
+
126
+ # Compute requested metrics
127
+ computed_metrics = {}
128
+ if 'degree' in metrics:
129
+ computed_metrics['degree_centrality'] = nx.degree_centrality(graph)
130
+ if 'betweenness' in metrics:
131
+ computed_metrics['betweenness_centrality'] = nx.betweenness_centrality(graph)
132
+ if 'closeness' in metrics:
133
+ computed_metrics['closeness_centrality'] = nx.closeness_centrality(graph)
134
+ if 'eigenvector' in metrics:
135
+ computed_metrics['eigenvector_centrality'] = nx.eigenvector_centrality(graph)
136
+ if 'density' in metrics:
137
+ computed_metrics['density'] = nx.density(graph)
138
+ if 'clustering_coefficient' in metrics:
139
+ computed_metrics['clustering_coefficient'] = nx.average_clustering(graph)
140
+
141
+ # Community detection
142
+ communities = community_louvain.best_partition(graph)
143
+
144
+ # Visualizations
145
+ visualizations = []
146
+ if visualize:
147
+ pos = nx.spring_layout(graph)
148
+ plt.figure(figsize=(10, 8))
149
+ nx.draw(
150
+ graph,
151
+ pos,
152
+ with_labels=True,
153
+ node_size=700,
154
+ node_color=list(communities.values()),
155
+ cmap=plt.cm.rainbow,
156
+ )
157
+ plt.title("Graph Visualization - Communities")
158
+ viz_path = "graph_communities.png"
159
+ plt.savefig(viz_path)
160
+ visualizations.append(viz_path)
161
+
162
+ return {
163
+ "metrics": computed_metrics,
164
+ "graph_summary": graph_summary,
165
+ "community_info": communities,
166
+ "visualizations": visualizations if visualize else "Visualizations not generated.",
167
+ }
168
+
169
+ @tool
170
+ def save_html_to_file(html_content: str, file_path: str) -> str:
171
+ """
172
+ Saves the provided HTML content to a file.
173
+
174
+ Args:
175
+ html_content: The HTML content to save.
176
+ file_path: The path where the HTML file will be saved.
177
+
178
+ Returns:
179
+ A confirmation message upon successful saving.
180
+ """
181
+ with open(file_path, 'w', encoding='utf-8') as file:
182
+ file.write(html_content)
183
+ return f"HTML content successfully saved to {file_path}"
184
+
185
+ @tool
186
+ def read_html_from_file(file_path: str) -> str:
187
+ """
188
+ Reads HTML content from a file.
189
+
190
+ Args:
191
+ file_path: The path of the HTML file to read.
192
+
193
+ Returns:
194
+ The HTML content as a string.
195
+ """
196
+ with open(file_path, 'r', encoding='utf-8') as file:
197
+ html_content = file.read()
198
+ return html_content
199
+
200
+ @tool
201
+ def export_graph_to_json(graph_data: dict) -> str:
202
+ """
203
+ Exports a NetworkX graph represented as a dictionary in node-link format to JSON.
204
+
205
+ Args:
206
+ graph_data: The graph data in node-link format.
207
+
208
+ Returns:
209
+ str: The JSON representation of the graph.
210
+
211
+ """
212
+ try:
213
+ graph = nx.node_link_graph(graph_data, edges="edges")
214
+ json_output = json.dumps(nx.node_link_data(graph), indent=4)
215
+ return json_output
216
+ except Exception as e:
217
+ return f"Error exporting graph to JSON: {str(e)}"
218
+
219
+ model = HfApiModel()
220
+ agent = CodeAgent(
221
+ tools=[analyze_graph, save_html_to_file, read_html_from_file, export_graph_to_json],
222
+ model=model,
223
+ additional_authorized_imports=["gradio","networkx","community_louvain","pyvis","matplotlib","json", "pandas"],
224
+ add_base_tools=True
225
+ )
226
+
227
+ interface = GradioUIWithExamples(agent, examples=examples)
228
+ interface.launch()