|
""" |
|
This module contain solvers for all kinds of equations: |
|
|
|
- algebraic or transcendental, use solve() |
|
|
|
- recurrence, use rsolve() |
|
|
|
- differential, use dsolve() |
|
|
|
- nonlinear (numerically), use nsolve() |
|
(you will need a good starting point) |
|
|
|
""" |
|
from __future__ import annotations |
|
|
|
from sympy.core import (S, Add, Symbol, Dummy, Expr, Mul) |
|
from sympy.core.assumptions import check_assumptions |
|
from sympy.core.exprtools import factor_terms |
|
from sympy.core.function import (expand_mul, expand_log, Derivative, |
|
AppliedUndef, UndefinedFunction, nfloat, |
|
Function, expand_power_exp, _mexpand, expand, |
|
expand_func) |
|
from sympy.core.logic import fuzzy_not |
|
from sympy.core.numbers import Float, Rational, _illegal |
|
from sympy.core.intfunc import integer_log, ilcm |
|
from sympy.core.power import Pow |
|
from sympy.core.relational import Eq, Ne |
|
from sympy.core.sorting import ordered, default_sort_key |
|
from sympy.core.sympify import sympify, _sympify |
|
from sympy.core.traversal import preorder_traversal |
|
from sympy.logic.boolalg import And, BooleanAtom |
|
|
|
from sympy.functions import (log, exp, LambertW, cos, sin, tan, acos, asin, atan, |
|
Abs, re, im, arg, sqrt, atan2) |
|
from sympy.functions.combinatorial.factorials import binomial |
|
from sympy.functions.elementary.hyperbolic import HyperbolicFunction |
|
from sympy.functions.elementary.piecewise import piecewise_fold, Piecewise |
|
from sympy.functions.elementary.trigonometric import TrigonometricFunction |
|
from sympy.integrals.integrals import Integral |
|
from sympy.ntheory.factor_ import divisors |
|
from sympy.simplify import (simplify, collect, powsimp, posify, |
|
powdenest, nsimplify, denom, logcombine, sqrtdenest, fraction, |
|
separatevars) |
|
from sympy.simplify.sqrtdenest import sqrt_depth |
|
from sympy.simplify.fu import TR1, TR2i, TR10, TR11 |
|
from sympy.strategies.rl import rebuild |
|
from sympy.matrices.exceptions import NonInvertibleMatrixError |
|
from sympy.matrices import Matrix, zeros |
|
from sympy.polys import roots, cancel, factor, Poly |
|
from sympy.polys.solvers import sympy_eqs_to_ring, solve_lin_sys |
|
from sympy.polys.polyerrors import GeneratorsNeeded, PolynomialError |
|
from sympy.polys.polytools import gcd |
|
from sympy.utilities.lambdify import lambdify |
|
from sympy.utilities.misc import filldedent, debugf |
|
from sympy.utilities.iterables import (connected_components, |
|
generate_bell, uniq, iterable, is_sequence, subsets, flatten, sift) |
|
from sympy.utilities.decorator import conserve_mpmath_dps |
|
|
|
from mpmath import findroot |
|
|
|
from sympy.solvers.polysys import solve_poly_system |
|
|
|
from types import GeneratorType |
|
from collections import defaultdict |
|
from itertools import combinations, product |
|
|
|
import warnings |
|
|
|
|
|
def recast_to_symbols(eqs, symbols): |
|
""" |
|
Return (e, s, d) where e and s are versions of *eqs* and |
|
*symbols* in which any non-Symbol objects in *symbols* have |
|
been replaced with generic Dummy symbols and d is a dictionary |
|
that can be used to restore the original expressions. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy.solvers.solvers import recast_to_symbols |
|
>>> from sympy import symbols, Function |
|
>>> x, y = symbols('x y') |
|
>>> fx = Function('f')(x) |
|
>>> eqs, syms = [fx + 1, x, y], [fx, y] |
|
>>> e, s, d = recast_to_symbols(eqs, syms); (e, s, d) |
|
([_X0 + 1, x, y], [_X0, y], {_X0: f(x)}) |
|
|
|
The original equations and symbols can be restored using d: |
|
|
|
>>> assert [i.xreplace(d) for i in eqs] == eqs |
|
>>> assert [d.get(i, i) for i in s] == syms |
|
|
|
""" |
|
if not iterable(eqs) and iterable(symbols): |
|
raise ValueError('Both eqs and symbols must be iterable') |
|
orig = list(symbols) |
|
symbols = list(ordered(symbols)) |
|
swap_sym = {} |
|
i = 0 |
|
for s in symbols: |
|
if not isinstance(s, Symbol) and s not in swap_sym: |
|
swap_sym[s] = Dummy('X%d' % i) |
|
i += 1 |
|
new_f = [] |
|
for i in eqs: |
|
isubs = getattr(i, 'subs', None) |
|
if isubs is not None: |
|
new_f.append(isubs(swap_sym)) |
|
else: |
|
new_f.append(i) |
|
restore = {v: k for k, v in swap_sym.items()} |
|
return new_f, [swap_sym.get(i, i) for i in orig], restore |
|
|
|
|
|
def _ispow(e): |
|
"""Return True if e is a Pow or is exp.""" |
|
return isinstance(e, Expr) and (e.is_Pow or isinstance(e, exp)) |
|
|
|
|
|
def _simple_dens(f, symbols): |
|
|
|
|
|
|
|
|
|
dens = set() |
|
for d in denoms(f, symbols): |
|
if d.is_Pow and d.exp.is_Number: |
|
if d.exp.is_zero: |
|
continue |
|
d = d.base |
|
dens.add(d) |
|
return dens |
|
|
|
|
|
def denoms(eq, *symbols): |
|
""" |
|
Return (recursively) set of all denominators that appear in *eq* |
|
that contain any symbol in *symbols*; if *symbols* are not |
|
provided then all denominators will be returned. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy.solvers.solvers import denoms |
|
>>> from sympy.abc import x, y, z |
|
|
|
>>> denoms(x/y) |
|
{y} |
|
|
|
>>> denoms(x/(y*z)) |
|
{y, z} |
|
|
|
>>> denoms(3/x + y/z) |
|
{x, z} |
|
|
|
>>> denoms(x/2 + y/z) |
|
{2, z} |
|
|
|
If *symbols* are provided then only denominators containing |
|
those symbols will be returned: |
|
|
|
>>> denoms(1/x + 1/y + 1/z, y, z) |
|
{y, z} |
|
|
|
""" |
|
|
|
pot = preorder_traversal(eq) |
|
dens = set() |
|
for p in pot: |
|
|
|
|
|
if not isinstance(p, Expr): |
|
continue |
|
den = denom(p) |
|
if den is S.One: |
|
continue |
|
dens.update(Mul.make_args(den)) |
|
if not symbols: |
|
return dens |
|
elif len(symbols) == 1: |
|
if iterable(symbols[0]): |
|
symbols = symbols[0] |
|
return {d for d in dens if any(s in d.free_symbols for s in symbols)} |
|
|
|
|
|
def checksol(f, symbol, sol=None, **flags): |
|
""" |
|
Checks whether sol is a solution of equation f == 0. |
|
|
|
Explanation |
|
=========== |
|
|
|
Input can be either a single symbol and corresponding value |
|
or a dictionary of symbols and values. When given as a dictionary |
|
and flag ``simplify=True``, the values in the dictionary will be |
|
simplified. *f* can be a single equation or an iterable of equations. |
|
A solution must satisfy all equations in *f* to be considered valid; |
|
if a solution does not satisfy any equation, False is returned; if one or |
|
more checks are inconclusive (and none are False) then None is returned. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import checksol, symbols |
|
>>> x, y = symbols('x,y') |
|
>>> checksol(x**4 - 1, x, 1) |
|
True |
|
>>> checksol(x**4 - 1, x, 0) |
|
False |
|
>>> checksol(x**2 + y**2 - 5**2, {x: 3, y: 4}) |
|
True |
|
|
|
To check if an expression is zero using ``checksol()``, pass it |
|
as *f* and send an empty dictionary for *symbol*: |
|
|
|
>>> checksol(x**2 + x - x*(x + 1), {}) |
|
True |
|
|
|
None is returned if ``checksol()`` could not conclude. |
|
|
|
flags: |
|
'numerical=True (default)' |
|
do a fast numerical check if ``f`` has only one symbol. |
|
'minimal=True (default is False)' |
|
a very fast, minimal testing. |
|
'warn=True (default is False)' |
|
show a warning if checksol() could not conclude. |
|
'simplify=True (default)' |
|
simplify solution before substituting into function and |
|
simplify the function before trying specific simplifications |
|
'force=True (default is False)' |
|
make positive all symbols without assumptions regarding sign. |
|
|
|
""" |
|
from sympy.physics.units import Unit |
|
|
|
minimal = flags.get('minimal', False) |
|
|
|
if sol is not None: |
|
sol = {symbol: sol} |
|
elif isinstance(symbol, dict): |
|
sol = symbol |
|
else: |
|
msg = 'Expecting (sym, val) or ({sym: val}, None) but got (%s, %s)' |
|
raise ValueError(msg % (symbol, sol)) |
|
|
|
if iterable(f): |
|
if not f: |
|
raise ValueError('no functions to check') |
|
rv = True |
|
for fi in f: |
|
check = checksol(fi, sol, **flags) |
|
if check: |
|
continue |
|
if check is False: |
|
return False |
|
rv = None |
|
return rv |
|
|
|
f = _sympify(f) |
|
|
|
if f.is_number: |
|
return f.is_zero |
|
|
|
if isinstance(f, Poly): |
|
f = f.as_expr() |
|
elif isinstance(f, (Eq, Ne)): |
|
if f.rhs in (S.true, S.false): |
|
f = f.reversed |
|
B, E = f.args |
|
if isinstance(B, BooleanAtom): |
|
f = f.subs(sol) |
|
if not f.is_Boolean: |
|
return |
|
elif isinstance(f, Eq): |
|
f = Add(f.lhs, -f.rhs, evaluate=False) |
|
|
|
if isinstance(f, BooleanAtom): |
|
return bool(f) |
|
elif not f.is_Relational and not f: |
|
return True |
|
|
|
illegal = set(_illegal) |
|
if any(sympify(v).atoms() & illegal for k, v in sol.items()): |
|
return False |
|
|
|
attempt = -1 |
|
numerical = flags.get('numerical', True) |
|
while 1: |
|
attempt += 1 |
|
if attempt == 0: |
|
val = f.subs(sol) |
|
if isinstance(val, Mul): |
|
val = val.as_independent(Unit)[0] |
|
if val.atoms() & illegal: |
|
return False |
|
elif attempt == 1: |
|
if not val.is_number: |
|
if not val.is_constant(*list(sol.keys()), simplify=not minimal): |
|
return False |
|
|
|
_, val = val.as_content_primitive() |
|
val = _mexpand(val.as_numer_denom()[0], recursive=True) |
|
elif attempt == 2: |
|
if minimal: |
|
return |
|
if flags.get('simplify', True): |
|
for k in sol: |
|
sol[k] = simplify(sol[k]) |
|
|
|
|
|
val = simplify(f.subs(sol)) |
|
if flags.get('force', True): |
|
val, reps = posify(val) |
|
|
|
exval = _mexpand(val, recursive=True) |
|
if exval.is_number: |
|
|
|
val = exval |
|
else: |
|
|
|
|
|
pot = preorder_traversal(expand_mul(val)) |
|
seen = set() |
|
saw_pow_func = False |
|
for p in pot: |
|
if p in seen: |
|
continue |
|
seen.add(p) |
|
if p.is_Pow and not p.exp.is_Integer: |
|
saw_pow_func = True |
|
elif p.is_Function: |
|
saw_pow_func = True |
|
elif isinstance(p, UndefinedFunction): |
|
saw_pow_func = True |
|
if saw_pow_func: |
|
break |
|
if saw_pow_func is False: |
|
return False |
|
if flags.get('force', True): |
|
|
|
val = val.subs(reps) |
|
nz = fuzzy_not(val.is_zero) |
|
if nz is not None: |
|
|
|
|
|
|
|
|
|
|
|
if val.is_number and val.has(LambertW): |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
return not nz |
|
break |
|
if val.is_Rational: |
|
return val == 0 |
|
if numerical and val.is_number: |
|
return (abs(val.n(18).n(12, chop=True)) < 1e-9) is S.true |
|
|
|
if flags.get('warn', False): |
|
warnings.warn("\n\tWarning: could not verify solution %s." % sol) |
|
|
|
|
|
|
|
|
|
def solve(f, *symbols, **flags): |
|
r""" |
|
Algebraically solves equations and systems of equations. |
|
|
|
Explanation |
|
=========== |
|
|
|
Currently supported: |
|
- polynomial |
|
- transcendental |
|
- piecewise combinations of the above |
|
- systems of linear and polynomial equations |
|
- systems containing relational expressions |
|
- systems implied by undetermined coefficients |
|
|
|
Examples |
|
======== |
|
|
|
The default output varies according to the input and might |
|
be a list (possibly empty), a dictionary, a list of |
|
dictionaries or tuples, or an expression involving relationals. |
|
For specifics regarding different forms of output that may appear, see :ref:`solve_output`. |
|
Let it suffice here to say that to obtain a uniform output from |
|
`solve` use ``dict=True`` or ``set=True`` (see below). |
|
|
|
>>> from sympy import solve, Poly, Eq, Matrix, Symbol |
|
>>> from sympy.abc import x, y, z, a, b |
|
|
|
The expressions that are passed can be Expr, Equality, or Poly |
|
classes (or lists of the same); a Matrix is considered to be a |
|
list of all the elements of the matrix: |
|
|
|
>>> solve(x - 3, x) |
|
[3] |
|
>>> solve(Eq(x, 3), x) |
|
[3] |
|
>>> solve(Poly(x - 3), x) |
|
[3] |
|
>>> solve(Matrix([[x, x + y]]), x, y) == solve([x, x + y], x, y) |
|
True |
|
|
|
If no symbols are indicated to be of interest and the equation is |
|
univariate, a list of values is returned; otherwise, the keys in |
|
a dictionary will indicate which (of all the variables used in |
|
the expression(s)) variables and solutions were found: |
|
|
|
>>> solve(x**2 - 4) |
|
[-2, 2] |
|
>>> solve((x - a)*(y - b)) |
|
[{a: x}, {b: y}] |
|
>>> solve([x - 3, y - 1]) |
|
{x: 3, y: 1} |
|
>>> solve([x - 3, y**2 - 1]) |
|
[{x: 3, y: -1}, {x: 3, y: 1}] |
|
|
|
If you pass symbols for which solutions are sought, the output will vary |
|
depending on the number of symbols you passed, whether you are passing |
|
a list of expressions or not, and whether a linear system was solved. |
|
Uniform output is attained by using ``dict=True`` or ``set=True``. |
|
|
|
>>> #### *** feel free to skip to the stars below *** #### |
|
>>> from sympy import TableForm |
|
>>> h = [None, ';|;'.join(['e', 's', 'solve(e, s)', 'solve(e, s, dict=True)', |
|
... 'solve(e, s, set=True)']).split(';')] |
|
>>> t = [] |
|
>>> for e, s in [ |
|
... (x - y, y), |
|
... (x - y, [x, y]), |
|
... (x**2 - y, [x, y]), |
|
... ([x - 3, y -1], [x, y]), |
|
... ]: |
|
... how = [{}, dict(dict=True), dict(set=True)] |
|
... res = [solve(e, s, **f) for f in how] |
|
... t.append([e, '|', s, '|'] + [res[0], '|', res[1], '|', res[2]]) |
|
... |
|
>>> # ******************************************************* # |
|
>>> TableForm(t, headings=h, alignments="<") |
|
e | s | solve(e, s) | solve(e, s, dict=True) | solve(e, s, set=True) |
|
--------------------------------------------------------------------------------------- |
|
x - y | y | [x] | [{y: x}] | ([y], {(x,)}) |
|
x - y | [x, y] | [(y, y)] | [{x: y}] | ([x, y], {(y, y)}) |
|
x**2 - y | [x, y] | [(x, x**2)] | [{y: x**2}] | ([x, y], {(x, x**2)}) |
|
[x - 3, y - 1] | [x, y] | {x: 3, y: 1} | [{x: 3, y: 1}] | ([x, y], {(3, 1)}) |
|
|
|
* If any equation does not depend on the symbol(s) given, it will be |
|
eliminated from the equation set and an answer may be given |
|
implicitly in terms of variables that were not of interest: |
|
|
|
>>> solve([x - y, y - 3], x) |
|
{x: y} |
|
|
|
When you pass all but one of the free symbols, an attempt |
|
is made to find a single solution based on the method of |
|
undetermined coefficients. If it succeeds, a dictionary of values |
|
is returned. If you want an algebraic solutions for one |
|
or more of the symbols, pass the expression to be solved in a list: |
|
|
|
>>> e = a*x + b - 2*x - 3 |
|
>>> solve(e, [a, b]) |
|
{a: 2, b: 3} |
|
>>> solve([e], [a, b]) |
|
{a: -b/x + (2*x + 3)/x} |
|
|
|
When there is no solution for any given symbol which will make all |
|
expressions zero, the empty list is returned (or an empty set in |
|
the tuple when ``set=True``): |
|
|
|
>>> from sympy import sqrt |
|
>>> solve(3, x) |
|
[] |
|
>>> solve(x - 3, y) |
|
[] |
|
>>> solve(sqrt(x) + 1, x, set=True) |
|
([x], set()) |
|
|
|
When an object other than a Symbol is given as a symbol, it is |
|
isolated algebraically and an implicit solution may be obtained. |
|
This is mostly provided as a convenience to save you from replacing |
|
the object with a Symbol and solving for that Symbol. It will only |
|
work if the specified object can be replaced with a Symbol using the |
|
subs method: |
|
|
|
>>> from sympy import exp, Function |
|
>>> f = Function('f') |
|
|
|
>>> solve(f(x) - x, f(x)) |
|
[x] |
|
>>> solve(f(x).diff(x) - f(x) - x, f(x).diff(x)) |
|
[x + f(x)] |
|
>>> solve(f(x).diff(x) - f(x) - x, f(x)) |
|
[-x + Derivative(f(x), x)] |
|
>>> solve(x + exp(x)**2, exp(x), set=True) |
|
([exp(x)], {(-sqrt(-x),), (sqrt(-x),)}) |
|
|
|
>>> from sympy import Indexed, IndexedBase, Tuple |
|
>>> A = IndexedBase('A') |
|
>>> eqs = Tuple(A[1] + A[2] - 3, A[1] - A[2] + 1) |
|
>>> solve(eqs, eqs.atoms(Indexed)) |
|
{A[1]: 1, A[2]: 2} |
|
|
|
* To solve for a function within a derivative, use :func:`~.dsolve`. |
|
|
|
To solve for a symbol implicitly, use implicit=True: |
|
|
|
>>> solve(x + exp(x), x) |
|
[-LambertW(1)] |
|
>>> solve(x + exp(x), x, implicit=True) |
|
[-exp(x)] |
|
|
|
It is possible to solve for anything in an expression that can be |
|
replaced with a symbol using :obj:`~sympy.core.basic.Basic.subs`: |
|
|
|
>>> solve(x + 2 + sqrt(3), x + 2) |
|
[-sqrt(3)] |
|
>>> solve((x + 2 + sqrt(3), x + 4 + y), y, x + 2) |
|
{y: -2 + sqrt(3), x + 2: -sqrt(3)} |
|
|
|
* Nothing heroic is done in this implicit solving so you may end up |
|
with a symbol still in the solution: |
|
|
|
>>> eqs = (x*y + 3*y + sqrt(3), x + 4 + y) |
|
>>> solve(eqs, y, x + 2) |
|
{y: -sqrt(3)/(x + 3), x + 2: -2*x/(x + 3) - 6/(x + 3) + sqrt(3)/(x + 3)} |
|
>>> solve(eqs, y*x, x) |
|
{x: -y - 4, x*y: -3*y - sqrt(3)} |
|
|
|
* If you attempt to solve for a number, remember that the number |
|
you have obtained does not necessarily mean that the value is |
|
equivalent to the expression obtained: |
|
|
|
>>> solve(sqrt(2) - 1, 1) |
|
[sqrt(2)] |
|
>>> solve(x - y + 1, 1) # /!\ -1 is targeted, too |
|
[x/(y - 1)] |
|
>>> [_.subs(z, -1) for _ in solve((x - y + 1).subs(-1, z), 1)] |
|
[-x + y] |
|
|
|
**Additional Examples** |
|
|
|
``solve()`` with check=True (default) will run through the symbol tags to |
|
eliminate unwanted solutions. If no assumptions are included, all possible |
|
solutions will be returned: |
|
|
|
>>> x = Symbol("x") |
|
>>> solve(x**2 - 1) |
|
[-1, 1] |
|
|
|
By setting the ``positive`` flag, only one solution will be returned: |
|
|
|
>>> pos = Symbol("pos", positive=True) |
|
>>> solve(pos**2 - 1) |
|
[1] |
|
|
|
When the solutions are checked, those that make any denominator zero |
|
are automatically excluded. If you do not want to exclude such solutions, |
|
then use the check=False option: |
|
|
|
>>> from sympy import sin, limit |
|
>>> solve(sin(x)/x) # 0 is excluded |
|
[pi] |
|
|
|
If ``check=False``, then a solution to the numerator being zero is found |
|
but the value of $x = 0$ is a spurious solution since $\sin(x)/x$ has the well |
|
known limit (without discontinuity) of 1 at $x = 0$: |
|
|
|
>>> solve(sin(x)/x, check=False) |
|
[0, pi] |
|
|
|
In the following case, however, the limit exists and is equal to the |
|
value of $x = 0$ that is excluded when check=True: |
|
|
|
>>> eq = x**2*(1/x - z**2/x) |
|
>>> solve(eq, x) |
|
[] |
|
>>> solve(eq, x, check=False) |
|
[0] |
|
>>> limit(eq, x, 0, '-') |
|
0 |
|
>>> limit(eq, x, 0, '+') |
|
0 |
|
|
|
**Solving Relationships** |
|
|
|
When one or more expressions passed to ``solve`` is a relational, |
|
a relational result is returned (and the ``dict`` and ``set`` flags |
|
are ignored): |
|
|
|
>>> solve(x < 3) |
|
(-oo < x) & (x < 3) |
|
>>> solve([x < 3, x**2 > 4], x) |
|
((-oo < x) & (x < -2)) | ((2 < x) & (x < 3)) |
|
>>> solve([x + y - 3, x > 3], x) |
|
(3 < x) & (x < oo) & Eq(x, 3 - y) |
|
|
|
Although checking of assumptions on symbols in relationals |
|
is not done, setting assumptions will affect how certain |
|
relationals might automatically simplify: |
|
|
|
>>> solve(x**2 > 4) |
|
((-oo < x) & (x < -2)) | ((2 < x) & (x < oo)) |
|
|
|
>>> r = Symbol('r', real=True) |
|
>>> solve(r**2 > 4) |
|
(2 < r) | (r < -2) |
|
|
|
There is currently no algorithm in SymPy that allows you to use |
|
relationships to resolve more than one variable. So the following |
|
does not determine that ``q < 0`` (and trying to solve for ``r`` |
|
and ``q`` will raise an error): |
|
|
|
>>> from sympy import symbols |
|
>>> r, q = symbols('r, q', real=True) |
|
>>> solve([r + q - 3, r > 3], r) |
|
(3 < r) & Eq(r, 3 - q) |
|
|
|
You can directly call the routine that ``solve`` calls |
|
when it encounters a relational: :func:`~.reduce_inequalities`. |
|
It treats Expr like Equality. |
|
|
|
>>> from sympy import reduce_inequalities |
|
>>> reduce_inequalities([x**2 - 4]) |
|
Eq(x, -2) | Eq(x, 2) |
|
|
|
If each relationship contains only one symbol of interest, |
|
the expressions can be processed for multiple symbols: |
|
|
|
>>> reduce_inequalities([0 <= x - 1, y < 3], [x, y]) |
|
(-oo < y) & (1 <= x) & (x < oo) & (y < 3) |
|
|
|
But an error is raised if any relationship has more than one |
|
symbol of interest: |
|
|
|
>>> reduce_inequalities([0 <= x*y - 1, y < 3], [x, y]) |
|
Traceback (most recent call last): |
|
... |
|
NotImplementedError: |
|
inequality has more than one symbol of interest. |
|
|
|
**Disabling High-Order Explicit Solutions** |
|
|
|
When solving polynomial expressions, you might not want explicit solutions |
|
(which can be quite long). If the expression is univariate, ``CRootOf`` |
|
instances will be returned instead: |
|
|
|
>>> solve(x**3 - x + 1) |
|
[-1/((-1/2 - sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)) - |
|
(-1/2 - sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)/3, |
|
-(-1/2 + sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)/3 - |
|
1/((-1/2 + sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)), |
|
-(3*sqrt(69)/2 + 27/2)**(1/3)/3 - |
|
1/(3*sqrt(69)/2 + 27/2)**(1/3)] |
|
>>> solve(x**3 - x + 1, cubics=False) |
|
[CRootOf(x**3 - x + 1, 0), |
|
CRootOf(x**3 - x + 1, 1), |
|
CRootOf(x**3 - x + 1, 2)] |
|
|
|
If the expression is multivariate, no solution might be returned: |
|
|
|
>>> solve(x**3 - x + a, x, cubics=False) |
|
[] |
|
|
|
Sometimes solutions will be obtained even when a flag is False because the |
|
expression could be factored. In the following example, the equation can |
|
be factored as the product of a linear and a quadratic factor so explicit |
|
solutions (which did not require solving a cubic expression) are obtained: |
|
|
|
>>> eq = x**3 + 3*x**2 + x - 1 |
|
>>> solve(eq, cubics=False) |
|
[-1, -1 + sqrt(2), -sqrt(2) - 1] |
|
|
|
**Solving Equations Involving Radicals** |
|
|
|
Because of SymPy's use of the principle root, some solutions |
|
to radical equations will be missed unless check=False: |
|
|
|
>>> from sympy import root |
|
>>> eq = root(x**3 - 3*x**2, 3) + 1 - x |
|
>>> solve(eq) |
|
[] |
|
>>> solve(eq, check=False) |
|
[1/3] |
|
|
|
In the above example, there is only a single solution to the |
|
equation. Other expressions will yield spurious roots which |
|
must be checked manually; roots which give a negative argument |
|
to odd-powered radicals will also need special checking: |
|
|
|
>>> from sympy import real_root, S |
|
>>> eq = root(x, 3) - root(x, 5) + S(1)/7 |
|
>>> solve(eq) # this gives 2 solutions but misses a 3rd |
|
[CRootOf(7*x**5 - 7*x**3 + 1, 1)**15, |
|
CRootOf(7*x**5 - 7*x**3 + 1, 2)**15] |
|
>>> sol = solve(eq, check=False) |
|
>>> [abs(eq.subs(x,i).n(2)) for i in sol] |
|
[0.48, 0.e-110, 0.e-110, 0.052, 0.052] |
|
|
|
The first solution is negative so ``real_root`` must be used to see that it |
|
satisfies the expression: |
|
|
|
>>> abs(real_root(eq.subs(x, sol[0])).n(2)) |
|
0.e-110 |
|
|
|
If the roots of the equation are not real then more care will be |
|
necessary to find the roots, especially for higher order equations. |
|
Consider the following expression: |
|
|
|
>>> expr = root(x, 3) - root(x, 5) |
|
|
|
We will construct a known value for this expression at x = 3 by selecting |
|
the 1-th root for each radical: |
|
|
|
>>> expr1 = root(x, 3, 1) - root(x, 5, 1) |
|
>>> v = expr1.subs(x, -3) |
|
|
|
The ``solve`` function is unable to find any exact roots to this equation: |
|
|
|
>>> eq = Eq(expr, v); eq1 = Eq(expr1, v) |
|
>>> solve(eq, check=False), solve(eq1, check=False) |
|
([], []) |
|
|
|
The function ``unrad``, however, can be used to get a form of the equation |
|
for which numerical roots can be found: |
|
|
|
>>> from sympy.solvers.solvers import unrad |
|
>>> from sympy import nroots |
|
>>> e, (p, cov) = unrad(eq) |
|
>>> pvals = nroots(e) |
|
>>> inversion = solve(cov, x)[0] |
|
>>> xvals = [inversion.subs(p, i) for i in pvals] |
|
|
|
Although ``eq`` or ``eq1`` could have been used to find ``xvals``, the |
|
solution can only be verified with ``expr1``: |
|
|
|
>>> z = expr - v |
|
>>> [xi.n(chop=1e-9) for xi in xvals if abs(z.subs(x, xi).n()) < 1e-9] |
|
[] |
|
>>> z1 = expr1 - v |
|
>>> [xi.n(chop=1e-9) for xi in xvals if abs(z1.subs(x, xi).n()) < 1e-9] |
|
[-3.0] |
|
|
|
Parameters |
|
========== |
|
|
|
f : |
|
- a single Expr or Poly that must be zero |
|
- an Equality |
|
- a Relational expression |
|
- a Boolean |
|
- iterable of one or more of the above |
|
|
|
symbols : (object(s) to solve for) specified as |
|
- none given (other non-numeric objects will be used) |
|
- single symbol |
|
- denested list of symbols |
|
(e.g., ``solve(f, x, y)``) |
|
- ordered iterable of symbols |
|
(e.g., ``solve(f, [x, y])``) |
|
|
|
flags : |
|
dict=True (default is False) |
|
Return list (perhaps empty) of solution mappings. |
|
set=True (default is False) |
|
Return list of symbols and set of tuple(s) of solution(s). |
|
exclude=[] (default) |
|
Do not try to solve for any of the free symbols in exclude; |
|
if expressions are given, the free symbols in them will |
|
be extracted automatically. |
|
check=True (default) |
|
If False, do not do any testing of solutions. This can be |
|
useful if you want to include solutions that make any |
|
denominator zero. |
|
numerical=True (default) |
|
Do a fast numerical check if *f* has only one symbol. |
|
minimal=True (default is False) |
|
A very fast, minimal testing. |
|
warn=True (default is False) |
|
Show a warning if ``checksol()`` could not conclude. |
|
simplify=True (default) |
|
Simplify all but polynomials of order 3 or greater before |
|
returning them and (if check is not False) use the |
|
general simplify function on the solutions and the |
|
expression obtained when they are substituted into the |
|
function which should be zero. |
|
force=True (default is False) |
|
Make positive all symbols without assumptions regarding sign. |
|
rational=True (default) |
|
Recast Floats as Rational; if this option is not used, the |
|
system containing Floats may fail to solve because of issues |
|
with polys. If rational=None, Floats will be recast as |
|
rationals but the answer will be recast as Floats. If the |
|
flag is False then nothing will be done to the Floats. |
|
manual=True (default is False) |
|
Do not use the polys/matrix method to solve a system of |
|
equations, solve them one at a time as you might "manually." |
|
implicit=True (default is False) |
|
Allows ``solve`` to return a solution for a pattern in terms of |
|
other functions that contain that pattern; this is only |
|
needed if the pattern is inside of some invertible function |
|
like cos, exp, ect. |
|
particular=True (default is False) |
|
Instructs ``solve`` to try to find a particular solution to |
|
a linear system with as many zeros as possible; this is very |
|
expensive. |
|
quick=True (default is False; ``particular`` must be True) |
|
Selects a fast heuristic to find a solution with many zeros |
|
whereas a value of False uses the very slow method guaranteed |
|
to find the largest number of zeros possible. |
|
cubics=True (default) |
|
Return explicit solutions when cubic expressions are encountered. |
|
When False, quartics and quintics are disabled, too. |
|
quartics=True (default) |
|
Return explicit solutions when quartic expressions are encountered. |
|
When False, quintics are disabled, too. |
|
quintics=True (default) |
|
Return explicit solutions (if possible) when quintic expressions |
|
are encountered. |
|
|
|
See Also |
|
======== |
|
|
|
rsolve: For solving recurrence relationships |
|
dsolve: For solving differential equations |
|
|
|
""" |
|
from .inequalities import reduce_inequalities |
|
|
|
|
|
|
|
|
|
|
|
|
|
hints = ('cubics', 'quartics', 'quintics') |
|
default = True |
|
for k in hints: |
|
default = flags.setdefault(k, bool(flags.get(k, default))) |
|
|
|
|
|
implicit = flags.get('implicit', False) |
|
|
|
|
|
warn = flags.get('warn', False) |
|
|
|
|
|
|
|
as_set = flags.get('set', False) |
|
|
|
|
|
bare_f = not iterable(f) |
|
|
|
|
|
|
|
if flags.get('quick', None) is not None: |
|
if not flags.get('particular', None): |
|
raise ValueError('when using `quick`, `particular` should be True') |
|
if flags.get('particular', False) and bare_f: |
|
raise ValueError(filldedent(""" |
|
The 'particular/quick' flag is usually used with systems of |
|
equations. Either pass your equation in a list or |
|
consider using a solver like `diophantine` if you are |
|
looking for a solution in integers.""")) |
|
|
|
|
|
|
|
|
|
def _sympified_list(w): |
|
return list(map(sympify, w if iterable(w) else [w])) |
|
f, symbols = (_sympified_list(w) for w in [f, symbols]) |
|
|
|
|
|
|
|
|
|
ordered_symbols = None |
|
if not symbols: |
|
|
|
symbols = set().union(*[fi.free_symbols for fi in f]) |
|
if len(symbols) < len(f): |
|
for fi in f: |
|
pot = preorder_traversal(fi) |
|
for p in pot: |
|
if isinstance(p, AppliedUndef): |
|
if not as_set: |
|
flags['dict'] = True |
|
symbols.add(p) |
|
pot.skip() |
|
ordered_symbols = False |
|
symbols = list(ordered(symbols)) |
|
else: |
|
if len(symbols) == 1 and iterable(symbols[0]): |
|
symbols = symbols[0] |
|
ordered_symbols = symbols and is_sequence(symbols, |
|
include=GeneratorType) |
|
_symbols = list(uniq(symbols)) |
|
if len(_symbols) != len(symbols): |
|
ordered_symbols = False |
|
symbols = list(ordered(symbols)) |
|
else: |
|
symbols = _symbols |
|
|
|
|
|
if len(symbols) != len(set(symbols)): |
|
raise ValueError('duplicate symbols given') |
|
|
|
exclude = flags.pop('exclude', set()) |
|
if exclude: |
|
if isinstance(exclude, Expr): |
|
exclude = [exclude] |
|
exclude = set().union(*[e.free_symbols for e in sympify(exclude)]) |
|
symbols = [s for s in symbols if s not in exclude] |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(f, list): |
|
f = [s for s in f if s is not S.true] |
|
|
|
|
|
for i, fi in enumerate(f): |
|
if isinstance(fi, (Eq, Ne)): |
|
if 'ImmutableDenseMatrix' in [type(a).__name__ for a in fi.args]: |
|
fi = fi.lhs - fi.rhs |
|
else: |
|
L, R = fi.args |
|
if isinstance(R, BooleanAtom): |
|
L, R = R, L |
|
if isinstance(L, BooleanAtom): |
|
if isinstance(fi, Ne): |
|
L = ~L |
|
if R.is_Relational: |
|
fi = ~R if L is S.false else R |
|
elif R.is_Symbol: |
|
return L |
|
elif R.is_Boolean and (~R).is_Symbol: |
|
return ~L |
|
else: |
|
raise NotImplementedError(filldedent(''' |
|
Unanticipated argument of Eq when other arg |
|
is True or False. |
|
''')) |
|
elif isinstance(fi, Eq): |
|
fi = Add(fi.lhs, -fi.rhs, evaluate=False) |
|
f[i] = fi |
|
|
|
|
|
|
|
if fi.is_Relational: |
|
if len(symbols) != 1: |
|
raise ValueError("can only solve for one symbol at a time") |
|
if warn and symbols[0].assumptions0: |
|
warnings.warn(filldedent(""" |
|
\tWarning: assumptions about variable '%s' are |
|
not handled currently.""" % symbols[0])) |
|
return reduce_inequalities(f, symbols=symbols) |
|
|
|
|
|
if isinstance(fi, Poly): |
|
f[i] = fi.as_expr() |
|
|
|
|
|
|
|
f[i] = f[i].replace(lambda w: isinstance(w, HyperbolicFunction) and \ |
|
w.has_free(*symbols), lambda w: w.rewrite(exp)) |
|
|
|
|
|
if f[i].is_Matrix: |
|
bare_f = False |
|
f.extend(list(f[i])) |
|
f[i] = S.Zero |
|
|
|
|
|
freei = f[i].free_symbols |
|
if freei and all(s.is_extended_real or s.is_imaginary for s in freei): |
|
fr, fi = f[i].as_real_imag() |
|
|
|
had = f[i].atoms(re, im, arg, atan2) |
|
if fr and fi and fr != fi and not any( |
|
i.atoms(re, im, arg, atan2) - had for i in (fr, fi)): |
|
if bare_f: |
|
bare_f = False |
|
f[i: i + 1] = [fr, fi] |
|
|
|
|
|
if any(isinstance(fi, (bool, BooleanAtom)) for fi in f): |
|
if as_set: |
|
return [], set() |
|
return [] |
|
|
|
for i, fi in enumerate(f): |
|
|
|
while True: |
|
was = fi |
|
fi = fi.replace(Abs, lambda arg: |
|
separatevars(Abs(arg)).rewrite(Piecewise) if arg.has(*symbols) |
|
else Abs(arg)) |
|
if was == fi: |
|
break |
|
|
|
for e in fi.find(Abs): |
|
if e.has(*symbols): |
|
raise NotImplementedError('solving %s when the argument ' |
|
'is not real or imaginary.' % e) |
|
|
|
|
|
fi = fi.replace(arg, lambda a: arg(a).rewrite(atan2).rewrite(atan)) |
|
|
|
|
|
f[i] = fi |
|
|
|
|
|
freim = [fi for fi in f if fi.has(re, im)] |
|
if freim: |
|
irf = [] |
|
for s in symbols: |
|
if s.is_real or s.is_imaginary: |
|
continue |
|
|
|
if any(fi.has(re(s), im(s)) for fi in freim): |
|
irf.append((s, re(s) + S.ImaginaryUnit*im(s))) |
|
if irf: |
|
for s, rhs in irf: |
|
f = [fi.xreplace({s: rhs}) for fi in f] + [s - rhs] |
|
symbols.extend([re(s), im(s)]) |
|
if bare_f: |
|
bare_f = False |
|
flags['dict'] = True |
|
|
|
|
|
|
|
f, symbols, swap_sym = recast_to_symbols(f, symbols) |
|
|
|
symset = set(symbols) |
|
|
|
|
|
|
|
|
|
|
|
newf = [] |
|
for fi in f: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ok = False |
|
if fi.free_symbols & symset: |
|
ok = True |
|
else: |
|
if fi.is_number: |
|
if fi.is_Number: |
|
if fi.is_zero: |
|
continue |
|
return [] |
|
ok = True |
|
else: |
|
if fi.is_constant(): |
|
ok = True |
|
if ok: |
|
newf.append(fi) |
|
if not newf: |
|
if as_set: |
|
return symbols, set() |
|
return [] |
|
f = newf |
|
del newf |
|
|
|
|
|
|
|
|
|
seen = set() |
|
non_inverts = set() |
|
for fi in f: |
|
pot = preorder_traversal(fi) |
|
for p in pot: |
|
if not isinstance(p, Expr) or isinstance(p, Piecewise): |
|
pass |
|
elif (isinstance(p, bool) or |
|
not p.args or |
|
p in symset or |
|
p.is_Add or p.is_Mul or |
|
p.is_Pow and not implicit or |
|
p.is_Function and not implicit) and p.func not in (re, im): |
|
continue |
|
elif p not in seen: |
|
seen.add(p) |
|
if p.free_symbols & symset: |
|
non_inverts.add(p) |
|
else: |
|
continue |
|
pot.skip() |
|
del seen |
|
non_inverts = dict(list(zip(non_inverts, [Dummy() for _ in non_inverts]))) |
|
f = [fi.subs(non_inverts) for fi in f] |
|
|
|
|
|
|
|
non_inverts = [(v, k.xreplace(swap_sym).subs(swap_sym)) for k, v in non_inverts.items()] |
|
|
|
|
|
floats = False |
|
if flags.get('rational', True) is not False: |
|
for i, fi in enumerate(f): |
|
if fi.has(Float): |
|
floats = True |
|
f[i] = nsimplify(fi, rational=True) |
|
|
|
|
|
|
|
flags['_denominators'] = _simple_dens(f[0], symbols) |
|
|
|
|
|
|
|
|
|
|
|
def _has_piecewise(e): |
|
if e.is_Piecewise: |
|
return e.has(*symbols) |
|
return any(_has_piecewise(a) for a in e.args) |
|
for i, fi in enumerate(f): |
|
if _has_piecewise(fi): |
|
f[i] = piecewise_fold(fi) |
|
|
|
|
|
|
|
|
|
|
|
|
|
targs = {t for fi in f for t in fi.atoms(TrigonometricFunction)} |
|
if len(targs) > 1: |
|
add, other = sift(targs, lambda x: x.args[0].is_Add, binary=True) |
|
add, other = [[i for i in l if i.has_free(*symbols)] for l in (add, other)] |
|
trep = {} |
|
for t in add: |
|
a = t.args[0] |
|
ind, dep = a.as_independent(*symbols) |
|
if dep in symbols or -dep in symbols: |
|
|
|
n = Dummy() if not ind.is_Number else ind |
|
trep[t] = TR10(t.func(dep + n)).xreplace({n: ind}) |
|
if other and len(other) <= 2: |
|
base = gcd(*[i.args[0] for i in other]) if len(other) > 1 else other[0].args[0] |
|
for i in other: |
|
trep[i] = TR11(i, base) |
|
f = [fi.xreplace(trep) for fi in f] |
|
|
|
|
|
|
|
|
|
if bare_f: |
|
solution = None |
|
if len(symbols) != 1: |
|
solution = _solve_undetermined(f[0], symbols, flags) |
|
if not solution: |
|
solution = _solve(f[0], *symbols, **flags) |
|
else: |
|
linear, solution = _solve_system(f, symbols, **flags) |
|
assert type(solution) is list |
|
assert not solution or type(solution[0]) is dict, solution |
|
|
|
|
|
|
|
|
|
as_dict = flags.get('dict', False) |
|
|
|
|
|
tuple_format = lambda s: [tuple([i.get(x, x) for x in symbols]) for i in s] |
|
if as_dict or as_set: |
|
unpack = None |
|
elif bare_f: |
|
if len(symbols) == 1: |
|
unpack = lambda s: [i[symbols[0]] for i in s] |
|
elif len(solution) == 1 and len(solution[0]) == len(symbols): |
|
|
|
unpack = lambda s: s[0] |
|
elif ordered_symbols: |
|
unpack = tuple_format |
|
else: |
|
unpack = lambda s: s |
|
else: |
|
if solution: |
|
if linear and len(solution) == 1: |
|
|
|
|
|
unpack = lambda s: s[0] |
|
elif ordered_symbols: |
|
unpack = tuple_format |
|
else: |
|
unpack = lambda s: s |
|
else: |
|
unpack = None |
|
|
|
|
|
if non_inverts and type(solution) is list: |
|
solution = [{k: v.subs(non_inverts) for k, v in s.items()} |
|
for s in solution] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if swap_sym: |
|
symbols = [swap_sym.get(k, k) for k in symbols] |
|
for i, sol in enumerate(solution): |
|
solution[i] = {swap_sym.get(k, k): v.subs(swap_sym) |
|
for k, v in sol.items()} |
|
|
|
|
|
|
|
|
|
check = flags.get('check', True) |
|
|
|
|
|
if floats and solution and flags.get('rational', None) is None: |
|
solution = nfloat(solution, exponent=False) |
|
|
|
solution = _remove_duplicate_solutions(solution) |
|
|
|
if check and solution: |
|
warn = flags.get('warn', False) |
|
got_None = [] |
|
no_False = [] |
|
for sol in solution: |
|
a_None = False |
|
for symb, val in sol.items(): |
|
test = check_assumptions(val, **symb.assumptions0) |
|
if test: |
|
continue |
|
if test is False: |
|
break |
|
a_None = True |
|
else: |
|
no_False.append(sol) |
|
if a_None: |
|
got_None.append(sol) |
|
|
|
solution = no_False |
|
if warn and got_None: |
|
warnings.warn(filldedent(""" |
|
\tWarning: assumptions concerning following solution(s) |
|
cannot be checked:""" + '\n\t' + |
|
', '.join(str(s) for s in got_None))) |
|
|
|
|
|
|
|
|
|
|
|
if not solution: |
|
if as_set: |
|
return symbols, set() |
|
return [] |
|
|
|
|
|
if not as_set: |
|
solution = [{k: s[k] for k in ordered(s)} for s in solution] |
|
solution.sort(key=default_sort_key) |
|
|
|
if not (as_set or as_dict): |
|
return unpack(solution) |
|
|
|
if as_dict: |
|
return solution |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ordered_symbols: |
|
k = symbols |
|
else: |
|
|
|
k = list(ordered(set(flatten(tuple(i.keys()) for i in solution)))) |
|
return k, {tuple([s.get(ki, ki) for ki in k]) for s in solution} |
|
|
|
|
|
def _solve_undetermined(g, symbols, flags): |
|
"""solve helper to return a list with one dict (solution) else None |
|
|
|
A direct call to solve_undetermined_coeffs is more flexible and |
|
can return both multiple solutions and handle more than one independent |
|
variable. Here, we have to be more cautious to keep from solving |
|
something that does not look like an undetermined coeffs system -- |
|
to minimize the surprise factor since singularities that cancel are not |
|
prohibited in solve_undetermined_coeffs. |
|
""" |
|
if g.free_symbols - set(symbols): |
|
sol = solve_undetermined_coeffs(g, symbols, **dict(flags, dict=True, set=None)) |
|
if len(sol) == 1: |
|
return sol |
|
|
|
|
|
def _solve(f, *symbols, **flags): |
|
"""Return a checked solution for *f* in terms of one or more of the |
|
symbols in the form of a list of dictionaries. |
|
|
|
If no method is implemented to solve the equation, a NotImplementedError |
|
will be raised. In the case that conversion of an expression to a Poly |
|
gives None a ValueError will be raised. |
|
""" |
|
|
|
not_impl_msg = "No algorithms are implemented to solve equation %s" |
|
|
|
if len(symbols) != 1: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nonlin_s = [] |
|
got_s = set() |
|
rhs_s = set() |
|
result = [] |
|
for s in symbols: |
|
xi, v = solve_linear(f, symbols=[s]) |
|
if xi == s: |
|
|
|
if flags.get('simplify', True): |
|
v = simplify(v) |
|
vfree = v.free_symbols |
|
if vfree & got_s: |
|
|
|
|
|
|
|
continue |
|
rhs_s |= vfree |
|
got_s.add(xi) |
|
result.append({xi: v}) |
|
elif xi: |
|
nonlin_s.append(s) |
|
if not nonlin_s: |
|
return result |
|
for s in nonlin_s: |
|
try: |
|
soln = _solve(f, s, **flags) |
|
for sol in soln: |
|
if sol[s].free_symbols & got_s: |
|
|
|
continue |
|
got_s.add(s) |
|
result.append(sol) |
|
except NotImplementedError: |
|
continue |
|
if got_s: |
|
return result |
|
else: |
|
raise NotImplementedError(not_impl_msg % f) |
|
|
|
|
|
|
|
symbol = symbols[0] |
|
|
|
|
|
f = f.replace(lambda e: isinstance(e, binomial) and e.has(symbol), |
|
lambda e: expand_func(e)) |
|
|
|
|
|
|
|
|
|
flags['check'] = checkdens = check = flags.pop('check', True) |
|
|
|
|
|
if f.is_Mul: |
|
result = set() |
|
for m in f.args: |
|
if m in {S.NegativeInfinity, S.ComplexInfinity, S.Infinity}: |
|
result = set() |
|
break |
|
soln = _vsolve(m, symbol, **flags) |
|
result.update(set(soln)) |
|
result = [{symbol: v} for v in result] |
|
if check: |
|
|
|
|
|
|
|
dens = flags.get('_denominators', _simple_dens(f, symbols)) |
|
result = [s for s in result if |
|
not any(checksol(den, s, **flags) for den in |
|
dens)] |
|
|
|
|
|
check = False |
|
flags['simplify'] = False |
|
|
|
elif f.is_Piecewise: |
|
result = set() |
|
if any(e.is_zero for e, c in f.args): |
|
f = f.simplify() |
|
|
|
cond = neg = True |
|
for expr, cnd in f.args: |
|
|
|
|
|
cond = And(neg, cnd) |
|
neg = And(neg, ~cond) |
|
|
|
if expr.is_zero and cond.simplify() != False: |
|
raise NotImplementedError(filldedent(''' |
|
An expression is already zero when %s. |
|
This means that in this *region* the solution |
|
is zero but solve can only represent discrete, |
|
not interval, solutions. If this is a spurious |
|
interval it might be resolved with simplification |
|
of the Piecewise conditions.''' % cond)) |
|
candidates = _vsolve(expr, symbol, **flags) |
|
|
|
for candidate in candidates: |
|
if candidate in result: |
|
|
|
continue |
|
try: |
|
v = cond.subs(symbol, candidate) |
|
_eval_simplify = getattr(v, '_eval_simplify', None) |
|
if _eval_simplify is not None: |
|
|
|
v = _eval_simplify(ratio=2, measure=lambda x: 1) |
|
except TypeError: |
|
|
|
continue |
|
if v == False: |
|
continue |
|
if v == True: |
|
result.add(candidate) |
|
else: |
|
result.add(Piecewise( |
|
(candidate, v), |
|
(S.NaN, True))) |
|
|
|
|
|
return [{symbol: r} for r in result] |
|
else: |
|
|
|
|
|
f_num, sol = solve_linear(f, symbols=symbols) |
|
if f_num.is_zero or sol is S.NaN: |
|
return [] |
|
elif f_num.is_Symbol: |
|
|
|
if flags.get('simplify', True): |
|
sol = simplify(sol) |
|
return [{f_num: sol}] |
|
|
|
poly = None |
|
|
|
if not f_num.is_Add: |
|
add_args = [i for i in f_num.atoms(Add) |
|
if symbol in i.free_symbols] |
|
if len(add_args) == 1: |
|
gen = add_args[0] |
|
spart = gen.as_independent(symbol)[1].as_base_exp()[0] |
|
if spart == symbol: |
|
try: |
|
poly = Poly(f_num, spart) |
|
except PolynomialError: |
|
pass |
|
|
|
result = False |
|
msg = '' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
if poly is None: |
|
poly = Poly(f_num) |
|
if poly is None: |
|
raise ValueError('could not convert %s to Poly' % f_num) |
|
except GeneratorsNeeded: |
|
simplified_f = simplify(f_num) |
|
if simplified_f != f_num: |
|
return _solve(simplified_f, symbol, **flags) |
|
raise ValueError('expression appears to be a constant') |
|
|
|
gens = [g for g in poly.gens if g.has(symbol)] |
|
|
|
def _as_base_q(x): |
|
"""Return (b**e, q) for x = b**(p*e/q) where p/q is the leading |
|
Rational of the exponent of x, e.g. exp(-2*x/3) -> (exp(x), 3) |
|
""" |
|
b, e = x.as_base_exp() |
|
if e.is_Rational: |
|
return b, e.q |
|
if not e.is_Mul: |
|
return x, 1 |
|
c, ee = e.as_coeff_Mul() |
|
if c.is_Rational and c is not S.One: |
|
return b**ee, c.q |
|
return x, 1 |
|
|
|
if len(gens) > 1: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bases, qs = list(zip(*[_as_base_q(g) for g in gens])) |
|
bases = set(bases) |
|
|
|
if len(bases) > 1 or not all(q == 1 for q in qs): |
|
funcs = {b for b in bases if b.is_Function} |
|
|
|
trig = {_ for _ in funcs if |
|
isinstance(_, TrigonometricFunction)} |
|
other = funcs - trig |
|
if not other and len(funcs.intersection(trig)) > 1: |
|
newf = None |
|
if f_num.is_Add and len(f_num.args) == 2: |
|
|
|
_args = f_num.args |
|
t = a, b = [i.atoms(Function).intersection( |
|
trig) for i in _args] |
|
if all(len(i) == 1 for i in t): |
|
a, b = [i.pop() for i in t] |
|
if isinstance(a, cos): |
|
a, b = b, a |
|
_args = _args[::-1] |
|
if isinstance(a, sin) and isinstance(b, cos |
|
) and a.args[0] == b.args[0]: |
|
|
|
newf, _d = (TR2i(_args[0]/_args[1]) + 1 |
|
).as_numer_denom() |
|
if not _d.is_Number: |
|
newf = None |
|
if newf is None: |
|
newf = TR1(f_num).rewrite(tan) |
|
if newf != f_num: |
|
|
|
|
|
flags['check'] = False |
|
result = _solve(newf, symbol, **flags) |
|
flags['check'] = check |
|
|
|
|
|
|
|
|
|
|
|
|
|
if result is False and funcs: |
|
funcs = list(ordered(funcs)) |
|
f1 = funcs[0] |
|
t = Dummy('t') |
|
|
|
ftry = f_num.subs(f1, t) |
|
|
|
|
|
if not ftry.has(symbol): |
|
cv_sols = _solve(ftry, t, **flags) |
|
cv_inv = list(ordered(_vsolve(t - f1, symbol, **flags)))[0] |
|
result = [{symbol: cv_inv.subs(sol)} for sol in cv_sols] |
|
|
|
if result is False: |
|
msg = 'multiple generators %s' % gens |
|
|
|
else: |
|
|
|
u = bases.pop() |
|
t = Dummy('t') |
|
inv = _vsolve(u - t, symbol, **flags) |
|
if isinstance(u, (Pow, exp)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _expand(p): |
|
b, e = p.as_base_exp() |
|
e = expand_mul(e) |
|
return expand_power_exp(b**e) |
|
ftry = f_num.replace( |
|
lambda w: w.is_Pow or isinstance(w, exp), |
|
_expand).subs(u, t) |
|
if not ftry.has(symbol): |
|
soln = _solve(ftry, t, **flags) |
|
result = [{symbol: i.subs(s)} for i in inv for s in soln] |
|
|
|
elif len(gens) == 1: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
poly = Poly(f_num, gens[0], composite=True) |
|
|
|
|
|
if not flags.pop('tsolve', False): |
|
soln = None |
|
deg = poly.degree() |
|
flags['tsolve'] = True |
|
hints = ('cubics', 'quartics', 'quintics') |
|
solvers = {h: flags.get(h) for h in hints} |
|
soln = roots(poly, **solvers) |
|
if sum(soln.values()) < deg: |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
soln = poly.all_roots() |
|
except NotImplementedError: |
|
if not flags.get('incomplete', True): |
|
raise NotImplementedError( |
|
filldedent(''' |
|
Neither high-order multivariate polynomials |
|
nor sorting of EX-domain polynomials is supported. |
|
If you want to see any results, pass keyword incomplete=True to |
|
solve; to see numerical values of roots |
|
for univariate expressions, use nroots. |
|
''')) |
|
else: |
|
pass |
|
else: |
|
soln = list(soln.keys()) |
|
|
|
if soln is not None: |
|
u = poly.gen |
|
if u != symbol: |
|
try: |
|
t = Dummy('t') |
|
inv = _vsolve(u - t, symbol, **flags) |
|
soln = {i.subs(t, s) for i in inv for s in soln} |
|
except NotImplementedError: |
|
|
|
soln = None |
|
else: |
|
check = False |
|
if soln is not None: |
|
if len(soln) > 2: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flags['simplify'] = flags.get('simplify', False) |
|
if soln is not None: |
|
result = [{symbol: v} for v in soln] |
|
|
|
|
|
|
|
if result is False: |
|
|
|
if flags.pop('_unrad', True): |
|
try: |
|
u = unrad(f_num, symbol) |
|
except (ValueError, NotImplementedError): |
|
u = False |
|
if u: |
|
eq, cov = u |
|
if cov: |
|
isym, ieq = cov |
|
inv = _vsolve(ieq, symbol, **flags)[0] |
|
rv = {inv.subs(xi) for xi in _solve(eq, isym, **flags)} |
|
else: |
|
try: |
|
rv = set(_vsolve(eq, symbol, **flags)) |
|
except NotImplementedError: |
|
rv = None |
|
if rv is not None: |
|
result = [{symbol: v} for v in rv] |
|
|
|
|
|
flags['simplify'] = flags.get('simplify', False) |
|
else: |
|
pass |
|
|
|
|
|
if result is False: |
|
flags.pop('tsolve', None) |
|
try: |
|
soln = _tsolve(f_num, symbol, **flags) |
|
if soln is not None: |
|
result = [{symbol: v} for v in soln] |
|
except PolynomialError: |
|
pass |
|
|
|
|
|
if result is False: |
|
raise NotImplementedError('\n'.join([msg, not_impl_msg % f])) |
|
|
|
result = _remove_duplicate_solutions(result) |
|
|
|
if flags.get('simplify', True): |
|
result = [{k: d[k].simplify() for k in d} for d in result] |
|
|
|
result = _remove_duplicate_solutions(result) |
|
|
|
|
|
flags['simplify'] = False |
|
|
|
if checkdens: |
|
|
|
|
|
dens = _simple_dens(f, symbols) |
|
result = [r for r in result if |
|
not any(checksol(d, r, **flags) |
|
for d in dens)] |
|
if check: |
|
|
|
result = [r for r in result if |
|
checksol(f_num, r, **flags) is not False] |
|
return result |
|
|
|
|
|
def _remove_duplicate_solutions(solutions: list[dict[Expr, Expr]] |
|
) -> list[dict[Expr, Expr]]: |
|
"""Remove duplicates from a list of dicts""" |
|
solutions_set = set() |
|
solutions_new = [] |
|
|
|
for sol in solutions: |
|
solset = frozenset(sol.items()) |
|
if solset not in solutions_set: |
|
solutions_new.append(sol) |
|
solutions_set.add(solset) |
|
|
|
return solutions_new |
|
|
|
|
|
def _solve_system(exprs, symbols, **flags): |
|
"""return ``(linear, solution)`` where ``linear`` is True |
|
if the system was linear, else False; ``solution`` |
|
is a list of dictionaries giving solutions for the symbols |
|
""" |
|
if not exprs: |
|
return False, [] |
|
|
|
if flags.pop('_split', True): |
|
|
|
V = exprs |
|
symsset = set(symbols) |
|
exprsyms = {e: e.free_symbols & symsset for e in exprs} |
|
E = [] |
|
sym_indices = {sym: i for i, sym in enumerate(symbols)} |
|
for n, e1 in enumerate(exprs): |
|
for e2 in exprs[:n]: |
|
|
|
if exprsyms[e1] & exprsyms[e2]: |
|
E.append((e1, e2)) |
|
G = V, E |
|
subexprs = connected_components(G) |
|
if len(subexprs) > 1: |
|
subsols = [] |
|
linear = True |
|
for subexpr in subexprs: |
|
subsyms = set() |
|
for e in subexpr: |
|
subsyms |= exprsyms[e] |
|
subsyms = sorted(subsyms, key = lambda x: sym_indices[x]) |
|
flags['_split'] = False |
|
_linear, subsol = _solve_system(subexpr, subsyms, **flags) |
|
if linear: |
|
linear = linear and _linear |
|
if not isinstance(subsol, list): |
|
subsol = [subsol] |
|
subsols.append(subsol) |
|
|
|
sols = [] |
|
for soldicts in product(*subsols): |
|
sols.append(dict(item for sd in soldicts |
|
for item in sd.items())) |
|
return linear, sols |
|
|
|
polys = [] |
|
dens = set() |
|
failed = [] |
|
result = [] |
|
solved_syms = [] |
|
linear = True |
|
manual = flags.get('manual', False) |
|
checkdens = check = flags.get('check', True) |
|
|
|
for j, g in enumerate(exprs): |
|
dens.update(_simple_dens(g, symbols)) |
|
i, d = _invert(g, *symbols) |
|
if d in symbols: |
|
if linear: |
|
linear = solve_linear(g, 0, [d])[0] == d |
|
g = d - i |
|
g = g.as_numer_denom()[0] |
|
if manual: |
|
failed.append(g) |
|
continue |
|
|
|
poly = g.as_poly(*symbols, extension=True) |
|
|
|
if poly is not None: |
|
polys.append(poly) |
|
else: |
|
failed.append(g) |
|
|
|
if polys: |
|
if all(p.is_linear for p in polys): |
|
n, m = len(polys), len(symbols) |
|
matrix = zeros(n, m + 1) |
|
|
|
for i, poly in enumerate(polys): |
|
for monom, coeff in poly.terms(): |
|
try: |
|
j = monom.index(1) |
|
matrix[i, j] = coeff |
|
except ValueError: |
|
matrix[i, m] = -coeff |
|
|
|
|
|
if flags.pop('particular', False): |
|
result = minsolve_linear_system(matrix, *symbols, **flags) |
|
else: |
|
result = solve_linear_system(matrix, *symbols, **flags) |
|
result = [result] if result else [] |
|
if failed: |
|
if result: |
|
solved_syms = list(result[0].keys()) |
|
else: |
|
solved_syms = [] |
|
|
|
else: |
|
linear = False |
|
if len(symbols) > len(polys): |
|
|
|
free = set().union(*[p.free_symbols for p in polys]) |
|
free = list(ordered(free.intersection(symbols))) |
|
got_s = set() |
|
result = [] |
|
for syms in subsets(free, len(polys)): |
|
try: |
|
|
|
res = solve_poly_system(polys, *syms) |
|
if res: |
|
for r in set(res): |
|
skip = False |
|
for r1 in r: |
|
if got_s and any(ss in r1.free_symbols |
|
for ss in got_s): |
|
|
|
|
|
skip = True |
|
if not skip: |
|
got_s.update(syms) |
|
result.append(dict(list(zip(syms, r)))) |
|
except NotImplementedError: |
|
pass |
|
if got_s: |
|
solved_syms = list(got_s) |
|
else: |
|
raise NotImplementedError('no valid subset found') |
|
else: |
|
try: |
|
result = solve_poly_system(polys, *symbols) |
|
if result: |
|
solved_syms = symbols |
|
result = [dict(list(zip(solved_syms, r))) for r in set(result)] |
|
except NotImplementedError: |
|
failed.extend([g.as_expr() for g in polys]) |
|
solved_syms = [] |
|
|
|
|
|
result = result or [{}] |
|
|
|
if failed: |
|
linear = False |
|
|
|
|
|
|
|
|
|
|
|
def _ok_syms(e, sort=False): |
|
rv = e.free_symbols & legal |
|
|
|
|
|
|
|
|
|
|
|
def key(sym): |
|
ep = e.as_poly(sym) |
|
if ep is None: |
|
complexity = (S.Infinity, S.Infinity, S.Infinity) |
|
else: |
|
coeff_syms = ep.LC().free_symbols |
|
complexity = (ep.degree(), len(coeff_syms & rv), len(coeff_syms)) |
|
return complexity + (default_sort_key(sym),) |
|
|
|
if sort: |
|
rv = sorted(rv, key=key) |
|
return rv |
|
|
|
legal = set(symbols) |
|
|
|
u = Dummy() |
|
for eq in ordered(failed, lambda _: len(_ok_syms(_))): |
|
newresult = [] |
|
bad_results = [] |
|
hit = False |
|
for r in result: |
|
got_s = set() |
|
|
|
eq2 = eq.subs(r) |
|
|
|
|
|
if check and r: |
|
b = checksol(u, u, eq2, minimal=True) |
|
if b is not None: |
|
|
|
|
|
|
|
if b: |
|
newresult.append(r) |
|
else: |
|
bad_results.append(r) |
|
continue |
|
|
|
|
|
ok_syms = _ok_syms(eq2, sort=True) |
|
if not ok_syms: |
|
if r: |
|
newresult.append(r) |
|
break |
|
for s in ok_syms: |
|
try: |
|
soln = _vsolve(eq2, s, **flags) |
|
except NotImplementedError: |
|
continue |
|
|
|
|
|
|
|
for sol in soln: |
|
if got_s and any(ss in sol.free_symbols for ss in got_s): |
|
|
|
continue |
|
rnew = r.copy() |
|
for k, v in r.items(): |
|
rnew[k] = v.subs(s, sol) |
|
|
|
rnew[s] = sol |
|
|
|
iset = set(rnew.items()) |
|
for i in newresult: |
|
if len(i) < len(iset) and not set(i.items()) - iset: |
|
|
|
|
|
break |
|
else: |
|
|
|
newresult.append(rnew) |
|
hit = True |
|
got_s.add(s) |
|
if not hit: |
|
raise NotImplementedError('could not solve %s' % eq2) |
|
else: |
|
result = newresult |
|
for b in bad_results: |
|
if b in result: |
|
result.remove(b) |
|
|
|
if not result: |
|
return False, [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_simplify = bool(failed) |
|
if flags.get('simplify', default_simplify): |
|
for r in result: |
|
for k in r: |
|
r[k] = simplify(r[k]) |
|
flags['simplify'] = False |
|
|
|
if checkdens: |
|
result = [r for r in result |
|
if not any(checksol(d, r, **flags) for d in dens)] |
|
|
|
if check and not linear: |
|
result = [r for r in result |
|
if not any(checksol(e, r, **flags) is False for e in exprs)] |
|
|
|
result = [r for r in result if r] |
|
return linear, result |
|
|
|
|
|
def solve_linear(lhs, rhs=0, symbols=[], exclude=[]): |
|
r""" |
|
Return a tuple derived from ``f = lhs - rhs`` that is one of |
|
the following: ``(0, 1)``, ``(0, 0)``, ``(symbol, solution)``, ``(n, d)``. |
|
|
|
Explanation |
|
=========== |
|
|
|
``(0, 1)`` meaning that ``f`` is independent of the symbols in *symbols* |
|
that are not in *exclude*. |
|
|
|
``(0, 0)`` meaning that there is no solution to the equation amongst the |
|
symbols given. If the first element of the tuple is not zero, then the |
|
function is guaranteed to be dependent on a symbol in *symbols*. |
|
|
|
``(symbol, solution)`` where symbol appears linearly in the numerator of |
|
``f``, is in *symbols* (if given), and is not in *exclude* (if given). No |
|
simplification is done to ``f`` other than a ``mul=True`` expansion, so the |
|
solution will correspond strictly to a unique solution. |
|
|
|
``(n, d)`` where ``n`` and ``d`` are the numerator and denominator of ``f`` |
|
when the numerator was not linear in any symbol of interest; ``n`` will |
|
never be a symbol unless a solution for that symbol was found (in which case |
|
the second element is the solution, not the denominator). |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import cancel, Pow |
|
|
|
``f`` is independent of the symbols in *symbols* that are not in |
|
*exclude*: |
|
|
|
>>> from sympy import cos, sin, solve_linear |
|
>>> from sympy.abc import x, y, z |
|
>>> eq = y*cos(x)**2 + y*sin(x)**2 - y # = y*(1 - 1) = 0 |
|
>>> solve_linear(eq) |
|
(0, 1) |
|
>>> eq = cos(x)**2 + sin(x)**2 # = 1 |
|
>>> solve_linear(eq) |
|
(0, 1) |
|
>>> solve_linear(x, exclude=[x]) |
|
(0, 1) |
|
|
|
The variable ``x`` appears as a linear variable in each of the |
|
following: |
|
|
|
>>> solve_linear(x + y**2) |
|
(x, -y**2) |
|
>>> solve_linear(1/x - y**2) |
|
(x, y**(-2)) |
|
|
|
When not linear in ``x`` or ``y`` then the numerator and denominator are |
|
returned: |
|
|
|
>>> solve_linear(x**2/y**2 - 3) |
|
(x**2 - 3*y**2, y**2) |
|
|
|
If the numerator of the expression is a symbol, then ``(0, 0)`` is |
|
returned if the solution for that symbol would have set any |
|
denominator to 0: |
|
|
|
>>> eq = 1/(1/x - 2) |
|
>>> eq.as_numer_denom() |
|
(x, 1 - 2*x) |
|
>>> solve_linear(eq) |
|
(0, 0) |
|
|
|
But automatic rewriting may cause a symbol in the denominator to |
|
appear in the numerator so a solution will be returned: |
|
|
|
>>> (1/x)**-1 |
|
x |
|
>>> solve_linear((1/x)**-1) |
|
(x, 0) |
|
|
|
Use an unevaluated expression to avoid this: |
|
|
|
>>> solve_linear(Pow(1/x, -1, evaluate=False)) |
|
(0, 0) |
|
|
|
If ``x`` is allowed to cancel in the following expression, then it |
|
appears to be linear in ``x``, but this sort of cancellation is not |
|
done by ``solve_linear`` so the solution will always satisfy the |
|
original expression without causing a division by zero error. |
|
|
|
>>> eq = x**2*(1/x - z**2/x) |
|
>>> solve_linear(cancel(eq)) |
|
(x, 0) |
|
>>> solve_linear(eq) |
|
(x**2*(1 - z**2), x) |
|
|
|
A list of symbols for which a solution is desired may be given: |
|
|
|
>>> solve_linear(x + y + z, symbols=[y]) |
|
(y, -x - z) |
|
|
|
A list of symbols to ignore may also be given: |
|
|
|
>>> solve_linear(x + y + z, exclude=[x]) |
|
(y, -x - z) |
|
|
|
(A solution for ``y`` is obtained because it is the first variable |
|
from the canonically sorted list of symbols that had a linear |
|
solution.) |
|
|
|
""" |
|
if isinstance(lhs, Eq): |
|
if rhs: |
|
raise ValueError(filldedent(''' |
|
If lhs is an Equality, rhs must be 0 but was %s''' % rhs)) |
|
rhs = lhs.rhs |
|
lhs = lhs.lhs |
|
dens = None |
|
eq = lhs - rhs |
|
n, d = eq.as_numer_denom() |
|
if not n: |
|
return S.Zero, S.One |
|
|
|
free = n.free_symbols |
|
if not symbols: |
|
symbols = free |
|
else: |
|
bad = [s for s in symbols if not s.is_Symbol] |
|
if bad: |
|
if len(bad) == 1: |
|
bad = bad[0] |
|
if len(symbols) == 1: |
|
eg = 'solve(%s, %s)' % (eq, symbols[0]) |
|
else: |
|
eg = 'solve(%s, *%s)' % (eq, list(symbols)) |
|
raise ValueError(filldedent(''' |
|
solve_linear only handles symbols, not %s. To isolate |
|
non-symbols use solve, e.g. >>> %s <<<. |
|
''' % (bad, eg))) |
|
symbols = free.intersection(symbols) |
|
symbols = symbols.difference(exclude) |
|
if not symbols: |
|
return S.Zero, S.One |
|
|
|
|
|
|
|
|
|
derivs = defaultdict(list) |
|
for der in n.atoms(Derivative): |
|
csym = der.free_symbols & symbols |
|
for c in csym: |
|
derivs[c].append(der) |
|
|
|
all_zero = True |
|
for xi in sorted(symbols, key=default_sort_key): |
|
|
|
if isinstance(derivs[xi], list): |
|
derivs[xi] = {der: der.doit() for der in derivs[xi]} |
|
newn = n.subs(derivs[xi]) |
|
dnewn_dxi = newn.diff(xi) |
|
|
|
|
|
free = dnewn_dxi.free_symbols |
|
if dnewn_dxi and (not free or any(dnewn_dxi.diff(s) for s in free) or free == symbols): |
|
all_zero = False |
|
if dnewn_dxi is S.NaN: |
|
break |
|
if xi not in dnewn_dxi.free_symbols: |
|
vi = -1/dnewn_dxi*(newn.subs(xi, 0)) |
|
if dens is None: |
|
dens = _simple_dens(eq, symbols) |
|
if not any(checksol(di, {xi: vi}, minimal=True) is True |
|
for di in dens): |
|
|
|
irep = [(i, i.doit()) for i in vi.atoms(Integral) if |
|
i.function.is_number] |
|
|
|
vi = expand_mul(vi.subs(irep)) |
|
return xi, vi |
|
if all_zero: |
|
return S.Zero, S.One |
|
if n.is_Symbol: |
|
return S.Zero, S.Zero |
|
return n, d |
|
|
|
|
|
def minsolve_linear_system(system, *symbols, **flags): |
|
r""" |
|
Find a particular solution to a linear system. |
|
|
|
Explanation |
|
=========== |
|
|
|
In particular, try to find a solution with the minimal possible number |
|
of non-zero variables using a naive algorithm with exponential complexity. |
|
If ``quick=True``, a heuristic is used. |
|
|
|
""" |
|
quick = flags.get('quick', False) |
|
|
|
s0 = solve_linear_system(system, *symbols, **flags) |
|
if not s0 or all(v == 0 for v in s0.values()): |
|
return s0 |
|
if quick: |
|
|
|
|
|
s = solve_linear_system(system, *symbols) |
|
def update(determined, solution): |
|
delete = [] |
|
for k, v in solution.items(): |
|
solution[k] = v.subs(determined) |
|
if not solution[k].free_symbols: |
|
delete.append(k) |
|
determined[k] = solution[k] |
|
for k in delete: |
|
del solution[k] |
|
determined = {} |
|
update(determined, s) |
|
while s: |
|
|
|
k = max((k for k in s.values()), |
|
key=lambda x: (len(x.free_symbols), default_sort_key(x))) |
|
kfree = k.free_symbols |
|
x = next(reversed(list(ordered(kfree)))) |
|
if len(kfree) != 1: |
|
determined[x] = S.Zero |
|
else: |
|
val = _vsolve(k, x, check=False)[0] |
|
if not val and not any(v.subs(x, val) for v in s.values()): |
|
determined[x] = S.One |
|
else: |
|
determined[x] = val |
|
update(determined, s) |
|
return determined |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
N = len(symbols) |
|
bestsol = minsolve_linear_system(system, *symbols, quick=True) |
|
n0 = len([x for x in bestsol.values() if x != 0]) |
|
for n in range(n0 - 1, 1, -1): |
|
debugf('minsolve: %s', n) |
|
thissol = None |
|
for nonzeros in combinations(range(N), n): |
|
subm = Matrix([system.col(i).T for i in nonzeros] + [system.col(-1).T]).T |
|
s = solve_linear_system(subm, *[symbols[i] for i in nonzeros]) |
|
if s and not all(v == 0 for v in s.values()): |
|
subs = [(symbols[v], S.One) for v in nonzeros] |
|
for k, v in s.items(): |
|
s[k] = v.subs(subs) |
|
for sym in symbols: |
|
if sym not in s: |
|
if symbols.index(sym) in nonzeros: |
|
s[sym] = S.One |
|
else: |
|
s[sym] = S.Zero |
|
thissol = s |
|
break |
|
if thissol is None: |
|
break |
|
bestsol = thissol |
|
return bestsol |
|
|
|
|
|
def solve_linear_system(system, *symbols, **flags): |
|
r""" |
|
Solve system of $N$ linear equations with $M$ variables, which means |
|
both under- and overdetermined systems are supported. |
|
|
|
Explanation |
|
=========== |
|
|
|
The possible number of solutions is zero, one, or infinite. Respectively, |
|
this procedure will return None or a dictionary with solutions. In the |
|
case of underdetermined systems, all arbitrary parameters are skipped. |
|
This may cause a situation in which an empty dictionary is returned. |
|
In that case, all symbols can be assigned arbitrary values. |
|
|
|
Input to this function is a $N\times M + 1$ matrix, which means it has |
|
to be in augmented form. If you prefer to enter $N$ equations and $M$ |
|
unknowns then use ``solve(Neqs, *Msymbols)`` instead. Note: a local |
|
copy of the matrix is made by this routine so the matrix that is |
|
passed will not be modified. |
|
|
|
The algorithm used here is fraction-free Gaussian elimination, |
|
which results, after elimination, in an upper-triangular matrix. |
|
Then solutions are found using back-substitution. This approach |
|
is more efficient and compact than the Gauss-Jordan method. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import Matrix, solve_linear_system |
|
>>> from sympy.abc import x, y |
|
|
|
Solve the following system:: |
|
|
|
x + 4 y == 2 |
|
-2 x + y == 14 |
|
|
|
>>> system = Matrix(( (1, 4, 2), (-2, 1, 14))) |
|
>>> solve_linear_system(system, x, y) |
|
{x: -6, y: 2} |
|
|
|
A degenerate system returns an empty dictionary: |
|
|
|
>>> system = Matrix(( (0,0,0), (0,0,0) )) |
|
>>> solve_linear_system(system, x, y) |
|
{} |
|
|
|
""" |
|
assert system.shape[1] == len(symbols) + 1 |
|
|
|
|
|
eqs = list(system * Matrix(symbols + (-1,))) |
|
eqs, ring = sympy_eqs_to_ring(eqs, symbols) |
|
sol = solve_lin_sys(eqs, ring, _raw=False) |
|
if sol is not None: |
|
sol = {sym:val for sym, val in sol.items() if sym != val} |
|
return sol |
|
|
|
|
|
def solve_undetermined_coeffs(equ, coeffs, *syms, **flags): |
|
r""" |
|
Solve a system of equations in $k$ parameters that is formed by |
|
matching coefficients in variables ``coeffs`` that are on |
|
factors dependent on the remaining variables (or those given |
|
explicitly by ``syms``. |
|
|
|
Explanation |
|
=========== |
|
|
|
The result of this function is a dictionary with symbolic values of those |
|
parameters with respect to coefficients in $q$ -- empty if there |
|
is no solution or coefficients do not appear in the equation -- else |
|
None (if the system was not recognized). If there is more than one |
|
solution, the solutions are passed as a list. The output can be modified using |
|
the same semantics as for `solve` since the flags that are passed are sent |
|
directly to `solve` so, for example the flag ``dict=True`` will always return a list |
|
of solutions as dictionaries. |
|
|
|
This function accepts both Equality and Expr class instances. |
|
The solving process is most efficient when symbols are specified |
|
in addition to parameters to be determined, but an attempt to |
|
determine them (if absent) will be made. If an expected solution is not |
|
obtained (and symbols were not specified) try specifying them. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import Eq, solve_undetermined_coeffs |
|
>>> from sympy.abc import a, b, c, h, p, k, x, y |
|
|
|
>>> solve_undetermined_coeffs(Eq(a*x + a + b, x/2), [a, b], x) |
|
{a: 1/2, b: -1/2} |
|
>>> solve_undetermined_coeffs(a - 2, [a]) |
|
{a: 2} |
|
|
|
The equation can be nonlinear in the symbols: |
|
|
|
>>> X, Y, Z = y, x**y, y*x**y |
|
>>> eq = a*X + b*Y + c*Z - X - 2*Y - 3*Z |
|
>>> coeffs = a, b, c |
|
>>> syms = x, y |
|
>>> solve_undetermined_coeffs(eq, coeffs, syms) |
|
{a: 1, b: 2, c: 3} |
|
|
|
And the system can be nonlinear in coefficients, too, but if |
|
there is only a single solution, it will be returned as a |
|
dictionary: |
|
|
|
>>> eq = a*x**2 + b*x + c - ((x - h)**2 + 4*p*k)/4/p |
|
>>> solve_undetermined_coeffs(eq, (h, p, k), x) |
|
{h: -b/(2*a), k: (4*a*c - b**2)/(4*a), p: 1/(4*a)} |
|
|
|
Multiple solutions are always returned in a list: |
|
|
|
>>> solve_undetermined_coeffs(a**2*x + b - x, [a, b], x) |
|
[{a: -1, b: 0}, {a: 1, b: 0}] |
|
|
|
Using flag ``dict=True`` (in keeping with semantics in :func:`~.solve`) |
|
will force the result to always be a list with any solutions |
|
as elements in that list. |
|
|
|
>>> solve_undetermined_coeffs(a*x - 2*x, [a], dict=True) |
|
[{a: 2}] |
|
""" |
|
if not (coeffs and all(i.is_Symbol for i in coeffs)): |
|
raise ValueError('must provide symbols for coeffs') |
|
|
|
if isinstance(equ, Eq): |
|
eq = equ.lhs - equ.rhs |
|
else: |
|
eq = equ |
|
|
|
ceq = cancel(eq) |
|
xeq = _mexpand(ceq.as_numer_denom()[0], recursive=True) |
|
|
|
free = xeq.free_symbols |
|
coeffs = free & set(coeffs) |
|
if not coeffs: |
|
return ([], {}) if flags.get('set', None) else [] |
|
|
|
if not syms: |
|
|
|
|
|
|
|
|
|
|
|
ind, dep = xeq.as_independent(*coeffs, as_Add=True) |
|
dfree = dep.free_symbols |
|
syms = dfree & ind.free_symbols |
|
if not syms: |
|
|
|
|
|
|
|
syms = dfree - set(coeffs) |
|
if not syms: |
|
syms = [Dummy()] |
|
else: |
|
if len(syms) == 1 and iterable(syms[0]): |
|
syms = syms[0] |
|
e, s, _ = recast_to_symbols([xeq], syms) |
|
xeq = e[0] |
|
syms = s |
|
|
|
|
|
|
|
gens = set(xeq.as_coefficients_dict(*syms).keys()) - {1} |
|
cset = set(coeffs) |
|
if any(g.has_xfree(cset) for g in gens): |
|
return |
|
|
|
|
|
|
|
e, gens, _ = recast_to_symbols([xeq], list(gens)) |
|
xeq = e[0] |
|
|
|
|
|
|
|
system = list(collect(xeq, gens, evaluate=False).values()) |
|
|
|
|
|
|
|
soln = solve(system, coeffs, **flags) |
|
|
|
|
|
|
|
settings = flags.get('dict', None) or flags.get('set', None) |
|
if type(soln) is dict or settings or len(soln) != 1: |
|
return soln |
|
return soln[0] |
|
|
|
|
|
def solve_linear_system_LU(matrix, syms): |
|
""" |
|
Solves the augmented matrix system using ``LUsolve`` and returns a |
|
dictionary in which solutions are keyed to the symbols of *syms* as ordered. |
|
|
|
Explanation |
|
=========== |
|
|
|
The matrix must be invertible. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import Matrix, solve_linear_system_LU |
|
>>> from sympy.abc import x, y, z |
|
|
|
>>> solve_linear_system_LU(Matrix([ |
|
... [1, 2, 0, 1], |
|
... [3, 2, 2, 1], |
|
... [2, 0, 0, 1]]), [x, y, z]) |
|
{x: 1/2, y: 1/4, z: -1/2} |
|
|
|
See Also |
|
======== |
|
|
|
LUsolve |
|
|
|
""" |
|
if matrix.rows != matrix.cols - 1: |
|
raise ValueError("Rows should be equal to columns - 1") |
|
A = matrix[:matrix.rows, :matrix.rows] |
|
b = matrix[:, matrix.cols - 1:] |
|
soln = A.LUsolve(b) |
|
solutions = {} |
|
for i in range(soln.rows): |
|
solutions[syms[i]] = soln[i, 0] |
|
return solutions |
|
|
|
|
|
def det_perm(M): |
|
""" |
|
Return the determinant of *M* by using permutations to select factors. |
|
|
|
Explanation |
|
=========== |
|
|
|
For sizes larger than 8 the number of permutations becomes prohibitively |
|
large, or if there are no symbols in the matrix, it is better to use the |
|
standard determinant routines (e.g., ``M.det()``.) |
|
|
|
See Also |
|
======== |
|
|
|
det_minor |
|
det_quick |
|
|
|
""" |
|
args = [] |
|
s = True |
|
n = M.rows |
|
list_ = M.flat() |
|
for perm in generate_bell(n): |
|
fac = [] |
|
idx = 0 |
|
for j in perm: |
|
fac.append(list_[idx + j]) |
|
idx += n |
|
term = Mul(*fac) |
|
args.append(term if s else -term) |
|
s = not s |
|
return Add(*args) |
|
|
|
|
|
def det_minor(M): |
|
""" |
|
Return the ``det(M)`` computed from minors without |
|
introducing new nesting in products. |
|
|
|
See Also |
|
======== |
|
|
|
det_perm |
|
det_quick |
|
|
|
""" |
|
n = M.rows |
|
if n == 2: |
|
return M[0, 0]*M[1, 1] - M[1, 0]*M[0, 1] |
|
else: |
|
return sum((1, -1)[i % 2]*Add(*[M[0, i]*d for d in |
|
Add.make_args(det_minor(M.minor_submatrix(0, i)))]) |
|
if M[0, i] else S.Zero for i in range(n)) |
|
|
|
|
|
def det_quick(M, method=None): |
|
""" |
|
Return ``det(M)`` assuming that either |
|
there are lots of zeros or the size of the matrix |
|
is small. If this assumption is not met, then the normal |
|
Matrix.det function will be used with method = ``method``. |
|
|
|
See Also |
|
======== |
|
|
|
det_minor |
|
det_perm |
|
|
|
""" |
|
if any(i.has(Symbol) for i in M): |
|
if M.rows < 8 and all(i.has(Symbol) for i in M): |
|
return det_perm(M) |
|
return det_minor(M) |
|
else: |
|
return M.det(method=method) if method else M.det() |
|
|
|
|
|
def inv_quick(M): |
|
"""Return the inverse of ``M``, assuming that either |
|
there are lots of zeros or the size of the matrix |
|
is small. |
|
""" |
|
if not all(i.is_Number for i in M): |
|
if not any(i.is_Number for i in M): |
|
det = lambda _: det_perm(_) |
|
else: |
|
det = lambda _: det_minor(_) |
|
else: |
|
return M.inv() |
|
n = M.rows |
|
d = det(M) |
|
if d == S.Zero: |
|
raise NonInvertibleMatrixError("Matrix det == 0; not invertible") |
|
ret = zeros(n) |
|
s1 = -1 |
|
for i in range(n): |
|
s = s1 = -s1 |
|
for j in range(n): |
|
di = det(M.minor_submatrix(i, j)) |
|
ret[j, i] = s*di/d |
|
s = -s |
|
return ret |
|
|
|
|
|
|
|
multi_inverses = { |
|
sin: lambda x: (asin(x), S.Pi - asin(x)), |
|
cos: lambda x: (acos(x), 2*S.Pi - acos(x)), |
|
} |
|
|
|
|
|
def _vsolve(e, s, **flags): |
|
"""return list of scalar values for the solution of e for symbol s""" |
|
return [i[s] for i in _solve(e, s, **flags)] |
|
|
|
|
|
def _tsolve(eq, sym, **flags): |
|
""" |
|
Helper for ``_solve`` that solves a transcendental equation with respect |
|
to the given symbol. Various equations containing powers and logarithms, |
|
can be solved. |
|
|
|
There is currently no guarantee that all solutions will be returned or |
|
that a real solution will be favored over a complex one. |
|
|
|
Either a list of potential solutions will be returned or None will be |
|
returned (in the case that no method was known to get a solution |
|
for the equation). All other errors (like the inability to cast an |
|
expression as a Poly) are unhandled. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import log, ordered |
|
>>> from sympy.solvers.solvers import _tsolve as tsolve |
|
>>> from sympy.abc import x |
|
|
|
>>> list(ordered(tsolve(3**(2*x + 5) - 4, x))) |
|
[-5/2 + log(2)/log(3), (-5*log(3)/2 + log(2) + I*pi)/log(3)] |
|
|
|
>>> tsolve(log(x) + 2*x, x) |
|
[LambertW(2)/2] |
|
|
|
""" |
|
if 'tsolve_saw' not in flags: |
|
flags['tsolve_saw'] = [] |
|
if eq in flags['tsolve_saw']: |
|
return None |
|
else: |
|
flags['tsolve_saw'].append(eq) |
|
|
|
rhs, lhs = _invert(eq, sym) |
|
|
|
if lhs == sym: |
|
return [rhs] |
|
try: |
|
if lhs.is_Add: |
|
|
|
|
|
f = factor(powdenest(lhs - rhs)) |
|
if f.is_Mul: |
|
return _vsolve(f, sym, **flags) |
|
if rhs: |
|
f = logcombine(lhs, force=flags.get('force', True)) |
|
if f.count(log) != lhs.count(log): |
|
if isinstance(f, log): |
|
return _vsolve(f.args[0] - exp(rhs), sym, **flags) |
|
return _tsolve(f - rhs, sym, **flags) |
|
|
|
elif lhs.is_Pow: |
|
if lhs.exp.is_Integer: |
|
if lhs - rhs != eq: |
|
return _vsolve(lhs - rhs, sym, **flags) |
|
|
|
if sym not in lhs.exp.free_symbols: |
|
return _vsolve(lhs.base - rhs**(1/lhs.exp), sym, **flags) |
|
|
|
|
|
if any(t.is_Dummy for t in rhs.free_symbols): |
|
raise NotImplementedError |
|
|
|
|
|
if not rhs: |
|
|
|
|
|
sol_base = _vsolve(lhs.base, sym, **flags) |
|
return [s for s in sol_base if lhs.exp.subs(sym, s) != 0] |
|
|
|
|
|
if not lhs.base.has(sym): |
|
if lhs.base == 0: |
|
return _vsolve(lhs.exp, sym, **flags) if rhs != 0 else [] |
|
|
|
|
|
if lhs.base == rhs.as_base_exp()[0]: |
|
|
|
sol = _vsolve(lhs.exp - rhs.as_base_exp()[1], sym, **flags) |
|
else: |
|
|
|
|
|
f = exp(log(lhs.base)*lhs.exp) - exp(log(rhs)) |
|
sol = _vsolve(f, sym, **flags) |
|
|
|
|
|
def equal(expr1, expr2): |
|
_ = Dummy() |
|
eq = checksol(expr1 - _, _, expr2) |
|
if eq is None: |
|
if nsimplify(expr1) != nsimplify(expr2): |
|
return False |
|
|
|
|
|
eq = expr1.equals(expr2) |
|
return eq |
|
|
|
|
|
e_rat = nsimplify(log(abs(rhs))/log(abs(lhs.base))) |
|
e_rat = simplify(posify(e_rat)[0]) |
|
n, d = fraction(e_rat) |
|
if expand(lhs.base**n - rhs**d) == 0: |
|
sol = [s for s in sol if not equal(lhs.exp.subs(sym, s), e_rat)] |
|
sol.extend(_vsolve(lhs.exp - e_rat, sym, **flags)) |
|
|
|
return list(set(sol)) |
|
|
|
|
|
else: |
|
sol = [] |
|
logform = lhs.exp*log(lhs.base) - log(rhs) |
|
if logform != lhs - rhs: |
|
try: |
|
sol.extend(_vsolve(logform, sym, **flags)) |
|
except NotImplementedError: |
|
pass |
|
|
|
|
|
check = [] |
|
if rhs == 1: |
|
|
|
check.extend(_vsolve(lhs.exp, sym, **flags)) |
|
check.extend(_vsolve(lhs.base - 1, sym, **flags)) |
|
check.extend(_vsolve(lhs.base + 1, sym, **flags)) |
|
elif rhs.is_Rational: |
|
for d in (i for i in divisors(abs(rhs.p)) if i != 1): |
|
e, t = integer_log(rhs.p, d) |
|
if not t: |
|
continue |
|
for s in divisors(abs(rhs.q)): |
|
if s**e== rhs.q: |
|
r = Rational(d, s) |
|
check.extend(_vsolve(lhs.base - r, sym, **flags)) |
|
check.extend(_vsolve(lhs.base + r, sym, **flags)) |
|
check.extend(_vsolve(lhs.exp - e, sym, **flags)) |
|
elif rhs.is_irrational: |
|
b_l, e_l = lhs.base.as_base_exp() |
|
n, d = (e_l*lhs.exp).as_numer_denom() |
|
b, e = sqrtdenest(rhs).as_base_exp() |
|
check = [sqrtdenest(i) for i in (_vsolve(lhs.base - b, sym, **flags))] |
|
check.extend([sqrtdenest(i) for i in (_vsolve(lhs.exp - e, sym, **flags))]) |
|
if e_l*d != 1: |
|
check.extend(_vsolve(b_l**n - rhs**(e_l*d), sym, **flags)) |
|
for s in check: |
|
ok = checksol(eq, sym, s) |
|
if ok is None: |
|
ok = eq.subs(sym, s).equals(0) |
|
if ok: |
|
sol.append(s) |
|
return list(set(sol)) |
|
|
|
elif lhs.is_Function and len(lhs.args) == 1: |
|
if lhs.func in multi_inverses: |
|
|
|
soln = [] |
|
for i in multi_inverses[type(lhs)](rhs): |
|
soln.extend(_vsolve(lhs.args[0] - i, sym, **flags)) |
|
return list(set(soln)) |
|
elif lhs.func == LambertW: |
|
return _vsolve(lhs.args[0] - rhs*exp(rhs), sym, **flags) |
|
|
|
rewrite = lhs.rewrite(exp) |
|
rewrite = rebuild(rewrite) |
|
if rewrite != lhs: |
|
return _vsolve(rewrite - rhs, sym, **flags) |
|
except NotImplementedError: |
|
pass |
|
|
|
|
|
if flags.pop('bivariate', True): |
|
|
|
|
|
|
|
|
|
|
|
logs = eq.atoms(log) |
|
spow = min( |
|
{i.exp for j in logs for i in j.atoms(Pow) |
|
if i.base == sym} or {1}) |
|
if spow != 1: |
|
p = sym**spow |
|
u = Dummy('bivariate-cov') |
|
ueq = eq.subs(p, u) |
|
if not ueq.has_free(sym): |
|
sol = _vsolve(ueq, u, **flags) |
|
inv = _vsolve(p - u, sym) |
|
return [i.subs(u, s) for i in inv for s in sol] |
|
|
|
g = _filtered_gens(eq.as_poly(), sym) |
|
up_or_log = set() |
|
for gi in g: |
|
if isinstance(gi, (exp, log)) or (gi.is_Pow and gi.base == S.Exp1): |
|
up_or_log.add(gi) |
|
elif gi.is_Pow: |
|
gisimp = powdenest(expand_power_exp(gi)) |
|
if gisimp.is_Pow and sym in gisimp.exp.free_symbols: |
|
up_or_log.add(gi) |
|
eq_down = expand_log(expand_power_exp(eq)).subs( |
|
dict(list(zip(up_or_log, [0]*len(up_or_log))))) |
|
eq = expand_power_exp(factor(eq_down, deep=True) + (eq - eq_down)) |
|
rhs, lhs = _invert(eq, sym) |
|
if lhs.has(sym): |
|
try: |
|
poly = lhs.as_poly() |
|
g = _filtered_gens(poly, sym) |
|
_eq = lhs - rhs |
|
sols = _solve_lambert(_eq, sym, g) |
|
|
|
|
|
for n, s in enumerate(sols): |
|
ns = nsimplify(s) |
|
if ns != s and ns.count_ops() <= s.count_ops(): |
|
ok = checksol(_eq, sym, ns) |
|
if ok is None: |
|
ok = _eq.subs(sym, ns).equals(0) |
|
if ok: |
|
sols[n] = ns |
|
return sols |
|
except NotImplementedError: |
|
|
|
if len(g) == 2: |
|
try: |
|
gpu = bivariate_type(lhs - rhs, *g) |
|
if gpu is None: |
|
raise NotImplementedError |
|
g, p, u = gpu |
|
flags['bivariate'] = False |
|
inversion = _tsolve(g - u, sym, **flags) |
|
if inversion: |
|
sol = _vsolve(p, u, **flags) |
|
return list({i.subs(u, s) |
|
for i in inversion for s in sol}) |
|
except NotImplementedError: |
|
pass |
|
else: |
|
pass |
|
|
|
if flags.pop('force', True): |
|
flags['force'] = False |
|
pos, reps = posify(lhs - rhs) |
|
if rhs == S.ComplexInfinity: |
|
return [] |
|
for u, s in reps.items(): |
|
if s == sym: |
|
break |
|
else: |
|
u = sym |
|
if pos.has(u): |
|
try: |
|
soln = _vsolve(pos, u, **flags) |
|
return [s.subs(reps) for s in soln] |
|
except NotImplementedError: |
|
pass |
|
else: |
|
pass |
|
|
|
return |
|
|
|
|
|
|
|
|
|
@conserve_mpmath_dps |
|
def nsolve(*args, dict=False, **kwargs): |
|
r""" |
|
Solve a nonlinear equation system numerically: ``nsolve(f, [args,] x0, |
|
modules=['mpmath'], **kwargs)``. |
|
|
|
Explanation |
|
=========== |
|
|
|
``f`` is a vector function of symbolic expressions representing the system. |
|
*args* are the variables. If there is only one variable, this argument can |
|
be omitted. ``x0`` is a starting vector close to a solution. |
|
|
|
Use the modules keyword to specify which modules should be used to |
|
evaluate the function and the Jacobian matrix. Make sure to use a module |
|
that supports matrices. For more information on the syntax, please see the |
|
docstring of ``lambdify``. |
|
|
|
If the keyword arguments contain ``dict=True`` (default is False) ``nsolve`` |
|
will return a list (perhaps empty) of solution mappings. This might be |
|
especially useful if you want to use ``nsolve`` as a fallback to solve since |
|
using the dict argument for both methods produces return values of |
|
consistent type structure. Please note: to keep this consistent with |
|
``solve``, the solution will be returned in a list even though ``nsolve`` |
|
(currently at least) only finds one solution at a time. |
|
|
|
Overdetermined systems are supported. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import Symbol, nsolve |
|
>>> import mpmath |
|
>>> mpmath.mp.dps = 15 |
|
>>> x1 = Symbol('x1') |
|
>>> x2 = Symbol('x2') |
|
>>> f1 = 3 * x1**2 - 2 * x2**2 - 1 |
|
>>> f2 = x1**2 - 2 * x1 + x2**2 + 2 * x2 - 8 |
|
>>> print(nsolve((f1, f2), (x1, x2), (-1, 1))) |
|
Matrix([[-1.19287309935246], [1.27844411169911]]) |
|
|
|
For one-dimensional functions the syntax is simplified: |
|
|
|
>>> from sympy import sin, nsolve |
|
>>> from sympy.abc import x |
|
>>> nsolve(sin(x), x, 2) |
|
3.14159265358979 |
|
>>> nsolve(sin(x), 2) |
|
3.14159265358979 |
|
|
|
To solve with higher precision than the default, use the prec argument: |
|
|
|
>>> from sympy import cos |
|
>>> nsolve(cos(x) - x, 1) |
|
0.739085133215161 |
|
>>> nsolve(cos(x) - x, 1, prec=50) |
|
0.73908513321516064165531208767387340401341175890076 |
|
>>> cos(_) |
|
0.73908513321516064165531208767387340401341175890076 |
|
|
|
To solve for complex roots of real functions, a nonreal initial point |
|
must be specified: |
|
|
|
>>> from sympy import I |
|
>>> nsolve(x**2 + 2, I) |
|
1.4142135623731*I |
|
|
|
``mpmath.findroot`` is used and you can find their more extensive |
|
documentation, especially concerning keyword parameters and |
|
available solvers. Note, however, that functions which are very |
|
steep near the root, the verification of the solution may fail. In |
|
this case you should use the flag ``verify=False`` and |
|
independently verify the solution. |
|
|
|
>>> from sympy import cos, cosh |
|
>>> f = cos(x)*cosh(x) - 1 |
|
>>> nsolve(f, 3.14*100) |
|
Traceback (most recent call last): |
|
... |
|
ValueError: Could not find root within given tolerance. (1.39267e+230 > 2.1684e-19) |
|
>>> ans = nsolve(f, 3.14*100, verify=False); ans |
|
312.588469032184 |
|
>>> f.subs(x, ans).n(2) |
|
2.1e+121 |
|
>>> (f/f.diff(x)).subs(x, ans).n(2) |
|
7.4e-15 |
|
|
|
One might safely skip the verification if bounds of the root are known |
|
and a bisection method is used: |
|
|
|
>>> bounds = lambda i: (3.14*i, 3.14*(i + 1)) |
|
>>> nsolve(f, bounds(100), solver='bisect', verify=False) |
|
315.730061685774 |
|
|
|
Alternatively, a function may be better behaved when the |
|
denominator is ignored. Since this is not always the case, however, |
|
the decision of what function to use is left to the discretion of |
|
the user. |
|
|
|
>>> eq = x**2/(1 - x)/(1 - 2*x)**2 - 100 |
|
>>> nsolve(eq, 0.46) |
|
Traceback (most recent call last): |
|
... |
|
ValueError: Could not find root within given tolerance. (10000 > 2.1684e-19) |
|
Try another starting point or tweak arguments. |
|
>>> nsolve(eq.as_numer_denom()[0], 0.46) |
|
0.46792545969349058 |
|
|
|
""" |
|
|
|
|
|
if 'method' in kwargs: |
|
raise ValueError(filldedent(''' |
|
Keyword "method" should not be used in this context. When using |
|
some mpmath solvers directly, the keyword "method" is |
|
used, but when using nsolve (and findroot) the keyword to use is |
|
"solver".''')) |
|
|
|
if 'prec' in kwargs: |
|
import mpmath |
|
mpmath.mp.dps = kwargs.pop('prec') |
|
|
|
|
|
as_dict = dict |
|
from builtins import dict |
|
|
|
|
|
if len(args) == 3: |
|
f = args[0] |
|
fargs = args[1] |
|
x0 = args[2] |
|
if iterable(fargs) and iterable(x0): |
|
if len(x0) != len(fargs): |
|
raise TypeError('nsolve expected exactly %i guess vectors, got %i' |
|
% (len(fargs), len(x0))) |
|
elif len(args) == 2: |
|
f = args[0] |
|
fargs = None |
|
x0 = args[1] |
|
if iterable(f): |
|
raise TypeError('nsolve expected 3 arguments, got 2') |
|
elif len(args) < 2: |
|
raise TypeError('nsolve expected at least 2 arguments, got %i' |
|
% len(args)) |
|
else: |
|
raise TypeError('nsolve expected at most 3 arguments, got %i' |
|
% len(args)) |
|
modules = kwargs.get('modules', ['mpmath']) |
|
if iterable(f): |
|
f = list(f) |
|
for i, fi in enumerate(f): |
|
if isinstance(fi, Eq): |
|
f[i] = fi.lhs - fi.rhs |
|
f = Matrix(f).T |
|
if iterable(x0): |
|
x0 = list(x0) |
|
if not isinstance(f, Matrix): |
|
|
|
if isinstance(f, Eq): |
|
f = f.lhs - f.rhs |
|
elif f.is_Relational: |
|
raise TypeError('nsolve cannot accept inequalities') |
|
syms = f.free_symbols |
|
if fargs is None: |
|
fargs = syms.copy().pop() |
|
if not (len(syms) == 1 and (fargs in syms or fargs[0] in syms)): |
|
raise ValueError(filldedent(''' |
|
expected a one-dimensional and numerical function''')) |
|
|
|
|
|
|
|
|
|
|
|
|
|
f = lambdify(fargs, f, modules) |
|
x = sympify(findroot(f, x0, **kwargs)) |
|
if as_dict: |
|
return [{fargs: x}] |
|
return x |
|
|
|
if len(fargs) > f.cols: |
|
raise NotImplementedError(filldedent(''' |
|
need at least as many equations as variables''')) |
|
verbose = kwargs.get('verbose', False) |
|
if verbose: |
|
print('f(x):') |
|
print(f) |
|
|
|
J = f.jacobian(fargs) |
|
if verbose: |
|
print('J(x):') |
|
print(J) |
|
|
|
f = lambdify(fargs, f.T, modules) |
|
J = lambdify(fargs, J, modules) |
|
|
|
x = findroot(f, x0, J=J, **kwargs) |
|
if as_dict: |
|
return [dict(zip(fargs, [sympify(xi) for xi in x]))] |
|
return Matrix(x) |
|
|
|
|
|
def _invert(eq, *symbols, **kwargs): |
|
""" |
|
Return tuple (i, d) where ``i`` is independent of *symbols* and ``d`` |
|
contains symbols. |
|
|
|
Explanation |
|
=========== |
|
|
|
``i`` and ``d`` are obtained after recursively using algebraic inversion |
|
until an uninvertible ``d`` remains. If there are no free symbols then |
|
``d`` will be zero. Some (but not necessarily all) solutions to the |
|
expression ``i - d`` will be related to the solutions of the original |
|
expression. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy.solvers.solvers import _invert as invert |
|
>>> from sympy import sqrt, cos |
|
>>> from sympy.abc import x, y |
|
>>> invert(x - 3) |
|
(3, x) |
|
>>> invert(3) |
|
(3, 0) |
|
>>> invert(2*cos(x) - 1) |
|
(1/2, cos(x)) |
|
>>> invert(sqrt(x) - 3) |
|
(3, sqrt(x)) |
|
>>> invert(sqrt(x) + y, x) |
|
(-y, sqrt(x)) |
|
>>> invert(sqrt(x) + y, y) |
|
(-sqrt(x), y) |
|
>>> invert(sqrt(x) + y, x, y) |
|
(0, sqrt(x) + y) |
|
|
|
If there is more than one symbol in a power's base and the exponent |
|
is not an Integer, then the principal root will be used for the |
|
inversion: |
|
|
|
>>> invert(sqrt(x + y) - 2) |
|
(4, x + y) |
|
>>> invert(sqrt(x + y) + 2) # note +2 instead of -2 |
|
(4, x + y) |
|
|
|
If the exponent is an Integer, setting ``integer_power`` to True |
|
will force the principal root to be selected: |
|
|
|
>>> invert(x**2 - 4, integer_power=True) |
|
(2, x) |
|
|
|
""" |
|
eq = sympify(eq) |
|
if eq.args: |
|
|
|
eq = eq.func(*eq.args) |
|
free = eq.free_symbols |
|
if not symbols: |
|
symbols = free |
|
if not free & set(symbols): |
|
return eq, S.Zero |
|
|
|
dointpow = bool(kwargs.get('integer_power', False)) |
|
|
|
lhs = eq |
|
rhs = S.Zero |
|
while True: |
|
was = lhs |
|
while True: |
|
indep, dep = lhs.as_independent(*symbols) |
|
|
|
|
|
if lhs.is_Add: |
|
|
|
if indep.is_zero: |
|
break |
|
|
|
lhs = dep |
|
rhs -= indep |
|
|
|
|
|
else: |
|
|
|
if indep is S.One: |
|
break |
|
|
|
lhs = dep |
|
rhs /= indep |
|
|
|
|
|
if lhs.is_Add: |
|
terms = {} |
|
for a in lhs.args: |
|
i, d = a.as_independent(*symbols) |
|
terms.setdefault(d, []).append(i) |
|
if any(len(v) > 1 for v in terms.values()): |
|
args = [] |
|
for d, i in terms.items(): |
|
if len(i) > 1: |
|
args.append(Add(*i)*d) |
|
else: |
|
args.append(i[0]*d) |
|
lhs = Add(*args) |
|
|
|
|
|
|
|
if lhs.is_Add and not rhs and len(lhs.args) == 2 and \ |
|
not lhs.is_polynomial(*symbols): |
|
a, b = ordered(lhs.args) |
|
ai, ad = a.as_independent(*symbols) |
|
bi, bd = b.as_independent(*symbols) |
|
if any(_ispow(i) for i in (ad, bd)): |
|
a_base, a_exp = ad.as_base_exp() |
|
b_base, b_exp = bd.as_base_exp() |
|
if a_base == b_base and a_exp.extract_additively(b_exp) is None: |
|
|
|
|
|
|
|
rat = powsimp(powdenest(ad/bd)) |
|
lhs = rat |
|
rhs = -bi/ai |
|
else: |
|
rat = ad/bd |
|
_lhs = powsimp(ad/bd) |
|
if _lhs != rat: |
|
lhs = _lhs |
|
rhs = -bi/ai |
|
elif ai == -bi: |
|
if isinstance(ad, Function) and ad.func == bd.func: |
|
if len(ad.args) == len(bd.args) == 1: |
|
lhs = ad.args[0] - bd.args[0] |
|
elif len(ad.args) == len(bd.args): |
|
|
|
|
|
raise NotImplementedError( |
|
'equal function with more than 1 argument') |
|
else: |
|
raise ValueError( |
|
'function with different numbers of args') |
|
|
|
elif lhs.is_Mul and any(_ispow(a) for a in lhs.args): |
|
lhs = powsimp(powdenest(lhs)) |
|
|
|
if lhs.is_Function: |
|
if hasattr(lhs, 'inverse') and lhs.inverse() is not None and len(lhs.args) == 1: |
|
|
|
|
|
|
|
|
|
|
|
|
|
rhs = lhs.inverse()(rhs) |
|
lhs = lhs.args[0] |
|
elif isinstance(lhs, atan2): |
|
y, x = lhs.args |
|
lhs = 2*atan(y/(sqrt(x**2 + y**2) + x)) |
|
elif lhs.func == rhs.func: |
|
if len(lhs.args) == len(rhs.args) == 1: |
|
lhs = lhs.args[0] |
|
rhs = rhs.args[0] |
|
elif len(lhs.args) == len(rhs.args): |
|
|
|
|
|
|
|
raise NotImplementedError( |
|
'equal function with more than 1 argument') |
|
else: |
|
raise ValueError( |
|
'function with different numbers of args') |
|
|
|
|
|
if rhs and lhs.is_Pow and lhs.exp.is_Integer and lhs.exp < 0: |
|
lhs = 1/lhs |
|
rhs = 1/rhs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if lhs.is_Pow and ( |
|
lhs.exp.is_Integer and dointpow or not lhs.exp.is_Integer and |
|
len(symbols) > 1 and len(lhs.base.free_symbols & set(symbols)) > 1): |
|
rhs = rhs**(1/lhs.exp) |
|
lhs = lhs.base |
|
|
|
if lhs == was: |
|
break |
|
return rhs, lhs |
|
|
|
|
|
def unrad(eq, *syms, **flags): |
|
""" |
|
Remove radicals with symbolic arguments and return (eq, cov), |
|
None, or raise an error. |
|
|
|
Explanation |
|
=========== |
|
|
|
None is returned if there are no radicals to remove. |
|
|
|
NotImplementedError is raised if there are radicals and they cannot be |
|
removed or if the relationship between the original symbols and the |
|
change of variable needed to rewrite the system as a polynomial cannot |
|
be solved. |
|
|
|
Otherwise the tuple, ``(eq, cov)``, is returned where: |
|
|
|
*eq*, ``cov`` |
|
*eq* is an equation without radicals (in the symbol(s) of |
|
interest) whose solutions are a superset of the solutions to the |
|
original expression. *eq* might be rewritten in terms of a new |
|
variable; the relationship to the original variables is given by |
|
``cov`` which is a list containing ``v`` and ``v**p - b`` where |
|
``p`` is the power needed to clear the radical and ``b`` is the |
|
radical now expressed as a polynomial in the symbols of interest. |
|
For example, for sqrt(2 - x) the tuple would be |
|
``(c, c**2 - 2 + x)``. The solutions of *eq* will contain |
|
solutions to the original equation (if there are any). |
|
|
|
*syms* |
|
An iterable of symbols which, if provided, will limit the focus of |
|
radical removal: only radicals with one or more of the symbols of |
|
interest will be cleared. All free symbols are used if *syms* is not |
|
set. |
|
|
|
*flags* are used internally for communication during recursive calls. |
|
Two options are also recognized: |
|
|
|
``take``, when defined, is interpreted as a single-argument function |
|
that returns True if a given Pow should be handled. |
|
|
|
Radicals can be removed from an expression if: |
|
|
|
* All bases of the radicals are the same; a change of variables is |
|
done in this case. |
|
* If all radicals appear in one term of the expression. |
|
* There are only four terms with sqrt() factors or there are less than |
|
four terms having sqrt() factors. |
|
* There are only two terms with radicals. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy.solvers.solvers import unrad |
|
>>> from sympy.abc import x |
|
>>> from sympy import sqrt, Rational, root |
|
|
|
>>> unrad(sqrt(x)*x**Rational(1, 3) + 2) |
|
(x**5 - 64, []) |
|
>>> unrad(sqrt(x) + root(x + 1, 3)) |
|
(-x**3 + x**2 + 2*x + 1, []) |
|
>>> eq = sqrt(x) + root(x, 3) - 2 |
|
>>> unrad(eq) |
|
(_p**3 + _p**2 - 2, [_p, _p**6 - x]) |
|
|
|
""" |
|
|
|
uflags = {"check": False, "simplify": False} |
|
|
|
def _cov(p, e): |
|
if cov: |
|
|
|
oldp, olde = cov |
|
if Poly(e, p).degree(p) in (1, 2): |
|
cov[:] = [p, olde.subs(oldp, _vsolve(e, p, **uflags)[0])] |
|
else: |
|
raise NotImplementedError |
|
else: |
|
cov[:] = [p, e] |
|
|
|
def _canonical(eq, cov): |
|
if cov: |
|
|
|
p, e = cov |
|
rep = {p: Dummy(p.name)} |
|
eq = eq.xreplace(rep) |
|
cov = [p.xreplace(rep), e.xreplace(rep)] |
|
|
|
|
|
|
|
eq = factor_terms(_mexpand(eq.as_numer_denom()[0], recursive=True), clear=True) |
|
if eq.is_Mul: |
|
args = [] |
|
for f in eq.args: |
|
if f.is_number: |
|
continue |
|
if f.is_Pow: |
|
args.append(f.base) |
|
else: |
|
args.append(f) |
|
eq = Mul(*args) |
|
|
|
|
|
margs = list(Mul.make_args(eq)) |
|
changed = False |
|
for i, m in enumerate(margs): |
|
if m.could_extract_minus_sign(): |
|
margs[i] = -m |
|
changed = True |
|
if changed: |
|
eq = Mul(*margs, evaluate=False) |
|
|
|
return eq, cov |
|
|
|
def _Q(pow): |
|
|
|
c = pow.as_base_exp()[1].as_coeff_Mul()[0] |
|
if not c.is_Rational: |
|
return S.One |
|
return c.q |
|
|
|
|
|
def _take(d): |
|
|
|
for pow in Mul.make_args(d): |
|
if not pow.is_Pow: |
|
continue |
|
if _Q(pow) == 1: |
|
continue |
|
if pow.free_symbols & syms: |
|
return True |
|
return False |
|
_take = flags.setdefault('_take', _take) |
|
|
|
if isinstance(eq, Eq): |
|
eq = eq.lhs - eq.rhs |
|
elif not isinstance(eq, Expr): |
|
return |
|
|
|
cov, nwas, rpt = [flags.setdefault(k, v) for k, v in |
|
sorted({"cov": [], "n": None, "rpt": 0}.items())] |
|
|
|
|
|
eq = powdenest(factor_terms(eq, radical=True, clear=True)) |
|
eq = eq.as_numer_denom()[0] |
|
eq = _mexpand(eq, recursive=True) |
|
if eq.is_number: |
|
return |
|
|
|
|
|
syms = set(syms) or eq.free_symbols |
|
poly = eq.as_poly() |
|
gens = [g for g in poly.gens if _take(g)] |
|
if not gens: |
|
return |
|
|
|
|
|
poly = eq.as_poly(*gens) |
|
|
|
|
|
if poly is None: |
|
return |
|
|
|
|
|
if any(g.exp.has(*syms) for g in gens): |
|
return |
|
|
|
def _rads_bases_lcm(poly): |
|
|
|
|
|
|
|
lcm = 1 |
|
rads = set() |
|
bases = set() |
|
for g in poly.gens: |
|
q = _Q(g) |
|
if q != 1: |
|
rads.add(g) |
|
lcm = ilcm(lcm, q) |
|
bases.add(g.base) |
|
return rads, bases, lcm |
|
rads, bases, lcm = _rads_bases_lcm(poly) |
|
|
|
covsym = Dummy('p', nonnegative=True) |
|
|
|
|
|
|
|
newsyms = set() |
|
for r in rads: |
|
newsyms.update(syms & r.free_symbols) |
|
if newsyms != syms: |
|
syms = newsyms |
|
|
|
drad = dict(zip(rads, range(len(rads)))) |
|
rterms = {(): []} |
|
args = Add.make_args(poly.as_expr()) |
|
for t in args: |
|
if _take(t): |
|
common = set(t.as_poly().gens).intersection(rads) |
|
key = tuple(sorted([drad[i] for i in common])) |
|
else: |
|
key = () |
|
rterms.setdefault(key, []).append(t) |
|
others = Add(*rterms.pop(())) |
|
rterms = [Add(*rterms[k]) for k in rterms.keys()] |
|
|
|
|
|
|
|
rterms = list(reversed(list(ordered(rterms)))) |
|
|
|
ok = False |
|
depth = sqrt_depth(eq) |
|
|
|
if len(rterms) == 1 and not (rterms[0].is_Add and lcm > 2): |
|
eq = rterms[0]**lcm - ((-others)**lcm) |
|
ok = True |
|
else: |
|
if len(rterms) == 1 and rterms[0].is_Add: |
|
rterms = list(rterms[0].args) |
|
if len(bases) == 1: |
|
b = bases.pop() |
|
if len(syms) > 1: |
|
x = b.free_symbols |
|
else: |
|
x = syms |
|
x = list(ordered(x))[0] |
|
try: |
|
inv = _vsolve(covsym**lcm - b, x, **uflags) |
|
if not inv: |
|
raise NotImplementedError |
|
eq = poly.as_expr().subs(b, covsym**lcm).subs(x, inv[0]) |
|
_cov(covsym, covsym**lcm - b) |
|
return _canonical(eq, cov) |
|
except NotImplementedError: |
|
pass |
|
|
|
if len(rterms) == 2: |
|
if not others: |
|
eq = rterms[0]**lcm - (-rterms[1])**lcm |
|
ok = True |
|
elif not log(lcm, 2).is_Integer: |
|
|
|
r0, r1 = rterms |
|
if flags.get('_reverse', False): |
|
r1, r0 = r0, r1 |
|
i0 = _rads0, _bases0, lcm0 = _rads_bases_lcm(r0.as_poly()) |
|
i1 = _rads1, _bases1, lcm1 = _rads_bases_lcm(r1.as_poly()) |
|
for reverse in range(2): |
|
if reverse: |
|
i0, i1 = i1, i0 |
|
r0, r1 = r1, r0 |
|
_rads1, _, lcm1 = i1 |
|
_rads1 = Mul(*_rads1) |
|
t1 = _rads1**lcm1 |
|
c = covsym**lcm1 - t1 |
|
for x in syms: |
|
try: |
|
sol = _vsolve(c, x, **uflags) |
|
if not sol: |
|
raise NotImplementedError |
|
neweq = r0.subs(x, sol[0]) + covsym*r1/_rads1 + \ |
|
others |
|
tmp = unrad(neweq, covsym) |
|
if tmp: |
|
eq, newcov = tmp |
|
if newcov: |
|
newp, newc = newcov |
|
_cov(newp, c.subs(covsym, |
|
_vsolve(newc, covsym, **uflags)[0])) |
|
else: |
|
_cov(covsym, c) |
|
else: |
|
eq = neweq |
|
_cov(covsym, c) |
|
ok = True |
|
break |
|
except NotImplementedError: |
|
if reverse: |
|
raise NotImplementedError( |
|
'no successful change of variable found') |
|
else: |
|
pass |
|
if ok: |
|
break |
|
elif len(rterms) == 3: |
|
|
|
|
|
|
|
info = [_rads_bases_lcm(i.as_poly()) for i in rterms] |
|
RAD = 0 |
|
BASES = 1 |
|
LCM = 2 |
|
if info[0][LCM] != 3: |
|
info.append(info.pop(0)) |
|
rterms.append(rterms.pop(0)) |
|
elif info[1][LCM] != 3: |
|
info.append(info.pop(1)) |
|
rterms.append(rterms.pop(1)) |
|
if info[0][LCM] == info[1][LCM] == 3: |
|
if info[1][BASES] != info[2][BASES]: |
|
info[0], info[1] = info[1], info[0] |
|
rterms[0], rterms[1] = rterms[1], rterms[0] |
|
if info[1][BASES] == info[2][BASES]: |
|
eq = rterms[0]**3 + (rterms[1] + rterms[2] + others)**3 |
|
ok = True |
|
elif info[2][LCM] < 5: |
|
|
|
a, b, c, d, A, B = [Dummy(i) for i in 'abcdAB'] |
|
|
|
|
|
zz = (c - d)*(A**3*a**9 + 3*A**2*B*a**6*b**3 - |
|
3*A**2*a**6*c**3 + 9*A**2*a**6*c**2*d - 9*A**2*a**6*c*d**2 + |
|
3*A**2*a**6*d**3 + 3*A*B**2*a**3*b**6 + 21*A*B*a**3*b**3*c**3 - |
|
63*A*B*a**3*b**3*c**2*d + 63*A*B*a**3*b**3*c*d**2 - |
|
21*A*B*a**3*b**3*d**3 + 3*A*a**3*c**6 - 18*A*a**3*c**5*d + |
|
45*A*a**3*c**4*d**2 - 60*A*a**3*c**3*d**3 + 45*A*a**3*c**2*d**4 - |
|
18*A*a**3*c*d**5 + 3*A*a**3*d**6 + B**3*b**9 - 3*B**2*b**6*c**3 + |
|
9*B**2*b**6*c**2*d - 9*B**2*b**6*c*d**2 + 3*B**2*b**6*d**3 + |
|
3*B*b**3*c**6 - 18*B*b**3*c**5*d + 45*B*b**3*c**4*d**2 - |
|
60*B*b**3*c**3*d**3 + 45*B*b**3*c**2*d**4 - 18*B*b**3*c*d**5 + |
|
3*B*b**3*d**6 - c**9 + 9*c**8*d - 36*c**7*d**2 + 84*c**6*d**3 - |
|
126*c**5*d**4 + 126*c**4*d**5 - 84*c**3*d**6 + 36*c**2*d**7 - |
|
9*c*d**8 + d**9) |
|
def _t(i): |
|
b = Mul(*info[i][RAD]) |
|
return cancel(rterms[i]/b), Mul(*info[i][BASES]) |
|
aa, AA = _t(0) |
|
bb, BB = _t(1) |
|
cc = -rterms[2] |
|
dd = others |
|
eq = zz.xreplace(dict(zip( |
|
(a, A, b, B, c, d), |
|
(aa, AA, bb, BB, cc, dd)))) |
|
ok = True |
|
|
|
if not ok: |
|
if log(lcm, 2).is_Integer and (not others and |
|
len(rterms) == 4 or len(rterms) < 4): |
|
def _norm2(a, b): |
|
return a**2 + b**2 + 2*a*b |
|
|
|
if len(rterms) == 4: |
|
|
|
r0, r1, r2, r3 = rterms |
|
eq = _norm2(r0, r1) - _norm2(r2, r3) |
|
ok = True |
|
elif len(rterms) == 3: |
|
|
|
r0, r1, r2 = rterms |
|
eq = _norm2(r1, r2) - _norm2(r0, others) |
|
ok = True |
|
elif len(rterms) == 2: |
|
|
|
r0, r1 = rterms |
|
eq = r0**2 - _norm2(r1, others) |
|
ok = True |
|
|
|
new_depth = sqrt_depth(eq) if ok else depth |
|
rpt += 1 |
|
if not ok or ( |
|
nwas is not None and len(rterms) == nwas and |
|
new_depth is not None and new_depth == depth and |
|
rpt > 3): |
|
raise NotImplementedError('Cannot remove all radicals') |
|
|
|
flags.update({"cov": cov, "n": len(rterms), "rpt": rpt}) |
|
neq = unrad(eq, *syms, **flags) |
|
if neq: |
|
eq, cov = neq |
|
eq, cov = _canonical(eq, cov) |
|
return eq, cov |
|
|
|
|
|
|
|
from sympy.solvers.bivariate import ( |
|
bivariate_type, _solve_lambert, _filtered_gens) |
|
|