Initial commit (Clean history)

This commit is contained in:
anhduy-tech
2025-12-30 11:27:14 +07:00
commit ef48c93de0
19255 changed files with 3248867 additions and 0 deletions

View File

@@ -0,0 +1,70 @@
from typing import Any, Optional
from ._version import __version_info__, __version__ # noqa
from .collection import Collection # noqa
from .config import Config # noqa
from .context import Context, MockContext # noqa
from .exceptions import ( # noqa
AmbiguousEnvVar,
AuthFailure,
CollectionNotFound,
Exit,
ParseError,
PlatformError,
ResponseNotAccepted,
SubprocessPipeError,
ThreadException,
UncastableEnvVar,
UnexpectedExit,
UnknownFileType,
UnpicklableConfigMember,
WatcherError,
CommandTimedOut,
)
from .executor import Executor # noqa
from .loader import FilesystemLoader # noqa
from .parser import Argument, Parser, ParserContext, ParseResult # noqa
from .program import Program # noqa
from .runners import Runner, Local, Failure, Result, Promise # noqa
from .tasks import task, call, Call, Task # noqa
from .terminals import pty_size # noqa
from .watchers import FailingResponder, Responder, StreamWatcher # noqa
def run(command: str, **kwargs: Any) -> Optional[Result]:
"""
Run ``command`` in a subprocess and return a `.Result` object.
See `.Runner.run` for API details.
.. note::
This function is a convenience wrapper around Invoke's `.Context` and
`.Runner` APIs.
Specifically, it creates an anonymous `.Context` instance and calls its
`~.Context.run` method, which in turn defaults to using a `.Local`
runner subclass for command execution.
.. versionadded:: 1.0
"""
return Context().run(command, **kwargs)
def sudo(command: str, **kwargs: Any) -> Optional[Result]:
"""
Run ``command`` in a ``sudo`` subprocess and return a `.Result` object.
See `.Context.sudo` for API details, such as the ``password`` kwarg.
.. note::
This function is a convenience wrapper around Invoke's `.Context` and
`.Runner` APIs.
Specifically, it creates an anonymous `.Context` instance and calls its
`~.Context.sudo` method, which in turn defaults to using a `.Local`
runner subclass for command execution (plus sudo-related bits &
pieces).
.. versionadded:: 1.4
"""
return Context().sudo(command, **kwargs)

View File

@@ -0,0 +1,3 @@
from invoke.main import program
program.run()

View File

@@ -0,0 +1,2 @@
__version_info__ = (2, 2, 1)
__version__ = ".".join(map(str, __version_info__))

View File

@@ -0,0 +1,608 @@
import copy
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Tuple
from .util import Lexicon, helpline
from .config import merge_dicts, copy_dict
from .parser import Context as ParserContext
from .tasks import Task
class Collection:
"""
A collection of executable tasks. See :doc:`/concepts/namespaces`.
.. versionadded:: 1.0
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
Create a new task collection/namespace.
`.Collection` offers a set of methods for building a collection of
tasks from scratch, plus a convenient constructor wrapping said API.
In either case:
* The first positional argument may be a string, which (if given) is
used as the collection's default name when performing namespace
lookups;
* A ``loaded_from`` keyword argument may be given, which sets metadata
indicating the filesystem path the collection was loaded from. This
is used as a guide when loading per-project :ref:`configuration files
<config-hierarchy>`.
* An ``auto_dash_names`` kwarg may be given, controlling whether task
and collection names have underscores turned to dashes in most cases;
it defaults to ``True`` but may be set to ``False`` to disable.
The CLI machinery will pass in the value of the
``tasks.auto_dash_names`` config value to this kwarg.
**The method approach**
May initialize with no arguments and use methods (e.g.
`.add_task`/`.add_collection`) to insert objects::
c = Collection()
c.add_task(some_task)
If an initial string argument is given, it is used as the default name
for this collection, should it be inserted into another collection as a
sub-namespace::
docs = Collection('docs')
docs.add_task(doc_task)
ns = Collection()
ns.add_task(top_level_task)
ns.add_collection(docs)
# Valid identifiers are now 'top_level_task' and 'docs.doc_task'
# (assuming the task objects were actually named the same as the
# variables we're using :))
For details, see the API docs for the rest of the class.
**The constructor approach**
All ``*args`` given to `.Collection` (besides the abovementioned
optional positional 'name' argument and ``loaded_from`` kwarg) are
expected to be `.Task` or `.Collection` instances which will be passed
to `.add_task`/`.add_collection` as appropriate. Module objects are
also valid (as they are for `.add_collection`). For example, the below
snippet results in the same two task identifiers as the one above::
ns = Collection(top_level_task, Collection('docs', doc_task))
If any ``**kwargs`` are given, the keywords are used as the initial
name arguments for the respective values::
ns = Collection(
top_level_task=some_other_task,
docs=Collection(doc_task)
)
That's exactly equivalent to::
docs = Collection(doc_task)
ns = Collection()
ns.add_task(some_other_task, 'top_level_task')
ns.add_collection(docs, 'docs')
See individual methods' API docs for details.
"""
# Initialize
self.tasks = Lexicon()
self.collections = Lexicon()
self.default: Optional[str] = None
self.name = None
self._configuration: Dict[str, Any] = {}
# Specific kwargs if applicable
self.loaded_from = kwargs.pop("loaded_from", None)
self.auto_dash_names = kwargs.pop("auto_dash_names", None)
# splat-kwargs version of default value (auto_dash_names=True)
if self.auto_dash_names is None:
self.auto_dash_names = True
# Name if applicable
_args = list(args)
if _args and isinstance(args[0], str):
self.name = self.transform(_args.pop(0))
# Dispatch args/kwargs
for arg in _args:
self._add_object(arg)
# Dispatch kwargs
for name, obj in kwargs.items():
self._add_object(obj, name)
def _add_object(self, obj: Any, name: Optional[str] = None) -> None:
method: Callable
if isinstance(obj, Task):
method = self.add_task
elif isinstance(obj, (Collection, ModuleType)):
method = self.add_collection
else:
raise TypeError("No idea how to insert {!r}!".format(type(obj)))
method(obj, name=name)
def __repr__(self) -> str:
task_names = list(self.tasks.keys())
collections = ["{}...".format(x) for x in self.collections.keys()]
return "<Collection {!r}: {}>".format(
self.name, ", ".join(sorted(task_names) + sorted(collections))
)
def __eq__(self, other: object) -> bool:
if isinstance(other, Collection):
return (
self.name == other.name
and self.tasks == other.tasks
and self.collections == other.collections
)
return False
def __bool__(self) -> bool:
return bool(self.task_names)
@classmethod
def from_module(
cls,
module: ModuleType,
name: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
loaded_from: Optional[str] = None,
auto_dash_names: Optional[bool] = None,
) -> "Collection":
"""
Return a new `.Collection` created from ``module``.
Inspects ``module`` for any `.Task` instances and adds them to a new
`.Collection`, returning it. If any explicit namespace collections
exist (named ``ns`` or ``namespace``) a copy of that collection object
is preferentially loaded instead.
When the implicit/default collection is generated, it will be named
after the module's ``__name__`` attribute, or its last dotted section
if it's a submodule. (I.e. it should usually map to the actual ``.py``
filename.)
Explicitly given collections will only be given that module-derived
name if they don't already have a valid ``.name`` attribute.
If the module has a docstring (``__doc__``) it is copied onto the
resulting `.Collection` (and used for display in help, list etc
output.)
:param str name:
A string, which if given will override any automatically derived
collection name (or name set on the module's root namespace, if it
has one.)
:param dict config:
Used to set config options on the newly created `.Collection`
before returning it (saving you a call to `.configure`.)
If the imported module had a root namespace object, ``config`` is
merged on top of it (i.e. overriding any conflicts.)
:param str loaded_from:
Identical to the same-named kwarg from the regular class
constructor - should be the path where the module was
found.
:param bool auto_dash_names:
Identical to the same-named kwarg from the regular class
constructor - determines whether emitted names are auto-dashed.
.. versionadded:: 1.0
"""
module_name = module.__name__.split(".")[-1]
def instantiate(obj_name: Optional[str] = None) -> "Collection":
# Explicitly given name wins over root ns name (if applicable),
# which wins over actual module name.
args = [name or obj_name or module_name]
kwargs = dict(
loaded_from=loaded_from, auto_dash_names=auto_dash_names
)
instance = cls(*args, **kwargs)
instance.__doc__ = module.__doc__
return instance
# See if the module provides a default NS to use in lieu of creating
# our own collection.
for candidate in ("ns", "namespace"):
obj = getattr(module, candidate, None)
if obj and isinstance(obj, Collection):
# TODO: make this into Collection.clone() or similar?
ret = instantiate(obj_name=obj.name)
ret.tasks = ret._transform_lexicon(obj.tasks)
ret.collections = ret._transform_lexicon(obj.collections)
ret.default = (
ret.transform(obj.default) if obj.default else None
)
# Explicitly given config wins over root ns config
obj_config = copy_dict(obj._configuration)
if config:
merge_dicts(obj_config, config)
ret._configuration = obj_config
return ret
# Failing that, make our own collection from the module's tasks.
tasks = filter(lambda x: isinstance(x, Task), vars(module).values())
# Again, explicit name wins over implicit one from module path
collection = instantiate()
for task in tasks:
collection.add_task(task)
if config:
collection.configure(config)
return collection
def add_task(
self,
task: "Task",
name: Optional[str] = None,
aliases: Optional[Tuple[str, ...]] = None,
default: Optional[bool] = None,
) -> None:
"""
Add `.Task` ``task`` to this collection.
:param task: The `.Task` object to add to this collection.
:param name:
Optional string name to bind to (overrides the task's own
self-defined ``name`` attribute and/or any Python identifier (i.e.
``.func_name``.)
:param aliases:
Optional iterable of additional names to bind the task as, on top
of the primary name. These will be used in addition to any aliases
the task itself declares internally.
:param default: Whether this task should be the collection default.
.. versionadded:: 1.0
"""
if name is None:
if task.name:
name = task.name
# XXX https://github.com/python/mypy/issues/1424
elif hasattr(task.body, "func_name"):
name = task.body.func_name # type: ignore
elif hasattr(task.body, "__name__"):
name = task.__name__
else:
raise ValueError("Could not obtain a name for this task!")
name = self.transform(name)
if name in self.collections:
err = "Name conflict: this collection has a sub-collection named {!r} already" # noqa
raise ValueError(err.format(name))
self.tasks[name] = task
for alias in list(task.aliases) + list(aliases or []):
self.tasks.alias(self.transform(alias), to=name)
if default is True or (default is None and task.is_default):
self._check_default_collision(name)
self.default = name
def add_collection(
self,
coll: "Collection",
name: Optional[str] = None,
default: Optional[bool] = None,
) -> None:
"""
Add `.Collection` ``coll`` as a sub-collection of this one.
:param coll: The `.Collection` to add.
:param str name:
The name to attach the collection as. Defaults to the collection's
own internal name.
:param default:
Whether this sub-collection('s default task-or-collection) should
be the default invocation of the parent collection.
.. versionadded:: 1.0
.. versionchanged:: 1.5
Added the ``default`` parameter.
"""
# Handle module-as-collection
if isinstance(coll, ModuleType):
coll = Collection.from_module(coll)
# Ensure we have a name, or die trying
name = name or coll.name
if not name:
raise ValueError("Non-root collections must have a name!")
name = self.transform(name)
# Test for conflict
if name in self.tasks:
err = "Name conflict: this collection has a task named {!r} already" # noqa
raise ValueError(err.format(name))
# Insert
self.collections[name] = coll
if default:
self._check_default_collision(name)
self.default = name
def _check_default_collision(self, name: str) -> None:
if self.default:
msg = "'{}' cannot be the default because '{}' already is!"
raise ValueError(msg.format(name, self.default))
def _split_path(self, path: str) -> Tuple[str, str]:
"""
Obtain first collection + remainder, of a task path.
E.g. for ``"subcollection.taskname"``, return ``("subcollection",
"taskname")``; for ``"subcollection.nested.taskname"`` return
``("subcollection", "nested.taskname")``, etc.
An empty path becomes simply ``('', '')``.
"""
parts = path.split(".")
coll = parts.pop(0)
rest = ".".join(parts)
return coll, rest
def subcollection_from_path(self, path: str) -> "Collection":
"""
Given a ``path`` to a subcollection, return that subcollection.
.. versionadded:: 1.0
"""
parts = path.split(".")
collection = self
while parts:
collection = collection.collections[parts.pop(0)]
return collection
def __getitem__(self, name: Optional[str] = None) -> Any:
"""
Returns task named ``name``. Honors aliases and subcollections.
If this collection has a default task, it is returned when ``name`` is
empty or ``None``. If empty input is given and no task has been
selected as the default, ValueError will be raised.
Tasks within subcollections should be given in dotted form, e.g.
'foo.bar'. Subcollection default tasks will be returned on the
subcollection's name.
.. versionadded:: 1.0
"""
return self.task_with_config(name)[0]
def _task_with_merged_config(
self, coll: str, rest: str, ours: Dict[str, Any]
) -> Tuple[str, Dict[str, Any]]:
task, config = self.collections[coll].task_with_config(rest)
return task, dict(config, **ours)
def task_with_config(
self, name: Optional[str]
) -> Tuple[str, Dict[str, Any]]:
"""
Return task named ``name`` plus its configuration dict.
E.g. in a deeply nested tree, this method returns the `.Task`, and a
configuration dict created by merging that of this `.Collection` and
any nested `Collections <.Collection>`, up through the one actually
holding the `.Task`.
See `~.Collection.__getitem__` for semantics of the ``name`` argument.
:returns: Two-tuple of (`.Task`, `dict`).
.. versionadded:: 1.0
"""
# Our top level configuration
ours = self.configuration()
# Default task for this collection itself
if not name:
if not self.default:
raise ValueError("This collection has no default task.")
return self[self.default], ours
# Normalize name to the format we're expecting
name = self.transform(name)
# Non-default tasks within subcollections -> recurse (sorta)
if "." in name:
coll, rest = self._split_path(name)
return self._task_with_merged_config(coll, rest, ours)
# Default task for subcollections (via empty-name lookup)
if name in self.collections:
return self._task_with_merged_config(name, "", ours)
# Regular task lookup
return self.tasks[name], ours
def __contains__(self, name: str) -> bool:
try:
self[name]
return True
except KeyError:
return False
def to_contexts(
self, ignore_unknown_help: Optional[bool] = None
) -> List[ParserContext]:
"""
Returns all contained tasks and subtasks as a list of parser contexts.
:param bool ignore_unknown_help:
Passed on to each task's ``get_arguments()`` method. See the config
option by the same name for details.
.. versionadded:: 1.0
.. versionchanged:: 1.7
Added the ``ignore_unknown_help`` kwarg.
"""
result = []
for primary, aliases in self.task_names.items():
task = self[primary]
result.append(
ParserContext(
name=primary,
aliases=aliases,
args=task.get_arguments(
ignore_unknown_help=ignore_unknown_help
),
)
)
return result
def subtask_name(self, collection_name: str, task_name: str) -> str:
return ".".join(
[self.transform(collection_name), self.transform(task_name)]
)
def transform(self, name: str) -> str:
"""
Transform ``name`` with the configured auto-dashes behavior.
If the collection's ``auto_dash_names`` attribute is ``True``
(default), all non leading/trailing underscores are turned into dashes.
(Leading/trailing underscores tend to get stripped elsewhere in the
stack.)
If it is ``False``, the inverse is applied - all dashes are turned into
underscores.
.. versionadded:: 1.0
"""
# Short-circuit on anything non-applicable, e.g. empty strings, bools,
# None, etc.
if not name:
return name
from_, to = "_", "-"
if not self.auto_dash_names:
from_, to = "-", "_"
replaced = []
end = len(name) - 1
for i, char in enumerate(name):
# Don't replace leading or trailing underscores (+ taking dotted
# names into account)
# TODO: not 100% convinced of this / it may be exposing a
# discrepancy between this level & higher levels which tend to
# strip out leading/trailing underscores entirely.
if (
i not in (0, end)
and char == from_
and name[i - 1] != "."
and name[i + 1] != "."
):
char = to
replaced.append(char)
return "".join(replaced)
def _transform_lexicon(self, old: Lexicon) -> Lexicon:
"""
Take a Lexicon and apply `.transform` to its keys and aliases.
:returns: A new Lexicon.
"""
new = Lexicon()
# Lexicons exhibit only their real keys in most places, so this will
# only grab those, not aliases.
for key, value in old.items():
# Deepcopy the value so we're not just copying a reference
new[self.transform(key)] = copy.deepcopy(value)
# Also copy all aliases, which are string-to-string key mappings
for key, value in old.aliases.items():
new.alias(from_=self.transform(key), to=self.transform(value))
return new
@property
def task_names(self) -> Dict[str, List[str]]:
"""
Return all task identifiers for this collection as a one-level dict.
Specifically, a dict with the primary/"real" task names as the key, and
any aliases as a list value.
It basically collapses the namespace tree into a single
easily-scannable collection of invocation strings, and is thus suitable
for things like flat-style task listings or transformation into parser
contexts.
.. versionadded:: 1.0
"""
ret = {}
# Our own tasks get no prefix, just go in as-is: {name: [aliases]}
for name, task in self.tasks.items():
ret[name] = list(map(self.transform, task.aliases))
# Subcollection tasks get both name + aliases prefixed
for coll_name, coll in self.collections.items():
for task_name, aliases in coll.task_names.items():
aliases = list(
map(lambda x: self.subtask_name(coll_name, x), aliases)
)
# Tack on collection name to alias list if this task is the
# collection's default.
if coll.default == task_name:
aliases += (coll_name,)
ret[self.subtask_name(coll_name, task_name)] = aliases
return ret
def configuration(self, taskpath: Optional[str] = None) -> Dict[str, Any]:
"""
Obtain merged configuration values from collection & children.
:param taskpath:
(Optional) Task name/path, identical to that used for
`~.Collection.__getitem__` (e.g. may be dotted for nested tasks,
etc.) Used to decide which path to follow in the collection tree
when merging config values.
:returns: A `dict` containing configuration values.
.. versionadded:: 1.0
"""
if taskpath is None:
return copy_dict(self._configuration)
return self.task_with_config(taskpath)[1]
def configure(self, options: Dict[str, Any]) -> None:
"""
(Recursively) merge ``options`` into the current `.configuration`.
Options configured this way will be available to all tasks. It is
recommended to use unique keys to avoid potential clashes with other
config options
For example, if you were configuring a Sphinx docs build target
directory, it's better to use a key like ``'sphinx.target'`` than
simply ``'target'``.
:param options: An object implementing the dictionary protocol.
:returns: ``None``.
.. versionadded:: 1.0
"""
merge_dicts(self._configuration, options)
def serialized(self) -> Dict[str, Any]:
"""
Return an appropriate-for-serialization version of this object.
See the documentation for `.Program` and its ``json`` task listing
format; this method is the driver for that functionality.
.. versionadded:: 1.0
"""
return {
"name": self.name,
"help": helpline(self),
"default": self.default,
"tasks": [
{
"name": self.transform(x.name),
"help": helpline(x),
"aliases": [self.transform(y) for y in x.aliases],
}
for x in sorted(self.tasks.values(), key=lambda x: x.name)
],
"collections": [
x.serialized()
for x in sorted(
self.collections.values(), key=lambda x: x.name or ""
)
],
}

View File

@@ -0,0 +1,32 @@
# Invoke tab-completion script to be sourced with Bash shell.
# Known to work on Bash 3.x, untested on 4.x.
_complete_{binary}() {{
local candidates
# COMP_WORDS contains the entire command string up til now (including
# program name).
# We hand it to Invoke so it can figure out the current context: spit back
# core options, task names, the current task's options, or some combo.
candidates=`{binary} --complete -- ${{COMP_WORDS[*]}}`
# `compgen -W` takes list of valid options & a partial word & spits back
# possible matches. Necessary for any partial word completions (vs
# completions performed when no partial words are present).
#
# $2 is the current word or token being tabbed on, either empty string or a
# partial word, and thus wants to be compgen'd to arrive at some subset of
# our candidate list which actually matches.
#
# COMPREPLY is the list of valid completions handed back to `complete`.
COMPREPLY=( $(compgen -W "${{candidates}}" -- $2) )
}}
# Tell shell builtin to use the above for completing our invocations.
# * -F: use given function name to generate completions.
# * -o default: when function generates no results, use filenames.
# * positional args: program names to complete for.
complete -F _complete_{binary} -o default {spaced_names}
# vim: set ft=sh :

View File

@@ -0,0 +1,129 @@
"""
Command-line completion mechanisms, executed by the core ``--complete`` flag.
"""
from typing import List
import glob
import os
import re
import shlex
from typing import TYPE_CHECKING
from ..exceptions import Exit, ParseError
from ..util import debug, task_name_sort_key
if TYPE_CHECKING:
from ..collection import Collection
from ..parser import Parser, ParseResult, ParserContext
def complete(
names: List[str],
core: "ParseResult",
initial_context: "ParserContext",
collection: "Collection",
parser: "Parser",
) -> Exit:
# Strip out program name (scripts give us full command line)
# TODO: this may not handle path/to/script though?
invocation = re.sub(r"^({}) ".format("|".join(names)), "", core.remainder)
debug("Completing for invocation: {!r}".format(invocation))
# Tokenize (shlex will have to do)
tokens = shlex.split(invocation)
# Handle flags (partial or otherwise)
if tokens and tokens[-1].startswith("-"):
tail = tokens[-1]
debug("Invocation's tail {!r} is flag-like".format(tail))
# Gently parse invocation to obtain 'current' context.
# Use last seen context in case of failure (required for
# otherwise-invalid partial invocations being completed).
contexts: List[ParserContext]
try:
debug("Seeking context name in tokens: {!r}".format(tokens))
contexts = parser.parse_argv(tokens)
except ParseError as e:
msg = "Got parser error ({!r}), grabbing its last-seen context {!r}" # noqa
debug(msg.format(e, e.context))
contexts = [e.context] if e.context is not None else []
# Fall back to core context if no context seen.
debug("Parsed invocation, contexts: {!r}".format(contexts))
if not contexts or not contexts[-1]:
context = initial_context
else:
context = contexts[-1]
debug("Selected context: {!r}".format(context))
# Unknown flags (could be e.g. only partially typed out; could be
# wholly invalid; doesn't matter) complete with flags.
debug("Looking for {!r} in {!r}".format(tail, context.flags))
if tail not in context.flags:
debug("Not found, completing with flag names")
# Long flags - partial or just the dashes - complete w/ long flags
if tail.startswith("--"):
for name in filter(
lambda x: x.startswith("--"), context.flag_names()
):
print(name)
# Just a dash, completes with all flags
elif tail == "-":
for name in context.flag_names():
print(name)
# Otherwise, it's something entirely invalid (a shortflag not
# recognized, or a java style flag like -foo) so return nothing
# (the shell will still try completing with files, but that doesn't
# hurt really.)
else:
pass
# Known flags complete w/ nothing or tasks, depending
else:
# Flags expecting values: do nothing, to let default (usually
# file) shell completion occur (which we actively want in this
# case.)
if context.flags[tail].takes_value:
debug("Found, and it takes a value, so no completion")
pass
# Not taking values (eg bools): print task names
else:
debug("Found, takes no value, printing task names")
print_task_names(collection)
# If not a flag, is either task name or a flag value, so just complete
# task names.
else:
debug("Last token isn't flag-like, just printing task names")
print_task_names(collection)
raise Exit
def print_task_names(collection: "Collection") -> None:
for name in sorted(collection.task_names, key=task_name_sort_key):
print(name)
# Just stick aliases after the thing they're aliased to. Sorting isn't
# so important that it's worth bending over backwards here.
for alias in collection.task_names[name]:
print(alias)
def print_completion_script(shell: str, names: List[str]) -> None:
# Grab all .completion files in invoke/completion/. (These used to have no
# suffix, but surprise, that's super fragile.
completions = {
os.path.splitext(os.path.basename(x))[0]: x
for x in glob.glob(
os.path.join(
os.path.dirname(os.path.realpath(__file__)), "*.completion"
)
)
}
try:
path = completions[shell]
except KeyError:
err = 'Completion for shell "{}" not supported (options are: {}).'
raise ParseError(err.format(shell, ", ".join(sorted(completions))))
debug("Printing completion script from {}".format(path))
# Choose one arbitrary program name for script's own internal invocation
# (also used to construct completion function names when necessary)
binary = names[0]
with open(path, "r") as script:
print(
script.read().format(binary=binary, spaced_names=" ".join(names))
)

View File

@@ -0,0 +1,10 @@
# Invoke tab-completion script for the fish shell
# Copy it to the ~/.config/fish/completions directory
function __complete_{binary}
{binary} --complete -- (commandline --tokenize)
end
# --no-files: Don't complete files unless invoke gives an empty result
# TODO: find a way to honor all binary_names
complete --command {binary} --no-files --arguments '(__complete_{binary})'

View File

