581 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			OCaml
		
	
	
	
	
	
		
		
			
		
	
	
			581 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			OCaml
		
	
	
	
	
	
| 
								 | 
							
								(*
							 | 
						||
| 
								 | 
							
								 * Copyright (c) 1997-1999 Massachusetts Institute of Technology
							 | 
						||
| 
								 | 
							
								 * Copyright (c) 2003, 2007-14 Matteo Frigo
							 | 
						||
| 
								 | 
							
								 * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
							 | 
						||
| 
								 | 
							
								 *
							 | 
						||
| 
								 | 
							
								 * This program is free software; you can redistribute it and/or modify
							 | 
						||
| 
								 | 
							
								 * it under the terms of the GNU General Public License as published by
							 | 
						||
| 
								 | 
							
								 * the Free Software Foundation; either version 2 of the License, or
							 | 
						||
| 
								 | 
							
								 * (at your option) any later version.
							 | 
						||
| 
								 | 
							
								 *
							 | 
						||
| 
								 | 
							
								 * This program is distributed in the hope that it will be useful,
							 | 
						||
| 
								 | 
							
								 * but WITHOUT ANY WARRANTY; without even the implied warranty of
							 | 
						||
| 
								 | 
							
								 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
							 | 
						||
| 
								 | 
							
								 * GNU General Public License for more details.
							 | 
						||
| 
								 | 
							
								 *
							 | 
						||
| 
								 | 
							
								 * You should have received a copy of the GNU General Public License
							 | 
						||
| 
								 | 
							
								 * along with this program; if not, write to the Free Software
							 | 
						||
| 
								 | 
							
								 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
							 | 
						||
| 
								 | 
							
								 *
							 | 
						||
| 
								 | 
							
								 *)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								open Util
							 | 
						||
| 
								 | 
							
								open Expr
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let node_insert x =  Assoctable.insert Expr.hash x
							 | 
						||
| 
								 | 
							
								let node_lookup x =  Assoctable.lookup Expr.hash (==) x
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								(*************************************************************
							 | 
						||
| 
								 | 
							
								 * Algebraic simplifier/elimination of common subexpressions
							 | 
						||
| 
								 | 
							
								 *************************************************************)
							 | 
						||
| 
								 | 
							
								module AlgSimp : sig 
							 | 
						||
| 
								 | 
							
								  val algsimp : expr list -> expr list
							 | 
						||
| 
								 | 
							
								end = struct
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  open Monads.StateMonad
							 | 
						||
| 
								 | 
							
								  open Monads.MemoMonad
							 | 
						||
| 
								 | 
							
								  open Assoctable
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let fetchSimp = 
							 | 
						||
| 
								 | 
							
								    fetchState >>= fun (s, _) -> returnM s
							 | 
						||
| 
								 | 
							
								  let storeSimp s =
							 | 
						||
| 
								 | 
							
								    fetchState >>= (fun (_, c) -> storeState (s, c))
							 | 
						||
| 
								 | 
							
								  let lookupSimpM key =
							 | 
						||
| 
								 | 
							
								    fetchSimp >>= fun table ->
							 | 
						||
| 
								 | 
							
								      returnM (node_lookup key table)
							 | 
						||
| 
								 | 
							
								  let insertSimpM key value =
							 | 
						||
| 
								 | 
							
								    fetchSimp >>= fun table ->
							 | 
						||
| 
								 | 
							
								      storeSimp (node_insert key value table)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let subset a b =
							 | 
						||
| 
								 | 
							
								    List.for_all (fun x -> List.exists (fun y -> x == y) b) a
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let structurallyEqualCSE a b = 
							 | 
						||
| 
								 | 
							
								    match (a, b) with
							 | 
						||
| 
								 | 
							
								    | (Num a, Num b) -> Number.equal a b
							 | 
						||
| 
								 | 
							
								    | (NaN a, NaN b) -> a == b
							 | 
						||
| 
								 | 
							
								    | (Load a, Load b) -> Variable.same a b
							 | 
						||
| 
								 | 
							
								    | (Times (a, a'), Times (b, b')) ->
							 | 
						||
| 
								 | 
							
								 	((a == b) && (a' == b')) ||
							 | 
						||
| 
								 | 
							
								 	((a == b') && (a' == b))
							 | 
						||
| 
								 | 
							
								    | (CTimes (a, a'), CTimes (b, b')) ->
							 | 
						||
| 
								 | 
							
								 	((a == b) && (a' == b')) ||
							 | 
						||
| 
								 | 
							
								 	((a == b') && (a' == b))
							 | 
						||
| 
								 | 
							
								    | (CTimesJ (a, a'), CTimesJ (b, b')) -> ((a == b) && (a' == b'))
							 | 
						||
| 
								 | 
							
								    | (Plus a, Plus b) -> subset a b && subset b a
							 | 
						||
| 
								 | 
							
								    | (Uminus a, Uminus b) -> (a == b)
							 | 
						||
| 
								 | 
							
								    | _ -> false
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let hashCSE x = 
							 | 
						||
| 
								 | 
							
								    if (!Magic.randomized_cse) then
							 | 
						||
| 
								 | 
							
								      Oracle.hash x
							 | 
						||
| 
								 | 
							
								    else
							 | 
						||
| 
								 | 
							
								      Expr.hash x
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let equalCSE a b = 
							 | 
						||
| 
								 | 
							
								    if (!Magic.randomized_cse) then
							 | 
						||
| 
								 | 
							
								      (structurallyEqualCSE a b || Oracle.likely_equal a b)
							 | 
						||
| 
								 | 
							
								    else
							 | 
						||
