from tkinter import ttk
import tkinter as tk

import numpy as np
from sklearn import model_selection
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sqlalchemy import create_engine
import pandas as pd
import psycopg2
import tkinter as tk

from scapy.sendrecv import sniff

from scapy.layers.l2 import Dot3, LLC, SNAP, Ether, Dot1Q, ARP, STP
from scapy.contrib.dtp import *
from scapy.packet import ls, Raw
from sklearn.model_selection import train_test_split

def t1_create_table():
    global connection, cursor
    name = table_name_value.get()
    try:
        connection = psycopg2.connect(user="postgres",
                                      password="DBpass123!",
                                      host="127.0.0.1",
                                      port="5432",
                                      database="test_DB")
        cursor = connection.cursor()

        create_table_query = """CREATE TABLE """ + name + """ (
                msg_id SERIAL PRIMARY KEY,
                src_mac VARCHAR(17) NOT NULL,
                dst_mac VARCHAR(17) NOT NULL,
                src_ip VARCHAR(15) NOT NULL,
                dst_ip VARCHAR(15) NOT NULL,
                icmp_type VARCHAR(1) NOT NULL,
                payload VARCHAR(32) NOT NULL
            )"""
        cursor.execute(create_table_query)
        connection.commit()

    except (Exception, psycopg2.Error) as error:
        print("Error while connecting to PostgreSQL", error)
    finally:
        if connection:
            cursor.close()
            connection.close()
            print("PostgreSQL connection is closed")

from scapy.all import IP, ICMP

def t2_ICMP_Packet_Capture():
    table = T2_table_name_value.get()
    packets = sniff(filter="icmp", count=int(number_of_icmp_msg_value.get()))

    for packet in packets:
        src_mac = packet.getlayer(Ether).src
        dst_mac = packet.getlayer(Ether).dst
        src_ip = packet.getlayer(IP).src
        dst_ip = packet.getlayer(IP).dst
        icmp_type = packet.getlayer(ICMP).type
        payload = packet.getlayer(Raw).load.decode()
        print(src_mac, dst_mac, src_ip, dst_ip, icmp_type, payload)
        t3_insert_to_table(table, src_mac, dst_mac, src_ip, dst_ip, icmp_type, payload)


def t3_insert_to_table(table_name, src_mac, dst_mac, src_ip, dst_ip, icmp_type, payload):
    global connection, cursor
    try:
        connection = psycopg2.connect(user="postgres",
                                      password="DBpass123!",
                                      host="127.0.0.1",
                                      port="5432",
                                      database="test_DB")
        cursor = connection.cursor()
        # Executing a SQL query to insert data into  table
        insert_query = """INSERT INTO """ + table_name + """ 
            (src_mac, dst_mac, src_ip, dst_ip, icmp_type, payload) 
            VALUES (%s, %s, %s, %s, %s, %s)
            """
        value_tuple = (src_mac, dst_mac, src_ip, dst_ip, icmp_type, payload)
        cursor.execute(insert_query, value_tuple)
        connection.commit()
        print("1 Record inserted successfully")
        # Fetch result
        cursor.execute("SELECT * from " + table_name)
        record = cursor.fetchall()

    except (Exception, psycopg2.Error) as error:
        print("Error while connecting to PostgreSQL", error)
    finally:
        if connection:
            cursor.close()
            connection.close()
            print("PostgreSQL connection is closed")


def t4_ICMP_Packet_Analyze():
    global connection, cursor
    try:
        connection = psycopg2.connect(user="postgres",
                                      password="DBpass123!",
                                      host="127.0.0.1",
                                      port="5432",
                                      database="test_DB")
        cursor = connection.cursor()
        cursor.execute("SELECT * from " + T3_table_name_value.get())
        records = cursor.fetchall()
        index = 5
        for record in records:
            if record[6].find(T3_secret_string_value.get()) != -1:
                ttk.Label(tab3, text="Dangerous Traffic:").grid(column=0, row=4)
                ttk.Label(tab3, text=f"{record[1]}--{record[2]}=={record[3]}--{record[4]}:{record[6]}").grid(column=0,
                                                                                                             row=index)
                print(record[3], record[4])
                index += 1
    except (Exception, psycopg2.Error) as error:
        print("Error while connecting to PostgreSQL", error)
    finally:
        if connection:
            cursor.close()
            connection.close()
            print("PostgreSQL connection is closed")


