CCSC 2024 - Ancient VM

13 minute read

Ancient VM Writeup

Challenge

An Andromeda hacker was able to extract an ancient, rusty VM that belongs to Project Echo, but he couldn’t get the file that contains the secret information. Do you think you can recover it?

We are given three files as part of the challenge:

  • ancient_vm - A stripped rust binary
  • program.txt - A small program written in some weird esolang
  • output.txt - The output of the program, if run on the VM. This consists of 60 numbers, separated by spaces.

Solution

I am a beginner in rev-ing, and the stripped binary looked quite intimidating, especially never having done any Rust reversing before. Based on the challenge description and the files given, I assumed that the situation was that the VM was interpreting the program, and the output was stored in the file program.txt. The specific output given was based on the flag, and I would have to understand how the VM works, and what the program does, in order to get the flag.

If we run the binary, we get the output Failed to read flag file. The VM is trying to read the flag from a file, and will presumably calculate the output based on that. Since I don’t have the flag (yet), I simply made a flag-looking flag.txt file (with something like CCSC{I_dont_know_how_to_rev}). After running again, we get a different error: byte index 29 is out of bounds of `CCSC{I_dont_know_how_to_rev}`. This means that the actual flag is longer, so I added some more characters to my flag. Repeating this a few more times, we get that the flag length is 46. Finally, we get the binary running without any errors, and we get an output.txt file containing 60 numbers, which does not match the given output.txt.

$ cat flag.txt && rm output.txt &&  ./vm && cat output.txt
CCSC{I_dont_know_how_to_revAAAAAAAAAAAAAAAAA}
1839 12 1036 111 7236 1045 99 -181 8450 1269 1048 -6 61945 -24 436 1049 153 47992 56225 48685 28504 90 1196 83 7540 13209 349 283 67 3560 367 31731 11440 0 4314 425 1451 18915 260 265 205 65152 650 12224 70207 700 1294 49465 153 25141 246 1044 712 47320 364 1105 72 762 206 30095 %     

Initially, I thought that each character in the flag would correspond to one of the numbers in the output, followed by some transformation. However, that didn’t seem to be the case, since the lengths are different and also since the first 5 characters of my flag were correct, but the initial output.txt didn’t match.

