(** Bombs-Must-Detonate: Linker @author Brian Go *)


open Linkst

(** List of provided function names *)

type provide_environment = string list

(** Association list of function name - function types *)

type type_environment = (string * link_type) list

exception Linker_error of string

(** Shorthand for String.concat " *)

let concat s = String.concat "" s

(** Raises the linker error of string s *)

let raise_linker_error s = raise (Linker_error s)

(** VM-Provided Remotable Functions *)

let vm_defs = [ LnkProvide ("to_string",
                            LnkFunction ([LnkAnything],LnkString));
                LnkProvide ("parseint",
                            LnkFunction ([LnkString],LnkInt));
                LnkProvide ("parsefloat",
                            LnkFunction ([LnkString],LnkFloat));
                LnkProvide ("array_length",
                            LnkFunction ([LnkArray LnkAnything],LnkInt));
                LnkProvide ("string_length",
                            LnkFunction ([LnkString],LnkInt));
                LnkProvide ("rand",
                            LnkFunction ([LnkVoid],LnkFloat));
                LnkProvide ("print",
                            LnkFunction ([LnkAnything],LnkVoid));
                LnkProvide ("print_error",
                            LnkFunction ([LnkAnything],LnkVoid));
                LnkProvide ("println",
                            LnkFunction ([LnkAnything],LnkVoid));]

(** Extracts the assembler code from the linker syntax tree *)

let rec extract_code lst = 
  match lst with
      car::cdr -> 
        (match car with 
             LnkCode s -> s
           | _ -> extract_code cdr)
    | [] -> ""

(** Extracts a list of the name of a provided function from the linker syntax tree element. Nil otherwise *)

let extract_provide def = match def with
    LnkProvide (name,_) -> [name]
  | _ -> []

(** Extracts the name/type pair of a provided function from the linker syntax tree element. Nil otherwise *)

let extract_provide_type def = match def with
    LnkProvide (name,ty) -> [(name,ty)]
  | _ -> []

(** Extracts a name/type pairs from a linker syntax tree element. Nil otherwisee *)

let extract_type def = match def with
    LnkProvide (name,ty) -> [name, ty]
  | LnkRequire (name,ty) -> [name, ty]
  | _ -> []

(** Checks for duplicate function definitions. Linker error if found. *)

let rec check_duplicates provide_list = match provide_list with
    car::cdr when List.mem car cdr -> 
      raise_linker_error (concat ["Function ";car;" is provided by more than one actor/vm."])
  | _::cdr -> check_duplicates cdr
  | [] -> ()

(** Check required function against provided functions (just names, not types. Linker error if not fouond *)

let check_require provide_list def = match def with
    LnkRequire (name,_) ->
      if List.mem name provide_list then
        ()
      else
        raise_linker_error (concat ["Function ";name;" was required but not provided."])
  | _ -> ()

(** Checks if two association lists are compatible (do not contain conflicting key/value pairs *)

let rec are_assoclists_compatible l1 l2 = match l1 with
    (x,y)::rest -> 
      if List.mem_assoc x l2 
      then y = List.assoc x l2
      else are_assoclists_compatible rest l2
  | [] -> true


(** Returns an association list of template types to link_types. Error if types are not unifiable. *)

let rec check_unified t1 t2 = match (t1,t2) with
    x,y when x = y -> []
  | LnkAnything, x -> []
  | x, LnkAnything -> []
  | LnkTemplate s, t2 -> [(s,t2)]
  | t1, LnkTemplate s -> [(s,t1)] 
        (** We do this in either order for convenience: a required function will never be a templated function, since we don't infer that some type is templated; rather we require that the specific types in each case where it is called are valid. *)

  | LnkEnumType (name1,size1), LnkEnumType (name2,size2) ->
      if (name1 = name2) && (size1 = size2) then
        []
      else 
        raise_linker_error 
          (concat ["Failed to reconcile enums ";
                   name1;"(";(string_of_int size1);") and ";
                   name2;"(";(string_of_int size2);")."])
  | LnkEnum (name1,size1), LnkEnum (name2,size2) ->
      if (name1 = name2) && (size1 = size2) then
        []
      else 
        raise_linker_error 
          (concat ["Failed to reconcile enums ";
                   name1;"(";(string_of_int size1);") and ";
                   name2;"(";(string_of_int size2);")."])
  | LnkStruct (name1,fields1), LnkStruct (name2,fields2) ->
      let _ = List.map2 check_unified fields1 fields2 in
        if name1 = name2 then
          []
        else 
          raise_linker_error
            (concat ["Failed to reconcile structs ";
                     name1;" and ";name2;"."])
  | LnkArray t_arr1, LnkArray t_arr2 ->
      check_unified t_arr1 t_arr2
  | LnkList t_lst1, LnkList t_lst2 ->
      check_unified t_lst1 t_lst2
  | LnkFunction (args1,ret1), LnkFunction (args2,ret2) ->
      let assoc1 = check_unified ret1 ret2 in
      let assoc2 = List.flatten (List.map2 check_unified args1 args2) in
        if are_assoclists_compatible assoc1 assoc2 then
          assoc1 @ assoc2
        else
          raise_linker_error "Could not reconcile templated arrow types."
  | LnkRef t_ref1, LnkRef t_ref2 ->
      check_unified t_ref1 t_ref2
  | _ -> raise_linker_error "Irreconciliable types."
        
(** Checks that the name/type list types is consistent with the name/type list of provided_types *)

let rec check_types types provided_types = match types with
    (name,ty)::rest ->
      if List.mem_assoc name provided_types then
        let other_type = List.assoc name provided_types in
        let _ = 
          try
            check_unified ty other_type
          with
              Linker_error s -> raise_linker_error (concat ["Conflicting type definitions for ";name;": ";s])
        in
          check_types rest provided_types
      else
        check_types rest provided_types
  | [] -> ()
        
(** Checks that all provide/require definitions in the linker syntax tree are consistent with each other *)

let check_consistent defs = 
  let provided = List.flatten (List.map extract_provide defs) in
  let types = List.flatten (List.map extract_type defs) in
  let provided_types = List.flatten (List.map extract_provide_type defs) in
  let _ = check_duplicates provided in
  let _ = List.map (check_require provided) defs in
  let _ = check_types types provided_types in
    ()

(** Usage: " Links src-file-1 and src-file-2 and puts the output (linker header removed) in the output files. We allow up to 2 source files to be specified because this is the max # of AIs on a team. " Links a single source file (to the VM definitions) and puts the output in output-file *)

    
let main () =
  try
    let _ = print_string "Linking...\n"; flush stdout; in
    let n_files = 
      if Array.length Sys.argv >= 5 then 2 else 1
    in
    let filename_list = 
      if n_files = 2 then
        [Sys.argv.(1); Sys.argv.(3)]
      else 
        [Sys.argv.(1)]
    in
    let outfile_names = 
      if n_files = 2 then
        [Sys.argv.(2); Sys.argv.(4)]
      else
        [Sys.argv.(2)]
    in
    let file_list = List.map open_in filename_list in
    let lexing_list = List.map Lexing.from_channel file_list in
      
      (** Parse files *)

    let linkst_list = List.map (Bmdlinkparse.link Bmdlinklex.bmd_link) lexing_list in
      
      (** Extract code *)

    let code_list = List.map extract_code linkst_list in
      
      (** Check for consistency *)

    let defs = List.flatten (vm_defs::linkst_list) in
    let _ = check_consistent defs in
      
      (** Output code *)

    let outfile_list = List.map open_out outfile_names in
    let _ = List.map2 output_string outfile_list code_list in
    let _ = print_string "Done.\n"; flush stdout; in
      ()
  with
      Parsing.Parse_error ->
        begin
          print_string "Syntax error on line ";
          print_string (string_of_int !Bmdlinklex.line_count);
          print_string ".\n";
          flush stdout;
        end
    | Sys_error s ->
        begin
          print_string "Could not open input file ";
          print_string s;
          print_string "\n";
          flush stdout;
        end
    | Linker_error s ->
        begin
          print_string s;
          print_string "\n";
          flush stdout;
        end
    | _ ->
        begin
          print_string "Linking failed.\n";
        end

let _ = main ()