summaryrefslogtreecommitdiff
path: root/python/util.py
blob: f04fcc26cf0d40d8937495a90019b4155d52f434 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/python

# CodeQ: an online programming tutor.
# Copyright (C) 2015 UL FRI
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
# details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import io
import re
from tokenize import tokenize, TokenError
import ast

def get_tokens(code):
    """ Gets a list of tokens. """
    try:
        stream = io.BytesIO(code.encode('utf-8'))
        return [t.string for t in tokenize(stream.readline) if t.string]
    except TokenError:
        return []

def all_tokens(code):
    try:
        stream = io.BytesIO(code.encode('utf-8'))
        return [t for t in tokenize(stream.readline)]
    except TokenError:
        return []

# Check if tokens contain a sequence of tokens (given as a list of strings).
def has_token_sequence(tokens, sequence):
    for i in range(len(tokens)-len(sequence)+1):
        if tokens[i:i+len(sequence)] == sequence:
            return True
    return False

def almost_equal(a, b, prec):
    """ Compares values a and b 
    using prec if value < 1 or 
    prec most significant numbers otherwise. 
    """
    return abs(a-b) <= max(abs(a), abs(b), 1) * 10**(-prec)

def get_numbers(s):
    """ Extracts numbers from string s. """
    str_vals = re.findall(r'''
        [-+]? # optional sign
        (?:
             (?: \d* \. \d+ ) # .1 .12 .123 etc 9.1 etc 98.1 etc
             |
             (?: \d+ \.? ) # 1. 12. 123. etc 1 12 123 etc
        )
        # followed by optional exponent part if desired
        (?: [Ee] [+-]? \d+ ) ?
        ''', s, re.VERBOSE)
    return [float(v) for v in str_vals]

def string_almost_equal(s, a, prec=3):
    """ Searches string s for a value that is almost equal to a. """
    for v in get_numbers(s):
        if almost_equal(v, a, prec):
                return True
    return False

def string_contains_number(s, a):
    """ Searches string s for a value that is equal to a. """
    return a in get_numbers(s)


def get_exception_desc(exc):
    # if have an exception!
    if exc:
        if 'EOFError' in exc:
            return [{'id':'eof_error'}]
        if 'timed out' in exc:
            return [{'id':'timed_out'}]
        if 'sandbox violation' in exc:
            return [{'id': 'sandbox_violation'}]
        if 'NameError' in exc:
            return [{'id':'name_error', 'args': {'message': exc}}]
        elif 'TypeError' in exc:
            return [{'id':'type_error', 'args': {'message': exc}}]
        else:
            return [{'id':'error', 'args': {'message': exc}}]
    return None

def get_ast(code):
    """
    Turn code into ast; use it when regular expressions on strings
    are not enjoyable enough.

    """
    return ast.parse(code)

def has_comprehension(tree):
    """  Searches code for comprehensions and generators. """
    for n in ast.walk(tree):
        if isinstance(n, ast.comprehension):
            return True
    return False

def has_loop(tree):
    """  Searches abstract syntax tree for loops (for and while). """
    for n in ast.walk(tree):
        if isinstance(n, ast.For) or isinstance(n, ast.While):
            return True
    return False

if __name__ == '__main__':
    print(has_token_sequence(get_tokens('x + y >= 0'), ['>=', '0']))
    print(has_token_sequence(get_tokens('x + y > 0'), ['>=', '0']))