summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xmonkey/monkey.py282
-rwxr-xr-xmonkey/test.py12
-rw-r--r--prolog/util.py190
3 files changed, 290 insertions, 194 deletions
diff --git a/monkey/monkey.py b/monkey/monkey.py
index 79584e0..a7beb6e 100755
--- a/monkey/monkey.py
+++ b/monkey/monkey.py
@@ -5,7 +5,7 @@ import time
from .edits import classify_edits
from prolog.engine import test
-from prolog.util import Token, compose, decompose, map_vars, normalized, rename_vars, stringify
+from prolog.util import Token, annotate, compose, map_vars, normalized, rename_vars, stringify
from .util import PQueue
# Starting from [code], find a sequence of edits that transforms it into a
@@ -20,107 +20,198 @@ def fix(name, code, edits, aux_code='', timeout=30, debug=False):
inserts, removes, changes = classify_edits(edits)
# Generate states that can be reached from the given program with one edit.
- # Program code is given as a list of [lines], where each line is a list of
- # tokens. Rule ranges are given in [rules] (see prolog.util.decompose).
- def step(lines, rules, path=None):
- # Apply edits in order from top to bottom; skip lines with index lower
- # than last step.
- start_line = path[-1][1] if path else 0
-
- for start, end in rules:
- rule_lines = lines[start:end]
- rule_vars = [t.val for line in rule_lines for t in line
- if t.type == 'VARIABLE' and t.val != '_']
-
- # Prepend a new rule (fact) before this rule (only if no line in
- # the current rule has been modified yet).
- if start_line == 0 or start > start_line:
- for after, cost in inserts.items():
- new_lines = lines[:start] + (after,) + lines[start:]
- new_rules = []
- for old_start, old_end in rules:
- if old_start == start:
- new_rules.append((start, start+1))
- new_rules.append((old_start + (0 if old_start < start else 1),
- old_end + (0 if old_end < start else 1)))
- new_step = ('add_rule', start, (tuple(), after))
- # Decrease probability as we add more rules.
- new_cost = cost * math.pow(0.3, len(rules))
-
- yield (new_lines, new_rules, new_step, new_cost)
-
- # Apply some edits for each line in this rule.
- for line_idx in range(start, end):
- if line_idx < start_line:
- continue
- line = lines[line_idx]
- line_normal = tuple(rename_vars(line))
+ # The program is given as a list of tokens.
+ def step(program, path=None):
+ # Apply edits to program in order; skip tokens with index lower than
+ # last step.
+ start = path[-1][4] if path else 0
+
+ first = 0 # first token in the current part
+ variables = []
+ for i, token in enumerate(program):
+ # Get variable names in the current rule.
+ if i == 0 or program[i-1].type == 'PERIOD':
+ variables = []
+ for t in program[i:]:
+ if t.type == 'VARIABLE' and not t.val.startswith('_'):
+ if t.val not in variables:
+ variables.append(t.val)
+ elif t.type == 'PERIOD':
+ break
+
+ # Skip already modified parts of the program.
+ if i < start:
+ continue
+
+ if i == 0 or program[i-1].stop:
+ first = i
+
+ # Add a new fact at beginning or after another rule.
+ if i == 0 or token.type == 'PERIOD':
+ new_start = i+1 if token.type == 'PERIOD' else 0
+ n_rules = program.count(Token('PERIOD', '.'))
+ rule = 0 if i == 0 else program[i-1].rule+1 # index of new rule
+
+ for new, cost in inserts.items():
+ # New rule must start with correct predicate name.
+ if not (new[0].type == 'NAME' and new[0].val in name):
+ continue
- # Apply edits whose left-hand side matches this line.
+ # Here we always insert a fact, so replace trailing :- with ..
+ if new[-1].type == 'FROM':
+ new = new[:-1]
+
+ new_real = tuple([t.clone(rule=rule, part=0) for t in new + (Token('PERIOD', '.', stop=True),)])
+ new_after = tuple([t.clone(rule=t.rule+1) for t in program[new_start:]])
+
+ new_program = program[:new_start] + new_real + new_after
+ new_step = ('add_rule', new_start, (), new_real, new_start)
+ new_cost = cost * math.pow(0.3, n_rules)
+ yield (new_program, new_step, new_cost)
+
+ if token.stop and i > first:
+ real_last = last = i+1
+ if token.type != 'FROM':
+ real_last -= 1
+ part = program[first:real_last]
+ part_whole = program[first:last]
+ part_normal = tuple(rename_vars(part))
+
+ # Apply each edit a→b where a matches this part.
seen = False # has there been such an edit?
- for (before, after), cost in changes.items():
- if line_normal == before:
+ for (a, b), cost in changes.items():
+ if part_normal == a:
seen = True
- after_real = tuple(map_vars(before, after, line, rule_vars))
- new_lines = lines[:line_idx] + (after_real,) + lines[line_idx+1:]
- new_step = ('change_line', line_idx, (tuple(line), after_real))
-
- yield (new_lines, rules, new_step, cost)
-
- # Remove the current line.
- if line_normal in removes.keys() or not seen:
- new_lines = lines[:line_idx] + lines[line_idx+1:]
- new_rules = []
- for old_start, old_end in rules:
- new_start, new_end = (old_start - (0 if old_start <= line_idx else 1),
- old_end - (0 if old_end <= line_idx else 1))
- if new_end > new_start:
- new_rules.append((new_start, new_end))
- new_step = ('remove_line', line_idx, (tuple(line), ()))
- new_cost = removes.get(line_normal, 1.0) * 0.99
-
- yield (new_lines, new_rules, new_step, new_cost)
-
- # Add a line after this one.
- for after, cost in inserts.items():
- # Don't try to insert a head into the body.
- if after[-1].type == 'FROM':
- continue
- after_real = tuple(map_vars([], after, [], rule_vars))
-
- idx = line_idx+1
- new_lines = lines[:idx] + (after_real,) + lines[idx:]
- new_rules = []
- for old_start, old_end in rules:
- new_rules.append((old_start + (0 if old_start < idx else 1),
- old_end + (0 if old_end < idx else 1)))
- new_step = ('add_subgoal', idx, ((), after_real))
- new_cost = cost * 0.4
-
- yield (new_lines, new_rules, new_step, new_cost)
-
- # Return a cleaned-up list of steps:
- # - multiple line edits in a single line are merged into one
- # - check if any lines can be removed from the program
+ if b[-1].type == 'FROM':
+ b = b[:-1] + (b[-1].clone(stop=True),)
+ b_real = tuple([t.clone(rule=program[first].rule,
+ part=program[first].part)
+ for t in map_vars(a, b, part, variables)])
+
+ new_program = program[:first] + b_real + program[real_last:]
+ new_step = ('change_part', first, part, b_real, first)
+ yield (new_program, new_step, cost)
+
+
+ # Remove this part.
+ if token.type in ('COMMA', 'SEMI'):
+ if part_normal in removes.keys() or not seen:
+ new_after = list(program[last:])
+ for j, t in enumerate(new_after):
+ if t.rule > token.rule:
+ break
+ new_after[j] = t.clone(part=t.part-1)
+ new_program = program[:first] + tuple(new_after)
+ new_step = ('remove_part', first, part_whole, (), first-1)
+ new_cost = removes.get(part_normal, 1.0) * 0.99
+ yield (new_program, new_step, new_cost)
+
+ # Remove this part at the end of the current rule.
+ if token.type == 'PERIOD' and token.part > 0:
+ if part_normal in removes.keys() or not seen:
+ if token.part == 0: # part is fact, remove rule
+ new_after = list(program[last+1:])
+ for j, t in enumerate(new_after):
+ new_after[j] = t.clone(rule=t.rule-1)
+ new_program = program[:first] + tuple(new_after)
+ new_step = ('remove_rule', first, part, (), first)
+ new_cost = removes.get(part_normal, 1.0) * 0.99
+ yield (new_program, new_step, new_cost)
+
+ else: # part is subgoal, remove part
+ new_after = list(program[last-1:])
+ for j, t in enumerate(new_after):
+ if t.rule > token.rule:
+ break
+ new_after[j] = t.clone(part=t.part-1)
+ new_program = program[:first-1] + tuple(new_after)
+ new_step = ('remove_part', first-1, (program[first-1],) + part, (), first-1)
+ new_cost = removes.get(part_normal, 1.0) * 0.99
+ yield (new_program, new_step, new_cost)
+
+
+ # Insert a new part (goal) after this part.
+ if token.type in ('COMMA', 'FROM'):
+ for new, cost in inserts.items():
+ # Don't try to insert a head into the body.
+ if new[-1].type == 'FROM':
+ continue
+
+ new_real = tuple([t.clone(rule=program[first].rule,
+ part=program[first].part+1)
+ for t in map_vars([], new, [], variables) + [Token('COMMA', ',', stop=True)]])
+
+ new_after = list(program[last:])
+ for j, t in enumerate(new_after):
+ if t.rule > token.rule:
+ break
+ new_after[j] = t.clone(rule=program[first].rule, part=t.part+1)
+
+ new_program = program[:last] + new_real + tuple(new_after)
+ new_step = ('add_part', last, (), new_real, last)
+ new_cost = cost * 0.4
+ yield (new_program, new_step, new_cost)
+
+ # Insert a new part (goal) at the end of current rule.
+ if token.type == 'PERIOD':
+ for new, cost in inserts.items():
+ # Don't try to insert a head into the body.
+ if new[-1].type == 'FROM':
+ continue
+
+ prepend = Token('FROM', ':-') if token.part == 0 else Token('COMMA', ',')
+ new_real = (prepend.clone(stop=True, rule=token.rule, part=token.part),) + \
+ tuple([t.clone(rule=token.rule, part=token.part+1)
+ for t in map_vars([], new, [], variables)])
+
+ new_after = list(program[last-1:])
+ for j, t in enumerate(new_after):
+ if t.rule > token.rule:
+ break
+ new_after[j] = t.clone(rule=t.rule, part=t.part+1)
+
+ new_program = program[:last-1] + new_real + tuple(new_after)
+ new_step = ('add_part', last-1, (), new_real, last)
+ new_cost = cost * 0.4
+ yield (new_program, new_step, new_cost)
+
+ # Return a cleaned-up list of steps.
def postprocess(steps):
new_steps = []
for step in steps:
+ # Remove the last field from each step as it is unnecessary after a
+ # path has been found.
+ step = step[:4]
if new_steps:
prev = new_steps[-1]
+ if prev[0] == 'remove_part' and step[0] == 'remove_part' and \
+ prev[1] == step[1]:
+ new_steps[-1] = ('remove_part', prev[1], prev[2]+step[2], ())
+ continue
+
+ if prev[0] == 'remove_part' and step[0] == 'add_part' and \
+ prev[1] == step[1]:
+ new_steps[-1] = ('change_part', prev[1], prev[2], step[3])
+ continue
- if prev[1] == step[1] and \
- prev[0] in ('add_rule', 'add_subgoal', 'change_line') and \
- step[0] == 'change_line' and \
- normalized(prev[2][1]) == normalized(step[2][0]):
- new_steps[-1] = (prev[0], prev[1], (prev[2][0], step[2][1]))
+ if prev[0] == 'change_part' and step[0] == 'change_part' and \
+ prev[1] == step[1] and step[2] == prev[3]:
+ new_steps[-1] = ('change_part', prev[1], prev[2], step[3])
continue
- if prev[0] == 'add_subgoal' and step[0] == 'remove_line' and \
- prev[1]+1 == step[1]:
- new_steps[-1] = ('change_line', prev[1], (step[2][0], prev[2][1]))
+ if prev[0] in ('add_part', 'change_part') and step[0] == 'change_part' and \
+ prev[1] == step[1] and step[2] == prev[3][:-1]:
+ new_steps[-1] = (prev[0], prev[1], prev[2], step[3]+(prev[3][-1],))
continue
+ if prev[0] in ('add_part', 'change_part') and step[0] == 'change_part' and \
+ step[1] == prev[1]+1 and step[2] == prev[3][1:]:
+ new_steps[-1] = (prev[0], prev[1], prev[2], (prev[3][0],)+step[3])
+ continue
new_steps.append(step)
+ for step in new_steps:
+ print('index {}: {} {} → {}'.format(
+ step[1], step[0], stringify(step[2]), stringify(step[3])))
return new_steps
# Main loop: best-first search through generated programs.
@@ -129,8 +220,7 @@ def fix(name, code, edits, aux_code='', timeout=30, debug=False):
# Each program gets a task with the sequence of edits that generated it and
# the associated cost. First candidate with cost 1 is the initial program.
- lines, rules = decompose(code)
- todo.push(((tuple(lines), tuple(rules)), (), 1.0), -1.0)
+ todo.push((program, (), 1.0), -1.0)
n_tested = 0
start_time = time.monotonic()
@@ -140,10 +230,10 @@ def fix(name, code, edits, aux_code='', timeout=30, debug=False):
task = todo.pop()
if task == None:
break
- (lines, rules), path, path_cost = task
+ program, path, path_cost = task
# If we have already seen this code, skip it.
- code = compose(lines, rules)
+ code = compose(program)
if code in done:
continue
done.add(code)
@@ -151,8 +241,8 @@ def fix(name, code, edits, aux_code='', timeout=30, debug=False):
# Print some info about the current task.
if debug:
print('Cost {:.12f}'.format(path_cost))
- for step_type, line, (before, after) in path:
- print('line {}: {} {} → {}'.format(line, step_type, stringify(before), stringify(after)))
+ for step_type, idx, a, b, _ in path:
+ print('index {}: {} {} → {}'.format(idx, step_type, stringify(a), stringify(b)))
# If the code is correct, we are done.
if test(name, code + '\n' + aux_code):
@@ -161,12 +251,12 @@ def fix(name, code, edits, aux_code='', timeout=30, debug=False):
n_tested += 1
# Otherwise generate new solutions.
- for new_lines, new_rules, new_step, new_cost in step(lines, rules, path):
+ for new_program, new_step, new_cost in step(program, path):
new_path_cost = path_cost * new_cost
if new_path_cost < 0.01:
continue
new_path = path + (new_step,)
- todo.push(((tuple(new_lines), tuple(new_rules)), new_path, new_path_cost), -new_path_cost)
+ todo.push((new_program, new_path, new_path_cost), -new_path_cost)
total_time = time.monotonic() - start_time
diff --git a/monkey/test.py b/monkey/test.py
index 83aa0c2..17c27fe 100755
--- a/monkey/test.py
+++ b/monkey/test.py
@@ -11,7 +11,7 @@ from .edits import classify_edits, trace_graph
from .graph import graphviz
from .monkey import fix
from prolog.engine import test
-from prolog.util import compose, decompose, stringify
+from prolog.util import annotate, compose, stringify
from .util import indent
# Load django models.
@@ -57,10 +57,10 @@ def print_hint(solution, steps, fix_time, n_tested):
if solution:
print(colored('Hint found! Tested {} programs in {:.1f} s.'.format(n_tested, fix_time), 'green'))
print(colored(' Edits', 'blue'))
- for step_type, line, (before, after) in steps:
- print(' {}: {} {} → {}'.format(line, step_type, stringify(before), stringify(after)))
+ for step_type, pos, a, b in steps:
+ print(' {}: {} {} → {}'.format(pos, step_type, stringify(a), stringify(b)))
print(colored(' Final version', 'blue'))
- print(indent(compose(*decompose(solution)), 2))
+ print(indent(compose(annotate(solution)), 2))
else:
print(colored('Hint not found! Tested {} programs in {:.1f} s.'.format(n_tested, fix_time), 'red'))
@@ -75,7 +75,7 @@ if len(sys.argv) >= 3 and sys.argv[2] == 'test':
if program in done:
continue
print(colored('Analyzing program {0}/{1}…'.format(i+1, len(incorrect)), 'yellow'))
- print(indent(compose(*decompose(program)), 2))
+ print(indent(compose(annotate(program)), 2))
solution, steps, fix_time, n_tested = fix(problem.name, program, edits, aux_code=aux_code, timeout=timeout)
if solution:
@@ -119,7 +119,7 @@ elif len(sys.argv) >= 3 and sys.argv[2] == 'info':
for p in sorted(incorrect):
if p in done:
continue
- print(indent(compose(*decompose(p)), 2))
+ print(indent(compose(annotate(p)), 2))
print()
# Print all student queries and their counts.
elif sys.argv[3] == 'queries':
diff --git a/prolog/util.py b/prolog/util.py
index 30f12da..e9c1811 100644
--- a/prolog/util.py
+++ b/prolog/util.py
@@ -30,14 +30,14 @@ class Token(namedtuple('Token', ['type', 'val', 'pos', 'rule', 'part', 'stop']))
def __hash__(self):
return hash(self[1])
-# Return a new Token, possibly modifying some fields.
-def clone_token(token, val=None, pos=None, rule=None, part=None):
- return Token(token.type,
- token.val if val is None else val,
- token.pos if pos is None else pos,
- token.rule if rule is None else rule,
- token.part if part is None else part,
- token.stop)
+ # Return a copy of this token, possibly modifying some fields.
+ def clone(self, type=None, val=None, pos=None, rule=None, part=None, stop=None):
+ return Token(self.type if type is None else type,
+ self.val if val is None else val,
+ self.pos if pos is None else pos,
+ self.rule if rule is None else rule,
+ self.part if part is None else part,
+ self.stop if stop is None else stop)
# Return a list of tokens in [text].
def tokenize(text):
@@ -63,83 +63,92 @@ def split(code):
yield stringify(tokens[start:idx])
start = idx + 1
-# Return a list of lines in [code] and a list of rule ranges.
-def decompose(code):
- lines = []
- rules = []
-
- rule_start = 0 # lowest line number in the current rule
- line = [] # tokens in the current line
- break_line = True # for each comma, go to a new line
+# Lex [code] into tokens with rule indexes and stop markers.
+def annotate(code):
+ rule = 0
+ part = 0
parens = [] # stack of currently open parens/brackets/braces
-
- for t in tokenize(code) + [Token('EOF')]:
- # Always break the line on a semicolon, even inside parens.
- if t.type == 'SEMI':
- if line:
- lines.append(tuple(line))
- line = []
- lines.append((t,))
- continue
-
- # Break the line on these tokens if we are not inside parens. Don't
- # append the final token unless it is the :- operator.
- if break_line and t.type in ('PERIOD', 'FROM', 'COMMA', 'EOF'):
- # Only append :- at the end of the line, ignore commas and periods.
- if t.type == 'FROM':
- line.append(t)
-
- # Append nonempty lines to the output list.
- if line:
- lines.append(tuple(line))
- line = []
-
- # Commit a new rule if it contains some lines.
- if t.type in ('PERIOD', 'EOF') and rule_start < len(lines):
- rules.append((rule_start, len(lines)))
- rule_start = len(lines)
- continue
-
- # Handle parens.
- if t.type == 'LPAREN':
- # Disallow breaking lines inside "name( )" (e.g. member(X, L)) but
- # not other ( ).
- if line and line[-1].type == 'NAME':
- parens.append('paren')
- break_line = False
+ in_parens = 0 # COMMA means a new part if this is 0
+
+ token = None
+ lexer.input(code)
+ for t in lexer:
+ tok_rule = rule
+ tok_part = part
+ tok_stop = True
+
+ if t.type == 'PERIOD': # .
+ rule += 1
+ part = 0
+ in_parens = 0
+ parens = []
+ elif t.type in ('FROM', 'SEMI'): # :- ;
+ part += 1
+ elif t.type == 'COMMA': # ,
+ if not parens or in_parens == 0:
+ part += 1
else:
- parens.append('ignore')
- elif t.type in ('LBRACKET', 'LBRACE'):
- # Disallow breaking lines inside "[ ]" and "{ }".
- parens.append('paren')
- break_line = False
- elif parens:
- if t.type in ('RPAREN', 'RBRACE', 'RBRACKET'):
+ tok_stop = False
+
+ # Handle left parens.
+ elif t.type == 'LPAREN': # (
+ if token and token.type == 'NAME': # name(
+ tok_stop = False
+ parens.append('COMPOUND')
+ in_parens += 1
+ else:
+ parens.append(t.type) # …, (
+ elif t.type == 'LBRACKET': # [
+ tok_stop = False
+ parens.append(t.type)
+ in_parens += 1
+ elif t.type == 'LBRACE': # {
+ parens.append(t.type)
+
+ # Handle right parens.
+ elif t.type == 'RPAREN': # )
+ if parens:
+ if parens[-1] == 'COMPOUND': # name(…)
+ tok_stop = False
+ parens.pop()
+ in_parens -= 1
+ elif parens[-1] == 'LPAREN': # (…)
+ parens.pop()
+ elif t.type == 'RBRACKET': # ]
+ if parens and parens[-1] == 'LBRACKET': # […]
+ tok_stop = False
+ parens.pop()
+ in_parens -= 1
+ elif t.type == 'RBRACE': # }
+ if parens and parens[-1] == 'LBRACE': # {…}
parens.pop()
- break_line = 'paren' not in parens
- # Append this token to the current line.
- line.append(t)
+ # Normal tokens.
+ else:
+ tok_stop = False
- return lines, rules
+ token = Token(t.type, t.value, t.lexpos, tok_rule, tok_part, tok_stop)
+ yield token
-# Format a list of [lines] according to [rules] (as returned by decompose).
-def compose(lines, rules):
+# Format a list of annotated [tokens].
+def compose(tokens):
code = ''
- for start, end in rules:
- for i in range(start, end):
- line = lines[i]
- if i > start:
+ prev = None
+ for t in tokens:
+ if t.type == 'SEMI':
+ code += '\n '
+ if prev and (prev.part != t.part or prev.rule != t.rule):
+ code += '\n'
+ if t.part:
code += ' '
- code += stringify(line)
- if i == end-1:
- code += '.\n'
- elif i == start:
- code += '\n'
- else:
- if line and line[-1].type != 'SEMI' and lines[i+1][-1].type != 'SEMI':
- code += ','
- code += '\n'
+
+ if t.type in ('PERIOD', 'COMMA'):
+ code += t.val + ' '
+ elif t.type in operators.values():
+ code += ' ' + t.val + ' '
+ else:
+ code += t.val
+ prev = t
return code.strip()
# Rename variables in [tokens] to A0, A1, A2,… in order of appearance.
@@ -150,19 +159,20 @@ def rename_vars(tokens, names=None):
# Return a new list.
tokens = list(tokens)
- for i in range(len(tokens)):
- if tokens[i].type == 'PERIOD':
+ for i, t in enumerate(tokens):
+ if t.type == 'PERIOD':
names.clear()
next_id = 0
- elif tokens[i] == Token('VARIABLE', '_'):
- tokens[i] = Token('VARIABLE', 'A{}'.format(next_id))
- next_id += 1
- elif tokens[i].type == 'VARIABLE':
- cur_name = tokens[i].val
- if cur_name not in names:
- names[cur_name] = 'A{}'.format(next_id)
+ elif t.type == 'VARIABLE':
+ if t.val.startswith('_'):
+ tokens[i] = t.clone(val='A{}'.format(next_id))
next_id += 1
- tokens[i] = Token('VARIABLE', names[cur_name])
+ else:
+ cur_name = t.val
+ if cur_name not in names:
+ names[cur_name] = 'A{}'.format(next_id)
+ next_id += 1
+ tokens[i] = t.clone(val=names[cur_name])
return tokens
# Helper function to remove trailing punctuation from lines and rename
@@ -203,14 +213,10 @@ def map_vars(a, b, tokens, variables):
for i, formal_name in enumerate(remaining_formal):
mapping[formal_name] = remaining_actual[i]
- return [t if t.type != 'VARIABLE' else clone_token(t, val=mapping[t.val]) for t in b]
+ return [t if t.type != 'VARIABLE' else t.clone(val=mapping[t.val]) for t in b]
# Basic sanity check.
if __name__ == '__main__':
- code = 'dup([H|T], [H1|T1]) :- dup(T1, T2). '
- lines, rules = decompose(code)
- print(compose(lines, rules))
-
var_names = {}
before = rename_vars(tokenize("dup([A0|A1], [A2|A3])"), var_names)
after = rename_vars(tokenize("dup([A0|A1], [A5, A4|A3])"), var_names)