Peter Bruin on Thu, 07 May 2015 12:09:55 +0200


[Date Prev] [Date Next] [Thread Prev] [Thread Next] [Date Index] [Thread Index]

Re: Strassen multiplication over the integers


Hello Bill,

[Sorry if there is any double post; I tried to send a reply before, but
it seems to have disappeared.]

>> > - ideally the program src/test/tune.c should deal with the tuning.
>> > This program deal with dependencies between tuning parameters.
>> 
>> OK, I hope to have some time soon to try to write an addition for
>> src/test/tune.c for this.
>
> Do not waste too much time with tune.c. The most important is to find
> out how the tuning is affected by the coefficient size.

(I have not yet done this, but hope to look at it soon.)

> By the way, what matrix size your project is using ?

In a currently reasonable case (computations with a curve of genus 22
over a finite field), one has to do matrix multiplications up to size
(361 x 69) * (69 x 67).  In a harder but possibly reachable case (a
curve of genus 40), all sizes will be about twice as large.

>> > - Most of the code is fairly generic, so maybe there could be a
>> > gen_matmul_sw function.

I have now implemented this; see the attached patch.  It is completely
analogous to ZM_mul_sw.  There are some new tests to ensure that the new
code is covered.

>> Indeed, and I actually already have an implementation of exactly this
>> function, but it is an older version with awkward conventions for the
>> indices.  I will try to update this soon.  (Unfortunately it will be
>> hard to tune that one...)
>
> Well, we might let the caller provide the tuning parameter.

I did not see a clearly best way to implement this; the tuning parameter
is currently the fixed value 24.  This seems to be a good value in the
cases I tested (matrices over a field of size p^2, where p = 2^80 + 13).

At some point, we could implement functions F2xqM_mul_Kronecker and
FpXQM_mul_Kronecker, which would mean that we only need Strassen
multiplication (and the tuning that comes with it) for ZM_mul and
possibly Flm_mul.

Thanks,

Peter


commit 6615a44670102192c1c5818cd9d44bd01de54798
Author: Peter Bruin <P.J.Bruin@math.leidenuniv.nl>
Date:   Sat May 2 01:08:09 2015 +0200

    implement gen_matmul_sw

diff --git a/src/basemath/alglin1.c b/src/basemath/alglin1.c
index 646a8e9..c354596 100644
--- a/src/basemath/alglin1.c
+++ b/src/basemath/alglin1.c
@@ -389,11 +389,212 @@ gen_matcolmul(GEN A, GEN B, void *E, const struct bb_field *ff)
   return gen_matcolmul_i(A, B, lgA, lgcols(A), E, ff);
 }
 
