package ntlmssp

import (
	"testing"

	"github.com/stretchr/testify/assert"
)

func TestLMOWFv1(t *testing.T) {
	tables := []struct {
		got  string
		want []byte
		err  error
	}{
		{
			"SecREt01",
			[]byte{
				0xff, 0x37, 0x50, 0xbc, 0xc2, 0xb2, 0x24, 0x12,
				0xc2, 0x26, 0x5b, 0x23, 0x73, 0x4e, 0x0d, 0xac,
			},
			nil,
		},
		{
			"secret01",
			[]byte{
				0xff, 0x37, 0x50, 0xbc, 0xc2, 0xb2, 0x24, 0x12,
				0xc2, 0x26, 0x5b, 0x23, 0x73, 0x4e, 0x0d, 0xac,
			},
			nil,
		},
		{
			"SECRET01",
			[]byte{
				0xff, 0x37, 0x50, 0xbc, 0xc2, 0xb2, 0x24, 0x12,
				0xc2, 0x26, 0x5b, 0x23, 0x73, 0x4e, 0x0d, 0xac,
			},
			nil,
		},
	}

	for _, table := range tables {
		got, err := lmowfV1(table.got)
		assert.Equal(t, table.want, got)
		assert.Equal(t, table.err, err)
	}
}

func TestNTOWFv1(t *testing.T) {
	tables := []struct {
		got  string
		want []byte
		err  error
	}{
		{
			"SecREt01",
			[]byte{
				0xcd, 0x06, 0xca, 0x7c, 0x7e, 0x10, 0xc9, 0x9b,
				0x1d, 0x33, 0xb7, 0x48, 0x5a, 0x2e, 0xd8, 0x08,
			},
			nil,
		},
	}

	for _, table := range tables {
		got, err := ntowfV1(table.got)
		assert.Equal(t, table.want, got)
		assert.Equal(t, table.err, err)
	}
}

func TestNTOWFv2(t *testing.T) {
	tables := []struct {
		username, password, domain string
		want                       []byte
		err                        error
	}{
		{
			"test",
			"test1234",
			"TESTNT",
			[]byte{
				0xc4, 0xea, 0x95, 0xcb, 0x14, 0x8d, 0xf1, 0x1b,
				0xf9, 0xd7, 0xc3, 0x61, 0x1a, 0xd6, 0xd7, 0x22,
			},
			nil,
		},
	}

	for _, table := range tables {
		got, err := ntowfV2(table.username, table.password, table.domain)
		assert.Equal(t, table.want, got)
		assert.Equal(t, table.err, err)
	}
}

// Taken from MS-NLMP

func TestLMv1WithSessionSecurityResponse(t *testing.T) {
	tables := []struct {
		got, want []byte
	}{
		{
			[]byte{
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
			},
			[]byte{
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
			},
		},
	}

	for _, table := range tables {
		assert.Equal(t, table.want, lmV1WithSessionSecurityResponse(table.got))
	}
}

func TestLMv1Response(t *testing.T) {
	tables := []struct {
		password        string
		serverChallenge []byte
		want            []byte
		err             error
	}{
		{
			"Password",
			[]byte{
				0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
			},
			[]byte{
				0x98, 0xde, 0xf7, 0xb8, 0x7f, 0x88, 0xaa, 0x5d,
				0xaf, 0xe2, 0xdf, 0x77, 0x96, 0x88, 0xa1, 0x72,
				0xde, 0xf1, 0x1c, 0x7d, 0x5c, 0xcd, 0xef, 0x13,
			},
			nil,
		},
	}

	for _, table := range tables {
		want, err := lmV1Response(table.password, table.serverChallenge)
		assert.Equal(t, table.err, err)
		if err == nil {
			assert.Equal(t, table.want, want)
		}
	}
}

func TestLMv2Response(t *testing.T) {
	tables := []struct {
		username, password, domain       string
		serverChallenge, clientChallenge []byte
		want                             []byte
		err                              error
	}{
		{
			"User",
			"Password",
			"Domain",
			[]byte{
				0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
			},
			[]byte{
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
			},
			[]byte{
				0x86, 0xc3, 0x50, 0x97, 0xac, 0x9c, 0xec, 0x10,
				0x25, 0x54, 0x76, 0x4a, 0x57, 0xcc, 0xcc, 0x19,
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
			},
			nil,
		},
	}

	for _, table := range tables {
		want, err := lmV2Response(table.username, table.password, table.domain, table.serverChallenge, table.clientChallenge)
		assert.Equal(t, table.err, err)
		if err == nil {
			assert.Equal(t, table.want, want)
		}
	}
}

