581 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			OCaml
		
	
	
	
	
	
			
		
		
	
	
			581 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			OCaml
		
	
	
	
	
	
| (*
 | |
|  * Copyright (c) 1997-1999 Massachusetts Institute of Technology
 | |
|  * Copyright (c) 2003, 2007-14 Matteo Frigo
 | |
|  * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
 | |
|  *
 | |
|  * This program is free software; you can redistribute it and/or modify
 | |
|  * it under the terms of the GNU General Public License as published by
 | |
|  * the Free Software Foundation; either version 2 of the License, or
 | |
|  * (at your option) any later version.
 | |
|  *
 | |
|  * This program is distributed in the hope that it will be useful,
 | |
|  * but WITHOUT ANY WARRANTY; without even the implied warranty of
 | |
|  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | |
|  * GNU General Public License for more details.
 | |
|  *
 | |
|  * You should have received a copy of the GNU General Public License
 | |
|  * along with this program; if not, write to the Free Software
 | |
|  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
 | |
|  *
 | |
|  *)
 | |
| 
 | |
| 
 | |
| open Util
 | |
| open Expr
 | |
| 
 | |
| let node_insert x =  Assoctable.insert Expr.hash x
 | |
| let node_lookup x =  Assoctable.lookup Expr.hash (==) x
 | |
| 
 | |
| (*************************************************************
 | |
|  * Algebraic simplifier/elimination of common subexpressions
 | |
|  *************************************************************)
 | |
| module AlgSimp : sig 
 | |
|   val algsimp : expr list -> expr list
 | |
| end = struct
 | |
| 
 | |
|   open Monads.StateMonad
 | |
|   open Monads.MemoMonad
 | |
|   open Assoctable
 | |
| 
 | |
|   let fetchSimp = 
 | |
|     fetchState >>= fun (s, _) -> returnM s
 | |
|   let storeSimp s =
 | |
|     fetchState >>= (fun (_, c) -> storeState (s, c))
 | |
|   let lookupSimpM key =
 | |
|     fetchSimp >>= fun table ->
 | |
|       returnM (node_lookup key table)
 | |
|   let insertSimpM key value =
 | |
|     fetchSimp >>= fun table ->
 | |
|       storeSimp (node_insert key value table)
 | |
| 
 | |
|   let subset a b =
 | |
|     List.for_all (fun x -> List.exists (fun y -> x == y) b) a
 | |
| 
 | |
|   let structurallyEqualCSE a b = 
 | |
|     match (a, b) with
 | |
|     | (Num a, Num b) -> Number.equal a b
 | |
|     | (NaN a, NaN b) -> a == b
 | |
|     | (Load a, Load b) -> Variable.same a b
 | |
|     | (Times (a, a'), Times (b, b')) ->
 | |
|  	((a == b) && (a' == b')) ||
 | |
|  	((a == b') && (a' == b))
 | |
|     | (CTimes (a, a'), CTimes (b, b')) ->
 | |
|  	((a == b) && (a' == b')) ||
 | |
|  	((a == b') && (a' == b))
 | |
|     | (CTimesJ (a, a'), CTimesJ (b, b')) -> ((a == b) && (a' == b'))
 | |
|     | (Plus a, Plus b) -> subset a b && subset b a
 | |
|     | (Uminus a, Uminus b) -> (a == b)
 | |
|     | _ -> false
 | |
| 
 | |
|   let hashCSE x = 
 | |
|     if (!Magic.randomized_cse) then
 | |
|       Oracle.hash x
 | |
|     else
 | |
|       Expr.hash x
 | |
| 
 | |
|   let equalCSE a b = 
 | |
|     if (!Magic.randomized_cse) then
 | |
|       (structurallyEqualCSE a b || Oracle.likely_equal a b)
 | |
|     else
 | |
|       structurallyEqualCSE a b
 | |
| 
 | |
|   let fetchCSE = 
 | |
|     fetchState >>= fun (_, c) -> returnM c
 | |
|   let storeCSE c =
 | |
|     fetchState >>= (fun (s, _) -> storeState (s, c))
 | |
|   let lookupCSEM key =
 | |
|     fetchCSE >>= fun table ->
 | |
|       returnM (Assoctable.lookup hashCSE equalCSE key table)
 | |
|   let insertCSEM key value =
 | |