| 
								 | 
							
								      structurallyEqualCSE a b
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let fetchCSE = 
							 | 
						||
| 
								 | 
							
								    fetchState >>= fun (_, c) -> returnM c
							 | 
						||
| 
								 | 
							
								  let storeCSE c =
							 | 
						||
| 
								 | 
							
								    fetchState >>= (fun (s, _) -> storeState (s, c))
							 | 
						||
| 
								 | 
							
								  let lookupCSEM key =
							 | 
						||
| 
								 | 
							
								    fetchCSE >>= fun table ->
							 | 
						||
| 
								 | 
							
								      returnM (Assoctable.lookup hashCSE equalCSE key table)
							 | 
						||
| 
								 | 
							
								  let insertCSEM key value =
							 | 
						||
| 
								 | 
							
								    fetchCSE >>= fun table ->
							 | 
						||
| 
								 | 
							
								      storeCSE (Assoctable.insert hashCSE key value table)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  (* memoize both x and Uminus x (unless x is already negated) *) 
							 | 
						||
| 
								 | 
							
								  let identityM x =
							 | 
						||
| 
								 | 
							
								    let memo x = memoizing lookupCSEM insertCSEM returnM x in
							 | 
						||
| 
								 | 
							
								    match x with
							 | 
						||
| 
								 | 
							
									Uminus _ -> memo x 
							 | 
						||
| 
								 | 
							
								      |	_ -> memo x >>= fun x' -> memo (Uminus x') >> returnM x'
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let makeNode = identityM
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  (* simplifiers for various kinds of nodes *)
							 | 
						||
| 
								 | 
							
								  let rec snumM = function
							 | 
						||
| 
								 | 
							
								      n when Number.is_zero n -> 
							 | 
						||
| 
								 | 
							
									makeNode (Num (Number.zero))
							 | 
						||
| 
								 | 
							
								    | n when Number.negative n -> 
							 | 
						||
| 
								 | 
							
									makeNode (Num (Number.negate n)) >>= suminusM
							 | 
						||
| 
								 | 
							
								    | n -> makeNode (Num n)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  and suminusM = function
							 | 
						||
| 
								 | 
							
								      Uminus x -> makeNode x
							 | 
						||
| 
								 | 
							
								    | Num a when (Number.is_zero a) -> snumM Number.zero
							 | 
						||
| 
								 | 
							
								    | a -> makeNode (Uminus a)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  and stimesM = function 
							 | 
						||
| 
								 | 
							
								    | (Uminus a, b) -> stimesM (a, b) >>= suminusM
							 | 
						||
| 
								 | 
							
								    | (a, Uminus b) -> stimesM (a, b) >>= suminusM
							 | 
						||
| 
								 | 
							
								    | (NaN I, CTimes (a, b)) -> stimesM (NaN I, b) >>= 
							 | 
						||
| 
								 | 
							
									fun ib -> sctimesM (a, ib)
							 | 
						||
| 
								 | 
							
								    | (NaN I, CTimesJ (a, b)) -> stimesM (NaN I, b) >>= 
							 | 
						||
| 
								 | 
							
									fun ib -> sctimesjM (a, ib)
							 | 
						||
| 
								 | 
							
								    | (Num a, Num b) -> snumM (Number.mul a b)
							 | 
						||
| 
								 | 
							
								    | (Num a, Times (Num b, c)) -> 
							 | 
						||
| 
								 | 
							
									snumM (Number.mul a b) >>= fun x -> stimesM (x, c)
							 | 
						||
| 
								 | 
							
								    | (Num a, b) when Number.is_zero a -> snumM Number.zero
							 | 
						||
| 
								 | 
							
								    | (Num a, b) when Number.is_one a -> makeNode b
							 | 
						||
| 
								 | 
							
								    | (Num a, b) when Number.is_mone a -> suminusM b
							 | 
						||
| 
								 | 
							
								    | (a, b) when is_known_constant b && not (is_known_constant a) -> 
							 | 
						||
| 
								 | 
							
									stimesM (b, a)
							 | 
						||
| 
								 | 
							
								    | (a, b) -> makeNode (Times (a, b))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  and sctimesM = function 
							 | 
						||
| 
								 | 
							
								    | (Uminus a, b) -> sctimesM (a, b) >>= suminusM
							 | 
						||
| 
								 | 
							
								    | (a, Uminus b) -> sctimesM (a, b) >>= suminusM
							 | 
						||
| 
								 | 
							
								    | (a, b) -> makeNode (CTimes (a, b))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  and sctimesjM = function 
							 | 
						||
| 
								 | 
							
								    | (Uminus a, b) -> sctimesjM (a, b) >>= suminusM
							 | 
						||
| 
								 | 
							
								    | (a, Uminus b) -> sctimesjM (a, b) >>= suminusM
							 | 
						||
| 
								 | 
							
								    | (a, b) -> makeNode (CTimesJ (a, b))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  and reduce_sumM x = match x with
							 | 
						||
| 
								 | 
							
								    [] -> returnM []
							 | 
						||
| 
								 | 
							
								  | [Num a] -> 
							 | 
						||
| 
								 | 
							
								      if (Number.is_zero a) then 
							 | 
						||
| 
								 | 
							
									returnM [] 
							 | 
						||
| 
								 | 
							
								      else returnM x
							 | 
						||
| 
								 | 
							
								  | [Uminus (Num a)] -> 
							 | 
						||
| 
								 | 
							
								      if (Number.is_zero a) then 
							 | 
						||
| 
								 | 
							
									returnM [] 
							 | 
						||