func TestNTLMv1Response(t *testing.T) {
	tables := []struct {
		password        string
		serverChallenge []byte
		want            []byte
		sessionBaseKey  []byte
		err             error
	}{
		{
			"Password",
			[]byte{
				0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
			},
			[]byte{
				0x67, 0xc4, 0x30, 0x11, 0xf3, 0x02, 0x98, 0xa2,
				0xad, 0x35, 0xec, 0xe6, 0x4f, 0x16, 0x33, 0x1c,
				0x44, 0xbd, 0xbe, 0xd9, 0x27, 0x84, 0x1f, 0x94,
			},
			[]byte{
				0xd8, 0x72, 0x62, 0xb0, 0xcd, 0xe4, 0xb1, 0xcb,
				0x74, 0x99, 0xbe, 0xcc, 0xcd, 0xf1, 0x07, 0x84,
			},
			nil,
		},
	}

	for _, table := range tables {
		want, sessionBaseKey, err := ntlmV1Response(table.password, table.serverChallenge)
		assert.Equal(t, table.err, err)
		if err == nil {
			assert.Equal(t, table.want, want)
			assert.Equal(t, table.sessionBaseKey, sessionBaseKey)
		}
	}
}

func TestNTLM2Response(t *testing.T) {
	tables := []struct {
		password                         string
		serverChallenge, clientChallenge []byte
		want                             []byte
		sessionBaseKey                   []byte
		err                              error
	}{
		{
			"Password",
			[]byte{
				0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
			},
			[]byte{
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
			},
			[]byte{
				0x75, 0x37, 0xf8, 0x03, 0xae, 0x36, 0x71, 0x28,
				0xca, 0x45, 0x82, 0x04, 0xbd, 0xe7, 0xca, 0xf8,
				0x1e, 0x97, 0xed, 0x26, 0x83, 0x26, 0x72, 0x32,
			},
			[]byte{
				0xd8, 0x72, 0x62, 0xb0, 0xcd, 0xe4, 0xb1, 0xcb,
				0x74, 0x99, 0xbe, 0xcc, 0xcd, 0xf1, 0x07, 0x84,
			},
			nil,
		},
	}

	for _, table := range tables {
		want, sessionBaseKey, err := ntlm2Response(table.password, table.serverChallenge, table.clientChallenge)
		assert.Equal(t, table.err, err)
		if err == nil {
			assert.Equal(t, table.want, want)
			assert.Equal(t, table.sessionBaseKey, sessionBaseKey)
		}
	}
}

func TestNTLMv2Temp(t *testing.T) {
	tables := []struct {
		timestamp       []byte
		clientChallenge []byte
		targetInfo      targetInfo
		want            []byte
		err             error
	}{
		{
			[]byte{
				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
			},
			[]byte{
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
			},
			targetInfo{
				map[avID][]uint8{
					msvAvNbComputerName: {
						0x53, 0x00, 0x65, 0x00, 0x72, 0x00, 0x76, 0x00,
						0x65, 0x00, 0x72, 0x00,
					},
					msvAvNbDomainName: {
						0x44, 0x00, 0x6f, 0x00, 0x6d, 0x00, 0x61, 0x00,
						0x69, 0x00, 0x6e, 0x00,
					},
				},
				[]avID{
					msvAvNbDomainName,
					msvAvNbComputerName,
				},
			},
			[]byte{
				0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
				0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x0c, 0x00,
				0x44, 0x00, 0x6f, 0x00, 0x6d, 0x00, 0x61, 0x00,
				0x69, 0x00, 0x6e, 0x00, 0x01, 0x00, 0x0c, 0x00,
				0x53, 0x00, 0x65, 0x00, 0x72, 0x00, 0x76, 0x00,
				0x65, 0x00, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00,
				0x00, 0x00, 0x00, 0x00,
			},
			nil,
		},
	}

	for _, table := range tables {
		want, err := ntlmV2Temp(table.timestamp, table.clientChallenge, table.targetInfo)
		assert.Equal(t, table.err, err)
		if err == nil {
			assert.Equal(t, table.want, want)
		}
	}
}

