import numpy as np
import xarray as xr
import tqdm 
import datetime, time, math
from pymongo import MongoClient
from numba import njit


def is_leap_year(year):
    if (year % 4) == 0:
        if (year % 100) == 0:
            if (year % 400) == 0:
                return True
            else:
                return False
        else:
            return True
    else:
        return False 

client = MongoClient("mongodb+srv://gdd-server:u8i3icLAJXjZEhTs@cluster0.wdxf4.mongodb.net")

db = client["gdd_database"]

# gdd = db.gdd

# gdd.drop()

# gdd = db["gdd"]

# resp = gdd.create_index([ ("location", "2dsphere") ])
# resp = gdd.create_index([ ("year", 1) ])

coords = xr.open_dataset("coords.nc")
lat = coords.latitude.data
lon = coords.longitude.data
lat = lat[::-1]



# years = list(range(1981, 2020 + 1))

# for year in years:
#     soy = np.datetime64("%s-01-01" % year)
#     print (year)
#     data = xr.open_dataset("data/temps_%s.nc" % year)
#     x = np.where(~np.isnan(np.nanmean(data.tmin.data, axis=0)))
#     # x = [(a, b) for a, b in zip(lat_a, lon_a)] # fix to x[0], x[1] when atlas limit removed
#     # FORCE LOCATIONS TO COLLEGE PARK, LAT 38.99 LON -76.94 BECAUSE OF ATLAS LIMIT

#     a1 = np.where(38.5 < lat)[0].tolist()
#     a2 = np.where(lat < 39.5)[0].tolist()
#     lat_a = np.array(list(set(a1) & set(a2)))

#     a1 = np.where(-77 < lon)[0].tolist()
#     a2 = np.where(lon < -76)[0].tolist()
#     lon_a = np.array(list(set(a1) & set(a2)))

#     x1 = np.array(np.meshgrid(lat_a, lon_a)).T.reshape(len(lat_a) * len(lon_a), 2).tolist()
#     x1 = [(z[0], z[1]) for z in x1]
#     x2 = [(a, b) for a, b in zip(x[0], x[1])] # fix to x = [..... (x[0], x[1])] and all limiting stuff above and below when atlas limit removed

#     x = list(set(x1) & set(x2))

#     tmins = data.tmin.data
#     tmaxs = data.tmax.data 


#     # if is_leap_year(year): # extra day in leap year screws everything up

#     #     tmin_1 = tmins[:59]
#     #     tmin_2 = tmins[60:]

#     #     tmax_1 = tmaxs[:59]
#     #     tmin_2 = tmaxs[60:]

#     #     tmins = np.concatenate([tmin_1, tmin_2], axis=0)
#     #     tmaxs = np.concatenate([tmax_1, tmin_2], axis=0)


#     locs = []

#     count = 0
#     for i in tqdm.tqdm(x):
#         if len(locs) % 100 == 0 and len(locs) != 0:
#             new_result = gdd.insert_many(locs)
#             locs = []
        
#         tmin_ = tmins[:, i[0], i[1]]
#         tmax_ = tmaxs[:, i[0], i[1]]
        
#         lat_ = lat[i[0]]
#         lon_ = lon[i[1]]

#         a = i

#         t = {}

#         _id = str(year) + "_"

#         _id += str(a[0]) + "_" + str(a[1])
        
#         t["location"] = {"type": "Point", "coordinates": [float(lon_), float(lat_)]}
#         t["prism_lat"] = int(a[0])
#         t["prism_lon"] = int(a[1])
        
#         t["last_date"] = datetime.datetime.strptime(str(soy + np.timedelta64(len(tmin_) - 1, "D")) , "%Y-%m-%d")
#         t["year"] = int(year)
#         t["min_temps"] = list([float(a) for a in tmin_])
#         t["max_temps"] = list([float(a) for a in tmax_])
#         t["_id"] = _id
     
#         locs.append(t)
        
#         count += 1

#     if len(locs) != 0:
#         new_result = gdd.insert_many(locs)

### 30 YEAR NORMALS ###
### Covers from 1981-2010 ###


@njit(fastmath=True, parallel=True)
def elemetwise(tmins, tmaxs):
    return np.divide(np.add(tmins, tmaxs), 2)
