"""Tools for generating forms based on SQLAlchemy models."""
import inspect
from sqlalchemy import inspect as sainspect
from wtforms import fields as wtforms_fields
from wtforms import validators
from wtforms.form import Form
from .fields import QuerySelectField
from .fields import QuerySelectMultipleField
__all__ = (
"model_fields",
"model_form",
)
def converts(*args):
def _inner(func):
func._converter_for = frozenset(args)
return func
return _inner
class ModelConversionError(Exception):
def __init__(self, message):
Exception.__init__(self, message)
class ModelConverterBase:
def __init__(self, converters, use_mro=True):
self.use_mro = use_mro
if not converters:
converters = {}
for name in dir(self):
obj = getattr(self, name)
if hasattr(obj, "_converter_for"):
for classname in obj._converter_for:
converters[classname] = obj
self.converters = converters
def get_converter(self, column):
"""Searches `self.converters` for a converter method with an argument
that matches the column's type."""
if self.use_mro:
types = inspect.getmro(type(column.type))
else:
types = [type(column.type)]
# Search by module + name
for col_type in types:
type_string = f"{col_type.__module__}.{col_type.__name__}"
# remove the 'sqlalchemy.' prefix for sqlalchemy <0.7 compatibility
if type_string.startswith("sqlalchemy."):
type_string = type_string[11:]
if type_string in self.converters:
return self.converters[type_string]
# Search by name
for col_type in types:
if col_type.__name__ in self.converters:
return self.converters[col_type.__name__]
raise ModelConversionError(
"Could not find field converter for column %s (%r)."
% (column.name, types[0])
)
def convert(self, model, mapper, prop, field_args, db_session=None):
if not hasattr(prop, "columns") and not hasattr(prop, "direction"):
return
elif not hasattr(prop, "direction") and len(prop.columns) != 1:
raise TypeError(
"Do not know how to convert multiple-column properties currently"
)
kwargs = {
"validators": [],
"filters": [],
"default": None,
"description": prop.doc,
}
if field_args:
kwargs.update(field_args)
if kwargs["validators"]:
# Copy to prevent modifying nested mutable values of the original
kwargs["validators"] = list(kwargs["validators"])
converter = None
column = None
if not hasattr(prop, "direction"):
column = prop.columns[0]
# Support sqlalchemy.schema.ColumnDefault, so users can benefit
# from setting defaults for fields, e.g.:
# field = Column(DateTimeField, default=datetime.utcnow)
default = getattr(column, "default", None)
if default is not None:
# Only actually change default if it has an attribute named
# 'arg' that's callable.
callable_default = getattr(default, "arg", None)
if callable_default is not None:
# ColumnDefault(val).arg can be also a plain value
default = (
callable_default(None)
if callable(callable_default)
else callable_default
)
kwargs["default"] = default
if column.nullable:
kwargs["validators"].append(validators.Optional())
else:
kwargs["validators"].append(validators.InputRequired())
converter = self.get_converter(column)
else:
# We have a property with a direction.
if db_session is None:
raise ModelConversionError(
"Cannot convert field %s, need DB session." % prop.key
)
foreign_model = prop.mapper.class_
nullable = True
for pair in prop.local_remote_pairs:
if not pair[0].nullable:
nullable = False
kwargs.update(
{
"allow_blank": nullable,
"query_factory": lambda: db_session.query(foreign_model).all(),
}
)
converter = self.converters[prop.direction.name]
return converter(
model=model, mapper=mapper, prop=prop, column=column, field_args=kwargs
)
class ModelConverter(ModelConverterBase):
def __init__(self, extra_converters=None, use_mro=True):
super().__init__(extra_converters, use_mro=use_mro)
@classmethod
def _string_common(cls, column, field_args, **extra):
if isinstance(column.type.length, int) and column.type.length:
field_args["validators"].append(validators.Length(max=column.type.length))
@converts("String") # includes Unicode
def conv_String(self, field_args, **extra):
self._string_common(field_args=field_args, **extra)
return wtforms_fields.StringField(**field_args)
@converts("Text", "LargeBinary", "Binary") # includes UnicodeText
def conv_Text(self, field_args, **extra):
self._string_common(field_args=field_args, **extra)
return wtforms_fields.TextAreaField(**field_args)
@converts("Boolean", "dialects.mssql.base.BIT")
def conv_Boolean(self, field_args, **extra):
return wtforms_fields.BooleanField(**field_args)
@converts("Date")
def conv_Date(self, field_args, **extra):
return wtforms_fields.DateField(**field_args)
@converts("DateTime")
def conv_DateTime(self, field_args, **extra):
return wtforms_fields.DateTimeField(**field_args)
@converts("Enum")
def conv_Enum(self, column, field_args, **extra):
field_args["choices"] = [(e, e) for e in column.type.enums]
return wtforms_fields.SelectField(**field_args)
@converts("Integer") # includes BigInteger and SmallInteger
def handle_integer_types(self, column, field_args, **extra):
unsigned = getattr(column.type, "unsigned", False)
if unsigned:
field_args["validators"].append(validators.NumberRange(min=0))
return wtforms_fields.IntegerField(**field_args)
@converts("Numeric") # includes DECIMAL, Float/FLOAT, REAL, and DOUBLE
def handle_decimal_types(self, column, field_args, **extra):
# override default decimal places limit, use database defaults instead
field_args.setdefault("places", None)
return wtforms_fields.DecimalField(**field_args)
@converts("dialects.mysql.types.YEAR", "dialects.mysql.base.YEAR")
def conv_MSYear(self, field_args, **extra):
field_args["validators"].append(validators.NumberRange(min=1901, max=2155))
return wtforms_fields.StringField(**field_args)
@converts("dialects.postgresql.types.INET", "dialects.postgresql.base.INET")
def conv_PGInet(self, field_args, **extra):
field_args.setdefault("label", "IP Address")
field_args["validators"].append(validators.IPAddress())
return wtforms_fields.StringField(**field_args)
@converts("dialects.postgresql.types.MACADDR", "dialects.postgresql.base.MACADDR")
def conv_PGMacaddr(self, field_args, **extra):
field_args.setdefault("label", "MAC Address")
field_args["validators"].append(validators.MacAddress())
return wtforms_fields.StringField(**field_args)
@converts(
"sql.sqltypes.UUID",
"dialects.postgresql.types.UUID",
"dialects.postgresql.base.UUID",
)
def conv_PGUuid(self, field_args, **extra):
field_args.setdefault("label", "UUID")
field_args["validators"].append(validators.UUID())
return wtforms_fields.StringField(**field_args)
@converts("MANYTOONE")
def conv_ManyToOne(self, field_args, **extra):
return QuerySelectField(**field_args)
@converts("MANYTOMANY", "ONETOMANY")
def conv_ManyToMany(self, field_args, **extra):
return QuerySelectMultipleField(**field_args)
def model_fields(
model,
db_session=None,
only=None,
exclude=None,
field_args=None,
converter=None,
exclude_pk=False,
exclude_fk=False,
):
"""Generate a dictionary of fields for a given SQLAlchemy model.
See `model_form` docstring for description of parameters.
"""
mapper = sainspect(model)
converter = converter or ModelConverter()
field_args = field_args or {}
properties = []
for prop in mapper.attrs.values():
if getattr(prop, "columns", None):
if exclude_fk and prop.columns[0].foreign_keys:
continue
elif exclude_pk and prop.columns[0].primary_key:
continue
properties.append((prop.key, prop))
# ((p.key, p) for p in mapper.iterate_properties)
if only:
properties = (x for x in properties if x[0] in only)
elif exclude:
properties = (x for x in properties if x[0] not in exclude)
field_dict = {}
for name, prop in properties:
field = converter.convert(model, mapper, prop, field_args.get(name), db_session)
if field is not None:
field_dict[name] = field
return field_dict