package services

import (
	"math"
	"sort"
	"time"

	"github.com/montanaflynn/stats"
	"github.com/tgs266/dawn-go-common/common"
	"gitlab.cs.umd.edu/dawn/go-backend/dawn-gdd/models"
	"gitlab.cs.umd.edu/dawn/go-backend/dawn-gdd/models/enums"
	"gitlab.cs.umd.edu/dawn/go-backend/dawn-gdd/persistence"
	"gitlab.cs.umd.edu/dawn/go-backend/dawn-gdd/persistence/entities"
	"gitlab.cs.umd.edu/dawn/go-backend/dawn-gdd/utils"
)

func GetStageYearData(ctx common.DawnCtx, request models.GddRequest, comparison int) models.StageData {
	product := enums.GetProductFromString(request.Product)
	gddData := persistence.CurrentGddFindFirstByYearAndLocation(ctx, request.BuildLocation())
	gdds := utils.CalculateGddValues(gddData.MinTemps, gddData.MaxTemps, product, false)
	request.Year = gddData.AnalogYear

	var gs []entities.Gdd
	norms := persistence.GetLastNormalsYearly(request.BuildLocation())

	if comparison == -1 {
		gs = norms
	} else {
		gs = []entities.Gdd{persistence.GddFindFirstByYearAndLocation(comparison, request.BuildLocation())}
	}
	var normalMeanNonAcc []float64
	comparisonRows := [][]float64{}

	for i := 0; i < len(gs[0].MinTemps); i++ {
		rowComp := []float64{}
		rowNormal := []float64{}
		for j := 0; j < len(gs); j++ {
			rowComp = append(rowComp, utils.CalculateSingleGdd(gs[j].MinTemps[i], gs[j].MaxTemps[i], product))
		}
		for j := 0; j < len(norms); j++ {
			rowNormal = append(rowNormal, utils.CalculateSingleGdd(norms[j].MinTemps[i], norms[j].MaxTemps[i], product))
		}
		comparisonRows = append(comparisonRows, rowComp)
		normMeanNoAccValue, _ := stats.Mean(rowNormal)
		normalMeanNonAcc = append(normalMeanNonAcc, normMeanNoAccValue)
	}

	allCfs := persistence.CfsFindByLocationMultiple(request.BuildLocation(), 4)
	// cfsMeans := persistence.CfsFindAllByLocation(request.BuildLocation())

	gddArr := [][]float64{}
	for i, c := range allCfs {
		gddArr = append(gddArr, gdds)
		cfsGddData := utils.CalculateGddValues(c.MinTemps, c.MaxTemps, product, false) // not accumulated
		// anomaly adjustment function
		// cfsGddData := utils.CalculateGddValuesCfsNormed(c.MinTemps, c.MaxTemps, product, cfsMeans.MinTemps, cfsMeans.MaxTemps, normalMeanNonAcc) // not accumulated
		gddArr[i] = append(gddArr[i], cfsGddData...)
		if len(gddArr[i]) > len(normalMeanNonAcc) {
			gddArr[i] = gddArr[i][:len(normalMeanNonAcc)]
		} else {
			gddArr[i] = append(gddArr[i], normalMeanNonAcc[len(gddArr[i]):]...)
		}
	}
	// none of this data is accumulated
	returnData := models.StageData{
		AllGdds:       gddArr,
		ComparisonAll: comparisonRows,
	}
	return returnData
}

func CalculateStages(ctx common.DawnCtx, request models.StageRequest) map[string]models.Bins {
	gddReq := models.GddRequest{
		Year:       request.PlantDate.Year(),
		Latitude:   request.Latitude,
		Longitude:  request.Longitude,
		Accumulate: false,
		Product:    "CORN",
	}
	fyData := GetStageYearData(ctx, gddReq, request.Comparison)

	start := request.PlantDate.YearDay()
	year := request.PlantDate.Year()
	if year%4 == 0 && year%100 != 0 || year%400 == 0 {
		start -= 1
	}

	state := map[string]models.StageStateInner{}
	stageMatches := models.BuildStageMatches(request.Mode, request.Value, start, fyData, request)

	accs := make([]float64, len(fyData.AllGdds))
	accs2 := make([]float64, len(fyData.ComparisonAll[0]))
	accNormal := 0.0
	for i := start; i < len(fyData.AllGdds[0]); i++ {

		for r, v := range fyData.AllGdds {
			accs[r] += v[i]
		}
		for j := 0; j < len(fyData.ComparisonAll[0]); j++ {
			accs2[j] += fyData.ComparisonAll[i][j]
		}

		normal, _ := stats.Mean(accs2)

		accNormal = normal

		for stage, stageVal := range stageMatches {
			dists := make([]float64, len(fyData.AllGdds))
			for r, v := range accs {
				dists[r] = math.Abs(stageVal - v)
			}
			if val, ok := state[stage]; !ok {

				state[stage] = models.StageStateInner{
					Dists:          dists,
					Hists:          make([]int, len(fyData.AllGdds)),
					NormalMeanDist: 1000000,
					NormalMeanIdx:  0,
				}
			} else {
				normalMeanDist := math.Abs(stageVal - accNormal)

				if normalMeanDist < val.NormalMeanDist {
					val.NormalMeanDist = normalMeanDist
					val.NormalMeanIdx = i
				}

				for r := range accs {
					if dists[r] < val.Dists[r] {
						val.Hists[r] = i
						val.Dists[r] = dists[r]
					}
				}
				state[stage] = val
			}
		}

	}
	ret := BinStageMatches(state, year, start, request.PlantDate)
	return ret

}

