581 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			OCaml
		
	
	
	
	
	
		
		
			
		
	
	
			581 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			OCaml
		
	
	
	
	
	
|   | (*
 | ||
|  |  * Copyright (c) 1997-1999 Massachusetts Institute of Technology | ||
|  |  * Copyright (c) 2003, 2007-14 Matteo Frigo | ||
|  |  * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology | ||
|  |  * | ||
|  |  * This program is free software; you can redistribute it and/or modify | ||
|  |  * it under the terms of the GNU General Public License as published by | ||
|  |  * the Free Software Foundation; either version 2 of the License, or | ||
|  |  * (at your option) any later version. | ||
|  |  * | ||
|  |  * This program is distributed in the hope that it will be useful, | ||
|  |  * but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
|  |  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||
|  |  * GNU General Public License for more details. | ||
|  |  * | ||
|  |  * You should have received a copy of the GNU General Public License | ||
|  |  * along with this program; if not, write to the Free Software | ||
|  |  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA | ||
|  |  * | ||
|  |  *) | ||
|  | 
 | ||
|  | 
 | ||
|  | open Util | ||
|  | open Expr | ||
|  | 
 | ||
|  | let node_insert x =  Assoctable.insert Expr.hash x | ||
|  | let node_lookup x =  Assoctable.lookup Expr.hash (==) x | ||
|  | 
 | ||
|  | (*************************************************************
 | ||
|  |  * Algebraic simplifier/elimination of common subexpressions | ||
|  |  *************************************************************) | ||
|  | module AlgSimp : sig  | ||
|  |   val algsimp : expr list -> expr list | ||
|  | end = struct | ||
|  | 
 | ||
|  |   open Monads.StateMonad | ||
|  |   open Monads.MemoMonad | ||
|  |   open Assoctable | ||
|  | 
 | ||
|  |   let fetchSimp =  | ||
|  |     fetchState >>= fun (s, _) -> returnM s | ||
|  |   let storeSimp s = | ||
|  |     fetchState >>= (fun (_, c) -> storeState (s, c)) | ||
|  |   let lookupSimpM key = | ||
|  |     fetchSimp >>= fun table -> | ||
|  |       returnM (node_lookup key table) | ||
|  |   let insertSimpM key value = | ||
|  |     fetchSimp >>= fun table -> | ||
|  |       storeSimp (node_insert key value table) | ||
|  | 
 | ||
|  |   let subset a b = | ||
|  |     List.for_all (fun x -> List.exists (fun y -> x == y) b) a | ||
|  | 
 | ||
|  |   let structurallyEqualCSE a b =  | ||
|  |     match (a, b) with | ||
|  |     | (Num a, Num b) -> Number.equal a b | ||
|  |     | (NaN a, NaN b) -> a == b | ||
|  |     | (Load a, Load b) -> Variable.same a b | ||
|  |     | (Times (a, a'), Times (b, b')) -> | ||
|  |  	((a == b) && (a' == b')) || | ||
|  |  	((a == b') && (a' == b)) | ||
|  |     | (CTimes (a, a'), CTimes (b, b')) -> | ||
|  |  	((a == b) && (a' == b')) || | ||
|  |  	((a == b') && (a' == b)) | ||
|  |     | (CTimesJ (a, a'), CTimesJ (b, b')) -> ((a == b) && (a' == b')) | ||
|  |     | (Plus a, Plus b) -> subset a b && subset b a | ||
|  |     | (Uminus a, Uminus b) -> (a == b) | ||
|  |     | _ -> false | ||
|  | 
 | ||
|  |   let hashCSE x =  | ||
|  |     if (!Magic.randomized_cse) then | ||
|  |       Oracle.hash x | ||
|  |     else | ||
|  |       Expr.hash x | ||
|  | 
 | ||
|  |   let equalCSE a b =  | ||
|  |     if (!Magic.randomized_cse) then | ||
|  |       (structurallyEqualCSE a b || Oracle.likely_equal a b) | ||
|  |     else | ||
|  |       structurallyEqualCSE a b | ||
|  | 
 | ||
