(** Bombs-Must-Detonate: IR Compiler @author: Brian Go*)


(** Final instruction *)

type instruction =
(** 0-arity *)

    Push
  | Pop
  | Read
  | Print
  | PrintLn
  | Swap
  | Stop
  | Return
  | Cons
  | Car
  | Cdr
  | IsNull
  | Nil
  | GetElem
  | SetElem
  | And
  | Or
  | Add
  | Sub
  | Mul
  | Div
  | DivI
  | RemI
  | Concat
  | Neg
  | Frac
  | Int
  | Lt
  | Gt
  | Lte
  | Gte
  | Eq
  | Neq
  | ConstUninit
  | Apply
(** 1-arity *)

  | ConstInt of int
  | ConstFloat of float
  | ConstString of string
  | PushSf of string
  | Assign of int
  | Acc of int
  | Rev of int
  | MakeBlock of int
  | MakeBlockFilled of int
  | AllocFields of int
  | GetField of int
  | SetField of int
  | Call of string
  | Jmp of string
  | Jz of string
  | Jnz of string  | RPC of string
  | Label of string
  | Comment of string
  | Annotation of string
  | GetElemStatic of int
  | SetElemStatic of int
(** 2-arity *)

  | MakeBlockStatic of int * int

(** IR Instruction *)

type ir_instruction =
    Instruction of instruction
(** Ir-Specific *)

  | GetVar of string
  | SetVar of string
  | NoteVar of string * int
  | NoteFunction of string
  | If of ir_buffer * ir_buffer * ir_buffer 
          (** test, then, else *)

  | For of ir_buffer * ir_buffer * ir_buffer * ir_buffer 
  | While of ir_buffer * ir_buffer
  | DoWhile of ir_buffer * ir_buffer
  | Break
  | Continue
  | BeginScope
  | EndScope
  | EndScopeSf
  | EndScopeSfRPC
  | DeclareGlobal of string

and ir_buffer = ir_instruction list

