import datetime
import random
import os.path
import os
import math
import json
import csv
import sys
from decimal import Decimal

# Assumes that each polluter and cleaner group is always just one test?
# Stuff added for ISSTA20 submission to simulate random/pairwise runs
# python get_od_fail_info.py

#print wing_function(['v1', 'p1', 'c2', 'v2', 'n', 'c3', 'c1', 'p2'], 'v2', {'p1': ['c1', 'c2'], 'p2': ['c1', 'c3']})
#print wing_function(['v1', 'p1', 'c2', 'v2', 'n', 'c3', 'c1', 'p2'], 'v1', {'p1': ['c1', 'c2'], 'p2': ['c1', 'c3']})
#print wing_function(['v1', 'c2', 'p1', 'v2', 'n', 'c3', 'c1', 'p2'], 'v2', {'p1': ['c1', 'c2'], 'p2': ['c1', 'c3']})
def wing_function(test_order, victim, d, isVictim): # _for_one_victim):
    i = {} # inverse of d
    for p, cs in d.items():
        for c in cs:
            i[c] = i.get(c, []) + [p]
    s = set() # empty set
    for t in test_order:
        # print str(s)
        if t == victim:
            return "PASS" if ((len(s) == 0 and isVictim) or (len(s) != 0 and not isVictim)) else "FAIL"  # TRUE is pass
        # Assumes that a test that is both polluter and cleaner is simply a polluter
        elif t in i.keys():
            for p in i[t]:
                s.discard(p)
        if t in d.keys():
            s.add(t)
    print str(test_order) + " victim: " + victim 
    raise "od test not found: " + victim

def some_function(dep_file):
    d = {}
    with open(dep_file) as f:
        for r in csv.reader(f):
            (old, od_type) = d.get(r[0], ({}, r[3]))
            c = r[2]
            if c == '':
                old[r[1]] = []
            else:
                old[r[1]] = old.get(r[1], []) + [c]
            d[r[0]] = (old,r[3])
    return d

# new and old tests are sets
def randomize(new_tests, old_tests):
    all_tests = old_tests + new_tests
    random.shuffle(all_tests)
    return all_tests

def full_to_class_full(test):
    idx = test.rfind('.')
    return (test[:idx], test)

def class_to_tests(test_list):
    dic = {}
    for key, val in [full_to_class_full(t) for t in test_list]:
        dic.setdefault(key, []).append(val)
    return dic

all_tests_c_m = {}
new_tests_c_m = {}
old_tests_c_m = {}
def randomize_junit_init(new_tests, old_tests):
    global all_tests_c_m
    global new_tests_c_m
    global old_tests_c_m
    new_tests_c_m = class_to_tests(new_tests)
    old_tests_c_m = class_to_tests(old_tests)
    all_tests_c_m = merge_two_dicts(old_tests_c_m, new_tests_c_m)

def merge_two_dicts(x, y):
    z = x.copy()   # start with x's keys and values
    for key in y:
        z[key] = z.get(key, []) + y[key]
    return z
    
def randomize_junit(new_tests, old_tests):
    ret_list = []
    classes = all_tests_c_m.keys()
    random.shuffle(classes)
    for c in classes:
        random.shuffle(all_tests_c_m[c])
        ret_list = ret_list + all_tests_c_m[c]
    return ret_list

def randomize_junit2(new_or_old):
    ret_list = []
    pro = new_tests_c_m if new_or_old == 'new' else old_tests_c_m
    classes = pro.keys()
    random.shuffle(classes)
    for c in classes:
        random.shuffle(pro[c])
        ret_list = ret_list + pro[c]
    return ret_list

def randomize_new_before_old(new_tests, old_tests):
    random.shuffle(new_tests)
    random.shuffle(old_tests)
    return new_tests + old_tests

def randomize_old_before_new(new_tests, old_tests):
    random.shuffle(new_tests)
    random.shuffle(old_tests)
    return old_tests + new_tests

def randomize_new_before_old_junit(new_tests, old_tests):
    ret = randomize_junit2('new') + randomize_junit2('old')
    return ret

def randomize_old_before_new_junit(new_tests, old_tests):
    return randomize_junit2('old') + randomize_junit2('new')

