C++程序  |  1932行  |  91.16 KB

/*
 * Copyright 2018 Google Inc.
 *
 * Use of this source code is governed by a BSD-style license that can be
 * found in the LICENSE file.
 */

#ifndef SKSL_STANDALONE

#ifdef SK_LLVM_AVAILABLE

#include "SkSLJIT.h"

#include "SkCpu.h"
#include "SkRasterPipeline.h"
#include "ir/SkSLAppendStage.h"
#include "ir/SkSLExpressionStatement.h"
#include "ir/SkSLFunctionCall.h"
#include "ir/SkSLFunctionReference.h"
#include "ir/SkSLIndexExpression.h"
#include "ir/SkSLProgram.h"
#include "ir/SkSLUnresolvedFunction.h"
#include "llvm/ExecutionEngine/RTDyldMemoryManager.h"

static constexpr int MAX_VECTOR_COUNT = 16;

extern "C" void sksl_pipeline_append(SkRasterPipeline* p, int stage, void* ctx) {
    p->append((SkRasterPipeline::StockStage) stage, ctx);
}

#define PTR_SIZE sizeof(void*)

extern "C" void sksl_pipeline_append_callback(SkRasterPipeline* p, void* fn) {
    p->append(fn, nullptr);
}

extern "C" void sksl_debug_print(float f) {
    printf("Debug: %f\n", f);
}

extern "C" float sksl_clamp1(float f, float min, float max) {
    return SkTPin(f, min, max);
}

using float2 = __attribute__((vector_size(8))) float;
using float3 = __attribute__((vector_size(16))) float;
using float4 = __attribute__((vector_size(16))) float;

extern "C" float2 sksl_clamp2(float2 f, float min, float max) {
    return float2 { SkTPin(f[0], min, max), SkTPin(f[1], min, max) };
}

extern "C" float3 sksl_clamp3(float3 f, float min, float max) {
    return float3 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max) };
}

extern "C" float4 sksl_clamp4(float4 f, float min, float max) {
    return float4 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max),
                    SkTPin(f[3], min, max) };
}

namespace SkSL {

static constexpr int STAGE_PARAM_COUNT = 12;

static bool ends_with_branch(const Statement& stmt) {
    switch (stmt.fKind) {
        case Statement::kBlock_Kind: {
            const Block& b = (const Block&) stmt;
            if (b.fStatements.size()) {
                return ends_with_branch(*b.fStatements.back());
            }
            return false;
        }
        case Statement::kBreak_Kind:    // fall through
        case Statement::kContinue_Kind: // fall through
        case Statement::kReturn_Kind:   // fall through
            return true;
        default:
            return false;
    }
}

JIT::JIT(Compiler* compiler)
: fCompiler(*compiler) {
    LLVMInitializeNativeTarget();
    LLVMInitializeNativeAsmPrinter();
    LLVMLinkInMCJIT();
    SkASSERT(!SkCpu::Supports(SkCpu::SKX)); // not yet supported
    if (SkCpu::Supports(SkCpu::HSW)) {
        fVectorCount = 8;
        fCPU = "haswell";
    } else if (SkCpu::Supports(SkCpu::AVX)) {
        fVectorCount = 8;
        fCPU = "ivybridge";
    } else {
        fVectorCount = 4;
        fCPU = nullptr;
    }
    fContext = LLVMContextCreate();
    fVoidType = LLVMVoidTypeInContext(fContext);
    fInt1Type = LLVMInt1TypeInContext(fContext);
    fInt1VectorType = LLVMVectorType(fInt1Type, fVectorCount);
    fInt1Vector2Type = LLVMVectorType(fInt1Type, 2);
    fInt1Vector3Type = LLVMVectorType(fInt1Type, 3);
    fInt1Vector4Type = LLVMVectorType(fInt1Type, 4);
    fInt8Type = LLVMInt8TypeInContext(fContext);
    fInt8PtrType = LLVMPointerType(fInt8Type, 0);
    fInt32Type = LLVMInt32TypeInContext(fContext);
    fInt64Type = LLVMInt64TypeInContext(fContext);
    fSizeTType = LLVMInt64TypeInContext(fContext);
    fInt32VectorType = LLVMVectorType(fInt32Type, fVectorCount);
    fInt32Vector2Type = LLVMVectorType(fInt32Type, 2);
    fInt32Vector3Type = LLVMVectorType(fInt32Type, 3);
    fInt32Vector4Type = LLVMVectorType(fInt32Type, 4);
    fFloat32Type = LLVMFloatTypeInContext(fContext);
    fFloat32VectorType = LLVMVectorType(fFloat32Type, fVectorCount);
    fFloat32Vector2Type = LLVMVectorType(fFloat32Type, 2);
    fFloat32Vector3Type = LLVMVectorType(fFloat32Type, 3);
    fFloat32Vector4Type = LLVMVectorType(fFloat32Type, 4);
}

JIT::~JIT() {
    LLVMOrcDisposeInstance(fJITStack);
    LLVMContextDispose(fContext);
}

void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType,
                             std::vector<LLVMTypeRef> parameters) {
    bool found = false;
    for (const auto& pair : *fProgram->fSymbols) {
        if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) {
            const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second;
            if (pair.first != ourName || returnType != this->getType(f.fReturnType) ||
                parameters.size() != f.fParameters.size()) {
                continue;
            }
            for (size_t i = 0; i < parameters.size(); ++i) {
                if (parameters[i] != this->getType(f.fParameters[i]->fType)) {
                    goto next;
                }
            }
            fFunctions[&f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(returnType,
                                                                                 parameters.data(),
                                                                                 parameters.size(),
                                                                                 false));
            found = true;
        }
        if (Symbol::kUnresolvedFunction_Kind == pair.second->fKind) {
            // FIXME consolidate this with the code above
            for (const auto& f : ((const UnresolvedFunction&) *pair.second).fFunctions) {
                if (pair.first != ourName || returnType != this->getType(f->fReturnType) ||
                    parameters.size() != f->fParameters.size()) {
                    continue;
                }
                for (size_t i = 0; i < parameters.size(); ++i) {
                    if (parameters[i] != this->getType(f->fParameters[i]->fType)) {
                        goto next;
                    }
                }
                fFunctions[f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(
                                                                                  returnType,
                                                                                  parameters.data(),
                                                                                  parameters.size(),
                                                                                  false));
                found = true;
            }
        }
        next:;
    }
    SkASSERT(found);
}

void JIT::loadBuiltinFunctions() {
    this->addBuiltinFunction("abs", "fabs", fFloat32Type, { fFloat32Type });
    this->addBuiltinFunction("sin", "sinf", fFloat32Type, { fFloat32Type });
    this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type });
    this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type });
    this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type });
    this->addBuiltinFunction("clamp", "sksl_clamp1", fFloat32Type, { fFloat32Type,
                                                                     fFloat32Type,
                                                                     fFloat32Type });
    this->addBuiltinFunction("clamp", "sksl_clamp2", fFloat32Vector2Type, { fFloat32Vector2Type,
                                                                            fFloat32Type,
                                                                            fFloat32Type });
    this->addBuiltinFunction("clamp", "sksl_clamp3", fFloat32Vector3Type, { fFloat32Vector3Type,
                                                                            fFloat32Type,
                                                                            fFloat32Type });
    this->addBuiltinFunction("clamp", "sksl_clamp4", fFloat32Vector4Type, { fFloat32Vector4Type,
                                                                            fFloat32Type,
                                                                            fFloat32Type });
    this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type });
}

uint64_t JIT::resolveSymbol(const char* name, JIT* jit) {
    LLVMOrcTargetAddress result;
    if (!LLVMOrcGetSymbolAddress(jit->fJITStack, &result, name)) {
        if (!strcmp(name, "_sksl_pipeline_append")) {
            result = (uint64_t) &sksl_pipeline_append;
        } else if (!strcmp(name, "_sksl_pipeline_append_callback")) {
            result = (uint64_t) &sksl_pipeline_append_callback;
        } else if (!strcmp(name, "_sksl_clamp1")) {
            result = (uint64_t) &sksl_clamp1;
        } else if (!strcmp(name, "_sksl_clamp2")) {
            result = (uint64_t) &sksl_clamp2;
        } else if (!strcmp(name, "_sksl_clamp3")) {
            result = (uint64_t) &sksl_clamp3;
        } else if (!strcmp(name, "_sksl_clamp4")) {
            result = (uint64_t) &sksl_clamp4;
        } else if (!strcmp(name, "_sksl_debug_print")) {
            result = (uint64_t) &sksl_debug_print;
        } else {
            result = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name);
        }
    }
    SkASSERT(result);
    return result;
}

LLVMValueRef JIT::compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc) {
    LLVMValueRef func = fFunctions[&fc.fFunction];
    SkASSERT(func);
    std::vector<LLVMValueRef> parameters;
    for (const auto& a : fc.fArguments) {
        parameters.push_back(this->compileExpression(builder, *a));
    }
    return LLVMBuildCall(builder, func, parameters.data(), parameters.size(), "");
}

