diff options
-rw-r--r-- | db/__init__.py | 60 | ||||
-rw-r--r-- | db/models.py | 106 |
2 files changed, 108 insertions, 58 deletions
diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e02a1fe --- /dev/null +++ b/db/__init__.py @@ -0,0 +1,60 @@ +# coding=utf-8 + +import os + +import psycopg2 + +__all__ = ['get_connection', 'return_connection', 'setup', 'models'] + +_host = None # the database hostname/IP +_port = None # the database port number +_database = None # the name of the database +_username = None # the username to access the database +_password = None # the password to access the database + + +# database parameters setup + +def _get_port(): + try: + return int(os.environ.get('CODEQ_DB_PORT')) + except: + return 5432 + + +def setup( + host=os.environ.get('CODEQ_DB_HOST') or 'localhost', + port=_get_port(), + database=os.environ.get('CODEQ_DB_DATABASE') or 'codeq', + username=os.environ.get('CODEQ_DB_USER') or 'codeq', + password=os.environ.get('CODEQ_DB_PASS') or 'c0d3q' +): + """Sets the database location and authentication parameters.""" + global _host, _port, _database, _username, _password + _host = host + _port = port + _database = database + _username = username + _password = password + +# connection pooling + +_connection_pool = [] + + +def get_connection(): + """Retrieves a database connection from the connection pool.""" + if _host is None: + setup() # lazy init + if len(_connection_pool) > 0: + return _connection_pool.pop() + return psycopg2.connect(host=_host, port=_port, database=_database, user=_username, password=_password) + + +def return_connection(connection): + """Returns the given database connection to the pool.""" + _connection_pool.append(connection) + + +if __name__ == '__main__': + setup() diff --git a/db/models.py b/db/models.py index 994e819..dcd17af 100644 --- a/db/models.py +++ b/db/models.py @@ -1,14 +1,15 @@ -import os, collections, psycopg2, json +# coding=utf-8 -_conn = None # the database connection -_host = None # the database hostname/IP -_port = None # the database port number -_database = None # the name of the database -_username = None # the username to access the database -_password = None # the password to access the database +import collections +import json -class CodeqUser(collections.namedtuple('CodeqUser', ['id', 'username', 'password', 'first_name', 'last_name', 'email', 'is_superuser', 'is_staff', 'is_active', 'date_joined', 'last_login'])): - __sql_prefix = 'select id, username, password, first_name, last_name, email, is_superuser, is_staff, is_active, date_joined, last_login from codeq_user' +from . import get_connection, return_connection + +__all__ = ['CodeqUser', 'Solution'] + + +class CodeqUser(collections.namedtuple('CodeqUser', ['id', 'username', 'password', 'name', 'email', 'is_admin', 'is_active', 'date_joined', 'last_login'])): + __sql_prefix = 'select id, username, password, name, email, is_admin, is_active, date_joined, last_login from codeq_user' @staticmethod def get(**kwargs): @@ -40,12 +41,6 @@ class Solution(collections.namedtuple('Solution', ['id', 'done', 'content', 'pro return _general_filter(kwargs, Solution, Solution.__sql_prefix, Solution.__row_conversion) -def _get_connection(): - global _conn, _host, _port, _database, _username, _password - if _conn == None: - _conn = psycopg2.connect(host=_host, port=_port, database=_database, user=_username, password=_password) - return _conn - def _no_row_conversion(row): return row @@ -57,17 +52,20 @@ def _general_get(kwargs_dict, clazz, sql_select, row_conversion_fn=_no_row_conve parameters.append(field_value) if len(conditions) == 0: return None - conn = _get_connection() - cur = conn.cursor('crsr1') # a named cursor: scrolling is done on the server - cur.arraysize = 1 # scroll unit in the number of rows + conn = get_connection() try: - cur.execute(sql_select + ' where ' + ' and '.join(conditions), parameters) - row = cur.fetchone() - if row: - return clazz(*row_conversion_fn(row)) - return None + cur = conn.cursor('crsr1') # a named cursor: scrolling is done on the server + cur.arraysize = 1 # scroll unit in the number of rows + try: + cur.execute(sql_select + ' where ' + ' and '.join(conditions), parameters) + row = cur.fetchone() + if row: + return clazz(*row_conversion_fn(row)) + return None + finally: + cur.close() finally: - cur.close() + return_connection(conn) def _general_filter(kwargs_dict, clazz, sql_select, row_conversion_fn=_no_row_conversion): conditions = [] @@ -77,45 +75,37 @@ def _general_filter(kwargs_dict, clazz, sql_select, row_conversion_fn=_no_row_co parameters.append(field_value) if len(conditions) == 0: return _general_list(clazz, sql_select) - conn = _get_connection() - cur = conn.cursor('crsr2') # a named cursor: scrolling is done on the server - cur.arraysize = 10000 # scroll unit in the number of rows + conn = get_connection() try: - cur.execute(sql_select + ' where ' + ' and '.join(conditions) + ' order by id', parameters) - result = [] - row = cur.fetchone() - while row: - result.append(clazz(*row_conversion_fn(row))) + cur = conn.cursor('crsr2') # a named cursor: scrolling is done on the server + cur.arraysize = 10000 # scroll unit in the number of rows + try: + cur.execute(sql_select + ' where ' + ' and '.join(conditions) + ' order by id', parameters) + result = [] row = cur.fetchone() - return result + while row: + result.append(clazz(*row_conversion_fn(row))) + row = cur.fetchone() + return result + finally: + cur.close() finally: - cur.close() + return_connection(conn) def _general_list(clazz, sql_select, row_conversion_fn=_no_row_conversion): - conn = _get_connection() - cur = conn.cursor('crsr3') # a named cursor: scrolling is done on the server - cur.arraysize = 10000 # scroll unit in the number of rows + conn = get_connection() try: - cur.execute(sql_select + ' order by id') - result = [] - row = cur.fetchone() - while row: - result.append(clazz(*row_conversion_fn(row))) + cur = conn.cursor('crsr3') # a named cursor: scrolling is done on the server + cur.arraysize = 10000 # scroll unit in the number of rows + try: + cur.execute(sql_select + ' order by id') + result = [] row = cur.fetchone() - return result + while row: + result.append(clazz(*row_conversion_fn(row))) + row = cur.fetchone() + return result + finally: + cur.close() finally: - cur.close() - -def init(): - global _host, _port, _database, _username, _password - _host = os.environ.get('CODEQ_DB_HOST') or 'localhost' - try: - _port = int(os.environ.get('CODEQ_DB_PORT')) or 5432 - except: - _port = 5432 - _database = os.environ.get('CODEQ_DB_DATABASE') or 'codeq' - _username = os.environ.get('CODEQ_DB_USER') or 'codeq' - _password = os.environ.get('CODEQ_DB_PASS') or 'c0d3q' - -if __name__ == '__main__': - init() + return_connection(conn) |