@@ -0,0 +1,33 @@
# Invoke tab-completion script to be sourced with the Z shell.
# Known to work on zsh 5.0.x, probably works on later 4.x releases as well (as
# it uses the older compctl completion system).
_complete_{binary}() {{
# `words` contains the entire command string up til now (including
# program name).
#
# We hand it to Invoke so it can figure out the current context: spit back
# core options, task names, the current task's options, or some combo.
#
# Before doing so, we attempt to tease out any collection flag+arg so we
# can ensure it is applied correctly.
collection_arg=''
if [[ "${{words}}" =~ "(-c|--collection) [^ ]+" ]]; then
collection_arg=$MATCH
fi
# `reply` is the array of valid completions handed back to `compctl`.
# Use ${{=...}} to force whitespace splitting in expansion of
# $collection_arg
reply=( $({binary} ${{=collection_arg}} --complete -- ${{words}}) )
}}
# Tell shell builtin to use the above for completing our given binary name(s).
# * -K: use given function name to generate completions.
# * +: specifies 'alternative' completion, where options after the '+' are only
# used if the completion from the options before the '+' result in no matches.
# * -f: when function generates no results, use filenames.
# * positional args: program names to complete for.
compctl -K _complete_{binary} + -f {spaced_names}
# vim: set ft=sh :

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,602 @@
import os
import re
from contextlib import contextmanager
from itertools import cycle
from os import PathLike
from typing import (
TYPE_CHECKING,
Any,
Generator,
Iterator,
List,
Optional,
Union,
)
from unittest.mock import Mock
from .config import Config, DataProxy
from .exceptions import Failure, AuthFailure, ResponseNotAccepted
from .runners import Result
from .watchers import FailingResponder
if TYPE_CHECKING:
from invoke.runners import Runner
class Context(DataProxy):
"""
Context-aware API wrapper & state-passing object.
`.Context` objects are created during command-line parsing (or, if desired,
by hand) and used to share parser and configuration state with executed
tasks (see :ref:`why-context`).
Specifically, the class offers wrappers for core API calls (such as `.run`)
which take into account CLI parser flags, configuration files, and/or
changes made at runtime. It also acts as a proxy for its `~.Context.config`
attribute - see that attribute's documentation for details.
Instances of `.Context` may be shared between tasks when executing
sub-tasks - either the same context the caller was given, or an altered
copy thereof (or, theoretically, a brand new one).
.. versionadded:: 1.0
"""
def __init__(self, config: Optional[Config] = None) -> None:
"""
:param config:
`.Config` object to use as the base configuration.
Defaults to an anonymous/default `.Config` instance.
"""
#: The fully merged `.Config` object appropriate for this context.
#:
#: `.Config` settings (see their documentation for details) may be
#: accessed like dictionary keys (``c.config['foo']``) or object
#: attributes (``c.config.foo``).
#:
#: As a convenience shorthand, the `.Context` object proxies to its
#: ``config`` attribute in the same way - e.g. ``c['foo']`` or
#: ``c.foo`` returns the same value as ``c.config['foo']``.
config = config if config is not None else Config()
self._set(_config=config)
#: A list of commands to run (via "&&") before the main argument to any
#: `run` or `sudo` calls. Note that the primary API for manipulating
#: this list is `prefix`; see its docs for details.
command_prefixes: List[str] = list()
self._set(command_prefixes=command_prefixes)
#: A list of directories to 'cd' into before running commands with
#: `run` or `sudo`; intended for management via `cd`, please see its
#: docs for details.
command_cwds: List[str] = list()
self._set(command_cwds=command_cwds)
@property
def config(self) -> Config:
# Allows Context to expose a .config attribute even though DataProxy
# otherwise considers it a config key.
return self._config
@config.setter
def config(self, value: Config) -> None:
# NOTE: mostly used by client libraries needing to tweak a Context's
# config at execution time; i.e. a Context subclass that bears its own
# unique data may want to be stood up when parameterizing/expanding a
# call list at start of a session, with the final config filled in at
# runtime.
self._set(_config=value)
def run(self, command: str, **kwargs: Any) -> Optional[Result]:
"""
Execute a local shell command, honoring config options.
Specifically, this method instantiates a `.Runner` subclass (according
to the ``runner`` config option; default is `.Local`) and calls its
``.run`` method with ``command`` and ``kwargs``.
See `.Runner.run` for details on ``command`` and the available keyword
arguments.
.. versionadded:: 1.0
"""
runner = self.config.runners.local(self)
return self._run(runner, command, **kwargs)
# NOTE: broken out of run() to allow for runner class injection in
# Fabric/etc, which needs to juggle multiple runner class types (local and
# remote).
def _run(
self, runner: "Runner", command: str, **kwargs: Any
) -> Optional[Result]:
command = self._prefix_commands(command)
return runner.run(command, **kwargs)
def sudo(self, command: str, **kwargs: Any) -> Optional[Result]:
"""
Execute a shell command via ``sudo`` with password auto-response.
**Basics**
This method is identical to `run` but adds a handful of
convenient behaviors around invoking the ``sudo`` program. It doesn't
do anything users could not do themselves by wrapping `run`, but the
use case is too common to make users reinvent these wheels themselves.
.. note::
If you intend to respond to sudo's password prompt by hand, just
use ``run("sudo command")`` instead! The autoresponding features in
this method will just get in your way.
Specifically, `sudo`:
* Places a `.FailingResponder` into the ``watchers`` kwarg (see
:doc:`/concepts/watchers`) which:
* searches for the configured ``sudo`` password prompt;
* responds with the configured sudo password (``sudo.password``
from the :doc:`configuration </concepts/configuration>`);
* can tell when that response causes an authentication failure
(e.g. if the system requires a password and one was not
configured), and raises `.AuthFailure` if so.
* Builds a ``sudo`` command string using the supplied ``command``
argument, prefixed by various flags (see below);
* Executes that command via a call to `run`, returning the result.
**Flags used**
``sudo`` flags used under the hood include:
- ``-S`` to allow auto-responding of password via stdin;
- ``-p <prompt>`` to explicitly state the prompt to use, so we can be
sure our auto-responder knows what to look for;
- ``-u <user>`` if ``user`` is not ``None``, to execute the command as
a user other than ``root``;
- When ``-u`` is present, ``-H`` is also added, to ensure the
subprocess has the requested user's ``$HOME`` set properly.
**Configuring behavior**
There are a couple of ways to change how this method behaves:
- Because it wraps `run`, it honors all `run` config parameters and
keyword arguments, in the same way that `run` does.
- Thus, invocations such as ``c.sudo('command', echo=True)`` are
possible, and if a config layer (such as a config file or env
var) specifies that e.g. ``run.warn = True``, that too will take
effect under `sudo`.
- `sudo` has its own set of keyword arguments (see below) and they are
also all controllable via the configuration system, under the
``sudo.*`` tree.
- Thus you could, for example, pre-set a sudo user in a config
file; such as an ``invoke.json`` containing ``{"sudo": {"user":
"someuser"}}``.
:param str password: Runtime override for ``sudo.password``.
:param str user: Runtime override for ``sudo.user``.
.. versionadded:: 1.0
"""
runner = self.config.runners.local(self)
return self._sudo(runner, command, **kwargs)
# NOTE: this is for runner injection; see NOTE above _run().
def _sudo(
self, runner: "Runner", command: str, **kwargs: Any
) -> Optional[Result]:
prompt = self.config.sudo.prompt
password = kwargs.pop("password", self.config.sudo.password)
user = kwargs.pop("user", self.config.sudo.user)
env = kwargs.get("env", {})
# TODO: allow subclassing for 'get the password' so users who REALLY
# want lazy runtime prompting can have it easily implemented.
# TODO: want to print a "cleaner" echo with just 'sudo <command>'; but
# hard to do as-is, obtaining config data from outside a Runner one
# holds is currently messy (could fix that), if instead we manually
# inspect the config ourselves that duplicates logic. NOTE: once we
# figure that out, there is an existing, would-fail-if-not-skipped test
# for this behavior in test/context.py.
# TODO: once that is done, though: how to handle "full debug" output
# exactly (display of actual, real full sudo command w/ -S and -p), in
# terms of API/config? Impl is easy, just go back to passing echo
# through to 'run'...
user_flags = ""
if user is not None:
user_flags = "-H -u {} ".format(user)
env_flags = ""
if env:
env_flags = "--preserve-env='{}' ".format(",".join(env.keys()))
command = self._prefix_commands(command)
cmd_str = "sudo -S -p '{}' {}{}{}".format(
prompt, env_flags, user_flags, command
)
watcher = FailingResponder(
pattern=re.escape(prompt),
response="{}\n".format(password),
sentinel="Sorry, try again.\n",
)
# Ensure we merge any user-specified watchers with our own.
# NOTE: If there are config-driven watchers, we pull those up to the
# kwarg level; that lets us merge cleanly without needing complex
# config-driven "override vs merge" semantics.
# TODO: if/when those semantics are implemented, use them instead.
# NOTE: config value for watchers defaults to an empty list; and we
# want to clone it to avoid actually mutating the config.
watchers = kwargs.pop("watchers", list(self.config.run.watchers))
watchers.append(watcher)
try:
return runner.run(cmd_str, watchers=watchers, **kwargs)
except Failure as failure:
# Transmute failures driven by our FailingResponder, into auth
# failures - the command never even ran.
# TODO: wants to be a hook here for users that desire "override a
# bad config value for sudo.password" manual input
# NOTE: as noted in #294 comments, we MAY in future want to update
# this so run() is given ability to raise AuthFailure on its own.
# For now that has been judged unnecessary complexity.
if isinstance(failure.reason, ResponseNotAccepted):
# NOTE: not bothering with 'reason' here, it's pointless.
error = AuthFailure(result=failure.result, prompt=prompt)
raise error
# Reraise for any other error so it bubbles up normally.
else:
raise
# TODO: wonder if it makes sense to move this part of things inside Runner,
# which would grow a `prefixes` and `cwd` init kwargs or similar. The less
# that's stuffed into Context, probably the better.
def _prefix_commands(self, command: str) -> str:
"""
Prefixes ``command`` with all prefixes found in ``command_prefixes``.
``command_prefixes`` is a list of strings which is modified by the
`prefix` context manager.
"""
prefixes = list(self.command_prefixes)
current_directory = self.cwd
if current_directory:
prefixes.insert(0, "cd {}".format(current_directory))
return " && ".join(prefixes + [command])
@contextmanager
def prefix(self, command: str) -> Generator[None, None, None]:
"""
Prefix all nested `run`/`sudo` commands with given command plus ``&&``.
Most of the time, you'll want to be using this alongside a shell script
which alters shell state, such as ones which export or alter shell
environment variables.
For example, one of the most common uses of this tool is with the
``workon`` command from `virtualenvwrapper
<https://virtualenvwrapper.readthedocs.io/en/latest/>`_::
with c.prefix('workon myvenv'):
c.run('./manage.py migrate')
In the above snippet, the actual shell command run would be this::
$ workon myvenv && ./manage.py migrate
This context manager is compatible with `cd`, so if your virtualenv
doesn't ``cd`` in its ``postactivate`` script, you could do the
following::
with c.cd('/path/to/app'):
with c.prefix('workon myvenv'):
c.run('./manage.py migrate')
c.run('./manage.py loaddata fixture')
Which would result in executions like so::
$ cd /path/to/app && workon myvenv && ./manage.py migrate
$ cd /path/to/app && workon myvenv && ./manage.py loaddata fixture
Finally, as alluded to above, `prefix` may be nested if desired, e.g.::
with c.prefix('workon myenv'):
c.run('ls')
with c.prefix('source /some/script'):
c.run('touch a_file')
The result::
$ workon myenv && ls
$ workon myenv && source /some/script && touch a_file
Contrived, but hopefully illustrative.
.. versionadded:: 1.0
"""
self.command_prefixes.append(command)
try:
yield
finally:
self.command_prefixes.pop()
@property
def cwd(self) -> str:
"""
Return the current working directory, accounting for uses of `cd`.
.. versionadded:: 1.0
"""
if not self.command_cwds:
# TODO: should this be None? Feels cleaner, though there may be
# benefits to it being an empty string, such as relying on a no-arg
# `cd` typically being shorthand for "go to user's $HOME".
return ""
# get the index for the subset of paths starting with the last / or ~
for i, path in reversed(list(enumerate(self.command_cwds))):
if path.startswith("~") or path.startswith("/"):
break
# TODO: see if there's a stronger "escape this path" function somewhere
# we can reuse. e.g., escaping tildes or slashes in filenames.
paths = [path.replace(" ", r"\ ") for path in self.command_cwds[i:]]
return str(os.path.join(*paths))
@contextmanager
def cd(self, path: Union[PathLike, str]) -> Generator[None, None, None]:
"""
Context manager that keeps directory state when executing commands.
Any calls to `run`, `sudo`, within the wrapped block will implicitly
have a string similar to ``"cd <path> && "`` prefixed in order to give
the sense that there is actually statefulness involved.
Because use of `cd` affects all such invocations, any code making use
of the `cwd` property will also be affected by use of `cd`.
Like the actual 'cd' shell builtin, `cd` may be called with relative
paths (keep in mind that your default starting directory is your user's
``$HOME``) and may be nested as well.
Below is a "normal" attempt at using the shell 'cd', which doesn't work
since all commands are executed in individual subprocesses -- state is
**not** kept between invocations of `run` or `sudo`::
c.run('cd /var/www')
c.run('ls')
The above snippet will list the contents of the user's ``$HOME``
instead of ``/var/www``. With `cd`, however, it will work as expected::
with c.cd('/var/www'):
c.run('ls') # Turns into "cd /var/www && ls"
Finally, a demonstration (see inline comments) of nesting::
with c.cd('/var/www'):
c.run('ls') # cd /var/www && ls
with c.cd('website1'):
c.run('ls') # cd /var/www/website1 && ls
.. note::
Space characters will be escaped automatically to make dealing with
such directory names easier.
.. versionadded:: 1.0
.. versionchanged:: 1.5
Explicitly cast the ``path`` argument (the only argument) to a
string; this allows any object defining ``__str__`` to be handed in
(such as the various ``Path`` objects out there), and not just
string literals.
"""
path = str(path)
self.command_cwds.append(path)
try:
yield
finally:
self.command_cwds.pop()
class MockContext(Context):
"""
A `.Context` whose methods' return values can be predetermined.
Primarily useful for testing Invoke-using codebases.
.. note::
This class wraps its ``run``, etc methods in `unittest.mock.Mock`
objects. This allows you to easily assert that the methods (still
returning the values you prepare them with) were actually called.
.. note::
Methods not given `Results <.Result>` to yield will raise
``NotImplementedError`` if called (since the alternative is to call the
real underlying method - typically undesirable when mocking.)
.. versionadded:: 1.0
.. versionchanged:: 1.5
Added ``Mock`` wrapping of ``run`` and ``sudo``.
"""
def __init__(self, config: Optional[Config] = None, **kwargs: Any) -> None:
"""
Create a ``Context``-like object whose methods yield `.Result` objects.
:param config:
A Configuration object to use. Identical in behavior to `.Context`.
:param run:
A data structure indicating what `.Result` objects to return from
calls to the instantiated object's `~.Context.run` method (instead
of actually executing the requested shell command).
Specifically, this kwarg accepts:
- A single `.Result` object.
- A boolean; if True, yields a `.Result` whose ``exited`` is ``0``,
and if False, ``1``.
- An iterable of the above values, which will be returned on each
subsequent call to ``.run`` (the first item on the first call,
the second on the second call, etc).
- A dict mapping command strings or compiled regexen to the above
values (including an iterable), allowing specific
call-and-response semantics instead of assuming a call order.
:param sudo:
Identical to ``run``, but whose values are yielded from calls to
`~.Context.sudo`.
:param bool repeat:
A flag determining whether results yielded by this class' methods
repeat or are consumed.
For example, when a single result is indicated, it will normally
only be returned once, causing ``NotImplementedError`` afterwards.
But when ``repeat=True`` is given, that result is returned on
every call, forever.
Similarly, iterable results are normally exhausted once, but when
this setting is enabled, they are wrapped in `itertools.cycle`.
Default: ``True``.
:raises:
``TypeError``, if the values given to ``run`` or other kwargs
aren't of the expected types.
.. versionchanged:: 1.5
Added support for boolean and string result values.
.. versionchanged:: 1.5
Added support for regex dict keys.
.. versionchanged:: 1.5
Added the ``repeat`` keyword argument.
.. versionchanged:: 2.0
Changed ``repeat`` default value from ``False`` to ``True``.
"""
# Set up like any other Context would, with the config
super().__init__(config)
# Pull out behavioral kwargs
self._set("__repeat", kwargs.pop("repeat", True))
# The rest must be things like run/sudo - mock Context method info
for method, results in kwargs.items():
# For each possible value type, normalize to iterable of Result
# objects (possibly repeating).
singletons = (Result, bool, str)
if isinstance(results, dict):
for key, value in results.items():
results[key] = self._normalize(value)
elif isinstance(results, singletons) or hasattr(
results, "__iter__"
):
results = self._normalize(results)
# Unknown input value: cry
else:
err = "Not sure how to yield results from a {!r}"
raise TypeError(err.format(type(results)))
# Save results for use by the method
self._set("__{}".format(method), results)
# Wrap the method in a Mock
self._set(method, Mock(wraps=getattr(self, method)))
def _normalize(self, value: Any) -> Iterator[Any]:
# First turn everything into an iterable
if not hasattr(value, "__iter__") or isinstance(value, str):
value = [value]
# Then turn everything within into a Result
results = []
for obj in value:
if isinstance(obj, bool):
obj = Result(exited=0 if obj else 1)
elif isinstance(obj, str):
obj = Result(obj)
results.append(obj)
# Finally, turn that iterable into an iteratOR, depending on repeat
return cycle(results) if getattr(self, "__repeat") else iter(results)
# TODO: _maybe_ make this more metaprogrammy/flexible (using __call__ etc)?
# Pretty worried it'd cause more hard-to-debug issues than it's presently
# worth. Maybe in situations where Context grows a _lot_ of methods (e.g.
# in Fabric 2; though Fabric could do its own sub-subclass in that case...)
def _yield_result(self, attname: str, command: str) -> Result:
try:
obj = getattr(self, attname)
# Dicts need to try direct lookup or regex matching
if isinstance(obj, dict):
try:
obj = obj[command]
except KeyError:
# TODO: could optimize by skipping this if not any regex
# objects in keys()?
for key, value in obj.items():
if hasattr(key, "match") and key.match(command):
obj = value
break
else:
# Nope, nothing did match.
raise KeyError
# Here, the value was either never a dict or has been extracted
# from one, so we can assume it's an iterable of Result objects due
# to work done by __init__.
result: Result = next(obj)
# Populate Result's command string with what matched unless
# explicitly given
if not result.command:
result.command = command
return result
except (AttributeError, IndexError, KeyError, StopIteration):
# raise_from(NotImplementedError(command), None)
raise NotImplementedError(command)
def run(self, command: str, *args: Any, **kwargs: Any) -> Result:
# TODO: perform more convenience stuff associating args/kwargs with the
# result? E.g. filling in .command, etc? Possibly useful for debugging
# if one hits unexpected-order problems with what they passed in to
# __init__.
return self._yield_result("__run", command)
def sudo(self, command: str, *args: Any, **kwargs: Any) -> Result:
# TODO: this completely nukes the top-level behavior of sudo(), which
# could be good or bad, depending. Most of the time I think it's good.
# No need to supply dummy password config, etc.
# TODO: see the TODO from run() re: injecting arg/kwarg values
return self._yield_result("__sudo", command)
def set_result_for(
self, attname: str, command: str, result: Result
) -> None:
"""
Modify the stored mock results for given ``attname`` (e.g. ``run``).
This is similar to how one instantiates `MockContext` with a ``run`` or
``sudo`` dict kwarg. For example, this::
mc = MockContext(run={'mycommand': Result("mystdout")})
assert mc.run('mycommand').stdout == "mystdout"
is functionally equivalent to this::
mc = MockContext()
mc.set_result_for('run', 'mycommand', Result("mystdout"))
assert mc.run('mycommand').stdout == "mystdout"
`set_result_for` is mostly useful for modifying an already-instantiated
`MockContext`, such as one created by test setup or helper methods.
.. versionadded:: 1.0
"""
attname = "__{}".format(attname)
heck = TypeError(
"Can't update results for non-dict or nonexistent mock results!"
)
# Get value & complain if it's not a dict.
# TODO: should we allow this to set non-dict values too? Seems vaguely
# pointless, at that point, just make a new MockContext eh?
try:
value = getattr(self, attname)
except AttributeError:
raise heck
if not isinstance(value, dict):
raise heck
# OK, we're good to modify, so do so.
value[command] = self._normalize(result)

View File

@@ -0,0 +1,123 @@
"""
Environment variable configuration loading class.
Using a class here doesn't really model anything but makes state passing (in a
situation requiring it) more convenient.
This module is currently considered private/an implementation detail and should
not be included in the Sphinx API documentation.
"""
import os
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Sequence
from .exceptions import UncastableEnvVar, AmbiguousEnvVar
from .util import debug
if TYPE_CHECKING:
from .config import Config
class Environment:
def __init__(self, config: "Config", prefix: str) -> None:
self._config = config
self._prefix = prefix
self.data: Dict[str, Any] = {} # Accumulator
def load(self) -> Dict[str, Any]:
"""
Return a nested dict containing values from `os.environ`.
Specifically, values whose keys map to already-known configuration
settings, allowing us to perform basic typecasting.
See :ref:`env-vars` for details.
"""
# Obtain allowed env var -> existing value map
env_vars = self._crawl(key_path=[], env_vars={})
m = "Scanning for env vars according to prefix: {!r}, mapping: {!r}"
debug(m.format(self._prefix, env_vars))
# Check for actual env var (honoring prefix) and try to set
for env_var, key_path in env_vars.items():
real_var = (self._prefix or "") + env_var
if real_var in os.environ:
self._path_set(key_path, os.environ[real_var])
debug("Obtained env var config: {!r}".format(self.data))
return self.data
def _crawl(
self, key_path: List[str], env_vars: Mapping[str, Sequence[str]]
) -> Dict[str, Any]:
"""
Examine config at location ``key_path`` & return potential env vars.
Uses ``env_vars`` dict to determine if a conflict exists, and raises an
exception if so. This dict is of the following form::
{
'EXPECTED_ENV_VAR_HERE': ['actual', 'nested', 'key_path'],
...
}
Returns another dictionary of new keypairs as per above.
"""
new_vars: Dict[str, List[str]] = {}
obj = self._path_get(key_path)
# Sub-dict -> recurse
if (
hasattr(obj, "keys")
and callable(obj.keys)
and hasattr(obj, "__getitem__")
):
for key in obj.keys():
merged_vars = dict(env_vars, **new_vars)
merged_path = key_path + [key]
crawled = self._crawl(merged_path, merged_vars)
# Handle conflicts
for key in crawled:
if key in new_vars:
err = "Found >1 source for {}"
raise AmbiguousEnvVar(err.format(key))
# Merge and continue
new_vars.update(crawled)
# Other -> is leaf, no recursion
else:
new_vars[self._to_env_var(key_path)] = key_path
return new_vars
def _to_env_var(self, key_path: Iterable[str]) -> str:
return "_".join(key_path).upper()
def _path_get(self, key_path: Iterable[str]) -> "Config":
# Gets are from self._config because that's what determines valid env
# vars and/or values for typecasting.
obj = self._config
for key in key_path:
obj = obj[key]
return obj
def _path_set(self, key_path: Sequence[str], value: str) -> None:
# Sets are to self.data since that's what we are presenting to the
# outer config object and debugging.
obj = self.data
for key in key_path[:-1]:
if key not in obj:
obj[key] = {}
obj = obj[key]
old = self._path_get(key_path)
new = self._cast(old, value)
obj[key_path[-1]] = new
def _cast(self, old: Any, new: Any) -> Any:
if isinstance(old, bool):
return new not in ("0", "")
elif isinstance(old, str):
return new
elif old is None:
return new
elif isinstance(old, (list, tuple)):
err = "Can't adapt an environment string into a {}!"
err = err.format(type(old))
raise UncastableEnvVar(err)
else:
return old.__class__(new)

View File

@@ -0,0 +1,425 @@
"""
Custom exception classes.
These vary in use case from "we needed a specific data structure layout in
exceptions used for message-passing" to simply "we needed to express an error
condition in a way easily told apart from other, truly unexpected errors".
"""
from pprint import pformat
from traceback import format_exception
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
if TYPE_CHECKING:
from .parser import ParserContext
from .runners import Result
from .util import ExceptionWrapper
class CollectionNotFound(Exception):
def __init__(self, name: str, start: str) -> None:
self.name = name
self.start = start
class Failure(Exception):
"""
Exception subclass representing failure of a command execution.
"Failure" may mean the command executed and the shell indicated an unusual
result (usually, a non-zero exit code), or it may mean something else, like
a ``sudo`` command which was aborted when the supplied password failed
authentication.
Two attributes allow introspection to determine the nature of the problem:
* ``result``: a `.Result` instance with info about the command being
executed and, if it ran to completion, how it exited.
* ``reason``: a wrapped exception instance if applicable (e.g. a
`.StreamWatcher` raised `WatcherError`) or ``None`` otherwise, in which
case, it's probably a `Failure` subclass indicating its own specific
nature, such as `UnexpectedExit` or `CommandTimedOut`.
This class is only rarely raised by itself; most of the time `.Runner.run`
(or a wrapper of same, such as `.Context.sudo`) will raise a specific
subclass like `UnexpectedExit` or `AuthFailure`.
.. versionadded:: 1.0
"""
def __init__(
self, result: "Result", reason: Optional["WatcherError"] = None
) -> None:
self.result = result
self.reason = reason
def streams_for_display(self) -> Tuple[str, str]:
"""
Return stdout/err streams as necessary for error display.
Subject to the following rules:
- If a given stream was *not* hidden during execution, a placeholder is
used instead, to avoid printing it twice.
- Only the last 10 lines of stream text is included.
- PTY-driven execution will lack stderr, and a specific message to this
effect is returned instead of a stderr dump.
:returns: Two-tuple of stdout, stderr strings.
.. versionadded:: 1.3
"""
already_printed = " already printed"
if "stdout" not in self.result.hide:
stdout = already_printed
else:
stdout = self.result.tail("stdout")
if self.result.pty:
stderr = " n/a (PTYs have no stderr)"
else:
if "stderr" not in self.result.hide:
stderr = already_printed
else:
stderr = self.result.tail("stderr")
return stdout, stderr
def __repr__(self) -> str:
return self._repr()
def _repr(self, **kwargs: Any) -> str:
"""
Return ``__repr__``-like value from inner result + any kwargs.
"""
# TODO: expand?
# TODO: truncate command?
template = "<{}: cmd={!r}{}>"
rest = ""
if kwargs:
rest = " " + " ".join(
"{}={}".format(key, value) for key, value in kwargs.items()
)
return template.format(
self.__class__.__name__, self.result.command, rest
)
class UnexpectedExit(Failure):
"""
A shell command ran to completion but exited with an unexpected exit code.
Its string representation displays the following:
- Command executed;
- Exit code;
- The last 10 lines of stdout, if it was hidden;
- The last 10 lines of stderr, if it was hidden and non-empty (e.g.
pty=False; when pty=True, stderr never happens.)
.. versionadded:: 1.0
"""
def __str__(self) -> str:
stdout, stderr = self.streams_for_display()
command = self.result.command
exited = self.result.exited
template = """Encountered a bad command exit code!
Command: {!r}
Exit code: {}
Stdout:{}
Stderr:{}
"""
return template.format(command, exited, stdout, stderr)
def _repr(self, **kwargs: Any) -> str:
kwargs.setdefault("exited", self.result.exited)
return super()._repr(**kwargs)
class CommandTimedOut(Failure):
"""
Raised when a subprocess did not exit within a desired timeframe.
"""
def __init__(self, result: "Result", timeout: int) -> None:
super().__init__(result)
self.timeout = timeout
def __repr__(self) -> str:
return self._repr(timeout=self.timeout)
def __str__(self) -> str:
stdout, stderr = self.streams_for_display()
command = self.result.command
template = """Command did not complete within {} seconds!
Command: {!r}
Stdout:{}
Stderr:{}
"""
return template.format(self.timeout, command, stdout, stderr)
class AuthFailure(Failure):
"""
An authentication failure, e.g. due to an incorrect ``sudo`` password.
.. note::
`.Result` objects attached to these exceptions typically lack exit code
information, since the command was never fully executed - the exception
was raised instead.
.. versionadded:: 1.0
"""
def __init__(self, result: "Result", prompt: str) -> None:
self.result = result
self.prompt = prompt
def __str__(self) -> str:
err = "The password submitted to prompt {!r} was rejected."
return err.format(self.prompt)
class ParseError(Exception):
"""
An error arising from the parsing of command-line flags/arguments.
Ambiguous input, invalid task names, invalid flags, etc.
.. versionadded:: 1.0
"""
def __init__(
self, msg: str, context: Optional["ParserContext"] = None
) -> None:
super().__init__(msg)
self.context = context
class Exit(Exception):
"""
Simple custom stand-in for SystemExit.
Replaces scattered sys.exit calls, improves testability, allows one to
catch an exit request without intercepting real SystemExits (typically an
unfriendly thing to do, as most users calling `sys.exit` rather expect it
to truly exit.)
Defaults to a non-printing, exit-0 friendly termination behavior if the
exception is uncaught.
If ``code`` (an int) given, that code is used to exit.
If ``message`` (a string) given, it is printed to standard error, and the
program exits with code ``1`` by default (unless overridden by also giving
``code`` explicitly.)
.. versionadded:: 1.0
"""
def __init__(
self, message: Optional[str] = None, code: Optional[int] = None
) -> None:
self.message = message
self._code = code
@property
def code(self) -> int:
if self._code is not None:
return self._code
return 1 if self.message else 0
class PlatformError(Exception):
"""
Raised when an illegal operation occurs for the current platform.
E.g. Windows users trying to use functionality requiring the ``pty``
module.
Typically used to present a clearer error message to the user.
.. versionadded:: 1.0
"""
pass
class AmbiguousEnvVar(Exception):
"""
Raised when loading env var config keys has an ambiguous target.
.. versionadded:: 1.0
"""
pass
class UncastableEnvVar(Exception):
"""
Raised on attempted env var loads whose default values are too rich.
E.g. trying to stuff ``MY_VAR="foo"`` into ``{'my_var': ['uh', 'oh']}``
doesn't make any sense until/if we implement some sort of transform option.
.. versionadded:: 1.0
"""
pass
class UnknownFileType(Exception):
"""
A config file of an unknown type was specified and cannot be loaded.
.. versionadded:: 1.0
"""
pass
class UnpicklableConfigMember(Exception):
"""
A config file contained module objects, which can't be pickled/copied.
We raise this more easily catchable exception instead of letting the
(unclearly phrased) TypeError bubble out of the pickle module. (However, to
avoid our own fragile catching of that error, we head it off by explicitly
testing for module members.)
.. versionadded:: 1.0.2
"""
pass
def _printable_kwargs(kwargs: Any) -> Dict[str, Any]:
"""
Return print-friendly version of a thread-related ``kwargs`` dict.
Extra care is taken with ``args`` members which are very long iterables -
those need truncating to be useful.
"""
printable = {}
for key, value in kwargs.items():
item = value
if key == "args":
item = []
for arg in value:
new_arg = arg
if hasattr(arg, "__len__") and len(arg) > 10:
msg = "<... remainder truncated during error display ...>"
new_arg = arg[:10] + [msg]
item.append(new_arg)
printable[key] = item
return printable
class ThreadException(Exception):
"""
One or more exceptions were raised within background threads.
The real underlying exceptions are stored in the `exceptions` attribute;
see its documentation for data structure details.
.. note::
Threads which did not encounter an exception, do not contribute to this
exception object and thus are not present inside `exceptions`.
.. versionadded:: 1.0
"""
#: A tuple of `ExceptionWrappers <invoke.util.ExceptionWrapper>` containing
#: the initial thread constructor kwargs (because `threading.Thread`
#: subclasses should always be called with kwargs) and the caught exception
#: for that thread as seen by `sys.exc_info` (so: type, value, traceback).
#:
#: .. note::
#: The ordering of this attribute is not well-defined.
#:
#: .. note::
#: Thread kwargs which appear to be very long (e.g. IO
#: buffers) will be truncated when printed, to avoid huge
#: unreadable error display.
exceptions: Tuple["ExceptionWrapper", ...] = tuple()
def __init__(self, exceptions: List["ExceptionWrapper"]) -> None:
self.exceptions = tuple(exceptions)
def __str__(self) -> str:
details = []
for x in self.exceptions:
# Build useful display
detail = "Thread args: {}\n\n{}"
details.append(
detail.format(
pformat(_printable_kwargs(x.kwargs)),
"\n".join(format_exception(x.type, x.value, x.traceback)),
)
)
args = (
len(self.exceptions),
", ".join(x.type.__name__ for x in self.exceptions),
"\n\n".join(details),
)
return """
Saw {} exceptions within threads ({}):
{}
""".format(
*args
)
class WatcherError(Exception):
"""
Generic parent exception class for `.StreamWatcher`-related errors.
Typically, one of these exceptions indicates a `.StreamWatcher` noticed
something anomalous in an output stream, such as an authentication response
failure.
`.Runner` catches these and attaches them to `.Failure` exceptions so they
can be referenced by intermediate code and/or act as extra info for end
users.
.. versionadded:: 1.0
"""
pass
class ResponseNotAccepted(WatcherError):
"""
A responder/watcher class noticed a 'bad' response to its submission.
Mostly used by `.FailingResponder` and subclasses, e.g. "oh dear I
autosubmitted a sudo password and it was incorrect."
.. versionadded:: 1.0
"""
pass
class SubprocessPipeError(Exception):
"""
Some problem was encountered handling subprocess pipes (stdout/err/in).
Typically only for corner cases; most of the time, errors in this area are
raised by the interpreter or the operating system, and end up wrapped in a
`.ThreadException`.
.. versionadded:: 1.3
"""
pass

View File

@@ -0,0 +1,229 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from .config import Config
from .parser import ParserContext
from .util import debug
from .tasks import Call, Task
if TYPE_CHECKING:
from .collection import Collection
from .runners import Result
from .parser import ParseResult
class Executor:
"""
An execution strategy for Task objects.
Subclasses may override various extension points to change, add or remove
behavior.
.. versionadded:: 1.0
"""
def __init__(
self,
collection: "Collection",
config: Optional["Config"] = None,
core: Optional["ParseResult"] = None,
) -> None:
"""
Initialize executor with handles to necessary data structures.
:param collection:
A `.Collection` used to look up requested tasks (and their default
config data, if any) by name during execution.
:param config:
An optional `.Config` holding configuration state. Defaults to an
empty `.Config` if not given.
:param core:
An optional `.ParseResult` holding parsed core program arguments.
Defaults to ``None``.
"""
self.collection = collection
self.config = config if config is not None else Config()
self.core = core
def execute(
self, *tasks: Union[str, Tuple[str, Dict[str, Any]], ParserContext]
) -> Dict["Task", "Result"]:
"""
Execute one or more ``tasks`` in sequence.
:param tasks:
An all-purpose iterable of "tasks to execute", each member of which
may take one of the following forms:
**A string** naming a task from the Executor's `.Collection`. This
name may contain dotted syntax appropriate for calling namespaced
tasks, e.g. ``subcollection.taskname``. Such tasks are executed
without arguments.
**A two-tuple** whose first element is a task name string (as
above) and whose second element is a dict suitable for use as
``**kwargs`` when calling the named task. E.g.::
[
('task1', {}),
('task2', {'arg1': 'val1'}),
...
]
is equivalent, roughly, to::
task1()
task2(arg1='val1')
**A `.ParserContext`** instance, whose ``.name`` attribute is used
as the task name and whose ``.as_kwargs`` attribute is used as the
task kwargs (again following the above specifications).
.. note::
When called without any arguments at all (i.e. when ``*tasks``
is empty), the default task from ``self.collection`` is used
instead, if defined.
:returns:
A dict mapping task objects to their return values.
This dict may include pre- and post-tasks if any were executed. For
example, in a collection with a ``build`` task depending on another
task named ``setup``, executing ``build`` will result in a dict
with two keys, one for ``build`` and one for ``setup``.
.. versionadded:: 1.0
"""
# Normalize input
debug("Examining top level tasks {!r}".format([x for x in tasks]))
calls = self.normalize(tasks)
debug("Tasks (now Calls) with kwargs: {!r}".format(calls))
# Obtain copy of directly-given tasks since they should sometimes
# behave differently
direct = list(calls)
# Expand pre/post tasks
# TODO: may make sense to bundle expansion & deduping now eh?
expanded = self.expand_calls(calls)
# Get some good value for dedupe option, even if config doesn't have
# the tree we expect. (This is a concession to testing.)
try:
dedupe = self.config.tasks.dedupe
except AttributeError:
dedupe = True
# Dedupe across entire run now that we know about all calls in order
calls = self.dedupe(expanded) if dedupe else expanded
# Execute
results = {}
# TODO: maybe clone initial config here? Probably not necessary,
# especially given Executor is not designed to execute() >1 time at the
# moment...
for call in calls:
autoprint = call in direct and call.autoprint
debug("Executing {!r}".format(call))
# Hand in reference to our config, which will preserve user
# modifications across the lifetime of the session.
config = self.config
# But make sure we reset its task-sensitive levels each time
# (collection & shell env)
# TODO: load_collection needs to be skipped if task is anonymous
# (Fabric 2 or other subclassing libs only)
collection_config = self.collection.configuration(call.called_as)
config.load_collection(collection_config)
config.load_shell_env()
debug("Finished loading collection & shell env configs")
# Get final context from the Call (which will know how to generate
# an appropriate one; e.g. subclasses might use extra data from
# being parameterized), handing in this config for use there.
context = call.make_context(config)
args = (context, *call.args)
result = call.task(*args, **call.kwargs)
if autoprint:
print(result)
# TODO: handle the non-dedupe case / the same-task-different-args
# case, wherein one task obj maps to >1 result.
results[call.task] = result
return results
def normalize(
self,
tasks: Tuple[
Union[str, Tuple[str, Dict[str, Any]], ParserContext], ...
],
) -> List["Call"]:
"""
Transform arbitrary task list w/ various types, into `.Call` objects.
See docstring for `~.Executor.execute` for details.
.. versionadded:: 1.0
"""
calls = []
for task in tasks:
name: Optional[str]
if isinstance(task, str):
name = task
kwargs = {}
elif isinstance(task, ParserContext):
name = task.name
kwargs = task.as_kwargs
else:
name, kwargs = task
c = Call(self.collection[name], kwargs=kwargs, called_as=name)
calls.append(c)
if not tasks and self.collection.default is not None:
calls = [Call(self.collection[self.collection.default])]
return calls
def dedupe(self, calls: List["Call"]) -> List["Call"]:
"""
Deduplicate a list of `tasks <.Call>`.
:param calls: An iterable of `.Call` objects representing tasks.
:returns: A list of `.Call` objects.
.. versionadded:: 1.0
"""
deduped = []
debug("Deduplicating tasks...")
for call in calls:
if call not in deduped:
debug("{!r}: no duplicates found, ok".format(call))
deduped.append(call)
else:
debug("{!r}: found in list already, skipping".format(call))
return deduped
def expand_calls(self, calls: List["Call"]) -> List["Call"]:
"""
Expand a list of `.Call` objects into a near-final list of same.
The default implementation of this method simply adds a task's
pre/post-task list before/after the task itself, as necessary.
Subclasses may wish to do other things in addition (or instead of) the
above, such as multiplying the `calls <.Call>` by argument vectors or
similar.
.. versionadded:: 1.0
"""
ret = []
for call in calls:
# Normalize to Call (this method is sometimes called with pre/post
# task lists, which may contain 'raw' Task objects)
if isinstance(call, Task):
call = Call(call)
debug("Expanding task-call {!r}".format(call))
# TODO: this is where we _used_ to call Executor.config_for(call,
# config)...
# TODO: now we may need to preserve more info like where the call
# came from, etc, but I feel like that shit should go _on the call
# itself_ right???
# TODO: we _probably_ don't even want the config in here anymore,
# we want this to _just_ be about the recursion across pre/post
# tasks or parameterization...?
ret.extend(self.expand_calls(call.pre))
ret.append(call)
ret.extend(self.expand_calls(call.post))
return ret

View File