LLVMTypeRef JIT::getType(const Type& type) {
    switch (type.kind()) {
        case Type::kOther_Kind:
            if (type.name() == "void") {
                return fVoidType;
            }
            SkASSERT(type.name() == "SkRasterPipeline");
            return fInt8PtrType;
        case Type::kScalar_Kind:
            if (type.isSigned() || type.isUnsigned()) {
                return fInt32Type;
            }
            if (type.isUnsigned()) {
                return fInt32Type;
            }
            if (type.isFloat()) {
                return fFloat32Type;
            }
            SkASSERT(type.name() == "bool");
            return fInt1Type;
        case Type::kArray_Kind:
            return LLVMPointerType(this->getType(type.componentType()), 0);
        case Type::kVector_Kind:
            if (type.name() == "float2" || type.name() == "half2") {
                return fFloat32Vector2Type;
            }
            if (type.name() == "float3" || type.name() == "half3") {
                return fFloat32Vector3Type;
            }
            if (type.name() == "float4" || type.name() == "half4") {
                return fFloat32Vector4Type;
            }
            if (type.name() == "int2" || type.name() == "short2" || type.name == "byte2") {
                return fInt32Vector2Type;
            }
            if (type.name() == "int3" || type.name() == "short3" || type.name == "byte3") {
                return fInt32Vector3Type;
            }
            if (type.name() == "int4" || type.name() == "short4" || type.name == "byte3") {
                return fInt32Vector4Type;
            }
            // fall through
        default:
            ABORT("unsupported type");
    }
}

void JIT::setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block) {
    fCurrentBlock = block;
    LLVMPositionBuilderAtEnd(builder, block);
}

std::unique_ptr<JIT::LValue> JIT::getLValue(LLVMBuilderRef builder, const Expression& expr) {
    switch (expr.fKind) {
        case Expression::kVariableReference_Kind: {
            class PointerLValue : public LValue {
            public:
                PointerLValue(LLVMValueRef ptr)
                : fPointer(ptr) {}

                LLVMValueRef load(LLVMBuilderRef builder) override {
                    return LLVMBuildLoad(builder, fPointer, "lvalue load");
                }

                void store(LLVMBuilderRef builder, LLVMValueRef value) override {
                    LLVMBuildStore(builder, value, fPointer);
                }

            private:
                LLVMValueRef fPointer;
            };
            const Variable* var = &((VariableReference&) expr).fVariable;
            if (var->fStorage == Variable::kParameter_Storage &&
                !(var->fModifiers.fFlags & Modifiers::kOut_Flag) &&
                fPromotedParameters.find(var) == fPromotedParameters.end()) {
                // promote parameter to variable
                fPromotedParameters.insert(var);
                LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
                LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(var->fType),
                                                      String(var->fName).c_str());
                LLVMBuildStore(builder, fVariables[var], alloca);
                LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
                fVariables[var] = alloca;
            }
            LLVMValueRef ptr = fVariables[var];
            return std::unique_ptr<LValue>(new PointerLValue(ptr));
        }
        case Expression::kTernary_Kind: {
            class TernaryLValue : public LValue {
            public:
                TernaryLValue(JIT* jit, LLVMValueRef test, std::unique_ptr<LValue> ifTrue,
                              std::unique_ptr<LValue> ifFalse)
                : fJIT(*jit)
                , fTest(test)
                , fIfTrue(std::move(ifTrue))
                , fIfFalse(std::move(ifFalse)) {}

                LLVMValueRef load(LLVMBuilderRef builder) override {
                    LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
                                                                              fJIT.fContext,
                                                                              fJIT.fCurrentFunction,
                                                                              "true ? ...");
                    LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
                                                                              fJIT.fContext,
                                                                              fJIT.fCurrentFunction,
                                                                              "false ? ...");
                    LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
                                                                            fJIT.fCurrentFunction,
                                                                            "ternary merge");
                    LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
                    fJIT.setBlock(builder, trueBlock);
                    LLVMValueRef ifTrue = fIfTrue->load(builder);
                    LLVMBuildBr(builder, merge);
                    fJIT.setBlock(builder, falseBlock);
                    LLVMValueRef ifFalse = fIfTrue->load(builder);
                    LLVMBuildBr(builder, merge);
                    fJIT.setBlock(builder, merge);
                    LLVMTypeRef type = LLVMPointerType(LLVMTypeOf(ifTrue), 0);
                    LLVMValueRef phi = LLVMBuildPhi(builder, type, "?");
                    LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
                    LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
                    LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
                    return phi;
                }

                void store(LLVMBuilderRef builder, LLVMValueRef value) override {
                    LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
                                                                              fJIT.fContext,
                                                                              fJIT.fCurrentFunction,
                                                                              "true ? ...");
                    LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
                                                                              fJIT.fContext,
                                                                              fJIT.fCurrentFunction,
                                                                              "false ? ...");
                    LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
                                                                            fJIT.fCurrentFunction,
                                                                            "ternary merge");
                    LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
                    fJIT.setBlock(builder, trueBlock);
                    fIfTrue->store(builder, value);
                    LLVMBuildBr(builder, merge);
                    fJIT.setBlock(builder, falseBlock);
                    fIfTrue->store(builder, value);
                    LLVMBuildBr(builder, merge);
                    fJIT.setBlock(builder, merge);
                }

            private:
                JIT& fJIT;
                LLVMValueRef fTest;
                std::unique_ptr<LValue> fIfTrue;
                std::unique_ptr<LValue> fIfFalse;
            };
            const TernaryExpression& t = (const TernaryExpression&) expr;
            LLVMValueRef test = this->compileExpression(builder, *t.fTest);
            return std::unique_ptr<LValue>(new TernaryLValue(this,
                                                             test,
                                                             this->getLValue(builder,
                                                                             *t.fIfTrue),
                                                             this->getLValue(builder,
                                                                             *t.fIfFalse)));
        }
        case Expression::kSwizzle_Kind: {
            class SwizzleLValue : public LValue {
            public:
                SwizzleLValue(JIT* jit, LLVMTypeRef type, std::unique_ptr<LValue> base,
                              std::vector<int> components)
                : fJIT(*jit)
                , fType(type)
                , fBase(std::move(base))
                , fComponents(components) {}

                LLVMValueRef load(LLVMBuilderRef builder) override {
                    LLVMValueRef base = fBase->load(builder);
                    if (fComponents.size() > 1) {
                        LLVMValueRef result = LLVMGetUndef(fType);
                        for (size_t i = 0; i < fComponents.size(); ++i) {
                            LLVMValueRef element = LLVMBuildExtractElement(
                                                                       builder,
                                                                       base,
                                                                       LLVMConstInt(fJIT.fInt32Type,
                                                                                    fComponents[i],
                                                                                    false),
                                                                       "swizzle extract");
                            result = LLVMBuildInsertElement(builder, result, element,
                                                            LLVMConstInt(fJIT.fInt32Type, i, false),
                                                            "swizzle insert");
                        }
                        return result;
                    }
                    SkASSERT(fComponents.size() == 1);
                    return LLVMBuildExtractElement(builder, base,
                                                            LLVMConstInt(fJIT.fInt32Type,
                                                                         fComponents[0],
                                                                         false),
                                                            "swizzle extract");
                }

                void store(LLVMBuilderRef builder, LLVMValueRef value) override {
                    LLVMValueRef result = fBase->load(builder);
                    if (fComponents.size() > 1) {
                        for (size_t i = 0; i < fComponents.size(); ++i) {
                            LLVMValueRef element = LLVMBuildExtractElement(builder, value,
                                                                           LLVMConstInt(
                                                                                    fJIT.fInt32Type,
                                                                                    i,
                                                                                    false),
                                                                           "swizzle extract");
                            result = LLVMBuildInsertElement(builder, result, element,
                                                            LLVMConstInt(fJIT.fInt32Type,
                                                                         fComponents[i],
                                                                         false),
                                                            "swizzle insert");
                        }
                    } else {
                        result = LLVMBuildInsertElement(builder, result, value,
                                                        LLVMConstInt(fJIT.fInt32Type,
                                                                     fComponents[0],
                                                                     false),
                                                        "swizzle insert");
                    }
                    fBase->store(builder, result);
                }

            private:
                JIT& fJIT;
                LLVMTypeRef fType;
                std::unique_ptr<LValue> fBase;
                std::vector<int> fComponents;
            };
            const Swizzle& s = (const Swizzle&) expr;
            return std::unique_ptr<LValue>(new SwizzleLValue(this, this->getType(s.fType),
                                                             this->getLValue(builder, *s.fBase),
                                                             s.fComponents));
        }
        default:
            ABORT("unsupported lvalue");
    }
}

JIT::TypeKind JIT::typeKind(const Type& type) {
    if (type.kind() == Type::kVector_Kind) {
        return this->typeKind(type.componentType());
    }
    if (type.fName == "int" || type.fName == "short" || type.fName == "byte") {
        return JIT::kInt_TypeKind;
    } else if (type.fName == "uint" || type.fName == "ushort" || type.fName == "ubyte") {
        return JIT::kUInt_TypeKind;
    } else if (type.fName == "float" || type.fName == "double" || type.fName == "half") {
        return JIT::kFloat_TypeKind;
    }
    ABORT("unsupported type: %s\n", type.description().c_str());
}

