Python sqlalchemy.sql.expression 模块,text() 实例源码

我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用sqlalchemy.sql.expression.text()

项目:triage    作者:dssg    | 项目源码 | 文件源码
def where(self, date, intervals):
        """
        Generates a WHERE clause
        Args:
            date: the end date
            intervals: intervals

        Returns: a clause for filtering the from_obj to be between the date and
            the greatest interval
        """
        # upper bound
        w = "{date_column} < '{date}'".format(
                            date_column=self.date_column, date=date)

        # lower bound (if possible)
        if 'all' not in intervals:
            greatest = "greatest(%s)" % str.join(
                    ",", ["interval '%s'" % i for i in intervals])
            min_date = "'{date}'::date - {greatest}".format(date=date, greatest=greatest)
            w += "AND {date_column} >= {min_date}".format(
                    date_column=self.date_column, min_date=min_date)
        if self.input_min_date is not None:
            w += "AND {date_column} >= '{bot}'::date".format(
                    date_column=self.date_column, bot=self.input_min_date)
        return ex.text(w)
项目:minstances    作者:0xa    | 项目源码 | 文件源码
def get_ping_aggregates(self, template):
        now = datetime.utcnow()

        if template == '30m':
            a = now - timedelta(days=7)
            agg_date = sqle.text("datetime((strftime('%s', time) / 1800) * 1800, 'unixepoch')")
        else:
            raise Exception()

        cnt = sqlf.count(Ping.id)
        q = sql.select([
            agg_date,
            sqlf.sum(Ping.response_time) / cnt,
            sqlf.sum(Ping.users) / cnt,
            sqlf.sum(Ping.statuses) / cnt,
            sqlf.sum(Ping.connections) / cnt,
        ])
        q = q.where((Ping.time >= a) & (Ping.instance_id == self.id))
        q = q.group_by(agg_date)
        q = q.order_by(Ping.id)
        Nt = namedtuple('PingAgg', ['time', 'response_time', 'users', 'statuses', 'connections'])
        r = db.session.execute(q)
        return [Nt(*t) for t in r]
项目:minstances    作者:0xa    | 项目源码 | 文件源码
def get_uptime_aggregates(self, template):
        now = datetime.utcnow()

        if template == '30m':
            a = now - timedelta(days=7)
            agg_date = sqle.text("datetime((strftime('%s', time) / 1800) * 1800, 'unixepoch')")
        else:
            raise Exception()

        cnt = sqlf.count(Ping.id)
        q = sql.select([
            agg_date,
            Ping.state,
            sqlf.sum(Ping.response_time) / cnt,
        ])
        q = q.where((Ping.time >= a) & (Ping.instance_id == self.id))
        q = q.group_by(agg_date, Ping.state)
        q = q.order_by(Ping.id)
        Nt = namedtuple('PingUptimeAgg', ['time', 'state', 'response_time'])
        r = db.session.execute(q)
        return [Nt(*t) for t in r]
项目:sqlacodegen    作者:agronholm    | 项目源码 | 文件源码
def test_pk_default(self):
        Table(
            'simple_items', self.metadata,
            Column('id', INTEGER, primary_key=True, server_default=text('uuid_generate_v4()'))
        )

        assert self.generate_code() == """\
# coding: utf-8
from sqlalchemy import Column, Integer, text
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()
metadata = Base.metadata


class SimpleItem(Base):
    __tablename__ = 'simple_items'

    id = Column(Integer, primary_key=True, server_default=text("uuid_generate_v4()"))
"""
项目:rucio    作者:rucio01    | 项目源码 | 文件源码
def get_updated_rse_counters(total_workers, worker_number, session=None):
    """
    Get updated rse_counters.

    :param total_workers:      Number of total workers.
    :param worker_number:      id of the executing worker.
    :param session:            Database session in use.
    :returns:                  List of rse_ids whose rse_counters need to be updated.
    """
    query = session.query(models.UpdatedRSECounter.rse_id).\
        distinct(models.UpdatedRSECounter.rse_id)

    if total_workers > 0:
        if session.bind.dialect.name == 'oracle':
            bindparams = [bindparam('worker_number', worker_number),
                          bindparam('total_workers', total_workers)]
            query = query.filter(text('ORA_HASH(rse_id, :total_workers) = :worker_number', bindparams=bindparams))
        elif session.bind.dialect.name == 'mysql':
            query = query.filter('mod(md5(rse_id), %s) = %s' % (total_workers + 1, worker_number))
        elif session.bind.dialect.name == 'postgresql':
            query = query.filter('mod(abs((\'x\'||md5(rse_id))::bit(32)::int), %s) = %s' % (total_workers + 1, worker_number))

    results = query.all()
    return [result.rse_id for result in results]
