n番目の項をFnとすると
- F0 = 0
- F1 = 1
- Fn + 2 = Fn + Fn + 1 (n >= 0)
となるような数列
0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, ...
int fib(int n) {
if (n == 0) { return 0; }
if (n == 1) { return 1; }
return fib(n - 1) + fib(n - 2);
}
//動作確認用のmainメソッド。以降は省略します。
public static void main(String[] args) {
try (Scanner s = new Scanner(System.in)) {
long n = s.nextLong();
long ret = new Fib().fib(n);
System.out.printf("fib(%d) = %d%n", n, ret);
}
}実行結果
fib(10) = 55
※今回のテーマには関係無いけどメソッドの実行が終わらないのでメモ化します
Map<Integer, Integer> values = new HashMap<>();
int fib(int n) {
if (n == 0) { return 0; }
if (n == 1) { return 1; }
Integer m = values.get(n);
if (m != null) { return m; }
m = fib(n - 1) + fib(n - 2);
values.put(n, m);
return m;
}実行結果
fib(100) = -1869596475
Map<BigInteger, BigInteger> values = new HashMap<>();
BigInteger zero = BigInteger.ZERO;
BigInteger one = BigInteger.ONE;
BigInteger two = one.add(one);
BigInteger fib(BigInteger n) {
if (n.compareTo(zero) == 0) { return zero; }
if (n.compareTo(one) == 0) { return one; }
BigInteger m = values.get(n);
if (m != null) { return m; }
m = fib(n.subtract(one)).add(fib(n.subtract(two)));
values.put(n, m);
return m;
}演算子オーバーロードしたい!!!
まあそれはさておき、
実行結果
fib(100) = 354224848179261915075
Exception in thread "main" java.lang.StackOverflowError
at java.math.BigInteger.subtract(BigInteger.java:1425)
at Fib.fib(Fib.java:24)
at Fib.fib(Fib.java:24)
at Fib.fib(Fib.java:24)
at Fib.fib(Fib.java:24)
at Fib.fib(Fib.java:24)
at Fib.fib(Fib.java:24)
at Fib.fib(Fib.java:24)
(以下略)
- 末尾再帰呼び出しに変形して
- 末尾再帰呼び出しを最適化
long sum(long n) {
if (n == 0) { return 0; }
return n + sum(n - 1);
}やはり n = 10000 ぐらいでスタックオーバーフロー
Exception in thread "main" java.lang.StackOverflowError
at Sum.sum(Sum.java:11)
at Sum.sum(Sum.java:11)
at Sum.sum(Sum.java:11)
(以下略)
long sum(long n) { return sum(n, 0); }
long sum(long n, long m) {
if (n == 0) { return m; }
return sum(n - 1, n + m);
}この時点ではまだスタックオーバーフロー
Javaコンパイラは末尾再帰呼び出しの最適化を行わないので Javaによる関数型プログラミング の7章を参考に末尾再帰呼び出しを手動で最適化する
public class TailRec<T> {
private final Supplier<TailRec<T>> next;
private final boolean done;
private final T result;
private TailRec(Supplier<TailRec<T>> next, boolean done, T result) {
this.next = next;
this.done = done;
this.result = result;
}
public T get() {
return Stream.iterate(this, a -> a.next.get())
.filter(a -> a.done)
.map(a -> a.result)
.findFirst()
.get();
}
public static <T> TailRec<T> call(Supplier<TailRec<T>> next) {
return new TailRec<>(next, false, null);
}
public static <T> TailRec<T> done(T result) {
return new TailRec<>(() -> null, true, result);
}
}long sum(long n) { return sum(n, 0).get(); }
TailRec<Long> sum(long n, long m) {
if (n == 0) { return TailRec.done(m); }
return TailRec.call(() -> sum(n - 1, n + m));
}戻り値をTailRecにしてreturnしてる所をTailRec.callとTailRec.doneで包んだだけ
long sum(long n) { return sum(n, 0); }
long sum(long n, long m) {
if (n == 0) { return m; }
return sum(n - 1, n + m);
}long sum(long n) { return sum(n, 0).get(); }
TailRec<Long> sum(long n, long m) {
if (n == 0) { return TailRec.done(m); }
return TailRec.call(() -> sum(n - 1, n + m));
}無事に結果を得られた!
sum(10000) = 50005000
long sum(long n) { return sum(n, 0); }
long sum(long n, long m) {
if (n == 0) { return m; }
//↓再帰呼び出しでスタック消費
return sum(n - 1, n + m);
}long sum(long n) { return sum(n, 0).get(); }
TailRec<Long> sum(long n, long m) {
if (n == 0) { return TailRec.done(m); }
//↓再帰呼び出しをTailRecで表現して呼び出し元に返す
//↓ここではメソッド実行はしていない
return TailRec.call(() -> sum(n - 1, n + m));
}long sum(long n) { return sum(n, 0).get(); }
TailRec<Long> sum(long n, long m) {
//↓再帰の終了部分
if (n == 0) { return TailRec.done(m); }
return TailRec.call(() -> sum(n - 1, n + m));
}//nextがTailRec.callでラップした部分
//Stream.iterateでTailRecを次々に生成 = 再帰呼び出し
Stream.iterate(this, a -> a.next.get())
.filter(a -> a.done) //再帰呼び出し終了部分だけに絞る
.map(a -> a.result) //値を取り出す
.findFirst() //再帰呼び出し終了部分はひとつだけで良い
.get();sumを実行してTailRec.callでTailRecを返すStream.iterateでTailRec.next.getを実行する- これを再帰呼び出しの分だけ繰り返す
- 再帰呼び出しが終了したら
TailRec.doneでTailRecを返す TailRecから値を取り出して処理終了
これが再帰呼び出しではなくループで実現されている
long sum(long n) { return sum(n, 0).get(); }
TailRec<Long> sum(long n, long m) {
if (n == 0) { return TailRec.done(m); }
//↓再帰呼び出しっぽい
return TailRec.call(() -> sum(n - 1, n + m));
}BigInteger fib(BigInteger n) {
if (n.compareTo(zero) == 0) { return zero; }
if (n.compareTo(one) == 0) { return one; }
BigInteger m = values.get(n);
if (m != null) { return m; }
m = fib(n.subtract(one)).add(fib(n.subtract(two)));
values.put(n, m);
return m;
}TailRec<BigInteger> fib(BigInteger n) {
if (n.compareTo(zero) == 0) { return TailRec.done(zero); }
if (n.compareTo(one) == 0) { return TailRec.done(one); }
BigInteger m1 = memo.get(n);
if (m1 != null) { return TailRec.done(m1); }
return TailRec.call(() -> {
BigInteger m2 = fib(n.subtract(one)).add(fib(n.subtract(two)));
memo.put(n, m2);
return m2;
});
}TailRec<BigInteger> fib(BigInteger n) {
(中略)
return TailRec.call(() -> {
//fibの戻り値はTailRecなのでaddメソッドは無い!!!
BigInteger m2 = fib(n.subtract(one)).add(fib(n.subtract(two)));
memo.put(n, m2);
return m2;
});
}変形前のコード
BigInteger fib(BigInteger n) {
if (n.compareTo(zero) == 0) { return zero; }
if (n.compareTo(one) == 0) { return one; }
BigInteger m = values.get(n);
if (m != null) { return m; }
//↓再帰呼び出しが二つある!!!
m = fib(n.subtract(one)).add(fib(n.subtract(two)));
values.put(n, m);
return m;
}- 継続渡しスタイルに変換すること
- 継続とはある時点における残りの処理
- 関数を計算結果を返すから計算結果を継続に渡すよう変換する
- 詳しくは 再帰関数のスタックオーバーフローを倒す話 その1 - ぐるぐる~ を読もう!(めっちゃわかりやすい)
int increment(int n) {
return n + 1;
}
int twice(int n) {
return n * 2;
}
void run() {
System.out.println(twice(increment(1)));
}これを呼び出し元の後続処理を継続として受け取り、計算結果を渡すようにすると……
//IntConsumer は int -> void な関数インターフェース
void increment(int n, IntConsumer k) {
int ret = n + 1; k.accept(ret);
}
void twice(int n, IntConsumer k) {
int ret = n * 2; k.accept(ret);
}
void run() {
increment(1, x -> twice(x, y -> System.out.println(y)));
}こうなる。
- 継続はある時点における残りの処理なので自然と末尾で呼び出すことになる
- つまり、再帰呼び出しを行う関数を継続渡しスタイルで書けば末尾再帰呼び出しを自然と書く事ができる!!!
簡単な例のため再び末尾再帰呼び出しになっていないsum関数に登場してもらう。
long sum(long n) {
if (n == 0) { return 0; }
//再帰呼び出ししてからnを足している!!!
return n + sum(n - 1);
}//こんな感じで呼び出す
System.out.println(sum(100));void sum(int n, IntConsumer k) {
if (n == 0) { return 0; }
return n + sum(n - 1);
}void sum(int n, IntConsumer k) {
if (n == 0) { k.accept(0); }
//returnしなくなったのでelseブロックを導入した
else {
return n + sum(n - 1);
}
}void sum(int n, IntConsumer k) {
if (n == 0) { k.accept(0); }
else {
int ret = sum(n - 1);
return n + ret;
}
}void sum(int n, IntConsumer k) {
if (n == 0) { k.accept(0); }
else {
int ret = sum(n - 1);
IntConsumer k = ret -> n + ret;
}
}void sum(int n, IntConsumer k) {
if (n == 0) { k.accept(0); }
else {
//n + retをkに渡す必要がある
sum(n - 1, ret -> n + ret);
}
}void sum(int n, IntConsumer k) {
if (n == 0) {
k.accept(0);
} else {
sum(n - 1, ret -> k.accept(n + ret));
}
}//こんな感じで呼び出す
sum(100, System.out::println);末尾再帰呼び出しになっている!!!
void sum(int n, IntConsumer k) { ... }じゃなくて
int sum(int n, IntConsumer k) { ... }にしたいんや!
int sum(int n, IntConsumer k) {
if (n == 0) {
k.accept(0);
} else {
sum(n - 1, ret -> k.accept(n + ret));
}
}int sum(int n, IntConsumer k) {
if (n == 0) {
return k.accept(0);
} else {
return sum(n - 1, ret -> k.accept(n + ret));
}
}//IntUnaryOperator は int -> int な関数インターフェース
int sum(int n, IntUnaryOperator k) {
if (n == 0) {
return k.accept(0);
} else {
return sum(n - 1, ret -> k.accept(n + ret));
}
}int sum(int n, IntUnaryOperator k) {
if (n == 0) {
return k.applyAsInt(0);
} else {
return sum(n - 1, ret -> k.applyAsInt(n + ret));
}
}これで戻り値がintになった!
もともとはこんな感じで呼び出していた。
//System.out::println が IntConsumer
sum(100, System.out::println);この呼び出し方で計算結果が出力されていた
- 計算結果(
int)を受け取って消費していた(void)ところを - 計算結果(
int)を受け取って返す(int)ようにすればいい
int ret = sum(100, IntUnaryOperator.identity());
System.out.println(ret);//こんな感じの呼び出し用メソッドを作っておくと便利
int sum(int n) {
return sum(n, IntUnaryOperator.identity());
}returnしている箇所を継続の実行に変える- 再帰呼び出しから後の処理を継続として渡すように変える
BigInteger fib(BigInteger n) {
if (n.compareTo(zero) == 0) { return zero; }
if (n.compareTo(one) == 0) { return one; }
BigInteger m = values.get(n);
if (m != null) { return m; }
m = fib(n.subtract(one)).add(fib(n.subtract(two)));
values.put(n, m);
return m;
}BigInteger fib(BigInteger n, UnaryOperator<BigInteger> k) {
if (n.compareTo(zero) == 0) { return zero; }
if (n.compareTo(one) == 0) { return one; }
BigInteger m = values.get(n);
if (m != null) { return m; }
m = fib(n.subtract(one)).add(fib(n.subtract(two)));
values.put(n, m);
return m;
}BigInteger fib(BigInteger n, UnaryOperator<BigInteger> k) {
if (n.compareTo(zero) == 0) { return k.apply(zero); }
if (n.compareTo(one) == 0) { return k.apply(one); }
BigInteger m = values.get(n);
if (m != null) { return k.apply(m); }
m = fib(n.subtract(one)).add(fib(n.subtract(two)));
values.put(n, m);
return k.apply(m);
}BigInteger fib(BigInteger n, UnaryOperator<BigInteger> k) {
if (n.compareTo(zero) == 0) { return k.apply(zero); }
if (n.compareTo(one) == 0) { return k.apply(one); }
BigInteger m = values.get(n);
if (m != null) { return k.apply(m); }
BigInteger x = fib(n.subtract(one));
BigInteger y = fib(n.subtract(two));
BigInteger z = x.add(y);
values.put(n, z);
return k.apply(z);
}BigInteger fib(BigInteger n, UnaryOperator<BigInteger> k) {
if (n.compareTo(zero) == 0) { return k.apply(zero); }
if (n.compareTo(one) == 0) { return k.apply(one); }
BigInteger m = values.get(n);
if (m != null) { return k.apply(m); }
BigInteger x = fib(n.subtract(one));
BigInteger y = fib(n.subtract(two));
UnaryOperator<BigInteger> k = y -> {
BigInteger z = x.add(y);
values.put(n, z);
return k.apply(z);
};
}BigInteger fib(BigInteger n, UnaryOperator<BigInteger> k) {
if (n.compareTo(zero) == 0) { return k.apply(zero); }
if (n.compareTo(one) == 0) { return k.apply(one); }
BigInteger m = values.get(n);
if (m != null) { return k.apply(m); }
BigInteger x = fib(n.subtract(one));
fib(n.subtract(two), y -> {
BigInteger z = x.add(y);
values.put(n, z);
return k.apply(z);
});
}BigInteger fib(BigInteger n, UnaryOperator<BigInteger> k) {
if (n.compareTo(zero) == 0) { return k.apply(zero); }
if (n.compareTo(one) == 0) { return k.apply(one); }
BigInteger m = values.get(n);
if (m != null) { return k.apply(m); }
BigInteger x = fib(n.subtract(one));
UnaryOperator<BigInteger> k = x -> {
return fib(n.subtract(two), y -> {
BigInteger z = x.add(y);
values.put(n, z);
return k.apply(z);
});
});
}BigInteger fib(BigInteger n, UnaryOperator<BigInteger> k) {
if (n.compareTo(zero) == 0) { return k.apply(zero); }
if (n.compareTo(one) == 0) { return k.apply(one); }
BigInteger m = values.get(n);
if (m != null) { return k.apply(m); }
return fib(n.subtract(one), x -> {
return fib(n.subtract(two), y -> {
BigInteger z = x.add(y);
values.put(n, z);
return k.apply(z);
});
});
}これで継続渡しスタイルになった!
TailRec<BigInteger> fib(BigInteger n, UnaryOperator<BigInteger> k) {
if (n.compareTo(zero) == 0) { return k.apply(zero); }
if (n.compareTo(one) == 0) { return k.apply(one); }
BigInteger m = values.get(n);
if (m != null) { return k.apply(m); }
return fib(n.subtract(one), x -> {
return fib(n.subtract(two), y -> {
BigInteger z = x.add(y);
values.put(n, z);
return k.apply(z);
});
});
}TailRec<BigInteger> fib(BigInteger n, UnaryOperator<BigInteger> k) {
if (n.compareTo(zero) == 0) { return TailRec.call(() -> k.apply(zero)); }
if (n.compareTo(one) == 0) { return TailRec.call(() -> k.apply(one)); }
BigInteger m = values.get(n);
if (m != null) { return TailRec.call(() -> k.apply(m)); }
return TailRec.call(() -> fib(n.subtract(one), x -> {
return TailRec.call(() -> fib(n.subtract(two), y -> {
BigInteger z = x.add(y);
values.put(n, z);
return TailRec.call(() -> k.apply(z));
}));
}));
}TailRec<BigInteger> fib(BigInteger n, Function<BigInteger, TailRec<BigInteger>> k) {
if (n.compareTo(zero) == 0) { return TailRec.call(() -> k.apply(zero)); }
if (n.compareTo(one) == 0) { return TailRec.call(() -> k.apply(one)); }
BigInteger m = values.get(n);
if (m != null) { return TailRec.call(() -> k.apply(m)); }
return TailRec.call(() -> fib(n.subtract(one), x -> {
return TailRec.call(() -> fib(n.subtract(two), y -> {
BigInteger z = x.add(y);
values.put(n, z);
return TailRec.call(() -> k.apply(z));
}));
}));
}ちなみに呼び出し方はこんな感じになる
BigInteger n = new BigInteger("10000");
//UnaryOperator.identity()をTailRec::doneに変えた
TailRec<BigInteger> tailRec = fib(n, TailRec::done);
BigInteger ret = tailRec.get();
System.out.println(ret);33644764876431783266621612005107543310302148460680063906564769974680081442166662368155595513633734025582065332680836159373734790483865268263040892463056431887354544369559827491606602099884183933864652731300088830269235673613135117579297437854413752130520504347701602264758318906527890855154366159582987279682987510631200575428783453215515103870818298969791613127856265033195487140214287532698187962046936097879900350962302291026368131493195275630227837628441540360584402572114334961180023091208287046088923962328835461505776583271252546093591128203925285393434620904245248929403901706233888991085841065183173360437470737908552631764325733993712871937587746897479926305837065742830161637408969178426378624212835258112820516370298089332099905707920064367426202389783111470054074998459250360633560933883831923386783056136435351892133279732908133732642652633989763922723407882928177953580570993691049175470808931841056146322338217465637321248226383092103297701648054726243842374862411453093812206564914032751086643394517512161526545361333111314042436854805106765843493523836959653428071768775328348234345557366719731392746273629108210679280784718035329131176778924659089938635459327894523777674406192240337638674004021330343297496902028328145933418826817683893072003634795623117103101291953169794607632737589253530772552375943788434504067715555779056450443016640119462580972216729758615026968443146952034614932291105970676243268515992834709891284706740862008587135016260312071903172086094081298321581077282076353186624611278245537208532365305775956430072517744315051539600905168603220349163222640885248852433158051534849622434848299380905070483482449327453732624567755879089187190803662058009594743150052402532709746995318770724376825907419939632265984147498193609285223945039707165443156421328157688908058783183404917434556270520223564846495196112460268313970975069382648706613264507665074611512677522748621598642530711298441182622661057163515069260029861704945425047491378115154139941550671256271197133252763631939606902895650288268608362241082050562430701794976171121233066073310059947366875
- Javaコンパイラは末尾再帰呼び出しの最適化をしてくれない
- でも手動で最適化できる
- 再帰呼び出しを末尾再帰呼び出しの形にするにはCPS変換が分かりやすい
- ていうかScalaやれ
before
TailRec<BigInteger> fib(BigInteger n, Function<BigInteger, TailRec<BigInteger>> k) {
if (n.compareTo(zero) == 0) { return TailRec.call(() -> k.apply(zero)); }
if (n.compareTo(one) == 0) { return TailRec.call(() -> k.apply(one)); }
BigInteger m = values.get(n);
if (m != null) { return TailRec.call(() -> k.apply(m)); }
return TailRec.call(() -> fib(n.subtract(one), x -> {
return TailRec.call(() -> fib(n.subtract(two), y -> {
BigInteger z = x.add(y);
values.put(n, z);
return TailRec.call(() -> k.apply(z));
}));
}));
}after
public class SemicolonlessFibonacci {
public static void main(String[] args) {
if (java.util.stream.Stream
.of(new java.math.BigInteger("100"))
.flatMap(a -> java.util.stream.Stream.<F> of(f -> n -> values -> k -> n.compareTo(java.math.BigInteger.ZERO) == 0
? () -> new javafx.util.Pair<>(k.apply(java.math.BigInteger.ZERO), java.util.Optional.empty())
: n.compareTo(java.math.BigInteger.ONE) == 0
? () -> new javafx.util.Pair<>(k.apply(java.math.BigInteger.ONE), java.util.Optional .empty())
: values.get(n) != null
? () -> new javafx.util.Pair<>(k.apply(values.get(n)), java.util.Optional.empty())
: () -> new javafx.util.Pair<>(
f.apply(f)
.apply(n.subtract(java.math.BigInteger.ONE))
.apply(values)
.apply(x -> () -> new javafx.util.Pair<>(
f.apply(f)
.apply(n.subtract(new java.math.BigInteger("2")))
.apply(values)
.apply(y -> values.put(n, x.add(y)) == null
? () -> new javafx.util.Pair<>( k.apply(x .add(y)), java.util.Optional .empty())
: () -> new javafx.util.Pair<>( k.apply(x .add(y)), java.util.Optional .empty())),
java.util.Optional.empty())),
java.util.Optional.empty()))
.map(fib -> java.util.stream.Stream.iterate(
fib.apply(fib)
.apply(a)
.apply(new java.util.HashMap<>())
.apply(m -> () -> new javafx.util.Pair<>(null, java.util.Optional.of(m))), t -> t.get().getKey())
.filter(t -> t.get().getValue().isPresent())
.map(t -> t.get().getValue().get())
.findFirst().get()))
.peek(System.out::println).count() == 0) {}
}
interface F extends java.util.function.Function<F, java.util.function.Function<java.math.BigInteger, java.util.function.Function<java.util.Map<java.math.BigInteger, java.math.BigInteger>, java.util.function.Function<java.util.function.Function<java.math.BigInteger, TailRec<java.math.BigInteger>>, TailRec<java.math.BigInteger>>>>> {
}
interface TailRec<T> extends java.util.function.Supplier<javafx.util.Pair<TailRec<T>, java.util.Optional<T>>> {
}
}