void JIT::vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns) {
    LLVMValueRef result = LLVMGetUndef(LLVMVectorType(LLVMTypeOf(*value), columns));
    for (int i = 0; i < columns; ++i) {
        result = LLVMBuildInsertElement(builder,
                                        result,
                                        *value,
                                        LLVMConstInt(fInt32Type, i, false),
                                        "vectorize");
    }
    *value = result;
}

void JIT::vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left,
                    LLVMValueRef* right) {
    if (b.fLeft->fType.kind() == Type::kScalar_Kind &&
        b.fRight->fType.kind() == Type::kVector_Kind) {
        this->vectorize(builder, left, b.fRight->fType.columns());
    } else if (b.fLeft->fType.kind() == Type::kVector_Kind &&
               b.fRight->fType.kind() == Type::kScalar_Kind) {
        this->vectorize(builder, right, b.fLeft->fType.columns());
    }
}


LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& b) {
    #define BINARY(SFunc, UFunc, FFunc) {                                    \
        LLVMValueRef left = this->compileExpression(builder, *b.fLeft);      \
        LLVMValueRef right = this->compileExpression(builder, *b.fRight);    \
        this->vectorize(builder, b, &left, &right);                          \
        switch (this->typeKind(b.fLeft->fType)) {                            \
            case kInt_TypeKind:                                              \
                return SFunc(builder, left, right, "binary");                \
            case kUInt_TypeKind:                                             \
                return UFunc(builder, left, right, "binary");                \
            case kFloat_TypeKind:                                            \
                return FFunc(builder, left, right, "binary");                \
            default:                                                         \
                ABORT("unsupported typeKind");                               \
        }                                                                    \
    }
    #define COMPOUND(SFunc, UFunc, FFunc) {                                  \
        std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft); \
        LLVMValueRef left = lvalue->load(builder);                           \
        LLVMValueRef right = this->compileExpression(builder, *b.fRight);    \
        this->vectorize(builder, b, &left, &right);                          \
        LLVMValueRef result;                                                 \
        switch (this->typeKind(b.fLeft->fType)) {                            \
            case kInt_TypeKind:                                              \
                result = SFunc(builder, left, right, "binary");              \
                break;                                                       \
            case kUInt_TypeKind:                                             \
                result = UFunc(builder, left, right, "binary");              \
                break;                                                       \
            case kFloat_TypeKind:                                            \
                result = FFunc(builder, left, right, "binary");              \
                break;                                                       \
            default:                                                         \
                ABORT("unsupported typeKind");                               \
        }                                                                    \
        lvalue->store(builder, result);                                      \
        return result;                                                       \
    }
    #define COMPARE(SFunc, SOp, UFunc, UOp, FFunc, FOp) {                    \
        LLVMValueRef left = this->compileExpression(builder, *b.fLeft);      \
        LLVMValueRef right = this->compileExpression(builder, *b.fRight);    \
        this->vectorize(builder, b, &left, &right);                          \
        switch (this->typeKind(b.fLeft->fType)) {                            \
            case kInt_TypeKind:                                              \
                return SFunc(builder, SOp, left, right, "binary");           \
            case kUInt_TypeKind:                                             \
                return UFunc(builder, UOp, left, right, "binary");           \
            case kFloat_TypeKind:                                            \
                return FFunc(builder, FOp, left, right, "binary");           \
            default:                                                         \
                ABORT("unsupported typeKind");                               \
        }                                                                    \
    }
    switch (b.fOperator) {
        case Token::EQ: {
            std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft);
            LLVMValueRef result = this->compileExpression(builder, *b.fRight);
            lvalue->store(builder, result);
            return result;
        }
        case Token::PLUS:
            BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
        case Token::MINUS:
            BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
        case Token::STAR:
            BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
        case Token::SLASH:
            BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
        case Token::PERCENT:
            BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
        case Token::BITWISEAND:
            BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
        case Token::BITWISEOR:
            BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
        case Token::SHL:
            BINARY(LLVMBuildShl, LLVMBuildShl, LLVMBuildShl);
        case Token::SHR:
            BINARY(LLVMBuildAShr, LLVMBuildLShr, LLVMBuildAShr);
        case Token::PLUSEQ:
            COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
        case Token::MINUSEQ:
            COMPOUND(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
        case Token::STAREQ:
            COMPOUND(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
        case Token::SLASHEQ:
            COMPOUND(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
        case Token::BITWISEANDEQ:
            COMPOUND(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
        case Token::BITWISEOREQ:
            COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
        case Token::EQEQ:
            switch (b.fLeft->fType.kind()) {
                case Type::kScalar_Kind:
                    COMPARE(LLVMBuildICmp, LLVMIntEQ,
                            LLVMBuildICmp, LLVMIntEQ,
                            LLVMBuildFCmp, LLVMRealOEQ);
                case Type::kVector_Kind: {
                    LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
                    LLVMValueRef right = this->compileExpression(builder, *b.fRight);
                    this->vectorize(builder, b, &left, &right);
                    LLVMValueRef value;
                    switch (this->typeKind(b.fLeft->fType)) {
                        case kInt_TypeKind:
                            value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
                            break;
                        case kUInt_TypeKind:
                            value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
                            break;
                        case kFloat_TypeKind:
                            value = LLVMBuildFCmp(builder, LLVMRealOEQ, left, right, "binary");
                            break;
                        default:
                            ABORT("unsupported typeKind");
                    }
                    LLVMValueRef args[1] = { value };
                    LLVMValueRef func;
                    switch (b.fLeft->fType.columns()) {
                        case 2: func = fFoldAnd2Func; break;
                        case 3: func = fFoldAnd3Func; break;
                        case 4: func = fFoldAnd4Func; break;
                        default:
                            SkASSERT(false);
                            func = fFoldAnd2Func;
                    }
                    return LLVMBuildCall(builder, func, args, 1, "all");
                }
                default:
                    SkASSERT(false);
            }
        case Token::NEQ:
            switch (b.fLeft->fType.kind()) {
                case Type::kScalar_Kind:
                    COMPARE(LLVMBuildICmp, LLVMIntNE,
                            LLVMBuildICmp, LLVMIntNE,
                            LLVMBuildFCmp, LLVMRealONE);
                case Type::kVector_Kind: {
                    LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
                    LLVMValueRef right = this->compileExpression(builder, *b.fRight);
                    this->vectorize(builder, b, &left, &right);
                    LLVMValueRef value;
                    switch (this->typeKind(b.fLeft->fType)) {
                        case kInt_TypeKind:
                            value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
                            break;
                        case kUInt_TypeKind:
                            value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
                            break;
                        case kFloat_TypeKind:
                            value = LLVMBuildFCmp(builder, LLVMRealONE, left, right, "binary");
                            break;
                        default:
                            ABORT("unsupported typeKind");
                    }
                    LLVMValueRef args[1] = { value };
                    LLVMValueRef func;
                    switch (b.fLeft->fType.columns()) {
                        case 2: func = fFoldOr2Func; break;
                        case 3: func = fFoldOr3Func; break;
                        case 4: func = fFoldOr4Func; break;
                        default:
                            SkASSERT(false);
                            func = fFoldOr2Func;
                    }
                    return LLVMBuildCall(builder, func, args, 1, "all");
                }
                default:
                    SkASSERT(false);
            }
        case Token::LT:
            COMPARE(LLVMBuildICmp, LLVMIntSLT,
                    LLVMBuildICmp, LLVMIntULT,
                    LLVMBuildFCmp, LLVMRealOLT);
        case Token::LTEQ:
            COMPARE(LLVMBuildICmp, LLVMIntSLE,
                    LLVMBuildICmp, LLVMIntULE,
                    LLVMBuildFCmp, LLVMRealOLE);
        case Token::GT:
            COMPARE(LLVMBuildICmp, LLVMIntSGT,
                    LLVMBuildICmp, LLVMIntUGT,
                    LLVMBuildFCmp, LLVMRealOGT);
        case Token::GTEQ:
            COMPARE(LLVMBuildICmp, LLVMIntSGE,
                    LLVMBuildICmp, LLVMIntUGE,
                    LLVMBuildFCmp, LLVMRealOGE);
        case Token::LOGICALAND: {
            LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
            LLVMBasicBlockRef ifFalse = fCurrentBlock;
            LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
                                                                     "true && ...");
            LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
                                                                    "&& merge");
            LLVMBuildCondBr(builder, left, ifTrue, merge);
            this->setBlock(builder, ifTrue);
            LLVMValueRef right = this->compileExpression(builder, *b.fRight);
            LLVMBuildBr(builder, merge);
            this->setBlock(builder, merge);
            LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "&&");
            LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 0, false) };
            LLVMBasicBlockRef incomingBlocks[2] = { ifTrue, ifFalse };
            LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
            return phi;
        }
        case Token::LOGICALOR: {
            LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
            LLVMBasicBlockRef ifTrue = fCurrentBlock;
            LLVMBasicBlockRef ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
                                                                      "false || ...");
            LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
                                                                    "|| merge");
            LLVMBuildCondBr(builder, left, merge, ifFalse);
            this->setBlock(builder, ifFalse);
            LLVMValueRef right = this->compileExpression(builder, *b.fRight);
            LLVMBuildBr(builder, merge);
            this->setBlock(builder, merge);
            LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "||");
            LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 1, false) };
            LLVMBasicBlockRef incomingBlocks[2] = { ifFalse, ifTrue };
            LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
            return phi;
        }
        default:
            printf("%s\n", b.description().c_str());
            ABORT("unsupported binary operator");
    }
}