项目:rucio    作者:rucio01    | 项目源码 | 文件源码
def get_updated_account_counters(total_workers, worker_number, session=None):
    """
    Get updated rse_counters.

    :param total_workers:      Number of total workers.
    :param worker_number:      id of the executing worker.
    :param session:            Database session in use.
    :returns:                  List of rse_ids whose rse_counters need to be updated.
    """
    query = session.query(models.UpdatedAccountCounter.account, models.UpdatedAccountCounter.rse_id).\
        distinct(models.UpdatedAccountCounter.account, models.UpdatedAccountCounter.rse_id)

    if total_workers > 0:
        if session.bind.dialect.name == 'oracle':
            bindparams = [bindparam('worker_number', worker_number),
                          bindparam('total_workers', total_workers)]
            query = query.filter(text('ORA_HASH(CONCAT(account, rse_id), :total_workers) = :worker_number', bindparams=bindparams))
        elif session.bind.dialect.name == 'mysql':
            query = query.filter('mod(md5(concat(account, rse_id)), %s) = %s' % (total_workers + 1, worker_number))
        elif session.bind.dialect.name == 'postgresql':
            query = query.filter('mod(abs((\'x\'||md5(concat(account, rse_id)))::bit(32)::int), %s) = %s' % (total_workers + 1, worker_number))

    return query.all()
项目:rucio    作者:rucio01    | 项目源码 | 文件源码
def get_db_time():
    """ Gives the utc time on the db. """
    s = session.get_session()
    try:
        storage_date_format = None
        if s.bind.dialect.name == 'oracle':
            query = select([text("sys_extract_utc(systimestamp)")])
        elif s.bind.dialect.name == 'mysql':
            query = select([text("utc_timestamp()")])
        elif s.bind.dialect.name == 'sqlite':
            query = select([text("datetime('now', 'utc')")])
            storage_date_format = '%Y-%m-%d  %H:%M:%S'
        else:
            query = select([func.current_date()])

        for now, in s.execute(query):
            if storage_date_format:
                return datetime.strptime(now, storage_date_format)
            return now

    finally:
        s.remove()
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def update(self, id, user_id, ingredient_id, coefficient):
        """Update a record in the database with new values"""
        query = text("""
        UPDATE tblUserPreference SET
            user_id=:user_id,
            ingredient_id=:ingredient_id,
            coefficient=:coefficient
        WHERE
            id=:id
        """)
        result = self.execute(query,
                              id=id,
                              user_id=user_id,
                              ingredient_id=ingredient_id,
                              coefficient=coefficient)
        db.session.commit()
        return result.rowcount > 0
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def test_fetchone(self):
        """It gets a single row"""
        query = text("""
        SELECT id, name FROM tblExample
        WHERE id=:id;
        """)
        # Expect one record to be returned
        result = self.dao.fetchone(query, id=1)
        self.assertDictEqual(result, {'id': 1, 'name': 'Foo'})

        # Expect None to be returned if
        result = self.dao.fetchone(query, id=-1)
        self.assertEqual(result, None)

        # Expect one record to be returned even if multiple rows
        result = self.dao.fetchone('SELECT id, name FROM tblExample ORDER BY id')
        self.assertDictEqual(result, {'id': 1, 'name': 'Foo'})
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def test_fetchmany(self):
        """It gets many rows"""
        query = text("""
        SELECT id, name FROM tblExample
        ORDER BY id
        """)

        # Expect multiple records to be returned
        result = self.dao.fetchmany(query)
        self.assertEqual(len(result), 3)
        self.assertDictEqual(result[0], {'id': 1, 'name': 'Foo'})
        self.assertDictEqual(result[1], {'id': 2, 'name': 'Bar'})
        self.assertDictEqual(result[2], {'id': 3, 'name': 'Baz'})

        # Expect rows to be limited by MAX_RESULTS_SIZE
        self.dao.MAX_RESULTS_SIZE = 2
        result = self.dao.fetchmany(query)
        self.assertEqual(len(result), 2)
