#!/usr/bin/env python

import numpy as np
import threading,importlib,queue,sys,time

class Node:
    available_node_id=0
    def __init__(self,src,interfaces):
        """

        """
        self.node_id=Node.available_node_id
        Node.available_node_id+=1 # Refresh node id
        self.src=src # Store the node source code
        self.args=None # Store the node arguments (passed through Simulator.create_node()
        self.rargs=None # Store the requests arguments
        self.plugins=list() # Contains all registered node plugins
        self.rqueue=queue.Queue() # Receive simulator acknowledgments
        self.chest={"state":"running", "turned_on":True, "request": None, "interfaces":dict(), "interfaces_queue_size":dict()}
        for interface in interfaces:
            self.chest["interfaces"][interface]=queue.Queue()
            self.chest["interfaces_queue_size"][interface]=0
        self.chest_lock=threading.Lock() # To access/modify self.chest

    def plugin_register(self,plugin):
        self.plugins.append(plugin)

    def plugin_notify(self,reason,args):
        """
        This function strives to avoid using Python specific features
        """
        for p in self.plugins:
            if reason == "receive_return" or reason == "receivet_return":
                p.on_receive_return(args[0],args[1],args[2],args[3])
            if reason == "send_call":
                p.on_send_call(args[0],args[1],args[2],args[3])
            if reason == "send_return":
                p.on_send_return(args[0],args[1],args[2],args[3],args[4])
            if reason == "terminated":
                p.on_terminated()

    def __getitem__(self,key):
        self.chest_lock.acquire()
        value=self.chest[key]
        self.chest_lock.release()
        return value

    def __setitem__(self,key,value):
        self.chest_lock.acquire()
        value=self.chest[key]=value
        self.chest_lock.release()

    def log(self,msg):
        self.rargs=msg
        self["request"]="log"
        self["state"]="call"
        self.wait_ack(["log"])

    def read(self, register):
        self["request"]="read"
        self.rargs=register
        self["state"]="call"
        ack=self.wait_ack(["read"])
        return ack[1]

    def wait(self,duration):
        self.rargs=duration
        self["request"]="timeout_add"
        self["state"]="call"
        self.wait_ack(["timeout_add"])
        self["state"]="pending"
        self.wait_ack(["timeout"])

    def wait_end(self):
        self["request"]="wait_end"
        self["state"]="request"
        self.wait_ack(["wait_end"])
        self.wait_ack(["sim_end"])

    def turn_off(self):
        self["turned_on"]=False
        self["request"]="turn_off"
        self["state"]="call"
        self.wait_ack(["turn_off"])
        
    def turn_on(self):
        self["turned_on"]=True
        self["request"]="turn_on"
        self["state"]="call"
        self.wait_ack(["turn_on"])
        
    def send(self, interface, data, datasize, dst):
        self.plugin_notify("send_call",(interface,data,datasize,dst))
        self.rargs=(interface, data, datasize, dst)
        self["request"]="send"
        self["state"]="request"
        ack=self.wait_ack(["send","send_cancel"])
        self.plugin_notify("send_return",(interface,data,datasize,dst,ack[1]))
        return ack[1]

    def receive(self,interface):
        self["request"]="receive"
        self.rargs=interface
        self["state"]="request"
        self.wait_ack(["receive"])
        data,start_at,end_at=self["interfaces"][interface].get()
        self.plugin_notify("receive_return",(interface,data,start_at,end_at))
        return (0,data)

    def sendt(self, interface, data, datasize, dst, timeout):
        self.rargs=timeout
        self["request"]="timeout_add"
        self["state"]="call"
        self.wait_ack(["timeout_add"])
        self.rargs=(interface, data, datasize, dst)
        self["request"]="send"
        self["state"]="request"
        ack=self.wait_ack(["send","timeout","send_cancel"])
        if ack[0] == "timeout":
            self["request"]="send_cancel"
            self["state"]="call"
            self.wait_ack(["send_cancel"])
            return -1
        self["request"]="timeout_remove"
        self["state"]="call"
        self.wait_ack(["timeout_remove"])
        return ack[1]

    def receivet(self,interface, timeout):
        self.rargs=timeout
        self["request"]="timeout_add"
        self["state"]="call"
        self.wait_ack(["timeout_add"])
        self["request"]="receive"
        self.rargs=interface
        self["state"]="request"
        ack=self.wait_ack(["receive","timeout"])
        if ack[0] == "timeout":
            return (-1,None)
        self["request"]="timeout_remove"
        self["state"]="call"
        self.wait_ack(["timeout_remove"])
        data,start_at,end_at=self["interfaces"][interface].get()
        self.plugin_notify("receivet_return",(interface,data,start_at,end_at))
        return (0,data)

    def wait_ack(self, ack_types):
        """
        Wait for specific acks from the request queue (rqueue)
        """
        ack_buffer=list() # To filter ack
        ack=None
        while True:
            ack=self.rqueue.get() # Wait for simulator acknowledgments
            if ack[0] not in ack_types:
                ack_buffer.append(ack)
            else:
                break
        # Push back the filtered ack
        for cur_ack in ack_buffer:
            self.rqueue.put(cur_ack)
        return(ack)
    
    def sync(self):
        """
        Wait until node stop running
        """
        while self["state"] == "running":
            pass
        
    def run(self,args):
        """
        Load and run the user program
        """
        self.node=importlib.import_module(self.src)
        self.args=args # Allow access to arguments
        self.node.execute(self)
        self["state"]="terminated"

