import numpy as np
import xarray as xr
import tqdm 
import datetime
from pymongo import MongoClient

def is_leap_year(year):
    if (year % 4) == 0:
        if (year % 100) == 0:
            if (year % 400) == 0:
                return True
            else:
                return False
        else:
            return True
    else:
        return False 

client = MongoClient("mongodb+srv://gdd-server:u8i3icLAJXjZEhTs@cluster0.wdxf4.mongodb.net")

db = client["gdd_database"]

gdd = db.gdd

gdd.drop()

gdd = db["gdd"]

resp = gdd.create_index([ ("location", "2dsphere") ])
resp = gdd.create_index([ ("year", 1) ])

coords = xr.open_dataset("coords.nc")
lat = coords.latitude.data
lon = coords.longitude.data
lat = lat[::-1]

years = list(range(1981, 2020 + 1))

for year in years:
    soy = np.datetime64("%s-01-01" % year)
    print (year)
    data = xr.open_dataset("data/temps_%s.nc" % year)
    x = np.where(~np.isnan(np.nanmean(data.tmin.data, axis=0)))
    # x = [(a, b) for a, b in zip(lat_a, lon_a)] # fix to x[0], x[1] when atlas limit removed
    # FORCE LOCATIONS TO COLLEGE PARK, LAT 38.99 LON -76.94 BECAUSE OF ATLAS LIMIT

    a1 = np.where(38.5 < lat)[0].tolist()
    a2 = np.where(lat < 39.5)[0].tolist()
    lat_a = np.array(list(set(a1) & set(a2)))

    a1 = np.where(-77 < lon)[0].tolist()
    a2 = np.where(lon < -76)[0].tolist()
    lon_a = np.array(list(set(a1) & set(a2)))

    x1 = np.array(np.meshgrid(lat_a, lon_a)).T.reshape(len(lat_a) * len(lon_a), 2).tolist()
    x1 = [(z[0], z[1]) for z in x1]
    x2 = [(a, b) for a, b in zip(x[0], x[1])] # fix to x = [..... (x[0], x[1])] and all limiting stuff above and below when atlas limit removed

    x = list(set(x1) & set(x2))

    tmins = data.tmin.data
    tmaxs = data.tmax.data 


    if is_leap_year(year): # extra day in leap year screws everything up

        tmin_1 = tmins[:59]
        tmin_2 = tmins[60:]

        tmax_1 = tmaxs[:59]
        tmin_2 = tmaxs[60:]

        tmins = np.concatenate([tmin_1, tmin_2], axis=0)
        tmaxs = np.concatenate([tmax_1, tmin_2], axis=0)


    locs = []

    count = 0
    for i in tqdm.tqdm(x):
        if len(locs) % 100 == 0 and len(locs) != 0:
            new_result = gdd.insert_many(locs)
            locs = []
        
        tmin_ = tmins[:, i[0], i[1]]
        tmax_ = tmaxs[:, i[0], i[1]]
        
        lat_ = lat[i[0]]
        lon_ = lon[i[1]]

        a = i

        t = {}

        _id = str(year) + "_"

        _id += str(a[0]) + "_" + str(a[1])
        
        t["location"] = {"type": "Point", "coordinates": [float(lon_), float(lat_)]}
        t["prism_lat"] = int(a[0])
        t["prism_lon"] = int(a[1])
        
        t["last_date"] = datetime.datetime.strptime(str(soy + np.timedelta64(len(tmin_) - 1, "D")) , "%Y-%m-%d")
        t["year"] = int(year)
        t["min_temps"] = list([float(a) for a in tmin_])
        t["max_temps"] = list([float(a) for a in tmax_])
        t["_id"] = _id
     
        locs.append(t)
        
        count += 1

    if len(locs) != 0:
        new_result = gdd.insert_many(locs)

### 30 YEAR NORMALS ###
### Covers from 1981-2010 ###

db = client["gdd_database"]

gdd = db.normals

gdd.drop()

gdd = db["normals"]

resp = gdd.create_index([ ("location", "2dsphere") ])
resp = gdd.create_index([ ("year", 1) ])

single_year_min = np.zeros((365, 621, 1405))
single_year_max = np.zeros((365, 621, 1405))

for year in range(1981, 2010+1):
    print (year)
    data = xr.open_dataset("data/temps_%s.nc" % year)

    tmins = data.tmin.data
    tmaxs = data.tmax.data  

    if is_leap_year(year): # extra day in leap year screws everything up

        tmin_1 = tmins[:59]
        tmin_2 = tmins[60:]

        tmax_1 = tmaxs[:59]
        tmin_2 = tmaxs[60:]

        tmins = np.concatenate([tmin_1, tmin_2], axis=0)
        tmaxs = np.concatenate([tmax_1, tmin_2], axis=0)

    single_year_max += tmaxs/30
    single_year_min += tmins/30

x = np.where(~np.isnan(np.nanmean(single_year_max, axis=0)))
# x = [(a, b) for a, b in zip(x[0], x[1])]

# FORCE LOCATIONS TO COLLEGE PARK, LAT 38.99 LON -76.94 BECAUSE OF ATLAS LIMIT

a1 = np.where(38.5 < lat)[0].tolist()
a2 = np.where(lat < 39.5)[0].tolist()
lat_a = np.array(list(set(a1) & set(a2)))

a1 = np.where(-77 < lon)[0].tolist()
a2 = np.where(lon < -76)[0].tolist()
lon_a = np.array(list(set(a1) & set(a2)))

x1 = np.array(np.meshgrid(lat_a, lon_a)).T.reshape(len(lat_a) * len(lon_a), 2).tolist()
x1 = [(z[0], z[1]) for z in x1]
x2 = [(a, b) for a, b in zip(x[0], x[1])] # fix to x = [..... (x[0], x[1])] and all limiting stuff above and below when atlas limit removed

x = list(set(x1) & set(x2))

tmins = single_year_min
tmaxs = single_year_max

locs = []

count = 0
for i in tqdm.tqdm(x):
    if len(locs) % 100 == 0 and len(locs) != 0:
        new_result = gdd.insert_many(locs)
        locs = []
    
    tmin_ = tmins[:, i[0], i[1]]
    tmax_ = tmaxs[:, i[0], i[1]]
    
    lat_ = lat[i[0]]
    lon_ = lon[i[1]]

    a = i

    t = {}

    _id = str(year) + "_"

    _id += str(a[0]) + "_" + str(a[1]) + "_" + "normal"
    
    t["location"] = {"type": "Point", "coordinates": [float(lon_), float(lat_)]}
    t["prism_lat"] = int(a[0])
    t["prism_lon"] = int(a[1])
    
    t["min_temps"] = list([float(a) for a in tmin_])
    t["max_temps"] = list([float(a) for a in tmax_])
    t["_id"] = _id
    
    locs.append(t)
    
    count += 1

if len(locs) != 0:
    new_result = gdd.insert_many(locs)
