289 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			OCaml
		
	
	
	
	
	
		
		
			
		
	
	
			289 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			OCaml
		
	
	
	
	
	
| 
								 | 
							
								(*
							 | 
						||
| 
								 | 
							
								 * Copyright (c) 1997-1999 Massachusetts Institute of Technology
							 | 
						||
| 
								 | 
							
								 * Copyright (c) 2003, 2007-14 Matteo Frigo
							 | 
						||
| 
								 | 
							
								 * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
							 | 
						||
| 
								 | 
							
								 *
							 | 
						||
| 
								 | 
							
								 * This program is free software; you can redistribute it and/or modify
							 | 
						||
| 
								 | 
							
								 * it under the terms of the GNU General Public License as published by
							 | 
						||
| 
								 | 
							
								 * the Free Software Foundation; either version 2 of the License, or
							 | 
						||
| 
								 | 
							
								 * (at your option) any later version.
							 | 
						||
| 
								 | 
							
								 *
							 | 
						||
| 
								 | 
							
								 * This program is distributed in the hope that it will be useful,
							 | 
						||
| 
								 | 
							
								 * but WITHOUT ANY WARRANTY; without even the implied warranty of
							 | 
						||
| 
								 | 
							
								 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
							 | 
						||
| 
								 | 
							
								 * GNU General Public License for more details.
							 | 
						||
| 
								 | 
							
								 *
							 | 
						||
| 
								 | 
							
								 * You should have received a copy of the GNU General Public License
							 | 
						||
| 
								 | 
							
								 * along with this program; if not, write to the Free Software
							 | 
						||
| 
								 | 
							
								 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
							 | 
						||
| 
								 | 
							
								 *
							 | 
						||
| 
								 | 
							
								 *)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								(*************************************************************
							 | 
						||
| 
								 | 
							
								 * Conversion of the dag to an assignment list
							 | 
						||
| 
								 | 
							
								 *************************************************************)
							 | 
						||
| 
								 | 
							
								(*
							 | 
						||
| 
								 | 
							
								 * This function is messy.  The main problem is that we want to
							 | 
						||
| 
								 | 
							
								 * inline dag nodes conditionally, depending on how many times they
							 | 
						||
| 
								 | 
							
								 * are used.  The Right Thing to do would be to modify the
							 | 
						||
| 
								 | 
							
								 * state monad to propagate some of the state backwards, so that
							 | 
						||
| 
								 | 
							
								 * we know whether a given node will be used again in the future.
							 | 
						||
| 
								 | 
							
								 * This modification is trivial in a lazy language, but it is
							 | 
						||
| 
								 | 
							
								 * messy in a strict language like ML.  
							 | 
						||
| 
								 | 
							
								 *
							 | 
						||
| 
								 | 
							
								 * In this implementation, we just do the obvious thing, i.e., visit
							 | 
						||
| 
								 | 
							
								 * the dag twice, the first to count the node usages, and the second to
							 | 
						||
| 
								 | 
							
								 * produce the output.
							 | 
						||
| 
								 | 
							
								 *)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								open Monads.StateMonad
							 | 
						||
| 
								 | 
							
								open Monads.MemoMonad
							 | 
						||
| 
								 | 
							
								open Expr
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let fresh = Variable.make_temporary
							 | 
						||
| 
								 | 
							
								let node_insert x =  Assoctable.insert Expr.hash x
							 | 
						||
| 
								 | 
							
								let node_lookup x =  Assoctable.lookup Expr.hash (==) x
							 | 
						||
| 
								 | 
							
								let empty = Assoctable.empty
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let fetchAl = 
							 | 
						||
| 
								 | 
							
								  fetchState >>= (fun (al, _, _) -> returnM al)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let storeAl al =
							 | 
						||
