function [u,x,mesh,l2err,h1err] = fem(n,basis,polydeg,plotsln)
%FEM Runs a finite element code to solve Poisson's equation
%
%  -div(grad u) = f,     in Omega=(0,1)^2
%             u = 0,     on boundary of Omega
%
% where f = 2 pi^2 sin(x pi) sin(y pi).
%
% Arguments:
%    n       - Number of elements in 1D - gives either n^2 (for square
%              elements) or 2n^2 (for triangular elements) elements.
%
%    basis   - The basis to use, also determines the type of elements.
%              The supported basis are (note multiple names are available
%              for each basis):
%              
%              tri
%              triangle
%              simplex
%                  Lagrange 2-simplex basis functions on triangles
%              quad
%              quadrilateral
%              rect
%              rectangle
%                  Tensor Lagrange basis on rectangles
%              reduced_tri
%              reduced_triangle
%              reduced_simplex
%                  Reduced 2-simplex basis functions (no interior DoFs).
%                  Only supports polynomial degree p=3
%              reduced_quad
%              reduced_quadrilateral
%              reduced_rect
%              reduced_rectangle
%                  Reduced rectangle basis functions (no interior DoFs).
%                  Only supports polynomial degree p=2 or p=3
%              cr
%              crouzeix_raviart
%                  Crouzeix-Raviart non-conforming basis on triangles
%                  Only supports polynomial degree p=1
%              rotated_quad
%              rotated_quadrilateral
%              rotated_rect
%              rotated_rectangle
%              rotated_bilinear
%                  Rotated bilinear non-conforming basis on rectangles
%                  Only supports polynomial degree p=1
%
%    polydeg - Polynomial degree of basis function to use, must be a
%              positive integer. Note that some basis only support limited
%              polynomial degrees
%
%    plotsln - If specified and is true the solution is plotted; calls
%              plot_soln
%
% Returns:
%   u     - Vector of values of the numerical solution at the points
%           represented by the degrees of freedom
%   x     - Nx2 matrix, where N is the number of degrees of freedom, with
%           each row containing the X and Y coordinate of corresponding to
%           that degree of freedom.
%   mesh  - Structure denoting the mesh, contains the following fields:
%             vertices    - Vx2 matrix, where V is number of vertices, with
%                           each row containing X/Y coordinates of vertex
%             elements    - ExN, where E is the number of elements and
%                           N is the number of vertices per element, with
%                           each row containing the indices of the vertices
%                           of the element in counter-clockwise order
%             dof_mapping - ExM, where M is the number of degrees of
%                           freedom per element, containing per row the
%                           indices of the global degrees of freedom
%                           corresponding to the element's local DoFs.
%   l2err - The error between the numerical and analytical solution
%           measured in the L^2 norm.
%   h1err - The error between the numerical and analytical solution
%           measured in the H^1 seminorm.

% Ensure polydeg and n are both integers
polydeg = round(polydeg);
n = round(n);