| 
								 | 
							
								      else returnM x
							 | 
						||
| 
								 | 
							
								  | (Num a) :: (Num b) :: s -> 
							 | 
						||
| 
								 | 
							
								      snumM (Number.add a b) >>= fun x ->
							 | 
						||
| 
								 | 
							
									reduce_sumM (x :: s)
							 | 
						||
| 
								 | 
							
								  | (Num a) :: (Uminus (Num b)) :: s -> 
							 | 
						||
| 
								 | 
							
								      snumM (Number.sub a b) >>= fun x ->
							 | 
						||
| 
								 | 
							
									reduce_sumM (x :: s)
							 | 
						||
| 
								 | 
							
								  | (Uminus (Num a)) :: (Num b) :: s -> 
							 | 
						||
| 
								 | 
							
								      snumM (Number.sub b a) >>= fun x ->
							 | 
						||
| 
								 | 
							
									reduce_sumM (x :: s)
							 | 
						||
| 
								 | 
							
								  | (Uminus (Num a)) :: (Uminus (Num b)) :: s -> 
							 | 
						||
| 
								 | 
							
								      snumM (Number.add a b) >>= 
							 | 
						||
| 
								 | 
							
								      suminusM >>= fun x ->
							 | 
						||
| 
								 | 
							
									reduce_sumM (x :: s)
							 | 
						||
| 
								 | 
							
								  | ((Num _) as a) :: b :: s -> reduce_sumM (b :: a :: s)
							 | 
						||
| 
								 | 
							
								  | ((Uminus (Num _)) as a) :: b :: s -> reduce_sumM (b :: a :: s)
							 | 
						||
| 
								 | 
							
								  | a :: s -> 
							 | 
						||
| 
								 | 
							
								      reduce_sumM s >>= fun s' -> returnM (a :: s')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  and collectible1 = function
							 | 
						||
| 
								 | 
							
								    | NaN _ -> false
							 | 
						||
| 
								 | 
							
								    | Uminus x -> collectible1 x
							 | 
						||
| 
								 | 
							
								    | _ -> true
							 | 
						||
| 
								 | 
							
								  and collectible (a, b) = collectible1 a
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  (* collect common factors: ax + bx -> (a+b)x *)
							 | 
						||
| 
								 | 
							
								  and collectM which x = 
							 | 
						||
| 
								 | 
							
								    let rec findCoeffM which = function
							 | 
						||
| 
								 | 
							
								      |	Times (a, b) when collectible (which (a, b)) -> returnM (which (a, b))
							 | 
						||
| 
								 | 
							
								      | Uminus x -> 
							 | 
						||
| 
								 | 
							
									  findCoeffM which x >>= fun (coeff, b) ->
							 | 
						||
| 
								 | 
							
									    suminusM coeff >>= fun mcoeff ->
							 | 
						||
| 
								 | 
							
								 	      returnM (mcoeff, b)
							 | 
						||
| 
								 | 
							
								      | x -> snumM Number.one >>= fun one -> returnM (one, x)
							 | 
						||
| 
								 | 
							
								    and separateM xpr = function
							 | 
						||
| 
								 | 
							
								 	[] -> returnM ([], [])
							 | 
						||
| 
								 | 
							
								      |	a :: b ->
							 | 
						||
| 
								 | 
							
								 	  separateM xpr b >>= fun (w, wo) ->
							 | 
						||
| 
								 | 
							
									    (* try first factor *)
							 | 
						||
| 
								 | 
							
								 	    findCoeffM (fun (a, b) -> (a, b)) a >>= fun (c, x) ->
							 | 
						||
| 
								 | 
							
								 	      if (xpr == x) && collectible (c, x) then returnM (c :: w, wo)
							 | 
						||
| 
								 | 
							
								 	      else
							 | 
						||
| 
								 | 
							
									      (* try second factor *)
							 | 
						||
| 
								 | 
							
								 		findCoeffM (fun (a, b) -> (b, a)) a >>= fun (c, x) ->
							 | 
						||
| 
								 | 
							
								 		  if (xpr == x) && collectible (c, x) then returnM (c :: w, wo)
							 | 
						||
| 
								 | 
							
								 		  else returnM (w, a :: wo)
							 | 
						||
| 
								 | 
							
								    in match x with
							 | 
						||
| 
								 | 
							
								      [] -> returnM x
							 | 
						||
| 
								 | 
							
								    | [a] -> returnM x
							 | 
						||
| 
								 | 
							
								    | a :: b ->
							 | 
						||
| 
								 | 
							
								 	findCoeffM which a >>= fun (_, xpr) ->
							 | 
						||
| 
								 | 
							
								 	  separateM xpr x >>= fun (w, wo) ->
							 | 
						||
| 
								 | 
							
								 	    collectM which wo >>= fun wo' ->
							 | 
						||
| 
								 | 
							
								 	      splusM w >>= fun w' ->
							 | 
						||