+static GEN
+gen_matmul_classical(GEN A, GEN B, long l, long la, long lb,
+		     void *E, const struct bb_field *ff)
+{
+  long j;
+  GEN C = cgetg(lb, t_MAT);
+  for(j = 1; j < lb; j++)
+    gel(C, j) = gen_matcolmul_i(A, gel(B, j), la, l, E, ff);
+  return C;
+}
+
+/* Strassen-Winograd algorithm */
+
+/*
+  Return A[ma+1..ma+da, na+1..na+ea] - B[mb+1..mb+db, nb+1..nb+eb]
+  as an (m x n)-matrix, padding the input with zeroes as necessary.
+*/
+static GEN
+add_slices(long m, long n,
+	   GEN A, long ma, long da, long na, long ea,
+	   GEN B, long mb, long db, long nb, long eb,
+	   void *E, const struct bb_field *ff)
+{
+  long min_d = minss(da, db), min_e = minss(ea, eb), i, j;
+  GEN M = cgetg(n + 1, t_MAT), C;
+
+  for (j = 1; j <= min_e; j++) {
+    gel(M, j) = C = cgetg(m + 1, t_COL);
+    for (i = 1; i <= min_d; i++)
+      gel(C, i) = ff->add(E, gcoeff(A, ma + i, na + j),
+			  gcoeff(B, mb + i, nb + j));
+    for (; i <= da; i++)
+      gel(C, i) = gcoeff(A, ma + i, na + j);
+    for (; i <= db; i++)
+      gel(C, i) = gcoeff(B, mb + i, nb + j);
+    for (; i <= m; i++)
+      gel(C, i) = ff->s(E, 0);
+  }
+  for (; j <= ea; j++) {
+    gel(M, j) = C = cgetg(m + 1, t_COL);
+    for (i = 1; i <= da; i++)
+      gel(C, i) = gcoeff(A, ma + i, na + j);
+    for (; i <= m; i++)
+      gel(C, i) = ff->s(E, 0);
+  }
+  for (; j <= eb; j++) {
+    gel(M, j) = C = cgetg(m + 1, t_COL);
+    for (i = 1; i <= db; i++)
+      gel(C, i) = gcoeff(B, mb + i, nb + j);
+    for (; i <= m; i++)
+      gel(C, i) = ff->s(E, 0);
+  }
+  for (; j <= n; j++) {
+    gel(M, j) = C = cgetg(m + 1, t_COL);
+    for (i = 1; i <= m; i++)
+      gel(C, i) = ff->s(E, 0);
+  }
+  return M;
+}
+
+/*
+  Return A[ma+1..ma+da, na+1..na+ea] - B[mb+1..mb+db, nb+1..nb+eb]
+  as an (m x n)-matrix, padding the input with zeroes as necessary.
+*/
+static GEN
+subtract_slices(long m, long n,
+		GEN A, long ma, long da, long na, long ea,
+		GEN B, long mb, long db, long nb, long eb,
+		void *E, const struct bb_field *ff)
+{
+  long min_d = minss(da, db), min_e = minss(ea, eb), i, j;
+  GEN M = cgetg(n + 1, t_MAT), C;
+
+  for (j = 1; j <= min_e; j++) {
+    gel(M, j) = C = cgetg(m + 1, t_COL);
+    for (i = 1; i <= min_d; i++)
+      gel(C, i) = ff->add(E, gcoeff(A, ma + i, na + j),
+			  ff->neg(E, gcoeff(B, mb + i, nb + j)));
+    for (; i <= da; i++)
+      gel(C, i) = gcoeff(A, ma + i, na + j);
+    for (; i <= db; i++)
+      gel(C, i) = ff->neg(E, gcoeff(B, mb + i, nb + j));
+    for (; i <= m; i++)
+      gel(C, i) = ff->s(E, 0);
+  }
+  for (; j <= ea; j++) {
+    gel(M, j) = C = cgetg(m + 1, t_COL);
+    for (i = 1; i <= da; i++)
+      gel(C, i) = gcoeff(A, ma + i, na + j);
+    for (; i <= m; i++)
+      gel(C, i) = ff->s(E, 0);
+  }
+  for (; j <= eb; j++) {
+    gel(M, j) = C = cgetg(m + 1, t_COL);
+    for (i = 1; i <= db; i++)
+      gel(C, i) = ff->neg(E, gcoeff(B, mb + i, nb + j));
+    for (; i <= m; i++)
+      gel(C, i) = ff->s(E, 0);
+  }
+  for (; j <= n; j++) {
+    gel(M, j) = C = cgetg(m + 1, t_COL);
+    for (i = 1; i <= m; i++)
+      gel(C, i) = ff->s(E, 0);
+  }
+  return M;
+}
+
+static GEN gen_matmul_i(GEN A, GEN B, long l, long la, long lb,
+			void *E, const struct bb_field *ff);
+
+static GEN
+gen_matmul_sw(GEN A, GEN B, long m, long n, long p,
+	      void *E, const struct bb_field *ff)
+{
+  pari_sp av = avma;
+  long m1 = (m + 1)/2, m2 = m/2,
+    n1 = (n + 1)/2, n2 = n/2,
+    p1 = (p + 1)/2, p2 = p/2;
+  GEN A11, A12, A22, B11, B21, B22,
+    S1, S2, S3, S4, T1, T2, T3, T4,
+    M1, M2, M3, M4, M5, M6, M7,
+    V1, V2, V3, C11, C12, C21, C22, C;
+
+  T2 = subtract_slices(n1, p2, B, 0, n1, p1, p2, B, n1, n2, p1, p2, E, ff);
+  S1 = subtract_slices(m2, n1, A, m1, m2, 0, n1, A, 0, m2, 0, n1, E, ff);
+  M2 = gen_matmul_i(S1, T2, m2 + 1, n1 + 1, p2 + 1, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 2, &T2, &M2);  /* destroy S1 */
+  T3 = subtract_slices(n1, p1, T2, 0, n1, 0, p2, B, 0, n1, 0, p1, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 2, &M2, &T3);  /* destroy T2 */
+  S2 = add_slices(m2, n1, A, m1, m2, 0, n1, A, m1, m2, n1, n2, E, ff);
+  T1 = subtract_slices(n1, p1, B, 0, n1, p1, p2, B, 0, n1, 0, p2, E, ff);
+  M3 = gen_matmul_i(S2, T1, m2 + 1, n1 + 1, p2 + 1, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 4, &M2, &T3, &S2, &M3);  /* destroy T1 */
+  S3 = subtract_slices(m1, n1, S2, 0, m2, 0, n1, A, 0, m1, 0, n1, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 4, &M2, &T3, &M3, &S3);  /* destroy S2 */
+  A11 = matslice(A, 1, m1, 1, n1);
+  B11 = matslice(B, 1, n1, 1, p1);
+  M1 = gen_matmul_i(A11, B11, m1 + 1, n1 + 1, p1 + 1, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 5, &M2, &T3, &M3, &S3, &M1);  /* destroy A11, B11 */
+  A12 = matslice(A, 1, m1, n1 + 1, n);
+  B21 = matslice(B, n1 + 1, n, 1, p1);
+  M4 = gen_matmul_i(A12, B21, m1 + 1, n2 + 1, p1 + 1, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 6, &M2, &T3, &M3, &S3, &M1, &M4);  /* destroy A12, B21 */
+  C11 = add_slices(m1, p1, M1, 0, m1, 0, p1, M4, 0, m1, 0, p1, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 6, &M2, &T3, &M3, &S3, &M1, &C11);  /* destroy M4 */
+  M5 = gen_matmul_i(S3, T3, m1 + 1, n1 + 1, p1 + 1, E, ff);
+  S4 = subtract_slices(m1, n2, A, 0, m1, n1, n2, S3, 0, m1, 0, n2, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 7, &M2, &T3, &M3, &M1, &C11, &M5, &S4);  /* destroy S3 */
+  T4 = add_slices(n2, p1, B, n1, n2, 0, p1, T3, 0, n2, 0, p1, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 7, &M2, &M3, &M1, &C11, &M5, &S4, &T4);  /* destroy T3 */
+  V1 = subtract_slices(m1, p1, M1, 0, m1, 0, p1, M5, 0, m1, 0, p1, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 6, &M2, &M3, &S4, &T4, &C11, &V1);  /* destroy M1, M5 */
+  B22 = matslice(B, n1 + 1, n, p1 + 1, p);
+  M6 = gen_matmul_i(S4, B22, m1 + 1, n2 + 1, p2 + 1, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 6, &M2, &M3, &T4, &C11, &V1, &M6);  /* destroy S4, B22 */
+  A22 = matslice(A, m1 + 1, m, n1 + 1, n);
+  M7 = gen_matmul_i(A22, T4, m2 + 1, n2 + 1, p1 + 1, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 6, &M2, &M3, &C11, &V1, &M6, &M7);  /* destroy A22, T4 */
+  V3 = add_slices(m1, p2, V1, 0, m1, 0, p2, M3, 0, m2, 0, p2, E, ff);
+  C12 = add_slices(m1, p2, V3, 0, m1, 0, p2, M6, 0, m1, 0, p2, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 6, &M2, &M3, &C11, &V1, &M7, &C12);  /* destroy V3, M6 */
+  V2 = add_slices(m2, p1, V1, 0, m2, 0, p1, M2, 0, m2, 0, p2, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 5, &M3, &C11, &M7, &C12, &V2);  /* destroy V1, M2 */
+  C21 = add_slices(m2, p1, V2, 0, m2, 0, p1, M7, 0, m2, 0, p1, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 5, &M3, &C11, &C12, &V2, &C21);  /* destroy M7 */
+  C22 = add_slices(m2, p2, V2, 0, m2, 0, p2, M3, 0, m2, 0, p2, E, ff);
+  if (gc_needed(av, 1))
+    gerepileall(av, 4, &C11, &C12, &C21, &C22);  /* destroy V2, M3 */
+  C = mkmat2(mkcol2(C11, C21), mkcol2(C12, C22));
+  return gerepileupto(av, matconcat(C));
+}
+
+/* Strassen-Winograd used for dim >= gen_matmul_sw_bound */
+static const long gen_matmul_sw_bound = 24;
+
+static GEN
+gen_matmul_i(GEN A, GEN B, long l, long la, long lb,
+	     void *E, const struct bb_field *ff)
+{
+  if (l <= gen_matmul_sw_bound
+      || la <= gen_matmul_sw_bound
+      || lb <= gen_matmul_sw_bound)
+    return gen_matmul_classical(A, B, l, la, lb, E, ff);
+  else
+    return gen_matmul_sw(A, B, l - 1, la - 1, lb - 1, E, ff);
+}
+
 GEN
 gen_matmul(GEN A, GEN B, void *E, const struct bb_field *ff)
 {
-  ulong j, l, lgA, lgB = lg(B);
-  GEN C;
+  ulong lgA, lgB = lg(B);
   if (lgB == 1)
     return cgetg(1, t_MAT);
   lgA = lg(A);
@@ -401,11 +602,7 @@ gen_matmul(GEN A, GEN B, void *E, const struct bb_field *ff)
     pari_err_OP("operation 'gen_matmul'", A, B);
   if (lgA == 1)
     return zeromat(0, lgB - 1);
-  l = lgcols(A);
-  C = cgetg(lgB, t_MAT);
-  for(j = 1; j < lgB; j++)
-    gel(C, j) = gen_matcolmul_i(A, gel(B, j), lgA, l, E, ff);
-  return C;
+  return gen_matmul_i(A, B, lgcols(A), lgA, lgB, E, ff);
 }
 
 static GEN