|  |   let fetchCSE =  | ||
|  |     fetchState >>= fun (_, c) -> returnM c | ||
|  |   let storeCSE c = | ||
|  |     fetchState >>= (fun (s, _) -> storeState (s, c)) | ||
|  |   let lookupCSEM key = | ||
|  |     fetchCSE >>= fun table -> | ||
|  |       returnM (Assoctable.lookup hashCSE equalCSE key table) | ||
|  |   let insertCSEM key value = | ||
|  |     fetchCSE >>= fun table -> | ||
|  |       storeCSE (Assoctable.insert hashCSE key value table) | ||
|  | 
 | ||
|  |   (* memoize both x and Uminus x (unless x is already negated) *)  | ||
|  |   let identityM x = | ||
|  |     let memo x = memoizing lookupCSEM insertCSEM returnM x in | ||
|  |     match x with | ||
|  | 	Uminus _ -> memo x  | ||
|  |       |	_ -> memo x >>= fun x' -> memo (Uminus x') >> returnM x' | ||
|  | 
 | ||
|  |   let makeNode = identityM | ||
|  | 
 | ||
|  |   (* simplifiers for various kinds of nodes *) | ||
|  |   let rec snumM = function | ||
|  |       n when Number.is_zero n ->  | ||
|  | 	makeNode (Num (Number.zero)) | ||
|  |     | n when Number.negative n ->  | ||
|  | 	makeNode (Num (Number.negate n)) >>= suminusM | ||
|  |     | n -> makeNode (Num n) | ||
|  | 
 | ||
|  |   and suminusM = function | ||
|  |       Uminus x -> makeNode x | ||
|  |     | Num a when (Number.is_zero a) -> snumM Number.zero | ||
|  |     | a -> makeNode (Uminus a) | ||
|  | 
 | ||
|  |   and stimesM = function  | ||
|  |     | (Uminus a, b) -> stimesM (a, b) >>= suminusM | ||
|  |     | (a, Uminus b) -> stimesM (a, b) >>= suminusM | ||
|  |     | (NaN I, CTimes (a, b)) -> stimesM (NaN I, b) >>=  | ||
|  | 	fun ib -> sctimesM (a, ib) | ||
|  |     | (NaN I, CTimesJ (a, b)) -> stimesM (NaN I, b) >>=  | ||
|  | 	fun ib -> sctimesjM (a, ib) | ||
|  |     | (Num a, Num b) -> snumM (Number.mul a b) | ||
|  |     | (Num a, Times (Num b, c)) ->  | ||
|  | 	snumM (Number.mul a b) >>= fun x -> stimesM (x, c) | ||
|  |     | (Num a, b) when Number.is_zero a -> snumM Number.zero | ||
|  |     | (Num a, b) when Number.is_one a -> makeNode b | ||
|  |     | (Num a, b) when Number.is_mone a -> suminusM b | ||
|  |     | (a, b) when is_known_constant b && not (is_known_constant a) ->  | ||
|  | 	stimesM (b, a) | ||
|  |     | (a, b) -> makeNode (Times (a, b)) | ||
|  | 
 | ||
|  |   and sctimesM = function  | ||
|  |     | (Uminus a, b) -> sctimesM (a, b) >>= suminusM | ||
|  |     | (a, Uminus b) -> sctimesM (a, b) >>= suminusM | ||
|  |     | (a, b) -> makeNode (CTimes (a, b)) | ||
|  | 
 | ||
|  |   and sctimesjM = function  | ||
|  |     | (Uminus a, b) -> sctimesjM (a, b) >>= suminusM | ||
|  |     | (a, Uminus b) -> sctimesjM (a, b) >>= suminusM | ||
|  |     | (a, b) -> makeNode (CTimesJ (a, b)) | ||
|  | 
 | ||
