// CREATOR: https://github.com/vixorien/SimpleShader
// LICENSE: MIT

#pragma once
#pragma comment(lib, "dxguid.lib")
#pragma comment(lib, "d3dcompiler.lib")

#include <d3d11.h>
#include <d3dcompiler.h>
#include <DirectXMath.h>
#include <wrl/client.h>

#include <unordered_map>
#include <vector>
#include <string>


// --------------------------------------------------------
// Used by simple shaders to store information about
// specific variables in constant buffers
// --------------------------------------------------------
struct SimpleShaderVariable
{
	unsigned int ByteOffset;
	unsigned int Size;
	unsigned int ConstantBufferIndex;
};

// --------------------------------------------------------
// Contains information about a specific
// constant buffer in a shader, as well as
// the local data buffer for it
// --------------------------------------------------------
struct SimpleConstantBuffer
{
	std::string Name;
	D3D_CBUFFER_TYPE Type = D3D_CBUFFER_TYPE::D3D11_CT_CBUFFER;
	unsigned int Size = 0;
	unsigned int BindIndex = 0;
	Microsoft::WRL::ComPtr<ID3D11Buffer> ConstantBuffer = 0;
	unsigned char* LocalDataBuffer = 0;
	std::vector<SimpleShaderVariable> Variables;
};

// --------------------------------------------------------
// Contains info about a single SRV in a shader
// --------------------------------------------------------
struct SimpleSRV
{
	unsigned int Index;		// The raw index of the SRV
	unsigned int BindIndex; // The register of the SRV
};

// --------------------------------------------------------
// Contains info about a single Sampler in a shader
// --------------------------------------------------------
struct SimpleSampler
{
	unsigned int Index;		// The raw index of the Sampler
	unsigned int BindIndex; // The register of the Sampler
};

// --------------------------------------------------------
// Base abstract class for simplifying shader handling
// --------------------------------------------------------
class ISimpleShader
{
public:
	ISimpleShader(Microsoft::WRL::ComPtr<ID3D11Device> device, Microsoft::WRL::ComPtr<ID3D11DeviceContext> context);
	virtual ~ISimpleShader();

	// Simple helpers
	bool IsShaderValid() { return shaderValid; }

	// Activating the shader and copying data
	void SetShader();
	void CopyAllBufferData();
	void CopyBufferData(unsigned int index);
	void CopyBufferData(std::string bufferName);

	// Sets arbitrary shader data
	bool SetData(std::string name, const void* data, unsigned int size);

	bool SetInt(std::string name, int data);
	bool SetFloat(std::string name, float data);
	bool SetFloat2(std::string name, const float data[2]);
	bool SetFloat2(std::string name, const DirectX::XMFLOAT2 data);
	bool SetFloat3(std::string name, const float data[3]);
	bool SetFloat3(std::string name, const DirectX::XMFLOAT3 data);
	bool SetFloat4(std::string name, const float data[4]);
	bool SetFloat4(std::string name, const DirectX::XMFLOAT4 data);
	bool SetMatrix4x4(std::string name, const float data[16]);
	bool SetMatrix4x4(std::string name, const DirectX::XMFLOAT4X4 data);

	// Setting shader resources
	virtual bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr<ID3D11ShaderResourceView> srv) = 0;
	virtual bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr<ID3D11SamplerState> samplerState) = 0;

	// Simple resource checking
	bool HasVariable(std::string name);
	bool HasShaderResourceView(std::string name);
	bool HasSamplerState(std::string name);

	// Getting data about variables and resources
	const SimpleShaderVariable* GetVariableInfo(std::string name);

	const SimpleSRV* GetShaderResourceViewInfo(std::string name);
	const SimpleSRV* GetShaderResourceViewInfo(unsigned int index);
	size_t GetShaderResourceViewCount() { return textureTable.size(); }

	const SimpleSampler* GetSamplerInfo(std::string name);
	const SimpleSampler* GetSamplerInfo(unsigned int index);
	size_t GetSamplerCount() { return samplerTable.size(); }

	// Get data about constant buffers
	unsigned int GetBufferCount();
	unsigned int GetBufferSize(unsigned int index);
	const SimpleConstantBuffer* GetBufferInfo(std::string name);
	const SimpleConstantBuffer* GetBufferInfo(unsigned int index);

	// Misc getters
	Microsoft::WRL::ComPtr<ID3DBlob> GetShaderBlob() { return shaderBlob; }

	// Error reporting
	static bool ReportErrors;
	static bool ReportWarnings;