@@ -0,0 +1,154 @@
import os
import sys
from importlib.machinery import ModuleSpec
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from types import ModuleType
from typing import Any, Optional, Tuple
from . import Config
from .exceptions import CollectionNotFound
from .util import debug
class Loader:
"""
Abstract class defining how to find/import a session's base `.Collection`.
.. versionadded:: 1.0
"""
def __init__(self, config: Optional["Config"] = None) -> None:
"""
Set up a new loader with some `.Config`.
:param config:
An explicit `.Config` to use; it is referenced for loading-related
config options. Defaults to an anonymous ``Config()`` if none is
given.
"""
if config is None:
config = Config()
self.config = config
def find(self, name: str) -> Optional[ModuleSpec]:
"""
Implementation-specific finder method seeking collection ``name``.
Must return a ModuleSpec valid for use by `importlib`, which is
typically a name string followed by the contents of the 3-tuple
returned by `importlib.module_from_spec` (``name``, ``loader``,
``origin``.)
For a sample implementation, see `.FilesystemLoader`.
.. versionadded:: 1.0
"""
raise NotImplementedError
def load(self, name: Optional[str] = None) -> Tuple[ModuleType, str]:
"""
Load and return collection module identified by ``name``.
This method requires a working implementation of `.find` in order to
function.
In addition to importing the named module, it will add the module's
parent directory to the front of `sys.path` to provide normal Python
import behavior (i.e. so the loaded module may load local-to-it modules
or packages.)
:returns:
Two-tuple of ``(module, directory)`` where ``module`` is the
collection-containing Python module object, and ``directory`` is
the string path to the directory the module was found in.
.. versionadded:: 1.0
"""
if name is None:
name = self.config.tasks.collection_name
spec = self.find(name)
if spec and spec.loader and spec.origin:
# Typically either tasks.py or tasks/__init__.py
source_file = Path(spec.origin)
# Will be 'the dir tasks.py is in', or 'tasks/', in both cases this
# is what wants to be in sys.path for "from . import sibling"
enclosing_dir = source_file.parent
# Will be "the directory above the spot that 'import tasks' found",
# namely the parent of "your task tree", i.e. "where project level
# config files are looked for". So, same as enclosing_dir for
# tasks.py, but one more level up for tasks/__init__.py...
module_parent = enclosing_dir
if spec.parent: # it's a package, so we have to go up again
module_parent = module_parent.parent
# Get the enclosing dir on the path
enclosing_str = str(enclosing_dir)
if enclosing_str not in sys.path:
sys.path.insert(0, enclosing_str)
# Actual import
module = module_from_spec(spec)
sys.modules[spec.name] = module # so 'from . import xxx' works
spec.loader.exec_module(module)
# Return the module and the folder it was found in
return module, str(module_parent)
msg = "ImportError loading {!r}, raising ImportError"
debug(msg.format(name))
raise ImportError
class FilesystemLoader(Loader):
"""
Loads Python files from the filesystem (e.g. ``tasks.py``.)
Searches recursively towards filesystem root from a given start point.
.. versionadded:: 1.0
"""
# TODO: could introduce config obj here for transmission to Collection
# TODO: otherwise Loader has to know about specific bits to transmit, such
# as auto-dashes, and has to grow one of those for every bit Collection
# ever needs to know
def __init__(self, start: Optional[str] = None, **kwargs: Any) -> None:
super().__init__(**kwargs)
if start is None:
start = self.config.tasks.search_root
self._start = start
@property
def start(self) -> str:
# Lazily determine default CWD if configured value is falsey
return self._start or os.getcwd()
def find(self, name: str) -> Optional[ModuleSpec]:
debug("FilesystemLoader find starting at {!r}".format(self.start))
spec = None
module = "{}.py".format(name)
paths = self.start.split(os.sep)
try:
# walk the path upwards to check for dynamic import
for x in reversed(range(len(paths) + 1)):
path = os.sep.join(paths[0:x])
if module in os.listdir(path):
spec = spec_from_file_location(
name, os.path.join(path, module)
)
break
elif name in os.listdir(path) and os.path.exists(
os.path.join(path, name, "__init__.py")
):
basepath = os.path.join(path, name)
spec = spec_from_file_location(
name,
os.path.join(basepath, "__init__.py"),
submodule_search_locations=[basepath],
)
break
if spec:
debug("Found module: {!r}".format(spec))
return spec
except (FileNotFoundError, ModuleNotFoundError):
msg = "ImportError loading {!r}, raising CollectionNotFound"
debug(msg.format(name))
raise CollectionNotFound(name=name, start=self.start)
return None

View File

@@ -0,0 +1,14 @@
"""
Invoke's own 'binary' entrypoint.
Dogfoods the `program` module.
"""
from . import __version__, Program
program = Program(
name="Invoke",
binary="inv[oke]",
binary_names=["invoke", "inv"],
version=__version__,
)

View File

@@ -0,0 +1,5 @@
# flake8: noqa
from .parser import *
from .context import ParserContext
from .context import ParserContext as Context, to_flag, translate_underscores
from .argument import Argument

View File

@@ -0,0 +1,178 @@
from typing import Any, Iterable, Optional, Tuple
# TODO: dynamic type for kind
# T = TypeVar('T')
class Argument:
"""
A command-line argument/flag.
:param name:
Syntactic sugar for ``names=[<name>]``. Giving both ``name`` and
``names`` is invalid.
:param names:
List of valid identifiers for this argument. For example, a "help"
argument may be defined with a name list of ``['-h', '--help']``.
:param kind:
Type factory & parser hint. E.g. ``int`` will turn the default text
value parsed, into a Python integer; and ``bool`` will tell the
parser not to expect an actual value but to treat the argument as a
toggle/flag.
:param default:
Default value made available to the parser if no value is given on the
command line.
:param help:
Help text, intended for use with ``--help``.
:param positional:
Whether or not this argument's value may be given positionally. When
``False`` (default) arguments must be explicitly named.
:param optional:
Whether or not this (non-``bool``) argument requires a value.
:param incrementable:
Whether or not this (``int``) argument is to be incremented instead of
overwritten/assigned to.
:param attr_name:
A Python identifier/attribute friendly name, typically filled in with
the underscored version when ``name``/``names`` contain dashes.
.. versionadded:: 1.0
"""
def __init__(
self,
name: Optional[str] = None,
names: Iterable[str] = (),
kind: Any = str,
default: Optional[Any] = None,
help: Optional[str] = None,
positional: bool = False,
optional: bool = False,
incrementable: bool = False,
attr_name: Optional[str] = None,
) -> None:
if name and names:
raise TypeError(
"Cannot give both 'name' and 'names' arguments! Pick one."
)
if not (name or names):
raise TypeError("An Argument must have at least one name.")
if names:
self.names = tuple(names)
elif name and not names:
self.names = (name,)
self.kind = kind
initial_value: Optional[Any] = None
# Special case: list-type args start out as empty list, not None.
if kind is list:
initial_value = []
# Another: incrementable args start out as their default value.
if incrementable:
initial_value = default
self.raw_value = self._value = initial_value
self.default = default
self.help = help
self.positional = positional
self.optional = optional
self.incrementable = incrementable
self.attr_name = attr_name
def __repr__(self) -> str:
nicks = ""
if self.nicknames:
nicks = " ({})".format(", ".join(self.nicknames))
flags = ""
if self.positional or self.optional:
flags = " "
if self.positional:
flags += "*"
if self.optional:
flags += "?"
# TODO: store this default value somewhere other than signature of
# Argument.__init__?
kind = ""
if self.kind != str:
kind = " [{}]".format(self.kind.__name__)
return "<{}: {}{}{}{}>".format(
self.__class__.__name__, self.name, nicks, kind, flags
)
@property
def name(self) -> Optional[str]:
"""
The canonical attribute-friendly name for this argument.
Will be ``attr_name`` (if given to constructor) or the first name in
``names`` otherwise.
.. versionadded:: 1.0
"""
return self.attr_name or self.names[0]
@property
def nicknames(self) -> Tuple[str, ...]:
return self.names[1:]
@property
def takes_value(self) -> bool:
if self.kind is bool:
return False
if self.incrementable:
return False
return True
@property
def value(self) -> Any:
# TODO: should probably be optional instead
return self._value if self._value is not None else self.default
@value.setter
def value(self, arg: str) -> None:
self.set_value(arg, cast=True)
def set_value(self, value: Any, cast: bool = True) -> None:
"""
Actual explicit value-setting API call.
Sets ``self.raw_value`` to ``value`` directly.
Sets ``self.value`` to ``self.kind(value)``, unless:
- ``cast=False``, in which case the raw value is also used.
- ``self.kind==list``, in which case the value is appended to
``self.value`` instead of cast & overwritten.
- ``self.incrementable==True``, in which case the value is ignored and
the current (assumed int) value is simply incremented.
.. versionadded:: 1.0
"""
self.raw_value = value
# Default to do-nothing/identity function
func = lambda x: x
# If cast, set to self.kind, which should be str/int/etc
if cast:
func = self.kind
# If self.kind is a list, append instead of using cast func.
if self.kind is list:
func = lambda x: self.value + [x]
# If incrementable, just increment.
if self.incrementable:
# TODO: explode nicely if self.value was not an int to start
# with
func = lambda x: self.value + 1
self._value = func(value)
@property
def got_value(self) -> bool:
"""
Returns whether the argument was ever given a (non-default) value.
For most argument kinds, this simply checks whether the internally
stored value is non-``None``; for others, such as ``list`` kinds,
different checks may be used.
.. versionadded:: 1.3
"""
if self.kind is list:
return bool(self._value)
return self._value is not None

View File

@@ -0,0 +1,266 @@
import itertools
from typing import Any, Dict, List, Iterable, Optional, Tuple, Union
try:
from ..vendor.lexicon import Lexicon
except ImportError:
from lexicon import Lexicon # type: ignore[no-redef]
from .argument import Argument
def translate_underscores(name: str) -> str:
return name.lstrip("_").rstrip("_").replace("_", "-")
def to_flag(name: str) -> str:
name = translate_underscores(name)
if len(name) == 1:
return "-" + name
return "--" + name
def sort_candidate(arg: Argument) -> str:
names = arg.names
# TODO: is there no "split into two buckets on predicate" builtin?
shorts = {x for x in names if len(x.strip("-")) == 1}
longs = {x for x in names if x not in shorts}
return str(sorted(shorts if shorts else longs)[0])
def flag_key(arg: Argument) -> List[Union[int, str]]:
"""
Obtain useful key list-of-ints for sorting CLI flags.
.. versionadded:: 1.0
"""
# Setup
ret: List[Union[int, str]] = []
x = sort_candidate(arg)
# Long-style flags win over short-style ones, so the first item of
# comparison is simply whether the flag is a single character long (with
# non-length-1 flags coming "first" [lower number])
ret.append(1 if len(x) == 1 else 0)
# Next item of comparison is simply the strings themselves,
# case-insensitive. They will compare alphabetically if compared at this
# stage.
ret.append(x.lower())
# Finally, if the case-insensitive test also matched, compare
# case-sensitive, but inverse (with lowercase letters coming first)
inversed = ""
for char in x:
inversed += char.lower() if char.isupper() else char.upper()
ret.append(inversed)
return ret
# Named slightly more verbose so Sphinx references can be unambiguous.
# Got real sick of fully qualified paths.
class ParserContext:
"""
Parsing context with knowledge of flags & their format.
Generally associated with the core program or a task.
When run through a parser, will also hold runtime values filled in by the
parser.
.. versionadded:: 1.0
"""
def __init__(
self,
name: Optional[str] = None,
aliases: Iterable[str] = (),
args: Iterable[Argument] = (),
) -> None:
"""
Create a new ``ParserContext`` named ``name``, with ``aliases``.
``name`` is optional, and should be a string if given. It's used to
tell ParserContext objects apart, and for use in a Parser when
determining what chunk of input might belong to a given ParserContext.
``aliases`` is also optional and should be an iterable containing
strings. Parsing will honor any aliases when trying to "find" a given
context in its input.
May give one or more ``args``, which is a quick alternative to calling
``for arg in args: self.add_arg(arg)`` after initialization.
"""
self.args = Lexicon()
self.positional_args: List[Argument] = []
self.flags = Lexicon()
self.inverse_flags: Dict[str, str] = {} # No need for Lexicon here
self.name = name
self.aliases = aliases
for arg in args:
self.add_arg(arg)
def __repr__(self) -> str:
aliases = ""
if self.aliases:
aliases = " ({})".format(", ".join(self.aliases))
name = (" {!r}{}".format(self.name, aliases)) if self.name else ""
args = (": {!r}".format(self.args)) if self.args else ""
return "<parser/Context{}{}>".format(name, args)
def add_arg(self, *args: Any, **kwargs: Any) -> None:
"""
Adds given ``Argument`` (or constructor args for one) to this context.
The Argument in question is added to the following dict attributes:
* ``args``: "normal" access, i.e. the given names are directly exposed
as keys.
* ``flags``: "flaglike" access, i.e. the given names are translated
into CLI flags, e.g. ``"foo"`` is accessible via ``flags['--foo']``.
* ``inverse_flags``: similar to ``flags`` but containing only the
"inverse" versions of boolean flags which default to True. This
allows the parser to track e.g. ``--no-myflag`` and turn it into a
False value for the ``myflag`` Argument.
.. versionadded:: 1.0
"""
# Normalize
if len(args) == 1 and isinstance(args[0], Argument):
arg = args[0]
else:
arg = Argument(*args, **kwargs)
# Uniqueness constraint: no name collisions
for name in arg.names:
if name in self.args:
msg = "Tried to add an argument named {!r} but one already exists!" # noqa
raise ValueError(msg.format(name))
# First name used as "main" name for purposes of aliasing
main = arg.names[0] # NOT arg.name
self.args[main] = arg
# Note positionals in distinct, ordered list attribute
if arg.positional:
self.positional_args.append(arg)
# Add names & nicknames to flags, args
self.flags[to_flag(main)] = arg
for name in arg.nicknames:
self.args.alias(name, to=main)
self.flags.alias(to_flag(name), to=to_flag(main))
# Add attr_name to args, but not flags
if arg.attr_name:
self.args.alias(arg.attr_name, to=main)
# Add to inverse_flags if required
if arg.kind == bool and arg.default is True:
# Invert the 'main' flag name here, which will be a dashed version
# of the primary argument name if underscore-to-dash transformation
# occurred.
inverse_name = to_flag("no-{}".format(main))
self.inverse_flags[inverse_name] = to_flag(main)
@property
def missing_positional_args(self) -> List[Argument]:
return [x for x in self.positional_args if x.value is None]
@property
def as_kwargs(self) -> Dict[str, Any]:
"""
This context's arguments' values keyed by their ``.name`` attribute.
Results in a dict suitable for use in Python contexts, where e.g. an
arg named ``foo-bar`` becomes accessible as ``foo_bar``.
.. versionadded:: 1.0
"""
ret = {}
for arg in self.args.values():
ret[arg.name] = arg.value
return ret
def names_for(self, flag: str) -> List[str]:
# TODO: should probably be a method on Lexicon/AliasDict
return list(set([flag] + self.flags.aliases_of(flag)))
def help_for(self, flag: str) -> Tuple[str, str]:
"""
Return 2-tuple of ``(flag-spec, help-string)`` for given ``flag``.
.. versionadded:: 1.0
"""
# Obtain arg obj
if flag not in self.flags:
err = "{!r} is not a valid flag for this context! Valid flags are: {!r}" # noqa
raise ValueError(err.format(flag, self.flags.keys()))
arg = self.flags[flag]
# Determine expected value type, if any
value = {str: "STRING", int: "INT"}.get(arg.kind)
# Format & go
full_names = []
for name in self.names_for(flag):
if value:
# Short flags are -f VAL, long are --foo=VAL
# When optional, also, -f [VAL] and --foo[=VAL]
if len(name.strip("-")) == 1:
value_ = ("[{}]".format(value)) if arg.optional else value
valuestr = " {}".format(value_)
else:
valuestr = "={}".format(value)
if arg.optional:
valuestr = "[{}]".format(valuestr)
else:
# no value => boolean
# check for inverse
if name in self.inverse_flags.values():
name = "--[no-]{}".format(name[2:])
valuestr = ""
# Tack together
full_names.append(name + valuestr)
namestr = ", ".join(sorted(full_names, key=len))
helpstr = arg.help or ""
return namestr, helpstr
def help_tuples(self) -> List[Tuple[str, Optional[str]]]:
"""
Return sorted iterable of help tuples for all member Arguments.
Sorts like so:
* General sort is alphanumerically
* Short flags win over long flags
* Arguments with *only* long flags and *no* short flags will come
first.
* When an Argument has multiple long or short flags, it will sort using
the most favorable (lowest alphabetically) candidate.
This will result in a help list like so::
--alpha, --zeta # 'alpha' wins
--beta
-a, --query # short flag wins
-b, --argh
-c
.. versionadded:: 1.0
"""
# TODO: argument/flag API must change :(
# having to call to_flag on 1st name of an Argument is just dumb.
# To pass in an Argument object to help_for may require moderate
# changes?
return list(
map(
lambda x: self.help_for(to_flag(x.name)),
sorted(self.flags.values(), key=flag_key),
)
)
def flag_names(self) -> Tuple[str, ...]:
"""
Similar to `help_tuples` but returns flag names only, no helpstrs.
Specifically, all flag names, flattened, in rough order.
.. versionadded:: 1.0
"""
# Regular flag names
flags = sorted(self.flags.values(), key=flag_key)
names = [self.names_for(to_flag(x.name)) for x in flags]
# Inverse flag names sold separately
names.append(list(self.inverse_flags.keys()))
return tuple(itertools.chain.from_iterable(names))

View File

