Server : Apache System : Linux iad1-shared-b8-43 6.6.49-grsec-jammy+ #10 SMP Thu Sep 12 23:23:08 UTC 2024 x86_64 User : dh_edsupp ( 6597262) PHP Version : 8.2.26 Disable Function : NONE Directory : /lib/python3/dist-packages/trac/db/ |
Upload File : |
# -*- coding: utf-8 -*- # # Copyright (C) 2005-2021 Edgewall Software # Copyright (C) 2005 Christopher Lenz <cmlenz@gmx.de> # All rights reserved. # # This software is licensed as described in the file COPYING, which # you should have received as part of this distribution. The terms # are also available at https://trac.edgewall.org/wiki/TracLicense. # # This software consists of voluntary contributions made by many # individuals. For the exact contribution history, see the revision # history and logs, available at https://trac.edgewall.org/log/. # # Author: Christopher Lenz <cmlenz@gmx.de> from ctypes.util import find_library import ctypes import os import re from pkg_resources import DistributionNotFound from subprocess import Popen, PIPE from trac.core import * from trac.config import Option from trac.db.api import ConnectionBase, IDatabaseConnector, \ parse_connection_uri from trac.db.util import ConnectionWrapper, IterableCursor from trac.util import get_pkginfo, lazy from trac.util.compat import close_fds from trac.util.html import Markup from trac.util.text import empty, exception_to_unicode, to_unicode from trac.util.translation import _ try: import psycopg2 as psycopg import psycopg2.extensions from psycopg2 import DataError, ProgrammingError from psycopg2.extensions import register_type, UNICODE, \ register_adapter, AsIs, QuotedString except ImportError: raise DistributionNotFound('psycopg2>=2.0 or psycopg2-binary', ['Trac']) else: register_type(UNICODE) register_adapter(Markup, lambda markup: QuotedString(str(markup))) register_adapter(type(empty), lambda empty: AsIs("''")) psycopg2_version = get_pkginfo(psycopg).get('version', psycopg.__version__) _libpq_pathname = None if not hasattr(psycopg, 'libpq_version'): # search path of libpq only if it is dynamically linked _f = _match = None try: with open(psycopg._psycopg.__file__, 'rb') as _f: if os.name != 'nt': _match = re.search( r''' \0( (?:/[^/\0]+)*/? libpq\.(?:so\.[0-9]+|[0-9]+\.dylib) )\0 '''.encode('utf-8'), _f.read(), re.VERBOSE) if _match: _libpq_pathname = _match.group(1) else: if re.search(r'\0libpq\.dll\0'.encode('utf-8'), _f.read(), re.IGNORECASE): _libpq_pathname = find_library('libpq') except AttributeError: pass del _f, _match _like_escape_re = re.compile(r'([/_%])') # Mapping from "abstract" SQL types to DB-specific types _type_map = { 'int64': 'bigint', } min_postgresql_version = (9, 1, 0) def assemble_pg_dsn(path, user=None, password=None, host=None, port=None): """Quote the parameters and assemble the DSN.""" def quote(value): if not isinstance(value, str): value = str(value) return "'%s'" % value.replace('\\', r'\\').replace("'", r"\'") dsn = {'dbname': path, 'user': user, 'password': password, 'host': host, 'port': port} return ' '.join("%s=%s" % (name, quote(value)) for name, value in dsn.items() if value) def _quote(identifier): return '"%s"' % identifier.replace('"', '""') def _version_tuple(ver): if ver: major, minor = divmod(ver, 10000) if major >= 10: # Extract 10.4 from 100004. return major, minor else: # Extract 9.1.23 from 90123. minor, patch = divmod(minor, 100) return major, minor, patch def _version_string(ver): if ver and not isinstance(ver, tuple): ver = _version_tuple(ver) if ver: return '.'.join(map(str, ver)) else: return '(unknown)' class PostgreSQLConnector(Component): """Database connector for PostgreSQL. Database URLs should be of the form: {{{ postgres://user[:password]@host[:port]/database[?schema=my_schema] }}} """ implements(IDatabaseConnector) required = False pg_dump_path = Option('trac', 'pg_dump_path', 'pg_dump', """Location of pg_dump for Postgres database backups""") def __init__(self): self._postgresql_version = \ 'server: (not-connected), client: %s' % \ _version_string(self._client_version) # IDatabaseConnector methods def get_supported_schemes(self): yield 'postgres', 1 def get_connection(self, path, log=None, user=None, password=None, host=None, port=None, params={}): params.setdefault('schema', 'public') cnx = PostgreSQLConnection(path, log, user, password, host, port, params) server_ver = _version_string(cnx.server_version) client_ver = _version_string(self._client_version) if not self.required: if cnx.server_version < min_postgresql_version: error = _( "PostgreSQL version is %(version)s. Minimum required " "version is %(min_version)s.", version=server_ver, min_version=_version_string(min_postgresql_version)) raise TracError(error) self._postgresql_version = \ 'server: %s, client: %s' % (server_ver, client_ver) self.required = True return cnx def get_exceptions(self): return psycopg def init_db(self, path, schema=None, log=None, user=None, password=None, host=None, port=None, params={}): cnx = self.get_connection(path, log, user, password, host, port, params) cursor = cnx.cursor() if cnx.schema and cnx.schema != 'public': cursor.execute('CREATE SCHEMA ' + _quote(cnx.schema)) cursor.execute('SET search_path TO %s', (cnx.schema,)) if schema is None: from trac.db_default import schema for table in schema: for stmt in self.to_sql(table): cursor.execute(stmt) cnx.commit() def destroy_db(self, path, log=None, user=None, password=None, host=None, port=None, params={}): cnx = self.get_connection(path, log, user, password, host, port, params) if cnx.schema and cnx.schema != 'public': cnx.execute('DROP SCHEMA %s CASCADE' % _quote(cnx.schema)) else: for table in cnx.get_table_names(): cnx.execute('DROP TABLE %s' % _quote(table)) cnx.commit() def db_exists(self, path, log=None, user=None, password=None, host=None, port=None, params={}): cnx = self.get_connection(path, log, user, password, host, port, params) cursor = cnx.cursor() cursor.execute(""" SELECT EXISTS(SELECT 1 FROM pg_namespace WHERE nspname=%s); """, (cnx.schema,)) return cursor.fetchone()[0] def to_sql(self, table): sql = ['CREATE TABLE %s (' % _quote(table.name)] coldefs = [] for column in table.columns: ctype = column.type ctype = _type_map.get(ctype, ctype) if column.auto_increment: ctype = 'SERIAL' if len(table.key) == 1 and column.name in table.key: ctype += ' PRIMARY KEY' coldefs.append(' %s %s' % (_quote(column.name), ctype)) if len(table.key) > 1: coldefs.append(' CONSTRAINT %s PRIMARY KEY (%s)' % (_quote(table.name + '_pk'), ','.join(_quote(col) for col in table.key))) sql.append(',\n'.join(coldefs) + '\n)') yield '\n'.join(sql) for index in table.indices: unique = 'UNIQUE' if index.unique else '' yield 'CREATE %s INDEX %s ON %s (%s)' % \ (unique, _quote('%s_%s_idx' % (table.name, '_'.join(index.columns))), _quote(table.name), ','.join(_quote(col) for col in index.columns)) def alter_column_types(self, table, columns): """Yield SQL statements altering the type of one or more columns of a table. Type changes are specified as a `columns` dict mapping column names to `(from, to)` SQL type tuples. """ alterations = [] for name, (from_, to) in sorted(columns.items()): to = _type_map.get(to, to) if to != _type_map.get(from_, from_): alterations.append((name, to)) if alterations: yield 'ALTER TABLE %s %s' % \ (_quote(table), ', '.join('ALTER COLUMN %s TYPE %s' % (_quote(name), type_) for name, type_ in alterations)) def backup(self, dest_file): db_url = self.env.config.get('trac', 'database') scheme, db_prop = parse_connection_uri(db_url) db_params = db_prop.setdefault('params', {}) db_params.setdefault('schema', 'public') db_name = os.path.basename(db_prop['path']) args = [self.pg_dump_path, '-C', '--inserts', '-x', '-Z', '8'] if 'user' in db_prop: args.extend(['-U', db_prop['user']]) host = db_params.get('host', db_prop.get('host')) if host: args.extend(['-h', host]) if '/' not in host: args.extend(['-p', str(db_prop.get('port', '5432'))]) # Need quote for -n (--schema) option args.extend(['-n', '"%s"' % db_params['schema']]) dest_file += ".gz" args.extend(['-f', dest_file, db_name]) environ = os.environ.copy() if 'password' in db_prop: environ['PGPASSWORD'] = str(db_prop['password']) try: p = Popen(args, env=environ, stderr=PIPE, close_fds=close_fds) except OSError as e: raise TracError(_("Unable to run %(path)s: %(msg)s", path=self.pg_dump_path, msg=exception_to_unicode(e))) errmsg = p.communicate()[1] if p.returncode != 0: raise TracError(_("pg_dump failed: %(msg)s", msg=to_unicode(errmsg.strip()))) if not os.path.exists(dest_file): raise TracError(_("No destination file created")) return dest_file def get_system_info(self): yield 'PostgreSQL', self._postgresql_version yield 'psycopg2', psycopg2_version @lazy def _client_version(self): version = None if hasattr(psycopg, 'libpq_version'): version = psycopg.libpq_version() elif _libpq_pathname: try: lib = ctypes.CDLL(_libpq_pathname) version = lib.PQlibVersion() except Exception as e: self.log.warning("Exception caught while retrieving libpq's " "version%s", exception_to_unicode(e, traceback=True)) return _version_tuple(version) def _pgdump_version(self): try: p = Popen([self.pg_dump_path, '--version'], stdout=PIPE, close_fds=close_fds) except OSError as e: raise TracError(_("Unable to run %(path)s: %(msg)s", path=self.pg_dump_path, msg=exception_to_unicode(e))) return p.communicate()[0] class PostgreSQLConnection(ConnectionBase, ConnectionWrapper): """Connection wrapper for PostgreSQL.""" poolable = True def __init__(self, path, log=None, user=None, password=None, host=None, port=None, params={}): if path.startswith('/'): path = path[1:] if 'host' in params: host = params['host'] cnx = psycopg.connect(assemble_pg_dsn(path, user, password, host, port)) cnx.set_client_encoding('UNICODE') self.schema = params.get('schema', 'public') if self.schema != 'public': try: cnx.cursor().execute('SET search_path TO %s', (self.schema,)) cnx.commit() except (DataError, ProgrammingError): # probably the schema doesn't exist cnx.rollback() ConnectionWrapper.__init__(self, cnx, log) def cursor(self): return IterableCursor(self.cnx.cursor(), self.log) def cast(self, column, type): # Temporary hack needed for the union of selects in the search module return 'CAST(%s AS %s)' % (column, _type_map.get(type, type)) def concat(self, *args): return '||'.join(args) def drop_column(self, table, column): self.execute(""" ALTER TABLE %s DROP COLUMN IF EXISTS %s """ % (self.quote(table), self.quote(column))) def drop_table(self, table): self.execute("DROP TABLE IF EXISTS " + self.quote(table)) def get_column_names(self, table): rows = self.execute(""" SELECT column_name FROM information_schema.columns WHERE table_schema=current_schema() AND table_name=%s ORDER BY ordinal_position """, (table,)) return [row[0] for row in rows] def get_last_id(self, cursor, table, column='id'): cursor.execute("SELECT CURRVAL(%s)", (self.quote(self._sequence_name(table, column)),)) return cursor.fetchone()[0] def get_sequence_names(self): seqs = [name[:-len('_id_seq')] for name, in self.execute(""" SELECT c.relname FROM pg_class c INNER JOIN pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = ANY (current_schemas(false)) AND c.relkind='S' AND c.relname LIKE %s ESCAPE '!' """, ('%!_id!_seq',))] return sorted(name for name in seqs if name in self.get_table_names()) def get_table_names(self): rows = self.execute(""" SELECT table_name FROM information_schema.tables WHERE table_schema=current_schema()""") return [row[0] for row in rows] def has_table(self, table): rows = self.execute(""" SELECT EXISTS (SELECT * FROM information_schema.columns WHERE table_schema=current_schema() AND table_name=%s) """, (table,)) return rows[0][0] def like(self): return "ILIKE %s ESCAPE '/'" def like_escape(self, text): return _like_escape_re.sub(r'/\1', text) def ping(self): cursor = self.cnx.cursor() cursor.execute('SELECT 1') def prefix_match(self): return "LIKE %s ESCAPE '/'" def prefix_match_value(self, prefix): return self.like_escape(prefix) + '%' def quote(self, identifier): return _quote(identifier) def reset_tables(self): # reset sequences cursor = self.cursor() cursor.execute(""" SELECT sequence_name FROM information_schema.sequences WHERE sequence_schema=%s """, (self.schema,)) for seq, in cursor.fetchall(): cursor.execute("ALTER SEQUENCE %s RESTART WITH 1" % seq) # clear tables table_names = self.get_table_names() for name in table_names: cursor.execute("DELETE FROM " + self.quote(name)) # PostgreSQL supports TRUNCATE TABLE as well # (see https://www.postgresql.org/docs/9.1/static/sql-truncate.html) # but on the small tables used here, DELETE is actually much faster return table_names def update_sequence(self, cursor, table, column='id'): cursor.execute("SELECT SETVAL(%%s, (SELECT MAX(%s) FROM %s))" % (self.quote(column), self.quote(table)), (self.quote(self._sequence_name(table, column)),)) def _sequence_name(self, table, column): return '%s_%s_seq' % (table, column) @lazy def server_version(self): return _version_tuple(self.cnx.server_version)