func AvgDiff(data []int) float64 {
	sort.Ints(data)
	sum := 0.0
	c := 0
	for i := 0; i < len(data)-1; i++ {
		diff := math.Abs(float64(data[i] - data[i+1]))
		sum += diff
		c += 1
	}
	return sum / float64(c)
}
func Min(data []int) int {
	sort.Ints(data)
	return data[0]
}

func BinStageMatches(stageState map[string]models.StageStateInner, year int, start int, plantDate time.Time) map[string]models.Bins {
	response := map[string]models.Bins{}
	alpha := 1.0
	add := 0
	if year%4 == 0 && year%100 != 0 || year%400 == 0 {
		add -= 1
	}

	binCount := 3

	for state, stateVal := range stageState {
		// min := stateVal.Normal95thIdx
		min := Min(stateVal.Hists)
		stepSize := int(math.Ceil(AvgDiff(stateVal.Hists)) + 1) // add 1 to increase range (cheating a little) and for uncertainty
		arr := []float64{}
		idxs := []int{}
		base := min
		total := 0
		for i := 0; i < binCount; i++ {
			count := 0.0
			for _, h := range stateVal.Hists {
				if base <= h && h < base+stepSize {
					count += 1
					total += 1
				}
			}
			idxs = append(idxs, base)
			arr = append(arr, count)
			base += stepSize
		}
		inner := models.Bins{}
		inner.Bins = []models.Bin{}
		for i := 0; i < binCount; i++ {
			idx := idxs[i] + add
			date := plantDate.AddDate(0, 0, idx-start)
			val := arr[i]
			smoothedVal := (val + alpha) / (float64(total) + float64(binCount)*alpha) // modified version of laplace smoothing to remove 0%
			inner.Bins = append(inner.Bins, models.Bin{
				Date:  date,
				Value: smoothedVal,
			})
		}
		inner.ComparisonMean = plantDate.AddDate(0, 0, stateVal.NormalMeanIdx-start)
		inner.Count = total
		response[state] = inner
	}
	return response
}

func ForecastFirstLastFreeze(ctx common.DawnCtx, request models.FreezingForecastRequest) models.FreezingForecastResponse {
	lastFreezeIdx := 0
	firstFreezeIdx := 0

	baseData := persistence.CurrentGddFindFirstByYearAndLocation(ctx, models.BuildLocation(request.Latitude, request.Longitude))
	cfsData := persistence.CfsFindAllByLocation(models.BuildLocation(request.Latitude, request.Longitude))
	normalsData := persistence.NormalsFindFirstByYearAndLocation(models.BuildLocation(request.Latitude, request.Longitude))

	cfsData.MinTemps = append(baseData.MinTemps, cfsData.MinTemps...)

	if len(cfsData.MinTemps) < len(normalsData.MinTemps) {
		smallerNormalRegion := normalsData.MinTemps[len(cfsData.MinTemps):]
		cfsData.MinTemps = append(cfsData.MinTemps, smallerNormalRegion...)
	}

	startDate := time.Date(time.Now().Year(), time.January, 1, 0, 0, 0, 0, time.UTC)

	firstHalfFirstDate := time.Date(time.Now().Year(), time.January, 1, 0, 0, 0, 0, time.UTC)
	firstHalfLastDate := time.Date(time.Now().Year(), time.July, 31, 0, 0, 0, 0, time.UTC)

	lastHalfFirstDate := time.Date(time.Now().Year(), time.August, 1, 0, 0, 0, 0, time.UTC)
	lastHalfLastDate := time.Date(time.Now().Year(), time.December, 31, 0, 0, 0, 0, time.UTC)

	for i := 0; i < len(cfsData.MinTemps); i++ {
		currentDate := startDate.AddDate(0, 0, i)
		if cfsData.MinTemps[i] <= request.FreezingTemp && currentDate.After(firstHalfFirstDate) && currentDate.Before(firstHalfLastDate) {
			lastFreezeIdx = i
		}
		if cfsData.MinTemps[i] <= request.FreezingTemp && currentDate.After(lastHalfFirstDate) && currentDate.Before(lastHalfLastDate) && firstFreezeIdx == 0 {
			firstFreezeIdx = i
			break
		}
	}

	lastFreezeDate := startDate.AddDate(0, 0, lastFreezeIdx)
	firstFreezeDate := startDate.AddDate(0, 0, firstFreezeIdx)

	return models.FreezingForecastResponse{
		LastFreeze:  []time.Time{lastFreezeDate},
		FirstFreeze: []time.Time{firstFreezeDate},
	}

}
