#
# Copyright (c) 2021 Contributors to the Eclipse Foundation
#
# This program and the accompanying materials are made
# available under the terms of the Eclipse Public License 2.0
# which is available at https://www.eclipse.org/legal/epl-2.0/
#
# SPDX-License-Identifier: EPL-2.0
#

import threading, random, asyncio, queue, ctypes, datetime, subprocess, os, io, json, traceback, sys, re, time
from typing import Dict, Callable, List, Optional, Union, Type
from snakes.nets import PetriNet


class WalkRecorder():
    def __init__(self) -> None:
        self.logger = None
        self.saveInFile = ""
        self.rerunFromFile = ""
        self.file = None
        self.extend = False
        self.continu = False

    def log(self, logger):
        self.logger = logger

    def saveAs(self, file: str):
        if file != None:
            self.saveInFile = file
            self.file = open(file, "w")
    
    def stop(self):
        self.saveInFile = ""
        self.rerunFromFile = ""
        if self.file != None and not self.file.closed:
            self.file.flush()
            self.file.close()
            self.file = None

    def record(self, step: WalkerStep):
        if self.saveInFile != "":
            self.file.write(f"{str(step.event)}\n")

    def playFrom(self, file: str, extend: bool, continu: bool):
        if file != None:
            self.extend = extend
            self.continu = continu
            self.rerunFromFile = file
            self.file = open(file, "r")

    def rerun(self, steps: List[WalkerStep]) -> List[WalkerStep]:
        line = ""
        if self.rerunFromFile == "":
            return steps
        else:
            line = self.file.readline().strip()
            newLine = line
        rStep = []
        for step in steps:
            parts = line.split(' ')
            if len(parts) > 2:
                step.breakStep = True
                newLine = parts[0] + ' ' + parts[1].strip()
            if str(step.event) == newLine:
                rStep.append(step)
        if rStep == [] and line != "":
            self.logger("Cannot select step in {}. Is the SUT non-deterministic?".format(line))
            rStep = None
        if rStep != None and rStep == [] and line == "":
            self.logger("All recorded lines have been executed successfully")
            if self.extend:
                self.logger("Continue testing and append to recorded file")
                self.saveInFile = self.rerunFromFile
                self.rerunFromFile = ""
                self.file = open(self.saveInFile, "a")
                rStep = steps
            else:
                if self.continu:
                    self.logger("Continue testing")
                    self.rerunFromFile = ""
                    rStep = steps
                else:
                    rStep = None
        return rStep


class TestStrategy():
    def __init__(self, strategy: str, recorder: Callable[['WalkRecorder'], None], log: Callable[[str], None], debugger: Callable[['Debugger'], None]) -> None:
        self.strategy = strategy
        self.recorder = recorder
        self.log = log
        self.debugger = debugger
        self.taken_transitions = []

    def next_step(self, walker: Dict[str, Callable[[], Walker]], take_reply_to_cmd: Optional[Event]):
        if self.strategy == "Random":
            return self.next_step_orig(walker, take_reply_to_cmd)
        else: 
            if self.strategy == "Prioritize non-selected":
                return self.next_step_improved(walker, take_reply_to_cmd)
            else:
                assert False

    def next_step_orig(self, walker, take_reply_to_cmd):
        steps = []
        if take_reply_to_cmd != None:
            steps = walker.next_steps((take_reply_to_cmd.port, take_reply_to_cmd.component))
            steps = [c for c in steps if c.event.port == take_reply_to_cmd.port and c.event.component == take_reply_to_cmd.component
                and c.event.kind == EventType.Reply and c.event.method == take_reply_to_cmd.method]
        else:
            connections = list(walker.nets.keys())
            for connection in connections:
                steps.extend(walker.next_steps(connection))
        step = None
        if len(steps) != 0:
            steps = self.recorder.rerun(steps)
            if steps == None:
                return None
            steps = self.debugger.debug_next_step(steps)
            step = random.choice(steps)
            assert step.event != None
        return step
    
    def next_step_improved(self, walker, take_reply_to_cmd):
        steps = []
        if take_reply_to_cmd != None:
            steps = walker.next_steps((take_reply_to_cmd.port, take_reply_to_cmd.component))
            steps = [c for c in steps if c.event.port == take_reply_to_cmd.port and c.event.component == take_reply_to_cmd.component
                and c.event.kind == EventType.Reply and c.event.method == take_reply_to_cmd.method]
        else:
            connections = list(walker.nets.keys())
            for connection in connections:
                steps.extend(walker.next_steps(connection))
        step = None
        if len(steps) != 0:
            steps = self.recorder.rerun(steps)
            if steps == None:
                return None
            rSteps = []
            for step in steps: # If possible, choose a step that has not been taken before
                net = walker.nets[step.event.connectionKey()]
                try:
                    if "_join" in step.clause:
                        step.clause = step.clause[0: (step.clause.rfind("_join"))]
                    if "_split" in step.clause:
                        step.clause = step.clause[0: (step.clause.rfind("_split"))]
                    if not step.clause.endswith("_0"):
                        step.clause = step.clause[0: (step.clause.rfind("_")+1)] + "0"
                    clause_str = step.event.interface + "." + step.event.port + "." + step.clause + "." + net._place[step.clause].meta['sourceline']
                except Exception as e:
                    # self.log(f"step: '{step}', clause: '{net._place[step.clause]}'")
                    # self.log(f"clause.meta: '{net._place[step.clause].meta['sourceline']}'")
                    raise Exception("New exception")
                if clause_str not in walker.seen_clauses:
                    rSteps.append(step)
            if rSteps == []:
                rSteps = steps
            rSteps = self.debugger.debug_next_step(rSteps)
            step = random.choice(rSteps)
        return step

