/** @jsx jsx **/

import { Component } from "react";
import { jsx, css } from "@emotion/core";
import { json, csv, keys, extent, max, nest } from "d3";
import { select, selectAll, event as currentEvent } from "d3-selection";
import { format } from "d3-format";
import { sankey, sankeyLinkHorizontal } from "d3-sankey";
import { drag } from "d3-drag";
import { scaleOrdinal, scaleLinear, scaleBand } from "d3-scale";
import { axisBottom, axisLeft, axisTop } from "d3-axis";
import { area, line } from "d3-shape";
import d3Tip from "d3-tip";

import data from "../assets/data/sankey_when_relationship.csv";

const d3 = {
  csv,
  select,
  selectAll,
  sankey,
  sankeyLinkHorizontal,
  drag,
  scaleOrdinal,
  scaleLinear,
  json,
  keys,
  extent,
  max,
  axisBottom,
  axisLeft,
  area,
  line,
  format,
  nest,
  scaleBand,
  axisTop
};

const styles = css`
  display: flex;
`;

const nameToLabel = {
  first_year: ["First Year"],
  "no relationship_y1": ["No relationship", "after 1st Year"],
  "no relationship_y2": ["No relationship", "after 2nd Year"],
  "no relationship_y3": ["No relationship", "after 3rd Year"],
  "no relationship_y4": ["No college", "relationship"],
  relationship: ["Had a relationship"]
};

export default class Sankey extends Component {
  constructor(props) {
    super(props);

    this.state = {
      width: 480,
      height: 720,
      data: null
    };
  }

  componentDidMount() {
    d3.csv(data)
      .then(csvData => {
        this.setState({ data: csvData });

        let finalData = this.formatData();

        this.initialize(finalData);
      })
      .catch(function(err) {
        throw err;
      });
  }

  formatData = () => {
    let data = { nodes: [], links: [] };

    this.state.data.forEach(function(d) {
      data.nodes.push({ name: d.source });
      data.nodes.push({ name: d.target });
      data.links.push({ source: d.source, target: d.target, value: +d.n });
    });

    // return only the distinct / unique nodes
    data.nodes = d3.keys(
      d3
        .nest()
        .key(function(d) {
          return d.name;
        })
        .object(data.nodes)
    );

    // loop through each link replacing the text with its index from node
    data.links.forEach(function(d, i) {
      data.links[i].source = data.nodes.indexOf(data.links[i].source);
      data.links[i].target = data.nodes.indexOf(data.links[i].target);
    });

    // now loop through each nodes to make nodes an array of objects
    // rather than an array of strings
    data.nodes.forEach(function(d, i) {
      data.nodes[i] = { name: d };
    });

    return data;
  };

  initialize = data => {
    var margin = { top: 40, right: 40, bottom: 80, left: 40 };

    var svg = d3
      .select("#" + this.props.variable + "-sankey-container")
      .append("svg")
      .attr("width", this.state.width)
      .attr("height", this.state.height)
      .append("g")
      .attr("id", this.props.variable + "-sankey")
      .attr("transform", "translate(" + margin.left + ", " + margin.top + ")");

    this.setState({
      width: this.state.width - margin.right - margin.left,
      height: this.state.height - margin.top - margin.bottom
    });

    this.update(data);
  };

  update = data => {
    var svg = d3.select("#" + this.props.variable + "-sankey");

    let colorScale = d3
      .scaleOrdinal()
      .domain(data.nodes.map(d => d.name))
      .range([
        "#A9AEA5",
        "#cad2d1",
        "#E0A69E",
        "#B37C8D",
        "#90686a",
        "#F0C6CD"
      ]);

    // Set the sankey diagram properties
    var sankey = d3
      .sankey()
      .nodeWidth(36)
      .nodePadding(80)
      .size([this.state.width, this.state.height]);

    let graph = sankey(data);

    // add in the links
    var link = svg
      .append("g")
      .selectAll(".link")
      .data(graph.links)
      .enter()
      .append("path")
      .attr("class", "link")
      .attr("d", d3.sankeyLinkHorizontal())
      .attr("fill", "none")
      .attr("stroke", "#606060")
      .attr("stroke-width", d => d.width)
      .attr("stroke-opacity", 0.3)
      .sort(function(a, b) {
        return b.dy - a.dy;
      });

    svg
      .append("g")
      .attr("class", "link-labels")
      .selectAll("text")
      .data(graph.links)
      .enter()
      .append("text")
      .attr("class", "link-label")
      .text(d => d.value)
      .attr(
        "transform",
        d =>
          "translate(" +
          (d.source.x1 + 10) +
          ", " +
          (d.y0 + d.y1 + 16) / 2 +
          ")"
      );

    // add in the nodes
    let nodes = svg
      .append("g")
      .classed("nodes", true)
      .selectAll("rect")
      .data(graph.nodes)
      .enter()
      .append("rect")
      .classed("node", true)
      .attr("x", d => d.x0)
      .attr("y", d => d.y0)
      .attr("width", d => d.x1 - d.x0)
      .attr("height", d => d.y1 - d.y0)
      .attr("fill", d => colorScale(d.name))
      .attr("opacity", 0.8);

    svg
      .append("g")
      .attr("class", "node-labels")
      .selectAll("text")
      .data(graph.nodes)
      .enter()
      .append("text")
      .attr("class", "node-label")
      .html(d => {
        let label = nameToLabel[d.name];
        if (label.length === 1) {
          return nameToLabel[d.name][0];
        } else {
          return (
            "<tspan  x='0'>" +
            label[0] +
            "</tspan><tspan  x='0' dy='12'> " +
            label[1] +
            "</tspan>"
          );
        }
      })
      .attr(
        "transform",
        d => "translate(" + (d.x0 + 20) + ", " + (d.y0 + 8) + ")rotate(90)"
      )
      .style("font-size", 12);
  };

  render() {
    return (
      <div css={styles}>
        <div
          id={this.props.variable + "-sankey-container"}
          style={{ margin: "auto" }}
        ></div>
      </div>
    );
  }
}