项目:stream2segment    作者:rizac    | 项目源码 | 文件源码
def seed_identifier(cls):  # @NoSelf
        '''returns data_identifier if the latter is not None, else net.sta.loc.cha by querying the
        relative channel and station'''
        # Needed note: To know what we are doing in 'sel' below, please look:
        # http://docs.sqlalchemy.org/en/latest/orm/extensions/hybrid.html#correlated-subquery-relationship-hybrid
        # Notes
        # - we use limit(1) cause we might get more than one
        # result. Regardless of why it happens (because we don't join or apply a distinct?)
        # it is relevant for us to get the first result which has the requested
        # network+station and location + channel strings
        # - the label(...) at the end makes all the difference. The doc is, as always, unclear
        # http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.label
        dot = text("'.'")
        sel = select([concat(Station.network, dot, Station.station, dot,
                             Channel.location, dot, Channel.channel)]).\
            where((Channel.id == cls.channel_id) & (Station.id == Channel.station_id)).limit(1).\
            label('seedidentifier')
        return case([(cls.data_identifier.isnot(None), cls.data_identifier)],
                    else_=sel)
项目:dota2-messenger-platform    作者:nico-arianto    | 项目源码 | 文件源码
def get_player(self, account_id=None, steam_id=None, real_name=None):
        query = Database.session.query(Player)
        if account_id:
            return query.filter(Player.account_id == account_id).first()
        elif steam_id:
            return query.filter(Player.steam_id == steam_id).first()
        elif real_name:  # recommended to be optimized by full-text search.
            return query.filter(or_(text('real_name like :real_name'), text('persona_name like :real_name'))).params(
                real_name="%" + real_name + "%").limit(LIMIT_DATA).all()
        else:
            raise ValueError('Account id or Steam id or real name must be specified!')
项目:dota2-messenger-platform    作者:nico-arianto    | 项目源码 | 文件源码
def get_match_summary_aggregate(self, match_id):
        return Database.session.query(MatchHero.account_id,
                                  func.sum(text('match_heroes.player_win')).label('player_win'),
                                  func.count(MatchHero.player_win).label('matches')). \
            filter(MatchHero.match_id >= match_id). \
            group_by(MatchHero.account_id). \
            all()
项目:dota2-messenger-platform    作者:nico-arianto    | 项目源码 | 文件源码
def get_match_hero_summary_aggregate(self, match_id):
        return Database.session.query(MatchHero.account_id,
                                  MatchHero.hero_id,
                                  func.sum(text('match_heroes.player_win')).label('player_win'),
                                  func.count(MatchHero.player_win).label('matches')). \
            filter(MatchHero.match_id >= match_id). \
            group_by(MatchHero.account_id). \
            group_by(MatchHero.hero_id). \
            all()
项目:dota2-messenger-platform    作者:nico-arianto    | 项目源码 | 文件源码
def get_match_item_summary_aggregate(self, match_id):
        return Database.session.query(MatchItem.account_id,
                                  MatchItem.item_id,
                                  func.sum(text('match_items.player_win')).label('player_win'),
                                  func.count(MatchItem.player_win).label('matches')). \
            filter(MatchItem.match_id >= match_id). \
            group_by(MatchItem.account_id). \
            group_by(MatchItem.item_id). \
            all()
项目:triage    作者:dssg    | 项目源码 | 文件源码
def __init__(self, aggregates, groups, from_obj, state_table,
                 state_group=None, prefix=None, suffix=None, schema=None):
        """
        Args:
            aggregates: collection of Aggregate objects.
            from_obj: defines the from clause, e.g. the name of the table. can use
            groups: a list of expressions to group by in the aggregation or a dictionary
                pairs group: expr pairs where group is the alias (used in column names)
            state_table: schema.table to query for comprehensive set of state_group entities
                regardless of what exists in the from_obj
            state_group: the group level found in the state table (e.g., "entity_id")
            prefix: prefix for aggregation tables and column names, defaults to from_obj
            suffix: suffix for aggregation table, defaults to "aggregation"
            schema: schema for aggregation tables

        The from_obj and group expressions are passed directly to the
            SQLAlchemy Select object so could be anything supported there.
            For details see:
            http://docs.sqlalchemy.org/en/latest/core/selectable.html

        Aggregates will have {collate_date} in their quantities substituted with the date
        of aggregation.
        """
        self.aggregates = aggregates
        self.from_obj = make_sql_clause(from_obj, ex.text)
        self.groups = groups if isinstance(groups, dict) else {str(g): g for g in groups}
        self.state_table = state_table
        self.state_group = state_group if state_group else "entity_id"
        self.prefix = prefix if prefix else str(from_obj)
        self.suffix = suffix if suffix else "aggregation"
        self.schema = schema
