File size: 8,774 Bytes
a1094f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import os
import json
from typing import Optional
import gradio as gr
from gradio import Interface, Blocks
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt
import community as community_louvain
import pyvis
from pyvis.network import Network
from smolagents import CodeAgent, HfApiModel, tool, GradioUI
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from openinference.instrumentation.smolagents import SmolagentsInstrumentor

# Set up telemetry
PHOENIX_API_KEY = os.getenv("PHOENIX_API_KEY")
api_key = f"api_key={PHOENIX_API_KEY}"
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = api_key
os.environ["PHOENIX_CLIENT_HEADERS"] = api_key
os.environ["PHOENIX_COLLECTOR_ENDPOINT"] = "https://app.phoenix.arize.com"

# Updated endpoint from local to cloud
endpoint = "https://app.phoenix.arize.com/v1/traces"
trace_provider = TracerProvider()
trace_provider.add_span_processor(BatchSpanProcessor(OTLPSpanExporter(endpoint)))
SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)

examples = [
    ["Analyze the degree, betweenness, and closeness centrality metrics for all families in the network, and highlight families with the highest values for each metric."],
    ["Identify families with significant betweenness centrality and discuss their potential influence in the network."],
    ["Compare the top three families by degree centrality with their closeness centrality rankings, and explain the differences."],
    ["Visualize the network structure, emphasizing families with high centrality values using color or size variations."],
    ["Explore the roles of families with above-average centrality values across all metrics and discuss their positions in the network."]
]

class GradioUIWithExamples(GradioUI):
    def __init__(self, agent, examples=None, **kwargs):
        super().__init__(agent, **kwargs)
        self.examples = examples

    def build_interface(self):
        with gr.Blocks() as demo:
            gr.Markdown("## Florentine Families Network Analysis")

            # Main Input/Output
            input_box = gr.Textbox(
                label="Your Question",
                placeholder="Type your question about the Florentine Families graph...",
            )
            output_box = gr.Textbox(
                label="Agent's Response",
                placeholder="Response will appear here...",
                interactive=False,
            )
            submit_button = gr.Button("Submit")

            # Link submit button to agent logic
            submit_button.click(
                self.agent.run,
                inputs=input_box,
                outputs=output_box,
            )

            # Add Examples
            if self.examples:
                gr.Markdown("### Examples")
                for example in self.examples:
                    gr.Button(example[0]).click(
                        lambda x=example[0]: x,  # Populate input box
                        inputs=[],
                        outputs=input_box,
                    )
        return demo

    def launch(self):
        # Use the custom-built interface instead of the base class's logic
        demo = self.build_interface()
        demo.launch()

# Initialize graph
graph = nx.florentine_families_graph()
#graph = nx.les_miserables_graph()

@tool
def analyze_graph(graph: nx.Graph, metrics: Optional[str] = None, visualize: Optional[bool] = False) -> dict:
    """
    Performs an in-depth analysis of the Florentine families graph, a predefined social network representing relationships between Renaissance Florentine families. This graph has already been initialized and should be used for all analyses unless another graph is explicitly provided.

    Args:
        graph: A networkx graph object to analyze. This is a required argument.
        metrics: A comma-separated string of centrality metrics to calculate. 
            Valid options include: 'degree', 'betweenness', 'closeness', 'eigenvector', 
            'density', 'clustering_coefficient'. If None, all metrics will be calculated.
        visualize: A boolean indicating whether to generate visualizations for the graph and its metrics.

    Returns:
        A dictionary containing:
        - 'metrics': Numerical results for the requested centrality metrics.
        - 'graph_summary': High-level statistics about the graph (number of nodes, edges, density, etc.).
        - 'community_info': Detected communities, if applicable.
        - 'visualizations': Paths to generated visualization files, if visualize is True.

    Note:
        - This tool defaults to analyzing the Florentine families graph. If a different graph is provided, it will override the default.
        - Ensure that the 'metrics' argument contains valid options to avoid errors.
    """

    if metrics:
        metrics = [metric.strip() for metric in metrics.split(',')]
    else:
        metrics = ['degree', 'betweenness', 'closeness', 'eigenvector', 'density', 'clustering_coefficient']

    # Graph summary
    graph_summary = {
        "number_of_nodes": graph.number_of_nodes(),
        "number_of_edges": graph.number_of_edges(),
        "density": nx.density(graph),
        "average_clustering": nx.average_clustering(graph),
        "connected_components": len(list(nx.connected_components(graph))),
    }

    # Compute requested metrics
    computed_metrics = {}
    if 'degree' in metrics:
        computed_metrics['degree_centrality'] = nx.degree_centrality(graph)
    if 'betweenness' in metrics:
        computed_metrics['betweenness_centrality'] = nx.betweenness_centrality(graph)
    if 'closeness' in metrics:
        computed_metrics['closeness_centrality'] = nx.closeness_centrality(graph)
    if 'eigenvector' in metrics:
        computed_metrics['eigenvector_centrality'] = nx.eigenvector_centrality(graph)
    if 'density' in metrics:
        computed_metrics['density'] = nx.density(graph)
    if 'clustering_coefficient' in metrics:
        computed_metrics['clustering_coefficient'] = nx.average_clustering(graph)

    # Community detection
    communities = community_louvain.best_partition(graph)

    # Visualizations
    visualizations = []
    if visualize:
        pos = nx.spring_layout(graph)
        plt.figure(figsize=(10, 8))
        nx.draw(
            graph,
            pos,
            with_labels=True,
            node_size=700,
            node_color=list(communities.values()),
            cmap=plt.cm.rainbow,
        )
        plt.title("Graph Visualization - Communities")
        viz_path = "graph_communities.png"
        plt.savefig(viz_path)
        visualizations.append(viz_path)

    return {
        "metrics": computed_metrics,
        "graph_summary": graph_summary,
        "community_info": communities,
        "visualizations": visualizations if visualize else "Visualizations not generated.",
    }

@tool
def save_html_to_file(html_content: str, file_path: str) -> str:
    """
    Saves the provided HTML content to a file.

    Args:
        html_content: The HTML content to save.
        file_path: The path where the HTML file will be saved.

    Returns:
        A confirmation message upon successful saving.
    """
    with open(file_path, 'w', encoding='utf-8') as file:
        file.write(html_content)
    return f"HTML content successfully saved to {file_path}"

@tool
def read_html_from_file(file_path: str) -> str:
    """
    Reads HTML content from a file.

    Args:
        file_path: The path of the HTML file to read.

    Returns:
        The HTML content as a string.
    """
    with open(file_path, 'r', encoding='utf-8') as file:
        html_content = file.read()
    return html_content

@tool
def export_graph_to_json(graph_data: dict) -> str:
    """
    Exports a NetworkX graph represented as a dictionary in node-link format to JSON.

    Args:
        graph_data: The graph data in node-link format.

    Returns:
        str: The JSON representation of the graph.

    """
    try:
        graph = nx.node_link_graph(graph_data, edges="edges")
        json_output = json.dumps(nx.node_link_data(graph), indent=4)
        return json_output
    except Exception as e:
        return f"Error exporting graph to JSON: {str(e)}"

model = HfApiModel()
agent = CodeAgent(
    tools=[analyze_graph, save_html_to_file, read_html_from_file, export_graph_to_json],
    model=model,
    additional_authorized_imports=["gradio","networkx","community_louvain","pyvis","matplotlib","json", "pandas"],
    add_base_tools=True
)

interface = GradioUIWithExamples(agent, examples=examples)
interface.launch()