289 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			OCaml
		
	
	
	
	
	
		
		
			
		
	
	
			289 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			OCaml
		
	
	
	
	
	
|   | (*
 | ||
|  |  * Copyright (c) 1997-1999 Massachusetts Institute of Technology | ||
|  |  * Copyright (c) 2003, 2007-14 Matteo Frigo | ||
|  |  * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology | ||
|  |  * | ||
|  |  * This program is free software; you can redistribute it and/or modify | ||
|  |  * it under the terms of the GNU General Public License as published by | ||
|  |  * the Free Software Foundation; either version 2 of the License, or | ||
|  |  * (at your option) any later version. | ||
|  |  * | ||
|  |  * This program is distributed in the hope that it will be useful, | ||
|  |  * but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
|  |  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the | ||
|  |  * GNU General Public License for more details. | ||
|  |  * | ||
|  |  * You should have received a copy of the GNU General Public License | ||
|  |  * along with this program; if not, write to the Free Software | ||
|  |  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA | ||
|  |  * | ||
|  |  *) | ||
|  | 
 | ||
|  | (*************************************************************
 | ||
|  |  * Conversion of the dag to an assignment list | ||
|  |  *************************************************************) | ||
|  | (*
 | ||
|  |  * This function is messy.  The main problem is that we want to | ||
|  |  * inline dag nodes conditionally, depending on how many times they | ||
|  |  * are used.  The Right Thing to do would be to modify the | ||
|  |  * state monad to propagate some of the state backwards, so that | ||
|  |  * we know whether a given node will be used again in the future. | ||
|  |  * This modification is trivial in a lazy language, but it is | ||
|  |  * messy in a strict language like ML.   | ||
|  |  * | ||
|  |  * In this implementation, we just do the obvious thing, i.e., visit | ||
|  |  * the dag twice, the first to count the node usages, and the second to | ||
|  |  * produce the output. | ||
|  |  *) | ||
|  | 
 | ||
|  | open Monads.StateMonad | ||
|  | open Monads.MemoMonad | ||
|  | open Expr | ||
|  | 
 | ||
|  | let fresh = Variable.make_temporary | ||
|  | let node_insert x =  Assoctable.insert Expr.hash x | ||
|  | let node_lookup x =  Assoctable.lookup Expr.hash (==) x | ||
|  | let empty = Assoctable.empty | ||
|  | 
 | ||
|  | let fetchAl =  | ||
|  |   fetchState >>= (fun (al, _, _) -> returnM al) | ||
|  | 
 | ||