项目:sqlacodegen    作者:agronholm    | 项目源码 | 文件源码
def remove_unicode_prefixes(text):
        return unicode_re.sub(r"\1\2\1", text)
项目:sqlacodegen    作者:agronholm    | 项目源码 | 文件源码
def remove_unicode_prefixes(text):
        return text
项目:sqlacodegen    作者:agronholm    | 项目源码 | 文件源码
def test_indexes_class(self):
        simple_items = Table(
            'simple_items', self.metadata,
            Column('id', INTEGER, primary_key=True),
            Column('number', INTEGER),
            Column('text', VARCHAR)
        )
        simple_items.indexes.add(Index('idx_number', simple_items.c.number))
        simple_items.indexes.add(Index('idx_text_number', simple_items.c.text, simple_items.c.number))
        simple_items.indexes.add(Index('idx_text', simple_items.c.text, unique=True))

        assert self.generate_code() == """\
# coding: utf-8
from sqlalchemy import Column, Index, Integer, String
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()
metadata = Base.metadata


class SimpleItem(Base):
    __tablename__ = 'simple_items'
    __table_args__ = (
        Index('idx_text_number', 'text', 'number'),
    )

    id = Column(Integer, primary_key=True)
    number = Column(Integer, index=True)
    text = Column(String, unique=True)
"""
项目:rucio    作者:rucio01    | 项目源码 | 文件源码
def get_updated_dids(total_workers, worker_number, limit=100, blacklisted_dids=[], session=None):
    """
    Get updated dids.

    :param total_workers:      Number of total workers.
    :param worker_number:      id of the executing worker.
    :param limit:              Maximum number of dids to return.
    :param blacklisted_dids:   Blacklisted dids to filter.
    :param session:            Database session in use.
    """
    query = session.query(models.UpdatedDID.id,
                          models.UpdatedDID.scope,
                          models.UpdatedDID.name,
                          models.UpdatedDID.rule_evaluation_action)

    if total_workers > 0:
        if session.bind.dialect.name == 'oracle':
            bindparams = [bindparam('worker_number', worker_number),
                          bindparam('total_workers', total_workers)]
            query = query.filter(text('ORA_HASH(name, :total_workers) = :worker_number', bindparams=bindparams))
        elif session.bind.dialect.name == 'mysql':
            query = query.filter(text('mod(md5(name), %s) = %s' % (total_workers + 1, worker_number)))
        elif session.bind.dialect.name == 'postgresql':
            query = query.filter(text('mod(abs((\'x\'||md5(name))::bit(32)::int), %s) = %s' % (total_workers + 1, worker_number)))

    if limit:
        fetched_dids = query.order_by(models.UpdatedDID.created_at).limit(limit).all()
        filtered_dids = [did for did in fetched_dids if (did.scope, did.name) not in blacklisted_dids]
        if len(fetched_dids) == limit and len(filtered_dids) == 0:
            return get_updated_dids(total_workers=total_workers,
                                    worker_number=worker_number,
                                    limit=None,
                                    blacklisted_dids=blacklisted_dids,
                                    session=session)
        else:
            return filtered_dids
    else:
        return [did for did in query.order_by(models.UpdatedDID.created_at).all() if (did.scope, did.name) not in blacklisted_dids]
项目:rucio    作者:rucio01    | 项目源码 | 文件源码
def get_rules_beyond_eol(date_check, worker_number, total_workers, session):
    """
    Get rules which have eol_at before a certain date.

    :param date_check:         The reference date that should be compared to eol_at.
    :param worker_number:      id of the executing worker.
    :param total_workers:      Number of total workers.
    :param session:            Database session in use.
    """
    query = session.query(models.ReplicationRule.scope,
                          models.ReplicationRule.name,
                          models.ReplicationRule.rse_expression,
                          models.ReplicationRule.locked,
                          models.ReplicationRule.id,
                          models.ReplicationRule.eol_at,
                          models.ReplicationRule.expires_at).\
        filter(models.ReplicationRule.eol_at < date_check)

    if session.bind.dialect.name == 'oracle':
        bindparams = [bindparam('worker_number', worker_number),
                      bindparam('total_workers', total_workers)]
        query = query.filter(text('ORA_HASH(name, :total_workers) = :worker_number', bindparams=bindparams))
    elif session.bind.dialect.name == 'mysql':
        query = query.filter(text('mod(md5(name), %s) = %s' % (total_workers + 1, worker_number)))
    elif session.bind.dialect.name == 'postgresql':
        query = query.filter(text('mod(abs((\'x\'||md5(name))::bit(32)::int), %s) = %s' % (total_workers + 1, worker_number)))
    return [rule for rule in query.all()]
