189 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			OCaml
		
	
	
	
	
	
		
		
			
		
	
	
			189 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			OCaml
		
	
	
	
	
	
| 
								 | 
							
								(*
							 | 
						||
| 
								 | 
							
								 * Copyright (c) 1997-1999 Massachusetts Institute of Technology
							 | 
						||
| 
								 | 
							
								 * Copyright (c) 2003, 2007-14 Matteo Frigo
							 | 
						||
| 
								 | 
							
								 * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
							 | 
						||
| 
								 | 
							
								 *
							 | 
						||
| 
								 | 
							
								 * This program is free software; you can redistribute it and/or modify
							 | 
						||
| 
								 | 
							
								 * it under the terms of the GNU General Public License as published by
							 | 
						||
| 
								 | 
							
								 * the Free Software Foundation; either version 2 of the License, or
							 | 
						||
| 
								 | 
							
								 * (at your option) any later version.
							 | 
						||
| 
								 | 
							
								 *
							 | 
						||
| 
								 | 
							
								 * This program is distributed in the hope that it will be useful,
							 | 
						||
| 
								 | 
							
								 * but WITHOUT ANY WARRANTY; without even the implied warranty of
							 | 
						||
| 
								 | 
							
								 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
							 | 
						||
| 
								 | 
							
								 * GNU General Public License for more details.
							 | 
						||
| 
								 | 
							
								 *
							 | 
						||
| 
								 | 
							
								 * You should have received a copy of the GNU General Public License
							 | 
						||
| 
								 | 
							
								 * along with this program; if not, write to the Free Software
							 | 
						||
| 
								 | 
							
								 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
							 | 
						||
| 
								 | 
							
								 *
							 | 
						||
| 
								 | 
							
								 *)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								(* policies for loading/computing twiddle factors *)
							 | 
						||
| 
								 | 
							
								open Complex
							 | 
						||
| 
								 | 
							
								open Util
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								type twop = TW_FULL | TW_CEXP | TW_NEXT
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let optostring = function
							 | 
						||
| 
								 | 
							
								  | TW_CEXP -> "TW_CEXP"
							 | 
						||
| 
								 | 
							
								  | TW_NEXT -> "TW_NEXT"
							 | 
						||
| 
								 | 
							
								  | TW_FULL -> "TW_FULL"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								type twinstr = (twop * int * int)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let rec unroll_twfull l = match l with
							 | 
						||
| 
								 | 
							
								| [] -> []
							 | 
						||
| 
								 | 
							
								| (TW_FULL, v, n) :: b ->
							 | 
						||
| 
								 | 
							
								    (forall [] cons 1 n (fun i -> (TW_CEXP, v, i)))
							 | 
						||
| 
								 | 
							
								    @ unroll_twfull b
							 | 
						||
| 
								 | 
							
								| a :: b -> a :: unroll_twfull b
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let twinstr_to_c_string l =
							 | 
						||
| 
								 | 
							
								  let one (op, a, b) = Printf.sprintf "{ %s, %d, %d }" (optostring op) a b
							 | 
						||
| 
								 | 
							
								  in let rec loop first = function
							 | 
						||
| 
								 | 
							
								    | [] -> ""
							 | 
						||
| 
								 | 
							
								    | a :: b ->  (if first then "\n" else ",\n") ^ (one a) ^ (loop false b)
							 | 
						||
| 
								 | 
							
								  in "{" ^ (loop true l) ^ "}"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let twinstr_to_simd_string vl l =
							 | 
						||
| 
								 | 
							
								  let one sep = function
							 | 
						||
| 
								 | 
							
								    | (TW_NEXT, 1, 0) -> sep ^ "{TW_NEXT, " ^ vl ^ ", 0}"
							 | 
						||
| 
								 | 
							
								    | (TW_NEXT, _, _) -> failwith "twinstr_to_simd_string"
							 | 
						||
| 
								 | 
							
								    | (TW_CEXP, v, b) -> sep ^ (Printf.sprintf "VTW(%d,%d)" v b)
							 | 
						||
| 
								 | 
							
								    | _ -> failwith "twinstr_to_simd_string"
							 | 
						||
| 
								 | 
							
								  in let rec loop first = function
							 | 
						||
| 
								 | 
							
								    | [] -> ""
							 | 
						||
| 
								 | 
							
								    | a :: b ->  (one (if first then "\n" else ",\n") a) ^ (loop false b)
							 | 
						||
| 
								 | 
							
								  in "{" ^ (loop true (unroll_twfull l)) ^ "}"
							 | 
						||
| 
								 | 
							
								  
							 | 
						||
| 
								 | 
							
								let rec pow m n =
							 | 
						||
