From 27d4458613a5b61f16ad9bf59ca1de460fea3b3a Mon Sep 17 00:00:00 2001 From: Timotej Lazar Date: Mon, 9 Jan 2017 18:07:23 +0100 Subject: First commit is the best commit --- db/__init__.py | 82 +++++++++++++++++++++++++++ db/models.py | 173 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 255 insertions(+) create mode 100644 db/__init__.py create mode 100644 db/models.py (limited to 'db') diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..783df6c --- /dev/null +++ b/db/__init__.py @@ -0,0 +1,82 @@ +# CodeQ: an online programming tutor. +# Copyright (C) 2015 UL FRI +# +# This program is free software: you can redistribute it and/or modify it under +# the terms of the GNU Affero General Public License as published by the Free +# Software Foundation, either version 3 of the License, or (at your option) any +# later version. +# +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +# details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import os +import threading +import psycopg2 + +__all__ = ['get_connection', 'return_connection', 'setup', 'models'] + +_module_access_lock = threading.Lock() + +_host = None # the database hostname/IP +_port = None # the database port number +_database = None # the name of the database +_username = None # the username to access the database +_password = None # the password to access the database + + +# database parameters setup + +def _get_port(): + try: + return int(os.environ.get('CODEQ_DB_PORT')) + except: + return 5432 + + +def setup( + host=os.environ.get('CODEQ_DB_HOST') or 'localhost', + port=_get_port(), + database=os.environ.get('CODEQ_DB_DATABASE') or 'codeq', + username=os.environ.get('CODEQ_DB_USER') or 'codeq', + password=os.environ.get('CODEQ_DB_PASS') or 'c0d3q' +): + """Sets the database location and authentication parameters.""" + global _host, _port, _database, _username, _password + _host = host + _port = port + _database = database + _username = username + _password = password + +# connection pooling + +_connection_pool = [] + + +def get_connection(): + """Retrieves a database connection from the connection pool.""" + with _module_access_lock: + if _host is None: + setup() # lazy init + if len(_connection_pool) > 0: + return _connection_pool.pop() + return psycopg2.connect(host=_host, port=_port, database=_database, user=_username, password=_password) + + +def return_connection(connection): + """Returns the given database connection to the pool.""" + try: + connection.rollback() # sanity check + except: + return + with _module_access_lock: + _connection_pool.append(connection) + + +if __name__ == '__main__': + setup() diff --git a/db/models.py b/db/models.py new file mode 100644 index 0000000..9edbec4 --- /dev/null +++ b/db/models.py @@ -0,0 +1,173 @@ +# CodeQ: an online programming tutor. +# Copyright (C) 2015 UL FRI +# +# This program is free software: you can redistribute it and/or modify it under +# the terms of the GNU Affero General Public License as published by the Free +# Software Foundation, either version 3 of the License, or (at your option) any +# later version. +# +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +# details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import collections + +from . import get_connection, return_connection + +__all__ = ['CodeqUser', 'Solution'] + +class CodeqUser(collections.namedtuple('CodeqUser', ['id', 'username', 'password', 'name', 'email', 'is_admin', 'is_active', 'date_joined', 'last_login', 'gui_lang', 'robot_address', 'saml_data', 'gui_layout'])): + __sql_prefix = 'select id, username, password, name, email, is_admin, is_active, date_joined, last_login, gui_lang, robot_address, saml_data, gui_layout from codeq_user' + + @staticmethod + def get(**kwargs): + return _general_get(kwargs, CodeqUser, CodeqUser.__sql_prefix) + + @staticmethod + def list(): + return _general_list(CodeqUser, CodeqUser.__sql_prefix) + + @staticmethod + def filter(**kwargs): + return _general_filter(kwargs, CodeqUser, CodeqUser.__sql_prefix) + + @staticmethod + def solved_problems(user_id, language): + return _run_sql('select g.identifier, p.identifier from solution s inner join problem p on p.id = s.problem_id inner join problem_group g on g.id = p.problem_group_id inner join language l on l.id = p.language_id where s.codeq_user_id = %s and l.identifier = %s and s.done = True', (user_id, language), fetch_one=False) + +class Problem(collections.namedtuple('Problem', ['id', 'language', 'group', 'identifier'])): + __sql_prefix = '''\ + select p.id, l.identifier, g.identifier, p.identifier + from problem p + inner join problem_group g on g.id = p.problem_group_id + inner join language l on l.id = p.language_id''' + __sql_order = 'p.language_id, p.problem_group_id, p.id' + + @staticmethod + def get(**kwargs): + kwargs = {'p.'+k: v for k, v in kwargs.items()} + return _general_get(kwargs, Problem, Problem.__sql_prefix) + + @staticmethod + def list(): + return _general_list(Problem, Problem.__sql_prefix, order=Problem.__sql_order) + + @staticmethod + def filter(**kwargs): + kwargs = {'p.'+k: v for k, v in kwargs.items()} + return _general_filter(kwargs, Problem, Problem.__sql_prefix, order=Problem.__sql_order) + + # get a list of problems with the given language identifier + @staticmethod + def filter_language(language): + kwargs = {'l.identifier': language} + return _general_filter(kwargs, Problem, Problem.__sql_prefix, order=Problem.__sql_order) + +# known as Attempt in the original code +class Solution(collections.namedtuple('Solution', ['id', 'done', 'content', 'problem_id', 'codeq_user_id', 'trace'])): + __sql_prefix = 'select id, done, content, problem_id, codeq_user_id, trace from solution' + + @staticmethod + def get(**kwargs): + return _general_get(kwargs, Solution, Solution.__sql_prefix) + + @staticmethod + def list(): + return _general_list(Solution, Solution.__sql_prefix) + + @staticmethod + def filter(**kwargs): + return _general_filter(kwargs, Solution, Solution.__sql_prefix) + + +def _no_row_conversion(row): + return row + +def _general_get(kwargs_dict, clazz, sql_select, row_conversion_fn=_no_row_conversion): + conditions = [] + parameters = [] + for field_name, field_value in kwargs_dict.items(): + conditions.append(field_name + ' = %s') + parameters.append(field_value) + if len(conditions) == 0: + return None + conn = get_connection() + try: + cur = conn.cursor('crsr1') # a named cursor: scrolling is done on the server + cur.arraysize = 1 # scroll unit in the number of rows + try: + cur.execute(sql_select + ' where ' + ' and '.join(conditions), parameters) + row = cur.fetchone() + if row: + return clazz(*row_conversion_fn(row)) + return None + finally: + cur.close() + finally: + conn.commit() + return_connection(conn) + +def _general_filter(kwargs_dict, clazz, sql_select, row_conversion_fn=_no_row_conversion, order='id'): + conditions = [] + parameters = [] + for field_name, field_value in kwargs_dict.items(): + conditions.append(field_name + ' = %s') + parameters.append(field_value) + if len(conditions) == 0: + return _general_list(clazz, sql_select) + conn = get_connection() + try: + cur = conn.cursor('crsr2') # a named cursor: scrolling is done on the server + cur.arraysize = 10000 # scroll unit in the number of rows + try: + cur.execute(sql_select + ' where ' + ' and '.join(conditions) + ' order by ' + order, parameters) + result = [] + row = cur.fetchone() + while row: + result.append(clazz(*row_conversion_fn(row))) + row = cur.fetchone() + return result + finally: + cur.close() + finally: + conn.commit() + return_connection(conn) + +def _general_list(clazz, sql_select, row_conversion_fn=_no_row_conversion, order='id'): + conn = get_connection() + try: + cur = conn.cursor('crsr3') # a named cursor: scrolling is done on the server + cur.arraysize = 10000 # scroll unit in the number of rows + try: + cur.execute(sql_select + ' order by ' + order) + result = [] + row = cur.fetchone() + while row: + result.append(clazz(*row_conversion_fn(row))) + row = cur.fetchone() + return result + finally: + cur.close() + finally: + conn.commit() + return_connection(conn) + +def _run_sql(sql, params, fetch_one=False): + conn = get_connection() + try: + cur = conn.cursor() + try: + cur.execute(sql, params) + if fetch_one: + return cur.fetchone() + else: + return cur.fetchall() + finally: + cur.close() + finally: + conn.commit() + return_connection(conn) -- cgit v1.2.1