This article is a technical counterpart of my previous post Finding duplicated code with tools from your CS course. It is deliberately written in a terse manner, and I’m not going to hold your hand. Consider reading the previous post first and coming back here later.
IntroductionGiven a
Terms are
This article describes:
Over the course of the article, we use Python-like pseudocode. A common pattern in the pseudocode is using dict[...] to associate temporary data with terms or variables. This should be read as using linear arrays addressed by unique term/variable indices, or alternatively ad-hoc fields in data types, as opposed to a hash table access.
Prior artOur first algorithm is an adaptation of the algorithm developed in:
Krzysztof Maziarz, Tom Ellis, Alan Lawrence, Andrew Fitzgibbon, and Simon Peyton Jones. 2021. Hashing modulo alpha-equivalence. In Proceedings of the 42nd ACM SIGPLAN International Conference on Programming Language Design and Implementation (PLDI 2021). Association for Computing Machinery, New York, NY, USA, 960–973. https://doi.org/10.1145/3453483.3454088
Maziarz et al.'s algorithm has
To the best of our knowledge, our algorithm for validating hashes is novel.
The third algorithm is an adaptation of:
Lasse Blaauwbroek, Miroslav Olšák, and Herman Geuvers. 2024. Hashing Modulo Context-Sensitive α-Equivalence. Proc. ACM Program. Lang. 8, PLDI, Article 229 (June 2024), 24 pages. https://doi.org/10.1145/3656459
Our algorithm has the same asymptotic complexity as described in the paper, but is adjusted to non-context-sensitive
Michalis Christou, Maxime Crochemore, Tomáš Flouri, Costas S. Iliopoulos, Jan JanoušEk, BořIvoj Melichar, and Solon P. Pissis. 2012. Computing all subtree repeats in ordered trees. Inf. Process. Lett. 112, 24 (December, 2012), 958–962. https://doi.org/10.1016/j.ipl.2012.09.001
HashingWe start with a named form, where all variables are accessed by names. This ensures that the innermost terms are already in the locally nameless form. We then compute the locally nameless forms of other terms recursively:
However, some string hashes, most commonly rolling hashes, allow the hash to be recomputed efficiently if part of the string is changed. Adjusting
A character at index
An implementation of the algorithm is reproduced below. To avoid handling parentheses, we implicitly translate terms to postfix notation, denoting calls with !.
range_of_expr: dict[Expr, tuple[int, int]] = {}
variable_nesting: dict[VariableName, int] = {}
variable_accesses: dict[VariableName, list[tuple[int, int]]] = {}
current_location: int = 0
def collect_locations(expr: Expr, nesting: int):
global current_location
start = current_location
match expr:
case Variable(x):
# x
current_location += 1
variable_accesses[x].append((start, nesting - variable_nesting[x]))
case Abstraction(x, body):
# body, \
variable_nesting[x] = nesting
variable_accesses[x] = []
collect_locations(body, nesting + 1)
current_location += 1
case Application(f, a):
# f, a, !
collect_locations(f, nesting)
collect_locations(a, nesting)
current_location += 1
end = current_location
range_of_expr[expr] = (start, end)
collect_locations(root, 0)
powers_of_b: list[int] = [1]
# Computes `h * b ** count % p` in amortized constant time.
def shift(h: int, count: int) -> int:
while len(powers_of_b) <= count:
powers_of_b.append(powers_of_b[-1] * b % p)
return h * powers_of_b[count] % p
# Functions capable of hashing variable names, de Bruijn indices, and the characters \, ! without
# collisions.
def hash_lambda() -> int: return 1
def hash_call() -> int: return 2
def hash_variable_name(x: VariableName) -> int: return x.int_id * 2 + 3
def hash_de_bruijn_index(i: int) -> int: return i * 2 + 4
def calculate_hashes(expr: Expr) -> int:
start, end = range_of_expr[expr]
match expr:
case Variable(x):
h = hash_variable_name(x)
case Abstraction(x, body):
h = calculate_hashes(body) + shift(hash_lambda(), end - start - 1)
for location, de_bruijn_index in variable_accesses[x]:
h += shift(
hash_de_bruijn_index(de_bruijn_index) - hash_variable_name(x),
location - start,
)
h %= p
case Application(f, a):
h = (
calculate_hashes(f)
+ shift(calculate_hashes(a), range_of_expr[a][0] - start)
+ shift(hash_call(), end - start - 1)
)
h %= p
print("The hash of", expr, "is", h)
return h
calculate_hashes(root)
The probabilistic guarantees of this scheme depend entirely on the choice of the hash. The collision probability of rolling hashes typically scales linearly with the length of the input. In this case, the length of the input exactly matches the number of subterms
For polynomial hashes, the collision probability is
Since there are
VerificationTo verify that the computed hashes don’t produce collisions, we group terms by their hashes and validate that in each group of size
We now introduce some terminology.
We call subterms with non-unique hashes (i.e. subterms that are not alone in their groups) pivots.
We say an optimized predicate for
For a term
For a term
Note that
We write
We rely on the following propositions:
If
If
If
If
If
If
If
If
If a path
To verify
If
Otherwise, we look for copies of
Note that in the latter case, if
An implementation of this algorithm follows.
def compare(u1: Term, t1: Term, u2: Term, t2: Term, h21: dict[int, int]) -> bool:
if (u2 is not t2) and (u2 is a pivot):
if there is any term alpha-equivalent to u2 outside t2:
return hash[u1] == hash[u2]
if hash[u2] in h21:
return h21[hash[u2]] == hash[u1]
h21[hash[u2]] = hash[u1]
match (u1, u2):
case (Variable(x1), Variable(x2)):
x1 = (x1 as de Bruijn index) if x1 defined within t1 else (x1 as name)
x2 = (x2 as de Bruijn index) if x2 defined within t2 else (x2 as name)
return x1 == x2
case (Application(u11, u12), Application(u21, u22)):
return compare(u11, t1, u21, t2, h21) and compare(u12, t1, u22, t2, h21)
case (Abstraction(_, v1), Abstraction(_, v2)):
return compare(v1, t1, v2, t2, h21)
case _:
return False
def verify_hashes():
# Not implemented: validate that, within each class, all terms have the same size.
# Not implemented: sort classes by increasing size of terms.
for class_members in classes:
t1 = class_members[0]
for t2 in class_members[1:]:
if not compare(t1, t1, t2, t2, {}):
return False
return True
It turns out that this algorithm takes linear time. We will now prove this.
The pair compare. Split such invocations into two categories depending on whether the path
Consider any path
Suppose that there are two pairs that pay with the same
This proves that the mapping
Notes:
The algorithm is linear even under the presence of collisions. The mapping
The arguments compare are not taken into consideration during the proof. compare can be transformed to serialize, which lists non-entered terms as either hash values or backrefs, followed by an assertion that the serialized strings of all terms within a group are equal. This still takes linear time because the total string length is linear. This algorithm can resolve hash collisions locally by splitting groups in expected linear time, but is more complex and requires more memory.
The only reason a serialize-based algorithm needs to be pre-fed with hashes is to determine which terms are pivots – the exact hashes or even collisions between pivots are inconsequential. Pivots mostly matter because of the assumption that the path
ClassesThe high-level overview of our deterministic algorithm for computing equivalence classes is as follows.
We start with the root term
The algorithms we propose build
Our algorithm for building
We start by adding an exact copy
size: dict[Term, int] = {}
max_index: dict[Term, int] = {}
forest: list[Term] = []
term_to_node: dict[Term, Term] = {}
def build_forest(t: Term) -> int:
t_prime = deep_copy(t)
forest.append(t_prime)
compute_term_properties(t)
recurse(t, t_prime, size[t])
def compute_term_properties(t: Term):
match t:
case Variable(x):
size[t] = 1
if x is a de Bruijn index:
max_index[t] = x
else: # x is a variable name
max_index[t] = -1
case Abstraction(x, u):
compute_term_properties(u)
size[t] = 1 + size[u]
max_index[t] = max_index[u] - 1
case Application(t1, t2):
compute_term_properties(t1)
compute_term_properties(t2)
size[t] = 1 + size[t1] + size[t2]
max_index[t] = max(max_index[t1], max_index[t2])
def recurse(t: Term, t_prime: Term, root_size: int):
if max_index[t] < 0: # locally closed
term_to_node[t] = t_prime
else:
if 2 * size[t] < root_size: # small
build_forest(t)
return
match (t, t_prime):
case (Abstraction(x, u), Abstraction(_, u_prime)):
replace_mentions(x) # not shown: replace all mentions of x in t with names
recurse(u, u_prime, root_size)
case (Application(u1, u2), Application(u1_prime, u2_prime)):
recurse(u1, u1_prime, root_size)
recurse(u2, u2_prime, root_size)
build_forest(root_t)
Since build_forest takes
To calculate syntactic equivalence classes of subterms of
@dataclass
class SizeGroup:
# variable accesses are not stored explicitly
abstractions: list[Term]
applications: list[Term]
by_size: list[SizeGroup] = [SizeGroup([], []) for _ in range(n + 1)]
node_classes: dict[Term, int] = {}
next_class: int = n_variables * 2 # leave space to easily number variable accesses
def populate_size_groups(t: Term) -> int:
match t:
case Variable(x):
# Populate classes of leaf nodes immediately.
if x is a de Bruijn index:
node_classes[t] = x
else: # x is a variable name, assuming an integer from 0 to `n_variables - 1`
node_classes[t] = n_variables + x
return 1
case Abstraction(_, u):
size = 1 + populate_size_groups(u)
by_size[size].abstractions.append(t)
return size
case Application(t1, t2):
size = 1 + populate_size_groups(t1) + populate_size_groups(t2)
by_size[size].applications.append(t)
return size
for t in forest:
populate_size_groups(t)
temporary_storage: list[list[Term]] = []
def group_by(nodes: list[Term], key: Callable[[Term], int]) -> list[list[Term]]:
present_keys: list[int] = []
for t in nodes:
k = key(t)
while k >= len(temporary_storage): # amortized O(|F|)
temporary_storage.append([])
if not temporary_storage[k]:
present_keys.append(k)
temporary_storage[k].append(t)
result = [temporary_storage[k] for k in present_keys]
for k in present_keys:
temporary_storage[k] = []
return result
for group in by_size:
for subgroup in group_by(group.abstractions, lambda t: node_classes[t.body]):
for t in subgroup:
node_classes[t] = next_class
next_class += 1
for subgroup1 in group_by(group.applications, lambda t: node_classes[t.function]):
for subgroup2 in group_by(subgroup1, lambda t: node_classes[t.argument]):
for t in subgroup2:
node_classes[t] = next_class
next_class += 1
Term classes can then be populated from node classes.
term_classes: dict[Term, int] = {}
def populate_term_classes(t: Term):
if t in term_to_node:
term_classes[t] = node_classes[term_to_node[t]]
else: # guaranteed to be unique (non-locally-closed and "big")
term_classes[t] = next_class
next_class += 1
match t:
case Abstraction(_, u):
populate_term_classes(u)
case Application(t1, t2):
populate_term_classes(t1)
populate_term_classes(t2)
populate_term_classes(root_t)