LLVMValueRef JIT::compileIndex(LLVMBuilderRef builder, const IndexExpression& idx) {
    LLVMValueRef base = this->compileExpression(builder, *idx.fBase);
    LLVMValueRef index = this->compileExpression(builder, *idx.fIndex);
    LLVMValueRef ptr = LLVMBuildGEP(builder, base, &index, 1, "index ptr");
    return LLVMBuildLoad(builder, ptr, "index load");
}

LLVMValueRef JIT::compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p) {
    std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
    LLVMValueRef result = lvalue->load(builder);
    LLVMValueRef mod;
    LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
    switch (p.fOperator) {
        case Token::PLUSPLUS:
            switch (this->typeKind(p.fType)) {
                case kInt_TypeKind: // fall through
                case kUInt_TypeKind:
                    mod = LLVMBuildAdd(builder, result, one, "++");
                    break;
                case kFloat_TypeKind:
                    mod = LLVMBuildFAdd(builder, result, one, "++");
                    break;
                default:
                    ABORT("unsupported typeKind");
            }
            break;
        case Token::MINUSMINUS:
            switch (this->typeKind(p.fType)) {
                case kInt_TypeKind: // fall through
                case kUInt_TypeKind:
                    mod = LLVMBuildSub(builder, result, one, "--");
                    break;
                case kFloat_TypeKind:
                    mod = LLVMBuildFSub(builder, result, one, "--");
                    break;
                default:
                    ABORT("unsupported typeKind");
            }
            break;
        default:
            ABORT("unsupported postfix op");
    }
    lvalue->store(builder, mod);
    return result;
}

LLVMValueRef JIT::compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p) {
    LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
    if (Token::LOGICALNOT == p.fOperator) {
        LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
        return LLVMBuildXor(builder, base, one, "!");
    }
    if (Token::MINUS == p.fOperator) {
        LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
        return LLVMBuildSub(builder, LLVMConstInt(this->getType(p.fType), 0, false), base, "-");
    }
    std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
    LLVMValueRef raw = lvalue->load(builder);
    LLVMValueRef result;
    switch (p.fOperator) {
        case Token::PLUSPLUS:
            switch (this->typeKind(p.fType)) {
                case kInt_TypeKind: // fall through
                case kUInt_TypeKind:
                    result = LLVMBuildAdd(builder, raw, one, "++");
                    break;
                case kFloat_TypeKind:
                    result = LLVMBuildFAdd(builder, raw, one, "++");
                    break;
                default:
                    ABORT("unsupported typeKind");
            }
            break;
        case Token::MINUSMINUS:
            switch (this->typeKind(p.fType)) {
                case kInt_TypeKind: // fall through
                case kUInt_TypeKind:
                    result = LLVMBuildSub(builder, raw, one, "--");
                    break;
                case kFloat_TypeKind:
                    result = LLVMBuildFSub(builder, raw, one, "--");
                    break;
                default:
                    ABORT("unsupported typeKind");
            }
            break;
        default:
            ABORT("unsupported prefix op");
    }
    lvalue->store(builder, result);
    return result;
}

LLVMValueRef JIT::compileVariableReference(LLVMBuilderRef builder, const VariableReference& v) {
    const Variable& var = v.fVariable;
    if (Variable::kParameter_Storage == var.fStorage &&
        !(var.fModifiers.fFlags & Modifiers::kOut_Flag) &&
        fPromotedParameters.find(&var) == fPromotedParameters.end()) {
        return fVariables[&var];
    }
    return LLVMBuildLoad(builder, fVariables[&var], String(var.fName).c_str());
}

void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) {
    SkASSERT(a.fArguments.size() >= 1);
    SkASSERT(a.fArguments[0]->fType == *fCompiler.context().fSkRasterPipeline_Type);
    LLVMValueRef pipeline = this->compileExpression(builder, *a.fArguments[0]);
    LLVMValueRef stage = LLVMConstInt(fInt32Type, a.fStage, 0);
    switch (a.fStage) {
        case SkRasterPipeline::callback: {
            SkASSERT(a.fArguments.size() == 2);
            SkASSERT(a.fArguments[1]->fKind == Expression::kFunctionReference_Kind);
            const FunctionDeclaration& functionDecl =
                                             *((FunctionReference&) *a.fArguments[1]).fFunctions[0];
            bool found = false;
            for (const auto& pe : *fProgram) {
                if (ProgramElement::kFunction_Kind == pe.fKind) {
                    const FunctionDefinition& def = (const FunctionDefinition&) pe;
                    if (&def.fDeclaration == &functionDecl) {
                        LLVMValueRef fn = this->compileStageFunction(def);
                        LLVMValueRef args[2] = {
                            pipeline,
                            LLVMBuildBitCast(builder, fn, fInt8PtrType, "callback cast")
                        };
                        LLVMBuildCall(builder, fAppendCallbackFunc, args, 2, "");
                        found = true;
                        break;
                    }
                }
            }
            SkASSERT(found);
            break;
        }
        default: {
            LLVMValueRef ctx;
            if (a.fArguments.size() == 2) {
                ctx = this->compileExpression(builder, *a.fArguments[1]);
                ctx = LLVMBuildBitCast(builder, ctx, fInt8PtrType, "context cast");
            } else {
                SkASSERT(a.fArguments.size() == 1);
                ctx = LLVMConstNull(fInt8PtrType);
            }
            LLVMValueRef args[3] = {
                pipeline,
                stage,
                ctx
            };
            LLVMBuildCall(builder, fAppendFunc, args, 3, "");
            break;
        }
    }
}

LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor& c) {
    switch (c.fType.kind()) {
        case Type::kScalar_Kind: {
            SkASSERT(c.fArguments.size() == 1);
            TypeKind from = this->typeKind(c.fArguments[0]->fType);
            TypeKind to = this->typeKind(c.fType);
            LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]);
            switch (to) {
                case kFloat_TypeKind:
                    switch (from) {
                        case kInt_TypeKind:
                            return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast");
                        case kUInt_TypeKind:
                            return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast");
                        case kFloat_TypeKind:
                            return base;
                        case kBool_TypeKind:
                            SkASSERT(false);
                    }
                case kInt_TypeKind:
                    switch (from) {
                        case kInt_TypeKind:
                            return base;
                        case kUInt_TypeKind:
                            return base;
                        case kFloat_TypeKind:
                            return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast");
                        case kBool_TypeKind:
                            SkASSERT(false);
                    }
                case kUInt_TypeKind:
                    switch (from) {
                        case kInt_TypeKind:
                            return base;
                        case kUInt_TypeKind:
                            return base;
                        case kFloat_TypeKind:
                            return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast");
                        case kBool_TypeKind:
                            SkASSERT(false);
                    }
                case kBool_TypeKind:
                    SkASSERT(false);
            }
        }
        case Type::kVector_Kind: {
            LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType));
            if (c.fArguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
                LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]);
                for (int i = 0; i < c.fType.columns(); ++i) {
                    vec = LLVMBuildInsertElement(builder, vec, value,
                                                 LLVMConstInt(fInt32Type, i, false),
                                                 "vec build 1");
                }
            } else {
                int index = 0;
                for (const auto& arg : c.fArguments) {
                    LLVMValueRef value = this->compileExpression(builder, *arg);
                    if (arg->fType.kind() == Type::kVector_Kind) {
                        for (int i = 0; i < arg->fType.columns(); ++i) {
                            LLVMValueRef column = LLVMBuildExtractElement(builder,
                                                                          vec,
                                                                          LLVMConstInt(fInt32Type,
                                                                                       i,
                                                                                       false),
                                                                          "construct extract");
                            vec = LLVMBuildInsertElement(builder, vec, column,
                                                         LLVMConstInt(fInt32Type, index++, false),
                                                         "vec build 2");
                        }
                    } else {
                        vec = LLVMBuildInsertElement(builder, vec, value,
                                                     LLVMConstInt(fInt32Type, index++, false),
                                                     "vec build 3");
                    }
                }
            }
            return vec;
        }
        default:
            break;
    }
    ABORT("unsupported constructor");
}

