Handling the Technical Interview
I really like the blog series …ing the technical interview from Aphyr. Besides the humor, I enjoy seeing Turing completeness in parts of systems that many people use but that were not designed to write entire programs in.
Jeg snakker lite norsk and I can’t write anything as funny as the original. But I write CTF challenges from time to time. CTF reverse engineering challenges are the perfect place for implementing such weird things and getting the joy of many people trying to figure it out. Therefore, this is the writeup for the challenge haskell4j that I-Al-Istannen, Hannes and I wrote together. But let me tell you about Java Method Handles first.
The documentation states:
A method handle is a typed, directly executable reference to an underlying method, constructor, field, or similar low-level operation, with optional transformations of arguments or return values.
They are the underling implementation of reflection, because they are a performant way of searching for methods and fields, invoking them and changing them slightly. They operate directly on the bytecode instruction stream. To once again quote from the documentation, this is part of the example:
Object x, y; String s; int i; MethodType mt; MethodHandle mh; MethodHandles.Lookup lookup = MethodHandles.lookup(); // mt is (char,char)String mt = MethodType.methodType(String.class, char.class, char.class); mh = lookup.findVirtual(String.class, "replace", mt); s = (String) mh.invokeExact("daddy",'d','n');
From this, it is not apparent that Java Method handles are Turing complete.
But the MethodHandles
class provides a lot of interesting methods to modify methods, frequently seen in functional languages (hence the name haskell4j), e.g.:
static MethodHandle dropArguments(MethodHandle target, int pos, Class<?>... valueTypes) Produces a method handle which will discard some dummy arguments before calling some other specified target method handle. static MethodHandle dropReturn(MethodHandle target) Drop the return value of the target handle (if any). static MethodHandle filterReturnValue(MethodHandle target, MethodHandle filter) Adapts a target method handle by post-processing its return value (if any) with a filter (another method handle). static MethodHandle filterArguments(MethodHandle target, int pos, MethodHandle... filters) Adapts a target method handle by pre-processing one or more of its arguments, each with its own unary filter function, and then calling the target with each pre-processed argument replaced by the result of its corresponding filter function. static MethodHandle foldArguments(MethodHandle target, int pos, MethodHandle combiner) Adapts a target method handle by pre-processing some of its arguments, starting at a given position, and then calling the target with the result of the pre-processing, inserted into the original sequence of arguments just before the folded arguments.
And most notably, transformations that can be used for control flow. The if:
static MethodHandle guardWithTest(MethodHandle test, MethodHandle target, MethodHandle fallback) Makes a method handle which adapts a target method handle, by guarding it with a test, a boolean-valued method handle.
And a few loop types:
static MethodHandle loop(MethodHandle[]... clauses) Constructs a method handle representing a loop with several loop variables that are updated and checked upon each iteration. static MethodHandle iteratedLoop(MethodHandle iterator, MethodHandle init, MethodHandle body) Constructs a loop that ranges over the values produced by an Iterator<T>. static MethodHandle whileLoop(MethodHandle init, MethodHandle pred, MethodHandle body) Constructs a while loop from an initializer, a body, and a predicate. static MethodHandle countedLoop(MethodHandle start, MethodHandle end, MethodHandle init, MethodHandle body) Constructs a loop that counts over a range of numbers. static MethodHandle doWhileLoop(MethodHandle init, MethodHandle body, MethodHandle pred) Constructs a do-while loop from an initializer, a body, and a predicate.
My idea was to implement an easy “encryption” algorithm that is not immediately obvious to guess just by looking at the outputs. You are given the output and the compiled file and have to figure out the original input.
This is the algorithm interview question as python code:
def reverse'(message):
msg = message[::-1]
if len(msg) != (11 * 2 + 1):
msg[11] = msg[11] ^ 11
return msg
def encrypt(message: str) -> List[int]:
for i in range(message):
message[i] = rotr(val=message[0], by=ord(message[0]))
res = []
while message:
res = append(res, (char) message[0] ^ (char) message[-1] ^ 3)
message = popFirst(reverse'(message))
return res
The only “plain Java” we allow ourselves is this, all we really need is a full operator system, so &
and ^
are enough:
public static int and(int a, int b) {
return a & b;
}
public static int not(int a) {
return ~a;
}
…and turn them into method handles.
public static MethodHandle and() throws Throwable {
return MethodHandles.lookup()
.findStatic(MethodHandleFun.class, "and", methodType(int.class, int.class, int.class));
}
public static MethodHandle not() throws Throwable {
return MethodHandles.lookup()
.findStatic(MethodHandleFun.class, "not", methodType(int.class, int.class));
}
To allow us to work on int
s and get bit level access, and also because Interger.compress
is always a good meme, let’s give ourselves a “shift right by one” method.
public static MethodHandle shiftR1() throws Throwable {
return insertArguments(
MethodHandles.lookup()
.findStatic(Integer.class, "compress", methodType(int.class, int.class, int.class)),
1,
-2
);
}
The first method we will be writing as a helper is function composition. It is much more natural to have this explicit (well, explicit for us, for now) and helps us keep our sanity while writing the challenge.
public static MethodHandle compose(MethodHandle f, MethodHandle g) {
return MethodHandles.foldArguments(
MethodHandles.dropArguments(f, 1, g.type().parameterList()),
g
);
}
This allows for a straight forward definition of or
in terms of and
and not
.
Also, we see filterArguments
in action for the first time.
It maps a function to the arguments of another.
/**
* Bitwise or.
* <pre>{@code
* a ∨ b = ¬(¬a ∧ ¬b) (anything but ¬a and ¬b)
* }</pre>
*/
public static MethodHandle or() throws Throwable {
return compose(not(), filterArguments(and(), 0, not(), not()));
}
While this looks cute, take a XOR. No one can acuse us of obscure undocumented code. There are comments, it’s just that you don’t get them when solving the challenge.
/**
* Implements bitwise xor using not and and.
* <pre>{@code
* a ^ b = ¬(a ∧ b) ∧ ¬(¬a ∧ ¬b)
* }</pre>
*/
public static MethodHandle xor() throws Throwable {
MethodHandle nand = compose(not(), and());
return MethodHandles.permuteArguments(
// x, y, x, y -> res
MethodHandles.filterArguments(
// x, y, !x, !y -> res
collectArguments(
// nand(x, y), !x, !y -> res
collectArguments(
// nand(x, y), nand(!x, !y) -> res
and(),
1,
nand
),
0,
nand
),
2,
not(), not()
),
methodType(int.class, int.class, int.class),
0, 1, 0, 1
);
}
There may or may or may not be easier implementations for any of the functions, but these are the very best for our design goal: A reverse engineering challenge. Hire Hannes as a performance engineer!
An important method to have is comparision for numbers.
Comparision is everywhere.
The simplest thing we can start with is isNotZero
.
With that we get isZero
.
This is also where we get into ifs and loops, pretty intutitive code, right?
/**
* Idea:
* <pre>{@code
* def isNotZero(a):
* for _ in range(32):
* if a: # bit was set (or a set to 1)
* a = 1
* else:
* a = a >>> 1 # try next bit
* return a
* }</pre>
*/
public static MethodHandle isNotZero() throws Throwable {
MethodHandle loopiboy = countedLoop(
// iterations
dropArguments(constant(int.class, 32), 0, int.class),
// init
identity(int.class),
// body. v, i, a
dropArguments(
// v, a
dropArguments(
// v
guardWithTest(
// Convert to boolean for test. If we have a 1, will be true
explicitCastArguments(
identity(int.class),
methodType(boolean.class, int.class)
),
// keep our one value
dropArguments(constant(int.class, 1), 0, int.class),
// test next bit
shiftR1()
),
1,
int.class
),
1,
int.class
)
);
return compose(
explicitCastArguments(identity(int.class), methodType(boolean.class, int.class)),
loopiboy
);
}
public static MethodHandle isZero() throws Throwable {
return explicitCastArguments(
compose(
insertArguments(xor(), 1, 1),
explicitCastArguments(isNotZero(), methodType(int.class, int.class))
),
methodType(boolean.class, int.class)
);
}
By the way, if you are wondering about us defining Java methods and variables here, they won’t make it. Automatic inlining before compilation will take care of this and provide maximum fun for everyone looking at the challenge. Now that we have this basic check, we can implement cool things like arithmetic operations, wow, adding two numbers in Java. We just need to make sure that the right bits meet.
/**
* Implements
* {@snippet :
* int add(int a, int b) {
* while (b != 0) {
* a = a ^ b;
* b = ((~a) & b) << 1;
* }
* return a;
* }
*}
*/
public static MethodHandle add() throws Throwable {
return loop(
new MethodHandle[]{
// ini
dropArguments(MethodHandles.identity(int.class), 1, int.class),
// step
xor(),
// pred. We do not exit here.
// Drop args here are NO-OP without a type, but the loop will drop for you
dropArguments(dropArguments(constant(boolean.class, true), 0), 0),
// fini
dropArguments(identity(int.class), 1, int.class)
},
new MethodHandle[]{
// ini
dropArguments(MethodHandles.identity(int.class), 0, int.class),
// step
compose(shiftL1(), compose(and(), not())),
// pred
dropArguments(isNotZero(), 0, int.class),
// fini
dropArguments(identity(int.class), 1, int.class)
}
);
}
/**
* Implements two's complement negation and then uses add.
*/
public static MethodHandle subtractInt() throws Throwable {
return collectArguments(
add(),
1,
compose(insertArguments(add(), 0, 1), not())
);
}
Also, we are now able to turn shifting right by one into shifting left by one by shifting right by 31. Hire me as a performance engineer!
/**
* Implements this:
* {@snippet lang = java:
* public static int shiftL1(int a) {
* for (int i = 0; i < 31; i++) {
* int l = a & 1;
* a = a >>> 1;
* if (l != 0) {
* a = a | -2147483648;
* }
* }
* return a & ~1;
* }
*}
*/
public static MethodHandle shiftL1() throws Throwable {
// gets (body(v, i, a...)
MethodHandle body = dropArguments( // we don't need `i` and don't supply any `a...`
permuteArguments( // add `l = a & 1` as parameter
filterArguments(
guardWithTest( // do the if check
dropArguments(
isNotZero(),
0,
int.class
),
dropArguments(
compose(
insertArguments(
or(),
0,
-2147483648
),
shiftR1()
),
1,
int.class
), // assume compose(f, g) is g(f(x))
dropArguments(
shiftR1(),
1,
int.class
)
),
1,
insertArguments(and(), 0, 1)
),
methodType(int.class, int.class, int.class),
0, 0
),
1
);
MethodHandle loopiboy = countedLoop(
constant(int.class, 31), // iterations
identity(int.class), // Loop variable (a)
dropArguments(
body,
2,
int.class
)
);
return compose(insertArguments(and(), 0, ~1), loopiboy);
}
Now we need some screws and bolts to work on arrays.
Luckily, the MethodHandles
class has us covered here.
static MethodHandle arrayConstructor(Class<?> arrayClass) Produces a method handle constructing arrays of a desired type, as if by the anewarray bytecode.
static MethodHandle arrayElementGetter(Class<?> arrayClass) Produces a method handle giving read access to elements of an array, as if by the aaload bytecode.
static MethodHandle arrayElementSetter(Class<?> arrayClass) Produces a method handle giving write access to elements of an array, as if by the astore bytecode.
static VarHandle arrayElementVarHandle(Class<?> arrayClass) Produces a VarHandle giving access to elements of an array of type arrayClass.
static MethodHandle arrayLength(Class<?> arrayClass) Produces a method handle returning the length of an array, as if by the arraylength bytecode.
/**
* Creates a method handle taking an array and a char and appends the char to the array,
* allocating a new one with one more entry.
* <p>
* {@snippet :
* private char[] append(char[] src, char c) {
* int end = iterations(src);
* char[] v = new char[src.length + 1];
* for (int i = 0; i < end; i++) {
* v = body(i, src, v, c);
* }
* return v;
* }
* private char[] body(int index, char[] src, char[] dest, char c) {
* tryFilter(index, src, dest, c);
* return dest;
* }
* private static void tryFilter(int index, char[] src, char[] dest, char c) {
* try {
* tryBody(index, src, dest);
* } catch (Exception e) {
* catchBody(index, dest, c);
* }
* }
* private static void catchBody(int index, char[] dest, char c) {
* dest[index] = c;
* }
* private static void tryBody(int index, char[] src, char[] dest) {
* dest[index] = src[index];
* }
* private int iterations(char[] src) {
* return src.length + 1;
* }
*}
*
* @param add performing integer addition
*/
public static MethodHandle createAppend(MethodHandle add) {
MethodHandle charArrayCtor = arrayConstructor(char[].class);
MethodHandle charArrayLength = arrayLength(char[].class);
MethodHandle get = arrayElementGetter(char[].class);
MethodHandle set = arrayElementSetter(char[].class);
MethodHandle addOne = insertArguments(add, 0, 1);
MethodHandle newSize = filterReturnValue(charArrayLength, addOne);
MethodHandle init = filterReturnValue(newSize, charArrayCtor);
MethodHandle copyOne = permuteArguments(
collectArguments(set, 2, get),
methodType(void.class, char[].class, int.class, char[].class),
0, 1, 2, 1
);
MethodHandle tryBody = dropArguments(copyOne, 3, char.class);
MethodHandle catchBody = dropArguments(set, 2, char[].class);
MethodHandle dupForReturn = foldArguments(
dropArguments(identity(char[].class), 1, int.class, char[].class, char.class),
catchException(tryBody, Exception.class, dropArguments(catchBody, 0, Exception.class))
);
return countedLoop(newSize, init, dupForReturn);
}
/**
* char[] in, char a
*/
public static MethodHandle append() throws Throwable {
return createAppend(add());
}
/**
* Pops the first element from an array.
* {@snippet :
* char[] newArr = new char[old.length - 1];
* for (int i = 0; i < old.length - 1; i++) {
* newArr[i] = old[i + 1];
* }
*}
*/
public static MethodHandle popFirst() throws Throwable {
MethodHandle iterations = MethodHandles.filterArguments(
MethodHandles.insertArguments(add(), 0, -1),
0,
arrayLength(char[].class)
);
MethodHandle init = compose(arrayConstructor(char[].class), iterations);
// newArray, i, oldArray
MethodHandle body = MethodHandles.permuteArguments(
// newArray, newArray, i, oldArray, i
MethodHandles.filterArguments(
// newArray, newArray, i, oldArray, i+1
MethodHandles.collectArguments(
// newArray, newArray, i, oldValue
MethodHandles.collectArguments(
// newArray
identity(char[].class),
1,
arrayElementSetter(char[].class)
),
3,
arrayElementGetter(char[].class)
),
4,
MethodHandles.collectArguments(add(), 0, constant(int.class, 1))
),
methodType(char[].class, char[].class, int.class, char[].class),
0, 0, 1, 2, 1
);
return countedLoop(
iterations,
init,
body
);
}
Java is known for its fast development speed, allowing us to iterate fast.
With that little bit of a setup, we can now swiftly implement a key function from our original (Python) algorithm – reverse
.
public static MethodHandle reverse() throws Throwable {
VarHandle arrayElementVarHandle = arrayElementVarHandle(char.class.arrayType());
MethodHandle arrayGetAndSet = arrayElementVarHandle.toMethodHandle(VarHandle.AccessMode.GET_AND_SET);
MethodHandle arrayGetChar = arrayElementGetter(char.class.arrayType());
MethodHandle arraySetChar = arrayElementSetter(char.class.arrayType());
MethodHandle arrayLength = arrayLength(char.class.arrayType());
// v, i, (length - i - 1)
// (char[], int, int)
MethodHandle combinedSwitchP1 = permuteArguments(
// v, i, v, (length - i - 1)
collectArguments(
// v, i, v[length - i - 1]
// tmp = v[i] ; v[i] = v[length - i - 1] ; return tmp
interceptArraySetIfIndexIsNice(arrayGetAndSet),
2,
arrayGetChar // read v[length - i - 1]
),
methodType(char.class, char.class.arrayType(), int.class, int.class),
0, 1, 0, 2
);
// v, i, (length - i - 1)
// (char[], int, int)
MethodHandle combinedSwap = permuteArguments(
// v, (length - i - 1), v, i, (length - i - 1)
collectArguments(
// v, (length - i - 1), v[i]
interceptArraySetIfIndexIsNice(arraySetChar),
2,
combinedSwitchP1
),
methodType(void.class, char.class.arrayType(), int.class, int.class),
0, 2, 0, 1, 2
);
MethodHandle loopyboy = countedLoop(
// iterations(char[])
collectArguments(
// x / 2
shiftR1(),
0,
// char.length
arrayLength
),
null,
permuteArguments(
// v, i, v, i
collectArguments(
// v, i, (length - i - 1)
combinedSwap,
2,
collectArguments(
filterReturnValue(
subtractInt(),
insertArguments(subtractInt(), 1, 1)
),
0,
arrayLength
)
),
methodType(void.class, int.class, char.class.arrayType()),
1, 0, 1, 0
)
);
return permuteArguments(
// s, s
collectArguments(
identity(char[].class),
0,
loopyboy
),
methodType(char[].class, char[].class),
0, 0
);
}
Now, just a few trivial helpers here and there (it totally didn’t take hours to write them).
public static MethodHandle interceptArraySetIfIndexIsNice(
MethodHandle normalSet
) throws Throwable {
return guardWithTest(
// arrayref, index, value
dropArguments(
dropArguments(
compose(
isZero(),
insertArguments(xor(), 0, 11)
),
1,
char.class
),
0,
char[].class
),
// arrayref, index, value
filterArguments(
normalSet,
2,
explicitCastArguments(
insertArguments(
xor(),
0,
11
),
methodType(char.class, char.class)
)
),
normalSet
);
}
/**
* In-place xors the input by the value of its first element.
*/
public static MethodHandle xorInput() throws Throwable {
MethodHandle modifyArray = countedLoop(
arrayLength(char[].class),
insertArguments(arrayElementGetter(char[].class), 1, 0),
permuteArguments(
// arr[0], arr, i, arr, i
collectArguments(
// arr[0], arr, i, val
permuteArguments(
// arr, i, val, arr[0], arr[0]
collectArguments(
// arr, i, val ^ arr[0], arr[0]
collectArguments(
// arr[0]
identity(char.class),
0,
arrayElementSetter(char[].class)
),
2,
explicitCastArguments(
xor(),
methodType(char.class, char.class, char.class)
)
),
methodType(char.class, char.class, char[].class, int.class, char.class),
1, 2, 3, 0, 0
),
3,
arrayElementGetter(char[].class)
),
methodType(char.class, char.class, int.class, char[].class),
0, 2, 1, 2, 1
)
);
return permuteArguments(
// arr, arr
filterArguments(
// arr[0], arr
dropArguments(identity(char[].class), 0, char.class),
0,
modifyArray
),
methodType(char[].class, char[].class),
0, 0
);
}
And now just like that we have a CTF challenge.
/**
* Implements
* <pre>{@code
* def reverse'(message):
* msg = message[::-1]
* if len(msg) != (11 * 2 + 1):
* msg[11] = msg[11] ^ 11
* return msg
*
* def encrypt(message: str) -> List[int]:
* for i in range(message):
* message[i] = rotr(val=message[0], by=ord(message[0]))
* res = []
* while message:
* res = append(res, (char) message[0] ^ (char) message[-1] ^ 3)
* message = popFirst(reverse'(message))
* return res
* }</pre>
*/
public static MethodHandle encrypt() throws Throwable {
// Vs: res, message
MethodHandle encryptLoop = loop(
// res
new MethodHandle[]{
// init. We create a new empty char array
insertArguments(arrayConstructor(char[].class), 0, 0),
// step
permuteArguments(
// res, message, message
filterArguments(
// res, message[0], message[message.length - 1]
explicitCastArguments(
// res, message[0], message[message.length - 1]
collectArguments(
append(),
1,
explicitCastArguments(
filterArguments(
xor(),
0,
insertArguments( // xor first argument with constant
xor(),
1,
3
)
),
methodType(char.class, int.class, int.class)
)
),
methodType(char[].class, char[].class, char.class, char.class)
),
1,
// message -> message[0]
insertArguments(arrayElementGetter(char[].class), 1, 0),
// message -> message[message.length - 1]
permuteArguments(
// message, message
filterArguments(
// message, length
filterArguments(
// message, length - 1
arrayElementGetter(char[].class),
1,
insertArguments(add(), 0, -1)
),
1,
arrayLength(char[].class)
),
methodType(char.class, char[].class),
0, 0
)
),
methodType(char[].class, char[].class, char[].class, char[].class),
0, 1, 1
),
// pred. We never cancel the loop from here
constant(boolean.class, true),
// fini. Ignored, never returns.
dropArguments(identity(char[].class), 1, char[].class)
},
// message
new MethodHandle[]{
// init. We just keep the message as-is to start with
identity(char[].class),
// step. popFirst(reverse(message))
dropArguments(compose(popFirst(), reverse()), 0, char[].class),
// pred. if(message.length != 0)
dropArguments(compose(isNotZero(), arrayLength(char[].class)), 0, char[].class),
// fini. return res
dropArguments(identity(char[].class), 1, char[].class)
}
);
return filterArguments(
encryptLoop,
0,
xorInput()
);
}
public static void main(String[] args) throws Throwable {
char[] flag = new String(Files.readAllBytes(
Paths.get(
MethodHandleFun.class.getProtectionDomain().getCodeSource().getLocation().toURI()
).resolveSibling("flag.txt")
)).toCharArray();
char[] res = (char[]) encrypt().invoke(flag);
for (char re : res) {
System.out.print((int) re + " ");
}
System.out.println();
}
If you read until here, you solved a reversing challenge as-is. But to provide a little more fun, we inline as much as Java (with a limit for maximum method length) allows. We use Spoon for source code transformations. Hire I-Al-Istannen as a performance engineer!
public static void inlineAllLocalVariables(CtMethod<?> method) {
boolean foundVars = true;
while (foundVars) {
List<CtLocalVariable> vars = method
.getElements(new TypeFilter<>(CtLocalVariable.class))
.stream().filter(it -> !(it.getParent() instanceof CtForEach))
.toList();
method.accept(new CtScanner() {
@Override
public <T> void visitCtLocalVariableReference(CtLocalVariableReference<T> reference) {
if (reference.getParent() instanceof CtVariableRead<?>) {
if (reference.getDeclaration().getAssignment() != null) {
reference.getParent().replace(reference.getDeclaration().getAssignment().clone());
}
}
}
});
vars.forEach(CtElement::delete);
foundVars = !vars.isEmpty();
}
}
/**
* A very minimalistic function inliner. Requires a single return with a value at the end of the
* method. Can not handle (in-)direct recursive calls.
*
* @param toInline the method to inline
* @param call a call to the method
*/
public static void inline(CtMethod<?> toInline, CtInvocation<?> call) {
Factory factory = toInline.getFactory();
long returnCount = toInline.getBody()
.getStatements()
.stream()
.filter(it -> it instanceof CtReturn<?>)
.count();
if (!(toInline.getBody().getLastStatement() instanceof CtReturn<?>)) {
throw new IllegalArgumentException("Method did not end with return " + toInline);
}
if (returnCount != 1) {
throw new IllegalArgumentException("Not exactly one return statement in " + toInline);
}
if (toInline.getType().equals(factory.Type().voidPrimitiveType())) {
throw new IllegalArgumentException("Void return type in " + toInline);
}
if (call.getParent(CtExecutable.class).equals(toInline)) {
throw new IllegalArgumentException("Can not inline recursively!");
}
// Deduplicate variable names
Set<String> takenVariableNames = new HashSet<>();
call.getParent(CtExecutable.class).accept(new CtScanner() {
@Override
public <T> void visitCtLocalVariable(CtLocalVariable<T> localVariable) {
super.visitCtLocalVariable(localVariable);
takenVariableNames.add(localVariable.getSimpleName());
}
@Override
public <T> void visitCtParameter(CtParameter<T> parameter) {
super.visitCtParameter(parameter);
takenVariableNames.add(parameter.getSimpleName());
}
});
Set<String> ourVariableNames = new HashSet<>();
toInline.accept(new CtScanner() {
@Override
public <T> void visitCtLocalVariable(CtLocalVariable<T> localVariable) {
super.visitCtLocalVariable(localVariable);
ourVariableNames.add(localVariable.getSimpleName());
}
});
if (!Sets.intersection(takenVariableNames, ourVariableNames).isEmpty()) {
// we need to rename our variables
Map<String, String> renames = new HashMap<>();
for (String name : ourVariableNames) {
String current = name;
for (int counter = 0; true; counter++) {
boolean hasConflict = takenVariableNames.contains(current);
hasConflict |= (ourVariableNames.contains(current) && counter > 0);
if (!hasConflict) {
break;
}
current = current + counter;
}
renames.put(name, current);
}
toInline.accept(new CtScanner() {
@Override
public <T> void visitCtLocalVariable(CtLocalVariable<T> localVariable) {
super.visitCtLocalVariable(localVariable);
if (renames.containsKey(localVariable.getSimpleName())) {
localVariable.setSimpleName(renames.get(localVariable.getSimpleName()));
}
}
@Override
public <T> void visitCtLocalVariableReference(CtLocalVariableReference<T> reference) {
super.visitCtLocalVariableReference(reference);
if (renames.containsKey(reference.getSimpleName())) {
reference.setSimpleName(renames.get(reference.getSimpleName()));
}
}
});
}
// Copy method statements over
List<CtStatement> methodStatements = toInline.getBody().clone().getStatements();
CtStatement callStatement = call.getParent(CtStatement.class);
for (int i = 0; i < methodStatements.size() - 1; i++) {
callStatement.insertBefore(rewireParameters(toInline, call, methodStatements.get(i)));
}
// Replace return value
CtExpression<?> ourReturn = toInline.getBody()
.<CtReturn<?>>getLastStatement()
.getReturnedExpression()
.clone();
call.replace(rewireParameters(toInline, call, ourReturn));
}
private static <R extends CtElement> R rewireParameters(
CtMethod<?> toInline, CtInvocation<?> call, R element
) {
element.accept(new CtScanner() {
@Override
public <T> void visitCtVariableRead(CtVariableRead<T> variableRead) {
if (!(variableRead.getVariable() instanceof CtParameterReference<T> paramRef)) {
return;
}
CtParameter<?> param = toInline.getParameters()
.stream()
.filter(it -> it.getSimpleName().equals(paramRef.getSimpleName()))
.findFirst()
.orElseThrow(() -> new IllegalArgumentException("Could not find my parameter?"));
int index = toInline.getParameters().indexOf(param);
if (index < 0) {
throw new IllegalArgumentException("Could not find my parameter in call?");
}
variableRead.replace(call.getArguments().get(index).clone());
}
});
return element;
}
To top that of let’s replace integer constants with seeded calls to random.
private static final int MAX_RANDOM_SEQUENCE_LENGTH = 2;
private static final List<String> CLASS_LIT_OPERATIONS = List.of(
".arrayType()",
".getComponentType()"
);
private final Factory factory;
ReplaceLiteralsScanner(Factory factory) {
this.factory = factory;
}
@Override
public <T> void visitCtLiteral(CtLiteral<T> literal) {
if (literal.getValue() instanceof Integer i) {
var cracked = crackSeed(i);
if (cracked.isEmpty()) {
return;
}
int seed = cracked.getAsInt();
int actual = new Random(seed).nextInt(255);
if (actual != i) {
throw new RuntimeException("OH NO");
}
literal.replace(factory.createCodeSnippetExpression(
"new java.util.Random({seed}).nextInt(255)"
.replace("{seed}", Integer.toString(seed))
));
}
super.visitCtLiteral(literal);
}
private OptionalInt crackSeed(int target) {
if (target > 255 || target < 0) {
return OptionalInt.empty();
}
while (true) {
int seed = ThreadLocalRandom.current().nextInt();
Random random = new Random(seed);
int ourTry = random.nextInt(255);
if (ourTry == target) {
return OptionalInt.of(seed);
}
}
}
@Override
public <T> void visitCtFieldRead(CtFieldRead<T> fieldRead) {
super.visitCtFieldRead(fieldRead);
if (!fieldRead.getVariable().getSimpleName().equals("class")) {
return;
}
if (!(fieldRead.getTarget() instanceof CtTypeAccess<?> typeAccess)) {
return;
}
int dimension = 0;
CtTypeReference<?> arrayType = typeAccess.getAccessedType();
if ((typeAccess.getAccessedType() instanceof CtArrayTypeReference<?> arrayRef)) {
arrayType = arrayRef.getArrayType();
dimension = arrayRef.getDimensionCount();
}
if (Set.of("Exception", "void").contains(arrayType.toString())) {
return;
}
var snippet = classLitForDimension(arrayType.toString(), dimension);
fieldRead.replace(factory.createCodeSnippetExpression(snippet));
}
private static String classLitForDimension(String arrayType, int dimension) {
int openingArrayBrackets = ThreadLocalRandom.current().nextInt(dimension + 1);
int currentDepth = openingArrayBrackets;
StringBuilder result = new StringBuilder(arrayType)
.append("[]".repeat(openingArrayBrackets))
.append(".class");
for (int i = 0; i < MAX_RANDOM_SEQUENCE_LENGTH; i++) {
double ascentPercentage = currentDepth > dimension
? 0.1
: currentDepth < dimension ? 0.8 : 0.5;
if (ThreadLocalRandom.current().nextDouble() >= ascentPercentage || currentDepth == 0) {
result.append(CLASS_LIT_OPERATIONS.get(0));
currentDepth++;
} else {
result.append(CLASS_LIT_OPERATIONS.get(1));
currentDepth--;
}
}
while (currentDepth > dimension) {
result.append(CLASS_LIT_OPERATIONS.get(1));
currentDepth--;
}
while (currentDepth < dimension) {
result.append(CLASS_LIT_OPERATIONS.get(0));
currentDepth++;
}
return result.toString();
}
We provided the compiled jar and the output of the program as produced by the main
method.
57 101 114 32 62 63 50 53 34 103 117 13 48 10 33 38 26 95 7 113 116 64 20 1 39 37 89 7 93 88 82 84 71 31 26 69 6 5 3 31 31 2 6 94 3 92 120 126 0 89 27 70 4 90 49 111 1 94 2 21 3
8 teams solved the challenge.
The full code is on GitHub.