import React, { Component } from 'react'
import PropTypes from 'prop-types'
import * as d3 from 'd3'
import { flatMapDeep, groupBy, map, mean, round, uniq } from 'lodash'
import { compose } from 'recompose'
import { connect } from 'react-redux'

import './Chart.css'
import { setDetailedComplex } from '../../../ducks/evaluation'

const color = (value, threshold = 0.8) =>
  d3.interpolateRgb('red', 'green')((value - threshold) * (1 / (1 - threshold)))

const accumulateLogs = logs =>
  logs.reduce((prev, log) => {
    return prev + log.reps * (log.weight / log.reference * 100)
  }, 0)

const calculateAverage = d => {
  return (
    accumulateLogs(d.logs) / uniq(map(d.logs, l => l.athleteDetails)).length
  )
}

const calculateTotal = d => calculateAverage(d)

const responsivefy = svg => {
  // get container + svg aspect ratio
  const container = d3.select(svg.node().parentNode)
  const width = parseInt(svg.attr('width'), 10)
  const height = parseInt(svg.attr('height'), 10)
  const aspect = width / height

  // get width of container and resize svg to fit it
  function resize() {
    const targetWidth = parseInt(container.style('width'), 10)
    svg.attr('width', targetWidth)
    svg.attr('height', Math.round(targetWidth / aspect))
  }

  // add viewBox and preserveAspectRatio properties,
  // and call resize so that svg resizes on inital page load
  svg
    .attr('viewBox', `0 0 ${width} ${height}`)
    .attr('preserveAspectRatio', 'xMinYMid')
    .call(resize)

  // to register multiple listeners for same event type,
  // you need to add namespace, i.e., 'click.foo'
  // necessary if you call invoke this function for multiple svgs
  // api docs: https://github.com/mbostock/d3/wiki/Selections#on
  d3.select(window).on(`resize.${svg.attr('id')}`, resize)
}

class Chart extends Component {
  componentDidMount = () => {
    this.setContext()
    this.update()
  }

  setContext = () => {
    this.createChart()
  }

  componentDidUpdate = (prevProps, prevState) => {
    this.update()
  }

  update = () => {
    const { data, detailedComplex } = this.props
    if (!detailedComplex) {
      this.renderEverything(data)
    } else {
      const complex = data.component.complexes.find(
        c => c.id === detailedComplex,
      )
      if (complex) {
        this.renderDetail(complex, data.logs)
      }
    }
  }

  createChart = () => {
    const { width: totalWidth, height: totalHeight, id } = this.props

    const margin = { top: 20, right: 20, bottom: 30, left: 40 }
    this.width = totalWidth - margin.left - margin.right
    this.height = totalHeight - margin.top - margin.bottom
    const { width, height } = this

    this.chart = d3
      .select(this.chartDiv)
      .append('svg')
      .attr('id', id)
      .attr('width', width + margin.left + margin.right)
      .attr('height', height + margin.top + margin.bottom)
      .call(responsivefy)
      .append('g')
      .attr('transform', `translate(${margin.left},${margin.top})`)

    this.chart.append('g').attr('class', 'axis axis--x')

    this.chart.append('g').attr('class', 'axis axis--y')
  }

  renderEverything = data => {
    const { chart, width, height } = this

    let accumulatedData = data.component.complexes.map(c => ({
      id: c.id,
      key: c.exercises.map(e => e.exercise.name).join(' & '),
      estimated: c.exercises.reduce(
        (acc, e) =>
          acc +
          e.sets.reduce((setAcc, set) => setAcc + set.reps * set.weight, 0),
        0,
      ),
      logs: flatMapDeep(c.exercises, e =>
        map(e.sets, set => data.logs.filter(l => l.set === set.id)),
      ),
    }))

    accumulatedData = accumulatedData.map(d => {
      const total = calculateTotal(d)
      return {
        ...d,
        total: isFinite(total) ? total : 0,
      }
    })

    const x = d3
      .scaleBand()
      .padding(0.2)
      .domain(accumulatedData.map(d => d.key))
      .range([0, width])

    const y = d3
      .scaleLinear()
      .domain([0, 1.2])
      .rangeRound([height, 0])

    chart
      .selectAll('g.axis.axis--x')
      .attr('transform', 'translate(0,' + height + ')')
      .call(d3.axisBottom(x))

    chart.selectAll('g.axis.axis--y').call(d3.axisLeft(y).ticks(10, '%'))

    const selection = chart.selectAll('.bar').data(accumulatedData, d => d.key)

    const textSelection = chart
      .selectAll('.text')
      .data(accumulatedData, d => d.key)

    selection
      .enter()
      .append('rect')
      .attr('class', 'bar')
      .attr('x', d => x(d.key))
      .attr('y', d => y(d.total / d.estimated))
      .attr('width', x.bandwidth())
      .attr('height', d => height - y(d.total / d.estimated))
      .attr('fill', d => color(d.total / d.estimated))
      .on('click', d => this.props.setDetailedComplex({ id: d.id }))

    // const update = chart.selectAll('.bar').data(accumulatedData)

    selection
      .transition()
      .ease(d3.easeQuadOut)
      .attr('y', d => y(d.total / d.estimated))
      .attr('height', d => height - y(d.total / d.estimated))
      .attr('fill', d => color(d.total / d.estimated))

    selection.exit().remove()

    textSelection.exit().remove()
  }

