diff options
-rwxr-xr-x | monkey/monkey.py | 282 | ||||
-rwxr-xr-x | monkey/test.py | 12 | ||||
-rw-r--r-- | prolog/util.py | 190 |
3 files changed, 290 insertions, 194 deletions
diff --git a/monkey/monkey.py b/monkey/monkey.py index 79584e0..a7beb6e 100755 --- a/monkey/monkey.py +++ b/monkey/monkey.py @@ -5,7 +5,7 @@ import time from .edits import classify_edits from prolog.engine import test -from prolog.util import Token, compose, decompose, map_vars, normalized, rename_vars, stringify +from prolog.util import Token, annotate, compose, map_vars, normalized, rename_vars, stringify from .util import PQueue # Starting from [code], find a sequence of edits that transforms it into a @@ -20,107 +20,198 @@ def fix(name, code, edits, aux_code='', timeout=30, debug=False): inserts, removes, changes = classify_edits(edits) # Generate states that can be reached from the given program with one edit. - # Program code is given as a list of [lines], where each line is a list of - # tokens. Rule ranges are given in [rules] (see prolog.util.decompose). - def step(lines, rules, path=None): - # Apply edits in order from top to bottom; skip lines with index lower - # than last step. - start_line = path[-1][1] if path else 0 - - for start, end in rules: - rule_lines = lines[start:end] - rule_vars = [t.val for line in rule_lines for t in line - if t.type == 'VARIABLE' and t.val != '_'] - - # Prepend a new rule (fact) before this rule (only if no line in - # the current rule has been modified yet). - if start_line == 0 or start > start_line: - for after, cost in inserts.items(): - new_lines = lines[:start] + (after,) + lines[start:] - new_rules = [] - for old_start, old_end in rules: - if old_start == start: - new_rules.append((start, start+1)) - new_rules.append((old_start + (0 if old_start < start else 1), - old_end + (0 if old_end < start else 1))) - new_step = ('add_rule', start, (tuple(), after)) - # Decrease probability as we add more rules. - new_cost = cost * math.pow(0.3, len(rules)) - - yield (new_lines, new_rules, new_step, new_cost) - - # Apply some edits for each line in this rule. - for line_idx in range(start, end): - if line_idx < start_line: - continue - line = lines[line_idx] - line_normal = tuple(rename_vars(line)) + # The program is given as a list of tokens. + def step(program, path=None): + # Apply edits to program in order; skip tokens with index lower than + # last step. + start = path[-1][4] if path else 0 + + first = 0 # first token in the current part + variables = [] + for i, token in enumerate(program): + # Get variable names in the current rule. + if i == 0 or program[i-1].type == 'PERIOD': + variables = [] + for t in program[i:]: + if t.type == 'VARIABLE' and not t.val.startswith('_'): + if t.val not in variables: + variables.append(t.val) + elif t.type == 'PERIOD': + break + + # Skip already modified parts of the program. + if i < start: + continue + + if i == 0 or program[i-1].stop: + first = i + + # Add a new fact at beginning or after another rule. + if i == 0 or token.type == 'PERIOD': + new_start = i+1 if token.type == 'PERIOD' else 0 + n_rules = program.count(Token('PERIOD', '.')) + rule = 0 if i == 0 else program[i-1].rule+1 # index of new rule + + for new, cost in inserts.items(): + # New rule must start with correct predicate name. + if not (new[0].type == 'NAME' and new[0].val in name): + continue - # Apply edits whose left-hand side matches this line. + # Here we always insert a fact, so replace trailing :- with .. + if new[-1].type == 'FROM': + new = new[:-1] + + new_real = tuple([t.clone(rule=rule, part=0) for t in new + (Token('PERIOD', '.', stop=True),)]) + new_after = tuple([t.clone(rule=t.rule+1) for t in program[new_start:]]) + + new_program = program[:new_start] + new_real + new_after + new_step = ('add_rule', new_start, (), new_real, new_start) + new_cost = cost * math.pow(0.3, n_rules) + yield (new_program, new_step, new_cost) + + if token.stop and i > first: + real_last = last = i+1 + if token.type != 'FROM': + real_last -= 1 + part = program[first:real_last] + part_whole = program[first:last] + part_normal = tuple(rename_vars(part)) + + # Apply each edit a→b where a matches this part. seen = False # has there been such an edit? - for (before, after), cost in changes.items(): - if line_normal == before: + for (a, b), cost in changes.items(): + if part_normal == a: seen = True - after_real = tuple(map_vars(before, after, line, rule_vars)) - new_lines = lines[:line_idx] + (after_real,) + lines[line_idx+1:] - new_step = ('change_line', line_idx, (tuple(line), after_real)) - - yield (new_lines, rules, new_step, cost) - - # Remove the current line. - if line_normal in removes.keys() or not seen: - new_lines = lines[:line_idx] + lines[line_idx+1:] - new_rules = [] - for old_start, old_end in rules: - new_start, new_end = (old_start - (0 if old_start <= line_idx else 1), - old_end - (0 if old_end <= line_idx else 1)) - if new_end > new_start: - new_rules.append((new_start, new_end)) - new_step = ('remove_line', line_idx, (tuple(line), ())) - new_cost = removes.get(line_normal, 1.0) * 0.99 - - yield (new_lines, new_rules, new_step, new_cost) - - # Add a line after this one. - for after, cost in inserts.items(): - # Don't try to insert a head into the body. - if after[-1].type == 'FROM': - continue - after_real = tuple(map_vars([], after, [], rule_vars)) - - idx = line_idx+1 - new_lines = lines[:idx] + (after_real,) + lines[idx:] - new_rules = [] - for old_start, old_end in rules: - new_rules.append((old_start + (0 if old_start < idx else 1), - old_end + (0 if old_end < idx else 1))) - new_step = ('add_subgoal', idx, ((), after_real)) - new_cost = cost * 0.4 - - yield (new_lines, new_rules, new_step, new_cost) - - # Return a cleaned-up list of steps: - # - multiple line edits in a single line are merged into one - # - check if any lines can be removed from the program + if b[-1].type == 'FROM': + b = b[:-1] + (b[-1].clone(stop=True),) + b_real = tuple([t.clone(rule=program[first].rule, + part=program[first].part) + for t in map_vars(a, b, part, variables)]) + + new_program = program[:first] + b_real + program[real_last:] + new_step = ('change_part', first, part, b_real, first) + yield (new_program, new_step, cost) + + + # Remove this part. + if token.type in ('COMMA', 'SEMI'): + if part_normal in removes.keys() or not seen: + new_after = list(program[last:]) + for j, t in enumerate(new_after): + if t.rule > token.rule: + break + new_after[j] = t.clone(part=t.part-1) + new_program = program[:first] + tuple(new_after) + new_step = ('remove_part', first, part_whole, (), first-1) + new_cost = removes.get(part_normal, 1.0) * 0.99 + yield (new_program, new_step, new_cost) + + # Remove this part at the end of the current rule. + if token.type == 'PERIOD' and token.part > 0: + if part_normal in removes.keys() or not seen: + if token.part == 0: # part is fact, remove rule + new_after = list(program[last+1:]) + for j, t in enumerate(new_after): + new_after[j] = t.clone(rule=t.rule-1) + new_program = program[:first] + tuple(new_after) + new_step = ('remove_rule', first, part, (), first) + new_cost = removes.get(part_normal, 1.0) * 0.99 + yield (new_program, new_step, new_cost) + + else: # part is subgoal, remove part + new_after = list(program[last-1:]) + for j, t in enumerate(new_after): + if t.rule > token.rule: + break + new_after[j] = t.clone(part=t.part-1) + new_program = program[:first-1] + tuple(new_after) + new_step = ('remove_part', first-1, (program[first-1],) + part, (), first-1) + new_cost = removes.get(part_normal, 1.0) * 0.99 + yield (new_program, new_step, new_cost) + + + # Insert a new part (goal) after this part. + if token.type in ('COMMA', 'FROM'): + for new, cost in inserts.items(): + # Don't try to insert a head into the body. + if new[-1].type == 'FROM': + continue + + new_real = tuple([t.clone(rule=program[first].rule, + part=program[first].part+1) + for t in map_vars([], new, [], variables) + [Token('COMMA', ',', stop=True)]]) + + new_after = list(program[last:]) + for j, t in enumerate(new_after): + if t.rule > token.rule: + break + new_after[j] = t.clone(rule=program[first].rule, part=t.part+1) + + new_program = program[:last] + new_real + tuple(new_after) + new_step = ('add_part', last, (), new_real, last) + new_cost = cost * 0.4 + yield (new_program, new_step, new_cost) + + # Insert a new part (goal) at the end of current rule. + if token.type == 'PERIOD': + for new, cost in inserts.items(): + # Don't try to insert a head into the body. + if new[-1].type == 'FROM': + continue + + prepend = Token('FROM', ':-') if token.part == 0 else Token('COMMA', ',') + new_real = (prepend.clone(stop=True, rule=token.rule, part=token.part),) + \ + tuple([t.clone(rule=token.rule, part=token.part+1) + for t in map_vars([], new, [], variables)]) + + new_after = list(program[last-1:]) + for j, t in enumerate(new_after): + if t.rule > token.rule: + break + new_after[j] = t.clone(rule=t.rule, part=t.part+1) + + new_program = program[:last-1] + new_real + tuple(new_after) + new_step = ('add_part', last-1, (), new_real, last) + new_cost = cost * 0.4 + yield (new_program, new_step, new_cost) + + # Return a cleaned-up list of steps. def postprocess(steps): new_steps = [] for step in steps: + # Remove the last field from each step as it is unnecessary after a + # path has been found. + step = step[:4] if new_steps: prev = new_steps[-1] + if prev[0] == 'remove_part' and step[0] == 'remove_part' and \ + prev[1] == step[1]: + new_steps[-1] = ('remove_part', prev[1], prev[2]+step[2], ()) + continue + + if prev[0] == 'remove_part' and step[0] == 'add_part' and \ + prev[1] == step[1]: + new_steps[-1] = ('change_part', prev[1], prev[2], step[3]) + continue - if prev[1] == step[1] and \ - prev[0] in ('add_rule', 'add_subgoal', 'change_line') and \ - step[0] == 'change_line' and \ - normalized(prev[2][1]) == normalized(step[2][0]): - new_steps[-1] = (prev[0], prev[1], (prev[2][0], step[2][1])) + if prev[0] == 'change_part' and step[0] == 'change_part' and \ + prev[1] == step[1] and step[2] == prev[3]: + new_steps[-1] = ('change_part', prev[1], prev[2], step[3]) continue - if prev[0] == 'add_subgoal' and step[0] == 'remove_line' and \ - prev[1]+1 == step[1]: - new_steps[-1] = ('change_line', prev[1], (step[2][0], prev[2][1])) + if prev[0] in ('add_part', 'change_part') and step[0] == 'change_part' and \ + prev[1] == step[1] and step[2] == prev[3][:-1]: + new_steps[-1] = (prev[0], prev[1], prev[2], step[3]+(prev[3][-1],)) continue + if prev[0] in ('add_part', 'change_part') and step[0] == 'change_part' and \ + step[1] == prev[1]+1 and step[2] == prev[3][1:]: + new_steps[-1] = (prev[0], prev[1], prev[2], (prev[3][0],)+step[3]) + continue new_steps.append(step) + for step in new_steps: + print('index {}: {} {} → {}'.format( + step[1], step[0], stringify(step[2]), stringify(step[3]))) return new_steps # Main loop: best-first search through generated programs. @@ -129,8 +220,7 @@ def fix(name, code, edits, aux_code='', timeout=30, debug=False): # Each program gets a task with the sequence of edits that generated it and # the associated cost. First candidate with cost 1 is the initial program. - lines, rules = decompose(code) - todo.push(((tuple(lines), tuple(rules)), (), 1.0), -1.0) + todo.push((program, (), 1.0), -1.0) n_tested = 0 start_time = time.monotonic() @@ -140,10 +230,10 @@ def fix(name, code, edits, aux_code='', timeout=30, debug=False): task = todo.pop() if task == None: break - (lines, rules), path, path_cost = task + program, path, path_cost = task # If we have already seen this code, skip it. - code = compose(lines, rules) + code = compose(program) if code in done: continue done.add(code) @@ -151,8 +241,8 @@ def fix(name, code, edits, aux_code='', timeout=30, debug=False): # Print some info about the current task. if debug: print('Cost {:.12f}'.format(path_cost)) - for step_type, line, (before, after) in path: - print('line {}: {} {} → {}'.format(line, step_type, stringify(before), stringify(after))) + for step_type, idx, a, b, _ in path: + print('index {}: {} {} → {}'.format(idx, step_type, stringify(a), stringify(b))) # If the code is correct, we are done. if test(name, code + '\n' + aux_code): @@ -161,12 +251,12 @@ def fix(name, code, edits, aux_code='', timeout=30, debug=False): n_tested += 1 # Otherwise generate new solutions. - for new_lines, new_rules, new_step, new_cost in step(lines, rules, path): + for new_program, new_step, new_cost in step(program, path): new_path_cost = path_cost * new_cost if new_path_cost < 0.01: continue new_path = path + (new_step,) - todo.push(((tuple(new_lines), tuple(new_rules)), new_path, new_path_cost), -new_path_cost) + todo.push((new_program, new_path, new_path_cost), -new_path_cost) total_time = time.monotonic() - start_time diff --git a/monkey/test.py b/monkey/test.py index 83aa0c2..17c27fe 100755 --- a/monkey/test.py +++ b/monkey/test.py @@ -11,7 +11,7 @@ from .edits import classify_edits, trace_graph from .graph import graphviz from .monkey import fix from prolog.engine import test -from prolog.util import compose, decompose, stringify +from prolog.util import annotate, compose, stringify from .util import indent # Load django models. @@ -57,10 +57,10 @@ def print_hint(solution, steps, fix_time, n_tested): if solution: print(colored('Hint found! Tested {} programs in {:.1f} s.'.format(n_tested, fix_time), 'green')) print(colored(' Edits', 'blue')) - for step_type, line, (before, after) in steps: - print(' {}: {} {} → {}'.format(line, step_type, stringify(before), stringify(after))) + for step_type, pos, a, b in steps: + print(' {}: {} {} → {}'.format(pos, step_type, stringify(a), stringify(b))) print(colored(' Final version', 'blue')) - print(indent(compose(*decompose(solution)), 2)) + print(indent(compose(annotate(solution)), 2)) else: print(colored('Hint not found! Tested {} programs in {:.1f} s.'.format(n_tested, fix_time), 'red')) @@ -75,7 +75,7 @@ if len(sys.argv) >= 3 and sys.argv[2] == 'test': if program in done: continue print(colored('Analyzing program {0}/{1}…'.format(i+1, len(incorrect)), 'yellow')) - print(indent(compose(*decompose(program)), 2)) + print(indent(compose(annotate(program)), 2)) solution, steps, fix_time, n_tested = fix(problem.name, program, edits, aux_code=aux_code, timeout=timeout) if solution: @@ -119,7 +119,7 @@ elif len(sys.argv) >= 3 and sys.argv[2] == 'info': for p in sorted(incorrect): if p in done: continue - print(indent(compose(*decompose(p)), 2)) + print(indent(compose(annotate(p)), 2)) print() # Print all student queries and their counts. elif sys.argv[3] == 'queries': diff --git a/prolog/util.py b/prolog/util.py index 30f12da..e9c1811 100644 --- a/prolog/util.py +++ b/prolog/util.py @@ -30,14 +30,14 @@ class Token(namedtuple('Token', ['type', 'val', 'pos', 'rule', 'part', 'stop'])) def __hash__(self): return hash(self[1]) -# Return a new Token, possibly modifying some fields. -def clone_token(token, val=None, pos=None, rule=None, part=None): - return Token(token.type, - token.val if val is None else val, - token.pos if pos is None else pos, - token.rule if rule is None else rule, - token.part if part is None else part, - token.stop) + # Return a copy of this token, possibly modifying some fields. + def clone(self, type=None, val=None, pos=None, rule=None, part=None, stop=None): + return Token(self.type if type is None else type, + self.val if val is None else val, + self.pos if pos is None else pos, + self.rule if rule is None else rule, + self.part if part is None else part, + self.stop if stop is None else stop) # Return a list of tokens in [text]. def tokenize(text): @@ -63,83 +63,92 @@ def split(code): yield stringify(tokens[start:idx]) start = idx + 1 -# Return a list of lines in [code] and a list of rule ranges. -def decompose(code): - lines = [] - rules = [] - - rule_start = 0 # lowest line number in the current rule - line = [] # tokens in the current line - break_line = True # for each comma, go to a new line +# Lex [code] into tokens with rule indexes and stop markers. +def annotate(code): + rule = 0 + part = 0 parens = [] # stack of currently open parens/brackets/braces - - for t in tokenize(code) + [Token('EOF')]: - # Always break the line on a semicolon, even inside parens. - if t.type == 'SEMI': - if line: - lines.append(tuple(line)) - line = [] - lines.append((t,)) - continue - - # Break the line on these tokens if we are not inside parens. Don't - # append the final token unless it is the :- operator. - if break_line and t.type in ('PERIOD', 'FROM', 'COMMA', 'EOF'): - # Only append :- at the end of the line, ignore commas and periods. - if t.type == 'FROM': - line.append(t) - - # Append nonempty lines to the output list. - if line: - lines.append(tuple(line)) - line = [] - - # Commit a new rule if it contains some lines. - if t.type in ('PERIOD', 'EOF') and rule_start < len(lines): - rules.append((rule_start, len(lines))) - rule_start = len(lines) - continue - - # Handle parens. - if t.type == 'LPAREN': - # Disallow breaking lines inside "name( )" (e.g. member(X, L)) but - # not other ( ). - if line and line[-1].type == 'NAME': - parens.append('paren') - break_line = False + in_parens = 0 # COMMA means a new part if this is 0 + + token = None + lexer.input(code) + for t in lexer: + tok_rule = rule + tok_part = part + tok_stop = True + + if t.type == 'PERIOD': # . + rule += 1 + part = 0 + in_parens = 0 + parens = [] + elif t.type in ('FROM', 'SEMI'): # :- ; + part += 1 + elif t.type == 'COMMA': # , + if not parens or in_parens == 0: + part += 1 else: - parens.append('ignore') - elif t.type in ('LBRACKET', 'LBRACE'): - # Disallow breaking lines inside "[ ]" and "{ }". - parens.append('paren') - break_line = False - elif parens: - if t.type in ('RPAREN', 'RBRACE', 'RBRACKET'): + tok_stop = False + + # Handle left parens. + elif t.type == 'LPAREN': # ( + if token and token.type == 'NAME': # name( + tok_stop = False + parens.append('COMPOUND') + in_parens += 1 + else: + parens.append(t.type) # …, ( + elif t.type == 'LBRACKET': # [ + tok_stop = False + parens.append(t.type) + in_parens += 1 + elif t.type == 'LBRACE': # { + parens.append(t.type) + + # Handle right parens. + elif t.type == 'RPAREN': # ) + if parens: + if parens[-1] == 'COMPOUND': # name(…) + tok_stop = False + parens.pop() + in_parens -= 1 + elif parens[-1] == 'LPAREN': # (…) + parens.pop() + elif t.type == 'RBRACKET': # ] + if parens and parens[-1] == 'LBRACKET': # […] + tok_stop = False + parens.pop() + in_parens -= 1 + elif t.type == 'RBRACE': # } + if parens and parens[-1] == 'LBRACE': # {…} parens.pop() - break_line = 'paren' not in parens - # Append this token to the current line. - line.append(t) + # Normal tokens. + else: + tok_stop = False - return lines, rules + token = Token(t.type, t.value, t.lexpos, tok_rule, tok_part, tok_stop) + yield token -# Format a list of [lines] according to [rules] (as returned by decompose). -def compose(lines, rules): +# Format a list of annotated [tokens]. +def compose(tokens): code = '' - for start, end in rules: - for i in range(start, end): - line = lines[i] - if i > start: + prev = None + for t in tokens: + if t.type == 'SEMI': + code += '\n ' + if prev and (prev.part != t.part or prev.rule != t.rule): + code += '\n' + if t.part: code += ' ' - code += stringify(line) - if i == end-1: - code += '.\n' - elif i == start: - code += '\n' - else: - if line and line[-1].type != 'SEMI' and lines[i+1][-1].type != 'SEMI': - code += ',' - code += '\n' + + if t.type in ('PERIOD', 'COMMA'): + code += t.val + ' ' + elif t.type in operators.values(): + code += ' ' + t.val + ' ' + else: + code += t.val + prev = t return code.strip() # Rename variables in [tokens] to A0, A1, A2,… in order of appearance. @@ -150,19 +159,20 @@ def rename_vars(tokens, names=None): # Return a new list. tokens = list(tokens) - for i in range(len(tokens)): - if tokens[i].type == 'PERIOD': + for i, t in enumerate(tokens): + if t.type == 'PERIOD': names.clear() next_id = 0 - elif tokens[i] == Token('VARIABLE', '_'): - tokens[i] = Token('VARIABLE', 'A{}'.format(next_id)) - next_id += 1 - elif tokens[i].type == 'VARIABLE': - cur_name = tokens[i].val - if cur_name not in names: - names[cur_name] = 'A{}'.format(next_id) + elif t.type == 'VARIABLE': + if t.val.startswith('_'): + tokens[i] = t.clone(val='A{}'.format(next_id)) next_id += 1 - tokens[i] = Token('VARIABLE', names[cur_name]) + else: + cur_name = t.val + if cur_name not in names: + names[cur_name] = 'A{}'.format(next_id) + next_id += 1 + tokens[i] = t.clone(val=names[cur_name]) return tokens # Helper function to remove trailing punctuation from lines and rename @@ -203,14 +213,10 @@ def map_vars(a, b, tokens, variables): for i, formal_name in enumerate(remaining_formal): mapping[formal_name] = remaining_actual[i] - return [t if t.type != 'VARIABLE' else clone_token(t, val=mapping[t.val]) for t in b] + return [t if t.type != 'VARIABLE' else t.clone(val=mapping[t.val]) for t in b] # Basic sanity check. if __name__ == '__main__': - code = 'dup([H|T], [H1|T1]) :- dup(T1, T2). ' - lines, rules = decompose(code) - print(compose(lines, rules)) - var_names = {} before = rename_vars(tokenize("dup([A0|A1], [A2|A3])"), var_names) after = rename_vars(tokenize("dup([A0|A1], [A5, A4|A3])"), var_names) |