package com.anf.ws.ar.xml.validator;

import java.io.IOException;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.bouncycastle.asn1.ASN1Boolean;
import org.bouncycastle.asn1.ASN1Integer;
import org.bouncycastle.asn1.ASN1ObjectIdentifier;
import org.bouncycastle.asn1.ASN1OctetString;
import org.bouncycastle.asn1.ASN1Primitive;
import org.bouncycastle.asn1.DERBitString;
import org.bouncycastle.asn1.DEROctetString;
import org.bouncycastle.asn1.DERUTF8String;
import org.bouncycastle.asn1.x500.AttributeTypeAndValue;
import org.bouncycastle.asn1.x500.RDN;
import org.bouncycastle.asn1.x500.X500Name;
import org.bouncycastle.asn1.x500.style.BCStyle;
import org.bouncycastle.asn1.x509.Extension;
import org.bouncycastle.asn1.x509.Extensions;
import org.bouncycastle.util.encoders.Hex;
import org.w3c.dom.Document;

import com.anf.cryptotoken.oid.PolicyOID;
import com.anf.pkcs10.PKCS10Request;
import com.anf.ws.ar.xml.parser.impl.ARParser;

public class XMLCertificationRequestValidator
{


	private static final Logger log = LogManager.getLogger(XMLCertificationRequestValidator.class);

	private static final String[] excludeOids = new String[]{ "40", "41", "41.1","47.1" };

	private ARParser parser;

	public XMLCertificationRequestValidator(Document registrationXML)
	{

		parser = new ARParser(registrationXML);

	}

	public String getRegistrationSubjectAttribute(String oid)
	{
		return (String)this.parser.getSubjectAttributes().get(oid);
	}

	public String getRegistrationExtension(String oid)
	{
		return (String)this.parser.getExtensions().get(oid);
	}

	public boolean isValidCertificationRequest(PKCS10Request request)
	{
		return (matchSubjectAttrsCSRandXML(request)) && (matchExtensionsCSRandXML(request));
	}

	private boolean matchSubjectAttrsCSRandXML(PKCS10Request request)
	{
		X500Name subject = request.getSubject();
		log.info("checking pkcs10 subject data vs xml data");
		ASN1ObjectIdentifier[] oids = subject.getAttributeTypes();
		boolean error=false;
		for (ASN1ObjectIdentifier oid : oids) {

			if (!BCStyle.OU.getId().equalsIgnoreCase(oid.getId())) {

				RDN[] values = subject.getRDNs(oid);
				boolean found=false;
				String xmlValue = getRegistrationSubjectAttribute(oid.getId());
				if (xmlValue!=null) {
					for (RDN rdn : values) {
						AttributeTypeAndValue[] typesAndValues = rdn.getTypesAndValues();
						for (AttributeTypeAndValue attributeTypeAndValue : typesAndValues) {
							if (found = xmlValue.equals(attributeTypeAndValue.getValue().toString()))
								break;
						}
						if (found)
							break;
					} 
					if(!found) {
						log.info("oid value not match xml:"+oid+": "+xmlValue);
					}
				}else
					log.warn("xml oid value not found:"+oid);

				error=!found;
				if(error)
					break;
			}

		}
		return !error;
	}

	private boolean matchExtensionsCSRandXML(PKCS10Request request)
	{

		boolean error=false;
		Extensions csrExtensions = request.getExtensions();
		if (csrExtensions!=null) {
			ASN1ObjectIdentifier[] extensionOIDs = csrExtensions.getExtensionOIDs();
			for (ASN1ObjectIdentifier oid : extensionOIDs) {
				if (!excludeOid(oid.getId())) {
					String xmlValue = getRegistrationExtension(oid.getId());
					if (xmlValue!=null) {
						Extension ext = csrExtensions.getExtension(oid);
						ASN1OctetString extnValue = ext.getExtnValue();
						try {
							ASN1Primitive objValue = ASN1Primitive.fromByteArray(extnValue.getOctets());
							String value = getStringValue(objValue);
							error = value==null || !xmlValue.equals(value);							 
							if (error) {
								log.info("csr extension not match xml extension: "+oid+" xml: "+xmlValue+", csr: "+value);
								break;
							}
						} catch (IOException e) {
						} 
					}
				}
			} 
		}
		return !error;
	}

	private String getStringValue(ASN1Primitive obj) {
		String value=null;
		if ((obj instanceof DERUTF8String)) {
			value = ((DERUTF8String)obj).getString();
		}else
			if ((obj instanceof ASN1Integer)) {
				value = String.valueOf(((ASN1Integer)obj).getValue());
			}else
				if ((obj instanceof ASN1Boolean)) {
					value = String.valueOf(((ASN1Boolean)obj).isTrue());
				}else
					if ((obj instanceof DEROctetString)) {
						value = new String(Hex.encode(((DEROctetString)obj).getOctets()));
					}else
						if ((obj instanceof DERBitString)) {
							value = new String(Hex.encode(((DERBitString)obj).getBytes(), 0, ((DERBitString)obj).getBytes().length - ((DERBitString)obj).getPadBits()));
						}else
							if ((obj instanceof ASN1ObjectIdentifier)) {
								value = ((ASN1ObjectIdentifier)obj).getId();
							}
		return value;
	}

	public static boolean excludeOid(String oid)
	{
		PolicyOID policy = PolicyOID.fromOid(oid);
		
		for (int i = 0; i < excludeOids.length; i++) {			
			if (policy!=null && policy.policyRoot().branch(excludeOids[i]).getId().equals(oid)) {
				return true;
			}
		}
		return false;
	}
}