|  |   and reduce_sumM x = match x with | ||
|  |     [] -> returnM [] | ||
|  |   | [Num a] ->  | ||
|  |       if (Number.is_zero a) then  | ||
|  | 	returnM []  | ||
|  |       else returnM x | ||
|  |   | [Uminus (Num a)] ->  | ||
|  |       if (Number.is_zero a) then  | ||
|  | 	returnM []  | ||
|  |       else returnM x | ||
|  |   | (Num a) :: (Num b) :: s ->  | ||
|  |       snumM (Number.add a b) >>= fun x -> | ||
|  | 	reduce_sumM (x :: s) | ||
|  |   | (Num a) :: (Uminus (Num b)) :: s ->  | ||
|  |       snumM (Number.sub a b) >>= fun x -> | ||
|  | 	reduce_sumM (x :: s) | ||
|  |   | (Uminus (Num a)) :: (Num b) :: s ->  | ||
|  |       snumM (Number.sub b a) >>= fun x -> | ||
|  | 	reduce_sumM (x :: s) | ||
|  |   | (Uminus (Num a)) :: (Uminus (Num b)) :: s ->  | ||
|  |       snumM (Number.add a b) >>=  | ||
|  |       suminusM >>= fun x -> | ||
|  | 	reduce_sumM (x :: s) | ||
|  |   | ((Num _) as a) :: b :: s -> reduce_sumM (b :: a :: s) | ||
|  |   | ((Uminus (Num _)) as a) :: b :: s -> reduce_sumM (b :: a :: s) | ||
|  |   | a :: s ->  | ||
|  |       reduce_sumM s >>= fun s' -> returnM (a :: s') | ||
|  | 
 | ||
|  |   and collectible1 = function | ||
|  |     | NaN _ -> false | ||
|  |     | Uminus x -> collectible1 x | ||
|  |     | _ -> true | ||
|  |   and collectible (a, b) = collectible1 a | ||
|  | 
 | ||