LLVMValueRef JIT::compileSwizzle(LLVMBuilderRef builder, const Swizzle& s) {
    LLVMValueRef base = this->compileExpression(builder, *s.fBase);
    if (s.fComponents.size() > 1) {
        LLVMValueRef result = LLVMGetUndef(this->getType(s.fType));
        for (size_t i = 0; i < s.fComponents.size(); ++i) {
            LLVMValueRef element = LLVMBuildExtractElement(
                                                       builder,
                                                       base,
                                                       LLVMConstInt(fInt32Type,
                                                                    s.fComponents[i],
                                                                    false),
                                                       "swizzle extract");
            result = LLVMBuildInsertElement(builder, result, element,
                                            LLVMConstInt(fInt32Type, i, false),
                                            "swizzle insert");
        }
        return result;
    }
    SkASSERT(s.fComponents.size() == 1);
    return LLVMBuildExtractElement(builder, base,
                                            LLVMConstInt(fInt32Type,
                                                         s.fComponents[0],
                                                         false),
                                            "swizzle extract");
}

LLVMValueRef JIT::compileTernary(LLVMBuilderRef builder, const TernaryExpression& t) {
    LLVMValueRef test = this->compileExpression(builder, *t.fTest);
    LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
                                                                "if true");
    LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
                                                            "if merge");
    LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
                                                                 "if false");
    LLVMBuildCondBr(builder, test, trueBlock, falseBlock);
    this->setBlock(builder, trueBlock);
    LLVMValueRef ifTrue = this->compileExpression(builder, *t.fIfTrue);
    trueBlock = fCurrentBlock;
    LLVMBuildBr(builder, merge);
    this->setBlock(builder, falseBlock);
    LLVMValueRef ifFalse = this->compileExpression(builder, *t.fIfFalse);
    falseBlock = fCurrentBlock;
    LLVMBuildBr(builder, merge);
    this->setBlock(builder, merge);
    LLVMValueRef phi = LLVMBuildPhi(builder, this->getType(t.fType), "?");
    LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
    LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
    LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
    return phi;
}

LLVMValueRef JIT::compileExpression(LLVMBuilderRef builder, const Expression& expr) {
    switch (expr.fKind) {
        case Expression::kAppendStage_Kind: {
            this->appendStage(builder, (const AppendStage&) expr);
            return LLVMValueRef();
        }
        case Expression::kBinary_Kind:
            return this->compileBinary(builder, (BinaryExpression&) expr);
        case Expression::kBoolLiteral_Kind:
            return LLVMConstInt(fInt1Type, ((BoolLiteral&) expr).fValue, false);
        case Expression::kConstructor_Kind:
            return this->compileConstructor(builder, (Constructor&) expr);
        case Expression::kIntLiteral_Kind:
            return LLVMConstInt(this->getType(expr.fType), ((IntLiteral&) expr).fValue, true);
        case Expression::kFieldAccess_Kind:
            abort();
        case Expression::kFloatLiteral_Kind:
            return LLVMConstReal(this->getType(expr.fType), ((FloatLiteral&) expr).fValue);
        case Expression::kFunctionCall_Kind:
            return this->compileFunctionCall(builder, (FunctionCall&) expr);
        case Expression::kIndex_Kind:
            return this->compileIndex(builder, (IndexExpression&) expr);
        case Expression::kPrefix_Kind:
            return this->compilePrefix(builder, (PrefixExpression&) expr);
        case Expression::kPostfix_Kind:
            return this->compilePostfix(builder, (PostfixExpression&) expr);
        case Expression::kSetting_Kind:
            abort();
        case Expression::kSwizzle_Kind:
            return this->compileSwizzle(builder, (Swizzle&) expr);
        case Expression::kVariableReference_Kind:
            return this->compileVariableReference(builder, (VariableReference&) expr);
        case Expression::kTernary_Kind:
            return this->compileTernary(builder, (TernaryExpression&) expr);
        case Expression::kTypeReference_Kind:
            abort();
        default:
            abort();
    }
    ABORT("unsupported expression: %s\n", expr.description().c_str());
}

void JIT::compileBlock(LLVMBuilderRef builder, const Block& block) {
    for (const auto& stmt : block.fStatements) {
        this->compileStatement(builder, *stmt);
    }
}

void JIT::compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls) {
    for (const auto& declStatement : decls.fDeclaration->fVars) {
        const VarDeclaration& decl = (VarDeclaration&) *declStatement;
        LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
        LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(decl.fVar->fType),
                                              String(decl.fVar->fName).c_str());
        fVariables[decl.fVar] = alloca;
        LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
        if (decl.fValue) {
            LLVMValueRef result = this->compileExpression(builder, *decl.fValue);
            LLVMBuildStore(builder, result, alloca);
        }
    }
}

void JIT::compileIf(LLVMBuilderRef builder, const IfStatement& i) {
    LLVMValueRef test = this->compileExpression(builder, *i.fTest);
    LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if true");
    LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
                                                              "if merge");
    LLVMBasicBlockRef ifFalse;
    if (i.fIfFalse) {
        ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if false");
    } else {
        ifFalse = merge;
    }
    LLVMBuildCondBr(builder, test, ifTrue, ifFalse);
    this->setBlock(builder, ifTrue);
    this->compileStatement(builder, *i.fIfTrue);
    if (!ends_with_branch(*i.fIfTrue)) {
        LLVMBuildBr(builder, merge);
    }
    if (i.fIfFalse) {
        this->setBlock(builder, ifFalse);
        this->compileStatement(builder, *i.fIfFalse);
        if (!ends_with_branch(*i.fIfFalse)) {
            LLVMBuildBr(builder, merge);
        }
    }
    this->setBlock(builder, merge);
}

void JIT::compileFor(LLVMBuilderRef builder, const ForStatement& f) {
    if (f.fInitializer) {
        this->compileStatement(builder, *f.fInitializer);
    }
    LLVMBasicBlockRef start;
    LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for body");
    LLVMBasicBlockRef next = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for next");
    LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for end");
    if (f.fTest) {
        start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for test");
        LLVMBuildBr(builder, start);
        this->setBlock(builder, start);
        LLVMValueRef test = this->compileExpression(builder, *f.fTest);
        LLVMBuildCondBr(builder, test, body, end);
    } else {
        start = body;
        LLVMBuildBr(builder, body);
    }
    this->setBlock(builder, body);
    fBreakTarget.push_back(end);
    fContinueTarget.push_back(next);
    this->compileStatement(builder, *f.fStatement);
    fBreakTarget.pop_back();
    fContinueTarget.pop_back();
    if (!ends_with_branch(*f.fStatement)) {
        LLVMBuildBr(builder, next);
    }
    this->setBlock(builder, next);
    if (f.fNext) {
        this->compileExpression(builder, *f.fNext);
    }
    LLVMBuildBr(builder, start);
    this->setBlock(builder, end);
}

void JIT::compileDo(LLVMBuilderRef builder, const DoStatement& d) {
    LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
                                                                "do test");
    LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
                                                           "do body");
    LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
                                                          "do end");
    LLVMBuildBr(builder, body);
    this->setBlock(builder, testBlock);
    LLVMValueRef test = this->compileExpression(builder, *d.fTest);
    LLVMBuildCondBr(builder, test, body, end);
    this->setBlock(builder, body);
    fBreakTarget.push_back(end);
    fContinueTarget.push_back(body);
    this->compileStatement(builder, *d.fStatement);
    fBreakTarget.pop_back();
    fContinueTarget.pop_back();
    if (!ends_with_branch(*d.fStatement)) {
        LLVMBuildBr(builder, testBlock);
    }
    this->setBlock(builder, end);
}

void JIT::compileWhile(LLVMBuilderRef builder, const WhileStatement& w) {
    LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
                                                           "while test");
    LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
                                                           "while body");
    LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
                                                          "while end");
    LLVMBuildBr(builder, testBlock);
    this->setBlock(builder, testBlock);
    LLVMValueRef test = this->compileExpression(builder, *w.fTest);
    LLVMBuildCondBr(builder, test, body, end);
    this->setBlock(builder, body);
    fBreakTarget.push_back(end);
    fContinueTarget.push_back(testBlock);
    this->compileStatement(builder, *w.fStatement);
    fBreakTarget.pop_back();
    fContinueTarget.pop_back();
    if (!ends_with_branch(*w.fStatement)) {
        LLVMBuildBr(builder, testBlock);
    }
    this->setBlock(builder, end);
}

void JIT::compileBreak(LLVMBuilderRef builder, const BreakStatement& b) {
    LLVMBuildBr(builder, fBreakTarget.back());
}

void JIT::compileContinue(LLVMBuilderRef builder, const ContinueStatement& b) {
    LLVMBuildBr(builder, fContinueTarget.back());
}

void JIT::compileReturn(LLVMBuilderRef builder, const ReturnStatement& r) {
    if (r.fExpression) {
        LLVMBuildRet(builder, this->compileExpression(builder, *r.fExpression));
    } else {
        LLVMBuildRetVoid(builder);
    }
}

