// internal/crypto/keys.go
package crypto

import (
	"crypto/rand"
	"encoding/hex"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"os"
	"sync"

	"github.com/nbd-wtf/go-nostr"
	"github.com/nbd-wtf/go-nostr/nip19" // Added import for proper NSEC decoding
	"golang.org/x/crypto/nacl/secretbox"
)

var (
	// ErrKeyNotFound is returned when a key for the specified pubkey cannot be found
	ErrKeyNotFound = errors.New("key not found")
	
	// ErrInvalidKey is returned when an invalid key is provided
	ErrInvalidKey = errors.New("invalid key")
	
	// ErrDecryptionFailed is returned when decryption of a private key fails
	ErrDecryptionFailed = errors.New("decryption failed")
)

// KeyStore manages encryption and storage of Nostr keys
type KeyStore struct {
	filePath string
	keys     map[string]string // pubkey -> encrypted privkey
	password []byte
	mu       sync.RWMutex
}

// NewKeyStore creates a new KeyStore
func NewKeyStore(filePath string, password string) (*KeyStore, error) {
	ks := &KeyStore{
		filePath: filePath,
		keys:     make(map[string]string),
		password: []byte(password),
	}
	
	// Try to load existing keys
	if _, err := os.Stat(filePath); err == nil {
		if err := ks.loadKeys(); err != nil {
			return nil, fmt.Errorf("failed to load keys: %w", err)
		}
	}
	
	return ks, nil
}

// loadKeys loads encrypted keys from the key file
func (ks *KeyStore) loadKeys() error {
	ks.mu.Lock()
	defer ks.mu.Unlock()
	
	data, err := os.ReadFile(ks.filePath)
	if err != nil {
		return err
	}
	
	return json.Unmarshal(data, &ks.keys)
}

// saveKeys saves the encrypted keys to the key file
func (ks *KeyStore) saveKeys() error {
	ks.mu.RLock()
	defer ks.mu.RUnlock()
	
	data, err := json.MarshalIndent(ks.keys, "", "  ")
	if err != nil {
		return err
	}
	
	return os.WriteFile(ks.filePath, data, 0600)
}

// GenerateKey generates a new Nostr keypair
func (ks *KeyStore) GenerateKey() (pubkey, privkey string, err error) {
	// Generate a new keypair
	sk := make([]byte, 32)
	if _, err := io.ReadFull(rand.Reader, sk); err != nil {
		return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
	}
	
	// Convert to hex strings
	privkey = hex.EncodeToString(sk)
	pub, err := nostr.GetPublicKey(privkey)
	if err != nil {
		return "", "", fmt.Errorf("failed to derive public key: %w", err)
	}
	
	// Encrypt and store the private key
	if err := ks.AddKey(pub, privkey); err != nil {
		return "", "", err
	}
	
	return pub, privkey, nil
}

// AddKey imports an existing private key
func (ks *KeyStore) AddKey(pubkey, privkey string) error {
	// Check if the private key is in NSEC format and decode it if needed
	if len(privkey) > 4 && privkey[:4] == "nsec" {
		var err error
		privkey, err = decodePrivateKey(privkey)
		if err != nil {
			return fmt.Errorf("failed to decode nsec key: %w", err)
		}
	}
	
	// Validate the key pair
	derivedPub, err := nostr.GetPublicKey(privkey)
	if err != nil {
		return ErrInvalidKey
	}
	
	if derivedPub != pubkey {
		return errors.New("public key does not match private key")
	}
	
	// Encrypt the private key
	encPrivkey, err := ks.encryptKey(privkey)
	if err != nil {
		return err
	}
	
	// Store the encrypted key
	ks.mu.Lock()
	ks.keys[pubkey] = encPrivkey
	ks.mu.Unlock()
	
	// Save to disk
	return ks.saveKeys()
}

// ImportKey imports a key from nsec or hex format
func (ks *KeyStore) ImportKey(keyStr string) (string, error) {
	var privkeyHex string
	var err error
	
	// Check if it's an nsec key
	if len(keyStr) > 4 && keyStr[:4] == "nsec" {
		privkeyHex, err = decodePrivateKey(keyStr)
		if err != nil {
			return "", fmt.Errorf("invalid nsec key: %w", err)
		}
	} else {
		// Assume it's a hex key
		if len(keyStr) != 64 {
			return "", ErrInvalidKey
		}
		privkeyHex = keyStr
	}
	
	// Derive the public key
	pubkey, err := nostr.GetPublicKey(privkeyHex)
	if err != nil {
		return "", fmt.Errorf("invalid private key: %w", err)
	}
	
	// Add the key to the store
	if err := ks.AddKey(pubkey, privkeyHex); err != nil {
		return "", err
	}
	
	return pubkey, nil
}