@@ -0,0 +1,455 @@
import copy
from typing import TYPE_CHECKING, Any, Iterable, List, Optional
try:
from ..vendor.lexicon import Lexicon
from ..vendor.fluidity import StateMachine, state, transition
except ImportError:
from lexicon import Lexicon # type: ignore[no-redef]
from fluidity import ( # type: ignore[no-redef]
StateMachine,
state,
transition,
)
from ..exceptions import ParseError
from ..util import debug
if TYPE_CHECKING:
from .context import ParserContext
def is_flag(value: str) -> bool:
return value.startswith("-")
def is_long_flag(value: str) -> bool:
return value.startswith("--")
class ParseResult(List["ParserContext"]):
"""
List-like object with some extra parse-related attributes.
Specifically, a ``.remainder`` attribute, which is the string found after a
``--`` in any parsed argv list; and an ``.unparsed`` attribute, a list of
tokens that were unable to be parsed.
.. versionadded:: 1.0
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.remainder = ""
self.unparsed: List[str] = []
class Parser:
"""
Create parser conscious of ``contexts`` and optional ``initial`` context.
``contexts`` should be an iterable of ``Context`` instances which will be
searched when new context names are encountered during a parse. These
Contexts determine what flags may follow them, as well as whether given
flags take values.
``initial`` is optional and will be used to determine validity of "core"
options/flags at the start of the parse run, if any are encountered.
``ignore_unknown`` determines what to do when contexts are found which do
not map to any members of ``contexts``. By default it is ``False``, meaning
any unknown contexts result in a parse error exception. If ``True``,
encountering an unknown context halts parsing and populates the return
value's ``.unparsed`` attribute with the remaining parse tokens.
.. versionadded:: 1.0
"""
def __init__(
self,
contexts: Iterable["ParserContext"] = (),
initial: Optional["ParserContext"] = None,
ignore_unknown: bool = False,
) -> None:
self.initial = initial
self.contexts = Lexicon()
self.ignore_unknown = ignore_unknown
for context in contexts:
debug("Adding {}".format(context))
if not context.name:
raise ValueError("Non-initial contexts must have names.")
exists = "A context named/aliased {!r} is already in this parser!"
if context.name in self.contexts:
raise ValueError(exists.format(context.name))
self.contexts[context.name] = context
for alias in context.aliases:
if alias in self.contexts:
raise ValueError(exists.format(alias))
self.contexts.alias(alias, to=context.name)
def parse_argv(self, argv: List[str]) -> ParseResult:
"""
Parse an argv-style token list ``argv``.
Returns a list (actually a subclass, `.ParseResult`) of
`.ParserContext` objects matching the order they were found in the
``argv`` and containing `.Argument` objects with updated values based
on any flags given.
Assumes any program name has already been stripped out. Good::
Parser(...).parse_argv(['--core-opt', 'task', '--task-opt'])
Bad::
Parser(...).parse_argv(['invoke', '--core-opt', ...])
:param argv: List of argument string tokens.
:returns:
A `.ParseResult` (a ``list`` subclass containing some number of
`.ParserContext` objects).
.. versionadded:: 1.0
"""
machine = ParseMachine(
# FIXME: initial should not be none
initial=self.initial, # type: ignore[arg-type]
contexts=self.contexts,
ignore_unknown=self.ignore_unknown,
)
# FIXME: Why isn't there str.partition for lists? There must be a
# better way to do this. Split argv around the double-dash remainder
# sentinel.
debug("Starting argv: {!r}".format(argv))
try:
ddash = argv.index("--")
except ValueError:
ddash = len(argv) # No remainder == body gets all
body = argv[:ddash]
remainder = argv[ddash:][1:] # [1:] to strip off remainder itself
if remainder:
debug(
"Remainder: argv[{!r}:][1:] => {!r}".format(ddash, remainder)
)
for index, token in enumerate(body):
# Handle non-space-delimited forms, if not currently expecting a
# flag value and still in valid parsing territory (i.e. not in
# "unknown" state which implies store-only)
# NOTE: we do this in a few steps so we can
# split-then-check-validity; necessary for things like when the
# previously seen flag optionally takes a value.
mutations = []
orig = token
if is_flag(token) and not machine.result.unparsed:
# Equals-sign-delimited flags, eg --foo=bar or -f=bar
if "=" in token:
token, _, value = token.partition("=")
msg = "Splitting x=y expr {!r} into tokens {!r} and {!r}"
debug(msg.format(orig, token, value))
mutations.append((index + 1, value))
# Contiguous boolean short flags, e.g. -qv
elif not is_long_flag(token) and len(token) > 2:
full_token = token[:]
rest, token = token[2:], token[:2]
err = "Splitting {!r} into token {!r} and rest {!r}"
debug(err.format(full_token, token, rest))
# Handle boolean flag block vs short-flag + value. Make
# sure not to test the token as a context flag if we've
# passed into 'storing unknown stuff' territory (e.g. on a
# core-args pass, handling what are going to be task args)
have_flag = (
token in machine.context.flags
and machine.current_state != "unknown"
)
if have_flag and machine.context.flags[token].takes_value:
msg = "{!r} is a flag for current context & it takes a value, giving it {!r}" # noqa
debug(msg.format(token, rest))
mutations.append((index + 1, rest))
else:
_rest = ["-{}".format(x) for x in rest]
msg = "Splitting multi-flag glob {!r} into {!r} and {!r}" # noqa
debug(msg.format(orig, token, _rest))
for item in reversed(_rest):
mutations.append((index + 1, item))
# Here, we've got some possible mutations queued up, and 'token'
# may have been overwritten as well. Whether we apply those and
# continue as-is, or roll it back, depends:
# - If the parser wasn't waiting for a flag value, we're already on
# the right track, so apply mutations and move along to the
# handle() step.
# - If we ARE waiting for a value, and the flag expecting it ALWAYS
# wants a value (it's not optional), we go back to using the
# original token. (TODO: could reorganize this to avoid the
# sub-parsing in this case, but optimizing for human-facing
# execution isn't critical.)
# - Finally, if we are waiting for a value AND it's optional, we
# inspect the first sub-token/mutation to see if it would otherwise
# have been a valid flag, and let that determine what we do (if
# valid, we apply the mutations; if invalid, we reinstate the
# original token.)
if machine.waiting_for_flag_value:
optional = machine.flag and machine.flag.optional
subtoken_is_valid_flag = token in machine.context.flags
if not (optional and subtoken_is_valid_flag):
token = orig
mutations = []
for index, value in mutations:
body.insert(index, value)
machine.handle(token)
machine.finish()
result = machine.result
result.remainder = " ".join(remainder)
return result
class ParseMachine(StateMachine):
initial_state = "context"
state("context", enter=["complete_flag", "complete_context"])
state("unknown", enter=["complete_flag", "complete_context"])
state("end", enter=["complete_flag", "complete_context"])
transition(from_=("context", "unknown"), event="finish", to="end")
transition(
from_="context",
event="see_context",
action="switch_to_context",
to="context",
)
transition(
from_=("context", "unknown"),
event="see_unknown",
action="store_only",
to="unknown",
)
def changing_state(self, from_: str, to: str) -> None:
debug("ParseMachine: {!r} => {!r}".format(from_, to))
def __init__(
self,
initial: "ParserContext",
contexts: Lexicon,
ignore_unknown: bool,
) -> None:
# Initialize
self.ignore_unknown = ignore_unknown
self.initial = self.context = copy.deepcopy(initial)
debug("Initialized with context: {!r}".format(self.context))
self.flag = None
self.flag_got_value = False
self.result = ParseResult()
self.contexts = copy.deepcopy(contexts)
debug("Available contexts: {!r}".format(self.contexts))
# In case StateMachine does anything in __init__
super().__init__()
@property
def waiting_for_flag_value(self) -> bool:
# Do we have a current flag, and does it expect a value (vs being a
# bool/toggle)?
takes_value = self.flag and self.flag.takes_value
if not takes_value:
return False
# OK, this flag is one that takes values.
# Is it a list type (which has only just been switched to)? Then it'll
# always accept more values.
# TODO: how to handle somebody wanting it to be some other iterable
# like tuple or custom class? Or do we just say unsupported?
if self.flag.kind is list and not self.flag_got_value:
return True
# Not a list, okay. Does it already have a value?
has_value = self.flag.raw_value is not None
# If it doesn't have one, we're waiting for one (which tells the parser
# how to proceed and typically to store the next token.)
# TODO: in the negative case here, we should do something else instead:
# - Except, "hey you screwed up, you already gave that flag!"
# - Overwrite, "oh you changed your mind?" - which requires more work
# elsewhere too, unfortunately. (Perhaps additional properties on
# Argument that can be queried, e.g. "arg.is_iterable"?)
return not has_value
def handle(self, token: str) -> None:
debug("Handling token: {!r}".format(token))
# Handle unknown state at the top: we don't care about even
# possibly-valid input if we've encountered unknown input.
if self.current_state == "unknown":
debug("Top-of-handle() see_unknown({!r})".format(token))
self.see_unknown(token)
return
# Flag
if self.context and token in self.context.flags:
debug("Saw flag {!r}".format(token))
self.switch_to_flag(token)
elif self.context and token in self.context.inverse_flags:
debug("Saw inverse flag {!r}".format(token))
self.switch_to_flag(token, inverse=True)
# Value for current flag
elif self.waiting_for_flag_value:
debug(
"We're waiting for a flag value so {!r} must be it?".format(
token
)
) # noqa
self.see_value(token)
# Positional args (must come above context-name check in case we still
# need a posarg and the user legitimately wants to give it a value that
# just happens to be a valid context name.)
elif self.context and self.context.missing_positional_args:
msg = "Context {!r} requires positional args, eating {!r}"
debug(msg.format(self.context, token))
self.see_positional_arg(token)
# New context
elif token in self.contexts:
self.see_context(token)
# Initial-context flag being given as per-task flag (e.g. --help)
elif self.initial and token in self.initial.flags:
debug("Saw (initial-context) flag {!r}".format(token))
flag = self.initial.flags[token]
# Special-case for core --help flag: context name is used as value.
if flag.name == "help":
flag.value = self.context.name
msg = "Saw --help in a per-task context, setting task name ({!r}) as its value" # noqa
debug(msg.format(flag.value))
# All others: just enter the 'switch to flag' parser state
else:
# TODO: handle inverse core flags too? There are none at the
# moment (e.g. --no-dedupe is actually 'no_dedupe', not a
# default-False 'dedupe') and it's up to us whether we actually
# put any in place.
self.switch_to_flag(token)
# Unknown
else:
if not self.ignore_unknown:
debug("Can't find context named {!r}, erroring".format(token))
self.error("No idea what {!r} is!".format(token))
else:
debug("Bottom-of-handle() see_unknown({!r})".format(token))
self.see_unknown(token)
def store_only(self, token: str) -> None:
# Start off the unparsed list
debug("Storing unknown token {!r}".format(token))
self.result.unparsed.append(token)
def complete_context(self) -> None:
debug(
"Wrapping up context {!r}".format(
self.context.name if self.context else self.context
)
)
# Ensure all of context's positional args have been given.
if self.context and self.context.missing_positional_args:
err = "'{}' did not receive required positional arguments: {}"
names = ", ".join(
"'{}'".format(x.name)
for x in self.context.missing_positional_args
)
self.error(err.format(self.context.name, names))
if self.context and self.context not in self.result:
self.result.append(self.context)
def switch_to_context(self, name: str) -> None:
self.context = copy.deepcopy(self.contexts[name])
debug("Moving to context {!r}".format(name))
debug("Context args: {!r}".format(self.context.args))
debug("Context flags: {!r}".format(self.context.flags))
debug("Context inverse_flags: {!r}".format(self.context.inverse_flags))
def complete_flag(self) -> None:
if self.flag:
msg = "Completing current flag {} before moving on"
debug(msg.format(self.flag))
# Barf if we needed a value and didn't get one
if (
self.flag
and self.flag.takes_value
and self.flag.raw_value is None
and not self.flag.optional
):
err = "Flag {!r} needed value and was not given one!"
self.error(err.format(self.flag))
# Handle optional-value flags; at this point they were not given an
# explicit value, but they were seen, ergo they should get treated like
# bools.
if self.flag and self.flag.raw_value is None and self.flag.optional:
msg = "Saw optional flag {!r} go by w/ no value; setting to True"
debug(msg.format(self.flag.name))
# Skip casting so the bool gets preserved
self.flag.set_value(True, cast=False)
def check_ambiguity(self, value: Any) -> bool:
"""
Guard against ambiguity when current flag takes an optional value.
.. versionadded:: 1.0
"""
# No flag is currently being examined, or one is but it doesn't take an
# optional value? Ambiguity isn't possible.
if not (self.flag and self.flag.optional):
return False
# We *are* dealing with an optional-value flag, but it's already
# received a value? There can't be ambiguity here either.
if self.flag.raw_value is not None:
return False
# Otherwise, there *may* be ambiguity if 1 or more of the below tests
# fail.
tests = []
# Unfilled posargs still exist?
tests.append(self.context and self.context.missing_positional_args)
# Value matches another valid task/context name?
tests.append(value in self.contexts)
if any(tests):
msg = "{!r} is ambiguous when given after an optional-value flag"
raise ParseError(msg.format(value))
def switch_to_flag(self, flag: str, inverse: bool = False) -> None:
# Sanity check for ambiguity w/ prior optional-value flag
self.check_ambiguity(flag)
# Also tie it off, in case prior had optional value or etc. Seems to be
# harmless for other kinds of flags. (TODO: this is a serious indicator
# that we need to move some of this flag-by-flag bookkeeping into the
# state machine bits, if possible - as-is it was REAL confusing re: why
# this was manually required!)
self.complete_flag()
# Set flag/arg obj
flag = self.context.inverse_flags[flag] if inverse else flag
# Update state
try:
self.flag = self.context.flags[flag]
except KeyError as e:
# Try fallback to initial/core flag
try:
self.flag = self.initial.flags[flag]
except KeyError:
# If it wasn't in either, raise the original context's
# exception, as that's more useful / correct.
raise e
debug("Moving to flag {!r}".format(self.flag))
# Bookkeeping for iterable-type flags (where the typical 'value
# non-empty/nondefault -> clearly it got its value already' test is
# insufficient)
self.flag_got_value = False
# Handle boolean flags (which can immediately be updated)
if self.flag and not self.flag.takes_value:
val = not inverse
debug("Marking seen flag {!r} as {}".format(self.flag, val))
self.flag.value = val
def see_value(self, value: Any) -> None:
self.check_ambiguity(value)
if self.flag and self.flag.takes_value:
debug("Setting flag {!r} to value {!r}".format(self.flag, value))
self.flag.value = value
self.flag_got_value = True
else:
self.error("Flag {!r} doesn't take any value!".format(self.flag))
def see_positional_arg(self, value: Any) -> None:
for arg in self.context.positional_args:
if arg.value is None:
arg.value = value
break
def error(self, msg: str) -> None:
raise ParseError(msg, self.context)

View File

@@ -0,0 +1,987 @@
import getpass
import inspect
import json
import os
import sys
import textwrap
from importlib import import_module # buffalo buffalo
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
)
from . import Collection, Config, Executor, FilesystemLoader
from .completion.complete import complete, print_completion_script
from .parser import Parser, ParserContext, Argument
from .exceptions import UnexpectedExit, CollectionNotFound, ParseError, Exit
from .terminals import pty_size
from .util import debug, enable_logging, helpline
if TYPE_CHECKING:
from .loader import Loader
from .parser import ParseResult
from .util import Lexicon
class Program:
"""
Manages top-level CLI invocation, typically via ``setup.py`` entrypoints.
Designed for distributing Invoke task collections as standalone programs,
but also used internally to implement the ``invoke`` program itself.
.. seealso::
:ref:`reusing-as-a-binary` for a tutorial/walkthrough of this
functionality.
.. versionadded:: 1.0
"""
core: "ParseResult"
def core_args(self) -> List["Argument"]:
"""
Return default core `.Argument` objects, as a list.
.. versionadded:: 1.0
"""
# Arguments present always, even when wrapped as a different binary
return [
Argument(
names=("command-timeout", "T"),
kind=int,
help="Specify a global command execution timeout, in seconds.",
),
Argument(
names=("complete",),
kind=bool,
default=False,
help="Print tab-completion candidates for given parse remainder.", # noqa
),
Argument(
names=("config", "f"),
help="Runtime configuration file to use.",
),
Argument(
names=("debug", "d"),
kind=bool,
default=False,
help="Enable debug output.",
),
Argument(
names=("dry", "R"),
kind=bool,
default=False,
help="Echo commands instead of running.",
),
Argument(
names=("echo", "e"),
kind=bool,
default=False,
help="Echo executed commands before running.",
),
Argument(
names=("help", "h"),
optional=True,
help="Show core or per-task help and exit.",
),
Argument(
names=("hide",),
help="Set default value of run()'s 'hide' kwarg.",
),
Argument(
names=("list", "l"),
optional=True,
help="List available tasks, optionally limited to a namespace.", # noqa
),
Argument(
names=("list-depth", "D"),
kind=int,
default=0,
help="When listing tasks, only show the first INT levels.",
),
Argument(
names=("list-format", "F"),
help="Change the display format used when listing tasks. Should be one of: flat (default), nested, json.", # noqa
default="flat",
),
Argument(
names=("print-completion-script",),
kind=str,
default="",
help="Print the tab-completion script for your preferred shell (bash|zsh|fish).", # noqa
),
Argument(
names=("prompt-for-sudo-password",),
kind=bool,
default=False,
help="Prompt user at start of session for the sudo.password config value.", # noqa
),
Argument(
names=("pty", "p"),
kind=bool,
default=False,
help="Use a pty when executing shell commands.",
),
Argument(
names=("version", "V"),
kind=bool,
default=False,
help="Show version and exit.",
),
Argument(
names=("warn-only", "w"),
kind=bool,
default=False,
help="Warn, instead of failing, when shell commands fail.",
),
Argument(
names=("write-pyc",),
kind=bool,
default=False,
help="Enable creation of .pyc files.",
),
]
def task_args(self) -> List["Argument"]:
"""
Return default task-related `.Argument` objects, as a list.
These are only added to the core args in "task runner" mode (the
default for ``invoke`` itself) - they are omitted when the constructor
is given a non-empty ``namespace`` argument ("bundled namespace" mode).
.. versionadded:: 1.0
"""
# Arguments pertaining specifically to invocation as 'invoke' itself
# (or as other arbitrary-task-executing programs, like 'fab')
return [
Argument(
names=("collection", "c"),
help="Specify collection name to load.",
),
Argument(
names=("no-dedupe",),
kind=bool,
default=False,
help="Disable task deduplication.",
),
Argument(
names=("search-root", "r"),
help="Change root directory used for finding task modules.",
),
]
argv: List[str]
# Other class-level global variables a subclass might override sometime
# maybe?
leading_indent_width = 2
leading_indent = " " * leading_indent_width
indent_width = 4
indent = " " * indent_width
col_padding = 3
def __init__(
self,
version: Optional[str] = None,
namespace: Optional["Collection"] = None,
name: Optional[str] = None,
binary: Optional[str] = None,
loader_class: Optional[Type["Loader"]] = None,
executor_class: Optional[Type["Executor"]] = None,
config_class: Optional[Type["Config"]] = None,
binary_names: Optional[List[str]] = None,
) -> None:
"""
Create a new, parameterized `.Program` instance.
:param str version:
The program's version, e.g. ``"0.1.0"``. Defaults to ``"unknown"``.
:param namespace:
A `.Collection` to use as this program's subcommands.
If ``None`` (the default), the program will behave like ``invoke``,
seeking a nearby task namespace with a `.Loader` and exposing
arguments such as :option:`--list` and :option:`--collection` for
inspecting or selecting specific namespaces.
If given a `.Collection` object, will use it as if it had been
handed to :option:`--collection`. Will also update the parser to
remove references to tasks and task-related options, and display
the subcommands in ``--help`` output. The result will be a program
that has a static set of subcommands.
:param str name:
The program's name, as displayed in ``--version`` output.
If ``None`` (default), is a capitalized version of the first word
in the ``argv`` handed to `.run`. For example, when invoked from a
binstub installed as ``foobar``, it will default to ``Foobar``.
:param str binary:
Descriptive lowercase binary name string used in help text.
For example, Invoke's own internal value for this is ``inv[oke]``,
denoting that it is installed as both ``inv`` and ``invoke``. As
this is purely text intended for help display, it may be in any
format you wish, though it should match whatever you've put into
your ``setup.py``'s ``console_scripts`` entry.
If ``None`` (default), uses the first word in ``argv`` verbatim (as
with ``name`` above, except not capitalized).
:param binary_names:
List of binary name strings, for use in completion scripts.
This list ensures that the shell completion scripts generated by
:option:`--print-completion-script` instruct the shell to use
that completion for all of this program's installed names.
For example, Invoke's internal default for this is ``["inv",
"invoke"]``.
If ``None`` (the default), the first word in ``argv`` (in the
invocation of :option:`--print-completion-script`) is used in a
single-item list.
:param loader_class:
The `.Loader` subclass to use when loading task collections.
Defaults to `.FilesystemLoader`.
:param executor_class:
The `.Executor` subclass to use when executing tasks.
Defaults to `.Executor`; may also be overridden at runtime by the
:ref:`configuration system <default-values>` and its
``tasks.executor_class`` setting (anytime that setting is not
``None``).
:param config_class:
The `.Config` subclass to use for the base config object.
Defaults to `.Config`.
.. versionchanged:: 1.2
Added the ``binary_names`` argument.
"""
self.version = "unknown" if version is None else version
self.namespace = namespace
self._name = name
# TODO 3.0: rename binary to binary_help_name or similar. (Or write
# code to autogenerate it from binary_names.)
self._binary = binary
self._binary_names = binary_names
self.argv = []
self.loader_class = loader_class or FilesystemLoader
self.executor_class = executor_class or Executor
self.config_class = config_class or Config
def create_config(self) -> None:
"""
Instantiate a `.Config` (or subclass, depending) for use in task exec.
This Config is fully usable but will lack runtime-derived data like
project & runtime config files, CLI arg overrides, etc. That data is
added later in `update_config`. See `.Config` docstring for lifecycle
details.
:returns: ``None``; sets ``self.config`` instead.
.. versionadded:: 1.0
"""
self.config = self.config_class()
def update_config(self, merge: bool = True) -> None:
"""
Update the previously instantiated `.Config` with parsed data.
For example, this is how ``--echo`` is able to override the default
config value for ``run.echo``.
:param bool merge:
Whether to merge at the end, or defer. Primarily useful for
subclassers. Default: ``True``.
.. versionadded:: 1.0
"""
# Now that we have parse results handy, we can grab the remaining
# config bits:
# - runtime config, as it is dependent on the runtime flag/env var
# - the overrides config level, as it is composed of runtime flag data
# NOTE: only fill in values that would alter behavior, otherwise we
# want the defaults to come through.
run = {}
if self.args["warn-only"].value:
run["warn"] = True
if self.args.pty.value:
run["pty"] = True
if self.args.hide.value:
run["hide"] = self.args.hide.value
if self.args.echo.value:
run["echo"] = True
if self.args.dry.value:
run["dry"] = True
tasks = {}
if "no-dedupe" in self.args and self.args["no-dedupe"].value:
tasks["dedupe"] = False
timeouts = {}
command = self.args["command-timeout"].value
if command:
timeouts["command"] = command
# Handle "fill in config values at start of runtime", which for now is
# just sudo password
sudo = {}
if self.args["prompt-for-sudo-password"].value:
prompt = "Desired 'sudo.password' config value: "
sudo["password"] = getpass.getpass(prompt)
overrides = dict(run=run, tasks=tasks, sudo=sudo, timeouts=timeouts)
self.config.load_overrides(overrides, merge=False)
runtime_path = self.args.config.value
if runtime_path is None:
runtime_path = os.environ.get("INVOKE_RUNTIME_CONFIG", None)
self.config.set_runtime_path(runtime_path)
self.config.load_runtime(merge=False)
if merge:
self.config.merge()
def run(self, argv: Optional[List[str]] = None, exit: bool = True) -> None:
"""
Execute main CLI logic, based on ``argv``.
:param argv:
The arguments to execute against. May be ``None``, a list of
strings, or a string. See `.normalize_argv` for details.
:param bool exit:
When ``False`` (default: ``True``), will ignore `.ParseError`,
`.Exit` and `.Failure` exceptions, which otherwise trigger calls to
`sys.exit`.
.. note::
This is mostly a concession to testing. If you're setting this
to ``False`` in a production setting, you should probably be
using `.Executor` and friends directly instead!
.. versionadded:: 1.0
"""
try:
# Create an initial config, which will hold defaults & values from
# most config file locations (all but runtime.) Used to inform
# loading & parsing behavior.
self.create_config()
# Parse the given ARGV with our CLI parsing machinery, resulting in
# things like self.args (core args/flags), self.collection (the
# loaded namespace, which may be affected by the core flags) and
# self.tasks (the tasks requested for exec and their own
# args/flags)
self.parse_core(argv)
# Handle collection concerns including project config
self.parse_collection()
# Parse remainder of argv as task-related input
self.parse_tasks()
# End of parsing (typically bailout stuff like --list, --help)
self.parse_cleanup()
# Update the earlier Config with new values from the parse step -
# runtime config file contents and flag-derived overrides (e.g. for
# run()'s echo, warn, etc options.)
self.update_config()
# Create an Executor, passing in the data resulting from the prior
# steps, then tell it to execute the tasks.
self.execute()
except (UnexpectedExit, Exit, ParseError) as e:
debug("Received a possibly-skippable exception: {!r}".format(e))
# Print error messages from parser, runner, etc if necessary;
# prevents messy traceback but still clues interactive user into
# problems.
if isinstance(e, ParseError):
print(e, file=sys.stderr)
if isinstance(e, Exit) and e.message:
print(e.message, file=sys.stderr)
if isinstance(e, UnexpectedExit) and e.result.hide:
print(e, file=sys.stderr, end="")
# Terminate execution unless we were told not to.
if exit:
if isinstance(e, UnexpectedExit):
code = e.result.exited
elif isinstance(e, Exit):
code = e.code
elif isinstance(e, ParseError):
code = 1
sys.exit(code)
else:
debug("Invoked as run(..., exit=False), ignoring exception")
except KeyboardInterrupt:
sys.exit(1) # Same behavior as Python itself outside of REPL
def parse_core(self, argv: Optional[List[str]]) -> None:
debug("argv given to Program.run: {!r}".format(argv))
self.normalize_argv(argv)
# Obtain core args (sets self.core)
self.parse_core_args()
debug("Finished parsing core args")
# Set interpreter bytecode-writing flag
sys.dont_write_bytecode = not self.args["write-pyc"].value
# Enable debugging from here on out, if debug flag was given.
# (Prior to this point, debugging requires setting INVOKE_DEBUG).
if self.args.debug.value:
enable_logging()
# Short-circuit if --version
if self.args.version.value:
debug("Saw --version, printing version & exiting")
self.print_version()
raise Exit
# Print (dynamic, no tasks required) completion script if requested
if self.args["print-completion-script"].value:
print_completion_script(
shell=self.args["print-completion-script"].value,
names=self.binary_names,
)
raise Exit
def parse_collection(self) -> None:
"""
Load a tasks collection & project-level config.
.. versionadded:: 1.0
"""
# Load a collection of tasks unless one was already set.
if self.namespace is not None:
debug(
"Program was given default namespace, not loading collection"
)
self.collection = self.namespace
else:
debug(
"No default namespace provided, trying to load one from disk"
) # noqa
# If no bundled namespace & --help was given, just print it and
# exit. (If we did have a bundled namespace, core --help will be
# handled *after* the collection is loaded & parsing is done.)
if self.args.help.value is True:
debug(
"No bundled namespace & bare --help given; printing help."
)
self.print_help()
raise Exit
self.load_collection()
# Set these up for potential use later when listing tasks
# TODO: be nice if these came from the config...! Users would love to
# say they default to nested for example. Easy 2.x feature-add.
self.list_root: Optional[str] = None
self.list_depth: Optional[int] = None
self.list_format = "flat"
self.scoped_collection = self.collection
# TODO: load project conf, if possible, gracefully
def parse_cleanup(self) -> None:
"""
Post-parsing, pre-execution steps such as --help, --list, etc.
.. versionadded:: 1.0
"""
halp = self.args.help.value
# Core (no value given) --help output (only when bundled namespace)
if halp is True:
debug("Saw bare --help, printing help & exiting")
self.print_help()
raise Exit
# Print per-task help, if necessary
if halp:
if halp in self.parser.contexts:
msg = "Saw --help <taskname>, printing per-task help & exiting"
debug(msg)
self.print_task_help(halp)
raise Exit
else:
# TODO: feels real dumb to factor this out of Parser, but...we
# should?
raise ParseError("No idea what '{}' is!".format(halp))
# Print discovered tasks if necessary
list_root = self.args.list.value # will be True or string
self.list_format = self.args["list-format"].value
self.list_depth = self.args["list-depth"].value
if list_root:
# Not just --list, but --list some-root - do moar work
if isinstance(list_root, str):
self.list_root = list_root
try:
sub = self.collection.subcollection_from_path(list_root)
self.scoped_collection = sub
except KeyError:
msg = "Sub-collection '{}' not found!"
raise Exit(msg.format(list_root))
self.list_tasks()
raise Exit
# Print completion helpers if necessary
if self.args.complete.value:
complete(
names=self.binary_names,
core=self.core,
initial_context=self.initial_context,
collection=self.collection,
# NOTE: can't reuse self.parser as it has likely been mutated
# between when it was set and now.
parser=self._make_parser(),
)
# Fallback behavior if no tasks were given & no default specified
# (mostly a subroutine for overriding purposes)
# NOTE: when there is a default task, Executor will select it when no
# tasks were found in CLI parsing.
if not self.tasks and not self.collection.default:
self.no_tasks_given()
def no_tasks_given(self) -> None:
debug(
"No tasks specified for execution and no default task; printing global help as fallback" # noqa
)
self.print_help()
raise Exit
def execute(self) -> None:
"""
Hand off data & tasks-to-execute specification to an `.Executor`.
.. note::
Client code just wanting a different `.Executor` subclass can just
set ``executor_class`` in `.__init__`, or override
``tasks.executor_class`` anywhere in the :ref:`config system
<default-values>` (which may allow you to avoid using a custom
Program entirely).
.. versionadded:: 1.0
"""
klass = self.executor_class
config_path = self.config.tasks.executor_class
if config_path is not None:
# TODO: why the heck is this not builtin to importlib?
module_path, _, class_name = config_path.rpartition(".")
# TODO: worth trying to wrap both of these and raising ImportError
# for cases where module exists but class name does not? More
# "normal" but also its own possible source of bugs/confusion...
module = import_module(module_path)
klass = getattr(module, class_name)
executor = klass(self.collection, self.config, self.core)
executor.execute(*self.tasks)
def normalize_argv(self, argv: Optional[List[str]]) -> None:
"""
Massages ``argv`` into a useful list of strings.
**If None** (the default), uses `sys.argv`.
**If a non-string iterable**, uses that in place of `sys.argv`.
**If a string**, performs a `str.split` and then executes with the
result. (This is mostly a convenience; when in doubt, use a list.)
Sets ``self.argv`` to the result.
.. versionadded:: 1.0
"""
if argv is None:
argv = sys.argv
debug("argv was None; using sys.argv: {!r}".format(argv))
elif isinstance(argv, str):
argv = argv.split()
debug("argv was string-like; splitting: {!r}".format(argv))
self.argv = argv
@property
def name(self) -> str:
"""
Derive program's human-readable name based on `.binary`.
.. versionadded:: 1.0
"""
return self._name or self.binary.capitalize()
@property
def called_as(self) -> str:
"""
Returns the program name we were actually called as.
Specifically, this is the (Python's os module's concept of a) basename
of the first argument in the parsed argument vector.
.. versionadded:: 1.2
"""
# XXX: defaults to empty string if 'argv' is '[]' or 'None'
return os.path.basename(self.argv[0]) if self.argv else ""
@property
def binary(self) -> str:
"""
Derive program's help-oriented binary name(s) from init args & argv.
.. versionadded:: 1.0
"""
return self._binary or self.called_as
@property
def binary_names(self) -> List[str]:
"""
Derive program's completion-oriented binary name(s) from args & argv.
.. versionadded:: 1.2
"""
return self._binary_names or [self.called_as]
# TODO 3.0: ugh rename this or core_args, they are too confusing
@property
def args(self) -> "Lexicon":
"""
Obtain core program args from ``self.core`` parse result.
.. versionadded:: 1.0
"""
return self.core[0].args
@property
def initial_context(self) -> ParserContext:
"""
The initial parser context, aka core program flags.
The specific arguments contained therein will differ depending on
whether a bundled namespace was specified in `.__init__`.
.. versionadded:: 1.0
"""
args = self.core_args()
if self.namespace is None:
args += self.task_args()
return ParserContext(args=args)
def print_version(self) -> None:
print("{} {}".format(self.name, self.version or "unknown"))
def print_help(self) -> None:
usage_suffix = "task1 [--task1-opts] ... taskN [--taskN-opts]"
if self.namespace is not None:
usage_suffix = "<subcommand> [--subcommand-opts] ..."
print("Usage: {} [--core-opts] {}".format(self.binary, usage_suffix))
print("")
print("Core options:")
print("")
self.print_columns(self.initial_context.help_tuples())
if self.namespace is not None:
self.list_tasks()
def parse_core_args(self) -> None:
"""
Filter out core args, leaving any tasks or their args for later.
Sets ``self.core`` to the `.ParseResult` from this step.
.. versionadded:: 1.0
"""
debug("Parsing initial context (core args)")
parser = Parser(initial=self.initial_context, ignore_unknown=True)
self.core = parser.parse_argv(self.argv[1:])
msg = "Core-args parse result: {!r} & unparsed: {!r}"
debug(msg.format(self.core, self.core.unparsed))
def load_collection(self) -> None:
"""
Load a task collection based on parsed core args, or die trying.
.. versionadded:: 1.0
"""
# NOTE: start, coll_name both fall back to configuration values within
# Loader (which may, however, get them from our config.)
start = self.args["search-root"].value
loader = self.loader_class( # type: ignore
config=self.config, start=start
)
coll_name = self.args.collection.value
try:
module, parent = loader.load(coll_name)
# This is the earliest we can load project config, so we should -
# allows project config to affect the task parsing step!
# TODO: is it worth merging these set- and load- methods? May
# require more tweaking of how things behave in/after __init__.
self.config.set_project_location(parent)
self.config.load_project()
self.collection = Collection.from_module(
module,
loaded_from=parent,
auto_dash_names=self.config.tasks.auto_dash_names,
)
except CollectionNotFound as e:
raise Exit("Can't find any collection named {!r}!".format(e.name))
def _update_core_context(
self, context: ParserContext, new_args: Dict[str, Any]
) -> None:
# Update core context w/ core_via_task args, if and only if the
# via-task version of the arg was truly given a value.
# TODO: push this into an Argument-aware Lexicon subclass and
# .update()?
for key, arg in new_args.items():
if arg.got_value:
context.args[key]._value = arg._value
def _make_parser(self) -> Parser:
return Parser(
initial=self.initial_context,
contexts=self.collection.to_contexts(
ignore_unknown_help=self.config.tasks.ignore_unknown_help
),
)
def parse_tasks(self) -> None:
"""
Parse leftover args, which are typically tasks & per-task args.
Sets ``self.parser`` to the parser used, ``self.tasks`` to the
parsed per-task contexts, and ``self.core_via_tasks`` to a context
holding any core flags seen within the task contexts.
Also modifies ``self.core`` to include the data from ``core_via_tasks``
(so that it correctly reflects any supplied core flags regardless of
where they appeared).
.. versionadded:: 1.0
"""
self.parser = self._make_parser()
debug("Parsing tasks against {!r}".format(self.collection))
result = self.parser.parse_argv(self.core.unparsed)
self.core_via_tasks = result.pop(0)
self._update_core_context(
context=self.core[0], new_args=self.core_via_tasks.args
)
self.tasks = result
debug("Resulting task contexts: {!r}".format(self.tasks))
def print_task_help(self, name: str) -> None:
"""
Print help for a specific task, e.g. ``inv --help <taskname>``.
.. versionadded:: 1.0
"""
# Setup
ctx = self.parser.contexts[name]
tuples = ctx.help_tuples()
docstring = inspect.getdoc(self.collection[name])
header = "Usage: {} [--core-opts] {} {}[other tasks here ...]"
opts = "[--options] " if tuples else ""
print(header.format(self.binary, name, opts))
print("")
print("Docstring:")
if docstring:
# Really wish textwrap worked better for this.
for line in docstring.splitlines():
if line.strip():
print(self.leading_indent + line)
else:
print("")
print("")
else:
print(self.leading_indent + "none")
print("")
print("Options:")
if tuples:
self.print_columns(tuples)
else:
print(self.leading_indent + "none")
print("")
def list_tasks(self) -> None:
# Short circuit if no tasks to show (Collection now implements bool)
focus = self.scoped_collection
if not focus:
msg = "No tasks found in collection '{}'!"
raise Exit(msg.format(focus.name))
# TODO: now that flat/nested are almost 100% unified, maybe rethink
# this a bit?
getattr(self, "list_{}".format(self.list_format))()
def list_flat(self) -> None:
pairs = self._make_pairs(self.scoped_collection)
self.display_with_columns(pairs=pairs)
def list_nested(self) -> None:
pairs = self._make_pairs(self.scoped_collection)
extra = "'*' denotes collection defaults"
self.display_with_columns(pairs=pairs, extra=extra)
def _make_pairs(
self,
coll: "Collection",
ancestors: Optional[List[str]] = None,
) -> List[Tuple[str, Optional[str]]]:
if ancestors is None:
ancestors = []
pairs = []
indent = len(ancestors) * self.indent
ancestor_path = ".".join(x for x in ancestors)
for name, task in sorted(coll.tasks.items()):
is_default = name == coll.default
# Start with just the name and just the aliases, no prefixes or
# dots.
displayname = name
aliases = list(map(coll.transform, sorted(task.aliases)))
# If displaying a sub-collection (or if we are displaying a given
# namespace/root), tack on some dots to make it clear these names
# require dotted paths to invoke.
if ancestors or self.list_root:
displayname = ".{}".format(displayname)
aliases = [".{}".format(x) for x in aliases]
# Nested? Indent, and add asterisks to default-tasks.
if self.list_format == "nested":
prefix = indent
if is_default:
displayname += "*"
# Flat? Prefix names and aliases with ancestor names to get full
# dotted path; and give default-tasks their collection name as the
# first alias.
if self.list_format == "flat":
prefix = ancestor_path
# Make sure leading dots are present for subcollections if
# scoped display
if prefix and self.list_root:
prefix = "." + prefix
aliases = [prefix + alias for alias in aliases]
if is_default and ancestors:
aliases.insert(0, prefix)
# Generate full name and help columns and add to pairs.
alias_str = " ({})".format(", ".join(aliases)) if aliases else ""
full = prefix + displayname + alias_str
pairs.append((full, helpline(task)))
# Determine whether we're at max-depth or not
truncate = self.list_depth and (len(ancestors) + 1) >= self.list_depth
for name, subcoll in sorted(coll.collections.items()):
displayname = name
if ancestors or self.list_root:
displayname = ".{}".format(displayname)
if truncate:
tallies = [
"{} {}".format(len(getattr(subcoll, attr)), attr)
for attr in ("tasks", "collections")
if getattr(subcoll, attr)
]
displayname += " [{}]".format(", ".join(tallies))
if self.list_format == "nested":
pairs.append((indent + displayname, helpline(subcoll)))
elif self.list_format == "flat" and truncate:
# NOTE: only adding coll-oriented pair if limiting by depth
pairs.append((ancestor_path + displayname, helpline(subcoll)))
# Recurse, if not already at max depth
if not truncate:
recursed_pairs = self._make_pairs(
coll=subcoll, ancestors=ancestors + [name]
)
pairs.extend(recursed_pairs)
return pairs
def list_json(self) -> None:
# Sanity: we can't cleanly honor the --list-depth argument without
# changing the data schema or otherwise acting strangely; and it also
# doesn't make a ton of sense to limit depth when the output is for a
# script to handle. So we just refuse, for now. TODO: find better way
if self.list_depth:
raise Exit(
"The --list-depth option is not supported with JSON format!"
) # noqa
# TODO: consider using something more formal re: the format this emits,
# eg json-schema or whatever. Would simplify the
# relatively-concise-but-only-human docs that currently describe this.
coll = self.scoped_collection
data = coll.serialized()
print(json.dumps(data))
def task_list_opener(self, extra: str = "") -> str:
root = self.list_root
depth = self.list_depth
specifier = " '{}'".format(root) if root else ""
tail = ""
if depth or extra:
depthstr = "depth={}".format(depth) if depth else ""
joiner = "; " if (depth and extra) else ""
tail = " ({}{}{})".format(depthstr, joiner, extra)
text = "Available{} tasks{}".format(specifier, tail)
# TODO: do use cases w/ bundled namespace want to display things like
# root and depth too? Leaving off for now...
if self.namespace is not None:
text = "Subcommands"
return text
def display_with_columns(
self, pairs: Sequence[Tuple[str, Optional[str]]], extra: str = ""
) -> None:
root = self.list_root
print("{}:\n".format(self.task_list_opener(extra=extra)))
self.print_columns(pairs)
# TODO: worth stripping this out for nested? since it's signified with
# asterisk there? ugggh
default = self.scoped_collection.default
if default:
specific = ""
if root:
specific = " '{}'".format(root)
default = ".{}".format(default)
# TODO: trim/prefix dots
print("Default{} task: {}\n".format(specific, default))
def print_columns(
self, tuples: Sequence[Tuple[str, Optional[str]]]
) -> None:
"""
Print tabbed columns from (name, help) ``tuples``.
Useful for listing tasks + docstrings, flags + help strings, etc.
.. versionadded:: 1.0
"""
# Calculate column sizes: don't wrap flag specs, give what's left over
# to the descriptions.
name_width = max(len(x[0]) for x in tuples)
desc_width = (
pty_size()[0]
- name_width
- self.leading_indent_width
- self.col_padding
- 1
)
wrapper = textwrap.TextWrapper(width=desc_width)
for name, help_str in tuples:
if help_str is None:
help_str = ""
# Wrap descriptions/help text
help_chunks = wrapper.wrap(help_str)
# Print flag spec + padding
name_padding = name_width - len(name)
spec = "".join(
(
self.leading_indent,
name,
name_padding * " ",
self.col_padding * " ",
)
)
# Print help text as needed
if help_chunks:
print(spec + help_chunks[0])
for chunk in help_chunks[1:]:
print((" " * len(spec)) + chunk)
else:
print(spec.rstrip())
print("")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,519 @@
"""
This module contains the core `.Task` class & convenience decorators used to
generate new tasks.
"""
import inspect
import types
from copy import deepcopy
from functools import update_wrapper
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Generic,
Iterable,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)
from .context import Context
from .parser import Argument, translate_underscores
if TYPE_CHECKING:
from inspect import Signature
from .config import Config
T = TypeVar("T", bound=Callable)
class Task(Generic[T]):
"""
Core object representing an executable task & its argument specification.
For the most part, this object is a clearinghouse for all of the data that
may be supplied to the `@task <invoke.tasks.task>` decorator, such as
``name``, ``aliases``, ``positional`` etc, which appear as attributes.
In addition, instantiation copies some introspection/documentation friendly
metadata off of the supplied ``body`` object, such as ``__doc__``,
``__name__`` and ``__module__``, allowing it to "appear as" ``body`` for
most intents and purposes.
.. versionadded:: 1.0
"""
# TODO: store these kwarg defaults central, refer to those values both here
# and in @task.
# TODO: allow central per-session / per-taskmodule control over some of
# them, e.g. (auto_)positional, auto_shortflags.
# NOTE: we shadow __builtins__.help here on purpose - obfuscating to avoid
# it feels bad, given the builtin will never actually be in play anywhere
# except a debug shell whose frame is exactly inside this class.
def __init__(
self,
body: Callable,
name: Optional[str] = None,
aliases: Iterable[str] = (),
positional: Optional[Iterable[str]] = None,
optional: Iterable[str] = (),
default: bool = False,
auto_shortflags: bool = True,
help: Optional[Dict[str, Any]] = None,
pre: Optional[Union[List[str], str]] = None,
post: Optional[Union[List[str], str]] = None,
autoprint: bool = False,
iterable: Optional[Iterable[str]] = None,
incrementable: Optional[Iterable[str]] = None,
) -> None:
# Real callable
self.body = body
update_wrapper(self, self.body)
# Copy a bunch of special properties from the body for the benefit of
# Sphinx autodoc or other introspectors.
self.__doc__ = getattr(body, "__doc__", "")
self.__name__ = getattr(body, "__name__", "")
self.__module__ = getattr(body, "__module__", "")
# Default name, alternate names, and whether it should act as the
# default for its parent collection
self._name = name
self.aliases = aliases
self.is_default = default
# Arg/flag/parser hints
self.positional = self.fill_implicit_positionals(positional)
self.optional = tuple(optional)
self.iterable = iterable or []
self.incrementable = incrementable or []
self.auto_shortflags = auto_shortflags
self.help = (help or {}).copy()
# Call chain bidness
self.pre = pre or []
self.post = post or []
self.times_called = 0
# Whether to print return value post-execution
self.autoprint = autoprint
@property
def name(self) -> str:
return self._name or self.__name__
def __repr__(self) -> str:
aliases = ""
if self.aliases:
aliases = " ({})".format(", ".join(self.aliases))
return "<Task {!r}{}>".format(self.name, aliases)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Task) or self.name != other.name:
return False
# Functions do not define __eq__ but func_code objects apparently do.
# (If we're wrapping some other callable, they will be responsible for
# defining equality on their end.)
if self.body == other.body:
return True
else:
try:
return self.body.__code__ == other.body.__code__
except AttributeError:
return False
def __hash__(self) -> int:
# Presumes name and body will never be changed. Hrm.
# Potentially cleaner to just not use Tasks as hash keys, but let's do
# this for now.
return hash(self.name) + hash(self.body)
def __call__(self, *args: Any, **kwargs: Any) -> T:
# Guard against calling tasks with no context.
if not isinstance(args[0], Context):
err = "Task expected a Context as its first arg, got {} instead!"
# TODO: raise a custom subclass _of_ TypeError instead
raise TypeError(err.format(type(args[0])))
result = self.body(*args, **kwargs)
self.times_called += 1
return result
@property
def called(self) -> bool:
return self.times_called > 0
def argspec(self, body: Callable) -> "Signature":
"""
Returns a modified `inspect.Signature` based on that of ``body``.
:returns:
an `inspect.Signature` matching that of ``body``, but with the
initial context argument removed.
:raises TypeError:
if the task lacks an initial positional `.Context` argument.
.. versionadded:: 1.0
.. versionchanged:: 2.0
Changed from returning a two-tuple of ``(arg_names, spec_dict)`` to
returning an `inspect.Signature`.
"""
# Handle callable-but-not-function objects
func = (
body
if isinstance(body, types.FunctionType)
else body.__call__ # type: ignore
)
# Rebuild signature with first arg dropped, or die usefully(ish trying
sig = inspect.signature(func)
params = list(sig.parameters.values())
# TODO: this ought to also check if an extant 1st param _was_ a Context
# arg, and yell similarly if not.
if not len(params):
# TODO: see TODO under __call__, this should be same type
raise TypeError("Tasks must have an initial Context argument!")
return sig.replace(parameters=params[1:])
def fill_implicit_positionals(
self, positional: Optional[Iterable[str]]
) -> Iterable[str]:
# If positionals is None, everything lacking a default
# value will be automatically considered positional.
if positional is None:
positional = [
x.name
for x in self.argspec(self.body).parameters.values()
if x.default is inspect.Signature.empty
]
return positional
def arg_opts(
self, name: str, default: str, taken_names: Set[str]
) -> Dict[str, Any]:
opts: Dict[str, Any] = {}
# Whether it's positional or not
opts["positional"] = name in self.positional
# Whether it is a value-optional flag
opts["optional"] = name in self.optional
# Whether it should be of an iterable (list) kind
if name in self.iterable:
opts["kind"] = list
# If user gave a non-None default, hopefully they know better
# than us what they want here (and hopefully it offers the list
# protocol...) - otherwise supply useful default
opts["default"] = default if default is not None else []
# Whether it should increment its value or not
if name in self.incrementable:
opts["incrementable"] = True
# Argument name(s) (replace w/ dashed version if underscores present,
# and move the underscored version to be the attr_name instead.)
original_name = name # For reference in eg help=
if "_" in name:
opts["attr_name"] = name
name = translate_underscores(name)
names = [name]
if self.auto_shortflags:
# Must know what short names are available
for char in name:
if not (char == name or char in taken_names):
names.append(char)
break
opts["names"] = names
# Handle default value & kind if possible
if default not in (None, inspect.Signature.empty):
# TODO: allow setting 'kind' explicitly.
# NOTE: skip setting 'kind' if optional is True + type(default) is
# bool; that results in a nonsensical Argument which gives the
# parser grief in a few ways.
kind = type(default)
if not (opts["optional"] and kind is bool):
opts["kind"] = kind
opts["default"] = default
# Help
for possibility in name, original_name:
if possibility in self.help:
opts["help"] = self.help.pop(possibility)
break
return opts
def get_arguments(
self, ignore_unknown_help: Optional[bool] = None
) -> List[Argument]:
"""
Return a list of Argument objects representing this task's signature.
:param bool ignore_unknown_help:
Controls whether unknown help flags cause errors. See the config
option by the same name for details.
.. versionadded:: 1.0
.. versionchanged:: 1.7
Added the ``ignore_unknown_help`` kwarg.
"""
# Core argspec
sig = self.argspec(self.body)
# Prime the list of all already-taken names (mostly for help in
# choosing auto shortflags)
taken_names = set(sig.parameters.keys())
# Build arg list (arg_opts will take care of setting up shortnames,
# etc)
args = []
for param in sig.parameters.values():
new_arg = Argument(
**self.arg_opts(param.name, param.default, taken_names)
)
args.append(new_arg)
# Update taken_names list with new argument's full name list
# (which may include new shortflags) so subsequent Argument
# creation knows what's taken.
taken_names.update(set(new_arg.names))
# If any values were leftover after consuming a 'help' dict, it implies
# the user messed up & had a typo or similar. Let's explode.
if self.help and not ignore_unknown_help:
raise ValueError(
"Help field was set for param(s) that don't exist: {}".format(
list(self.help.keys())
)
)
# Now we need to ensure positionals end up in the front of the list, in
# order given in self.positionals, so that when Context consumes them,
# this order is preserved.
for posarg in reversed(list(self.positional)):
for i, arg in enumerate(args):
if arg.name == posarg:
args.insert(0, args.pop(i))
break
return args
def task(*args: Any, **kwargs: Any) -> Callable:
"""
Marks wrapped callable object as a valid Invoke task.
May be called without any parentheses if no extra options need to be
specified. Otherwise, the following keyword arguments are allowed in the
parenthese'd form:
* ``name``: Default name to use when binding to a `.Collection`. Useful for
avoiding Python namespace issues (i.e. when the desired CLI level name
can't or shouldn't be used as the Python level name.)
* ``aliases``: Specify one or more aliases for this task, allowing it to be
invoked as multiple different names. For example, a task named ``mytask``
with a simple ``@task`` wrapper may only be invoked as ``"mytask"``.
Changing the decorator to be ``@task(aliases=['myothertask'])`` allows
invocation as ``"mytask"`` *or* ``"myothertask"``.
* ``positional``: Iterable overriding the parser's automatic "args with no
default value are considered positional" behavior. If a list of arg
names, no args besides those named in this iterable will be considered
positional. (This means that an empty list will force all arguments to be
given as explicit flags.)
* ``optional``: Iterable of argument names, declaring those args to
have :ref:`optional values <optional-values>`. Such arguments may be
given as value-taking options (e.g. ``--my-arg=myvalue``, wherein the
task is given ``"myvalue"``) or as Boolean flags (``--my-arg``, resulting
in ``True``).
* ``iterable``: Iterable of argument names, declaring them to :ref:`build
iterable values <iterable-flag-values>`.
* ``incrementable``: Iterable of argument names, declaring them to
:ref:`increment their values <incrementable-flag-values>`.
* ``default``: Boolean option specifying whether this task should be its
collection's default task (i.e. called if the collection's own name is
given.)
* ``auto_shortflags``: Whether or not to automatically create short
flags from task options; defaults to True.
* ``help``: Dict mapping argument names to their help strings. Will be
displayed in ``--help`` output. For arguments containing underscores
(which are transformed into dashes on the CLI by default), either the
dashed or underscored version may be supplied here.
* ``pre``, ``post``: Lists of task objects to execute prior to, or after,
the wrapped task whenever it is executed.
* ``autoprint``: Boolean determining whether to automatically print this
task's return value to standard output when invoked directly via the CLI.
Defaults to False.
* ``klass``: Class to instantiate/return. Defaults to `.Task`.
If any non-keyword arguments are given, they are taken as the value of the
``pre`` kwarg for convenience's sake. (It is an error to give both
``*args`` and ``pre`` at the same time.)
.. versionadded:: 1.0
.. versionchanged:: 1.1
Added the ``klass`` keyword argument.
"""
klass: Type[Task] = kwargs.pop("klass", Task)
# @task -- no options were (probably) given.
if len(args) == 1 and callable(args[0]) and not isinstance(args[0], Task):
return klass(args[0], **kwargs)
# @task(pre, tasks, here)
if args:
if "pre" in kwargs:
raise TypeError(
"May not give *args and 'pre' kwarg simultaneously!"
)
kwargs["pre"] = args
def inner(body: Callable) -> Task[T]:
_task = klass(body, **kwargs)
return _task
# update_wrapper(inner, klass)
return inner
class Call:
"""
Represents a call/execution of a `.Task` with given (kw)args.
Similar to `~functools.partial` with some added functionality (such as the
delegation to the inner task, and optional tracking of the name it's being
called by.)
.. versionadded:: 1.0
"""
def __init__(
self,
task: "Task",
called_as: Optional[str] = None,
args: Optional[Tuple[str, ...]] = None,
kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Create a new `.Call` object.
:param task: The `.Task` object to be executed.
:param str called_as:
The name the task is being called as, e.g. if it was called by an
alias or other rebinding. Defaults to ``None``, aka, the task was
referred to by its default name.
:param tuple args:
Positional arguments to call with, if any. Default: ``None``.
:param dict kwargs:
Keyword arguments to call with, if any. Default: ``None``.
"""
self.task = task
self.called_as = called_as
self.args = args or tuple()
self.kwargs = kwargs or dict()
# TODO: just how useful is this? feels like maybe overkill magic
def __getattr__(self, name: str) -> Any:
return getattr(self.task, name)
def __deepcopy__(self, memo: object) -> "Call":
return self.clone()
def __repr__(self) -> str:
aka = ""
if self.called_as is not None and self.called_as != self.task.name:
aka = " (called as: {!r})".format(self.called_as)
return "<{} {!r}{}, args: {!r}, kwargs: {!r}>".format(
self.__class__.__name__,
self.task.name,
aka,
self.args,
self.kwargs,
)
def __eq__(self, other: object) -> bool:
# NOTE: Not comparing 'called_as'; a named call of a given Task with
# same args/kwargs should be considered same as an unnamed call of the
# same Task with the same args/kwargs (e.g. pre/post task specified w/o
# name). Ditto tasks with multiple aliases.
for attr in "task args kwargs".split():
if getattr(self, attr) != getattr(other, attr):
return False
return True
def make_context(self, config: "Config") -> Context:
"""
Generate a `.Context` appropriate for this call, with given config.
.. versionadded:: 1.0
"""
return Context(config=config)
def clone_data(self) -> Dict[str, Any]:
"""
Return keyword args suitable for cloning this call into another.
.. versionadded:: 1.1
"""
return dict(
task=self.task,
called_as=self.called_as,
args=deepcopy(self.args),
kwargs=deepcopy(self.kwargs),
)
def clone(
self,
into: Optional[Type["Call"]] = None,
with_: Optional[Dict[str, Any]] = None,
) -> "Call":
"""
Return a standalone copy of this Call.
Useful when parameterizing task executions.
:param into:
A subclass to generate instead of the current class. Optional.
:param dict with_:
A dict of additional keyword arguments to use when creating the new
clone; typically used when cloning ``into`` a subclass that has
extra args on top of the base class. Optional.
.. note::
This dict is used to ``.update()`` the original object's data
(the return value from its `clone_data`), so in the event of
a conflict, values in ``with_`` will win out.
.. versionadded:: 1.0
.. versionchanged:: 1.1
Added the ``with_`` kwarg.
"""
klass = into if into is not None else self.__class__
data = self.clone_data()
if with_ is not None:
data.update(with_)
return klass(**data)
def call(task: "Task", *args: Any, **kwargs: Any) -> "Call":
"""
Describes execution of a `.Task`, typically with pre-supplied arguments.
Useful for setting up :ref:`pre/post task invocations
<parameterizing-pre-post-tasks>`. It's actually just a convenient wrapper
around the `.Call` class, which may be used directly instead if desired.
For example, here's two build-like tasks that both refer to a ``setup``
pre-task, one with no baked-in argument values (and thus no need to use
`.call`), and one that toggles a boolean flag::
@task
def setup(c, clean=False):
if clean:
c.run("rm -rf target")
# ... setup things here ...
c.run("tar czvf target.tgz target")
@task(pre=[setup])
def build(c):
c.run("build, accounting for leftover files...")
@task(pre=[call(setup, clean=True)])
def clean_build(c):
c.run("build, assuming clean slate...")
Please see the constructor docs for `.Call` for details - this function's
``args`` and ``kwargs`` map directly to the same arguments as in that
method.
.. versionadded:: 1.0
"""
return Call(task, args=args, kwargs=kwargs)