def randomize(new_tests, old_tests):
    all_tests = old_tests + new_tests
    random.shuffle(all_tests)
    return all_tests

def randomize_reverse():
    return null

def randomize_junit_reverse():
    return null

def isolation(tests, victim, d, isVictim):
    failc = 0
    result = wing_function(tests, victim, d, isVictim)
    if result == "FAIL":
        failc += 1
    return (failc, rounds, rounds)
    

def pairwise(new_tests, old_tests, victim, d, isVictim):
    failc = 0
    rounds = 0
    runs = 0
    for new_test in new_tests:
        for old_test in old_tests:
            if new_test == old_test:
                continue
            else:
                rounds += 2
                if new_test == victim or old_test == victim:
                    runs += 2
                    n_o_result = wing_function([new_test,old_test], victim, d, isVictim)
                    o_n_result = wing_function([old_test,new_test], victim, d, isVictim)
                    if n_o_result == "FAIL":
                        failc += 1
                    if o_n_result == "FAIL":
                        failc += 1
    return (failc, runs, rounds)

def pairwise_new(new_tests, old_tests, victim, d, isVictim):
    return pairwise(new_tests, new_tests, victim, d, isVictim)

def random_main(new_tests, old_tests, victim, d, isVictim, rounds, order, runpassinginreverse):
    failc = 0
    orders = []
    reverse_fail = 0
    reverse_runs = 0
    for i in range(rounds):
        orders.append(order(new_tests, old_tests))

    roundsrun = 0 
    for o in orders:
        if roundsrun >= rounds: # added after TACAS submission; problem before allowed 99 to jump to 101 for 100 rounds and then it would continue for all orders
            break

        result = wing_function(o, victim, d, isVictim)
        roundsrun += 1
        if result != "FAIL" and runpassinginreverse:
            # print "================ before reverse result: " + result + " order: " + str(o)
            if roundsrun == rounds: # added after TACAS submission
                break
            roundsrun += 1
            reverse_runs += 1
            o.reverse()
            result = wing_function(o, victim, d, isVictim)
            if result == "FAIL":
                reverse_fail += 1

            # print "================ after reverse result: " + result + " order: " + str(o)
        if result == "FAIL":
            failc += 1
    return (failc, roundsrun, rounds, reverse_fail, reverse_runs)

def get_class_name(test):
    return '.'.join(test.split(".")[:-1])

def add_to_pol(dicta, classname, polluter_or_cleaner, iscleaner):
    if classname in dicta.keys():
        (pols, pcleaners) = dicta[classname]
    else:
        (pols, pcleaners) = (set(), set())
    if iscleaner:
        pcleaners.add(polluter_or_cleaner)
    else:
        pols.add(polluter_or_cleaner)
    dicta[classname] = (pols, pcleaners)

def prob_fail_formula(victim, d, order):
    sm = 0.0
    vpol = set()
    vcleaner = set()
    victim_c = get_class_name(victim)
    pc_to_cc = {}
    for pol in d.keys():
        # Setup dictionary of {class1: ([pol_name1, pol_name2], [cleaner_name1]), ...}
        pol_c = get_class_name(pol)
        if victim_c == pol_c:
            vpol.add(pol)
        else:
            add_to_pol(pc_to_cc, pol_c, pol, False)
        cleaners = d[pol]
        nvcleaners = set()
        for cleaner in cleaners:
            cleaner_c = get_class_name(cleaner)
            if victim_c == cleaner_c:
                vcleaner.add(cleaner)
            else:
                add_to_pol(pc_to_cc, cleaner_c, cleaner, True)

    for classname in pc_to_cc.keys():
        # print str.format("classname: {}", classname)
        # print str.format("p: {}", pc_to_cc[classname][0])
        # print str.format("c: {}", pc_to_cc[classname][1])
        plen = len(pc_to_cc[classname][0]) * 1.0
        sm += plen / (plen + len(pc_to_cc[classname][1]))

    if order.__name__ == "randomize":
        psets = [pc_to_cc[i][0] for i in pc_to_cc.keys()]
        if len(psets) != 0:
            plen = len(vpol.union(set.union(*psets)))
        else:
            plen = len(vpol)
        csets = [pc_to_cc[i][1] for i in pc_to_cc.keys()]
        if len(csets) != 0:
            clen = len(vcleaner.union(set.union(*csets)))
        else:
            clen = len(vcleaner)
        # print str.format("p: {}", plen)
        # print str.format("c: {}", clen)
        return plen / (plen + clen + 1.0)
    elif order.__name__ == "randomize_junit":
        # print str.format("sm: {}", sm)
        # print str.format("1/k+1 * sm: {}", (( 1 / (len(pc_to_cc.keys()) + 1.0)) * sm))
        # print str.format("p+c+1: {}", len(vpol) + len(vcleaner) + 1)
        return (len(vpol) + (( 1 / (len(pc_to_cc.keys()) + 1.0)) * sm)) / (len(vpol) + len(vcleaner) + 1)
    else:
        raise "unknown order for formulas " + order

