Python sympy 模块,cse() 实例源码

我们从Python开源项目中,提取了以下5个代码示例,用于说明如何使用sympy.cse()

项目:Python-iBeacon-Scan    作者:NikNitro    | 项目源码 | 文件源码
def test_cse():
    e = a*a + b*b + sympy.exp(-a*a - b*b)
    e2 = sympy.cse(e)
    f = g.llvm_callable([a, b], e2)
    res = float(e.subs({a: 2.3, b: 0.1}).evalf())
    jit_res = f(2.3, 0.1)

    assert isclose(jit_res, res)
项目:Python-iBeacon-Scan    作者:NikNitro    | 项目源码 | 文件源码
def test_cse_multiple():
    e1 = a*a
    e2 = a*a + b*b
    e3 = sympy.cse([e1, e2])

    raises(NotImplementedError,
           lambda: g.llvm_callable([a, b], e3, callback_type='scipy.integrate'))

    f = g.llvm_callable([a, b], e3)
    jit_res = f(0.1, 1.5)
    assert len(jit_res) == 2
    res = eval_cse(e3, {a: 0.1, b: 1.5})
    assert isclose(res[0], jit_res[0])
    assert isclose(res[1], jit_res[1])
项目:Python-iBeacon-Scan    作者:NikNitro    | 项目源码 | 文件源码
def test_callback_cubature_multiple():
    e1 = a*a
    e2 = a*a + b*b
    e3 = sympy.cse([e1, e2, 4*e2])
    f = g.llvm_callable([a, b], e3, callback_type='cubature')

    # Number of input variables
    ndim = 2
    # Number of output expression values
    outdim = 3

    m = ctypes.c_int(ndim)
    fdim = ctypes.c_int(outdim)
    array_type = ctypes.c_double * ndim
    out_array_type = ctypes.c_double * outdim
    inp = {a: 0.2, b: 1.5}
    array = array_type(inp[a], inp[b])
    out_array = out_array_type()
    jit_ret = f(m, array, None, fdim, out_array)

    assert jit_ret == 0

    res = eval_cse(e3, inp)

    assert isclose(out_array[0], res[0])
    assert isclose(out_array[1], res[1])
    assert isclose(out_array[2], res[2])
项目:zippy    作者:securesystemslab    | 项目源码 | 文件源码
def genfcode(lambdastr, use_cse=False):
    """
    Python lambda string -> C function code

    Optionally cse() is used to eliminate common subexpressions.
    """
    # TODO: verify lambda string
    # interpret lambda string
    varstr, fstr = lambdastr.split(': ')
    varstr = varstr.lstrip('lambda ')
    # generate C variable string
    cvars = varstr.split(',')
    cvarstr = ''
    for v in cvars:
        cvarstr += 'double %s, ' % v
    cvarstr = cvarstr.rstrip(', ')
    # convert function string to C syntax
    if not use_cse:
        cfstr = ''
        finalexpr = cexpr(fstr)
    else:
        # eliminate common subexpressions
        subs, finalexpr = cse(sympify(fstr), _gentmpvars())
        if len(finalexpr) != 1:
            raise ValueError("Length should be 1")
        vardec = ''
        cfstr = ''
        for symbol, expr in subs:
            vardec += '    double %s;\n' % symbol.name
            cfstr += '    %s = %s;\n' % (
                symbol.name, cexpr(str(expr.evalf(dps))))
        cfstr = vardec + cfstr
        finalexpr = cexpr(str(finalexpr[0].evalf(dps)))
    # generate C code
    code = """
inline double f(%s)
    {
%s
    return %s;
    }
""" % (cvarstr, cfstr, finalexpr)
    return code
项目:meshless    作者:compmech    | 项目源码 | 文件源码
def print_as_array(m, mname, sufix=None, use_cse=False, header=None,
        print_file=True, collect_for=None, pow_by_mul=True, order='C',
        op='+='):
    ls = []
    if use_cse:
        subs, m_list = sympy.cse(m)
        for i, v in enumerate(m_list):
            m[i] = v
    if sufix is None:
        namesufix = '{0}'.format(mname)
    else:
        namesufix = '{0}_{1}'.format(mname, sufix)
    filename = 'print_{0}.txt'.format(namesufix)
    if header:
        ls.append(header)
    if use_cse:
        ls.append('# cdefs')
        num = 9
        for i, sub in enumerate(subs[::num]):
            ls.append('cdef double ' + ', '.join(
                        map(str, [j[0] for j in subs[num*i:num*(i+1)]])))
        ls.append('# subs')
        for sub in subs:
            ls.append('{0} = {1}'.format(*sub))
    ls.append('# {0}'.format(namesufix))
    num = len([i for i in list(m) if i])
    ls.append('# {0}_num={1}'.format(namesufix, num))
    if order == 'C':
        miter = enumerate(np.ravel(m))
    elif order == 'F':
        miter = enumerate(np.ravel(m.T))
    miter = list(miter)
    for i, v in miter:
        if v:
            if collect_for is not None:
                v = collect(v, collect_for, evaluate=False)
                ls.append('{0}[pos+{1}] +='.format(mname, i))
                for k, expr in v.items():
                    ls.append('#   collected for {k}'.format(k=k))
                    ls.append('    {expr}'.format(expr=k*expr))
            else:
                if pow_by_mul:
                    v = str(v)
                    for p in re.findall(r'\w+\*\*\d+', v):
                        var, exp = p.split('**')
                        v = v.replace(p, '(' + '*'.join([var]*int(exp)) + ')')
                ls.append('{0}[pos+{1}] {2} {3}'.format(mname, i, op, v))
    string = '\n'.join(ls)
    if print_file:
        with open(filename, 'w') as f:
            f.write(string)
    return string