Spaces:
Sleeping
Sleeping
File size: 10,627 Bytes
b200bda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 |
"""
Plotting (requires matplotlib)
"""
from colorsys import hsv_to_rgb, hls_to_rgb
from .libmp import NoConvergence
from .libmp.backend import xrange
class VisualizationMethods(object):
plot_ignore = (ValueError, ArithmeticError, ZeroDivisionError, NoConvergence)
def plot(ctx, f, xlim=[-5,5], ylim=None, points=200, file=None, dpi=None,
singularities=[], axes=None):
r"""
Shows a simple 2D plot of a function `f(x)` or list of functions
`[f_0(x), f_1(x), \ldots, f_n(x)]` over a given interval
specified by *xlim*. Some examples::
plot(lambda x: exp(x)*li(x), [1, 4])
plot([cos, sin], [-4, 4])
plot([fresnels, fresnelc], [-4, 4])
plot([sqrt, cbrt], [-4, 4])
plot(lambda t: zeta(0.5+t*j), [-20, 20])
plot([floor, ceil, abs, sign], [-5, 5])
Points where the function raises a numerical exception or
returns an infinite value are removed from the graph.
Singularities can also be excluded explicitly
as follows (useful for removing erroneous vertical lines)::
plot(cot, ylim=[-5, 5]) # bad
plot(cot, ylim=[-5, 5], singularities=[-pi, 0, pi]) # good
For parts where the function assumes complex values, the
real part is plotted with dashes and the imaginary part
is plotted with dots.
.. note :: This function requires matplotlib (pylab).
"""
if file:
axes = None
fig = None
if not axes:
import pylab
fig = pylab.figure()
axes = fig.add_subplot(111)
if not isinstance(f, (tuple, list)):
f = [f]
a, b = xlim
colors = ['b', 'r', 'g', 'm', 'k']
for n, func in enumerate(f):
x = ctx.arange(a, b, (b-a)/float(points))
segments = []
segment = []
in_complex = False
for i in xrange(len(x)):
try:
if i != 0:
for sing in singularities:
if x[i-1] <= sing and x[i] >= sing:
raise ValueError
v = func(x[i])
if ctx.isnan(v) or abs(v) > 1e300:
raise ValueError
if hasattr(v, "imag") and v.imag:
re = float(v.real)
im = float(v.imag)
if not in_complex:
in_complex = True
segments.append(segment)
segment = []
segment.append((float(x[i]), re, im))
else:
if in_complex:
in_complex = False
segments.append(segment)
segment = []
if hasattr(v, "real"):
v = v.real
segment.append((float(x[i]), v))
except ctx.plot_ignore:
if segment:
segments.append(segment)
segment = []
if segment:
segments.append(segment)
for segment in segments:
x = [s[0] for s in segment]
y = [s[1] for s in segment]
if not x:
continue
c = colors[n % len(colors)]
if len(segment[0]) == 3:
z = [s[2] for s in segment]
axes.plot(x, y, '--'+c, linewidth=3)
axes.plot(x, z, ':'+c, linewidth=3)
else:
axes.plot(x, y, c, linewidth=3)
axes.set_xlim([float(_) for _ in xlim])
if ylim:
axes.set_ylim([float(_) for _ in ylim])
axes.set_xlabel('x')
axes.set_ylabel('f(x)')
axes.grid(True)
if fig:
if file:
pylab.savefig(file, dpi=dpi)
else:
pylab.show()
def default_color_function(ctx, z):
if ctx.isinf(z):
return (1.0, 1.0, 1.0)
if ctx.isnan(z):
return (0.5, 0.5, 0.5)
pi = 3.1415926535898
a = (float(ctx.arg(z)) + ctx.pi) / (2*ctx.pi)
a = (a + 0.5) % 1.0
b = 1.0 - float(1/(1.0+abs(z)**0.3))
return hls_to_rgb(a, b, 0.8)
blue_orange_colors = [
(-1.0, (0.0, 0.0, 0.0)),
(-0.95, (0.1, 0.2, 0.5)), # dark blue
(-0.5, (0.0, 0.5, 1.0)), # blueish
(-0.05, (0.4, 0.8, 0.8)), # cyanish
( 0.0, (1.0, 1.0, 1.0)),
( 0.05, (1.0, 0.9, 0.3)), # yellowish
( 0.5, (0.9, 0.5, 0.0)), # orangeish
( 0.95, (0.7, 0.1, 0.0)), # redish
( 1.0, (0.0, 0.0, 0.0)),
( 2.0, (0.0, 0.0, 0.0)),
]
def phase_color_function(ctx, z):
if ctx.isinf(z):
return (1.0, 1.0, 1.0)
if ctx.isnan(z):
return (0.5, 0.5, 0.5)
pi = 3.1415926535898
w = float(ctx.arg(z)) / pi
w = max(min(w, 1.0), -1.0)
for i in range(1,len(blue_orange_colors)):
if blue_orange_colors[i][0] > w:
a, (ra, ga, ba) = blue_orange_colors[i-1]
b, (rb, gb, bb) = blue_orange_colors[i]
s = (w-a) / (b-a)
return ra+(rb-ra)*s, ga+(gb-ga)*s, ba+(bb-ba)*s
def cplot(ctx, f, re=[-5,5], im=[-5,5], points=2000, color=None,
verbose=False, file=None, dpi=None, axes=None):
"""
Plots the given complex-valued function *f* over a rectangular part
of the complex plane specified by the pairs of intervals *re* and *im*.
For example::
cplot(lambda z: z, [-2, 2], [-10, 10])
cplot(exp)
cplot(zeta, [0, 1], [0, 50])
By default, the complex argument (phase) is shown as color (hue) and
the magnitude is show as brightness. You can also supply a
custom color function (*color*). This function should take a
complex number as input and return an RGB 3-tuple containing
floats in the range 0.0-1.0.
Alternatively, you can select a builtin color function by passing
a string as *color*:
* "default" - default color scheme
* "phase" - a color scheme that only renders the phase of the function,
with white for positive reals, black for negative reals, gold in the
upper half plane, and blue in the lower half plane.
To obtain a sharp image, the number of points may need to be
increased to 100,000 or thereabout. Since evaluating the
function that many times is likely to be slow, the 'verbose'
option is useful to display progress.
.. note :: This function requires matplotlib (pylab).
"""
if color is None or color == "default":
color = ctx.default_color_function
if color == "phase":
color = ctx.phase_color_function
import pylab
if file:
axes = None
fig = None
if not axes:
fig = pylab.figure()
axes = fig.add_subplot(111)
rea, reb = re
ima, imb = im
dre = reb - rea
dim = imb - ima
M = int(ctx.sqrt(points*dre/dim)+1)
N = int(ctx.sqrt(points*dim/dre)+1)
x = pylab.linspace(rea, reb, M)
y = pylab.linspace(ima, imb, N)
# Note: we have to be careful to get the right rotation.
# Test with these plots:
# cplot(lambda z: z if z.real < 0 else 0)
# cplot(lambda z: z if z.imag < 0 else 0)
w = pylab.zeros((N, M, 3))
for n in xrange(N):
for m in xrange(M):
z = ctx.mpc(x[m], y[n])
try:
v = color(f(z))
except ctx.plot_ignore:
v = (0.5, 0.5, 0.5)
w[n,m] = v
if verbose:
print(str(n) + ' of ' + str(N))
rea, reb, ima, imb = [float(_) for _ in [rea, reb, ima, imb]]
axes.imshow(w, extent=(rea, reb, ima, imb), origin='lower')
axes.set_xlabel('Re(z)')
axes.set_ylabel('Im(z)')
if fig:
if file:
pylab.savefig(file, dpi=dpi)
else:
pylab.show()
def splot(ctx, f, u=[-5,5], v=[-5,5], points=100, keep_aspect=True, \
wireframe=False, file=None, dpi=None, axes=None):
"""
Plots the surface defined by `f`.
If `f` returns a single component, then this plots the surface
defined by `z = f(x,y)` over the rectangular domain with
`x = u` and `y = v`.
If `f` returns three components, then this plots the parametric
surface `x, y, z = f(u,v)` over the pairs of intervals `u` and `v`.
For example, to plot a simple function::
>>> from mpmath import *
>>> f = lambda x, y: sin(x+y)*cos(y)
>>> splot(f, [-pi,pi], [-pi,pi]) # doctest: +SKIP
Plotting a donut::
>>> r, R = 1, 2.5
>>> f = lambda u, v: [r*cos(u), (R+r*sin(u))*cos(v), (R+r*sin(u))*sin(v)]
>>> splot(f, [0, 2*pi], [0, 2*pi]) # doctest: +SKIP
.. note :: This function requires matplotlib (pylab) 0.98.5.3 or higher.
"""
import pylab
import mpl_toolkits.mplot3d as mplot3d
if file:
axes = None
fig = None
if not axes:
fig = pylab.figure()
axes = mplot3d.axes3d.Axes3D(fig)
ua, ub = u
va, vb = v
du = ub - ua
dv = vb - va
if not isinstance(points, (list, tuple)):
points = [points, points]
M, N = points
u = pylab.linspace(ua, ub, M)
v = pylab.linspace(va, vb, N)
x, y, z = [pylab.zeros((M, N)) for i in xrange(3)]
xab, yab, zab = [[0, 0] for i in xrange(3)]
for n in xrange(N):
for m in xrange(M):
fdata = f(ctx.convert(u[m]), ctx.convert(v[n]))
try:
x[m,n], y[m,n], z[m,n] = fdata
except TypeError:
x[m,n], y[m,n], z[m,n] = u[m], v[n], fdata
for c, cab in [(x[m,n], xab), (y[m,n], yab), (z[m,n], zab)]:
if c < cab[0]:
cab[0] = c
if c > cab[1]:
cab[1] = c
if wireframe:
axes.plot_wireframe(x, y, z, rstride=4, cstride=4)
else:
axes.plot_surface(x, y, z, rstride=4, cstride=4)
axes.set_xlabel('x')
axes.set_ylabel('y')
axes.set_zlabel('z')
if keep_aspect:
dx, dy, dz = [cab[1] - cab[0] for cab in [xab, yab, zab]]
maxd = max(dx, dy, dz)
if dx < maxd:
delta = maxd - dx
axes.set_xlim3d(xab[0] - delta / 2.0, xab[1] + delta / 2.0)
if dy < maxd:
delta = maxd - dy
axes.set_ylim3d(yab[0] - delta / 2.0, yab[1] + delta / 2.0)
if dz < maxd:
delta = maxd - dz
axes.set_zlim3d(zab[0] - delta / 2.0, zab[1] + delta / 2.0)
if fig:
if file:
pylab.savefig(file, dpi=dpi)
else:
pylab.show()
VisualizationMethods.plot = plot
VisualizationMethods.default_color_function = default_color_function
VisualizationMethods.phase_color_function = phase_color_function
VisualizationMethods.cplot = cplot
VisualizationMethods.splot = splot
|