  renderDetail = (complex, logs) => {
    const { chart, width, height } = this

    const sets = flatMapDeep(complex.exercises, e =>
      e.sets.map(s => ({ ...s, label: `${s.reps}x${s.weight}` })),
    )

    const groupedLogs = groupBy(logs, 'set')

    const x = d3
      .scaleBand()
      .padding(0.2)
      .domain(sets.map((set, index) => set.id))
      .range([0, width])

    const y = d3
      .scaleLinear()
      .domain([0, 1.2])
      .rangeRound([height, 0])

    const xAxis = d3
      .axisBottom(x)
      .tickFormat(s => sets.find(set => set.id === s).label)

    chart
      .selectAll('g.axis.axis--x')
      .attr('transform', 'translate(0,' + height + ')')
      .call(xAxis)

    chart.selectAll('g.axis.axis--y').call(d3.axisLeft(y).ticks(10, '%'))

    const selection = chart.selectAll('.bar').data(sets, s => s.id)
    const textSelection = chart.selectAll('.text').data(sets, s => s.id)

    const meanLogs = set =>
      mean(groupedLogs[set.id].map(l => l.weight / l.reference)) /
      set.weight *
      100

    selection
      .enter()
      .append('rect')
      .attr('class', 'bar')
      .attr('x', set => x(set.id))
      .attr('y', set => y(meanLogs(set)))
      .attr('width', set => {
        const factor = mean(groupedLogs[set.id].map(l => l.reps)) / set.reps
        return x.bandwidth() * factor
      })
      .attr('height', set => height - y(meanLogs(set)))
      .attr('fill', set => color(meanLogs(set)))
      .on('click', d => this.props.setDetailedComplex({ id: null }))

    textSelection
      .enter()
      .append('text')
      .attr('class', 'text')
      .attr('x', set => x(set.id) + x.bandwidth() / 2)
      .attr('y', set => height - 20)
      .attr('text-anchor', 'middle')
      .attr('font-size', '12px')
      .text(
        set => `${round(mean(groupedLogs[set.id].map(l => l.reps)), 2)} reps`,
      )

    textSelection
      .enter()
      .append('text')
      .attr('class', 'text')
      .attr('x', set => x(set.id) + x.bandwidth() / 2)
      .attr('y', set => height - 10)
      .attr('text-anchor', 'middle')
      .attr('font-size', '12px')
      .text(
        set =>
          `${round(
            mean(groupedLogs[set.id].map(l => l.weight / l.reference * 100)),
          )} %`,
      )

    selection
      .enter()
      .append('rect')
      .attr('class', 'bar')
      .attr('x', set => x(set.id))
      .attr('y', d => y(1))
      .attr('width', x.bandwidth())
      .attr('height', d => height - y(1))
      .attr('stroke', 'black')
      .attr('stroke-dasharray', '5 10')
      .attr('stroke-width', '1')
      .attr('fill', 'none')
      .on('click', d => this.props.setDetailedComplex({ id: null }))

    selection.exit().remove()

    textSelection.exit().remove()
  }

  render() {
    return (
      <div
        ref={c => {
          this.chartDiv = c
        }}
      />
    )
  }
}

Chart.propTypes = {
  width: PropTypes.number,
  height: PropTypes.number,
}

Chart.defaultProps = {
  width: 600,
  height: 400,
}

const enhanceChart = compose(
  connect(
    ({ evaluation: { detail: { complex } } }) => ({ detailedComplex: complex }),
    { setDetailedComplex },
  ),
)

export default enhanceChart(Chart)