View File

@@ -0,0 +1,248 @@
"""
Utility functions surrounding terminal devices & I/O.
Much of this code performs platform-sensitive branching, e.g. Windows support.
This is its own module to abstract away what would otherwise be distracting
logic-flow interruptions.
"""
from contextlib import contextmanager
from typing import Generator, IO, Optional, Tuple
import os
import select
import sys
# TODO: move in here? They're currently platform-agnostic...
from .util import has_fileno, isatty
WINDOWS = sys.platform == "win32"
"""
Whether or not the current platform appears to be Windows in nature.
Note that Cygwin's Python is actually close enough to "real" UNIXes that it
doesn't need (or want!) to use PyWin32 -- so we only test for literal Win32
setups (vanilla Python, ActiveState etc) here.
.. versionadded:: 1.0
"""
if sys.platform == "win32":
import msvcrt
from ctypes import (
Structure,
c_ushort,
windll,
POINTER,
byref,
)
from ctypes.wintypes import HANDLE, _COORD, _SMALL_RECT
else:
import fcntl
import struct
import termios
import tty
if sys.platform == "win32":
def _pty_size() -> Tuple[Optional[int], Optional[int]]:
class CONSOLE_SCREEN_BUFFER_INFO(Structure):
_fields_ = [
("dwSize", _COORD),
("dwCursorPosition", _COORD),
("wAttributes", c_ushort),
("srWindow", _SMALL_RECT),
("dwMaximumWindowSize", _COORD),
]
GetStdHandle = windll.kernel32.GetStdHandle
GetConsoleScreenBufferInfo = windll.kernel32.GetConsoleScreenBufferInfo
GetStdHandle.restype = HANDLE
GetConsoleScreenBufferInfo.argtypes = [
HANDLE,
POINTER(CONSOLE_SCREEN_BUFFER_INFO),
]
hstd = GetStdHandle(-11) # STD_OUTPUT_HANDLE = -11
csbi = CONSOLE_SCREEN_BUFFER_INFO()
ret = GetConsoleScreenBufferInfo(hstd, byref(csbi))
if ret:
sizex = csbi.srWindow.Right - csbi.srWindow.Left + 1
sizey = csbi.srWindow.Bottom - csbi.srWindow.Top + 1
return sizex, sizey
else:
return (None, None)
else:
def _pty_size() -> Tuple[Optional[int], Optional[int]]:
"""
Suitable for most POSIX platforms.
.. versionadded:: 1.0
"""
# Sentinel values to be replaced w/ defaults by caller
size = (None, None)
# We want two short unsigned integers (rows, cols)
# Note: TIOCGWINSZ struct contains 4 unsigned shorts, 2 unused
fmt = "HHHH"
# Create an empty (zeroed) buffer for ioctl to map onto. Yay for C!
buf = struct.pack(fmt, 0, 0, 0, 0)
# Call TIOCGWINSZ to get window size of stdout, returns our filled
# buffer
try:
result = fcntl.ioctl(sys.stdout, termios.TIOCGWINSZ, buf)
# Unpack buffer back into Python data types
# NOTE: this unpack gives us rows x cols, but we return the
# inverse.
rows, cols, *_ = struct.unpack(fmt, result)
return (cols, rows)
# Fallback to emptyish return value in various failure cases:
# * sys.stdout being monkeypatched, such as in testing, and lacking
# * .fileno
# * sys.stdout having a .fileno but not actually being attached to a
# * TTY
# * termios not having a TIOCGWINSZ attribute (happens sometimes...)
# * other situations where ioctl doesn't explode but the result isn't
# something unpack can deal with
except (struct.error, TypeError, IOError, AttributeError):
pass
return size
def pty_size() -> Tuple[int, int]:
"""
Determine current local pseudoterminal dimensions.
:returns:
A ``(num_cols, num_rows)`` two-tuple describing PTY size. Defaults to
``(80, 24)`` if unable to get a sensible result dynamically.
.. versionadded:: 1.0
"""
cols, rows = _pty_size()
# TODO: make defaults configurable?
return (cols or 80, rows or 24)
def stdin_is_foregrounded_tty(stream: IO) -> bool:
"""
Detect if given stdin ``stream`` seems to be in the foreground of a TTY.
Specifically, compares the current Python process group ID to that of the
stream's file descriptor to see if they match; if they do not match, it is
likely that the process has been placed in the background.
This is used as a test to determine whether we should manipulate an active
stdin so it runs in a character-buffered mode; touching the terminal in
this way when the process is backgrounded, causes most shells to pause
execution.
.. note::
Processes that aren't attached to a terminal to begin with, will always
fail this test, as it starts with "do you have a real ``fileno``?".
.. versionadded:: 1.0
"""
if not has_fileno(stream):
return False
return os.getpgrp() == os.tcgetpgrp(stream.fileno())
def cbreak_already_set(stream: IO) -> bool:
# Explicitly not docstringed to remain private, for now. Eh.
# Checks whether tty.setcbreak appears to have already been run against
# ``stream`` (or if it would otherwise just not do anything).
# Used to effect idempotency for character-buffering a stream, which also
# lets us avoid multiple capture-then-restore cycles.
attrs = termios.tcgetattr(stream)
lflags, cc = attrs[3], attrs[6]
echo = bool(lflags & termios.ECHO)
icanon = bool(lflags & termios.ICANON)
# setcbreak sets ECHO and ICANON to 0/off, CC[VMIN] to 1-ish, and CC[VTIME]
# to 0-ish. If any of that is not true we can reasonably assume it has not
# yet been executed against this stream.
sentinels = (
not echo,
not icanon,
cc[termios.VMIN] in [1, b"\x01"],
cc[termios.VTIME] in [0, b"\x00"],
)
return all(sentinels)
@contextmanager
def character_buffered(
stream: IO,
) -> Generator[None, None, None]:
"""
Force local terminal ``stream`` be character, not line, buffered.
Only applies to Unix-based systems; on Windows this is a no-op.
.. versionadded:: 1.0
"""
if (
WINDOWS
or not isatty(stream)
or not stdin_is_foregrounded_tty(stream)
or cbreak_already_set(stream)
):
yield
else:
old_settings = termios.tcgetattr(stream)
tty.setcbreak(stream)
try:
yield
finally:
termios.tcsetattr(stream, termios.TCSADRAIN, old_settings)
def ready_for_reading(input_: IO) -> bool:
"""
Test ``input_`` to determine whether a read action will succeed.
:param input_: Input stream object (file-like).
:returns: ``True`` if a read should succeed, ``False`` otherwise.
.. versionadded:: 1.0
"""
# A "real" terminal stdin needs select/kbhit to tell us when it's ready for
# a nonblocking read().
# Otherwise, assume a "safer" file-like object that can be read from in a
# nonblocking fashion (e.g. a StringIO or regular file).
if not has_fileno(input_):
return True
if sys.platform == "win32":
return msvcrt.kbhit()
else:
reads, _, _ = select.select([input_], [], [], 0.0)
return bool(reads and reads[0] is input_)
def bytes_to_read(input_: IO) -> int:
"""
Query stream ``input_`` to see how many bytes may be readable.
.. note::
If we are unable to tell (e.g. if ``input_`` isn't a true file
descriptor or isn't a valid TTY) we fall back to suggesting reading 1
byte only.
:param input: Input stream object (file-like).
:returns: `int` number of bytes to read.
.. versionadded:: 1.0
"""
# NOTE: we have to check both possibilities here; situations exist where
# it's not a tty but has a fileno, or vice versa; neither is typically
# going to work re: ioctl().
if not WINDOWS and isatty(input_) and has_fileno(input_):
fionread = fcntl.ioctl(input_, termios.FIONREAD, b" ")
return int(struct.unpack("h", fionread)[0])
return 1

View File

@@ -0,0 +1,268 @@
from collections import namedtuple
from contextlib import contextmanager
from types import TracebackType
from typing import Any, Generator, List, IO, Optional, Tuple, Type, Union
import io
import logging
import os
import threading
import sys
# NOTE: This is the canonical location for commonly-used vendored modules,
# which is the only spot that performs this try/except to allow repackaged
# Invoke to function (e.g. distro packages which unvendor the vendored bits and
# thus must import our 'vendored' stuff from the overall environment.)
# All other uses of Lexicon, etc should do 'from .util import lexicon' etc.
# Saves us from having to update the same logic in a dozen places.
# TODO: would this make more sense to put _into_ invoke.vendor? That way, the
# import lines which now read 'from .util import <third party stuff>' would be
# more obvious. Requires packagers to leave invoke/vendor/__init__.py alone tho
try:
from .vendor.lexicon import Lexicon # noqa
from .vendor import yaml # noqa
except ImportError:
from lexicon import Lexicon # type: ignore[no-redef] # noqa
import yaml # type: ignore[no-redef] # noqa
LOG_FORMAT = "%(name)s.%(module)s.%(funcName)s: %(message)s"
def enable_logging() -> None:
logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT)
# Allow from-the-start debugging (vs toggled during load of tasks module) via
# shell env var.
if os.environ.get("INVOKE_DEBUG"):
enable_logging()
# Add top level logger functions to global namespace. Meh.
log = logging.getLogger("invoke")
debug = log.debug
def task_name_sort_key(name: str) -> Tuple[List[str], str]:
"""
Return key tuple for use sorting dotted task names, via e.g. `sorted`.
.. versionadded:: 1.0
"""
parts = name.split(".")
return (
# First group/sort by non-leaf path components. This keeps everything
# grouped in its hierarchy, and incidentally puts top-level tasks
# (whose non-leaf path set is the empty list) first, where we want them
parts[:-1],
# Then we sort lexicographically by the actual task name
parts[-1],
)
# TODO: Make part of public API sometime
@contextmanager
def cd(where: str) -> Generator[None, None, None]:
cwd = os.getcwd()
os.chdir(where)
try:
yield
finally:
os.chdir(cwd)
def has_fileno(stream: IO) -> bool:
"""
Cleanly determine whether ``stream`` has a useful ``.fileno()``.
.. note::
This function helps determine if a given file-like object can be used
with various terminal-oriented modules and functions such as `select`,
`termios`, and `tty`. For most of those, a fileno is all that is
required; they'll function even if ``stream.isatty()`` is ``False``.
:param stream: A file-like object.
:returns:
``True`` if ``stream.fileno()`` returns an integer, ``False`` otherwise
(this includes when ``stream`` lacks a ``fileno`` method).
.. versionadded:: 1.0
"""
try:
return isinstance(stream.fileno(), int)
except (AttributeError, io.UnsupportedOperation):
return False
def isatty(stream: IO) -> Union[bool, Any]:
"""
Cleanly determine whether ``stream`` is a TTY.
Specifically, first try calling ``stream.isatty()``, and if that fails
(e.g. due to lacking the method entirely) fallback to `os.isatty`.
.. note::
Most of the time, we don't actually care about true TTY-ness, but
merely whether the stream seems to have a fileno (per `has_fileno`).
However, in some cases (notably the use of `pty.fork` to present a
local pseudoterminal) we need to tell if a given stream has a valid
fileno but *isn't* tied to an actual terminal. Thus, this function.
:param stream: A file-like object.
:returns:
A boolean depending on the result of calling ``.isatty()`` and/or
`os.isatty`.
.. versionadded:: 1.0
"""
# If there *is* an .isatty, ask it.
if hasattr(stream, "isatty") and callable(stream.isatty):
return stream.isatty()
# If there wasn't, see if it has a fileno, and if so, ask os.isatty
elif has_fileno(stream):
return os.isatty(stream.fileno())
# If we got here, none of the above worked, so it's reasonable to assume
# the darn thing isn't a real TTY.
return False
def helpline(obj: object) -> Optional[str]:
"""
Yield an object's first docstring line, or None if there was no docstring.
.. versionadded:: 1.0
"""
docstring = obj.__doc__
if (
not docstring
or not docstring.strip()
or docstring == type(obj).__doc__
):
return None
return docstring.lstrip().splitlines()[0]
class ExceptionHandlingThread(threading.Thread):
"""
Thread handler making it easier for parent to handle thread exceptions.
Based in part on Fabric 1's ThreadHandler. See also Fabric GH issue #204.
When used directly, can be used in place of a regular ``threading.Thread``.
If subclassed, the subclass must do one of:
- supply ``target`` to ``__init__``
- define ``_run()`` instead of ``run()``
This is because this thread's entire point is to wrap behavior around the
thread's execution; subclasses could not redefine ``run()`` without
breaking that functionality.
.. versionadded:: 1.0
"""
def __init__(self, **kwargs: Any) -> None:
"""
Create a new exception-handling thread instance.
Takes all regular `threading.Thread` keyword arguments, via
``**kwargs`` for easier display of thread identity when raising
captured exceptions.
"""
super().__init__(**kwargs)
# No record of why, but Fabric used daemon threads ever since the
# switch from select.select, so let's keep doing that.
self.daemon = True
# Track exceptions raised in run()
self.kwargs = kwargs
# TODO: legacy cruft that needs to be removed
self.exc_info: Optional[
Union[
Tuple[Type[BaseException], BaseException, TracebackType],
Tuple[None, None, None],
]
] = None
def run(self) -> None:
try:
# Allow subclasses implemented using the "override run()'s body"
# approach to work, by using _run() instead of run(). If that
# doesn't appear to be the case, then assume we're being used
# directly and just use super() ourselves.
# XXX https://github.com/python/mypy/issues/1424
if hasattr(self, "_run") and callable(self._run): # type: ignore
# TODO: this could be:
# - io worker with no 'result' (always local)
# - tunnel worker, also with no 'result' (also always local)
# - threaded concurrent run(), sudo(), put(), etc, with a
# result (not necessarily local; might want to be a subproc or
# whatever eventually)
# TODO: so how best to conditionally add a "capture result
# value of some kind"?
# - update so all use cases use subclassing, add functionality
# alongside self.exception() that is for the result of _run()
# - split out class that does not care about result of _run()
# and let it continue acting like a normal thread (meh)
# - assume the run/sudo/etc case will use a queue inside its
# worker body, orthogonal to how exception handling works
self._run() # type: ignore
else:
super().run()
except BaseException:
# Store for actual reraising later
self.exc_info = sys.exc_info()
# And log now, in case we never get to later (e.g. if executing
# program is hung waiting for us to do something)
msg = "Encountered exception {!r} in thread for {!r}"
# Name is either target function's dunder-name, or just "_run" if
# we were run subclass-wise.
name = "_run"
if "target" in self.kwargs:
name = self.kwargs["target"].__name__
debug(msg.format(self.exc_info[1], name)) # noqa
def exception(self) -> Optional["ExceptionWrapper"]:
"""
If an exception occurred, return an `.ExceptionWrapper` around it.
:returns:
An `.ExceptionWrapper` managing the result of `sys.exc_info`, if an
exception was raised during thread execution. If no exception
occurred, returns ``None`` instead.
.. versionadded:: 1.0
"""
if self.exc_info is None:
return None
return ExceptionWrapper(self.kwargs, *self.exc_info)
@property
def is_dead(self) -> bool:
"""
Returns ``True`` if not alive and has a stored exception.
Used to detect threads that have excepted & shut down.
.. versionadded:: 1.0
"""
# NOTE: it seems highly unlikely that a thread could still be
# is_alive() but also have encountered an exception. But hey. Why not
# be thorough?
return (not self.is_alive()) and self.exc_info is not None
def __repr__(self) -> str:
# TODO: beef this up more
return str(self.kwargs["target"].__name__)
# NOTE: ExceptionWrapper defined here, not in exceptions.py, to avoid circular
# dependency issues (e.g. Failure subclasses need to use some bits from this
# module...)
#: A namedtuple wrapping a thread-borne exception & that thread's arguments.
#: Mostly used as an intermediate between `.ExceptionHandlingThread` (which
#: preserves initial exceptions) and `.ThreadException` (which holds 1..N such
#: exceptions, as typically multiple threads are involved.)
ExceptionWrapper = namedtuple(
"ExceptionWrapper", "kwargs type value traceback"
)

View File

@@ -0,0 +1,4 @@
from .machine import (StateMachine, state, transition,
InvalidConfiguration, InvalidTransition,
GuardNotSatisfied, ForkedTransition)

View File

@@ -0,0 +1,8 @@
import sys
if sys.version_info >= (3,):
def callable(obj):
return hasattr(obj, '__call__')
else:
callable = callable

View File

@@ -0,0 +1,270 @@
import re
import inspect
from .backwardscompat import callable
# metaclass implementation idea from
# http://blog.ianbicking.org/more-on-python-metaprogramming-comment-14.html
_transition_gatherer = []
def transition(event, from_, to, action=None, guard=None):
_transition_gatherer.append([event, from_, to, action, guard])
_state_gatherer = []
def state(name, enter=None, exit=None):
_state_gatherer.append([name, enter, exit])
class MetaStateMachine(type):
def __new__(cls, name, bases, dictionary):
global _transition_gatherer, _state_gatherer
Machine = super(MetaStateMachine, cls).__new__(cls, name, bases, dictionary)
Machine._class_transitions = []
Machine._class_states = {}
for s in _state_gatherer:
Machine._add_class_state(*s)
for i in _transition_gatherer:
Machine._add_class_transition(*i)
_transition_gatherer = []
_state_gatherer = []
return Machine
StateMachineBase = MetaStateMachine('StateMachineBase', (object, ), {})
class StateMachine(StateMachineBase):
def __init__(self):
self._bring_definitions_to_object_level()
self._inject_into_parts()
self._validate_machine_definitions()
if callable(self.initial_state):
self.initial_state = self.initial_state()
self._current_state_object = self._state_by_name(self.initial_state)
self._current_state_object.run_enter(self)
self._create_state_getters()
def __new__(cls, *args, **kwargs):
obj = super(StateMachine, cls).__new__(cls)
obj._states = {}
obj._transitions = []
return obj
def _bring_definitions_to_object_level(self):
self._states.update(self.__class__._class_states)
self._transitions.extend(self.__class__._class_transitions)
def _inject_into_parts(self):
for collection in [self._states.values(), self._transitions]:
for component in collection:
component.machine = self
def _validate_machine_definitions(self):
if len(self._states) < 2:
raise InvalidConfiguration('There must be at least two states')
if not getattr(self, 'initial_state', None):
raise InvalidConfiguration('There must exist an initial state')
@classmethod
def _add_class_state(cls, name, enter, exit):
cls._class_states[name] = _State(name, enter, exit)
def add_state(self, name, enter=None, exit=None):
state = _State(name, enter, exit)
setattr(self, state.getter_name(), state.getter_method().__get__(self, self.__class__))
self._states[name] = state
def _current_state_name(self):
return self._current_state_object.name
current_state = property(_current_state_name)
def changing_state(self, from_, to):
"""
This method is called whenever a state change is executed
"""
pass
def _new_state(self, state):
self.changing_state(self._current_state_object.name, state.name)
self._current_state_object = state
def _state_objects(self):
return list(self._states.values())
def states(self):
return [s.name for s in self._state_objects()]
@classmethod
def _add_class_transition(cls, event, from_, to, action, guard):
transition = _Transition(event, [cls._class_states[s] for s in _listize(from_)],
cls._class_states[to], action, guard)
cls._class_transitions.append(transition)
setattr(cls, event, transition.event_method())
def add_transition(self, event, from_, to, action=None, guard=None):
transition = _Transition(event, [self._state_by_name(s) for s in _listize(from_)],
self._state_by_name(to), action, guard)
self._transitions.append(transition)
setattr(self, event, transition.event_method().__get__(self, self.__class__))
def _process_transitions(self, event_name, *args, **kwargs):
transitions = self._transitions_by_name(event_name)
transitions = self._ensure_from_validity(transitions)
this_transition = self._check_guards(transitions)
this_transition.run(self, *args, **kwargs)
def _create_state_getters(self):
for state in self._state_objects():
setattr(self, state.getter_name(), state.getter_method().__get__(self, self.__class__))
def _state_by_name(self, name):
for state in self._state_objects():
if state.name == name:
return state
def _transitions_by_name(self, name):
return list(filter(lambda transition: transition.event == name, self._transitions))
def _ensure_from_validity(self, transitions):
valid_transitions = list(filter(
lambda transition: transition.is_valid_from(self._current_state_object),
transitions))
if len(valid_transitions) == 0:
raise InvalidTransition("Cannot %s from %s" % (
transitions[0].event, self.current_state))
return valid_transitions
def _check_guards(self, transitions):
allowed_transitions = []
for transition in transitions:
if transition.check_guard(self):
allowed_transitions.append(transition)
if len(allowed_transitions) == 0:
raise GuardNotSatisfied("Guard is not satisfied for this transition")
elif len(allowed_transitions) > 1:
raise ForkedTransition("More than one transition was allowed for this event")
return allowed_transitions[0]
class _Transition(object):
def __init__(self, event, from_, to, action, guard):
self.event = event
self.from_ = from_
self.to = to
self.action = action
self.guard = _Guard(guard)
def event_method(self):
def generated_event(machine, *args, **kwargs):
these_transitions = machine._process_transitions(self.event, *args, **kwargs)
generated_event.__doc__ = 'event %s' % self.event
generated_event.__name__ = self.event
return generated_event
def is_valid_from(self, from_):
return from_ in _listize(self.from_)
def check_guard(self, machine):
return self.guard.check(machine)
def run(self, machine, *args, **kwargs):
machine._current_state_object.run_exit(machine)
machine._new_state(self.to)
self.to.run_enter(machine)
_ActionRunner(machine).run(self.action, *args, **kwargs)
class _Guard(object):
def __init__(self, action):
self.action = action
def check(self, machine):
if self.action is None:
return True
items = _listize(self.action)
result = True
for item in items:
result = result and self._evaluate(machine, item)
return result
def _evaluate(self, machine, item):
if callable(item):
return item(machine)
else:
guard = getattr(machine, item)
if callable(guard):
guard = guard()
return guard
class _State(object):
def __init__(self, name, enter, exit):
self.name = name
self.enter = enter
self.exit = exit
def getter_name(self):
return 'is_%s' % self.name
def getter_method(self):
def state_getter(self_machine):
return self_machine.current_state == self.name
return state_getter
def run_enter(self, machine):
_ActionRunner(machine).run(self.enter)
def run_exit(self, machine):
_ActionRunner(machine).run(self.exit)
class _ActionRunner(object):
def __init__(self, machine):
self.machine = machine
def run(self, action_param, *args, **kwargs):
if not action_param:
return
action_items = _listize(action_param)
for action_item in action_items:
self._run_action(action_item, *args, **kwargs)
def _run_action(self, action, *args, **kwargs):
if callable(action):
self._try_to_run_with_args(action, self.machine, *args, **kwargs)
else:
self._try_to_run_with_args(getattr(self.machine, action), *args, **kwargs)
def _try_to_run_with_args(self, action, *args, **kwargs):
try:
action(*args, **kwargs)
except TypeError:
action()
class InvalidConfiguration(Exception):
pass
class InvalidTransition(Exception):
pass
class GuardNotSatisfied(Exception):
pass
class ForkedTransition(Exception):
pass
def _listize(value):
return type(value) in [list, tuple] and value or [value]

View File

@@ -0,0 +1,24 @@
from ._version import __version_info__, __version__ # noqa
from .attribute_dict import AttributeDict
from .alias_dict import AliasDict
class Lexicon(AttributeDict, AliasDict):
def __init__(self, *args, **kwargs):
# Need to avoid combining AliasDict's initial attribute write on
# self.aliases, with AttributeDict's __setattr__. Doing so results in
# an infinite loop. Instead, just skip straight to dict() for both
# explicitly (i.e. we override AliasDict.__init__ instead of extending
# it.)
# NOTE: could tickle AttributeDict.__init__ instead, in case it ever
# grows one.
dict.__init__(self, *args, **kwargs)
dict.__setattr__(self, "aliases", {})
def __getattr__(self, key):
# Intercept deepcopy/etc driven access to self.aliases when not
# actually set. (Only a problem for us, due to abovementioned combo of
# Alias and Attribute Dicts, so not solvable in a parent alone.)
if key == "aliases" and key not in self.__dict__:
self.__dict__[key] = {}
return super(Lexicon, self).__getattr__(key)

View File

@@ -0,0 +1,2 @@
__version_info__ = (2, 0, 1)
__version__ = ".".join(map(str, __version_info__))

View File

@@ -0,0 +1,95 @@
class AliasDict(dict):
def __init__(self, *args, **kwargs):
super(AliasDict, self).__init__(*args, **kwargs)
self.aliases = {}
def alias(self, from_, to):
self.aliases[from_] = to
def unalias(self, from_):
del self.aliases[from_]
def aliases_of(self, name):
"""
Returns other names for given real key or alias ``name``.
If given a real key, returns its aliases.
If given an alias, returns the real key it points to, plus any other
aliases of that real key. (The given alias itself is not included in
the return value.)
"""
names = []
key = name
# self.aliases keys are aliases, not realkeys. Easy test to see if we
# should flip around to the POV of a realkey when given an alias.
if name in self.aliases:
key = self.aliases[name]
# Ensure the real key shows up in output.
names.append(key)
# 'key' is now a realkey, whose aliases are all keys whose value is
# itself. Filter out the original name given.
names.extend(
[k for k, v in self.aliases.items() if v == key and k != name]
)
return names
def _handle(self, key, value, single, multi, unaliased):
# Attribute existence test required to not blow up when deepcopy'd
if key in getattr(self, "aliases", {}):
target = self.aliases[key]
# Single-string targets
if isinstance(target, str):
return single(self, target, value)
# Multi-string targets
else:
if multi:
return multi(self, target, value)
else:
for subkey in target:
single(self, subkey, value)
else:
return unaliased(self, key, value)
def __setitem__(self, key, value):
def single(d, target, value):
d[target] = value
def unaliased(d, key, value):
super(AliasDict, d).__setitem__(key, value)
return self._handle(key, value, single, None, unaliased)
def __getitem__(self, key):
def single(d, target, value):
return d[target]
def unaliased(d, key, value):
return super(AliasDict, d).__getitem__(key)
def multi(d, target, value):
msg = "Multi-target aliases have no well-defined value and can't be read." # noqa
raise ValueError(msg)
return self._handle(key, None, single, multi, unaliased)
def __contains__(self, key):
def single(d, target, value):
return target in d
def multi(d, target, value):
return all(subkey in self for subkey in self.aliases[key])
def unaliased(d, key, value):
return super(AliasDict, d).__contains__(key)
return self._handle(key, None, single, multi, unaliased)
def __delitem__(self, key):
def single(d, target, value):
del d[target]
def unaliased(d, key, value):
return super(AliasDict, d).__delitem__(key)
return self._handle(key, None, single, None, unaliased)

