Server : Apache System : Linux iad1-shared-b8-43 6.6.49-grsec-jammy+ #10 SMP Thu Sep 12 23:23:08 UTC 2024 x86_64 User : dh_edsupp ( 6597262) PHP Version : 8.2.26 Disable Function : NONE Directory : /lib/python3/dist-packages/trac/db/ |
Upload File : |
# -*- coding: utf-8 -*- # # Copyright (C) 2005-2021 Edgewall Software # Copyright (C) 2005-2006 Christopher Lenz <cmlenz@gmx.de> # Copyright (C) 2005 Jeff Weiss <trac@jeffweiss.org> # Copyright (C) 2006 Andres Salomon <dilinger@athenacr.com> # All rights reserved. # # This software is licensed as described in the file COPYING, which # you should have received as part of this distribution. The terms # are also available at https://trac.edgewall.org/wiki/TracLicense. # # This software consists of voluntary contributions made by many # individuals. For the exact contribution history, see the revision # history and logs, available at https://trac.edgewall.org/log/. import os import re import sys from contextlib import closing from subprocess import Popen, PIPE from trac.api import IEnvironmentSetupParticipant from trac.core import * from trac.config import Option from trac.db.api import ConnectionBase, DatabaseManager, IDatabaseConnector, \ get_column_names, parse_connection_uri from trac.db.util import ConnectionWrapper, IterableCursor from trac.util import as_int, get_pkginfo from trac.util.html import Markup from trac.util.compat import close_fds from trac.util.text import exception_to_unicode, to_unicode from trac.util.translation import _ _like_escape_re = re.compile(r'([/_%])') try: import pymysql except ImportError: pymysql = None pymsql_version = None else: pymsql_version = get_pkginfo(pymysql).get('version', pymysql.__version__) class MySQLUnicodeCursor(pymysql.cursors.Cursor): def execute(self, query, args=None): if args: args = tuple(str(arg) if isinstance(arg, Markup) else arg for arg in args) return super().execute(query, args) def executemany(self, query, args): if args: args = [tuple(str(item) if isinstance(item, Markup) else item for item in arg) for arg in args] return super().executemany(query, args) def fetchall(self): return list(super().fetchall()) class MySQLSilentCursor(MySQLUnicodeCursor): def _show_warnings(self, conn=None): pass # Mapping from "abstract" SQL types to DB-specific types _type_map = { 'int64': 'bigint', 'text': 'mediumtext', } def _quote(identifier): return "`%s`" % identifier.replace('`', '``') class MySQLConnector(Component): """Database connector for MySQL version 4.1 and greater. Database URLs should be of the form:: {{{ mysql://user[:password]@host[:port]/database[?param1=value¶m2=value] }}} The following parameters are supported: * `compress`: Enable compression (0 or 1) * `init_command`: Command to run once the connection is created * `named_pipe`: Use a named pipe to connect on Windows (0 or 1) * `read_default_file`: Read default client values from the given file * `read_default_group`: Configuration group to use from the default file * `unix_socket`: Use a Unix socket at the given path to connect """ implements(IDatabaseConnector, IEnvironmentSetupParticipant) required = False mysqldump_path = Option('trac', 'mysqldump_path', 'mysqldump', """Location of mysqldump for MySQL database backups""") def __init__(self): if pymysql: self._mysql_version = \ 'server: (not-connected), client: "%s", thread-safe: %s' % \ (pymysql.get_client_info(), pymysql.thread_safe()) else: self._mysql_version = None # IDatabaseConnector methods def get_supported_schemes(self): yield 'mysql', 1 def get_connection(self, path, log=None, user=None, password=None, host=None, port=None, params={}): cnx = MySQLConnection(path, log, user, password, host, port, params) if not self.required: self._mysql_version = \ 'server: "%s", client: "%s", thread-safe: %s' \ % (cnx.cnx.get_server_info(), pymysql.get_client_info(), pymysql.thread_safe()) self.required = True return cnx def get_exceptions(self): return pymysql def init_db(self, path, schema=None, log=None, user=None, password=None, host=None, port=None, params={}): cnx = self.get_connection(path, log, user, password, host, port, params) self._verify_variables(cnx) max_bytes = self._max_bytes(cnx) cursor = cnx.cursor() if schema is None: from trac.db_default import schema for table in schema: for stmt in self.to_sql(table, max_bytes=max_bytes): self.log.debug(stmt) cursor.execute(stmt) self._verify_table_status(cnx) cnx.commit() def destroy_db(self, path, log=None, user=None, password=None, host=None, port=None, params={}): cnx = self.get_connection(path, log, user, password, host, port, params) for table_name in cnx.get_table_names(): cnx.drop_table(table_name) cnx.commit() def db_exists(self, path, log=None, user=None, password=None, host=None, port=None, params={}): cnx = self.get_connection(path, log, user, password, host, port, params) return bool(cnx.get_table_names()) def _max_bytes(self, cnx): if cnx is None: connector, args = DatabaseManager(self.env).get_connector() with closing(connector.get_connection(**args)) as cnx: charset = cnx.charset else: charset = cnx.charset return 4 if charset == 'utf8mb4' else 3 _max_key_length = 3072 def _collist(self, table, columns, max_bytes): """Take a list of columns and impose limits on each so that indexing works properly. Some Versions of MySQL limit each index prefix to 3072 bytes total, with a max of 767 bytes per column. """ cols = [] limit_col = 767 // max_bytes limit = min(self._max_key_length // (max_bytes * len(columns)), limit_col) for c in columns: name = _quote(c) table_col = list(filter((lambda x: x.name == c), table.columns)) if len(table_col) == 1 and table_col[0].type.lower() == 'text': if table_col[0].key_size is not None: name += '(%d)' % min(table_col[0].key_size, limit_col) else: name += '(%s)' % limit # For non-text columns, we simply throw away the extra bytes. # That could certainly be optimized better, but for now let's KISS. cols.append(name) return ','.join(cols) def to_sql(self, table, max_bytes=None): if max_bytes is None: max_bytes = self._max_bytes(None) sql = ['CREATE TABLE %s (' % _quote(table.name)] coldefs = [] for column in table.columns: ctype = column.type ctype = _type_map.get(ctype, ctype) if column.auto_increment: ctype = 'INT UNSIGNED NOT NULL AUTO_INCREMENT' # Override the column type, as a text field cannot # use auto_increment. column.type = 'int' coldefs.append(' %s %s' % (_quote(column.name), ctype)) if len(table.key) > 0: coldefs.append(' PRIMARY KEY (%s)' % self._collist(table, table.key, max_bytes=max_bytes)) sql.append(',\n'.join(coldefs) + '\n)') yield '\n'.join(sql) for index in table.indices: unique = 'UNIQUE' if index.unique else '' idxname = '%s_%s_idx' % (table.name, '_'.join(index.columns)) yield 'CREATE %s INDEX %s ON %s (%s)' % \ (unique, _quote(idxname), _quote(table.name), self._collist(table, index.columns, max_bytes=max_bytes)) def alter_column_types(self, table, columns): """Yield SQL statements altering the type of one or more columns of a table. Type changes are specified as a `columns` dict mapping column names to `(from, to)` SQL type tuples. """ alterations = [] for name, (from_, to) in sorted(columns.items()): to = _type_map.get(to, to) if to != _type_map.get(from_, from_): alterations.append((name, to)) if alterations: yield "ALTER TABLE %s %s" % (table, ', '.join("MODIFY %s %s" % each for each in alterations)) def backup(self, dest_file): db_url = self.env.config.get('trac', 'database') scheme, db_prop = parse_connection_uri(db_url) db_params = db_prop.setdefault('params', {}) db_name = os.path.basename(db_prop['path']) args = [self.mysqldump_path, '--no-defaults'] if 'host' in db_prop: args.extend(['-h', db_prop['host']]) if 'port' in db_prop: args.extend(['-P', str(db_prop['port'])]) if 'user' in db_prop: args.extend(['-u', db_prop['user']]) for name, value in db_params.items(): if name == 'compress' and as_int(value, 0): args.append('--compress') elif name == 'named_pipe' and as_int(value, 0): args.append('--protocol=pipe') elif name == 'read_default_file': # Must be first args.insert(1, '--defaults-file=' + value) elif name == 'unix_socket': args.extend(['--protocol=socket', '--socket=' + value]) elif name not in ('init_command', 'read_default_group'): self.log.warning("Invalid connection string parameter '%s'", name) args.extend(['-r', dest_file, db_name]) environ = os.environ.copy() if 'password' in db_prop: environ['MYSQL_PWD'] = str(db_prop['password']) try: p = Popen(args, env=environ, stderr=PIPE, close_fds=close_fds) except OSError as e: raise TracError(_("Unable to run %(path)s: %(msg)s", path=self.mysqldump_path, msg=exception_to_unicode(e))) errmsg = p.communicate()[1] if p.returncode != 0: raise TracError(_("mysqldump failed: %(msg)s", msg=to_unicode(errmsg.strip()))) if not os.path.exists(dest_file): raise TracError(_("No destination file created")) return dest_file def get_system_info(self): yield 'MySQL', self._mysql_version yield pymysql.__name__, pymsql_version # IEnvironmentSetupParticipant methods def environment_created(self): pass def environment_needs_upgrade(self): if self.required: with self.env.db_query as db: self._verify_table_status(db) self._verify_variables(db) return False def upgrade_environment(self): pass UNSUPPORTED_ENGINES = ('MyISAM', 'EXAMPLE', 'ARCHIVE', 'CSV', 'ISAM') def _verify_table_status(self, db): from trac.db_default import schema tables = [t.name for t in schema] cursor = db.cursor() cursor.execute("SHOW TABLE STATUS WHERE name IN (%s)" % ','.join(('%s',) * len(tables)), tables) cols = get_column_names(cursor) rows = [dict(zip(cols, row)) for row in cursor] engines = [row['Name'] for row in rows if row['Engine'] in self.UNSUPPORTED_ENGINES] if engines: raise TracError(_( "All tables must be created as InnoDB or NDB storage engine " "to support transactions. The following tables have been " "created as storage engine which doesn't support " "transactions: %(tables)s", tables=', '.join(engines))) non_utf8bin = [row['Name'] for row in rows if row['Collation'] not in ('utf8_bin', 'utf8mb4_bin', None)] if non_utf8bin: raise TracError(_("All tables must be created with utf8_bin or " "utf8mb4_bin as collation. The following tables " "don't have the collations: %(tables)s", tables=', '.join(non_utf8bin))) SUPPORTED_COLLATIONS = (('utf8', 'utf8_bin'), ('utf8mb4', 'utf8mb4_bin')) def _verify_variables(self, db): cursor = db.cursor() cursor.execute("SHOW VARIABLES WHERE variable_name IN (" "'default_storage_engine','storage_engine'," "'default_tmp_storage_engine'," "'character_set_database','collation_database')") vars = {row[0].lower(): row[1] for row in cursor} engine = vars.get('default_storage_engine') or \ vars.get('storage_engine') if engine in self.UNSUPPORTED_ENGINES: raise TracError(_("The current storage engine is %(engine)s. " "It must be InnoDB or NDB storage engine to " "support transactions.", engine=engine)) tmp_engine = vars.get('default_tmp_storage_engine') if tmp_engine in self.UNSUPPORTED_ENGINES: raise TracError(_("The current storage engine for TEMPORARY " "tables is %(engine)s. It must be InnoDB or NDB " "storage engine to support transactions.", engine=tmp_engine)) charset = vars['character_set_database'] collation = vars['collation_database'] if (charset, collation) not in self.SUPPORTED_COLLATIONS: raise TracError(_( "The charset and collation of database are '%(charset)s' and " "'%(collation)s'. The database must be created with one of " "%(supported)s.", charset=charset, collation=collation, supported=repr(self.SUPPORTED_COLLATIONS))) class MySQLConnection(ConnectionBase, ConnectionWrapper): """Connection wrapper for MySQL.""" poolable = True def __init__(self, path, log, user=None, password=None, host=None, port=None, params={}): if path.startswith('/'): path = path[1:] if password is None: password = '' if port is None: port = 3306 opts = {'charset': 'utf8'} for name, value in params.items(): if name == 'read_default_group': opts[name] = value elif name == 'init_command': opts[name] = value elif name in ('read_default_file', 'unix_socket'): opts[name] = value elif name in ('compress', 'named_pipe'): opts[name] = as_int(value, 0) elif name == 'charset': value = value.lower() if value in ('utf8', 'utf8mb4'): opts[name] = value else: self.log.warning("Invalid connection string parameter " "'%s=%s'", name, value) else: self.log.warning("Invalid connection string parameter '%s'", name) cnx = pymysql.connect(db=path, user=user, passwd=password, host=host, port=port, **opts) cursor = cnx.cursor() cursor.execute("SHOW VARIABLES WHERE " " variable_name='character_set_database'") self.charset = cursor.fetchone()[1] cursor.close() if self.charset != opts['charset']: cnx.close() opts['charset'] = self.charset cnx = pymysql.connect(db=path, user=user, passwd=password, host=host, port=port, **opts) self.schema = path ConnectionWrapper.__init__(self, cnx, log) self._is_closed = False def cursor(self): return IterableCursor(MySQLUnicodeCursor(self.cnx), self.log) def rollback(self): self.cnx.ping() try: self.cnx.rollback() except pymysql.ProgrammingError: self._is_closed = True def close(self): if not self._is_closed: try: self.cnx.close() except pymysql.ProgrammingError: pass # this error would mean it's already closed. So, ignore self._is_closed = True def cast(self, column, type): if type in ('int', 'int64'): type = 'signed' elif type == 'text': type = 'char' return 'CAST(%s AS %s)' % (column, type) def concat(self, *args): return 'concat(%s)' % ', '.join(args) def drop_column(self, table, column): cursor = pymysql.cursors.Cursor(self.cnx) if column in self.get_column_names(table): quoted_table = self.quote(table) cursor.execute("SHOW INDEX FROM %s" % quoted_table) columns = get_column_names(cursor) keys = {} for row in cursor.fetchall(): row = dict(zip(columns, row)) keys.setdefault(row['Key_name'], []).append(row['Column_name']) # drop all composite indices which in the given column is involved for key, columns in keys.items(): if len(columns) > 1 and column in columns: if key == 'PRIMARY': cursor.execute("ALTER TABLE %s DROP PRIMARY KEY" % quoted_table) else: cursor.execute("ALTER TABLE %s DROP KEY %s" % (quoted_table, self.quote(key))) cursor.execute("ALTER TABLE %s DROP COLUMN %s " % (quoted_table, self.quote(column))) def drop_table(self, table): cursor = MySQLSilentCursor(self.cnx) cursor.execute("DROP TABLE IF EXISTS " + self.quote(table)) def get_column_names(self, table): rows = self.execute(""" SELECT column_name FROM information_schema.columns WHERE table_schema=%s AND table_name=%s ORDER BY ordinal_position """, (self.schema, table)) return [row[0] for row in rows] def get_last_id(self, cursor, table, column='id'): return cursor.lastrowid def get_sequence_names(self): return [] def get_table_names(self): rows = self.execute(""" SELECT table_name FROM information_schema.tables WHERE table_schema=%s """, (self.schema,)) return [row[0] for row in rows] def has_table(self, table): rows = self.execute(""" SELECT EXISTS (SELECT * FROM information_schema.columns WHERE table_schema=%s AND table_name=%s) """, (self.schema, table)) return bool(rows[0][0]) def like(self): return "LIKE %%s COLLATE %s_general_ci ESCAPE '/'" % self.charset def like_escape(self, text): return _like_escape_re.sub(r'/\1', text) def reset_tables(self): table_names = [] if not self.schema: return table_names cursor = self.cursor() cursor.execute(""" SELECT t.table_name, EXISTS (SELECT * FROM information_schema.columns AS c WHERE c.table_schema=t.table_schema AND c.table_name=t.table_name AND extra='auto_increment') FROM information_schema.tables AS t WHERE t.table_schema=%s """, (self.schema,)) for table, has_autoinc in cursor.fetchall(): table_names.append(table) quoted = self.quote(table) if not has_autoinc: # DELETE FROM is preferred to TRUNCATE TABLE, as the # auto_increment is not used. cursor.execute("DELETE FROM %s" % quoted) else: # TRUNCATE TABLE is preferred to DELETE FROM, as we # need to reset the auto_increment in MySQL. cursor.execute("TRUNCATE TABLE %s" % quoted) return table_names def prefix_match(self): return "LIKE %s ESCAPE '/'" def prefix_match_value(self, prefix): return self.like_escape(prefix) + '%' def quote(self, identifier): """Return the quoted identifier.""" return _quote(identifier) def update_sequence(self, cursor, table, column='id'): # MySQL handles sequence updates automagically pass