import * as React from 'react';

import { ReactClassHooks, ReactClassHookFunction } from './ReactClassHooks';

interface HookableClass extends React.ComponentClass {
  __staticHooks: ReactClassHooks[];
  __hooked: boolean;
}

interface Hookable extends React.Component {
  __hooks: ReactClassHooks[];
}

export function makeHookable(clazz: React.ComponentClass) {
  const hookableClass = clazz as HookableClass;

  hookableClass.prototype.componentDidMount = mount(hookableClass);
  hookableClass.prototype.componentWillUnmount = unmount(hookableClass);
  hookableClass.prototype.render = render(hookableClass);
  hookableClass.prototype.componentDidUpdate = update(hookableClass);
  hookableClass.prototype.shouldComponentUpdate = shouldUpdate(hookableClass);
}

export function addStaticHook(
  clazz: React.ComponentClass,
  hook: ReactClassHooks
) {
  const hookableClass = clazz as HookableClass;

  if (hookableClass.__hooked) {
    return;
  }

  hookableClass.__hooked = true;

  if (!hookableClass.__staticHooks) {
    hookableClass.__staticHooks = [];
  }

  hookableClass.__staticHooks.push(hook);
}

export function addHook(component: React.Component, hook: ReactClassHooks) {
  const hookable = component as Hookable;

  if (!hookable.__hooks) {
    hookable.__hooks = [];
  }

  hookable.__hooks.push(hook);
}

function getHooks(
  hookable: Hookable,
  hookableClass: HookableClass
): ReactClassHooks[] {
  if (!hookable.__hooks && !hookableClass.__staticHooks) {
    return [];
  }

  return (hookable.__hooks || []).concat(hookableClass.__staticHooks || []);
}

function invokeHook(
  hookable: Hookable,
  hookFunction: ReactClassHookFunction,
  ...args: any[]
): any {
  if (hookFunction) {
    try {
      return hookFunction(hookable, ...args);
    } catch (e) {}
  }
}

function mount(hookableClass: HookableClass) {
  const prev = hookableClass.prototype.componentDidMount;

  return function inner() {
    // @ts-ignore
    const hookable = this as Hookable;

    const hooks = getHooks(hookable, hookableClass);
    hooks.forEach(hook => invokeHook(hookable, hook.mount));

    return prev ? prev.call(hookable) : null;
  };
}

function unmount(hookableClass: HookableClass) {
  const prev = hookableClass.prototype.componentWillUnmount;

  return function inner() {
    // @ts-ignore
    const hookable = this as Hookable;

    const hooks = getHooks(hookable, hookableClass);
    hooks.forEach(hook => invokeHook(hookable, hook.unmount));

    return prev ? prev.call(hookable) : null;
  };
}

function render(hookableClass: HookableClass) {
  const prev = hookableClass.prototype.render;

  return function inner() {
    // @ts-ignore
    const hookable = this as Hookable;

    const hooks = getHooks(hookable, hookableClass);

    hooks.forEach(hook => invokeHook(hookable, hook.beforeRender));

    const result = prev.call(hookable);

    hooks.forEach(hook => invokeHook(hookable, hook.afterRender));

    return result;
  };
}

function update(hookableClass: HookableClass) {
  const prev = hookableClass.prototype.componentDidUpdate;

  return function inner() {
    // @ts-ignore
    const hookable = this as Hookable;
    const args = arguments;

    const hooks = getHooks(hookable, hookableClass);
    hooks.forEach(hook => invokeHook(hookable, hook.update));

    return prev ? prev.apply(hookable, args) : null;
  };
}

function shouldUpdate(hookableClass: HookableClass) {
  const prev = hookableClass.prototype.shouldComponentUpdate;

  return function inner() {
    // @ts-ignore
    const hookable = this as Hookable;
    const args = arguments;

    const hooks = getHooks(hookable, hookableClass);
    const shouldUpdate = hooks.reduce((update, hook) => {
      return update || invokeHook(hookable, hook.shouldComponentUpdate, args);
    }, false);

    return shouldUpdate || (prev ? prev.apply(hookable, args) : false);
  };
}

export function isHookable(component: React.Component) {
  //@ts-ignore
  return Boolean(component && component.constructor.__hooked);
}