(** Shorthand for String.concat " *)

let concat s_list = String.concat "" s_list

exception Ircompile_error of string

(** Keeps track of the buffer of code used to initialize structs/arrays/lists. We need to initialize values of these types as null but of the correct structure. i.e. someStructType someStruct; stomeStruct.someField = someValue; should be valid. *)

let init_buffer = ref []

(** Raises an IR compile error by concatenating s_list *)

let raise_compile_error s_list = 
  raise (Ircompile_error (concat s_list))

(** Concatenates n repetitions of the given list *)

let rec repeat_list lst n = 
  if n = 0 then []
  else if n = 1 then lst
  else if n > 1 then lst @ (repeat_list lst (n-1))
  else raise (Invalid_argument "repeat_list must have a nonnegative argument")

(** Used by make_label *)

let label_id = ref (-1)

(** Generate labels with unique prefixes *)

let make_label str = 
  let _ = label_id := !label_id + 1 in
    concat ["_";string_of_int !label_id;"_";str]

(** Takes an argument list and notes the positions *)

let rec note_arg_list lst index = match lst with
    (_,varname)::rest -> (NoteVar (varname,index))::(note_arg_list rest (index+1))
  | [] -> []

(** Compiles the initialization code of data type dt *)

let rec compile_datatype_init dt = match dt with
    (Ast.SynIntType
    | Ast.SynFloatType
    | Ast.SynStringType
    | Ast.SynBoolType
    | Ast.SynVoidType
    | Ast.SynRefType _) -> [Instruction ConstUninit]
  | Ast.SynEnumOrStructType (name,field_types) -> 
      (** field_types is empty if its an enum/template *)

      if List.length !field_types > 0 then
        let n_fields = List.length !field_types in
        let fields_result = compile_datatypes_pushing !field_types in
          fields_result
          @ [Instruction (Rev n_fields)]
          @ [Instruction (MakeBlockStatic (1,n_fields))]
      else
        [Instruction ConstUninit]
  | Ast.SynArrayType (t_arr,vprod) ->
      let dt_code = compile_datatype_init t_arr in
      let vprod_code = compile_value_producer vprod in
        dt_code 
        @ [Instruction Push]
        @ vprod_code
        @ [Instruction (MakeBlockFilled 0)]
  | Ast.SynListType _ -> 
      [Instruction Nil]
  | Ast.SynArrowType (_,_) -> [Instruction (ConstString "")]

(** Compiles each data type in dt_list, inserting push instructions between each compilation result *)

and compile_datatypes_pushing dt_list = match dt_list with
    dt::rest -> (compile_datatype_init dt)@[Instruction Push]@(compile_datatypes_pushing rest)
  | [] -> []

(** Compiles an expression *)

and compile_expression expr = match expr with
    Ast.SynVarDeclare vdecl -> compile_variable_declaration vdecl
  | Ast.SynEnumDeclare edecl -> []
  | Ast.SynVarAssign vasgn -> compile_variable_assignment vasgn
  | Ast.SynCond cond -> compile_conditional cond
  | Ast.SynLoop loop -> compile_loop loop
  | Ast.SynFunctionCall funcall -> compile_function_call funcall
  | Ast.SynReturnStatement ret -> compile_return ret
  | Ast.SynBreak -> [Break]
  | Ast.SynContinue -> [Continue]

(** Compiles a list of expressions *)

and compile_expression_list expr_list = match expr_list with
    cur::rest -> 
      let curbuf = compile_expression cur in
      let restbuf = compile_expression_list rest in
        curbuf@restbuf
  | [] -> []


(** Compiles a variable declarations *)

and compile_variable_declaration vdecl = match vdecl with
    Ast.SynVarDeclareNoInit (dt, varname) -> 
      let dt_init = compile_datatype_init dt in
        dt_init @ [SetVar varname]
  | Ast.SynVarDeclareWithInit(_, varname, vprod) -> 
      let valbuf = compile_value_producer vprod in
        valbuf@[SetVar varname]

(** Compiles a conditional. Would be nice to add short-circuiting in the future. *)

and compile_conditional cond = match cond with
    Ast.SynIf (vp,expr_list) -> 
      let vp_result = compile_value_producer vp in
      let expr_result = compile_expression_list expr_list in
        [If (vp_result, expr_result, [])]
  | Ast.SynIfCase (vp,expr_list,continued_cond) ->
      let vp_result = compile_value_producer vp in
      let expr_result = compile_expression_list expr_list in
      let continued_result = compile_continued_conditional continued_cond in
        [If (vp_result, expr_result, continued_result)]

(** Compiles an else case *)

and compile_continued_conditional cond = match cond with
    Ast.SynFinalElse expr_list -> compile_expression_list expr_list
  | Ast.SynElse cond -> compile_conditional cond

(** Compiles a loop *)

and compile_loop loop = match loop with
    Ast.SynWhile (vprod,lexpr_list) -> 
      let vp_result = compile_value_producer vprod in
      let lexpr_result = compile_expression_list lexpr_list in
        [While (vp_result,lexpr_result)]
  | Ast.SynFor (init,test,step,body) -> 
      let init_result = compile_expression init in
      let test_result = compile_value_producer test in 
      let step_result = compile_expression step in
      let body_result = compile_expression_list body in
        [For (init_result, test_result, step_result, body_result)]
  | Ast.SynDoWhile (lexpr_list, vprod) ->
      let vp_result = compile_value_producer vprod in
      let lexpr_result  = compile_expression_list lexpr_list in
        [While (lexpr_result,vp_result)]

(** Compiles a value *)

and compile_value value = match value with 
    Ast.SynIntValue i -> [Instruction (ConstInt i)]
  | Ast.SynFloatValue f -> [Instruction (ConstFloat f)]
  | Ast.SynBoolValue b -> [Instruction (ConstInt (if b then 1 else 0))]
  | Ast.SynStringValue s -> [Instruction (ConstString s)]

(** Compiles a value produder *)

and compile_value_producer vprod = match vprod with 
    Ast.SynValue value -> compile_value value
  | Ast.SynFunctionCallValue funcall -> compile_function_call funcall
  | Ast.SynVarIdentifier vident -> compile_get_variable_identifier vident
  | Ast.SynBinop (vp1,op,vp2) -> compile_binary_operation vp1 op vp2
  | Ast.SynPrefixUnop (op,vp) -> compile_prefix_unary_operation op vp
  | Ast.SynParenthesized vprod -> compile_value_producer vprod
  | Ast.SynArrayValueProducer vprod_list ->
      let n = List.length vprod_list in
      let values_result = compile_value_producer_list_pushing vprod_list in
        values_result 
        @ [Instruction (Rev n);
           Instruction (MakeBlockStatic (0,n))]
  | Ast.SynListValueProducer lvprod -> compile_list_value_producer lvprod

(** Compiles each value producer in order, pushing the results onto the stack *)

and compile_value_producer_list_pushing vp_list = match vp_list with
    cur::rest -> 
      let cur_result = compile_value_producer cur in
      let rest_result= compile_value_producer_list_pushing rest in
        cur_result
        @ [Instruction Push]
        @ rest_result
  | [] -> []

(** Compiles a list value producer *)

and compile_list_value_producer lvprod = match lvprod with
    Ast.SynListNil _ -> [Instruction Nil]
  | Ast.SynListList vprod_list -> 
      let result = compile_value_producer_list_pushing vprod_list in
        result
        @ [Instruction Nil]
        @ (repeat_list [Instruction Cons] (List.length vprod_list))
  | Ast.SynListCons (vp1,vp2) ->
      let result_car = compile_value_producer vp1 in
      let result_cdr = compile_value_producer vp2 in
        result_car
        @ [Instruction Push]
        @ result_cdr
        @ [Instruction Cons]

(** Compiles a variable assignment *)

and compile_variable_assignment vasgn = match vasgn with 
    Ast.SynVarAssignment (vident,vprod) ->
      let vp_result = compile_value_producer vprod in
      let setvar_result = compile_set_variable_identifier vident in
        vp_result @ setvar_result
  | Ast.SynVarModify (vident, op, vprod) ->
      let vp1, vp2 = Ast.SynVarIdentifier vident, vprod
      in
      let op_result = compile_value_producer 
        (Ast.SynBinop (vp1, op, vp2)) in
      let setvar_result = compile_set_variable_identifier vident in
        op_result @ setvar_result

(** Produces code that sets the variable identifier to the value in the accumulator. Leaves the value in the accumulator *)

and compile_set_variable_identifier vident =
  match vident with
      Ast.SynVarName s -> [SetVar s]
    | Ast.SynStructOrEnumValue (vid1,s2,ftype) ->
        (match !ftype with
             Ast.VIdStructField index ->
               let getstruct = compile_get_variable_identifier vid1 in
                 [Instruction Push]
                 @ getstruct
                 @ [Instruction (SetElemStatic index)] 
                     (** structs are reference types; we're done at this point *)

           | Ast.VIdEnumValue _ ->
               raise_compile_error ["Unexpected compile error: cannot set an enum value."]
           | Ast.VIdUnspecified ->
               raise_compile_error ["Unexpected compile error: field index ";s2;" was not determined by type checker."])
    | Ast.SynArrayCell (vid,vprod) ->
        let vid_result = compile_get_variable_identifier vid in
        let vprod_result = compile_value_producer vprod in
          [Instruction Push]
          @ vprod_result
          @ [Instruction Push]
          @ vid_result
          @ [Instruction SetElem]

(** Get a variable identifier into the accumulator *)

and compile_get_variable_identifier vident =
  match vident with
      Ast.SynVarName s -> [GetVar s]
    | Ast.SynStructOrEnumValue (vid1,s2,ftype) ->  
        (match !ftype with
             Ast.VIdStructField index ->
               let getstruct = compile_get_variable_identifier vid1 in
                 getstruct @ [Instruction (GetElemStatic index)]
           | Ast.VIdEnumValue index ->
               [Instruction (ConstInt index)]
           | Ast.VIdUnspecified -> 
               raise_compile_error ["Unexpected compile error: field index ";s2;" was not determined by type checker."])
    | Ast.SynArrayCell (vid,vprod) ->
        let vid_result = compile_get_variable_identifier vid in
        let vprod_result = compile_value_producer vprod in
          vprod_result 
          @ [Instruction Push]
          @ vid_result
          @ [Instruction GetElem]

(** Compiles a binary operation expression *)

and compile_binary_operation vp1 op vp2 = 
  let vp1_result = compile_value_producer vp1 in
  let vp2_result = compile_value_producer vp2 in
  let op_command = match op with
      Ast.SynBinopAnd -> [Instruction And]
    | Ast.SynBinopOr -> [Instruction Or]
    | Ast.SynBinopAdd -> [Instruction Add]
    | Ast.SynBinopSub -> [Instruction Sub]
    | Ast.SynBinopMul -> [Instruction Mul]
    | Ast.SynBinopDiv -> [Instruction Div]
    | Ast.SynBinopIDiv ->[Instruction DivI]
    | Ast.SynBinopMod -> [Instruction RemI]
    | Ast.SynBinopConcat -> [Instruction Concat]
    | Ast.SynBinopCons -> [Instruction Cons]
    | Ast.SynCompLt -> [Instruction Lt]
    | Ast.SynCompGt -> [Instruction Gt]
    | Ast.SynCompLte ->[Instruction Lte]
    | Ast.SynCompGte ->[Instruction Gte]
    | Ast.SynCompEq -> [Instruction Eq]
    | Ast.SynCompNeq ->[Instruction Neq]
  in
    vp2_result 
    @ [Instruction Push]
    @ vp1_result
    @ op_command
                         
(** Compiles unary operations *)
   
and compile_prefix_unary_operation op vp =
    let vp_result = compile_value_producer vp in
    let op_command = match op with 
        Ast.SynUnopNot -> [Instruction Push;
                           Instruction (ConstInt 1);
                           Instruction Add;
                           Instruction Push;
                           Instruction (ConstInt 2);
                           Instruction Swap;
                           Instruction RemI]
      | Ast.SynUnopCar -> [Instruction Car]
      | Ast.SynUnopCdr -> [Instruction Cdr]
      | Ast.SynUnopTrunc -> [Instruction Int]
      | Ast.SynUnopDeref -> [Instruction (GetElemStatic 0)]
      | Ast.SynUnopNeg -> [Instruction Neg]
      | Ast.SynUnopNull -> [Instruction IsNull]
    in
      vp_result @ op_command

(** Compiles a function call. Doesn't use tail-recursion -- this would be nice to add in the future. *)

and compile_function_call funcall = match funcall with
    Ast.SynLocalCall (fname,vprod_list,isvar) -> 
      let n_args = List.length vprod_list in
      let args_result = compile_value_producer_list_pushing vprod_list in
      let ret_label = make_label (concat ["returnfrom_local_";fname]) in
      let call_instruction = 
        if !isvar then 
          [GetVar fname;
           Instruction Apply]
        else
          [Instruction (Call fname)]
      in
        [Instruction (Comment (concat ["Local call to ";fname]));
         Instruction (PushSf ret_label);
         BeginScope]
        @ args_result
        @ [Instruction (Rev n_args)]
        @ call_instruction
        @ [EndScopeSf;
           Instruction (Label ret_label)]
  | Ast.SynRemoteCall (fname,vprod_list,ret_var) ->
      let n_args = List.length vprod_list in
      let args_result = compile_value_producer_list_pushing vprod_list in
      let ret_label = make_label (concat ["returnfrom_remote_withresult_";fname]) in
      let done_label = make_label (concat ["done_interpreting_remote_results_";fname]) in
      let vident_set_result = compile_set_variable_identifier ret_var in
        [Instruction (Comment (concat ["Remote call with result to ";fname]));
         Instruction (PushSf ret_label);
         BeginScope]
        @ args_result
        @ [Instruction (Rev n_args);
           Instruction (RPC fname);
           EndScopeSfRPC;
           Instruction (Label ret_label); 
              (** Acc || Stack = success || result *)

           Instruction (Jz done_label);
           Instruction Swap
             (** result || success *)

        @ vident_set_result 
          (** ? || success *)

        @ [Instruction Swap
           (** success || ? *)

           Instruction (Label done_label); 
              (** success || ? *)

           Instruction Swap
             (** ? || success *)

           Instruction (Acc 0); 
              (** success || success *)

           Instruction Pop
             (** success || *)

           
  | Ast.SynRemoteCallNoResult (fname, vprod_list) ->
      let n_args = List.length vprod_list in
      let args_result = compile_value_producer_list_pushing vprod_list in
      let ret_label = make_label (concat ["returnfrom_local_";fname]) in
        [Instruction (Comment (concat ["Remote call with no result to ";fname]));
         Instruction (PushSf ret_label);
         BeginScope]
        @ args_result
        @ [Instruction (Rev n_args);
           Instruction (RPC fname);
           EndScopeSfRPC;
           Instruction (Label ret_label); 
              (** success || result *)

           Instruction Swap
             (** result || success *)

           Instruction (Acc 0); 
              (** success || success *)

           Instruction Pop
             (** success || *)

              
(** Compiles a return statement *)

and compile_return ret = match ret with
    Ast.SynVoidReturn ->
      [Instruction ConstUninitInstruction Return
  | Ast.SynValueReturn vprod ->
      let vp_result = compile_value_producer vprod in
        vp_result@[Instruction Return]

(** Compiles a function definition. Maybe this should have gone to IR first, but it works nicely as is. *)

let rec compile_function_definition name arg_list expr_list = 
  let label = make_label name in
  let args = note_arg_list arg_list 0 in
  let body_code = compile_expression_list expr_list in
    [NoteFunction name;
     Instruction (Annotation (concat ["function ";name;" ";string_of_int (List.length arg_list)])); 
           (** annotate *)

     Instruction (Label label);
     BeginScope]
    @ args
    @ body_code
    @ [Instruction (Comment (concat ["end function ";name])); 
       Instruction ConstUninit
       Instruction Return;
       EndScope]

(** Compile global expression *)

let rec compile_global_expression expr = match expr with
    Ast.SynFunctionDeclare (dt, name, arglist) -> [NoteFunction name]
  | Ast.SynRemotableFunctionDeclare (dt, name, arglist) -> []
  | Ast.SynFunctionDefine ((dt, name, arglist), expr_list) -> 
        compile_function_definition name arglist expr_list
  | Ast.SynTemplatedDeclare (_,expr) -> compile_global_expression expr
  | Ast.SynTemplatedDefine (_,expr) -> compile_global_expression expr
  | Ast.SynStructDeclare sdecl -> []
  | Ast.SynGlobalEnumDeclare edecl -> []
  | Ast.SynGlobalVarDeclare vdecl -> 
      (match vdecl with
           Ast.SynVarDeclareNoInit (dt,v) -> 
             let dt_result = compile_datatype_init dt in
             let _ = init_buffer := !init_buffer @ dt_result @ [SetVar v] in
               [DeclareGlobal v]
         | _ -> raise_compile_error ["Unexpected compile error: cannot initialize a variable in the global scope."])                          
  | Ast.SynStateMachine _ ->
      raise_compile_error ["Unexpected compile error: source-to-source compilation of state machine failed."]
  | Ast.SynInclude data -> 
      match !data with
          Ast.IncludeAst ast -> compile_global_expression_list ast
        | Ast.IncludeFileName s -> 
            raise_compile_error ["Unexpected compile error: include ";s;" was not typechecked properly."]
            
(** Compile a global expression list *)

and compile_global_expression_list src_prog = 
  match src_prog with
      cur::rest -> 
        let curbuf = compile_global_expression cur in
        let restbuf = compile_global_expression_list rest in
          curbuf@restbuf
    | [] -> []

(** Compile a source program to IR representation, including the init buffer *)

and compile_ir src_prog =
  try
    let _ = init_buffer := [] in
    let init_label = make_label "INIT" in
    let program_result = compile_global_expression_list src_prog in
      [Instruction (Annotation (concat ["init ";init_label]))]
      @ program_result
      @ [Instruction (Annotation (concat ["function ";init_label;" 0"]));
         Instruction (Label (init_label))]
      @ !init_buffer
      @ [Instruction Return]
  with
      Ircompile_error s ->
        let _ = output_string stderr s in
        let _ = output_string stderr "\n" in
        let _ = flush stderr in
          raise (Ircompile_error s)

(** Produces a string representation of the given instruction *)

let string_of_instruction inst = match inst with
    Push -> "Push"
  | Pop -> "Pop"
  | Read -> "Read"
  | Print -> "Print"
  | PrintLn -> "PrintLn"
  | Swap -> "Swap"
  | Stop -> "Stop"
  | Return -> "Return"
  | Cons -> "Cons"
  | Car -> "Car"
  | Cdr -> "Cdr"
  | IsNull -> "IsNull"
  | Nil -> "Nil"
  | GetElem -> "GetElem"
  | SetElem -> "SetElem"
  | And -> "And"
  | Or -> "Or"
  | Add -> "Add"
  | Sub -> "Sub"
  | Mul -> "Mul"
  | Div -> "Div"
  | DivI -> "DivI"
  | RemI -> "RemI"
  | Concat -> "Concat"
  | Neg -> "Neg"
  | Frac -> "Frac"
  | Int -> "Int"
  | Lt -> "Lt"
  | Gt -> "Gt"
  | Lte -> "Lte"
  | Gte -> "Gte"
  | Eq -> "Eq"
  | Neq -> "Neq"
  | ConstUninit -> "Const"
  | ConstInt i -> concat ["Const ";string_of_int i]
  | ConstFloat f -> concat ["Const ";string_of_float f]
  | ConstString s -> concat ["Const \"";s;"\""]
  | PushSf s -> concat ["PushSf ";s]
  | Assign i -> concat ["Assign ";string_of_int i] 
  | Acc i -> concat ["Acc ";string_of_int i]
  | Rev i -> concat ["Rev ";string_of_int i]
  | MakeBlock i -> concat ["MakeBlock ";string_of_int i]
  | AllocFields i -> concat ["-allocfields, ";string_of_int i]
  | GetField i -> concat ["GetField ";string_of_int i]
  | SetField i -> concat ["SetField ";string_of_int i]
  | Call s -> concat ["Call ";s]
  | Jmp s -> concat ["Jmp ";s]
  | Jz s -> concat ["Jz ";s]
  | Jnz s -> concat ["Jnz ";s]
  | RPC s -> concat ["RPC ";s]
  | Label s -> concat [s;": "]
  | Comment s -> concat ["//";s]
  | Annotation s -> concat ["-";s]
  | GetElemStatic i -> concat ["GetElem ";string_of_int i]
  | SetElemStatic i -> concat ["SetElem ";string_of_int i]
  | MakeBlockStatic (i,j) -> concat ["MakeBlockStatic ";string_of_int i;" ";string_of_int j]
  | MakeBlockFilled n -> concat ["MakeFilledBlock ";string_of_int n]
  | Apply -> "Apply"

(** Produces a string representation of the given IR instruction *)

let rec string_of_ir_instruction inst = match inst with
    Instruction i -> string_of_instruction i
  | GetVar s -> concat ["GetVar ";s]
  | SetVar s -> concat ["SetVar ";s]
  | NoteVar (s,i) -> concat ["NoteVar ";s;" ";string_of_int i]
  | If (test,body,elsebody) -> concat 
      ["IF\n";
       string_of_ir_buffer test;
       "\nTHEN\n";
       string_of_ir_buffer body;
       "\nELSE\n";
       string_of_ir_buffer elsebody;
       "\nENDIF\n"]
  | For (init,test,step,body) -> concat
      ["FOR\nINIT\n";
       string_of_ir_buffer init;
       "\nTEST IF TRUE\n";
       string_of_ir_buffer test;
       "\nSTEP\n";
       string_of_ir_buffer step;
       "\nDO\n";
       string_of_ir_buffer body;
       "\nENDFOR\n"]
  | While (test,body) -> concat
      ["WHILE\nTEST IF TRUE\n";
       string_of_ir_buffer test;
       "\nDO\n";
       string_of_ir_buffer body;
       "\nENDWHILE\n"]
  | DoWhile (body,test) -> concat
      ["WHILE\nTEST IF TRUE\n";
       string_of_ir_buffer test;
       "\nDO\n";
       string_of_ir_buffer body;
       "\nENDWHILE\n"]
  | Break -> "Break"
  | Continue -> "Continue"
  | DeclareGlobal s -> concat ["DeclareGlobal ";s]
  | EndScope -> "EndScope"
  | EndScopeSf -> "EndScopeSf"
  | BeginScope -> "BeginScope"
  | EndScopeSfRPC -> "EndScopeSfRPC"
  | NoteFunction s -> concat ["NoteFunction ";s]
  

(** Produces a string representation of the given IR instruction buffer *)

and string_of_ir_buffer buf = match buf with
    cur::rest -> concat [(string_of_ir_instruction cur);"\n";(string_of_ir_buffer rest)]
  | [] -> ""

(** Prints the IR buffer to the given output channel *)

and print_ir_buffer chan buf =
  output_string chan (string_of_ir_buffer buf)