import { appendUnique, arrayEquals, domainDescriptions, flatMorph, groupBy, isArray, jsTypeOfDescriptions, printable, throwParseError } from "@ark/util";
import { compileLiteralPropAccess, compileSerializedValue } from "../shared/compile.js";
import { Disjoint } from "../shared/disjoint.js";
import { implementNode } from "../shared/implement.js";
import { intersectNodesRoot, intersectOrPipeNodes } from "../shared/intersections.js";
import { $ark, registeredReference } from "../shared/registry.js";
import { hasArkKind } from "../shared/utils.js";
import { BaseRoot } from "./root.js";
import { defineRightwardIntersections } from "./utils.js";
const implementation = implementNode({
  kind: "union",
  hasAssociatedError: true,
  collapsibleKey: "branches",
  keys: {
    ordered: {},
    branches: {
      child: true,
      parse: (schema, ctx) => {
        const branches = [];
        schema.forEach(branchSchema => {
          const branchNodes = hasArkKind(branchSchema, "root") ? branchSchema.branches : ctx.$.parseSchema(branchSchema).branches;
          branchNodes.forEach(node => {
            if (node.hasKind("morph")) {
              const matchingMorphIndex = branches.findIndex(matching => matching.hasKind("morph") && matching.hasEqualMorphs(node));
              if (matchingMorphIndex === -1) branches.push(node);else {
                const matchingMorph = branches[matchingMorphIndex];
                branches[matchingMorphIndex] = ctx.$.node("morph", {
                  ...matchingMorph.inner,
                  in: matchingMorph.in.rawOr(node.in)
                });
              }
            } else branches.push(node);
          });
        });
        if (!ctx.def.ordered) branches.sort((l, r) => l.hash < r.hash ? -1 : 1);
        return branches;
      }
    }
  },
  normalize: schema => isArray(schema) ? {
    branches: schema
  } : schema,
  reduce: (inner, $) => {
    const reducedBranches = reduceBranches(inner);
    if (reducedBranches.length === 1) return reducedBranches[0];
    if (reducedBranches.length === inner.branches.length) return;
    return $.node("union", {
      ...inner,
      branches: reducedBranches
    }, {
      prereduced: true
    });
  },
  defaults: {
    description: node => node.distribute(branch => branch.description, describeBranches),
    expected: ctx => {
      const byPath = groupBy(ctx.errors, "propString");
      const pathDescriptions = Object.entries(byPath).map(([path, errors]) => {
        const branchesAtPath = [];
        errors.forEach(errorAtPath =>
        // avoid duplicate messages when multiple branches
        // are invalid due to the same error
        appendUnique(branchesAtPath, errorAtPath.expected));
        const expected = describeBranches(branchesAtPath);
        // if there are multiple actual descriptions that differ,
        // just fall back to printable, which is the most specific
        const actual = errors.every(e => e.actual === errors[0].actual) ? errors[0].actual : printable(errors[0].data);
        return `${path && `${path} `}must be ${expected}${actual && ` (was ${actual})`}`;
      });
      return describeBranches(pathDescriptions);
    },
    problem: ctx => ctx.expected,
    message: ctx => ctx.problem
  },
  intersections: {
    union: (l, r, ctx) => {
      if (l.isNever !== r.isNever) {
        // if exactly one operand is never, we can use it to discriminate based on presence
        return Disjoint.init("presence", l, r);
      }
      let resultBranches;
      if (l.ordered) {
        if (r.ordered) {
          throwParseError(writeOrderedIntersectionMessage(l.expression, r.expression));
        }
        resultBranches = intersectBranches(r.branches, l.branches, ctx);
        if (resultBranches instanceof Disjoint) resultBranches.invert();
      } else resultBranches = intersectBranches(l.branches, r.branches, ctx);
      if (resultBranches instanceof Disjoint) return resultBranches;
      return ctx.$.parseSchema(l.ordered || r.ordered ? {
        branches: resultBranches,
        ordered: true
      } : {
        branches: resultBranches
      });
    },
    ...defineRightwardIntersections("union", (l, r, ctx) => {
      const branches = intersectBranches(l.branches, [r], ctx);
      if (branches instanceof Disjoint) return branches;
      if (branches.length === 1) return branches[0];
      return ctx.$.parseSchema(l.ordered ? {
        branches,
        ordered: true
      } : {
        branches
      });
    })
  }
});
export class UnionNode extends BaseRoot {
  isBoolean = this.branches.length === 2 && this.branches[0].hasUnit(false) && this.branches[1].hasUnit(true);
  get branchGroups() {
    const branchGroups = [];
    let firstBooleanIndex = -1;
    this.branches.forEach(branch => {
      if (branch.hasKind("unit") && branch.domain === "boolean") {
        if (firstBooleanIndex === -1) {
          firstBooleanIndex = branchGroups.length;
          branchGroups.push(branch);
        } else branchGroups[firstBooleanIndex] = $ark.intrinsic.boolean;
        return;
      }
      branchGroups.push(branch);
    });
    return branchGroups;
  }
  unitBranches = this.branches.filter(n => n.in.hasKind("unit"));
  discriminant = this.discriminate();
  discriminantJson = this.discriminant ? discriminantToJson(this.discriminant) : null;
  expression = this.distribute(n => n.nestableExpression, expressBranches);
  get shortDescription() {
    return this.distribute(branch => branch.shortDescription, describeBranches);
  }
  innerToJsonSchema() {
    return {
      anyOf: this.branchGroups.map(group =>
      // special case to simplify { const: true } | { const: false }
      // to the canonical JSON Schema representation { type: "boolean" }
      group.equals($ark.intrinsic.boolean) ? {
        type: "boolean"
      } : group.toJsonSchema())
    };
  }
  traverseAllows = (data, ctx) => this.branches.some(b => b.traverseAllows(data, ctx));
  traverseApply = (data, ctx) => {
    const errors = [];
    for (let i = 0; i < this.branches.length; i++) {
      ctx.pushBranch();
      this.branches[i].traverseApply(data, ctx);
      if (!ctx.hasError()) {
        if (this.branches[i].includesMorph) return ctx.queuedMorphs.push(...ctx.popBranch().queuedMorphs);
        return ctx.popBranch();
      }
      errors.push(ctx.popBranch().error);
    }
    ctx.errorFromNodeContext({
      code: "union",
      errors,
      meta: this.meta
    });
  };
  compile(js) {
    if (!this.discriminant ||
    // if we have a union of two units like `boolean`, the
    // undiscriminated compilation will be just as fast
    this.unitBranches.length === this.branches.length && this.branches.length === 2) return this.compileIndiscriminable(js);
    // we need to access the path as optional so we don't throw if it isn't present
    let condition = this.discriminant.optionallyChainedPropString;
    if (this.discriminant.kind === "domain") condition = `typeof ${condition} === "object" ? ${condition} === null ? "null" : "object" : typeof ${condition} === "function" ? "object" : typeof ${condition}`;
    const cases = this.discriminant.cases;
    const caseKeys = Object.keys(cases);
    js.block(`switch(${condition})`, () => {
      for (const k in cases) {
        const v = cases[k];
        const caseCondition = k === "default" ? k : `case ${k}`;
        js.line(`${caseCondition}: return ${v === true ? v : js.invoke(v)}`);
      }
      return js;
    });
    if (js.traversalKind === "Allows") {
      js.return(false);
      return;
    }
    const expected = describeBranches(this.discriminant.kind === "domain" ? caseKeys.map(k => {
      const jsTypeOf = k.slice(1, -1);
      return jsTypeOf === "function" ? domainDescriptions.object : domainDescriptions[jsTypeOf];
    }) : caseKeys);
    const serializedPathSegments = this.discriminant.path.map(k => typeof k === "string" ? JSON.stringify(k) : registeredReference(k));
    const serializedExpected = JSON.stringify(expected);
    const serializedActual = this.discriminant.kind === "domain" ? `${serializedTypeOfDescriptions}[${condition}]` : `${serializedPrintable}(${condition})`;
    // TODO: should have its own error code
    js.line(`ctx.errorFromNodeContext({
	code: "predicate",
	expected: ${serializedExpected},
	actual: ${serializedActual},
	relativePath: [${serializedPathSegments}],
	meta: ${this.compiledMeta}
})`);
  }
  compileIndiscriminable(js) {
    if (js.traversalKind === "Apply") {
      js.const("errors", "[]");
      this.branches.forEach(branch => js.line("ctx.pushBranch()").line(js.invoke(branch)).if("!ctx.hasError()", () => js.return(branch.includesMorph ? "ctx.queuedMorphs.push(...ctx.popBranch().queuedMorphs)" : "ctx.popBranch()")).line("errors.push(ctx.popBranch().error)"));
      js.line(`ctx.errorFromNodeContext({ code: "union", errors, meta: ${this.compiledMeta} })`);
    } else {
      this.branches.forEach(branch => js.if(`${js.invoke(branch)}`, () => js.return(true)));
      js.return(false);
    }
  }
  get nestableExpression() {
    // avoid adding unnecessary parentheses around boolean since it's
    // already collapsed to a single keyword
    return this.isBoolean ? "boolean" : `(${this.expression})`;
  }
  discriminate() {
    if (this.branches.length < 2) return null;
    if (this.unitBranches.length === this.branches.length) {
      const cases = flatMorph(this.unitBranches, (i, n) => [`${n.in.serializedValue}`, n.hasKind("morph") ? n : true]);
      return {
        kind: "unit",
        path: [],
        optionallyChainedPropString: "data",
        cases
      };
    }
    const candidates = [];
    for (let lIndex = 0; lIndex < this.branches.length - 1; lIndex++) {
      const l = this.branches[lIndex];
      for (let rIndex = lIndex + 1; rIndex < this.branches.length; rIndex++) {
        const r = this.branches[rIndex];
        const result = intersectNodesRoot(l.in, r.in, l.$);
        if (!(result instanceof Disjoint)) continue;
        for (const entry of result) {
          if (!entry.kind || entry.optional) continue;
          let lSerialized;
          let rSerialized;
          if (entry.kind === "domain") {
            const lValue = entry.l;
            const rValue = entry.r;
            lSerialized = `"${typeof lValue === "string" ? lValue : lValue.domain}"`;
            rSerialized = `"${typeof rValue === "string" ? rValue : rValue.domain}"`;
          } else if (entry.kind === "unit") {
            lSerialized = entry.l.serializedValue;
            rSerialized = entry.r.serializedValue;
          } else continue;
          const matching = candidates.find(d => arrayEquals(d.path, entry.path) && d.kind === entry.kind);
          if (!matching) {
            candidates.push({
              kind: entry.kind,
              cases: {
                [lSerialized]: [l],
                [rSerialized]: [r]
              },
              path: entry.path
            });
          } else {
            matching.cases[lSerialized] = appendUnique(matching.cases[lSerialized], l);
            matching.cases[rSerialized] = appendUnique(matching.cases[rSerialized], r);
          }
        }
      }
    }
    const best = candidates.sort((l, r) => Object.keys(l.cases).length - Object.keys(r.cases).length).at(-1);
    if (!best) return null;
    let defaultBranches = [...this.branches];
    const bestCtx = {
      kind: best.kind,
      path: best.path,
      optionallyChainedPropString: optionallyChainPropString(best.path)
    };
    const cases = flatMorph(best.cases, (k, caseBranches) => {
      const prunedBranches = [];
      defaultBranches = defaultBranches.filter(n => !caseBranches.includes(n));
      for (const branch of caseBranches) {
        const pruned = pruneDiscriminant(branch, bestCtx);
        // if any branch of the union has no constraints (i.e. is unknown)
        // return it right away
        if (pruned === null) return [k, true];
        prunedBranches.push(pruned);
      }
      const caseNode = prunedBranches.length === 1 ? prunedBranches[0] : this.$.node("union", prunedBranches);
      Object.assign(this.referencesById, caseNode.referencesById);
      return [k, caseNode];
    });
    if (defaultBranches.length) {
      cases.default = this.$.node("union", defaultBranches, {
        prereduced: true
      });
      Object.assign(this.referencesById, cases.default.referencesById);
    }
    return Object.assign(bestCtx, {
      cases
    });
  }
}
const optionallyChainPropString = path => path.reduce((acc, k) => acc + compileLiteralPropAccess(k, true), "data");
const serializedTypeOfDescriptions = registeredReference(jsTypeOfDescriptions);
const serializedPrintable = registeredReference(printable);
export const Union = {
  implementation,
  Node: UnionNode
};
const discriminantToJson = discriminant => ({
  kind: discriminant.kind,
  path: discriminant.path.map(k => typeof k === "string" ? k : compileSerializedValue(k)),
  cases: flatMorph(discriminant.cases, (k, node) => [k, node === true ? node : node.hasKind("union") && node.discriminantJson ? node.discriminantJson : node.json])
});
const describeExpressionOptions = {
  delimiter: " | ",
  finalDelimiter: " | "
};
const expressBranches = expressions => describeBranches(expressions, describeExpressionOptions);
const describeBranches = (descriptions, opts) => {
  const delimiter = opts?.delimiter ?? ", ";
  const finalDelimiter = opts?.finalDelimiter ?? " or ";
  if (descriptions.length === 0) return "never";
  if (descriptions.length === 1) return descriptions[0];
  if (descriptions.length === 2 && descriptions[0] === "false" && descriptions[1] === "true" || descriptions[0] === "true" && descriptions[1] === "false") return "boolean";
  // keep track of seen descriptions to avoid duplication
  const seen = {};
  const unique = descriptions.filter(s => seen[s] ? false : seen[s] = true);
  const last = unique.pop();
  return `${unique.join(delimiter)}${unique.length ? finalDelimiter : ""}${last}`;
};
export const intersectBranches = (l, r, ctx) => {
  // If the corresponding r branch is identified as a subtype of an l branch, the
  // value at rIndex is set to null so we can avoid including previous/future
  // inersections in the reduced result.
  const batchesByR = r.map(() => []);
  for (let lIndex = 0; lIndex < l.length; lIndex++) {
    let candidatesByR = {};
    for (let rIndex = 0; rIndex < r.length; rIndex++) {
      if (batchesByR[rIndex] === null) {
        // rBranch is a subtype of an lBranch and
        // will not yield any distinct intersection
        continue;
      }
      if (l[lIndex].equals(r[rIndex])) {
        // Combination of subtype and supertype cases
        batchesByR[rIndex] = null;
        candidatesByR = {};
        break;
      }
      const branchIntersection = intersectOrPipeNodes(l[lIndex], r[rIndex], ctx);
      if (branchIntersection instanceof Disjoint) {
        // Doesn't tell us anything useful about their relationships
        // with other branches
        continue;
      }
      if (branchIntersection.equals(l[lIndex])) {
        // If the current l branch is a subtype of r, intersections
        // with previous and remaining branches of r won't lead to
        // distinct intersections.
        batchesByR[rIndex].push(l[lIndex]);
        candidatesByR = {};
        break;
      }
      if (branchIntersection.equals(r[rIndex])) {
        // If the current r branch is a subtype of l, set its batch to
        // null, removing any previous intersections and preventing any
        // of its remaining intersections from being computed.
        batchesByR[rIndex] = null;
      } else {
        // If neither l nor r is a subtype of the other, add their
        // intersection as a candidate (could still be removed if it is
        // determined l or r is a subtype of a remaining branch).
        candidatesByR[rIndex] = branchIntersection;
      }
    }
    for (const rIndex in candidatesByR) {
      // batchesByR at rIndex should never be null if it is in candidatesByR
      batchesByR[rIndex][lIndex] = candidatesByR[rIndex];
    }
  }
  // Compile the reduced intersection result, including:
  // 		1. Remaining candidates resulting from distinct intersections or strict subtypes of r
  // 		2. Original r branches corresponding to indices with a null batch (subtypes of l)
  const resultBranches = batchesByR.flatMap(
  // ensure unions returned from branchable intersections like sequence are flattened
  (batch, i) => batch?.flatMap(branch => branch.branches) ?? r[i]);
  return resultBranches.length === 0 ? Disjoint.init("union", l, r) : resultBranches;
};
export const reduceBranches = ({
  branches,
  ordered
}) => {
  if (branches.length < 2) return branches;
  const uniquenessByIndex = branches.map(() => true);
  for (let i = 0; i < branches.length; i++) {
    for (let j = i + 1; j < branches.length && uniquenessByIndex[i] && uniquenessByIndex[j]; j++) {
      if (branches[i].equals(branches[j])) {
        // if the two branches are equal, only "j" is marked as
        // redundant so at least one copy could still be included in
        // the final set of branches.
        uniquenessByIndex[j] = false;
        continue;
      }
      const intersection = intersectNodesRoot(branches[i].in, branches[j].in, branches[0].$);
      if (intersection instanceof Disjoint) continue;
      if (!ordered) assertDeterminateOverlap(branches[i], branches[j]);
      if (intersection.equals(branches[i].in)) {
        // preserve ordered branches that are a subtype of a subsequent branch
        uniquenessByIndex[i] = !!ordered;
      } else if (intersection.equals(branches[j].in)) uniquenessByIndex[j] = false;
    }
  }
  return branches.filter((_, i) => uniquenessByIndex[i]);
};
const assertDeterminateOverlap = (l, r) => {
  if ((l.includesMorph || r.includesMorph) && (!arrayEquals(l.shallowMorphs, r.shallowMorphs, {
    isEqual: (l, r) => l.hasEqualMorphs(r)
  }) || !arrayEquals(l.flatMorphs, r.flatMorphs, {
    isEqual: (l, r) => l.propString === r.propString && l.node.hasEqualMorphs(r.node)
  }))) {
    throwParseError(writeIndiscriminableMorphMessage(l.expression, r.expression));
  }
};
export const pruneDiscriminant = (discriminantBranch, discriminantCtx) => discriminantBranch.transform((nodeKind, inner) => {
  if (nodeKind === "domain" || nodeKind === "unit") return null;
  return inner;
}, {
  shouldTransform: (node, ctx) => {
    // safe to cast here as index nodes are never discriminants
    const propString = optionallyChainPropString(ctx.path);
    if (!discriminantCtx.optionallyChainedPropString.startsWith(propString)) return false;
    if (node.hasKind("domain") && node.domain === "object")
      // if we've already checked a path at least as long as the current one,
      // we don't need to revalidate that we're in an object
      return true;
    if ((node.hasKind("domain") || discriminantCtx.kind === "unit") && propString === discriminantCtx.optionallyChainedPropString)
      // if the discriminant has already checked the domain at the current path
      // (or a unit literal, implying a domain), we don't need to recheck it
      return true;
    // we don't need to recurse into index nodes as they will never
    // have a required path therefore can't be used to discriminate
    return node.children.length !== 0 && node.kind !== "index";
  }
});
export const writeIndiscriminableMorphMessage = (lDescription, rDescription) => `An unordered union of a type including a morph and a type with overlapping input is indeterminate:
Left: ${lDescription}
Right: ${rDescription}`;
export const writeOrderedIntersectionMessage = (lDescription, rDescription) => `The intersection of two ordered unions is indeterminate:
Left: ${lDescription}
Right: ${rDescription}`;