fprintf('Creating FE space           - ');
tic;
% Create the finite element spaces, the mesh, and the quadrature rules
% depending on the basis requested
% Details of how these are constructed will be skipped, but after this
% switch the following variables will be set:
%
% x          - Matrix containing, per row, coordinates of the dofs
% bdry       - Logical vector (1 or 0) denoting for each dof if the dof is
%              on the boundary of the domain
% mesh       - Structure denoting the mesh (see above)
% quadrature - Structure containing basis function and quadrature info:
%                points        - Matrix containing points on the reference
%                                element for quadrature - one point per row
%                weights       - Vector of quadrature weights corresponding
%                                to the quadrature points
%                basis         - Matrix of the basis functions on the
%                                reference element at the quadrature points
%                                Each row corresponds to a quadrature point
%                                and each column to one of the basis
%                                functions, see basis_* functions for order
%                                of the basis functions
%                deriv_basis_x - Matrix of the derivative of the basis
%                                functions w.r.t the x-coordinate evaluated
%                                at the quadrature points
%                deriv_basis_y - Matrix of the derivative of the basis
%                                functions w.r.t the y-coordinate evaluated
%                                at the quadrature points
switch basis
    case {'tri', 'triangle', 'simplex'}
        [x,bdry,mesh,quadrature] = lagrange_tri_fe_space(n,polydeg);
    case {'quad','quadrilateral','rect','rectangle'}
        [x,bdry,mesh,quadrature] = lagrange_quad_fe_space(n,polydeg);
    case {'reduced_tri', 'reduced_triangle', 'reduced_simplex'}
        [x,bdry,mesh,quadrature] = reduced_tri_fe_space(n,polydeg);
    case {'reduced_quad','reduced_quadrilateral','reduced_rect','reduced_rectangle'}
        [x,bdry,mesh,quadrature] = reduced_quad_fe_space(n,polydeg);
    case {'cr', 'crouzeix_raviart'}
        if polydeg ~= 1
            error('Only p=1 supported for Crouzeix-Raviart element')
        end
        [x,bdry,mesh,quadrature] = crouzeix_raviart_fe_space(n);
    case {'rotated_quad','rotated_quadrilateral','rotated_rect','rotated_rectangle','rotated_bilinear'}
        if polydeg ~= 1
            error('Only p=1 supported for rotated bilinear element')
        end
        [x,bdry,mesh,quadrature] = rotated_bilinear_fe_space(n);
    otherwise
        error('Unknown basis type');
end
toc;

% Get the number of dofs per element
ndofs = size(quadrature.basis,2);

% Right hand side forcing function
f = @(X) 2.*pi^2.*sin(X(:,1).*pi).*sin(X(:,2).*pi);

% Setup right hand side vector
b = zeros(length(x),1);

% Setup row, column, and value vectors for constructing sparse matrix
% We essentially append the local stiffness matrix, with column/row
% indices, to the following three vectors. We can construct a sparse matrix
% directly from those vectors (the vectors are allowed to have duplicate
% row/column pairs - when constructed into the matrix the entries are
% summed which is the desired behaviour).
fprintf('Constructing matrix and RHS - ');
tic;
vlen = ndofs*ndofs*size(mesh.elements,1);
I = zeros(vlen,1);
J = zeros(vlen,1);
V = zeros(vlen,1);

