// CREATOR: https://github.com/vixorien/SimpleShader // LICENSE: MIT #pragma once #pragma comment(lib, "dxguid.lib") #pragma comment(lib, "d3dcompiler.lib") #include #include #include #include #include #include #include // -------------------------------------------------------- // Used by simple shaders to store information about // specific variables in constant buffers // -------------------------------------------------------- struct SimpleShaderVariable { unsigned int ByteOffset; unsigned int Size; unsigned int ConstantBufferIndex; }; // -------------------------------------------------------- // Contains information about a specific // constant buffer in a shader, as well as // the local data buffer for it // -------------------------------------------------------- struct SimpleConstantBuffer { std::string Name; D3D_CBUFFER_TYPE Type = D3D_CBUFFER_TYPE::D3D11_CT_CBUFFER; unsigned int Size = 0; unsigned int BindIndex = 0; Microsoft::WRL::ComPtr ConstantBuffer = 0; unsigned char* LocalDataBuffer = 0; std::vector Variables; }; // -------------------------------------------------------- // Contains info about a single SRV in a shader // -------------------------------------------------------- struct SimpleSRV { unsigned int Index; // The raw index of the SRV unsigned int BindIndex; // The register of the SRV }; // -------------------------------------------------------- // Contains info about a single Sampler in a shader // -------------------------------------------------------- struct SimpleSampler { unsigned int Index; // The raw index of the Sampler unsigned int BindIndex; // The register of the Sampler }; // -------------------------------------------------------- // Base abstract class for simplifying shader handling // -------------------------------------------------------- class ISimpleShader { public: ISimpleShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context); virtual ~ISimpleShader(); // Simple helpers bool IsShaderValid() { return shaderValid; } // Activating the shader and copying data void SetShader(); void CopyAllBufferData(); void CopyBufferData(unsigned int index); void CopyBufferData(std::string bufferName); // Sets arbitrary shader data bool SetData(std::string name, const void* data, unsigned int size); bool SetInt(std::string name, int data); bool SetFloat(std::string name, float data); bool SetFloat2(std::string name, const float data[2]); bool SetFloat2(std::string name, const DirectX::XMFLOAT2 data); bool SetFloat3(std::string name, const float data[3]); bool SetFloat3(std::string name, const DirectX::XMFLOAT3 data); bool SetFloat4(std::string name, const float data[4]); bool SetFloat4(std::string name, const DirectX::XMFLOAT4 data); bool SetMatrix4x4(std::string name, const float data[16]); bool SetMatrix4x4(std::string name, const DirectX::XMFLOAT4X4 data); // Setting shader resources virtual bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv) = 0; virtual bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState) = 0; // Simple resource checking bool HasVariable(std::string name); bool HasShaderResourceView(std::string name); bool HasSamplerState(std::string name); // Getting data about variables and resources const SimpleShaderVariable* GetVariableInfo(std::string name); const SimpleSRV* GetShaderResourceViewInfo(std::string name); const SimpleSRV* GetShaderResourceViewInfo(unsigned int index); size_t GetShaderResourceViewCount() { return textureTable.size(); } const SimpleSampler* GetSamplerInfo(std::string name); const SimpleSampler* GetSamplerInfo(unsigned int index); size_t GetSamplerCount() { return samplerTable.size(); } // Get data about constant buffers unsigned int GetBufferCount(); unsigned int GetBufferSize(unsigned int index); const SimpleConstantBuffer* GetBufferInfo(std::string name); const SimpleConstantBuffer* GetBufferInfo(unsigned int index); // Misc getters Microsoft::WRL::ComPtr GetShaderBlob() { return shaderBlob; } // Error reporting static bool ReportErrors; static bool ReportWarnings; protected: bool shaderValid; Microsoft::WRL::ComPtr shaderBlob; Microsoft::WRL::ComPtr device; Microsoft::WRL::ComPtr deviceContext; // Resource counts unsigned int constantBufferCount; // Maps for variables and buffers SimpleConstantBuffer* constantBuffers; // For index-based lookup std::vector shaderResourceViews; std::vector samplerStates; std::unordered_map cbTable; std::unordered_map varTable; std::unordered_map textureTable; std::unordered_map samplerTable; // Initialization method bool LoadShaderFile(LPCWSTR shaderFile); // Pure virtual functions for dealing with shader types virtual bool CreateShader(Microsoft::WRL::ComPtr shaderBlob) = 0; virtual void SetShaderAndCBs() = 0; virtual void CleanUp(); // Helpers for finding data by name SimpleShaderVariable* FindVariable(std::string name, int size); SimpleConstantBuffer* FindConstantBuffer(std::string name); // Error logging void Log(std::string message, WORD color); void LogW(std::wstring message, WORD color); void Log(std::string message); void LogW(std::wstring message); void LogError(std::string message); void LogErrorW(std::wstring message); void LogWarning(std::string message); void LogWarningW(std::wstring message); }; // -------------------------------------------------------- // Derived class for VERTEX shaders /////////////////////// // -------------------------------------------------------- class SimpleVertexShader : public ISimpleShader { public: SimpleVertexShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile); SimpleVertexShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile, Microsoft::WRL::ComPtr inputLayout, bool perInstanceCompatible); ~SimpleVertexShader(); Microsoft::WRL::ComPtr GetDirectXShader() { return shader; } Microsoft::WRL::ComPtr GetInputLayout() { return inputLayout; } bool GetPerInstanceCompatible() { return perInstanceCompatible; } bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv); bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState); protected: bool perInstanceCompatible; Microsoft::WRL::ComPtr inputLayout; Microsoft::WRL::ComPtr shader; bool CreateShader(Microsoft::WRL::ComPtr shaderBlob); void SetShaderAndCBs(); void CleanUp(); }; // -------------------------------------------------------- // Derived class for PIXEL shaders //////////////////////// // -------------------------------------------------------- class SimplePixelShader : public ISimpleShader { public: SimplePixelShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile); ~SimplePixelShader(); Microsoft::WRL::ComPtr GetDirectXShader() { return shader; } bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv); bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState); protected: Microsoft::WRL::ComPtr shader; bool CreateShader(Microsoft::WRL::ComPtr shaderBlob); void SetShaderAndCBs(); void CleanUp(); }; // -------------------------------------------------------- // Derived class for DOMAIN shaders /////////////////////// // -------------------------------------------------------- class SimpleDomainShader : public ISimpleShader { public: SimpleDomainShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile); ~SimpleDomainShader(); Microsoft::WRL::ComPtr GetDirectXShader() { return shader; } bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv); bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState); protected: Microsoft::WRL::ComPtr shader; bool CreateShader(Microsoft::WRL::ComPtr shaderBlob); void SetShaderAndCBs(); void CleanUp(); }; // -------------------------------------------------------- // Derived class for HULL shaders ///////////////////////// // -------------------------------------------------------- class SimpleHullShader : public ISimpleShader { public: SimpleHullShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile); ~SimpleHullShader(); Microsoft::WRL::ComPtr GetDirectXShader() { return shader; } bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv); bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState); protected: Microsoft::WRL::ComPtr shader; bool CreateShader(Microsoft::WRL::ComPtr shaderBlob); void SetShaderAndCBs(); void CleanUp(); }; // -------------------------------------------------------- // Derived class for GEOMETRY shaders ///////////////////// // -------------------------------------------------------- class SimpleGeometryShader : public ISimpleShader { public: SimpleGeometryShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile, bool useStreamOut = 0, bool allowStreamOutRasterization = 0); ~SimpleGeometryShader(); Microsoft::WRL::ComPtr GetDirectXShader() { return shader; } bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv); bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState); bool CreateCompatibleStreamOutBuffer(Microsoft::WRL::ComPtr buffer, int vertexCount); static void UnbindStreamOutStage(Microsoft::WRL::ComPtr deviceContext); protected: // Shader itself Microsoft::WRL::ComPtr shader; // Stream out related bool useStreamOut; bool allowStreamOutRasterization; unsigned int streamOutVertexSize; bool CreateShader(Microsoft::WRL::ComPtr shaderBlob); bool CreateShaderWithStreamOut(Microsoft::WRL::ComPtr shaderBlob); void SetShaderAndCBs(); void CleanUp(); // Helpers unsigned int CalcComponentCount(unsigned int mask); }; // -------------------------------------------------------- // Derived class for COMPUTE shaders ////////////////////// // -------------------------------------------------------- class SimpleComputeShader : public ISimpleShader { public: SimpleComputeShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile); ~SimpleComputeShader(); Microsoft::WRL::ComPtr GetDirectXShader() { return shader; } void DispatchByGroups(unsigned int groupsX, unsigned int groupsY, unsigned int groupsZ); void DispatchByThreads(unsigned int threadsX, unsigned int threadsY, unsigned int threadsZ); bool HasUnorderedAccessView(std::string name); bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv); bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState); bool SetUnorderedAccessView(std::string name, Microsoft::WRL::ComPtr uav, unsigned int appendConsumeOffset = -1); int GetUnorderedAccessViewIndex(std::string name); protected: Microsoft::WRL::ComPtr shader; std::unordered_map uavTable; unsigned int threadsX; unsigned int threadsY; unsigned int threadsZ; unsigned int threadsTotal; bool CreateShader(Microsoft::WRL::ComPtr shaderBlob); void SetShaderAndCBs(); void CleanUp(); };