|  | let storeAl al = | ||
|  |   fetchState >>= (fun (_, visited, visited') -> | ||
|  |     storeState (al, visited, visited')) | ||
|  | 
 | ||
|  | let fetchVisited = fetchState >>= (fun (_, v, _) -> returnM v) | ||
|  | 
 | ||
|  | let storeVisited visited = | ||
|  |   fetchState >>= (fun (al, _, visited') -> | ||
|  |     storeState (al, visited, visited')) | ||
|  | 
 | ||
|  | let fetchVisited' = fetchState >>= (fun (_, _, v') -> returnM v') | ||
|  | let storeVisited' visited' = | ||
|  |   fetchState >>= (fun (al, visited, _) -> | ||
|  |     storeState (al, visited, visited')) | ||
|  | let lookupVisitedM' key = | ||
|  |   fetchVisited' >>= fun table -> | ||
|  |     returnM (node_lookup key table) | ||
|  | let insertVisitedM' key value = | ||
|  |   fetchVisited' >>= fun table -> | ||
|  |     storeVisited' (node_insert key value table) | ||
|  | 
 | ||
|  | let counting f x = | ||
|  |   fetchVisited >>= (fun v -> | ||
|  |     match node_lookup x v with | ||
|  |       Some count ->  | ||
|  | 	let incr_cnt =  | ||
|  | 	  fetchVisited >>= (fun v' -> | ||
|  | 	    storeVisited (node_insert x (count + 1) v')) | ||
|  | 	in | ||
|  | 	begin | ||
|  | 	  match x with | ||
|  | 	    (* Uminus is always inlined.  Visit child *) | ||
|  | 	    Uminus y -> f y >> incr_cnt | ||
|  | 	  | _ -> incr_cnt | ||
|  | 	end | ||
|  |     | None -> | ||
|  |         f x >> fetchVisited >>= (fun v' -> | ||
|  |             storeVisited (node_insert x 1 v'))) | ||
|  | 
 | ||
|  | let with_varM v x =  | ||
|  |   fetchAl >>= (fun al -> storeAl ((v, x) :: al)) >> returnM (Load v) | ||
|  | 
 | ||
|  | let inlineM = returnM | ||
|  | 
 | ||
|  | let with_tempM x = match x with | ||
|  | | Load v when Variable.is_temporary v -> inlineM x (* avoid trivial moves *) | ||
|  | |  _ -> with_varM (fresh ()) x | ||
|  | 
 | ||
|  | (* declare a temporary only if node is used more than once *) | ||
|  | let with_temp_maybeM node x = | ||
|  |   fetchVisited >>= (fun v -> | ||
|  |     match node_lookup node v with | ||
|  |       Some count ->  | ||
|  |         if (count = 1 && !Magic.inline_single) then | ||
|  |           inlineM x | ||
|  |         else | ||
|  |           with_tempM x | ||
|  |     | None -> | ||
|  |         failwith "with_temp_maybeM") | ||
|  | type fma =  | ||
|  |     NO_FMA | ||
|  |   | FMA of expr * expr * expr   (* FMA (a, b, c) => a + b * c *) | ||
|  |   | FMS of expr * expr * expr   (* FMS (a, b, c) => -a + b * c *) | ||
|  |   | FNMS of expr * expr * expr  (* FNMS (a, b, c) => a - b * c *) | ||
|  | 
 | ||
|  | let good_for_fma (a, b) =  | ||
|  |   let good = function | ||
|  |     | NaN I -> true | ||
|  |     | NaN CONJ -> true | ||
|  |     | NaN _ -> false | ||
|  |     | Times(NaN _, _) -> false | ||
|  |     | Times(_, NaN _) -> false | ||
|  |     | _ -> true | ||
|  |   in good a && good b | ||
|  | 
 | ||
|  | let build_fma l =  | ||
|  |   if (not !Magic.enable_fma) then NO_FMA | ||
|  |   else match l with | ||
|  |   | [a; Uminus (Times (b, c))] when good_for_fma (b, c) -> FNMS (a, b, c) | ||
|  |   | [Uminus (Times (b, c)); a] when good_for_fma (b, c) -> FNMS (a, b, c) | ||
|  |   | [Uminus a; Times (b, c)] when good_for_fma (b, c) -> FMS (a, b, c) | ||
|  |   | [Times (b, c); Uminus a] when good_for_fma (b, c) -> FMS (a, b, c) | ||
|  |   | [a; Times (b, c)] when good_for_fma (b, c) -> FMA (a, b, c) | ||
|  |   | [Times (b, c); a] when good_for_fma (b, c) -> FMA (a, b, c) | ||
|  |   | _ -> NO_FMA | ||
|  | 
 | ||
|  | let children_fma l = match build_fma l with | ||
|  | | FMA (a, b, c) -> Some (a, b, c) | ||
|  | | FMS (a, b, c) -> Some (a, b, c) | ||
|  | | FNMS (a, b, c) -> Some (a, b, c) | ||
|  | | NO_FMA -> None | ||
|  | 
 | ||
|  | 
 | ||
|  | let rec visitM x = | ||
|  |   counting (function | ||
|  |     | Load v -> returnM () | ||
|  |     | Num a -> returnM () | ||
|  |     | NaN a -> returnM () | ||
|  |     | Store (v, x) -> visitM x | ||
|  |     | Plus a -> (match children_fma a with | ||
|  | 	None -> mapM visitM a >> returnM () | ||
|  |       | Some (a, b, c) ->  | ||
|  |           (* visit fma's arguments twice to make sure they are not inlined *) | ||
|  | 	  visitM a >> visitM a >> | ||
|  | 	  visitM b >> visitM b >> | ||
|  | 	  visitM c >> visitM c) | ||
|  |     | Times (a, b) -> visitM a >> visitM b | ||
|  |     | CTimes (a, b) -> visitM a >> visitM b | ||
|  |     | CTimesJ (a, b) -> visitM a >> visitM b | ||
|  |     | Uminus a -> visitM a) | ||
|  |     x | ||
|  | 
 | ||
|  | let visit_rootsM = mapM visitM | ||
|  | 
 | ||
|  | 
 | ||
|  | let rec expr_of_nodeM x = | ||
|  |   memoizing lookupVisitedM' insertVisitedM' | ||
|  |     (function x -> match x with | ||
|  |     | Load v ->  | ||
|  | 	if (Variable.is_temporary v) then | ||
|  | 	  inlineM (Load v) | ||
|  | 	else if (Variable.is_locative v && !Magic.inline_loads) then | ||
|  |           inlineM (Load v) | ||
|  |         else if (Variable.is_constant v && !Magic.inline_loads_constants) then | ||
|  |           inlineM (Load v) | ||
|  | 	else | ||
|  |           with_tempM (Load v) | ||
|  |     | Num a -> | ||
|  |         if !Magic.inline_constants then | ||
|  |           inlineM (Num a) | ||
|  | 	else | ||
|  |           with_temp_maybeM x (Num a) | ||
|  |     | NaN a -> inlineM (NaN a) | ||
|  |     | Store (v, x) ->  | ||
|  |         expr_of_nodeM x >>=  | ||
|  | 	(if !Magic.trivial_stores then with_tempM else inlineM) >>= | ||
|  |         with_varM v  | ||
|  | 
 | ||
|  |     | Plus a ->  | ||
|  | 	begin | ||
|  | 	  match build_fma a with | ||
|  | 	    FMA (a, b, c) ->	   | ||
|  | 	      expr_of_nodeM a >>= fun a' -> | ||
|  | 		expr_of_nodeM b >>= fun b' -> | ||
|  | 		  expr_of_nodeM c >>= fun c' -> | ||
|  | 		    with_temp_maybeM x (Plus [a'; Times (b', c')]) | ||
|  | 	  | FMS (a, b, c) ->	   | ||
|  | 	      expr_of_nodeM a >>= fun a' -> | ||
|  | 		expr_of_nodeM b >>= fun b' -> | ||
|  | 		  expr_of_nodeM c >>= fun c' -> | ||
|  | 		    with_temp_maybeM x  | ||
|  | 		      (Plus [Times (b', c'); Uminus a']) | ||
|  | 	  | FNMS (a, b, c) ->	   | ||
|  | 	      expr_of_nodeM a >>= fun a' -> | ||
|  | 		expr_of_nodeM b >>= fun b' -> | ||
|  | 		  expr_of_nodeM c >>= fun c' -> | ||
|  | 		    with_temp_maybeM x  | ||
|  | 		      (Plus [a'; Uminus (Times (b', c'))]) | ||
|  | 	  | NO_FMA -> | ||
|  |               mapM expr_of_nodeM a >>= fun a' -> | ||
|  | 		with_temp_maybeM x (Plus a') | ||
|  | 	end | ||
|  |     | CTimes (Load _ as a, b) when !Magic.generate_bytw -> | ||
|  |         expr_of_nodeM b >>= fun b' -> | ||
|  |           with_tempM (CTimes (a, b')) | ||
|  |     | CTimes (a, b) -> | ||
|  |         expr_of_nodeM a >>= fun a' -> | ||
|  |           expr_of_nodeM b >>= fun b' -> | ||
|  |             with_tempM (CTimes (a', b')) | ||
|  |     | CTimesJ (Load _ as a, b) when !Magic.generate_bytw -> | ||
|  |         expr_of_nodeM b >>= fun b' -> | ||
|  |           with_tempM (CTimesJ (a, b')) | ||
|  |     | CTimesJ (a, b) -> | ||
|  |         expr_of_nodeM a >>= fun a' -> | ||
|  |           expr_of_nodeM b >>= fun b' -> | ||
|  |             with_tempM (CTimesJ (a', b')) | ||
|  |     | Times (a, b) -> | ||
|  |         expr_of_nodeM a >>= fun a' -> | ||
|  |           expr_of_nodeM b >>= fun b' -> | ||
|  | 	    begin | ||
|  | 	      match a' with | ||
|  | 		Num a'' when !Magic.strength_reduce_mul && Number.is_two a'' -> | ||
|  | 		  (inlineM b' >>= fun b'' -> | ||
|  | 		    with_temp_maybeM x (Plus [b''; b''])) | ||
|  | 	      | _ -> with_temp_maybeM x (Times (a', b')) | ||
|  | 	    end | ||
|  |     | Uminus a -> | ||
|  |         expr_of_nodeM a >>= fun a' -> | ||
|  |           inlineM (Uminus a')) | ||
|  |     x | ||
|  | 
 | ||
|  | let expr_of_rootsM = mapM expr_of_nodeM | ||
|  | 
 | ||
|  | let peek_alistM roots = | ||
|  |   visit_rootsM roots >> expr_of_rootsM roots >> fetchAl | ||
|  | 
 | ||
|  | let wrap_assign (a, b) = Expr.Assign (a, b) | ||
|  | 
 | ||
|  | let to_assignments dag = | ||
|  |   let () = Util.info "begin to_alist" in | ||
|  |   let al = List.rev (runM ([], empty, empty) peek_alistM dag) in | ||
|  |   let res = List.map wrap_assign al in | ||
|  |   let () = Util.info "end to_alist" in | ||
|  |   res | ||
|  | 
 | ||
|  | 
 | ||
|  | (* dump alist in `dot' format *) | ||
|  | let dump print alist = | ||
|  |   let vs v = "\"" ^ (Variable.unparse v) ^ "\"" in | ||
|  |   begin | ||
|  |     print "digraph G {\n"; | ||
|  |     print "\tsize=\"6,6\";\n"; | ||
|  | 
 | ||
|  |     (* all input nodes have the same rank *) | ||
|  |     print "{ rank = same;\n"; | ||
|  |     List.iter (fun (Expr.Assign (v, x)) -> | ||
|  |       List.iter (fun y ->  | ||
|  | 	if (Variable.is_locative y) then print("\t" ^ (vs y) ^ ";\n")) | ||
|  | 	(Expr.find_vars x)) | ||
|  |       alist; | ||
|  |     print "}\n"; | ||
|  | 
 | ||
|  |     (* all output nodes have the same rank *) | ||
|  |     print "{ rank = same;\n"; | ||
|  |     List.iter (fun (Expr.Assign (v, x)) -> | ||
|  |       if (Variable.is_locative v) then print("\t" ^ (vs v) ^ ";\n")) | ||
|  |       alist; | ||
|  |     print "}\n"; | ||
|  |      | ||
|  |     (* edges *) | ||
|  |     List.iter (fun (Expr.Assign (v, x)) -> | ||
|  |       List.iter (fun y -> print("\t" ^ (vs y) ^ " -> " ^ (vs v) ^ ";\n")) | ||
|  | 	(Expr.find_vars x)) | ||
|  |       alist; | ||
|  | 
 | ||
|  |     print "}\n"; | ||
|  |   end | ||
|  | 
 |