diff --git a/DX11Starter.vcxproj b/DX11Starter.vcxproj index d2480fb..f083e8d 100644 --- a/DX11Starter.vcxproj +++ b/DX11Starter.vcxproj @@ -130,6 +130,7 @@ + @@ -140,6 +141,7 @@ + diff --git a/DX11Starter.vcxproj.filters b/DX11Starter.vcxproj.filters index 1bc47d3..bb2d3b0 100644 --- a/DX11Starter.vcxproj.filters +++ b/DX11Starter.vcxproj.filters @@ -42,6 +42,9 @@ Source Files + + Source Files + @@ -71,6 +74,9 @@ Header Files + + Header Files + diff --git a/Game.cpp b/Game.cpp index 8caa79c..7290e26 100644 --- a/Game.cpp +++ b/Game.cpp @@ -2,6 +2,7 @@ #include "Vertex.h" #include "Input.h" #include "BufferStructs.h" +#include "SimpleShader.h" // Needed for a helper function to read compiled shader files from the hard drive #pragma comment(lib, "d3dcompiler.lib") @@ -93,69 +94,10 @@ void Game::Init() // -------------------------------------------------------- void Game::LoadShaders() { - // Blob for reading raw data - // - This is a simplified way of handling raw data - ID3DBlob* shaderBlob; - - // Read our compiled vertex shader code into a blob - // - Essentially just "open the file and plop its contents here" - D3DReadFileToBlob( - GetFullPathTo_Wide(L"VertexShader.cso").c_str(), // Using a custom helper for file paths - &shaderBlob); - - // Create a vertex shader from the information we - // have read into the blob above - // - A blob can give a pointer to its contents, and knows its own size - device->CreateVertexShader( - shaderBlob->GetBufferPointer(), // Get a pointer to the blob's contents - shaderBlob->GetBufferSize(), // How big is that data? - 0, // No classes in this shader - vertexShader.GetAddressOf()); // The address of the ID3D11VertexShader* - - - // Create an input layout that describes the vertex format - // used by the vertex shader we're using - // - This is used by the pipeline to know how to interpret the raw data - // sitting inside a vertex buffer - // - Doing this NOW because it requires a vertex shader's byte code to verify against! - // - Luckily, we already have that loaded (the blob above) - D3D11_INPUT_ELEMENT_DESC inputElements[2] = {}; - - // Set up the first element - a position, which is 3 float values - inputElements[0].Format = DXGI_FORMAT_R32G32B32_FLOAT; // Most formats are described as color channels; really it just means "Three 32-bit floats" - inputElements[0].SemanticName = "POSITION"; // This is "POSITION" - needs to match the semantics in our vertex shader input! - inputElements[0].AlignedByteOffset = D3D11_APPEND_ALIGNED_ELEMENT; // How far into the vertex is this? Assume it's after the previous element - - // Set up the second element - a color, which is 4 more float values - inputElements[1].Format = DXGI_FORMAT_R32G32B32A32_FLOAT; // 4x 32-bit floats - inputElements[1].SemanticName = "COLOR"; // Match our vertex shader input! - inputElements[1].AlignedByteOffset = D3D11_APPEND_ALIGNED_ELEMENT; // After the previous element - - // Create the input layout, verifying our description against actual shader code - device->CreateInputLayout( - inputElements, // An array of descriptions - 2, // How many elements in that array - shaderBlob->GetBufferPointer(), // Pointer to the code of a shader that uses this layout - shaderBlob->GetBufferSize(), // Size of the shader code that uses this layout - inputLayout.GetAddressOf()); // Address of the resulting ID3D11InputLayout* - - - - // Read and create the pixel shader - // - Reusing the same blob here, since we're done with the vert shader code - D3DReadFileToBlob( - GetFullPathTo_Wide(L"PixelShader.cso").c_str(), // Using a custom helper for file paths - &shaderBlob); - - device->CreatePixelShader( - shaderBlob->GetBufferPointer(), - shaderBlob->GetBufferSize(), - 0, - pixelShader.GetAddressOf()); + vertexShader = std::make_shared(device, context, GetFullPathTo_Wide(L"VertexShader.cso").c_str()); + pixelShader = std::make_shared(device, context, GetFullPathTo_Wide(L"PixelShader.cso").c_str()); } - - // -------------------------------------------------------- // Creates the geometry we're going to draw - a single triangle for now // -------------------------------------------------------- @@ -301,13 +243,6 @@ void Game::Draw(float deltaTime, float totalTime) constantBufferVS.GetAddressOf() // Array of buffers (or address of one) ); - // Set the vertex and pixel shaders to use for the next Draw() command - // - These don't technically need to be set every frame - // - Once you start applying different shaders to different objects, - // you'll need to swap the current shaders before each draw - context->VSSetShader(vertexShader.Get(), 0, 0); - context->PSSetShader(pixelShader.Get(), 0, 0); - // Ensure the pipeline knows how to interpret the data (numbers) // from the vertex buffer. // - If all of your 3D models use the exact same vertex layout, diff --git a/Game.h b/Game.h index 51cfa8d..970d7e3 100644 --- a/Game.h +++ b/Game.h @@ -4,6 +4,7 @@ #include "Camera.h" #include "Mesh.h" #include "Entity.h" +#include "SimpleShader.h" #include #include // Used for ComPtr - a smart pointer for COM objects #include @@ -39,8 +40,8 @@ private: // - More info here: https://github.com/Microsoft/DirectXTK/wiki/ComPtr // Shaders and shader-related constructs - Microsoft::WRL::ComPtr pixelShader; - Microsoft::WRL::ComPtr vertexShader; + std::shared_ptr pixelShader; + std::shared_ptr vertexShader; Microsoft::WRL::ComPtr inputLayout; // Temporary A2 shapes diff --git a/SimpleShader.cpp b/SimpleShader.cpp new file mode 100644 index 0000000..592089e --- /dev/null +++ b/SimpleShader.cpp @@ -0,0 +1,1973 @@ +// CREATOR: https://github.com/vixorien/SimpleShader +// LICENSE: MIT + +#include "SimpleShader.h" + +// Default error reporting state +bool ISimpleShader::ReportErrors = false; +bool ISimpleShader::ReportWarnings = false; + +// To enable error reporting, use either or both +// of the following lines somewhere in your program, +// preferably before loading/using any shaders. +// +// ISimpleShader::ReportErrors = true; +// ISimpleShader::ReportWarnings = true; + + +/////////////////////////////////////////////////////////////////////////////// +// ------ BASE SIMPLE SHADER -------------------------------------------------- +/////////////////////////////////////////////////////////////////////////////// + +// -------------------------------------------------------- +// Constructor accepts Direct3D device & context +// -------------------------------------------------------- +ISimpleShader::ISimpleShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context) +{ + // Save the device + this->device = device; + this->deviceContext = context; + + // Set up fields + this->constantBufferCount = 0; + this->constantBuffers = 0; + this->shaderValid = false; +} + +// -------------------------------------------------------- +// Destructor +// -------------------------------------------------------- +ISimpleShader::~ISimpleShader() +{ + // Derived class destructors will call this class's CleanUp method +} + +// -------------------------------------------------------- +// Cleans up the variable table and buffers - Some things will +// be handled by derived classes +// -------------------------------------------------------- +void ISimpleShader::CleanUp() +{ + // Handle constant buffers and local data buffers + for (unsigned int i = 0; i < constantBufferCount; i++) + { + delete[] constantBuffers[i].LocalDataBuffer; + } + + if (constantBuffers) + { + delete[] constantBuffers; + constantBufferCount = 0; + } + + for (unsigned int i = 0; i < shaderResourceViews.size(); i++) + delete shaderResourceViews[i]; + + for (unsigned int i = 0; i < samplerStates.size(); i++) + delete samplerStates[i]; + + // Clean up tables + varTable.clear(); + cbTable.clear(); + samplerTable.clear(); + textureTable.clear(); +} + +// -------------------------------------------------------- +// Loads the specified shader and builds the variable table +// using shader reflection. +// +// shaderFile - A "wide string" specifying the compiled shader to load +// +// Returns true if shader is loaded properly, false otherwise +// -------------------------------------------------------- +bool ISimpleShader::LoadShaderFile(LPCWSTR shaderFile) +{ + // Load the shader to a blob and ensure it worked + HRESULT hr = D3DReadFileToBlob(shaderFile, shaderBlob.GetAddressOf()); + if (hr != S_OK) + { + if (ReportErrors) + { + LogError("SimpleShader::LoadShaderFile() - Error loading file '"); + LogW(shaderFile); + LogError("'. Ensure this file exists and is spelled correctly.\n"); + } + + return false; + } + + // Create the shader - Calls an overloaded version of this abstract + // method in the appropriate child class + shaderValid = CreateShader(shaderBlob); + if (!shaderValid) + { + if (ReportErrors) + { + LogError("SimpleShader::LoadShaderFile() - Error creating shader from file '"); + LogW(shaderFile); + LogError("'. Ensure the type of shader (vertex, pixel, etc.) matches the SimpleShader type (SimpleVertexShader, SimplePixelShader, etc.) you're using.\n"); + } + + return false; + } + + // Set up shader reflection to get information about + // this shader and its variables, buffers, etc. + Microsoft::WRL::ComPtr refl; + D3DReflect( + shaderBlob->GetBufferPointer(), + shaderBlob->GetBufferSize(), + IID_ID3D11ShaderReflection, + (void**)refl.GetAddressOf()); + + // Get the description of the shader + D3D11_SHADER_DESC shaderDesc; + refl->GetDesc(&shaderDesc); + + // Create resource arrays + constantBufferCount = shaderDesc.ConstantBuffers; + constantBuffers = new SimpleConstantBuffer[constantBufferCount]; + + // Handle bound resources (like shaders and samplers) + unsigned int resourceCount = shaderDesc.BoundResources; + for (unsigned int r = 0; r < resourceCount; r++) + { + // Get this resource's description + D3D11_SHADER_INPUT_BIND_DESC resourceDesc; + refl->GetResourceBindingDesc(r, &resourceDesc); + + // Check the type + switch (resourceDesc.Type) + { + case D3D_SIT_STRUCTURED: // Treat structured buffers as texture resources + case D3D_SIT_TEXTURE: // A texture resource + { + // Create the SRV wrapper + SimpleSRV* srv = new SimpleSRV(); + srv->BindIndex = resourceDesc.BindPoint; // Shader bind point + srv->Index = (unsigned int)shaderResourceViews.size(); // Raw index + + textureTable.insert(std::pair(resourceDesc.Name, srv)); + shaderResourceViews.push_back(srv); + } + break; + + case D3D_SIT_SAMPLER: // A sampler resource + { + // Create the sampler wrapper + SimpleSampler* samp = new SimpleSampler(); + samp->BindIndex = resourceDesc.BindPoint; // Shader bind point + samp->Index = (unsigned int)samplerStates.size(); // Raw index + + samplerTable.insert(std::pair(resourceDesc.Name, samp)); + samplerStates.push_back(samp); + } + break; + } + } + + // Loop through all constant buffers + for (unsigned int b = 0; b < constantBufferCount; b++) + { + // Get this buffer + ID3D11ShaderReflectionConstantBuffer* cb = + refl->GetConstantBufferByIndex(b); + + // Get the description of this buffer + D3D11_SHADER_BUFFER_DESC bufferDesc; + cb->GetDesc(&bufferDesc); + + // Save the type, which we reference when setting these buffers + constantBuffers[b].Type = bufferDesc.Type; + + // Get the description of the resource binding, so + // we know exactly how it's bound in the shader + D3D11_SHADER_INPUT_BIND_DESC bindDesc; + refl->GetResourceBindingDescByName(bufferDesc.Name, &bindDesc); + + // Set up the buffer and put its pointer in the table + constantBuffers[b].BindIndex = bindDesc.BindPoint; + constantBuffers[b].Name = bufferDesc.Name; + cbTable.insert(std::pair(bufferDesc.Name, &constantBuffers[b])); + + // Create this constant buffer + D3D11_BUFFER_DESC newBuffDesc = {}; + newBuffDesc.Usage = D3D11_USAGE_DEFAULT; + newBuffDesc.ByteWidth = ((bufferDesc.Size + 15) / 16) * 16; // Quick and dirty 16-byte alignment using integer division + newBuffDesc.BindFlags = D3D11_BIND_CONSTANT_BUFFER; + newBuffDesc.CPUAccessFlags = 0; + newBuffDesc.MiscFlags = 0; + newBuffDesc.StructureByteStride = 0; + device->CreateBuffer(&newBuffDesc, 0, constantBuffers[b].ConstantBuffer.GetAddressOf()); + + // Set up the data buffer for this constant buffer + constantBuffers[b].Size = bufferDesc.Size; + constantBuffers[b].LocalDataBuffer = new unsigned char[bufferDesc.Size]; + ZeroMemory(constantBuffers[b].LocalDataBuffer, bufferDesc.Size); + + // Loop through all variables in this buffer + for (unsigned int v = 0; v < bufferDesc.Variables; v++) + { + // Get this variable + ID3D11ShaderReflectionVariable* var = + cb->GetVariableByIndex(v); + + // Get the description of the variable and its type + D3D11_SHADER_VARIABLE_DESC varDesc; + var->GetDesc(&varDesc); + + // Create the variable struct + SimpleShaderVariable varStruct = {}; + varStruct.ConstantBufferIndex = b; + varStruct.ByteOffset = varDesc.StartOffset; + varStruct.Size = varDesc.Size; + + // Get a string version + std::string varName(varDesc.Name); + + // Add this variable to the table and the constant buffer + varTable.insert(std::pair(varName, varStruct)); + constantBuffers[b].Variables.push_back(varStruct); + } + } + + // All set + return true; +} + +// -------------------------------------------------------- +// Helper for looking up a variable by name and also +// verifying that it is the requested size +// +// name - the name of the variable to look for +// size - the size of the variable (for verification), or -1 to bypass +// -------------------------------------------------------- +SimpleShaderVariable* ISimpleShader::FindVariable(std::string name, int size) +{ + // Look for the key + std::unordered_map::iterator result = + varTable.find(name); + + // Did we find the key? + if (result == varTable.end()) + return 0; + + // Grab the result from the iterator + SimpleShaderVariable* var = &(result->second); + + // Is the data size correct ? + if (size > 0 && var->Size != size) + return 0; + + // Success + return var; +} + +// -------------------------------------------------------- +// Helper for looking up a constant buffer by name +// -------------------------------------------------------- +SimpleConstantBuffer* ISimpleShader::FindConstantBuffer(std::string name) +{ + // Look for the key + std::unordered_map::iterator result = + cbTable.find(name); + + // Did we find the key? + if (result == cbTable.end()) + return 0; + + // Success + return result->second; +} + +// -------------------------------------------------------- +// Prints the specified message to the console with the +// given color and Visual Studio's output window +// -------------------------------------------------------- +void ISimpleShader::Log(std::string message, WORD color) +{ + // Swap console color + HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE); + SetConsoleTextAttribute(hConsole, color); + + printf_s(message.c_str()); + OutputDebugString(message.c_str()); + + // Swap back + SetConsoleTextAttribute(hConsole, FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE); +} + +// -------------------------------------------------------- +// Prints the specified message, as a wide string, to the +// console with the given color and Visual Studio's output window +// -------------------------------------------------------- +void ISimpleShader::LogW(std::wstring message, WORD color) +{ + // Swap console color + HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE); + SetConsoleTextAttribute(hConsole, color); + + wprintf_s(message.c_str()); + OutputDebugStringW(message.c_str()); + + // Swap back + SetConsoleTextAttribute(hConsole, FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE); +} + + +// Helpers for pritning errors and warnings in specific colors using regular and wide character strings +void ISimpleShader::Log(std::string message) { Log(message, FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE | FOREGROUND_INTENSITY); } +void ISimpleShader::LogW(std::wstring message) { LogW(message, FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE | FOREGROUND_INTENSITY); } +void ISimpleShader::LogError(std::string message) { Log(message, FOREGROUND_RED | FOREGROUND_INTENSITY); } +void ISimpleShader::LogErrorW(std::wstring message) { LogW(message, FOREGROUND_RED | FOREGROUND_INTENSITY); } +void ISimpleShader::LogWarning(std::string message) { Log(message, FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_INTENSITY); } +void ISimpleShader::LogWarningW(std::wstring message) { LogW(message, FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_INTENSITY); } + + +// -------------------------------------------------------- +// Sets the shader and associated constant buffers in Direct3D +// -------------------------------------------------------- +void ISimpleShader::SetShader() +{ + // Ensure the shader is valid + if (!shaderValid) return; + + // Set the shader and any relevant constant buffers, which + // is an overloaded method in a subclass + SetShaderAndCBs(); +} + +// -------------------------------------------------------- +// Copies the relevant data to the all of this +// shader's constant buffers. To just copy one +// buffer, use CopyBufferData() +// -------------------------------------------------------- +void ISimpleShader::CopyAllBufferData() +{ + // Ensure the shader is valid + if (!shaderValid) return; + + // Loop through the constant buffers and copy all data + for (unsigned int i = 0; i < constantBufferCount; i++) + { + // Copy the entire local data buffer + deviceContext->UpdateSubresource( + constantBuffers[i].ConstantBuffer.Get(), 0, 0, + constantBuffers[i].LocalDataBuffer, 0, 0); + } +} + +// -------------------------------------------------------- +// Copies local data to the shader's specified constant buffer +// +// index - The index of the buffer to copy. +// Useful for updating more frequently-changing +// variables without having to re-copy all buffers. +// +// NOTE: The "index" of the buffer might NOT be the same +// as its register, especially if you have buffers +// bound to non-sequential registers! +// -------------------------------------------------------- +void ISimpleShader::CopyBufferData(unsigned int index) +{ + // Ensure the shader is valid + if (!shaderValid) return; + + // Validate the index + if (index >= this->constantBufferCount) + return; + + // Check for the buffer + SimpleConstantBuffer* cb = &this->constantBuffers[index]; + if (!cb) return; + + // Copy the data and get out + deviceContext->UpdateSubresource( + cb->ConstantBuffer.Get(), 0, 0, + cb->LocalDataBuffer, 0, 0); +} + +// -------------------------------------------------------- +// Copies local data to the shader's specified constant buffer +// +// bufferName - Specifies the name of the buffer to copy. +// Useful for updating more frequently-changing +// variables without having to re-copy all buffers. +// -------------------------------------------------------- +void ISimpleShader::CopyBufferData(std::string bufferName) +{ + // Ensure the shader is valid + if (!shaderValid) return; + + // Check for the buffer + SimpleConstantBuffer* cb = this->FindConstantBuffer(bufferName); + if (!cb) return; + + // Copy the data and get out + deviceContext->UpdateSubresource( + cb->ConstantBuffer.Get(), 0, 0, + cb->LocalDataBuffer, 0, 0); +} + + +// -------------------------------------------------------- +// Sets a variable by name with arbitrary data of the specified size +// +// name - The name of the shader variable +// data - The data to set in the buffer +// size - The size of the data (this must be less than or equal to the variable's size) +// +// Returns true if data is copied, false if variable doesn't exist +// -------------------------------------------------------- +bool ISimpleShader::SetData(std::string name, const void* data, unsigned int size) +{ + // Look for the variable and verify + SimpleShaderVariable* var = FindVariable(name, -1); + if (var == 0) + { + if (ReportWarnings) + { + LogWarning("SimpleShader::SetData() - Shader variable '"); + Log(name); + LogWarning("' not found. Ensure the name is spelled correctly and that it exists in a constant buffer in the shader.\n"); + } + return false; + } + + // Ensure we're not trying to copy more data than the variable can hold + // Note: We can copy less data, in the case of a subset of an array + if (size > var->Size) + { + if (ReportWarnings) + { + LogWarning("SimpleShader::SetData() - Shader variable '"); + Log(name); + LogWarning("' is smaller than the size of the data being set. Ensure the variable is large enough for the specified data.\n"); + } + return false; + } + + // Set the data in the local data buffer + memcpy( + constantBuffers[var->ConstantBufferIndex].LocalDataBuffer + var->ByteOffset, + data, + size); + + // Success + return true; +} + +// -------------------------------------------------------- +// Sets INTEGER data +// -------------------------------------------------------- +bool ISimpleShader::SetInt(std::string name, int data) +{ + return this->SetData(name, (void*)(&data), sizeof(int)); +} + +// -------------------------------------------------------- +// Sets a FLOAT variable by name in the local data buffer +// -------------------------------------------------------- +bool ISimpleShader::SetFloat(std::string name, float data) +{ + return this->SetData(name, (void*)(&data), sizeof(float)); +} + +// -------------------------------------------------------- +// Sets a FLOAT2 variable by name in the local data buffer +// -------------------------------------------------------- +bool ISimpleShader::SetFloat2(std::string name, const float data[2]) +{ + return this->SetData(name, (void*)data, sizeof(float) * 2); +} + +// -------------------------------------------------------- +// Sets a FLOAT2 variable by name in the local data buffer +// -------------------------------------------------------- +bool ISimpleShader::SetFloat2(std::string name, const DirectX::XMFLOAT2 data) +{ + return this->SetData(name, &data, sizeof(float) * 2); +} + +// -------------------------------------------------------- +// Sets a FLOAT3 variable by name in the local data buffer +// -------------------------------------------------------- +bool ISimpleShader::SetFloat3(std::string name, const float data[3]) +{ + return this->SetData(name, (void*)data, sizeof(float) * 3); +} + +// -------------------------------------------------------- +// Sets a FLOAT3 variable by name in the local data buffer +// -------------------------------------------------------- +bool ISimpleShader::SetFloat3(std::string name, const DirectX::XMFLOAT3 data) +{ + return this->SetData(name, &data, sizeof(float) * 3); +} + +// -------------------------------------------------------- +// Sets a FLOAT4 variable by name in the local data buffer +// -------------------------------------------------------- +bool ISimpleShader::SetFloat4(std::string name, const float data[4]) +{ + return this->SetData(name, (void*)data, sizeof(float) * 4); +} + +// -------------------------------------------------------- +// Sets a FLOAT4 variable by name in the local data buffer +// -------------------------------------------------------- +bool ISimpleShader::SetFloat4(std::string name, const DirectX::XMFLOAT4 data) +{ + return this->SetData(name, &data, sizeof(float) * 4); +} + +// -------------------------------------------------------- +// Sets a MATRIX (4x4) variable by name in the local data buffer +// -------------------------------------------------------- +bool ISimpleShader::SetMatrix4x4(std::string name, const float data[16]) +{ + return this->SetData(name, (void*)data, sizeof(float) * 16); +} + +// -------------------------------------------------------- +// Sets a MATRIX (4x4) variable by name in the local data buffer +// -------------------------------------------------------- +bool ISimpleShader::SetMatrix4x4(std::string name, const DirectX::XMFLOAT4X4 data) +{ + return this->SetData(name, &data, sizeof(float) * 16); +} + +// -------------------------------------------------------- +// Determines if the shader contains the specified +// variable within one of its constant buffers +// -------------------------------------------------------- +bool ISimpleShader::HasVariable(std::string name) +{ + return FindVariable(name, -1) != 0; +} + +// -------------------------------------------------------- +// Determines if the shader contains the specified SRV +// -------------------------------------------------------- +bool ISimpleShader::HasShaderResourceView(std::string name) +{ + return GetShaderResourceViewInfo(name) != 0; +} + +// -------------------------------------------------------- +// Determines if the shader contains the specified sampler +// -------------------------------------------------------- +bool ISimpleShader::HasSamplerState(std::string name) +{ + return GetSamplerInfo(name) != 0; +} + +// -------------------------------------------------------- +// Gets info about a shader variable, if it exists +// -------------------------------------------------------- +const SimpleShaderVariable* ISimpleShader::GetVariableInfo(std::string name) +{ + return FindVariable(name, -1); +} + +// -------------------------------------------------------- +// Gets info about an SRV in the shader (or null) +// +// name - the name of the SRV +// -------------------------------------------------------- +const SimpleSRV* ISimpleShader::GetShaderResourceViewInfo(std::string name) +{ + // Look for the key + std::unordered_map::iterator result = + textureTable.find(name); + + // Did we find the key? + if (result == textureTable.end()) + return 0; + + // Success + return result->second; +} + + +// -------------------------------------------------------- +// Gets info about an SRV in the shader (or null) +// +// index - the index of the SRV +// -------------------------------------------------------- +const SimpleSRV* ISimpleShader::GetShaderResourceViewInfo(unsigned int index) +{ + // Valid index? + if (index >= shaderResourceViews.size()) return 0; + + // Grab the bind index + return shaderResourceViews[index]; +} + + +// -------------------------------------------------------- +// Gets info about a sampler in the shader (or null) +// +// name - the name of the sampler +// -------------------------------------------------------- +const SimpleSampler* ISimpleShader::GetSamplerInfo(std::string name) +{ + // Look for the key + std::unordered_map::iterator result = + samplerTable.find(name); + + // Did we find the key? + if (result == samplerTable.end()) + return 0; + + // Success + return result->second; +} + +// -------------------------------------------------------- +// Gets info about a sampler in the shader (or null) +// +// index - the index of the sampler +// -------------------------------------------------------- +const SimpleSampler* ISimpleShader::GetSamplerInfo(unsigned int index) +{ + // Valid index? + if (index >= samplerStates.size()) return 0; + + // Grab the bind index + return samplerStates[index]; +} + + +// -------------------------------------------------------- +// Gets the number of constant buffers in this shader +// -------------------------------------------------------- +unsigned int ISimpleShader::GetBufferCount() { return constantBufferCount; } + + + +// -------------------------------------------------------- +// Gets the size of a particular constant buffer, or -1 +// -------------------------------------------------------- +unsigned int ISimpleShader::GetBufferSize(unsigned int index) +{ + // Valid index? + if (index >= constantBufferCount) + return -1; + + // Grab the size + return constantBuffers[index].Size; +} + +// -------------------------------------------------------- +// Gets info about a particular constant buffer +// by name, if it exists +// -------------------------------------------------------- +const SimpleConstantBuffer* ISimpleShader::GetBufferInfo(std::string name) +{ + return FindConstantBuffer(name); +} + +// -------------------------------------------------------- +// Gets info about a particular constant buffer +// +// index - the index of the constant buffer +// -------------------------------------------------------- +const SimpleConstantBuffer* ISimpleShader::GetBufferInfo(unsigned int index) +{ + // Check for valid index + if (index >= constantBufferCount) return 0; + + // Return the specific buffer + return &constantBuffers[index]; +} + + + + + +/////////////////////////////////////////////////////////////////////////////// +// ------ SIMPLE VERTEX SHADER ------------------------------------------------ +/////////////////////////////////////////////////////////////////////////////// + +// -------------------------------------------------------- +// Constructor just calls the base +// -------------------------------------------------------- +SimpleVertexShader::SimpleVertexShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile) + : ISimpleShader(device, context) +{ + // Ensure we set to zero to successfully trigger + // the Input Layout creation during LoadShaderFile() + this->perInstanceCompatible = false; + + // Load the actual compiled shader file + this->LoadShaderFile(shaderFile); +} + +// -------------------------------------------------------- +// Constructor overload which takes a custom input layout +// +// Passing in a valid input layout will stop LoadShaderFile() +// from creating an input layout from shader reflection +// -------------------------------------------------------- +SimpleVertexShader::SimpleVertexShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile, Microsoft::WRL::ComPtr inputLayout, bool perInstanceCompatible) + : ISimpleShader(device, context) +{ + // Save the custom input layout + this->inputLayout = inputLayout; + + // Unable to determine from an input layout, require user to tell us + this->perInstanceCompatible = perInstanceCompatible; + + // Load the actual compiled shader file + this->LoadShaderFile(shaderFile); +} + +// -------------------------------------------------------- +// Destructor - Clean up actual shader (base will be called automatically) +// -------------------------------------------------------- +SimpleVertexShader::~SimpleVertexShader() +{ + CleanUp(); +} + +// -------------------------------------------------------- +// Handles cleaning up shader and base class clean up +// -------------------------------------------------------- +void SimpleVertexShader::CleanUp() +{ + ISimpleShader::CleanUp(); +} + +// -------------------------------------------------------- +// Creates the Direct3D vertex shader +// +// shaderBlob - The shader's compiled code +// +// Returns true if shader is created correctly, false otherwise +// -------------------------------------------------------- +bool SimpleVertexShader::CreateShader(Microsoft::WRL::ComPtr shaderBlob) +{ + // Clean up first, in the event this method is + // called more than once on the same object + this->CleanUp(); + + // Create the shader from the blob + HRESULT result = device->CreateVertexShader( + shaderBlob->GetBufferPointer(), + shaderBlob->GetBufferSize(), + 0, + shader.GetAddressOf()); + + // Did the creation work? + if (result != S_OK) + return false; + + // Do we already have an input layout? + // (This would come from one of the constructor overloads) + if (inputLayout) + return true; + + // Vertex shader was created successfully, so we now use the + // shader code to re-reflect and create an input layout that + // matches what the vertex shader expects. Code adapted from: + // https://takinginitiative.wordpress.com/2011/12/11/directx-1011-basic-shader-reflection-automatic-input-layout-creation/ + + // Reflect shader info + Microsoft::WRL::ComPtr refl; + D3DReflect( + shaderBlob->GetBufferPointer(), + shaderBlob->GetBufferSize(), + IID_ID3D11ShaderReflection, + (void**)refl.GetAddressOf()); + + // Get shader info + D3D11_SHADER_DESC shaderDesc; + refl->GetDesc(&shaderDesc); + + // Read input layout description from shader info + std::vector inputLayoutDesc; + for (unsigned int i = 0; i < shaderDesc.InputParameters; i++) + { + D3D11_SIGNATURE_PARAMETER_DESC paramDesc; + refl->GetInputParameterDesc(i, ¶mDesc); + + // Check the semantic name for "_PER_INSTANCE" + std::string perInstanceStr = "_PER_INSTANCE"; + std::string sem = paramDesc.SemanticName; + int lenDiff = (int)sem.size() - (int)perInstanceStr.size(); + bool isPerInstance = + lenDiff >= 0 && + sem.compare(lenDiff, perInstanceStr.size(), perInstanceStr) == 0; + + // Fill out input element desc + D3D11_INPUT_ELEMENT_DESC elementDesc = {}; + elementDesc.SemanticName = paramDesc.SemanticName; + elementDesc.SemanticIndex = paramDesc.SemanticIndex; + elementDesc.InputSlot = 0; + elementDesc.AlignedByteOffset = D3D11_APPEND_ALIGNED_ELEMENT; + elementDesc.InputSlotClass = D3D11_INPUT_PER_VERTEX_DATA; + elementDesc.InstanceDataStepRate = 0; + + // Replace anything affected by "per instance" data + if (isPerInstance) + { + elementDesc.InputSlot = 1; // Assume per instance data comes from another input slot! + elementDesc.InputSlotClass = D3D11_INPUT_PER_INSTANCE_DATA; + elementDesc.InstanceDataStepRate = 1; + + perInstanceCompatible = true; + } + + // Determine DXGI format + if (paramDesc.Mask == 1) + { + if (paramDesc.ComponentType == D3D_REGISTER_COMPONENT_UINT32) elementDesc.Format = DXGI_FORMAT_R32_UINT; + else if (paramDesc.ComponentType == D3D_REGISTER_COMPONENT_SINT32) elementDesc.Format = DXGI_FORMAT_R32_SINT; + else if (paramDesc.ComponentType == D3D_REGISTER_COMPONENT_FLOAT32) elementDesc.Format = DXGI_FORMAT_R32_FLOAT; + } + else if (paramDesc.Mask <= 3) + { + if (paramDesc.ComponentType == D3D_REGISTER_COMPONENT_UINT32) elementDesc.Format = DXGI_FORMAT_R32G32_UINT; + else if (paramDesc.ComponentType == D3D_REGISTER_COMPONENT_SINT32) elementDesc.Format = DXGI_FORMAT_R32G32_SINT; + else if (paramDesc.ComponentType == D3D_REGISTER_COMPONENT_FLOAT32) elementDesc.Format = DXGI_FORMAT_R32G32_FLOAT; + } + else if (paramDesc.Mask <= 7) + { + if (paramDesc.ComponentType == D3D_REGISTER_COMPONENT_UINT32) elementDesc.Format = DXGI_FORMAT_R32G32B32_UINT; + else if (paramDesc.ComponentType == D3D_REGISTER_COMPONENT_SINT32) elementDesc.Format = DXGI_FORMAT_R32G32B32_SINT; + else if (paramDesc.ComponentType == D3D_REGISTER_COMPONENT_FLOAT32) elementDesc.Format = DXGI_FORMAT_R32G32B32_FLOAT; + } + else if (paramDesc.Mask <= 15) + { + if (paramDesc.ComponentType == D3D_REGISTER_COMPONENT_UINT32) elementDesc.Format = DXGI_FORMAT_R32G32B32A32_UINT; + else if (paramDesc.ComponentType == D3D_REGISTER_COMPONENT_SINT32) elementDesc.Format = DXGI_FORMAT_R32G32B32A32_SINT; + else if (paramDesc.ComponentType == D3D_REGISTER_COMPONENT_FLOAT32) elementDesc.Format = DXGI_FORMAT_R32G32B32A32_FLOAT; + } + + // Save element desc + inputLayoutDesc.push_back(elementDesc); + } + + // Try to create Input Layout + HRESULT hr = device->CreateInputLayout( + &inputLayoutDesc[0], + (unsigned int)inputLayoutDesc.size(), + shaderBlob->GetBufferPointer(), + shaderBlob->GetBufferSize(), + inputLayout.GetAddressOf()); + + // All done, clean up + return true; +} + +// -------------------------------------------------------- +// Sets the vertex shader, input layout and constant buffers +// for future Direct3D drawing +// -------------------------------------------------------- +void SimpleVertexShader::SetShaderAndCBs() +{ + // Is shader valid? + if (!shaderValid) return; + + // Set the shader and input layout + deviceContext->IASetInputLayout(inputLayout.Get()); + deviceContext->VSSetShader(shader.Get(), 0, 0); + + // Set the constant buffers + for (unsigned int i = 0; i < constantBufferCount; i++) + { + // Skip "buffers" that aren't true constant buffers + if (constantBuffers[i].Type != D3D11_CT_CBUFFER) + continue; + + // This is a real constant buffer, so set it + deviceContext->VSSetConstantBuffers( + constantBuffers[i].BindIndex, + 1, + constantBuffers[i].ConstantBuffer.GetAddressOf()); + } +} + +// -------------------------------------------------------- +// Sets a shader resource view in the vertex shader stage +// +// name - The name of the texture resource in the shader +// srv - The shader resource view of the texture in GPU memory +// +// Returns true if a texture of the given name was found, false otherwise +// -------------------------------------------------------- +bool SimpleVertexShader::SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv) +{ + // Look for the variable and verify + const SimpleSRV* srvInfo = GetShaderResourceViewInfo(name); + if (srvInfo == 0) + { + if (ReportWarnings) + { + LogWarning("SimpleVertexShader::SetShaderResourceView() - SRV named '"); + Log(name); + LogWarning("' was not found in the shader. Ensure the name is spelled correctly and that it exists in the shader.\n"); + } + return false; + } + + // Set the shader resource view + deviceContext->VSSetShaderResources(srvInfo->BindIndex, 1, srv.GetAddressOf()); + + // Success + return true; +} + +// -------------------------------------------------------- +// Sets a sampler state in the vertex shader stage +// +// name - The name of the sampler state in the shader +// samplerState - The sampler state in GPU memory +// +// Returns true if a sampler of the given name was found, false otherwise +// -------------------------------------------------------- +bool SimpleVertexShader::SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState) +{ + // Look for the variable and verify + const SimpleSampler* sampInfo = GetSamplerInfo(name); + if (sampInfo == 0) + { + if (ReportWarnings) + { + LogWarning("SimpleVertexShader::SetSamplerState() - Sampler named '"); + Log(name); + LogWarning("' was not found in the shader. Ensure the name is spelled correctly and that it exists in the shader.\n"); + } + return false; + } + + // Set the shader resource view + deviceContext->VSSetSamplers(sampInfo->BindIndex, 1, samplerState.GetAddressOf()); + + // Success + return true; +} + + +/////////////////////////////////////////////////////////////////////////////// +// ------ SIMPLE PIXEL SHADER ------------------------------------------------- +/////////////////////////////////////////////////////////////////////////////// + +// -------------------------------------------------------- +// Constructor just calls the base +// -------------------------------------------------------- +SimplePixelShader::SimplePixelShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile) + : ISimpleShader(device, context) +{ + // Load the actual compiled shader file + this->LoadShaderFile(shaderFile); +} + +// -------------------------------------------------------- +// Destructor - Clean up actual shader (base will be called automatically) +// -------------------------------------------------------- +SimplePixelShader::~SimplePixelShader() +{ + CleanUp(); +} + +// -------------------------------------------------------- +// Handles cleaning up shader and base class clean up +// -------------------------------------------------------- +void SimplePixelShader::CleanUp() +{ + ISimpleShader::CleanUp(); +} + +// -------------------------------------------------------- +// Creates the Direct3D pixel shader +// +// shaderBlob - The shader's compiled code +// +// Returns true if shader is created correctly, false otherwise +// -------------------------------------------------------- +bool SimplePixelShader::CreateShader(Microsoft::WRL::ComPtr shaderBlob) +{ + // Clean up first, in the event this method is + // called more than once on the same object + this->CleanUp(); + + // Create the shader from the blob + HRESULT result = device->CreatePixelShader( + shaderBlob->GetBufferPointer(), + shaderBlob->GetBufferSize(), + 0, + shader.GetAddressOf()); + + // Check the result + return (result == S_OK); +} + +// -------------------------------------------------------- +// Sets the pixel shader and constant buffers for +// future Direct3D drawing +// -------------------------------------------------------- +void SimplePixelShader::SetShaderAndCBs() +{ + // Is shader valid? + if (!shaderValid) return; + + // Set the shader + deviceContext->PSSetShader(shader.Get(), 0, 0); + + // Set the constant buffers + for (unsigned int i = 0; i < constantBufferCount; i++) + { + // Skip "buffers" that aren't true constant buffers + if (constantBuffers[i].Type != D3D11_CT_CBUFFER) + continue; + + // This is a real constant buffer, so set it + deviceContext->PSSetConstantBuffers( + constantBuffers[i].BindIndex, + 1, + constantBuffers[i].ConstantBuffer.GetAddressOf()); + } +} + +// -------------------------------------------------------- +// Sets a shader resource view in the pixel shader stage +// +// name - The name of the texture resource in the shader +// srv - The shader resource view of the texture in GPU memory +// +// Returns true if a texture of the given name was found, false otherwise +// -------------------------------------------------------- +bool SimplePixelShader::SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv) +{ + // Look for the variable and verify + const SimpleSRV* srvInfo = GetShaderResourceViewInfo(name); + if (srvInfo == 0) + { + if (ReportWarnings) + { + LogWarning("SimplePixelShader::SetShaderResourceView() - SRV named '"); + Log(name); + LogWarning("' was not found in the shader. Ensure the name is spelled correctly and that it exists in the shader.\n"); + } + return false; + } + + // Set the shader resource view + deviceContext->PSSetShaderResources(srvInfo->BindIndex, 1, srv.GetAddressOf()); + + // Success + return true; +} + +// -------------------------------------------------------- +// Sets a sampler state in the pixel shader stage +// +// name - The name of the sampler state in the shader +// samplerState - The sampler state in GPU memory +// +// Returns true if a sampler of the given name was found, false otherwise +// -------------------------------------------------------- +bool SimplePixelShader::SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState) +{ + // Look for the variable and verify + const SimpleSampler* sampInfo = GetSamplerInfo(name); + if (sampInfo == 0) + { + if (ReportWarnings) + { + LogWarning("SimplePixelShader::SetSamplerState() - Sampler named '"); + Log(name); + LogWarning("' was not found in the shader. Ensure the name is spelled correctly and that it exists in the shader.\n"); + } + return false; + } + + // Set the shader resource view + deviceContext->PSSetSamplers(sampInfo->BindIndex, 1, samplerState.GetAddressOf()); + + // Success + return true; +} + + + + +/////////////////////////////////////////////////////////////////////////////// +// ------ SIMPLE DOMAIN SHADER ------------------------------------------------ +/////////////////////////////////////////////////////////////////////////////// + +// -------------------------------------------------------- +// Constructor just calls the base +// -------------------------------------------------------- +SimpleDomainShader::SimpleDomainShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile) + : ISimpleShader(device, context) +{ + // Load the actual compiled shader file + this->LoadShaderFile(shaderFile); +} + +// -------------------------------------------------------- +// Destructor - Clean up actual shader (base will be called automatically) +// -------------------------------------------------------- +SimpleDomainShader::~SimpleDomainShader() +{ + CleanUp(); +} + +// -------------------------------------------------------- +// Handles cleaning up shader and base class clean up +// -------------------------------------------------------- +void SimpleDomainShader::CleanUp() +{ + ISimpleShader::CleanUp(); +} + +// -------------------------------------------------------- +// Creates the Direct3D domain shader +// +// shaderBlob - The shader's compiled code +// +// Returns true if shader is created correctly, false otherwise +// -------------------------------------------------------- +bool SimpleDomainShader::CreateShader(Microsoft::WRL::ComPtr shaderBlob) +{ + // Clean up first, in the event this method is + // called more than once on the same object + this->CleanUp(); + + // Create the shader from the blob + HRESULT result = device->CreateDomainShader( + shaderBlob->GetBufferPointer(), + shaderBlob->GetBufferSize(), + 0, + shader.GetAddressOf()); + + // Check the result + return (result == S_OK); +} + +// -------------------------------------------------------- +// Sets the domain shader and constant buffers for +// future Direct3D drawing +// -------------------------------------------------------- +void SimpleDomainShader::SetShaderAndCBs() +{ + // Is shader valid? + if (!shaderValid) return; + + // Set the shader + deviceContext->DSSetShader(shader.Get(), 0, 0); + + // Set the constant buffers + for (unsigned int i = 0; i < constantBufferCount; i++) + { + // Skip "buffers" that aren't true constant buffers + if (constantBuffers[i].Type != D3D11_CT_CBUFFER) + continue; + + // This is a real constant buffer, so set it + deviceContext->DSSetConstantBuffers( + constantBuffers[i].BindIndex, + 1, + constantBuffers[i].ConstantBuffer.GetAddressOf()); + } +} + +// -------------------------------------------------------- +// Sets a shader resource view in the domain shader stage +// +// name - The name of the texture resource in the shader +// srv - The shader resource view of the texture in GPU memory +// +// Returns true if a texture of the given name was found, false otherwise +// -------------------------------------------------------- +bool SimpleDomainShader::SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv) +{ + // Look for the variable and verify + const SimpleSRV* srvInfo = GetShaderResourceViewInfo(name); + if (srvInfo == 0) + { + if (ReportWarnings) + { + LogWarning("SimpleDomainShader::SetShaderResourceView() - SRV named '"); + Log(name); + LogWarning("' was not found in the shader. Ensure the name is spelled correctly and that it exists in the shader.\n"); + } + return false; + } + + // Set the shader resource view + deviceContext->DSSetShaderResources(srvInfo->BindIndex, 1, srv.GetAddressOf()); + + // Success + return true; +} + +// -------------------------------------------------------- +// Sets a sampler state in the domain shader stage +// +// name - The name of the sampler state in the shader +// samplerState - The sampler state in GPU memory +// +// Returns true if a sampler of the given name was found, false otherwise +// -------------------------------------------------------- +bool SimpleDomainShader::SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState) +{ + // Look for the variable and verify + const SimpleSampler* sampInfo = GetSamplerInfo(name); + if (sampInfo == 0) + { + if (ReportWarnings) + { + LogWarning("SimpleDomainShader::SetSamplerState() - Sampler named '"); + Log(name); + LogWarning("' was not found in the shader. Ensure the name is spelled correctly and that it exists in the shader.\n"); + } + return false; + } + + // Set the shader resource view + deviceContext->DSSetSamplers(sampInfo->BindIndex, 1, samplerState.GetAddressOf()); + + // Success + return true; +} + + + +/////////////////////////////////////////////////////////////////////////////// +// ------ SIMPLE HULL SHADER -------------------------------------------------- +/////////////////////////////////////////////////////////////////////////////// + +// -------------------------------------------------------- +// Constructor just calls the base +// -------------------------------------------------------- +SimpleHullShader::SimpleHullShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile) + : ISimpleShader(device, context) +{ + // Load the actual compiled shader file + this->LoadShaderFile(shaderFile); +} + +// -------------------------------------------------------- +// Destructor - Clean up actual shader (base will be called automatically) +// -------------------------------------------------------- +SimpleHullShader::~SimpleHullShader() +{ + CleanUp(); +} + +// -------------------------------------------------------- +// Handles cleaning up shader and base class clean up +// -------------------------------------------------------- +void SimpleHullShader::CleanUp() +{ + ISimpleShader::CleanUp(); +} + +// -------------------------------------------------------- +// Creates the Direct3D hull shader +// +// shaderBlob - The shader's compiled code +// +// Returns true if shader is created correctly, false otherwise +// -------------------------------------------------------- +bool SimpleHullShader::CreateShader(Microsoft::WRL::ComPtr shaderBlob) +{ + // Clean up first, in the event this method is + // called more than once on the same object + this->CleanUp(); + + // Create the shader from the blob + HRESULT result = device->CreateHullShader( + shaderBlob->GetBufferPointer(), + shaderBlob->GetBufferSize(), + 0, + shader.GetAddressOf()); + + // Check the result + return (result == S_OK); +} + +// -------------------------------------------------------- +// Sets the hull shader and constant buffers for +// future Direct3D drawing +// -------------------------------------------------------- +void SimpleHullShader::SetShaderAndCBs() +{ + // Is shader valid? + if (!shaderValid) return; + + // Set the shader + deviceContext->HSSetShader(shader.Get(), 0, 0); + + // Set the constant buffers? + for (unsigned int i = 0; i < constantBufferCount; i++) + { + // Skip "buffers" that aren't true constant buffers + if (constantBuffers[i].Type != D3D11_CT_CBUFFER) + continue; + + // This is a real constant buffer, so set it + deviceContext->HSSetConstantBuffers( + constantBuffers[i].BindIndex, + 1, + constantBuffers[i].ConstantBuffer.GetAddressOf()); + } +} + +// -------------------------------------------------------- +// Sets a shader resource view in the hull shader stage +// +// name - The name of the texture resource in the shader +// srv - The shader resource view of the texture in GPU memory +// +// Returns true if a texture of the given name was found, false otherwise +// -------------------------------------------------------- +bool SimpleHullShader::SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv) +{ + // Look for the variable and verify + const SimpleSRV* srvInfo = GetShaderResourceViewInfo(name); + if (srvInfo == 0) + { + if (ReportWarnings) + { + LogWarning("SimpleHullShader::SetShaderResourceView() - SRV named '"); + Log(name); + LogWarning("' was not found in the shader. Ensure the name is spelled correctly and that it exists in the shader.\n"); + } + return false; + } + + // Set the shader resource view + deviceContext->HSSetShaderResources(srvInfo->BindIndex, 1, srv.GetAddressOf()); + + // Success + return true; +} + +// -------------------------------------------------------- +// Sets a sampler state in the hull shader stage +// +// name - The name of the sampler state in the shader +// samplerState - The sampler state in GPU memory +// +// Returns true if a sampler of the given name was found, false otherwise +// -------------------------------------------------------- +bool SimpleHullShader::SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState) +{ + // Look for the variable and verify + const SimpleSampler* sampInfo = GetSamplerInfo(name); + if (sampInfo == 0) + { + if (ReportWarnings) + { + LogWarning("SimpleHullShader::SetSamplerState() - Sampler named '"); + Log(name); + LogWarning("' was not found in the shader. Ensure the name is spelled correctly and that it exists in the shader.\n"); + } + return false; + } + + // Set the shader resource view + deviceContext->HSSetSamplers(sampInfo->BindIndex, 1, samplerState.GetAddressOf()); + + // Success + return true; +} + + + + +/////////////////////////////////////////////////////////////////////////////// +// ------ SIMPLE GEOMETRY SHADER ---------------------------------------------- +/////////////////////////////////////////////////////////////////////////////// + +// -------------------------------------------------------- +// Constructor calls the base and sets up potential stream-out options +// -------------------------------------------------------- +SimpleGeometryShader::SimpleGeometryShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile, bool useStreamOut, bool allowStreamOutRasterization) + : ISimpleShader(device, context) +{ + this->streamOutVertexSize = 0; + this->useStreamOut = useStreamOut; + this->allowStreamOutRasterization = allowStreamOutRasterization; + + // Load the actual compiled shader file + this->LoadShaderFile(shaderFile); +} + +// -------------------------------------------------------- +// Destructor - Clean up actual shader (base will be called automatically) +// -------------------------------------------------------- +SimpleGeometryShader::~SimpleGeometryShader() +{ + CleanUp(); +} + +// -------------------------------------------------------- +// Handles cleaning up shader and base class clean up +// -------------------------------------------------------- +void SimpleGeometryShader::CleanUp() +{ + ISimpleShader::CleanUp(); +} + +// -------------------------------------------------------- +// Creates the Direct3D Geometry shader +// +// shaderBlob - The shader's compiled code +// +// Returns true if shader is created correctly, false otherwise +// -------------------------------------------------------- +bool SimpleGeometryShader::CreateShader(Microsoft::WRL::ComPtr shaderBlob) +{ + // Clean up first, in the event this method is + // called more than once on the same object + this->CleanUp(); + + // Using stream out? + if (useStreamOut) + return this->CreateShaderWithStreamOut(shaderBlob); + + // Create the shader from the blob + HRESULT result = device->CreateGeometryShader( + shaderBlob->GetBufferPointer(), + shaderBlob->GetBufferSize(), + 0, + shader.GetAddressOf()); + + // Check the result + return (result == S_OK); +} + +// -------------------------------------------------------- +// Creates the Direct3D Geometry shader and sets it up for +// stream output, if possible. +// +// shaderBlob - The shader's compiled code +// +// Returns true if shader is created correctly, false otherwise +// -------------------------------------------------------- +bool SimpleGeometryShader::CreateShaderWithStreamOut(Microsoft::WRL::ComPtr shaderBlob) +{ + // Clean up first, in the event this method is + // called more than once on the same object + this->CleanUp(); + + // Reflect shader info + Microsoft::WRL::ComPtr refl; + D3DReflect( + shaderBlob->GetBufferPointer(), + shaderBlob->GetBufferSize(), + IID_ID3D11ShaderReflection, + (void**)refl.GetAddressOf()); + + // Get shader info + D3D11_SHADER_DESC shaderDesc; + refl->GetDesc(&shaderDesc); + + // Set up the output signature + streamOutVertexSize = 0; + std::vector soDecl; + for (unsigned int i = 0; i < shaderDesc.OutputParameters; i++) + { + // Get the info about this entry + D3D11_SIGNATURE_PARAMETER_DESC paramDesc; + refl->GetOutputParameterDesc(i, ¶mDesc); + + // Create the SO Declaration + D3D11_SO_DECLARATION_ENTRY entry = {}; + entry.SemanticIndex = paramDesc.SemanticIndex; + entry.SemanticName = paramDesc.SemanticName; + entry.Stream = paramDesc.Stream; + entry.StartComponent = 0; // Assume starting at 0 + entry.OutputSlot = 0; // Assume the first output slot + + // Check the mask to determine how many components are used + entry.ComponentCount = CalcComponentCount(paramDesc.Mask); + + // Increment the size + streamOutVertexSize += entry.ComponentCount * sizeof(float); + + // Add to the declaration + soDecl.push_back(entry); + } + + // Rasterization allowed? + unsigned int rast = allowStreamOutRasterization ? 0 : D3D11_SO_NO_RASTERIZED_STREAM; + + // Create the shader + HRESULT result = device->CreateGeometryShaderWithStreamOutput( + shaderBlob->GetBufferPointer(), // Shader blob pointer + shaderBlob->GetBufferSize(), // Shader blob size + &soDecl[0], // Stream out declaration + (unsigned int)soDecl.size(), // Number of declaration entries + NULL, // Buffer strides (not used - assume tightly packed?) + 0, // No buffer strides + rast, // Index of the stream to rasterize (if any) + NULL, // Not using class linkage + shader.GetAddressOf()); + + return (result == S_OK); +} + +// -------------------------------------------------------- +// Creates a vertex buffer that is compatible with the stream output +// delcaration that was used to create the shader. This buffer will +// not be cleaned up (Released) by the simple shader - you must clean +// it up yourself when you're done with it. Immediately returns +// false if the shader was not created with stream output, the shader +// isn't valid or the determined stream out vertex size is zero. +// +// buffer - Pointer to an ID3D11Buffer pointer to hold the buffer ref +// vertexCount - Amount of vertices the buffer should hold +// +// Returns true if buffer is created successfully AND stream output +// was used to create the shader. False otherwise. +// -------------------------------------------------------- +bool SimpleGeometryShader::CreateCompatibleStreamOutBuffer(Microsoft::WRL::ComPtr buffer, int vertexCount) +{ + // Was stream output actually used? + if (!this->useStreamOut || !shaderValid || streamOutVertexSize == 0) + { + if (ReportErrors) + { + LogError("SimpleGeometryShader::CreateCompatibleStreamOutBuffer() - Either the shader is not valid or this SimpleGeometryShader was not initialized for stream out usage.\n"); + } + + return false; + } + + // Set up the buffer description + D3D11_BUFFER_DESC desc = {}; + desc.BindFlags = D3D11_BIND_STREAM_OUTPUT | D3D11_BIND_VERTEX_BUFFER; + desc.ByteWidth = streamOutVertexSize * vertexCount; + desc.CPUAccessFlags = 0; + desc.MiscFlags = 0; + desc.StructureByteStride = 0; + desc.Usage = D3D11_USAGE_DEFAULT; + + // Attempt to create the buffer and return the result + HRESULT result = device->CreateBuffer(&desc, 0, buffer.GetAddressOf()); + return (result == S_OK); +} + +// -------------------------------------------------------- +// Helper method to unbind all stream out buffers from the SO stage +// -------------------------------------------------------- +void SimpleGeometryShader::UnbindStreamOutStage(Microsoft::WRL::ComPtr deviceContext) +{ + unsigned int offset = 0; + ID3D11Buffer* unset[4] = { 0, 0, 0, 0 }; // Max of 4 output targets according to Direct3D documentation + deviceContext->SOSetTargets(4, unset, &offset); +} + +// -------------------------------------------------------- +// Sets the geometry shader and constant buffers for +// future Direct3D drawing +// -------------------------------------------------------- +void SimpleGeometryShader::SetShaderAndCBs() +{ + // Is shader valid? + if (!shaderValid) return; + + // Set the shader + deviceContext->GSSetShader(shader.Get(), 0, 0); + + // Set the constant buffers? + for (unsigned int i = 0; i < constantBufferCount; i++) + { + // Skip "buffers" that aren't true constant buffers + if (constantBuffers[i].Type != D3D11_CT_CBUFFER) + continue; + + // This is a real constant buffer, so set it + deviceContext->GSSetConstantBuffers( + constantBuffers[i].BindIndex, + 1, + constantBuffers[i].ConstantBuffer.GetAddressOf()); + } +} + +// -------------------------------------------------------- +// Sets a shader resource view in the Geometry shader stage +// +// name - The name of the texture resource in the shader +// srv - The shader resource view of the texture in GPU memory +// +// Returns true if a texture of the given name was found, false otherwise +// -------------------------------------------------------- +bool SimpleGeometryShader::SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv) +{ + // Look for the variable and verify + const SimpleSRV* srvInfo = GetShaderResourceViewInfo(name); + if (srvInfo == 0) + { + if (ReportWarnings) + { + LogWarning("SimpleGeometryShader::SetShaderResourceView() - SRV named '"); + Log(name); + LogWarning("' was not found in the shader. Ensure the name is spelled correctly and that it exists in the shader.\n"); + } + return false; + } + + // Set the shader resource view + deviceContext->GSSetShaderResources(srvInfo->BindIndex, 1, srv.GetAddressOf()); + + // Success + return true; +} + +// -------------------------------------------------------- +// Sets a sampler state in the Geometry shader stage +// +// name - The name of the sampler state in the shader +// samplerState - The sampler state in GPU memory +// +// Returns true if a sampler of the given name was found, false otherwise +// -------------------------------------------------------- +bool SimpleGeometryShader::SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState) +{ + // Look for the variable and verify + const SimpleSampler* sampInfo = GetSamplerInfo(name); + if (sampInfo == 0) + { + if (ReportWarnings) + { + LogWarning("SimpleGeometryShader::SetSamplerState() - Sampler named '"); + Log(name); + LogWarning("' was not found in the shader. Ensure the name is spelled correctly and that it exists in the shader.\n"); + } + return false; + } + + // Set the shader resource view + deviceContext->GSSetSamplers(sampInfo->BindIndex, 1, samplerState.GetAddressOf()); + + // Success + return true; +} + +// -------------------------------------------------------- +// Calculates the number of components specified by a parameter description mask +// +// mask - The mask to check (only values 0 - 15 are considered) +// +// Returns an integer between 0 - 4 inclusive +// -------------------------------------------------------- +unsigned int SimpleGeometryShader::CalcComponentCount(unsigned int mask) +{ + unsigned int result = 0; + result += (unsigned int)((mask & 1) == 1); + result += (unsigned int)((mask & 2) == 2); + result += (unsigned int)((mask & 4) == 4); + result += (unsigned int)((mask & 8) == 8); + return result; +} + + + +/////////////////////////////////////////////////////////////////////////////// +// ------ SIMPLE COMPUTE SHADER ----------------------------------------------- +/////////////////////////////////////////////////////////////////////////////// + +// -------------------------------------------------------- +// Constructor just calls the base +// -------------------------------------------------------- +SimpleComputeShader::SimpleComputeShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile) + : ISimpleShader(device, context) +{ + this->threadsTotal = 0; + this->threadsX = 0; + this->threadsY = 0; + this->threadsZ = 0; + + // Load the actual compiled shader file + this->LoadShaderFile(shaderFile); +} + +// -------------------------------------------------------- +// Destructor - Clean up actual shader (base will be called automatically) +// -------------------------------------------------------- +SimpleComputeShader::~SimpleComputeShader() +{ + CleanUp(); +} + +// -------------------------------------------------------- +// Handles cleaning up shader and base class clean up +// -------------------------------------------------------- +void SimpleComputeShader::CleanUp() +{ + ISimpleShader::CleanUp(); + + uavTable.clear(); +} + +// -------------------------------------------------------- +// Creates the Direct3D Compute shader +// +// shaderBlob - The shader's compiled code +// +// Returns true if shader is created correctly, false otherwise +// -------------------------------------------------------- +bool SimpleComputeShader::CreateShader(Microsoft::WRL::ComPtr shaderBlob) +{ + // Clean up first, in the event this method is + // called more than once on the same object + this->CleanUp(); + + // Create the shader from the blob + HRESULT result = device->CreateComputeShader( + shaderBlob->GetBufferPointer(), + shaderBlob->GetBufferSize(), + 0, + shader.GetAddressOf()); + + // Was the shader created correctly? + if (result != S_OK) + return false; + + // Set up shader reflection to get information about UAV's + Microsoft::WRL::ComPtr refl; + D3DReflect( + shaderBlob->GetBufferPointer(), + shaderBlob->GetBufferSize(), + IID_ID3D11ShaderReflection, + (void**)refl.GetAddressOf()); + + // Get the description of the shader + D3D11_SHADER_DESC shaderDesc; + refl->GetDesc(&shaderDesc); + + // Grab the thread info + threadsTotal = refl->GetThreadGroupSize( + &threadsX, + &threadsY, + &threadsZ); + + // Loop and get all UAV resources + unsigned int resourceCount = shaderDesc.BoundResources; + for (unsigned int r = 0; r < resourceCount; r++) + { + // Get this resource's description + D3D11_SHADER_INPUT_BIND_DESC resourceDesc; + refl->GetResourceBindingDesc(r, &resourceDesc); + + // Check the type, looking for any kind of UAV + switch (resourceDesc.Type) + { + case D3D_SIT_UAV_APPEND_STRUCTURED: + case D3D_SIT_UAV_CONSUME_STRUCTURED: + case D3D_SIT_UAV_RWBYTEADDRESS: + case D3D_SIT_UAV_RWSTRUCTURED: + case D3D_SIT_UAV_RWSTRUCTURED_WITH_COUNTER: + case D3D_SIT_UAV_RWTYPED: + uavTable.insert(std::pair(resourceDesc.Name, resourceDesc.BindPoint)); + } + } + + // All set + return true; +} + +// -------------------------------------------------------- +// Sets the Compute shader and constant buffers for +// future Direct3D drawing +// -------------------------------------------------------- +void SimpleComputeShader::SetShaderAndCBs() +{ + // Is shader valid? + if (!shaderValid) return; + + // Set the shader + deviceContext->CSSetShader(shader.Get(), 0, 0); + + // Set the constant buffers? + for (unsigned int i = 0; i < constantBufferCount; i++) + { + // Skip "buffers" that aren't true constant buffers + if (constantBuffers[i].Type != D3D11_CT_CBUFFER) + continue; + + // This is a real constant buffer, so set it + deviceContext->CSSetConstantBuffers( + constantBuffers[i].BindIndex, + 1, + constantBuffers[i].ConstantBuffer.GetAddressOf()); + } +} + +// -------------------------------------------------------- +// Dispatches the compute shader with the specified amount +// of groups, using the number of threads per group +// specified in the shader file itself +// +// For example, calling this method with params (5,1,1) on +// a shader with (8,2,2) threads per group will launch a +// total of 160 threads: ((5 * 8) * (1 * 2) * (1 * 2)) +// +// This is identical to using the device context's +// Dispatch() method yourself. +// +// Note: This will dispatch the currently active shader, +// not necessarily THIS shader. Be sure to activate this +// shader with SetShader() before calling Dispatch +// +// groupsX - Numbers of groups in the X dimension +// groupsY - Numbers of groups in the Y dimension +// groupsZ - Numbers of groups in the Z dimension +// -------------------------------------------------------- +void SimpleComputeShader::DispatchByGroups(unsigned int groupsX, unsigned int groupsY, unsigned int groupsZ) +{ + deviceContext->Dispatch(groupsX, groupsY, groupsZ); +} + +// -------------------------------------------------------- +// Dispatches the compute shader with AT LEAST the +// specified amount of threads, calculating the number of +// groups to dispatch using the number of threads per group +// specified in the shader file itself +// +// For example, calling this method with params (10,3,3) on +// a shader with (5,2,2) threads per group will launch +// 8 total groups and 160 total threads, calculated by: +// Groups: ceil(10/5) * ceil(3/2) * ceil(3/2) = 8 +// Threads: ((2 * 5) * (2 * 2) * (2 * 2)) = 160 +// +// Note: This will dispatch the currently active shader, +// not necessarily THIS shader. Be sure to activate this +// shader with SetShader() before calling Dispatch +// +// threadsX - Desired numbers of threads in the X dimension +// threadsY - Desired numbers of threads in the Y dimension +// threadsZ - Desired numbers of threads in the Z dimension +// -------------------------------------------------------- +void SimpleComputeShader::DispatchByThreads(unsigned int threadsX, unsigned int threadsY, unsigned int threadsZ) +{ + deviceContext->Dispatch( + max((unsigned int)ceil((float)threadsX / this->threadsX), 1), + max((unsigned int)ceil((float)threadsY / this->threadsY), 1), + max((unsigned int)ceil((float)threadsZ / this->threadsZ), 1)); +} + +// -------------------------------------------------------- +// Determines if this shader has the specified UAV +// -------------------------------------------------------- +bool SimpleComputeShader::HasUnorderedAccessView(std::string name) +{ + return GetUnorderedAccessViewIndex(name) != -1; +} + +// -------------------------------------------------------- +// Sets a shader resource view in the Compute shader stage +// +// name - The name of the texture resource in the shader +// srv - The shader resource view of the texture in GPU memory +// +// Returns true if a texture of the given name was found, false otherwise +// -------------------------------------------------------- +bool SimpleComputeShader::SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv) +{ + // Look for the variable and verify + const SimpleSRV* srvInfo = GetShaderResourceViewInfo(name); + if (srvInfo == 0) + { + if (ReportWarnings) + { + LogWarning("SimpleComputeShader::SetShaderResourceView() - SRV named '"); + Log(name); + LogWarning("' was not found in the shader. Ensure the name is spelled correctly and that it exists in the shader.\n"); + } + return false; + } + + // Set the shader resource view + deviceContext->CSSetShaderResources(srvInfo->BindIndex, 1, srv.GetAddressOf()); + + // Success + return true; +} + +// -------------------------------------------------------- +// Sets a sampler state in the Compute shader stage +// +// name - The name of the sampler state in the shader +// samplerState - The sampler state in GPU memory +// +// Returns true if a sampler of the given name was found, false otherwise +// -------------------------------------------------------- +bool SimpleComputeShader::SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState) +{ + // Look for the variable and verify + const SimpleSampler* sampInfo = GetSamplerInfo(name); + if (sampInfo == 0) + { + if (ReportWarnings) + { + LogWarning("SimpleComputeShader::SetSamplerState() - Sampler named '"); + Log(name); + LogWarning("' was not found in the shader. Ensure the name is spelled correctly and that it exists in the shader.\n"); + } + return false; + } + + // Set the shader resource view + deviceContext->CSSetSamplers(sampInfo->BindIndex, 1, samplerState.GetAddressOf()); + + // Success + return true; +} + +// -------------------------------------------------------- +// Sets an unordered access view in the Compute shader stage +// +// name - The name of the sampler state in the shader +// uav - The UAV in GPU memory +// appendConsumeOffset - Used for append or consume UAV's (optional) +// +// Returns true if a UAV of the given name was found, false otherwise +// -------------------------------------------------------- +bool SimpleComputeShader::SetUnorderedAccessView(std::string name, Microsoft::WRL::ComPtr uav, unsigned int appendConsumeOffset) +{ + // Look for the variable and verify + unsigned int bindIndex = GetUnorderedAccessViewIndex(name); + if (bindIndex == -1) + { + if (ReportWarnings) + { + LogWarning("SimpleComputeShader::SetUnorderedAccessView() - UAV named '"); + Log(name); + LogWarning("' was not found in the shader. Ensure the name is spelled correctly and that it exists in the shader.\n"); + } + return false; + } + + // Set the shader resource view + deviceContext->CSSetUnorderedAccessViews(bindIndex, 1, uav.GetAddressOf(), &appendConsumeOffset); + + // Success + return true; +} + +// -------------------------------------------------------- +// Gets the index of the specified UAV (or -1) +// -------------------------------------------------------- +int SimpleComputeShader::GetUnorderedAccessViewIndex(std::string name) +{ + // Look for the key + std::unordered_map::iterator result = + uavTable.find(name); + + // Did we find the key? + if (result == uavTable.end()) + return -1; + + // Success + return result->second; +} diff --git a/SimpleShader.h b/SimpleShader.h new file mode 100644 index 0000000..78a0110 --- /dev/null +++ b/SimpleShader.h @@ -0,0 +1,326 @@ +// CREATOR: https://github.com/vixorien/SimpleShader +// LICENSE: MIT + +#pragma once +#pragma comment(lib, "dxguid.lib") +#pragma comment(lib, "d3dcompiler.lib") + +#include +#include +#include +#include + +#include +#include +#include + + +// -------------------------------------------------------- +// Used by simple shaders to store information about +// specific variables in constant buffers +// -------------------------------------------------------- +struct SimpleShaderVariable +{ + unsigned int ByteOffset; + unsigned int Size; + unsigned int ConstantBufferIndex; +}; + +// -------------------------------------------------------- +// Contains information about a specific +// constant buffer in a shader, as well as +// the local data buffer for it +// -------------------------------------------------------- +struct SimpleConstantBuffer +{ + std::string Name; + D3D_CBUFFER_TYPE Type = D3D_CBUFFER_TYPE::D3D11_CT_CBUFFER; + unsigned int Size = 0; + unsigned int BindIndex = 0; + Microsoft::WRL::ComPtr ConstantBuffer = 0; + unsigned char* LocalDataBuffer = 0; + std::vector Variables; +}; + +// -------------------------------------------------------- +// Contains info about a single SRV in a shader +// -------------------------------------------------------- +struct SimpleSRV +{ + unsigned int Index; // The raw index of the SRV + unsigned int BindIndex; // The register of the SRV +}; + +// -------------------------------------------------------- +// Contains info about a single Sampler in a shader +// -------------------------------------------------------- +struct SimpleSampler +{ + unsigned int Index; // The raw index of the Sampler + unsigned int BindIndex; // The register of the Sampler +}; + +// -------------------------------------------------------- +// Base abstract class for simplifying shader handling +// -------------------------------------------------------- +class ISimpleShader +{ +public: + ISimpleShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context); + virtual ~ISimpleShader(); + + // Simple helpers + bool IsShaderValid() { return shaderValid; } + + // Activating the shader and copying data + void SetShader(); + void CopyAllBufferData(); + void CopyBufferData(unsigned int index); + void CopyBufferData(std::string bufferName); + + // Sets arbitrary shader data + bool SetData(std::string name, const void* data, unsigned int size); + + bool SetInt(std::string name, int data); + bool SetFloat(std::string name, float data); + bool SetFloat2(std::string name, const float data[2]); + bool SetFloat2(std::string name, const DirectX::XMFLOAT2 data); + bool SetFloat3(std::string name, const float data[3]); + bool SetFloat3(std::string name, const DirectX::XMFLOAT3 data); + bool SetFloat4(std::string name, const float data[4]); + bool SetFloat4(std::string name, const DirectX::XMFLOAT4 data); + bool SetMatrix4x4(std::string name, const float data[16]); + bool SetMatrix4x4(std::string name, const DirectX::XMFLOAT4X4 data); + + // Setting shader resources + virtual bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv) = 0; + virtual bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState) = 0; + + // Simple resource checking + bool HasVariable(std::string name); + bool HasShaderResourceView(std::string name); + bool HasSamplerState(std::string name); + + // Getting data about variables and resources + const SimpleShaderVariable* GetVariableInfo(std::string name); + + const SimpleSRV* GetShaderResourceViewInfo(std::string name); + const SimpleSRV* GetShaderResourceViewInfo(unsigned int index); + size_t GetShaderResourceViewCount() { return textureTable.size(); } + + const SimpleSampler* GetSamplerInfo(std::string name); + const SimpleSampler* GetSamplerInfo(unsigned int index); + size_t GetSamplerCount() { return samplerTable.size(); } + + // Get data about constant buffers + unsigned int GetBufferCount(); + unsigned int GetBufferSize(unsigned int index); + const SimpleConstantBuffer* GetBufferInfo(std::string name); + const SimpleConstantBuffer* GetBufferInfo(unsigned int index); + + // Misc getters + Microsoft::WRL::ComPtr GetShaderBlob() { return shaderBlob; } + + // Error reporting + static bool ReportErrors; + static bool ReportWarnings; + +protected: + + bool shaderValid; + Microsoft::WRL::ComPtr shaderBlob; + Microsoft::WRL::ComPtr device; + Microsoft::WRL::ComPtr deviceContext; + + // Resource counts + unsigned int constantBufferCount; + + // Maps for variables and buffers + SimpleConstantBuffer* constantBuffers; // For index-based lookup + std::vector shaderResourceViews; + std::vector samplerStates; + std::unordered_map cbTable; + std::unordered_map varTable; + std::unordered_map textureTable; + std::unordered_map samplerTable; + + // Initialization method + bool LoadShaderFile(LPCWSTR shaderFile); + + // Pure virtual functions for dealing with shader types + virtual bool CreateShader(Microsoft::WRL::ComPtr shaderBlob) = 0; + virtual void SetShaderAndCBs() = 0; + + virtual void CleanUp(); + + // Helpers for finding data by name + SimpleShaderVariable* FindVariable(std::string name, int size); + SimpleConstantBuffer* FindConstantBuffer(std::string name); + + // Error logging + void Log(std::string message, WORD color); + void LogW(std::wstring message, WORD color); + void Log(std::string message); + void LogW(std::wstring message); + void LogError(std::string message); + void LogErrorW(std::wstring message); + void LogWarning(std::string message); + void LogWarningW(std::wstring message); +}; + +// -------------------------------------------------------- +// Derived class for VERTEX shaders /////////////////////// +// -------------------------------------------------------- +class SimpleVertexShader : public ISimpleShader +{ +public: + SimpleVertexShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile); + SimpleVertexShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile, Microsoft::WRL::ComPtr inputLayout, bool perInstanceCompatible); + ~SimpleVertexShader(); + Microsoft::WRL::ComPtr GetDirectXShader() { return shader; } + Microsoft::WRL::ComPtr GetInputLayout() { return inputLayout; } + bool GetPerInstanceCompatible() { return perInstanceCompatible; } + + bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv); + bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState); + +protected: + bool perInstanceCompatible; + Microsoft::WRL::ComPtr inputLayout; + Microsoft::WRL::ComPtr shader; + bool CreateShader(Microsoft::WRL::ComPtr shaderBlob); + void SetShaderAndCBs(); + void CleanUp(); +}; + + +// -------------------------------------------------------- +// Derived class for PIXEL shaders //////////////////////// +// -------------------------------------------------------- +class SimplePixelShader : public ISimpleShader +{ +public: + SimplePixelShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile); + ~SimplePixelShader(); + Microsoft::WRL::ComPtr GetDirectXShader() { return shader; } + + bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv); + bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState); + +protected: + Microsoft::WRL::ComPtr shader; + bool CreateShader(Microsoft::WRL::ComPtr shaderBlob); + void SetShaderAndCBs(); + void CleanUp(); +}; + +// -------------------------------------------------------- +// Derived class for DOMAIN shaders /////////////////////// +// -------------------------------------------------------- +class SimpleDomainShader : public ISimpleShader +{ +public: + SimpleDomainShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile); + ~SimpleDomainShader(); + Microsoft::WRL::ComPtr GetDirectXShader() { return shader; } + + bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv); + bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState); + +protected: + Microsoft::WRL::ComPtr shader; + bool CreateShader(Microsoft::WRL::ComPtr shaderBlob); + void SetShaderAndCBs(); + void CleanUp(); +}; + +// -------------------------------------------------------- +// Derived class for HULL shaders ///////////////////////// +// -------------------------------------------------------- +class SimpleHullShader : public ISimpleShader +{ +public: + SimpleHullShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile); + ~SimpleHullShader(); + Microsoft::WRL::ComPtr GetDirectXShader() { return shader; } + + bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv); + bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState); + +protected: + Microsoft::WRL::ComPtr shader; + bool CreateShader(Microsoft::WRL::ComPtr shaderBlob); + void SetShaderAndCBs(); + void CleanUp(); +}; + +// -------------------------------------------------------- +// Derived class for GEOMETRY shaders ///////////////////// +// -------------------------------------------------------- +class SimpleGeometryShader : public ISimpleShader +{ +public: + SimpleGeometryShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile, bool useStreamOut = 0, bool allowStreamOutRasterization = 0); + ~SimpleGeometryShader(); + Microsoft::WRL::ComPtr GetDirectXShader() { return shader; } + + bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv); + bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState); + + bool CreateCompatibleStreamOutBuffer(Microsoft::WRL::ComPtr buffer, int vertexCount); + + static void UnbindStreamOutStage(Microsoft::WRL::ComPtr deviceContext); + +protected: + // Shader itself + Microsoft::WRL::ComPtr shader; + + // Stream out related + bool useStreamOut; + bool allowStreamOutRasterization; + unsigned int streamOutVertexSize; + + bool CreateShader(Microsoft::WRL::ComPtr shaderBlob); + bool CreateShaderWithStreamOut(Microsoft::WRL::ComPtr shaderBlob); + void SetShaderAndCBs(); + void CleanUp(); + + // Helpers + unsigned int CalcComponentCount(unsigned int mask); +}; + + +// -------------------------------------------------------- +// Derived class for COMPUTE shaders ////////////////////// +// -------------------------------------------------------- +class SimpleComputeShader : public ISimpleShader +{ +public: + SimpleComputeShader(Microsoft::WRL::ComPtr device, Microsoft::WRL::ComPtr context, LPCWSTR shaderFile); + ~SimpleComputeShader(); + Microsoft::WRL::ComPtr GetDirectXShader() { return shader; } + + void DispatchByGroups(unsigned int groupsX, unsigned int groupsY, unsigned int groupsZ); + void DispatchByThreads(unsigned int threadsX, unsigned int threadsY, unsigned int threadsZ); + + bool HasUnorderedAccessView(std::string name); + + bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr srv); + bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr samplerState); + bool SetUnorderedAccessView(std::string name, Microsoft::WRL::ComPtr uav, unsigned int appendConsumeOffset = -1); + + int GetUnorderedAccessViewIndex(std::string name); + +protected: + Microsoft::WRL::ComPtr shader; + std::unordered_map uavTable; + + unsigned int threadsX; + unsigned int threadsY; + unsigned int threadsZ; + unsigned int threadsTotal; + + bool CreateShader(Microsoft::WRL::ComPtr shaderBlob); + void SetShaderAndCBs(); + void CleanUp(); +};