# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved. # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr # # This file is part of astroid. # # astroid is free software: you can redistribute it and/or modify it # under the terms of the GNU Lesser General Public License as published by the # Free Software Foundation, either version 2.1 of the License, or (at your # option) any later version. # # astroid is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License # for more details. # # You should have received a copy of the GNU Lesser General Public License along # with astroid. If not, see . """visitor doing some postprocessing on the astroid tree. Try to resolve definitions (namespace) dictionary, relationship... This module has been imported from pyreverse """ __docformat__ = "restructuredtext en" from os.path import dirname import astroid from astroid.exceptions import InferenceError from astroid.utils import LocalsVisitor from astroid.modutils import get_module_part, is_relative, is_standard_module class IdGeneratorMixIn(object): """ Mixin adding the ability to generate integer uid """ def __init__(self, start_value=0): self.id_count = start_value def init_counter(self, start_value=0): """init the id counter """ self.id_count = start_value def generate_id(self): """generate a new identifier """ self.id_count += 1 return self.id_count class Linker(IdGeneratorMixIn, LocalsVisitor): """ walk on the project tree and resolve relationships. According to options the following attributes may be added to visited nodes: * uid, a unique identifier for the node (on astroid.Project, astroid.Module, astroid.Class and astroid.locals_type). Only if the linker has been instantiated with tag=True parameter (False by default). * Function a mapping from locals names to their bounded value, which may be a constant like a string or an integer, or an astroid node (on astroid.Module, astroid.Class and astroid.Function). * instance_attrs_type as locals_type but for klass member attributes (only on astroid.Class) * implements, list of implemented interface _objects_ (only on astroid.Class nodes) """ def __init__(self, project, inherited_interfaces=0, tag=False): IdGeneratorMixIn.__init__(self) LocalsVisitor.__init__(self) # take inherited interface in consideration or not self.inherited_interfaces = inherited_interfaces # tag nodes or not self.tag = tag # visited project self.project = project def visit_project(self, node): """visit an astroid.Project node * optionally tag the node with a unique id """ if self.tag: node.uid = self.generate_id() for module in node.modules: self.visit(module) def visit_package(self, node): """visit an astroid.Package node * optionally tag the node with a unique id """ if self.tag: node.uid = self.generate_id() for subelmt in node.values(): self.visit(subelmt) def visit_module(self, node): """visit an astroid.Module node * set the locals_type mapping * set the depends mapping * optionally tag the node with a unique id """ if hasattr(node, 'locals_type'): return node.locals_type = {} node.depends = [] if self.tag: node.uid = self.generate_id() def visit_class(self, node): """visit an astroid.Class node * set the locals_type and instance_attrs_type mappings * set the implements list and build it * optionally tag the node with a unique id """ if hasattr(node, 'locals_type'): return node.locals_type = {} if self.tag: node.uid = self.generate_id() # resolve ancestors for baseobj in node.ancestors(recurs=False): specializations = getattr(baseobj, 'specializations', []) specializations.append(node) baseobj.specializations = specializations # resolve instance attributes node.instance_attrs_type = {} for assattrs in node.instance_attrs.values(): for assattr in assattrs: self.handle_assattr_type(assattr, node) # resolve implemented interface try: node.implements = list(node.interfaces(self.inherited_interfaces)) except InferenceError: node.implements = () def visit_function(self, node): """visit an astroid.Function node * set the locals_type mapping * optionally tag the node with a unique id """ if hasattr(node, 'locals_type'): return node.locals_type = {} if self.tag: node.uid = self.generate_id() link_project = visit_project link_module = visit_module link_class = visit_class link_function = visit_function def visit_assname(self, node): """visit an astroid.AssName node handle locals_type """ # avoid double parsing done by different Linkers.visit # running over the same project: if hasattr(node, '_handled'): return node._handled = True if node.name in node.frame(): frame = node.frame() else: # the name has been defined as 'global' in the frame and belongs # there. Btw the frame is not yet visited as the name is in the # root locals; the frame hence has no locals_type attribute frame = node.root() try: values = node.infered() try: already_infered = frame.locals_type[node.name] for valnode in values: if not valnode in already_infered: already_infered.append(valnode) except KeyError: frame.locals_type[node.name] = values except astroid.InferenceError: pass def handle_assattr_type(self, node, parent): """handle an astroid.AssAttr node handle instance_attrs_type """ try: values = list(node.infer()) try: already_infered = parent.instance_attrs_type[node.attrname] for valnode in values: if not valnode in already_infered: already_infered.append(valnode) except KeyError: parent.instance_attrs_type[node.attrname] = values except astroid.InferenceError: pass def visit_import(self, node): """visit an astroid.Import node resolve module dependencies """ context_file = node.root().file for name in node.names: relative = is_relative(name[0], context_file) self._imported_module(node, name[0], relative) def visit_from(self, node): """visit an astroid.From node resolve module dependencies """ basename = node.modname context_file = node.root().file if context_file is not None: relative = is_relative(basename, context_file) else: relative = False for name in node.names: if name[0] == '*': continue # analyze dependencies fullname = '%s.%s' % (basename, name[0]) if fullname.find('.') > -1: try: # XXX: don't use get_module_part, missing package precedence fullname = get_module_part(fullname, context_file) except ImportError: continue if fullname != basename: self._imported_module(node, fullname, relative) def compute_module(self, context_name, mod_path): """return true if the module should be added to dependencies""" package_dir = dirname(self.project.path) if context_name == mod_path: return 0 elif is_standard_module(mod_path, (package_dir,)): return 1 return 0 # protected methods ######################################################## def _imported_module(self, node, mod_path, relative): """notify an imported module, used to analyze dependencies """ module = node.root() context_name = module.name if relative: mod_path = '%s.%s' % ('.'.join(context_name.split('.')[:-1]), mod_path) if self.compute_module(context_name, mod_path): # handle dependencies if not hasattr(module, 'depends'): module.depends = [] mod_paths = module.depends if not mod_path in mod_paths: mod_paths.append(mod_path)