from peewee import *
from search import InventorySearch as ivs
from playhouse.shortcuts import ReconnectMixin

class ReconnectMySQLDatabase(ReconnectMixin, MySQLDatabase):
    pass

# auto-reconnect if mysql disconnects
db = ReconnectMySQLDatabase('inventory', thread_safe=True, user='inventory', password='nfrwnfprifbwef', host='database', port=3306)
search = None 

class user(Model):
    name = CharField()
    username = CharField(unique=True, primary_key=True)
    password  = CharField() # replace with AD or something!!
    changepw = BooleanField(default=False)

    class Meta:
        database = db
        legacy_table_names = False

class office(Model):
    name = CharField(unique=True)
    officeid = AutoField()

    class Meta:
        database = db
        legacy_table_names = False

class location(Model):
    name = CharField()
    locationid = CharField(unique=True, primary_key=True)
    description = CharField(null=True)

    parent = ForeignKeyField('self', null=True, backref="sublocations")

    class Meta:
        database = db
        legacy_table_names = False



class item(Model):
    loc = ForeignKeyField(location, backref="items_here", null=True)
    office = ForeignKeyField(office, backref="items_here", null=True)
    fullname = CharField(null=True)
    description = CharField(null=True)
    serial = CharField(null=True)
    checkout = BooleanField(default=False)
    checkout_loc = ForeignKeyField(location, backref="items_checkedout_here", null=True)
    checkout_user = ForeignKeyField(user, backref="items_held", null=True)
    checkout_start = DateTimeField(null=True)
    checkout_end = DateTimeField(null=True)
    mac = CharField(null=True)
    barcode = CharField(unique=True, primary_key=True)
    fwver = CharField(null=True)
    manufacturer = CharField()
    
    
    last_user = ForeignKeyField(user, null=True) # remove null=True once user auth works

    class Meta:
        database = db
        legacy_table_names = False

class component(Model):
    owner = ForeignKeyField(item, backref="components")
    name = CharField()
    description = CharField(null=True)
    barcode = CharField(unique=True, primary_key=True)
    serial = CharField(null=True)

    class Meta:
        database = db
        legacy_table_names = False

def init():
    print("Connecting to database...")
    import time
    while True:
        try:
            db.connect()
            break
        except:
            time.sleep(1)
    print("Checking & creating tables...")
    db.create_tables([location, office, item, component, user])
    print("Database initialized.")
    global search
    print("Creating cache index... ", end='', flush=True)
    search = ivs()
    add = item.select().dicts()
    #print(add)
    #print(type(add))
    for itm in add:
        try:
            itm["location"] = item.select().where(item.barcode==itm["barcode"])[0].loc.name
        except:
            pass
        #print(itm)
        #print(type(itm))
        search.add_document(itm)
    print(len(add))
    print("Cache build complete.")

def search_item(query, filters: dict={}):
    #print(filters)
    if len(filters) > 0:
        filt = ""
        for key, val in filters.items():
            if key == "office":
                # convert to integer!
                if val != 'all':
                    if len(office.select().where(office.name == val).dicts()) > 0:
                        val2 = str(office.select().where(office.name == val).dicts()[0]['officeid'])
                    #print(val2)
                    else:
                        # office does not have any parts
                        val2 = str(999999)
                    filt += key + " = " + val2 + " AND "
                else:
                    continue
                
            else:
                filt += key + " = " + val + " AND "
        filt = filt[0:-4] # remove extra and
        #print(filt)
        return search.search(query, filt)["hits"]
    else:
        return search.search(query, "")["hits"]

def find_item(barcode):
    return search.get_barcode(barcode)

def find_item_location(barcode):
    try:
        return item.select().where(item.barcode==barcode)[0].loc
    except:
        return False

def create_item(fullname, serial, officename, barcode, locationid=None, description=None, manufacturer=None, mac=None, fwver=None):
    try:
        off = office(name=officename)
        off.save(force_insert=True)
    except IntegrityError:
        pass
    try:
        loc = get_location_id(locationid)
        if loc == False:
            loc = None
        else:
            pass
            #print("Found location: " + loc.name)
        off = office.select().where(office.name == officename)[0]
        itm = item(office=off, barcode=barcode, fullname=fullname, description=description, loc=loc, serial=serial, mac=mac, fwver=fwver, manufacturer=manufacturer)
        itm.save(force_insert=True)
        itmdict= item.select().where(item.barcode==barcode).dicts()[0]
        try:
            itmdict["location"] = loc.name
            #print(locationid)
            #print(itmdict["location"])
        except:
            pass
        search.add_document(itmdict) 
        print("added item: " + itm.fullname)
        return itm
    except IntegrityError:
        print("Duplicate item " + fullname)
        return False
    

