DKSOPAL.cpp 4.63 KB
Newer Older
1 2 3 4 5 6 7
#include "DKSOPAL.h"

DKSOPAL::DKSOPAL() {
  dkscol = nullptr;
  dksgreens = nullptr;
}

8 9 10 11 12
DKSOPAL::DKSOPAL(const char* api_name, const char* device_name) {
  setAPI(api_name, strlen(api_name));
  setDevice(device_name, strlen(device_name));
}

13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
DKSOPAL::~DKSOPAL() {
  delete dkscol;
  delete dksgreens;
}

int DKSOPAL::setupOPAL() {
  int ierr = DKS_ERROR;
  if (apiOpenCL()) {
    ierr = OPENCL_SAFECALL( DKS_SUCCESS );
    //TODO: only enable if AMD libraries are available
    dkscol = OPENCL_SAFEINIT_AMD( new OpenCLCollimatorPhysics(getOpenCLBase()) );
    dksgreens = OPENCL_SAFEINIT_AMD( new OpenCLGreensFunction(getOpenCLBase()) );
  } else if (apiCuda()) {
    ierr = CUDA_SAFECALL( DKS_SUCCESS );
    dkscol = CUDA_SAFEINIT( new CudaCollimatorPhysics(getCudaBase()) );
    dksgreens = CUDA_SAFEINIT( new CudaGreensFunction(getCudaBase()) );
  } else if (apiOpenMP()) {
    ierr = MIC_SAFECALL( DKS_SUCCESS );
    dkscol = MIC_SAFEINIT( new MICCollimatorPhysics(getMICBase()) );
    dksgreens = MIC_SAFEINIT( new MICGreensFunction(getMICBase()) );
  } else {
    ierr = DKS_ERROR;
  }

  return ierr;
}

int DKSOPAL::initDevice() {
  int ierr = setupDevice();
42
  if (ierr == DKS_SUCCESS)
43 44
    ierr = setupOPAL();
  return ierr;
45

46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
}

int DKSOPAL::callGreensIntegral(void *tmp_ptr, int I, int J, int K, int NI, int NJ, 
				double hz_m0, double hz_m1, double hz_m2, int streamId) {

    return dksgreens->greensIntegral(tmp_ptr, I, J, K, NI, NJ, 
				     hz_m0, hz_m1, hz_m2, streamId);

}

int DKSOPAL::callGreensIntegration(void *mem_ptr, void *tmp_ptr, 
				   int I, int J, int K, int streamId) {

  return dksgreens->integrationGreensFunction(mem_ptr, tmp_ptr, I, J, K, streamId);
}

int DKSOPAL::callMirrorRhoField(void *mem_ptr, int I, int J, int K, int streamId) {

  return dksgreens->mirrorRhoField(mem_ptr, I, J, K, streamId);  
}

int DKSOPAL::callMultiplyComplexFields(void *mem_ptr1, void *mem_ptr2, int size, int streamId) {
  
  return dksgreens->multiplyCompelxFields(mem_ptr1, mem_ptr2, size, streamId);
}

int DKSOPAL::callCollimatorPhysics(void *mem_ptr, void *par_ptr, 
				   int numparticles, int numparams,
74 75
				   int &numaddback, int &numdead, 
				   bool enableRutherforScattering) 
76 77
{

78
  return dkscol->CollimatorPhysics(mem_ptr, par_ptr, numparticles, enableRutherforScattering);
79 80 81 82

}


83 84
int DKSOPAL::callCollimatorPhysics2(void *mem_ptr, void *par_ptr, int numparticles,
				    bool enableRutherforScattering) 
85 86
{

87
  return dkscol->CollimatorPhysics(mem_ptr, par_ptr, numparticles, enableRutherforScattering);
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
  
}

int DKSOPAL::callCollimatorPhysicsSoA(void *label_ptr, void *localID_ptr, 
				      void *rx_ptr, void *ry_ptr, void *rz_ptr, 
				      void *px_ptr, void *py_ptr, void *pz_ptr,
				      void *par_ptr, int numparticles)
{

  
    return dkscol->CollimatorPhysicsSoA(label_ptr, localID_ptr, 
					rx_ptr, ry_ptr, rz_ptr, 
					px_ptr, py_ptr, pz_ptr,
					par_ptr,  numparticles);

}


int DKSOPAL::callCollimatorPhysicsSort(void *mem_ptr, int numparticles, int &numaddback) 
{

  return dkscol->CollimatorPhysicsSort(mem_ptr, numparticles, numaddback);

}

int DKSOPAL::callCollimatorPhysicsSortSoA(void *label_ptr, void *localID_ptr, 
					  void *rx_ptr, void *ry_ptr, void *rz_ptr, 
					  void *px_ptr, void *py_ptr, void *pz_ptr,
					  void *par_ptr, int numparticles, int &numaddback) 
{

  return MIC_SAFECALL(dkscol->CollimatorPhysicsSortSoA(label_ptr, localID_ptr, 
						       rx_ptr, ry_ptr, rz_ptr, 
						       px_ptr, py_ptr, pz_ptr,
						       par_ptr,  numparticles, numaddback));

}


int DKSOPAL::callParallelTTrackerPush(void *r_ptr, void *p_ptr, int npart, 
				      void *dt_ptr, double dt, double c, 
				      bool usedt, int streamId) 
{

  return dkscol->ParallelTTrackerPush(r_ptr, p_ptr, npart, dt_ptr, dt, c, usedt, streamId);

}

136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
int DKSOPAL::callParallelTTrackerPush(void *r_ptr, void *p_ptr, void *dt_ptr, 
				      int npart, double c, int streamId) {
  
  return dkscol->ParallelTTrackerPush(r_ptr, p_ptr, npart, dt_ptr, 0, c, true, streamId);

}

int DKSOPAL::callParallelTTrackerKick(void *r_ptr, void *p_ptr, void *ef_ptr,
				      void *bf_ptr, void *dt_ptr, double charge, double mass,
				      int npart, double c, int streamId) 
{
  
  return dkscol->ParallelTTrackerKick(r_ptr, p_ptr, ef_ptr, bf_ptr, dt_ptr, 
				      charge, mass, npart, c, streamId);

}

153 154 155 156 157 158 159 160 161 162
int DKSOPAL::callParallelTTrackerPushTransform(void *x_ptr, void *p_ptr, 
					       void *lastSec_ptr, void *orient_ptr, 
					       int npart, int nsec, void *dt_ptr, double dt, 
					       double c, bool usedt, int streamId)
{

  return dkscol->ParallelTTrackerPushTransform(x_ptr, p_ptr, lastSec_ptr, orient_ptr,
					       npart, nsec, dt_ptr, dt, c, usedt, streamId);
  
}