tldr
node accepts visitor, visitor visits node
Overview
Design patterns were (and still are) something I found hard to fully understand
ever since I started to learn them. Either I found myself memorizing the
pattern and “think” I understood it, or I found myself “knowing” some of it but
still don’t fully understand them.
Today I’d like to write down something of my current understanding of the
Visitor Pattern. This is more
of my confusions and ah-ha moments when I was learning this
(maybe for the 5th time cause I keep thinking I got it but then actually forget about it)
Why is this pattern needed?
Took me a while to really understand why do we need a Visitor Pattern.
Polymorphism allows you to call different code based on the instance, where
visitor pattern allows you to call different code based on the instance and
a visitor.
This allows you to decouple data from behavior: the data stays in the class, and the behavior lives in the visitor.
When the class tree is big (70 concrete nodes) and you need adhoc behaviors, this pattern helps to keep your code clean.
One of the classic/standard places to use Visitor Pattern is when dealing with Abstract Syntax Tree.
In the classic shape example below .area() calls the corresponding method
because it’s based on the instance of the object

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
|
import math
class Shape:
def __init__(self):
pass
def area(self):
pass
class Circle(Shape):
def __init__(self, _radius):
self.radius = _radius
def area(self):
return math.pi * self.radius ** 2
class Square(Shape):
def __init__(self, _size):
self.size = _size
def area(self):
return self.size ** 2
circle = Circle(2)
square = Square(2)
shapes = [circle, square]
for shape in shapes:
print(shape.area())
|
All above seems good enough, if I want to get the xml export string of the shape I
would just add a new method to Shape and implement them on all the subclasses.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
|
class Shape:
def __init__(self):
pass
def area(self):
pass
class Circle(Shape):
def __init__(self, _radius):
self.radius = _radius
def area(self):
return math.pi * self.radius ** 2
def xml_export(self):
return f"<circle radius={self.radius} />"
class Square(Shape):
def __init__(self, _size):
self.size = _size
def area(self):
return self.size ** 2
def xml_export(self):
return f"<square size={self.size} />"
|
This however has two problems
- The classes become massive and bloated. A
Circle class shouldn’t need to know about XML, JSON, or database connections. It violates the Single Responsibility Principle.
- The code change needs to happen on the concrete types, say you want to
separate your library code and client code, you probably don’t want to make
changes on your
Circle class in library, especially if this is a real
library, you won’t be able to accommodate for every client usage. Something Something Open Close Principle
How about we just pull all calculations out to functions
What if we could move the logic to somewhere outside of the concrete classes?
Let’s see what if we just pull the area method out from all shapes and do it with a function
1
2
3
4
5
6
7
8
9
10
11
12
13
|
def area(shape):
res = 0
if isinstance(shape, Circle):
res = shape.size ** 2
elif isinstance(shape, Square):
res = shape.size ** 2
return res
circle = Circle(2)
square = Square(2)
shapes = [circle, square]
for shape in shapes:
print(area(shape))
|
This is kinda OK but for every new shape we add we need to remember to update area(), which might be a problem if there are a lot of different functions, say we have an xmlexport(shape) and a jsonexport(shape) then we would need to update in 3 different functions. Another issue is type casting might be heavy in some languages vs utilizing the languages builtin polymorphism.
Visitor Pattern
The following code isn’t that useful, but it’s trying to mimic the area()
functionality using an AreaVisitor. Note that accept() and visit_xxx()
are returning values to be printed in the for loop.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|
import math
class AreaVisitor:
def __init__(self):
pass
def visit_circle(self, shape):
return math.pi * shape.radius ** 2
def visit_square(self, shape):
return shape.size * shape.size
class Shape:
def __init__(self):
pass
def accept(self, visitor):
pass
class Circle(Shape):
def __init__(self, _radius):
self.radius = _radius
def accept(self, visitor):
return visitor.visit_circle(self)
class Square(Shape):
def __init__(self, _size):
self.size = _size
def accept(self, visitor):
return visitor.visit_square(self)
circle = Circle(2)
square = Square(2)
shapes = [circle, square]
area_visitor = AreaVisitor()
for shape in shapes:
print(shape.accept(area_visitor))
|
The concrete classes provides all information needed and no behavior. Each of them has an accept(visitor) method, which takes in a visitor and calls the corresponding visitor function with self.
1
2
3
4
5
6
7
8
|
shape = Circle(2)
area_visitor = AreaVisitor()
shape.accept(area_visitor)
"""
calls Circle::accept with current instance as self, which is a circle (first dispatch)
visitor.visit_circle(self) # calls area_visitor.visit_circle(shape) where shape is the circle (second dispatch)
"""
|
Let’s also try to add in the TagExportVisitor
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|
class TagExportVisitor:
def __init__(self):
pass
def visit_circle(self, shape):
return f"<circle radius={shape.radius} />"
def visit_square(self, shape):
return f"<square size={shape.size} />"
circle = Circle(2)
square = Square(2)
shapes = [circle, square]
tag_export_visitor = TagExportVisitor()
for shape in shapes:
print(shape.accept(tag_export_visitor))
"""
Outputs
<circle radius=2 />
<square size=2 />
"""
|
The Classic Visitor Pattern
In the classic visitor pattern accept() and visit() do not return anything.
Visitors are stateful, they maintain the result.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|
import math
class TotalAreaVisitor:
def __init__(self):
self.total_area = 0
pass
def visit_circle(self, shape):
self.total_area += math.pi * shape.radius ** 2
def visit_square(self, shape):
self.total_area += shape.size * shape.size
class Shape:
def __init__(self):
pass
def accept(self, visitor):
pass
class Circle(Shape):
def __init__(self, _radius):
self.radius = _radius
def accept(self, visitor):
visitor.visit_circle(self)
class Square(Shape):
def __init__(self, _size):
self.size = _size
def accept(self, visitor):
visitor.visit_square(self)
circle = Circle(2)
square = Square(2)
shapes = [circle, square]
total_area_visitor = TotalAreaVisitor()
for shape in shapes:
shape.accept(total_area_visitor)
print(total_area_visitor.total_area)
|
Gemini:
Why GoF Designed It This Way
In 1994, strongly typed languages like C++ didn’t have generics (). If accept() returned a value, you would have to hardcode the return type
(e.g., int). That would mean every visitor would be forced to return an int.
By making the return type void (None) and storing the result as a class member, GoF decoupled the operation from the return type.
TotalAreaVisitor can store a float, while TagExportVisitor can store a List[str], and the Shape interface doesn’t have to care at all.
My feeling is we don’t need to strictly abide by the “classic” version, do whatever you need to do. For instance Go’s ast doesn’t even use the visit_xxx() interface, instead it just does visit(node) and type switches inside. You can also see the evaluate example below returning results of sub expressions, if strictly abiding the no return rule, the visitor would need to maintain the stack in its state.
What does this Accept even mean
Took me a while to realize that Accept is the instance “accepting” the visitor into the “house”.
Go ast’s Visitor
Golang’s Visitor interface doesn’t really go with the classic pattern. Gemini said this is because
- Go prefers smaller interfaces than large interfaces
- Had Go’s visitor interface go with the classic version the interface would end up too big and to implement a visitor which only wants to check the comments you need to stub out all other methods
What is this Walk thing had to do with Visitor Pattern
Visitor Pattern doesn’t really need to be Walking anything, in the examples above, the “Walk” is essentially the for loop.
However one of the places visitor pattern works well is when traversing Abstract Syntax Tree, which essentially is a tree composed of different nodes.
So the “Visitor walks through the nodes in the tree”, and each of the nodes “accept” the visitor.
Evaluating an AST
I asked Gemini to generate an example of using the Visitor Pattern to evaluate and print an AST tree that handles Add and Multiply
The recursive part was a bit confusing for me but once I accepted the fact that expr.left.accept(self) would dispatch to evaluator.visit_add when left is Add it was clearer.
+
/ \
1 2
The tree above would have the following calls
Add(1, 2).accept(evaluator)
evaluator.visit_add( Add(1, 2) )
# Evaluates the left side
Number(1).accept(evaluator)
evaluator.visit_number( Number(1) ) -> returns 1
# Evaluates the right side
Number(2).accept(evaluator)
evaluator.visit_number( Number(2) ) -> returns 2
returns 1 + 2 (which is 3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
from abc import ABC, abstractmethod
# ==========================================
# 1. The AST Node Interface
# ==========================================
class Expr(ABC):
@abstractmethod
def accept(self, visitor: 'Visitor'):
pass
# ==========================================
# 2. Concrete AST Nodes (Data Structures)
# ==========================================
class Number(Expr):
def __init__(self, value: float):
self.value = value
def accept(self, visitor: 'Visitor'):
return visitor.visit_number(self)
class Add(Expr):
def __init__(self, left: Expr, right: Expr):
self.left = left
self.right = right
def accept(self, visitor: 'Visitor'):
return visitor.visit_add(self)
class Multiply(Expr):
def __init__(self, left: Expr, right: Expr):
self.left = left
self.right = right
def accept(self, visitor: 'Visitor'):
return visitor.visit_multiply(self)
# ==========================================
# 3. The Visitor Interface
# ==========================================
class Visitor(ABC):
@abstractmethod
def visit_number(self, expr: Number):
pass
@abstractmethod
def visit_add(self, expr: Add):
pass
@abstractmethod
def visit_multiply(self, expr: Multiply):
pass
# ==========================================
# 4. Concrete Visitors (The Logic)
# ==========================================
class EvaluateVisitor(Visitor):
"""Walks the tree and calculates the mathematical result."""
def visit_number(self, expr: Number):
# Base case: just return the number
return expr.value
def visit_add(self, expr: Add):
# Recursively evaluate the left and right sides, then add them
left_val = expr.left.accept(self)
right_val = expr.right.accept(self)
return left_val + right_val
def visit_multiply(self, expr: Multiply):
# Recursively evaluate the left and right sides, then multiply them
left_val = expr.left.accept(self)
right_val = expr.right.accept(self)
return left_val * right_val
class PrintVisitor(Visitor):
"""Walks the tree and returns a formatted string with parentheses."""
def visit_number(self, expr: Number):
return str(expr.value)
def visit_add(self, expr: Add):
left_str = expr.left.accept(self)
right_str = expr.right.accept(self)
return f"({left_str} + {right_str})"
def visit_multiply(self, expr: Multiply):
left_str = expr.left.accept(self)
right_str = expr.right.accept(self)
return f"({left_str} * {right_str})"
# ==========================================
# 5. Usage
# ==========================================
if __name__ == "__main__":
# Let's build an AST for the equation: (3 + 5) * 2
# Tree structure:
# *
# / \
# + 2
# / \
# 3 5
ast_root = Multiply(
Add(Number(3), Number(5)),
Number(2)
)
# 1. Evaluate the math
evaluator = EvaluateVisitor()
result = ast_root.accept(evaluator)
# 2. Print the equation
printer = PrintVisitor()
equation_string = ast_root.accept(printer)
print(f"Equation: {equation_string}")
print(f"Result: {result}")
|