Skip to main content
Wachter Space 🚀
  1. Posts/

b01lers CTF resnet Model Inversion

With KITCTF we participated in the bo01lers CTF and finished 6th. There were some quite fun challenges. Including the resnet challenge, which is a machine learning challenge. I hope to see more machine learning challenges in the future.

Challenge description:

A naive AI startup released a new visual password system based on State-of-the-Art Neural Network technology. Wanting to save on costs they reuse the popular Resnet model to create embeddings which input password images are checked against hoping to leverage the feature extraction capabilities of Resnet. Can you crack this visual password system? nc ctf.b01lers.com 9101

The challenge consists of two python files:

# resnet_password.py
import torch
from PIL import Image
from torchvision import transforms
from torchvision.models.resnet import resnet18

from io import BytesIO
import base64

from Embedding import key

model = resnet18(pretrained=True)

# get embedding rather than logits from final layer
model.fc = torch.nn.Identity()

try:
    with open("flag.txt") as f:
        FLAG = f.read()
except FileNotFoundError:
    FLAG = "TEMP_FLAG"
    
key = torch.tensor(key)

def check_image():
    encoded = input("Input base64 image password\n")
    decoded = Image.open(BytesIO(base64.b64decode(encoded)))
    tensored = transforms.ToTensor()(decoded)
    
    with torch.no_grad():
        embedding = model(tensored.reshape(1, 3, 224, 224))[0]
        diff = ((embedding - key)**2).mean()
        diff = diff.item()
    
    print(f"Image embedding differed by {diff}")
    if diff < 1e-4:
        print(f"Nice! Here's the flag: {FLAG}")

if __name__ == '__main__':
    check_image()
