import { SelectionChange } from '@angular/cdk/collections';
import { Directive, Input, OnDestroy, OnInit, Optional } from '@angular/core';
import { MatTree } from '@angular/material/tree';
import { ComponentStateService } from '@core/services';
import { Subscription } from 'rxjs';
import { StorageKey } from '../../models/storage-key.model';

@Directive({
  standalone: true,
  selector: '[appPersistExpandState]',
})
export class PersistExpandStateDirective<T extends object = any>
  implements OnInit, OnDestroy
{
  constructor(
    private _cmpState: ComponentStateService,
    @Optional() private _matTree: MatTree<T>
  ) {
    if (!(_matTree instanceof MatTree)) {
      throw new Error('Rosetta: Directive must use <MatTree> component.');
    }
  }

  @Input('appPersistExpandState') stateKey = '';
  @Input('appPersistExpandStateKey') key = 'id';
  @Input('appPersistExpandStateExpandKey') expandKey = 'expandable';

  private _sub = new Subscription();

  private get _namespaceKey(): StorageKey {
    return { key: this.stateKey, namespace: 'expand' };
  }

  ngOnInit(): void {
    this.resetState();

    this._sub.add(
      this._matTree.treeControl.expansionModel.changed.subscribe(n =>
        this._updateState(n)
      )
    );
  }

  ngOnDestroy(): void {
    this._sub.unsubscribe();
  }

  resetState(): boolean {
    if (this._cmpState.has(this._namespaceKey)) {
      return false;
    }

    let nodes = this._matTree.treeControl.dataNodes;
    const expandedNodeSet = this._cmpState.get(this._namespaceKey);

    if (nodes && nodes.length > 0 && expandedNodeSet) {
      nodes = nodes.filter(
        (n: any) => n[this.expandKey] && expandedNodeSet.has(n[this.key])
      );
      this._matTree.treeControl.expansionModel.select(...nodes);
    }

    return true;
  }

  private _updateState(
    selectionChanged: Omit<SelectionChange<T>, 'source'>
  ): void {
    const state = this._cmpState.get(this._namespaceKey) || new Set();
    selectionChanged.removed.forEach(n => state.delete((n as any)[this.key]));
    selectionChanged.added.forEach(n => state.add((n as any)[this.key]));
    this._cmpState.set(this._namespaceKey, state);
  }
}