def update_item(fullname, serial, officename, barcode, locationid=None, description=None, manufacturer=None, mac=None, fwver=None):
    try:
        off = office(name=officename)
        off.save(force_insert=True)
    except IntegrityError:
        pass
    try:
        loc = get_location_id(locationid)
        if loc == False:
            loc = None
        else:
            pass
            #print("Found location: " + loc.name)
        off = office.select().where(office.name == officename)[0]
        itm = item(office=off, barcode=barcode, fullname=fullname, description=description, loc=loc, serial=serial, mac=mac, fwver=fwver, manufacturer=manufacturer)
        itm.save()
        itmdict= item.select().where(item.barcode==barcode).dicts()[0]
        try:
            itmdict["location"] = loc.name
            #print(locationid)
            #print(itmdict["location"])
        except:
            pass
        search.add_document(itmdict) 
        print("updated item: " + itm.fullname)
        return itm
    except IntegrityError:
        print("Duplicate item " + fullname)
        return False

def delete_item(itm):
    #item.delete(itm)
    itm.delete_instance()

def delete_item_id(barcode):
    #item.delete(itm)
    itm = get_item(barcode)
    itm.delete_instance()
    
def item_location_str(itm):
    try:
        return itm.loc.name
    except:
        if itm.checkout:
            return "Checked out to unknown location"
        else:
            return "Unknown"

def create_component(parentitem, name, barcode, serial=None, description=None):
    itm = parentitem
    try:
        cmp = component(owner=itm, name=name, barcode=barcode, description=description, serial=serial)
        cmp.save(force_insert=True)
        print("added component: " + cmp.name)
        return cmp
    except IntegrityError:
        print("Duplicate component " + name)
        return False

def get_item(barcode):
    query = item.select().where(item.barcode == barcode)
    if len(query) == 1:
        return query[0]

    # check if component exists
    return get_component(barcode)


def get_component(barcode):
    query = component.select().where(component.barcode == barcode)
    if len(query) == 1:
        return query[0]
    
    return False

def create_user(name, username, password, changepw=False):
    try:
        usr = user(username=username, name=name, password=password, changepw=changepw)
        usr.save(force_insert=True)
        return usr
    except IntegrityError:
        print("User " + username + " already exists.")
        return False

def change_password(username, password):
    usr = get_user(username)
    if usr:
        usr.password = password
        user.changepw = False
        usr.save(force_insert=True)
        return True
    else:
        return False

def checkout(user, barcode, loc=None):
    itm = get_item(barcode)
    if itm:
        itm.checkout = True
        itm.checkout_user = user
        itm.checkout_loc = loc
        itm.save()
        return itm
    else:
        return False

def checkin(user, barcode, loc=None):
    itm = get_item(barcode)
    if itm:
        itm.checkout = False
        itm.last_user = user
        if loc is not None:
            itm.loc = loc
        itm.save()
        return itm
    else:
        return False

        
def create_location(name, barcode, parent=None, description=None):
    try:
        loc = location(name=name, locationid=barcode, parent=parent, description=description)
        loc.save(force_insert=True)
        print(loc.name, loc.locationid)
        return loc
    except:
        return False

def _find_parent(loc, parent):
    if hasattr(loc, 'parent'):
        if loc.parent.locationid == parent.locationid:
            return True
        else:
            return _find_parent(loc.parent, parent)
    else:
        return False

def get_location(name, parent=None):
    try:
        query = location.select().where(location.name == name)
        if parent is not None:
            for loc in query:
                if _find_parent(loc, parent):
                    return loc
            return False

        else:
            if len(query) == 1:
                return query[0]
            else:
                return False
    except:
        return False
    
def get_location_id(barcode):
    try:
        #print("str" + barcode + "str")
        if len(barcode) > 0:
            query = location.select()
            for loc in query:
                #print(loc.name, loc.locationid)
                if loc.locationid == barcode:
                    return loc
            return False
        else:
            return False
    except:
        return False

def get_user(name):
    query = user.select().where(user.username == name)
    if len(query) == 1:
        return query[0]

    query = user.select().where(user.name == name)
    if len(query) == 1:
        return query[0]

    return False

def user_login(username, password):
    user = get_user(username)
    if user:
        return user.password == password
    else:
        return False
    
def check_db_connection():
    query = office.select()
    if len(query) > 0:
        # require the query to fire by checking it's return length
        pass

def test():
    costa = create_user("Costa Aralis", "caralpwtwfpis", "12345")
    costa.username = "caralis"
    costa.save(force_insert=True)

    office = location(name="Chicago CIC")
    office.save()
    shelf2a = location(name="Shelf 2A", parent=office)
    shelf2a.save()

    create_item("BRS50", "BRS50-00242Q2Q-STCY99HHSESXX.X.XX", "12345678", location=shelf2a)

    create_item("BRS50", "BRS50-00242W2W-STCY99HHSESXX.X.XX", "123456789", location=office)

    #brs50 = part(name="BRS50", description="it's a frickin BRS dude", quantity=1)
    #brs50.save()
    #mybrs = item(owner=brs50, fullname="BRS50-00242Q2Q-STCY99HHSESXX.X.XX", description="This one has 6 dead ports. RIP", loc=shelf2a, barcode="tlg4276p4dj85697")
    #mybrs.save()
    print("Querying...")
    #query = part.select()
    query = item.select().where(item.name == "BRS50")

    for brs in query:
        for itm in brs.items:
            print(itm.fullname)


if __name__ == "__main__":
    init()
    test()