|     fetchCSE >>= fun table ->
 | |
|       storeCSE (Assoctable.insert hashCSE key value table)
 | |
| 
 | |
|   (* memoize both x and Uminus x (unless x is already negated) *) 
 | |
|   let identityM x =
 | |
|     let memo x = memoizing lookupCSEM insertCSEM returnM x in
 | |
|     match x with
 | |
| 	Uminus _ -> memo x 
 | |
|       |	_ -> memo x >>= fun x' -> memo (Uminus x') >> returnM x'
 | |
| 
 | |
|   let makeNode = identityM
 | |
| 
 | |
|   (* simplifiers for various kinds of nodes *)
 | |
|   let rec snumM = function
 | |
|       n when Number.is_zero n -> 
 | |
| 	makeNode (Num (Number.zero))
 | |
|     | n when Number.negative n -> 
 | |
| 	makeNode (Num (Number.negate n)) >>= suminusM
 | |
|     | n -> makeNode (Num n)
 | |
| 
 | |
|   and suminusM = function
 | |
|       Uminus x -> makeNode x
 | |
|     | Num a when (Number.is_zero a) -> snumM Number.zero
 | |
|     | a -> makeNode (Uminus a)
 | |
| 
 | |
|   and stimesM = function 
 | |
|     | (Uminus a, b) -> stimesM (a, b) >>= suminusM
 | |
|     | (a, Uminus b) -> stimesM (a, b) >>= suminusM
 | |
|     | (NaN I, CTimes (a, b)) -> stimesM (NaN I, b) >>= 
 | |
| 	fun ib -> sctimesM (a, ib)
 | |
|     | (NaN I, CTimesJ (a, b)) -> stimesM (NaN I, b) >>= 
 | |
| 	fun ib -> sctimesjM (a, ib)
 | |
|     | (Num a, Num b) -> snumM (Number.mul a b)
 | |
|     | (Num a, Times (Num b, c)) -> 
 | |
| 	snumM (Number.mul a b) >>= fun x -> stimesM (x, c)
 | |
|     | (Num a, b) when Number.is_zero a -> snumM Number.zero
 | |
|     | (Num a, b) when Number.is_one a -> makeNode b
 | |
|     | (Num a, b) when Number.is_mone a -> suminusM b
 | |
|     | (a, b) when is_known_constant b && not (is_known_constant a) -> 
 | |
| 	stimesM (b, a)
 | |
|     | (a, b) -> makeNode (Times (a, b))
 | |
| 
 | |
|   and sctimesM = function 
 | |
|     | (Uminus a, b) -> sctimesM (a, b) >>= suminusM
 | |
|     | (a, Uminus b) -> sctimesM (a, b) >>= suminusM
 | |
|     | (a, b) -> makeNode (CTimes (a, b))
 | |
| 
 | |
|   and sctimesjM = function 
 | |
|     | (Uminus a, b) -> sctimesjM (a, b) >>= suminusM
 | |
|     | (a, Uminus b) -> sctimesjM (a, b) >>= suminusM
 | |
|     | (a, b) -> makeNode (CTimesJ (a, b))
 | |
| 
 | |
|   and reduce_sumM x = match x with
 | |
|     [] -> returnM []
 | |
|   | [Num a] -> 
 | |
|       if (Number.is_zero a) then 
 | |
| 	returnM [] 
 | |
|       else returnM x
 | |
|   | [Uminus (Num a)] -> 
 | |
|       if (Number.is_zero a) then 
 | |
| 	returnM [] 
 | |
|       else returnM x
 | |
|   | (Num a) :: (Num b) :: s -> 
 | |
|       snumM (Number.add a b) >>= fun x ->
 | |
| 	reduce_sumM (x :: s)
 | |
|   | (Num a) :: (Uminus (Num b)) :: s -> 
 | |
|       snumM (Number.sub a b) >>= fun x ->
 | |
| 	reduce_sumM (x :: s)
 | |
|   | (Uminus (Num a)) :: (Num b) :: s -> 
 | |
|       snumM (Number.sub b a) >>= fun x ->
 | |
| 	reduce_sumM (x :: s)
 | |
|   | (Uminus (Num a)) :: (Uminus (Num b)) :: s -> 
 | |
|       snumM (Number.add a b) >>= 
 | |
|       suminusM >>= fun x ->
 | |