| 
								 | 
							
								  if (n = 0) then 1
							 | 
						||
| 
								 | 
							
								  else m * pow m (n - 1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let rec is_pow m n =
							 | 
						||
| 
								 | 
							
								  n = 1 || ((n mod m) = 0 && is_pow m (n / m))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let rec log m n = if n = 1 then 0 else 1 + log m (n / m)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let rec largest_power_smaller_than m i =
							 | 
						||
| 
								 | 
							
								  if (is_pow m i) then i
							 | 
						||
| 
								 | 
							
								  else largest_power_smaller_than m (i - 1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let rec smallest_power_larger_than m i =
							 | 
						||
| 
								 | 
							
								  if (is_pow m i) then i
							 | 
						||
| 
								 | 
							
								  else smallest_power_larger_than m (i + 1)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let rec_array n f =
							 | 
						||
| 
								 | 
							
								  let g = ref (fun i -> Complex.zero) in
							 | 
						||
| 
								 | 
							
								  let a = Array.init n (fun i -> lazy (!g i)) in
							 | 
						||
| 
								 | 
							
								  let h i = f (fun i -> Lazy.force a.(i)) i in
							 | 
						||
| 
								 | 
							
								  begin
							 | 
						||
| 
								 | 
							
								    g := h;
							 | 
						||
| 
								 | 
							
								    h
							 | 
						||
| 
								 | 
							
								  end
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								 
							 | 
						||
| 
								 | 
							
								let ctimes use_complex_arith a b =
							 | 
						||
| 
								 | 
							
								  if use_complex_arith then
							 | 
						||
| 
								 | 
							
								    Complex.ctimes a b
							 | 
						||
| 
								 | 
							
								  else
							 | 
						||
| 
								 | 
							
								    Complex.times a b
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let ctimesj use_complex_arith a b =
							 | 
						||
| 
								 | 
							
								  if use_complex_arith then
							 | 
						||
| 
								 | 
							
								    Complex.ctimesj a b
							 | 
						||
| 
								 | 
							
								  else
							 | 
						||
| 
								 | 
							
								    Complex.times (Complex.conj a) b
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let make_bytwiddle sign use_complex_arith g f i =
							 | 
						||
| 
								 | 
							
								  if i = 0 then 
							 | 
						||
| 
								 | 
							
								    f i
							 | 
						||
| 
								 | 
							
								  else if sign = 1 then 
							 | 
						||
| 
								 | 
							
								    ctimes use_complex_arith (g i) (f i)
							 | 
						||
| 
								 | 
							
								  else
							 | 
						||
| 
								 | 
							
								    ctimesj use_complex_arith (g i) (f i)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								(* various policies for computing/loading twiddle factors *)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let twiddle_policy_load_all v use_complex_arith =
							 | 
						||
| 
								 | 
							
								  let bytwiddle n sign w f =
							 | 
						||
| 
								 | 
							
								    make_bytwiddle sign use_complex_arith (fun i -> w (i - 1)) f
							 | 
						||
| 
								 | 
							
								  and twidlen n = 2 * (n - 1)
							 | 
						||
| 
								 | 
							
								  and twdesc r = [(TW_FULL, v, r);(TW_NEXT, 1, 0)]
							 | 
						||
| 
								 | 
							
								  in bytwiddle, twidlen, twdesc
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								(*
							 | 
						||
| 
								 | 
							
								 * if i is a power of two, then load w (log i)
							 | 
						||
| 
								 | 
							
								 * else let x = largest power of 2 less than i in
							 | 
						||
| 
								 | 
							
								 *      let y = i - x in
							 | 
						||
| 
								 | 
							
								 *      compute w^{x+y} = w^x * w^y
							 | 
						||
| 
								 | 
							
								 *)
							 | 
						||
| 
								 | 
							
								let twiddle_policy_log2 v use_complex_arith =
							 | 
						||
| 
								 | 
							
								  let bytwiddle n sign w f =
							 | 
						||
| 
								 | 
							
								    let g = rec_array n (fun self i ->
							 | 
						||
| 
								 | 
							
								      if i = 0 then Complex.one
							 | 
						||
| 
								 | 
							
								      else if is_pow 2 i then w (log 2 i)
							 | 
						||
| 
								 | 
							
								      else let x = largest_power_smaller_than 2 i in
							 | 
						||
| 
								 | 
							
								      let y = i - x in
							 | 
						||
| 
								 | 
							
									ctimes use_complex_arith (self x) (self y))
							 | 
						||
| 
								 | 
							
								    in make_bytwiddle sign use_complex_arith g f
							 | 
						||
| 
								 | 
							
								  and twidlen n = 2 * (log 2 (largest_power_smaller_than 2 (2 * n - 1)))
							 | 
						||
| 
								 | 
							
								  and twdesc n =
							 | 
						||
| 
								 | 
							
								    (List.flatten 
							 | 
						||
| 
								 | 
							
								       (List.map 
							 | 
						||
| 
								 | 
							
									  (fun i -> 
							 | 
						||
| 
								 | 
							
									    if i > 0 && is_pow 2 i then 
							 | 
						||
| 
								 | 
							
									      [TW_CEXP, v, i] 
							 | 
						||
| 
								 | 
							
									    else 
							 | 
						||
| 
								 | 
							
									      [])
							 | 
						||
| 
								 | 
							
									  (iota n)))
							 | 
						||
| 
								 | 
							
								    @ [(TW_NEXT, 1, 0)]
							 | 
						||
| 
								 | 
							
								  in bytwiddle, twidlen, twdesc
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let twiddle_policy_log3 v use_complex_arith =
							 | 
						||
| 
								 | 
							
								  let rec terms_needed i pi s n =
							 | 
						||
| 
								 | 
							
								    if (s >= n - 1) then i
							 | 
						||
| 
								 | 
							
								    else terms_needed (i + 1) (3 * pi) (s + pi) n
							 | 
						||
| 
								 | 
							
								  in
							 | 
						||
| 
								 | 
							
								  let rec bytwiddle n sign w f =
							 | 
						||
| 
								 | 
							
								    let nterms = terms_needed 0 1 0 n in
							 | 
						||
| 
								 | 
							
								    let maxterm = pow 3 (nterms - 1) in
							 | 
						||
| 
								 | 
							
								    let g = rec_array (3 * n) (fun self i ->
							 | 
						||
| 
								 | 
							
								      if i = 0 then Complex.one
							 | 
						||
| 
								 | 
							
								      else if is_pow 3 i then w (log 3 i)
							 | 
						||
| 
								 | 
							
								      else if i = (n - 1) && maxterm >= n then
							 | 
						||
| 
								 | 
							
									w (nterms - 1)
							 | 
						||
| 
								 | 
							
								      else let x = smallest_power_larger_than 3 i in
							 | 
						||
| 
								 | 
							
								      if (i + i >= x) then
							 | 
						||
| 
								 | 
							
									let x = min x (n - 1) in
							 | 
						||
| 
								 | 
							
									  ctimesj use_complex_arith (self (x - i)) (self x)
							 | 
						||
| 
								 | 
							
								      else let x = largest_power_smaller_than 3 i in
							 | 
						||
| 
								 | 
							
									ctimes use_complex_arith (self (i - x)) (self x))
							 | 
						||
| 
								 | 
							
								    in make_bytwiddle sign use_complex_arith g f
							 | 
						||
| 
								 | 
							
								  and twidlen n = 2 * (terms_needed 0 1 0 n)
							 | 
						||
| 
								 | 
							
								  and twdesc n =
							 | 
						||
| 
								 | 
							
								    (List.map 
							 | 
						||
| 
								 | 
							
								       (fun i -> 
							 | 
						||
| 
								 | 
							
									  let x = min (pow 3 i) (n - 1) in
							 | 
						||
| 
								 | 
							
									    TW_CEXP, v, x)
							 | 
						||
| 
								 | 
							
								       (iota ((twidlen n) / 2)))
							 | 
						||
| 
								 | 
							
								    @ [(TW_NEXT, 1, 0)]
							 | 
						||
| 
								 | 
							
								  in bytwiddle, twidlen, twdesc
							 | 
						||
| 
								 | 
							
								    
							 | 
						||
| 
								 | 
							
								let current_twiddle_policy = ref twiddle_policy_load_all
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let twiddle_policy use_complex_arith = 
							 | 
						||
| 
								 | 
							
								  !current_twiddle_policy use_complex_arith
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let set_policy x = Arg.Unit (fun () -> current_twiddle_policy := x)
							 | 
						||
| 
								 | 
							
								let set_policy_int x = Arg.Int (fun i -> current_twiddle_policy := x i)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let undocumented = " Undocumented twiddle policy"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								let speclist = [
							 | 
						||
| 
								 | 
							
								  "-twiddle-load-all", set_policy twiddle_policy_load_all, undocumented;
							 | 
						||
| 
								 | 
							
								  "-twiddle-log2", set_policy twiddle_policy_log2, undocumented;
							 | 
						||
| 
								 | 
							
								  "-twiddle-log3", set_policy twiddle_policy_log3, undocumented;
							 | 
						||
| 
								 | 
							
								] 
							 |