View File

@@ -0,0 +1,16 @@
class AttributeDict(dict):
def __getattr__(self, key):
try:
return self[key]
except KeyError:
# to conform with __getattr__ spec
raise AttributeError(key)
def __setattr__(self, key, value):
self[key] = value
def __delattr__(self, key):
del self[key]
def __dir__(self):
return dir(type(self)) + list(self.keys())

View File

@@ -0,0 +1,427 @@
from .error import *
from .tokens import *
from .events import *
from .nodes import *
from .loader import *
from .dumper import *
__version__ = '5.4.1'
try:
from .cyaml import *
__with_libyaml__ = True
except ImportError:
__with_libyaml__ = False
import io
#------------------------------------------------------------------------------
# Warnings control
#------------------------------------------------------------------------------
# 'Global' warnings state:
_warnings_enabled = {
'YAMLLoadWarning': True,
}
# Get or set global warnings' state
def warnings(settings=None):
if settings is None:
return _warnings_enabled
if type(settings) is dict:
for key in settings:
if key in _warnings_enabled:
_warnings_enabled[key] = settings[key]
# Warn when load() is called without Loader=...
class YAMLLoadWarning(RuntimeWarning):
pass
def load_warning(method):
if _warnings_enabled['YAMLLoadWarning'] is False:
return
import warnings
message = (
"calling yaml.%s() without Loader=... is deprecated, as the "
"default Loader is unsafe. Please read "
"https://msg.pyyaml.org/load for full details."
) % method
warnings.warn(message, YAMLLoadWarning, stacklevel=3)
#------------------------------------------------------------------------------
def scan(stream, Loader=Loader):
"""
Scan a YAML stream and produce scanning tokens.
"""
loader = Loader(stream)
try:
while loader.check_token():
yield loader.get_token()
finally:
loader.dispose()
def parse(stream, Loader=Loader):
"""
Parse a YAML stream and produce parsing events.
"""
loader = Loader(stream)
try:
while loader.check_event():
yield loader.get_event()
finally:
loader.dispose()
def compose(stream, Loader=Loader):
"""
Parse the first YAML document in a stream
and produce the corresponding representation tree.
"""
loader = Loader(stream)
try:
return loader.get_single_node()
finally:
loader.dispose()
def compose_all(stream, Loader=Loader):
"""
Parse all YAML documents in a stream
and produce corresponding representation trees.
"""
loader = Loader(stream)
try:
while loader.check_node():
yield loader.get_node()
finally:
loader.dispose()
def load(stream, Loader=None):
"""
Parse the first YAML document in a stream
and produce the corresponding Python object.
"""
if Loader is None:
load_warning('load')
Loader = FullLoader
loader = Loader(stream)
try:
return loader.get_single_data()
finally:
loader.dispose()
def load_all(stream, Loader=None):
"""
Parse all YAML documents in a stream
and produce corresponding Python objects.
"""
if Loader is None:
load_warning('load_all')
Loader = FullLoader
loader = Loader(stream)
try:
while loader.check_data():
yield loader.get_data()
finally:
loader.dispose()
def full_load(stream):
"""
Parse the first YAML document in a stream
and produce the corresponding Python object.
Resolve all tags except those known to be
unsafe on untrusted input.
"""
return load(stream, FullLoader)
def full_load_all(stream):
"""
Parse all YAML documents in a stream
and produce corresponding Python objects.
Resolve all tags except those known to be
unsafe on untrusted input.
"""
return load_all(stream, FullLoader)
def safe_load(stream):
"""
Parse the first YAML document in a stream
and produce the corresponding Python object.
Resolve only basic YAML tags. This is known
to be safe for untrusted input.
"""
return load(stream, SafeLoader)
def safe_load_all(stream):
"""
Parse all YAML documents in a stream
and produce corresponding Python objects.
Resolve only basic YAML tags. This is known
to be safe for untrusted input.
"""
return load_all(stream, SafeLoader)
def unsafe_load(stream):
"""
Parse the first YAML document in a stream
and produce the corresponding Python object.
Resolve all tags, even those known to be
unsafe on untrusted input.
"""
return load(stream, UnsafeLoader)
def unsafe_load_all(stream):
"""
Parse all YAML documents in a stream
and produce corresponding Python objects.
Resolve all tags, even those known to be
unsafe on untrusted input.
"""
return load_all(stream, UnsafeLoader)
def emit(events, stream=None, Dumper=Dumper,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None):
"""
Emit YAML parsing events into a stream.
If stream is None, return the produced string instead.
"""
getvalue = None
if stream is None:
stream = io.StringIO()
getvalue = stream.getvalue
dumper = Dumper(stream, canonical=canonical, indent=indent, width=width,
allow_unicode=allow_unicode, line_break=line_break)
try:
for event in events:
dumper.emit(event)
finally:
dumper.dispose()
if getvalue:
return getvalue()
def serialize_all(nodes, stream=None, Dumper=Dumper,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None):
"""
Serialize a sequence of representation trees into a YAML stream.
If stream is None, return the produced string instead.
"""
getvalue = None
if stream is None:
if encoding is None:
stream = io.StringIO()
else:
stream = io.BytesIO()
getvalue = stream.getvalue
dumper = Dumper(stream, canonical=canonical, indent=indent, width=width,
allow_unicode=allow_unicode, line_break=line_break,
encoding=encoding, version=version, tags=tags,
explicit_start=explicit_start, explicit_end=explicit_end)
try:
dumper.open()
for node in nodes:
dumper.serialize(node)
dumper.close()
finally:
dumper.dispose()
if getvalue:
return getvalue()
def serialize(node, stream=None, Dumper=Dumper, **kwds):
"""
Serialize a representation tree into a YAML stream.
If stream is None, return the produced string instead.
"""
return serialize_all([node], stream, Dumper=Dumper, **kwds)
def dump_all(documents, stream=None, Dumper=Dumper,
default_style=None, default_flow_style=False,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None, sort_keys=True):
"""
Serialize a sequence of Python objects into a YAML stream.
If stream is None, return the produced string instead.
"""
getvalue = None
if stream is None:
if encoding is None:
stream = io.StringIO()
else:
stream = io.BytesIO()
getvalue = stream.getvalue
dumper = Dumper(stream, default_style=default_style,
default_flow_style=default_flow_style,
canonical=canonical, indent=indent, width=width,
allow_unicode=allow_unicode, line_break=line_break,
encoding=encoding, version=version, tags=tags,
explicit_start=explicit_start, explicit_end=explicit_end, sort_keys=sort_keys)
try:
dumper.open()
for data in documents:
dumper.represent(data)
dumper.close()
finally:
dumper.dispose()
if getvalue:
return getvalue()
def dump(data, stream=None, Dumper=Dumper, **kwds):
"""
Serialize a Python object into a YAML stream.
If stream is None, return the produced string instead.
"""
return dump_all([data], stream, Dumper=Dumper, **kwds)
def safe_dump_all(documents, stream=None, **kwds):
"""
Serialize a sequence of Python objects into a YAML stream.
Produce only basic YAML tags.
If stream is None, return the produced string instead.
"""
return dump_all(documents, stream, Dumper=SafeDumper, **kwds)
def safe_dump(data, stream=None, **kwds):
"""
Serialize a Python object into a YAML stream.
Produce only basic YAML tags.
If stream is None, return the produced string instead.
"""
return dump_all([data], stream, Dumper=SafeDumper, **kwds)
def add_implicit_resolver(tag, regexp, first=None,
Loader=None, Dumper=Dumper):
"""
Add an implicit scalar detector.
If an implicit scalar value matches the given regexp,
the corresponding tag is assigned to the scalar.
first is a sequence of possible initial characters or None.
"""
if Loader is None:
loader.Loader.add_implicit_resolver(tag, regexp, first)
loader.FullLoader.add_implicit_resolver(tag, regexp, first)
loader.UnsafeLoader.add_implicit_resolver(tag, regexp, first)
else:
Loader.add_implicit_resolver(tag, regexp, first)
Dumper.add_implicit_resolver(tag, regexp, first)
def add_path_resolver(tag, path, kind=None, Loader=None, Dumper=Dumper):
"""
Add a path based resolver for the given tag.
A path is a list of keys that forms a path
to a node in the representation tree.
Keys can be string values, integers, or None.
"""
if Loader is None:
loader.Loader.add_path_resolver(tag, path, kind)
loader.FullLoader.add_path_resolver(tag, path, kind)
loader.UnsafeLoader.add_path_resolver(tag, path, kind)
else:
Loader.add_path_resolver(tag, path, kind)
Dumper.add_path_resolver(tag, path, kind)
def add_constructor(tag, constructor, Loader=None):
"""
Add a constructor for the given tag.
Constructor is a function that accepts a Loader instance
and a node object and produces the corresponding Python object.
"""
if Loader is None:
loader.Loader.add_constructor(tag, constructor)
loader.FullLoader.add_constructor(tag, constructor)
loader.UnsafeLoader.add_constructor(tag, constructor)
else:
Loader.add_constructor(tag, constructor)
def add_multi_constructor(tag_prefix, multi_constructor, Loader=None):
"""
Add a multi-constructor for the given tag prefix.
Multi-constructor is called for a node if its tag starts with tag_prefix.
Multi-constructor accepts a Loader instance, a tag suffix,
and a node object and produces the corresponding Python object.
"""
if Loader is None:
loader.Loader.add_multi_constructor(tag_prefix, multi_constructor)
loader.FullLoader.add_multi_constructor(tag_prefix, multi_constructor)
loader.UnsafeLoader.add_multi_constructor(tag_prefix, multi_constructor)
else:
Loader.add_multi_constructor(tag_prefix, multi_constructor)
def add_representer(data_type, representer, Dumper=Dumper):
"""
Add a representer for the given type.
Representer is a function accepting a Dumper instance
and an instance of the given data type
and producing the corresponding representation node.
"""
Dumper.add_representer(data_type, representer)
def add_multi_representer(data_type, multi_representer, Dumper=Dumper):
"""
Add a representer for the given type.
Multi-representer is a function accepting a Dumper instance
and an instance of the given data type or subtype
and producing the corresponding representation node.
"""
Dumper.add_multi_representer(data_type, multi_representer)
class YAMLObjectMetaclass(type):
"""
The metaclass for YAMLObject.
"""
def __init__(cls, name, bases, kwds):
super(YAMLObjectMetaclass, cls).__init__(name, bases, kwds)
if 'yaml_tag' in kwds and kwds['yaml_tag'] is not None:
if isinstance(cls.yaml_loader, list):
for loader in cls.yaml_loader:
loader.add_constructor(cls.yaml_tag, cls.from_yaml)
else:
cls.yaml_loader.add_constructor(cls.yaml_tag, cls.from_yaml)
cls.yaml_dumper.add_representer(cls, cls.to_yaml)
class YAMLObject(metaclass=YAMLObjectMetaclass):
"""
An object that can dump itself to a YAML stream
and load itself from a YAML stream.
"""
__slots__ = () # no direct instantiation, so allow immutable subclasses
yaml_loader = [Loader, FullLoader, UnsafeLoader]
yaml_dumper = Dumper
yaml_tag = None
yaml_flow_style = None
@classmethod
def from_yaml(cls, loader, node):
"""
Convert a representation node to a Python object.
"""
return loader.construct_yaml_object(node, cls)
@classmethod
def to_yaml(cls, dumper, data):
"""
Convert a Python object to a representation node.
"""
return dumper.represent_yaml_object(cls.yaml_tag, data, cls,
flow_style=cls.yaml_flow_style)

View File

@@ -0,0 +1,139 @@
__all__ = ['Composer', 'ComposerError']
from .error import MarkedYAMLError
from .events import *
from .nodes import *
class ComposerError(MarkedYAMLError):
pass
class Composer:
def __init__(self):
self.anchors = {}
def check_node(self):
# Drop the STREAM-START event.
if self.check_event(StreamStartEvent):
self.get_event()
# If there are more documents available?
return not self.check_event(StreamEndEvent)
def get_node(self):
# Get the root node of the next document.
if not self.check_event(StreamEndEvent):
return self.compose_document()
def get_single_node(self):
# Drop the STREAM-START event.
self.get_event()
# Compose a document if the stream is not empty.
document = None
if not self.check_event(StreamEndEvent):
document = self.compose_document()
# Ensure that the stream contains no more documents.
if not self.check_event(StreamEndEvent):
event = self.get_event()
raise ComposerError("expected a single document in the stream",
document.start_mark, "but found another document",
event.start_mark)
# Drop the STREAM-END event.
self.get_event()
return document
def compose_document(self):
# Drop the DOCUMENT-START event.
self.get_event()
# Compose the root node.
node = self.compose_node(None, None)
# Drop the DOCUMENT-END event.
self.get_event()
self.anchors = {}
return node
def compose_node(self, parent, index):
if self.check_event(AliasEvent):
event = self.get_event()
anchor = event.anchor
if anchor not in self.anchors:
raise ComposerError(None, None, "found undefined alias %r"
% anchor, event.start_mark)
return self.anchors[anchor]
event = self.peek_event()
anchor = event.anchor
if anchor is not None:
if anchor in self.anchors:
raise ComposerError("found duplicate anchor %r; first occurrence"
% anchor, self.anchors[anchor].start_mark,
"second occurrence", event.start_mark)
self.descend_resolver(parent, index)
if self.check_event(ScalarEvent):
node = self.compose_scalar_node(anchor)
elif self.check_event(SequenceStartEvent):
node = self.compose_sequence_node(anchor)
elif self.check_event(MappingStartEvent):
node = self.compose_mapping_node(anchor)
self.ascend_resolver()
return node
def compose_scalar_node(self, anchor):
event = self.get_event()
tag = event.tag
if tag is None or tag == '!':
tag = self.resolve(ScalarNode, event.value, event.implicit)
node = ScalarNode(tag, event.value,
event.start_mark, event.end_mark, style=event.style)
if anchor is not None:
self.anchors[anchor] = node
return node
def compose_sequence_node(self, anchor):
start_event = self.get_event()
tag = start_event.tag
if tag is None or tag == '!':
tag = self.resolve(SequenceNode, None, start_event.implicit)
node = SequenceNode(tag, [],
start_event.start_mark, None,
flow_style=start_event.flow_style)
if anchor is not None:
self.anchors[anchor] = node
index = 0
while not self.check_event(SequenceEndEvent):
node.value.append(self.compose_node(node, index))
index += 1
end_event = self.get_event()
node.end_mark = end_event.end_mark
return node
def compose_mapping_node(self, anchor):
start_event = self.get_event()
tag = start_event.tag
if tag is None or tag == '!':
tag = self.resolve(MappingNode, None, start_event.implicit)
node = MappingNode(tag, [],
start_event.start_mark, None,
flow_style=start_event.flow_style)
if anchor is not None:
self.anchors[anchor] = node
while not self.check_event(MappingEndEvent):
#key_event = self.peek_event()
item_key = self.compose_node(node, None)
#if item_key in node.value:
# raise ComposerError("while composing a mapping", start_event.start_mark,
# "found duplicate key", key_event.start_mark)
item_value = self.compose_node(node, item_key)
#node.value[item_key] = item_value
node.value.append((item_key, item_value))
end_event = self.get_event()
node.end_mark = end_event.end_mark
return node

View File

@@ -0,0 +1,748 @@
__all__ = [
'BaseConstructor',
'SafeConstructor',
'FullConstructor',
'UnsafeConstructor',
'Constructor',
'ConstructorError'
]
from .error import *
from .nodes import *
import collections.abc, datetime, base64, binascii, re, sys, types
class ConstructorError(MarkedYAMLError):
pass
class BaseConstructor:
yaml_constructors = {}
yaml_multi_constructors = {}
def __init__(self):
self.constructed_objects = {}
self.recursive_objects = {}
self.state_generators = []
self.deep_construct = False
def check_data(self):
# If there are more documents available?
return self.check_node()
def check_state_key(self, key):
"""Block special attributes/methods from being set in a newly created
object, to prevent user-controlled methods from being called during
deserialization"""
if self.get_state_keys_blacklist_regexp().match(key):
raise ConstructorError(None, None,
"blacklisted key '%s' in instance state found" % (key,), None)
def get_data(self):
# Construct and return the next document.
if self.check_node():
return self.construct_document(self.get_node())
def get_single_data(self):
# Ensure that the stream contains a single document and construct it.
node = self.get_single_node()
if node is not None:
return self.construct_document(node)
return None
def construct_document(self, node):
data = self.construct_object(node)
while self.state_generators:
state_generators = self.state_generators
self.state_generators = []
for generator in state_generators:
for dummy in generator:
pass
self.constructed_objects = {}
self.recursive_objects = {}
self.deep_construct = False
return data
def construct_object(self, node, deep=False):
if node in self.constructed_objects:
return self.constructed_objects[node]
if deep:
old_deep = self.deep_construct
self.deep_construct = True
if node in self.recursive_objects:
raise ConstructorError(None, None,
"found unconstructable recursive node", node.start_mark)
self.recursive_objects[node] = None
constructor = None
tag_suffix = None
if node.tag in self.yaml_constructors:
constructor = self.yaml_constructors[node.tag]
else:
for tag_prefix in self.yaml_multi_constructors:
if tag_prefix is not None and node.tag.startswith(tag_prefix):
tag_suffix = node.tag[len(tag_prefix):]
constructor = self.yaml_multi_constructors[tag_prefix]
break
else:
if None in self.yaml_multi_constructors:
tag_suffix = node.tag
constructor = self.yaml_multi_constructors[None]
elif None in self.yaml_constructors:
constructor = self.yaml_constructors[None]
elif isinstance(node, ScalarNode):
constructor = self.__class__.construct_scalar
elif isinstance(node, SequenceNode):
constructor = self.__class__.construct_sequence
elif isinstance(node, MappingNode):
constructor = self.__class__.construct_mapping
if tag_suffix is None:
data = constructor(self, node)
else:
data = constructor(self, tag_suffix, node)
if isinstance(data, types.GeneratorType):
generator = data
data = next(generator)
if self.deep_construct:
for dummy in generator:
pass
else:
self.state_generators.append(generator)
self.constructed_objects[node] = data
del self.recursive_objects[node]
if deep:
self.deep_construct = old_deep
return data
def construct_scalar(self, node):
if not isinstance(node, ScalarNode):
raise ConstructorError(None, None,
"expected a scalar node, but found %s" % node.id,
node.start_mark)
return node.value
def construct_sequence(self, node, deep=False):
if not isinstance(node, SequenceNode):
raise ConstructorError(None, None,
"expected a sequence node, but found %s" % node.id,
node.start_mark)
return [self.construct_object(child, deep=deep)
for child in node.value]
def construct_mapping(self, node, deep=False):
if not isinstance(node, MappingNode):
raise ConstructorError(None, None,
"expected a mapping node, but found %s" % node.id,
node.start_mark)
mapping = {}
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep)
if not isinstance(key, collections.abc.Hashable):
raise ConstructorError("while constructing a mapping", node.start_mark,
"found unhashable key", key_node.start_mark)
value = self.construct_object(value_node, deep=deep)
mapping[key] = value
return mapping
def construct_pairs(self, node, deep=False):
if not isinstance(node, MappingNode):
raise ConstructorError(None, None,
"expected a mapping node, but found %s" % node.id,
node.start_mark)
pairs = []
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep)
value = self.construct_object(value_node, deep=deep)
pairs.append((key, value))
return pairs
@classmethod
def add_constructor(cls, tag, constructor):
if not 'yaml_constructors' in cls.__dict__:
cls.yaml_constructors = cls.yaml_constructors.copy()
cls.yaml_constructors[tag] = constructor
@classmethod
def add_multi_constructor(cls, tag_prefix, multi_constructor):
if not 'yaml_multi_constructors' in cls.__dict__:
cls.yaml_multi_constructors = cls.yaml_multi_constructors.copy()
cls.yaml_multi_constructors[tag_prefix] = multi_constructor
class SafeConstructor(BaseConstructor):
def construct_scalar(self, node):
if isinstance(node, MappingNode):
for key_node, value_node in node.value:
if key_node.tag == 'tag:yaml.org,2002:value':
return self.construct_scalar(value_node)
return super().construct_scalar(node)
def flatten_mapping(self, node):
merge = []
index = 0
while index < len(node.value):
key_node, value_node = node.value[index]
if key_node.tag == 'tag:yaml.org,2002:merge':
del node.value[index]
if isinstance(value_node, MappingNode):
self.flatten_mapping(value_node)
merge.extend(value_node.value)
elif isinstance(value_node, SequenceNode):
submerge = []
for subnode in value_node.value:
if not isinstance(subnode, MappingNode):
raise ConstructorError("while constructing a mapping",
node.start_mark,
"expected a mapping for merging, but found %s"
% subnode.id, subnode.start_mark)
self.flatten_mapping(subnode)
submerge.append(subnode.value)
submerge.reverse()
for value in submerge:
merge.extend(value)
else:
raise ConstructorError("while constructing a mapping", node.start_mark,
"expected a mapping or list of mappings for merging, but found %s"
% value_node.id, value_node.start_mark)
elif key_node.tag == 'tag:yaml.org,2002:value':
key_node.tag = 'tag:yaml.org,2002:str'
index += 1
else:
index += 1
if merge:
node.value = merge + node.value
def construct_mapping(self, node, deep=False):
if isinstance(node, MappingNode):
self.flatten_mapping(node)
return super().construct_mapping(node, deep=deep)
def construct_yaml_null(self, node):
self.construct_scalar(node)
return None
bool_values = {
'yes': True,
'no': False,
'true': True,
'false': False,
'on': True,
'off': False,
}
def construct_yaml_bool(self, node):
value = self.construct_scalar(node)
return self.bool_values[value.lower()]
def construct_yaml_int(self, node):
value = self.construct_scalar(node)
value = value.replace('_', '')
sign = +1
if value[0] == '-':
sign = -1
if value[0] in '+-':
value = value[1:]
if value == '0':
return 0
elif value.startswith('0b'):
return sign*int(value[2:], 2)
elif value.startswith('0x'):
return sign*int(value[2:], 16)
elif value[0] == '0':
return sign*int(value, 8)
elif ':' in value:
digits = [int(part) for part in value.split(':')]
digits.reverse()
base = 1
value = 0
for digit in digits:
value += digit*base
base *= 60
return sign*value
else:
return sign*int(value)
inf_value = 1e300
while inf_value != inf_value*inf_value:
inf_value *= inf_value
nan_value = -inf_value/inf_value # Trying to make a quiet NaN (like C99).
def construct_yaml_float(self, node):
value = self.construct_scalar(node)
value = value.replace('_', '').lower()
sign = +1
if value[0] == '-':
sign = -1
if value[0] in '+-':
value = value[1:]
if value == '.inf':
return sign*self.inf_value
elif value == '.nan':
return self.nan_value
elif ':' in value:
digits = [float(part) for part in value.split(':')]
digits.reverse()
base = 1
value = 0.0
for digit in digits:
value += digit*base
base *= 60
return sign*value
else:
return sign*float(value)
def construct_yaml_binary(self, node):
try:
value = self.construct_scalar(node).encode('ascii')
except UnicodeEncodeError as exc:
raise ConstructorError(None, None,
"failed to convert base64 data into ascii: %s" % exc,
node.start_mark)
try:
if hasattr(base64, 'decodebytes'):
return base64.decodebytes(value)
else:
return base64.decodestring(value)
except binascii.Error as exc:
raise ConstructorError(None, None,
"failed to decode base64 data: %s" % exc, node.start_mark)
timestamp_regexp = re.compile(
r'''^(?P<year>[0-9][0-9][0-9][0-9])
-(?P<month>[0-9][0-9]?)
-(?P<day>[0-9][0-9]?)
(?:(?:[Tt]|[ \t]+)
(?P<hour>[0-9][0-9]?)
:(?P<minute>[0-9][0-9])
:(?P<second>[0-9][0-9])
(?:\.(?P<fraction>[0-9]*))?
(?:[ \t]*(?P<tz>Z|(?P<tz_sign>[-+])(?P<tz_hour>[0-9][0-9]?)
(?::(?P<tz_minute>[0-9][0-9]))?))?)?$''', re.X)
def construct_yaml_timestamp(self, node):
value = self.construct_scalar(node)
match = self.timestamp_regexp.match(node.value)
values = match.groupdict()
year = int(values['year'])
month = int(values['month'])
day = int(values['day'])
if not values['hour']:
return datetime.date(year, month, day)
hour = int(values['hour'])
minute = int(values['minute'])
second = int(values['second'])
fraction = 0
tzinfo = None
if values['fraction']:
fraction = values['fraction'][:6]
while len(fraction) < 6:
fraction += '0'
fraction = int(fraction)
if values['tz_sign']:
tz_hour = int(values['tz_hour'])
tz_minute = int(values['tz_minute'] or 0)
delta = datetime.timedelta(hours=tz_hour, minutes=tz_minute)
if values['tz_sign'] == '-':
delta = -delta
tzinfo = datetime.timezone(delta)
elif values['tz']:
tzinfo = datetime.timezone.utc
return datetime.datetime(year, month, day, hour, minute, second, fraction,
tzinfo=tzinfo)
def construct_yaml_omap(self, node):
# Note: we do not check for duplicate keys, because it's too
# CPU-expensive.
omap = []
yield omap
if not isinstance(node, SequenceNode):
raise ConstructorError("while constructing an ordered map", node.start_mark,
"expected a sequence, but found %s" % node.id, node.start_mark)
for subnode in node.value:
if not isinstance(subnode, MappingNode):
raise ConstructorError("while constructing an ordered map", node.start_mark,
"expected a mapping of length 1, but found %s" % subnode.id,
subnode.start_mark)
if len(subnode.value) != 1:
raise ConstructorError("while constructing an ordered map", node.start_mark,
"expected a single mapping item, but found %d items" % len(subnode.value),
subnode.start_mark)
key_node, value_node = subnode.value[0]
key = self.construct_object(key_node)
value = self.construct_object(value_node)
omap.append((key, value))
def construct_yaml_pairs(self, node):
# Note: the same code as `construct_yaml_omap`.
pairs = []
yield pairs
if not isinstance(node, SequenceNode):
raise ConstructorError("while constructing pairs", node.start_mark,
"expected a sequence, but found %s" % node.id, node.start_mark)
for subnode in node.value:
if not isinstance(subnode, MappingNode):
raise ConstructorError("while constructing pairs", node.start_mark,
"expected a mapping of length 1, but found %s" % subnode.id,
subnode.start_mark)
if len(subnode.value) != 1:
raise ConstructorError("while constructing pairs", node.start_mark,
"expected a single mapping item, but found %d items" % len(subnode.value),
subnode.start_mark)
key_node, value_node = subnode.value[0]
key = self.construct_object(key_node)
value = self.construct_object(value_node)
pairs.append((key, value))
def construct_yaml_set(self, node):
data = set()
yield data
value = self.construct_mapping(node)
data.update(value)
def construct_yaml_str(self, node):
return self.construct_scalar(node)
def construct_yaml_seq(self, node):
data = []
yield data
data.extend(self.construct_sequence(node))
def construct_yaml_map(self, node):
data = {}
yield data
value = self.construct_mapping(node)
data.update(value)
def construct_yaml_object(self, node, cls):
data = cls.__new__(cls)
yield data
if hasattr(data, '__setstate__'):
state = self.construct_mapping(node, deep=True)
data.__setstate__(state)
else:
state = self.construct_mapping(node)
data.__dict__.update(state)
def construct_undefined(self, node):
raise ConstructorError(None, None,
"could not determine a constructor for the tag %r" % node.tag,
node.start_mark)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:null',
SafeConstructor.construct_yaml_null)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:bool',
SafeConstructor.construct_yaml_bool)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:int',
SafeConstructor.construct_yaml_int)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:float',
SafeConstructor.construct_yaml_float)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:binary',
SafeConstructor.construct_yaml_binary)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:timestamp',
SafeConstructor.construct_yaml_timestamp)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:omap',
SafeConstructor.construct_yaml_omap)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:pairs',
SafeConstructor.construct_yaml_pairs)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:set',
SafeConstructor.construct_yaml_set)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:str',
SafeConstructor.construct_yaml_str)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:seq',
SafeConstructor.construct_yaml_seq)
SafeConstructor.add_constructor(
'tag:yaml.org,2002:map',
SafeConstructor.construct_yaml_map)
SafeConstructor.add_constructor(None,
SafeConstructor.construct_undefined)
class FullConstructor(SafeConstructor):
# 'extend' is blacklisted because it is used by
# construct_python_object_apply to add `listitems` to a newly generate
# python instance
def get_state_keys_blacklist(self):
return ['^extend$', '^__.*__$']
def get_state_keys_blacklist_regexp(self):
if not hasattr(self, 'state_keys_blacklist_regexp'):
self.state_keys_blacklist_regexp = re.compile('(' + '|'.join(self.get_state_keys_blacklist()) + ')')
return self.state_keys_blacklist_regexp
def construct_python_str(self, node):
return self.construct_scalar(node)
def construct_python_unicode(self, node):
return self.construct_scalar(node)
def construct_python_bytes(self, node):
try:
value = self.construct_scalar(node).encode('ascii')
except UnicodeEncodeError as exc:
raise ConstructorError(None, None,
"failed to convert base64 data into ascii: %s" % exc,
node.start_mark)
try:
if hasattr(base64, 'decodebytes'):
return base64.decodebytes(value)
else:
return base64.decodestring(value)
except binascii.Error as exc:
raise ConstructorError(None, None,
"failed to decode base64 data: %s" % exc, node.start_mark)
def construct_python_long(self, node):
return self.construct_yaml_int(node)
def construct_python_complex(self, node):
return complex(self.construct_scalar(node))
def construct_python_tuple(self, node):
return tuple(self.construct_sequence(node))
def find_python_module(self, name, mark, unsafe=False):
if not name:
raise ConstructorError("while constructing a Python module", mark,
"expected non-empty name appended to the tag", mark)
if unsafe:
try:
__import__(name)
except ImportError as exc:
raise ConstructorError("while constructing a Python module", mark,
"cannot find module %r (%s)" % (name, exc), mark)
if name not in sys.modules:
raise ConstructorError("while constructing a Python module", mark,
"module %r is not imported" % name, mark)
return sys.modules[name]
def find_python_name(self, name, mark, unsafe=False):
if not name:
raise ConstructorError("while constructing a Python object", mark,
"expected non-empty name appended to the tag", mark)
if '.' in name:
module_name, object_name = name.rsplit('.', 1)
else:
module_name = 'builtins'
object_name = name
if unsafe:
try:
__import__(module_name)
except ImportError as exc:
raise ConstructorError("while constructing a Python object", mark,
"cannot find module %r (%s)" % (module_name, exc), mark)
if module_name not in sys.modules:
raise ConstructorError("while constructing a Python object", mark,
"module %r is not imported" % module_name, mark)
module = sys.modules[module_name]
if not hasattr(module, object_name):
raise ConstructorError("while constructing a Python object", mark,
"cannot find %r in the module %r"
% (object_name, module.__name__), mark)
return getattr(module, object_name)
def construct_python_name(self, suffix, node):
value = self.construct_scalar(node)
if value:
raise ConstructorError("while constructing a Python name", node.start_mark,
"expected the empty value, but found %r" % value, node.start_mark)
return self.find_python_name(suffix, node.start_mark)
def construct_python_module(self, suffix, node):
value = self.construct_scalar(node)
if value:
raise ConstructorError("while constructing a Python module", node.start_mark,
"expected the empty value, but found %r" % value, node.start_mark)
return self.find_python_module(suffix, node.start_mark)
def make_python_instance(self, suffix, node,
args=None, kwds=None, newobj=False, unsafe=False):
if not args:
args = []
if not kwds:
kwds = {}
cls = self.find_python_name(suffix, node.start_mark)
if not (unsafe or isinstance(cls, type)):
raise ConstructorError("while constructing a Python instance", node.start_mark,
"expected a class, but found %r" % type(cls),
node.start_mark)
if newobj and isinstance(cls, type):
return cls.__new__(cls, *args, **kwds)
else:
return cls(*args, **kwds)
def set_python_instance_state(self, instance, state, unsafe=False):
if hasattr(instance, '__setstate__'):
instance.__setstate__(state)
else:
slotstate = {}
if isinstance(state, tuple) and len(state) == 2:
state, slotstate = state
if hasattr(instance, '__dict__'):
if not unsafe and state:
for key in state.keys():
self.check_state_key(key)
instance.__dict__.update(state)
elif state:
slotstate.update(state)
for key, value in slotstate.items():
if not unsafe:
self.check_state_key(key)
setattr(instance, key, value)
def construct_python_object(self, suffix, node):
# Format:
# !!python/object:module.name { ... state ... }
instance = self.make_python_instance(suffix, node, newobj=True)
yield instance
deep = hasattr(instance, '__setstate__')
state = self.construct_mapping(node, deep=deep)
self.set_python_instance_state(instance, state)
def construct_python_object_apply(self, suffix, node, newobj=False):
# Format:
# !!python/object/apply # (or !!python/object/new)
# args: [ ... arguments ... ]
# kwds: { ... keywords ... }
# state: ... state ...
# listitems: [ ... listitems ... ]
# dictitems: { ... dictitems ... }
# or short format:
# !!python/object/apply [ ... arguments ... ]
# The difference between !!python/object/apply and !!python/object/new
# is how an object is created, check make_python_instance for details.
if isinstance(node, SequenceNode):
args = self.construct_sequence(node, deep=True)
kwds = {}
state = {}
listitems = []
dictitems = {}
else:
value = self.construct_mapping(node, deep=True)
args = value.get('args', [])
kwds = value.get('kwds', {})
state = value.get('state', {})
listitems = value.get('listitems', [])
dictitems = value.get('dictitems', {})
instance = self.make_python_instance(suffix, node, args, kwds, newobj)
if state:
self.set_python_instance_state(instance, state)
if listitems:
instance.extend(listitems)
if dictitems:
for key in dictitems:
instance[key] = dictitems[key]
return instance
def construct_python_object_new(self, suffix, node):
return self.construct_python_object_apply(suffix, node, newobj=True)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/none',
FullConstructor.construct_yaml_null)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/bool',
FullConstructor.construct_yaml_bool)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/str',
FullConstructor.construct_python_str)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/unicode',
FullConstructor.construct_python_unicode)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/bytes',
FullConstructor.construct_python_bytes)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/int',
FullConstructor.construct_yaml_int)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/long',
FullConstructor.construct_python_long)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/float',
FullConstructor.construct_yaml_float)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/complex',
FullConstructor.construct_python_complex)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/list',
FullConstructor.construct_yaml_seq)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/tuple',
FullConstructor.construct_python_tuple)
FullConstructor.add_constructor(
'tag:yaml.org,2002:python/dict',
FullConstructor.construct_yaml_map)
FullConstructor.add_multi_constructor(
'tag:yaml.org,2002:python/name:',
FullConstructor.construct_python_name)
class UnsafeConstructor(FullConstructor):
def find_python_module(self, name, mark):
return super(UnsafeConstructor, self).find_python_module(name, mark, unsafe=True)
def find_python_name(self, name, mark):
return super(UnsafeConstructor, self).find_python_name(name, mark, unsafe=True)
def make_python_instance(self, suffix, node, args=None, kwds=None, newobj=False):
return super(UnsafeConstructor, self).make_python_instance(
suffix, node, args, kwds, newobj, unsafe=True)
def set_python_instance_state(self, instance, state):
return super(UnsafeConstructor, self).set_python_instance_state(
instance, state, unsafe=True)
UnsafeConstructor.add_multi_constructor(
'tag:yaml.org,2002:python/module:',
UnsafeConstructor.construct_python_module)
UnsafeConstructor.add_multi_constructor(
'tag:yaml.org,2002:python/object:',
UnsafeConstructor.construct_python_object)
UnsafeConstructor.add_multi_constructor(
'tag:yaml.org,2002:python/object/new:',
UnsafeConstructor.construct_python_object_new)
UnsafeConstructor.add_multi_constructor(
'tag:yaml.org,2002:python/object/apply:',
UnsafeConstructor.construct_python_object_apply)
# Constructor is same as UnsafeConstructor. Need to leave this in place in case
# people have extended it directly.
class Constructor(UnsafeConstructor):
pass

