summaryrefslogtreecommitdiff
path: root/monkey
diff options
context:
space:
mode:
Diffstat (limited to 'monkey')
-rw-r--r--monkey/__init__.py3
-rw-r--r--monkey/edits.py12
-rwxr-xr-xmonkey/test.py4
3 files changed, 9 insertions, 10 deletions
diff --git a/monkey/__init__.py b/monkey/__init__.py
index 83ea215..fbb4535 100644
--- a/monkey/__init__.py
+++ b/monkey/__init__.py
@@ -83,7 +83,8 @@ def fix(code, edits, test, timeout=30, debug=False):
print('{}: {} → {}'.format(idx, stringify(a), stringify(b)))
# If the code is correct, we are done.
- if test(code):
+ n_correct, n_all = test(code)
+ if n_correct == n_all:
return code, path, total_time, n_tested
n_tested += 1
diff --git a/monkey/edits.py b/monkey/edits.py
index ae44301..73f692d 100644
--- a/monkey/edits.py
+++ b/monkey/edits.py
@@ -234,13 +234,11 @@ if __name__ == '__main__':
# Check for cached results.
normal_code = stringify(rename_vars(tokenize(code)))
code_key = (normal_code, tuple(dependencies))
- if code_key in test_results[pid]:
- return test_results[pid][code_key]
-
- aux_code = '\n' + solutions_for_problems(problem.language, dependencies) + '\n' + facts
- correct, hints = problem_module.test(code, aux_code)
- test_results[pid][code_key] = correct
- return correct
+ if code_key not in test_results[pid]:
+ aux_code = '\n' + solutions_for_problems(problem.language, dependencies) + '\n' + facts
+ n_correct, n_all, _ = problem_module.test(code, aux_code)
+ test_results[pid][code_key] = (n_correct, n_all)
+ return test_results[pid][code_key]
print('Analyzing traces for {}… '.format(problem.identifier), end='', flush=True)
print('{} traces… '.format(len(solutions)), end='', flush=True)
diff --git a/monkey/test.py b/monkey/test.py
index 1d82c33..bb49948 100755
--- a/monkey/test.py
+++ b/monkey/test.py
@@ -47,8 +47,8 @@ def test(code):
dependencies = sorted([p[2:] for p in other_problems
if p.identifier in used_predicate_identifiers])
aux_code = '\n' + solutions_for_problems('prolog', dependencies) + '\n' + facts
- correct, hints = problem_module.test(code, aux_code)
- return correct
+ n_correct, n_all, _ = problem_module.test(code, aux_code)
+ return n_correct, n_all
traces = [s.trace for s in Solution.filter(problem_id=problem.id)]