How to write Asteroids without Casting
Since very early in my programming life, I have been writing Asteroids clones. I've done it many times, and I pretty much always encounter the same obstacle: how do you use polymorphism when there are two objects with different dynamic types interacting with each other? Actually the question comes up in a lot of contexts, not just Asteroids, but Asteroids is a fun example. Anyway, I recently found a new approach (new to me), and I really like it, but I can't really explain it in just one sentence, so take a deep breath, and come with me on an Asteroids endevaor.
So, we're writing Asteroids using instances of C++ classes to represent the objects in the game, something like this:
class Object
{
};
class Ship : public Object
{
};
class Bullet : public Object
{
};
class Asteroid : public Object
{
};
There's a bunch of objects floating around in space, and when two of them hit each other, something should happen, so it seems like there should be function like:
class Object
{
public:
virtual void handleCollision(Object* other) = 0;
};
The subclasses override handleCollision()
to do whatever it does when things collide, like so:
class Object
{
public:
virtual void handleCollision(Object* other) = 0;
};
class Ship : public Object
{
public:
virtual void handleCollision(Ship* other);
virtual void handleCollision(Bullet* other);
virtual void handleCollision(Asteroid* other);
};
class Bullet : public Object
{
public:
virtual void handleCollision(Ship* other);
virtual void handleCollision(Bullet* other);
virtual void handleCollision(Asteroid* other);
};
class Asteroid : public Object
{
public:
virtual void handleCollision(Ship* other);
virtual void handleCollision(Bullet* other);
virtual void handleCollision(Asteroid* other);
};
And then you'd have a loop in the game that checks collisions:
for (Object* A : objects)
for (Object* B : objects)
{
if( A->collidesWith(B) )
{
A->handleCollision(B);
}
}
And then polymorphism happens and the right collide function gets called, right?
Well, no. Or rather, it's not that simple. When you call this:
A->handleCollisions(B);
A
and B
are both static type Object
. If A
's dynamic type is Asteroid
and B
's dynamic type is Bullet
, C++ will look in A
's virtual-table for a function like this:
void Asteroid::handleCollisions(Object*)
{
}
But it will not look at B
's dynamic type, and then search A
's virtual-table for a function like this:
void Asteroid::handleCollisions(Bullet*)
{
// Asteroid hit bullet go boom!
}
So, how do you make handleCollisions
take into account the dynamic type of both this
and the function argument? It doesn't seem like you can.
I've written stuff like this many times, and around now is when I usually give up and just tag my data:
enum Type
{
kShip,
kAsteroid,
kBullet
};
class Object
{
Type type;
virtual void handleCollision(Object* other) = 0;
};
class Ship : public Object
{
Ship() : type(kShip) {}
};
class Bullet : public Object
{
Bullet() : type(kBullet) {}
};
class Asteroid : public Object
{
Asteroid() : type(kAsteroid) {}
};
And then make handleCollision()
do it's thing by switching on type:
void Asteroid::handleCollisions(Object* other)
{
switch(other->type)
{
case kShip:
// Ship explode lose life / game over
break;
case kAsteroid:
// Asteroids pass like ships in the night
break;
case kBullet:
// Asteroid hit bullet go boom!
break;
}
}
That gets the job done, but there's something unsettling about it. The virtual-table already encodes the dynamic type of the object, and here I am doing it again. Another way would be to use a dynamic cast. That would eliminate the redundant data, but it would feel even worse.
Let's look at this line again:
A->handleCollisions(B);
That will look to the dynamic type of A
, but not B
, so it will call this:
void Asteroid::handleCollisions(Object* other)
{
}
What if that function, instead of trying to figure out what type other
is, called a another function on other
and let C++ figure it out:
void Asteroid::handleCollisions(Object* other)
{
other->back_handleCollisions(this);
}
C++ will then look into other
's virtual-table and find this function:
void Bullet::back_handleCollisions(Asteroid* other)
{
other->handleCollisions(this);
}
Which calls back to handleCollisions, this time with the static type Bullet
in the argument so it goes here:
void Asteroid::handleCollisions(Bullet* other)
{
// Asteroid hit bullet go boom!
}
And JUST LIKE THAT, we're where we want to be, and we didn't have to switch or cast or do any tags!
Of course, doing that for all the classes would be a lot of tedious C++ code to write.
That's why I wrote a python script to generate it!
import sys
import json
with open(sys.argv[1]) as f:
s = f.read()
config = json.loads(s)
classes = config["classes"]
functions = config["functions"]
baseClass = config["baseType"]
returnType = config["returnType"]
back_declaration_template = "virtual ReturnType back_foo(Secondary*);"
back_definition_template = ""
if returnType == "void":
back_definition_template = "ReturnType Primary::back_foo(Secondary* _) { _->foo(this); }"
else:
back_definition_template = "ReturnType Primary::back_foo(Secondary* _) { return _->foo(this); }"
declaration_template = "virtual ReturnType foo(Secondary*);"
definition_template = ""
if returnType == "void":
definition_template = "ReturnType Primary::foo(Secondary* _) { _->back_foo(this); }"
else:
definition_template = "ReturnType Primary::foo(Secondary* _) { return _->back_foo(this); }"
body_template = "ReturnType Primary::foo(Secondary* _)\n{\n defaultCode\n}\n"
forward_declaration_template = "class Primary;"
class_declaration_template = """
class Primary : public Base
{
public:
definitions
};""".replace("Base", baseClass)
base_class_declaration_template = """
class Base
{
public:
definitions
};""".replace("Base", baseClass)
class_declarations = []
def make_macro():
prototypes = []
for foo in functions:
prototypes.append(declaration_template \
.replace("ReturnType", returnType) \
.replace("Secondary", baseClass) \
.replace("foo", foo)) \
for secondary in classes:
prototypes.append( back_declaration_template \
.replace("ReturnType", returnType) \
.replace("foo", foo) \
.replace("Secondary", secondary) )
for secondary in classes:
prototypes.append( declaration_template \
.replace("ReturnType", returnType) \
.replace("foo", foo) \
.replace("Secondary", secondary) )
return " " + "\\\n ".join(prototypes)
macroName = config["macro"]
def make_base_class_declaration():
prototypes = []
for foo in functions:
prototypes += [declaration_template \
.replace("ReturnType", returnType) \
.replace("Secondary", baseClass) \
.replace("foo", foo) \
.replace(";", " = 0;")]
for secondary in classes:
prototypes.append( back_declaration_template \
.replace("ReturnType", returnType) \
.replace("foo", foo) \
.replace("Secondary", secondary) \
.replace(";", " = 0;") )
return base_class_declaration_template \
.replace("definitions", " " + "\n ".join(prototypes) )
def make_primary_class_declarations():
declarations = []
for primary in classes:
declarations.append(class_declaration_template \
.replace("definitions", macroName) \
.replace("Primary", primary)
)
return "\n".join(declarations)
def make_class_declarations():
return make_base_class_declaration() + "\n"\
+ make_primary_class_declarations()
def make_level_one_source():
bodies = []
for foo in functions:
for primary in classes:
bodies.append(definition_template \
.replace("ReturnType", returnType) \
.replace("foo", foo) \
.replace("Primary", primary) \
.replace("Secondary", baseClass))
for secondary in classes:
bodies.append(back_definition_template \
.replace("ReturnType", returnType) \
.replace("foo", foo) \
.replace("Primary", primary) \
.replace("Secondary", secondary))
return "\n".join(bodies)
defaultCode = config["defaultCode"]
def make_source():
bodies = []
for foo in functions:
for primary in classes:
for secondary in classes:
bodies.append(body_template \
.replace("ReturnType", returnType) \
.replace("foo", foo) \
.replace("Primary", primary) \
.replace("Secondary", secondary) \
.replace("defaultCode", defaultCode))
return "\n".join(bodies)
def make_forward_declarations():
declarations = [forward_declaration_template \
.replace("Primary", baseClass)]
for c in classes:
declarations.append(forward_declaration_template.replace("Primary", c))
return "\n".join(declarations)
headerName = config["header"]
sourceName = config["source"]
namespaceName = config["namespace"]
with open(headerName, "w") as f:
headerOnceConstant = headerName.replace(".","_")
f.write("#ifndef _" + headerOnceConstant + "_\n")
f.write("#define _" + headerOnceConstant + "_\n\n")
f.write("namespace " + namespaceName + "\n{\n")
f.write(make_forward_declarations() + "\n\n")
f.write("#define " + macroName + " \\\n" + make_macro() + "\n\n")
f.write(make_class_declarations() + "\n\n")
f.write("\n}\n")
f.write("\n#endif\n")
with open(sourceName, "w") as f:
f.write("#include \"" + headerName + "\"\n")
f.write("namespace " + namespaceName + "\n")
f.write("{\n")
f.write(make_source() + "\n")
f.write(make_level_one_source() + "\n\n")
f.write("}")
Run with this config file:
{
"classes" : [
"Asteroid",
"Ship",
"Bullet"
]
, "baseType" : "Object"
, "returnType" : "void"
, "functions" : [
"handleCollision"
]
, "defaultCode" : ""
, "namespace" : "oids"
, "macro" : "ASTEROID_DEFINITIONS"
, "header" : "asteroids.h"
, "source" : "asteroids.cpp"
}
To get this header:
#ifndef _asteroids_h_
#define _asteroids_h_
namespace oids
{
class Object;
class Asteroid;
class Ship;
class Bullet;
#define ASTEROID_DEFINITIONS \
virtual void handleCollision(Object*);\
virtual void back_handleCollision(Asteroid*);\
virtual void back_handleCollision(Ship*);\
virtual void back_handleCollision(Bullet*);\
virtual void handleCollision(Asteroid*);\
virtual void handleCollision(Ship*);\
virtual void handleCollision(Bullet*);
class Object
{
public:
virtual void handleCollision(Object*) = 0;
virtual void back_handleCollision(Asteroid*) = 0;
virtual void back_handleCollision(Ship*) = 0;
virtual void back_handleCollision(Bullet*) = 0;
};
class Asteroid : public Object
{
public:
ASTEROID_DEFINITIONS
};
class Ship : public Object
{
public:
ASTEROID_DEFINITIONS
};
class Bullet : public Object
{
public:
ASTEROID_DEFINITIONS
};
}
#endif
And this source file:
#include "asteroids.h"
#include <stdio.h>
namespace oids
{
void Asteroid::handleCollision(Asteroid* _)
{
printf( "Asteroid pass asteroid in night\n" );
}
void Asteroid::handleCollision(Ship* _)
{
printf( "Asteroid hit ship BOOM split make smaller asteroids\n" );
}
void Asteroid::handleCollision(Bullet* _)
{
printf( "Asteroid hit bullet BOOM split make smaller asteorids\n" );
}
void Ship::handleCollision(Asteroid* _)
{
printf( "Ship hit asteroid BOOM ship go bye-bye\n" );
}
void Ship::handleCollision(Ship* _)
{
printf( "Ship pass ship in night (they're allies)\n" );
}
void Ship::handleCollision(Bullet* _)
{
printf( "Ship hits bullet. FRIENDLY FIRE!\n" );
}
void Bullet::handleCollision(Asteroid* _)
{
printf( "Bullet hit asteroid !\n" );
}
void Bullet::handleCollision(Ship* _)
{
printf( "Bullet hits ship. FRIENDLY FIRE!\n" );
}
void Bullet::handleCollision(Bullet* _)
{
printf( "Bullet hits bullet. Amazing marksmanship!\n" );
}
void Asteroid::handleCollision(Object* _) { _->back_handleCollision(this); }
void Asteroid::back_handleCollision(Asteroid* _) { _->handleCollision(this); }
void Asteroid::back_handleCollision(Ship* _) { _->handleCollision(this); }
void Asteroid::back_handleCollision(Bullet* _) { _->handleCollision(this); }
void Ship::handleCollision(Object* _) { _->back_handleCollision(this); }
void Ship::back_handleCollision(Asteroid* _) { _->handleCollision(this); }
void Ship::back_handleCollision(Ship* _) { _->handleCollision(this); }
void Ship::back_handleCollision(Bullet* _) { _->handleCollision(this); }
void Bullet::handleCollision(Object* _) { _->back_handleCollision(this); }
void Bullet::back_handleCollision(Asteroid* _) { _->handleCollision(this); }
void Bullet::back_handleCollision(Ship* _) { _->handleCollision(this); }
void Bullet::back_handleCollision(Bullet* _) { _->handleCollision(this); }
}
I had to insert the printfs by hand.
Then run with this main:
#include <vector>
#include "asteroids.h"
using namespace oids;
int main()
{
Asteroid* asteroid = new Asteroid;
Ship* ship = new Ship;
Bullet* bullet = new Bullet;
std::vector<Object*> objects;
objects.push_back(asteroid);
objects.push_back(ship);
objects.push_back(bullet);
for (std::vector<Object*>::iterator A = objects.begin(); A < objects.end(); A++)
for (std::vector<Object*>::iterator B = objects.begin(); B < objects.end(); B++)
{
(*A)->handleCollision(*B);
}
delete asteroid;
delete ship;
delete bullet;
return 0;
}
And lo.
Asteroid pass asteroid in night
Asteroid hit ship BOOM split make smaller asteroids
Asteroid hit bullet BOOM split make smaller asteorids
Ship hit asteroid BOOM ship go bye-bye
Ship pass ship in night (they're allies)
Ship hits bullet. FRIENDLY FIRE!
Bullet hit asteroid !
Bullet hits ship. FRIENDLY FIRE!
Bullet hits bullet. Amazing marksmanship!
There you have it, asteroids without casting (or tagging (which feels like casting)).