func TestNTLMv2Response(t *testing.T) {
	tables := []struct {
		username, password, domain                  string
		serverChallenge, clientChallenge, timestamp []byte
		targetInfo                                  targetInfo
		want                                        []byte
		keyExchangeKey                              []byte
		err                                         error
	}{
		{
			"User",
			"Password",
			"Domain",
			[]byte{
				0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
			},
			[]byte{
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
			},
			[]byte{
				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
			},
			targetInfo{
				map[avID][]uint8{
					msvAvNbComputerName: {
						0x53, 0x00, 0x65, 0x00, 0x72, 0x00, 0x76, 0x00,
						0x65, 0x00, 0x72, 0x00,
					},
					msvAvNbDomainName: {
						0x44, 0x00, 0x6f, 0x00, 0x6d, 0x00, 0x61, 0x00,
						0x69, 0x00, 0x6e, 0x00,
					},
				},
				[]avID{
					msvAvNbDomainName,
					msvAvNbComputerName,
				},
			},
			[]byte{
				0x68, 0xcd, 0x0a, 0xb8, 0x51, 0xe5, 0x1c, 0x96,
				0xaa, 0xbc, 0x92, 0x7b, 0xeb, 0xef, 0x6a, 0x1c,
				0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
				0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x0c, 0x00,
				0x44, 0x00, 0x6f, 0x00, 0x6d, 0x00, 0x61, 0x00,
				0x69, 0x00, 0x6e, 0x00, 0x01, 0x00, 0x0c, 0x00,
				0x53, 0x00, 0x65, 0x00, 0x72, 0x00, 0x76, 0x00,
				0x65, 0x00, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00,
				0x00, 0x00, 0x00, 0x00,
			},
			[]byte{
				0x8d, 0xe4, 0x0c, 0xca, 0xdb, 0xc1, 0x4a, 0x82,
				0xf1, 0x5c, 0xb0, 0xad, 0x0d, 0xe9, 0x5c, 0xa3,
			},
			nil,
		},
	}

	for _, table := range tables {
		want, keyExchangeKey, err := ntlmV2Response(table.username, table.password, table.domain, table.serverChallenge, table.clientChallenge, table.timestamp, table.targetInfo)
		assert.Equal(t, table.err, err)
		if err == nil {
			assert.Equal(t, table.want, want)
			assert.Equal(t, table.keyExchangeKey, keyExchangeKey)
		}
	}
}

//func lmChallengeResponse(flags uint32, level LmCompatibilityLevel, clientChallenge []byte, username, password, domain string, cm *challengeMessage) ([]byte, error) {
func TestLmChallengeResponse(t *testing.T) {
	tables := []struct {
		flags                      uint32
		level                      lmCompatibilityLevel
		clientChallenge            []byte
		username, password, domain string
		challengeMessage           *challengeMessage
		want                       []byte
		err                        error
	}{
		{
			ntlmsspNegotiateExtendedSessionsecurity.Set(0),
			2,
			[]byte{
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
			},
			"User",
			"Password",
			"Domain",
			&challengeMessage{},
			[]byte{
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
			},
			nil,
		},
		{
			0,
			1,
			[]byte{
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
			},
			"User",
			"Password",
			"Domain",
			&challengeMessage{
				challengeMessageFields: challengeMessageFields{
					ServerChallenge: [8]byte{
						0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
					},
				},
			},
			[]byte{
				0x98, 0xde, 0xf7, 0xb8, 0x7f, 0x88, 0xaa, 0x5d,
				0xaf, 0xe2, 0xdf, 0x77, 0x96, 0x88, 0xa1, 0x72,
				0xde, 0xf1, 0x1c, 0x7d, 0x5c, 0xcd, 0xef, 0x13,
			},
			nil,
		},
		{
			0,
			2,
			[]byte{
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
			},
			"User",
			"Password",
			"Domain",
			&challengeMessage{
				challengeMessageFields: challengeMessageFields{
					ServerChallenge: [8]byte{
						0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
					},
				},
			},
			[]byte{
				0x67, 0xc4, 0x30, 0x11, 0xf3, 0x02, 0x98, 0xa2,
				0xad, 0x35, 0xec, 0xe6, 0x4f, 0x16, 0x33, 0x1c,
				0x44, 0xbd, 0xbe, 0xd9, 0x27, 0x84, 0x1f, 0x94,
			},
			nil,
		},
		{
			0,
			3,
			[]byte{
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
			},
			"User",
			"Password",
			"Domain",
			&challengeMessage{
				challengeMessageFields: challengeMessageFields{
					ServerChallenge: [8]byte{
						0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
					},
				},
				TargetInfo: targetInfo{
					map[avID][]uint8{
						msvAvTimestamp: {
							0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
						},
					},
					[]avID{
						msvAvTimestamp,
					},
				},
			},
			[]byte{
				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
				0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
			},
			nil,
		},
		{
			0,
			3,
			[]byte{
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
			},
			"User",
			"Password",
			"Domain",
			&challengeMessage{
				challengeMessageFields: challengeMessageFields{
					ServerChallenge: [8]byte{
						0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
					},
				},
			},
			[]byte{
				0x86, 0xc3, 0x50, 0x97, 0xac, 0x9c, 0xec, 0x10,
				0x25, 0x54, 0x76, 0x4a, 0x57, 0xcc, 0xcc, 0x19,
				0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
			},
			nil,
		},
	}

	for _, table := range tables {
		want, err := lmChallengeResponse(table.flags, table.level, table.clientChallenge, table.username, table.password, table.domain, table.challengeMessage)
		assert.Equal(t, table.err, err)
		if err == nil {
			assert.Equal(t, table.want, want)
		}
	}
}