|  |   (* collect common factors: ax + bx -> (a+b)x *) | ||
|  |   and collectM which x =  | ||
|  |     let rec findCoeffM which = function | ||
|  |       |	Times (a, b) when collectible (which (a, b)) -> returnM (which (a, b)) | ||
|  |       | Uminus x ->  | ||
|  | 	  findCoeffM which x >>= fun (coeff, b) -> | ||
|  | 	    suminusM coeff >>= fun mcoeff -> | ||
|  |  	      returnM (mcoeff, b) | ||
|  |       | x -> snumM Number.one >>= fun one -> returnM (one, x) | ||
|  |     and separateM xpr = function | ||
|  |  	[] -> returnM ([], []) | ||
|  |       |	a :: b -> | ||
|  |  	  separateM xpr b >>= fun (w, wo) -> | ||
|  | 	    (* try first factor *) | ||
|  |  	    findCoeffM (fun (a, b) -> (a, b)) a >>= fun (c, x) -> | ||
|  |  	      if (xpr == x) && collectible (c, x) then returnM (c :: w, wo) | ||
|  |  	      else | ||
|  | 	      (* try second factor *) | ||
|  |  		findCoeffM (fun (a, b) -> (b, a)) a >>= fun (c, x) -> | ||
|  |  		  if (xpr == x) && collectible (c, x) then returnM (c :: w, wo) | ||
|  |  		  else returnM (w, a :: wo) | ||
|  |     in match x with | ||
|  |       [] -> returnM x | ||
|  |     | [a] -> returnM x | ||
|  |     | a :: b -> | ||
|  |  	findCoeffM which a >>= fun (_, xpr) -> | ||
|  |  	  separateM xpr x >>= fun (w, wo) -> | ||
|  |  	    collectM which wo >>= fun wo' -> | ||
|  |  	      splusM w >>= fun w' -> | ||
|  |  		stimesM (w', xpr) >>= fun t' -> | ||
|  |  		  returnM (t':: wo') | ||
|  | 
 | ||
|  |   and mangleSumM x = returnM x | ||
|  |       >>= reduce_sumM  | ||
|  |       >>= collectM (fun (a, b) -> (a, b)) | ||
|  |       >>= collectM (fun (a, b) -> (b, a)) | ||
|  |       >>= reduce_sumM  | ||
|  |       >>= deepCollectM !Magic.deep_collect_depth | ||
|  |       >>= reduce_sumM | ||
|  | 
 | ||
|  |   and reorder_uminus = function  (* push all Uminuses to the end *) | ||
|  |       [] -> [] | ||
|  |     | ((Uminus _) as a' :: b) -> (reorder_uminus b) @ [a'] | ||
|  |     | (a :: b) -> a :: (reorder_uminus b)                       | ||
|  | 
 | ||
|  |   and canonicalizeM = function  | ||
|  |       [] -> snumM Number.zero | ||
|  |     | [a] -> makeNode a                    (* one term *) | ||
|  |     | a -> generateFusedMultAddM (reorder_uminus a) | ||
|  | 
 | ||
|  |   and generateFusedMultAddM =  | ||
|  |     let rec is_multiplication = function | ||
|  |       | Times (Num a, b) -> true | ||
|  |       | Uminus (Times (Num a, b)) -> true | ||
|  |       | _ -> false | ||
|  |     and separate = function | ||
|  | 	[] -> ([], [], Number.zero) | ||
|  |       | (Times (Num a, b)) as this :: c ->  | ||
|  | 	  let (x, y, max) = separate c in | ||
|  | 	  let newmax = if (Number.greater a max) then a else max in | ||
|  | 	  (this :: x, y, newmax) | ||
|  |       | (Uminus (Times (Num a, b))) as this :: c ->  | ||
|  | 	  let (x, y, max) = separate c in | ||
|  | 	  let newmax = if (Number.greater a max) then a else max in | ||
|  | 	  (this :: x, y, newmax) | ||
|  |       | this :: c -> | ||
|  | 	  let (x, y, max) = separate c in | ||
|  | 	  (x, this :: y, max) | ||
|  |     in fun l -> | ||
|  |       if !Magic.enable_fma && count is_multiplication l >= 2 then | ||
|  | 	let (w, wo, max) = separate l in | ||
|  | 	snumM (Number.div Number.one max) >>= fun invmax' -> | ||
|  | 	  snumM max >>= fun max' -> | ||
|  | 	    mapM (fun x -> stimesM (invmax', x)) w >>= splusM >>= fun pw' -> | ||
|  | 	      stimesM (max', pw') >>= fun mw' -> | ||
|  | 		splusM (wo @ [mw']) | ||
|  |       else  | ||
|  | 	makeNode (Plus l) | ||
|  | 
 | ||
|  | 
 | ||
|  |   and negative = function | ||
|  |       Uminus _ -> true | ||
|  |     | _ -> false | ||
|  | 
 | ||
|  |   (*
 | ||
|  |    * simplify patterns of the form | ||
|  |    * | ||
|  |    *  ((c_1 * a + ...) + ...) +  (c_2 * a + ...) | ||
|  |    * | ||
|  |    * The pattern includes arbitrary coefficients and minus signs. | ||
|  |    * A common case of this pattern is the butterfly | ||
|  |    *   (a + b) + (a - b) | ||
|  |    *   (a + b) - (a - b) | ||
|  |    *) | ||
|  |   (* this whole procedure needs much more thought *) | ||
|  |   and deepCollectM maxdepth l = | ||
|  |     let rec findTerms depth x = match x with | ||
|  |       | Uminus x -> findTerms depth x | ||
|  |       |	Times (Num _, b) -> (findTerms (depth - 1) b) | ||
|  |       |	Plus l when depth > 0 -> | ||
|  | 	  x :: List.flatten (List.map (findTerms (depth - 1)) l) | ||
|  |       |	x -> [x] | ||
|  |     and duplicates = function | ||
|  | 	[] -> [] | ||
|  |       |	a :: b -> if List.memq a b then a :: duplicates b | ||
|  |       else duplicates b | ||
|  | 
 | ||
|  |     in let rec splitDuplicates depth d x = | ||
|  |       if (List.memq x d) then  | ||
|  | 	snumM (Number.zero) >>= fun zero -> | ||
|  | 	  returnM (zero, x) | ||
|  |       else match x with | ||
|  |       |	Times (a, b) -> | ||
|  | 	  splitDuplicates (depth - 1) d a >>= fun (a', xa) -> | ||
|  | 	    splitDuplicates (depth - 1) d b >>= fun (b', xb) -> | ||
|  | 	      stimesM (a', b') >>= fun ab -> | ||
|  | 		stimesM (a, xb) >>= fun xb' -> | ||
|  | 		  stimesM (xa, b) >>= fun xa' -> | ||
|  | 		    stimesM (xa, xb) >>= fun xab -> | ||
|  | 		      splusM [xa'; xb'; xab] >>= fun x -> | ||
|  | 			returnM (ab, x) | ||
|  |       | Uminus a ->  | ||
|  | 	  splitDuplicates depth d a >>= fun (x, y) -> | ||
|  | 	    suminusM x >>= fun ux ->  | ||
|  | 	      suminusM y >>= fun uy ->  | ||
|  | 		returnM (ux, uy) | ||
|  |       |	Plus l when depth > 0 ->  | ||
|  | 	  mapM (splitDuplicates (depth - 1) d) l >>= fun ld -> | ||
|  | 	    let (l', d') = List.split ld in | ||
|  | 	    splusM l' >>= fun p -> | ||
|  | 	      splusM d' >>= fun d'' -> | ||
|  | 	      returnM (p, d'') | ||
|  |       |	x ->  | ||
|  | 	  snumM (Number.zero) >>= fun zero' -> | ||
|  | 	    returnM (x, zero') | ||
|  | 
 | ||
|  |     in let l' = List.flatten (List.map (findTerms maxdepth) l) | ||
|  |     in match duplicates l' with | ||
|  |     | [] -> returnM l | ||
|  |     | d -> | ||
|  | 	mapM (splitDuplicates maxdepth d) l >>= fun ld -> | ||
|  | 	  let (l', d') = List.split ld in | ||
|  | 	  splusM l' >>= fun l'' -> | ||
|  | 	    let rec flattenPlusM = function | ||
|  | 	      | Plus l -> returnM l | ||
|  | 	      | Uminus x -> | ||
|  | 		  flattenPlusM x >>= mapM suminusM | ||
|  | 	      | x -> returnM [x] | ||
|  | 	    in | ||
|  | 	    mapM flattenPlusM d' >>= fun d'' -> | ||
|  | 	      splusM (List.flatten d'') >>= fun d''' -> | ||
|  | 		mangleSumM [l''; d'''] | ||
|  | 
 | ||
|  |   and splusM l = | ||
|  |     let fma_heuristics x =  | ||
|  |       if !Magic.enable_fma then  | ||
|  | 	match x with | ||
|  | 	| [Uminus (Times _); Times _] -> Some false | ||
|  | 	| [Times _; Uminus (Times _)] -> Some false | ||
|  | 	| [Uminus (_); Times _] -> Some true | ||
|  | 	| [Times _; Uminus (Plus _)] -> Some true | ||
|  | 	| [_; Uminus (Times _)] -> Some false | ||
|  | 	| [Uminus (Times _); _] -> Some false | ||
|  | 	| _ -> None | ||
|  |       else | ||
|  | 	None | ||
|  |     in | ||
|  |     mangleSumM l >>=  fun l' -> | ||
|  |       (* no terms are negative.  Don't do anything *) | ||
|  |       if not (List.exists negative l') then | ||
|  | 	canonicalizeM l' | ||
|  |       (* all terms are negative.  Negate them all and collect the minus sign *) | ||
|  |       else if List.for_all negative l' then | ||
|  | 	mapM suminusM l' >>= splusM >>= suminusM | ||
|  |       else match fma_heuristics l' with | ||
|  |       |	Some true -> mapM suminusM l' >>= splusM >>= suminusM | ||
|  |       |	Some false -> canonicalizeM l' | ||
|  |       |	None -> | ||
|  |          (* Ask the Oracle for the canonical form *) | ||
|  | 	  if (not !Magic.randomized_cse) && | ||
|  | 	    Oracle.should_flip_sign (Plus l') then | ||
|  | 	    mapM suminusM l' >>= splusM >>= suminusM | ||
|  | 	  else | ||
|  | 	    canonicalizeM l' | ||
|  | 
 | ||
|  |   (* monadic style algebraic simplifier for the dag *) | ||
|  |   let rec algsimpM x = | ||
|  |     memoizing lookupSimpM insertSimpM  | ||
|  |       (function  | ||
|  |  	| Num a -> snumM a | ||
|  |  	| NaN _ as x -> makeNode x | ||
|  |  	| Plus a ->  | ||
|  |  	    mapM algsimpM a >>= splusM | ||
|  |  	| Times (a, b) ->  | ||
|  |  	    (algsimpM a >>= fun a' -> | ||
|  |  	      algsimpM b >>= fun b' -> | ||
|  |  		stimesM (a', b')) | ||
|  |  	| CTimes (a, b) ->  | ||
|  |  	    (algsimpM a >>= fun a' -> | ||
|  |  	      algsimpM b >>= fun b' -> | ||
|  | 		sctimesM (a', b')) | ||
|  |  	| CTimesJ (a, b) ->  | ||
|  |  	    (algsimpM a >>= fun a' -> | ||
|  |  	      algsimpM b >>= fun b' -> | ||
|  | 		sctimesjM (a', b')) | ||
|  |  	| Uminus a ->  | ||
|  |  	    algsimpM a >>= suminusM  | ||
|  |  	| Store (v, a) -> | ||
|  |  	    algsimpM a >>= fun a' -> | ||
|  |  	      makeNode (Store (v, a')) | ||
|  |  	| Load _ as x -> makeNode x) | ||
|  |       x | ||
|  | 
 | ||
|  |    let initialTable = (empty, empty) | ||
|  |    let simp_roots = mapM algsimpM | ||
|  |    let algsimp = runM initialTable simp_roots | ||
|  | end | ||
|  | 
 | ||
|  | (*************************************************************
 | ||
|  |  * Network transposition algorithm | ||
|  |  *************************************************************) | ||
|  | module Transpose = struct | ||
|  |   open Monads.StateMonad | ||
|  |   open Monads.MemoMonad | ||
|  |   open Littlesimp | ||
|  | 
 | ||
|  |   let fetchDuals = fetchState | ||
|  |   let storeDuals = storeState | ||
|  | 
 | ||
|  |   let lookupDualsM key = | ||
|  |     fetchDuals >>= fun table -> | ||
|  |       returnM (node_lookup key table) | ||
|  | 
 | ||
|  |   let insertDualsM key value = | ||
|  |     fetchDuals >>= fun table -> | ||
|  |       storeDuals (node_insert key value table) | ||
|  | 
 | ||
|  |   let rec visit visited vtable parent_table = function | ||
|  |       [] -> (visited, parent_table) | ||
|  |     | node :: rest -> | ||
|  | 	match node_lookup node vtable with | ||
|  | 	| Some _ -> visit visited vtable parent_table rest | ||
|  | 	| None -> | ||
|  | 	    let children = match node with | ||
|  | 	    | Store (v, n) -> [n] | ||
|  | 	    | Plus l -> l | ||
|  | 	    | Times (a, b) -> [a; b] | ||
|  | 	    | CTimes (a, b) -> [a; b] | ||
|  | 	    | CTimesJ (a, b) -> [a; b] | ||
|  | 	    | Uminus x -> [x] | ||
|  | 	    | _ -> [] | ||
|  | 	    in let rec loop t = function | ||
|  | 		[] -> t | ||
|  | 	      |	a :: rest -> | ||
|  | 		  (match node_lookup a t with | ||
|  | 		    None -> loop (node_insert a [node] t) rest | ||
|  | 		  | Some c -> loop (node_insert a (node :: c) t) rest) | ||
|  | 	    in  | ||
|  | 	    (visit  | ||
|  | 	       (node :: visited) | ||
|  | 	       (node_insert node () vtable) | ||
|  | 	       (loop parent_table children) | ||
|  | 	       (children @ rest)) | ||
|  | 
 | ||
|  |   let make_transposer parent_table = | ||
|  |     let rec termM node candidate_parent =  | ||
|  |       match candidate_parent with | ||
|  |       |	Store (_, n) when n == node ->  | ||
|  | 	  dualM candidate_parent >>= fun x' -> returnM [x'] | ||
|  |       | Plus (l) when List.memq node l ->  | ||
|  | 	  dualM candidate_parent >>= fun x' -> returnM [x'] | ||
|  |       | Times (a, b) when b == node ->  | ||
|  | 	  dualM candidate_parent >>= fun x' ->  | ||
|  | 	    returnM [makeTimes (a, x')] | ||
|  |       | CTimes (a, b) when b == node ->  | ||
|  | 	  dualM candidate_parent >>= fun x' ->  | ||
|  | 	    returnM [CTimes (a, x')] | ||
|  |       | CTimesJ (a, b) when b == node ->  | ||
|  | 	  dualM candidate_parent >>= fun x' ->  | ||
|  | 	    returnM [CTimesJ (a, x')] | ||
|  |       | Uminus n when n == node ->  | ||
|  | 	  dualM candidate_parent >>= fun x' ->  | ||
|  | 	    returnM [makeUminus x'] | ||
|  |       | _ -> returnM [] | ||
|  |      | ||
|  |     and dualExpressionM this_node =  | ||
|  |       mapM (termM this_node)  | ||
|  | 	(match node_lookup this_node parent_table with | ||
|  | 	| Some a -> a | ||
|  | 	| None -> failwith "bug in dualExpressionM" | ||
|  | 	) >>= fun l -> | ||
|  | 	returnM (makePlus (List.flatten l)) | ||
|  | 
 | ||
|  |     and dualM this_node = | ||
|  |       memoizing lookupDualsM insertDualsM | ||
|  | 	(function | ||
|  | 	  | Load v as x ->  | ||
|  | 	      if (Variable.is_constant v) then | ||
|  | 		returnM (Load v) | ||
|  | 	      else | ||
|  | 		(dualExpressionM x >>= fun d -> | ||
|  | 		  returnM (Store (v, d))) | ||
|  | 	  | Store (v, x) -> returnM (Load v) | ||
|  | 	  | x -> dualExpressionM x) | ||
|  | 	this_node | ||
|  | 
 | ||
|  |     in dualM | ||
|  | 
 | ||
|  |   let is_store = function  | ||
|  |     | Store _ -> true | ||
|  |     | _ -> false | ||
|  | 
 | ||
|  |   let transpose dag =  | ||
|  |     let _ = Util.info "begin transpose" in | ||
|  |     let (all_nodes, parent_table) =  | ||
|  |       visit [] Assoctable.empty Assoctable.empty dag in | ||
|  |     let transposerM = make_transposer parent_table in | ||
|  |     let mapTransposerM = mapM transposerM in | ||
|  |     let duals = runM Assoctable.empty mapTransposerM all_nodes in | ||
|  |     let roots = List.filter is_store duals in | ||
|  |     let _ = Util.info "end transpose" in | ||
|  |     roots | ||
|  | end | ||
|  | 
 | ||
|  | 
 | ||
|  | (*************************************************************
 | ||
|  |  * Various dag statistics | ||
|  |  *************************************************************) | ||
|  | module Stats : sig | ||
|  |   type complexity | ||
|  |   val complexity : Expr.expr list -> complexity | ||
|  |   val same_complexity : complexity -> complexity -> bool | ||
|  |   val leq_complexity : complexity -> complexity -> bool | ||
|  |   val to_string : complexity -> string | ||
|  | end = struct | ||
|  |   type complexity = int * int * int * int * int * int | ||
|  |   let rec visit visited vtable = function | ||
|  |       [] -> visited | ||
|  |     | node :: rest -> | ||
|  | 	match node_lookup node vtable with | ||
|  | 	  Some _ -> visit visited vtable rest | ||
|  | 	| None -> | ||
|  | 	    let children = match node with | ||
|  | 	      Store (v, n) -> [n] | ||
|  | 	    | Plus l -> l | ||
|  | 	    | Times (a, b) -> [a; b] | ||
|  | 	    | Uminus x -> [x] | ||
|  | 	    | _ -> [] | ||
|  | 	    in visit (node :: visited) | ||
|  | 	      (node_insert node () vtable) | ||
|  | 	      (children @ rest) | ||
|  | 
 | ||
|  |   let complexity dag =  | ||
|  |     let rec loop (load, store, plus, times, uminus, num) = function  | ||
|  |       	[] -> (load, store, plus, times, uminus, num) | ||
|  |       | node :: rest -> | ||
|  | 	  loop | ||
|  | 	    (match node with | ||
|  | 	    | Load _ -> (load + 1, store, plus, times, uminus, num) | ||
|  | 	    | Store _ -> (load, store + 1, plus, times, uminus, num) | ||
|  | 	    | Plus x -> (load, store, plus + (List.length x - 1), times, uminus, num) | ||
|  | 	    | Times _ -> (load, store, plus, times + 1, uminus, num) | ||
|  | 	    | Uminus _ -> (load, store, plus, times, uminus + 1, num) | ||
|  | 	    | Num _ -> (load, store, plus, times, uminus, num + 1) | ||
|  | 	    | CTimes _ -> (load, store, plus, times, uminus, num) | ||
|  | 	    | CTimesJ _ -> (load, store, plus, times, uminus, num) | ||
|  | 	    | NaN _ -> (load, store, plus, times, uminus, num)) | ||
|  | 	    rest | ||
|  |     in let (l, s, p, t, u, n) =  | ||
|  |       loop (0, 0, 0, 0, 0, 0) (visit [] Assoctable.empty dag) | ||
|  |     in (l, s, p, t, u, n) | ||
|  | 
 | ||
|  |   let weight (l, s, p, t, u, n) = | ||
|  |     l + s + 10 * p + 20 * t + u + n | ||
|  | 
 | ||
|  |   let same_complexity a b = weight a = weight b | ||
|  |   let leq_complexity a b = weight a <= weight b | ||
|  | 
 | ||
|  |   let to_string (l, s, p, t, u, n) = | ||
|  |     Printf.sprintf "ld=%d st=%d add=%d mul=%d uminus=%d num=%d\n" | ||
|  | 		   l s p t u n | ||
|  | 		    | ||
|  | end     | ||
|  | 
 | ||
|  | (* simplify the dag *) | ||
|  | let algsimp v =  | ||
|  |   let rec simplification_loop v = | ||
|  |     let () = Util.info "simplification step" in | ||
|  |     let complexity = Stats.complexity v in | ||
|  |     let () = Util.info ("complexity = " ^ (Stats.to_string complexity)) in | ||
|  |     let v = (AlgSimp.algsimp @@ Transpose.transpose @@  | ||
|  | 	     AlgSimp.algsimp @@ Transpose.transpose) v in | ||
|  |     let complexity' = Stats.complexity v in | ||
|  |     let () = Util.info ("complexity = " ^ (Stats.to_string complexity')) in | ||
|  |     if (Stats.leq_complexity complexity' complexity) then | ||
|  |       let () = Util.info "end algsimp" in | ||
|  |       v | ||
|  |     else | ||
|  |       simplification_loop v | ||
|  | 
 | ||
|  |   in | ||
|  |   let () = Util.info "begin algsimp" in | ||
|  |   let v = AlgSimp.algsimp v in | ||
|  |   if !Magic.network_transposition then simplification_loop v else v | ||
|  | 
 |