// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package authenticode

import (
	"encoding/hex"
	"errors"
	"fmt"
	"path/filepath"
	"strings"
	"unsafe"

	"github.com/dblohm7/wingoes"
	"golang.org/x/sys/windows"
)

var (
	// ErrSigNotFound is returned if no authenticode signature could be found.
	ErrSigNotFound = errors.New("authenticode signature not found")
	// ErrUnexpectedCertSubject is wrapped with the actual cert subject and
	// returned when the binary is signed by a different subject than expected.
	ErrUnexpectedCertSubject = errors.New("unexpected cert subject")
	errCertSubjectNotFound = errors.New("cert subject not found")
	errCertSubjectDecodeLenMismatch = errors.New("length mismatch while decoding cert subject")
)

const (
	_CERT_STRONG_SIGN_OID_INFO_CHOICE = 2
	_CMSG_SIGNER_CERT_INFO_PARAM      = 7
	_MSI_INVALID_HASH_IS_FATAL        = 1
	_TRUST_E_NOSIGNATURE              = wingoes.HRESULT(-((0x800B0100 ^ 0xFFFFFFFF) + 1))
)

// Verify performs authenticode verification on the file at path, and also
// ensures that expectedCertSubject was the entity who signed it. path may point
// to either a PE binary or an MSI package. ErrSigNotFound is returned if no
// signature is found.
func Verify(path string, expectedCertSubject string) error {
	path16, err := windows.UTF16PtrFromString(path)
	if err != nil {
		return err
	}

	var subject string
	if strings.EqualFold(filepath.Ext(path), ".msi") {
		subject, err = verifyMSI(path16)
	} else {
		subject, _, err = queryPE(path16, true)
	}

	if err != nil {
		return err
	}

	if subject != expectedCertSubject {
		return fmt.Errorf("%w %q", ErrUnexpectedCertSubject, subject)
	}

	return nil
}

// SigProvenance indicates whether an authenticode signature was embedded within
// the file itself, or the signature applies to an associated catalog file.
type SigProvenance int

const (
	SigProvUnknown = SigProvenance(iota)
	SigProvEmbedded
	SigProvCatalog
)

// QueryCertSubject obtains the subject associated with the certificate used to
// sign the PE binary located at path. When err == nil, it also returns the
// provenance of that signature. ErrSigNotFound is returned if no signature
// is found. Note that this function does *not* validate the chain of trust; use
// Verify for that purpose!
func QueryCertSubject(path string) (certSubject string, provenance SigProvenance, err error) {
	path16, err := windows.UTF16PtrFromString(path)
	if err != nil {
		return "", SigProvUnknown, err
	}

	return queryPE(path16, false)
}

func queryPE(utf16Path *uint16, verify bool) (string, SigProvenance, error) {
	certSubject, err := queryEmbeddedCertSubject(utf16Path, verify)
	switch err {
	case nil:
		return certSubject, SigProvEmbedded, nil
	case ErrSigNotFound:
		// Try looking for the signature in a catalog file.
	default:
		return certSubject, SigProvEmbedded, err
	}

	certSubject, err = queryCatalogCertSubject(utf16Path, verify)
	if err != nil {
		if err == ErrSigNotFound {
			return "", SigProvUnknown, err
		}
		return certSubject, SigProvCatalog, err
	}

	return certSubject, SigProvCatalog, nil
}

func verifyMSI(path *uint16) (string, error) {
	var certCtx *windows.CertContext
	hr := msiGetFileSignatureInformation(path, _MSI_INVALID_HASH_IS_FATAL, &certCtx, nil, nil)
	if e := wingoes.ErrorFromHRESULT(hr); e.Failed() {
		if e == wingoes.ErrorFromHRESULT(_TRUST_E_NOSIGNATURE) {
			return "", ErrSigNotFound
		}
		return "", e
	}
	defer windows.CertFreeCertificateContext(certCtx)

	return certSubjectFromCertContext(certCtx)
}

func certSubjectFromCertContext(certCtx *windows.CertContext) (string, error) {
	desiredLen := windows.CertGetNameString(
		certCtx,
		windows.CERT_NAME_SIMPLE_DISPLAY_TYPE,
		0,
		nil,
		nil,
		0,
	)
	if desiredLen <= 1 {
		return "", errCertSubjectNotFound
	}

	buf := make([]uint16, desiredLen)
	actualLen := windows.CertGetNameString(
		certCtx,
		windows.CERT_NAME_SIMPLE_DISPLAY_TYPE,
		0,
		nil,
		&buf[0],
		desiredLen,
	)
	if actualLen != desiredLen {
		return "", errCertSubjectDecodeLenMismatch
	}

	return windows.UTF16ToString(buf), nil
}