| 
								 | 
							
								 		stimesM (w', xpr) >>= fun t' ->
							 | 
						||
| 
								 | 
							
								 		  returnM (t':: wo')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  and mangleSumM x = returnM x
							 | 
						||
| 
								 | 
							
								      >>= reduce_sumM 
							 | 
						||
| 
								 | 
							
								      >>= collectM (fun (a, b) -> (a, b))
							 | 
						||
| 
								 | 
							
								      >>= collectM (fun (a, b) -> (b, a))
							 | 
						||
| 
								 | 
							
								      >>= reduce_sumM 
							 | 
						||
| 
								 | 
							
								      >>= deepCollectM !Magic.deep_collect_depth
							 | 
						||
| 
								 | 
							
								      >>= reduce_sumM
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  and reorder_uminus = function  (* push all Uminuses to the end *)
							 | 
						||
| 
								 | 
							
								      [] -> []
							 | 
						||
| 
								 | 
							
								    | ((Uminus _) as a' :: b) -> (reorder_uminus b) @ [a']
							 | 
						||
| 
								 | 
							
								    | (a :: b) -> a :: (reorder_uminus b)                      
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  and canonicalizeM = function 
							 | 
						||
| 
								 | 
							
								      [] -> snumM Number.zero
							 | 
						||
| 
								 | 
							
								    | [a] -> makeNode a                    (* one term *)
							 | 
						||
| 
								 | 
							
								    | a -> generateFusedMultAddM (reorder_uminus a)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  and generateFusedMultAddM = 
							 | 
						||
| 
								 | 
							
								    let rec is_multiplication = function
							 | 
						||
| 
								 | 
							
								      | Times (Num a, b) -> true
							 | 
						||
| 
								 | 
							
								      | Uminus (Times (Num a, b)) -> true
							 | 
						||
| 
								 | 
							
								      | _ -> false
							 | 
						||
| 
								 | 
							
								    and separate = function
							 | 
						||
| 
								 | 
							
									[] -> ([], [], Number.zero)
							 | 
						||
| 
								 | 
							
								      | (Times (Num a, b)) as this :: c -> 
							 | 
						||
| 
								 | 
							
									  let (x, y, max) = separate c in
							 | 
						||
| 
								 | 
							
									  let newmax = if (Number.greater a max) then a else max in
							 | 
						||
| 
								 | 
							
									  (this :: x, y, newmax)
							 | 
						||
| 
								 | 
							
								      | (Uminus (Times (Num a, b))) as this :: c -> 
							 | 
						||
| 
								 | 
							
									  let (x, y, max) = separate c in
							 | 
						||
| 
								 | 
							
									  let newmax = if (Number.greater a max) then a else max in
							 | 
						||
| 
								 | 
							
									  (this :: x, y, newmax)
							 | 
						||
| 
								 | 
							
								      | this :: c ->
							 | 
						||
| 
								 | 
							
									  let (x, y, max) = separate c in
							 | 
						||
| 
								 | 
							
									  (x, this :: y, max)
							 | 
						||
| 
								 | 
							
								    in fun l ->
							 | 
						||
| 
								 | 
							
								      if !Magic.enable_fma && count is_multiplication l >= 2 then
							 | 
						||
| 
								 | 
							
									let (w, wo, max) = separate l in
							 | 
						||
| 
								 | 
							
									snumM (Number.div Number.one max) >>= fun invmax' ->
							 | 
						||
| 
								 | 
							
									  snumM max >>= fun max' ->
							 | 
						||
| 
								 | 
							
									    mapM (fun x -> stimesM (invmax', x)) w >>= splusM >>= fun pw' ->
							 | 
						||
| 
								 | 
							
									      stimesM (max', pw') >>= fun mw' ->
							 | 
						||
| 
								 | 
							
										splusM (wo @ [mw'])
							 | 
						||
| 
								 | 
							
								      else 
							 | 
						||
| 
								 | 
							
									makeNode (Plus l)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  and negative = function
							 | 
						||
| 
								 | 
							
								      Uminus _ -> true
							 | 
						||
| 
								 | 
							
								    | _ -> false
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  (*
							 | 
						||
| 
								 | 
							
								   * simplify patterns of the form
							 | 
						||
| 
								 | 
							
								   *
							 | 
						||
| 
								 | 
							
								   *  ((c_1 * a + ...) + ...) +  (c_2 * a + ...)
							 | 
						||
| 
								 | 
							
								   *
							 | 
						||
| 
								 | 
							
								   * The pattern includes arbitrary coefficients and minus signs.
							 | 
						||
| 
								 | 
							
								   * A common case of this pattern is the butterfly
							 | 
						||
| 
								 | 
							
								   *   (a + b) + (a - b)
							 | 
						||
| 
								 | 
							
								   *   (a + b) - (a - b)
							 | 
						||
| 
								 | 
							
								   *)
							 | 
						||
| 
								 | 
							
								  (* this whole procedure needs much more thought *)
							 | 
						||
| 
								 | 
							
								  and deepCollectM maxdepth l =
							 | 
						||
| 
								 | 
							
								    let rec findTerms depth x = match x with
							 | 
						||
| 
								 | 
							
								      | Uminus x -> findTerms depth x
							 | 
						||
| 
								 | 
							
								      |	Times (Num _, b) -> (findTerms (depth - 1) b)
							 | 
						||
| 
								 | 
							
								      |	Plus l when depth > 0 ->
							 | 
						||
| 
								 | 
							
									  x :: List.flatten (List.map (findTerms (depth - 1)) l)
							 | 
						||
| 
								 | 
							
								      |	x -> [x]
							 | 
						||
| 
								 | 
							
								    and duplicates = function
							 | 
						||
| 
								 | 
							
									[] -> []
							 | 
						||
| 
								 | 
							
								      |	a :: b -> if List.memq a b then a :: duplicates b
							 | 
						||
| 
								 | 
							
								      else duplicates b
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    in let rec splitDuplicates depth d x =
							 | 
						||
| 
								 | 
							
								      if (List.memq x d) then 
							 | 
						||
| 
								 | 
							
									snumM (Number.zero) >>= fun zero ->
							 | 
						||
| 
								 | 
							
									  returnM (zero, x)
							 | 
						||
| 
								 | 
							
								      else match x with
							 | 
						||
| 
								 | 
							
								      |	Times (a, b) ->
							 | 
						||
| 
								 | 
							
									  splitDuplicates (depth - 1) d a >>= fun (a', xa) ->
							 | 
						||
| 
								 | 
							
									    splitDuplicates (depth - 1) d b >>= fun (b', xb) ->
							 | 
						||
| 
								 | 
							
									      stimesM (a', b') >>= fun ab ->
							 | 
						||
| 
								 | 
							
										stimesM (a, xb) >>= fun xb' ->
							 | 
						||
| 
								 | 
							
										  stimesM (xa, b) >>= fun xa' ->
							 | 
						||
| 
								 | 
							
										    stimesM (xa, xb) >>= fun xab ->
							 | 
						||
| 
								 | 
							
										      splusM [xa'; xb'; xab] >>= fun x ->
							 | 
						||
| 
								 | 
							
											returnM (ab, x)
							 | 
						||
| 
								 | 
							
								      | Uminus a -> 
							 | 
						||
| 
								 | 
							
									  splitDuplicates depth d a >>= fun (x, y) ->
							 | 
						||
| 
								 | 
							
									    suminusM x >>= fun ux -> 
							 | 
						||
| 
								 | 
							
									      suminusM y >>= fun uy -> 
							 | 
						||
| 
								 | 
							
										returnM (ux, uy)
							 | 
						||
| 
								 | 
							
								      |	Plus l when depth > 0 -> 
							 | 
						||
| 
								 | 
							
									  mapM (splitDuplicates (depth - 1) d) l >>= fun ld ->
							 | 
						||
| 
								 | 
							
									    let (l', d') = List.split ld in
							 | 
						||
| 
								 | 
							
									    splusM l' >>= fun p ->
							 | 
						||
| 
								 | 
							
									      splusM d' >>= fun d'' ->
							 | 
						||
| 
								 | 
							
									      returnM (p, d'')
							 | 
						||
| 
								 | 
							
								      |	x -> 
							 | 
						||
| 
								 | 
							
									  snumM (Number.zero) >>= fun zero' ->
							 | 
						||
| 
								 | 
							
									    returnM (x, zero')
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    in let l' = List.flatten (List.map (findTerms maxdepth) l)
							 | 
						||
| 
								 | 
							
								    in match duplicates l' with
							 | 
						||
| 
								 | 
							
								    | [] -> returnM l
							 | 
						||
| 
								 | 
							
								    | d ->
							 | 
						||
| 
								 | 
							
									mapM (splitDuplicates maxdepth d) l >>= fun ld ->
							 | 
						||
| 
								 | 
							
									  let (l', d') = List.split ld in
							 | 
						||
| 
								 | 
							
									  splusM l' >>= fun l'' ->
							 | 
						||
| 
								 | 
							
									    let rec flattenPlusM = function
							 | 
						||
| 
								 | 
							
									      | Plus l -> returnM l
							 | 
						||
| 
								 | 
							
									      | Uminus x ->
							 | 
						||
| 
								 | 
							
										  flattenPlusM x >>= mapM suminusM
							 | 
						||
| 
								 | 
							
									      | x -> returnM [x]
							 | 
						||
| 
								 | 
							
									    in
							 | 
						||
| 
								 | 
							
									    mapM flattenPlusM d' >>= fun d'' ->
							 | 
						||
| 
								 | 
							
									      splusM (List.flatten d'') >>= fun d''' ->
							 | 
						||
| 
								 | 
							
										mangleSumM [l''; d''']
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  and splusM l =
							 | 
						||
| 
								 | 
							
								    let fma_heuristics x = 
							 | 
						||
| 
								 | 
							
								      if !Magic.enable_fma then 
							 | 
						||
| 
								 | 
							
									match x with
							 | 
						||
| 
								 | 
							
									| [Uminus (Times _); Times _] -> Some false
							 | 
						||
| 
								 | 
							
									| [Times _; Uminus (Times _)] -> Some false
							 | 
						||
| 
								 | 
							
									| [Uminus (_); Times _] -> Some true
							 | 
						||
| 
								 | 
							
									| [Times _; Uminus (Plus _)] -> Some true
							 | 
						||
| 
								 | 
							
									| [_; Uminus (Times _)] -> Some false
							 | 
						||
| 
								 | 
							
									| [Uminus (Times _); _] -> Some false
							 | 
						||
| 
								 | 
							
									| _ -> None
							 | 
						||
| 
								 | 
							
								      else
							 | 
						||
| 
								 | 
							
									None
							 | 
						||
| 
								 | 
							
								    in
							 | 
						||
| 
								 | 
							
								    mangleSumM l >>=  fun l' ->
							 | 
						||
| 
								 | 
							
								      (* no terms are negative.  Don't do anything *)
							 | 
						||
| 
								 | 
							
								      if not (List.exists negative l') then
							 | 
						||
| 
								 | 
							
									canonicalizeM l'
							 | 
						||
| 
								 | 
							
								      (* all terms are negative.  Negate them all and collect the minus sign *)
							 | 
						||
| 
								 | 
							
								      else if List.for_all negative l' then
							 | 
						||
| 
								 | 
							
									mapM suminusM l' >>= splusM >>= suminusM
							 | 
						||
| 
								 | 
							
								      else match fma_heuristics l' with
							 | 
						||
| 
								 | 
							
								      |	Some true -> mapM suminusM l' >>= splusM >>= suminusM
							 | 
						||
| 
								 | 
							
								      |	Some false -> canonicalizeM l'
							 | 
						||
| 
								 | 
							
								      |	None ->
							 | 
						||
| 
								 | 
							
								         (* Ask the Oracle for the canonical form *)
							 | 
						||
| 
								 | 
							
									  if (not !Magic.randomized_cse) &&
							 | 
						||
| 
								 | 
							
									    Oracle.should_flip_sign (Plus l') then
							 | 
						||
| 
								 | 
							
									    mapM suminusM l' >>= splusM >>= suminusM
							 | 
						||
| 
								 | 
							
									  else
							 | 
						||
| 
								 | 
							
									    canonicalizeM l'
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  (* monadic style algebraic simplifier for the dag *)
							 | 
						||
| 
								 | 
							
								  let rec algsimpM x =
							 | 
						||
| 
								 | 
							
								    memoizing lookupSimpM insertSimpM 
							 | 
						||
| 
								 | 
							
								      (function 
							 | 
						||
| 
								 | 
							
								 	| Num a -> snumM a
							 | 
						||
| 
								 | 
							
								 	| NaN _ as x -> makeNode x
							 | 
						||
| 
								 | 
							
								 	| Plus a -> 
							 | 
						||
| 
								 | 
							
								 	    mapM algsimpM a >>= splusM
							 | 
						||
| 
								 | 
							
								 	| Times (a, b) -> 
							 | 
						||
| 
								 | 
							
								 	    (algsimpM a >>= fun a' ->
							 | 
						||
| 
								 | 
							
								 	      algsimpM b >>= fun b' ->
							 | 
						||
| 
								 | 
							
								 		stimesM (a', b'))
							 | 
						||
| 
								 | 
							
								 	| CTimes (a, b) -> 
							 | 
						||
| 
								 | 
							
								 	    (algsimpM a >>= fun a' ->
							 | 
						||
| 
								 | 
							
								 	      algsimpM b >>= fun b' ->
							 | 
						||
| 
								 | 
							
										sctimesM (a', b'))
							 | 
						||
| 
								 | 
							
								 	| CTimesJ (a, b) -> 
							 | 
						||
| 
								 | 
							
								 	    (algsimpM a >>= fun a' ->
							 | 
						||
| 
								 | 
							
								 	      algsimpM b >>= fun b' ->
							 | 
						||
| 
								 | 
							
										sctimesjM (a', b'))
							 | 
						||
| 
								 | 
							
								 	| Uminus a -> 
							 | 
						||
| 
								 | 
							
								 	    algsimpM a >>= suminusM 
							 | 
						||
| 
								 | 
							
								 	| Store (v, a) ->
							 | 
						||
| 
								 | 
							
								 	    algsimpM a >>= fun a' ->
							 | 
						||
| 
								 | 
							
								 	      makeNode (Store (v, a'))
							 | 
						||
| 
								 | 
							
								 	| Load _ as x -> makeNode x)
							 | 
						||
| 
								 | 
							
								      x
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								   let initialTable = (empty, empty)
							 | 
						||
| 
								 | 
							
								   let simp_roots = mapM algsimpM
							 | 
						||
| 
								 | 
							
								   let algsimp = runM initialTable simp_roots
							 | 
						||
| 
								 | 
							
								end
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								(*************************************************************
							 | 
						||
| 
								 | 
							
								 * Network transposition algorithm
							 | 
						||
| 
								 | 
							
								 *************************************************************)
							 | 
						||
| 
								 | 
							
								module Transpose = struct
							 | 
						||
| 
								 | 
							
								  open Monads.StateMonad
							 | 
						||
| 
								 | 
							
								  open Monads.MemoMonad
							 | 
						||
| 
								 | 
							
								  open Littlesimp
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let fetchDuals = fetchState
							 | 
						||
| 
								 | 
							
								  let storeDuals = storeState
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let lookupDualsM key =
							 | 
						||
| 
								 | 
							
								    fetchDuals >>= fun table ->
							 | 
						||
| 
								 | 
							
								      returnM (node_lookup key table)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let insertDualsM key value =
							 | 
						||
| 
								 | 
							
								    fetchDuals >>= fun table ->
							 | 
						||
| 
								 | 
							
								      storeDuals (node_insert key value table)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let rec visit visited vtable parent_table = function
							 | 
						||
| 
								 | 
							
								      [] -> (visited, parent_table)
							 | 
						||
| 
								 | 
							
								    | node :: rest ->
							 | 
						||
| 
								 | 
							
									match node_lookup node vtable with
							 | 
						||
| 
								 | 
							
									| Some _ -> visit visited vtable parent_table rest
							 | 
						||
| 
								 | 
							
									| None ->
							 | 
						||
| 
								 | 
							
									    let children = match node with
							 | 
						||
| 
								 | 
							
									    | Store (v, n) -> [n]
							 | 
						||
| 
								 | 
							
									    | Plus l -> l
							 | 
						||
| 
								 | 
							
									    | Times (a, b) -> [a; b]
							 | 
						||
| 
								 | 
							
									    | CTimes (a, b) -> [a; b]
							 | 
						||
| 
								 | 
							
									    | CTimesJ (a, b) -> [a; b]
							 | 
						||
| 
								 | 
							
									    | Uminus x -> [x]
							 | 
						||
| 
								 | 
							
									    | _ -> []
							 | 
						||
| 
								 | 
							
									    in let rec loop t = function
							 | 
						||
| 
								 | 
							
										[] -> t
							 | 
						||
| 
								 | 
							
									      |	a :: rest ->
							 | 
						||
| 
								 | 
							
										  (match node_lookup a t with
							 | 
						||
| 
								 | 
							
										    None -> loop (node_insert a [node] t) rest
							 | 
						||
| 
								 | 
							
										  | Some c -> loop (node_insert a (node :: c) t) rest)
							 | 
						||
| 
								 | 
							
									    in 
							 | 
						||
| 
								 | 
							
									    (visit 
							 | 
						||
| 
								 | 
							
									       (node :: visited)
							 | 
						||
| 
								 | 
							
									       (node_insert node () vtable)
							 | 
						||
| 
								 | 
							
									       (loop parent_table children)
							 | 
						||
| 
								 | 
							
									       (children @ rest))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let make_transposer parent_table =
							 | 
						||
| 
								 | 
							
								    let rec termM node candidate_parent = 
							 | 
						||
| 
								 | 
							
								      match candidate_parent with
							 | 
						||
| 
								 | 
							
								      |	Store (_, n) when n == node -> 
							 | 
						||
| 
								 | 
							
									  dualM candidate_parent >>= fun x' -> returnM [x']
							 | 
						||
| 
								 | 
							
								      | Plus (l) when List.memq node l -> 
							 | 
						||
| 
								 | 
							
									  dualM candidate_parent >>= fun x' -> returnM [x']
							 | 
						||
| 
								 | 
							
								      | Times (a, b) when b == node -> 
							 | 
						||
| 
								 | 
							
									  dualM candidate_parent >>= fun x' -> 
							 | 
						||
| 
								 | 
							
									    returnM [makeTimes (a, x')]
							 | 
						||
| 
								 | 
							
								      | CTimes (a, b) when b == node -> 
							 | 
						||
| 
								 | 
							
									  dualM candidate_parent >>= fun x' -> 
							 | 
						||
| 
								 | 
							
									    returnM [CTimes (a, x')]
							 | 
						||
| 
								 | 
							
								      | CTimesJ (a, b) when b == node -> 
							 | 
						||
| 
								 | 
							
									  dualM candidate_parent >>= fun x' -> 
							 | 
						||
| 
								 | 
							
									    returnM [CTimesJ (a, x')]
							 | 
						||
| 
								 | 
							
								      | Uminus n when n == node -> 
							 | 
						||
| 
								 | 
							
									  dualM candidate_parent >>= fun x' -> 
							 | 
						||
| 
								 | 
							
									    returnM [makeUminus x']
							 | 
						||
| 
								 | 
							
								      | _ -> returnM []
							 | 
						||
| 
								 | 
							
								    
							 | 
						||
| 
								 | 
							
								    and dualExpressionM this_node = 
							 | 
						||
| 
								 | 
							
								      mapM (termM this_node) 
							 | 
						||
| 
								 | 
							
									(match node_lookup this_node parent_table with
							 | 
						||
| 
								 | 
							
									| Some a -> a
							 | 
						||
| 
								 | 
							
									| None -> failwith "bug in dualExpressionM"
							 | 
						||
| 
								 | 
							
									) >>= fun l ->
							 | 
						||
| 
								 | 
							
									returnM (makePlus (List.flatten l))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    and dualM this_node =
							 | 
						||
| 
								 | 
							
								      memoizing lookupDualsM insertDualsM
							 | 
						||
| 
								 | 
							
									(function
							 | 
						||
| 
								 | 
							
									  | Load v as x -> 
							 | 
						||
| 
								 | 
							
									      if (Variable.is_constant v) then
							 | 
						||
| 
								 | 
							
										returnM (Load v)
							 | 
						||
| 
								 | 
							
									      else
							 | 
						||
| 
								 | 
							
										(dualExpressionM x >>= fun d ->
							 | 
						||
| 
								 | 
							
										  returnM (Store (v, d)))
							 | 
						||
| 
								 | 
							
									  | Store (v, x) -> returnM (Load v)
							 | 
						||
| 
								 | 
							
									  | x -> dualExpressionM x)
							 | 
						||
| 
								 | 
							
									this_node
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    in dualM
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let is_store = function 
							 | 
						||
| 
								 | 
							
								    | Store _ -> true
							 | 
						||
| 
								 | 
							
								    | _ -> false
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let transpose dag = 
							 | 
						||
| 
								 | 
							
								    let _ = Util.info "begin transpose" in
							 | 
						||
| 
								 | 
							
								    let (all_nodes, parent_table) = 
							 | 
						||
| 
								 | 
							
								      visit [] Assoctable.empty Assoctable.empty dag in
							 | 
						||
| 
								 | 
							
								    let transposerM = make_transposer parent_table in
							 | 
						||
| 
								 | 
							
								    let mapTransposerM = mapM transposerM in
							 | 
						||
| 
								 | 
							
								    let duals = runM Assoctable.empty mapTransposerM all_nodes in
							 | 
						||
| 
								 | 
							
								    let roots = List.filter is_store duals in
							 | 
						||
| 
								 | 
							
								    let _ = Util.info "end transpose" in
							 | 
						||
| 
								 | 
							
								    roots
							 | 
						||
| 
								 | 
							
								end
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								(*************************************************************
							 | 
						||
| 
								 | 
							
								 * Various dag statistics
							 | 
						||
| 
								 | 
							
								 *************************************************************)
							 | 
						||
| 
								 | 
							
								module Stats : sig
							 | 
						||
| 
								 | 
							
								  type complexity
							 | 
						||
| 
								 | 
							
								  val complexity : Expr.expr list -> complexity
							 | 
						||
| 
								 | 
							
								  val same_complexity : complexity -> complexity -> bool
							 | 
						||
| 
								 | 
							
								  val leq_complexity : complexity -> complexity -> bool
							 | 
						||
| 
								 | 
							
								  val to_string : complexity -> string
							 | 
						||
| 
								 | 
							
								end = struct
							 | 
						||
| 
								 | 
							
								  type complexity = int * int * int * int * int * int
							 | 
						||
| 
								 | 
							
								  let rec visit visited vtable = function
							 | 
						||
| 
								 | 
							
								      [] -> visited
							 | 
						||
| 
								 | 
							
								    | node :: rest ->
							 | 
						||
| 
								 | 
							
									match node_lookup node vtable with
							 | 
						||
| 
								 | 
							
									  Some _ -> visit visited vtable rest
							 | 
						||
| 
								 | 
							
									| None ->
							 | 
						||
| 
								 | 
							
									    let children = match node with
							 | 
						||
| 
								 | 
							
									      Store (v, n) -> [n]
							 | 
						||
| 
								 | 
							
									    | Plus l -> l
							 | 
						||
| 
								 | 
							
									    | Times (a, b) -> [a; b]
							 | 
						||
| 
								 | 
							
									    | Uminus x -> [x]
							 | 
						||
| 
								 | 
							
									    | _ -> []
							 | 
						||
| 
								 | 
							
									    in visit (node :: visited)
							 | 
						||
| 
								 | 
							
									      (node_insert node () vtable)
							 | 
						||
| 
								 | 
							
									      (children @ rest)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let complexity dag = 
							 | 
						||
| 
								 | 
							
								    let rec loop (load, store, plus, times, uminus, num) = function 
							 | 
						||
| 
								 | 
							
								      	[] -> (load, store, plus, times, uminus, num)
							 | 
						||
| 
								 | 
							
								      | node :: rest ->
							 | 
						||
| 
								 | 
							
									  loop
							 | 
						||
| 
								 | 
							
									    (match node with
							 | 
						||
| 
								 | 
							
									    | Load _ -> (load + 1, store, plus, times, uminus, num)
							 | 
						||
| 
								 | 
							
									    | Store _ -> (load, store + 1, plus, times, uminus, num)
							 | 
						||
| 
								 | 
							
									    | Plus x -> (load, store, plus + (List.length x - 1), times, uminus, num)
							 | 
						||
| 
								 | 
							
									    | Times _ -> (load, store, plus, times + 1, uminus, num)
							 | 
						||
| 
								 | 
							
									    | Uminus _ -> (load, store, plus, times, uminus + 1, num)
							 | 
						||
| 
								 | 
							
									    | Num _ -> (load, store, plus, times, uminus, num + 1)
							 | 
						||
| 
								 | 
							
									    | CTimes _ -> (load, store, plus, times, uminus, num)
							 | 
						||
| 
								 | 
							
									    | CTimesJ _ -> (load, store, plus, times, uminus, num)
							 | 
						||
| 
								 | 
							
									    | NaN _ -> (load, store, plus, times, uminus, num))
							 | 
						||
| 
								 | 
							
									    rest
							 | 
						||
| 
								 | 
							
								    in let (l, s, p, t, u, n) = 
							 | 
						||
| 
								 | 
							
								      loop (0, 0, 0, 0, 0, 0) (visit [] Assoctable.empty dag)
							 | 
						||
| 
								 | 
							
								    in (l, s, p, t, u, n)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let weight (l, s, p, t, u, n) =
							 | 
						||
| 
								 | 
							
								    l + s + 10 * p + 20 * t + u + n
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let same_complexity a b = weight a = weight b
							 | 
						||
| 
								 | 
							
								  let leq_complexity a b = weight a <= weight b
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  let to_string (l, s, p, t, u, n) =
							 | 
						||
| 
								 | 
							
								    Printf.sprintf "ld=%d st=%d add=%d mul=%d uminus=%d num=%d\n"
							 | 
						||
| 
								 | 
							
										   l s p t u n
							 | 
						||
| 
								 | 
							
										   
							 | 
						||
| 
								 | 
							
								end    
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								(* simplify the dag *)
							 | 
						||
| 
								 | 
							
								let algsimp v = 
							 | 
						||
| 
								 | 
							
								  let rec simplification_loop v =
							 | 
						||
| 
								 | 
							
								    let () = Util.info "simplification step" in
							 | 
						||
| 
								 | 
							
								    let complexity = Stats.complexity v in
							 | 
						||
| 
								 | 
							
								    let () = Util.info ("complexity = " ^ (Stats.to_string complexity)) in
							 | 
						||
| 
								 | 
							
								    let v = (AlgSimp.algsimp @@ Transpose.transpose @@ 
							 | 
						||
| 
								 | 
							
									     AlgSimp.algsimp @@ Transpose.transpose) v in
							 | 
						||
| 
								 | 
							
								    let complexity' = Stats.complexity v in
							 | 
						||
| 
								 | 
							
								    let () = Util.info ("complexity = " ^ (Stats.to_string complexity')) in
							 | 
						||
| 
								 | 
							
								    if (Stats.leq_complexity complexity' complexity) then
							 | 
						||
| 
								 | 
							
								      let () = Util.info "end algsimp" in
							 | 
						||
| 
								 | 
							
								      v
							 | 
						||
| 
								 | 
							
								    else
							 | 
						||
| 
								 | 
							
								      simplification_loop v
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  in
							 | 
						||
| 
								 | 
							
								  let () = Util.info "begin algsimp" in
							 | 
						||
| 
								 | 
							
								  let v = AlgSimp.algsimp v in
							 | 
						||
| 
								 | 
							
								  if !Magic.network_transposition then simplification_loop v else v
							 | 
						||
| 
								 | 
							
								
							 |