Source code for pypika.terms

import inspect
import itertools
import re
from datetime import date
from enum import Enum

from typing import (
    Iterable,
    Union,
)

from pypika.enums import (
    Arithmetic,
    Boolean,
    Dialects,
    Equality,
    JSONOperators,
    Matching,
)
from pypika.utils import (
    CaseException,
    FunctionException,
    builder,
    format_alias_sql,
    format_quotes,
    ignore_copy,
    resolve_is_aggregate,
)

try:
    basestring
except NameError:
    basestring = str

__author__ = "Timothy Heys"
__email__ = "theys@kayak.com"


[docs]class Term: is_aggregate = False def __init__(self, alias=None): self.alias = alias @builder def as_(self, alias): self.alias = alias @property def tables_(self): return set()
[docs] @staticmethod def wrap_constant(val, wrapper_cls=None): """ Used for wrapping raw inputs such as numbers in Criterions and Operator. For example, the expression F('abc')+1 stores the integer part in a ValueWrapper object. :param val: Any value. :param wrapper_cls: A pypika class which wraps a constant value so it can be handled as a component of the query. :return: Raw string, number, or decimal values will be returned in a ValueWrapper. Fields and other parts of the querybuilder will be returned as inputted. """ from .queries import QueryBuilder if isinstance(val, (Term, QueryBuilder, Interval)): return val if val is None: return NullValue() if isinstance(val, list): return Array(*val) if isinstance(val, tuple): return Tuple(*val) # Need to default here to avoid the recursion. ValueWrapper extends this class. wrapper_cls = wrapper_cls or ValueWrapper return wrapper_cls(val)
[docs] @staticmethod def wrap_json(val, wrapper_cls=None): from .queries import QueryBuilder if isinstance(val, (Term, QueryBuilder, Interval)): return val if val is None: return NullValue() if isinstance(val, (str, int, bool)): wrapper_cls = wrapper_cls or ValueWrapper return wrapper_cls(val) return JSON(val)
[docs] def replace_table(self, current_table, new_table): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. The base implementation returns self because not all terms have a table property. :param current_table: The table to be replaced. :param new_table: The table to replace with. :return: Self. """ return self
[docs] def fields(self): return [self]
[docs] def eq(self, other): return self == other
[docs] def isnull(self): return NullCriterion(self)
[docs] def notnull(self): return self.isnull().negate()
[docs] def bitwiseand(self, value): return BitwiseAndCriterion(self, value)
[docs] def gt(self, other): return self > other
[docs] def gte(self, other): return self >= other
[docs] def lt(self, other): return self < other
[docs] def lte(self, other): return self <= other
[docs] def ne(self, other): return self != other
[docs] def like(self, expr): return BasicCriterion(Matching.like, self, self.wrap_constant(expr))
[docs] def not_like(self, expr): return BasicCriterion(Matching.not_like, self, self.wrap_constant(expr))
[docs] def ilike(self, expr): return BasicCriterion(Matching.ilike, self, self.wrap_constant(expr))
[docs] def not_ilike(self, expr): return BasicCriterion(Matching.not_ilike, self, self.wrap_constant(expr))
[docs] def regex(self, pattern): return BasicCriterion(Matching.regex, self, self.wrap_constant(pattern))
[docs] def between(self, lower, upper): return BetweenCriterion(self, self.wrap_constant(lower), self.wrap_constant(upper))
[docs] def isin(self, arg): if isinstance(arg, (list, tuple, set)): return ContainsCriterion(self, Tuple(*[self.wrap_constant(value) for value in arg])) return ContainsCriterion(self, arg)
[docs] def notin(self, arg): return self.isin(arg).negate()
[docs] def bin_regex(self, pattern): return BasicCriterion(Matching.bin_regex, self, self.wrap_constant(pattern))
[docs] def negate(self): return Not(self)
def __invert__(self): return Not(self) def __pos__(self): return self def __neg__(self): return Negative(self) def __add__(self, other): return ArithmeticExpression(Arithmetic.add, self, self.wrap_constant(other)) def __sub__(self, other): return ArithmeticExpression(Arithmetic.sub, self, self.wrap_constant(other)) def __mul__(self, other): return ArithmeticExpression(Arithmetic.mul, self, self.wrap_constant(other)) def __truediv__(self, other): return ArithmeticExpression(Arithmetic.div, self, self.wrap_constant(other)) def __pow__(self, other): return Pow(self, other) def __mod__(self, other): return Mod(self, other) def __radd__(self, other): return ArithmeticExpression(Arithmetic.add, self.wrap_constant(other), self) def __rsub__(self, other): return ArithmeticExpression(Arithmetic.sub, self.wrap_constant(other), self) def __rmul__(self, other): return ArithmeticExpression(Arithmetic.mul, self.wrap_constant(other), self) def __rtruediv__(self, other): return ArithmeticExpression(Arithmetic.div, self.wrap_constant(other), self) def __eq__(self, other): return BasicCriterion(Equality.eq, self, self.wrap_constant(other)) def __ne__(self, other): return BasicCriterion(Equality.ne, self, self.wrap_constant(other)) def __gt__(self, other): return BasicCriterion(Equality.gt, self, self.wrap_constant(other)) def __ge__(self, other): return BasicCriterion(Equality.gte, self, self.wrap_constant(other)) def __lt__(self, other): return BasicCriterion(Equality.lt, self, self.wrap_constant(other)) def __le__(self, other): return BasicCriterion(Equality.lte, self, self.wrap_constant(other)) def __getitem__(self, item): if not isinstance(item, slice): raise TypeError("Field' object is not subscriptable") return self.between(item.start, item.stop) def __str__(self): return self.get_sql(quote_char='"', secondary_quote_char="'") def __hash__(self): return hash(self.get_sql(with_alias=True))
[docs] def get_sql(self, **kwargs): raise NotImplementedError()
[docs]class Parameter(Term): is_aggregate = None def __init__(self, placeholder): super(Parameter, self).__init__() self.placeholder = placeholder
[docs] def fields(self): return []
[docs] def get_sql(self, **kwargs): return str(self.placeholder)
[docs]class Negative(Term): def __init__(self, term): super(Negative, self).__init__() self.term = term @property def is_aggregate(self): return self.term.is_aggregate
[docs] def get_sql(self, **kwargs): return '-{term}'.format(term=self.term.get_sql(**kwargs))
[docs]class ValueWrapper(Term): is_aggregate = None def __init__(self, value, alias=None): super(ValueWrapper, self).__init__(alias) self.value = value
[docs] def fields(self): return []
[docs] def get_value_sql(self, **kwargs): quote_char = kwargs.get('secondary_quote_char') or '' # FIXME escape values if isinstance(self.value, Term): return self.value.get_sql(**kwargs) if isinstance(self.value, Enum): return self.value.value if isinstance(self.value, date): value = self.value.isoformat() return format_quotes(value, quote_char) if isinstance(self.value, basestring): value = self.value.replace(quote_char, quote_char * 2) return format_quotes(value, quote_char) if isinstance(self.value, bool): return str.lower(str(self.value)) if self.value is None: return 'null' return str(self.value)
[docs] def get_sql(self, quote_char=None, secondary_quote_char="'", **kwargs): sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs) return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)
[docs]class JSON(Term): table = None def __init__(self, value, alias=None): super().__init__(alias) self.value = value def _recursive_get_sql(self, value, **kwargs): if isinstance(value, dict): return self._get_dict_sql(value, **kwargs) if isinstance(value, list): return self._get_list_sql(value, **kwargs) if isinstance(value, str): return self._get_str_sql(value, **kwargs) return str(value) def _get_dict_sql(self, value, **kwargs): pairs = ['{key}:{value}' .format(key=self._recursive_get_sql(k, **kwargs), value=self._recursive_get_sql(v, **kwargs)) for k, v in value.items()] return ''.join(['{', ','.join(pairs), '}']) def _get_list_sql(self, value, **kwargs): pairs = [self._recursive_get_sql(v, **kwargs) for v in value] return ''.join(['[', ','.join(pairs), ']']) @staticmethod def _get_str_sql(value, quote_char='"', **kwargs): return format_quotes(value, quote_char)
[docs] def get_sql(self, secondary_quote_char="'", **kwargs): return format_quotes(self._recursive_get_sql(self.value), secondary_quote_char)
[docs] def get_json_value(self, key_or_index: Union[str, int]): return BasicCriterion(JSONOperators.GET_JSON_VALUE, self, self.wrap_constant(key_or_index))
[docs] def get_text_value(self, key_or_index: Union[str, int]): return BasicCriterion(JSONOperators.GET_TEXT_VALUE, self, self.wrap_constant(key_or_index))
[docs] def get_path_json_value(self, path_json: str): return BasicCriterion(JSONOperators.GET_PATH_JSON_VALUE, self, self.wrap_json(path_json))
[docs] def get_path_text_value(self, path_json: str): return BasicCriterion(JSONOperators.GET_PATH_TEXT_VALUE, self, self.wrap_json(path_json))
[docs] def has_key(self, other): return BasicCriterion(JSONOperators.HAS_KEY, self, self.wrap_json(other))
[docs] def contains(self, other): return BasicCriterion(JSONOperators.CONTAINS, self, self.wrap_json(other))
[docs] def contained_by(self, other): return BasicCriterion(JSONOperators.CONTAINED_BY, self, self.wrap_json(other))
[docs] def has_keys(self, other: Iterable): return BasicCriterion(JSONOperators.HAS_KEYS, self, Array(*other))
[docs] def has_any_keys(self, other: Iterable): return BasicCriterion(JSONOperators.HAS_ANY_KEYS, self, Array(*other))
[docs]class Values(Term): def __init__(self, field, ): super().__init__(None) self.field = Field(field) if not isinstance(field, Field) else field
[docs] def get_sql(self, quote_char=None, **kwargs): return 'VALUES({value})'.format(value=self.field.get_sql(quote_char=quote_char, **kwargs))
[docs]class NullValue(Term):
[docs] def fields(self): return []
[docs] def get_sql(self, **kwargs): sql = 'NULL' return format_alias_sql(sql, self.alias, **kwargs)
[docs]class Criterion(Term): def __and__(self, other): return ComplexCriterion(Boolean.and_, self, other) def __or__(self, other): return ComplexCriterion(Boolean.or_, self, other) def __xor__(self, other): return ComplexCriterion(Boolean.xor_, self, other)
[docs] @staticmethod def any(terms=()): crit = EmptyCriterion() for term in terms: crit |= term return crit
[docs] @staticmethod def all(terms=()): crit = EmptyCriterion() for term in terms: crit &= term return crit
[docs] def fields(self): raise NotImplementedError()
[docs] def get_sql(self): raise NotImplementedError()
[docs]class EmptyCriterion: is_aggregate = None tables_ = set() def __and__(self, other): return other def __or__(self, other): return other def __xor__(self, other): return other
[docs]class Field(Criterion, JSON): def __init__(self, name, alias=None, table=None): super(Field, self).__init__(alias) self.name = name self.table = table
[docs] def fields(self): return [self]
@property def tables_(self): return {self.table} @builder def replace_table(self, current_table, new_table): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. :param current_table: The table to be replaced. :param new_table: The table to replace with. :return: A copy of the field with the tables replaced. """ self.table = new_table if self.table == current_table else self.table
[docs] def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, secondary_quote_char="'", **kwargs): field_sql = format_quotes(self.name, quote_char) # Need to add namespace if the table has an alias if self.table and (with_namespace or self.table.alias): field_sql = "{namespace}.{name}" \ .format( namespace=format_quotes(self.table.alias or self.table._table_name, quote_char), name=field_sql, ) field_alias = getattr(self, 'alias', None) if with_alias: return format_alias_sql(field_sql, field_alias, quote_char=quote_char, **kwargs) return field_sql
[docs]class Index(Term): def __init__(self, name, alias=None): super(Index, self).__init__(alias) self.name = name
[docs] def get_sql(self, quote_char=None, **kwargs): return format_quotes(self.name, quote_char)
[docs]class Star(Field): def __init__(self, table=None): super(Star, self).__init__('*', table=table) @property def tables_(self): if self.table is None: return {} return {self.table}
[docs] def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, **kwargs): if self.table and (with_namespace or self.table.alias): namespace = self.table.alias or getattr(self.table, '_table_name') return "{}.*" \ .format(format_quotes(namespace, quote_char)) return '*'
[docs]class Tuple(Criterion): def __init__(self, *values): super(Tuple, self).__init__() self.values = [self.wrap_constant(value) for value in values]
[docs] def fields(self): return list(itertools.chain(*[value.fields() for value in self.values]))
[docs] def get_sql(self, **kwargs): return '({})'.format( ','.join(term.get_sql(**kwargs) for term in self.values) )
@property def is_aggregate(self): return all([value.is_aggregate for value in self.values]) @builder def replace_table(self, current_table, new_table): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. :param current_table: The table to be replaced. :param new_table: The table to replace with. :return: A copy of the field with the tables replaced. """ self.values = [[value.replace_table(current_table, new_table) for value in value_list] for value_list in self.values]
[docs]class Array(Tuple):
[docs] def get_sql(self, **kwargs): dialect = kwargs.get('dialect', None) template = 'ARRAY[{}]' \ if dialect in (Dialects.POSTGRESQL, Dialects.REDSHIFT) \ else '[{}]' return template.format( ','.join(term.get_sql(**kwargs) for term in self.values) )
[docs]class Bracket(Tuple): def __init__(self, term): super(Bracket, self).__init__(term)
[docs] def get_sql(self, **kwargs): sql = super(Bracket, self).get_sql(**kwargs) return format_alias_sql(sql=sql, alias=self.alias, **kwargs)
[docs]class NestedCriterion(Criterion): def __init__(self, comparator, nested_comparator, left, right, nested, alias=None): super().__init__(alias) self.left = left self.comparator = comparator self.nested_comparator = nested_comparator self.right = right self.nested = nested
[docs] def fields(self): return list(itertools.chain(self.right.fields(), self.left.fields(), self.nested.fields()))
@property def is_aggregate(self): return resolve_is_aggregate([term.is_aggregate for term in [self.left, self.right, self.nested]]) @property def tables_(self): return self.left.tables_ | self.right.tables_ | self.nested.tables_ @builder def replace_table(self, current_table, new_table): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. :param current_table: The table to be replaced. :param new_table: The table to replace with. :return: A copy of the criterion with the tables replaced. """ self.left = self.left.replace_table(current_table, new_table) self.right = self.right.replace_table(current_table, new_table) self.nested = self.right.replace_table(current_table, new_table)
[docs] def get_sql(self, with_alias=False, **kwargs): sql = '{left}{comparator}{right}{nested_comparator}{nested}'.format( left=self.left.get_sql(**kwargs), comparator=self.comparator.value, right=self.right.get_sql(**kwargs), nested_comparator=self.nested_comparator.value, nested=self.nested.get_sql(**kwargs) ) if with_alias: return format_alias_sql(sql=sql, alias=self.alias, **kwargs) return sql
[docs]class BasicCriterion(Criterion): def __init__(self, comparator, left, right, alias=None): """ A wrapper for a basic criterion such as equality or inequality. This wraps three parts, a left and right term and a comparator which defines the type of comparison. :param comparator: Type: Comparator This defines the type of comparison, such as {quote}={quote} or {quote}>{quote}. :param left: The term on the left side of the expression. :param right: The term on the right side of the expression. """ super(BasicCriterion, self).__init__(alias) self.comparator = comparator self.left = left self.right = right @property def is_aggregate(self): return resolve_is_aggregate([term.is_aggregate for term in [self.left, self.right]]) @property def tables_(self): return self.left.tables_ | self.right.tables_ @builder def replace_table(self, current_table, new_table): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. :param current_table: The table to be replaced. :param new_table: The table to replace with. :return: A copy of the criterion with the tables replaced. """ self.left = self.left.replace_table(current_table, new_table) self.right = self.right.replace_table(current_table, new_table)
[docs] def fields(self): return self.left.fields() + self.right.fields()
[docs] def get_sql(self, quote_char='"', with_alias=False, **kwargs): sql = '{left}{comparator}{right}'.format( comparator=self.comparator.value, left=self.left.get_sql(quote_char=quote_char, **kwargs), right=self.right.get_sql(quote_char=quote_char, **kwargs), ) if with_alias and self.alias: return '{sql} "{alias}"'.format(sql=sql, alias=self.alias) return sql
[docs]class ContainsCriterion(Criterion): def __init__(self, term, container, alias=None): """ A wrapper for a "IN" criterion. This wraps two parts, a term and a container. The term is the part of the expression that is checked for membership in the container. The container can either be a list or a subquery. :param term: The term to assert membership for within the container. :param container: A list or subquery. """ super(ContainsCriterion, self).__init__(alias) self.term = term self.container = container self._is_negated = False @property def tables_(self): return self.term.tables_ @property def is_aggregate(self): return self.term.is_aggregate @builder def replace_table(self, current_table, new_table): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. :param current_table: The table to be replaced. :param new_table: The table to replace with. :return: A copy of the criterion with the tables replaced. """ self.term = self.term.replace_table(current_table, new_table)
[docs] def fields(self): return self.term.fields() if self.term.fields else []
[docs] def get_sql(self, subquery=None, **kwargs): return "{term} {not_}IN {container}".format( term=self.term.get_sql(**kwargs), container=self.container.get_sql(subquery=True, **kwargs), not_='NOT ' if self._is_negated else '' )
[docs] def negate(self): self._is_negated = True return self
[docs]class BetweenCriterion(Criterion): def __init__(self, term, start, end, alias=None): super(BetweenCriterion, self).__init__(alias) self.term = term self.start = start self.end = end @property def tables_(self): return self.term.tables_ @property def is_aggregate(self): return self.term.is_aggregate @builder def replace_table(self, current_table, new_table): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. :param current_table: The table to be replaced. :param new_table: The table to replace with. :return: A copy of the criterion with the tables replaced. """ self.term = self.term.replace_table(current_table, new_table)
[docs] def get_sql(self, **kwargs): # FIXME escape return "{term} BETWEEN {start} AND {end}".format( term=self.term.get_sql(**kwargs), start=self.start.get_sql(**kwargs), end=self.end.get_sql(**kwargs), )
[docs] def fields(self): return self.term.fields() if self.term.fields else []
[docs]class BitwiseAndCriterion(Criterion): def __init__(self, term, value, alias=None): super(BitwiseAndCriterion, self).__init__(alias) self.term = term self.value = value @property def tables_(self): return self.term.tables_ @builder def replace_table(self, current_table, new_table): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. :param current_table: The table to be replaced. :param new_table: The table to replace with. :return: A copy of the criterion with the tables replaced. """ self.term = self.term.replace_table(current_table, new_table)
[docs] def get_sql(self, **kwargs): return "({term} & {value})".format( term=self.term.get_sql(**kwargs), value=self.value, )
[docs] def fields(self): return self.term.fields() if self.term.fields else []
[docs]class NullCriterion(Criterion): def __init__(self, term, alias=None): super(NullCriterion, self).__init__(alias) self.term = term @property def tables_(self): return self.term.tables_ @builder def replace_table(self, current_table, new_table): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. :param current_table: The table to be replaced. :param new_table: The table to replace with. :return: A copy of the criterion with the tables replaced. """ self.term = self.term.replace_table(current_table, new_table)
[docs] def get_sql(self, **kwargs): return "{term} IS NULL".format( term=self.term.get_sql(**kwargs), )
[docs] def fields(self): return self.term.fields() if self.term.fields else []
[docs]class ComplexCriterion(BasicCriterion):
[docs] def fields(self): return self.left.fields() + self.right.fields()
[docs] def get_sql(self, subcriterion=False, **kwargs): sql = '{left} {comparator} {right}'.format( comparator=self.comparator.value, left=self.left.get_sql(subcriterion=self.needs_brackets(self.left), **kwargs), right=self.right.get_sql(subcriterion=self.needs_brackets(self.right), **kwargs), ) if subcriterion: return '({criterion})'.format( criterion=sql ) return sql
[docs] def needs_brackets(self, term): return isinstance(term, ComplexCriterion) and not term.comparator == self.comparator
[docs]class ArithmeticExpression(Term): """ Wrapper for an arithmetic function. Can be simple with two terms or complex with nested terms. Order of operations are also preserved. """ mul_order = [Arithmetic.mul, Arithmetic.div] add_order = [Arithmetic.add, Arithmetic.sub] def __init__(self, operator, left, right, alias=None): """ Wrapper for an arithmetic expression. :param operator: Type: Arithmetic An operator for the expression such as {quote}+{quote} or {quote}/{quote} :param left: The term on the left side of the expression. :param right: The term on the right side of the expression. :param alias: (Optional) an alias for the term which can be used inside a select statement. :return: """ super(ArithmeticExpression, self).__init__(alias) self.operator = operator self.left = left self.right = right @property def is_aggregate(self): # True if both left and right terms are True or None. None if both terms are None. Otherwise, False return resolve_is_aggregate([self.left.is_aggregate, self.right.is_aggregate]) @property def tables_(self): return self.left.tables_ | self.right.tables_ @builder def replace_table(self, current_table, new_table): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. :param current_table: The table to be replaced. :param new_table: The table to replace with. :return: A copy of the term with the tables replaced. """ self.left = self.left.replace_table(current_table, new_table) self.right = self.right.replace_table(current_table, new_table)
[docs] def fields(self): return self.left.fields() + self.right.fields()
[docs] def get_sql(self, with_alias=False, **kwargs): is_mul = self.operator in self.mul_order is_left_add, is_right_add = [getattr(side, 'operator', None) in self.add_order for side in [self.left, self.right]] quote_char = kwargs.get('quote_char', None) arithmatic_sql = '{left}{operator}{right}'.format( operator=self.operator.value, left=("({})" if is_mul and is_left_add else "{}").format(self.left.get_sql(**kwargs)), right=("({})" if is_mul and is_right_add else "{}").format(self.right.get_sql(**kwargs)), ) if with_alias: return format_alias_sql(arithmatic_sql, self.alias, **kwargs) return arithmatic_sql
[docs]class Case(Term): def __init__(self, alias=None): super(Case, self).__init__(alias=alias) self._cases = [] self._else = None @property def is_aggregate(self): # True if all cases are True or None. None all cases are None. Otherwise, False return resolve_is_aggregate([term.is_aggregate for _, term in self._cases] + [self._else.is_aggregate if self._else else None]) @builder def when(self, criterion, term): self._cases.append((criterion, self.wrap_constant(term))) @builder def replace_table(self, current_table, new_table): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. :param current_table: The table to be replaced. :param new_table: The table to replace with. :return: A copy of the term with the tables replaced. """ self._cases = [[criterion.replace_table(current_table, new_table), term.replace_table(current_table, new_table)] for criterion, term in self._cases] self._else = self._else.replace_table(current_table, new_table) if self._else else None @builder def else_(self, term): self._else = self.wrap_constant(term) return self
[docs] def get_sql(self, with_alias=False, **kwargs): if not self._cases: raise CaseException("At least one 'when' case is required for a CASE statement.") cases = " ".join('WHEN {when} THEN {then}'.format( when=criterion.get_sql(**kwargs), then=term.get_sql(**kwargs) ) for criterion, term in self._cases) else_ = (' ELSE {}'.format(self._else.get_sql(**kwargs)) if self._else else '') case_sql = 'CASE {cases}{else_} END'.format(cases=cases, else_=else_) if with_alias: return format_alias_sql(case_sql, self.alias, **kwargs) return case_sql
[docs] def fields(self): fields = [] for criterion, term in self._cases: fields += criterion.fields() + term.fields() if self._else is not None: fields += self._else.fields() return fields
@property def tables_(self): tables = set() if self._cases: tables |= {table for case in self._cases for part in case for table in part.tables_ if hasattr(part, 'tables_')} if self._else and hasattr(self._else, 'tables_'): tables |= {table for table in self._else.tables_} return tables
[docs]class Not(Criterion): def __init__(self, term, alias=None): super(Not, self).__init__(alias=alias) self.term = term
[docs] def fields(self): return self.term.fields() if self.term.fields else []
[docs] def get_sql(self, **kwargs): kwargs['subcriterion'] = True sql = "NOT {term}".format(term=self.term.get_sql(**kwargs)) return format_alias_sql(sql, self.alias, **kwargs)
@ignore_copy def __getattr__(self, name): """ Delegate method calls to the class wrapped by Not(). Re-wrap methods on child classes of Term (e.g. isin, eg...) to retain 'NOT <term>' output. """ item_func = getattr(self.term, name) if not inspect.ismethod(item_func): return item_func def inner(inner_self, *args, **kwargs): result = item_func(inner_self, *args, **kwargs) if isinstance(result, (Term,)): return Not(result) return result return inner @builder def replace_table(self, current_table, new_table): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. :param current_table: The table to be replaced. :param new_table: The table to replace with. :return: A copy of the criterion with the tables replaced. """ self.term = self.term.replace_table(current_table, new_table) @property def tables_(self): return self.term.tables_
[docs]class CustomFunction(): def __init__(self, name, params=None): self.name = name self.params = params def __call__(self, *args, **kwargs): if not self._has_params(): return Function(self.name, alias=kwargs.get('alias')) if not self._is_valid_function_call(*args): raise FunctionException("Function {name} require these arguments ({params}), ({args}) passed".format( name=self.name, params=', '.join(str(p) for p in self.params), args=', '.join(str(p) for p in args) )) return Function(self.name, *args, alias=kwargs.get('alias')) def _has_params(self): return self.params is not None def _is_valid_function_call(self, *args): return len(args) == len(self.params)
[docs]class Function(Criterion): def __init__(self, name, *args, **kwargs): super(Function, self).__init__(kwargs.get('alias')) self.name = name self.args = [self.wrap_constant(param) for param in args] self.schema = kwargs.get('schema') @property def tables_(self): return {table for param in self.args for table in param.tables_}
[docs] def fields(self): return [field for param in self.args if hasattr(param, 'fields') for field in param.fields()]
@property def is_aggregate(self): """ This is a shortcut that assumes if a function has a single argument and that argument is aggregated, then this function is also aggregated. A more sophisticated approach is needed, however it is unclear how that might work. :returns: True if the function accepts one argument and that argument is aggregate. """ return len(self.args) == 1 and self.args[0].is_aggregate @builder def replace_table(self, current_table, new_table): """ Replaces all occurrences of the specified table with the new table. Useful when reusing fields across queries. :param current_table: The table to be replaced. :param new_table: The table to replace with. :return: A copy of the criterion with the tables replaced. """ self.args = [param.replace_table(current_table, new_table) for param in self.args]
[docs] def get_special_params_sql(self, **kwargs): pass
[docs] def get_function_sql(self, **kwargs): special_params_sql = self.get_special_params_sql(**kwargs) return '{name}({args}{special})'.format( name=self.name, args=','.join(p.get_sql(with_alias=False, **kwargs) if hasattr(p, 'get_sql') else str(p) for p in self.args), special=(' ' + special_params_sql) if special_params_sql else '', )
[docs] def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, dialect=None, **kwargs): # FIXME escape function_sql = self.get_function_sql(with_namespace=with_namespace, quote_char=quote_char, dialect=dialect) if self.schema is not None: function_sql = '{schema}.{function}' \ .format(schema=self.schema.get_sql(quote_char=quote_char, dialect=dialect, **kwargs), function=function_sql) if with_alias: return format_alias_sql(function_sql, self.alias, quote_char=quote_char, **kwargs) return function_sql
[docs]class AggregateFunction(Function): is_aggregate = True
[docs]class AnalyticFunction(Function): is_analytic = True def __init__(self, name, *args, **kwargs): super(AnalyticFunction, self).__init__(name, *args, **kwargs) self._partition = [] self._orderbys = [] self._include_over = False @builder def over(self, *terms): self._include_over = True self._partition += terms @builder def orderby(self, *terms, **kwargs): self._include_over = True self._orderbys += [(term, kwargs.get('order')) for term in terms] def _orderby_field(self, field, orient, **kwargs): if orient is None: return field.get_sql(**kwargs) return '{field} {orient}'.format( field=field.get_sql(**kwargs), orient=orient.value, )
[docs] def get_partition_sql(self, **kwargs): terms = [] if self._partition: terms.append('PARTITION BY {args}'.format( args=','.join(p.get_sql(**kwargs) if hasattr(p, 'get_sql') else str(p) for p in self._partition))) if self._orderbys: terms.append('ORDER BY {orderby}'.format( orderby=','.join( self._orderby_field(field, orient, **kwargs) for field, orient in self._orderbys ))) return ' '.join(terms)
[docs] def get_function_sql(self, **kwargs): function_sql = super(AnalyticFunction, self).get_function_sql(**kwargs) partition_sql = self.get_partition_sql(**kwargs) if not self._include_over: return function_sql return '{function_sql} OVER({partition_sql})'.format( function_sql=function_sql, partition_sql=partition_sql )
[docs]class WindowFrameAnalyticFunction(AnalyticFunction):
[docs] class Edge: def __init__(self, value=None): self.value = value def __str__(self): return '{value} {modifier}'.format( value=self.value or 'UNBOUNDED', modifier=self.modifier, )
def __init__(self, name, *args, **kwargs): super(WindowFrameAnalyticFunction, self).__init__(name, *args, **kwargs) self.frame = None self.bound = None def _set_frame_and_bounds(self, frame, bound, and_bound): if self.frame or self.bound: raise AttributeError() self.frame = frame self.bound = (bound, and_bound) if and_bound else bound @builder def rows(self, bound, and_bound=None): self._set_frame_and_bounds('ROWS', bound, and_bound) @builder def range(self, bound, and_bound=None): self._set_frame_and_bounds('RANGE', bound, and_bound)
[docs] def get_frame_sql(self): if not isinstance(self.bound, tuple): return '{frame} {bound}'.format( frame=self.frame, bound=self.bound ) lower, upper = self.bound return '{frame} BETWEEN {lower} AND {upper}'.format( frame=self.frame, lower=lower, upper=upper, )
[docs] def get_partition_sql(self, **kwargs): partition_sql = super(WindowFrameAnalyticFunction, self).get_partition_sql(**kwargs) if not self.frame and not self.bound: return partition_sql return '{over} {frame}'.format( over=partition_sql, frame=self.get_frame_sql() )
[docs]class IgnoreNullsAnalyticFunction(AnalyticFunction): def __init__(self, name, *args, **kwargs): super(IgnoreNullsAnalyticFunction, self).__init__(name, *args, **kwargs) self._ignore_nulls = False @builder def ignore_nulls(self): self._ignore_nulls = True
[docs] def get_special_params_sql(self, **kwargs): if self._ignore_nulls: return 'IGNORE NULLS' # No special params unless ignoring nulls return None
[docs]class Interval: templates = { # MySQL requires no single quotes around the expr and unit Dialects.MYSQL: 'INTERVAL {expr} {unit}', # PostgreSQL, Redshift and Vertica require quotes around the expr and unit e.g. INTERVAL '1 week' Dialects.POSTGRESQL: 'INTERVAL \'{expr} {unit}\'', Dialects.REDSHIFT: 'INTERVAL \'{expr} {unit}\'', Dialects.VERTICA: 'INTERVAL \'{expr} {unit}\'', # Oracle requires just single quotes around the expr Dialects.ORACLE: 'INTERVAL \'{expr}\' {unit}' } units = ['years', 'months', 'days', 'hours', 'minutes', 'seconds', 'microseconds'] labels = ['YEAR', 'MONTH', 'DAY', 'HOUR', 'MINUTE', 'SECOND', 'MICROSECOND'] trim_pattern = re.compile(r'(^0+\.)|(\.0+$)|(^[0\-.: ]+[\-: ])|([\-:. ][0\-.: ]+$)') def __init__(self, years=0, months=0, days=0, hours=0, minutes=0, seconds=0, microseconds=0, quarters=0, weeks=0, dialect=None): self.dialect = dialect self.largest = None self.smallest = None if quarters: self.quarters = quarters return if weeks: self.weeks = weeks return for unit, label, value in zip(self.units, self.labels, [years, months, days, hours, minutes, seconds, microseconds]): if value: setattr(self, unit, int(value)) self.largest = self.largest or label self.smallest = label def __str__(self): return self.get_sql() @property def tables_(self): return {}
[docs] def fields(self): return []
[docs] def get_sql(self, **kwargs): dialect = self.dialect or kwargs.get('dialect') if self.largest == 'MICROSECOND': expr = getattr(self, 'microseconds') unit = 'MICROSECOND' elif hasattr(self, 'quarters'): expr = getattr(self, 'quarters') unit = 'QUARTER' elif hasattr(self, 'weeks'): expr = getattr(self, 'weeks') unit = 'WEEK' else: # Create the whole expression but trim out the unnecessary fields expr = "{years}-{months}-{days} {hours}:{minutes}:{seconds}.{microseconds}".format( years=getattr(self, 'years', 0), months=getattr(self, 'months', 0), days=getattr(self, 'days', 0), hours=getattr(self, 'hours', 0), minutes=getattr(self, 'minutes', 0), seconds=getattr(self, 'seconds', 0), microseconds=getattr(self, 'microseconds', 0), ) expr = self.trim_pattern.sub('', expr) unit = '{largest}_{smallest}'.format( largest=self.largest, smallest=self.smallest, ) if self.largest != self.smallest else self.largest return self.templates.get(dialect, 'INTERVAL \'{expr} {unit}\'') \ .format(expr=expr, unit=unit)
[docs]class Pow(Function): def __init__(self, term, exponent, alias=None): super(Pow, self).__init__('POW', term, exponent, alias=alias)
[docs]class Mod(Function): def __init__(self, term, modulus, alias=None): super(Mod, self).__init__('MOD', term, modulus, alias=alias)
[docs]class Rollup(Function): def __init__(self, *terms): super(Rollup, self).__init__('ROLLUP', *terms)
[docs]class PseudoColumn(Term): """ Represents a pseudo column (a "column" which yields a value when selected but is not actually a real table column). """ def __init__(self, name): self.name = name
[docs] def get_sql(self, **kwargs): return self.name
[docs] def fields(self): return []