VirtualMailManager/handler.py
branchv0.6.x
changeset 355 48bf20b43f2e
parent 351 4bba5fb90b78
child 366 d6573da35b5f
--- a/VirtualMailManager/handler.py	Mon Aug 09 03:52:01 2010 +0000
+++ b/VirtualMailManager/handler.py	Mon Aug 09 05:42:56 2010 +0000
@@ -18,8 +18,6 @@
 from shutil import rmtree
 from subprocess import Popen, PIPE
 
-from pyPgSQL import PgSQL  # python-pgsql - http://pypgsql.sourceforge.net
-
 from VirtualMailManager.account import Account
 from VirtualMailManager.alias import Alias
 from VirtualMailManager.aliasdomain import AliasDomain
@@ -30,7 +28,7 @@
      DATABASE_ERROR, DOMAINDIR_GROUP_MISMATCH, DOMAIN_INVALID, \
      FOUND_DOTS_IN_PATH, INVALID_ARGUMENT, MAILDIR_PERM_MISMATCH, \
      NOT_EXECUTABLE, NO_SUCH_ACCOUNT, NO_SUCH_ALIAS, NO_SUCH_BINARY, \
-     NO_SUCH_DIRECTORY, NO_SUCH_RELOCATED, RELOCATED_EXISTS
+     NO_SUCH_DIRECTORY, NO_SUCH_RELOCATED, RELOCATED_EXISTS, VMM_ERROR
 from VirtualMailManager.domain import Domain, get_gid
 from VirtualMailManager.emailaddress import EmailAddress
 from VirtualMailManager.errors import \
@@ -42,6 +40,7 @@
 
 
 _ = lambda msg: msg
+_db_mod = None
 
 CFG_FILE = 'vmm.cfg'
 CFG_PATH = '/root:/usr/local/etc:/etc'
@@ -59,7 +58,7 @@
 class Handler(object):
     """Wrapper class to simplify the access on all the stuff from
     VirtualMailManager"""
-    __slots__ = ('_cfg', '_cfg_fname', '_dbh', '_warnings')
+    __slots__ = ('_cfg', '_cfg_fname', '_db_connect', '_dbh', '_warnings')
 
     def __init__(self, skip_some_checks=False):
         """Creates a new Handler instance.
@@ -75,6 +74,7 @@
         self._warnings = []
         self._cfg = None
         self._dbh = None
+        self._db_connect = None
 
         if os.geteuid():
             raise NotRootError(_(u"You are not root.\n\tGood bye!\n"),
@@ -85,6 +85,7 @@
         if not skip_some_checks:
             self._cfg.check()
             self._chkenv()
+            self._set_db_connect()
 
     def _find_cfg_file(self):
         """Search the CFG_FILE in CFG_PATH.
@@ -143,21 +144,62 @@
                 else:
                     raise
 
-    def _db_connect(self):
+    def _set_db_connect(self):
+        """check which module to use and set self._db_connect"""
+        global _db_mod
+        if self._cfg.dget('database.module').lower() == 'psycopg2':
+            try:
+                _db_mod = __import__('psycopg2')
+            except ImportError:
+                raise VMMError(_(u"Unable to import database module '%s'") %
+                               'psycopg2', VMM_ERROR)
+            self._db_connect = self._psycopg2_connect
+        else:
+            try:
+                tmp = __import__('pyPgSQL', globals(), locals(), ['PgSQL'])
+            except ImportError:
+                raise VMMError(_(u"Unable to import database module '%s'") %
+                               'pyPgSQL', VMM_ERROR)
+            _db_mod = tmp.PgSQL
+            self._db_connect = self._pypgsql_connect
+
+    def _pypgsql_connect(self):
         """Creates a pyPgSQL.PgSQL.connection instance."""
-        if self._dbh is None or (isinstance(self._dbh, PgSQL.Connection) and
+        if self._dbh is None or (isinstance(self._dbh, _db_mod.Connection) and
                                   not self._dbh._isOpen):
             try:
-                self._dbh = PgSQL.connect(
+                self._dbh = _db_mod.connect(
                         database=self._cfg.dget('database.name'),
                         user=self._cfg.pget('database.user'),
                         host=self._cfg.dget('database.host'),
+                        port=self._cfg.dget('database.port'),
                         password=self._cfg.pget('database.pass'),
                         client_encoding='utf8', unicode_results=True)
                 dbc = self._dbh.cursor()
                 dbc.execute("SET NAMES 'UTF8'")
                 dbc.close()
-            except PgSQL.libpq.DatabaseError, err:
+            except _db_mod.libpq.DatabaseError, err:
+                raise VMMError(str(err), DATABASE_ERROR)
+
+    def _psycopg2_connect(self):
+        """Return a new psycopg2 connection object."""
+        if self._dbh is None or \
+          (isinstance(self._dbh, _db_mod.extensions.connection) and
+           self._dbh.closed):
+            try:
+                self._dbh = _db_mod.connect(
+                        host=self._cfg.dget('database.host'),
+                        sslmode=self._cfg.dget('database.sslmode'),
+                        port=self._cfg.dget('database.port'),
+                        database=self._cfg.dget('database.name'),
+                        user=self._cfg.pget('database.user'),
+                        password=self._cfg.pget('database.pass'))
+                self._dbh.set_client_encoding('utf8')
+                _db_mod.extensions.register_type(_db_mod.extensions.UNICODE)
+                dbc = self._dbh.cursor()
+                dbc.execute("SET NAMES 'UTF8'")
+                dbc.close()
+            except _db_mod.DatabaseError, err:
                 raise VMMError(str(err), DATABASE_ERROR)
 
     def _chk_other_address_types(self, address, exclude):