def simulate_orders(new_tests, old_tests, victim, d, isVictim, rounds, order, runpassinginreverse):
    reverse_runs = 0
    reverse_fail = 0
    if order.__name__.startswith('random'):
        (failc, runs, roundsc, reverse_fail, reverse_runs) = random_main(new_tests, old_tests, victim, d, isVictim, rounds, order, runpassinginreverse)
    else:
        (failc, runs, roundsc) = order(new_tests, old_tests, victim, d, isVictim)
    print "fails: " + str(failc) + " runs: " + str(runs) + " rounds: " + str(roundsc) + " reverse_fail: " + str(reverse_fail)  + " reverse_runs: " + str(reverse_runs) + " simulated===="

def check_orders_help(order, runpassinginreverse, new_tests, old_tests, victim, d, isVictim, rounds):
    if order.__name__ == "randomize_reverse":
        order = randomize
        runpassinginreverse = True
    elif order.__name__ == "randomize_junit_reverse":
        order = randomize_junit
        runpassinginreverse = True
    simulate_orders(new_tests, old_tests, victim, d, isVictim, rounds, order, runpassinginreverse )

def check_orders(new_tests, old_tests, victim, d, isVictim, rounds, orders, use_formulas):
    randomize_junit_init(new_tests, old_tests) #### this is ugly
    for order in orders:
        random.seed(1)
        print "Order: " + order.__name__
        runpassinginreverse = False

        if (order.__name__ == "randomize" or order.__name__ == "randomize_junit") and use_formulas:
            # if victim all cleaners cleans all polluters, then use formula instead of simulation
            onecleaners = d[next(iter(d.keys()))]
            if (all(onecleaners == d[x] for x in d.keys())):
                if (isVictim):
                    failc = int(prob_fail_formula(victim, d, order) * rounds)
                else:
                    failc = int((1 - prob_fail_formula(victim, d, order)) * rounds)
                print "fails: " + str(failc) + " runs: " + str(rounds) + " rounds: " + str(rounds) + " reverse_fail: " + str(-1)  + " reverse_runs: " + str(-1) + " used_formula===="
            else:
                check_orders_help(order, runpassinginreverse, new_tests, old_tests, victim, d, isVictim, rounds)
        else:
            check_orders_help(order, runpassinginreverse, new_tests, old_tests, victim, d, isVictim, rounds)

def parse_file(test_file):
    with open(test_file) as f:
        return f.read().splitlines()

def new_old_tests_default(new_tests, old_tests):
    return (new_tests, old_tests)

def new_old_tests_classes(new_tests, old_tests):
    old_tests_new_classes = []
    new_classes = set([new_test[:new_test.rfind('.')] for new_test in new_tests])

    for old_test in old_tests:
        if old_test[:old_test.rfind('.')] in new_classes:
            old_tests_new_classes.append(old_test)

    return (new_tests + old_tests_new_classes, [test for test in old_tests if test not in old_tests_new_classes])