| 	reduce_sumM (x :: s)
 | |
|   | ((Num _) as a) :: b :: s -> reduce_sumM (b :: a :: s)
 | |
|   | ((Uminus (Num _)) as a) :: b :: s -> reduce_sumM (b :: a :: s)
 | |
|   | a :: s -> 
 | |
|       reduce_sumM s >>= fun s' -> returnM (a :: s')
 | |
| 
 | |
|   and collectible1 = function
 | |
|     | NaN _ -> false
 | |
|     | Uminus x -> collectible1 x
 | |
|     | _ -> true
 | |
|   and collectible (a, b) = collectible1 a
 | |
| 
 | |
|   (* collect common factors: ax + bx -> (a+b)x *)
 | |
|   and collectM which x = 
 | |
|     let rec findCoeffM which = function
 | |
|       |	Times (a, b) when collectible (which (a, b)) -> returnM (which (a, b))
 | |
|       | Uminus x -> 
 | |
| 	  findCoeffM which x >>= fun (coeff, b) ->
 | |
| 	    suminusM coeff >>= fun mcoeff ->
 | |
|  	      returnM (mcoeff, b)
 | |
|       | x -> snumM Number.one >>= fun one -> returnM (one, x)
 | |
|     and separateM xpr = function
 | |
|  	[] -> returnM ([], [])
 | |
|       |	a :: b ->
 | |
|  	  separateM xpr b >>= fun (w, wo) ->
 | |
| 	    (* try first factor *)
 | |
|  	    findCoeffM (fun (a, b) -> (a, b)) a >>= fun (c, x) ->
 | |
|  	      if (xpr == x) && collectible (c, x) then returnM (c :: w, wo)
 | |
|  	      else
 | |
| 	      (* try second factor *)
 | |
|  		findCoeffM (fun (a, b) -> (b, a)) a >>= fun (c, x) ->
 | |
|  		  if (xpr == x) && collectible (c, x) then returnM (c :: w, wo)
 | |
|  		  else returnM (w, a :: wo)
 | |
|     in match x with
 | |
|       [] -> returnM x
 | |
|     | [a] -> returnM x
 | |
|     | a :: b ->
 | |
|  	findCoeffM which a >>= fun (_, xpr) ->
 | |
|  	  separateM xpr x >>= fun (w, wo) ->
 | |
|  	    collectM which wo >>= fun wo' ->
 | |
|  	      splusM w >>= fun w' ->
 | |
