import React, { useState, useCallback, useEffect, useRef } from 'react';
import { ForceGraph2D } from 'react-force-graph';
import { useAgentContext, Node } from '../../context/AgentContext';

interface Link {
  source: string;
  target: string;
}

interface GraphData {
  nodes: Node[];
  links: Link[];
}

const getNodeColor = (group: string, id: string) => {
    if (id.includes('user_uploads')) {
      return '#FF6B6B';  // Special color for the 'User Uploads' node
    }
    switch (group) {
      case 'agent':
        return '#004DB5';  // Agents are blue
      case 'domain':
        return '#7D7D7D';  // Domain nodes are grey
      case 'user-added':
        return '#FFB74D';  // User-added nodes are orange
      default:
        return '#000000';  // Default to black for unclassified nodes
    }
};

const KnowledgeGraph: React.FC = () => {
  const { personas } = useAgentContext();
  const [graphData, setGraphData] = useState<GraphData>({ nodes: [], links: [] });
  const [expandedNodes, setExpandedNodes] = useState<Set<string>>(new Set());
  const graphRef = useRef<any>(null);
  const containerRef = useRef<HTMLDivElement>(null);

  useEffect(() => {
    // Load only the root nodes (agents) and their direct sub-nodes (User Uploads & Domains)
    const initialNodes: Node[] = personas.flatMap(persona => [
      { id: persona.id, name: persona.name, group: 'agent', val: 30 },
      ...persona.nodes.filter(node => node.group === 'user-added' || node.group === 'domain')
    ]);

    const initialLinks: Link[] = personas.flatMap(persona => [
      ...persona.nodes
        .filter(node => node.group === 'user-added' || node.group === 'domain')
        .map(subNode => ({ source: persona.id, target: subNode.id }))
    ]);

    setGraphData({ nodes: initialNodes, links: initialLinks });
  }, [personas]);

  useEffect(() => {
    if (graphRef.current && graphData.nodes.length > 0 && containerRef.current) {
      const fg = graphRef.current;

      // Adjust force simulation to prevent graph from disappearing
      fg.d3Force('charge').strength(-200);
      fg.d3Force('link').distance(150);

      // Disable default centering
      fg.d3Force('center', null);

      // Manually zoom and center the graph
      setTimeout(() => {
        fg.zoomToFit(400, 50); // Ensure this zooms the entire graph into view
      }, 500);
    }
  }, [graphData]);

  const handleNodeClick = useCallback(
    (node: Node) => {
      const isExpanded = expandedNodes.has(node.id);
      if (isExpanded) {
        // Collapse the node
        setExpandedNodes(prev => {
          const newSet = new Set(prev);
          newSet.delete(node.id);
          return newSet;
        });

        const filteredNodes = graphData.nodes.filter(n => !(n.id.startsWith(`${node.id}-`)));
        const filteredLinks = graphData.links.filter(link => {
            const linkSource = typeof link.source === 'string' ? link.source : (link.source as Node).id;
            return linkSource !== node.id && !linkSource.startsWith(`${node.id}-`);
          });
          

        setGraphData({ nodes: filteredNodes, links: filteredLinks });
      } else {
        // Expand the node and add its sub-nodes dynamically
        setExpandedNodes(prev => new Set(prev.add(node.id)));

        const subNodes = personas
          .flatMap(persona =>
            persona.nodes.filter(n => n.group === 'domain' && !graphData.nodes.some(gn => gn.id === n.id))
          )
          .map(subNode => ({
            ...subNode,
            id: `${node.id}-${subNode.id}`, // Ensure unique ids for sub-nodes
          }));

        const newLinks = subNodes.map(subNode => ({
          source: node.id,
          target: subNode.id,
        }));

        setGraphData(prev => ({
          nodes: [...prev.nodes, ...subNodes],
          links: [...prev.links, ...newLinks],
        }));
      }
    },
    [graphData, personas, expandedNodes]
  );

  return (
    <div ref={containerRef} className="w-full max-w-full h-[800px] bg-gray-100 rounded-lg shadow-lg relative overflow-hidden">
    <ForceGraph2D
      ref={graphRef}
      graphData={graphData}
      nodeLabel="name"
      nodeColor={(node: Node) => getNodeColor(node.group, node.id)}
      nodeCanvasObject={(node, ctx, globalScale) => {
        const label = node.name;
        const fontSize = 12 / globalScale;
        ctx.font = `${fontSize}px Sans-Serif`;
        const textWidth = ctx.measureText(label).width;
        const bckgDimensions: [number, number] = [textWidth, fontSize].map(n => n + fontSize * 0.2) as [number, number];
  
        ctx.beginPath();
        ctx.arc(node.x ?? 0, node.y ?? 0, 10, 0, 2 * Math.PI);
        ctx.fillStyle = getNodeColor(node.group, node.id);
        ctx.fill();
  
        ctx.strokeStyle = '#FFFFFF'; // White border for contrast
        ctx.lineWidth = 2;
        ctx.stroke();
  
        ctx.fillStyle = 'rgba(255, 255, 255, 0.8)';
        ctx.fillRect((node.x ?? 0) - bckgDimensions[0] / 2, (node.y ?? 0) - bckgDimensions[1] / 2, bckgDimensions[0], bckgDimensions[1]);
  
        ctx.textAlign = 'center';
        ctx.textBaseline = 'middle';
        ctx.fillStyle = '#000000';
        ctx.fillText(label, node.x ?? 0, node.y ?? 0);
      }}
      nodeCanvasObjectMode={() => 'replace'}
      onNodeClick={handleNodeClick}
      linkColor={() => 'rgba(0,0,0,0.2)'}
      linkWidth={1}
      cooldownTicks={100}
      onEngineStop={() => {
        graphRef.current.zoomToFit(400, 50);
      }}
    />
  </div>
  

  );
};

export default KnowledgeGraph;