void JIT::compileStatement(LLVMBuilderRef builder, const Statement& stmt) {
    switch (stmt.fKind) {
        case Statement::kBlock_Kind:
            this->compileBlock(builder, (Block&) stmt);
            break;
        case Statement::kBreak_Kind:
            this->compileBreak(builder, (BreakStatement&) stmt);
            break;
        case Statement::kContinue_Kind:
            this->compileContinue(builder, (ContinueStatement&) stmt);
            break;
        case Statement::kDiscard_Kind:
            abort();
        case Statement::kDo_Kind:
            this->compileDo(builder, (DoStatement&) stmt);
            break;
        case Statement::kExpression_Kind:
            this->compileExpression(builder, *((ExpressionStatement&) stmt).fExpression);
            break;
        case Statement::kFor_Kind:
            this->compileFor(builder, (ForStatement&) stmt);
            break;
        case Statement::kGroup_Kind:
            abort();
        case Statement::kIf_Kind:
            this->compileIf(builder, (IfStatement&) stmt);
            break;
        case Statement::kNop_Kind:
            break;
        case Statement::kReturn_Kind:
            this->compileReturn(builder, (ReturnStatement&) stmt);
            break;
        case Statement::kSwitch_Kind:
            abort();
        case Statement::kVarDeclarations_Kind:
            this->compileVarDeclarations(builder, (VarDeclarationsStatement&) stmt);
            break;
        case Statement::kWhile_Kind:
            this->compileWhile(builder, (WhileStatement&) stmt);
            break;
        default:
            abort();
    }
}

void JIT::compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc) {
    // loop over fVectorCount pixels, running the body of the stage function for each of them
    LLVMValueRef oldFunction = fCurrentFunction;
    fCurrentFunction = newFunc;
    std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
    LLVMGetParams(fCurrentFunction, params.get());
    LLVMValueRef programParam = params.get()[1];
    LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
    LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
    LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
    fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
    this->setBlock(builder, fAllocaBlock);
    // temporaries to store the color channel vectors
    LLVMValueRef rVec = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
    LLVMBuildStore(builder, params.get()[4], rVec);
    LLVMValueRef gVec = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
    LLVMBuildStore(builder, params.get()[5], gVec);
    LLVMValueRef bVec = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
    LLVMBuildStore(builder, params.get()[6], bVec);
    LLVMValueRef aVec = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
    LLVMBuildStore(builder, params.get()[7], aVec);
    LLVMValueRef color = LLVMBuildAlloca(builder, fFloat32Vector4Type, "color");
    fVariables[f.fDeclaration.fParameters[1]] = LLVMBuildTrunc(builder, params.get()[3], fInt32Type,
                                                               "y->Int32");
    fVariables[f.fDeclaration.fParameters[2]] = color;
    LLVMValueRef ivar = LLVMBuildAlloca(builder, fInt32Type, "i");
    LLVMBuildStore(builder, LLVMConstInt(fInt32Type, 0, false), ivar);
    LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
    this->setBlock(builder, start);
    LLVMValueRef iload = LLVMBuildLoad(builder, ivar, "load i");
    fVariables[f.fDeclaration.fParameters[0]] = LLVMBuildAdd(builder,
                                                             LLVMBuildTrunc(builder,
                                                                            params.get()[2],
                                                                            fInt32Type,
                                                                            "x->Int32"),
                                                             iload,
                                                             "x");
    LLVMValueRef vectorSize = LLVMConstInt(fInt32Type, fVectorCount, false);
    LLVMValueRef test = LLVMBuildICmp(builder, LLVMIntSLT, iload, vectorSize, "i < vectorSize");
    LLVMBasicBlockRef loopBody = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "body");
    LLVMBasicBlockRef loopEnd = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "end");
    LLVMBuildCondBr(builder, test, loopBody, loopEnd);
    this->setBlock(builder, loopBody);
    LLVMValueRef vec = LLVMGetUndef(fFloat32Vector4Type);
    // extract the r, g, b, and a values from the color channel vectors and store them into "color"
    for (int i = 0; i < 4; ++i) {
        vec = LLVMBuildInsertElement(builder, vec,
                                     LLVMBuildExtractElement(builder,
                                                             params.get()[4 + i],
                                                             iload, "initial"),
                                     LLVMConstInt(fInt32Type, i, false),
                                     "vec build");
    }
    LLVMBuildStore(builder, vec, color);
    // write actual loop body
    this->compileStatement(builder, *f.fBody);
    // extract the r, g, b, and a values from "color" and stick them back into the color channel
    // vectors
    LLVMValueRef colorLoad = LLVMBuildLoad(builder, color, "color load");
    LLVMBuildStore(builder,
                   LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, rVec, "rVec"),
                                          LLVMBuildExtractElement(builder, colorLoad,
                                                                  LLVMConstInt(fInt32Type, 0,
                                                                               false),
                                                                  "rExtract"),
                                          iload, "rInsert"),
                   rVec);
    LLVMBuildStore(builder,
                   LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, gVec, "gVec"),
                                          LLVMBuildExtractElement(builder, colorLoad,
                                                                  LLVMConstInt(fInt32Type, 1,
                                                                               false),
                                                                  "gExtract"),
                                          iload, "gInsert"),
                   gVec);
    LLVMBuildStore(builder,
                   LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, bVec, "bVec"),
                                          LLVMBuildExtractElement(builder, colorLoad,
                                                                  LLVMConstInt(fInt32Type, 2,
                                                                               false),
                                                                  "bExtract"),
                                          iload, "bInsert"),
                   bVec);
    LLVMBuildStore(builder,
                   LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, aVec, "aVec"),
                                          LLVMBuildExtractElement(builder, colorLoad,
                                                                  LLVMConstInt(fInt32Type, 3,
                                                                               false),
                                                                  "aExtract"),
                                          iload, "aInsert"),
                   aVec);
    LLVMValueRef inc = LLVMBuildAdd(builder, iload, LLVMConstInt(fInt32Type, 1, false), "inc i");
    LLVMBuildStore(builder, inc, ivar);
    LLVMBuildBr(builder, start);
    this->setBlock(builder, loopEnd);
    // increment program pointer, call the next stage
    LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
    LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
    LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType, "cast next->func");
    LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
                                             LLVMBuildAdd(builder,
                                                          LLVMBuildPtrToInt(builder,
                                                                            programParam,
                                                                            fInt64Type,
                                                                            "cast 1"),
                                                          LLVMConstInt(fInt64Type, PTR_SIZE, false),
                                                          "add"),
                                            LLVMPointerType(fInt8PtrType, 0), "cast 2");
    LLVMValueRef args[STAGE_PARAM_COUNT] = {
        params.get()[0],
        nextInc,
        params.get()[2],
        params.get()[3],
        LLVMBuildLoad(builder, rVec, "rVec"),
        LLVMBuildLoad(builder, gVec, "gVec"),
        LLVMBuildLoad(builder, bVec, "bVec"),
        LLVMBuildLoad(builder, aVec, "aVec"),
        params.get()[8],
        params.get()[9],
        params.get()[10],
        params.get()[11]
    };
    LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
    LLVMBuildRetVoid(builder);
    // finish
    LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
    LLVMBuildBr(builder, start);
    LLVMDisposeBuilder(builder);
    if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
        ABORT("verify failed\n");
    }
    fAllocaBlock = oldAllocaBlock;
    fCurrentBlock = oldCurrentBlock;
    fCurrentFunction = oldFunction;
}

// FIXME maybe pluggable code generators? Need to do something to separate all
// of the normal codegen from the vector codegen and break this up into multiple
// classes.

bool JIT::getVectorLValue(LLVMBuilderRef builder, const Expression& e,
                          LLVMValueRef out[CHANNELS]) {
    switch (e.fKind) {
        case Expression::kVariableReference_Kind:
            if (fColorParam == &((VariableReference&) e).fVariable) {
                memcpy(out, fChannels, sizeof(fChannels));
                return true;
            }
            return false;
        case Expression::kSwizzle_Kind: {
            const Swizzle& s = (const Swizzle&) e;
            LLVMValueRef base[CHANNELS];
            if (!this->getVectorLValue(builder, *s.fBase, base)) {
                return false;
            }
            for (size_t i = 0; i < s.fComponents.size(); ++i) {
                out[i] = base[s.fComponents[i]];
            }
            return true;
        }
        default:
            return false;
    }
}

bool JIT::getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left,
                                  LLVMValueRef outLeft[CHANNELS], const Expression& right,
                                  LLVMValueRef outRight[CHANNELS]) {
    if (!this->compileVectorExpression(builder, left, outLeft)) {
        return false;
    }
    int leftColumns = left.fType.columns();
    int rightColumns = right.fType.columns();
    if (leftColumns == 1 && rightColumns > 1) {
        for (int i = 1; i < rightColumns; ++i) {
            outLeft[i] = outLeft[0];
        }
    }
    if (!this->compileVectorExpression(builder, right, outRight)) {
        return false;
    }
    if (rightColumns == 1 && leftColumns > 1) {
        for (int i = 1; i < leftColumns; ++i) {
            outRight[i] = outRight[0];
        }
    }
    return true;
}

