Plan9 libmp のカラツバ乗算

Plan9 libmp はカラツバ乗算を利用しています。どのように途中結果を求めたり一時領域に保存するのか、オープンソース版のソースを読んで調べてみました。

https://github.com/brho/plan9/blob/master/sys/src/libmp/port/mpmul.c

Plan9 libmp の多倍長整数は、配列の添字の小さい方に小さな桁を、添字が増える方向に桁が上がっていきます。低レベルの関数は符号なし整数中の着目する桁スライス同士の四則演算になっていて、桁へのポインタと演算対象の桁数を指定して演算をおこないます。

乗算は、桁数が32を越えると、カラツバ乗算の再帰になり、桁数が少ないときは筆算乗算に切り換えます。このライブラリのカラツバ乗算は、次の式にしたがいます。

p = (u0 + u1 * B**n) * (v0 + v1 * B**n)
  = u0*v0 + (u0*v0 + u1*v1 + (u1-u0)*(v0-v1)) * B**n + u1*v1 * B**(2*n)

B は 32 ビットなら 2**32、64 ビットなら 2**64 です。n を求めるには、被乗数の B 進数での桁数を 2 で割って切り上げます。

static void
mpkaratsuba(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
{
	mpdigit *t, *u0, *u1, *v0, *v1, *u0v0, *u1v1, *res, *diffprod;
	int u0len, u1len, v0len, v1len, reslen;
	int sign, n;

	n = alen/2;
	if(alen&1)
		n++;

続いて、u0、u1、v0、v1 のポインタと桁数を求めます。原則として、上で求めた n を u0 と v0 の桁数に使います。n を切り上げで求めるため、u0、u1、v0、v1 の 4 つの桁数の最大値は n に一致します。これで、 n を作業領域の確保サイズの計算に使っても桁が不足することがなくなります。

	u0len = n;
	u1len = alen-n;
	if(blen > n){
		v0len = n;
		v1len = blen-n;
	} else {
		v0len = blen;
		v1len = 0;
	}
	u0 = a;
	u1 = a + u0len;
	v0 = b;
	v1 = b + v0len;

カラツバ乗算では、筆算乗算とは違い、u0*v1 等を 2 度加算しなければなりません。筆算乗算では計算結果にどんどん足して畳み込んでいけば良いので作業領域が不要ですが、カラツバ乗算では乗算回数を減らす代償として、作業領域が必要になります。ところで、カラツバ乗算に必要な作業領域は、n 桁分 2 つの減算結果と 2n 桁分の部分乗算の結果を格納する分だけで良いはずですが、2n+1 桁分 3 つと、4n+1 桁分一つで、合わせて 10n 桁を越える作業領域を使っています。そうしなければならない理由は、libmp の低レベル n 桁と m 桁の加算関数が桁上げ含めて n + m + 1 桁の結果を求めるためです。

作業領域の確保では、まとめて calloc に相当する mallocz してから切り分けます。

ところで、libmp の作業領域の確保の仕方は、無駄が多いやりかたになっています。理屈では、再帰による繰り返しを含めて全体で必要な作業領域の大きさは、公比 1/2 の等比級数におさまります。例えば、最初の分割時の作業領域の大きさが 10n の場合、その乗算に必要とする領域が倍の 20n を越えることはありません。乗算の再帰呼び出しに入る前に領域をプールとして確保して、その中をスタック構造として再帰呼び出しの格段階の作業領域を切り分けて使うことで、効率良くメモリ管理をおこなうことができますが、そうしたオーバーヘッドを減らす工夫をしていません。

	t = mallocz(Dbytes*5*(2*n+1), 1);
	if(t == nil)
		sysfatal("mpkaratsuba: %r");
	u0v0 = t;
	u1v1 = t + (2*n+1);
	diffprod = t + 2*(2*n+1);
	res = t + 3*(2*n+1);
	reslen = 4*n+1;

部分乗算は最初に w1 = (u1-u0)*(v0-v1) を符号なし減算と符号なし乗算を使って求めます。乗算の方は、桁を分割しながら再帰的にカラツバ乗算を繰り返していきます。減算は符号なしなので、大小関係を比較して、大きな方から小さな方を引いて、符号を別に求めます。この符号は w1 を足すときに符号なし加算か減算のどちらを使うべきかの判定に利用します。

ここで、部分乗算を w1 = (u1+u0)*(v1+v0) でおこなうと、符号判定が不要になる反面、足し算で桁上がりの分、桁が増えてしまいます。減算では桁が増えることはありません。これで、必ず桁数が減らしつつ部分乗算を再帰的におこなっていけるようになり、桁数による終了条件の判定で安全に再帰呼び出しループから抜け出すことが保証できます。もしも加算でおこなうと、場合によっては部分乗算の桁数が変化しなくなり、桁数による終了判定では再帰呼び出しが無限ループに陥る場合があります。

	sign = 1;
	if(mpveccmp(u1, u1len, u0, u0len) < 0){
		sign = -1;
		mpvecsub(u0, u0len, u1, u1len, u0v0);
	} else
		mpvecsub(u1, u1len, u0, u1len, u0v0);

	if(mpveccmp(v0, v0len, v1, v1len) < 0){
		sign *= -1;
		mpvecsub(v1, v1len, v0, v1len, u1v1);
	} else
		mpvecsub(v0, v0len, v1, v1len, u1v1);

	mpvecmul(u0v0, u0len, u1v1, v0len, diffprod);

続いて、w2 = u1*v1w0 = u0*v0 を求めます。

	memset(t, 0, 2*(2*n+1)*Dbytes);
	if(v1len > 0)
		mpvecmul(u1, u1len, v1, v1len, u1v1);

	mpvecmul(u0, u0len, v0, v0len, u0v0);

3つの部分乗算が終わったので、桁をずらしながら足して作業領域に結果を求めます。u0 * v0 は一番下の桁に足し、さらに n 桁ずらして足します。u1 * v1 は n 桁ずらして足し、さらに 2n 桁ずらして足します。その上で、w1 = (u1-u0)*(v0-v1) を n 桁ずらして足します。ただし、w1 は符号があるので、符号が正なら足し込み、負なら引き算します。

	// res = u0*v0<<n + u0*v0
	mpvecadd(res, reslen, u0v0, u0len+v0len, res);
	mpvecadd(res+n, reslen-n, u0v0, u0len+v0len, res+n);

	// res += u1*v1<<n + u1*v1<<2*n
	if(v1len > 0){
		mpvecadd(res+n, reslen-n, u1v1, u1len+v1len, res+n);
		mpvecadd(res+2*n, reslen-2*n, u1v1, u1len+v1len, res+2*n);
	}

	// res += (u1-u0)*(v0-v1)<<n
	if(sign < 0)
		mpvecsub(res+n, reslen-n, diffprod, u0len+v0len, res+n);
	else
		mpvecadd(res+n, reslen-n, diffprod, u0len+v0len, res+n);

最後に、作業領域中の結果を、参照引数の指す配列へコピーして、作業領域を開放します。

	memmove(p, res, (alen+blen)*Dbytes);

	free(t);
}