import socket
import binascii


def build_ethernet_header(dst_mac, src_mac, protocol_type):
    dst_mac_byte = binascii.unhexlify(dst_mac.replace(":", ""))
    src_mac_byte = binascii.unhexlify(src_mac.replace(":", ""))

    pt_options = {
        "ipv4": b'\x08\x00',
        "ipv6": b'\x86\xdd',
        "arp": b'\x08\x06'
    }
    pt_byte = pt_options[protocol_type]

    created_ethernet_header = dst_mac_byte + src_mac_byte + pt_byte
    return created_ethernet_header


def build_ipv4_header(src_ip, dst_ip, protocol, packet_payload):
    version_and_hl_byte = b'\x45'  # version 4 je pre IPv4, HL = dlzka hlavicky v pocte 32 bit. slov = 5
    tos_byte = b'\x00'  # ToS = Type of Service suvisi s kvalitou sluzieb
    identification_byte = b'\x00\x00'
    flags_byte = b'\x40\x00'

    tmp_ttl = 13  # vzpocet celkovej dlzky paketu v bajtoch
    ttl_byte = tmp_ttl.to_bytes(1, 'big')  # konverzia celkovej dlzky do hex na 16bit

    protocol_options = {
        "icmp": 1,
        "tcp": 6,
        "udp": 17
    }
    tmp_protocol = protocol_options[protocol]
    protocol_byte = tmp_protocol.to_bytes(1, 'big')

    tmp_total_length = 20 + len(packet_payload)  # vzpocet celkovej dlzky paketu v bajtoch
    total_length_byte = tmp_total_length.to_bytes(2, 'big')  # konverzia celkovej dlzky do hex na 16bit

    header_checksum_byte = b'\x00\x00'

    src_ip_byte = socket.inet_aton(src_ip)
    dst_ip_byte = socket.inet_aton(dst_ip)

    created_ipv4_header = version_and_hl_byte + tos_byte + total_length_byte + identification_byte + flags_byte + \
                          ttl_byte + protocol_byte + header_checksum_byte + src_ip_byte + dst_ip_byte

    checksum = compute_header_checksum(created_ipv4_header, 160)
    header_checksum_byte = binascii.unhexlify(hex(int(checksum, 2))[2:])

    created_ipv4_header = version_and_hl_byte + tos_byte + total_length_byte + identification_byte + flags_byte + \
                          ttl_byte + protocol_byte + header_checksum_byte + src_ip_byte + dst_ip_byte

    return created_ipv4_header


def compute_header_checksum(created_header, bit_size):
    hex_val = created_header.hex()
    binary = bin(int(hex_val, 16))[2:].zfill(bit_size)

    header_bin_array = []
    for i in range(int(len(binary) / 16)):
        tmp_str = binary[16 * i:16 * i + 16]
        header_bin_array.append(tmp_str)

    checksum = header_bin_array[0]
    for i in range(len(header_bin_array)-1):
        checksum = compute_checksum(checksum+header_bin_array[i+1], 16)

    # Calculating the complement of sum
    final_checksum = ''
    for i in checksum:
        if i == '1':
            final_checksum += '0'
        else:
            final_checksum += '1'
    return final_checksum


def build_icmp_echo_request_header(msg):
    type_byte = b'\x08'
    code_byte = b'\x00'
    checksum_byte = b'\x00\x00'
    identifier_byte = b'\x00\x00'
    sequence_number_byte = b'\x00\x00'
    payload = msg.encode()
    created_icmp_echo_request_header = type_byte + code_byte + checksum_byte + identifier_byte + sequence_number_byte + \
                                       payload

    checksum = compute_header_checksum(created_icmp_echo_request_header, len(created_icmp_echo_request_header)*8)
    checksum_byte = binascii.unhexlify(hex(int(checksum, 2))[2:])

    created_icmp_echo_request_header = type_byte + code_byte + checksum_byte + identifier_byte + sequence_number_byte + \
                                       payload

    return created_icmp_echo_request_header


def compute_checksum(sent_message, k):
    # Dividing sent message in packets of k bits.
    c1 = sent_message[0:k]
    c2 = sent_message[k:2 * k]

    # Calculating the binary sum of packets
    calc_sum = bin(int(c1, 2) + int(c2, 2))[2:]  # konverzia do 10 sustavy, scitanie a prevod do binarneho tvaru
    # Adding the overflow bits
    if len(calc_sum) > k:
        x = len(calc_sum) - k
        calc_sum = bin(int(calc_sum[0:x], 2) + int(calc_sum[x:], 2))[2:]
    if len(calc_sum) < k:
        calc_sum = '0' * (k - len(calc_sum)) + calc_sum

    return calc_sum


if __name__ == '__main__':
    s = socket.socket(socket.AF_PACKET, socket.SOCK_RAW)
    s.bind(("enp0s3", 0))

    ethernet_header = build_ethernet_header("11:22:33:11:22:33", "aa:bb:cc:aa:bb:cc", "ipv4")

    icmp_header = build_icmp_echo_request_header("TEST")
    ipv4_header = build_ipv4_header("10.20.30.40", "22.22.22.22", "icmp", icmp_header)

    packet = ethernet_header + ipv4_header + icmp_header

    s.send(packet)