bool JIT::compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b,
                              LLVMValueRef out[CHANNELS]) {
    LLVMValueRef left[CHANNELS];
    LLVMValueRef right[CHANNELS];
    #define VECTOR_BINARY(signedOp, unsignedOp, floatOp) {                               \
        if (!this->getVectorBinaryOperands(builder, *b.fLeft, left, *b.fRight, right)) { \
            return false;                                                                \
        }                                                                                \
        for (int i = 0; i < b.fLeft->fType.columns(); ++i) {                             \
            switch (this->typeKind(b.fLeft->fType)) {                                    \
                case kInt_TypeKind:                                                      \
                    out[i] = signedOp(builder, left[i], right[i], "binary");             \
                    break;                                                               \
                case kUInt_TypeKind:                                                     \
                    out[i] = unsignedOp(builder, left[i], right[i], "binary");           \
                    break;                                                               \
                case kFloat_TypeKind:                                                    \
                    out[i] = floatOp(builder, left[i], right[i], "binary");              \
                    break;                                                               \
                case kBool_TypeKind:                                                     \
                    SkASSERT(false);                                                       \
                    break;                                                               \
            }                                                                            \
        }                                                                                \
        return true;                                                                     \
    }
    switch (b.fOperator) {
        case Token::EQ: {
            if (!this->getVectorLValue(builder, *b.fLeft, left)) {
                return false;
            }
            if (!this->compileVectorExpression(builder, *b.fRight, right)) {
                return false;
            }
            int columns = b.fRight->fType.columns();
            for (int i = 0; i < columns; ++i) {
                LLVMBuildStore(builder, right[i], left[i]);
            }
            return true;
        }
        case Token::PLUS:
            VECTOR_BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
        case Token::MINUS:
            VECTOR_BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
        case Token::STAR:
            VECTOR_BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
        case Token::SLASH:
            VECTOR_BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
        case Token::PERCENT:
            VECTOR_BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
        case Token::BITWISEAND:
            VECTOR_BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
        case Token::BITWISEOR:
            VECTOR_BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
        default:
            printf("unsupported operator: %s\n", b.description().c_str());
            return false;
    }
}

bool JIT::compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c,
                                   LLVMValueRef out[CHANNELS]) {
    switch (c.fType.kind()) {
        case Type::kScalar_Kind: {
            SkASSERT(c.fArguments.size() == 1);
            TypeKind from = this->typeKind(c.fArguments[0]->fType);
            TypeKind to = this->typeKind(c.fType);
            LLVMValueRef base[CHANNELS];
            if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
                return false;
            }
            #define CONSTRUCT(fn)                                                                \
                out[0] = LLVMGetUndef(LLVMVectorType(this->getType(c.fType), fVectorCount));     \
                for (int i = 0; i < fVectorCount; ++i) {                                         \
                    LLVMValueRef index = LLVMConstInt(fInt32Type, i, false);                     \
                    LLVMValueRef baseVal = LLVMBuildExtractElement(builder, base[0], index,      \
                                                                   "construct extract");         \
                    out[0] = LLVMBuildInsertElement(builder, out[0],                             \
                                                    fn(builder, baseVal, this->getType(c.fType), \
                                                       "cast"),                                  \
                                                    index, "construct insert");                  \
                }                                                                                \
                return true;
            if (kFloat_TypeKind == to) {
                if (kInt_TypeKind == from) {
                    CONSTRUCT(LLVMBuildSIToFP);
                }
                if (kUInt_TypeKind == from) {
                    CONSTRUCT(LLVMBuildUIToFP);
                }
            }
            if (kInt_TypeKind == to) {
                if (kFloat_TypeKind == from) {
                    CONSTRUCT(LLVMBuildFPToSI);
                }
                if (kUInt_TypeKind == from) {
                    return true;
                }
            }
            if (kUInt_TypeKind == to) {
                if (kFloat_TypeKind == from) {
                    CONSTRUCT(LLVMBuildFPToUI);
                }
                if (kInt_TypeKind == from) {
                    return base;
                }
            }
            printf("%s\n", c.description().c_str());
            ABORT("unsupported constructor");
        }
        case Type::kVector_Kind: {
            if (c.fArguments.size() == 1) {
                LLVMValueRef base[CHANNELS];
                if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
                    return false;
                }
                for (int i = 0; i < c.fType.columns(); ++i) {
                    out[i] = base[0];
                }
            } else {
                SkASSERT(c.fArguments.size() == (size_t) c.fType.columns());
                for (int i = 0; i < c.fType.columns(); ++i) {
                    LLVMValueRef base[CHANNELS];
                    if (!this->compileVectorExpression(builder, *c.fArguments[i], base)) {
                        return false;
                    }
                    out[i] = base[0];
                }
            }
            return true;
        }
        default:
            break;
    }
    ABORT("unsupported constructor");
}

bool JIT::compileVectorFloatLiteral(LLVMBuilderRef builder,
                                    const FloatLiteral& f,
                                    LLVMValueRef out[CHANNELS]) {
    LLVMValueRef value = LLVMConstReal(this->getType(f.fType), f.fValue);
    LLVMValueRef values[MAX_VECTOR_COUNT];
    for (int i = 0; i < fVectorCount; ++i) {
        values[i] = value;
    }
    out[0] = LLVMConstVector(values, fVectorCount);
    return true;
}


bool JIT::compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s,
                               LLVMValueRef out[CHANNELS]) {
    LLVMValueRef all[CHANNELS];
    if (!this->compileVectorExpression(builder, *s.fBase, all)) {
        return false;
    }
    for (size_t i = 0; i < s.fComponents.size(); ++i) {
        out[i] = all[s.fComponents[i]];
    }
    return true;
}

bool JIT::compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v,
                                         LLVMValueRef out[CHANNELS]) {
    if (&v.fVariable == fColorParam) {
        for (int i = 0; i < CHANNELS; ++i) {
            out[i] = LLVMBuildLoad(builder, fChannels[i], "variable reference");
        }
        return true;
    }
    return false;
}

bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr,
                                  LLVMValueRef out[CHANNELS]) {
    switch (expr.fKind) {
        case Expression::kBinary_Kind:
            return this->compileVectorBinary(builder, (const BinaryExpression&) expr, out);
        case Expression::kConstructor_Kind:
            return this->compileVectorConstructor(builder, (const Constructor&) expr, out);
        case Expression::kFloatLiteral_Kind:
            return this->compileVectorFloatLiteral(builder, (const FloatLiteral&) expr, out);
        case Expression::kSwizzle_Kind:
            return this->compileVectorSwizzle(builder, (const Swizzle&) expr, out);
        case Expression::kVariableReference_Kind:
            return this->compileVectorVariableReference(builder, (const VariableReference&) expr,
                                                        out);
        default:
            return false;
    }
}

bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt) {
    switch (stmt.fKind) {
        case Statement::kBlock_Kind:
            for (const auto& s : ((const Block&) stmt).fStatements) {
                if (!this->compileVectorStatement(builder, *s)) {
                    return false;
                }
            }
            return true;
        case Statement::kExpression_Kind:
            LLVMValueRef result;
            return this->compileVectorExpression(builder,
                                                 *((const ExpressionStatement&) stmt).fExpression,
                                                 &result);
        default:
            return false;
    }
}

bool JIT::compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc) {
    LLVMValueRef oldFunction = fCurrentFunction;
    fCurrentFunction = newFunc;
    std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
    LLVMGetParams(fCurrentFunction, params.get());
    LLVMValueRef programParam = params.get()[1];
    LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
    LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
    LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
    fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
    this->setBlock(builder, fAllocaBlock);
    fChannels[0] = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
    LLVMBuildStore(builder, params.get()[4], fChannels[0]);
    fChannels[1] = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
    LLVMBuildStore(builder, params.get()[5], fChannels[1]);
    fChannels[2] = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
    LLVMBuildStore(builder, params.get()[6], fChannels[2]);
    fChannels[3] = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
    LLVMBuildStore(builder, params.get()[7], fChannels[3]);
    LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
    this->setBlock(builder, start);
    bool success = this->compileVectorStatement(builder, *f.fBody);
    if (success) {
        // increment program pointer, call next
        LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
        LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
        LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType,
                                                "cast next->func");
        LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
                                                 LLVMBuildAdd(builder,
                                                              LLVMBuildPtrToInt(builder,
                                                                                programParam,
                                                                                fInt64Type,
                                                                                "cast 1"),
                                                              LLVMConstInt(fInt64Type, PTR_SIZE,
                                                                           false),
                                                              "add"),
                                                LLVMPointerType(fInt8PtrType, 0), "cast 2");
        LLVMValueRef args[STAGE_PARAM_COUNT] = {
            params.get()[0],
            nextInc,
            params.get()[2],
            params.get()[3],
            LLVMBuildLoad(builder, fChannels[0], "rVec"),
            LLVMBuildLoad(builder, fChannels[1], "gVec"),
            LLVMBuildLoad(builder, fChannels[2], "bVec"),
            LLVMBuildLoad(builder, fChannels[3], "aVec"),
            params.get()[8],
            params.get()[9],
            params.get()[10],
            params.get()[11]
        };
        LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
        LLVMBuildRetVoid(builder);
        // finish
        LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
        LLVMBuildBr(builder, start);
        LLVMDisposeBuilder(builder);
        if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
            ABORT("verify failed\n");
        }
    } else {
        LLVMDeleteBasicBlock(fAllocaBlock);
        LLVMDeleteBasicBlock(start);
    }

    fAllocaBlock = oldAllocaBlock;
    fCurrentBlock = oldCurrentBlock;
    fCurrentFunction = oldFunction;
    return success;
}