View File

@@ -0,0 +1,101 @@
__all__ = [
'CBaseLoader', 'CSafeLoader', 'CFullLoader', 'CUnsafeLoader', 'CLoader',
'CBaseDumper', 'CSafeDumper', 'CDumper'
]
from yaml._yaml import CParser, CEmitter
from .constructor import *
from .serializer import *
from .representer import *
from .resolver import *
class CBaseLoader(CParser, BaseConstructor, BaseResolver):
def __init__(self, stream):
CParser.__init__(self, stream)
BaseConstructor.__init__(self)
BaseResolver.__init__(self)
class CSafeLoader(CParser, SafeConstructor, Resolver):
def __init__(self, stream):
CParser.__init__(self, stream)
SafeConstructor.__init__(self)
Resolver.__init__(self)
class CFullLoader(CParser, FullConstructor, Resolver):
def __init__(self, stream):
CParser.__init__(self, stream)
FullConstructor.__init__(self)
Resolver.__init__(self)
class CUnsafeLoader(CParser, UnsafeConstructor, Resolver):
def __init__(self, stream):
CParser.__init__(self, stream)
UnsafeConstructor.__init__(self)
Resolver.__init__(self)
class CLoader(CParser, Constructor, Resolver):
def __init__(self, stream):
CParser.__init__(self, stream)
Constructor.__init__(self)
Resolver.__init__(self)
class CBaseDumper(CEmitter, BaseRepresenter, BaseResolver):
def __init__(self, stream,
default_style=None, default_flow_style=False,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None, sort_keys=True):
CEmitter.__init__(self, stream, canonical=canonical,
indent=indent, width=width, encoding=encoding,
allow_unicode=allow_unicode, line_break=line_break,
explicit_start=explicit_start, explicit_end=explicit_end,
version=version, tags=tags)
Representer.__init__(self, default_style=default_style,
default_flow_style=default_flow_style, sort_keys=sort_keys)
Resolver.__init__(self)
class CSafeDumper(CEmitter, SafeRepresenter, Resolver):
def __init__(self, stream,
default_style=None, default_flow_style=False,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None, sort_keys=True):
CEmitter.__init__(self, stream, canonical=canonical,
indent=indent, width=width, encoding=encoding,
allow_unicode=allow_unicode, line_break=line_break,
explicit_start=explicit_start, explicit_end=explicit_end,
version=version, tags=tags)
SafeRepresenter.__init__(self, default_style=default_style,
default_flow_style=default_flow_style, sort_keys=sort_keys)
Resolver.__init__(self)
class CDumper(CEmitter, Serializer, Representer, Resolver):
def __init__(self, stream,
default_style=None, default_flow_style=False,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None, sort_keys=True):
CEmitter.__init__(self, stream, canonical=canonical,
indent=indent, width=width, encoding=encoding,
allow_unicode=allow_unicode, line_break=line_break,
explicit_start=explicit_start, explicit_end=explicit_end,
version=version, tags=tags)
Representer.__init__(self, default_style=default_style,
default_flow_style=default_flow_style, sort_keys=sort_keys)
Resolver.__init__(self)

View File

@@ -0,0 +1,62 @@
__all__ = ['BaseDumper', 'SafeDumper', 'Dumper']
from .emitter import *
from .serializer import *
from .representer import *
from .resolver import *
class BaseDumper(Emitter, Serializer, BaseRepresenter, BaseResolver):
def __init__(self, stream,
default_style=None, default_flow_style=False,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None, sort_keys=True):
Emitter.__init__(self, stream, canonical=canonical,
indent=indent, width=width,
allow_unicode=allow_unicode, line_break=line_break)
Serializer.__init__(self, encoding=encoding,
explicit_start=explicit_start, explicit_end=explicit_end,
version=version, tags=tags)
Representer.__init__(self, default_style=default_style,
default_flow_style=default_flow_style, sort_keys=sort_keys)
Resolver.__init__(self)
class SafeDumper(Emitter, Serializer, SafeRepresenter, Resolver):
def __init__(self, stream,
default_style=None, default_flow_style=False,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None, sort_keys=True):
Emitter.__init__(self, stream, canonical=canonical,
indent=indent, width=width,
allow_unicode=allow_unicode, line_break=line_break)
Serializer.__init__(self, encoding=encoding,
explicit_start=explicit_start, explicit_end=explicit_end,
version=version, tags=tags)
SafeRepresenter.__init__(self, default_style=default_style,
default_flow_style=default_flow_style, sort_keys=sort_keys)
Resolver.__init__(self)
class Dumper(Emitter, Serializer, Representer, Resolver):
def __init__(self, stream,
default_style=None, default_flow_style=False,
canonical=None, indent=None, width=None,
allow_unicode=None, line_break=None,
encoding=None, explicit_start=None, explicit_end=None,
version=None, tags=None, sort_keys=True):
Emitter.__init__(self, stream, canonical=canonical,
indent=indent, width=width,
allow_unicode=allow_unicode, line_break=line_break)
Serializer.__init__(self, encoding=encoding,
explicit_start=explicit_start, explicit_end=explicit_end,
version=version, tags=tags)
Representer.__init__(self, default_style=default_style,
default_flow_style=default_flow_style, sort_keys=sort_keys)
Resolver.__init__(self)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,75 @@
__all__ = ['Mark', 'YAMLError', 'MarkedYAMLError']
class Mark:
def __init__(self, name, index, line, column, buffer, pointer):
self.name = name
self.index = index
self.line = line
self.column = column
self.buffer = buffer
self.pointer = pointer
def get_snippet(self, indent=4, max_length=75):
if self.buffer is None:
return None
head = ''
start = self.pointer
while start > 0 and self.buffer[start-1] not in '\0\r\n\x85\u2028\u2029':
start -= 1
if self.pointer-start > max_length/2-1:
head = ' ... '
start += 5
break
tail = ''
end = self.pointer
while end < len(self.buffer) and self.buffer[end] not in '\0\r\n\x85\u2028\u2029':
end += 1
if end-self.pointer > max_length/2-1:
tail = ' ... '
end -= 5
break
snippet = self.buffer[start:end]
return ' '*indent + head + snippet + tail + '\n' \
+ ' '*(indent+self.pointer-start+len(head)) + '^'
def __str__(self):
snippet = self.get_snippet()
where = " in \"%s\", line %d, column %d" \
% (self.name, self.line+1, self.column+1)
if snippet is not None:
where += ":\n"+snippet
return where
class YAMLError(Exception):
pass
class MarkedYAMLError(YAMLError):
def __init__(self, context=None, context_mark=None,
problem=None, problem_mark=None, note=None):
self.context = context
self.context_mark = context_mark
self.problem = problem
self.problem_mark = problem_mark
self.note = note
def __str__(self):
lines = []
if self.context is not None:
lines.append(self.context)
if self.context_mark is not None \
and (self.problem is None or self.problem_mark is None
or self.context_mark.name != self.problem_mark.name
or self.context_mark.line != self.problem_mark.line
or self.context_mark.column != self.problem_mark.column):
lines.append(str(self.context_mark))
if self.problem is not None:
lines.append(self.problem)
if self.problem_mark is not None:
lines.append(str(self.problem_mark))
if self.note is not None:
lines.append(self.note)
return '\n'.join(lines)

View File

@@ -0,0 +1,86 @@
# Abstract classes.
class Event(object):
def __init__(self, start_mark=None, end_mark=None):
self.start_mark = start_mark
self.end_mark = end_mark
def __repr__(self):
attributes = [key for key in ['anchor', 'tag', 'implicit', 'value']
if hasattr(self, key)]
arguments = ', '.join(['%s=%r' % (key, getattr(self, key))
for key in attributes])
return '%s(%s)' % (self.__class__.__name__, arguments)
class NodeEvent(Event):
def __init__(self, anchor, start_mark=None, end_mark=None):
self.anchor = anchor
self.start_mark = start_mark
self.end_mark = end_mark
class CollectionStartEvent(NodeEvent):
def __init__(self, anchor, tag, implicit, start_mark=None, end_mark=None,
flow_style=None):
self.anchor = anchor
self.tag = tag
self.implicit = implicit
self.start_mark = start_mark
self.end_mark = end_mark
self.flow_style = flow_style
class CollectionEndEvent(Event):
pass
# Implementations.
class StreamStartEvent(Event):
def __init__(self, start_mark=None, end_mark=None, encoding=None):
self.start_mark = start_mark
self.end_mark = end_mark
self.encoding = encoding
class StreamEndEvent(Event):
pass
class DocumentStartEvent(Event):
def __init__(self, start_mark=None, end_mark=None,
explicit=None, version=None, tags=None):
self.start_mark = start_mark
self.end_mark = end_mark
self.explicit = explicit
self.version = version
self.tags = tags
class DocumentEndEvent(Event):
def __init__(self, start_mark=None, end_mark=None,
explicit=None):
self.start_mark = start_mark
self.end_mark = end_mark
self.explicit = explicit
class AliasEvent(NodeEvent):
pass
class ScalarEvent(NodeEvent):
def __init__(self, anchor, tag, implicit, value,
start_mark=None, end_mark=None, style=None):
self.anchor = anchor
self.tag = tag
self.implicit = implicit
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
self.style = style
class SequenceStartEvent(CollectionStartEvent):
pass
class SequenceEndEvent(CollectionEndEvent):
pass
class MappingStartEvent(CollectionStartEvent):
pass
class MappingEndEvent(CollectionEndEvent):
pass

View File

@@ -0,0 +1,63 @@
__all__ = ['BaseLoader', 'FullLoader', 'SafeLoader', 'Loader', 'UnsafeLoader']
from .reader import *
from .scanner import *
from .parser import *
from .composer import *
from .constructor import *
from .resolver import *
class BaseLoader(Reader, Scanner, Parser, Composer, BaseConstructor, BaseResolver):
def __init__(self, stream):
Reader.__init__(self, stream)
Scanner.__init__(self)
Parser.__init__(self)
Composer.__init__(self)
BaseConstructor.__init__(self)
BaseResolver.__init__(self)
class FullLoader(Reader, Scanner, Parser, Composer, FullConstructor, Resolver):
def __init__(self, stream):
Reader.__init__(self, stream)
Scanner.__init__(self)
Parser.__init__(self)
Composer.__init__(self)
FullConstructor.__init__(self)
Resolver.__init__(self)
class SafeLoader(Reader, Scanner, Parser, Composer, SafeConstructor, Resolver):
def __init__(self, stream):
Reader.__init__(self, stream)
Scanner.__init__(self)
Parser.__init__(self)
Composer.__init__(self)
SafeConstructor.__init__(self)
Resolver.__init__(self)
class Loader(Reader, Scanner, Parser, Composer, Constructor, Resolver):
def __init__(self, stream):
Reader.__init__(self, stream)
Scanner.__init__(self)
Parser.__init__(self)
Composer.__init__(self)
Constructor.__init__(self)
Resolver.__init__(self)
# UnsafeLoader is the same as Loader (which is and was always unsafe on
# untrusted input). Use of either Loader or UnsafeLoader should be rare, since
# FullLoad should be able to load almost all YAML safely. Loader is left intact
# to ensure backwards compatibility.
class UnsafeLoader(Reader, Scanner, Parser, Composer, Constructor, Resolver):
def __init__(self, stream):
Reader.__init__(self, stream)
Scanner.__init__(self)
Parser.__init__(self)
Composer.__init__(self)
Constructor.__init__(self)
Resolver.__init__(self)

View File

@@ -0,0 +1,49 @@
class Node(object):
def __init__(self, tag, value, start_mark, end_mark):
self.tag = tag
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
def __repr__(self):
value = self.value
#if isinstance(value, list):
# if len(value) == 0:
# value = '<empty>'
# elif len(value) == 1:
# value = '<1 item>'
# else:
# value = '<%d items>' % len(value)
#else:
# if len(value) > 75:
# value = repr(value[:70]+u' ... ')
# else:
# value = repr(value)
value = repr(value)
return '%s(tag=%r, value=%s)' % (self.__class__.__name__, self.tag, value)
class ScalarNode(Node):
id = 'scalar'
def __init__(self, tag, value,
start_mark=None, end_mark=None, style=None):
self.tag = tag
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
self.style = style
class CollectionNode(Node):
def __init__(self, tag, value,
start_mark=None, end_mark=None, flow_style=None):
self.tag = tag
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
self.flow_style = flow_style
class SequenceNode(CollectionNode):
id = 'sequence'
class MappingNode(CollectionNode):
id = 'mapping'

View File

@@ -0,0 +1,589 @@
# The following YAML grammar is LL(1) and is parsed by a recursive descent
# parser.
#
# stream ::= STREAM-START implicit_document? explicit_document* STREAM-END
# implicit_document ::= block_node DOCUMENT-END*
# explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END*
# block_node_or_indentless_sequence ::=
# ALIAS
# | properties (block_content | indentless_block_sequence)?
# | block_content
# | indentless_block_sequence
# block_node ::= ALIAS
# | properties block_content?
# | block_content
# flow_node ::= ALIAS
# | properties flow_content?
# | flow_content
# properties ::= TAG ANCHOR? | ANCHOR TAG?
# block_content ::= block_collection | flow_collection | SCALAR
# flow_content ::= flow_collection | SCALAR
# block_collection ::= block_sequence | block_mapping
# flow_collection ::= flow_sequence | flow_mapping
# block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END
# indentless_sequence ::= (BLOCK-ENTRY block_node?)+
# block_mapping ::= BLOCK-MAPPING_START
# ((KEY block_node_or_indentless_sequence?)?
# (VALUE block_node_or_indentless_sequence?)?)*
# BLOCK-END
# flow_sequence ::= FLOW-SEQUENCE-START
# (flow_sequence_entry FLOW-ENTRY)*
# flow_sequence_entry?
# FLOW-SEQUENCE-END
# flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
# flow_mapping ::= FLOW-MAPPING-START
# (flow_mapping_entry FLOW-ENTRY)*
# flow_mapping_entry?
# FLOW-MAPPING-END
# flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
#
# FIRST sets:
#
# stream: { STREAM-START }
# explicit_document: { DIRECTIVE DOCUMENT-START }
# implicit_document: FIRST(block_node)
# block_node: { ALIAS TAG ANCHOR SCALAR BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START }
# flow_node: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START }
# block_content: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR }
# flow_content: { FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR }
# block_collection: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START }
# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START }
# block_sequence: { BLOCK-SEQUENCE-START }
# block_mapping: { BLOCK-MAPPING-START }
# block_node_or_indentless_sequence: { ALIAS ANCHOR TAG SCALAR BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START BLOCK-ENTRY }
# indentless_sequence: { ENTRY }
# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START }
# flow_sequence: { FLOW-SEQUENCE-START }
# flow_mapping: { FLOW-MAPPING-START }
# flow_sequence_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START KEY }
# flow_mapping_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START KEY }
__all__ = ['Parser', 'ParserError']
from .error import MarkedYAMLError
from .tokens import *
from .events import *
from .scanner import *
class ParserError(MarkedYAMLError):
pass
class Parser:
# Since writing a recursive-descendant parser is a straightforward task, we
# do not give many comments here.
DEFAULT_TAGS = {
'!': '!',
'!!': 'tag:yaml.org,2002:',
}
def __init__(self):
self.current_event = None
self.yaml_version = None
self.tag_handles = {}
self.states = []
self.marks = []
self.state = self.parse_stream_start
def dispose(self):
# Reset the state attributes (to clear self-references)
self.states = []
self.state = None
def check_event(self, *choices):
# Check the type of the next event.
if self.current_event is None:
if self.state:
self.current_event = self.state()
if self.current_event is not None:
if not choices:
return True
for choice in choices:
if isinstance(self.current_event, choice):
return True
return False
def peek_event(self):
# Get the next event.
if self.current_event is None:
if self.state:
self.current_event = self.state()
return self.current_event
def get_event(self):
# Get the next event and proceed further.
if self.current_event is None:
if self.state:
self.current_event = self.state()
value = self.current_event
self.current_event = None
return value
# stream ::= STREAM-START implicit_document? explicit_document* STREAM-END
# implicit_document ::= block_node DOCUMENT-END*
# explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END*
def parse_stream_start(self):
# Parse the stream start.
token = self.get_token()
event = StreamStartEvent(token.start_mark, token.end_mark,
encoding=token.encoding)
# Prepare the next state.
self.state = self.parse_implicit_document_start
return event
def parse_implicit_document_start(self):
# Parse an implicit document.
if not self.check_token(DirectiveToken, DocumentStartToken,
StreamEndToken):
self.tag_handles = self.DEFAULT_TAGS
token = self.peek_token()
start_mark = end_mark = token.start_mark
event = DocumentStartEvent(start_mark, end_mark,
explicit=False)
# Prepare the next state.
self.states.append(self.parse_document_end)
self.state = self.parse_block_node
return event
else:
return self.parse_document_start()
def parse_document_start(self):
# Parse any extra document end indicators.
while self.check_token(DocumentEndToken):
self.get_token()
# Parse an explicit document.
if not self.check_token(StreamEndToken):
token = self.peek_token()
start_mark = token.start_mark
version, tags = self.process_directives()
if not self.check_token(DocumentStartToken):
raise ParserError(None, None,
"expected '<document start>', but found %r"
% self.peek_token().id,
self.peek_token().start_mark)
token = self.get_token()
end_mark = token.end_mark
event = DocumentStartEvent(start_mark, end_mark,
explicit=True, version=version, tags=tags)
self.states.append(self.parse_document_end)
self.state = self.parse_document_content
else:
# Parse the end of the stream.
token = self.get_token()
event = StreamEndEvent(token.start_mark, token.end_mark)
assert not self.states
assert not self.marks
self.state = None
return event
def parse_document_end(self):
# Parse the document end.
token = self.peek_token()
start_mark = end_mark = token.start_mark
explicit = False
if self.check_token(DocumentEndToken):
token = self.get_token()
end_mark = token.end_mark
explicit = True
event = DocumentEndEvent(start_mark, end_mark,
explicit=explicit)
# Prepare the next state.
self.state = self.parse_document_start
return event
def parse_document_content(self):
if self.check_token(DirectiveToken,
DocumentStartToken, DocumentEndToken, StreamEndToken):
event = self.process_empty_scalar(self.peek_token().start_mark)
self.state = self.states.pop()
return event
else:
return self.parse_block_node()
def process_directives(self):
self.yaml_version = None
self.tag_handles = {}
while self.check_token(DirectiveToken):
token = self.get_token()
if token.name == 'YAML':
if self.yaml_version is not None:
raise ParserError(None, None,
"found duplicate YAML directive", token.start_mark)
major, minor = token.value
if major != 1:
raise ParserError(None, None,
"found incompatible YAML document (version 1.* is required)",
token.start_mark)
self.yaml_version = token.value
elif token.name == 'TAG':
handle, prefix = token.value
if handle in self.tag_handles:
raise ParserError(None, None,
"duplicate tag handle %r" % handle,
token.start_mark)
self.tag_handles[handle] = prefix
if self.tag_handles:
value = self.yaml_version, self.tag_handles.copy()
else:
value = self.yaml_version, None
for key in self.DEFAULT_TAGS:
if key not in self.tag_handles:
self.tag_handles[key] = self.DEFAULT_TAGS[key]
return value
# block_node_or_indentless_sequence ::= ALIAS
# | properties (block_content | indentless_block_sequence)?
# | block_content
# | indentless_block_sequence
# block_node ::= ALIAS
# | properties block_content?
# | block_content
# flow_node ::= ALIAS
# | properties flow_content?
# | flow_content
# properties ::= TAG ANCHOR? | ANCHOR TAG?
# block_content ::= block_collection | flow_collection | SCALAR
# flow_content ::= flow_collection | SCALAR
# block_collection ::= block_sequence | block_mapping
# flow_collection ::= flow_sequence | flow_mapping
def parse_block_node(self):
return self.parse_node(block=True)
def parse_flow_node(self):
return self.parse_node()
def parse_block_node_or_indentless_sequence(self):
return self.parse_node(block=True, indentless_sequence=True)
def parse_node(self, block=False, indentless_sequence=False):
if self.check_token(AliasToken):
token = self.get_token()
event = AliasEvent(token.value, token.start_mark, token.end_mark)
self.state = self.states.pop()
else:
anchor = None
tag = None
start_mark = end_mark = tag_mark = None
if self.check_token(AnchorToken):
token = self.get_token()
start_mark = token.start_mark
end_mark = token.end_mark
anchor = token.value
if self.check_token(TagToken):
token = self.get_token()
tag_mark = token.start_mark
end_mark = token.end_mark
tag = token.value
elif self.check_token(TagToken):
token = self.get_token()
start_mark = tag_mark = token.start_mark
end_mark = token.end_mark
tag = token.value
if self.check_token(AnchorToken):
token = self.get_token()
end_mark = token.end_mark
anchor = token.value
if tag is not None:
handle, suffix = tag
if handle is not None:
if handle not in self.tag_handles:
raise ParserError("while parsing a node", start_mark,
"found undefined tag handle %r" % handle,
tag_mark)
tag = self.tag_handles[handle]+suffix
else:
tag = suffix
#if tag == '!':
# raise ParserError("while parsing a node", start_mark,
# "found non-specific tag '!'", tag_mark,
# "Please check 'http://pyyaml.org/wiki/YAMLNonSpecificTag' and share your opinion.")
if start_mark is None:
start_mark = end_mark = self.peek_token().start_mark
event = None
implicit = (tag is None or tag == '!')
if indentless_sequence and self.check_token(BlockEntryToken):
end_mark = self.peek_token().end_mark
event = SequenceStartEvent(anchor, tag, implicit,
start_mark, end_mark)
self.state = self.parse_indentless_sequence_entry
else:
if self.check_token(ScalarToken):
token = self.get_token()
end_mark = token.end_mark
if (token.plain and tag is None) or tag == '!':
implicit = (True, False)
elif tag is None:
implicit = (False, True)
else:
implicit = (False, False)
event = ScalarEvent(anchor, tag, implicit, token.value,
start_mark, end_mark, style=token.style)
self.state = self.states.pop()
elif self.check_token(FlowSequenceStartToken):
end_mark = self.peek_token().end_mark
event = SequenceStartEvent(anchor, tag, implicit,
start_mark, end_mark, flow_style=True)
self.state = self.parse_flow_sequence_first_entry
elif self.check_token(FlowMappingStartToken):
end_mark = self.peek_token().end_mark
event = MappingStartEvent(anchor, tag, implicit,
start_mark, end_mark, flow_style=True)
self.state = self.parse_flow_mapping_first_key
elif block and self.check_token(BlockSequenceStartToken):
end_mark = self.peek_token().start_mark
event = SequenceStartEvent(anchor, tag, implicit,
start_mark, end_mark, flow_style=False)
self.state = self.parse_block_sequence_first_entry
elif block and self.check_token(BlockMappingStartToken):
end_mark = self.peek_token().start_mark
event = MappingStartEvent(anchor, tag, implicit,
start_mark, end_mark, flow_style=False)
self.state = self.parse_block_mapping_first_key
elif anchor is not None or tag is not None:
# Empty scalars are allowed even if a tag or an anchor is
# specified.
event = ScalarEvent(anchor, tag, (implicit, False), '',
start_mark, end_mark)
self.state = self.states.pop()
else:
if block:
node = 'block'
else:
node = 'flow'
token = self.peek_token()
raise ParserError("while parsing a %s node" % node, start_mark,
"expected the node content, but found %r" % token.id,
token.start_mark)
return event
# block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)* BLOCK-END
def parse_block_sequence_first_entry(self):
token = self.get_token()
self.marks.append(token.start_mark)
return self.parse_block_sequence_entry()
def parse_block_sequence_entry(self):
if self.check_token(BlockEntryToken):
token = self.get_token()
if not self.check_token(BlockEntryToken, BlockEndToken):
self.states.append(self.parse_block_sequence_entry)
return self.parse_block_node()
else:
self.state = self.parse_block_sequence_entry
return self.process_empty_scalar(token.end_mark)
if not self.check_token(BlockEndToken):
token = self.peek_token()
raise ParserError("while parsing a block collection", self.marks[-1],
"expected <block end>, but found %r" % token.id, token.start_mark)
token = self.get_token()
event = SequenceEndEvent(token.start_mark, token.end_mark)
self.state = self.states.pop()
self.marks.pop()
return event
# indentless_sequence ::= (BLOCK-ENTRY block_node?)+
def parse_indentless_sequence_entry(self):
if self.check_token(BlockEntryToken):
token = self.get_token()
if not self.check_token(BlockEntryToken,
KeyToken, ValueToken, BlockEndToken):
self.states.append(self.parse_indentless_sequence_entry)
return self.parse_block_node()
else:
self.state = self.parse_indentless_sequence_entry
return self.process_empty_scalar(token.end_mark)
token = self.peek_token()
event = SequenceEndEvent(token.start_mark, token.start_mark)
self.state = self.states.pop()
return event
# block_mapping ::= BLOCK-MAPPING_START
# ((KEY block_node_or_indentless_sequence?)?
# (VALUE block_node_or_indentless_sequence?)?)*
# BLOCK-END
def parse_block_mapping_first_key(self):
token = self.get_token()
self.marks.append(token.start_mark)
return self.parse_block_mapping_key()
def parse_block_mapping_key(self):
if self.check_token(KeyToken):
token = self.get_token()
if not self.check_token(KeyToken, ValueToken, BlockEndToken):
self.states.append(self.parse_block_mapping_value)
return self.parse_block_node_or_indentless_sequence()
else:
self.state = self.parse_block_mapping_value
return self.process_empty_scalar(token.end_mark)
if not self.check_token(BlockEndToken):
token = self.peek_token()
raise ParserError("while parsing a block mapping", self.marks[-1],
"expected <block end>, but found %r" % token.id, token.start_mark)
token = self.get_token()
event = MappingEndEvent(token.start_mark, token.end_mark)
self.state = self.states.pop()
self.marks.pop()
return event
def parse_block_mapping_value(self):
if self.check_token(ValueToken):
token = self.get_token()
if not self.check_token(KeyToken, ValueToken, BlockEndToken):
self.states.append(self.parse_block_mapping_key)
return self.parse_block_node_or_indentless_sequence()
else:
self.state = self.parse_block_mapping_key
return self.process_empty_scalar(token.end_mark)
else:
self.state = self.parse_block_mapping_key
token = self.peek_token()
return self.process_empty_scalar(token.start_mark)
# flow_sequence ::= FLOW-SEQUENCE-START
# (flow_sequence_entry FLOW-ENTRY)*
# flow_sequence_entry?
# FLOW-SEQUENCE-END
# flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
#
# Note that while production rules for both flow_sequence_entry and
# flow_mapping_entry are equal, their interpretations are different.
# For `flow_sequence_entry`, the part `KEY flow_node? (VALUE flow_node?)?`
# generate an inline mapping (set syntax).
def parse_flow_sequence_first_entry(self):
token = self.get_token()
self.marks.append(token.start_mark)
return self.parse_flow_sequence_entry(first=True)
def parse_flow_sequence_entry(self, first=False):
if not self.check_token(FlowSequenceEndToken):
if not first:
if self.check_token(FlowEntryToken):
self.get_token()
else:
token = self.peek_token()
raise ParserError("while parsing a flow sequence", self.marks[-1],
"expected ',' or ']', but got %r" % token.id, token.start_mark)
if self.check_token(KeyToken):
token = self.peek_token()
event = MappingStartEvent(None, None, True,
token.start_mark, token.end_mark,
flow_style=True)
self.state = self.parse_flow_sequence_entry_mapping_key
return event
elif not self.check_token(FlowSequenceEndToken):
self.states.append(self.parse_flow_sequence_entry)
return self.parse_flow_node()
token = self.get_token()
event = SequenceEndEvent(token.start_mark, token.end_mark)
self.state = self.states.pop()
self.marks.pop()
return event
def parse_flow_sequence_entry_mapping_key(self):
token = self.get_token()
if not self.check_token(ValueToken,
FlowEntryToken, FlowSequenceEndToken):
self.states.append(self.parse_flow_sequence_entry_mapping_value)
return self.parse_flow_node()
else:
self.state = self.parse_flow_sequence_entry_mapping_value
return self.process_empty_scalar(token.end_mark)
def parse_flow_sequence_entry_mapping_value(self):
if self.check_token(ValueToken):
token = self.get_token()
if not self.check_token(FlowEntryToken, FlowSequenceEndToken):
self.states.append(self.parse_flow_sequence_entry_mapping_end)
return self.parse_flow_node()
else:
self.state = self.parse_flow_sequence_entry_mapping_end
return self.process_empty_scalar(token.end_mark)
else:
self.state = self.parse_flow_sequence_entry_mapping_end
token = self.peek_token()
return self.process_empty_scalar(token.start_mark)
def parse_flow_sequence_entry_mapping_end(self):
self.state = self.parse_flow_sequence_entry
token = self.peek_token()
return MappingEndEvent(token.start_mark, token.start_mark)
# flow_mapping ::= FLOW-MAPPING-START
# (flow_mapping_entry FLOW-ENTRY)*
# flow_mapping_entry?
# FLOW-MAPPING-END
# flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
def parse_flow_mapping_first_key(self):
token = self.get_token()
self.marks.append(token.start_mark)
return self.parse_flow_mapping_key(first=True)
def parse_flow_mapping_key(self, first=False):
if not self.check_token(FlowMappingEndToken):
if not first:
if self.check_token(FlowEntryToken):
self.get_token()
else:
token = self.peek_token()
raise ParserError("while parsing a flow mapping", self.marks[-1],
"expected ',' or '}', but got %r" % token.id, token.start_mark)
if self.check_token(KeyToken):
token = self.get_token()
if not self.check_token(ValueToken,
FlowEntryToken, FlowMappingEndToken):
self.states.append(self.parse_flow_mapping_value)
return self.parse_flow_node()
else:
self.state = self.parse_flow_mapping_value
return self.process_empty_scalar(token.end_mark)
elif not self.check_token(FlowMappingEndToken):
self.states.append(self.parse_flow_mapping_empty_value)
return self.parse_flow_node()
token = self.get_token()
event = MappingEndEvent(token.start_mark, token.end_mark)
self.state = self.states.pop()
self.marks.pop()
return event
def parse_flow_mapping_value(self):
if self.check_token(ValueToken):
token = self.get_token()
if not self.check_token(FlowEntryToken, FlowMappingEndToken):
self.states.append(self.parse_flow_mapping_key)
return self.parse_flow_node()
else:
self.state = self.parse_flow_mapping_key
return self.process_empty_scalar(token.end_mark)
else:
self.state = self.parse_flow_mapping_key
token = self.peek_token()
return self.process_empty_scalar(token.start_mark)
def parse_flow_mapping_empty_value(self):
self.state = self.parse_flow_mapping_key
return self.process_empty_scalar(self.peek_token().start_mark)
def process_empty_scalar(self, mark):
return ScalarEvent(None, None, (True, False), '', mark, mark)