% Loop over every element
idx = 1;
for ele=1:size(mesh.elements,1)    
    % Get coordinates of the elements vertices
    vertices = mesh.vertices(mesh.elements(ele,:), :)';
    
    % Compute affine map from the reference element to the current element
    if size(vertices,2) == 3
        % Triangles
        [B, c] = tri_mapping(vertices);
    else
        % Quads
        [B, c] = quad_mapping(vertices);
    end
    % We need to use B & c from the mapping transposed as the quadrature
    % points are row vectors (not column), and we need to expand c from
    % a row vector (after transpose) into matrix with the same number of
    % rows as quadrature points; then we can simply do
    %    q*B+c,
    % where q is the matrix of quadrature points, to transform them all in
    % one operation to global space
    B = B';
    c = repmat(c', size(quadrature.points,1), 1);
    
    % Compute right hand side at quadrature points
    fx = f(quadrature.points*B+c);
    
    % Construct local matrix and right-hand side
    local_rhs = zeros(ndofs,1);
    inv_jac = inv(B);
    det_jac = abs(det(B));
    % Loop over DoFs for both test and trial functions
    for i=1:ndofs
        for j=1:ndofs
            I(idx) = mesh.dof_mapping(ele,i);
            J(idx) = mesh.dof_mapping(ele,j);
            % \int_T grad(u_j).grad(v_i) dx  using numerical integration:
            V(idx) = det_jac*dot(quadrature.weights, ...
                sum(([quadrature.deriv_basis_x(:,i) quadrature.deriv_basis_y(:,i)]*inv_jac') ...
                .*([quadrature.deriv_basis_x(:,j) quadrature.deriv_basis_y(:,j)]*inv_jac'), 2));
            idx = idx+1;
        end
        % \int_T f v_i dx   using numerical integration:
        local_rhs(i) = det_jac*dot(fx.*quadrature.basis(:,i), quadrature.weights);
    end
    
    % Add the local RHS into the correct indices of the global RHS
    b(mesh.dof_mapping(ele,:)) = b(mesh.dof_mapping(ele,:)) + local_rhs;
end

% Construct matrix from sparse vectors
A = sparse(I, J, V, length(x), length(x));

% Delete boundary DoFs from the matrix and right hand side
A = A(~bdry, ~bdry);
b = b(~bdry);
toc;

% Zero-init result vector
u = zeros(length(x),1);
% Compute non-boundary values by solving Au=b:
fprintf('Solving linear system       - ');
tic;
u(~bdry) = A\b;
toc;

fprintf('Computing errors            - ');
tic;
% Compute errors - we have a known analytical solution: sin(x pi) sin(y pi)
anal_soln = @(X) sin(X(:,1).*pi).*sin(X(:,2).*pi);
deriv_anal_soln = @(X) [pi.*cos(X(:,1).*pi).*sin(X(:,2).*pi) ...
    pi.*sin(X(:,1).*pi).*cos(X(:,2).*pi)];

l2err = 0;
h1err = 0;
% Loop over every element
for ele=1:size(mesh.elements,1)
    % Compute affine map from the reference element to the current element
    vertices = mesh.vertices(mesh.elements(ele,:), :)';
    if size(vertices,2) == 3
        % Triangles
        [B, c] = tri_mapping(vertices);
    else
        % Quads
        [B, c] = quad_mapping(vertices);
    end
    B = B';
    c = repmat(c', size(quadrature.points,1), 1);
    inv_jac = inv(B);
    det_jac = abs(det(B));
    
    % Numerical solution at quadrature points is simply sum of basis
    % functions at those points multiplied by the coefficient from u
    % (similar for the derivatives, but need to multiply by inverse of B)
    uh = quadrature.basis*u(mesh.dof_mapping(ele,:));
    d_uh = [quadrature.deriv_basis_x*u(mesh.dof_mapping(ele,:)) ...
        quadrature.deriv_basis_y*u(mesh.dof_mapping(ele,:))]*inv_jac';
    
    % The L2 norm is simply the Euclidean norm of the vector of
    % values at the quadrature points, mutliplied by the determinant of the
    % jacobian
    l2err = l2err + det_jac*norm(anal_soln(quadrature.points*B+c)-uh);
    
    % Similar for the H1 seminorm - except deriv_anal_soln-d_uh is actually
    % a two-column matrix (first column x-deriv, second y-deriv)
    % so we need to use the Frobenius norm instead
    h1err = h1err + det_jac*...
        norm(deriv_anal_soln(quadrature.points*B+c)-d_uh, 'fro');
end
toc;

% Plot solution if requested
if exist('plotsln','var') && plotsln
    plot_soln(basis,polydeg,mesh,u);
end

% =========================================================================
end % End of main function
% =========================================================================

% =========================================================================
% Functions to construct Finite Element Spaces
% =========================================================================

function [x,bdry,mesh,quadrature] = lagrange_tri_fe_space(n,polydeg)
%LAGRANGE_TRI_FE_SPACE Gets mesh, basis and quadrature for Lagrange n-simplex
if polydeg < 1
    error('Only p>=1 supported for Lagrange simplex element')
end
mesh = construct_tri_mesh(n);
ndofs = (polydeg+2)*(polydeg+1)/2;
mesh.dof_mapping = zeros(size(mesh.elements,1), ndofs);
nodes = linspace(0, 1, n*polydeg+1);
[X, Y] = meshgrid(nodes, nodes);
bdry = false(size(X));
bdry(1,:) = true;
bdry(end,:) = true;
bdry(:,1) = true;
bdry(:,end) = true;
x = [reshape(X, [], 1) reshape(Y, [], 1)];
bdry = reshape(bdry, [], 1);
stride = n*polydeg+1;
for i=0:n-1
    for j=0:n-1
        index = i*stride*polydeg+j*polydeg+1;
        mesh.dof_mapping(2*(i*n+j)+1, 1:3) = ...
            [index index+stride*polydeg index+polydeg];
        mesh.dof_mapping(2*(i*n+j)+2, 1:3) = ...
            [index+stride*polydeg+polydeg index+polydeg index+stride*polydeg];
        if polydeg > 1
            mesh.dof_mapping(2*(i*n+j)+1, 4:(polydeg+2)) ...
                = (index+stride):stride:(index+stride*(polydeg-1));
            mesh.dof_mapping(2*(i*n+j)+1, (polydeg+3):(2*polydeg+1)) ...
                = (index+stride*(polydeg-1)+1):(1-stride):(index+stride+polydeg-1);
            mesh.dof_mapping(2*(i*n+j)+1, (2*polydeg+2):(3*polydeg)) ...
                = (index+polydeg-1):-1:(index+1);

            mesh.dof_mapping(2*(i*n+j)+2, 4:(polydeg+2)) ...
                = (index+stride*(polydeg-1)+polydeg):-stride:(index+stride+polydeg);
            mesh.dof_mapping(2*(i*n+j)+2, (polydeg+3):(2*polydeg+1)) ...
                = (index+stride+polydeg-1):(stride-1):(index+stride*(polydeg-1)+1);
            mesh.dof_mapping(2*(i*n+j)+2, (2*polydeg+2):(3*polydeg)) ...
                = (index+stride*polydeg+1):(index+stride*polydeg+polydeg-1);
        end
        if polydeg > 2
            offset = 3*polydeg+1;
            for k=1:(polydeg-2)
                mesh.dof_mapping(2*(i*n+j)+1, offset:(offset+polydeg-2-k)) ...
                    = (index+stride+k):stride:(index+stride*(polydeg-1-k)+k);
                mesh.dof_mapping(2*(i*n+j)+2, offset:(offset+polydeg-2-k)) ...
                    = (index+polydeg+stride*(polydeg-1)-k):-stride:(index+polydeg+stride*(k+1)-k);
                offset = offset + polydeg-1-k;
            end
        end
    end
end
[quadrature.points,quadrature.weights] = quadrature_tri(polydeg*2);
quadrature.basis = ...
    basis_lagrange_tri(polydeg,quadrature.points(:,1),quadrature.points(:,2));
[quadrature.deriv_basis_x,quadrature.deriv_basis_y] = ...
    grad_basis_lagrange_tri(polydeg,quadrature.points(:,1),quadrature.points(:,2));
end

function [x,bdry,mesh,quadrature] = lagrange_quad_fe_space(n,polydeg)
%LAGRANGE_QUAD_FE_SPACE Gets mesh, basis and quadrature for Lagrange rectangle
if polydeg < 1
    error('Only p>=1 supported for Lagrange quadrilateral element')
end
mesh = construct_quad_mesh(n);
ndofs = (polydeg+1)^2;
mesh.dof_mapping = zeros(size(mesh.elements,1), ndofs);
nodes = linspace(0, 1, n*polydeg+1);
[X, Y] = meshgrid(nodes, nodes);
bdry = false(size(X));
bdry(1,:) = true;
bdry(end,:) = true;
bdry(:,1) = true;
bdry(:,end) = true;
x = [reshape(X, [], 1) reshape(Y, [], 1)];
bdry = reshape(bdry, [] ,1);
stride = n*polydeg+1;
for i=0:n-1
    for j=0:n-1
        index = i*stride*polydeg+j*polydeg+1;
        mesh.dof_mapping(i*n+j+1, 1:4) = ...
            [index index+stride*polydeg index+stride*polydeg+polydeg index+polydeg];
        if polydeg > 1
            mesh.dof_mapping(i*n+j+1, 5:(polydeg+3)) ...
                = (index+stride):stride:(index+stride*(polydeg-1));
            mesh.dof_mapping(i*n+j+1, (polydeg+4):(2*polydeg+2)) ...
                = (index+stride*polydeg+1):(index+stride*polydeg+polydeg-1);
            mesh.dof_mapping(i*n+j+1, (2*polydeg+3):(3*polydeg+1)) ...
                = (index+stride*(polydeg-1)+polydeg):-stride:(index+stride+polydeg);
            mesh.dof_mapping(i*n+j+1, (3*polydeg+2):(4*polydeg)) ...
                = (index+polydeg-1):-1:(index+1);
            for k=1:(polydeg-1)
                mesh.dof_mapping(i*n+j+1, (4*polydeg+(polydeg-1)*(k-1)+1):(4*polydeg+(polydeg-1)*k)) ...
                    = (index+stride+k):stride:(index+stride*(polydeg-1)+k);
            end
        end
    end
end
[quadrature.points,quadrature.weights] = quadrature_rect(polydeg*2);
quadrature.basis = ...
    basis_lagrange_quad(polydeg,quadrature.points(:,1),quadrature.points(:,2));
[quadrature.deriv_basis_x,quadrature.deriv_basis_y] = ...
    grad_basis_lagrange_quad(polydeg,quadrature.points(:,1),quadrature.points(:,2));
end

function [x,bdry,mesh,quadrature] = reduced_tri_fe_space(n,polydeg)
%REDUCED_TRI_FE_SPACE Gets mesh, basis and quadrature for reduced n-simplex
if polydeg ~= 3
    error('Only p=3 supported for reduced Lagrange simplex element')
end
mesh = construct_tri_mesh(n);
ndofs = polydeg*3;
mesh.dof_mapping = zeros(size(mesh.elements,1), ndofs);
verts = linspace(0, 1, n+1);
reduced = linspace(0, 1, n*polydeg+1);
reduced(1:polydeg:end) = [];
[VX, VY] = meshgrid(verts,verts);
[EX, EY] = meshgrid(verts,reduced);
DY = repmat(reduced', 1, n);
DX = repmat(flip(reshape(reduced,[], n)), n, 1);

bdry = false(size(VX));
bdry(1,:) = true;
bdry(end,:) = true;
bdry(:,1) = true;
bdry(:,end) = true;
edge_bdry = false(size(EX));
edge_bdry(:,1) = true;
edge_bdry(:,end) = true;

x = [reshape(VX, [], 1) reshape(VY, [], 1); ...
    reshape(EX, [], 1) reshape(EY, [], 1); ...
    reshape(EY, [], 1) reshape(EX, [], 1); ...
    reshape(DX, [], 1) reshape(DY, [], 1)];
bdry = [reshape(bdry, [], 1); reshape(edge_bdry, [], 1); ...
    reshape(edge_bdry, [], 1); reshape(false(size(DX)), [], 1)];

yoffset = (n+1)^2;
xoffset = yoffset + n*(n+1)*(polydeg-1);
doffset = xoffset + n^2*(polydeg-1);
for i=0:n-1
    for j=0:n-1
        mesh.dof_mapping(2*(i*n+j)+1, 1:3) = ...
            [i*(n+1)+j+1 (i+1)*(n+1)+j+1 i*(n+1)+j+2];
        mesh.dof_mapping(2*(i*n+j)+2, 1:3) = ...
            [(i+1)*(n+1)+j+2 i*(n+1)+j+2 (i+1)*(n+1)+j+1];
        if polydeg > 1
            mesh.dof_mapping(2*(i*n+j)+1, 4:(polydeg+2)) ...
                = xoffset + (j*n+i)*(polydeg-1) + (1:(polydeg-1));
            mesh.dof_mapping(2*(i*n+j)+1, (polydeg+3):(2*polydeg+1)) ...
                = doffset + ((i+1)*n+j)*(polydeg-1) + (1:(polydeg-1));
            mesh.dof_mapping(2*(i*n+j)+1, (2*polydeg+2):(3*polydeg)) ...
                = yoffset + (i*n+j)*(polydeg-1) + ((polydeg-1):-1:1);

            mesh.dof_mapping(2*(i*n+j)+2, 4:(polydeg+2)) ...
                = xoffset + ((j+1)*n+i)*(polydeg-1) + ((polydeg-1):-1:1);
            mesh.dof_mapping(2*(i*n+j)+2, (polydeg+3):(2*polydeg+1)) ...
                = doffset + ((i+1)*n+j)*(polydeg-1) + ((polydeg-1):-1:1);
            mesh.dof_mapping(2*(i*n+j)+2, (2*polydeg+2):(3*polydeg)) ...
                = yoffset + ((i+1)*n+j)*(polydeg-1) + (1:(polydeg-1));
        end
    end
end
[quadrature.points,quadrature.weights] = quadrature_tri(polydeg*2);
quadrature.basis = ...
    basis_reduced_tri(polydeg,quadrature.points(:,1),quadrature.points(:,2));
[quadrature.deriv_basis_x,quadrature.deriv_basis_y] = ...
    grad_basis_reduced_tri(polydeg,quadrature.points(:,1),quadrature.points(:,2));
end

function [x,bdry,mesh,quadrature] = reduced_quad_fe_space(n,polydeg)
%REDUCED_QUAD_FE_SPACE Gets mesh, basis and quadrature for reduced rectangle
if polydeg < 2 || polydeg > 3
    error('Only p=2,3 supported for reduced Lagrange quadrilateral element')
end
mesh = construct_quad_mesh(n);
ndofs = polydeg*4;
mesh.dof_mapping = zeros(size(mesh.elements,1), ndofs);
verts = linspace(0, 1, n+1);
reduced = linspace(0, 1, n*polydeg+1);
reduced(1:polydeg:end) = [];
[VX, VY] = meshgrid(verts,verts);
[EX, EY] = meshgrid(verts,reduced);

bdry = false(size(VX));
bdry(1,:) = true;
bdry(end,:) = true;
bdry(:,1) = true;
bdry(:,end) = true;
edge_bdry = false(size(EX));
edge_bdry(:,1) = true;
edge_bdry(:,end) = true;

x = [reshape(VX, [], 1) reshape(VY, [], 1); ...
    reshape(EX, [], 1) reshape(EY, [], 1); ...
    reshape(EY, [], 1) reshape(EX, [], 1)];
bdry = [reshape(bdry, [], 1); reshape(edge_bdry, [], 1); reshape(edge_bdry, [], 1)];

yoffset = (n+1)^2;
xoffset = yoffset + n*(n+1)*(polydeg-1);
for i=0:n-1
    for j=0:n-1
        mesh.dof_mapping(i*n+j+1, 1:4) = ...
            [i*(n+1)+j+1 (i+1)*(n+1)+j+1 (i+1)*(n+1)+j+2 i*(n+1)+j+2];
        if polydeg > 1
            mesh.dof_mapping(i*n+j+1, 5:(polydeg+3)) ...
                = xoffset + (j*n+i)*(polydeg-1) + (1:(polydeg-1));
            mesh.dof_mapping(i*n+j+1, (polydeg+4):(2*polydeg+2)) ...
                = yoffset + ((i+1)*n+j)*(polydeg-1) + (1:(polydeg-1));
            mesh.dof_mapping(i*n+j+1, (2*polydeg+3):(3*polydeg+1)) ...
                = xoffset + ((j+1)*n+i)*(polydeg-1) + ((polydeg-1):-1:1);
            mesh.dof_mapping(i*n+j+1, (3*polydeg+2):(4*polydeg)) ...
                = yoffset + (i*n+j)*(polydeg-1) + ((polydeg-1):-1:1);
        end
    end
end
[quadrature.points,quadrature.weights] = quadrature_rect(polydeg*2);
quadrature.basis = ...
    basis_reduced_quad(polydeg,quadrature.points(:,1),quadrature.points(:,2));
[quadrature.deriv_basis_x,quadrature.deriv_basis_y] = ...
    grad_basis_reduced_quad(polydeg,quadrature.points(:,1),quadrature.points(:,2));
end

function [x,bdry,mesh,quadrature] = crouzeix_raviart_fe_space(n)
%CROUZEIX_RAVIART_FE_SPACE Gets mesh, basis and quadrature for Crouzeix-Raviart
mesh = construct_tri_mesh(n);
ndofs = 3;
mesh.dof_mapping = zeros(size(mesh.elements,1), ndofs);
verts = linspace(0, 1, n+1);
midpoints = linspace(1/(2*n), 1-1/(2*n), n);
[EX, EY] = meshgrid(verts,midpoints);
[DX, DY] = meshgrid(midpoints,midpoints);
bdry = false(size(EX));
bdry(:,1) = true;
bdry(:,end) = true;
x = [reshape(EX, [], 1) reshape(EY, [], 1); ...
    reshape(EY, [], 1) reshape(EX, [], 1); ...
    reshape(DX, [], 1) reshape(DY, [], 1)];
bdry = [reshape(bdry, [], 1); reshape(bdry, [], 1); reshape(false(size(DX)), [], 1)];
xoffset = n*(n+1);
doffset = xoffset + n*(n+1);
for i=0:n-1
    for j=0:n-1
        mesh.dof_mapping(2*(i*n+j)+1, 1:3) = ...
            [doffset+i*n+j+1 i*n+j+1 xoffset+j*n+i+1 ];
        mesh.dof_mapping(2*(i*n+j)+2, 1:3) = ...
            [doffset+i*n+j+1 (i+1)*n+j+1 xoffset+(j+1)*n+i+1];
    end
end
[quadrature.points,quadrature.weights] = quadrature_tri(2);
quadrature.basis = ...
    basis_crouzeix_raviart(quadrature.points(:,1),quadrature.points(:,2));
[quadrature.deriv_basis_x,quadrature.deriv_basis_y] = ...
    grad_basis_crouzeix_raviart(quadrature.points(:,1),quadrature.points(:,2));
end

function [x,bdry,mesh,quadrature] = rotated_bilinear_fe_space(n)
%ROTATED_BILINEAR_FE_SPACE Gets mesh, basis and quadrature for rotated bilinear
mesh = construct_quad_mesh(n);
ndofs = 4;
mesh.dof_mapping = zeros(size(mesh.elements,1), ndofs);
verts = linspace(0, 1, n+1);
midpoints = linspace(1/(2*n), 1-1/(2*n), n);
[EX, EY] = meshgrid(verts,midpoints);
bdry = false(size(EX));
bdry(:,1) = true;
bdry(:,end) = true;
x = [reshape(EX, [], 1) reshape(EY, [], 1);
    reshape(EY, [], 1) reshape(EX, [], 1)];
bdry = [reshape(bdry, [], 1); reshape(bdry, [], 1)];
xoffset = n*(n+1);
for i=0:n-1
    for j=0:n-1
        mesh.dof_mapping(i*n+j+1, 1:4) = ...
            [(i+1)*n+j+1 xoffset+(j+1)*n+i+1 i*n+j+1 xoffset+j*n+i+1];
    end
end
[quadrature.points,quadrature.weights] = quadrature_rect(2);
quadrature.basis = ...
    basis_rotated_bilinear(quadrature.points(:,1),quadrature.points(:,2));
[quadrature.deriv_basis_x,quadrature.deriv_basis_y] = ...
    grad_basis_rotated_bilinear(quadrature.points(:,1),quadrature.points(:,2));
end
       
% =========================================================================
% Functions to construct mesh
% =========================================================================

function mesh = construct_quad_mesh(n)
%CONSTRUCT_QUAD_MESH Constructs a uniform mesh of quadrilaterals
%
% The mesh is constructed by generating a uniform mesh of n-by-n squares.
%
% Arguments:
%   n - Number of required squares in one dimension
%
% Returns:
%   mesh - Structure denoting the mesh, contains the following fields:
%            vertices - Vx2 matrix, where V is number of vertices, with
%                       each row containing X & Y coordinates of the vertex
%            elements - Nx4, where N=n^2 is the number of elements, with
%                       each row containing the indices of the four
%                       vertices of the element in counter-clockwise order.
x = linspace(0, 1, n+1);
[X,Y] = meshgrid(x, x);
indices = reshape(1:((n+1)^2), n+1, n+1);
mesh.vertices = [reshape(X, [], 1) reshape(Y, [], 1)];
mesh.elements = [reshape(indices(1:n,1:n), [], 1), reshape(indices(1:n,2:(n+1)), [], 1), ...
    reshape(indices(2:(n+1),2:(n+1)), [], 1), reshape(indices(2:(n+1),1:n), [], 1)];
end

function mesh = construct_tri_mesh(n)
%CONSTRUCT_TRI_MESH Constructs a uniform mesh of triangles
%
% The mesh is constructed by generating a uniform mesh of n-by-n squares,
% and then dividing each square diagonally north-west to south-east.
%
% Arguments:
%   n - Number of required squares in one dimension
%
% Returns:
%   mesh - Structure denoting the mesh, contains the following fields:
%            vertices - Vx2 matrix, where V is number of vertices, with
%                       each row containing X & Y coordinates of the vertex
%            elements - Nx3, where N=2n^2 is the number of elements, with
%                       each row containing the indices of the three
%                       vertices of the element in counter-clockwise order.
nodes1d = linspace(0, 1, n+1);
indices = reshape(1:((n+1)^2), n+1, n+1);
[X,Y] = meshgrid(nodes1d, nodes1d);
mesh.vertices = [reshape(X, [], 1) reshape(Y, [], 1)];
no_eles = 2*n^2;
mesh.elements = zeros(no_eles, 3);
mesh.elements(1:2:no_eles,:) = [reshape(indices(1:n,1:n), [], 1), ...
    reshape(indices(1:n,2:(n+1)), [], 1), reshape(indices(2:(n+1),1:n), [], 1)];
mesh.elements(2:2:no_eles,:) = [reshape(indices(2:(n+1),2:(n+1)), [], 1), ...
    reshape(indices(2:(n+1),1:n), [], 1), reshape(indices(1:n,2:(n+1)), [], 1)];
end

% =========================================================================
% Functions returning mapping from reference element to actual element
% =========================================================================

function [B, b] = quad_mapping(vertices)
%QUAD_MAPPING Returns affine map from [-1,1]^2 to specified element
% Assumes we have parallelogram - ensures affine map
%
% Arguments:
%   vertices - 4x2 matrix containing the coordinates of the element in
%              counter-clockwise order
%
% Returns:
%   [B, b] - Matrix and vector to perform mapping from reference: x = B*x+b

B = [vertices(:,2)-vertices(:,1) vertices(:,4)-vertices(:,1)]/2;
b = (vertices(:,1)+vertices(:,2)+vertices(:,3)+vertices(:,4))/4;
end

function [B, b] = tri_mapping(vertices)
%QUAD_MAPPING Returns affine map from (-1 -1)-(1,-1)-(-1,1) triangle
%
% Arguments:
%   vertices - 3x2 matrix containing the coordinates of the element in
%              counter-clockwise order
%
% Returns:
%   [B, b] - Matrix and vector to perform mapping from reference: x = B*x+b

B = [vertices(:,2)-vertices(:,1)  vertices(:,3)-vertices(:,1)]/2;
b = (vertices(:,2)+vertices(:,3))/2;
end