def test_mode(test_names, rounds, d, orders, tests_cat, use_formulas):
    for test_name in test_names:
        victim = test_name
        print "Test: " + test_name + " started at: " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        if version == 'idflakies':
            (deps, od_type) = d[victim]
            isVictim = True if od_type == "victim" else False
            new_tests = deps.keys() + list(set([item for sublist in deps.values() for item in sublist])) + [victim]
            old_tests = []
        else:
            new_tests = parse_file([os.path.join(files['newtests'], x) for x in os.listdir(files['newtests']) if x.startswith(test_name)][0])
            old_tests = parse_file([os.path.join(files['oldtests'], x) for x in os.listdir(files['oldtests']) if x.startswith(test_name)][0])
            (deps, od_type) = d[victim]
            isVictim = True if od_type == "victim" else False

        #tests_cats = [new_old_tests_default, new_old_tests_classes]
        tests_cats = [new_old_tests_default]
        for tests_cat in tests_cats:
            print "Cat: " + tests_cat.__name__ + " started at: " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            (new_tests, old_tests) = tests_cat(new_tests, old_tests)
            check_orders(new_tests, old_tests, victim, deps, isVictim, rounds, orders, use_formulas)

def strat_mode(info_to_tests, rounds, d, orders, tests_cats):
    for info in info_to_tests.keys():
        print "Info: " + str(info) + " victims: " + str(len(info_to_tests[info])) + " started at: " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        test_name = info_to_tests[info][0]
        new_tests = parse_file([os.path.join(files['newtests'], x) for x in os.listdir(files['newtests']) if x.startswith(test_name)][0])
        old_tests = parse_file([os.path.join(files['oldtests'], x) for x in os.listdir(files['oldtests']) if x.startswith(test_name)][0])
        random.seed(1)
        randomize_junit_init(new_tests, old_tests)

        for tests_cat in tests_cats:
            print "Cat: " + tests_cat.__name__ + " started at: " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            (new_tests, old_tests) = tests_cat(new_tests, old_tests)
            for order in orders:
                print "Order: " + order.__name__
                perm = []
                for i in range(rounds):
                    perm.append(order(new_tests, old_tests))
                    
                oneOrMoreFail = 0
                allFail = 0
                for o in perm:
                    failc = 0
                    for victim in info_to_tests[info]:
                        (deps, od_type) = d[victim]
                        isVictim = True if od_type == "victim" else False
                        result = wing_function(o, victim, deps, isVictim)
                        if result == "FAIL":
                            failc += 1

                    if failc == len(info_to_tests[info]):
                        allFail += 1
                    if failc > 0:
                        oneOrMoreFail += 1
                print "oneOrMoreFails: " + str(oneOrMoreFail) + " allFails: " + str(allFail) + " rounds: " + str(rounds) 

version = 'idflakies'
tests_cat = [new_old_tests_default]#, new_old_tests_classes]

#orders = [randomize_junit]
#orders = [randomize_reverse, randomize_junit_reverse]

#orders = [randomize, randomize_junit] # omegaA no reverse, omegaC no reverse
orders = [randomize_junit, randomize_junit_reverse] # omegaC no reverse, omegaC with reverse
# orders = [randomize, randomize_reverse] # omegaA no reverse, omegaA with reverse
# orders = [randomize, randomize_reverse, randomize_junit, randomize_junit_reverse]
mapp_orders = { "omega_c_reverse_passing" : randomize_junit_reverse, "omega_c" : randomize_junit, "omega_a_reverse_passing" : randomize_reverse, "omega_a" : randomize } 


# Simulation for per test
# use_formulas = False

if __name__ == '__main__':
    inputs = sys.argv
    files = inputs[1]
    rounds = int(inputs[3])
    # if len(inputs) == 5:
    #     must_sim = bool(inputs[4])
    parse_orders = inputs[4].split(';')
    use_formulas = {"true":True,"false":False}[inputs[5].lower()]
    orders = []
    for order in parse_orders:
        if order not in mapp_orders:
            raise ValueError(str.format("Unexpected order passed in: {}", order))
        orders.append(mapp_orders[order])
    d = some_function(files)

    test_info = [str.format("slug,sha,{}", inputs[2])]
    info_to_tests = {}
    test_to_info = {}
    for s_s_t in test_info:
        info = s_s_t.split(',')
        test_to_info[info[2]] = (info[0], info[1])
        info_to_tests[(info[0],info[1])] = info_to_tests.get((info[0],info[1]), []) + [info[2]]

    test_mode(test_to_info.keys(), rounds, d, orders, tests_cat, use_formulas)
