관리 메뉴

밤 늦게까지 여는 카페

ARA* 알고리즘 - 처음부터 완벽할 필요는 없잖아요! 본문

알고리즘/Path Finding

ARA* 알고리즘 - 처음부터 완벽할 필요는 없잖아요!

Jㅐ둥이 2023. 5. 8. 04:27

안녕하세요. 오늘은 A* 알고리즘의 변형 중 하나인 Anytime Reparing A*(ARA*) 알고리즘을 공부해보려고 해요!

  • LIKHACHEV, Maxim; GORDON, Geoffrey J.; THRUN, Sebastian. ARA*: Anytime A* with provable bounds on sub-optimality. Advances in neural information processing systems, 2003, 16.
  • R. Zhou and E. A. Hansen. Multiple sequence alignment using A*. In Proc. of the National Conference on Artificial Intelligence (AAAI), 2002. Student abstract
    • 논문에서 Anytime 알고리즘으로 소개하길래 읽어봤는데 W를 변경시키지는 않고, bound라는 개념을 이용해서 경로를 개선하더라고요.

 

WA* 알고리즘 포스팅 마지막에 언급했던 바로 그 Anytime A*인데요...!

ARA*에서 W를 어떻게 다루고 있는지 한번 살펴보도록 하겠습니다.


1. ARA* 알고리즘

ARA* 알고리즘도 WA* 알고리즘과 같이 평가 함수 f `에 가중치(w)가 포함되어 있습니다.
f ` = g ` + w * h ` (w >= 1)
노드들의 g ` 값은 ∞ 으로 초기화 되어 있습니다.

A* 알고리즘과 비교하면 1) 종결 조건이 달라진 것, 2) INCONS라는 집합이 추가된 것이 크게 다릅니다.

 

ARA* 알고리즘은 이미 계획된 경로를 다시 계획합니다. 그래서 목적 노드를 OPEN에서 다시 뽑을 수 있기 때문에 종결 조건이 달라져야만 했습니다.

또한, 경로를 다시 계획하면서 w가 감소되기 때문에 노드들의 g ` 값이 변할 수 있습니다. 이런 노드들을 추적하기 위해서 INCONS라는 집합이 추가되었습니다.

 

MAIN 과정

  1. 시작 노드의 g ` 값을 0으로 설정합니다.
  2. OPEN, CLOSED, INCONS 라는 빈 리스트들을 초기화 합니다.
  3. 시작 노드의 f ` 값을 계산하고, OPEN 리스트에 추가합니다.
  4. ImprovePath 를 실행합니다.
  5. w를 min(w, g `(목적 노드) / OPEN과 INCONS에 포함되어 있는 노드들의 f ` 최솟값) 으로 설정합니다.
  6. w를 감소시킵니다(정하기 나름인 것 같습니다).
  7. 계획된 경로를 반환합니다.
  8. w가 1보다 작거나 같으면 종료합니다.
  9. INCONS에 추가된 노드들을 OPEN으로 옮깁니다.
  10. CLOSED를 초기화합니다.
  11. 4로 돌아갑니다.

ImprovePath 과정

  1. 목적 노드의 f `값이 OPEN에 포함되어 있는 노드들의 f `최솟값보다 작거나 같으면 종료합니다.
  2. OPEN에서 f `값이 가장 작은 노드 n을 추출합니다.
  3. CLOSED에 n을 추가합니다.
  4. Connected(n)의 모든 노드들(n_c)에 대해서 다음 과정을 수행합니다.
    1. 만약 g `(n_c)가 g `(n) + c(n, n_c) 보다 작거나 같다면 다음 노드로 넘어갑니다.
    2. g `(n_c)를 갱신합니다.
    3. 만약 n_c가 CLOSED에 포함되어 있지 않다면 n_c를 OPEN에 추가합니다.
    4. CLOSED에 포함되어 있다면 INCONS에 추가합니다.
  5. 1로 돌아갑니다.

 