| 
								 | 
							
								  fetchState >>= (fun (_, visited, visited') ->
							 | 
						||
| 
								 | 
							
								    storeState (al, visited, visited'))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let fetchVisited = fetchState >>= (fun (_, v, _) -> returnM v)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let storeVisited visited =
							 | 
						||
| 
								 | 
							
								  fetchState >>= (fun (al, _, visited') ->
							 | 
						||
| 
								 | 
							
								    storeState (al, visited, visited'))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let fetchVisited' = fetchState >>= (fun (_, _, v') -> returnM v')
							 | 
						||
| 
								 | 
							
								let storeVisited' visited' =
							 | 
						||
| 
								 | 
							
								  fetchState >>= (fun (al, visited, _) ->
							 | 
						||
| 
								 | 
							
								    storeState (al, visited, visited'))
							 | 
						||
| 
								 | 
							
								let lookupVisitedM' key =
							 | 
						||
| 
								 | 
							
								  fetchVisited' >>= fun table ->
							 | 
						||
| 
								 | 
							
								    returnM (node_lookup key table)
							 | 
						||
| 
								 | 
							
								let insertVisitedM' key value =
							 | 
						||
| 
								 | 
							
								  fetchVisited' >>= fun table ->
							 | 
						||
| 
								 | 
							
								    storeVisited' (node_insert key value table)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let counting f x =
							 | 
						||
| 
								 | 
							
								  fetchVisited >>= (fun v ->
							 | 
						||
| 
								 | 
							
								    match node_lookup x v with
							 | 
						||
| 
								 | 
							
								      Some count -> 
							 | 
						||
| 
								 | 
							
									let incr_cnt = 
							 | 
						||
| 
								 | 
							
									  fetchVisited >>= (fun v' ->
							 | 
						||
| 
								 | 
							
									    storeVisited (node_insert x (count + 1) v'))
							 | 
						||
| 
								 | 
							
									in
							 | 
						||
| 
								 | 
							
									begin
							 | 
						||
| 
								 | 
							
									  match x with
							 | 
						||
| 
								 | 
							
									    (* Uminus is always inlined.  Visit child *)
							 | 
						||
| 
								 | 
							
									    Uminus y -> f y >> incr_cnt
							 | 
						||
| 
								 | 
							
									  | _ -> incr_cnt
							 | 
						||
| 
								 | 
							
									end
							 | 
						||
| 
								 | 
							
								    | None ->
							 | 
						||
| 
								 | 
							
								        f x >> fetchVisited >>= (fun v' ->
							 | 
						||
| 
								 | 
							
								            storeVisited (node_insert x 1 v')))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let with_varM v x = 
							 | 
						||
| 
								 | 
							
								  fetchAl >>= (fun al -> storeAl ((v, x) :: al)) >> returnM (Load v)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let inlineM = returnM
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let with_tempM x = match x with
							 | 
						||
| 
								 | 
							
								| Load v when Variable.is_temporary v -> inlineM x (* avoid trivial moves *)
							 | 
						||
| 
								 | 
							
								|  _ -> with_varM (fresh ()) x
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								(* declare a temporary only if node is used more than once *)
							 | 
						||
| 
								 | 
							
								let with_temp_maybeM node x =
							 | 
						||
| 
								 | 
							
								  fetchVisited >>= (fun v ->
							 | 
						||
| 
								 | 
							
								    match node_lookup node v with
							 | 
						||
| 
								 | 
							
								      Some count -> 
							 | 
						||
| 
								 | 
							
								        if (count = 1 && !Magic.inline_single) then
							 | 
						||
| 
								 | 
							
								          inlineM x
							 | 
						||
| 
								 | 
							
								        else
							 | 
						||
| 
								 | 
							
								          with_tempM x
							 | 
						||
| 
								 | 
							
								    | None ->
							 | 
						||
| 
								 | 
							
								        failwith "with_temp_maybeM")
							 | 
						||
| 
								 | 
							
								type fma = 
							 | 
						||
| 
								 | 
							
								    NO_FMA
							 | 
						||
| 
								 | 
							
								  | FMA of expr * expr * expr   (* FMA (a, b, c) => a + b * c *)
							 | 
						||
| 
								 | 
							
								  | FMS of expr * expr * expr   (* FMS (a, b, c) => -a + b * c *)
							 | 
						||
| 
								 | 
							
								  | FNMS of expr * expr * expr  (* FNMS (a, b, c) => a - b * c *)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let good_for_fma (a, b) = 
							 | 
						||
| 
								 | 
							
								  let good = function
							 | 
						||
| 
								 | 
							
								    | NaN I -> true
							 | 
						||
| 
								 | 
							
								    | NaN CONJ -> true
							 | 
						||
| 
								 | 
							
								    | NaN _ -> false
							 | 
						||
| 
								 | 
							
								    | Times(NaN _, _) -> false
							 | 
						||
| 
								 | 
							
								    | Times(_, NaN _) -> false
							 | 
						||
| 
								 | 
							
								    | _ -> true
							 | 
						||
| 
								 | 
							
								  in good a && good b
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let build_fma l = 
							 | 
						||
| 
								 | 
							
								  if (not !Magic.enable_fma) then NO_FMA
							 | 
						||
| 
								 | 
							
								  else match l with
							 | 
						||
| 
								 | 
							
								  | [a; Uminus (Times (b, c))] when good_for_fma (b, c) -> FNMS (a, b, c)
							 | 
						||
| 
								 | 
							
								  | [Uminus (Times (b, c)); a] when good_for_fma (b, c) -> FNMS (a, b, c)
							 | 
						||
| 
								 | 
							
								  | [Uminus a; Times (b, c)] when good_for_fma (b, c) -> FMS (a, b, c)
							 | 
						||
| 
								 | 
							
								  | [Times (b, c); Uminus a] when good_for_fma (b, c) -> FMS (a, b, c)
							 | 
						||
| 
								 | 
							
								  | [a; Times (b, c)] when good_for_fma (b, c) -> FMA (a, b, c)
							 | 
						||
| 
								 | 
							
								  | [Times (b, c); a] when good_for_fma (b, c) -> FMA (a, b, c)
							 | 
						||
| 
								 | 
							
								  | _ -> NO_FMA
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let children_fma l = match build_fma l with
							 | 
						||
| 
								 | 
							
								| FMA (a, b, c) -> Some (a, b, c)
							 | 
						||
| 
								 | 
							
								| FMS (a, b, c) -> Some (a, b, c)
							 | 
						||
| 
								 | 
							
								| FNMS (a, b, c) -> Some (a, b, c)
							 | 
						||
| 
								 | 
							
								| NO_FMA -> None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let rec visitM x =
							 | 
						||
| 
								 | 
							
								  counting (function
							 | 
						||
| 
								 | 
							
								    | Load v -> returnM ()
							 | 
						||
| 
								 | 
							
								    | Num a -> returnM ()
							 | 
						||
| 
								 | 
							
								    | NaN a -> returnM ()
							 | 
						||
| 
								 | 
							
								    | Store (v, x) -> visitM x
							 | 
						||
| 
								 | 
							
								    | Plus a -> (match children_fma a with
							 | 
						||
| 
								 | 
							
									None -> mapM visitM a >> returnM ()
							 | 
						||
| 
								 | 
							
								      | Some (a, b, c) -> 
							 | 
						||
| 
								 | 
							
								          (* visit fma's arguments twice to make sure they are not inlined *)
							 | 
						||
| 
								 | 
							
									  visitM a >> visitM a >>
							 | 
						||
| 
								 | 
							
									  visitM b >> visitM b >>
							 | 
						||
| 
								 | 
							
									  visitM c >> visitM c)
							 | 
						||
| 
								 | 
							
								    | Times (a, b) -> visitM a >> visitM b
							 | 
						||
| 
								 | 
							
								    | CTimes (a, b) -> visitM a >> visitM b
							 | 
						||
| 
								 | 
							
								    | CTimesJ (a, b) -> visitM a >> visitM b
							 | 
						||
| 
								 | 
							
								    | Uminus a -> visitM a)
							 | 
						||
| 
								 | 
							
								    x
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let visit_rootsM = mapM visitM
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let rec expr_of_nodeM x =
							 | 
						||
| 
								 | 
							
								  memoizing lookupVisitedM' insertVisitedM'
							 | 
						||
| 
								 | 
							
								    (function x -> match x with
							 | 
						||
| 
								 | 
							
								    | Load v -> 
							 | 
						||
| 
								 | 
							
									if (Variable.is_temporary v) then
							 | 
						||
| 
								 | 
							
									  inlineM (Load v)
							 | 
						||
| 
								 | 
							
									else if (Variable.is_locative v && !Magic.inline_loads) then
							 | 
						||
| 
								 | 
							
								          inlineM (Load v)
							 | 
						||
| 
								 | 
							
								        else if (Variable.is_constant v && !Magic.inline_loads_constants) then
							 | 
						||
| 
								 | 
							
								          inlineM (Load v)
							 | 
						||
| 
								 | 
							
									else
							 | 
						||
| 
								 | 
							
								          with_tempM (Load v)
							 | 
						||
| 
								 | 
							
								    | Num a ->
							 | 
						||
| 
								 | 
							
								        if !Magic.inline_constants then
							 | 
						||
| 
								 | 
							
								          inlineM (Num a)
							 | 
						||
| 
								 | 
							
									else
							 | 
						||
| 
								 | 
							
								          with_temp_maybeM x (Num a)
							 | 
						||
| 
								 | 
							
								    | NaN a -> inlineM (NaN a)
							 | 
						||
| 
								 | 
							
								    | Store (v, x) -> 
							 | 
						||
| 
								 | 
							
								        expr_of_nodeM x >>= 
							 | 
						||
| 
								 | 
							
									(if !Magic.trivial_stores then with_tempM else inlineM) >>=
							 | 
						||
| 
								 | 
							
								        with_varM v 
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    | Plus a -> 
							 | 
						||
| 
								 | 
							
									begin
							 | 
						||
| 
								 | 
							
									  match build_fma a with
							 | 
						||
| 
								 | 
							
									    FMA (a, b, c) ->	  
							 | 
						||
| 
								 | 
							
									      expr_of_nodeM a >>= fun a' ->
							 | 
						||
| 
								 | 
							
										expr_of_nodeM b >>= fun b' ->
							 | 
						||
| 
								 | 
							
										  expr_of_nodeM c >>= fun c' ->
							 | 
						||
| 
								 | 
							
										    with_temp_maybeM x (Plus [a'; Times (b', c')])
							 | 
						||
| 
								 | 
							
									  | FMS (a, b, c) ->	  
							 | 
						||
| 
								 | 
							
									      expr_of_nodeM a >>= fun a' ->
							 | 
						||
| 
								 | 
							
										expr_of_nodeM b >>= fun b' ->
							 | 
						||
| 
								 | 
							
										  expr_of_nodeM c >>= fun c' ->
							 | 
						||
| 
								 | 
							
										    with_temp_maybeM x 
							 | 
						||
| 
								 | 
							
										      (Plus [Times (b', c'); Uminus a'])
							 | 
						||
| 
								 | 
							
									  | FNMS (a, b, c) ->	  
							 | 
						||
| 
								 | 
							
									      expr_of_nodeM a >>= fun a' ->
							 | 
						||
| 
								 | 
							
										expr_of_nodeM b >>= fun b' ->
							 | 
						||
| 
								 | 
							
										  expr_of_nodeM c >>= fun c' ->
							 | 
						||
| 
								 | 
							
										    with_temp_maybeM x 
							 | 
						||
| 
								 | 
							
										      (Plus [a'; Uminus (Times (b', c'))])
							 | 
						||
| 
								 | 
							
									  | NO_FMA ->
							 | 
						||
| 
								 | 
							
								              mapM expr_of_nodeM a >>= fun a' ->
							 | 
						||
| 
								 | 
							
										with_temp_maybeM x (Plus a')
							 | 
						||
| 
								 | 
							
									end
							 | 
						||
| 
								 | 
							
								    | CTimes (Load _ as a, b) when !Magic.generate_bytw ->
							 | 
						||
| 
								 | 
							
								        expr_of_nodeM b >>= fun b' ->
							 | 
						||
| 
								 | 
							
								          with_tempM (CTimes (a, b'))
							 | 
						||
| 
								 | 
							
								    | CTimes (a, b) ->
							 | 
						||
| 
								 | 
							
								        expr_of_nodeM a >>= fun a' ->
							 | 
						||
| 
								 | 
							
								          expr_of_nodeM b >>= fun b' ->
							 | 
						||
| 
								 | 
							
								            with_tempM (CTimes (a', b'))
							 | 
						||
| 
								 | 
							
								    | CTimesJ (Load _ as a, b) when !Magic.generate_bytw ->
							 | 
						||
| 
								 | 
							
								        expr_of_nodeM b >>= fun b' ->
							 | 
						||
| 
								 | 
							
								          with_tempM (CTimesJ (a, b'))
							 | 
						||
| 
								 | 
							
								    | CTimesJ (a, b) ->
							 | 
						||
| 
								 | 
							
								        expr_of_nodeM a >>= fun a' ->
							 | 
						||
| 
								 | 
							
								          expr_of_nodeM b >>= fun b' ->
							 | 
						||
| 
								 | 
							
								            with_tempM (CTimesJ (a', b'))
							 | 
						||
| 
								 | 
							
								    | Times (a, b) ->
							 | 
						||
| 
								 | 
							
								        expr_of_nodeM a >>= fun a' ->
							 | 
						||
| 
								 | 
							
								          expr_of_nodeM b >>= fun b' ->
							 | 
						||
| 
								 | 
							
									    begin
							 | 
						||
| 
								 | 
							
									      match a' with
							 | 
						||
| 
								 | 
							
										Num a'' when !Magic.strength_reduce_mul && Number.is_two a'' ->
							 | 
						||
| 
								 | 
							
										  (inlineM b' >>= fun b'' ->
							 | 
						||
| 
								 | 
							
										    with_temp_maybeM x (Plus [b''; b'']))
							 | 
						||
| 
								 | 
							
									      | _ -> with_temp_maybeM x (Times (a', b'))
							 | 
						||
| 
								 | 
							
									    end
							 | 
						||
| 
								 | 
							
								    | Uminus a ->
							 | 
						||
| 
								 | 
							
								        expr_of_nodeM a >>= fun a' ->
							 | 
						||
| 
								 | 
							
								          inlineM (Uminus a'))
							 | 
						||
| 
								 | 
							
								    x
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let expr_of_rootsM = mapM expr_of_nodeM
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let peek_alistM roots =
							 | 
						||
| 
								 | 
							
								  visit_rootsM roots >> expr_of_rootsM roots >> fetchAl
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let wrap_assign (a, b) = Expr.Assign (a, b)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let to_assignments dag =
							 | 
						||
| 
								 | 
							
								  let () = Util.info "begin to_alist" in
							 | 
						||
| 
								 | 
							
								  let al = List.rev (runM ([], empty, empty) peek_alistM dag) in
							 | 
						||
| 
								 | 
							
								  let res = List.map wrap_assign al in
							 | 
						||
| 
								 | 
							
								  let () = Util.info "end to_alist" in
							 | 
						||
| 
								 | 
							
								  res
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								(* dump alist in `dot' format *)
							 | 
						||
| 
								 | 
							
								let dump print alist =
							 | 
						||
| 
								 | 
							
								  let vs v = "\"" ^ (Variable.unparse v) ^ "\"" in
							 | 
						||
| 
								 | 
							
								  begin
							 | 
						||
| 
								 | 
							
								    print "digraph G {\n";
							 | 
						||
| 
								 | 
							
								    print "\tsize=\"6,6\";\n";
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    (* all input nodes have the same rank *)
							 | 
						||
| 
								 | 
							
								    print "{ rank = same;\n";
							 | 
						||
| 
								 | 
							
								    List.iter (fun (Expr.Assign (v, x)) ->
							 | 
						||
| 
								 | 
							
								      List.iter (fun y -> 
							 | 
						||
| 
								 | 
							
									if (Variable.is_locative y) then print("\t" ^ (vs y) ^ ";\n"))
							 | 
						||
| 
								 | 
							
									(Expr.find_vars x))
							 | 
						||
| 
								 | 
							
								      alist;
							 | 
						||
| 
								 | 
							
								    print "}\n";
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    (* all output nodes have the same rank *)
							 | 
						||
| 
								 | 
							
								    print "{ rank = same;\n";
							 | 
						||
| 
								 | 
							
								    List.iter (fun (Expr.Assign (v, x)) ->
							 | 
						||
| 
								 | 
							
								      if (Variable.is_locative v) then print("\t" ^ (vs v) ^ ";\n"))
							 | 
						||
| 
								 | 
							
								      alist;
							 | 
						||
| 
								 | 
							
								    print "}\n";
							 | 
						||
| 
								 | 
							
								    
							 | 
						||
| 
								 | 
							
								    (* edges *)
							 | 
						||
| 
								 | 
							
								    List.iter (fun (Expr.Assign (v, x)) ->
							 | 
						||
| 
								 | 
							
								      List.iter (fun y -> print("\t" ^ (vs y) ^ " -> " ^ (vs v) ^ ";\n"))
							 | 
						||
| 
								 | 
							
									(Expr.find_vars x))
							 | 
						||
| 
								 | 
							
								      alist;
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    print "}\n";
							 | 
						||
| 
								 | 
							
								  end
							 | 
						||
| 
								 | 
							
								
							 |