# @njit(fastmath=True, parallel=False)
def std(arr, fill_arr):
    i_size = 122
    j_size = 69 # 207
    k_size = 281 # 207
    pbar = tqdm.tqdm(total=np.product([i for i in arr.shape[1:]]))

    for j in range(0, fill_arr.shape[1], j_size):
        for k in range(0, fill_arr.shape[2], k_size):
            a = arr[:, :, j:j+j_size, k:k+k_size]
            std_ = np.std(a, axis=0)

            fill_arr[:, j:j+j_size, k:k+k_size] = std_
            pbar.update(366 * j_size * k_size)
    pbar.close()
    return fill_arr

a1 = np.where(38.5 < lat)[0].tolist()
a2 = np.where(lat < 39.5)[0].tolist()
lat_a = np.array(list(set(a1) & set(a2)))

a1 = np.where(-77 < lon)[0].tolist()
a2 = np.where(lon < -76)[0].tolist()
lon_a = np.array(list(set(a1) & set(a2)))

# x1 = np.array(np.meshgrid(lat_a, lon_a)).T.reshape(len(lat_a) * len(lon_a), 2).tolist()
# x1 = [(z[0], z[1]) for z in x1]
# x2 = [(a, b) for a, b in zip(x[0], x[1])] # fix to x = [..... (x[0], x[1])] and all limiting stuff above and below when atlas limit removed

# x = list(set(x1) & set(x2))

# print (lat.shape)
# exit()


db = client["gdd_database"]

gdd = db.normals

gdd.drop()

gdd = db["normals"]

resp = gdd.create_index([ ("location", "2dsphere") ])
resp = gdd.create_index([ ("year", 1) ])

year_gdd_base = np.memmap("year_gdd_base.npy", mode="w+", shape=(40, 366, len(lat_a), len(lon_a)), dtype=np.float32)
corn_year_gdd_base = np.memmap("corn_year_gdd_base.npy", mode="w+", shape=(40, 366, len(lat_a), len(lon_a)), dtype=np.float32)

ld_min = np.empty((year_gdd_base.shape[-2], year_gdd_base.shape[-1]), dtype=np.float16)
ld_max = np.empty((year_gdd_base.shape[-2], year_gdd_base.shape[-1]), dtype=np.float16)
ld_count = 0 

i = 0
for year in range(1981, 2020+1):
    print (year)
    t = time.time()
    data = xr.open_dataset("data/temps_%s.nc" % year)

    tmins = data.tmin.data
    tmaxs = data.tmax.data 

    tmaxs = tmaxs[np.ix_(list(range(len(tmaxs))),lat_a, lon_a)]
    tmins = tmins[np.ix_(list(range(len(tmaxs))),lat_a, lon_a)]
    
    tmins = tmins.astype(np.float32)#np.cast(tmins, np.float16)
    tmaxs = tmaxs.astype(np.float32)#np.cast(tmins, np.float16)
    del data

    ## idx 59 is leap day

    if not is_leap_year(year): # extra day in leap year screws everything up
        tmins = np.insert(tmins, 59, np.nan, axis=0)
        tmaxs = np.insert(tmaxs, 59, np.nan, axis=0)
    
    else:
        # TODO One for normal one for reg
        ld_min += tmins[59]
        ld_max += tmaxs[59]
        ld_count += 1

    test = elemetwise(tmins, tmaxs)
    year_gdd_base[i] = test#(tmaxs + tmins) / 2
    over_86 = tmins > 86
    under_50 = tmaxs < 50
    tmaxs[over_86] = 86  
    tmaxs[under_50] = 50  
    tmins[over_86] = 86  
    tmins[under_50] = 50 

    test = elemetwise(tmins, tmaxs)

    corn_year_gdd_base[i] = test#(tmaxs + tmins) / 2
    i += 1


print ("generating means")
normals_gdd_base = year_gdd_base[:30].mean(axis=0)
corn_normals_gdd_base = corn_year_gdd_base[:30].mean(axis=0)

corn_global_mean = corn_year_gdd_base.mean(axis=0)
gdd_base_mean = year_gdd_base.mean(axis=0)

ld_min = ld_min/ld_count
ld_max = ld_max/ld_count