LLVMValueRef JIT::compileStageFunction(const FunctionDefinition& f) {
    LLVMTypeRef returnType = fVoidType;
    LLVMTypeRef parameterTypes[12] = { fSizeTType, LLVMPointerType(fInt8PtrType, 0), fSizeTType,
                                       fSizeTType, fFloat32VectorType, fFloat32VectorType,
                                       fFloat32VectorType, fFloat32VectorType, fFloat32VectorType,
                                       fFloat32VectorType, fFloat32VectorType, fFloat32VectorType };
    LLVMTypeRef stageFuncType = LLVMFunctionType(returnType, parameterTypes, 12, false);
    LLVMValueRef result = LLVMAddFunction(fModule,
                                          (String(f.fDeclaration.fName) + "$stage").c_str(),
                                          stageFuncType);
    fColorParam = f.fDeclaration.fParameters[2];
    if (!this->compileStageFunctionVector(f, result)) {
        // vectorization failed, fall back to looping over the pixels
        this->compileStageFunctionLoop(f, result);
    }
    return result;
}

bool JIT::hasStageSignature(const FunctionDeclaration& f) {
    return f.fReturnType == *fProgram->fContext->fVoid_Type &&
           f.fParameters.size() == 3 &&
           f.fParameters[0]->fType == *fProgram->fContext->fInt_Type &&
           f.fParameters[0]->fModifiers.fFlags == 0 &&
           f.fParameters[1]->fType == *fProgram->fContext->fInt_Type &&
           f.fParameters[1]->fModifiers.fFlags == 0 &&
           f.fParameters[2]->fType == *fProgram->fContext->fHalf4_Type &&
           f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag);
}

LLVMValueRef JIT::compileFunction(const FunctionDefinition& f) {
    if (this->hasStageSignature(f.fDeclaration)) {
        this->compileStageFunction(f);
        // we compile foo$stage *in addition* to compiling foo, as we can't be sure that the intent
        // was to produce an SkJumper stage just because the signature matched or that the function
        // is not otherwise called. May need a better way to handle this.
    }
    LLVMTypeRef returnType = this->getType(f.fDeclaration.fReturnType);
    std::vector<LLVMTypeRef> parameterTypes;
    for (const auto& p : f.fDeclaration.fParameters) {
        LLVMTypeRef type = this->getType(p->fType);
        if (p->fModifiers.fFlags & Modifiers::kOut_Flag) {
            type = LLVMPointerType(type, 0);
        }
        parameterTypes.push_back(type);
    }
    fCurrentFunction  = LLVMAddFunction(fModule,
                                        String(f.fDeclaration.fName).c_str(),
                                        LLVMFunctionType(returnType, parameterTypes.data(),
                                                         parameterTypes.size(), false));
    fFunctions[&f.fDeclaration] = fCurrentFunction;

    std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[parameterTypes.size()]);
    LLVMGetParams(fCurrentFunction, params.get());
    for (size_t i = 0; i < f.fDeclaration.fParameters.size(); ++i) {
        fVariables[f.fDeclaration.fParameters[i]] = params.get()[i];
    }
    LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
    fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
    LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
    fCurrentBlock = start;
    LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
    this->compileStatement(builder, *f.fBody);
    if (!ends_with_branch(*f.fBody)) {
        if (f.fDeclaration.fReturnType == *fProgram->fContext->fVoid_Type) {
            LLVMBuildRetVoid(builder);
        } else {
            LLVMBuildUnreachable(builder);
        }
    }
    LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
    LLVMBuildBr(builder, start);
    LLVMDisposeBuilder(builder);
    if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
        ABORT("verify failed\n");
    }
    return fCurrentFunction;
}

void JIT::createModule() {
    fPromotedParameters.clear();
    fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext);
    this->loadBuiltinFunctions();
    LLVMTypeRef fold2Params[1] = { fInt1Vector2Type };
    fFoldAnd2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v2i1",
                                    LLVMFunctionType(fInt1Type, fold2Params, 1, false));
    fFoldOr2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v2i1",
                                   LLVMFunctionType(fInt1Type, fold2Params, 1, false));
    LLVMTypeRef fold3Params[1] = { fInt1Vector3Type };
    fFoldAnd3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v3i1",
                                    LLVMFunctionType(fInt1Type, fold3Params, 1, false));
    fFoldOr3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v3i1",
                                   LLVMFunctionType(fInt1Type, fold3Params, 1, false));
    LLVMTypeRef fold4Params[1] = { fInt1Vector4Type };
    fFoldAnd4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v4i1",
                                    LLVMFunctionType(fInt1Type, fold4Params, 1, false));
    fFoldOr4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v4i1",
                                   LLVMFunctionType(fInt1Type, fold4Params, 1, false));
    // LLVM doesn't do void*, have to declare it as int8*
    LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType };
    fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType,
                                                                                    appendParams,
                                                                                    3,
                                                                                    false));
    LLVMTypeRef appendCallbackParams[2] = { fInt8PtrType, fInt8PtrType };
    fAppendCallbackFunc = LLVMAddFunction(fModule, "sksl_pipeline_append_callback",
                                          LLVMFunctionType(fVoidType, appendCallbackParams, 2,
                                                           false));

    LLVMTypeRef debugParams[3] = { fFloat32Type };
    fDebugFunc = LLVMAddFunction(fModule, "sksl_debug_print", LLVMFunctionType(fVoidType,
                                                                               debugParams,
                                                                               1,
                                                                               false));

    for (const auto& e : *fProgram) {
        if (e.fKind == ProgramElement::kFunction_Kind) {
            this->compileFunction((FunctionDefinition&) e);
        }
    }
}

std::unique_ptr<JIT::Module> JIT::compile(std::unique_ptr<Program> program) {
    fCompiler.optimize(*program);
    fProgram = std::move(program);
    this->createModule();
    this->optimize();
    return std::unique_ptr<Module>(new Module(std::move(fProgram), fSharedModule, fJITStack));
}

void JIT::optimize() {
    LLVMPassManagerBuilderRef pmb = LLVMPassManagerBuilderCreate();
    LLVMPassManagerBuilderSetOptLevel(pmb, 3);
    LLVMPassManagerRef functionPM = LLVMCreateFunctionPassManagerForModule(fModule);
    LLVMPassManagerBuilderPopulateFunctionPassManager(pmb, functionPM);
    LLVMPassManagerRef modulePM = LLVMCreatePassManager();
    LLVMPassManagerBuilderPopulateModulePassManager(pmb, modulePM);
    LLVMInitializeFunctionPassManager(functionPM);

    LLVMValueRef func = LLVMGetFirstFunction(fModule);
    for (;;) {
        if (!func) {
            break;
        }
        LLVMRunFunctionPassManager(functionPM, func);
        func = LLVMGetNextFunction(func);
    }
    LLVMRunPassManager(modulePM, fModule);
    LLVMDisposePassManager(functionPM);
    LLVMDisposePassManager(modulePM);
    LLVMPassManagerBuilderDispose(pmb);

    std::string error_string;
    if (LLVMLoadLibraryPermanently(nullptr)) {
        ABORT("LLVMLoadLibraryPermanently failed");
    }
    char* defaultTriple = LLVMGetDefaultTargetTriple();
    char* error;
    LLVMTargetRef target;
    if (LLVMGetTargetFromTriple(defaultTriple, &target, &error)) {
        ABORT("LLVMGetTargetFromTriple failed");
    }

    if (!LLVMTargetHasJIT(target)) {
        ABORT("!LLVMTargetHasJIT");
    }
    LLVMTargetMachineRef targetMachine = LLVMCreateTargetMachine(target,
                                                                 defaultTriple,
                                                                 fCPU,
                                                                 nullptr,
                                                                 LLVMCodeGenLevelDefault,
                                                                 LLVMRelocDefault,
                                                                 LLVMCodeModelJITDefault);
    LLVMDisposeMessage(defaultTriple);
    LLVMTargetDataRef dataLayout = LLVMCreateTargetDataLayout(targetMachine);
    LLVMSetModuleDataLayout(fModule, dataLayout);
    LLVMDisposeTargetData(dataLayout);

    fJITStack = LLVMOrcCreateInstance(targetMachine);
    fSharedModule = LLVMOrcMakeSharedModule(fModule);
    LLVMOrcModuleHandle orcModule;
    LLVMOrcAddEagerlyCompiledIR(fJITStack, &orcModule, fSharedModule,
                                (LLVMOrcSymbolResolverFn) resolveSymbol, this);
    LLVMDisposeTargetMachine(targetMachine);
}

void* JIT::Module::getSymbol(const char* name) {
    LLVMOrcTargetAddress result;
    if (LLVMOrcGetSymbolAddress(fJITStack, &result, name)) {
        ABORT("GetSymbolAddress error");
    }
    if (!result) {
        ABORT("symbol not found");
    }
    return (void*) result;
}

void* JIT::Module::getJumperStage(const char* name) {
    return this->getSymbol((String(name) + "$stage").c_str());
}

} // namespace

#endif // SK_LLVM_AVAILABLE

#endif // SKSL_STANDALONE