import time
import ctypes
from llvmlite import ir, binding


def build_and_compile():
    # Required in newer llvmlite builds to register codegen targets.
    binding.initialize_all_targets()
    binding.initialize_all_asmprinters()

    module = ir.Module(name="benchmark")
    triple = binding.get_process_triple()
    module.triple = triple

    fnty = ir.FunctionType(ir.DoubleType(), [ir.IntType(64)])
    fn = ir.Function(module, fnty, name="sum_of_squares")

    entry = fn.append_basic_block("entry")
    loop = fn.append_basic_block("loop")
    exit_ = fn.append_basic_block("exit")

    builder = ir.IRBuilder(entry)
    n = fn.args[0]
    builder.branch(loop)

    builder.position_at_end(loop)
    i = builder.phi(ir.IntType(64), name="i")
    total = builder.phi(ir.DoubleType(), name="total")
    i.add_incoming(ir.Constant(ir.IntType(64), 0), entry)
    total.add_incoming(ir.Constant(ir.DoubleType(), 0.0), entry)

    i_f = builder.sitofp(i, ir.DoubleType())
    square = builder.fmul(i_f, i_f)
    new_total = builder.fadd(total, square)
    new_i = builder.add(i, ir.Constant(ir.IntType(64), 1))

    i.add_incoming(new_i, loop)
    total.add_incoming(new_total, loop)

    cond = builder.icmp_signed("<", new_i, n)
    builder.cbranch(cond, loop, exit_)

    builder.position_at_end(exit_)
    builder.ret(new_total)

    llvm_ir = str(module)
    llvm_module = binding.parse_assembly(llvm_ir)
    llvm_module.verify()

    target = binding.Target.from_triple(triple)
    target_machine = target.create_target_machine(opt=3)

    engine = binding.create_mcjit_compiler(llvm_module, target_machine)
    engine.finalize_object()
    engine.run_static_constructors()

    func_ptr = engine.get_function_address("sum_of_squares")
    cfunc = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_int64)(func_ptr)
    return cfunc, engine


N = 10_000_000

cfunc, _engine = build_and_compile()

start = time.perf_counter()
result = cfunc(N)
elapsed = time.perf_counter() - start

print("LLVM IR JIT (opt=3)")
print(f"  N       = {N:,}")
print(f"  Result  = {result:.6e}")
print(f"  Time    = {elapsed:.6f}s")