// GetPrivateKey retrieves and decrypts a private key for the given pubkey
func (ks *KeyStore) GetPrivateKey(pubkey string) (string, error) {
	ks.mu.RLock()
	encryptedKey, exists := ks.keys[pubkey]
	ks.mu.RUnlock()
	
	if !exists {
		return "", ErrKeyNotFound
	}
	
	return ks.decryptKey(encryptedKey)
}

// RemoveKey removes a key from the store
func (ks *KeyStore) RemoveKey(pubkey string) error {
	ks.mu.Lock()
	delete(ks.keys, pubkey)
	ks.mu.Unlock()
	
	return ks.saveKeys()
}

// ListKeys returns a list of all stored public keys
func (ks *KeyStore) ListKeys() []string {
	ks.mu.RLock()
	defer ks.mu.RUnlock()
	
	keys := make([]string, 0, len(ks.keys))
	for k := range ks.keys {
		keys = append(keys, k)
	}
	return keys
}

// DecodeNsecKey decodes an nsec key to hex format
func (ks *KeyStore) DecodeNsecKey(nsecKey string) (string, error) {
	return decodePrivateKey(nsecKey)
}

// encryptKey encrypts a private key using the store's password
func (ks *KeyStore) encryptKey(privkey string) (string, error) {
	// Generate a random nonce
	var nonce [24]byte
	if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
		return "", fmt.Errorf("failed to generate nonce: %w", err)
	}
	
	// Create a key from the password
	var key [32]byte
	copy(key[:], ks.password)
	
	// Encrypt the private key
	privkeyBytes := []byte(privkey)
	encrypted := secretbox.Seal(nonce[:], privkeyBytes, &nonce, &key)
	
	return hex.EncodeToString(encrypted), nil
}

// decryptKey decrypts an encrypted private key
func (ks *KeyStore) decryptKey(encryptedKey string) (string, error) {
	encryptedData, err := hex.DecodeString(encryptedKey)
	if err != nil {
		return "", fmt.Errorf("invalid encrypted key format: %w", err)
	}
	
	// The first 24 bytes are the nonce
	if len(encryptedData) < 24 {
		return "", ErrDecryptionFailed
	}
	
	var nonce [24]byte
	copy(nonce[:], encryptedData[:24])
	
	// Create a key from the password
	var key [32]byte
	copy(key[:], ks.password)
	
	// Decrypt the private key
	decryptedData, ok := secretbox.Open(nil, encryptedData[24:], &nonce, &key)
	if !ok {
		return "", ErrDecryptionFailed
	}
	
	return string(decryptedData), nil
}

// ChangePassword changes the password used for key encryption
func (ks *KeyStore) ChangePassword(newPassword string) error {
	ks.mu.Lock()
	defer ks.mu.Unlock()
	
	// Create a copy of current keys
	oldKeys := make(map[string]string)
	for k, v := range ks.keys {
		oldKeys[k] = v
	}
	
	// Change the password
	oldPassword := ks.password
	ks.password = []byte(newPassword)
	
	// Re-encrypt all keys with the new password
	for pubkey, encryptedKey := range oldKeys {
		// Get the plaintext private key using the old password
		ks.password = oldPassword
		privkey, err := ks.decryptKey(encryptedKey)
		if err != nil {
			// Restore the old password and return
			ks.password = oldPassword
			return fmt.Errorf("failed to decrypt key: %w", err)
		}
		
		// Encrypt it with the new password
		ks.password = []byte(newPassword)
		newEncryptedKey, err := ks.encryptKey(privkey)
		if err != nil {
			// Restore the old password and return
			ks.password = oldPassword
			return fmt.Errorf("failed to re-encrypt key: %w", err)
		}
		
		// Update the key in the store
		ks.keys[pubkey] = newEncryptedKey
	}
	
	// Save the re-encrypted keys
	return ks.saveKeys()
}

// decodePrivateKey decodes an nsec private key using proper NIP-19 format
func decodePrivateKey(nsecKey string) (string, error) {
	if len(nsecKey) < 4 || nsecKey[:4] != "nsec" {
		return "", fmt.Errorf("invalid nsec key")
	}
	
	// Use the proper NIP-19 decoder
	prefix, data, err := nip19.Decode(nsecKey)
	if err != nil {
		return "", fmt.Errorf("failed to decode nsec key: %w", err)
	}
	
	// Verify that we got the right type of key
	if prefix != "nsec" {
		return "", fmt.Errorf("expected nsec prefix, got %s", prefix)
	}
	
	// For nsec keys, data should be a hex string of the private key
	hexKey, ok := data.(string)
	if !ok {
		return "", fmt.Errorf("invalid nsec data format")
	}
	
	return hexKey, nil
}