项目:rucio    作者:rucio01    | 项目源码 | 文件源码
def get_expired_rules(total_workers, worker_number, limit=100, blacklisted_rules=[], session=None):
    """
    Get expired rules.

    :param total_workers:      Number of total workers.
    :param worker_number:      id of the executing worker.
    :param limit:              Maximum number of rules to return.
    :param backlisted_rules:   List of blacklisted rules.
    :param session:            Database session in use.
    """

    query = session.query(models.ReplicationRule.id, models.ReplicationRule.rse_expression).filter(models.ReplicationRule.expires_at < datetime.utcnow(),
                                                                                                   models.ReplicationRule.locked == False,
                                                                                                   models.ReplicationRule.child_rule_id == None).\
        with_hint(models.ReplicationRule, "index(rules RULES_EXPIRES_AT_IDX)", 'oracle').\
        order_by(models.ReplicationRule.expires_at)  # NOQA

    if session.bind.dialect.name == 'oracle':
        bindparams = [bindparam('worker_number', worker_number),
                      bindparam('total_workers', total_workers)]
        query = query.filter(text('ORA_HASH(name, :total_workers) = :worker_number', bindparams=bindparams))
    elif session.bind.dialect.name == 'mysql':
        query = query.filter(text('mod(md5(name), %s) = %s' % (total_workers + 1, worker_number)))
    elif session.bind.dialect.name == 'postgresql':
        query = query.filter(text('mod(abs((\'x\'||md5(name))::bit(32)::int), %s) = %s' % (total_workers + 1, worker_number)))

    if limit:
        fetched_rules = query.limit(limit).all()
        filtered_rules = [rule for rule in fetched_rules if rule[0] not in blacklisted_rules]
        if len(fetched_rules) == limit and len(filtered_rules) == 0:
            return get_expired_rules(total_workers=total_workers,
                                     worker_number=worker_number,
                                     limit=None,
                                     blacklisted_rules=blacklisted_rules,
                                     session=session)
        else:
            return filtered_rules
    else:
        return [rule for rule in query.all() if rule[0] not in blacklisted_rules]
项目:rucio    作者:rucio01    | 项目源码 | 文件源码
def list_bad_replicas_history(limit=10000, thread=None, total_threads=None, session=None):
    """
    List the bad file replicas history. Method only used by necromancer

    :param limit: The maximum number of replicas returned.
    :param thread: The assigned thread for this necromancer.
    :param total_threads: The total number of threads of all necromancers.
    :param session: The database session in use.
    """

    query = session.query(models.BadReplicas.scope, models.BadReplicas.name, models.BadReplicas.rse_id).\
        filter(models.BadReplicas.state == BadFilesStatus.BAD)

    if total_threads and (total_threads - 1) > 0:
        if session.bind.dialect.name == 'oracle':
            bindparams = [bindparam('thread_number', thread), bindparam('total_threads', total_threads - 1)]
            query = query.filter(text('ORA_HASH(name, :total_threads) = :thread_number', bindparams=bindparams))
        elif session.bind.dialect.name == 'mysql':
            query = query.filter(text('mod(md5(name), %s) = %s' % (total_threads - 1, thread)))
        elif session.bind.dialect.name == 'postgresql':
            query = query.filter(text('mod(abs((\'x\'||md5(name))::bit(32)::int), %s) = %s' % (total_threads - 1, thread)))

    query = query.limit(limit)

    bad_replicas = {}
    for scope, name, rse_id in query.yield_per(1000):
        if rse_id not in bad_replicas:
            bad_replicas[rse_id] = []
        bad_replicas[rse_id].append({'scope': scope, 'name': name})
    return bad_replicas
项目:rucio    作者:rucio01    | 项目源码 | 文件源码
def get_replica_atime(replica, session=None):
    """
    Get the accessed_at timestamp for a replica. Just for testing.
    :param replicas: List of dictionaries {scope, name, rse_id, path}
    :param session: Database session to use.

    :returns: A datetime timestamp with the last access time.
    """
    if 'rse_id' not in replica:
        replica['rse_id'] = get_rse_id(rse=replica['rse'], session=session)

    return session.query(models.RSEFileAssociation.accessed_at).filter_by(scope=replica['scope'], name=replica['name'], rse_id=replica['rse_id']).\
        with_hint(models.RSEFileAssociation, text="INDEX(REPLICAS REPLICAS_PK)", dialect_name='oracle').one()[0]