class Simulator:
    """
    Flow-Level Discrete Event Simulator for Cyber-Physical Systems
    The general format for an event is (type,timestamp,event,priority)
    Event types:
        - 0 send                  (0,timestamp,(src,dst,interface,data,datasize,duration,datasize_remaining), 1)
        - 1 timeout               (1,timestamp,node_id,4)
        - 2 breakpoint_manual     (3,timestamp,0,0)
        - 3 breakpoint_auto       (4,timestamp,0,0)

    Very important: when the simulator wakes up a node (changing is state to running)
    data that should be received by that node on the current simulated time SHOULD be in the queue!
    Thus, the send event must be handle before the other event (priority equals to 1). Otherwise plugings such as the power states
    one may not gives accurate results because of missing entries in the nodes received queues.
    """
    
    def __init__(self,netmat):
        """
        Format of netmat: { "wlan0": (BW,L,IS_WIRED), "eth0": (BW,L,IS_WIRED) }
        Where BW are the bandwidth matrices and L the latency matrices. IS_WIRED is a
        boolean specifying if the interface is wired or wireless.
        """
        self.netmat=netmat
        self.nodes=list()
        self.sharing=dict()
        for interface in netmat.keys():
            if netmat[interface]["is_wired"]:
                self.sharing[interface]=np.zeros(len(netmat[interface]["bandwidth"]))
        self.events=np.empty((0,4),dtype=object)
        self.events_dirty=True # For optimization reasons
        self.startat=-1
        self.time=0
        self.debug_file_path="./esds.debug"
        self.precision=".3f"
        self.interferences=True
        self.wait_end_nodes=list() # Keep track of nodes that wait for the end of the simulation
        self.time_truncated=format(self.time,self.precision) # Truncated version is used in log print

    def update_network(self,netmat):
        for event in self.events:
            if int(event[0]) == 0:
                cur_event=event[2]
                ts=float(event[1])
                src_id,dst_id,interface, data, datasize,duration, datasize_remaining,start_at=cur_event
                new_bw=netmat[interface]["bandwidth"][int(src_id),int(dst_id)]
                old_bw=self.netmat[interface]["bandwidth"][int(src_id),int(dst_id)]
                new_lat=netmat[interface]["latency"][int(src_id),int(dst_id)]
                old_lat=self.netmat[interface]["latency"][int(src_id),int(dst_id)]
                if new_bw != old_bw or new_lat != old_lat:
                    new_datasize_remaining=float(datasize_remaining)*((ts-self.time)/float(duration))
                    if new_datasize_remaining > 0:
                        latency_factor=new_datasize_remaining/float(datasize)
                        if interface == "wlan0":
                            new_duration=new_datasize_remaining*8/new_bw+new_lat*latency_factor
                        else:
                            new_duration=new_datasize_remaining*8/(new_bw/self.sharing[interface][int(dst_id)])+new_lat*latency_factor
                        event[1]=self.time+new_duration
                        event[2][6]=new_datasize_remaining
                        event[2][5]=new_duration
        self.netmat=netmat
            
    def debug(self):
        """
        Log all the informations for debugging
        """
        stdout_save = sys.stdout
        with open(self.debug_file_path, "a") as debug_file:
            sys.stdout = debug_file 
            print("-----------------------------------------------")
            print("Started since {}s".format(round(time.time()-self.startat,2)))
            print("Simulated time {}s (or more precisely {}s)".format(self.time_truncated,self.time))
            states=dict()
            timeout_mode=list()
            sharing=dict()
            for node in self.nodes:
                s=node["state"]
                states[s]=states[s]+1 if s in states else 1
                if self.sharing["eth0"][node.node_id] > 0:
                    sharing["n"+str(node.node_id)]=str(int(self.sharing["eth0"][node.node_id]))
            print("Node number per state: ",end="")
            for key in states:
                print(key+"="+str(states[key]), end=" ")
            print("\nNode sharing: ",end="")
            for node_id in sharing:
                print(node_id+"="+sharing[node_id], end=" ")
            print("\nIds of node in timeout mode: ", end="")
            for n in timeout_mode:
                print(n,end=" ")
            print("\nSorted events list:")
            print(self.events)
            sys.stdout = stdout_save
          
    def create_node(self, src, args=None):
        """
        Create a node thread and run it
        """
        node=Node(src, self.netmat.keys())
        self.nodes.append(node)
        thread=threading.Thread(target=node.run, daemon=False,args=[args])
        thread.start()

    def log(self,msg,node=None):
        src = "esds" if node is None else "n"+str(node)
        print("[t="+str(self.time_truncated)+",src="+src+"] "+msg)

    def sort_events(self):
        """
        Sort the events by timestamp and priorities
        """
        sorted_indexes=np.lexsort((self.events[:,3],self.events[:,1]))
        self.events=self.events[sorted_indexes]
        
    def sync_node(self,node):
        """
        Process all call request and wait for Node.sync() to return
        """
        node.sync()
        while node["state"] == "call":
            if node["request"] == "log":
                self.log(node.rargs,node=node.node_id)
                node["state"]="running"
                node.rqueue.put(("log",0))
            elif node["request"] == "timeout_add":
                self.add_event(1,self.time+node.rargs,node.node_id,priority=3)
                node["state"]="running"
                node.rqueue.put(("timeout_add",0))
            elif node["request"] == "timeout_remove":
                selector=list()
                for event in self.events:
                    if event[0] == 1 and event[2]==node.node_id:
                        selector.append(True)
                    else:
                        selector.append(False)
                self.events=self.events[~np.array(selector)]
                node["state"]="running"
                node.rqueue.put(("timeout_remove",0))
            elif node["request"] == "read":
                node["state"]="running"
                if node.rargs == "clock":
                    node.rqueue.put(("read",self.time))
                elif node.rargs == "wlan0_ncom":
                    count=0
                    # Count number of communication on wlan0
                    for event in self.events:
                        if event[0] == 0 and event[2][1] == node.node_id and event[2][2] == "wlan0":
                            count+=1
                    node.rqueue.put(("read",count))
                elif node.rargs == "eth0_ncom":
                    count=0
                    # Count number of communication on eth0
                    for event in self.events:
                        if event[0] == 0 and event[2][1] == node.node_id and event[2][2] == "eth0":
                            count+=1
                    node.rqueue.put(("read",count))
                else:
                    node.rqueue.put(("read",0)) # Always return 0 if register is unknown
            elif node["request"] == "turn_on":
                node["state"]="running"
                node.rqueue.put(("turn_on",0))
                self.log("Turned on",node=node.node_id)
            elif node["request"] == "turn_off":
                selector_wlan0=list()
                selector_other=list()
                for event in self.events:
                    if event[0]==0 and int(event[2][1])==node.node_id:
                        if event[2][2] == "wlan0":
                            selector_wlan0.append(True)
                            selector_other.append(False)
                        else:
                            selector_wlan0.append(False)
                            selector_other.append(True)
                    else:
                        selector_wlan0.append(False)
                        selector_other.append(False)
                # Informed sender to cancel send
                for event in self.events[selector_other]:
                    sender=self.nodes[int(event[2][0])]
                    sender["state"]="running"
                    sender.rqueue.put(("send_cancel",2))
                # Remove communications
                if(len(self.events) != 0):
                    self.events=self.events[~(np.array(selector_wlan0)|np.array(selector_other))]
                for interface in self.sharing.keys():
                    self.sharing[interface][node.node_id]=0 # Sharing goes back to zero
                node["state"]="running"
                node.rqueue.put(("turn_off",0))
                self.log("Turned off",node=node.node_id)
            elif node["request"] == "send_cancel":
                selector=list()
                for event in self.events:
                    if event[0]==0 and int(event[2][0]) == node.node_id:
                        selector.append(True)
                        if event[2][2] != "wlan0":
                            self.update_sharing(int(event[2][1]),-1,event[2][2])
                    else:
                        selector.append(False)
                self.events=self.events[~np.array(selector)]
                node["state"]="running"
                node.rqueue.put(("send_cancel",0))
            node.sync()

    def update_sharing(self, dst, amount,interface):
        """
        Manage bandwidth sharing on wired interfaces
        """
        sharing=self.sharing[interface][dst]
        new_sharing=sharing+amount
        for event in self.events:
            if event[0] == 0 and event[2][2] != "wlan0" and int(event[2][1]) == dst:
                remaining=event[1]-self.time
                if remaining > 0:
                    remaining=remaining/sharing if sharing>1 else remaining # First restore sharing
                    remaining=remaining*new_sharing if new_sharing > 1 else remaining # Then apply new sharing
                    event[2][5]=remaining # Update duration
                    event[1]=self.time+remaining # Update timestamp
        self.sharing[interface][dst]=new_sharing
        self.sort_events()

    def handle_interferences(self,sender,receiver):
        """
        Interferences are detected by looking for conflicts between
        new events and existing events.
        """
        status=False
        selector=list()
        notify=set()
        for event in self.events:
            event_type=event[0]
            com=event[2]
            if event_type==0 and com[2] == "wlan0":
                com_sender=int(com[0])
                com_receiver=int(com[1])
                select=False
                if receiver==com_sender:
                    status=True
                    notify.add(receiver)
                elif receiver==com_receiver:
                    status=True
                    select=True
                    notify.add(receiver)
                if sender==com_receiver and com_sender != com_receiver:
                    select=True
                    notify.add(sender)
                selector.append(select)
            else:
                selector.append(False)
        if len(selector) != 0:
            self.events=self.events[~np.array(selector)]
            for node in notify:
                self.log("Interferences on wlan0",node=node)
        return status
    
    def sync_event(self, node):
        """
        Collect events from the nodes
        """
        if node["state"] == "request":
            if node["request"] == "send":
                node["state"]="pending"
                interface, data, datasize, dst=node.rargs
                self.communicate(interface, node.node_id, dst, data, datasize)
            elif node["request"] == "receive":
                interface=node.rargs
                if node["interfaces_queue_size"][interface] > 0:
                    node["interfaces_queue_size"][interface]-=1
                    node.rqueue.put(("receive",0))
                    node["state"]="running"
                    # Do not forget to collect the next event. This is the only request which is processed here
                    self.sync_node(node)
                    self.sync_event(node)
            elif node["request"] == "wait_end":
                node["state"]="pending"
                node.rqueue.put(("wait_end",0))
                self.wait_end_nodes.append(node.node_id)

    def communicate(self, interface, src, dst, data, datasize):
        """
        Create communication event between src and dst
        """
        nsrc=self.nodes[src]
        if interface=="wlan0":
            self.log("Send "+str(datasize)+" bytes on "+interface,node=src)
            for dst in self.list_wireless_receivers(nsrc):
                if self.nodes[dst]["turned_on"]:
                    duration=datasize*8/self.netmat["wlan0"]["bandwidth"][src,dst]+self.netmat["wlan0"]["latency"][src,dst]
                    if src == dst:
                        self.add_event(0,duration+self.time,(src,dst,interface,data,datasize,duration,datasize,self.time))
                    elif not self.interferences:
                        self.add_event(0,duration+self.time,(src,dst,interface,data,datasize,duration,datasize,self.time))
                    elif not self.handle_interferences(src,dst):
                        self.add_event(0,duration+self.time,(src,dst,interface,data,datasize,duration,datasize,self.time))
        else:
            if self.nodes[dst]["turned_on"]:
                self.log("Send "+str(datasize)+" bytes to n"+str(dst)+" on "+interface,node=src)
                self.update_sharing(dst,1,interface) # Update sharing first
                # Note that in the following we send more data than expected to handle bandwidth sharing (datasize*8*sharing):
                duration=datasize*8/(self.netmat["eth0"]["bandwidth"][src,dst]/self.sharing["eth0"][dst])+self.netmat["eth0"]["latency"][src,dst]
                self.add_event(0,duration+self.time,(src,dst,interface,data,datasize,duration,datasize,self.time))
            else:
                nsrc["state"]="request" # Try later when node is on
        
        
    def list_wireless_receivers(self,node):
        """
        Deduce reachable receivers from the bandwidth matrix
        """
        selector = self.netmat["wlan0"]["bandwidth"][node.node_id,] > 0
        return np.arange(0,selector.shape[0])[selector]

            
    def add_event(self,event_type,event_ts,event,priority=1):
        """
        Call this function with sort=True the least amount of time possible
        """
        self.events=np.concatenate([self.events,[np.array([event_type,event_ts,np.array(event,dtype=object),priority],dtype=object)]]) # Add new events
        self.sort_events()
            
    def run(self, breakpoints=[],breakpoint_callback=lambda s:None,breakpoints_every=None,debug=False,interferences=True):
        """
        Run the simulation with the created nodes
        """
        ##### Setup simulation
        self.startat=time.time()
        self.interferences=interferences
        for bp in breakpoints:
            self.add_event(2,bp,0,0)
        if breakpoints_every != None:
            self.add_event(3,breakpoints_every,0,0)
        if debug:
            with open(self.debug_file_path, "w") as f:
                f.write("Python version {}\n".format(sys.version))
                f.write("Simulation started at {}\n".format(self.startat))
                f.write("Number of nodes is "+str(len(self.nodes))+"\n")
                f.write("Manual breakpoints list: "+str(breakpoints)+"\n")
                f.write("Breakpoints every "+str(breakpoints_every)+"s\n")
        ##### Simulation loop
        while True:
            # Synchronize every nodes
            for node in self.nodes:
                self.sync_node(node)
            # Manage events
            for node in self.nodes:
                self.sync_event(node)
            # Generate debug logs
            if debug:
                self.debug()
            # Simulation end
            if len(self.events) <= 0 or len(self.events) == 1 and self.events[0,0] == 3:
                # Notify nodes that wait for the end of the simulation
                # Note that we do not allow them to create new events (even if they try, they will not be processed)
                for node_id in self.wait_end_nodes:
                    self.nodes[node_id].rqueue.put(("sim_end",0))
                    self.nodes[node_id]["state"]="running"
                    self.sync_node(self.nodes[node_id]) # Allow them for make call requests (printing logs for example)
                break # End the event processing loop

            # Update simulation time
            self.time=self.events[0,1]
            self.time_truncated=format(self.time,self.precision) # refresh truncated time

            # Process events
            while len(self.events) > 0 and self.events[0,1] == self.time:
                event_type=int(self.events[0,0])
                ts=self.events[0,1]
                event=self.events[0,2]
                self.events=np.delete(self.events,0,0) # Consume events NOW! not at the end of the loop (event list may change in between)
                if event_type == 0:
                    src_id,dst_id,interface, data, datasize,duration,datasize_remaining,start_at=event
                    src=self.nodes[int(src_id)]
                    dst=self.nodes[int(dst_id)]                    
                    if interface == "wlan0":
                        if src.node_id != dst.node_id:
                            dst["interfaces"][interface].put((data,start_at,self.time))
                            dst["interfaces_queue_size"][interface]+=1
                            self.log("Receive "+str(datasize)+" bytes on "+interface,node=int(dst_id))
                            # If node is receiving makes it consume (this way if there is a timeout, it will be removed!)
                            if dst["state"] == "request" and dst["request"] == "receive":
                                dst["interfaces_queue_size"][interface]-=1
                                dst.rqueue.put(("receive",0))
                                dst["state"]="running"
                                self.sync_node(dst)
                        else:
                            src["state"]="running"
                            src.rqueue.put(("send",0))
                    else:
                        dst["interfaces"][interface].put((data,start_at,self.time))
                        dst["interfaces_queue_size"][interface]+=1
                        self.update_sharing(dst.node_id,-1,interface)
                        self.log("Receive "+str(datasize)+" bytes on "+interface,node=int(dst_id))
                        # If node is receiving makes it consume (this way if there is a timeout, it will be removed!)
                        if dst["state"] == "request" and dst["request"] == "receive":
                            dst["interfaces_queue_size"][interface]-=1
                            dst.rqueue.put(("receive",0))
                            dst["state"]="running"
                            self.sync_node(dst)
                        src["state"]="running"
                        src.rqueue.put(("send",0))
                elif event_type == 1:
                    node=self.nodes[int(event)]
                    node["state"]="running"
                    node.rqueue.put(("timeout",0))
                    self.sync_node(node)
                elif event_type == 2 or event_type == 3:
                    breakpoint_callback(self)
                    if event_type == 3:
                        self.add_event(3,self.time+breakpoints_every,0,0)                        
                
        ##### Simulation ends
        self.log("Simulation ends")