summaryrefslogtreecommitdiff
path: root/chromium/third_party/dawn/src/dawn_native/d3d12/ComputePipelineD3D12.cpp
blob: d70846ea099db02684843e164c27a3ec3ca03770 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// Copyright 2017 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dawn_native/d3d12/ComputePipelineD3D12.h"

#include "common/Assert.h"
#include "dawn_native/d3d12/DeviceD3D12.h"
#include "dawn_native/d3d12/PipelineLayoutD3D12.h"
#include "dawn_native/d3d12/PlatformFunctions.h"
#include "dawn_native/d3d12/ShaderModuleD3D12.h"

namespace dawn_native { namespace d3d12 {

    ComputePipeline::ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor)
        : ComputePipelineBase(device, descriptor) {
        uint32_t compileFlags = 0;
#if defined(_DEBUG)
        // Enable better shader debugging with the graphics debugging tools.
        compileFlags |= D3DCOMPILE_DEBUG | D3DCOMPILE_SKIP_OPTIMIZATION;
#endif
        // SPRIV-cross does matrix multiplication expecting row major matrices
        compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR;

        const ShaderModule* module = ToBackend(descriptor->computeStage->module);
        const std::string& hlslSource = module->GetHLSLSource(ToBackend(GetLayout()));

        ComPtr<ID3DBlob> compiledShader;
        ComPtr<ID3DBlob> errors;

        const PlatformFunctions* functions = device->GetFunctions();
        if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
                                         nullptr, descriptor->computeStage->entryPoint, "cs_5_1",
                                         compileFlags, 0, &compiledShader, &errors))) {
            printf("%s\n", reinterpret_cast<char*>(errors->GetBufferPointer()));
            ASSERT(false);
        }

        D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {};
        d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature().Get();
        d3dDesc.CS.pShaderBytecode = compiledShader->GetBufferPointer();
        d3dDesc.CS.BytecodeLength = compiledShader->GetBufferSize();

        device->GetD3D12Device()->CreateComputePipelineState(&d3dDesc,
                                                             IID_PPV_ARGS(&mPipelineState));
    }

    ComputePipeline::~ComputePipeline() {
        ToBackend(GetDevice())->ReferenceUntilUnused(mPipelineState);
    }

    ComPtr<ID3D12PipelineState> ComputePipeline::GetPipelineState() {
        return mPipelineState;
    }

}}  // namespace dawn_native::d3d12