我们从Python开源项目中,提取了以下5个代码示例,用于说明如何使用sympy.cse()。
def test_cse(): e = a*a + b*b + sympy.exp(-a*a - b*b) e2 = sympy.cse(e) f = g.llvm_callable([a, b], e2) res = float(e.subs({a: 2.3, b: 0.1}).evalf()) jit_res = f(2.3, 0.1) assert isclose(jit_res, res)
def test_cse_multiple(): e1 = a*a e2 = a*a + b*b e3 = sympy.cse([e1, e2]) raises(NotImplementedError, lambda: g.llvm_callable([a, b], e3, callback_type='scipy.integrate')) f = g.llvm_callable([a, b], e3) jit_res = f(0.1, 1.5) assert len(jit_res) == 2 res = eval_cse(e3, {a: 0.1, b: 1.5}) assert isclose(res[0], jit_res[0]) assert isclose(res[1], jit_res[1])
def test_callback_cubature_multiple(): e1 = a*a e2 = a*a + b*b e3 = sympy.cse([e1, e2, 4*e2]) f = g.llvm_callable([a, b], e3, callback_type='cubature') # Number of input variables ndim = 2 # Number of output expression values outdim = 3 m = ctypes.c_int(ndim) fdim = ctypes.c_int(outdim) array_type = ctypes.c_double * ndim out_array_type = ctypes.c_double * outdim inp = {a: 0.2, b: 1.5} array = array_type(inp[a], inp[b]) out_array = out_array_type() jit_ret = f(m, array, None, fdim, out_array) assert jit_ret == 0 res = eval_cse(e3, inp) assert isclose(out_array[0], res[0]) assert isclose(out_array[1], res[1]) assert isclose(out_array[2], res[2])
def genfcode(lambdastr, use_cse=False): """ Python lambda string -> C function code Optionally cse() is used to eliminate common subexpressions. """ # TODO: verify lambda string # interpret lambda string varstr, fstr = lambdastr.split(': ') varstr = varstr.lstrip('lambda ') # generate C variable string cvars = varstr.split(',') cvarstr = '' for v in cvars: cvarstr += 'double %s, ' % v cvarstr = cvarstr.rstrip(', ') # convert function string to C syntax if not use_cse: cfstr = '' finalexpr = cexpr(fstr) else: # eliminate common subexpressions subs, finalexpr = cse(sympify(fstr), _gentmpvars()) if len(finalexpr) != 1: raise ValueError("Length should be 1") vardec = '' cfstr = '' for symbol, expr in subs: vardec += ' double %s;\n' % symbol.name cfstr += ' %s = %s;\n' % ( symbol.name, cexpr(str(expr.evalf(dps)))) cfstr = vardec + cfstr finalexpr = cexpr(str(finalexpr[0].evalf(dps))) # generate C code code = """ inline double f(%s) { %s return %s; } """ % (cvarstr, cfstr, finalexpr) return code
def print_as_array(m, mname, sufix=None, use_cse=False, header=None, print_file=True, collect_for=None, pow_by_mul=True, order='C', op='+='): ls = [] if use_cse: subs, m_list = sympy.cse(m) for i, v in enumerate(m_list): m[i] = v if sufix is None: namesufix = '{0}'.format(mname) else: namesufix = '{0}_{1}'.format(mname, sufix) filename = 'print_{0}.txt'.format(namesufix) if header: ls.append(header) if use_cse: ls.append('# cdefs') num = 9 for i, sub in enumerate(subs[::num]): ls.append('cdef double ' + ', '.join( map(str, [j[0] for j in subs[num*i:num*(i+1)]]))) ls.append('# subs') for sub in subs: ls.append('{0} = {1}'.format(*sub)) ls.append('# {0}'.format(namesufix)) num = len([i for i in list(m) if i]) ls.append('# {0}_num={1}'.format(namesufix, num)) if order == 'C': miter = enumerate(np.ravel(m)) elif order == 'F': miter = enumerate(np.ravel(m.T)) miter = list(miter) for i, v in miter: if v: if collect_for is not None: v = collect(v, collect_for, evaluate=False) ls.append('{0}[pos+{1}] +='.format(mname, i)) for k, expr in v.items(): ls.append('# collected for {k}'.format(k=k)) ls.append(' {expr}'.format(expr=k*expr)) else: if pow_by_mul: v = str(v) for p in re.findall(r'\w+\*\*\d+', v): var, exp = p.split('**') v = v.replace(p, '(' + '*'.join([var]*int(exp)) + ')') ls.append('{0}[pos+{1}] {2} {3}'.format(mname, i, op, v)) string = '\n'.join(ls) if print_file: with open(filename, 'w') as f: f.write(string) return string