type objectQuery struct {
	certStore    windows.Handle
	cryptMsg     windows.Handle
	encodingType uint32
}

func newObjectQuery(utf16Path *uint16) (*objectQuery, error) {
	var oq objectQuery
	if err := windows.CryptQueryObject(
		windows.CERT_QUERY_OBJECT_FILE,
		unsafe.Pointer(utf16Path),
		windows.CERT_QUERY_CONTENT_FLAG_PKCS7_SIGNED_EMBED,
		windows.CERT_QUERY_FORMAT_FLAG_BINARY,
		0,
		&oq.encodingType,
		nil,
		nil,
		&oq.certStore,
		&oq.cryptMsg,
		nil,
	); err != nil {
		return nil, err
	}

	return &oq, nil
}

func (oq *objectQuery) Close() error {
	if oq.certStore != 0 {
		if err := windows.CertCloseStore(oq.certStore, 0); err != nil {
			return err
		}
		oq.certStore = 0
	}

	if oq.cryptMsg != 0 {
		if err := cryptMsgClose(oq.cryptMsg); err != nil {
			return err
		}
		oq.cryptMsg = 0
	}

	return nil
}

func (oq *objectQuery) certSubject() (string, error) {
	var certInfoLen uint32
	if err := cryptMsgGetParam(
		oq.cryptMsg,
		_CMSG_SIGNER_CERT_INFO_PARAM,
		0,
		unsafe.Pointer(nil),
		&certInfoLen,
	); err != nil {
		return "", err
	}

	buf := make([]byte, certInfoLen)
	if err := cryptMsgGetParam(
		oq.cryptMsg,
		_CMSG_SIGNER_CERT_INFO_PARAM,
		0,
		unsafe.Pointer(&buf[0]),
		&certInfoLen,
	); err != nil {
		return "", err
	}

	certInfo := (*windows.CertInfo)(unsafe.Pointer(&buf[0]))
	certCtx, err := windows.CertFindCertificateInStore(
		oq.certStore,
		oq.encodingType,
		0,
		windows.CERT_FIND_SUBJECT_CERT,
		unsafe.Pointer(certInfo),
		nil,
	)
	if err != nil {
		return "", err
	}
	defer windows.CertFreeCertificateContext(certCtx)

	return certSubjectFromCertContext(certCtx)
}

func queryEmbeddedCertSubject(utf16Path *uint16, verify bool) (string, error) {
	oq, err := newObjectQuery(utf16Path)
	if err != nil {
		if e, ok := err.(windows.Errno); ok && uint32(e) == uint32(windows.CRYPT_E_NO_MATCH) {
			err = ErrSigNotFound
		}
		return "", err
	}
	defer oq.Close()

	certSubj, err := oq.certSubject()
	if err != nil {
		return "", err
	}

	if !verify {
		return certSubj, nil
	}

	wintrustArg := unsafe.Pointer(&windows.WinTrustFileInfo{
		Size:     uint32(unsafe.Sizeof(windows.WinTrustFileInfo{})),
		FilePath: utf16Path,
	})
	err = verifyTrust(windows.WTD_CHOICE_FILE, wintrustArg)
	// Note that we intentionally return certSubj even when err != nil; the idea
	// is that we might still want to know who the cert subject claims to be,
	// even if the validation has failed (eg for troubleshooting purposes).
	return certSubj, err
}

var (
	_BCRYPT_SHA256_ALGORITHM   = &([]uint16{'S', 'H', 'A', '2', '5', '6', 0})[0]
	_OID_CERT_STRONG_SIGN_OS_1 = &([]byte("1.3.6.1.4.1.311.72.1.1\x00"))[0]
)

type _HCATADMIN windows.Handle
type _HCATINFO windows.Handle

type _CATALOG_INFO struct {
	size        uint32
	catalogFile [windows.MAX_PATH]uint16
}

type _WINTRUST_CATALOG_INFO struct {
	size                 uint32
	catalogVersion       uint32
	catalogFilePath      *uint16
	memberTag            *uint16
	memberFilePath       *uint16
	memberFile           windows.Handle
	pCalculatedFileHash  *byte
	cbCalculatedFileHash uint32
	catalogContext       uintptr
	catAdmin             _HCATADMIN
}

