Spaces:
Running
Running
""" | |
********** | |
Matplotlib | |
********** | |
Draw networks with matplotlib. | |
Examples | |
-------- | |
G = nx.complete_graph(5) | |
nx.draw(G) | |
See Also | |
-------- | |
- :doc:`matplotlib <matplotlib:index>` | |
- :func:`matplotlib.pyplot.scatter` | |
- :obj:`matplotlib.patches.FancyArrowPatch` | |
""" | |
from numbers import Number | |
import networkx as nx | |
from networkx.drawing.layout import ( | |
circular_layout, | |
kamada_kawai_layout, | |
planar_layout, | |
random_layout, | |
shell_layout, | |
spectral_layout, | |
spring_layout, | |
) | |
__all__ = [ | |
"draw", | |
"draw_networkx", | |
"draw_networkx_nodes", | |
"draw_networkx_edges", | |
"draw_networkx_labels", | |
"draw_networkx_edge_labels", | |
"draw_circular", | |
"draw_kamada_kawai", | |
"draw_random", | |
"draw_spectral", | |
"draw_spring", | |
"draw_planar", | |
"draw_shell", | |
] | |
def draw(G, pos=None, ax=None, **kwds): | |
"""Draw the graph G with Matplotlib. | |
Draw the graph as a simple representation with no node | |
labels or edge labels and using the full Matplotlib figure area | |
and no axis labels by default. See draw_networkx() for more | |
full-featured drawing that allows title, axis labels etc. | |
Parameters | |
---------- | |
G : graph | |
A networkx graph | |
pos : dictionary, optional | |
A dictionary with nodes as keys and positions as values. | |
If not specified a spring layout positioning will be computed. | |
See :py:mod:`networkx.drawing.layout` for functions that | |
compute node positions. | |
ax : Matplotlib Axes object, optional | |
Draw the graph in specified Matplotlib axes. | |
kwds : optional keywords | |
See networkx.draw_networkx() for a description of optional keywords. | |
Examples | |
-------- | |
>>> G = nx.dodecahedral_graph() | |
>>> nx.draw(G) | |
>>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout | |
See Also | |
-------- | |
draw_networkx | |
draw_networkx_nodes | |
draw_networkx_edges | |
draw_networkx_labels | |
draw_networkx_edge_labels | |
Notes | |
----- | |
This function has the same name as pylab.draw and pyplot.draw | |
so beware when using `from networkx import *` | |
since you might overwrite the pylab.draw function. | |
With pyplot use | |
>>> import matplotlib.pyplot as plt | |
>>> G = nx.dodecahedral_graph() | |
>>> nx.draw(G) # networkx draw() | |
>>> plt.draw() # pyplot draw() | |
Also see the NetworkX drawing examples at | |
https://networkx.org/documentation/latest/auto_examples/index.html | |
""" | |
import matplotlib.pyplot as plt | |
if ax is None: | |
cf = plt.gcf() | |
else: | |
cf = ax.get_figure() | |
cf.set_facecolor("w") | |
if ax is None: | |
if cf.axes: | |
ax = cf.gca() | |
else: | |
ax = cf.add_axes((0, 0, 1, 1)) | |
if "with_labels" not in kwds: | |
kwds["with_labels"] = "labels" in kwds | |
draw_networkx(G, pos=pos, ax=ax, **kwds) | |
ax.set_axis_off() | |
plt.draw_if_interactive() | |
return | |
def draw_networkx(G, pos=None, arrows=None, with_labels=True, **kwds): | |
r"""Draw the graph G using Matplotlib. | |
Draw the graph with Matplotlib with options for node positions, | |
labeling, titles, and many other drawing features. | |
See draw() for simple drawing without labels or axes. | |
Parameters | |
---------- | |
G : graph | |
A networkx graph | |
pos : dictionary, optional | |
A dictionary with nodes as keys and positions as values. | |
If not specified a spring layout positioning will be computed. | |
See :py:mod:`networkx.drawing.layout` for functions that | |
compute node positions. | |
arrows : bool or None, optional (default=None) | |
If `None`, directed graphs draw arrowheads with | |
`~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges | |
via `~matplotlib.collections.LineCollection` for speed. | |
If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish). | |
If `False`, draw edges using LineCollection (linear and fast). | |
For directed graphs, if True draw arrowheads. | |
Note: Arrows will be the same color as edges. | |
arrowstyle : str (default='-\|>' for directed graphs) | |
For directed graphs, choose the style of the arrowsheads. | |
For undirected graphs default to '-' | |
See `matplotlib.patches.ArrowStyle` for more options. | |
arrowsize : int or list (default=10) | |
For directed graphs, choose the size of the arrow head's length and | |
width. A list of values can be passed in to assign a different size for arrow head's length and width. | |
See `matplotlib.patches.FancyArrowPatch` for attribute `mutation_scale` | |
for more info. | |
with_labels : bool (default=True) | |
Set to True to draw labels on the nodes. | |
ax : Matplotlib Axes object, optional | |
Draw the graph in the specified Matplotlib axes. | |
nodelist : list (default=list(G)) | |
Draw only specified nodes | |
edgelist : list (default=list(G.edges())) | |
Draw only specified edges | |
node_size : scalar or array (default=300) | |
Size of nodes. If an array is specified it must be the | |
same length as nodelist. | |
node_color : color or array of colors (default='#1f78b4') | |
Node color. Can be a single color or a sequence of colors with the same | |
length as nodelist. Color can be string or rgb (or rgba) tuple of | |
floats from 0-1. If numeric values are specified they will be | |
mapped to colors using the cmap and vmin,vmax parameters. See | |
matplotlib.scatter for more details. | |
node_shape : string (default='o') | |
The shape of the node. Specification is as matplotlib.scatter | |
marker, one of 'so^>v<dph8'. | |
alpha : float or None (default=None) | |
The node and edge transparency | |
cmap : Matplotlib colormap, optional | |
Colormap for mapping intensities of nodes | |
vmin,vmax : float, optional | |
Minimum and maximum for node colormap scaling | |
linewidths : scalar or sequence (default=1.0) | |
Line width of symbol border | |
width : float or array of floats (default=1.0) | |
Line width of edges | |
edge_color : color or array of colors (default='k') | |
Edge color. Can be a single color or a sequence of colors with the same | |
length as edgelist. Color can be string or rgb (or rgba) tuple of | |
floats from 0-1. If numeric values are specified they will be | |
mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters. | |
edge_cmap : Matplotlib colormap, optional | |
Colormap for mapping intensities of edges | |
edge_vmin,edge_vmax : floats, optional | |
Minimum and maximum for edge colormap scaling | |
style : string (default=solid line) | |
Edge line style e.g.: '-', '--', '-.', ':' | |
or words like 'solid' or 'dashed'. | |
(See `matplotlib.patches.FancyArrowPatch`: `linestyle`) | |
labels : dictionary (default=None) | |
Node labels in a dictionary of text labels keyed by node | |
font_size : int (default=12 for nodes, 10 for edges) | |
Font size for text labels | |
font_color : color (default='k' black) | |
Font color string. Color can be string or rgb (or rgba) tuple of | |
floats from 0-1. | |
font_weight : string (default='normal') | |
Font weight | |
font_family : string (default='sans-serif') | |
Font family | |
label : string, optional | |
Label for graph legend | |
kwds : optional keywords | |
See networkx.draw_networkx_nodes(), networkx.draw_networkx_edges(), and | |
networkx.draw_networkx_labels() for a description of optional keywords. | |
Notes | |
----- | |
For directed graphs, arrows are drawn at the head end. Arrows can be | |
turned off with keyword arrows=False. | |
Examples | |
-------- | |
>>> G = nx.dodecahedral_graph() | |
>>> nx.draw(G) | |
>>> nx.draw(G, pos=nx.spring_layout(G)) # use spring layout | |
>>> import matplotlib.pyplot as plt | |
>>> limits = plt.axis("off") # turn off axis | |
Also see the NetworkX drawing examples at | |
https://networkx.org/documentation/latest/auto_examples/index.html | |
See Also | |
-------- | |
draw | |
draw_networkx_nodes | |
draw_networkx_edges | |
draw_networkx_labels | |
draw_networkx_edge_labels | |
""" | |
from inspect import signature | |
import matplotlib.pyplot as plt | |
# Get all valid keywords by inspecting the signatures of draw_networkx_nodes, | |
# draw_networkx_edges, draw_networkx_labels | |
valid_node_kwds = signature(draw_networkx_nodes).parameters.keys() | |
valid_edge_kwds = signature(draw_networkx_edges).parameters.keys() | |
valid_label_kwds = signature(draw_networkx_labels).parameters.keys() | |
# Create a set with all valid keywords across the three functions and | |
# remove the arguments of this function (draw_networkx) | |
valid_kwds = (valid_node_kwds | valid_edge_kwds | valid_label_kwds) - { | |
"G", | |
"pos", | |
"arrows", | |
"with_labels", | |
} | |
if any(k not in valid_kwds for k in kwds): | |
invalid_args = ", ".join([k for k in kwds if k not in valid_kwds]) | |
raise ValueError(f"Received invalid argument(s): {invalid_args}") | |
node_kwds = {k: v for k, v in kwds.items() if k in valid_node_kwds} | |
edge_kwds = {k: v for k, v in kwds.items() if k in valid_edge_kwds} | |
label_kwds = {k: v for k, v in kwds.items() if k in valid_label_kwds} | |
if pos is None: | |
pos = nx.drawing.spring_layout(G) # default to spring layout | |
draw_networkx_nodes(G, pos, **node_kwds) | |
draw_networkx_edges(G, pos, arrows=arrows, **edge_kwds) | |
if with_labels: | |
draw_networkx_labels(G, pos, **label_kwds) | |
plt.draw_if_interactive() | |
def draw_networkx_nodes( | |
G, | |
pos, | |
nodelist=None, | |
node_size=300, | |
node_color="#1f78b4", | |
node_shape="o", | |
alpha=None, | |
cmap=None, | |
vmin=None, | |
vmax=None, | |
ax=None, | |
linewidths=None, | |
edgecolors=None, | |
label=None, | |
margins=None, | |
): | |
"""Draw the nodes of the graph G. | |
This draws only the nodes of the graph G. | |
Parameters | |
---------- | |
G : graph | |
A networkx graph | |
pos : dictionary | |
A dictionary with nodes as keys and positions as values. | |
Positions should be sequences of length 2. | |
ax : Matplotlib Axes object, optional | |
Draw the graph in the specified Matplotlib axes. | |
nodelist : list (default list(G)) | |
Draw only specified nodes | |
node_size : scalar or array (default=300) | |
Size of nodes. If an array it must be the same length as nodelist. | |
node_color : color or array of colors (default='#1f78b4') | |
Node color. Can be a single color or a sequence of colors with the same | |
length as nodelist. Color can be string or rgb (or rgba) tuple of | |
floats from 0-1. If numeric values are specified they will be | |
mapped to colors using the cmap and vmin,vmax parameters. See | |
matplotlib.scatter for more details. | |
node_shape : string (default='o') | |
The shape of the node. Specification is as matplotlib.scatter | |
marker, one of 'so^>v<dph8'. | |
alpha : float or array of floats (default=None) | |
The node transparency. This can be a single alpha value, | |
in which case it will be applied to all the nodes of color. Otherwise, | |
if it is an array, the elements of alpha will be applied to the colors | |
in order (cycling through alpha multiple times if necessary). | |
cmap : Matplotlib colormap (default=None) | |
Colormap for mapping intensities of nodes | |
vmin,vmax : floats or None (default=None) | |
Minimum and maximum for node colormap scaling | |
linewidths : [None | scalar | sequence] (default=1.0) | |
Line width of symbol border | |
edgecolors : [None | scalar | sequence] (default = node_color) | |
Colors of node borders. Can be a single color or a sequence of colors with the | |
same length as nodelist. Color can be string or rgb (or rgba) tuple of floats | |
from 0-1. If numeric values are specified they will be mapped to colors | |
using the cmap and vmin,vmax parameters. See `~matplotlib.pyplot.scatter` for more details. | |
label : [None | string] | |
Label for legend | |
margins : float or 2-tuple, optional | |
Sets the padding for axis autoscaling. Increase margin to prevent | |
clipping for nodes that are near the edges of an image. Values should | |
be in the range ``[0, 1]``. See :meth:`matplotlib.axes.Axes.margins` | |
for details. The default is `None`, which uses the Matplotlib default. | |
Returns | |
------- | |
matplotlib.collections.PathCollection | |
`PathCollection` of the nodes. | |
Examples | |
-------- | |
>>> G = nx.dodecahedral_graph() | |
>>> nodes = nx.draw_networkx_nodes(G, pos=nx.spring_layout(G)) | |
Also see the NetworkX drawing examples at | |
https://networkx.org/documentation/latest/auto_examples/index.html | |
See Also | |
-------- | |
draw | |
draw_networkx | |
draw_networkx_edges | |
draw_networkx_labels | |
draw_networkx_edge_labels | |
""" | |
from collections.abc import Iterable | |
import matplotlib as mpl | |
import matplotlib.collections # call as mpl.collections | |
import matplotlib.pyplot as plt | |
import numpy as np | |
if ax is None: | |
ax = plt.gca() | |
if nodelist is None: | |
nodelist = list(G) | |
if len(nodelist) == 0: # empty nodelist, no drawing | |
return mpl.collections.PathCollection(None) | |
try: | |
xy = np.asarray([pos[v] for v in nodelist]) | |
except KeyError as err: | |
raise nx.NetworkXError(f"Node {err} has no position.") from err | |
if isinstance(alpha, Iterable): | |
node_color = apply_alpha(node_color, alpha, nodelist, cmap, vmin, vmax) | |
alpha = None | |
node_collection = ax.scatter( | |
xy[:, 0], | |
xy[:, 1], | |
s=node_size, | |
c=node_color, | |
marker=node_shape, | |
cmap=cmap, | |
vmin=vmin, | |
vmax=vmax, | |
alpha=alpha, | |
linewidths=linewidths, | |
edgecolors=edgecolors, | |
label=label, | |
) | |
ax.tick_params( | |
axis="both", | |
which="both", | |
bottom=False, | |
left=False, | |
labelbottom=False, | |
labelleft=False, | |
) | |
if margins is not None: | |
if isinstance(margins, Iterable): | |
ax.margins(*margins) | |
else: | |
ax.margins(margins) | |
node_collection.set_zorder(2) | |
return node_collection | |
def draw_networkx_edges( | |
G, | |
pos, | |
edgelist=None, | |
width=1.0, | |
edge_color="k", | |
style="solid", | |
alpha=None, | |
arrowstyle=None, | |
arrowsize=10, | |
edge_cmap=None, | |
edge_vmin=None, | |
edge_vmax=None, | |
ax=None, | |
arrows=None, | |
label=None, | |
node_size=300, | |
nodelist=None, | |
node_shape="o", | |
connectionstyle="arc3", | |
min_source_margin=0, | |
min_target_margin=0, | |
): | |
r"""Draw the edges of the graph G. | |
This draws only the edges of the graph G. | |
Parameters | |
---------- | |
G : graph | |
A networkx graph | |
pos : dictionary | |
A dictionary with nodes as keys and positions as values. | |
Positions should be sequences of length 2. | |
edgelist : collection of edge tuples (default=G.edges()) | |
Draw only specified edges | |
width : float or array of floats (default=1.0) | |
Line width of edges | |
edge_color : color or array of colors (default='k') | |
Edge color. Can be a single color or a sequence of colors with the same | |
length as edgelist. Color can be string or rgb (or rgba) tuple of | |
floats from 0-1. If numeric values are specified they will be | |
mapped to colors using the edge_cmap and edge_vmin,edge_vmax parameters. | |
style : string or array of strings (default='solid') | |
Edge line style e.g.: '-', '--', '-.', ':' | |
or words like 'solid' or 'dashed'. | |
Can be a single style or a sequence of styles with the same | |
length as the edge list. | |
If less styles than edges are given the styles will cycle. | |
If more styles than edges are given the styles will be used sequentially | |
and not be exhausted. | |
Also, `(offset, onoffseq)` tuples can be used as style instead of a strings. | |
(See `matplotlib.patches.FancyArrowPatch`: `linestyle`) | |
alpha : float or array of floats (default=None) | |
The edge transparency. This can be a single alpha value, | |
in which case it will be applied to all specified edges. Otherwise, | |
if it is an array, the elements of alpha will be applied to the colors | |
in order (cycling through alpha multiple times if necessary). | |
edge_cmap : Matplotlib colormap, optional | |
Colormap for mapping intensities of edges | |
edge_vmin,edge_vmax : floats, optional | |
Minimum and maximum for edge colormap scaling | |
ax : Matplotlib Axes object, optional | |
Draw the graph in the specified Matplotlib axes. | |
arrows : bool or None, optional (default=None) | |
If `None`, directed graphs draw arrowheads with | |
`~matplotlib.patches.FancyArrowPatch`, while undirected graphs draw edges | |
via `~matplotlib.collections.LineCollection` for speed. | |
If `True`, draw arrowheads with FancyArrowPatches (bendable and stylish). | |
If `False`, draw edges using LineCollection (linear and fast). | |
Note: Arrowheads will be the same color as edges. | |
arrowstyle : str (default='-\|>' for directed graphs) | |
For directed graphs and `arrows==True` defaults to '-\|>', | |
For undirected graphs default to '-'. | |
See `matplotlib.patches.ArrowStyle` for more options. | |
arrowsize : int (default=10) | |
For directed graphs, choose the size of the arrow head's length and | |
width. See `matplotlib.patches.FancyArrowPatch` for attribute | |
`mutation_scale` for more info. | |
connectionstyle : string (default="arc3") | |
Pass the connectionstyle parameter to create curved arc of rounding | |
radius rad. For example, connectionstyle='arc3,rad=0.2'. | |
See `matplotlib.patches.ConnectionStyle` and | |
`matplotlib.patches.FancyArrowPatch` for more info. | |
node_size : scalar or array (default=300) | |
Size of nodes. Though the nodes are not drawn with this function, the | |
node size is used in determining edge positioning. | |
nodelist : list, optional (default=G.nodes()) | |
This provides the node order for the `node_size` array (if it is an array). | |
node_shape : string (default='o') | |
The marker used for nodes, used in determining edge positioning. | |
Specification is as a `matplotlib.markers` marker, e.g. one of 'so^>v<dph8'. | |
label : None or string | |
Label for legend | |
min_source_margin : int (default=0) | |
The minimum margin (gap) at the beginning of the edge at the source. | |
min_target_margin : int (default=0) | |
The minimum margin (gap) at the end of the edge at the target. | |
Returns | |
------- | |
matplotlib.collections.LineCollection or a list of matplotlib.patches.FancyArrowPatch | |
If ``arrows=True``, a list of FancyArrowPatches is returned. | |
If ``arrows=False``, a LineCollection is returned. | |
If ``arrows=None`` (the default), then a LineCollection is returned if | |
`G` is undirected, otherwise returns a list of FancyArrowPatches. | |
Notes | |
----- | |
For directed graphs, arrows are drawn at the head end. Arrows can be | |
turned off with keyword arrows=False or by passing an arrowstyle without | |
an arrow on the end. | |
Be sure to include `node_size` as a keyword argument; arrows are | |
drawn considering the size of nodes. | |
Self-loops are always drawn with `~matplotlib.patches.FancyArrowPatch` | |
regardless of the value of `arrows` or whether `G` is directed. | |
When ``arrows=False`` or ``arrows=None`` and `G` is undirected, the | |
FancyArrowPatches corresponding to the self-loops are not explicitly | |
returned. They should instead be accessed via the ``Axes.patches`` | |
attribute (see examples). | |
Examples | |
-------- | |
>>> G = nx.dodecahedral_graph() | |
>>> edges = nx.draw_networkx_edges(G, pos=nx.spring_layout(G)) | |
>>> G = nx.DiGraph() | |
>>> G.add_edges_from([(1, 2), (1, 3), (2, 3)]) | |
>>> arcs = nx.draw_networkx_edges(G, pos=nx.spring_layout(G)) | |
>>> alphas = [0.3, 0.4, 0.5] | |
>>> for i, arc in enumerate(arcs): # change alpha values of arcs | |
... arc.set_alpha(alphas[i]) | |
The FancyArrowPatches corresponding to self-loops are not always | |
returned, but can always be accessed via the ``patches`` attribute of the | |
`matplotlib.Axes` object. | |
>>> import matplotlib.pyplot as plt | |
>>> fig, ax = plt.subplots() | |
>>> G = nx.Graph([(0, 1), (0, 0)]) # Self-loop at node 0 | |
>>> edge_collection = nx.draw_networkx_edges(G, pos=nx.circular_layout(G), ax=ax) | |
>>> self_loop_fap = ax.patches[0] | |
Also see the NetworkX drawing examples at | |
https://networkx.org/documentation/latest/auto_examples/index.html | |
See Also | |
-------- | |
draw | |
draw_networkx | |
draw_networkx_nodes | |
draw_networkx_labels | |
draw_networkx_edge_labels | |
""" | |
import matplotlib as mpl | |
import matplotlib.collections # call as mpl.collections | |
import matplotlib.colors # call as mpl.colors | |
import matplotlib.patches # call as mpl.patches | |
import matplotlib.path # call as mpl.path | |
import matplotlib.pyplot as plt | |
import numpy as np | |
# The default behavior is to use LineCollection to draw edges for | |
# undirected graphs (for performance reasons) and use FancyArrowPatches | |
# for directed graphs. | |
# The `arrows` keyword can be used to override the default behavior | |
use_linecollection = not G.is_directed() | |
if arrows in (True, False): | |
use_linecollection = not arrows | |
# Some kwargs only apply to FancyArrowPatches. Warn users when they use | |
# non-default values for these kwargs when LineCollection is being used | |
# instead of silently ignoring the specified option | |
if use_linecollection and any( | |
[ | |
arrowstyle is not None, | |
arrowsize != 10, | |
connectionstyle != "arc3", | |
min_source_margin != 0, | |
min_target_margin != 0, | |
] | |
): | |
import warnings | |
msg = ( | |
"\n\nThe {0} keyword argument is not applicable when drawing edges\n" | |
"with LineCollection.\n\n" | |
"To make this warning go away, either specify `arrows=True` to\n" | |
"force FancyArrowPatches or use the default value for {0}.\n" | |
"Note that using FancyArrowPatches may be slow for large graphs.\n" | |
) | |
if arrowstyle is not None: | |
msg = msg.format("arrowstyle") | |
if arrowsize != 10: | |
msg = msg.format("arrowsize") | |
if connectionstyle != "arc3": | |
msg = msg.format("connectionstyle") | |
if min_source_margin != 0: | |
msg = msg.format("min_source_margin") | |
if min_target_margin != 0: | |
msg = msg.format("min_target_margin") | |
warnings.warn(msg, category=UserWarning, stacklevel=2) | |
if arrowstyle == None: | |
if G.is_directed(): | |
arrowstyle = "-|>" | |
else: | |
arrowstyle = "-" | |
if ax is None: | |
ax = plt.gca() | |
if edgelist is None: | |
edgelist = list(G.edges()) | |
if len(edgelist) == 0: # no edges! | |
return [] | |
if nodelist is None: | |
nodelist = list(G.nodes()) | |
# FancyArrowPatch handles color=None different from LineCollection | |
if edge_color is None: | |
edge_color = "k" | |
edgelist_tuple = list(map(tuple, edgelist)) | |
# set edge positions | |
edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edgelist]) | |
# Check if edge_color is an array of floats and map to edge_cmap. | |
# This is the only case handled differently from matplotlib | |
if ( | |
np.iterable(edge_color) | |
and (len(edge_color) == len(edge_pos)) | |
and np.all([isinstance(c, Number) for c in edge_color]) | |
): | |
if edge_cmap is not None: | |
assert isinstance(edge_cmap, mpl.colors.Colormap) | |
else: | |
edge_cmap = plt.get_cmap() | |
if edge_vmin is None: | |
edge_vmin = min(edge_color) | |
if edge_vmax is None: | |
edge_vmax = max(edge_color) | |
color_normal = mpl.colors.Normalize(vmin=edge_vmin, vmax=edge_vmax) | |
edge_color = [edge_cmap(color_normal(e)) for e in edge_color] | |
def _draw_networkx_edges_line_collection(): | |
edge_collection = mpl.collections.LineCollection( | |
edge_pos, | |
colors=edge_color, | |
linewidths=width, | |
antialiaseds=(1,), | |
linestyle=style, | |
alpha=alpha, | |
) | |
edge_collection.set_cmap(edge_cmap) | |
edge_collection.set_clim(edge_vmin, edge_vmax) | |
edge_collection.set_zorder(1) # edges go behind nodes | |
edge_collection.set_label(label) | |
ax.add_collection(edge_collection) | |
return edge_collection | |
def _draw_networkx_edges_fancy_arrow_patch(): | |
# Note: Waiting for someone to implement arrow to intersection with | |
# marker. Meanwhile, this works well for polygons with more than 4 | |
# sides and circle. | |
def to_marker_edge(marker_size, marker): | |
if marker in "s^>v<d": # `large` markers need extra space | |
return np.sqrt(2 * marker_size) / 2 | |
else: | |
return np.sqrt(marker_size) / 2 | |
# Draw arrows with `matplotlib.patches.FancyarrowPatch` | |
arrow_collection = [] | |
if isinstance(arrowsize, list): | |
if len(arrowsize) != len(edge_pos): | |
raise ValueError("arrowsize should have the same length as edgelist") | |
else: | |
mutation_scale = arrowsize # scale factor of arrow head | |
base_connection_style = mpl.patches.ConnectionStyle(connectionstyle) | |
# Fallback for self-loop scale. Left outside of _connectionstyle so it is | |
# only computed once | |
max_nodesize = np.array(node_size).max() | |
def _connectionstyle(posA, posB, *args, **kwargs): | |
# check if we need to do a self-loop | |
if np.all(posA == posB): | |
# Self-loops are scaled by view extent, except in cases the extent | |
# is 0, e.g. for a single node. In this case, fall back to scaling | |
# by the maximum node size | |
selfloop_ht = 0.005 * max_nodesize if h == 0 else h | |
# this is called with _screen space_ values so convert back | |
# to data space | |
data_loc = ax.transData.inverted().transform(posA) | |
v_shift = 0.1 * selfloop_ht | |
h_shift = v_shift * 0.5 | |
# put the top of the loop first so arrow is not hidden by node | |
path = [ | |
# 1 | |
data_loc + np.asarray([0, v_shift]), | |
# 4 4 4 | |
data_loc + np.asarray([h_shift, v_shift]), | |
data_loc + np.asarray([h_shift, 0]), | |
data_loc, | |
# 4 4 4 | |
data_loc + np.asarray([-h_shift, 0]), | |
data_loc + np.asarray([-h_shift, v_shift]), | |
data_loc + np.asarray([0, v_shift]), | |
] | |
ret = mpl.path.Path(ax.transData.transform(path), [1, 4, 4, 4, 4, 4, 4]) | |
# if not, fall back to the user specified behavior | |
else: | |
ret = base_connection_style(posA, posB, *args, **kwargs) | |
return ret | |
# FancyArrowPatch doesn't handle color strings | |
arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha) | |
for i, (src, dst) in zip(fancy_edges_indices, edge_pos): | |
x1, y1 = src | |
x2, y2 = dst | |
shrink_source = 0 # space from source to tail | |
shrink_target = 0 # space from head to target | |
if isinstance(arrowsize, list): | |
# Scale each factor of each arrow based on arrowsize list | |
mutation_scale = arrowsize[i] | |
if np.iterable(node_size): # many node sizes | |
source, target = edgelist[i][:2] | |
source_node_size = node_size[nodelist.index(source)] | |
target_node_size = node_size[nodelist.index(target)] | |
shrink_source = to_marker_edge(source_node_size, node_shape) | |
shrink_target = to_marker_edge(target_node_size, node_shape) | |
else: | |
shrink_source = shrink_target = to_marker_edge(node_size, node_shape) | |
if shrink_source < min_source_margin: | |
shrink_source = min_source_margin | |
if shrink_target < min_target_margin: | |
shrink_target = min_target_margin | |
if len(arrow_colors) > i: | |
arrow_color = arrow_colors[i] | |
elif len(arrow_colors) == 1: | |
arrow_color = arrow_colors[0] | |
else: # Cycle through colors | |
arrow_color = arrow_colors[i % len(arrow_colors)] | |
if np.iterable(width): | |
if len(width) > i: | |
line_width = width[i] | |
else: | |
line_width = width[i % len(width)] | |
else: | |
line_width = width | |
if ( | |
np.iterable(style) | |
and not isinstance(style, str) | |
and not isinstance(style, tuple) | |
): | |
if len(style) > i: | |
linestyle = style[i] | |
else: # Cycle through styles | |
linestyle = style[i % len(style)] | |
else: | |
linestyle = style | |
arrow = mpl.patches.FancyArrowPatch( | |
(x1, y1), | |
(x2, y2), | |
arrowstyle=arrowstyle, | |
shrinkA=shrink_source, | |
shrinkB=shrink_target, | |
mutation_scale=mutation_scale, | |
color=arrow_color, | |
linewidth=line_width, | |
connectionstyle=_connectionstyle, | |
linestyle=linestyle, | |
zorder=1, | |
) # arrows go behind nodes | |
arrow_collection.append(arrow) | |
ax.add_patch(arrow) | |
return arrow_collection | |
# compute initial view | |
minx = np.amin(np.ravel(edge_pos[:, :, 0])) | |
maxx = np.amax(np.ravel(edge_pos[:, :, 0])) | |
miny = np.amin(np.ravel(edge_pos[:, :, 1])) | |
maxy = np.amax(np.ravel(edge_pos[:, :, 1])) | |
w = maxx - minx | |
h = maxy - miny | |
# Draw the edges | |
if use_linecollection: | |
edge_viz_obj = _draw_networkx_edges_line_collection() | |
# Make sure selfloop edges are also drawn | |
selfloops_to_draw = [loop for loop in nx.selfloop_edges(G) if loop in edgelist] | |
if selfloops_to_draw: | |
fancy_edges_indices = [ | |
edgelist_tuple.index(loop) for loop in selfloops_to_draw | |
] | |
edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in selfloops_to_draw]) | |
arrowstyle = "-" | |
_draw_networkx_edges_fancy_arrow_patch() | |
else: | |
fancy_edges_indices = range(len(edgelist)) | |
edge_viz_obj = _draw_networkx_edges_fancy_arrow_patch() | |
# update view after drawing | |
padx, pady = 0.05 * w, 0.05 * h | |
corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady) | |
ax.update_datalim(corners) | |
ax.autoscale_view() | |
ax.tick_params( | |
axis="both", | |
which="both", | |
bottom=False, | |
left=False, | |
labelbottom=False, | |
labelleft=False, | |
) | |
return edge_viz_obj | |
def draw_networkx_labels( | |
G, | |
pos, | |
labels=None, | |
font_size=12, | |
font_color="k", | |
font_family="sans-serif", | |
font_weight="normal", | |
alpha=None, | |
bbox=None, | |
horizontalalignment="center", | |
verticalalignment="center", | |
ax=None, | |
clip_on=True, | |
): | |
"""Draw node labels on the graph G. | |
Parameters | |
---------- | |
G : graph | |
A networkx graph | |
pos : dictionary | |
A dictionary with nodes as keys and positions as values. | |
Positions should be sequences of length 2. | |
labels : dictionary (default={n: n for n in G}) | |
Node labels in a dictionary of text labels keyed by node. | |
Node-keys in labels should appear as keys in `pos`. | |
If needed use: `{n:lab for n,lab in labels.items() if n in pos}` | |
font_size : int (default=12) | |
Font size for text labels | |
font_color : color (default='k' black) | |
Font color string. Color can be string or rgb (or rgba) tuple of | |
floats from 0-1. | |
font_weight : string (default='normal') | |
Font weight | |
font_family : string (default='sans-serif') | |
Font family | |
alpha : float or None (default=None) | |
The text transparency | |
bbox : Matplotlib bbox, (default is Matplotlib's ax.text default) | |
Specify text box properties (e.g. shape, color etc.) for node labels. | |
horizontalalignment : string (default='center') | |
Horizontal alignment {'center', 'right', 'left'} | |
verticalalignment : string (default='center') | |
Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'} | |
ax : Matplotlib Axes object, optional | |
Draw the graph in the specified Matplotlib axes. | |
clip_on : bool (default=True) | |
Turn on clipping of node labels at axis boundaries | |
Returns | |
------- | |
dict | |
`dict` of labels keyed on the nodes | |
Examples | |
-------- | |
>>> G = nx.dodecahedral_graph() | |
>>> labels = nx.draw_networkx_labels(G, pos=nx.spring_layout(G)) | |
Also see the NetworkX drawing examples at | |
https://networkx.org/documentation/latest/auto_examples/index.html | |
See Also | |
-------- | |
draw | |
draw_networkx | |
draw_networkx_nodes | |
draw_networkx_edges | |
draw_networkx_edge_labels | |
""" | |
import matplotlib.pyplot as plt | |
if ax is None: | |
ax = plt.gca() | |
if labels is None: | |
labels = {n: n for n in G.nodes()} | |
text_items = {} # there is no text collection so we'll fake one | |
for n, label in labels.items(): | |
(x, y) = pos[n] | |
if not isinstance(label, str): | |
label = str(label) # this makes "1" and 1 labeled the same | |
t = ax.text( | |
x, | |
y, | |
label, | |
size=font_size, | |
color=font_color, | |
family=font_family, | |
weight=font_weight, | |
alpha=alpha, | |
horizontalalignment=horizontalalignment, | |
verticalalignment=verticalalignment, | |
transform=ax.transData, | |
bbox=bbox, | |
clip_on=clip_on, | |
) | |
text_items[n] = t | |
ax.tick_params( | |
axis="both", | |
which="both", | |
bottom=False, | |
left=False, | |
labelbottom=False, | |
labelleft=False, | |
) | |
return text_items | |
def draw_networkx_edge_labels( | |
G, | |
pos, | |
edge_labels=None, | |
label_pos=0.5, | |
font_size=10, | |
font_color="k", | |
font_family="sans-serif", | |
font_weight="normal", | |
alpha=None, | |
bbox=None, | |
horizontalalignment="center", | |
verticalalignment="center", | |
ax=None, | |
rotate=True, | |
clip_on=True, | |
): | |
"""Draw edge labels. | |
Parameters | |
---------- | |
G : graph | |
A networkx graph | |
pos : dictionary | |
A dictionary with nodes as keys and positions as values. | |
Positions should be sequences of length 2. | |
edge_labels : dictionary (default=None) | |
Edge labels in a dictionary of labels keyed by edge two-tuple. | |
Only labels for the keys in the dictionary are drawn. | |
label_pos : float (default=0.5) | |
Position of edge label along edge (0=head, 0.5=center, 1=tail) | |
font_size : int (default=10) | |
Font size for text labels | |
font_color : color (default='k' black) | |
Font color string. Color can be string or rgb (or rgba) tuple of | |
floats from 0-1. | |
font_weight : string (default='normal') | |
Font weight | |
font_family : string (default='sans-serif') | |
Font family | |
alpha : float or None (default=None) | |
The text transparency | |
bbox : Matplotlib bbox, optional | |
Specify text box properties (e.g. shape, color etc.) for edge labels. | |
Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}. | |
horizontalalignment : string (default='center') | |
Horizontal alignment {'center', 'right', 'left'} | |
verticalalignment : string (default='center') | |
Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'} | |
ax : Matplotlib Axes object, optional | |
Draw the graph in the specified Matplotlib axes. | |
rotate : bool (default=True) | |
Rotate edge labels to lie parallel to edges | |
clip_on : bool (default=True) | |
Turn on clipping of edge labels at axis boundaries | |
Returns | |
------- | |
dict | |
`dict` of labels keyed by edge | |
Examples | |
-------- | |
>>> G = nx.dodecahedral_graph() | |
>>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G)) | |
Also see the NetworkX drawing examples at | |
https://networkx.org/documentation/latest/auto_examples/index.html | |
See Also | |
-------- | |
draw | |
draw_networkx | |
draw_networkx_nodes | |
draw_networkx_edges | |
draw_networkx_labels | |
""" | |
import matplotlib.pyplot as plt | |
import numpy as np | |
if ax is None: | |
ax = plt.gca() | |
if edge_labels is None: | |
labels = {(u, v): d for u, v, d in G.edges(data=True)} | |
else: | |
labels = edge_labels | |
# Informative exception for multiedges | |
try: | |
(u, v) = next(iter(labels)) # ensures no edge key provided | |
except ValueError as err: | |
raise nx.NetworkXError( | |
"draw_networkx_edge_labels does not support multiedges." | |
) from err | |
except StopIteration: | |
pass | |
text_items = {} | |
for (n1, n2), label in labels.items(): | |
(x1, y1) = pos[n1] | |
(x2, y2) = pos[n2] | |
(x, y) = ( | |
x1 * label_pos + x2 * (1.0 - label_pos), | |
y1 * label_pos + y2 * (1.0 - label_pos), | |
) | |
if rotate: | |
# in degrees | |
angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360 | |
# make label orientation "right-side-up" | |
if angle > 90: | |
angle -= 180 | |
if angle < -90: | |
angle += 180 | |
# transform data coordinate angle to screen coordinate angle | |
xy = np.array((x, y)) | |
trans_angle = ax.transData.transform_angles( | |
np.array((angle,)), xy.reshape((1, 2)) | |
)[0] | |
else: | |
trans_angle = 0.0 | |
# use default box of white with white border | |
if bbox is None: | |
bbox = {"boxstyle": "round", "ec": (1.0, 1.0, 1.0), "fc": (1.0, 1.0, 1.0)} | |
if not isinstance(label, str): | |
label = str(label) # this makes "1" and 1 labeled the same | |
t = ax.text( | |
x, | |
y, | |
label, | |
size=font_size, | |
color=font_color, | |
family=font_family, | |
weight=font_weight, | |
alpha=alpha, | |
horizontalalignment=horizontalalignment, | |
verticalalignment=verticalalignment, | |
rotation=trans_angle, | |
transform=ax.transData, | |
bbox=bbox, | |
zorder=1, | |
clip_on=clip_on, | |
) | |
text_items[(n1, n2)] = t | |
ax.tick_params( | |
axis="both", | |
which="both", | |
bottom=False, | |
left=False, | |
labelbottom=False, | |
labelleft=False, | |
) | |
return text_items | |
def draw_circular(G, **kwargs): | |
"""Draw the graph `G` with a circular layout. | |
This is a convenience function equivalent to:: | |
nx.draw(G, pos=nx.circular_layout(G), **kwargs) | |
Parameters | |
---------- | |
G : graph | |
A networkx graph | |
kwargs : optional keywords | |
See `draw_networkx` for a description of optional keywords. | |
Notes | |
----- | |
The layout is computed each time this function is called. For | |
repeated drawing it is much more efficient to call | |
`~networkx.drawing.layout.circular_layout` directly and reuse the result:: | |
>>> G = nx.complete_graph(5) | |
>>> pos = nx.circular_layout(G) | |
>>> nx.draw(G, pos=pos) # Draw the original graph | |
>>> # Draw a subgraph, reusing the same node positions | |
>>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") | |
Examples | |
-------- | |
>>> G = nx.path_graph(5) | |
>>> nx.draw_circular(G) | |
See Also | |
-------- | |
:func:`~networkx.drawing.layout.circular_layout` | |
""" | |
draw(G, circular_layout(G), **kwargs) | |
def draw_kamada_kawai(G, **kwargs): | |
"""Draw the graph `G` with a Kamada-Kawai force-directed layout. | |
This is a convenience function equivalent to:: | |
nx.draw(G, pos=nx.kamada_kawai_layout(G), **kwargs) | |
Parameters | |
---------- | |
G : graph | |
A networkx graph | |
kwargs : optional keywords | |
See `draw_networkx` for a description of optional keywords. | |
Notes | |
----- | |
The layout is computed each time this function is called. | |
For repeated drawing it is much more efficient to call | |
`~networkx.drawing.layout.kamada_kawai_layout` directly and reuse the | |
result:: | |
>>> G = nx.complete_graph(5) | |
>>> pos = nx.kamada_kawai_layout(G) | |
>>> nx.draw(G, pos=pos) # Draw the original graph | |
>>> # Draw a subgraph, reusing the same node positions | |
>>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") | |
Examples | |
-------- | |
>>> G = nx.path_graph(5) | |
>>> nx.draw_kamada_kawai(G) | |
See Also | |
-------- | |
:func:`~networkx.drawing.layout.kamada_kawai_layout` | |
""" | |
draw(G, kamada_kawai_layout(G), **kwargs) | |
def draw_random(G, **kwargs): | |
"""Draw the graph `G` with a random layout. | |
This is a convenience function equivalent to:: | |
nx.draw(G, pos=nx.random_layout(G), **kwargs) | |
Parameters | |
---------- | |
G : graph | |
A networkx graph | |
kwargs : optional keywords | |
See `draw_networkx` for a description of optional keywords. | |
Notes | |
----- | |
The layout is computed each time this function is called. | |
For repeated drawing it is much more efficient to call | |
`~networkx.drawing.layout.random_layout` directly and reuse the result:: | |
>>> G = nx.complete_graph(5) | |
>>> pos = nx.random_layout(G) | |
>>> nx.draw(G, pos=pos) # Draw the original graph | |
>>> # Draw a subgraph, reusing the same node positions | |
>>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") | |
Examples | |
-------- | |
>>> G = nx.lollipop_graph(4, 3) | |
>>> nx.draw_random(G) | |
See Also | |
-------- | |
:func:`~networkx.drawing.layout.random_layout` | |
""" | |
draw(G, random_layout(G), **kwargs) | |
def draw_spectral(G, **kwargs): | |
"""Draw the graph `G` with a spectral 2D layout. | |
This is a convenience function equivalent to:: | |
nx.draw(G, pos=nx.spectral_layout(G), **kwargs) | |
For more information about how node positions are determined, see | |
`~networkx.drawing.layout.spectral_layout`. | |
Parameters | |
---------- | |
G : graph | |
A networkx graph | |
kwargs : optional keywords | |
See `draw_networkx` for a description of optional keywords. | |
Notes | |
----- | |
The layout is computed each time this function is called. | |
For repeated drawing it is much more efficient to call | |
`~networkx.drawing.layout.spectral_layout` directly and reuse the result:: | |
>>> G = nx.complete_graph(5) | |
>>> pos = nx.spectral_layout(G) | |
>>> nx.draw(G, pos=pos) # Draw the original graph | |
>>> # Draw a subgraph, reusing the same node positions | |
>>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") | |
Examples | |
-------- | |
>>> G = nx.path_graph(5) | |
>>> nx.draw_spectral(G) | |
See Also | |
-------- | |
:func:`~networkx.drawing.layout.spectral_layout` | |
""" | |
draw(G, spectral_layout(G), **kwargs) | |
def draw_spring(G, **kwargs): | |
"""Draw the graph `G` with a spring layout. | |
This is a convenience function equivalent to:: | |
nx.draw(G, pos=nx.spring_layout(G), **kwargs) | |
Parameters | |
---------- | |
G : graph | |
A networkx graph | |
kwargs : optional keywords | |
See `draw_networkx` for a description of optional keywords. | |
Notes | |
----- | |
`~networkx.drawing.layout.spring_layout` is also the default layout for | |
`draw`, so this function is equivalent to `draw`. | |
The layout is computed each time this function is called. | |
For repeated drawing it is much more efficient to call | |
`~networkx.drawing.layout.spring_layout` directly and reuse the result:: | |
>>> G = nx.complete_graph(5) | |
>>> pos = nx.spring_layout(G) | |
>>> nx.draw(G, pos=pos) # Draw the original graph | |
>>> # Draw a subgraph, reusing the same node positions | |
>>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") | |
Examples | |
-------- | |
>>> G = nx.path_graph(20) | |
>>> nx.draw_spring(G) | |
See Also | |
-------- | |
draw | |
:func:`~networkx.drawing.layout.spring_layout` | |
""" | |
draw(G, spring_layout(G), **kwargs) | |
def draw_shell(G, nlist=None, **kwargs): | |
"""Draw networkx graph `G` with shell layout. | |
This is a convenience function equivalent to:: | |
nx.draw(G, pos=nx.shell_layout(G, nlist=nlist), **kwargs) | |
Parameters | |
---------- | |
G : graph | |
A networkx graph | |
nlist : list of list of nodes, optional | |
A list containing lists of nodes representing the shells. | |
Default is `None`, meaning all nodes are in a single shell. | |
See `~networkx.drawing.layout.shell_layout` for details. | |
kwargs : optional keywords | |
See `draw_networkx` for a description of optional keywords. | |
Notes | |
----- | |
The layout is computed each time this function is called. | |
For repeated drawing it is much more efficient to call | |
`~networkx.drawing.layout.shell_layout` directly and reuse the result:: | |
>>> G = nx.complete_graph(5) | |
>>> pos = nx.shell_layout(G) | |
>>> nx.draw(G, pos=pos) # Draw the original graph | |
>>> # Draw a subgraph, reusing the same node positions | |
>>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") | |
Examples | |
-------- | |
>>> G = nx.path_graph(4) | |
>>> shells = [[0], [1, 2, 3]] | |
>>> nx.draw_shell(G, nlist=shells) | |
See Also | |
-------- | |
:func:`~networkx.drawing.layout.shell_layout` | |
""" | |
draw(G, shell_layout(G, nlist=nlist), **kwargs) | |
def draw_planar(G, **kwargs): | |
"""Draw a planar networkx graph `G` with planar layout. | |
This is a convenience function equivalent to:: | |
nx.draw(G, pos=nx.planar_layout(G), **kwargs) | |
Parameters | |
---------- | |
G : graph | |
A planar networkx graph | |
kwargs : optional keywords | |
See `draw_networkx` for a description of optional keywords. | |
Raises | |
------ | |
NetworkXException | |
When `G` is not planar | |
Notes | |
----- | |
The layout is computed each time this function is called. | |
For repeated drawing it is much more efficient to call | |
`~networkx.drawing.layout.planar_layout` directly and reuse the result:: | |
>>> G = nx.path_graph(5) | |
>>> pos = nx.planar_layout(G) | |
>>> nx.draw(G, pos=pos) # Draw the original graph | |
>>> # Draw a subgraph, reusing the same node positions | |
>>> nx.draw(G.subgraph([0, 1, 2]), pos=pos, node_color="red") | |
Examples | |
-------- | |
>>> G = nx.path_graph(4) | |
>>> nx.draw_planar(G) | |
See Also | |
-------- | |
:func:`~networkx.drawing.layout.planar_layout` | |
""" | |
draw(G, planar_layout(G), **kwargs) | |
def apply_alpha(colors, alpha, elem_list, cmap=None, vmin=None, vmax=None): | |
"""Apply an alpha (or list of alphas) to the colors provided. | |
Parameters | |
---------- | |
colors : color string or array of floats (default='r') | |
Color of element. Can be a single color format string, | |
or a sequence of colors with the same length as nodelist. | |
If numeric values are specified they will be mapped to | |
colors using the cmap and vmin,vmax parameters. See | |
matplotlib.scatter for more details. | |
alpha : float or array of floats | |
Alpha values for elements. This can be a single alpha value, in | |
which case it will be applied to all the elements of color. Otherwise, | |
if it is an array, the elements of alpha will be applied to the colors | |
in order (cycling through alpha multiple times if necessary). | |
elem_list : array of networkx objects | |
The list of elements which are being colored. These could be nodes, | |
edges or labels. | |
cmap : matplotlib colormap | |
Color map for use if colors is a list of floats corresponding to points | |
on a color mapping. | |
vmin, vmax : float | |
Minimum and maximum values for normalizing colors if a colormap is used | |
Returns | |
------- | |
rgba_colors : numpy ndarray | |
Array containing RGBA format values for each of the node colours. | |
""" | |
from itertools import cycle, islice | |
import matplotlib as mpl | |
import matplotlib.cm # call as mpl.cm | |
import matplotlib.colors # call as mpl.colors | |
import numpy as np | |
# If we have been provided with a list of numbers as long as elem_list, | |
# apply the color mapping. | |
if len(colors) == len(elem_list) and isinstance(colors[0], Number): | |
mapper = mpl.cm.ScalarMappable(cmap=cmap) | |
mapper.set_clim(vmin, vmax) | |
rgba_colors = mapper.to_rgba(colors) | |
# Otherwise, convert colors to matplotlib's RGB using the colorConverter | |
# object. These are converted to numpy ndarrays to be consistent with the | |
# to_rgba method of ScalarMappable. | |
else: | |
try: | |
rgba_colors = np.array([mpl.colors.colorConverter.to_rgba(colors)]) | |
except ValueError: | |
rgba_colors = np.array( | |
[mpl.colors.colorConverter.to_rgba(color) for color in colors] | |
) | |
# Set the final column of the rgba_colors to have the relevant alpha values | |
try: | |
# If alpha is longer than the number of colors, resize to the number of | |
# elements. Also, if rgba_colors.size (the number of elements of | |
# rgba_colors) is the same as the number of elements, resize the array, | |
# to avoid it being interpreted as a colormap by scatter() | |
if len(alpha) > len(rgba_colors) or rgba_colors.size == len(elem_list): | |
rgba_colors = np.resize(rgba_colors, (len(elem_list), 4)) | |
rgba_colors[1:, 0] = rgba_colors[0, 0] | |
rgba_colors[1:, 1] = rgba_colors[0, 1] | |
rgba_colors[1:, 2] = rgba_colors[0, 2] | |
rgba_colors[:, 3] = list(islice(cycle(alpha), len(rgba_colors))) | |
except TypeError: | |
rgba_colors[:, -1] = alpha | |
return rgba_colors | |