项目:rucio    作者:rucio01    | 项目源码 | 文件源码
def list_expired_temporary_dids(rse, limit, worker_number=None, total_workers=None,
                                session=None):
    """
    List expired temporary DIDs.

    :param rse: the rse name.
    :param limit: The maximum number of replicas returned.
    :param worker_number:      id of the executing worker.
    :param total_workers:      Number of total workers.
    :param session: The database session in use.

    :returns: a list of dictionary replica.
    """
    rse_id = get_rse_id(rse, session=session)
    is_none = None
    query = session.query(models.TemporaryDataIdentifier.scope,
                          models.TemporaryDataIdentifier.name,
                          models.TemporaryDataIdentifier.path,
                          models.TemporaryDataIdentifier.bytes).\
        with_hint(models.TemporaryDataIdentifier, "INDEX(tmp_dids TMP_DIDS_EXPIRED_AT_IDX)", 'oracle').\
        filter(case([(models.TemporaryDataIdentifier.expired_at != is_none, models.TemporaryDataIdentifier.rse_id), ]) == rse_id)

    if worker_number and total_workers and total_workers - 1 > 0:
        if session.bind.dialect.name == 'oracle':
            bindparams = [bindparam('worker_number', worker_number - 1), bindparam('total_workers', total_workers - 1)]
            query = query.filter(text('ORA_HASH(name, :total_workers) = :worker_number', bindparams=bindparams))
        elif session.bind.dialect.name == 'mysql':
            query = query.filter(text('mod(md5(name), %s) = %s' % (total_workers - 1, worker_number - 1)))
        elif session.bind.dialect.name == 'postgresql':
            query = query.filter(text('mod(abs((\'x\'||md5(path))::bit(32)::int), %s) = %s' % (total_workers - 1, worker_number - 1)))

    return [{'path': path,
             'rse': rse,
             'rse_id': rse_id,
             'scope': scope,
             'name': name,
             'bytes': bytes}
            for scope, name, path, bytes in query.limit(limit)]
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def get_by_id(self, id):
        """Retrieve a record from the database by ID"""
        query = text("""
        SELECT id, name FROM tblBreakfast
        WHERE id=:id;
        """)
        return self.fetchone(query, id=id)
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def list_all(self):
        """Get all records from the database"""
        query = text("""
        SELECT id, name FROM tblBreakfast;
        """)
        return self.fetchall(query)
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def get_by_id(self, id):
        """Retrieve a record from the database by ID"""
        query = text("""
        SELECT id, name FROM tblIngredient
        WHERE id=:id;
        """)
        return self.fetchone(query, id=id)
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def list_all(self):
        """Get all records from the database"""
        query = text("""
        SELECT id, name FROM tblIngredient;
        """)
        return self.fetchall(query)
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def get_by_id(self, id):
        """Retrieve a record from the database by ID"""
        query = text("""
        SELECT id, first_name, last_name FROM tblUser
        WHERE id=:id;
        """)
        return self.fetchone(query, id=id)
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def create(self, first_name, last_name):
        """Create a new record in the database"""
        query = text("""
        INSERT INTO tblUser
            (first_name, last_name)
        VALUES
            (:first_name, :last_name);
        """)
        result = self.execute(query, first_name=first_name, last_name=last_name)
        db.session.commit()
        return result.lastrowid
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def update(self, id, first_name, last_name):
        """Update a record in the database with new values"""
        query = text("""
        UPDATE tblUser SET
            first_name=:first_name,
            last_name=:last_name
        WHERE
            id=:id
        """)
        result = self.execute(query,
                              id=id,
                              first_name=first_name,
                              last_name=last_name)
        db.session.commit()
        return result.rowcount > 0
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def delete(self, id):
        """Delete a record from the database for an ID"""
        query = text("""
        DELETE FROM tblUser WHERE id=:id
        """)
        result = self.execute(query, id=id)
        return result.rowcount > 0
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def list_all(self):
        """Get all records from the database"""
        query = text("""
        SELECT id, first_name, last_name FROM tblUser;
        """)
        return self.fetchall(query)
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def get_by_id(self, id):
        """Retrieve a record from the database by ID"""
        query = text("""
        SELECT id, user_id, ingredient_id, coefficient FROM tblUserPreference
        WHERE id=:id;
        """)
        return self.fetchone(query, id=id)
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def create(self, user_id, ingredient_id, coefficient):
        """Create a new record in the database"""
        query = text("""
        INSERT INTO tblUserPreference
            (user_id, ingredient_id, coefficient)
        VALUES
            (:user_id, :ingredient_id, :coefficient);
        """)
        result = self.execute(query,
                              user_id=user_id,
                              ingredient_id=ingredient_id,
                              coefficient=coefficient)
        db.session.commit()
        return result.lastrowid
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def delete(self, id):
        """Delete a record from the database for an ID"""
        query = text("""
        DELETE FROM tblUserPreference WHERE id=:id
        """)
        result = self.execute(query, id=id)
        return result.rowcount > 0
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def list_all_for_user(self, user_id):
        """Get all records from the database"""
        query = text("""
        SELECT id, user_id, ingredient_id, coefficient FROM tblUserPreference
        WHERE user_id=:user_id;
        """)
        return self.fetchall(query, user_id=user_id)
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def list_all(self):
        """Get all ingredients for all breakfasts"""

        query = text("""
        SELECT breakfast_id, ingredient_id, coefficient FROM tblBreakfastIngredient;
        """)

        return self.fetchall(query)
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def get_all_breakfast_ingredients(self):
        """Get all ingredients for all breakfasts"""

        query = text("""
        SELECT breakfast_id, ingredient_id, coefficient FROM tblBreakfastIngredient;
        """)

        return self.fetchall(query)
