"""Discover and load entry points from installed packages.""" # Copyright (c) Thomas Kluyver and contributors # Distributed under the terms of the MIT license; see LICENSE file. from contextlib import contextmanager import glob from importlib import import_module import io import itertools import os.path as osp import re import sys import warnings import zipfile if sys.version_info[0] >= 3: import configparser else: from backports import configparser entry_point_pattern = re.compile(r""" (?P\w+(\.\w+)*) (:(?P\w+(\.\w+)*))? \s* (\[(?P.+)\])? $ """, re.VERBOSE) file_in_zip_pattern = re.compile(r""" (?P[^/\\]+)\.(dist|egg)-info [/\\]entry_points.txt$ """, re.VERBOSE) __version__ = '0.3' class BadEntryPoint(Exception): """Raised when an entry point can't be parsed. """ def __init__(self, epstr): self.epstr = epstr def __str__(self): return "Couldn't parse entry point spec: %r" % self.epstr @staticmethod @contextmanager def err_to_warnings(): try: yield except BadEntryPoint as e: warnings.warn(str(e)) class NoSuchEntryPoint(Exception): """Raised by :func:`get_single` when no matching entry point is found.""" def __init__(self, group, name): self.group = group self.name = name def __str__(self): return "No {!r} entry point found in group {!r}".format(self.name, self.group) class CaseSensitiveConfigParser(configparser.ConfigParser): optionxform = staticmethod(str) class EntryPoint(object): def __init__(self, name, module_name, object_name, extras=None, distro=None): self.name = name self.module_name = module_name self.object_name = object_name self.extras = extras self.distro = distro def __repr__(self): return "EntryPoint(%r, %r, %r, %r)" % \ (self.name, self.module_name, self.object_name, self.distro) def load(self): """Load the object to which this entry point refers. """ mod = import_module(self.module_name) obj = mod if self.object_name: for attr in self.object_name.split('.'): obj = getattr(obj, attr) return obj @classmethod def from_string(cls, epstr, name, distro=None): """Parse an entry point from the syntax in entry_points.txt :param str epstr: The entry point string (not including 'name =') :param str name: The name of this entry point :param Distribution distro: The distribution in which the entry point was found :rtype: EntryPoint :raises BadEntryPoint: if *epstr* can't be parsed as an entry point. """ m = entry_point_pattern.match(epstr) if m: mod, obj, extras = m.group('modulename', 'objectname', 'extras') if extras is not None: extras = re.split(r',\s*', extras) return cls(name, mod, obj, extras, distro) else: raise BadEntryPoint(epstr) class Distribution(object): def __init__(self, name, version): self.name = name self.version = version def __repr__(self): return "Distribution(%r, %r)" % (self.name, self.version) def iter_files_distros(path=None, repeated_distro='first'): if path is None: path = sys.path # Distributions found earlier in path will shadow those with the same name # found later. If these distributions used different module names, it may # actually be possible to import both, but in most cases this shadowing # will be correct. distro_names_seen = set() for folder in path: if folder.rstrip('/\\').endswith('.egg'): # Gah, eggs egg_name = osp.basename(folder) if '-' in egg_name: distro = Distribution(*egg_name.split('-')[:2]) if (repeated_distro == 'first') \ and (distro.name in distro_names_seen): continue distro_names_seen.add(distro.name) else: distro = None if osp.isdir(folder): ep_path = osp.join(folder, 'EGG-INFO', 'entry_points.txt') if osp.isfile(ep_path): cp = CaseSensitiveConfigParser(delimiters=('=',)) cp.read([ep_path]) yield cp, distro elif zipfile.is_zipfile(folder): z = zipfile.ZipFile(folder) try: info = z.getinfo('EGG-INFO/entry_points.txt') except KeyError: continue cp = CaseSensitiveConfigParser(delimiters=('=',)) with z.open(info) as f: fu = io.TextIOWrapper(f) cp.read_file(fu, source=osp.join(folder, 'EGG-INFO', 'entry_points.txt')) yield cp, distro # zip imports, not egg elif zipfile.is_zipfile(folder): with zipfile.ZipFile(folder) as zf: for info in zf.infolist(): m = file_in_zip_pattern.match(info.filename) if not m: continue distro_name_version = m.group('dist_version') if '-' in distro_name_version: distro = Distribution(*distro_name_version.split('-', 1)) if (repeated_distro == 'first') \ and (distro.name in distro_names_seen): continue distro_names_seen.add(distro.name) else: distro = None cp = CaseSensitiveConfigParser(delimiters=('=',)) with zf.open(info) as f: fu = io.TextIOWrapper(f) cp.read_file(fu, source=osp.join(folder, info.filename)) yield cp, distro # Regular file imports (not egg, not zip file) for path in itertools.chain( glob.iglob(osp.join(folder, '*.dist-info', 'entry_points.txt')), glob.iglob(osp.join(folder, '*.egg-info', 'entry_points.txt')) ): distro_name_version = osp.splitext(osp.basename(osp.dirname(path)))[0] if '-' in distro_name_version: distro = Distribution(*distro_name_version.split('-', 1)) if (repeated_distro == 'first') \ and (distro.name in distro_names_seen): continue distro_names_seen.add(distro.name) else: distro = None cp = CaseSensitiveConfigParser(delimiters=('=',)) cp.read([path]) yield cp, distro def get_single(group, name, path=None): """Find a single entry point. Returns an :class:`EntryPoint` object, or raises :exc:`NoSuchEntryPoint` if no match is found. """ for config, distro in iter_files_distros(path=path): if (group in config) and (name in config[group]): epstr = config[group][name] with BadEntryPoint.err_to_warnings(): return EntryPoint.from_string(epstr, name, distro) raise NoSuchEntryPoint(group, name) def get_group_named(group, path=None): """Find a group of entry points with unique names. Returns a dictionary of names to :class:`EntryPoint` objects. """ result = {} for ep in get_group_all(group, path=path): if ep.name not in result: result[ep.name] = ep return result def get_group_all(group, path=None): """Find all entry points in a group. Returns a list of :class:`EntryPoint` objects. """ result = [] for config, distro in iter_files_distros(path=path): if group in config: for name, epstr in config[group].items(): with BadEntryPoint.err_to_warnings(): result.append(EntryPoint.from_string(epstr, name, distro)) return result if __name__ == '__main__': import pprint pprint.pprint(get_group_all('console_scripts'))