자바 재귀함수 꼬리물기 최적화

Recursion

아래 예제를 보자. 1부터 N까지 합을 구하는 재귀함수다.

1
2
3
int sum(int n) {
return n <= 1 ? n : n + sum(n-1);
}

sum(100); 을 실행하면 결과 값 5050이 잘 리턴된다.

하지만 큰 숫자를 넣고 실행하면 어떻게 될까?

sum(100000); 을 실행하면 우리에게 익숙한 StactOverflowError가 발생한다. (컴퓨터마다 한계숫자는 차이가 있음)
십만은 컴퓨터에게 그리 큰 숫자가 아닌데 왜 메모리가 모자른걸까?

StackOverFlow 개발자라면 모두 아는 사이트

Recursion And Stack in Java

자바는 메소드가 호출될 때 현재 하고 있는 일을 중단하고(suspend) 현재 컨텍스트의 환경을 스택 메모리에 저장(push)한다.

그리고 메소드가 리턴되었을 때 다시 스택에서 가져와(pop) 실행을 재개한다.(resume)

위에서 큰 수를 넣었을 때 재귀함수가 반복적으로 호출되면서 스택에 계속 push만 되다보니 메모리 허용치를 넘어서 에러가 나는 것이였다. 스택메모리는 빠르지만 공간은 작다.

현재 환경을 스택에 저장하는 이유는 메소드 호출 후에 돌아올 지점을 기억하고, 돌아와서 나머지 작업을 재개해야하기 때문이다. 그렇다면 반환 후 나머지 작업을 할 필요가 없을 경우엔 스택에 저장 할 필요가 없다는 것이 아닐까??

이처럼 재귀호출 부분을 맨 마지막 꼬리에 위치시키는 방법을 Tail Call Recursion(꼬리물기 재귀) 라고하며.
Tail Call Recursion을 했을 때 스택저장을 피하는 것을 TCE(Tail Call Elimination) 또는 TCO(Tail Call Optimization)라고 한다.
나머지 작업이 없기 때문에 스택에 저장할 필요가 없어진다.

Tail Call Elimination (꼬리물기 최적화)

정말 그런지 확인해보자 위의 sum 함수를 Tail Call Recursion으로 만들면 다음과 같다.

1
2
3
int sum(int n, int acc) {
return n <= 1 ? acc : sum(n-1, n+acc);
}

처음에 나온 sum 함수는 내부에서 sum(n-1)을 호출한 후 + 연산을 재개해야 하기 때문에 Tail Call이 아니다.

하지만 변경된 위의 예제는 마지막에 호출하고 뒤에 아무일도 하지 않기 때문에 tail call 이라고 할 수 있다.

변경된 함수로 sum(100000); 을 실행해보자.. 역시나 에러가 나는 것을 볼 수 있다.

불행히도 자바에서는 TCE가 구현되어 있지 않다. 함수형 언어같은 경우 위처럼 했을 때 최적화가되어 스택메모리에 저장이 안되지만 자바는 여전히 저장된다.

Tail Call Elimination in Java

자바는 TCE가 지원되지 않지만 비슷하게 구현할 수는 있다.
먼저 TailCall 이라는 추상클래스를 만들고 3개의 추상메소드를 만든다.

1
2
3
4
5
public abstract class TailCall<T> {
public abstract TailCall<T> resume();
public abstract T eval();
public abstract boolean isSuspend();
}

TailCall을 상속하는 두개의 클래스 Return, Suspend를 만든다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
public class Return<T> extends TailCall<T> {
private T t;

public Return(T t) {
this.t = t;
}

@Override
public TailCall<T> resume() {
throw new IllegalStateException("Return has no resume");
}

@Override
public T eval() {
return t;
}

@Override
public boolean isSuspend() {
return false;
}
}

public class Suspend<T> extends TailCall<T> {
private Supplier<TailCall<T>> supplier;

@Override
public TailCall<T> resume() {
return supplier.get();
}

@Override
public T eval() {
throw new IllegalStateException("Suspend has no value");
}

@Override
public boolean isSuspend() {
return true;
}
}

TailCall을 사용하여 sum 함수를 수정하면 다음과 같다.

1
2
3
TailCall<Integer> sum(int n, int acc) {
return n <= 1 ? new TailCall.Return<>(acc) : new TailCall.Suspend<>(() -> sum(n-1, n + acc));
}

종료 조건이 되면 결과값을 담은 TailCall.Return 타입을 리턴하고,

종료 조건이 아니면 TailCall.Suspend 타입을 리턴한다. Suspend를 생성할 때 생성자 인자로 supplier를 받는다.

sum함수에서의 supplier는 () -> sum(n-1, n + acc) 이다.

재귀호출을 하는 함수이며 resume() 했을 때 다음에 호출될 TailCall을 가르키게된다. Linked List 구조와 비슷해보인다.

사용법은 다음과 같다. 큰숫자를 넣고 실행해도 에러없이 잘 동작하는 것을 볼 수 있다.

1
2
3
4
5
TailCall<Integer> tailCall = sum(100000, 1);
while(tailCall.isSuspend()) {
tailCall = tailCall.resume();
}
System.out.println(tailCall.eval());

TailCall tailCall = sum(100000, 1); 을 실행해도 재귀적으로 함수를 바로 호출하지 않기 때문에 스택메모리가 쌓이지 않는다. 함수 호출 부분을 supplier로 선언하여 지연평가를 하기 때문이다.

tailCall.resume(); 을 했을 때 비로소 평가되며 다음 TailCall을 반환한다.
결과적으로 스택메모리를 사용하는 대신 힙메모리를 사용하게되면서 StactOverflowError를 피할 수 있다.

마지막으로 Suspend 인스턴스에서 바로 eval()을 사용할 수 있게 구현하고, 정적 팩토리 메서드 ret(), sus()를 구현하면 TailCall이 다음과 같이 완성된다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import java.util.function.Supplier;

public abstract class TailCall<T> {
public abstract TailCall<T> resume();
public abstract T eval();
public abstract boolean isSuspend();

public static <T> Return<T> ret(T t) {
return new Return<>(t);
}

public static <T> Suspend<T> sus(Supplier<TailCall<T>> s) {
return new Suspend<>(s);
}

public static class Return<T> extends TailCall<T> {
private T t;

public Return(T t) {
this.t = t;
}

@Override
public TailCall<T> resume() {
throw new IllegalStateException("Return has no resume");
}

@Override
public T eval() {
return t;
}

@Override
public boolean isSuspend() {
return false;
}
}

public static class Suspend<T> extends TailCall<T> {
private Supplier<TailCall<T>> supplier;

public Suspend(Supplier<TailCall<T>> supplier) {
this.supplier = supplier;
}

@Override
public TailCall<T> resume() {
return supplier.get();
}

@Override
public T eval() {
TailCall<T> curr = this;
while(curr.isSuspend()) {
curr = curr.resume();
}
return curr.eval();
}

@Override
public boolean isSuspend() {
return true;
}
}
}

사용 코드

1
2
3
4
5
6
7
8
9
TailCall<Integer> sumRecur(int n, int acc) {
return n <= 1 ? TailCall.ret(acc) : TailCall.sus(() -> sumRecur(n-1, n + acc));
}

int sum(int n) {
return sumRecur(n, 1).eval();
}

System.out.println(sum(100000));

reference

Functional Programming in Java