func queryCatalogCertSubject(utf16Path *uint16, verify bool) (string, error) {
	var catAdmin _HCATADMIN
	policy := windows.CertStrongSignPara{
		Size:                      uint32(unsafe.Sizeof(windows.CertStrongSignPara{})),
		InfoChoice:                _CERT_STRONG_SIGN_OID_INFO_CHOICE,
		InfoOrSerializedInfoOrOID: unsafe.Pointer(_OID_CERT_STRONG_SIGN_OS_1),
	}
	if err := cryptCATAdminAcquireContext2(
		&catAdmin,
		nil,
		_BCRYPT_SHA256_ALGORITHM,
		&policy,
		0,
	); err != nil {
		return "", err
	}
	defer cryptCATAdminReleaseContext(catAdmin, 0)

	// We use windows.CreateFile instead of standard library facilities because:
	// 1. Subsequent API calls directly utilize the file's Win32 HANDLE;
	// 2. We're going to be hashing the contents of this file, so we want to
	//    provide a sequential-scan hint to the kernel.
	memberFile, err := windows.CreateFile(
		utf16Path,
		windows.GENERIC_READ,
		windows.FILE_SHARE_READ,
		nil,
		windows.OPEN_EXISTING,
		windows.FILE_FLAG_SEQUENTIAL_SCAN,
		0,
	)
	if err != nil {
		return "", err
	}
	defer windows.CloseHandle(memberFile)

	var hashLen uint32
	if err := cryptCATAdminCalcHashFromFileHandle2(
		catAdmin,
		memberFile,
		&hashLen,
		nil,
		0,
	); err != nil {
		return "", err
	}

	hashBuf := make([]byte, hashLen)
	if err := cryptCATAdminCalcHashFromFileHandle2(
		catAdmin,
		memberFile,
		&hashLen,
		&hashBuf[0],
		0,
	); err != nil {
		return "", err
	}

	catInfoCtx, err := cryptCATAdminEnumCatalogFromHash(
		catAdmin,
		&hashBuf[0],
		hashLen,
		0,
		nil,
	)
	if err != nil {
		if err == windows.ERROR_NOT_FOUND {
			err = ErrSigNotFound
		}
		return "", err
	}
	defer cryptCATAdminReleaseCatalogContext(catAdmin, catInfoCtx, 0)

	catInfo := _CATALOG_INFO{
		size: uint32(unsafe.Sizeof(_CATALOG_INFO{})),
	}
	if err := cryptCATAdminCatalogInfoFromContext(catInfoCtx, &catInfo, 0); err != nil {
		return "", err
	}

	oq, err := newObjectQuery(&catInfo.catalogFile[0])
	if err != nil {
		return "", err
	}
	defer oq.Close()

	certSubj, err := oq.certSubject()
	if err != nil {
		return "", err
	}

	if !verify {
		return certSubj, nil
	}

	// memberTag is required to be formatted this way.
	hbh := strings.ToUpper(hex.EncodeToString(hashBuf))
	memberTag, err := windows.UTF16PtrFromString(hbh)
	if err != nil {
		return "", err
	}

	wintrustArg := unsafe.Pointer(&_WINTRUST_CATALOG_INFO{
		size:            uint32(unsafe.Sizeof(_WINTRUST_CATALOG_INFO{})),
		catalogFilePath: &catInfo.catalogFile[0],
		memberTag:       memberTag,
		memberFilePath:  utf16Path,
		memberFile:      memberFile,
		catAdmin:        catAdmin,
	})
	err = verifyTrust(windows.WTD_CHOICE_CATALOG, wintrustArg)
	// Note that we intentionally return certSubj even when err != nil; the idea
	// is that we might still want to know who the cert subject claims to be,
	// even if the validation has failed (eg for troubleshooting purposes).
	return certSubj, err
}

func verifyTrust(infoType uint32, info unsafe.Pointer) error {
	data := &windows.WinTrustData{
		Size:                            uint32(unsafe.Sizeof(windows.WinTrustData{})),
		UIChoice:                        windows.WTD_UI_NONE,
		RevocationChecks:                windows.WTD_REVOKE_WHOLECHAIN, // Full revocation checking, as this is called with network connectivity.
		UnionChoice:                     infoType,
		StateAction:                     windows.WTD_STATEACTION_VERIFY,
		FileOrCatalogOrBlobOrSgnrOrCert: info,
	}
	err := windows.WinVerifyTrustEx(windows.InvalidHWND, &windows.WINTRUST_ACTION_GENERIC_VERIFY_V2, data)

	data.StateAction = windows.WTD_STATEACTION_CLOSE
	windows.WinVerifyTrustEx(windows.InvalidHWND, &windows.WINTRUST_ACTION_GENERIC_VERIFY_V2, data)

	return err
}