ld_gdd_base = (ld_min + ld_max) / 2
ld_min = np.clip(ld_min, 50, 86)
ld_max = np.clip(ld_max, 50, 86)
corn_ld_gdd_base = (ld_min + ld_max) / 2

corn_normals_gdd_base[59] = corn_ld_gdd_base
normals_gdd_base[59] = ld_gdd_base

corn_global_mean[59] = corn_ld_gdd_base
gdd_base_mean[59] = ld_gdd_base

corn_year_gdd_base[:, 59] = corn_ld_gdd_base
year_gdd_base[:, 59] = ld_gdd_base
print ("completed first bit")

corn_year_gdd_std = np.empty((366, len(lat_a), len(lon_a)))
year_gdd_base_std = np.empty((366, len(lat_a), len(lon_a)))

print ("corn std")
corn_year_gdd_std = std(corn_year_gdd_base, corn_year_gdd_std)
print ("gdd std")
year_gdd_base_std = std(year_gdd_base, year_gdd_base_std)
print ("done")

x = np.where(~np.isnan(np.nanmean(corn_normals_gdd_base, axis=0)))
# x = [(a, b) for a, b in zip(x[0], x[1])]

# FORCE LOCATIONS TO COLLEGE PARK, LAT 38.99 LON -76.94 BECAUSE OF ATLAS LIMIT

# a1 = np.where(38.5 < lat)[0].tolist()
# a2 = np.where(lat < 39.5)[0].tolist()
# lat_a = np.array(list(set(a1) & set(a2)))

# a1 = np.where(-77 < lon)[0].tolist()
# a2 = np.where(lon < -76)[0].tolist()
# lon_a = np.array(list(set(a1) & set(a2)))

x1 = np.array(np.meshgrid(lat_a, lon_a)).T.reshape(len(lat_a) * len(lon_a), 2).tolist()
x1 = [(z[0], z[1]) for z in x1]
x2 = [(a, b) for a, b in zip(x[0], x[1])] # fix to x = [..... (x[0], x[1])] and all limiting stuff above and below when atlas limit removed

# x = list(set(x1) & set(x2))
# print (x2)

# tmins = single_year_min
# tmaxs = single_year_max

locs = []

count = 0
for i in tqdm.tqdm(x2):
    if len(locs) % 100 == 0 and len(locs) != 0:
        new_result = gdd.insert_many(locs)
        locs = []
    
    normals_gdd_base_ = normals_gdd_base[:, i[0], i[1]]
    corn_normals_gdd_base_ = corn_normals_gdd_base[:, i[0], i[1]]

    # print (normals_gdd_base_)40

    corn_std_ = corn_year_gdd_std[:, i[0], i[1]]
    main_std_ = year_gdd_base_std[:, i[0], i[1]]
    corn_mean = corn_global_mean[:, i[0], i[1]]
    main_mean = gdd_base_mean[:, i[0], i[1]]
    # corn_std_ = corn_year_gdd_std[:, i[0], i[1]]
    # main_std_ = year_gdd_base_std[:, i[0], i[1]]
    
    lat_ = lat[i[0]]
    lon_ = lon[i[1]]

    a = i

    t = {}

    _id = str(year) + "_"

    _id += str(a[0]) + "_" + str(a[1]) + "_" + "normal"
    
    t["location"] = {"type": "Point", "coordinates": [float(lon_), float(lat_)]}
    t["prism_lat"] = int(a[0])
    t["prism_lon"] = int(a[1])
    
    t["normals_gdd_base"] = list([float(a) for a in normals_gdd_base_])
    t["corn_normals_gdd_base"] = list([float(a) for a in corn_normals_gdd_base_])
    t["global_mean"] = list([float(a) for a in main_mean])
    t["corn_global_mean"] = list([float(a) for a in corn_mean])
    # t["combined_min_mean"] = list([float(a) for a in combined_min_])
    # t["combined_max_mean"] = list([float(a) for a in combined_max_])
    # t["corn_std"] = list([float(a) for a in corn_std_])
    # t["main_std"] = list([float(a) for a in main_std_])
    t["corn_std"] = list([float(a) for a in corn_std_])
    t["main_std"] = list([float(a) for a in main_std_])
    t["_id"] = _id
    
    locs.append(t)
    
    count += 1

if len(locs) != 0:
    new_result = gdd.insert_many(locs)