diff --git a/src/test/32/ff b/src/test/32/ff
index df1b630..2f03fc2 100644
--- a/src/test/32/ff
+++ b/src/test/32/ff
@@ -296,14 +296,26 @@ t^4 + t^2
 [3*t^2 + 3*t, 18446744073709551628*t^4 + 5*t^3 + 4*t^2 + 1844674407370955162
 7*t + 18446744073709551628, 18446744073709551628*t^4 + 5*t^3 + 4*t^2 + 18446
 744073709551628*t]~
-? test(q)=my(t=ffgen(q,'t),M=matrix(10,10,i,j,random(t)));subst(charpoly(M),'x,M)==0;
-? test(nextprime(2^7)^5)
+? test(q,n)=my(t=ffgen(q,'t),M=matrix(n,n,i,j,random(t)));subst(charpoly(M),'x,M)==0;
+? test(nextprime(2^7)^5,10)
 1
-? test(nextprime(2^15)^5)
+? test(nextprime(2^15)^5,10)
 1
-? test(nextprime(2^31)^5)
+? test(nextprime(2^31)^5,10)
 1
-? test(nextprime(2^63)^5)
+? test(nextprime(2^63)^5,10)
+1
+? test(nextprime(2^80)^2,10)
+1
+? test(nextprime(2^7)^5,27)
+1
+? test(nextprime(2^15)^5,27)
+1
+? test(nextprime(2^31)^5,27)
+1
+? test(nextprime(2^63)^5,27)
+1
+? test(nextprime(2^80)^2,27)
 1
 ? print("Total time spent: ",gettime);
-Total time spent: 1440
+Total time spent: 1959
diff --git a/src/test/in/ff b/src/test/in/ff
index 8ed968c..4d07a4e 100644
--- a/src/test/in/ff
+++ b/src/test/in/ff
@@ -88,11 +88,17 @@ test(2^5)
 test(7^5)
 test((2^64+13)^5)
 
-test(q)={
-  my(t = ffgen(q, 't), M = matrix(10, 10, i, j, random(t)));
+test(q, n)={
+  my(t = ffgen(q, 't), M = matrix(n, n, i, j, random(t)));
   subst(charpoly(M), 'x, M) == 0;
 }
-test(nextprime(2^7)^5)
-test(nextprime(2^15)^5)
-test(nextprime(2^31)^5)
-test(nextprime(2^63)^5)
+test(nextprime(2^7)^5, 10)
+test(nextprime(2^15)^5, 10)
+test(nextprime(2^31)^5, 10)
+test(nextprime(2^63)^5, 10)
+test(nextprime(2^80)^2, 10)
+test(nextprime(2^7)^5, 27)
+test(nextprime(2^15)^5, 27)
+test(nextprime(2^31)^5, 27)
+test(nextprime(2^63)^5, 27)
+test(nextprime(2^80)^2, 27)