#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include "math_constants.h"

#include "ImageIO.h"

#include <iostream>
#include <ctime>

using namespace std;

#define N           512
#define BLOCKDIM    16
#define SCALE       1.5

#define step_r      -0.8
#define step_i       0.156

void cudaCheckError(char* errInfo);

__global__ void juliaKernel(unsigned char* result, int w, int h, float waveLength);
__device__ int onePixel(float x, float y);

struct gpuComplex {
    float   r;
    float   i;
    __device__ gpuComplex(float _r, float _i) : r(_r), i(_i) {}

    __device__ float magnitudeSqr(void) {
        return r * r + i * i;
    }
    __device__ gpuComplex operator*(const gpuComplex& other) {
        return gpuComplex(r*other.r - i*other.i, i*other.r + r*other.i);
    }
    __device__ gpuComplex operator+(const gpuComplex& other) {
        return gpuComplex(r + other.r, i + other.i);
    }
};

int main(int argc, char** argv)
{
    unsigned char* result = new unsigned char[N*N * 4];
    unsigned char* dev_result;

    cudaMalloc((void**)&dev_result, N*N * 4 * sizeof(unsigned char));
    cudaCheckError("Memory allocation.");

    dim3 blockDim = dim3(BLOCKDIM, BLOCKDIM, 1);
    dim3 gridDim = dim3((N + BLOCKDIM - 1) / BLOCKDIM, (N + BLOCKDIM - 1) / BLOCKDIM, 1);

    juliaKernel << <gridDim, blockDim >> > (dev_result, N, N, SCALE);
    cudaCheckError("Kernel call.");

    cudaMemcpy(result, dev_result, N*N * 4 * sizeof(unsigned char), cudaMemcpyDeviceToHost);
    cudaCheckError("Memcpy: dev -> host");

    writeRGBImageToFile("image.png", result, N, N);

    return 0;
}

__global__ void juliaKernel(unsigned char *ptr, int w, int h, float scale) {

    int i_x = blockIdx.x*blockDim.x + threadIdx.x;
    int i_y = blockIdx.y*blockDim.y + threadIdx.y;

    float x = scale * (w / 2.0f - i_x) / (w / 2.0f);
    float y = scale * (h / 2.0f - i_y) / (h / 2.0f);

    int tid = (i_y*w + i_x) * 4;

    if (x < w && y < h) {
        int value = onePixel(x, y);
        ptr[tid + 0] = 255 * value;
        ptr[tid + 1] = 255 * value;
        ptr[tid + 2] = 0;
		ptr[tid + 3] = 255;
	}

    return;
}

__device__ int onePixel(float x, float y) {

    gpuComplex c(step_r, step_i);
    gpuComplex a(x, y);

    int i = 0;
    for (i = 0; i<200; i++) {
        a = a * a + c;
        if (a.magnitudeSqr() > 1000)
            return 0;
    }

    return 1;
}

void cudaCheckError(char* errInfo)
{
    cudaError_t errCode = cudaGetLastError();

    if (errCode != 0)
    {
        cout << "CUDA error occured: " << endl;
        cout << " - Error description: " << cudaGetErrorString(errCode) << endl;
        cout << " - Error info: " << errInfo << endl;
    }
}