import os.path
import sqlite3
import src.kill_all as kill_all
import numpy as np
import os
import io
import pickle
import logging
from src.timeGeo import timeGeoObject

class ImageTableNotInDbException(Exception):
    def __init__(self, dbName=""):
        self.message = f"Image table not in db {dbName}"
        super().__init__(self.message)

def adapt_array(arr):
    """
    http://stackoverflow.com/a/31312102/190597 (SoulNibbler)
    """
    out = io.BytesIO()
    np.save(out, arr)
    out.seek(0)
    return sqlite3.Binary(out.read())

def convert_array(text):
    out = io.BytesIO(text)
    out.seek(0)
    return np.load(out)

# Converts np.array to TEXT when inserting
sqlite3.register_adapter(np.ndarray, adapt_array)

# Converts TEXT to np.array when selecting
sqlite3.register_converter("array", convert_array)

class dbs_class:
    def __init__(self,db_file_name: str):
        self.db_file_name= db_file_name
        self.connector = None
        self.cursor = None
        self.check_create_db(db_file_name)
        with sqlite3.connect(db_file_name) as connector:
            cursor=connector.cursor()
            cursor.execute("SELECT name,type FROM pragma_table_info('image')")
            res=cursor.fetchall()
            self.image_columns=dict(res)
            self.image_arrays=[x[0] for (idx,x) in enumerate(self.image_columns.items()) if ('BLOB' in x[1]) ]


    def check_create_db(self,name: str):
        try:
            if not os.path.isfile(name):
                self.create_image_db()
            else:
                with sqlite3.connect(name) as connector:
                    cursor=connector.cursor()
                    cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='image'")
                    res=cursor.fetchall()
                    if len(res)==0:
                        raise ImageTableNotInDbException(name)
        except:
            kill_all.kill_db_agent()

    def create_image_db(self):
        create_sting = """CREATE TABLE image(
            hash BIGINT PRIMARY KEY,
            fullpath TEXT  UNIQUE NOT NULL,
            id TEXT UNIQUE NOT NULL,
            image TEXT NOT NULL UNIQUE,
            range REAL,
            perc BLOB,
            time BLOB,
            lat BLOB,
            lon BLOB,
            altitude BLOB,
            yaw BLOB,
            speed BLOB,
            angv_x BLOB,
            angv_y BLOB,
            angv_z BLOB,
            bb BLOB,
            data BLOB
        );"""
        with sqlite3.connect(self.db_file_name) as connector:
            cursor = connector.cursor()
            cursor.execute(create_sting)




    def insert_image(self, timeGeo: timeGeoObject):
        with sqlite3.connect(self.db_file_name) as connector:
            cursor=connector.cursor()
            res = cursor.execute("SELECT id FROM image WHERE hash = ?", (timeGeo.hash,))
            if res.fetchone() is not None:
                return
            record={}
            record['hash']=timeGeo.hash
            record['fullpath']=str(timeGeo.fname)
            record['image']=str(timeGeo.image)
            record['id']=timeGeo.id
            record['perc']= sqlite3.Binary(timeGeo.perc.tobytes())
            record['time']= sqlite3.Binary(timeGeo.time.tobytes())
            record['lat']= sqlite3.Binary(timeGeo.lat.tobytes())
            record['lon']= sqlite3.Binary(timeGeo.lon.tobytes())
            record['altitude']= sqlite3.Binary(timeGeo.altitude.tobytes())
            record['yaw']= sqlite3.Binary(timeGeo.yaw.tobytes())
            record['speed']= sqlite3.Binary(timeGeo.speed.tobytes())
            record['angv_x']= sqlite3.Binary(timeGeo.angv_x.tobytes())
            record['angv_y']= sqlite3.Binary(timeGeo.angv_y.tobytes())
            record['angv_z']= sqlite3.Binary(timeGeo.angv_z.tobytes())
            record['range']=timeGeo.range
            record['bb']= sqlite3.Binary(timeGeo.rtree_bb.tobytes())
            record['data']=sqlite3.Binary(timeGeo.data.tobytes())

            cursor.execute(f"insert into image ({','.join(record.keys())}) VALUES ({('?,' * len(record.keys()))[:-1]})",
                           tuple(record.values()))
            #cursor.execute("INSERT INTO image  VALUES (?, ?, ?,?,?)",
            #               (image_id, image_dir, geo_name,timegeo[-1,1],timegeo[0,1]))
            connector.commit()


    def read_image(self, hash_: np.int64):
        with sqlite3.connect(self.db_file_name) as connector:
            cursor=connector.cursor()
            cursor.execute("SELECT * FROM image WHERE hash = ?", (hash_,))
            res=cursor.fetchone()
            print(res)
            if res is None:
                return None
            res2=dict(zip(self.image_columns.keys(), res))
            for col in self.image_arrays:
                res2[col]=np.frombuffer(res2[col],dtype=np.float64)

            return timeGeoObject(None,res2)


    def check_image(self, image_id: str, timegeo: np.array) -> timeGeoObject:
        with sqlite3.connect(self.db_file_name) as connector:
            cursor=connector.cursor()
            cursor.execute("SELECT * FROM image WHERE image_name = ?", (image_id,))
            res=cursor.fetchone()
            if res is None:
                return
            res2=dict(zip(self.image_columns.keys(), res))
            for col in self.image_arrays:
                res2[col]=np.frombuffer(res2[col],dtype=np.float64)

            for col in self.image_funcs:
                res2[col]=pickle.loads(res2[col])

            times, t_map, perc_map =self.create_maps_for_images(timegeo)

            def cmp_arrays(a1,a2):
                if a1.shape!=a2.shape:
                    logging.info("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
                if not np.array_equal(a1,a2):
                    logging.info(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")

            cmp_arrays(res2['perc'],timegeo[:, 0])
            cmp_arrays(res2['times'],times)
            cmp_arrays(res2['lat'],timegeo[:, 2])
            cmp_arrays(res2['lon'],timegeo[:, 3])
            cmp_arrays(res2['altitude'],timegeo[:, 4])
            cmp_arrays(res2['yaw'],timegeo[:, 5])
            cmp_arrays(res2['speed'],timegeo[:, 6])
            cmp_arrays(res2['angv_x'],timegeo[:, 7])
            cmp_arrays(res2['angv_y'],timegeo[:, 8])
            cmp_arrays(res2['angv_z'],timegeo[:, 9])

            res_t_o=np.array([t_map(x) for x in timegeo[:, 0]])
            res_perc_o = np.array([perc_map(x) for x in times[::-1]])

            res_t_s=np.array([res2['t_map'](x) for x in res2['perc']])
            res_perc_s = np.array([res2['perc_map'](x) for x in res2['times'][::-1]])
            cmp_arrays(res_t_o,res_t_s)
            cmp_arrays(res_perc_o,res_perc_s)


    def create_object_db(self):
        pass

    def insert_object(self):
        pass

    def is_in_image_db(self,hash: np.int64) -> timeGeoObject:
        with sqlite3.connect(self.db_file_name) as connector:
            cursor=connector.cursor()
            cursor.execute("SELECT * FROM image WHERE hash = ?", (hash,))
            res=cursor.fetchone()
            if res is None:
                return None
            res2=dict(zip(self.image_columns.keys(), res))
            for col in self.image_arrays:
                res2[col]=np.frombuffer(res2[col],dtype=np.float64)

            res2['data']=np.resize(res2['data'],(len(res2['data'])//10,10))
#            res2['convex_hull']=np.resize(res2['convex_hull'],(len(res2['convex_hull'])//2,2))
            return timeGeoObject(hash,None,res2)


def copyMemoryToDisk(fn: str):
    memory_db = sqlite3.connect(':memory:')
    backup_db = sqlite3.connect(fn)
    memory_db.backup(backup_db)
    memory_db.close()
    backup_db.close()