protected:

	bool shaderValid;
	Microsoft::WRL::ComPtr<ID3DBlob> shaderBlob;
	Microsoft::WRL::ComPtr<ID3D11Device> device;
	Microsoft::WRL::ComPtr<ID3D11DeviceContext> deviceContext;

	// Resource counts
	unsigned int constantBufferCount;

	// Maps for variables and buffers
	SimpleConstantBuffer* constantBuffers; // For index-based lookup
	std::vector<SimpleSRV*>		shaderResourceViews;
	std::vector<SimpleSampler*>	samplerStates;
	std::unordered_map<std::string, SimpleConstantBuffer*> cbTable;
	std::unordered_map<std::string, SimpleShaderVariable> varTable;
	std::unordered_map<std::string, SimpleSRV*> textureTable;
	std::unordered_map<std::string, SimpleSampler*> samplerTable;

	// Initialization method
	bool LoadShaderFile(LPCWSTR shaderFile);

	// Pure virtual functions for dealing with shader types
	virtual bool CreateShader(Microsoft::WRL::ComPtr<ID3DBlob> shaderBlob) = 0;
	virtual void SetShaderAndCBs() = 0;

	virtual void CleanUp();

	// Helpers for finding data by name
	SimpleShaderVariable* FindVariable(std::string name, int size);
	SimpleConstantBuffer* FindConstantBuffer(std::string name);

	// Error logging
	void Log(std::string message, WORD color);
	void LogW(std::wstring message, WORD color);
	void Log(std::string message);
	void LogW(std::wstring message);
	void LogError(std::string message);
	void LogErrorW(std::wstring message);
	void LogWarning(std::string message);
	void LogWarningW(std::wstring message);
};

// --------------------------------------------------------
// Derived class for VERTEX shaders ///////////////////////
// --------------------------------------------------------
class SimpleVertexShader : public ISimpleShader
{
public:
	SimpleVertexShader(Microsoft::WRL::ComPtr<ID3D11Device> device, Microsoft::WRL::ComPtr<ID3D11DeviceContext> context, LPCWSTR shaderFile);
	SimpleVertexShader(Microsoft::WRL::ComPtr<ID3D11Device> device, Microsoft::WRL::ComPtr<ID3D11DeviceContext> context, LPCWSTR shaderFile, Microsoft::WRL::ComPtr<ID3D11InputLayout> inputLayout, bool perInstanceCompatible);
	~SimpleVertexShader();
	Microsoft::WRL::ComPtr<ID3D11VertexShader> GetDirectXShader() { return shader; }
	Microsoft::WRL::ComPtr<ID3D11InputLayout> GetInputLayout() { return inputLayout; }
	bool GetPerInstanceCompatible() { return perInstanceCompatible; }

	bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr<ID3D11ShaderResourceView> srv);
	bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr<ID3D11SamplerState> samplerState);

protected:
	bool perInstanceCompatible;
	Microsoft::WRL::ComPtr<ID3D11InputLayout> inputLayout;
	Microsoft::WRL::ComPtr<ID3D11VertexShader> shader;
	bool CreateShader(Microsoft::WRL::ComPtr<ID3DBlob> shaderBlob);
	void SetShaderAndCBs();
	void CleanUp();
};


// --------------------------------------------------------
// Derived class for PIXEL shaders ////////////////////////
// --------------------------------------------------------
class SimplePixelShader : public ISimpleShader
{
public:
	SimplePixelShader(Microsoft::WRL::ComPtr<ID3D11Device> device, Microsoft::WRL::ComPtr<ID3D11DeviceContext> context, LPCWSTR shaderFile);
	~SimplePixelShader();
	Microsoft::WRL::ComPtr<ID3D11PixelShader> GetDirectXShader() { return shader; }

	bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr<ID3D11ShaderResourceView> srv);
	bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr<ID3D11SamplerState> samplerState);

protected:
	Microsoft::WRL::ComPtr<ID3D11PixelShader> shader;
	bool CreateShader(Microsoft::WRL::ComPtr<ID3DBlob> shaderBlob);
	void SetShaderAndCBs();
	void CleanUp();
};

// --------------------------------------------------------
// Derived class for DOMAIN shaders ///////////////////////
// --------------------------------------------------------
class SimpleDomainShader : public ISimpleShader
{
public:
	SimpleDomainShader(Microsoft::WRL::ComPtr<ID3D11Device> device, Microsoft::WRL::ComPtr<ID3D11DeviceContext> context, LPCWSTR shaderFile);
	~SimpleDomainShader();
	Microsoft::WRL::ComPtr<ID3D11DomainShader> GetDirectXShader() { return shader; }

	bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr<ID3D11ShaderResourceView> srv);
	bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr<ID3D11SamplerState> samplerState);

protected:
	Microsoft::WRL::ComPtr<ID3D11DomainShader> shader;
	bool CreateShader(Microsoft::WRL::ComPtr<ID3DBlob> shaderBlob);
	void SetShaderAndCBs();
	void CleanUp();
};

// --------------------------------------------------------
// Derived class for HULL shaders /////////////////////////
// --------------------------------------------------------
class SimpleHullShader : public ISimpleShader
{
public:
	SimpleHullShader(Microsoft::WRL::ComPtr<ID3D11Device> device, Microsoft::WRL::ComPtr<ID3D11DeviceContext> context, LPCWSTR shaderFile);
	~SimpleHullShader();
	Microsoft::WRL::ComPtr<ID3D11HullShader> GetDirectXShader() { return shader; }

	bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr<ID3D11ShaderResourceView> srv);
	bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr<ID3D11SamplerState> samplerState);

protected:
	Microsoft::WRL::ComPtr<ID3D11HullShader> shader;
	bool CreateShader(Microsoft::WRL::ComPtr<ID3DBlob> shaderBlob);
	void SetShaderAndCBs();
	void CleanUp();
};

// --------------------------------------------------------
// Derived class for GEOMETRY shaders /////////////////////
// --------------------------------------------------------
class SimpleGeometryShader : public ISimpleShader
{
public:
	SimpleGeometryShader(Microsoft::WRL::ComPtr<ID3D11Device> device, Microsoft::WRL::ComPtr<ID3D11DeviceContext> context, LPCWSTR shaderFile, bool useStreamOut = 0, bool allowStreamOutRasterization = 0);
	~SimpleGeometryShader();
	Microsoft::WRL::ComPtr<ID3D11GeometryShader> GetDirectXShader() { return shader; }

	bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr<ID3D11ShaderResourceView> srv);
	bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr<ID3D11SamplerState> samplerState);

	bool CreateCompatibleStreamOutBuffer(Microsoft::WRL::ComPtr<ID3D11Buffer> buffer, int vertexCount);

	static void UnbindStreamOutStage(Microsoft::WRL::ComPtr<ID3D11DeviceContext> deviceContext);

protected:
	// Shader itself
	Microsoft::WRL::ComPtr<ID3D11GeometryShader> shader;

	// Stream out related
	bool useStreamOut;
	bool allowStreamOutRasterization;
	unsigned int streamOutVertexSize;

	bool CreateShader(Microsoft::WRL::ComPtr<ID3DBlob> shaderBlob);
	bool CreateShaderWithStreamOut(Microsoft::WRL::ComPtr<ID3DBlob> shaderBlob);
	void SetShaderAndCBs();
	void CleanUp();

	// Helpers
	unsigned int CalcComponentCount(unsigned int mask);
};


// --------------------------------------------------------
// Derived class for COMPUTE shaders //////////////////////
// --------------------------------------------------------
class SimpleComputeShader : public ISimpleShader
{
public:
	SimpleComputeShader(Microsoft::WRL::ComPtr<ID3D11Device> device, Microsoft::WRL::ComPtr<ID3D11DeviceContext> context, LPCWSTR shaderFile);
	~SimpleComputeShader();
	Microsoft::WRL::ComPtr<ID3D11ComputeShader> GetDirectXShader() { return shader; }

	void DispatchByGroups(unsigned int groupsX, unsigned int groupsY, unsigned int groupsZ);
	void DispatchByThreads(unsigned int threadsX, unsigned int threadsY, unsigned int threadsZ);

	bool HasUnorderedAccessView(std::string name);

	bool SetShaderResourceView(std::string name, Microsoft::WRL::ComPtr<ID3D11ShaderResourceView> srv);
	bool SetSamplerState(std::string name, Microsoft::WRL::ComPtr<ID3D11SamplerState> samplerState);
	bool SetUnorderedAccessView(std::string name, Microsoft::WRL::ComPtr<ID3D11UnorderedAccessView> uav, unsigned int appendConsumeOffset = -1);

	int GetUnorderedAccessViewIndex(std::string name);

protected:
	Microsoft::WRL::ComPtr<ID3D11ComputeShader> shader;
	std::unordered_map<std::string, unsigned int> uavTable;

	unsigned int threadsX;
	unsigned int threadsY;
	unsigned int threadsZ;
	unsigned int threadsTotal;

	bool CreateShader(Microsoft::WRL::ComPtr<ID3DBlob> shaderBlob);
	void SetShaderAndCBs();
	void CleanUp();
};