summaryrefslogtreecommitdiff
path: root/db/models.py
blob: 994e8194d7cf38dd08b42c708408db9e60220295 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os, collections, psycopg2, json

_conn = None # the database connection
_host = None # the database hostname/IP
_port = None # the database port number
_database = None # the name of the database
_username = None # the username to access the database
_password = None # the password to access the database

class CodeqUser(collections.namedtuple('CodeqUser', ['id', 'username', 'password', 'first_name', 'last_name', 'email', 'is_superuser', 'is_staff', 'is_active', 'date_joined', 'last_login'])):
    __sql_prefix = 'select id, username, password, first_name, last_name, email, is_superuser, is_staff, is_active, date_joined, last_login from codeq_user'

    @staticmethod
    def get(**kwargs):
        return _general_get(kwargs, CodeqUser, CodeqUser.__sql_prefix)

    @staticmethod
    def list():
        return _general_list(CodeqUser, CodeqUser.__sql_prefix)

    @staticmethod
    def filter(**kwargs):
        return _general_filter(kwargs, CodeqUser, CodeqUser.__sql_prefix)

# known as Attempt in the original code
class Solution(collections.namedtuple('Solution', ['id', 'done', 'content', 'problem_id', 'codeq_user_id', 'trace'])):
    __sql_prefix = 'select id, done, content, problem_id, codeq_user_id, trace::text from solution'
    __row_conversion = lambda row: (row[0], row[1], row[2], row[3], row[4], json.loads(row[5]))

    @staticmethod
    def get(**kwargs):
        return _general_get(kwargs, Solution, Solution.__sql_prefix, Solution.__row_conversion)

    @staticmethod
    def list():
        return _general_list(Solution, Solution.__sql_prefix, Solution.__row_conversion)

    @staticmethod
    def filter(**kwargs):
        return _general_filter(kwargs, Solution, Solution.__sql_prefix, Solution.__row_conversion)


def _get_connection():
    global _conn, _host, _port, _database, _username, _password
    if _conn == None:
        _conn = psycopg2.connect(host=_host, port=_port, database=_database, user=_username, password=_password)
    return _conn

def _no_row_conversion(row):
    return row

def _general_get(kwargs_dict, clazz, sql_select, row_conversion_fn=_no_row_conversion):
    conditions = []
    parameters = []
    for field_name, field_value in kwargs_dict.items():
        conditions.append(field_name + ' = %s')
        parameters.append(field_value)
    if len(conditions) == 0:
        return None
    conn = _get_connection()
    cur = conn.cursor('crsr1') # a named cursor: scrolling is done on the server
    cur.arraysize = 1 # scroll unit in the number of rows
    try:
        cur.execute(sql_select + ' where ' + ' and '.join(conditions), parameters)
        row = cur.fetchone()
        if row:
            return clazz(*row_conversion_fn(row))
        return None
    finally:
        cur.close()

def _general_filter(kwargs_dict, clazz, sql_select, row_conversion_fn=_no_row_conversion):
    conditions = []
    parameters = []
    for field_name, field_value in kwargs_dict.items():
        conditions.append(field_name + ' = %s')
        parameters.append(field_value)
    if len(conditions) == 0:
        return _general_list(clazz, sql_select)
    conn = _get_connection()
    cur = conn.cursor('crsr2') # a named cursor: scrolling is done on the server
    cur.arraysize = 10000 # scroll unit in the number of rows
    try:
        cur.execute(sql_select + ' where ' + ' and '.join(conditions) + ' order by id', parameters)
        result = []
        row = cur.fetchone()
        while row:
            result.append(clazz(*row_conversion_fn(row)))
            row = cur.fetchone()
        return result
    finally:
        cur.close()

def _general_list(clazz, sql_select, row_conversion_fn=_no_row_conversion):
    conn = _get_connection()
    cur = conn.cursor('crsr3') # a named cursor: scrolling is done on the server
    cur.arraysize = 10000 # scroll unit in the number of rows
    try:
        cur.execute(sql_select + ' order by id')
        result = []
        row = cur.fetchone()
        while row:
            result.append(clazz(*row_conversion_fn(row)))
            row = cur.fetchone()
        return result
    finally:
        cur.close()

def init():
    global _host, _port, _database, _username, _password
    _host = os.environ.get('CODEQ_DB_HOST') or 'localhost'
    try:
        _port = int(os.environ.get('CODEQ_DB_PORT')) or 5432
    except:
        _port = 5432
    _database = os.environ.get('CODEQ_DB_DATABASE') or 'codeq'
    _username = os.environ.get('CODEQ_DB_USER') or 'codeq'
    _password = os.environ.get('CODEQ_DB_PASS') or 'c0d3q'

if __name__ == '__main__':
    init()