# Embedding.py
key = [
        0.8650, 0.9455, 0.8639, 1.0084, 0.9285, 0.8526, 0.8912, 1.1995, 0.9959,
        0.9118, 0.8543, 0.8345, 0.9492, 0.9275, 0.8548, 0.9360, 0.8214, 1.3323,
        0.9196, 0.8247, 0.9448, 1.0769, 0.8564, 1.0093, 0.8940, 0.9418, 0.9409,
        0.9712, 0.8296, 0.9297, 0.8755, 0.8383, 0.9345, 0.8950, 0.8965, 0.8585,
        0.9743, 0.9376, 0.8611, 0.9015, 0.8631, 0.8253, 0.7963, 0.9119, 0.8424,
        0.8542, 0.9713, 1.1977, 0.9288, 0.8498, 0.9680, 0.9709, 0.8566, 0.8771,
        0.9883, 0.8877, 1.1215, 0.9938, 1.0397, 0.9664, 0.8750, 1.0023, 1.0433,
        0.9407, 0.8281, 1.0676, 0.8490, 0.9219, 1.0369, 0.8166, 0.9104, 0.8466,
        0.9019, 0.9150, 0.9451, 0.8786, 0.8384, 0.9434, 0.9005, 0.8694, 0.9531,
        1.0324, 1.1126, 1.2170, 0.8468, 0.9347, 0.9186, 0.8807, 0.9584, 0.8694,
        0.8221, 0.8793, 0.9319, 0.9409, 0.8699, 0.8341, 0.9384, 0.9210, 0.9601,
        0.8372, 1.0205, 0.9668, 0.9575, 0.9817, 0.9190, 0.7899, 0.9782, 0.9639,
        0.8616, 0.9053, 0.8651, 0.9046, 0.8485, 0.8941, 0.8070, 0.8750, 0.9405,
        0.9611, 0.9338, 0.9463, 0.9585, 0.9170, 0.8921, 0.8913, 0.8737, 0.8281,
        0.8603, 0.9241, 0.9073, 0.9468, 0.8172, 0.8761, 0.8523, 0.8735, 1.0539,
        1.0376, 0.8064, 0.8629, 0.8900, 0.9022, 0.8998, 0.9001, 0.9887, 0.8502,
        0.8893, 0.9280, 0.8340, 1.0512, 0.9638, 0.9730, 0.8568, 0.9343, 0.9663,
        0.8496, 0.8383, 0.7968, 0.9864, 0.9984, 0.8385, 0.8920, 0.9880, 1.2400,
        0.9082, 1.2500, 0.9052, 0.9534, 0.9997, 0.9543, 0.8232, 0.9885, 0.8958,
        0.8891, 0.9299, 0.8389, 0.8473, 0.8692, 0.8773, 0.8880, 1.0001, 0.8334,
        0.9308, 0.9468, 1.2865, 0.9266, 0.9571, 1.0100, 0.8823, 0.9597, 0.9218,
        0.9291, 0.8751, 0.8698, 0.9300, 1.0193, 0.8342, 0.8409, 0.8747, 0.8876,
        0.8868, 0.8185, 0.8570, 1.0043, 1.0543, 0.8878, 1.0854, 0.9775, 0.8642,
        0.9046, 0.9525, 0.9877, 0.8551, 0.8609, 0.9107, 0.9581, 0.9754, 0.9230,
        0.9453, 0.9090, 0.8826, 0.9057, 0.9171, 0.9841, 0.9214, 0.9026, 0.9366,
        0.8717, 0.9639, 0.9030, 1.0163, 0.9076, 0.9112, 0.9738, 0.8439, 0.9663,
        0.8593, 0.9215, 1.0177, 0.9198, 0.8552, 1.0665, 0.9245, 0.8600, 1.0593,
        0.8057, 0.9690, 0.9620, 0.9400, 0.8539, 0.9616, 1.0030, 1.0665, 0.9096,
        0.9266, 0.9860, 0.9339, 1.2710, 0.8949, 0.9893, 0.9885, 1.0137, 0.9511,
        0.8313, 0.9359, 0.9718, 0.8608, 0.9710, 0.8859, 0.8902, 0.9090, 0.8831,
        0.8862, 0.8377, 0.9148, 0.8918, 0.9451, 0.8943, 1.0040, 0.8859, 1.0251,
        0.8077, 1.2893, 0.7872, 0.9042, 0.9990, 0.9819, 0.8438, 0.9236, 0.8789,
        1.0265, 0.9193, 0.8327, 0.8968, 0.8743, 0.8636, 0.9922, 0.8292, 0.9180,
        1.1223, 0.9540, 0.8736, 0.8395, 0.8859, 0.8306, 0.8534, 1.0874, 0.8890,
        0.9144, 0.9516, 0.8604, 0.9001, 1.0925, 0.9599, 1.0549, 0.8452, 0.8947,
        0.9556, 0.9757, 0.8775, 1.0404, 0.9159, 0.8583, 0.9495, 0.8469, 0.8919,
        0.8877, 0.9701, 0.9075, 0.9153, 0.8997, 0.9073, 0.9370, 1.1040, 0.9694,
        0.9153, 0.9745, 0.8301, 0.9659, 0.8538, 0.9454, 0.8904, 0.8712, 0.8812,
        0.9240, 1.1233, 0.9104, 0.8279, 0.8594, 0.9163, 0.8216, 0.9497, 0.8768,
        0.9002, 0.9149, 0.9308, 0.8898, 0.8224, 0.8996, 0.9806, 0.8820, 0.8134,
        0.9127, 1.0037, 0.8360, 0.8509, 1.0956, 0.9381, 0.9410, 0.8880, 0.9200,
        0.8738, 1.0781, 0.8949, 1.0256, 0.8457, 0.8585, 0.9247, 0.8767, 0.8310,
        0.8649, 0.9320, 1.0716, 0.8237, 0.9167, 1.0095, 1.0060, 1.0311, 0.9016,
        1.0073, 0.8987, 0.9918, 0.9710, 0.9505, 0.8579, 0.9880, 1.0872, 0.9249,
        0.9356, 0.9214, 0.8857, 0.9356, 1.1062, 1.0667, 0.8551, 0.9001, 0.9490,
        0.8467, 0.8118, 0.9020, 0.7785, 0.8705, 0.9300, 0.9352, 0.8551, 0.8192,
        0.9887, 0.9683, 0.8297, 0.8715, 0.8383, 0.8492, 0.9008, 0.8828, 0.9078,
        0.9090, 0.9956, 0.9091, 0.9729, 1.0581, 0.8514, 0.9229, 0.9135, 0.9439,
        0.9019, 0.9204, 0.8584, 1.1420, 0.9046, 0.8769, 0.9000, 0.8779, 0.8923,
        0.8457, 0.8609, 0.9619, 0.9375, 0.9699, 0.9989, 0.8633, 1.0134, 0.9025,
        0.8584, 0.9644, 0.9300, 0.8777, 0.8741, 0.8695, 0.9289, 0.8809, 0.9310,
        0.9730, 1.2499, 0.8951, 0.8661, 0.8809, 1.0089, 0.8673, 1.0068, 0.9371,
        1.0070, 0.9130, 0.9215, 1.0064, 0.8882, 0.9126, 0.8043, 0.8842, 0.9036,
        0.8836, 0.8597, 0.8899, 0.9195, 0.9754, 0.8787, 0.8744, 1.1108, 1.0311,
        0.9571, 0.9503, 1.1428, 0.8770, 1.0210, 1.0389, 0.9261, 0.9023, 0.9089,
        0.9123, 0.8272, 0.8943, 0.8782, 0.9152, 0.9302, 0.8577, 0.9566, 0.8462,
        0.8203, 0.8180, 1.0899, 1.0567, 0.9685, 0.9352, 0.9568, 0.9383
      ]

