Having contributed to this Community Wiki on StackOverflow regarding validating user input, I thought I'd finally sit down and write something more robust to do these kinds of tasks. I wanted something that would allow detailed configuration of the acceptable inputs, but without a huge list of arguments or spaghetti of processing code, and that would be easily extensible (for e.g. adding new validation options).
The code is written to be compatible with both 2.x and 3.x (tested in 2.7.9 and 3.4.2 on Mavericks), and has been pylint
ed in both ("Your code has been rated at 10.00/10"). I have also included some built-in testing.
As well as any general comments/suggestions you may have, there are a few specific points I'd be particularly interested in feedback on:
- The use of classes (note I've had to disable
too-few-public-methods
- was OOP the right way to go, or should I have used e.g. a function factory?); - The implementation of instance caching (is
Cached
sufficiently reusable in other contexts, or too tightly tied toInputValidator
?); and - How the testing is set up.
Code [also available as a gist]:
"""Functionality for validating user inputs."""
# pylint: disable=too-few-public-methods
from __future__ import print_function
import sys
__all__ = ['InputValidator']
def create_choice_validator(choices):
"""Create a validator function based on defined choices.
Notes:
Attempts to create a set from choices to speed up membership tests
with hashable choices.
Arguments:
choices (collection): The valid choices.
Returns:
callable: A validation function to apply to user input.
"""
try:
choices = set(choices)
except TypeError:
pass
def validator(ui_):
"""Validate user input based on choices."""
if ui_ not in choices:
msg = 'Input must be one of {!r}'
raise ValueError(msg.format(choices))
return validator
def create_empty_validator(allow_empty):
"""Validate user input based on presence.
Arguments:
allow_empty (bool): Whether to allow empty input.
Returns:
callable: A validation function to apply to user input.
"""
if not allow_empty:
def validator(ui_):
"""Reject False-y input."""
if not ui_:
raise ValueError('Input must be present.')
else:
validator = lambda ui_: None
return validator
def create_len_validator(len_):
"""Create a validation function based on input length.
Arguments:
len_ (int or tuple): Either the acceptable length, or a tuple
(min_len, max_len).
Returns:
callable: A validation function to apply to user input.
"""
try:
min_, max_ = len_
except TypeError:
def validator(ui_):
"""Validate user input based on length."""
if len(ui_) != len_:
msg = 'Input must contain {} elements.'
raise ValueError(msg.format(len_))
else:
def validator(ui_):
"""Validate user input based on length."""
if len(ui_) < min_:
msg = 'Input must contain at least {} elements.'
raise ValueError(msg.format(min_))
elif len(ui_) > max_:
msg = 'Input must contain at most {} elements.'
raise ValueError(msg.format(max_))
return validator
def create_max_validator(max_):
"""Create a validation function based on input size.
Arguments:
max_: The maximum permitted value.
Returns:
callable: A validation function to apply to user input.
"""
def validator(ui_):
"""Validate user input based on size."""
if ui_ > max_:
msg = 'Input must be at most {}.'
raise ValueError(msg.format(max_))
return validator
def create_min_validator(min_):
"""Create a validation function based on input size.
Arguments:
min_: The minimum permitted value.
Returns:
callable: A validation function to apply to user input.
"""
def validator(ui_):
"""Validate user input based on size."""
if ui_ < min_:
msg = 'Input must be at least {}.'
raise ValueError(msg.format(min_))
return validator
class Cached(object):
"""Cache classes by positional arguments."""
# pylint: disable=no-member
def __new__(cls, *args, **_):
if not hasattr(cls, 'cache'):
setattr(cls, 'cache', {})
if not args:
return super(Cached, cls).__new__(cls)
if args not in cls.cache:
cls.cache[args] = super(Cached, cls).__new__(cls)
return cls.cache[args]
class InputValidator(Cached):
"""Create validators for user input.
Notes:
Type is validated first - the argument to all other validation
functions is the type-converted input.
The following **config options are supported:
- choices (collection): The valid choices for the input.
- prompt (str): The default prompt to use if not supplied to
get_input (defaults to InputValidator.DEFAULT_PROMPT).
- allow_empty' (bool): Whether to allow '' (defaults to False).
- min_: The minimum value permitted.
- max_: The maximum value permitted.
- source (callable): The function to use to take user input
(defaults to [raw_]input).
- type_ (callable): The type to attempt to convert the input to
(defaults to str).
Arguments:
name (str, optional): The name to store the validator under.
Defaults to None (i.e. not stored).
**config (dict): The configuration options for the validator.
Attributes:
DEFAULT_PROMPT (str): The default prompt to use if not supplied
in config or the call to get_input.
VALIDATORS (list): The validation functions.
"""
DEFAULT_PROMPT = '> '
VALIDATORS = [
(('choices',), create_choice_validator),
(('allow_empty', False), create_empty_validator),
(('len_',), create_len_validator),
(('min_',), create_min_validator),
(('max_',), create_max_validator),
]
def __new__(cls, name=None, **config):
if name is None:
self = super(InputValidator, cls).__new__(cls)
else:
self = super(InputValidator, cls).__new__(cls, name)
if hasattr(self, 'config') and self.config != config:
raise TypeError('Configuration conflict')
return self
def __init__(self, name=None, **config):
# Basic arguments
self.config = config
self.name = name
# Select appropriate source for user input
source = config.get('source')
if source is None:
if sys.version_info.major < 3:
source = raw_input # pylint: disable=undefined-variable
else:
source = input
self.source = source
# Default configuration
self.empty = config.get('empty', False)
self.prompt = config.get('prompt', self.DEFAULT_PROMPT)
self.type_ = config.get('type_', str)
# Validation functions
self.validators = []
for get_args, creator in self.VALIDATORS:
item = config.get(*get_args) # pylint: disable=star-args
if item is not None:
self.validators.append(creator(item))
def get_input(self, prompt=None):
"""Get validated input.
Arguments:
prompt (str, optional): The prompt to use. Defaults to the
instance's prompt attribute.
"""
if prompt is None:
prompt = self.prompt
while True:
ui_ = self.source(prompt)
# Basic type validation
try:
ui_ = self.type_(ui_)
except ValueError as err:
msg = 'Input must be {!r}.'
print(msg.format(self.type_))
continue
# Any other validation required
for validate in self.validators:
try:
validate(ui_)
except ValueError as err:
print(err)
break
else:
return ui_
def __call__(self, *args, **kwargs):
"""Allow direct call, invoking get_input."""
return self.get_input(*args, **kwargs)
if __name__ == '__main__':
# Built-in testing
from ast import literal_eval
class SuppressStdOut(object):
"""Suppress the standard output for testing."""
def flush(self, *_, **__):
"""Don't flush anything."""
pass
def write(self, *_, **__):
"""Don't write anything."""
pass
sys.stdout = SuppressStdOut()
def input_test(_):
"""Return whatever is first in args."""
return input_test.args.pop(0)
# 1. Caching
# Ensure caching isn't activated without name argument
assert InputValidator() is not InputValidator()
# Ensure caching is activated with positional name...
assert InputValidator('name') is InputValidator('name')
# ...and keyword name...
assert InputValidator('name') is InputValidator(name='name')
# ...and handles configuration conflicts
try:
_ = InputValidator('name', option='other')
except TypeError:
pass
else:
assert False, 'TypeError not thrown for configuration conflict'
# 2. Calling
input_test.args = ['test', 'test']
# Test both call forms return correct value
VALIDATOR = InputValidator(source=input_test)
assert VALIDATOR.get_input() == VALIDATOR() == 'test'
# 3. Numerical validation
input_test.args = ['-1', '11', 'foo', '5']
VALIDATOR = InputValidator(source=input_test, type_=int, min_=0, max_=10)
assert VALIDATOR() == 5
# 4. Empty string validation
# Test empty not allowed...
input_test.args = ['', 'test', '']
VALIDATOR = InputValidator(source=input_test)
assert VALIDATOR() == 'test'
# ...and allowed
input_test.args = ['']
VALIDATOR = InputValidator(source=input_test, allow_empty=True)
assert VALIDATOR() == ''
# 5. Choice validation
input_test.args = ['foo', 'bar']
VALIDATOR = InputValidator(source=input_test, choices=['bar'])
assert VALIDATOR() == 'bar'
# 6. Length validation
# Test exact length...
CORRECT_LEN = 10
input_test.args = [
'a' * (CORRECT_LEN + 1),
'a' * (CORRECT_LEN - 1),
'a' * CORRECT_LEN
]
VALIDATOR = InputValidator(source=input_test, len_=CORRECT_LEN)
assert VALIDATOR() == 'a' * CORRECT_LEN
# ...and length range...
MIN_LEN = 5
MAX_LEN = 10
input_test.args = [
'a' * (MIN_LEN - 1),
'a' * (MAX_LEN + 1),
'a' * MAX_LEN
]
VALIDATOR = InputValidator(source=input_test, len_=(MIN_LEN, MAX_LEN))
assert VALIDATOR() == 'a' * MAX_LEN
# ...and errors
LEN = 'foo'
try:
_ = InputValidator(len_=LEN)
except ValueError:
pass
else:
assert False, 'ValueError not thrown for {!r}.'.format(LEN)
# 7. Something completely different
OUTPUT = ['foo', 'bar', 'baz']
input_test.args = ['[]', '["foo"]', repr(OUTPUT)]
VALIDATOR = InputValidator(source=input_test, len_=3, type_=literal_eval)
assert VALIDATOR() == OUTPUT