summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--db/__init__.py60
-rw-r--r--db/models.py106
2 files changed, 108 insertions, 58 deletions
diff --git a/db/__init__.py b/db/__init__.py
new file mode 100644
index 0000000..e02a1fe
--- /dev/null
+++ b/db/__init__.py
@@ -0,0 +1,60 @@
+# coding=utf-8
+
+import os
+
+import psycopg2
+
+__all__ = ['get_connection', 'return_connection', 'setup', 'models']
+
+_host = None # the database hostname/IP
+_port = None # the database port number
+_database = None # the name of the database
+_username = None # the username to access the database
+_password = None # the password to access the database
+
+
+# database parameters setup
+
+def _get_port():
+ try:
+ return int(os.environ.get('CODEQ_DB_PORT'))
+ except:
+ return 5432
+
+
+def setup(
+ host=os.environ.get('CODEQ_DB_HOST') or 'localhost',
+ port=_get_port(),
+ database=os.environ.get('CODEQ_DB_DATABASE') or 'codeq',
+ username=os.environ.get('CODEQ_DB_USER') or 'codeq',
+ password=os.environ.get('CODEQ_DB_PASS') or 'c0d3q'
+):
+ """Sets the database location and authentication parameters."""
+ global _host, _port, _database, _username, _password
+ _host = host
+ _port = port
+ _database = database
+ _username = username
+ _password = password
+
+# connection pooling
+
+_connection_pool = []
+
+
+def get_connection():
+ """Retrieves a database connection from the connection pool."""
+ if _host is None:
+ setup() # lazy init
+ if len(_connection_pool) > 0:
+ return _connection_pool.pop()
+ return psycopg2.connect(host=_host, port=_port, database=_database, user=_username, password=_password)
+
+
+def return_connection(connection):
+ """Returns the given database connection to the pool."""
+ _connection_pool.append(connection)
+
+
+if __name__ == '__main__':
+ setup()
diff --git a/db/models.py b/db/models.py
index 994e819..dcd17af 100644
--- a/db/models.py
+++ b/db/models.py
@@ -1,14 +1,15 @@
-import os, collections, psycopg2, json
+# coding=utf-8
-_conn = None # the database connection
-_host = None # the database hostname/IP
-_port = None # the database port number
-_database = None # the name of the database
-_username = None # the username to access the database
-_password = None # the password to access the database
+import collections
+import json
-class CodeqUser(collections.namedtuple('CodeqUser', ['id', 'username', 'password', 'first_name', 'last_name', 'email', 'is_superuser', 'is_staff', 'is_active', 'date_joined', 'last_login'])):
- __sql_prefix = 'select id, username, password, first_name, last_name, email, is_superuser, is_staff, is_active, date_joined, last_login from codeq_user'
+from . import get_connection, return_connection
+
+__all__ = ['CodeqUser', 'Solution']
+
+
+class CodeqUser(collections.namedtuple('CodeqUser', ['id', 'username', 'password', 'name', 'email', 'is_admin', 'is_active', 'date_joined', 'last_login'])):
+ __sql_prefix = 'select id, username, password, name, email, is_admin, is_active, date_joined, last_login from codeq_user'
@staticmethod
def get(**kwargs):
@@ -40,12 +41,6 @@ class Solution(collections.namedtuple('Solution', ['id', 'done', 'content', 'pro
return _general_filter(kwargs, Solution, Solution.__sql_prefix, Solution.__row_conversion)
-def _get_connection():
- global _conn, _host, _port, _database, _username, _password
- if _conn == None:
- _conn = psycopg2.connect(host=_host, port=_port, database=_database, user=_username, password=_password)
- return _conn
-
def _no_row_conversion(row):
return row
@@ -57,17 +52,20 @@ def _general_get(kwargs_dict, clazz, sql_select, row_conversion_fn=_no_row_conve
parameters.append(field_value)
if len(conditions) == 0:
return None
- conn = _get_connection()
- cur = conn.cursor('crsr1') # a named cursor: scrolling is done on the server
- cur.arraysize = 1 # scroll unit in the number of rows
+ conn = get_connection()
try:
- cur.execute(sql_select + ' where ' + ' and '.join(conditions), parameters)
- row = cur.fetchone()
- if row:
- return clazz(*row_conversion_fn(row))
- return None
+ cur = conn.cursor('crsr1') # a named cursor: scrolling is done on the server
+ cur.arraysize = 1 # scroll unit in the number of rows
+ try:
+ cur.execute(sql_select + ' where ' + ' and '.join(conditions), parameters)
+ row = cur.fetchone()
+ if row:
+ return clazz(*row_conversion_fn(row))
+ return None
+ finally:
+ cur.close()
finally:
- cur.close()
+ return_connection(conn)
def _general_filter(kwargs_dict, clazz, sql_select, row_conversion_fn=_no_row_conversion):
conditions = []
@@ -77,45 +75,37 @@ def _general_filter(kwargs_dict, clazz, sql_select, row_conversion_fn=_no_row_co
parameters.append(field_value)
if len(conditions) == 0:
return _general_list(clazz, sql_select)
- conn = _get_connection()
- cur = conn.cursor('crsr2') # a named cursor: scrolling is done on the server
- cur.arraysize = 10000 # scroll unit in the number of rows
+ conn = get_connection()
try:
- cur.execute(sql_select + ' where ' + ' and '.join(conditions) + ' order by id', parameters)
- result = []
- row = cur.fetchone()
- while row:
- result.append(clazz(*row_conversion_fn(row)))
+ cur = conn.cursor('crsr2') # a named cursor: scrolling is done on the server
+ cur.arraysize = 10000 # scroll unit in the number of rows
+ try:
+ cur.execute(sql_select + ' where ' + ' and '.join(conditions) + ' order by id', parameters)
+ result = []
row = cur.fetchone()
- return result
+ while row:
+ result.append(clazz(*row_conversion_fn(row)))
+ row = cur.fetchone()
+ return result
+ finally:
+ cur.close()
finally:
- cur.close()
+ return_connection(conn)
def _general_list(clazz, sql_select, row_conversion_fn=_no_row_conversion):
- conn = _get_connection()
- cur = conn.cursor('crsr3') # a named cursor: scrolling is done on the server
- cur.arraysize = 10000 # scroll unit in the number of rows
+ conn = get_connection()
try:
- cur.execute(sql_select + ' order by id')
- result = []
- row = cur.fetchone()
- while row:
- result.append(clazz(*row_conversion_fn(row)))
+ cur = conn.cursor('crsr3') # a named cursor: scrolling is done on the server
+ cur.arraysize = 10000 # scroll unit in the number of rows
+ try:
+ cur.execute(sql_select + ' order by id')
+ result = []
row = cur.fetchone()
- return result
+ while row:
+ result.append(clazz(*row_conversion_fn(row)))
+ row = cur.fetchone()
+ return result
+ finally:
+ cur.close()
finally:
- cur.close()
-
-def init():
- global _host, _port, _database, _username, _password
- _host = os.environ.get('CODEQ_DB_HOST') or 'localhost'
- try:
- _port = int(os.environ.get('CODEQ_DB_PORT')) or 5432
- except:
- _port = 5432
- _database = os.environ.get('CODEQ_DB_DATABASE') or 'codeq'
- _username = os.environ.get('CODEQ_DB_USER') or 'codeq'
- _password = os.environ.get('CODEQ_DB_PASS') or 'c0d3q'
-
-if __name__ == '__main__':
- init()
+ return_connection(conn)