class Debugger():
    def __init__(self):
        self.choice_func = None
        self.lock = threading.Lock()
        self.enable_lock = threading.Lock()
        self.wait = False
        self.enabled = False
        self.take_step = -1

    def debug_register_choice_func(self, func):
        self.choice_func = func

    def debug_set(self, state: bool):
        self.enable_lock.acquire()
        self.enabled = state
        self.enable_lock.release()

    def debug_next_step(self, steps: List[Type['Event']]):
        self.lock.acquire()
        for step in steps:
            self.enable_lock.acquire()
            enabled = self.enabled
            self.enable_lock.release()
            if self.enabled and step.event.has_breakpoint():
                self.wait = True
            if self.enabled and step.breakStep:
                self.wait = True
        has_to_wait = self.wait
        if has_to_wait and self.choice_func != None:
             self.choice_func(steps)
        self.lock.release()
        while has_to_wait:
            time.sleep(0.1)
            self.lock.acquire()
            has_to_wait = self.wait
            self.lock.release()
        self.lock.acquire()
        if self.take_step != -1:
            self.wait = True
            tmpSteps = []
            tmpSteps.extend(steps)
            steps = []
            steps.append(tmpSteps[self.take_step])
        else:
            self.wait = False
        self.lock.release()
        return steps

    def debug_pause(self):
        self.lock.acquire()
        self.wait = True
        self.lock.release()

    def debug_continue(self):
        self.lock.acquire()
        self.wait = False
        self.take_step = -1
        self.lock.release()       

    def debug_step(self, index: str):
        self.lock.acquire()
        self.wait = False
        self.take_step = int(index)
        self.lock.release()

    def debug_timeout(self, timeout: int):
        # self.lock.acquire()
        # if self.wait or self.take_step != -1:
        #     timeout = None # Infinite timeout during debugging
        # self.lock.release()
        return timeout


