summaryrefslogtreecommitdiff
path: root/_downloads/d4ff05da55438abd8747b5d94df09ea1/plot_beam_search.ipynb
blob: c1612ba3286a3712824d8ff98edc45b952f6b774 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Beam Search\n\nBeam search with dynamic beam width.\n\nThe progressive widening beam search repeatedly executes a beam search\nwith increasing beam width until the target node is found.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import math\n\nimport matplotlib.pyplot as plt\nimport networkx as nx\n\n\ndef progressive_widening_search(G, source, value, condition, initial_width=1):\n    \"\"\"Progressive widening beam search to find a node.\n\n    The progressive widening beam search involves a repeated beam\n    search, starting with a small beam width then extending to\n    progressively larger beam widths if the target node is not\n    found. This implementation simply returns the first node found that\n    matches the termination condition.\n\n    `G` is a NetworkX graph.\n\n    `source` is a node in the graph. The search for the node of interest\n    begins here and extends only to those nodes in the (weakly)\n    connected component of this node.\n\n    `value` is a function that returns a real number indicating how good\n    a potential neighbor node is when deciding which neighbor nodes to\n    enqueue in the breadth-first search. Only the best nodes within the\n    current beam width will be enqueued at each step.\n\n    `condition` is the termination condition for the search. This is a\n    function that takes a node as input and return a Boolean indicating\n    whether the node is the target. If no node matches the termination\n    condition, this function raises :exc:`NodeNotFound`.\n\n    `initial_width` is the starting beam width for the beam search (the\n    default is one). If no node matching the `condition` is found with\n    this beam width, the beam search is restarted from the `source` node\n    with a beam width that is twice as large (so the beam width\n    increases exponentially). The search terminates after the beam width\n    exceeds the number of nodes in the graph.\n\n    \"\"\"\n    # Check for the special case in which the source node satisfies the\n    # termination condition.\n    if condition(source):\n        return source\n    # The largest possible value of `i` in this range yields a width at\n    # least the number of nodes in the graph, so the final invocation of\n    # `bfs_beam_edges` is equivalent to a plain old breadth-first\n    # search. Therefore, all nodes will eventually be visited.\n    log_m = math.ceil(math.log2(len(G)))\n    for i in range(log_m):\n        width = initial_width * pow(2, i)\n        # Since we are always starting from the same source node, this\n        # search may visit the same nodes many times (depending on the\n        # implementation of the `value` function).\n        for u, v in nx.bfs_beam_edges(G, source, value, width):\n            if condition(v):\n                return v\n    # At this point, since all nodes have been visited, we know that\n    # none of the nodes satisfied the termination condition.\n    raise nx.NodeNotFound(\"no node satisfied the termination condition\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Search for a node with high centrality.\n\nWe generate a random graph, compute the centrality of each node, then perform\nthe progressive widening search in order to find a node of high centrality.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Set a seed for random number generation so the example is reproducible\nseed = 89\n\nG = nx.gnp_random_graph(100, 0.5, seed=seed)\ncentrality = nx.eigenvector_centrality(G)\navg_centrality = sum(centrality.values()) / len(G)\n\n\ndef has_high_centrality(v):\n    return centrality[v] >= avg_centrality\n\n\nsource = 0\nvalue = centrality.get\ncondition = has_high_centrality\n\nfound_node = progressive_widening_search(G, source, value, condition)\nc = centrality[found_node]\nprint(f\"found node {found_node} with centrality {c}\")\n\n\n# Draw graph\npos = nx.spring_layout(G, seed=seed)\noptions = {\n    \"node_color\": \"blue\",\n    \"node_size\": 20,\n    \"edge_color\": \"grey\",\n    \"linewidths\": 0,\n    \"width\": 0.1,\n}\nnx.draw(G, pos, **options)\n# Draw node with high centrality as large and red\nnx.draw_networkx_nodes(G, pos, nodelist=[found_node], node_size=100, node_color=\"r\")\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.15"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}