So, not a lot of code. Let’s see what it does. First, it is using a pretrained resnet18. resnet18 is a convolutional neural network with residual connections for ImageNet. We do not need the details of its architecture. All we need to know, is that it takes images in tensor form, with 3 color channels. And that we have direct access to the network, i.e. we can compute gradients.

Normally, reset would output scores for each of its 1000 classes, but here the last layer is replaced with the identity function. This way it outputs the embedding (size 512) of an image. The embedding can be thought of as a compressed representation of the input, where inputs of the same class should have a similar embedding.

The challenge now wants us to give an image input that matches the given embedding key very closely. In CTF terms: The neural network is a crackme, and we have to reverse it. In machine learning security terms: We need to perform a model inversion attack.

Model inversion attacks have been presented by Fredrikson et al. in their paper “Model Inversion Attacks that Exploit Confidence Information and Basic Countermeasures”. What they do is, given a class, find an input that matches this class. Here is the famous training data reconstruction result from a facial recognition dataset.

Untitled

We have a very similar problem, in fact we just need to swap out the loss function to be the difference of embedding and key diff = ((embedding - key)**2).mean() as given in the challenge.

The attack implementation is fairly simple. It is very similar to the normal training of a neural network, but instead of using backpropagation to optimize the model parameters, we optimize an input tensor. Optimization is pretty fast, and I did not spend a lot of time tweaking hyper-parameters. What took some time is to finally find a tensor that does not suffer from the lossy conversion to an image file. But with very conservative clipping torch.clip(x, min=0.1, max=0.90) and x.mul(255).div(255), I managed to find an input that has an embedding to the key. If you are interested, here is what the final image look like. Fortunately for us, we do not need to find an input on the data manifold.

input.png

Here is the full attack code:

from Embedding import key
import torch
from torchvision.models.resnet import resnet18
import numpy as np
from PIL import Image
import base64
from io import BytesIO
from pwn import *

def invert(target_model, target_embedding, input_shape=(1, 3, 224, 224), num_itr=150, step_size=4):
    x = torch.zeros(input_shape, requires_grad=True)

    min_loss = float("inf")
    best_img = None

    for i in range(num_itr):
        x = x.detach()
        x.requires_grad = True
        pred = target_model(x)
        loss = ((target_embedding - pred)**2).mean()
        loss.backward()
        grad = x.grad
        x = x.mul(255).div(255)

        if loss.item() < min_loss:
            best_img = x

        with torch.no_grad():
            x -= step_size * grad
            x = torch.clip(x, min=0.01, max=0.99)
        print(f"epoch {i}: {loss.item()}")
    return best_img

def tensor_to_base64img(inv_tensor):
    inv_tensor = inv_tensor.reshape(3, 224, 224)
    scaled_tensor = inv_tensor.mul(255).byte()
    npimg = np.transpose(scaled_tensor.cpu().detach().numpy(), (1, 2, 0))
    i = Image.fromarray(npimg)

    buffered = BytesIO()
    i.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue())

def send_solution(encoded):
    r = remote('ctf.b01lers.com', 9101)
    r.sendline(encoded)
    r.interactive()

def main():
    model = resnet18(pretrained=True)
    model.fc = torch.nn.Identity()

    target = torch.tensor(key)

    shape = (1, 3, 224, 224)
    step_size = 4
    num_itr = 350

    inv_tensor = invert(model, target, shape, num_itr, step_size)

    send_solution(tensor_to_base64img(inv_tensor))

main()