class TestApplicationWalker():
    def __init__(self, nets: Dict[str, Callable[[], PetriNet]], constraints: List[Type['Constraint']], send_event: Callable[['Event'], None], 
                 stopped: Callable[[Optional[str]], None], strategy: str, log: Callable[[str], None], recorder: Callable[['WalkRecorder'], None], 
                 debugger: Callable[[Optional['Debugger']], None]) -> None:
        self.send_event = send_event
        self.stopped = stopped
        self.walker = Walker(nets, constraints, log)
        self.event_queue: queue.Queue[Union[Event, None]] = queue.Queue()
        self.thread: Optional[threading.Thread] = None
        self.stop_requested = False
        self.recorder = recorder
        self.test_strategy = TestStrategy(strategy, self.recorder, log, debugger)
        self.debugger = debugger

    def start(self):
        assert self.thread == None, "Already running"
        self.thread = threading.Thread(target=self.__run_non_async)
        self.thread.start()
    
    def stop(self):
        if self.thread != None:
            self.stop_requested = True
            self.event_queue.put(None) # Force run to stop
            if threading.current_thread() != self.thread: self.thread.join()
            self.stop_requested = False

    def received_event(self, event: 'Event'):
        self.event_queue.put(event)

    def __run_non_async(self):
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        loop.run_until_complete(self.__run())
        loop.close()

    async def __run(self):
        port_notification_during_command_transition: Dict[str, List[Event]] = {}
        take_reply_to_cmd: Optional[Event] = None
        stop_on_no_events = False
        error: Optional[str] = None

        self.walker.log("Initial states:")
        for (connection, states) in self.walker.states.items():
            for (machine, state) in states.items():
                self.walker.log(f"Connection ({connection[1]}, {connection[0]}) , machine '{machine}' is in state '{state}'")

        try:
            while not self.stop_requested and error == None:
                event: Optional[Event] = None
                try:
                    timeout = NO_EVENTS_TIMEOUT if stop_on_no_events else DEFAULT_TIMEOUT
                    timeout = self.debugger.debug_timeout(timeout)
                    if self.event_queue.qsize() == 0:
                        event = self.event_queue.get(True, timeout)
                    else:
                        event = self.event_queue.get()
                    if event == None: continue # None means we have to stop (added in stop())
                except queue.Empty:
                    pass

                if event != None:
                    stop_on_no_events = False
                    parameter_place_name = f"P_{event.method}{'_reply' if event.kind == EventType.Reply else ''}"
                    # self.walker.log(f"Process event: '{str(event)}'")
                    if not event.connectionKey() in self.walker.nets:
                        error = f"Received event '{str(event)}' from unknown port '{event.port}'"
                        continue
                    if not parameter_place_name in self.walker.nets[event.connectionKey()]._place:
                        error = f"Event '{event.method}' is unknown for port '{event.port}'"
                        continue
                    place = self.walker.nets[event.connectionKey()]._place[parameter_place_name]
                    # self.walker.log(f"Process place: '{str(place)}'")
                    place.add([Parameters([p.value for p in event.parameters])])
                    # self.walker.log(f"Process parameters: '{str(place.meta)}'")
                    steps = [e for e in self.walker.next_steps(event.connectionKey()) if e.event == event]
                    if len(steps) == 0:
                        if event.kind == EventType.Notification and event.port in port_notification_during_command_transition:
                            port_notification_during_command_transition[event.port].append(event)
                        else:
                            error = f"Event '{str(event)}' is not possible"
                    else:
                        step = random.choice(steps)
                        # self.walker.log(f"Process step: '{str(step)}'")
                        self.walker.take_step(step)
                        if event.kind == EventType.Reply:
                            for notification in port_notification_during_command_transition[event.port]:
                                steps = [e for e in self.walker.next_steps(event.connectionKey()) if e.event == notification]
                                if len(steps) == 0:
                                    error = f"Event '{str(notification)}' is not possible"
                                    break
                                else:
                                    self.walker.take_step(random.choice(steps))
                            del port_notification_during_command_transition[event.port]  
                        elif event.kind == EventType.Command:
                            take_reply_to_cmd = event
                else:
                    step = self.test_strategy.next_step(self.walker, take_reply_to_cmd)
                    # self.walker.log(f"Process step: '{str(step)}'")
                    if step == None:
                        if stop_on_no_events:
                            error = "No next steps possible from test application"
                        else:
                            stop_on_no_events = True
                    else:
                        if step.event.kind == EventType.Command:
                            port_notification_during_command_transition[step.event.port] = []
                        self.send_event(step.event)
                        self.recorder.record(step)
                        take_reply_to_cmd = None
                        self.walker.take_step(step)
        except Exception as e:
            error = f"Error while running: {repr(e)}, event: {str(event)}, place: {str(place)}"
            traceback.print_exc()

        if not self.stop_requested:
            self.stopped(error)