I thought that instead of trying to understand the VM, I could try to just try to slowly change the flag, and see how the output changes depending on my change. I tried to take one number at a time, so I focused on just the first number from the output. With my flag, I got 1839, but I should be getting 1705. Changing the first 3 characters resulted in no change, but changing the first few characters following CCSC{ changed the output. In order to determine the exact range of indices that affected the first number, I wrote this script:

import os
import random
import string

def get_output(flag: bytes):
    if os.path.exists('output.txt'):
        os.remove('output.txt')
    with open('flag.txt', 'wb') as f:
        f.write(flag)
    os.system('./vm')
    with open('output.txt') as f:
        return list(map(int, f.read().split()))

FLAG_LENGTH = 46

flag_orig = bytearray(ord(random.choice(string.ascii_uppercase)) for _ in range(FLAG_LENGTH))
original_output = get_output(flag_orig)
print(f'Original output: {original_output[0]}')
for i in range(FLAG_LENGTH):
    flag = flag_orig.copy()
    flag[i] += 0x20
    output = get_output(flag)
    if output[0] != original_output[0]:
        print(f'Index {i:<3} affects the first number: {output[0]}')

Running it, we get the following output:

Original output: 1367
Index 3   affects the first number: 1383
Index 4   affects the first number: 1383
Index 5   affects the first number: 1383
Index 6   affects the first number: 1383
Index 7   affects the first number: 1383
Index 8   affects the first number: 1383
Index 9   affects the first number: 1383
Index 16  affects the first number: 1431
Index 17  affects the first number: 1431
Index 18  affects the first number: 1431
Index 19  affects the first number: 1431
Index 20  affects the first number: 1431
Index 21  affects the first number: 1431
Index 22  affects the first number: 1431

And we now have the list of indices which affect the first output number. After some experimentation, I also found that the first output is calculated as follows:

output[0] = flag[3]//2 + flag[4]//2 + flag[5]//2 + flag[6]//2 + flag[7]//2 + flag[8]//2 + flag[9]//2 + 2*(flag[16] + flag[17] + flag[18] + flag[19] + flag[20] + flag[21] + flag[22])

This can be seen since adding 32 to each position either adds 16 to the output or adds 64 to the output.

After repeating the process for the next few numbers, I found that the first few fields of the output are calculated as follows:

output[0] = flag[3]//2 + flag[4]//2 + flag[5]//2 + flag[6]//2 + flag[7]//2 + flag[8]//2 + flag[9]//2 + 2*(flag[16] + flag[17] + flag[18] + flag[19] + flag[20] + flag[21] + flag[22])
output[1] = flag[8] + flag[9] + flag[10] - flag[14] - flag[15] - flag[16]
output[2] = flag[12] + flag[13] + flag[14] + flag[15] + flag[16] + flag[17] + flag[29] + flag[30] + flag[31] + flag[32] + flag[33] + flag[34]
output[5] = flag[21] + flag[22] + flag[23] + flag[24] + flag[25] + flag[26] + flag[38] + flag[39] + flag[40] + flag[41] + flag[42] + flag[43]
output[6] = -flag[2] - flag[3] - flag[4] - flag[5] - flag[6] + flag[14] + flag[15] + flag[16] + flag[17] + flag[18]
output[7] = -flag[19] - flag[20] - flag[21] - flag[22] + flag[39] + flag[40] + flag[41] + flag[42]

You may notice that I didn’t figure out the rule for the numbers at indices 3 and 4. It seems like everytime I run it, the rule was something different, so I ignored it for now. At this rate, figuring out all the outputs would take too long, so I wrote a script to figure out the constraints for me, by following the same process:

  1. Create a random flag
  2. Change one character at a time, and see how the output changes. Measure by how much the output changes depending on the change in the flag (did it increase by the same amount, did it decrease, did it increase by 2x, etc).
  3. Repeat for all numbers in the output.
FLAG_LENGTH = 46
OUTPUTS = 60

flag_orig = bytearray(ord(random.choice(string.ascii_uppercase)) for _ in range(FLAG_LENGTH))

multipliers: list[dict[int, float]] = [{} for _ in range(OUTPUTS)]

for i in range(FLAG_LENGTH):
    flag = flag_orig.copy()
    change = 0x20
    flag[i] += change
    output = get_output(flag)
    for j, x in enumerate(output):
        if x == original_output[j]:
            continue
        multipliers[j][i] = (x - original_output[j]) / change

for i, m in enumerate(multipliers):
    output = f'output[{i}] = '
    for j, x in m.items():
        if abs(x) < 1:
            output += f'flag[{j}]//{int(1/x)} + '
        else:
            output += f'flag[{j}]*{int(x)} + '
    print(output[:-2])

And with this, we have the constraints for a large portion of the numbers in the output.

output[0] = flag[3]//2 + flag[4]//2 + flag[5]//2 + flag[6]//2 + flag[7]//2 + flag[8]//2 + flag[9]//2 + flag[16]*2 + flag[17]*2 + flag[18]*2 + flag[19]*2 + flag[20]*2 + flag[21]*2 + flag[22]*2 
output[1] = flag[8]*1 + flag[9]*1 + flag[10]*1 + flag[14]*-1 + flag[15]*-1 + flag[16]*-1 
output[2] = flag[12]*1 + flag[13]*1 + flag[14]*1 + flag[15]*1 + flag[16]*1 + flag[17]*1 + flag[29]*1 + flag[30]*1 + flag[31]*1 + flag[32]*1 + flag[33]*1 + flag[34]*1 
output[3] = flag[2]*1 + flag[3]*1 + flag[4]*1 + flag[5]*1 + flag[11]*1 + flag[12]*1 + flag[13]*1 + flag[14]*1 
output[4] = flag[8]//16 + flag[9]//16 + flag[10]//16 + flag[11]//16 + flag[12]//16 + flag[13]//16 + flag[39]*16 + flag[40]*16 + flag[41]*16 + flag[42]*16 + flag[43]*16 + flag[44]*16 
output[5] = flag[21]*1 + flag[22]*1 + flag[23]*1 + flag[24]*1 + flag[25]*1 + flag[26]*1 + flag[38]*1 + flag[39]*1 + flag[40]*1 + flag[41]*1 + flag[42]*1 + flag[43]*1 
output[6] = flag[2]*-1 + flag[3]*-1 + flag[4]*-1 + flag[5]*-1 + flag[6]*-1 + flag[14]*1 + flag[15]*1 + flag[16]*1 + flag[17]*1 + flag[18]*1 
output[7] = flag[19]*-1 + flag[20]*-1 + flag[21]*-1 + flag[22]*-1 + flag[39]*1 + flag[40]*1 + flag[41]*1 + flag[42]*1 
output[8] = flag[32]*65 + flag[33]*72 + flag[37]*69 + flag[38]*75 

However, you can see that it is not always correct. For example, the rule for output[8] is not correct, since my script assumes that the change in the output is linear with the change in the flag, which is not the case here, but it looks like some constraints are the sum of the multiplication between two indices instead, so I incorporated that as well in the script. However, in some cases, there are multiple indices that could be paired together, so in some cases I had to do the process multiple times.

for i, m in enumerate(multipliers):
    output = f'output[{i}] = '

    matchings:dict[int, list[int]] = {}

    for j, x in m.items():
        if abs(x) >= ord('A'):
            # I assume that this is a multiplication by another flag index
            matching = [flag_i for flag_i, flag_val in enumerate(flag_orig) if flag_i != j and x == flag_val and flag_i in m]
            matchings[j] = matching
        elif abs(x) < 1:
            output += f'flag[{j}]//{int(1/x)} + '
        else:
            output += f'flag[{j}]*{int(x)} + '

    used: set[int] = set()

    for x, mx in matchings.items():
        if x in used:
            continue
        possible_pairs = []
        for y in mx:
            if x in matchings[y]:
                possible_pairs.append(y)

        if len(possible_pairs) != 1:
            print("Error: multiple possible pairs")
        else:
            used.add(possible_pairs[0])
            output += f'flag[{x}]*flag[{possible_pairs[0]}] + '

    print(output[:-2])
output[8] = flag[32]*flag[37] + flag[33]*flag[38]
output[12] = flag[10]*flag[37] + flag[11]*flag[38] + flag[12]*flag[39] + flag[13]*flag[40] + flag[14]*flag[41] + flag[15]*flag[42] + flag[16]*flag[43] + flag[17]*flag[44] 

Running this a few times, I collected the constraints for 42 out of the 60 numbers. Having the constraints, I could now try to solve them and get the flag. For this, I used the following Z3 script:

import z3

with open('./output_flag.txt', 'r') as f:
    output = list(map(int, f.read().split()))

flag = [z3.BitVec(f'flag_{i}', 8) for i in range(45)]

s = z3.Solver()

s.add(flag[0] == ord('C'))
s.add(flag[1] == ord('C'))
s.add(flag[2] == ord('S'))
s.add(flag[3] == ord('S'))
s.add(flag[4] == ord('{'))
s.add(output[0] == (flag[3]>>1) + (flag[4]>>1) + (flag[5]>>1) + (flag[6]>>1) + (flag[7]>>1) + (flag[8]>>1) + (flag[9]>>1) + 2*(flag[16] + flag[17] + flag[18] + flag[19] + flag[20] + flag[21] + flag[22]))
s.add(output[1] == flag[8] + flag[9] + flag[10] - flag[14] - flag[15] - flag[16])
s.add(output[2] == flag[12] + flag[13] + flag[14] + flag[15] + flag[16] + flag[17] + flag[29] + flag[30] + flag[31] + flag[32] + flag[33] + flag[34])
s.add(output[5] == flag[21] + flag[22] + flag[23] + flag[24] + flag[25] + flag[26] + flag[38] + flag[39] + flag[40] + flag[41] + flag[42] + flag[43])
s.add(output[6] == -flag[2] - flag[3] - flag[4] - flag[5] - flag[6] + flag[14] + flag[15] + flag[16] + flag[17] + flag[18])
s.add(output[7] == -flag[19] - flag[20] - flag[21] - flag[22] + flag[39] + flag[40] + flag[41] + flag[42])
s.add(output[8] == flag[32] * flag[37] + flag[33] * flag[38])
s.add(output[9] == flag[18] + flag[19] + flag[20] + flag[21] + flag[22] + flag[23] + flag[24] + flag[26] + flag[27] + flag[28] + flag[29] + flag[30] + flag[31] + flag[32])
s.add(output[10] == flag[32] + flag[33] + flag[34] + flag[35] + flag[36] + flag[37] + flag[10] + flag[11] + flag[12] + flag[13] + flag[14] + flag[15])
s.add(output[11] == flag[20] + flag[21] - flag[12] - flag[13])
s.add(output[12] == flag[10] * flag[37] + flag[11] * flag[38] + flag[12] * flag[39] + flag[13] * flag[40] + flag[14] * flag[41] + flag[15] * flag[42] + flag[16] * flag[43] + flag[17] * flag[44])
s.add(output[13] == -flag[0] - flag[1] - flag[2] - flag[3] + flag[29] + flag[30] + flag[31] + flag[32])
s.add(output[14] == flag[14] + flag[15] + flag[22] + flag[23])
s.add(output[15] == flag[1] + (flag[2] + flag[3] + flag[4] + flag[5] + flag[6])*2 + flag[7])
s.add(output[17] == flag[2] * flag[11] + flag[3] * flag[12] + flag[4] * flag[13] + flag[5] * flag[14] + flag[6] * flag[15])
s.add(output[18] == flag[17] * flag[27] + flag[18] * flag[28] + flag[19] * flag[29] + flag[20] * flag[30] + flag[21] * flag[31] + flag[22] * flag[32] + flag[23] * flag[33] + flag[24] * flag[34])
s.add(output[19] == flag[7] * flag[29] + flag[8] * flag[30] + flag[9] * flag[31] + flag[10] * flag[32] + flag[11] * flag[33] + flag[12] * flag[34] + flag[13] * flag[35] )
s.add(output[20] == flag[5] * flag[26] + flag[6] * flag[27] + flag[7] * flag[28] + flag[8] * flag[29])
s.add(output[22] == flag[11] * 1 + flag[12] * 1 + flag[13] * 1 + flag[14] * 1 + flag[15] * 1 + flag[16] * 1 + flag[17] * 1 + flag[29] * 1 + flag[30] * 1 + flag[31] * 1 + flag[32] * 1 + flag[33] * 1 + flag[34] * 1 + flag[35] * 1)
s.add(output[24] == flag[10] * flag[40])
s.add(output[25] == flag[19] * flag[22])
s.add(output[26] == flag[25] * 1 + flag[26] * 1 + flag[30] * 1 + flag[31] * 1)
s.add(output[30] == (flag[24]>>1) + (flag[25]>>1) + flag[32]*2 + flag[33]*2)
s.add(output[31] == flag[8] * flag[25] + flag[9] * flag[26] + flag[10] * flag[27] )
s.add(output[32] == flag[9] * flag[17] )
s.add(output[33] == flag[30] * 1 + flag[31] * 1 + flag[36] * -1 + flag[37] * -1 )
s.add(output[35] == flag[7] * 1 + flag[8] * 1 + flag[15] * 1 + flag[16] * 1 )
s.add(output[36] == flag[5] * 1 + flag[6] * 1 + flag[7] * 1 + flag[8] * 1 + flag[9] * 1 + flag[10] * 1 + flag[11] * 1 + flag[19] * 1 + flag[20] * 1 + flag[21] * 1 + flag[22] * 1 + flag[23] * 1 + flag[24] * 1 + flag[25] * 1 )
s.add(output[37] == flag[4] * flag[35] + flag[5] * flag[36] + flag[6] * flag[37] )
s.add(output[38] == flag[33] * 1 + flag[34] * 1 + flag[40] * 1 + flag[41] * 1 )
s.add(output[40] == flag[13] * 1 + flag[20] * 1 )
s.add(output[42] == flag[30] * 1 + flag[31] * 1 + flag[32] * 1 + flag[33] * 1 + flag[34] * 1 + flag[35] * 1 + flag[36] * 1 + flag[37] * 1 + flag[38] * 1 + flag[39] * 1)
s.add(output[43] == flag[2] * 16 + flag[3] * 16 + flag[4] * 16 + flag[5] * 16 + flag[6]*16 + flag[7]*16 + flag[8]*16 + flag[9]*16 + (flag[34]>>4) + (flag[35]>>4) + (flag[36]>>4) + (flag[37]>>4) + (flag[38]>>4) + (flag[39]>>4) + (flag[40]>>4) + (flag[41]>>4))
s.add(output[45] == flag[0] * 1 + flag[1] * 1 + flag[2] * 1 + flag[3] * 1 + flag[6] * 1 + flag[7] * 1 + flag[8] * 1 + flag[9] * 1 )
s.add(output[46] == flag[7] * 1 + flag[8] * 1 + flag[9] * 1 + flag[10] * 1 + flag[11] * 1 + flag[12] * 1 + flag[21] * 1 + flag[22] * 1 + flag[23] * 1 + flag[24] * 1 + flag[25] * 1 + flag[26] * 1 )
s.add(output[47] == flag[18] * flag[35] + flag[19] * flag[36] + flag[20] * flag[37] + flag[21] * flag[38] + flag[22] * flag[39] + flag[23] * flag[40] + flag[24] * flag[41] )
s.add(output[51] == (flag[42]>>4) + flag[43] * 16)
s.add(output[52] == flag[0] * 1 + flag[1] * 1 + flag[2] * 1 + flag[3] * 1 + flag[23] * 1 + flag[24] * 1 + flag[25] * 1 + flag[26] * 1 )
s.add(output[53] == flag[4] * flag[29] + flag[5] * flag[30] + flag[6] * flag[31] + flag[7] * flag[32] + flag[8] * flag[33] + flag[9] * flag[34] + flag[10] * flag[35] )
s.add(output[54] == flag[0] * 1 + flag[1] * 1 + flag[14] * 1 + flag[15] * 1 )
s.add(output[58] == flag[11] * 1 + flag[14] * 1 )
s.add(output[59] == flag[24] * flag[28] + flag[25] * flag[29] + flag[26] * flag[30] + flag[27] * flag[31]  + flag[28] * flag[32])


sat = (s.check())

if sat == z3.sat:
    m = s.model()
    print(m)
    for x in range(45):
        # print(str(m[flag[x]]))
        print(chr(int(str(m[flag[x]]))), end='')
else:
    print("unsat")

And, finally:

CCSC{r3v_s0m3_rust_4nd_s0lv3_s0m3_c0nstra1nt5%                                                                                                                                                                                                                                                

I didn’t r3v_s0m3_rust, but I definitely did s0lv3_s0m3_c0nstra1nt5!

Updated: