# HG changeset patch # User Pascal Volk # Date 1281332576 0 # Node ID 48bf20b43f2ef99b6b465cc679da4673f619a5cc # Parent a653c43048b18657db0cdb1654e761da22c5f3fc VMM/handler: Added support for psycopg2. diff -r a653c43048b1 -r 48bf20b43f2e VirtualMailManager/handler.py --- a/VirtualMailManager/handler.py Mon Aug 09 03:52:01 2010 +0000 +++ b/VirtualMailManager/handler.py Mon Aug 09 05:42:56 2010 +0000 @@ -18,8 +18,6 @@ from shutil import rmtree from subprocess import Popen, PIPE -from pyPgSQL import PgSQL # python-pgsql - http://pypgsql.sourceforge.net - from VirtualMailManager.account import Account from VirtualMailManager.alias import Alias from VirtualMailManager.aliasdomain import AliasDomain @@ -30,7 +28,7 @@ DATABASE_ERROR, DOMAINDIR_GROUP_MISMATCH, DOMAIN_INVALID, \ FOUND_DOTS_IN_PATH, INVALID_ARGUMENT, MAILDIR_PERM_MISMATCH, \ NOT_EXECUTABLE, NO_SUCH_ACCOUNT, NO_SUCH_ALIAS, NO_SUCH_BINARY, \ - NO_SUCH_DIRECTORY, NO_SUCH_RELOCATED, RELOCATED_EXISTS + NO_SUCH_DIRECTORY, NO_SUCH_RELOCATED, RELOCATED_EXISTS, VMM_ERROR from VirtualMailManager.domain import Domain, get_gid from VirtualMailManager.emailaddress import EmailAddress from VirtualMailManager.errors import \ @@ -42,6 +40,7 @@ _ = lambda msg: msg +_db_mod = None CFG_FILE = 'vmm.cfg' CFG_PATH = '/root:/usr/local/etc:/etc' @@ -59,7 +58,7 @@ class Handler(object): """Wrapper class to simplify the access on all the stuff from VirtualMailManager""" - __slots__ = ('_cfg', '_cfg_fname', '_dbh', '_warnings') + __slots__ = ('_cfg', '_cfg_fname', '_db_connect', '_dbh', '_warnings') def __init__(self, skip_some_checks=False): """Creates a new Handler instance. @@ -75,6 +74,7 @@ self._warnings = [] self._cfg = None self._dbh = None + self._db_connect = None if os.geteuid(): raise NotRootError(_(u"You are not root.\n\tGood bye!\n"), @@ -85,6 +85,7 @@ if not skip_some_checks: self._cfg.check() self._chkenv() + self._set_db_connect() def _find_cfg_file(self): """Search the CFG_FILE in CFG_PATH. @@ -143,21 +144,62 @@ else: raise - def _db_connect(self): + def _set_db_connect(self): + """check which module to use and set self._db_connect""" + global _db_mod + if self._cfg.dget('database.module').lower() == 'psycopg2': + try: + _db_mod = __import__('psycopg2') + except ImportError: + raise VMMError(_(u"Unable to import database module '%s'") % + 'psycopg2', VMM_ERROR) + self._db_connect = self._psycopg2_connect + else: + try: + tmp = __import__('pyPgSQL', globals(), locals(), ['PgSQL']) + except ImportError: + raise VMMError(_(u"Unable to import database module '%s'") % + 'pyPgSQL', VMM_ERROR) + _db_mod = tmp.PgSQL + self._db_connect = self._pypgsql_connect + + def _pypgsql_connect(self): """Creates a pyPgSQL.PgSQL.connection instance.""" - if self._dbh is None or (isinstance(self._dbh, PgSQL.Connection) and + if self._dbh is None or (isinstance(self._dbh, _db_mod.Connection) and not self._dbh._isOpen): try: - self._dbh = PgSQL.connect( + self._dbh = _db_mod.connect( database=self._cfg.dget('database.name'), user=self._cfg.pget('database.user'), host=self._cfg.dget('database.host'), + port=self._cfg.dget('database.port'), password=self._cfg.pget('database.pass'), client_encoding='utf8', unicode_results=True) dbc = self._dbh.cursor() dbc.execute("SET NAMES 'UTF8'") dbc.close() - except PgSQL.libpq.DatabaseError, err: + except _db_mod.libpq.DatabaseError, err: + raise VMMError(str(err), DATABASE_ERROR) + + def _psycopg2_connect(self): + """Return a new psycopg2 connection object.""" + if self._dbh is None or \ + (isinstance(self._dbh, _db_mod.extensions.connection) and + self._dbh.closed): + try: + self._dbh = _db_mod.connect( + host=self._cfg.dget('database.host'), + sslmode=self._cfg.dget('database.sslmode'), + port=self._cfg.dget('database.port'), + database=self._cfg.dget('database.name'), + user=self._cfg.pget('database.user'), + password=self._cfg.pget('database.pass')) + self._dbh.set_client_encoding('utf8') + _db_mod.extensions.register_type(_db_mod.extensions.UNICODE) + dbc = self._dbh.cursor() + dbc.execute("SET NAMES 'UTF8'") + dbc.close() + except _db_mod.DatabaseError, err: raise VMMError(str(err), DATABASE_ERROR) def _chk_other_address_types(self, address, exclude):