def where(self, date, intervals): """ Generates a WHERE clause Args: date: the end date intervals: intervals Returns: a clause for filtering the from_obj to be between the date and the greatest interval """ # upper bound w = "{date_column} < '{date}'".format( date_column=self.date_column, date=date) # lower bound (if possible) if 'all' not in intervals: greatest = "greatest(%s)" % str.join( ",", ["interval '%s'" % i for i in intervals]) min_date = "'{date}'::date - {greatest}".format(date=date, greatest=greatest) w += "AND {date_column} >= {min_date}".format( date_column=self.date_column, min_date=min_date) if self.input_min_date is not None: w += "AND {date_column} >= '{bot}'::date".format( date_column=self.date_column, bot=self.input_min_date) return ex.text(w)
def get_ping_aggregates(self, template): now = datetime.utcnow() if template == '30m': a = now - timedelta(days=7) agg_date = sqle.text("datetime((strftime('%s', time) / 1800) * 1800, 'unixepoch')") else: raise Exception() cnt = sqlf.count(Ping.id) q = sql.select([ agg_date, sqlf.sum(Ping.response_time) / cnt, sqlf.sum(Ping.users) / cnt, sqlf.sum(Ping.statuses) / cnt, sqlf.sum(Ping.connections) / cnt, ]) q = q.where((Ping.time >= a) & (Ping.instance_id == self.id)) q = q.group_by(agg_date) q = q.order_by(Ping.id) Nt = namedtuple('PingAgg', ['time', 'response_time', 'users', 'statuses', 'connections']) r = db.session.execute(q) return [Nt(*t) for t in r]
def get_uptime_aggregates(self, template): now = datetime.utcnow() if template == '30m': a = now - timedelta(days=7) agg_date = sqle.text("datetime((strftime('%s', time) / 1800) * 1800, 'unixepoch')") else: raise Exception() cnt = sqlf.count(Ping.id) q = sql.select([ agg_date, Ping.state, sqlf.sum(Ping.response_time) / cnt, ]) q = q.where((Ping.time >= a) & (Ping.instance_id == self.id)) q = q.group_by(agg_date, Ping.state) q = q.order_by(Ping.id) Nt = namedtuple('PingUptimeAgg', ['time', 'state', 'response_time']) r = db.session.execute(q) return [Nt(*t) for t in r]
def test_pk_default(self): Table( 'simple_items', self.metadata, Column('id', INTEGER, primary_key=True, server_default=text('uuid_generate_v4()')) ) assert self.generate_code() == """\ # coding: utf-8 from sqlalchemy import Column, Integer, text from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() metadata = Base.metadata class SimpleItem(Base): __tablename__ = 'simple_items' id = Column(Integer, primary_key=True, server_default=text("uuid_generate_v4()")) """
def get_updated_rse_counters(total_workers, worker_number, session=None): """ Get updated rse_counters. :param total_workers: Number of total workers. :param worker_number: id of the executing worker. :param session: Database session in use. :returns: List of rse_ids whose rse_counters need to be updated. """ query = session.query(models.UpdatedRSECounter.rse_id).\ distinct(models.UpdatedRSECounter.rse_id) if total_workers > 0: if session.bind.dialect.name == 'oracle': bindparams = [bindparam('worker_number', worker_number), bindparam('total_workers', total_workers)] query = query.filter(text('ORA_HASH(rse_id, :total_workers) = :worker_number', bindparams=bindparams)) elif session.bind.dialect.name == 'mysql': query = query.filter('mod(md5(rse_id), %s) = %s' % (total_workers + 1, worker_number)) elif session.bind.dialect.name == 'postgresql': query = query.filter('mod(abs((\'x\'||md5(rse_id))::bit(32)::int), %s) = %s' % (total_workers + 1, worker_number)) results = query.all() return [result.rse_id for result in results]
def get_updated_account_counters(total_workers, worker_number, session=None): """ Get updated rse_counters. :param total_workers: Number of total workers. :param worker_number: id of the executing worker. :param session: Database session in use. :returns: List of rse_ids whose rse_counters need to be updated. """ query = session.query(models.UpdatedAccountCounter.account, models.UpdatedAccountCounter.rse_id).\ distinct(models.UpdatedAccountCounter.account, models.UpdatedAccountCounter.rse_id) if total_workers > 0: if session.bind.dialect.name == 'oracle': bindparams = [bindparam('worker_number', worker_number), bindparam('total_workers', total_workers)] query = query.filter(text('ORA_HASH(CONCAT(account, rse_id), :total_workers) = :worker_number', bindparams=bindparams)) elif session.bind.dialect.name == 'mysql': query = query.filter('mod(md5(concat(account, rse_id)), %s) = %s' % (total_workers + 1, worker_number)) elif session.bind.dialect.name == 'postgresql': query = query.filter('mod(abs((\'x\'||md5(concat(account, rse_id)))::bit(32)::int), %s) = %s' % (total_workers + 1, worker_number)) return query.all()
def get_db_time(): """ Gives the utc time on the db. """ s = session.get_session() try: storage_date_format = None if s.bind.dialect.name == 'oracle': query = select([text("sys_extract_utc(systimestamp)")]) elif s.bind.dialect.name == 'mysql': query = select([text("utc_timestamp()")]) elif s.bind.dialect.name == 'sqlite': query = select([text("datetime('now', 'utc')")]) storage_date_format = '%Y-%m-%d %H:%M:%S' else: query = select([func.current_date()]) for now, in s.execute(query): if storage_date_format: return datetime.strptime(now, storage_date_format) return now finally: s.remove()
def update(self, id, user_id, ingredient_id, coefficient): """Update a record in the database with new values""" query = text(""" UPDATE tblUserPreference SET user_id=:user_id, ingredient_id=:ingredient_id, coefficient=:coefficient WHERE id=:id """) result = self.execute(query, id=id, user_id=user_id, ingredient_id=ingredient_id, coefficient=coefficient) db.session.commit() return result.rowcount > 0
def test_fetchone(self): """It gets a single row""" query = text(""" SELECT id, name FROM tblExample WHERE id=:id; """) # Expect one record to be returned result = self.dao.fetchone(query, id=1) self.assertDictEqual(result, {'id': 1, 'name': 'Foo'}) # Expect None to be returned if result = self.dao.fetchone(query, id=-1) self.assertEqual(result, None) # Expect one record to be returned even if multiple rows result = self.dao.fetchone('SELECT id, name FROM tblExample ORDER BY id') self.assertDictEqual(result, {'id': 1, 'name': 'Foo'})
def test_fetchmany(self): """It gets many rows""" query = text(""" SELECT id, name FROM tblExample ORDER BY id """) # Expect multiple records to be returned result = self.dao.fetchmany(query) self.assertEqual(len(result), 3) self.assertDictEqual(result[0], {'id': 1, 'name': 'Foo'}) self.assertDictEqual(result[1], {'id': 2, 'name': 'Bar'}) self.assertDictEqual(result[2], {'id': 3, 'name': 'Baz'}) # Expect rows to be limited by MAX_RESULTS_SIZE self.dao.MAX_RESULTS_SIZE = 2 result = self.dao.fetchmany(query) self.assertEqual(len(result), 2)
def seed_identifier(cls): # @NoSelf '''returns data_identifier if the latter is not None, else net.sta.loc.cha by querying the relative channel and station''' # Needed note: To know what we are doing in 'sel' below, please look: # http://docs.sqlalchemy.org/en/latest/orm/extensions/hybrid.html#correlated-subquery-relationship-hybrid # Notes # - we use limit(1) cause we might get more than one # result. Regardless of why it happens (because we don't join or apply a distinct?) # it is relevant for us to get the first result which has the requested # network+station and location + channel strings # - the label(...) at the end makes all the difference. The doc is, as always, unclear # http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.label dot = text("'.'") sel = select([concat(Station.network, dot, Station.station, dot, Channel.location, dot, Channel.channel)]).\ where((Channel.id == cls.channel_id) & (Station.id == Channel.station_id)).limit(1).\ label('seedidentifier') return case([(cls.data_identifier.isnot(None), cls.data_identifier)], else_=sel)
def get_player(self, account_id=None, steam_id=None, real_name=None): query = Database.session.query(Player) if account_id: return query.filter(Player.account_id == account_id).first() elif steam_id: return query.filter(Player.steam_id == steam_id).first() elif real_name: # recommended to be optimized by full-text search. return query.filter(or_(text('real_name like :real_name'), text('persona_name like :real_name'))).params( real_name="%" + real_name + "%").limit(LIMIT_DATA).all() else: raise ValueError('Account id or Steam id or real name must be specified!')
def get_match_summary_aggregate(self, match_id): return Database.session.query(MatchHero.account_id, func.sum(text('match_heroes.player_win')).label('player_win'), func.count(MatchHero.player_win).label('matches')). \ filter(MatchHero.match_id >= match_id). \ group_by(MatchHero.account_id). \ all()
def get_match_hero_summary_aggregate(self, match_id): return Database.session.query(MatchHero.account_id, MatchHero.hero_id, func.sum(text('match_heroes.player_win')).label('player_win'), func.count(MatchHero.player_win).label('matches')). \ filter(MatchHero.match_id >= match_id). \ group_by(MatchHero.account_id). \ group_by(MatchHero.hero_id). \ all()
def get_match_item_summary_aggregate(self, match_id): return Database.session.query(MatchItem.account_id, MatchItem.item_id, func.sum(text('match_items.player_win')).label('player_win'), func.count(MatchItem.player_win).label('matches')). \ filter(MatchItem.match_id >= match_id). \ group_by(MatchItem.account_id). \ group_by(MatchItem.item_id). \ all()
def __init__(self, aggregates, groups, from_obj, state_table, state_group=None, prefix=None, suffix=None, schema=None): """ Args: aggregates: collection of Aggregate objects. from_obj: defines the from clause, e.g. the name of the table. can use groups: a list of expressions to group by in the aggregation or a dictionary pairs group: expr pairs where group is the alias (used in column names) state_table: schema.table to query for comprehensive set of state_group entities regardless of what exists in the from_obj state_group: the group level found in the state table (e.g., "entity_id") prefix: prefix for aggregation tables and column names, defaults to from_obj suffix: suffix for aggregation table, defaults to "aggregation" schema: schema for aggregation tables The from_obj and group expressions are passed directly to the SQLAlchemy Select object so could be anything supported there. For details see: http://docs.sqlalchemy.org/en/latest/core/selectable.html Aggregates will have {collate_date} in their quantities substituted with the date of aggregation. """ self.aggregates = aggregates self.from_obj = make_sql_clause(from_obj, ex.text) self.groups = groups if isinstance(groups, dict) else {str(g): g for g in groups} self.state_table = state_table self.state_group = state_group if state_group else "entity_id" self.prefix = prefix if prefix else str(from_obj) self.suffix = suffix if suffix else "aggregation" self.schema = schema
def remove_unicode_prefixes(text): return unicode_re.sub(r"\1\2\1", text)
def remove_unicode_prefixes(text): return text
def test_indexes_class(self): simple_items = Table( 'simple_items', self.metadata, Column('id', INTEGER, primary_key=True), Column('number', INTEGER), Column('text', VARCHAR) ) simple_items.indexes.add(Index('idx_number', simple_items.c.number)) simple_items.indexes.add(Index('idx_text_number', simple_items.c.text, simple_items.c.number)) simple_items.indexes.add(Index('idx_text', simple_items.c.text, unique=True)) assert self.generate_code() == """\ # coding: utf-8 from sqlalchemy import Column, Index, Integer, String from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() metadata = Base.metadata class SimpleItem(Base): __tablename__ = 'simple_items' __table_args__ = ( Index('idx_text_number', 'text', 'number'), ) id = Column(Integer, primary_key=True) number = Column(Integer, index=True) text = Column(String, unique=True) """
def get_updated_dids(total_workers, worker_number, limit=100, blacklisted_dids=[], session=None): """ Get updated dids. :param total_workers: Number of total workers. :param worker_number: id of the executing worker. :param limit: Maximum number of dids to return. :param blacklisted_dids: Blacklisted dids to filter. :param session: Database session in use. """ query = session.query(models.UpdatedDID.id, models.UpdatedDID.scope, models.UpdatedDID.name, models.UpdatedDID.rule_evaluation_action) if total_workers > 0: if session.bind.dialect.name == 'oracle': bindparams = [bindparam('worker_number', worker_number), bindparam('total_workers', total_workers)] query = query.filter(text('ORA_HASH(name, :total_workers) = :worker_number', bindparams=bindparams)) elif session.bind.dialect.name == 'mysql': query = query.filter(text('mod(md5(name), %s) = %s' % (total_workers + 1, worker_number))) elif session.bind.dialect.name == 'postgresql': query = query.filter(text('mod(abs((\'x\'||md5(name))::bit(32)::int), %s) = %s' % (total_workers + 1, worker_number))) if limit: fetched_dids = query.order_by(models.UpdatedDID.created_at).limit(limit).all() filtered_dids = [did for did in fetched_dids if (did.scope, did.name) not in blacklisted_dids] if len(fetched_dids) == limit and len(filtered_dids) == 0: return get_updated_dids(total_workers=total_workers, worker_number=worker_number, limit=None, blacklisted_dids=blacklisted_dids, session=session) else: return filtered_dids else: return [did for did in query.order_by(models.UpdatedDID.created_at).all() if (did.scope, did.name) not in blacklisted_dids]
def get_rules_beyond_eol(date_check, worker_number, total_workers, session): """ Get rules which have eol_at before a certain date. :param date_check: The reference date that should be compared to eol_at. :param worker_number: id of the executing worker. :param total_workers: Number of total workers. :param session: Database session in use. """ query = session.query(models.ReplicationRule.scope, models.ReplicationRule.name, models.ReplicationRule.rse_expression, models.ReplicationRule.locked, models.ReplicationRule.id, models.ReplicationRule.eol_at, models.ReplicationRule.expires_at).\ filter(models.ReplicationRule.eol_at < date_check) if session.bind.dialect.name == 'oracle': bindparams = [bindparam('worker_number', worker_number), bindparam('total_workers', total_workers)] query = query.filter(text('ORA_HASH(name, :total_workers) = :worker_number', bindparams=bindparams)) elif session.bind.dialect.name == 'mysql': query = query.filter(text('mod(md5(name), %s) = %s' % (total_workers + 1, worker_number))) elif session.bind.dialect.name == 'postgresql': query = query.filter(text('mod(abs((\'x\'||md5(name))::bit(32)::int), %s) = %s' % (total_workers + 1, worker_number))) return [rule for rule in query.all()]
def get_expired_rules(total_workers, worker_number, limit=100, blacklisted_rules=[], session=None): """ Get expired rules. :param total_workers: Number of total workers. :param worker_number: id of the executing worker. :param limit: Maximum number of rules to return. :param backlisted_rules: List of blacklisted rules. :param session: Database session in use. """ query = session.query(models.ReplicationRule.id, models.ReplicationRule.rse_expression).filter(models.ReplicationRule.expires_at < datetime.utcnow(), models.ReplicationRule.locked == False, models.ReplicationRule.child_rule_id == None).\ with_hint(models.ReplicationRule, "index(rules RULES_EXPIRES_AT_IDX)", 'oracle').\ order_by(models.ReplicationRule.expires_at) # NOQA if session.bind.dialect.name == 'oracle': bindparams = [bindparam('worker_number', worker_number), bindparam('total_workers', total_workers)] query = query.filter(text('ORA_HASH(name, :total_workers) = :worker_number', bindparams=bindparams)) elif session.bind.dialect.name == 'mysql': query = query.filter(text('mod(md5(name), %s) = %s' % (total_workers + 1, worker_number))) elif session.bind.dialect.name == 'postgresql': query = query.filter(text('mod(abs((\'x\'||md5(name))::bit(32)::int), %s) = %s' % (total_workers + 1, worker_number))) if limit: fetched_rules = query.limit(limit).all() filtered_rules = [rule for rule in fetched_rules if rule[0] not in blacklisted_rules] if len(fetched_rules) == limit and len(filtered_rules) == 0: return get_expired_rules(total_workers=total_workers, worker_number=worker_number, limit=None, blacklisted_rules=blacklisted_rules, session=session) else: return filtered_rules else: return [rule for rule in query.all() if rule[0] not in blacklisted_rules]
def list_bad_replicas_history(limit=10000, thread=None, total_threads=None, session=None): """ List the bad file replicas history. Method only used by necromancer :param limit: The maximum number of replicas returned. :param thread: The assigned thread for this necromancer. :param total_threads: The total number of threads of all necromancers. :param session: The database session in use. """ query = session.query(models.BadReplicas.scope, models.BadReplicas.name, models.BadReplicas.rse_id).\ filter(models.BadReplicas.state == BadFilesStatus.BAD) if total_threads and (total_threads - 1) > 0: if session.bind.dialect.name == 'oracle': bindparams = [bindparam('thread_number', thread), bindparam('total_threads', total_threads - 1)] query = query.filter(text('ORA_HASH(name, :total_threads) = :thread_number', bindparams=bindparams)) elif session.bind.dialect.name == 'mysql': query = query.filter(text('mod(md5(name), %s) = %s' % (total_threads - 1, thread))) elif session.bind.dialect.name == 'postgresql': query = query.filter(text('mod(abs((\'x\'||md5(name))::bit(32)::int), %s) = %s' % (total_threads - 1, thread))) query = query.limit(limit) bad_replicas = {} for scope, name, rse_id in query.yield_per(1000): if rse_id not in bad_replicas: bad_replicas[rse_id] = [] bad_replicas[rse_id].append({'scope': scope, 'name': name}) return bad_replicas
def get_replica_atime(replica, session=None): """ Get the accessed_at timestamp for a replica. Just for testing. :param replicas: List of dictionaries {scope, name, rse_id, path} :param session: Database session to use. :returns: A datetime timestamp with the last access time. """ if 'rse_id' not in replica: replica['rse_id'] = get_rse_id(rse=replica['rse'], session=session) return session.query(models.RSEFileAssociation.accessed_at).filter_by(scope=replica['scope'], name=replica['name'], rse_id=replica['rse_id']).\ with_hint(models.RSEFileAssociation, text="INDEX(REPLICAS REPLICAS_PK)", dialect_name='oracle').one()[0]
def list_expired_temporary_dids(rse, limit, worker_number=None, total_workers=None, session=None): """ List expired temporary DIDs. :param rse: the rse name. :param limit: The maximum number of replicas returned. :param worker_number: id of the executing worker. :param total_workers: Number of total workers. :param session: The database session in use. :returns: a list of dictionary replica. """ rse_id = get_rse_id(rse, session=session) is_none = None query = session.query(models.TemporaryDataIdentifier.scope, models.TemporaryDataIdentifier.name, models.TemporaryDataIdentifier.path, models.TemporaryDataIdentifier.bytes).\ with_hint(models.TemporaryDataIdentifier, "INDEX(tmp_dids TMP_DIDS_EXPIRED_AT_IDX)", 'oracle').\ filter(case([(models.TemporaryDataIdentifier.expired_at != is_none, models.TemporaryDataIdentifier.rse_id), ]) == rse_id) if worker_number and total_workers and total_workers - 1 > 0: if session.bind.dialect.name == 'oracle': bindparams = [bindparam('worker_number', worker_number - 1), bindparam('total_workers', total_workers - 1)] query = query.filter(text('ORA_HASH(name, :total_workers) = :worker_number', bindparams=bindparams)) elif session.bind.dialect.name == 'mysql': query = query.filter(text('mod(md5(name), %s) = %s' % (total_workers - 1, worker_number - 1))) elif session.bind.dialect.name == 'postgresql': query = query.filter(text('mod(abs((\'x\'||md5(path))::bit(32)::int), %s) = %s' % (total_workers - 1, worker_number - 1))) return [{'path': path, 'rse': rse, 'rse_id': rse_id, 'scope': scope, 'name': name, 'bytes': bytes} for scope, name, path, bytes in query.limit(limit)]
def get_by_id(self, id): """Retrieve a record from the database by ID""" query = text(""" SELECT id, name FROM tblBreakfast WHERE id=:id; """) return self.fetchone(query, id=id)
def list_all(self): """Get all records from the database""" query = text(""" SELECT id, name FROM tblBreakfast; """) return self.fetchall(query)
def get_by_id(self, id): """Retrieve a record from the database by ID""" query = text(""" SELECT id, name FROM tblIngredient WHERE id=:id; """) return self.fetchone(query, id=id)
def list_all(self): """Get all records from the database""" query = text(""" SELECT id, name FROM tblIngredient; """) return self.fetchall(query)
def get_by_id(self, id): """Retrieve a record from the database by ID""" query = text(""" SELECT id, first_name, last_name FROM tblUser WHERE id=:id; """) return self.fetchone(query, id=id)
def create(self, first_name, last_name): """Create a new record in the database""" query = text(""" INSERT INTO tblUser (first_name, last_name) VALUES (:first_name, :last_name); """) result = self.execute(query, first_name=first_name, last_name=last_name) db.session.commit() return result.lastrowid
def update(self, id, first_name, last_name): """Update a record in the database with new values""" query = text(""" UPDATE tblUser SET first_name=:first_name, last_name=:last_name WHERE id=:id """) result = self.execute(query, id=id, first_name=first_name, last_name=last_name) db.session.commit() return result.rowcount > 0
def delete(self, id): """Delete a record from the database for an ID""" query = text(""" DELETE FROM tblUser WHERE id=:id """) result = self.execute(query, id=id) return result.rowcount > 0
def list_all(self): """Get all records from the database""" query = text(""" SELECT id, first_name, last_name FROM tblUser; """) return self.fetchall(query)
def get_by_id(self, id): """Retrieve a record from the database by ID""" query = text(""" SELECT id, user_id, ingredient_id, coefficient FROM tblUserPreference WHERE id=:id; """) return self.fetchone(query, id=id)
def create(self, user_id, ingredient_id, coefficient): """Create a new record in the database""" query = text(""" INSERT INTO tblUserPreference (user_id, ingredient_id, coefficient) VALUES (:user_id, :ingredient_id, :coefficient); """) result = self.execute(query, user_id=user_id, ingredient_id=ingredient_id, coefficient=coefficient) db.session.commit() return result.lastrowid
def delete(self, id): """Delete a record from the database for an ID""" query = text(""" DELETE FROM tblUserPreference WHERE id=:id """) result = self.execute(query, id=id) return result.rowcount > 0
def list_all_for_user(self, user_id): """Get all records from the database""" query = text(""" SELECT id, user_id, ingredient_id, coefficient FROM tblUserPreference WHERE user_id=:user_id; """) return self.fetchall(query, user_id=user_id)
def list_all(self): """Get all ingredients for all breakfasts""" query = text(""" SELECT breakfast_id, ingredient_id, coefficient FROM tblBreakfastIngredient; """) return self.fetchall(query)
def get_all_breakfast_ingredients(self): """Get all ingredients for all breakfasts""" query = text(""" SELECT breakfast_id, ingredient_id, coefficient FROM tblBreakfastIngredient; """) return self.fetchall(query)
def test_fetchall(self): """It gets all rows""" query = text(""" SELECT id, name FROM tblExample ORDER BY id """) # Expect multiple records to be returned result = self.dao.fetchall(query) self.assertEqual(len(result), 3) self.assertDictEqual(result[0], {'id': 1, 'name': 'Foo'}) self.assertDictEqual(result[1], {'id': 2, 'name': 'Bar'}) self.assertDictEqual(result[2], {'id': 3, 'name': 'Baz'})
def has_table(self, connection, table_name, schema=None): if schema is None: schema=self.default_schema_name stmt = select([column('tablename')], from_obj=[text('dbc.tablesvx')]).where( and_(text('DatabaseName=:schema'), text('TableName=:table_name'))) res = connection.execute(stmt, schema=schema, table_name=table_name).fetchone() return res is not None
def get_columns(self, connection, table_name, schema=None, **kw): helpView=False if schema is None: schema = self.default_schema_name if int(self.server_version_info.split('.')[0])<16: dbc_columninfo='dbc.ColumnsV' #Check if the object us a view stmt = select([column('tablekind')],\ from_obj=[text('dbc.tablesV')]).where(\ and_(text('DatabaseName=:schema'),\ text('TableName=:table_name'),\ text("tablekind='V'"))) res = connection.execute(stmt, schema=schema, table_name=table_name).rowcount helpView = (res==1) else: dbc_columninfo='dbc.ColumnsQV' stmt = select([column('columnname'), column('columntype'),\ column('columnlength'), column('chartype'),\ column('decimaltotaldigits'), column('decimalfractionaldigits'),\ column('columnformat'),\ column('nullable'), column('defaultvalue'), column('idcoltype')],\ from_obj=[text(dbc_columninfo)]).where(\ and_(text('DatabaseName=:schema'),\ text('TableName=:table_name'))) res = connection.execute(stmt, schema=schema, table_name=table_name).fetchall() #If this is a view in pre-16 version, get types for individual columns if helpView: res=[self._get_column_help(connection, schema,table_name,r['columnname']) for r in res] return [self._get_column_info(row) for row in res]
def get_table_names(self, connection, schema=None, **kw): if schema is None: schema = self.default_schema_name stmt = select([column('tablename')], from_obj=[text('dbc.TablesVX')]).where( and_(text('DatabaseName = :schema'), or_(text('tablekind=\'T\''), text('tablekind=\'O\'')))) res = connection.execute(stmt, schema=schema).fetchall() return [self.normalize_name(name['tablename']) for name in res]
def get_schema_names(self, connection, **kw): stmt = select([column('username')], from_obj=[text('dbc.UsersV')], order_by=[text('username')]) res = connection.execute(stmt).fetchall() return [self.normalize_name(name['username']) for name in res]
def get_view_names(self, connection, schema=None, **kw): if schema is None: schema = self.default_schema_name stmt = select([column('tablename')], from_obj=[text('dbc.TablesVX')]).where( and_(text('DatabaseName = :schema'), text('tablekind=\'V\''))) res = connection.execute(stmt, schema=schema).fetchall() return [self.normalize_name(name['tablename']) for name in res]
def get_unique_constraints(self, connection, table_name, schema=None, **kw): """ Overrides base class method """ if schema is None: schema = self.default_schema_name stmt = select([column('ColumnName'), column('IndexName')], from_obj=[text('dbc.Indices')]) \ .where(and_(text('DatabaseName = :schema'), text('TableName=:table'), text('IndexType=:indextype'))) \ .order_by(asc(column('IndexName'))) # U for Unique res = connection.execute(stmt, schema=schema, table=table_name, indextype='U').fetchall() def grouper(fk_row): return { 'name': self.normalize_name(fk_row['IndexName']), } unique_constraints = list() for constraint_info, constraint_cols in groupby(res, grouper): unique_constraint = { 'name': self.normalize_name(constraint_info['name']), 'column_names': list() } for constraint_col in constraint_cols: unique_constraint['column_names'].append(self.normalize_name(constraint_col['ColumnName'])) unique_constraints.append(unique_constraint) return unique_constraints
def get_indexes(self, connection, table_name, schema=None, **kw): """ Overrides base class method """ if schema is None: schema = self.default_schema_name stmt = select(["*"], from_obj=[text('dbc.Indices')]) \ .where(and_(text('DatabaseName = :schema'), text('TableName=:table'))) \ .order_by(asc(column('IndexName'))) res = connection.execute(stmt, schema=schema, table=table_name).fetchall() def grouper(fk_row): return { 'name': fk_row.IndexName or fk_row.IndexNumber, # If IndexName is None TODO: Check what to do 'unique': True if fk_row.UniqueFlag == 'Y' else False } # TODO: Check if there's a better way indices = list() for index_info, index_cols in groupby(res, grouper): index_dict = { 'name': index_info['name'], 'column_names': list(), 'unique': index_info['unique'] } for index_col in index_cols: index_dict['column_names'].append(self.normalize_name(index_col['ColumnName'])) indices.append(index_dict) return indices
def get_transaction_mode(self, connection, **kw): """ Returns the transaction mode set for the current session. T = TDBS A = ANSI """ stmt = select([text('transaction_mode')],\ from_obj=[text('dbc.sessioninfov')]).\ where(text('sessionno=SESSION')) res = connection.execute(stmt).scalar() return res
def _get_server_version_info(self, connection, **kw): """ Returns the Teradata Database software version. """ stmt = select([text('InfoData')],\ from_obj=[text('dbc.dbcinfov')]).\ where(text('InfoKey=\'VERSION\'')) res = connection.execute(stmt).scalar() return res