__global__
#ifdef __CUDACC__
__launch_bounds__ (32*FluxKernelOccupancy, 1)
#endif
void FluxKernel(ShallowWaterSolverCu ShalWatCu)
{
__shared__ float WTempSharedA[BlockDimX*BlockDimZ];
__shared__ float WTempSharedB[BlockDimX*BlockDimZ];
__shared__ float2 huhvTempSharedA[BlockDimX*BlockDimZ];
__shared__ float2 huhvTempSharedB[BlockDimX*BlockDimZ];
int GlobalX = blockIdx.x *(blockDim.x-2) + threadIdx.x-1;
int GlobalZ = blockIdx.y *(blockDim.y-2) + threadIdx.y-1;
float BSouthWest, BSouthEast, BNorthWest, BNorthEast;
float BWest, BEast, BNorth, BSouth;
U UAverageCenter, UReconstructedSouth, UReconstructedWest, UReconstructedEastNeighbor;
int SharedIndex;
if(GlobalZ <= ShalWatCu.SizeZ)
{
BSouthWest=GetValueFromGrid(ShalWatCu.B,GlobalX,GlobalZ,ShalWatCu.XBAlign, ShalWatCu.SizeXB,ShalWatCu.SizeZB);
BSouthEast=GetValueFromGrid(ShalWatCu.B,GlobalX+1,GlobalZ,ShalWatCu.XBAlign, ShalWatCu.SizeXB,ShalWatCu.SizeZB);
BNorthWest=GetValueFromGrid(ShalWatCu.B,GlobalX,GlobalZ+1,ShalWatCu.XBAlign, ShalWatCu.SizeXB,ShalWatCu.SizeZB);
BNorthEast=GetValueFromGrid(ShalWatCu.B,GlobalX+1,GlobalZ+1,ShalWatCu.XBAlign, ShalWatCu.SizeXB,ShalWatCu.SizeZB);
BWest = 0.5f*(BSouthWest+BNorthWest);
BEast = 0.5f*(BSouthEast+BNorthEast);
BNorth = 0.5f*(BNorthWest+BNorthEast);
BSouth = 0.5f*(BSouthWest+BSouthEast);
UAverageCenter = GetU(ShalWatCu,GlobalX, GlobalZ);
U UAverageNorth = GetU(ShalWatCu,GlobalX, GlobalZ+1);
U UAverageSouth = GetU(ShalWatCu,GlobalX, GlobalZ-1);
U UGradientZ = CalculateUGradient(UAverageNorth, UAverageCenter, UAverageSouth,ShalWatCu.DXInv);
U UReconstructedNorth = ReconstructU(UAverageCenter ,UGradientZ,1.f,ShalWatCu.DX);
UReconstructedSouth = ReconstructU(UAverageCenter ,UGradientZ,-1.f,ShalWatCu.DX);
ForcePositivity(UReconstructedSouth.W,UReconstructedNorth.W,BSouth,BNorth,UAverageCenter.W);
SharedIndex = threadIdx.y*blockDim.x + threadIdx.x;
WTempSharedA[SharedIndex] = UReconstructedNorth.W;
huhvTempSharedA[SharedIndex] = make_float2(UReconstructedNorth.hu,UReconstructedNorth.hv);
if(threadIdx.y!=0)
{
U UAverageEast = GetU(ShalWatCu,GlobalX+1, GlobalZ);
U UAverageWest = GetU(ShalWatCu,GlobalX-1, GlobalZ);
U UGradientX = CalculateUGradient(UAverageEast, UAverageCenter, UAverageWest,ShalWatCu.DXInv);
U UReconstructedEast = ReconstructU(UAverageCenter ,UGradientX,1.f,ShalWatCu.DX);
UReconstructedWest = ReconstructU(UAverageCenter ,UGradientX,-1.f,ShalWatCu.DX);
ForcePositivity(UReconstructedWest.W,UReconstructedEast.W,BWest,BEast,UAverageCenter.W);
UReconstructedEastNeighbor.W =__shfl_up( UReconstructedEast.W, 1);
UReconstructedEastNeighbor.hu =__shfl_up( UReconstructedEast.hu, 1);
UReconstructedEastNeighbor.hv =__shfl_up( UReconstructedEast.hv, 1);
}
}
__syncthreads();
//GlobalZ = blockIdx.y *(blockDim.y-2) + threadIdx.y-1;
U HWest, HSouth, HEast, HNorth;
if(threadIdx.y!=0 && GlobalZ <= ShalWatCu.SizeZ)
{
U UReconstructedNorthNeighbor;
SharedIndex -= blockDim.x;
UReconstructedNorthNeighbor.W = WTempSharedA[SharedIndex];
float2 huhv = huhvTempSharedA[SharedIndex];
UReconstructedNorthNeighbor.hu = huhv.x;
UReconstructedNorthNeighbor.hv = huhv.y;
float BPlus; float BMinus;
float hNorthNeighbor = ReconstructH(UReconstructedNorthNeighbor.W,BSouth);
float hSouth = ReconstructH(UReconstructedSouth.W,BSouth);
float uNorthNeighbor, vNorthNeighbor,uSouth,vSouth;
ReconstructV(hNorthNeighbor,uNorthNeighbor,vNorthNeighbor,UReconstructedNorthNeighbor.hu,UReconstructedNorthNeighbor.hv,ShalWatCu.DX,ShalWatCu.DX4);
ReconstructV(hSouth,uSouth,vSouth,UReconstructedSouth.hu,UReconstructedSouth.hv,ShalWatCu.DX,ShalWatCu.DX4);
U UCorrectedNorthNeighbor=CalculateUCorrected(UReconstructedNorthNeighbor.W,hNorthNeighbor,uNorthNeighbor,vNorthNeighbor);
U UCorrectedSouth=CalculateUCorrected(UReconstructedSouth.W,hSouth,uSouth,vSouth);
CalculateSpeedOfPropagation(BPlus, BMinus, hNorthNeighbor,hSouth,vNorthNeighbor,vSouth,ShalWatCu.g);
HSouth = CalculateHZ(BPlus, BMinus ,UCorrectedNorthNeighbor,UCorrectedSouth,BSouth, ShalWatCu.g, ShalWatCu.Minh);
SharedIndex = threadIdx.y*blockDim.x + threadIdx.x;
WTempSharedB[SharedIndex]= HSouth.W;
huhvTempSharedB[SharedIndex]= make_float2(HSouth.hu,HSouth.hv);
if(threadIdx.y != blockDim.y-1)
{
float APlus; float AMinus;
float hEastNeighbor = ReconstructH(UReconstructedEastNeighbor.W,BWest);
float hWest = ReconstructH(UReconstructedWest.W,BWest);
float uEastNeighbor, vEastNeighbor,uWest,vWest;
ReconstructV(hEastNeighbor,uEastNeighbor,vEastNeighbor,UReconstructedEastNeighbor.hu,UReconstructedEastNeighbor.hv,ShalWatCu.DX,ShalWatCu.DX4);
ReconstructV(hWest,uWest,vWest,UReconstructedWest.hu,UReconstructedWest.hv,ShalWatCu.DX,ShalWatCu.DX4);
U UCorrectedEastNeighbor=CalculateUCorrected(UReconstructedEastNeighbor.W,hEastNeighbor,uEastNeighbor,vEastNeighbor);
U UCorrectedWest=CalculateUCorrected(UReconstructedWest.W,hWest,uWest,vWest);
CalculateSpeedOfPropagation(APlus, AMinus, hEastNeighbor,hWest,uEastNeighbor,uWest,ShalWatCu.g);
float MaximumSpeedOfPropagation = fmaxf(APlus,fmaxf(-AMinus,fmaxf(BPlus,-BMinus)));
MaximumSpeedOfPropagation = ReduceToMaxWarp(MaximumSpeedOfPropagation);
int ThreadIDWithinWarp = (threadIdx.y*blockDim.x + threadIdx.x) % warpSize;
if(ThreadIDWithinWarp == 0)
AtomicMaxInt(ShalWatCu.MaximumSpeedOfPropagation,MaximumSpeedOfPropagation);
HWest = CalculateHX(APlus, AMinus ,UCorrectedEastNeighbor,UCorrectedWest,BWest, ShalWatCu.g, ShalWatCu.Minh);
HEast.W =__shfl_down( HWest.W, 1);
HEast.hu =__shfl_down( HWest.hu, 1);
HEast.hv =__shfl_down( HWest.hv, 1);
}
}
__syncthreads();
GlobalX = blockIdx.x *(blockDim.x-2) + threadIdx.x-1;
GlobalZ = blockIdx.y *(blockDim.y-2) + threadIdx.y-1;
if(threadIdx.x > 0 && threadIdx.x < blockDim.x -1 && threadIdx.y > 0 && threadIdx.y < blockDim.y-1 &&
GlobalX >= 0 && GlobalX < ShalWatCu.SizeX && GlobalZ >= 0 && GlobalZ < ShalWatCu.SizeZ )
{
SharedIndex = threadIdx.y*blockDim.x + threadIdx.x + blockDim.x;
float2 Zhuhv = huhvTempSharedB[SharedIndex];
HNorth.W = WTempSharedB[SharedIndex];
HNorth.hu = Zhuhv.x;
HNorth.hv = Zhuhv.y;
float BAverage = 0.25f*(BSouthWest+BSouthEast+BNorthWest+BNorthEast);
float SX=CalculateS(BWest,BEast,BAverage,UAverageCenter.W,ShalWatCu.DXInv,ShalWatCu.g);
float SZ=CalculateS(BSouth,BNorth,BAverage, UAverageCenter.W,ShalWatCu.DXInv,ShalWatCu.g);
U DUDT;
for(int i=0; i< 3; i++)
DUDT.Values[i] = - (HNorth.Values[i]-HSouth.Values[i]+HEast.Values[i]-HWest.Values[i])*ShalWatCu.DXInv;
DUDT.hu+=SX;
DUDT.hv+=SZ;
int Index = GlobalZ * ShalWatCu.XAlign + GlobalX;
ShalWatCu.TempA[Index]= DUDT.W;
ShalWatCu.TempB[Index]= DUDT.hu;
ShalWatCu.TempC[Index]= DUDT.hv;
}
}