#!/usr/bin/python3

from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from secret import FLAG

def print_banner():
    print("== EaaS v2.0 ==")
    print("This is the new and improved Encryption as a Service!")
    print("Now can you break it?")

def menu():
    print("Options:")
    print("1. Get my encrypted confidential data.")
    print("2. Send me your encrypted data.")
    print("Enter your choice:")
    return input()

def encrypt(data: bytes, key: bytes, iv: bytes) -> bytes:
    cipher = AES.new(key, AES.MODE_CBC, iv)
    return iv + cipher.encrypt(pad(data, AES.block_size))

def decrypt(data: bytes, key: bytes, iv: bytes) -> bytes:
    cipher = AES.new(key, AES.MODE_CBC, iv)
    return unpad(cipher.decrypt(data), AES.block_size)[16:]
 
def eaas_v2(key: bytes, iv: bytes):
    choice = menu()

    if choice == '1':
        print("Here is your encrypted confidential data:")
        print(encrypt(FLAG, key, iv).hex())

    elif choice == '2':
        print("Enter your encrypted data:")
        ciphertext = bytes.fromhex(input())

        if decrypt(ciphertext, key, iv) == FLAG:
            print("Whoops! You are not supposed to know that!")
        else:
            print("That's not it!")

def main():
    key = get_random_bytes(AES.block_size)
    iv  = get_random_bytes(AES.block_size)

    print_banner()
    while True:
        try:
            eaas_v2(key, iv)
        except Exception as e:
            print("Error:", e)
        
if __name__ == '__main__':
    main()