def t5_ML_DataSet():
    connection = psycopg2.connect(user="postgres",
                                  password="DBpass123!",
                                  host="127.0.0.1",
                                  port="5432",
                                  database="test_DB")
    df = pd.read_sql("select * from " + T4_table_name_value.get(), con=connection)
    print(df)
    mark = []
    cursor = connection.cursor()
    cursor.execute("SELECT * from " + T4_table_name_value.get())
    records = cursor.fetchall()
    for record in records:
        if record[6].find(T4_danger_string_value.get()) != -1:
            mark.append("Danger")
        else:
            mark.append("OK")

    df2 = df.assign(label=mark)
    engine = create_engine("postgresql://postgres:DBpass123!@localhost:5432/test_DB")
    df2.to_sql(T4_final_table_with_lable_value.get(), engine, if_exists='replace')
    print(df2)


def t6_ICMP_Packet_Clasification():
    global connection
    ttk.Label(tab5, text="                      ", background="white").grid(column=1, row=10)
    ttk.Label(tab5, text="                      ", background="white").grid(column=1, row=11)
    ttk.Label(tab5, text="                      ", background="white").grid(column=1, row=13)
    try:
        connection = psycopg2.connect(user="postgres",
                                      password="DBpass123!",
                                      host="127.0.0.1",
                                      port="5432",
                                      database="test_DB")
        df = pd.read_sql("select * from " + T5_train_table_value.get(), con=connection)

        my_test_json = {
            "src_mac": [T5_src_mac_value.get()],
            "dst_mac": [T5_dst_mac_value.get()],
            "src_ip": [T5_src_ip_value.get()],
            "dst_ip": [T5_dst_ip_value.get()],
            "icmp_type": [T5_icmp_type_value.get()],
            "payload": [T5_payload_value.get()]
        }
        my_test_df = pd.DataFrame(my_test_json)

        x_matrix = df.iloc[:, 2:8]
        x_matrix = pd.concat([x_matrix, my_test_df], ignore_index=True, axis=0)

        column_header = ["src_mac", "dst_mac", "src_ip", "dst_ip", "payload"]
        for column in column_header:
            enc = LabelEncoder()
            label_encoder = enc.fit(x_matrix[column])
            x_matrix[column] = label_encoder.transform(x_matrix[column])

        my_test_df = x_matrix.tail(1)
        x_matrix = x_matrix.iloc[:-1, :]

        y_vector = np.array(df.label)

        X_train, X_test, y_train, y_test = train_test_split(x_matrix, y_vector, test_size=0.25, random_state=0)

        rf = RandomForestClassifier(criterion="entropy", max_depth=5, n_estimators=5)
        scores = model_selection.cross_val_score(rf, X_train, y_train, cv=5, n_jobs=-1)
        # vypis testovani jednotlivych foldov
        print("\nPriemerne score najlepsieho modelu: " + str(np.mean(scores)))

        # predikcia a vratenie vysledneho vektora pre testovacie data
        rf.fit(X_train, y_train)
        final_predict = rf.predict(X_test)
        tts_acc = (y_test == final_predict).mean()
        print("Total Correct:\t", (y_test == final_predict).sum())

        # predikcia na novej vzorke
        my_value = rf.predict(my_test_df)
        print("Predikcia je:", my_value)
        ttk.Label(tab5, text="Result\n:############################").grid(column=0, row=9)
        ttk.Label(tab5, text="Average Accuracy of Classificator (K-Fold): ").grid(column=0, row=10)
        ttk.Label(tab5, text=f"{str(np.mean(scores))}").grid(column=1, row=10)
        ttk.Label(tab5, text="Average Accuracy of Classificator (Train_Test_Split): ").grid(column=0, row=11)
        ttk.Label(tab5, text=f"{tts_acc}").grid(column=1, row=11)
        ttk.Label(tab5, text="############################").grid(column=0, row=12)
        ttk.Label(tab5, text="Classification Result: ").grid(column=0, row=13)
        if my_value == "OK":
            ttk.Label(tab5, text=f"{my_value}", background="green", foreground="white").grid(column=1, row=13)
        else:
            ttk.Label(tab5, text=f"{my_value}", background="red", foreground="red").grid(column=1, row=13)


    except (Exception, psycopg2.Error) as error:
        print("Error while connecting to PostgreSQL", error)
    finally:
        if connection:
            print("PostgreSQL connection is closed")


if __name__ == '__main__':
    root = tk.Tk()
    root.title("Network Security Tool")
    root.geometry("700x300")

    tabControl = ttk.Notebook(root)
    tab1 = ttk.Frame(tabControl)
    tab2 = ttk.Frame(tabControl)
    tab3 = ttk.Frame(tabControl)
    tab4 = ttk.Frame(tabControl)
    tab5 = ttk.Frame(tabControl)
    tabControl.add(tab1, text="Create Table")
    tabControl.add(tab2, text="Capture Data")
    tabControl.add(tab3, text="Analyze Data")
    tabControl.add(tab4, text="ML-part1")
    tabControl.add(tab5, text="ML-part2")
    tabControl.pack(expand=1, fill="both")

    table_name_value = tk.StringVar(root, "")
    ttk.Label(tab1, text="Creating Table for Storing Data").grid(columnspan=2, row=0)
    ttk.Label(tab1, text="Set Table Name:").grid(column=0, row=1)
    ttk.Entry(tab1, textvariable=table_name_value).grid(column=1, row=1)
    ttk.Button(tab1, text="Create Table", command=t1_create_table).grid(column=0, row=4)

    number_of_icmp_msg_value = tk.StringVar(root, "")
    T2_table_name_value = tk.StringVar(root, "")
    ttk.Label(tab2, text="Capturing ICMP messages").grid(columnspan=2, row=0)
    ttk.Label(tab2, text="Set Packet Count:").grid(column=0, row=1)
    ttk.Entry(tab2, textvariable=number_of_icmp_msg_value).grid(column=1, row=1)
    ttk.Label(tab2, text="Set Table Name:").grid(column=0, row=2)
    ttk.Entry(tab2, textvariable=T2_table_name_value).grid(column=1, row=2)
    ttk.Button(tab2, text="Start Capture", command=t2_ICMP_Packet_Capture).grid(column=0, row=3)

    T3_secret_string_value = tk.StringVar(root, "")
    T3_table_name_value = tk.StringVar(root, "")
    ttk.Label(tab3, text="Analyze ICMP messages").grid(columnspan=2, row=0)
    ttk.Label(tab3, text="Set Secret Word to Find:").grid(column=0, row=1)
    ttk.Entry(tab3, textvariable=T3_secret_string_value).grid(column=1, row=1)
    ttk.Label(tab3, text="Set Table Name:").grid(column=0, row=2)
    ttk.Entry(tab3, textvariable=T3_table_name_value).grid(column=1, row=2)
    ttk.Button(tab3, text="Analyze", command=t4_ICMP_Packet_Analyze).grid(column=0, row=3)

    ttk.Label(tab4, text="DataSet Creation").grid(columnspan=2, row=0)
    T4_table_name_value = tk.StringVar(root, "")
    T4_danger_string_value = tk.StringVar(root, "")
    T4_final_table_with_lable_value = tk.StringVar(root, "")
    ttk.Label(tab4, text="Set Table Name:").grid(column=0, row=1)
    ttk.Entry(tab4, textvariable=T4_table_name_value).grid(column=1, row=1)
    ttk.Label(tab4, text="Set Danger Payload:").grid(column=0, row=2)
    ttk.Entry(tab4, textvariable=T4_danger_string_value).grid(column=1, row=2)
    ttk.Label(tab4, text="Set Final Table Name (labeled):").grid(column=0, row=3)
    ttk.Entry(tab4, textvariable=T4_final_table_with_lable_value).grid(column=1, row=3)
    ttk.Button(tab4, text="Create DataSet", command=t5_ML_DataSet).grid(column=0, row=4)

    T5_src_mac_value = tk.StringVar(root, "")
    T5_dst_mac_value = tk.StringVar(root, "")
    T5_src_ip_value = tk.StringVar(root, "")
    T5_dst_ip_value = tk.StringVar(root, "")
    T5_icmp_type_value = tk.StringVar(root, "")
    T5_payload_value = tk.StringVar(root, "")
    T5_train_table_value = tk.StringVar(root, "")

    ttk.Label(tab5, text="Classification of ICMP messages using ML").grid(columnspan=2, row=0)
    ttk.Label(tab5, text="Set Source MAC address:").grid(column=0, row=1)
    ttk.Entry(tab5, textvariable=T5_src_mac_value).grid(column=1, row=1)
    ttk.Label(tab5, text="Set Destination MAC address:").grid(column=0, row=2)
    ttk.Entry(tab5, textvariable=T5_dst_mac_value).grid(column=1, row=2)
    ttk.Label(tab5, text="Set Source IP address:").grid(column=0, row=3)
    ttk.Entry(tab5, textvariable=T5_src_ip_value).grid(column=1, row=3)
    ttk.Label(tab5, text="Set Destionation IP address:").grid(column=0, row=4)
    ttk.Entry(tab5, textvariable=T5_dst_ip_value).grid(column=1, row=4)
    ttk.Label(tab5, text="Set ICMP Type:").grid(column=0, row=5)
    ttk.Entry(tab5, textvariable=T5_icmp_type_value).grid(column=1, row=5)
    ttk.Label(tab5, text="Set Payload:").grid(column=0, row=6)
    ttk.Entry(tab5, textvariable=T5_payload_value).grid(column=1, row=6)
    ttk.Label(tab5, text="Set Train Table Name:").grid(column=0, row=7)
    ttk.Entry(tab5, textvariable=T5_train_table_value).grid(column=1, row=7)
    ttk.Button(tab5, text="Start Clasification", command=t6_ICMP_Packet_Clasification).grid(column=0, row=8)

    root.mainloop()
