Commit 9734157a authored by Uldis Locans's avatar Uldis Locans

push and kick for OPALs BorisPusher

parent 24f9c9db
...@@ -33,6 +33,10 @@ public: ...@@ -33,6 +33,10 @@ public:
virtual int ParallelTTrackerPush(void *r_ptr, void *p_ptr, int npart, void *dt_ptr, virtual int ParallelTTrackerPush(void *r_ptr, void *p_ptr, int npart, void *dt_ptr,
double dt, double c, bool usedt = false, int streamId = -1) = 0; double dt, double c, bool usedt = false, int streamId = -1) = 0;
virtual int ParallelTTrackerKick(void *r_ptr, void *p_ptr, void *ef_ptr,
void *bf_ptr, void *dt_ptr, double charge,
double mass, int npart, double c, int streamId = -1) = 0;
virtual int ParallelTTrackerPushTransform(void *x_ptr, void *p_ptr, void *lastSec_ptr, virtual int ParallelTTrackerPushTransform(void *x_ptr, void *p_ptr, void *lastSec_ptr,
void *orient_ptr, int npart, int nsec, void *dt_ptr, void *orient_ptr, int npart, int nsec, void *dt_ptr,
double dt, double c, bool usedt = false, double dt, double c, bool usedt = false,
......
...@@ -33,6 +33,14 @@ __device__ inline double dot(double3 &d1, double3 &d2) { ...@@ -33,6 +33,14 @@ __device__ inline double dot(double3 &d1, double3 &d2) {
} }
__device__ inline double3 cross(double3 &lhs, double3 &rhs) {
double3 tmp;
tmp.x = lhs.y * rhs.z - lhs.z * rhs.y;
tmp.y = lhs.z * rhs.x - lhs.x * rhs.z;
tmp.z = lhs.x * rhs.y - lhs.y * rhs.x;
return tmp;
}
__device__ inline bool checkHit(double &z, double *par) { __device__ inline bool checkHit(double &z, double *par) {
/* check if particle is in the degrader material */ /* check if particle is in the degrader material */
...@@ -423,7 +431,7 @@ __global__ void kernelPush(double3 *gR, double3 *gP, int npart, double dtc) { ...@@ -423,7 +431,7 @@ __global__ void kernelPush(double3 *gR, double3 *gP, int npart, double dtc) {
} }
__global__ void kernelPush(double3 *gR, double3 *gP, int npart, double *gdt, double c) { __global__ void kernelPush(double3 *gR, double3 *gP, double *gdt, int npart, double c) {
//get global id and thread id //get global id and thread id
volatile int tid = threadIdx.x; volatile int tid = threadIdx.x;
...@@ -449,7 +457,51 @@ __global__ void kernelPush(double3 *gR, double3 *gP, int npart, double *gdt, dou ...@@ -449,7 +457,51 @@ __global__ void kernelPush(double3 *gR, double3 *gP, int npart, double *gdt, dou
} }
} }
//TODO: kernel for push with switch off unitless positions with dt[i]*c __global__ void kernelKick(double3 *gR, double3 *gP, double3 *gEf,
double3 *gBf, double *gdt, double charge,
double mass, int npart, double c)
{
volatile int tid = threadIdx.x;
volatile int idx = blockIdx.x * blockDim.x + tid;
if (idx < npart) {
double3 R = gR[idx];
double3 P = gP[idx];
double3 Ef = gEf[idx];
double3 Bf = gBf[idx];
double dt = gdt[idx];
P.x += 0.5 * dt * charge * c / mass * Ef.x;
P.y += 0.5 * dt * charge * c / mass * Ef.y;
P.z += 0.5 * dt * charge * c / mass * Ef.z;
double gamma = sqrt(1.0 + dot(P, P));
double3 t, w, s;
t.x = 0.5 * dt * charge * c * c / (gamma * mass) * Bf.x;
t.y = 0.5 * dt * charge * c * c / (gamma * mass) * Bf.y;
t.z = 0.5 * dt * charge * c * c / (gamma * mass) * Bf.z;
double3 crossPt = cross(P, t);
w.x = P.x + crossPt.x;
w.y = P.y + crossPt.y;
w.z = P.z + crossPt.z;
s.x = 2.0 / (1.0 + dot(t, t)) * t.x;
s.y = 2.0 / (1.0 + dot(t, t)) * t.y;
s.z = 2.0 / (1.0 + dot(t, t)) * t.z;
double3 crossws = cross(w, s);
P.x += crossws.x;
P.y += crossws.y;
P.z += crossws.z;
P.x += 0.5 * dt * charge * c / mass * Ef.x;
P.y += 0.5 * dt * charge * c / mass * Ef.y;
P.z += 0.5 * dt * charge * c / mass * Ef.z;
gP[idx] = P;
}
}
__device__ double3 deviceTransformTo(const double3 &vec, const double3 &ori) { __device__ double3 deviceTransformTo(const double3 &vec, const double3 &ori) {
...@@ -671,12 +723,12 @@ int CudaCollimatorPhysics::ParallelTTrackerPush(void *r_ptr, void *p_ptr, int np ...@@ -671,12 +723,12 @@ int CudaCollimatorPhysics::ParallelTTrackerPush(void *r_ptr, void *p_ptr, int np
} }
} else { } else {
if (streamId == -1) { if (streamId == -1) {
kernelPush<<<blocks, threads>>>((double3*)r_ptr, (double3*)p_ptr, npart, kernelPush<<<blocks, threads>>>((double3*)r_ptr, (double3*)p_ptr,
(double*)dt_ptr, c); (double*)dt_ptr, npart, c);
} else { } else {
cudaStream_t cs = m_base->cuda_getStream(streamId); cudaStream_t cs = m_base->cuda_getStream(streamId);
kernelPush<<<blocks, threads, 0, cs >>>((double3*)r_ptr, (double3*)p_ptr, npart, kernelPush<<<blocks, threads, 0, cs >>>((double3*)r_ptr, (double3*)p_ptr,
(double*)dt_ptr, c); (double*)dt_ptr, npart, c);
} }
} }
...@@ -684,6 +736,29 @@ int CudaCollimatorPhysics::ParallelTTrackerPush(void *r_ptr, void *p_ptr, int np ...@@ -684,6 +736,29 @@ int CudaCollimatorPhysics::ParallelTTrackerPush(void *r_ptr, void *p_ptr, int np
return DKS_SUCCESS; return DKS_SUCCESS;
} }
int CudaCollimatorPhysics::ParallelTTrackerKick(void *r_ptr, void *p_ptr, void *ef_ptr,
void *bf_ptr, void *dt_ptr, double charge,
double mass, int npart,
double c, int streamId)
{
int threads = BLOCK_SIZE;
int blocks = npart / threads + 1;
//call kernel
if (streamId == -1) {
kernelKick<<<blocks, threads>>>((double3*)r_ptr, (double3*)p_ptr, (double3*)ef_ptr,
(double3*)bf_ptr, (double*)dt_ptr, charge, mass, npart, c);
} else {
cudaStream_t cs = m_base->cuda_getStream(streamId);
kernelKick<<<blocks, threads, 0, cs >>>((double3*)r_ptr, (double3*)p_ptr,
(double3*)ef_ptr, (double3*)bf_ptr,
(double*)dt_ptr, charge, mass, npart, c);
}
return DKS_SUCCESS;
}
int CudaCollimatorPhysics::ParallelTTrackerPushTransform(void *x_ptr, void *p_ptr, int CudaCollimatorPhysics::ParallelTTrackerPushTransform(void *x_ptr, void *p_ptr,
void *lastSec_ptr, void *orient_ptr, void *lastSec_ptr, void *orient_ptr,
int npart, int nsec, int npart, int nsec,
......
...@@ -141,6 +141,10 @@ public: ...@@ -141,6 +141,10 @@ public:
int ParallelTTrackerPush(void *r_ptr, void *p_ptr, int npart, void *dt_ptr, int ParallelTTrackerPush(void *r_ptr, void *p_ptr, int npart, void *dt_ptr,
double dt, double c, bool usedt = false, int streamId = -1); double dt, double c, bool usedt = false, int streamId = -1);
int ParallelTTrackerKick(void *r_ptr, void *p_ptr, void *ef_ptr,
void *bf_ptr, void *dt_ptr, double charge, double mass,
int npart, double c, int streamId = -1);
/** BorisPusher push function with transformto function form OPAL /** BorisPusher push function with transformto function form OPAL
* ParallelTTracker integration from OPAL implemented in cuda. * ParallelTTracker integration from OPAL implemented in cuda.
* For more details see ParallelTTracler docomentation in opal * For more details see ParallelTTracler docomentation in opal
......
...@@ -6,6 +6,11 @@ DKSOPAL::DKSOPAL() { ...@@ -6,6 +6,11 @@ DKSOPAL::DKSOPAL() {
dksgreens = nullptr; dksgreens = nullptr;
} }
DKSOPAL::DKSOPAL(const char* api_name, const char* device_name) {
setAPI(api_name, strlen(api_name));
setDevice(device_name, strlen(device_name));
}
DKSOPAL::~DKSOPAL() { DKSOPAL::~DKSOPAL() {
delete dksfft; delete dksfft;
delete dkscol; delete dkscol;
...@@ -14,7 +19,6 @@ DKSOPAL::~DKSOPAL() { ...@@ -14,7 +19,6 @@ DKSOPAL::~DKSOPAL() {
int DKSOPAL::setupOPAL() { int DKSOPAL::setupOPAL() {
int ierr = DKS_ERROR; int ierr = DKS_ERROR;
if (apiOpenCL()) { if (apiOpenCL()) {
ierr = OPENCL_SAFECALL( DKS_SUCCESS ); ierr = OPENCL_SAFECALL( DKS_SUCCESS );
//TODO: only enable if AMD libraries are available //TODO: only enable if AMD libraries are available
...@@ -40,9 +44,10 @@ int DKSOPAL::setupOPAL() { ...@@ -40,9 +44,10 @@ int DKSOPAL::setupOPAL() {
int DKSOPAL::initDevice() { int DKSOPAL::initDevice() {
int ierr = setupDevice(); int ierr = setupDevice();
if (ierr == DKS_ERROR) if (ierr == DKS_SUCCESS)
ierr = setupOPAL(); ierr = setupOPAL();
return ierr; return ierr;
} }
/* setup fft plans to reuse if multiple ffts of same size are needed */ /* setup fft plans to reuse if multiple ffts of same size are needed */
...@@ -237,7 +242,6 @@ int DKSOPAL::callCollimatorPhysicsSoA(void *label_ptr, void *localID_ptr, ...@@ -237,7 +242,6 @@ int DKSOPAL::callCollimatorPhysicsSoA(void *label_ptr, void *localID_ptr,
int DKSOPAL::callCollimatorPhysicsSort(void *mem_ptr, int numparticles, int &numaddback) int DKSOPAL::callCollimatorPhysicsSort(void *mem_ptr, int numparticles, int &numaddback)
{ {
return dkscol->CollimatorPhysicsSort(mem_ptr, numparticles, numaddback); return dkscol->CollimatorPhysicsSort(mem_ptr, numparticles, numaddback);
} }
...@@ -265,6 +269,23 @@ int DKSOPAL::callParallelTTrackerPush(void *r_ptr, void *p_ptr, int npart, ...@@ -265,6 +269,23 @@ int DKSOPAL::callParallelTTrackerPush(void *r_ptr, void *p_ptr, int npart,
} }
int DKSOPAL::callParallelTTrackerPush(void *r_ptr, void *p_ptr, void *dt_ptr,
int npart, double c, int streamId) {
return dkscol->ParallelTTrackerPush(r_ptr, p_ptr, npart, dt_ptr, 0, c, true, streamId);
}
int DKSOPAL::callParallelTTrackerKick(void *r_ptr, void *p_ptr, void *ef_ptr,
void *bf_ptr, void *dt_ptr, double charge, double mass,
int npart, double c, int streamId)
{
return dkscol->ParallelTTrackerKick(r_ptr, p_ptr, ef_ptr, bf_ptr, dt_ptr,
charge, mass, npart, c, streamId);
}
int DKSOPAL::callParallelTTrackerPushTransform(void *x_ptr, void *p_ptr, int DKSOPAL::callParallelTTrackerPushTransform(void *x_ptr, void *p_ptr,
void *lastSec_ptr, void *orient_ptr, void *lastSec_ptr, void *orient_ptr,
int npart, int nsec, void *dt_ptr, double dt, int npart, int nsec, void *dt_ptr, double dt,
......
...@@ -46,6 +46,8 @@ public: ...@@ -46,6 +46,8 @@ public:
DKSOPAL(); DKSOPAL();
DKSOPAL(const char* api_name, const char* device_name);
~DKSOPAL(); ~DKSOPAL();
int initDevice(); int initDevice();
...@@ -210,6 +212,20 @@ public: ...@@ -210,6 +212,20 @@ public:
int npart, int nsec, void *dt_ptr, int npart, int nsec, void *dt_ptr,
double dt, double c, bool usedt = false, double dt, double c, bool usedt = false,
int streamId = -1); int streamId = -1);
/**
* Integration code from ParallelTTracker from OPAL.
* For specifics check OPAL docs and CudaCollimatorPhysics class docs
*/
int callParallelTTrackerPush(void *r_ptr, void *p_ptr, void *dt_ptr,
int npart, double c, int streamId = -1);
/**
* Integration code from ParallelTTracker from OPAL.
* For specifics check OPAL docs and CudaCollimatorPhysics class docs
*/
int callParallelTTrackerKick(void *r_ptr, void *p_ptr, void *ef_ptr,
void *bf_ptr, void *dt_ptr, double charge,
double mass, int npart, double c, int streamId = -1);
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment