/*
	Copyright (C) 2003 Frdric Giudicelli (contact_nos@yahoo.com). 
	All rights reserved.

	This product includes cryptographic software written by Eric Young
	(eay@cryptsoft.com)

	This program is released under the GPL with the additional exemption that
	compiling, linking, and/or using OpenSSL is allowed.

	This program is free software; you can redistribute it and/or modify it
	under the terms of the GNU General Public License as published by the Free
	Software Foundation; either version 2 of the License.

	This program is distributed in the hope that it will be useful, but WITHOUT
	ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
	FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
	more details.

	You should have received a copy of the GNU General Public License along with
	this program; if not, write to the Free Software Foundation, Inc., 59 Temple
	Place, Suite 330, Boston, MA 02111-1307 USA
*/


// PKCS12.cpp: implementation of the PKCS12 class.
//
//////////////////////////////////////////////////////////////////////

#include "PKI_PKCS12.h"


//////////////////////////////////////////////////////////////////////
// Construction/Destruction
//////////////////////////////////////////////////////////////////////


PKI_PKCS12 PKI_PKCS12::EmptyInstance;

PKI_PKCS12::PKI_PKCS12(const PKI_PKCS12 & other)
{
	p12 = NULL;
	*this = other;
}

PKI_PKCS12::PKI_PKCS12()
{
	p12 = NULL;
}

PKI_PKCS12::~PKI_PKCS12()
{
	Clear();
}

int PKI_PKCS12::dump_certs_keys_p12 (STACK_OF(X509) * CertStore, PKCS12 *p12, const char *pass, int passlen)
{
	STACK_OF(PKCS7) *asafes;
	STACK_OF(PKCS12_SAFEBAG) *bags;
	int i, bagnid;
	PKCS7 *p7;

	if (!( asafes = M_PKCS12_unpack_authsafes (p12))) return 0;

	for (i = 0; i < sk_PKCS7_num (asafes); i++)
	{
		p7 = sk_PKCS7_value (asafes, i);
		bagnid = OBJ_obj2nid (p7->type);
		if (bagnid == NID_pkcs7_data)
		{
			bags = M_PKCS12_unpack_p7data (p7);
		}
		else
		{	
			if (bagnid == NID_pkcs7_encrypted)
			{
				bags = M_PKCS12_unpack_p7encdata (p7, pass, passlen);
			} 
			else
			{
				continue;
			}
		}
		if (!bags) return 0;

	    	
		if (!dump_certs_pkeys_bags (CertStore, bags, pass, passlen))
		{
			sk_PKCS12_SAFEBAG_pop_free (bags, PKCS12_SAFEBAG_free);
			return 0;
		}
		sk_PKCS12_SAFEBAG_pop_free (bags, PKCS12_SAFEBAG_free);
	}
	sk_PKCS7_pop_free (asafes, PKCS7_free);
	return 1;
}

int PKI_PKCS12::dump_certs_pkeys_bags (STACK_OF(X509) * CertStore, STACK_OF(PKCS12_SAFEBAG) *bags, const char *pass, int passlen)
{
	int i;
	for (i = 0; i < sk_PKCS12_SAFEBAG_num (bags); i++)
	{
		if (!dump_certs_pkeys_bag	( CertStore,
										sk_PKCS12_SAFEBAG_value (bags, i),
										pass, 
										passlen
									)
			)

			return 0;
	}
	return 1;
}

int PKI_PKCS12::dump_certs_pkeys_bag (STACK_OF(X509) * CertStore, PKCS12_SAFEBAG *bag, const char *pass, int passlen)
{
	PKCS8_PRIV_KEY_INFO *p8;
	X509 *x509;
	EVP_PKEY * PrivateKey;

	switch (M_PKCS12_bag_type(bag))
	{
		case NID_keyBag:
			p8 = bag->value.keybag;
			if (!(PrivateKey = EVP_PKCS82PKEY (p8))) return 0;
			if(!m_EndUserKey.SetKey(PrivateKey))
			{
				NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
			}
			EVP_PKEY_free(PrivateKey);
			break;

		case NID_pkcs8ShroudedKeyBag:
			if (!(p8 = M_PKCS12_decrypt_skey (bag, pass, passlen))) return 0;
			if (!(PrivateKey = EVP_PKCS82PKEY (p8))) return 0;
			PKCS8_PRIV_KEY_INFO_free(p8);
			if(!m_EndUserKey.SetKey(PrivateKey))
			{
				NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
			}
			EVP_PKEY_free(PrivateKey);
			break;

		case NID_certBag:
			if (M_PKCS12_cert_bag_type(bag) != NID_x509Certificate ) return 1;				
			if (!(x509 = M_PKCS12_certbag2x509(bag))) return 0;
			
			sk_X509_push(CertStore, x509);
			break;

		case NID_safeContentsBag:
			return dump_certs_pkeys_bags (CertStore, bag->value.safes, pass, passlen);
						
		default:
			return 1;
			break;
	}
	return 1;
}

bool PKI_PKCS12::LoadFromFile(const char *Filename, const char *Password)
{
	BIO *in;
	
	Clear();
	
	in = BIO_new_file(Filename, "rb");
	if(!in)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
		return false;
	}
	p12 = d2i_PKCS12_bio (in, NULL);
	if(!p12)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
		BIO_free_all(in);
		return false;
	}

	if(!Private_Load(Password, Password?true:false))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
		BIO_free_all(in);
		return false;
	}
	BIO_free_all(in);
	
	return true;
}


bool PKI_PKCS12::Private_Load(const char *Password, bool LoadAll)
{
	EVP_CIPHER *enc;
	STACK_OF(X509) * CertStore;
	PKI_CERT ParentCert;
	char Name[50];

	enc = (EVP_CIPHER *)EVP_des_ede3_cbc();
		
	if(!LoadAll)
		return true;
	
	
	if(!Password [0] && PKCS12_verify_mac(p12, NULL, 0))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_INVALID_P12_PWD);
		return false;
	} 
	else
	{
		if (!PKCS12_verify_mac(p12, Password, -1))
		{
			NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_INVALID_P12_PWD);
			return false;
		}
	}


	CertStore = sk_X509_new_null();
	if(!CertStore)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_MALLOC);
		return false;
	}

	if (!dump_certs_keys_p12 (CertStore, p12, Password, -1))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
		sk_X509_pop_free(CertStore, X509_free);
		return false;
	}

	//S'il n'y a pas de cert
	if(sk_X509_num(CertStore) <= 0)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_MISMATCH_KEYS);
		sk_X509_pop_free(CertStore, X509_free);
		return false;
	}


	X509 * cert;
	int i;

	m_EndUserCert.Clear();
	for(i=0; i<sk_X509_num(CertStore); i++)
	{
		cert = sk_X509_value(CertStore, i);
		if(cert)
		{
			if (X509_check_private_key(cert, (EVP_PKEY*)m_EndUserKey.GetRsaKey()))
			{//End user cert
				if(!m_EndUserCert.SetCert(cert))
				{            
					NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_MISMATCH_KEYS);
					sk_X509_pop_free(CertStore, X509_free);
					return false;
				}
				if(!m_EndUserCert.SetPrivateKey(m_EndUserKey))
				{
					NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_MISMATCH_KEYS);
					sk_X509_pop_free(CertStore, X509_free);
					return false;
				}
			}
			else
			{//Parent cert
				if(!ParentCert.SetCert(cert))
				{
					NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_MISMATCH_KEYS);
					sk_X509_pop_free(CertStore, X509_free);
					return false;
				}
				sprintf(Name, "%ld", (unsigned long)cert);
				m_ParentCerts.Add(Name, ParentCert.GetCertPEM().c_str());
			}
		}
	}

	if(!m_EndUserCert)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_MISMATCH_KEYS);
		sk_X509_pop_free(CertStore, X509_free);
		return false;
	}
	sk_X509_pop_free(CertStore, X509_free);

	return true;
}

bool PKI_PKCS12::Load(const mString & strP12, const char *Password)
{
	unsigned char * outdatas;
	unsigned char * p;
	int outdataslen;
	
	Clear();
	
	
	if(!strP12.ToDER(&outdatas, &outdataslen))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
		return false;
	}
	
	p = outdatas;
	p12 = d2i_PKCS12(NULL, &p, outdataslen);
	if(!p12)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
		free(outdatas);
		return false;
	}
	free(outdatas);
	
	if(!Private_Load(Password, Password?true:false))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
		return false;
	}
	
	return true;
}

bool PKI_PKCS12::Reload(const char *Password)
{
	if(!p12)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_BAD_PARAM);
		return false;
	}
	if(!Private_Load(Password, Password?true:false))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
		return false;
	}
	return true;
}

void PKI_PKCS12::Clear(bool all)
{
	if(p12)
	{
		PKCS12_free(p12);
		p12 = NULL;
	}
	Pkcs12Pem = "";
	if(all)
	{
		m_ParentCerts.Clear();
		m_EndUserCert.Clear();
		m_EndUserKey.Clear();
	}
}

const PKI_CERT & PKI_PKCS12::GetEndUserCert() const
{
	return m_EndUserCert;
}

const PKI_RSA & PKI_PKCS12::GetEndUserKey() const
{
	return m_EndUserKey;
}

const HashTable_String & PKI_PKCS12::GetParentCerts() const
{
	return m_ParentCerts;
}

void PKI_PKCS12::SetEndUserCert(const PKI_CERT & EndUserCert)
{
	m_EndUserCert = EndUserCert;
}

void PKI_PKCS12::SetEndUserKey(const PKI_RSA & EndUserKey)
{
	m_EndUserKey = EndUserKey;
}

void PKI_PKCS12::SetParentCerts(const HashTable_String & ParentCerts)
{
	m_ParentCerts = ParentCerts;
}

bool PKI_PKCS12::Generate(const char * passphrase)
{
	Clear(false);
	
	STACK_OF(PKCS12_SAFEBAG) *bags;
	STACK_OF(PKCS7) *safes;
	PKCS12_SAFEBAG *bag;
	PKCS8_PRIV_KEY_INFO *p8;
	PKCS7 *authsafe;
	unsigned char keyid[EVP_MAX_MD_SIZE];
	unsigned int keyidlen;
	STACK_OF(X509) *certs;
	int i;
	PKI_CERT tmpCert;
	const char * name;
	const char * strCert;
	X509 *cert;
	long pos;

	/*
		-> Gestion des certificats
	*/
	certs = sk_X509_new_null();
	if (!certs)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_MALLOC);
		return false;
	}

	// We get the public key id
	if(!X509_digest(m_EndUserCert.GetX509(), EVP_sha1(), keyid, &keyidlen))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
		sk_X509_pop_free(certs, X509_free);
		return false;
	}

	// We add the end user's cert to the stack
	if(!sk_X509_push(certs, m_EndUserCert.GetX509(true)))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
		sk_X509_pop_free(certs, X509_free);
		return false;
	}

	for(i=0; i<m_ParentCerts.EntriesCount() ; i++)
	{
		strCert = m_ParentCerts.Get(i);
		if(!strCert)
		{
			NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
			sk_X509_pop_free(certs, X509_free);
			return false;
		}

		if(!tmpCert.SetCert(strCert))
		{
			NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
			sk_X509_pop_free(certs, X509_free);
			return false;
		}

		if(!sk_X509_push(certs, tmpCert.GetX509(true)))
		{
			NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
			sk_X509_pop_free(certs, X509_free);
			return false;
		}
	}
	
	bags = sk_PKCS12_SAFEBAG_new_null ();
	if(!bags)
	{
		sk_X509_pop_free(certs, X509_free);
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_MALLOC);
		return false;
	}

	// We now have loads of certificates: include them all
	for(i = 0; i < sk_X509_num(certs); i++) 
	{
		cert = sk_X509_value(certs, i);
		if(!cert)
		{
			NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
			sk_X509_pop_free(certs, X509_free);
			return false;
		}

		bag = PKCS12_x5092certbag(cert);
		if(!bag)
		{
			NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
			sk_X509_pop_free(certs, X509_free);
			return false;
		}
		
		name = NULL;
		if(tmpCert.SetCert(cert))
		{
			pos = tmpCert.GetCertDN().SeekEntryName("commonName", -1);
			if(pos != HASHTABLE_NOT_FOUND)
			{
				name = tmpCert.GetCertDN().Get(pos); 
			}
		}
		
		// If it's and user cert, set id
		if(tmpCert == m_EndUserCert)
		{
			if(!name) name = "User Certificate";
			if(!PKCS12_add_friendlyname(bag, name, -1))
			{
				NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
				sk_X509_pop_free(certs, X509_free);
				return false;
			}
			if(!PKCS12_add_localkeyid(bag, keyid, keyidlen))
			{
				NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
				sk_X509_pop_free(certs, X509_free);
				return false;
			}
		} 
		else
		{
			if(!name) name = "CA Certificate";
			if(!PKCS12_add_friendlyname(bag, name, -1))
			{
				NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
				sk_X509_pop_free(certs, X509_free);
				return false;
			}
		}
		if(!sk_PKCS12_SAFEBAG_push(bags, bag))
		{
			NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
			sk_X509_pop_free(certs, X509_free);
			return false;
		}
	}
	sk_X509_pop_free(certs, X509_free);



	authsafe = PKCS12_pack_p7encdata(NID_pbe_WithSHA1And40BitRC2_CBC, passphrase, -1, NULL, 0, PKCS12_DEFAULT_ITER, bags);
	if(!authsafe)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_GEN_P12);
		sk_PKCS12_SAFEBAG_pop_free(bags, PKCS12_SAFEBAG_free);
		return false;
	}
	sk_PKCS12_SAFEBAG_pop_free(bags, PKCS12_SAFEBAG_free);


	safes = sk_PKCS7_new_null ();
	if(!safes)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_MALLOC);
		return false;
	}

	// On ajoute tout les certificats au P7b
	if(!sk_PKCS7_push (safes, authsafe))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
		sk_PKCS7_pop_free(safes, PKCS7_free);
		return false;
	}
	/*
		<- Gestion des certificats
	*/




	//*********************************************************************



	/*
		-> Gestion de la cl�priv�
	*/

	
	/* Make a shrouded key bag */
	p8 = EVP_PKEY2PKCS8 ((EVP_PKEY*)m_EndUserKey.GetRsaKey());
	if(!p8)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_GEN_P12);
		sk_PKCS7_pop_free(safes, PKCS7_free);
		return false;
	}

	bag = PKCS12_MAKE_SHKEYBAG(NID_pbe_WithSHA1And3_Key_TripleDES_CBC, passphrase, -1, NULL, 0, PKCS12_DEFAULT_ITER, p8);
	if(!bag)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_GEN_P12);
		sk_PKCS7_pop_free(safes, PKCS7_free);
		PKCS8_PRIV_KEY_INFO_free(p8);
		return false;
	}
	PKCS8_PRIV_KEY_INFO_free(p8);


	if(!PKCS12_add_friendlyname (bag, "User Private Key", -1))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_GEN_P12);
		sk_PKCS7_pop_free(safes, PKCS7_free);
		return false;
	}
	if(!PKCS12_add_localkeyid (bag, keyid, keyidlen))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_GEN_P12);
		sk_PKCS7_pop_free(safes, PKCS7_free);
		return false;
	}
	
	bags = sk_PKCS12_SAFEBAG_new_null();
	if(!bags)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_MALLOC);
		sk_PKCS7_pop_free(safes, PKCS7_free);
		return false;
	}
	// Turn it into unencrypted safe bag
	if(!sk_PKCS12_SAFEBAG_push (bags, bag))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_GEN_P12);
		sk_PKCS7_pop_free(safes, PKCS7_free);
		sk_PKCS12_SAFEBAG_pop_free(bags, PKCS12_SAFEBAG_free);
		return false;
	}
	
	authsafe = PKCS12_pack_p7data (bags);
	if(!authsafe)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_GEN_P12);
		sk_PKCS7_pop_free(safes, PKCS7_free);
		sk_PKCS12_SAFEBAG_pop_free(bags, PKCS12_SAFEBAG_free);
		return false;
	}
	sk_PKCS12_SAFEBAG_pop_free(bags, PKCS12_SAFEBAG_free);

	// On ajoute la cl�priv� au P7b
	if(!sk_PKCS7_push (safes, authsafe))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_GEN_P12);
		sk_PKCS7_pop_free(safes, PKCS7_free);
		sk_PKCS12_SAFEBAG_pop_free(bags, PKCS12_SAFEBAG_free);
		return false;
	}

	/*
		<- Gestion de la cl�priv�
	*/




	/*
		Cr�tion du P12
	*/
	p12 = PKCS12_init (NID_pkcs7_data);
	if(!p12)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_GEN_P12);
		sk_PKCS7_pop_free(safes, PKCS7_free);
		return false;
	}

	if(!PKCS12_pack_authsafes (p12, safes))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_GEN_P12);
		sk_PKCS7_pop_free(safes, PKCS7_free);
		return false;
	}
	sk_PKCS7_pop_free(safes, PKCS7_free);

	if(!PKCS12_set_mac (p12, passphrase, -1, NULL, 0, PKCS12_DEFAULT_ITER, NULL))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_GEN_P12);
		return false;
	}

	return true;	
}

bool PKI_PKCS12::PKCS12ToString()
{
	int n;
	unsigned char * b;
	unsigned char * p;

	n=i2d_PKCS12(p12,NULL);
	if(!n)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
		return false;
	}

	b=(unsigned char *)malloc(n);
	if (!b)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_MALLOC);
		return false;
	}
	p=b;

	n = i2d_PKCS12(p12,&p);
	if(!n)
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_GEN_P12);
		free(b);
		return false;
	}

	if(!Pkcs12Pem.FromDER(b, n))
	{
		free(b);
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
		return false;
	}
	free(b);

	return true;
}

const mString & PKI_PKCS12::GetPemPKCS12() const
{
	if(p12 && !Pkcs12Pem.size())
	{
		if(!((PKI_PKCS12*)this)->PKCS12ToString())
			NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
	}
	return Pkcs12Pem;
}


PKCS12 * PKI_PKCS12::GetPKCS12(bool Duplicate) const
{
	if(!p12)
		return NULL;
	
	if(Duplicate)
	{
		return (PKCS12*)ASN1_item_dup(ASN1_ITEM_rptr(PKCS12), (void*)p12);
	}

	return p12;
}


bool PKI_PKCS12::load_Datas(const PKCS12 * c_p12)
{
	Clear();

	if(!(p12 = (PKCS12*)ASN1_item_dup(ASN1_ITEM_rptr(PKCS12), (void*)c_p12)))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_MALLOC);
		return false;
	}
	
	if(!Private_Load(NULL, false))
	{
		NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_ABORT);
		return false;
	}
	
	return true;
}

bool PKI_PKCS12::give_Datas(PKCS12 **c_p12) const
{
	if(*c_p12)
		PKCS12_free(*c_p12);

	if(!p12)
	{
		*c_p12 = NULL;
	}
	else
	{
		*c_p12 = GetPKCS12(true);
		if(!*c_p12)
		{
			NEWPKIerr(CRYPTO_ERROR_TXT, ERROR_BAD_PARAM);
			return false;
		}
	}

	return true;
}

PKI_PKCS12::operator int() const
{
	return p12?1:0;
}

bool PKI_PKCS12::operator=(const PKI_PKCS12 & other)
{
	// Trying to copy myself on me
	if(&other.p12 == &p12)
		return true;

	Clear();
	if(!other.p12)
		return false;

	return load_Datas(other.p12);
}