View File

@@ -0,0 +1,185 @@
# This module contains abstractions for the input stream. You don't have to
# looks further, there are no pretty code.
#
# We define two classes here.
#
# Mark(source, line, column)
# It's just a record and its only use is producing nice error messages.
# Parser does not use it for any other purposes.
#
# Reader(source, data)
# Reader determines the encoding of `data` and converts it to unicode.
# Reader provides the following methods and attributes:
# reader.peek(length=1) - return the next `length` characters
# reader.forward(length=1) - move the current position to `length` characters.
# reader.index - the number of the current character.
# reader.line, stream.column - the line and the column of the current character.
__all__ = ['Reader', 'ReaderError']
from .error import YAMLError, Mark
import codecs, re
class ReaderError(YAMLError):
def __init__(self, name, position, character, encoding, reason):
self.name = name
self.character = character
self.position = position
self.encoding = encoding
self.reason = reason
def __str__(self):
if isinstance(self.character, bytes):
return "'%s' codec can't decode byte #x%02x: %s\n" \
" in \"%s\", position %d" \
% (self.encoding, ord(self.character), self.reason,
self.name, self.position)
else:
return "unacceptable character #x%04x: %s\n" \
" in \"%s\", position %d" \
% (self.character, self.reason,
self.name, self.position)
class Reader(object):
# Reader:
# - determines the data encoding and converts it to a unicode string,
# - checks if characters are in allowed range,
# - adds '\0' to the end.
# Reader accepts
# - a `bytes` object,
# - a `str` object,
# - a file-like object with its `read` method returning `str`,
# - a file-like object with its `read` method returning `unicode`.
# Yeah, it's ugly and slow.
def __init__(self, stream):
self.name = None
self.stream = None
self.stream_pointer = 0
self.eof = True
self.buffer = ''
self.pointer = 0
self.raw_buffer = None
self.raw_decode = None
self.encoding = None
self.index = 0
self.line = 0
self.column = 0
if isinstance(stream, str):
self.name = "<unicode string>"
self.check_printable(stream)
self.buffer = stream+'\0'
elif isinstance(stream, bytes):
self.name = "<byte string>"
self.raw_buffer = stream
self.determine_encoding()
else:
self.stream = stream
self.name = getattr(stream, 'name', "<file>")
self.eof = False
self.raw_buffer = None
self.determine_encoding()
def peek(self, index=0):
try:
return self.buffer[self.pointer+index]
except IndexError:
self.update(index+1)
return self.buffer[self.pointer+index]
def prefix(self, length=1):
if self.pointer+length >= len(self.buffer):
self.update(length)
return self.buffer[self.pointer:self.pointer+length]
def forward(self, length=1):
if self.pointer+length+1 >= len(self.buffer):
self.update(length+1)
while length:
ch = self.buffer[self.pointer]
self.pointer += 1
self.index += 1
if ch in '\n\x85\u2028\u2029' \
or (ch == '\r' and self.buffer[self.pointer] != '\n'):
self.line += 1
self.column = 0
elif ch != '\uFEFF':
self.column += 1
length -= 1
def get_mark(self):
if self.stream is None:
return Mark(self.name, self.index, self.line, self.column,
self.buffer, self.pointer)
else:
return Mark(self.name, self.index, self.line, self.column,
None, None)
def determine_encoding(self):
while not self.eof and (self.raw_buffer is None or len(self.raw_buffer) < 2):
self.update_raw()
if isinstance(self.raw_buffer, bytes):
if self.raw_buffer.startswith(codecs.BOM_UTF16_LE):
self.raw_decode = codecs.utf_16_le_decode
self.encoding = 'utf-16-le'
elif self.raw_buffer.startswith(codecs.BOM_UTF16_BE):
self.raw_decode = codecs.utf_16_be_decode
self.encoding = 'utf-16-be'
else:
self.raw_decode = codecs.utf_8_decode
self.encoding = 'utf-8'
self.update(1)
NON_PRINTABLE = re.compile('[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]')
def check_printable(self, data):
match = self.NON_PRINTABLE.search(data)
if match:
character = match.group()
position = self.index+(len(self.buffer)-self.pointer)+match.start()
raise ReaderError(self.name, position, ord(character),
'unicode', "special characters are not allowed")
def update(self, length):
if self.raw_buffer is None:
return
self.buffer = self.buffer[self.pointer:]
self.pointer = 0
while len(self.buffer) < length:
if not self.eof:
self.update_raw()
if self.raw_decode is not None:
try:
data, converted = self.raw_decode(self.raw_buffer,
'strict', self.eof)
except UnicodeDecodeError as exc:
character = self.raw_buffer[exc.start]
if self.stream is not None:
position = self.stream_pointer-len(self.raw_buffer)+exc.start
else:
position = exc.start
raise ReaderError(self.name, position, character,
exc.encoding, exc.reason)
else:
data = self.raw_buffer
converted = len(data)
self.check_printable(data)
self.buffer += data
self.raw_buffer = self.raw_buffer[converted:]
if self.eof:
self.buffer += '\0'
self.raw_buffer = None
break
def update_raw(self, size=4096):
data = self.stream.read(size)
if self.raw_buffer is None:
self.raw_buffer = data
else:
self.raw_buffer += data
self.stream_pointer += len(data)
if not data:
self.eof = True

View File

@@ -0,0 +1,389 @@
__all__ = ['BaseRepresenter', 'SafeRepresenter', 'Representer',
'RepresenterError']
from .error import *
from .nodes import *
import datetime, copyreg, types, base64, collections
class RepresenterError(YAMLError):
pass
class BaseRepresenter:
yaml_representers = {}
yaml_multi_representers = {}
def __init__(self, default_style=None, default_flow_style=False, sort_keys=True):
self.default_style = default_style
self.sort_keys = sort_keys
self.default_flow_style = default_flow_style
self.represented_objects = {}
self.object_keeper = []
self.alias_key = None
def represent(self, data):
node = self.represent_data(data)
self.serialize(node)
self.represented_objects = {}
self.object_keeper = []
self.alias_key = None
def represent_data(self, data):
if self.ignore_aliases(data):
self.alias_key = None
else:
self.alias_key = id(data)
if self.alias_key is not None:
if self.alias_key in self.represented_objects:
node = self.represented_objects[self.alias_key]
#if node is None:
# raise RepresenterError("recursive objects are not allowed: %r" % data)
return node
#self.represented_objects[alias_key] = None
self.object_keeper.append(data)
data_types = type(data).__mro__
if data_types[0] in self.yaml_representers:
node = self.yaml_representers[data_types[0]](self, data)
else:
for data_type in data_types:
if data_type in self.yaml_multi_representers:
node = self.yaml_multi_representers[data_type](self, data)
break
else:
if None in self.yaml_multi_representers:
node = self.yaml_multi_representers[None](self, data)
elif None in self.yaml_representers:
node = self.yaml_representers[None](self, data)
else:
node = ScalarNode(None, str(data))
#if alias_key is not None:
# self.represented_objects[alias_key] = node
return node
@classmethod
def add_representer(cls, data_type, representer):
if not 'yaml_representers' in cls.__dict__:
cls.yaml_representers = cls.yaml_representers.copy()
cls.yaml_representers[data_type] = representer
@classmethod
def add_multi_representer(cls, data_type, representer):
if not 'yaml_multi_representers' in cls.__dict__:
cls.yaml_multi_representers = cls.yaml_multi_representers.copy()
cls.yaml_multi_representers[data_type] = representer
def represent_scalar(self, tag, value, style=None):
if style is None:
style = self.default_style
node = ScalarNode(tag, value, style=style)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
return node
def represent_sequence(self, tag, sequence, flow_style=None):
value = []
node = SequenceNode(tag, value, flow_style=flow_style)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
best_style = True
for item in sequence:
node_item = self.represent_data(item)
if not (isinstance(node_item, ScalarNode) and not node_item.style):
best_style = False
value.append(node_item)
if flow_style is None:
if self.default_flow_style is not None:
node.flow_style = self.default_flow_style
else:
node.flow_style = best_style
return node
def represent_mapping(self, tag, mapping, flow_style=None):
value = []
node = MappingNode(tag, value, flow_style=flow_style)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
best_style = True
if hasattr(mapping, 'items'):
mapping = list(mapping.items())
if self.sort_keys:
try:
mapping = sorted(mapping)
except TypeError:
pass
for item_key, item_value in mapping:
node_key = self.represent_data(item_key)
node_value = self.represent_data(item_value)
if not (isinstance(node_key, ScalarNode) and not node_key.style):
best_style = False
if not (isinstance(node_value, ScalarNode) and not node_value.style):
best_style = False
value.append((node_key, node_value))
if flow_style is None:
if self.default_flow_style is not None:
node.flow_style = self.default_flow_style
else:
node.flow_style = best_style
return node
def ignore_aliases(self, data):
return False
class SafeRepresenter(BaseRepresenter):
def ignore_aliases(self, data):
if data is None:
return True
if isinstance(data, tuple) and data == ():
return True
if isinstance(data, (str, bytes, bool, int, float)):
return True
def represent_none(self, data):
return self.represent_scalar('tag:yaml.org,2002:null', 'null')
def represent_str(self, data):
return self.represent_scalar('tag:yaml.org,2002:str', data)
def represent_binary(self, data):
if hasattr(base64, 'encodebytes'):
data = base64.encodebytes(data).decode('ascii')
else:
data = base64.encodestring(data).decode('ascii')
return self.represent_scalar('tag:yaml.org,2002:binary', data, style='|')
def represent_bool(self, data):
if data:
value = 'true'
else:
value = 'false'
return self.represent_scalar('tag:yaml.org,2002:bool', value)
def represent_int(self, data):
return self.represent_scalar('tag:yaml.org,2002:int', str(data))
inf_value = 1e300
while repr(inf_value) != repr(inf_value*inf_value):
inf_value *= inf_value
def represent_float(self, data):
if data != data or (data == 0.0 and data == 1.0):
value = '.nan'
elif data == self.inf_value:
value = '.inf'
elif data == -self.inf_value:
value = '-.inf'
else:
value = repr(data).lower()
# Note that in some cases `repr(data)` represents a float number
# without the decimal parts. For instance:
# >>> repr(1e17)
# '1e17'
# Unfortunately, this is not a valid float representation according
# to the definition of the `!!float` tag. We fix this by adding
# '.0' before the 'e' symbol.
if '.' not in value and 'e' in value:
value = value.replace('e', '.0e', 1)
return self.represent_scalar('tag:yaml.org,2002:float', value)
def represent_list(self, data):
#pairs = (len(data) > 0 and isinstance(data, list))
#if pairs:
# for item in data:
# if not isinstance(item, tuple) or len(item) != 2:
# pairs = False
# break
#if not pairs:
return self.represent_sequence('tag:yaml.org,2002:seq', data)
#value = []
#for item_key, item_value in data:
# value.append(self.represent_mapping(u'tag:yaml.org,2002:map',
# [(item_key, item_value)]))
#return SequenceNode(u'tag:yaml.org,2002:pairs', value)
def represent_dict(self, data):
return self.represent_mapping('tag:yaml.org,2002:map', data)
def represent_set(self, data):
value = {}
for key in data:
value[key] = None
return self.represent_mapping('tag:yaml.org,2002:set', value)
def represent_date(self, data):
value = data.isoformat()
return self.represent_scalar('tag:yaml.org,2002:timestamp', value)
def represent_datetime(self, data):
value = data.isoformat(' ')
return self.represent_scalar('tag:yaml.org,2002:timestamp', value)
def represent_yaml_object(self, tag, data, cls, flow_style=None):
if hasattr(data, '__getstate__'):
state = data.__getstate__()
else:
state = data.__dict__.copy()
return self.represent_mapping(tag, state, flow_style=flow_style)
def represent_undefined(self, data):
raise RepresenterError("cannot represent an object", data)
SafeRepresenter.add_representer(type(None),
SafeRepresenter.represent_none)
SafeRepresenter.add_representer(str,
SafeRepresenter.represent_str)
SafeRepresenter.add_representer(bytes,
SafeRepresenter.represent_binary)
SafeRepresenter.add_representer(bool,
SafeRepresenter.represent_bool)
SafeRepresenter.add_representer(int,
SafeRepresenter.represent_int)
SafeRepresenter.add_representer(float,
SafeRepresenter.represent_float)
SafeRepresenter.add_representer(list,
SafeRepresenter.represent_list)
SafeRepresenter.add_representer(tuple,
SafeRepresenter.represent_list)
SafeRepresenter.add_representer(dict,
SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(set,
SafeRepresenter.represent_set)
SafeRepresenter.add_representer(datetime.date,
SafeRepresenter.represent_date)
SafeRepresenter.add_representer(datetime.datetime,
SafeRepresenter.represent_datetime)
SafeRepresenter.add_representer(None,
SafeRepresenter.represent_undefined)
class Representer(SafeRepresenter):
def represent_complex(self, data):
if data.imag == 0.0:
data = '%r' % data.real
elif data.real == 0.0:
data = '%rj' % data.imag
elif data.imag > 0:
data = '%r+%rj' % (data.real, data.imag)
else:
data = '%r%rj' % (data.real, data.imag)
return self.represent_scalar('tag:yaml.org,2002:python/complex', data)
def represent_tuple(self, data):
return self.represent_sequence('tag:yaml.org,2002:python/tuple', data)
def represent_name(self, data):
name = '%s.%s' % (data.__module__, data.__name__)
return self.represent_scalar('tag:yaml.org,2002:python/name:'+name, '')
def represent_module(self, data):
return self.represent_scalar(
'tag:yaml.org,2002:python/module:'+data.__name__, '')
def represent_object(self, data):
# We use __reduce__ API to save the data. data.__reduce__ returns
# a tuple of length 2-5:
# (function, args, state, listitems, dictitems)
# For reconstructing, we calls function(*args), then set its state,
# listitems, and dictitems if they are not None.
# A special case is when function.__name__ == '__newobj__'. In this
# case we create the object with args[0].__new__(*args).
# Another special case is when __reduce__ returns a string - we don't
# support it.
# We produce a !!python/object, !!python/object/new or
# !!python/object/apply node.
cls = type(data)
if cls in copyreg.dispatch_table:
reduce = copyreg.dispatch_table[cls](data)
elif hasattr(data, '__reduce_ex__'):
reduce = data.__reduce_ex__(2)
elif hasattr(data, '__reduce__'):
reduce = data.__reduce__()
else:
raise RepresenterError("cannot represent an object", data)
reduce = (list(reduce)+[None]*5)[:5]
function, args, state, listitems, dictitems = reduce
args = list(args)
if state is None:
state = {}
if listitems is not None:
listitems = list(listitems)
if dictitems is not None:
dictitems = dict(dictitems)
if function.__name__ == '__newobj__':
function = args[0]
args = args[1:]
tag = 'tag:yaml.org,2002:python/object/new:'
newobj = True
else:
tag = 'tag:yaml.org,2002:python/object/apply:'
newobj = False
function_name = '%s.%s' % (function.__module__, function.__name__)
if not args and not listitems and not dictitems \
and isinstance(state, dict) and newobj:
return self.represent_mapping(
'tag:yaml.org,2002:python/object:'+function_name, state)
if not listitems and not dictitems \
and isinstance(state, dict) and not state:
return self.represent_sequence(tag+function_name, args)
value = {}
if args:
value['args'] = args
if state or not isinstance(state, dict):
value['state'] = state
if listitems:
value['listitems'] = listitems
if dictitems:
value['dictitems'] = dictitems
return self.represent_mapping(tag+function_name, value)
def represent_ordered_dict(self, data):
# Provide uniform representation across different Python versions.
data_type = type(data)
tag = 'tag:yaml.org,2002:python/object/apply:%s.%s' \
% (data_type.__module__, data_type.__name__)
items = [[key, value] for key, value in data.items()]
return self.represent_sequence(tag, [items])
Representer.add_representer(complex,
Representer.represent_complex)
Representer.add_representer(tuple,
Representer.represent_tuple)
Representer.add_representer(type,
Representer.represent_name)
Representer.add_representer(collections.OrderedDict,
Representer.represent_ordered_dict)
Representer.add_representer(types.FunctionType,
Representer.represent_name)
Representer.add_representer(types.BuiltinFunctionType,
Representer.represent_name)
Representer.add_representer(types.ModuleType,
Representer.represent_module)
Representer.add_multi_representer(object,
Representer.represent_object)

View File

@@ -0,0 +1,227 @@
__all__ = ['BaseResolver', 'Resolver']
from .error import *
from .nodes import *
import re
class ResolverError(YAMLError):
pass
class BaseResolver:
DEFAULT_SCALAR_TAG = 'tag:yaml.org,2002:str'
DEFAULT_SEQUENCE_TAG = 'tag:yaml.org,2002:seq'
DEFAULT_MAPPING_TAG = 'tag:yaml.org,2002:map'
yaml_implicit_resolvers = {}
yaml_path_resolvers = {}
def __init__(self):
self.resolver_exact_paths = []
self.resolver_prefix_paths = []
@classmethod
def add_implicit_resolver(cls, tag, regexp, first):
if not 'yaml_implicit_resolvers' in cls.__dict__:
implicit_resolvers = {}
for key in cls.yaml_implicit_resolvers:
implicit_resolvers[key] = cls.yaml_implicit_resolvers[key][:]
cls.yaml_implicit_resolvers = implicit_resolvers
if first is None:
first = [None]
for ch in first:
cls.yaml_implicit_resolvers.setdefault(ch, []).append((tag, regexp))
@classmethod
def add_path_resolver(cls, tag, path, kind=None):
# Note: `add_path_resolver` is experimental. The API could be changed.
# `new_path` is a pattern that is matched against the path from the
# root to the node that is being considered. `node_path` elements are
# tuples `(node_check, index_check)`. `node_check` is a node class:
# `ScalarNode`, `SequenceNode`, `MappingNode` or `None`. `None`
# matches any kind of a node. `index_check` could be `None`, a boolean
# value, a string value, or a number. `None` and `False` match against
# any _value_ of sequence and mapping nodes. `True` matches against
# any _key_ of a mapping node. A string `index_check` matches against
# a mapping value that corresponds to a scalar key which content is
# equal to the `index_check` value. An integer `index_check` matches
# against a sequence value with the index equal to `index_check`.
if not 'yaml_path_resolvers' in cls.__dict__:
cls.yaml_path_resolvers = cls.yaml_path_resolvers.copy()
new_path = []
for element in path:
if isinstance(element, (list, tuple)):
if len(element) == 2:
node_check, index_check = element
elif len(element) == 1:
node_check = element[0]
index_check = True
else:
raise ResolverError("Invalid path element: %s" % element)
else:
node_check = None
index_check = element
if node_check is str:
node_check = ScalarNode
elif node_check is list:
node_check = SequenceNode
elif node_check is dict:
node_check = MappingNode
elif node_check not in [ScalarNode, SequenceNode, MappingNode] \
and not isinstance(node_check, str) \
and node_check is not None:
raise ResolverError("Invalid node checker: %s" % node_check)
if not isinstance(index_check, (str, int)) \
and index_check is not None:
raise ResolverError("Invalid index checker: %s" % index_check)
new_path.append((node_check, index_check))
if kind is str:
kind = ScalarNode
elif kind is list:
kind = SequenceNode
elif kind is dict:
kind = MappingNode
elif kind not in [ScalarNode, SequenceNode, MappingNode] \
and kind is not None:
raise ResolverError("Invalid node kind: %s" % kind)
cls.yaml_path_resolvers[tuple(new_path), kind] = tag
def descend_resolver(self, current_node, current_index):
if not self.yaml_path_resolvers:
return
exact_paths = {}
prefix_paths = []
if current_node:
depth = len(self.resolver_prefix_paths)
for path, kind in self.resolver_prefix_paths[-1]:
if self.check_resolver_prefix(depth, path, kind,
current_node, current_index):
if len(path) > depth:
prefix_paths.append((path, kind))
else:
exact_paths[kind] = self.yaml_path_resolvers[path, kind]
else:
for path, kind in self.yaml_path_resolvers:
if not path:
exact_paths[kind] = self.yaml_path_resolvers[path, kind]
else:
prefix_paths.append((path, kind))
self.resolver_exact_paths.append(exact_paths)
self.resolver_prefix_paths.append(prefix_paths)
def ascend_resolver(self):
if not self.yaml_path_resolvers:
return
self.resolver_exact_paths.pop()
self.resolver_prefix_paths.pop()
def check_resolver_prefix(self, depth, path, kind,
current_node, current_index):
node_check, index_check = path[depth-1]
if isinstance(node_check, str):
if current_node.tag != node_check:
return
elif node_check is not None:
if not isinstance(current_node, node_check):
return
if index_check is True and current_index is not None:
return
if (index_check is False or index_check is None) \
and current_index is None:
return
if isinstance(index_check, str):
if not (isinstance(current_index, ScalarNode)
and index_check == current_index.value):
return
elif isinstance(index_check, int) and not isinstance(index_check, bool):
if index_check != current_index:
return
return True
def resolve(self, kind, value, implicit):
if kind is ScalarNode and implicit[0]:
if value == '':
resolvers = self.yaml_implicit_resolvers.get('', [])
else:
resolvers = self.yaml_implicit_resolvers.get(value[0], [])
wildcard_resolvers = self.yaml_implicit_resolvers.get(None, [])
for tag, regexp in resolvers + wildcard_resolvers:
if regexp.match(value):
return tag
implicit = implicit[1]
if self.yaml_path_resolvers:
exact_paths = self.resolver_exact_paths[-1]
if kind in exact_paths:
return exact_paths[kind]
if None in exact_paths:
return exact_paths[None]
if kind is ScalarNode:
return self.DEFAULT_SCALAR_TAG
elif kind is SequenceNode:
return self.DEFAULT_SEQUENCE_TAG
elif kind is MappingNode:
return self.DEFAULT_MAPPING_TAG
class Resolver(BaseResolver):
pass
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:bool',
re.compile(r'''^(?:yes|Yes|YES|no|No|NO
|true|True|TRUE|false|False|FALSE
|on|On|ON|off|Off|OFF)$''', re.X),
list('yYnNtTfFoO'))
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:float',
re.compile(r'''^(?:[-+]?(?:[0-9][0-9_]*)\.[0-9_]*(?:[eE][-+][0-9]+)?
|\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\.[0-9_]*
|[-+]?\.(?:inf|Inf|INF)
|\.(?:nan|NaN|NAN))$''', re.X),
list('-+0123456789.'))
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:int',
re.compile(r'''^(?:[-+]?0b[0-1_]+
|[-+]?0[0-7_]+
|[-+]?(?:0|[1-9][0-9_]*)
|[-+]?0x[0-9a-fA-F_]+
|[-+]?[1-9][0-9_]*(?::[0-5]?[0-9])+)$''', re.X),
list('-+0123456789'))
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:merge',
re.compile(r'^(?:<<)$'),
['<'])
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:null',
re.compile(r'''^(?: ~
|null|Null|NULL
| )$''', re.X),
['~', 'n', 'N', ''])
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:timestamp',
re.compile(r'''^(?:[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]
|[0-9][0-9][0-9][0-9] -[0-9][0-9]? -[0-9][0-9]?
(?:[Tt]|[ \t]+)[0-9][0-9]?
:[0-9][0-9] :[0-9][0-9] (?:\.[0-9]*)?
(?:[ \t]*(?:Z|[-+][0-9][0-9]?(?::[0-9][0-9])?))?)$''', re.X),
list('0123456789'))
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:value',
re.compile(r'^(?:=)$'),
['='])
# The following resolver is only for documentation purposes. It cannot work
# because plain scalars cannot start with '!', '&', or '*'.
Resolver.add_implicit_resolver(
'tag:yaml.org,2002:yaml',
re.compile(r'^(?:!|&|\*)$'),
list('!&*'))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,111 @@
__all__ = ['Serializer', 'SerializerError']
from .error import YAMLError
from .events import *
from .nodes import *
class SerializerError(YAMLError):
pass
class Serializer:
ANCHOR_TEMPLATE = 'id%03d'
def __init__(self, encoding=None,
explicit_start=None, explicit_end=None, version=None, tags=None):
self.use_encoding = encoding
self.use_explicit_start = explicit_start
self.use_explicit_end = explicit_end
self.use_version = version
self.use_tags = tags
self.serialized_nodes = {}
self.anchors = {}
self.last_anchor_id = 0
self.closed = None
def open(self):
if self.closed is None:
self.emit(StreamStartEvent(encoding=self.use_encoding))
self.closed = False
elif self.closed:
raise SerializerError("serializer is closed")
else:
raise SerializerError("serializer is already opened")
def close(self):
if self.closed is None:
raise SerializerError("serializer is not opened")
elif not self.closed:
self.emit(StreamEndEvent())
self.closed = True
#def __del__(self):
# self.close()
def serialize(self, node):
if self.closed is None:
raise SerializerError("serializer is not opened")
elif self.closed:
raise SerializerError("serializer is closed")
self.emit(DocumentStartEvent(explicit=self.use_explicit_start,
version=self.use_version, tags=self.use_tags))
self.anchor_node(node)
self.serialize_node(node, None, None)
self.emit(DocumentEndEvent(explicit=self.use_explicit_end))
self.serialized_nodes = {}
self.anchors = {}
self.last_anchor_id = 0
def anchor_node(self, node):
if node in self.anchors:
if self.anchors[node] is None:
self.anchors[node] = self.generate_anchor(node)
else:
self.anchors[node] = None
if isinstance(node, SequenceNode):
for item in node.value:
self.anchor_node(item)
elif isinstance(node, MappingNode):
for key, value in node.value:
self.anchor_node(key)
self.anchor_node(value)
def generate_anchor(self, node):
self.last_anchor_id += 1
return self.ANCHOR_TEMPLATE % self.last_anchor_id
def serialize_node(self, node, parent, index):
alias = self.anchors[node]
if node in self.serialized_nodes:
self.emit(AliasEvent(alias))
else:
self.serialized_nodes[node] = True
self.descend_resolver(parent, index)
if isinstance(node, ScalarNode):
detected_tag = self.resolve(ScalarNode, node.value, (True, False))
default_tag = self.resolve(ScalarNode, node.value, (False, True))
implicit = (node.tag == detected_tag), (node.tag == default_tag)
self.emit(ScalarEvent(alias, node.tag, implicit, node.value,
style=node.style))
elif isinstance(node, SequenceNode):
implicit = (node.tag
== self.resolve(SequenceNode, node.value, True))
self.emit(SequenceStartEvent(alias, node.tag, implicit,
flow_style=node.flow_style))
index = 0
for item in node.value:
self.serialize_node(item, node, index)
index += 1
self.emit(SequenceEndEvent())
elif isinstance(node, MappingNode):
implicit = (node.tag
== self.resolve(MappingNode, node.value, True))
self.emit(MappingStartEvent(alias, node.tag, implicit,
flow_style=node.flow_style))
for key, value in node.value:
self.serialize_node(key, node, None)
self.serialize_node(value, node, key)
self.emit(MappingEndEvent())
self.ascend_resolver()

View File

@@ -0,0 +1,104 @@
class Token(object):
def __init__(self, start_mark, end_mark):
self.start_mark = start_mark
self.end_mark = end_mark
def __repr__(self):
attributes = [key for key in self.__dict__
if not key.endswith('_mark')]
attributes.sort()
arguments = ', '.join(['%s=%r' % (key, getattr(self, key))
for key in attributes])
return '%s(%s)' % (self.__class__.__name__, arguments)
#class BOMToken(Token):
# id = '<byte order mark>'
class DirectiveToken(Token):
id = '<directive>'
def __init__(self, name, value, start_mark, end_mark):
self.name = name
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
class DocumentStartToken(Token):
id = '<document start>'
class DocumentEndToken(Token):
id = '<document end>'
class StreamStartToken(Token):
id = '<stream start>'
def __init__(self, start_mark=None, end_mark=None,
encoding=None):
self.start_mark = start_mark
self.end_mark = end_mark
self.encoding = encoding
class StreamEndToken(Token):
id = '<stream end>'
class BlockSequenceStartToken(Token):
id = '<block sequence start>'
class BlockMappingStartToken(Token):
id = '<block mapping start>'
class BlockEndToken(Token):
id = '<block end>'
class FlowSequenceStartToken(Token):
id = '['
class FlowMappingStartToken(Token):
id = '{'
class FlowSequenceEndToken(Token):
id = ']'
class FlowMappingEndToken(Token):
id = '}'
class KeyToken(Token):
id = '?'
class ValueToken(Token):
id = ':'
class BlockEntryToken(Token):
id = '-'
class FlowEntryToken(Token):
id = ','
class AliasToken(Token):
id = '<alias>'
def __init__(self, value, start_mark, end_mark):
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
class AnchorToken(Token):
id = '<anchor>'
def __init__(self, value, start_mark, end_mark):
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
class TagToken(Token):
id = '<tag>'
def __init__(self, value, start_mark, end_mark):
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
class ScalarToken(Token):
id = '<scalar>'
def __init__(self, value, plain, start_mark, end_mark, style=None):
self.value = value
self.plain = plain
self.start_mark = start_mark
self.end_mark = end_mark
self.style = style

View File

@@ -0,0 +1,145 @@
import re
import threading
from typing import Generator, Iterable
from .exceptions import ResponseNotAccepted
class StreamWatcher(threading.local):
"""
A class whose subclasses may act on seen stream data from subprocesses.
Subclasses must exhibit the following API; see `Responder` for a concrete
example.
* ``__init__`` is completely up to each subclass, though as usual,
subclasses *of* subclasses should be careful to make use of `super` where
appropriate.
* `submit` must accept the entire current contents of the stream being
watched, as a string, and may optionally return an iterable of strings
(or act as a generator iterator, i.e. multiple calls to ``yield
<string>``), which will each be written to the subprocess' standard
input.
.. note::
`StreamWatcher` subclasses exist in part to enable state tracking, such
as detecting when a submitted password didn't work & erroring (or
prompting a user, or etc). Such bookkeeping isn't easily achievable
with simple callback functions.
.. note::
`StreamWatcher` subclasses `threading.local` so that its instances can
be used to 'watch' both subprocess stdout and stderr in separate
threads.
.. versionadded:: 1.0
"""
def submit(self, stream: str) -> Iterable[str]:
"""
Act on ``stream`` data, potentially returning responses.
:param str stream:
All data read on this stream since the beginning of the session.
:returns:
An iterable of ``str`` (which may be empty).
.. versionadded:: 1.0
"""
raise NotImplementedError
class Responder(StreamWatcher):
"""
A parameterizable object that submits responses to specific patterns.
Commonly used to implement password auto-responds for things like ``sudo``.
.. versionadded:: 1.0
"""
def __init__(self, pattern: str, response: str) -> None:
r"""
Imprint this `Responder` with necessary parameters.
:param pattern:
A raw string (e.g. ``r"\[sudo\] password for .*:"``) which will be
turned into a regular expression.
:param response:
The string to submit to the subprocess' stdin when ``pattern`` is
detected.
"""
# TODO: precompile the keys into regex objects
self.pattern = pattern
self.response = response
self.index = 0
def pattern_matches(
self, stream: str, pattern: str, index_attr: str
) -> Iterable[str]:
"""
Generic "search for pattern in stream, using index" behavior.
Used here and in some subclasses that want to track multiple patterns
concurrently.
:param str stream: The same data passed to ``submit``.
:param str pattern: The pattern to search for.
:param str index_attr: The name of the index attribute to use.
:returns: An iterable of string matches.
.. versionadded:: 1.0
"""
# NOTE: generifies scanning so it can be used to scan for >1 pattern at
# once, e.g. in FailingResponder.
# Only look at stream contents we haven't seen yet, to avoid dupes.
index = getattr(self, index_attr)
new = stream[index:]
# Search, across lines if necessary
matches = re.findall(pattern, new, re.S)
# Update seek index if we've matched
if matches:
setattr(self, index_attr, index + len(new))
return matches
def submit(self, stream: str) -> Generator[str, None, None]:
# Iterate over findall() response in case >1 match occurred.
for _ in self.pattern_matches(stream, self.pattern, "index"):
yield self.response
class FailingResponder(Responder):
"""
Variant of `Responder` which is capable of detecting incorrect responses.
This class adds a ``sentinel`` parameter to ``__init__``, and its
``submit`` will raise `.ResponseNotAccepted` if it detects that sentinel
value in the stream.
.. versionadded:: 1.0
"""
def __init__(self, pattern: str, response: str, sentinel: str) -> None:
super().__init__(pattern, response)
self.sentinel = sentinel
self.failure_index = 0
self.tried = False
def submit(self, stream: str) -> Generator[str, None, None]:
# Behave like regular Responder initially
response = super().submit(stream)
# Also check stream for our failure sentinel
failed = self.pattern_matches(stream, self.sentinel, "failure_index")
# Error out if we seem to have failed after a previous response.
if self.tried and failed:
err = 'Auto-response to r"{}" failed with {!r}!'.format(
self.pattern, self.sentinel
)
raise ResponseNotAccepted(err)
# Once we see that we had a response, take note
if response:
self.tried = True
# Again, behave regularly by default.
return response