summaryrefslogtreecommitdiff
path: root/mlir/lib/Debug/ExecutionContext.cpp
blob: f7505b6608c8148075410a94344305f887cac129 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
//===- ExecutionContext.cpp - Debug Execution Context Support -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Debug/ExecutionContext.h"

#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/FormatVariadic.h"

#include <cstddef>

using namespace mlir;
using namespace mlir::tracing;

//===----------------------------------------------------------------------===//
// ActionActiveStack
//===----------------------------------------------------------------------===//

void ActionActiveStack::print(raw_ostream &os, bool withContext) const {
  os << "ActionActiveStack depth " << getDepth() << "\n";
  const ActionActiveStack *current = this;
  int count = 0;
  while (current) {
    llvm::errs() << llvm::formatv("#{0,3}: ", count++);
    current->action.print(llvm::errs());
    llvm::errs() << "\n";
    ArrayRef<IRUnit> context = current->action.getContextIRUnits();
    if (withContext && !context.empty()) {
      llvm::errs() << "Context:\n";
      llvm::interleave(
          current->action.getContextIRUnits(),
          [&](const IRUnit &unit) {
            llvm::errs() << "  - ";
            unit.print(llvm::errs());
          },
          [&]() { llvm::errs() << "\n"; });
      llvm::errs() << "\n";
    }
    current = current->parent;
  }
}

//===----------------------------------------------------------------------===//
// ExecutionContext
//===----------------------------------------------------------------------===//

static const LLVM_THREAD_LOCAL ActionActiveStack *actionStack = nullptr;

void ExecutionContext::registerObserver(Observer *observer) {
  observers.push_back(observer);
}

void ExecutionContext::operator()(llvm::function_ref<void()> transform,
                                  const Action &action) {
  // Update the top of the stack with the current action.
  int depth = 0;
  if (actionStack)
    depth = actionStack->getDepth() + 1;
  ActionActiveStack info{actionStack, action, depth};
  actionStack = &info;
  auto raii = llvm::make_scope_exit([&]() { actionStack = info.getParent(); });
  Breakpoint *breakpoint = nullptr;

  // Invoke the callback here and handles control requests here.
  auto handleUserInput = [&]() -> bool {
    if (!onBreakpointControlExecutionCallback)
      return true;
    auto todoNext = onBreakpointControlExecutionCallback(actionStack);
    switch (todoNext) {
    case ExecutionContext::Apply:
      depthToBreak = std::nullopt;
      return true;
    case ExecutionContext::Skip:
      depthToBreak = std::nullopt;
      return false;
    case ExecutionContext::Step:
      depthToBreak = depth + 1;
      return true;
    case ExecutionContext::Next:
      depthToBreak = depth;
      return true;
    case ExecutionContext::Finish:
      depthToBreak = depth - 1;
      return true;
    }
    llvm::report_fatal_error("Unknown control request");
  };

  // Try to find a breakpoint that would hit on this action.
  // Right now there is no way to collect them all, we stop at the first one.
  for (auto *breakpointManager : breakpoints) {
    breakpoint = breakpointManager->match(action);
    if (breakpoint)
      break;
  }
  info.setBreakpoint(breakpoint);

  bool shouldExecuteAction = true;
  // If we have a breakpoint, or if `depthToBreak` was previously set and the
  // current depth matches, we invoke the user-provided callback.
  if (breakpoint || (depthToBreak && depth <= depthToBreak))
    shouldExecuteAction = handleUserInput();

  // Notify the observers about the current action.
  for (auto *observer : observers)
    observer->beforeExecute(actionStack, breakpoint, shouldExecuteAction);

  if (shouldExecuteAction) {
    // Execute the action here.
    transform();

    // Notify the observers about completion of the action.
    for (auto *observer : observers)
      observer->afterExecute(actionStack);
  }

  if (depthToBreak && depth <= depthToBreak)
    handleUserInput();
}