|  		stimesM (w', xpr) >>= fun t' ->
 | |
|  		  returnM (t':: wo')
 | |
| 
 | |
|   and mangleSumM x = returnM x
 | |
|       >>= reduce_sumM 
 | |
|       >>= collectM (fun (a, b) -> (a, b))
 | |
|       >>= collectM (fun (a, b) -> (b, a))
 | |
|       >>= reduce_sumM 
 | |
|       >>= deepCollectM !Magic.deep_collect_depth
 | |
|       >>= reduce_sumM
 | |
| 
 | |
|   and reorder_uminus = function  (* push all Uminuses to the end *)
 | |
|       [] -> []
 | |
|     | ((Uminus _) as a' :: b) -> (reorder_uminus b) @ [a']
 | |
|     | (a :: b) -> a :: (reorder_uminus b)                      
 | |
| 
 | |
|   and canonicalizeM = function 
 | |
|       [] -> snumM Number.zero
 | |
|     | [a] -> makeNode a                    (* one term *)
 | |
|     | a -> generateFusedMultAddM (reorder_uminus a)
 | |
| 
 | |
|   and generateFusedMultAddM = 
 | |
|     let rec is_multiplication = function
 | |
|       | Times (Num a, b) -> true
 | |
|       | Uminus (Times (Num a, b)) -> true
 | |
|       | _ -> false
 | |
|     and separate = function
 | |
| 	[] -> ([], [], Number.zero)
 | |
|       | (Times (Num a, b)) as this :: c -> 
 | |
| 	  let (x, y, max) = separate c in
 | |
| 	  let newmax = if (Number.greater a max) then a else max in
 | |
| 	  (this :: x, y, newmax)
 | |
|       | (Uminus (Times (Num a, b))) as this :: c -> 
 | |
| 	  let (x, y, max) = separate c in
 | |
| 	  let newmax = if (Number.greater a max) then a else max in
 | |
| 	  (this :: x, y, newmax)
 | |
|       | this :: c ->
 | |
| 	  let (x, y, max) = separate c in
 | |
| 	  (x, this :: y, max)
 | |
|     in fun l ->
 | |
|       if !Magic.enable_fma && count is_multiplication l >= 2 then
 | |
| 	let (w, wo, max) = separate l in
 | |
| 	snumM (Number.div Number.one max) >>= fun invmax' ->
 | |
| 	  snumM max >>= fun max' ->
 | |
| 	    mapM (fun x -> stimesM (invmax', x)) w >>= splusM >>= fun pw' ->
 | |
| 	      stimesM (max', pw') >>= fun mw' ->
 | |
| 		splusM (wo @ [mw'])
 | |
|       else 
 | |
| 	makeNode (Plus l)
 | |
| 
 | |
| 
 | |
|   and negative = function
 | |
|       Uminus _ -> true
 | |
|     | _ -> false
 | |
| 
 | |
|   (*
 | |
|    * simplify patterns of the form
 | |
|    *
 | |
|    *  ((c_1 * a + ...) + ...) +  (c_2 * a + ...)
 | |
|    *
 | |
|    * The pattern includes arbitrary coefficients and minus signs.
 | |
|    * A common case of this pattern is the butterfly
 | |
|    *   (a + b) + (a - b)
 | |
|    *   (a + b) - (a - b)
 | |
|    *)
 | |
|   (* this whole procedure needs much more thought *)
 | |
|   and deepCollectM maxdepth l =
 | |
|     let rec findTerms depth x = match x with
 | |
|       | Uminus x -> findTerms depth x
 | |
|       |	Times (Num _, b) -> (findTerms (depth - 1) b)
 | |
|       |	Plus l when depth > 0 ->
 | |
| 	  x :: List.flatten (List.map (findTerms (depth - 1)) l)
 | |
|       |	x -> [x]
 | |
|     and duplicates = function
 | |
| 	[] -> []
 | |
|       |	a :: b -> if List.memq a b then a :: duplicates b
 | |
|       else duplicates b
 | |
| 
 | |
|     in let rec splitDuplicates depth d x =
 | |
|       if (List.memq x d) then 
 | |
| 	snumM (Number.zero) >>= fun zero ->
 | |
| 	  returnM (zero, x)
 | |
|       else match x with
 | |
|       |	Times (a, b) ->
 | |
| 	  splitDuplicates (depth - 1) d a >>= fun (a', xa) ->
 | |
| 	    splitDuplicates (depth - 1) d b >>= fun (b', xb) ->
 | |
| 	      stimesM (a', b') >>= fun ab ->
 | |
| 		stimesM (a, xb) >>= fun xb' ->
 | |
| 		  stimesM (xa, b) >>= fun xa' ->
 | |
| 		    stimesM (xa, xb) >>= fun xab ->
 | |
| 		      splusM [xa'; xb'; xab] >>= fun x ->
 | |
| 			returnM (ab, x)
 | |
|       | Uminus a -> 
 | |
| 	  splitDuplicates depth d a >>= fun (x, y) ->
 | |
| 	    suminusM x >>= fun ux -> 
 | |
| 	      suminusM y >>= fun uy -> 
 | |
| 		returnM (ux, uy)
 | |
|       |	Plus l when depth > 0 -> 
 | |
| 	  mapM (splitDuplicates (depth - 1) d) l >>= fun ld ->
 | |
| 	    let (l', d') = List.split ld in
 | |
| 	    splusM l' >>= fun p ->
 | |
| 	      splusM d' >>= fun d'' ->
 | |
| 	      returnM (p, d'')
 | |
|       |	x -> 
 | |
| 	  snumM (Number.zero) >>= fun zero' ->
 | |
| 	    returnM (x, zero')
 | |
| 
 | |
|     in let l' = List.flatten (List.map (findTerms maxdepth) l)
 | |
|     in match duplicates l' with
 | |
|     | [] -> returnM l
 | |
|     | d ->
 | |
| 	mapM (splitDuplicates maxdepth d) l >>= fun ld ->
 | |
| 	  let (l', d') = List.split ld in
 | |
| 	  splusM l' >>= fun l'' ->
 | |
| 	    let rec flattenPlusM = function
 | |
| 	      | Plus l -> returnM l
 | |
| 	      | Uminus x ->
 | |
| 		  flattenPlusM x >>= mapM suminusM
 | |
| 	      | x -> returnM [x]
 | |
| 	    in
 | |
| 	    mapM flattenPlusM d' >>= fun d'' ->
 | |
| 	      splusM (List.flatten d'') >>= fun d''' ->
 | |
| 		mangleSumM [l''; d''']
 | |
| 
 | |
|   and splusM l =
 | |
|     let fma_heuristics x = 
 | |
|       if !Magic.enable_fma then 
 | |
| 	match x with
 | |
| 	| [Uminus (Times _); Times _] -> Some false
 | |
| 	| [Times _; Uminus (Times _)] -> Some false
 | |
| 	| [Uminus (_); Times _] -> Some true
 | |
| 	| [Times _; Uminus (Plus _)] -> Some true
 | |
| 	| [_; Uminus (Times _)] -> Some false
 | |
| 	| [Uminus (Times _); _] -> Some false
 | |
| 	| _ -> None
 | |
|       else
 | |
| 	None
 | |
|     in
 | |
|     mangleSumM l >>=  fun l' ->
 | |
|       (* no terms are negative.  Don't do anything *)
 | |
|       if not (List.exists negative l') then
 | |
| 	canonicalizeM l'
 | |
|       (* all terms are negative.  Negate them all and collect the minus sign *)
 | |
|       else if List.for_all negative l' then
 | |
| 	mapM suminusM l' >>= splusM >>= suminusM
 | |
|       else match fma_heuristics l' with
 | |
|       |	Some true -> mapM suminusM l' >>= splusM >>= suminusM
 | |
|       |	Some false -> canonicalizeM l'
 | |
|       |	None ->
 | |
|          (* Ask the Oracle for the canonical form *)
 | |
| 	  if (not !Magic.randomized_cse) &&
 | |
| 	    Oracle.should_flip_sign (Plus l') then
 | |
| 	    mapM suminusM l' >>= splusM >>= suminusM
 | |
| 	  else
 | |
| 	    canonicalizeM l'
 | |
| 
 | |
|   (* monadic style algebraic simplifier for the dag *)
 | |
|   let rec algsimpM x =
 | |
|     memoizing lookupSimpM insertSimpM 
 | |
|       (function 
 | |
|  	| Num a -> snumM a
 | |
|  	| NaN _ as x -> makeNode x
 | |
|  	| Plus a -> 
 | |
|  	    mapM algsimpM a >>= splusM
 | |
|  	| Times (a, b) -> 
 | |
|  	    (algsimpM a >>= fun a' ->
 | |
|  	      algsimpM b >>= fun b' ->
 | |
|  		stimesM (a', b'))
 | |
|  	| CTimes (a, b) -> 
 | |
|  	    (algsimpM a >>= fun a' ->
 | |
|  	      algsimpM b >>= fun b' ->
 | |
| 		sctimesM (a', b'))
 | |
|  	| CTimesJ (a, b) -> 
 | |
|  	    (algsimpM a >>= fun a' ->
 | |
|  	      algsimpM b >>= fun b' ->
 | |
| 		sctimesjM (a', b'))
 | |
|  	| Uminus a -> 
 | |
|  	    algsimpM a >>= suminusM 
 | |
|  	| Store (v, a) ->
 | |
|  	    algsimpM a >>= fun a' ->
 | |
|  	      makeNode (Store (v, a'))
 | |
|  	| Load _ as x -> makeNode x)
 | |
|       x
 | |
| 
 | |
|    let initialTable = (empty, empty)
 | |
|    let simp_roots = mapM algsimpM
 | |
|    let algsimp = runM initialTable simp_roots
 | |
| end
 | |
| 
 | |
| (*************************************************************
 | |
|  * Network transposition algorithm
 | |
|  *************************************************************)
 | |
| module Transpose = struct
 | |
|   open Monads.StateMonad
 | |
|   open Monads.MemoMonad
 | |
|   open Littlesimp
 | |
| 
 | |
|   let fetchDuals = fetchState
 | |
|   let storeDuals = storeState
 | |
| 
 | |
|   let lookupDualsM key =
 | |
|     fetchDuals >>= fun table ->
 | |
|       returnM (node_lookup key table)
 | |
| 
 | |
|   let insertDualsM key value =
 | |
|     fetchDuals >>= fun table ->
 | |
|       storeDuals (node_insert key value table)
 | |
| 
 | |
|   let rec visit visited vtable parent_table = function
 | |
|       [] -> (visited, parent_table)
 | |
|     | node :: rest ->
 | |
| 	match node_lookup node vtable with
 | |
| 	| Some _ -> visit visited vtable parent_table rest
 | |
| 	| None ->
 | |
| 	    let children = match node with
 | |
| 	    | Store (v, n) -> [n]
 | |
| 	    | Plus l -> l
 | |
| 	    | Times (a, b) -> [a; b]
 | |
| 	    | CTimes (a, b) -> [a; b]
 | |
| 	    | CTimesJ (a, b) -> [a; b]
 | |
| 	    | Uminus x -> [x]
 | |
| 	    | _ -> []
 | |
| 	    in let rec loop t = function
 | |
| 		[] -> t
 | |
| 	      |	a :: rest ->
 | |
| 		  (match node_lookup a t with
 | |
| 		    None -> loop (node_insert a [node] t) rest
 | |
| 		  | Some c -> loop (node_insert a (node :: c) t) rest)
 | |
| 	    in 
 | |
| 	    (visit 
 | |
| 	       (node :: visited)
 | |
| 	       (node_insert node () vtable)
 | |
| 	       (loop parent_table children)
 | |
| 	       (children @ rest))
 | |
| 
 | |
|   let make_transposer parent_table =
 | |
|     let rec termM node candidate_parent = 
 | |
|       match candidate_parent with
 | |
|       |	Store (_, n) when n == node -> 
 | |
| 	  dualM candidate_parent >>= fun x' -> returnM [x']
 | |
|       | Plus (l) when List.memq node l -> 
 | |
| 	  dualM candidate_parent >>= fun x' -> returnM [x']
 | |
|       | Times (a, b) when b == node -> 
 | |
| 	  dualM candidate_parent >>= fun x' -> 
 | |
| 	    returnM [makeTimes (a, x')]
 | |
|       | CTimes (a, b) when b == node -> 
 | |
| 	  dualM candidate_parent >>= fun x' -> 
 | |
| 	    returnM [CTimes (a, x')]
 | |
|       | CTimesJ (a, b) when b == node -> 
 | |
| 	  dualM candidate_parent >>= fun x' -> 
 | |
| 	    returnM [CTimesJ (a, x')]
 | |
|       | Uminus n when n == node -> 
 | |
| 	  dualM candidate_parent >>= fun x' -> 
 | |
| 	    returnM [makeUminus x']
 | |
|       | _ -> returnM []
 | |
|     
 | |
|     and dualExpressionM this_node = 
 | |
|       mapM (termM this_node) 
 | |
| 	(match node_lookup this_node parent_table with
 | |
| 	| Some a -> a
 | |
| 	| None -> failwith "bug in dualExpressionM"
 | |
| 	) >>= fun l ->
 | |
| 	returnM (makePlus (List.flatten l))
 | |
| 
 | |
|     and dualM this_node =
 | |
|       memoizing lookupDualsM insertDualsM
 | |
| 	(function
 | |
| 	  | Load v as x -> 
 | |
| 	      if (Variable.is_constant v) then
 | |
| 		returnM (Load v)
 | |
| 	      else
 | |
| 		(dualExpressionM x >>= fun d ->
 | |
| 		  returnM (Store (v, d)))
 | |
| 	  | Store (v, x) -> returnM (Load v)
 | |
| 	  | x -> dualExpressionM x)
 | |
| 	this_node
 | |
| 
 | |
|     in dualM
 | |
| 
 | |
|   let is_store = function 
 | |
|     | Store _ -> true
 | |
|     | _ -> false
 | |
| 
 | |
|   let transpose dag = 
 | |
|     let _ = Util.info "begin transpose" in
 | |
|     let (all_nodes, parent_table) = 
 | |
|       visit [] Assoctable.empty Assoctable.empty dag in
 | |
|     let transposerM = make_transposer parent_table in
 | |
|     let mapTransposerM = mapM transposerM in
 | |
|     let duals = runM Assoctable.empty mapTransposerM all_nodes in
 | |
|     let roots = List.filter is_store duals in
 | |
|     let _ = Util.info "end transpose" in
 | |
|     roots
 | |
| end
 | |
| 
 | |
| 
 | |
| (*************************************************************
 | |
|  * Various dag statistics
 | |
|  *************************************************************)
 | |
| module Stats : sig
 | |
|   type complexity
 | |
|   val complexity : Expr.expr list -> complexity
 | |
|   val same_complexity : complexity -> complexity -> bool
 | |
|   val leq_complexity : complexity -> complexity -> bool
 | |
|   val to_string : complexity -> string
 | |
| end = struct
 | |
|   type complexity = int * int * int * int * int * int
 | |
|   let rec visit visited vtable = function
 | |
|       [] -> visited
 | |
|     | node :: rest ->
 | |
| 	match node_lookup node vtable with
 | |
| 	  Some _ -> visit visited vtable rest
 | |
| 	| None ->
 | |
| 	    let children = match node with
 | |
| 	      Store (v, n) -> [n]
 | |
| 	    | Plus l -> l
 | |
| 	    | Times (a, b) -> [a; b]
 | |
| 	    | Uminus x -> [x]
 | |
| 	    | _ -> []
 | |
| 	    in visit (node :: visited)
 | |
| 	      (node_insert node () vtable)
 | |
| 	      (children @ rest)
 | |
| 
 | |
|   let complexity dag = 
 | |
|     let rec loop (load, store, plus, times, uminus, num) = function 
 | |
|       	[] -> (load, store, plus, times, uminus, num)
 | |
|       | node :: rest ->
 | |
| 	  loop
 | |
| 	    (match node with
 | |
| 	    | Load _ -> (load + 1, store, plus, times, uminus, num)
 | |
| 	    | Store _ -> (load, store + 1, plus, times, uminus, num)
 | |
| 	    | Plus x -> (load, store, plus + (List.length x - 1), times, uminus, num)
 | |
| 	    | Times _ -> (load, store, plus, times + 1, uminus, num)
 | |
| 	    | Uminus _ -> (load, store, plus, times, uminus + 1, num)
 | |
| 	    | Num _ -> (load, store, plus, times, uminus, num + 1)
 | |
| 	    | CTimes _ -> (load, store, plus, times, uminus, num)
 | |
| 	    | CTimesJ _ -> (load, store, plus, times, uminus, num)
 | |
| 	    | NaN _ -> (load, store, plus, times, uminus, num))
 | |
| 	    rest
 | |
|     in let (l, s, p, t, u, n) = 
 | |
|       loop (0, 0, 0, 0, 0, 0) (visit [] Assoctable.empty dag)
 | |
|     in (l, s, p, t, u, n)
 | |
| 
 | |
|   let weight (l, s, p, t, u, n) =
 | |
|     l + s + 10 * p + 20 * t + u + n
 | |
| 
 | |
|   let same_complexity a b = weight a = weight b
 | |
|   let leq_complexity a b = weight a <= weight b
 | |
| 
 | |
|   let to_string (l, s, p, t, u, n) =
 | |
|     Printf.sprintf "ld=%d st=%d add=%d mul=%d uminus=%d num=%d\n"
 | |
| 		   l s p t u n
 | |
| 		   
 | |
| end    
 | |
| 
 | |
| (* simplify the dag *)
 | |
| let algsimp v = 
 | |
|   let rec simplification_loop v =
 | |
|     let () = Util.info "simplification step" in
 | |
|     let complexity = Stats.complexity v in
 | |
|     let () = Util.info ("complexity = " ^ (Stats.to_string complexity)) in
 | |
|     let v = (AlgSimp.algsimp @@ Transpose.transpose @@ 
 | |
| 	     AlgSimp.algsimp @@ Transpose.transpose) v in
 | |
|     let complexity' = Stats.complexity v in
 | |
|     let () = Util.info ("complexity = " ^ (Stats.to_string complexity')) in
 | |
|     if (Stats.leq_complexity complexity' complexity) then
 | |
|       let () = Util.info "end algsimp" in
 | |
|       v
 | |
|     else
 | |
|       simplification_loop v
 | |
| 
 | |
|   in
 | |
|   let () = Util.info "begin algsimp" in
 | |
|   let v = AlgSimp.algsimp v in
 | |
|   if !Magic.network_transposition then simplification_loop v else v
 | |
| 
 | 
