<template>
    <div ref="chartWrapper">
        <div ref="chart"></div>
        <div ref="tooltip" class="tooltip"></div>
    </div>
  </template>
  
  <script>
  import * as d3 from 'd3';
  import Vue from 'vue';
  
  export default Vue.component('research-violin-plot', {
    props: {
      data: {                           
        type: (Array, null),
        required: true,
      },
      primaryKey: {                   
        type: String,
        required: true,
      },
      subsetKey: {                   
        type: String,
        required: false,
        default: null
      },
      width: {                          //unused, chart will take size of its parent container
        type: Number,
        requred: false,
        default: 600
      },
      height: {                         
        type: Number,
        requred: false,
        default: 300
      },
      highlightKey: {                   //key of label to highlight
        type: String,
        required: false,
      },
      yAxisLabel: {
        type: String,
        required: false,
      },
      xAxisLabel: {
        type: String,
        required: false,
      }
    },
    data() {
        return {
        }
    },
    watch: {
        data() {
            this.drawChart();
        },
        highlightKey(key) {
            this.doHighlight(key);
        }
    },
    mounted() {
        if(this.data){
            this.drawChart();
        }else{
            console.log('no data');
        }
        //window.addEventListener('resize', this.handleResize);
    },
    beforeDestroy(){
        //window.removeEventListener('resize', this.handleResize);
    },
    methods: {
        handleResize(){
            this.drawChart();
        },
        drawChart(){
            console.log("---Violin Plot");
            console.log("   data", this.data);

            if(!this.data) return;

            const tooltip = this.$refs.tooltip;
            const primaryKey = this.primaryKey;
            const subsetKey = this.subsetKey;

            const keys = Array.from(new Set(this.data.map((d) => d[primaryKey])));
            const hasSubsetKey = subsetKey && this.data[0][subsetKey] && this.data[0][subsetKey] !== "";
            const domain = hasSubsetKey ? this.data.map((d) => d[primaryKey] +' - '+d[subsetKey]) : keys;

            //pre-render x-axis labels to get the their max height
            //this way we can ensure long labels dont get cut off at the bottom
            const tempsvg = d3.select(this.$refs.chart)
                .append('svg')
            const templabels = tempsvg.append("g")
                .selectAll("text")
                .data(keys).enter()
                .append("text").text(d => d)
                .style("text-anchor", "end")
                .attr('font-size', '12px')
                .attr("transform", "rotate(-55)");
            const bbox = templabels.node().parentNode.getBBox();
            const labelsHeight = bbox.height;     

            //clear rendering
            d3.select(this.$refs.chart).html('')

            //calculate sizes and margins
            const parentWidth = this.$refs.chartWrapper.parentElement.offsetWidth;
            const labels = { xAxis: this.xAxisLabel?20:0, yAxis: this.yAxisLabel?20:0 }
            const margin = { top: 20, right: 10, bottom: labelsHeight + labels.xAxis, left: 40 };
            let width = parentWidth;
            let height = this.height;
            if(margin.bottom > (height/2)){
                height = margin.bottom * 2;
            }
            let plotWidth = width - margin.left - margin.right - labels.xAxis;
            let plotHeight = height - margin.top - margin.bottom - labels.yAxis;

            /*
            //update plot width so each violin has min size
            //warn: this will cause violin plot to be wider than requested
            //if there are many items
            let itemWidth = plotWidth / this.data.length;
            itemWidth = itemWidth < 10 ? 10 : itemWidth;
            plotWidth = itemWidth * this.data.length;
            width = plotWidth + margin.left + margin.right; 
            */

            const min = d3.min(this.data, (d) => d.min);
            const max = d3.max(this.data, (d) => d.max);

            const svg = d3.select(this.$refs.chart)
                .append('svg')
                .attr('width', width)
                .attr('height', height)

            if(this.xAxisLabel){
                const label = svg.append('g')
                    .append('text')
                    .attr('class', 'chart-label')
                    .text(this.xAxisLabel)
                    const bbox = label.node().getBBox();
                    const xAxisLabelTopPosition = (margin.top + plotHeight / 2) + (bbox.width / 2);
                    label.attr('transform', `rotate(-90) translate( -${(xAxisLabelTopPosition)}, 15)`);
            }
            if(this.yAxisLabel){
                const label = svg.append('g')
                    .append('text')
                    .attr('class', 'chart-label')
                    .text(this.yAxisLabel)
                    const bbox = label.node().getBBox();
                    const yAxisLabelLeftPosition = width - (plotWidth/2) - (bbox.width / 2);
                    label.attr('transform', `translate(${yAxisLabelLeftPosition},${height - 15})`)
            }

            const plot = svg.append("g")
                .attr("transform", `translate(${margin.left+labels.xAxis},${margin.top})`);

            const entryKey = (entry) => {
                if(hasSubsetKey){
                    return entry[primaryKey] + ' - ' + entry[subsetKey];
                }else{
                    return entry[primaryKey];
                }
            }

            // x scale
            const x = d3.scaleBand()
                .domain(domain)
                .range([5, plotWidth])
                .padding(0);

            let x2;
            if(hasSubsetKey){
                x2 = d3.scaleBand()
                    .domain(keys)
                    .range([5, plotWidth])
                    .padding(0);
            }

            // y scale
            const y = d3.scaleLinear()
                .domain([min, max])
                .range([plotHeight, 0])
                .nice();

            //x-axis ticks
            plot.append("g")
                .attr("transform", `translate(0,${plotHeight})`)
                .call(d3.axisBottom( hasSubsetKey ? x2 : x))
                .selectAll("text")
                .style("text-anchor", "end")
                .attr('font-size', '12px')
                .attr("transform", "rotate(-55) translate(-5, 0)");

            //y-axis ticks
            plot.append("g")
                .call(d3.axisLeft(y));

            const boxWidth = x.bandwidth() * 0.6;

            if(hasSubsetKey){
                keys.forEach((key, i) => {
                    plot.append('rect')
                        .attr("width", x2.bandwidth())    
                        .attr('height', plotHeight)
                        .attr('x', x2(key))
                        .attr('class', 'violin-bg')
                        .attr('fill', i % 2 ? '#fff' : '#eee')
                })
            }

            this.data.forEach((entry) => {
                const xCenter = x(entryKey(entry)) + x.bandwidth() / 2;

                const box = plot.append('g')
                    .attr("width", boxWidth)
                    .attr("class", "violin-group")
                    .attr("data-key", entryKey(entry));

                const boxNode = box.node();

                // kde
                const bandwidth = 1;
                const thresholds = d3.range(d3.min(entry.exprValues), d3.max(entry.exprValues), 0.1);
                const density = this.kde(this.epanechnikovKernel(bandwidth), thresholds, entry.exprValues);

                // normalize kde
                const violinWidth = boxWidth / 2;
                const maxDensity = d3.max(density, d => d[1]);
                const xViolinScale = d3.scaleLinear()
                    .domain([-maxDensity, maxDensity])
                    .range([-violinWidth, violinWidth]);

                const violinPath = d3.line()
                    .x(d => xViolinScale(d[1]) + xCenter) // Scale density for width
                    .y(d => y(d[0])); // Map y-values to data range

                const mirroredDensity = density.map(d => [d[0], -d[1]]).reverse();

                box.append('path')
                    .datum(density.concat(mirroredDensity)) // Combine for full violin
                    .attr('d', violinPath)
                    .attr('fill', entry.color)
                    .attr('stroke', 'none');

                // Draw box
                box.append("rect")
                    .attr("x", xCenter - 5 / 2)
                    .attr("y", y(entry.q3))
                    .attr("width", 5)
                    .attr("height", Math.max(0, y(entry.q1) - y(entry.q3))) // Avoid negative heights
                    .attr("fill", "transparent")
                    .attr("stroke", "black")

                // Median line
                box.append("line")
                    .attr("x1", xCenter - boxWidth / 2)
                    .attr("x2", xCenter + boxWidth / 2)
                    .attr("y1", y(entry.median))
                    .attr("y2", y(entry.median))
                    .attr("stroke", "black");

                // Whiskers
                box.append("line")
                    .attr("x1", xCenter)
                    .attr("x2", xCenter)
                    .attr("y1", y(entry.min))
                    .attr("y2", y(entry.q1))
                    .attr("stroke", "black");

                box.append("line")
                    .attr("x1", xCenter)
                    .attr("x2", xCenter)
                    .attr("y1", y(entry.q3))
                    .attr("y2", y(entry.max))
                    .attr("stroke", "black");

                // Add whisker caps
                box.append("line")
                    .attr("x1", xCenter - boxWidth / 4)
                    .attr("x2", xCenter + boxWidth / 4)
                    .attr("y1", y(entry.min))
                    .attr("y2", y(entry.min))
                    .attr("stroke", "black");

                box.append("line")
                    .attr("x1", xCenter - boxWidth / 4)
                    .attr("x2", xCenter + boxWidth / 4)
                    .attr("y1", y(entry.max))
                    .attr("y2", y(entry.max))
                    .attr("stroke", "black");

                //event listener layer
                box.append("rect")
                    .attr("x", xCenter - boxWidth / 2)
                    .attr("y", y(entry.max))
                    .attr("width", boxWidth)
                    .attr("height", y(entry.min) - y(entry.max))
                    .attr("fill", "transparent")
                    .style("pointer-events", "all");


                // Tooltip mouseover
                boxNode.addEventListener('mouseover', function(e){
                    tooltip.innerHTML = `
                                        <div style="display:flex;gap:5px"><div style="width:50px;font-weight:bold;white-space:nowrap;overflow:hidden;text-overflow:ellipsis;">${primaryKey}:</div> ${entry[primaryKey]}</div>
                                        <div style="display:${hasSubsetKey?'flex':'none'};gap:5px"><div style="width:50px;font-weight:bold;white-space:nowrap;overflow:hidden;text-overflow:ellipsis;">${subsetKey}:</div> ${entry[subsetKey]}</div>
                                        <div style="display:${entry.gene?'flex':'none'};gap:5px"><div style="width:50px;font-weight:bold">Gene:</div> ${entry.gene}</div>
                                        <div style="display:flex;gap:5px"><div style="width:50px;font-weight:bold">Max:</div> ${entry.max}</div>
                                        <div style="display:flex;gap:5px"><div style="width:50px;font-weight:bold">Q3:</div> ${entry.q3}</div>
                                        <div style="display:flex;gap:5px"><div style="width:50px;font-weight:bold">Median:</div> ${entry.median.toFixed(4)}</div>
                                        <div style="display:flex;gap:5px"><div style="width:50px;font-weight:bold">Q1:</div> ${entry.q1}</div>
                                        <div style="display:flex;gap:5px"><div style="width:50px;font-weight:bold">Min:</div> ${entry.min}</div>
                                         `;
                    tooltip.classList.add('show')
                })
                // Tooltip mousemove to follow the cursor
                boxNode.addEventListener('mousemove', function(e){
                    tooltip.style.top = (e.clientY - 10) + "px";
                    tooltip.style.left = (e.clientX + 10) + "px";
                })
                // Tooltip mouseout to hide it
                boxNode.addEventListener('mouseout', function(e){
                    tooltip.classList.remove('show');
                    tooltip.style.top = -1000 + "px";
                    tooltip.style.left = -1000 + "px";
                });
            });
        },
        kde(kernel, thresholds, data) {
            return thresholds.map(t => [t, data.reduce((sum, d) => sum + kernel(t - d), 0)]);
        },
        epanechnikovKernel(bandwidth) {
            return function (u) {
                u = u / bandwidth;
                return Math.abs(u) <= 1 ? 0.75 * (1 - u * u) / bandwidth : 0;
            };
        },
        doHighlight(key){
            const svg = this.$refs.chart;
            const violins = svg.querySelectorAll('.violin-group');
            violins.forEach(violin=>{
                if(!key || violin.dataset.key===key){
                    violin.style.opacity = '1';
                }else{
                    violin.style.opacity = '0.1';
                }
            })
        }
    },
  });
  </script>
  
  <style scoped>
  svg {
    font-family: sans-serif;
  }
  ::v-deep .chart-label{
    font-size:12px;
    opacity:0.5;
  }
  .tooltip{
    position:fixed;
    background: white;
    padding: 5px 10px;
    box-shadow: rgba(0, 0, 0, 0.5) -4px 9px 25px -6px;
  }
  .tooltip.show{
    opacity: 1;
  }
  </style>
  