我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用ast.walk()。
def get_local_vars(source, namespace): # local_vars = sys._getframe(depth).f_locals local_vars_names = set(namespace.keys()) root = ast.parse(source) required_vars_names = set() for node in ast.walk(root): if isinstance(node, ast.Name): required_vars_names.add(node.id) builtin_vars_names = set(vars(builtins).keys()) required_local_vars = required_vars_names & local_vars_names # we might want to add a compiler-ish thing in the future params = {} for v in required_local_vars: params[v] = namespace[v] return params
def search(func, depth=1): local_vars = sys._getframe(depth).f_locals source = get_source_code(func) tree = ast.parse(source) child_funcs = [] for node in ast.walk(tree): if isinstance(node, ast.Call): if isinstance(node.func, ast.Name): child_funcs.append(node.func.id) elif (isinstance(node, ast.Name) and node.id in local_vars and callable(local_vars[node.id]) and node.id not in sys.builtin_module_names): child_funcs.append(node.id) child_load_str = '' for child in child_funcs: if child in local_vars: try: load_string = search(local_vars[child], depth=(depth + 1)) child_load_str += load_string + '\n' except Exception as e: pass load_str = child_load_str + source return load_str
def get_statement_startend2(lineno, node): import ast # flatten all statements and except handlers into one lineno-list # AST's line numbers start indexing at 1 l = [] for x in ast.walk(node): if isinstance(x, _ast.stmt) or isinstance(x, _ast.ExceptHandler): l.append(x.lineno - 1) for name in "finalbody", "orelse": val = getattr(x, name, None) if val: # treat the finally/orelse part as its own statement l.append(val[0].lineno - 1 - 1) l.sort() insert_index = bisect_right(l, lineno) start = l[insert_index - 1] if insert_index >= len(l): end = None else: end = l[insert_index] return start, end
def replaceHazards(a): if not isinstance(a, ast.AST): return for field in ast.walk(a): if type(a) == ast.Import: for i in range(len(a.names)): if a.names[i].name not in supportedLibraries: if not (a.names[i].name[0] == "r" and a.names[i].name[1] in "0123456789") and not ("NotAllowed" in a.names[i].name): a.names[i].name = a.names[i].name + "NotAllowed" elif type(a) == ast.ImportFrom: if a.module not in supportedLibraries: if not (a.module[0] == "r" and a.module[1] in "0123456789") and not ("NotAllowed" in a.module): a.module = a.module + "NotAllowed" elif type(a) == ast.Call: if type(a.func) == ast.Name and a.func.id in ["compile", "eval", "execfile", "file", "open", "__import__", "apply"]: a.func.id = a.func.id + "NotAllowed"
def gatherAllNames(a, keep_orig=True): """Gather all names in the tree (variable or otherwise). Names are returned along with their original names (which are used in variable mapping)""" if type(a) == list: allIds = set() for line in a: allIds |= gatherAllNames(line) return allIds if not isinstance(a, ast.AST): return set() allIds = set() for node in ast.walk(a): if type(node) == ast.Name: origName = node.originalId if (keep_orig and hasattr(node, "originalId")) else None allIds |= set([(node.id, origName)]) return allIds
def gatherAllParameters(a, keep_orig=True): """Gather all parameters in the tree. Names are returned along with their original names (which are used in variable mapping)""" if type(a) == list: allIds = set() for line in a: allIds |= gatherAllVariables(line) return allIds if not isinstance(a, ast.AST): return set() allIds = set() for node in ast.walk(a): if type(node) == ast.arg: origName = node.originalId if (keep_orig and hasattr(node, "originalId")) else None allIds |= set([(node.arg, origName)]) return allIds
def getAllImports(a): """Gather all imported module names""" if not isinstance(a, ast.AST): return [] imports = [] for child in ast.walk(a): if type(child) == ast.Import: for alias in child.names: if alias.name in supportedLibraries: imports.append(alias.asname if alias.asname != None else alias.name) else: log("astTools\tgetAllImports\tUnknown library: " + alias.name, "bug") elif type(child) == ast.ImportFrom: if child.module in supportedLibraries: for alias in child.names: # these are all functions if alias.name in libraryMap[child.module]: imports.append(alias.asname if alias.asname != None else alias.name) else: log("astTools\tgetAllImports\tUnknown import from name: " + \ child.module + "," + alias.name, "bug") else: log("astTools\tgetAllImports\tUnknown library: " + child.module, "bug") return imports
def get_version(): with open(os.path.join('settei', 'version.py')) as f: tree = ast.parse(f.read(), f.name) for node in ast.walk(tree): if not (isinstance(node, ast.Assign) and len(node.targets) == 1): continue target, = node.targets value = node.value if not (isinstance(target, ast.Name) and target.id == 'VERSION_INFO' and isinstance(value, ast.Tuple)): continue elts = value.elts if any(not isinstance(elt, ast.Num) for elt in elts): continue return '.'.join(str(elt.n) for elt in elts)
def process(fl, external, genfiles, vendor): src = open(fl).read() tree = ast.parse(src, fl) lst = [] wksp = WORKSPACE(external, genfiles, vendor) for stmt in ast.walk(tree): stmttype = type(stmt) if stmttype == ast.Call: fn = getattr(wksp, stmt.func.id, "") if not callable(fn): continue path, name = keywords(stmt) if path.endswith(".git"): path = path[:-4] path = pathmap.get(path, path) tup = fn(name, path) lst.append(tup) return lst
def walk_python_files(): u''' Generator that yields all CKAN Python source files. Yields 2-tuples containing the filename in absolute and relative (to the project root) form. ''' def _is_dir_ignored(root, d): if d.startswith(u'.'): return True return os.path.join(rel_root, d) in IGNORED_DIRS for abs_root, dirnames, filenames in os.walk(PROJECT_ROOT): rel_root = os.path.relpath(abs_root, PROJECT_ROOT) if rel_root == u'.': rel_root = u'' dirnames[:] = [d for d in dirnames if not _is_dir_ignored(rel_root, d)] for filename in filenames: if not filename.endswith(u'.py'): continue abs_name = os.path.join(abs_root, filename) rel_name = os.path.join(rel_root, filename) yield abs_name, rel_name
def run(self): tree = self.tree if self.filename == 'stdin': lines = stdin_utils.stdin_get_value() tree = ast.parse(lines) for statement in ast.walk(tree): for child in ast.iter_child_nodes(statement): child.__flake8_builtins_parent = statement for statement in ast.walk(tree): value = None if isinstance(statement, ast.Assign): value = self.check_assignment(statement) elif isinstance(statement, ast.FunctionDef): value = self.check_function_definition(statement) if value: for line, offset, msg, rtype in value: yield line, offset, msg, rtype
def walk(node): """ Recursively yield all descendant nodes in the tree starting at ``node`` (including ``node`` itself), using depth-first pre-order traversal (yieling parents before their children). This is similar to ``ast.walk()``, but with a different order, and it works for both ``ast`` and ``astroid`` trees. Also, as ``iter_children()``, it skips singleton nodes generated by ``ast``. """ iter_children = iter_children_func(node) done = set() stack = [node] while stack: current = stack.pop() assert current not in done # protect againt infinite loop in case of a bad tree. done.add(current) yield current # Insert all children in reverse order (so that first child ends up on top of the stack). # This is faster than building a list and reversing it. ins = len(stack) for c in iter_children(current): stack.insert(ins, c)
def print_timing(self): # pylint: disable=no-self-use # Test the implementation of asttokens.util.walk, which uses the same approach as # visit_tree(). This doesn't run as a normal unittest, but if you'd like to see timings, e.g. # after experimenting with the implementation, run this to see them: # # nosetests -i print_timing -s tests.test_util # import timeit import textwrap setup = textwrap.dedent( ''' import ast, asttokens source = "foo(bar(1 + 2), 'hello' + ', ' + 'world')" atok = asttokens.ASTTokens(source, parse=True) ''') print("ast", sorted(timeit.repeat( setup=setup, number=10000, stmt='len(list(ast.walk(atok.tree)))'))) print("util", sorted(timeit.repeat( setup=setup, number=10000, stmt='len(list(asttokens.util.walk(atok.tree)))')))
def test_walk_ast(self): atok = asttokens.ASTTokens(self.source, parse=True) def view(node): return "%s:%s" % (node.__class__.__name__, atok.get_text(node)) scan = [view(n) for n in asttokens.util.walk(atok.tree)] self.assertEqual(scan, [ "Module:foo(bar(1 + 2), 'hello' + ', ' + 'world')", "Expr:foo(bar(1 + 2), 'hello' + ', ' + 'world')", "Call:foo(bar(1 + 2), 'hello' + ', ' + 'world')", 'Name:foo', 'Call:bar(1 + 2)', 'Name:bar', 'BinOp:1 + 2', 'Num:1', 'Num:2', "BinOp:'hello' + ', ' + 'world'", "BinOp:'hello' + ', '", "Str:'hello'", "Str:', '", "Str:'world'" ])
def test_walk_astroid(self): atok = asttokens.ASTTokens(self.source, tree=astroid.builder.parse(self.source)) def view(node): return "%s:%s" % (node.__class__.__name__, atok.get_text(node)) scan = [view(n) for n in asttokens.util.walk(atok.tree)] self.assertEqual(scan, [ "Module:foo(bar(1 + 2), 'hello' + ', ' + 'world')", "Expr:foo(bar(1 + 2), 'hello' + ', ' + 'world')", "Call:foo(bar(1 + 2), 'hello' + ', ' + 'world')", 'Name:foo', 'Call:bar(1 + 2)', 'Name:bar', 'BinOp:1 + 2', 'Const:1', 'Const:2', "BinOp:'hello' + ', ' + 'world'", "BinOp:'hello' + ', '", "Const:'hello'", "Const:', '", "Const:'world'" ])
def test_replace(self): self.assertEqual(asttokens.util.replace("this is a test", [(0, 4, "X"), (8, 9, "THE")]), "X is THE test") self.assertEqual(asttokens.util.replace("this is a test", []), "this is a test") self.assertEqual(asttokens.util.replace("this is a test", [(7,7," NOT")]), "this is NOT a test") source = "foo(bar(1 + 2), 'hello' + ', ' + 'world')" atok = asttokens.ASTTokens(source, parse=True) names = [n for n in asttokens.util.walk(atok.tree) if isinstance(n, ast.Name)] strings = [n for n in asttokens.util.walk(atok.tree) if isinstance(n, ast.Str)] repl1 = [atok.get_text_range(n) + ('TEST',) for n in names] repl2 = [atok.get_text_range(n) + ('val',) for n in strings] self.assertEqual(asttokens.util.replace(source, repl1 + repl2), "TEST(TEST(1 + 2), val + val + val)") self.assertEqual(asttokens.util.replace(source, repl2 + repl1), "TEST(TEST(1 + 2), val + val + val)")
def _fine_property_definition(self, property_name): """Find the lines in the source code that contain this property's name and definition. This function can find both attribute assignments as well as methods/functions. Args: property_name (str): the name of the property to look up in the template definition Returns: tuple: line numbers for the start and end of the attribute definition """ for node in ast.walk(ast.parse(self._source)): if isinstance(node, ast.Assign) and node.targets[0].id == property_name: return node.targets[0].lineno - 1, self._get_node_line_end(node) elif isinstance(node, ast.FunctionDef) and node.name == property_name: return node.lineno - 1, self._get_node_line_end(node) raise ValueError('The requested node could not be found.')
def _find_non_builtin_globals(source, codeobj): try: import ast except ImportError: return None try: import __builtin__ except ImportError: import builtins as __builtin__ vars = dict.fromkeys(codeobj.co_varnames) return [ node.id for node in ast.walk(ast.parse(source)) if isinstance(node, ast.Name) and node.id not in vars and node.id not in __builtin__.__dict__ ]
def linerange(node): """Get line number range from a node.""" strip = {"body": None, "orelse": None, "handlers": None, "finalbody": None} for key in strip.keys(): if hasattr(node, key): strip[key] = getattr(node, key) setattr(node, key, []) lines_min = 9999999999 lines_max = -1 for n in ast.walk(node): if hasattr(n, 'lineno'): lines_min = min(lines_min, n.lineno) lines_max = max(lines_max, n.lineno) for key in strip.keys(): if strip[key] is not None: setattr(node, key, strip[key]) if lines_max > -1: return list(range(lines_min, lines_max + 1)) return [0, 1]
def get_orig_line_from_s_orig(s_orig, line_no): if(line_no == None): return -1 node = py_ast.get_ast(s_orig) nodeList = [i for i in ast.walk(node) if (hasattr(i, 'lineno') and hasattr(i, 'orig_lineno') and i.lineno == line_no)] if(len(nodeList) == 0): #print("------ get_orig_line_from_s_orig begin--------") #print(s_orig) #print(line_no) #print("------ get_orig_line_from_s_orig end--------") #print("______________") #print("cannot find lineno") #print("______________") #node = preprocess.add_str_node(node) #nodeList2 = [i for i in ast.walk(node) if (hasattr(i, 'lineno') and hasattr(i, 'orig_lineno') and i.lineno == line_no)] #if(len(nodeList2) == 0): return line_no return nodeList[0].orig_lineno
def get_ast(source_prog): """Returns the ast of the program, with comments converted into string literals. Args: source_prog, string, string version of the source code """ wrapped_str = comment_to_str(source_prog, TRANS_PREFIXES) node = ast.parse(wrapped_str) add_parent_info(node) nodeList = [i for i in ast.walk(node) if (isinstance(i, ast.stmt))] for i in nodeList: if(hasattr(i,'lineno')): #i.orig_lineno = 1 temp = getLineNum(i) if(temp != -1): i.orig_lineno = temp #a = 1 nodeList = [i for i in ast.walk(node)] for i in nodeList: if(hasattr(i,'parent')): delattr(i, 'parent') return node
def main(): parser = argparse.ArgumentParser() parser.add_argument('dir') args = parser.parse_args() n_err = 0 for dir, _, files in os.walk(args.dir): for file in files: _, ext = os.path.splitext(file) if not ext == '.py': continue path = os.path.join(dir, file) lines = open(path).readlines() for lineno, msg in check(''.join(lines)): print('{:s}:{:d} : {:s}'.format(path, lineno, msg)) print(lines[lineno - 1]) n_err += 1 if n_err > 0: sys.exit('{:d} style errors are found.'.format(n_err))
def check_nesting(self, **kwargs): """Inspect the code for too much nested expressions.""" try: max_nesting = kwargs['max_nesting'] except KeyError: return # Traverse the nodes and find those that are nested # (have 'body' attribute). nodes = [(node, node.lineno) for node in ast.walk(self.parsed_code.body[0]) if hasattr(node, 'body')] nesting_level = len(nodes) if nesting_level > max_nesting: # The line number where the error was found # is the next one (thus + 1): line_number = nodes[-1][1] + 1 self.issues[line_number].add( self.code_errors.nesting_too_deep( nesting_level, max_nesting ) )
def check_indentation(self, **kwargs): """Inspect the code for indentation size errors.""" try: indentation_size = kwargs['indentation_size'] except KeyError: # Use the default value instead: indentation_size = self.DEFAULT_RULES['indentation_size'] # Traverse the nodes and find those that are nested # (have 'body' attribute). nodes = [node for node in ast.walk(self.parsed_code.body[0]) if hasattr(node, 'body')] # Use the previous line offset # as a guide for the next line indentation. last_offset = 0 for node in nodes: line_number = node.body[0].lineno col_offset = node.body[0].col_offset if col_offset > last_offset + indentation_size: offset = col_offset - last_offset self.issues[line_number].add( self.code_errors.indentation(offset, indentation_size) ) last_offset = col_offset
def check_methods_per_class(self, **kwargs): """ Inspect the code for too many methods per class. """ try: methods_per_class = kwargs['methods_per_class'] except KeyError: return klass = self.parsed_code.body[0] if not isinstance(klass, ast.ClassDef): return methods = [(node, node.lineno) for node in ast.walk(klass) if isinstance(node, ast.FunctionDef)] try: # Get the last method of the class # and its line number: line_number = methods[-1][1] self.issues[line_number].add( self.code_errors.too_many_methods_per_class( len(methods), methods_per_class ) ) except IndexError: return
def walk(self, prog_ast): result = list(ast.walk(prog_ast)) import_nodes = [node for node in result if isinstance(node, ast.Import)] import_from_nodes = [node for node in result if isinstance(node, ast.ImportFrom)] for node in import_nodes: for name in node.names: if ImportHandler.is_builtin(name.name): new_ast = ImportHandler.get_builtin_ast(name.name) else: new_ast = ImportHandler.get_module_ast(name.name, self.base_folder) result += self.walk(new_ast) for node in import_from_nodes: if node.module == "typing": # FIXME ignore typing for now, not to break type vars continue if ImportHandler.is_builtin(node.module): new_ast = ImportHandler.get_builtin_ast(node.module) else: new_ast = ImportHandler.get_module_ast(node.module, self.base_folder) result += self.walk(new_ast) return result
def ingest(self, rootdir): """ Collect all the .py files to perform analysis upon """ if not os.path.isdir(rootdir): raise Exception("directory %s passed in is not a dir" % rootdir) self.__target_dir = rootdir # walk the dirs/files for root, subdir, files in os.walk(self.__target_dir): for f in files: if f.endswith(".py"): fullpath = root + os.sep + f contents = file(fullpath).read() tree = ast.parse(contents) self.__fn_to_ast[fullpath] = tree # potentially analyze .html files for jinja templates if self.perform_jinja_analysis: self.__template_dir = self.get_template_dir()
def find_template_dir(self): # web-p2 is web-p2/partners/templates # login is login/templates # TODO: look for invocations of `jinja2.Environment` and see if # we can pull the template directory / package from there? Should work # for most. template_dirs = set() for root, subdir, files in os.walk(self.__target_dir): for fname in files: fpath = os.path.join(root, fname) if fname.endswith(".html"): with open(fpath, "rb") as f: # Hmm, smells like a jinja template! if b"{%" in f.read(): template_dirs.add(root) # If there are multiple template directories in a repo we might need # repo-specific overrides. return None if not template_dirs else os.path.commonprefix(template_dirs)
def get_template_dir(self): """ return the directory containing jinja2 templates ex: web-p2 is web-p2/partners/templates """ template_dirs = set() for root, subdir, files in os.walk(self.__target_dir): for fname in files: fpath = os.path.join(root, fname) if fname.endswith(".html"): with open(fpath, "rb") as f: # Hmm, smells like a jinja template! if b"{%" in f.read(): template_dirs.add(root) # If there are multiple template directories in a repo we might need # repo-specific overrides. return None if not template_dirs else os.path.commonprefix(template_dirs)
def file_contains_pluggable(file_path, pluggable): plugin_class = None try: with open(file_path, "r") as f: syntax_tree = ast.parse(f.read()) except FileNotFoundError: return [False, None] for statement in ast.walk(syntax_tree): if isinstance(statement, ast.ClassDef): class_name = statement.name bases = list(map(lambda b: b.id if isinstance(b, ast.Name) else b.attr, statement.bases)) if pluggable in bases: plugin_class = class_name return [plugin_class is not None, plugin_class]
def _walk_files(self, path): """Walk paths and yield Python paths Directories and files are yielded in alphabetical order. Directories starting with a "." are skipped. As are those that match any provided ignore patterns. """ if os.path.isfile(path): yield path elif not os.path.isdir(path): LOG.error("The path '%s' can't be found.", path) raise StopIteration for root, dirs, filenames in os.walk(path): # Remove dot-directories from the dirs list. dirs[:] = sorted(d for d in dirs if not d.startswith('.') and not self._is_ignored(d)) for filename in sorted(filenames): if self._is_python(filename): yield os.path.join(root, filename)
def _walk_ast(self, node, top=False): if not hasattr(node, 'parent'): node.parent = None node.parents = [] for field, value in ast.iter_fields(node): if isinstance(value, list): for index, item in enumerate(value): if isinstance(item, ast.AST): self._walk_ast(item) self._set_parnt_fields(item, node, field, index) elif isinstance(value, ast.AST): self._walk_ast(value) self._set_parnt_fields(value, node, field) if top: return ast.walk(node)
def runTest(self): """Makes a simple test of the output""" body = ast.parse(self.candidate_code, self.file_name, 'exec') code = compile(self.candidate_code, self.file_name, 'exec') format_nodes = [ node for node in ast.walk(body) if isinstance(node, ast.Attribute) and node.attr == 'format' and isinstance(node.value, ast.Str) and '{}' in node.value.s ] self.assertGreater( len(format_nodes), 0, "It should have at one format call with curly braces {}" ) exec(code) self.assertMultiLineEqual('Talk is cheap. Show me the code.\n', self.__mockstdout.getvalue(), 'Output is not correct')
def runTest(self): """Makes a simple test of the output""" body = ast.parse(self.candidate_code, self.file_name, 'exec') code = compile(self.candidate_code, self.file_name, 'exec') mult_instructions = [ node for node in ast.walk(body) if isinstance(node, ast.Mult) ] self.assertGreater(len(mult_instructions), 0, "It should have at least one duplication" ) exec(code) self.assertMultiLineEqual('ka'*10+'\n', self.__mockstdout.getvalue(), "Should have printed ka 10 times")
def runTest(self): """Makes a simple test of the output""" body = ast.parse(self.candidate_code, self.file_name, 'exec') code = compile(self.candidate_code, self.file_name, 'exec', optimize=0) exec(code) if_statements = [ node for node in ast.walk(body) if isinstance(node, ast.If) ] self.assertGreater(len(if_statements), 0, "Should have at least on if statement") self.assertMultiLineEqual(self.correct_output, self.__mockstdout.getvalue(), "Output should be correct")
def get_init(self, filename="__init__.py"): """ Get various info from the package without importing them """ import ast with open(filename) as init_file: module = ast.parse(init_file.read()) itr = lambda x: (ast.literal_eval(node.value) for node in ast.walk(module) \ if isinstance(node, ast.Assign) and node.targets[0].id == x) try: return next(itr("__author__")), \ next(itr("__email__")), \ next(itr("__license__")), \ next(itr("__version__")) except StopIteration: raise ValueError("One of author, email, license, or version" " cannot be found in {}".format(filename))
def get_instance_variables(node, bound_name_classifier=BOUND_METHOD_ARGUMENT_NAME): """ Return instance variables used in an AST node """ node_attributes = [ child for child in ast.walk(node) if isinstance(child, ast.Attribute) and get_attribute_name_id(child) == bound_name_classifier ] node_function_call_names = [ get_object_name(child) for child in ast.walk(node) if isinstance(child, ast.Call) ] node_instance_variables = [ attribute for attribute in node_attributes if get_object_name(attribute) not in node_function_call_names ] return node_instance_variables
def __init__(self, expr): self.expr = expr self.ns = {} try: tree = ast.parse(expr) except SyntaxError as exc: raise argparse.ArgumentTypeError('Invalid service spec %r. Parse error:\n' ' %s %s^\n' '%s' % (expr, exc.text, ' '*exc.offset, exc)) for node in ast.walk(tree): if isinstance(node, ast.Name): if not hasattr(builtins, node.id): try: __import__(node.id) except ImportError as exc: raise argparse.ArgumentTypeError('Invalid service spec %r. Import error: %s' % (expr, exc)) self.ns[node.id] = sys.modules[node.id]
def childHasTag(a, tag): """ Includes the AST itself""" if hasattr(a, tag): return True if type(a) == list: for child in a: if childHasTag(child, tag): return True return False elif not isinstance(a, ast.AST): return False for node in ast.walk(a): if hasattr(node, tag): return True return False
def hasMultiComp(a): if not isinstance(a, ast.AST): return False for node in ast.walk(a): if hasattr(node, "multiComp") and node.multiComp: return True return False