__m128	CPU_Ray_Tracer::pakcet_kd_tree_traverse(const RayPacket& ray_packet,Hit* pHits)
{
	Packet_Traversal_Data m_Stack[KdTreeNode::MAX_DEPTH];
	unsigned int top=0;

	const vec4f zero(0,0,0,0);
	const vec4f infinity(INFINITY,INFINITY,INFINITY,INFINITY);
	
	// first. intersect the big box, binding our kd-tree.
	vec4f t_near,t_far;
	vec4ub hit_box;
	for(int i=0;i<4;i++)
		hit_box[i] = Intersect<Ray4f,AABB4f>::exec(ray_packet.GetRay(i),kd_tree->GetBindingBox(),&t_near[i],&t_far[i]);

	// suppose that if any ray miss the big box, so there is no need to traversae all rays in packet
	if(!hit_box[0] || !hit_box[1] || !hit_box[2] || !hit_box[3])
		return zero; 

	for(int i=0;i<4;i++)
	{
		// if t_near located bihind ray origin(ray.pos), store t_near as 0.
		if(t_near[i] <= 0.0f)
			t_near[i] = 0.0f;

		// ray hit the box behind it's origin
		if( t_near[i] >= t_far[i] ) 
			return zero;
	}

	// precompute 1.0f/ray.dir
	const vec4f one(1,1,1,1);
	const vec4f rdir_inv[3] =  {_mm_div_ps(one,ray_packet.rd[0]),
								_mm_div_ps(one,ray_packet.rd[1]),
								_mm_div_ps(one,ray_packet.rd[2])};

	// this trick allows remove one "if"
	const int left_or_right[4] = { (ray_packet.rd[0][0] >= 0)?1:0,  
								   (ray_packet.rd[1][0] >= 0)?1:0,  
								   (ray_packet.rd[2][0] >= 0)?1:0, 0 };
	// use 32 bit integer offsets instead of pointers
	unsigned int nodeOffset		= 0;

	// we will store 0xFFFFFFFF in each component of trav_result if corresponding ray intersected something 
	__m128 trav_result = zero;

	// this mask shows what rays we must traverse. All four rays valid in the beggining.
	register __m128 mask;
	mask.m128_u32[0] = 0xFFFFFFFF;
	mask.m128_u32[1] = 0xFFFFFFFF;
	mask.m128_u32[2] = 0xFFFFFFFF;
	mask.m128_u32[3] = 0xFFFFFFFF;

	KdTreeNode node;
	while(true)
	{
		node = kd_tree->GetNodeByOffset(nodeOffset);
		while(!node.Leaf())
		{
			const int   split_axis = node.GetAxis();
			const vec4f split_pos(node.GetSplitPos(),node.GetSplitPos(),node.GetSplitPos(),node.GetSplitPos());
			
			// t_split[0..3] = (split_pos[0..3] - ray.pos[0..3])/ ray.dir[0..3]
			const vec4f t_split	= _mm_mul_ps( _mm_sub_ps( split_pos, ray_packet.ro[split_axis] ), rdir_inv[split_axis] );

			unsigned int nearNodeOffset	= node.GetLeftOffset() + (1-left_or_right[split_axis]);
			unsigned int farNodeOffset	= node.GetLeftOffset() + left_or_right[split_axis];

			//if all t_split[0..3] < t_near[0..3] where mask = 'trtraversale' then all rays in far node
			if (_mm_movemask_ps( _mm_and_ps( _mm_cmplt_ps( t_split, t_near ), mask ))  == _mm_movemask_ps(mask) )
			{ 
				node = kd_tree->GetNodeByOffset(farNodeOffset);
			} 
			// else if all t_split[0..3] > t_far[0..3] where mask = 'true' then all rays in near node
			else if(_mm_movemask_ps( _mm_and_ps( _mm_cmpge_ps( t_split, t_far ), mask )) ==  _mm_movemask_ps(mask) )
			{
				node = kd_tree->GetNodeByOffset(nearNodeOffset);
			}
			else
			{
				node = kd_tree->GetNodeByOffset(nearNodeOffset);

				// stack.push( node->far, t_far[0..3], mask[0..3] )
				m_Stack[top].nodeOffset = farNodeOffset;
				m_Stack[top].t_far		= t_far;
				m_Stack[top].mask		= mask;
				top++;
	
				// needn't use mask there, we will repair t_far when pop far node from stack
				t_far = t_split;
				// go to the nearest node, traverse rays where t_split[0..3] > t_near[0..3]
				mask  = _mm_and_ps(mask,_mm_cmpgt_ps( t_split, t_near ));
			}
		}

		// compute intersections, store result 't' in t_hit if intersects
		vec4f t_hit = infinity;
		for(int i=0;i<4;i++)
		{
			if(mask.m128_u32[i])
				t_hit[i] =	IntersectAllPrimitivesInLeaf(ray_packet.GetRay(i),node,pHits+i);
		}

		// store trav_result where t_hit[0..3] <= t_far[0..3]
		trav_result = _mm_or_ps(trav_result,_mm_cmple_ps(t_hit,t_far));

		// all 4 rays have been stoped, nothing else to traverse anymore
		if(top == 0 || _mm_movemask_ps(trav_result) == 0xF )
			break;

		//t_near[0..3] = t_far[0..3] where ray not hit; 
		mask	= _mm_andnot_ps(trav_result	,mask); 
		t_near	= _mm_or_ps(_mm_and_ps( mask, t_far),_mm_andnot_ps(mask,t_near));

		// (next_node,t_far,mask) = stack.pop();
		top--;
		nodeOffset  = m_Stack[top].nodeOffset;
		t_far		= m_Stack[top].t_far; // btw, here we repair t_far
		mask		= _mm_andnot_ps(trav_result	,m_Stack[top].mask);
	}
	
	return trav_result;
}