项目:eggsnspam    作者:wayfair    | 项目源码 | 文件源码
def test_fetchall(self):
        """It gets all rows"""
        query = text("""
        SELECT id, name FROM tblExample
        ORDER BY id
        """)

        # Expect multiple records to be returned
        result = self.dao.fetchall(query)
        self.assertEqual(len(result), 3)
        self.assertDictEqual(result[0], {'id': 1, 'name': 'Foo'})
        self.assertDictEqual(result[1], {'id': 2, 'name': 'Bar'})
        self.assertDictEqual(result[2], {'id': 3, 'name': 'Baz'})
项目:sqlalchemy-teradata    作者:Teradata    | 项目源码 | 文件源码
def has_table(self, connection, table_name, schema=None):

        if schema is None:
            schema=self.default_schema_name

        stmt = select([column('tablename')],
                      from_obj=[text('dbc.tablesvx')]).where(
                        and_(text('DatabaseName=:schema'),
                             text('TableName=:table_name')))

        res = connection.execute(stmt, schema=schema, table_name=table_name).fetchone()
        return res is not None
项目:sqlalchemy-teradata    作者:Teradata    | 项目源码 | 文件源码
def get_columns(self, connection, table_name, schema=None, **kw):

        helpView=False

        if schema is None:
            schema = self.default_schema_name

        if int(self.server_version_info.split('.')[0])<16:
            dbc_columninfo='dbc.ColumnsV'

            #Check if the object us a view
            stmt = select([column('tablekind')],\
                            from_obj=[text('dbc.tablesV')]).where(\
                            and_(text('DatabaseName=:schema'),\
                                 text('TableName=:table_name'),\
                                 text("tablekind='V'")))
            res = connection.execute(stmt, schema=schema, table_name=table_name).rowcount
            helpView = (res==1)

        else:
            dbc_columninfo='dbc.ColumnsQV'

        stmt = select([column('columnname'), column('columntype'),\
                        column('columnlength'), column('chartype'),\
                        column('decimaltotaldigits'), column('decimalfractionaldigits'),\
                        column('columnformat'),\
                        column('nullable'), column('defaultvalue'), column('idcoltype')],\
                        from_obj=[text(dbc_columninfo)]).where(\
                        and_(text('DatabaseName=:schema'),\
                             text('TableName=:table_name')))

        res = connection.execute(stmt, schema=schema, table_name=table_name).fetchall()

        #If this is a view in pre-16 version, get types for individual columns
        if helpView:
            res=[self._get_column_help(connection, schema,table_name,r['columnname']) for r in res]

        return [self._get_column_info(row) for row in res]
项目:sqlalchemy-teradata    作者:Teradata    | 项目源码 | 文件源码
def get_table_names(self, connection, schema=None, **kw):

        if schema is None:
            schema = self.default_schema_name

        stmt = select([column('tablename')],
                      from_obj=[text('dbc.TablesVX')]).where(
                      and_(text('DatabaseName = :schema'),
                          or_(text('tablekind=\'T\''),
                              text('tablekind=\'O\''))))
        res = connection.execute(stmt, schema=schema).fetchall()
        return [self.normalize_name(name['tablename']) for name in res]
项目:sqlalchemy-teradata    作者:Teradata    | 项目源码 | 文件源码
def get_schema_names(self, connection, **kw):
        stmt = select([column('username')],
               from_obj=[text('dbc.UsersV')],
               order_by=[text('username')])
        res = connection.execute(stmt).fetchall()
        return [self.normalize_name(name['username']) for name in res]
项目:sqlalchemy-teradata    作者:Teradata    | 项目源码 | 文件源码
def get_view_names(self, connection, schema=None, **kw):

        if schema is None:
            schema = self.default_schema_name

        stmt = select([column('tablename')],
                      from_obj=[text('dbc.TablesVX')]).where(
                      and_(text('DatabaseName = :schema'),
                           text('tablekind=\'V\'')))

        res = connection.execute(stmt, schema=schema).fetchall()
        return [self.normalize_name(name['tablename']) for name in res]
项目:sqlalchemy-teradata    作者:Teradata    | 项目源码 | 文件源码
def get_unique_constraints(self, connection, table_name, schema=None, **kw):
        """
        Overrides base class method
        """
        if schema is None:
            schema = self.default_schema_name

        stmt = select([column('ColumnName'), column('IndexName')], from_obj=[text('dbc.Indices')]) \
            .where(and_(text('DatabaseName = :schema'),
                        text('TableName=:table'),
                        text('IndexType=:indextype'))) \
            .order_by(asc(column('IndexName')))

        # U for Unique
        res = connection.execute(stmt, schema=schema, table=table_name, indextype='U').fetchall()

        def grouper(fk_row):
            return {
                'name': self.normalize_name(fk_row['IndexName']),
            }

        unique_constraints = list()
        for constraint_info, constraint_cols in groupby(res, grouper):
            unique_constraint = {
                'name': self.normalize_name(constraint_info['name']),
                'column_names': list()
            }

            for constraint_col in constraint_cols:
                unique_constraint['column_names'].append(self.normalize_name(constraint_col['ColumnName']))

            unique_constraints.append(unique_constraint)

        return unique_constraints
项目:sqlalchemy-teradata    作者:Teradata    | 项目源码 | 文件源码
def get_indexes(self, connection, table_name, schema=None, **kw):
        """
        Overrides base class method
        """

        if schema is None:
            schema = self.default_schema_name

        stmt = select(["*"], from_obj=[text('dbc.Indices')]) \
            .where(and_(text('DatabaseName = :schema'),
                        text('TableName=:table'))) \
            .order_by(asc(column('IndexName')))

        res = connection.execute(stmt, schema=schema, table=table_name).fetchall()

        def grouper(fk_row):
            return {
                'name': fk_row.IndexName or fk_row.IndexNumber, # If IndexName is None TODO: Check what to do
                'unique': True if fk_row.UniqueFlag == 'Y' else False
            }

        # TODO: Check if there's a better way
        indices = list()
        for index_info, index_cols in groupby(res, grouper):
            index_dict = {
                'name': index_info['name'],
                'column_names': list(),
                'unique': index_info['unique']
            }

            for index_col in index_cols:
                index_dict['column_names'].append(self.normalize_name(index_col['ColumnName']))

            indices.append(index_dict)

        return indices
项目:sqlalchemy-teradata    作者:Teradata    | 项目源码 | 文件源码
def get_transaction_mode(self, connection, **kw):
        """
        Returns the transaction mode set for the current session.
        T = TDBS
        A = ANSI
        """
        stmt = select([text('transaction_mode')],\
                from_obj=[text('dbc.sessioninfov')]).\
                where(text('sessionno=SESSION'))

        res = connection.execute(stmt).scalar()
        return res
项目:sqlalchemy-teradata    作者:Teradata    | 项目源码 | 文件源码
def _get_server_version_info(self, connection, **kw):
        """
        Returns the Teradata Database software version.
        """
        stmt = select([text('InfoData')],\
                from_obj=[text('dbc.dbcinfov')]).\
                where(text('InfoKey=\'VERSION\''))

        res = connection.execute(stmt).scalar()
        return res