2. 예제 코드

W 값이 바뀔 때마다 f `이 자동으로 재계산되어야 해서 heuristic 함수에서 W 값을 전역 변수로 읽어오게 수정했습니다.

 

더보기
from dataclasses import dataclass
from typing import List
import heapq
import copy


@dataclass
class Node:
    id: int
    x: float
    y: float

    g_val: float = float("inf")
    parent_id: int = None

    @property
    def f_val(self) -> float:
        return self.g_val + heuristic(self)

    def __lt__(self, other):
        return self.f_val < other.f_val


@dataclass
class Edge:
    start_node: Node
    end_node: Node
    weight: float

    def get_other_node(self, node_id: int) -> Node:
        if self.start_node.id == node_id:
            return self.end_node
        else:
            return self.start_node


@dataclass
class Graph:
    nodes: dict
    adjacent_edges: dict

    def get_node(self, node_id: int) -> Node:
        return self.nodes[node_id]

    def get_connected_nodes(self, node_id: int) -> List[Node]:
        adjacent_edges = self.adjacent_edges.get(node_id, [])
        return [edge.get_other_node(node_id) for edge in adjacent_edges]

    def get_adjacent_edges(self, node_id: int) -> List[Edge]:
        return self.adjacent_edges.get(node_id, [])

    def set_edge_weight(self, edge_id: int, weight: float):
        self.edges[edge_id].set_weight(weight)

    def trace(self, node_id):
        parent_id = self.nodes[node_id].parent_id
        if parent_id != None:
            return self.trace(parent_id) + [node_id]
        return [node_id]


END_NODE = None
W = None


def set_W(w):
    global W
    W = w


def set_heuristic_end_node(node: Node):
    global END_NODE
    END_NODE = node


def heuristic(node) -> float:
    global END_NODE, W
    distance = abs(END_NODE.x - node.x) + abs(END_NODE.y - node.y)
    return W * distance


class Open:
    open_heap: List
    open_dict: dict

    def __init__(self):
        self.open_heap = []
        self.open_dict = {}
        self.num_of_search = 0

    def add(self, node: Node):
        heapq.heappush(self.open_heap, node)
        try:
            self.open_dict[node.id] += 1
        except:
            self.open_dict[node.id] = 1

    def pop(self) -> Node:
        node = heapq.heappop(self.open_heap)
        self.open_dict[node.id] -= 1
        if self.open_dict[node.id] < 0:
            print("[ERROR] 뭔가 잘못되었습니다!!!")
        return node

    @property
    def length(self) -> int:
        return len(self.open_heap)

    @property
    def min_f(self) -> float:
        try:
            return heapq.nsmallest(1, self.open_heap)[0].f_val
        except:
            return float("inf")


class Closed:
    closed_dict: dict

    def __init__(self):
        self.closed_dict = {}

    def is_contain(self, node: Node):
        return node.id in self.closed_dict

    def add(self, node: Node):
        self.closed_dict[node.id] = True

    def remove(self, node: Node):
        del self.closed_dict[node.id]

    def clear(self):
        self.closed_dict = {}


def ara(graph, start_node_id, end_node_id, W=1):
    open = Open()
    closed = Closed()
    incons = Open()
    PREV_W = W

    start_node = graph.get_node(start_node_id)
    end_node = graph.get_node(end_node_id)
    set_heuristic_end_node(end_node)
    set_W(W)

    start_node.g_val = 0
    open.add(start_node)

    while W >= 1:
        PREV_W = W
        improve_path(graph, end_node, W, open, closed, incons, heuristic)
        W = min(W, end_node.g_val / min(open.min_f, incons.min_f))
        W -= 0.1
        while incons.length > 0:
            open.add(incons.pop())
        closed.clear()

    # W가 1이 되었을 때
    if PREV_W != max(1, W):
        improve_path(graph, end_node, max(1, W), open, closed, incons, heuristic)


def improve_path(graph, end_node, W, open, closed, incons, heuristic):
    set_W(W)
    num_of_searched_node = 0
    while end_node.f_val > open.min_f:
        node = open.pop()
        num_of_searched_node += 1
        closed.add(node)

        edges = graph.get_adjacent_edges(node.id)
        for edge in edges:
            other_node = edge.get_other_node(node.id)
            new_g_val = node.g_val + edge.weight
            if other_node.g_val <= new_g_val:
                continue

            other_node.g_val = new_g_val

            if not closed.is_contain(other_node):
                other_node.parent_id = node.id
                open.add(other_node)
            else:
                incons.add(other_node)
    print("W 값:", W)
    print("경로:", graph.trace(end_node.id))
    print("실제로 탐색된 노드 수:", num_of_searched_node)
    print("====================")


def make_graph(grid: List[List[int]]) -> Graph:
    nodes = {}
    for rowIdx in range(len(grid)):
        for colIdx in range(len(grid[rowIdx])):
            if grid[rowIdx][colIdx] == 1:
                continue
            id = f"{colIdx}_{rowIdx*2}"
            nodes[id] = Node(id=id, x=colIdx, y=rowIdx * 2)

    adjacent_edges = {}
    for rowIdx in range(len(grid)):
        for colIdx in range(len(grid[rowIdx])):
            if grid[rowIdx][colIdx] == 1:
                continue
            id = f"{colIdx}_{rowIdx*2}"
            upper_id = f"{colIdx}_{rowIdx*2-2}"
            lower_id = f"{colIdx}_{rowIdx*2+2}"
            left_id = f"{colIdx-1}_{rowIdx*2}"
            right_id = f"{colIdx+1}_{rowIdx*2}"
            adjacent_edges[id] = []
            try:
                adjacent_edges[id].append(Edge(start_node=nodes[id], end_node=nodes[upper_id], weight=2))
            except:
                pass
            try:
                adjacent_edges[id].append(Edge(start_node=nodes[id], end_node=nodes[lower_id], weight=2))
            except:
                pass
            try:
                adjacent_edges[id].append(Edge(start_node=nodes[id], end_node=nodes[left_id], weight=1))
            except:
                pass
            try:
                adjacent_edges[id].append(Edge(start_node=nodes[id], end_node=nodes[right_id], weight=1))
            except:
                pass
    return Graph(nodes=nodes, adjacent_edges=adjacent_edges)


if __name__ == "__main__":
    grid = [
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 1, 0, 1, 0, 1, 0, 1, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 1, 0, 0],
    ]

    for val in range(1, 5):
        print("초기 W 값:", val)
        print("====================")
        graph = make_graph(grid)
        ara(graph, "0_0", "9_8", W=val)
        print()
        print()

 

3. 예시

WA* 알고리즘에서 사용한 격자 지도를 재활용했습니다... ㅎㅎ

언제까지 우려 먹을 수 있을까요?

 

ARA* 알고리즘이 정상적으로 동작하는지 어떻게 알 수 있을까요?

저는 W가 3 이상일 때, 최초에 긴 경로를 반환하고, 그 후에는 최적의 경로를 반환하는 것을 확인하려고 합니다.

 

다행히 실행 결과가 가설을 입증해주네요...!

휴우우...


ARA* 알고리즘은 어떠셨나요?

Anytime이라고 해서 그 어떤 순간이라도 경로를 찾을 수 있는 것은 아니고, 시간 제약 조건에 맞춰서 해를 찾을 수 있는 것이었더라고요 ㅎㅎ

 

ARA* 알고리즘에서 소개된 W를 다루는 방법을 MAPF 알고리즘들에도 적용시킬 수 있지 않을까 싶은데...

어떻게 적용하면 좋을지는 조금 더 고민해봐야 할 것 같습니다👍👍👍

반응형