307 lines
7.2 KiB
Go

// internal/crypto/keys.go
package crypto
import (
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"sync"
"github.com/nbd-wtf/go-nostr"
"golang.org/x/crypto/nacl/secretbox"
)
var (
// ErrKeyNotFound is returned when a key for the specified pubkey cannot be found
ErrKeyNotFound = errors.New("key not found")
// ErrInvalidKey is returned when an invalid key is provided
ErrInvalidKey = errors.New("invalid key")
// ErrDecryptionFailed is returned when decryption of a private key fails
ErrDecryptionFailed = errors.New("decryption failed")
)
// KeyStore manages encryption and storage of Nostr keys
type KeyStore struct {
filePath string
keys map[string]string // pubkey -> encrypted privkey
password []byte
mu sync.RWMutex
}
// NewKeyStore creates a new KeyStore
func NewKeyStore(filePath string, password string) (*KeyStore, error) {
ks := &KeyStore{
filePath: filePath,
keys: make(map[string]string),
password: []byte(password),
}
// Try to load existing keys
if _, err := os.Stat(filePath); err == nil {
if err := ks.loadKeys(); err != nil {
return nil, fmt.Errorf("failed to load keys: %w", err)
}
}
return ks, nil
}
// loadKeys loads encrypted keys from the key file
func (ks *KeyStore) loadKeys() error {
ks.mu.Lock()
defer ks.mu.Unlock()
data, err := os.ReadFile(ks.filePath)
if err != nil {
return err
}
return json.Unmarshal(data, &ks.keys)
}
// saveKeys saves the encrypted keys to the key file
func (ks *KeyStore) saveKeys() error {
ks.mu.RLock()
defer ks.mu.RUnlock()
data, err := json.MarshalIndent(ks.keys, "", " ")
if err != nil {
return err
}
return os.WriteFile(ks.filePath, data, 0600)
}
// GenerateKey generates a new Nostr keypair
func (ks *KeyStore) GenerateKey() (pubkey, privkey string, err error) {
// Generate a new keypair
sk := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, sk); err != nil {
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
}
// Convert to hex strings
privkey = hex.EncodeToString(sk)
pub, err := nostr.GetPublicKey(privkey)
if err != nil {
return "", "", fmt.Errorf("failed to derive public key: %w", err)
}
// Encrypt and store the private key
if err := ks.AddKey(pub, privkey); err != nil {
return "", "", err
}
return pub, privkey, nil
}
// AddKey imports an existing private key
func (ks *KeyStore) AddKey(pubkey, privkey string) error {
// Validate the key pair
derivedPub, err := nostr.GetPublicKey(privkey)
if err != nil {
return ErrInvalidKey
}
if derivedPub != pubkey {
return errors.New("public key does not match private key")
}
// Encrypt the private key
encPrivkey, err := ks.encryptKey(privkey)
if err != nil {
return err
}
// Store the encrypted key
ks.mu.Lock()
ks.keys[pubkey] = encPrivkey
ks.mu.Unlock()
// Save to disk
return ks.saveKeys()
}
// ImportKey imports a key from nsec or hex format
func (ks *KeyStore) ImportKey(keyStr string) (string, error) {
var privkeyHex string
var err error
// Check if it's an nsec key
if len(keyStr) > 4 && keyStr[:4] == "nsec" {
privkeyHex, err = decodePrivateKey(keyStr)
if err != nil {
return "", fmt.Errorf("invalid nsec key: %w", err)
}
} else {
// Assume it's a hex key
if len(keyStr) != 64 {
return "", ErrInvalidKey
}
privkeyHex = keyStr
}
// Derive the public key
pubkey, err := nostr.GetPublicKey(privkeyHex)
if err != nil {
return "", fmt.Errorf("invalid private key: %w", err)
}
// Add the key to the store
if err := ks.AddKey(pubkey, privkeyHex); err != nil {
return "", err
}
return pubkey, nil
}
// GetPrivateKey retrieves and decrypts a private key for the given pubkey
func (ks *KeyStore) GetPrivateKey(pubkey string) (string, error) {
ks.mu.RLock()
encryptedKey, exists := ks.keys[pubkey]
ks.mu.RUnlock()
if !exists {
return "", ErrKeyNotFound
}
return ks.decryptKey(encryptedKey)
}
// RemoveKey removes a key from the store
func (ks *KeyStore) RemoveKey(pubkey string) error {
ks.mu.Lock()
delete(ks.keys, pubkey)
ks.mu.Unlock()
return ks.saveKeys()
}
// ListKeys returns a list of all stored public keys
func (ks *KeyStore) ListKeys() []string {
ks.mu.RLock()
defer ks.mu.RUnlock()
keys := make([]string, 0, len(ks.keys))
for k := range ks.keys {
keys = append(keys, k)
}
return keys
}
// encryptKey encrypts a private key using the store's password
func (ks *KeyStore) encryptKey(privkey string) (string, error) {
// Generate a random nonce
var nonce [24]byte
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
return "", fmt.Errorf("failed to generate nonce: %w", err)
}
// Create a key from the password
var key [32]byte
copy(key[:], ks.password)
// Encrypt the private key
privkeyBytes := []byte(privkey)
encrypted := secretbox.Seal(nonce[:], privkeyBytes, &nonce, &key)
return hex.EncodeToString(encrypted), nil
}
// decryptKey decrypts an encrypted private key
func (ks *KeyStore) decryptKey(encryptedKey string) (string, error) {
encryptedData, err := hex.DecodeString(encryptedKey)
if err != nil {
return "", fmt.Errorf("invalid encrypted key format: %w", err)
}
// The first 24 bytes are the nonce
if len(encryptedData) < 24 {
return "", ErrDecryptionFailed
}
var nonce [24]byte
copy(nonce[:], encryptedData[:24])
// Create a key from the password
var key [32]byte
copy(key[:], ks.password)
// Decrypt the private key
decryptedData, ok := secretbox.Open(nil, encryptedData[24:], &nonce, &key)
if !ok {
return "", ErrDecryptionFailed
}
return string(decryptedData), nil
}
// ChangePassword changes the password used for key encryption
func (ks *KeyStore) ChangePassword(newPassword string) error {
ks.mu.Lock()
defer ks.mu.Unlock()
// Create a copy of current keys
oldKeys := make(map[string]string)
for k, v := range ks.keys {
oldKeys[k] = v
}
// Change the password
oldPassword := ks.password
ks.password = []byte(newPassword)
// Re-encrypt all keys with the new password
for pubkey, encryptedKey := range oldKeys {
// Get the plaintext private key using the old password
ks.password = oldPassword
privkey, err := ks.decryptKey(encryptedKey)
if err != nil {
// Restore the old password and return
ks.password = oldPassword
return fmt.Errorf("failed to decrypt key: %w", err)
}
// Encrypt it with the new password
ks.password = []byte(newPassword)
newEncryptedKey, err := ks.encryptKey(privkey)
if err != nil {
// Restore the old password and return
ks.password = oldPassword
return fmt.Errorf("failed to re-encrypt key: %w", err)
}
// Update the key in the store
ks.keys[pubkey] = newEncryptedKey
}
// Save the re-encrypted keys
return ks.saveKeys()
}
// decodePrivateKey decodes an nsec private key
func decodePrivateKey(nsecKey string) (string, error) {
if len(nsecKey) < 4 || nsecKey[:4] != "nsec" {
return "", fmt.Errorf("invalid nsec key")
}
data, err := base64.StdEncoding.DecodeString(nsecKey[4:])
if err != nil {
return "", fmt.Errorf("failed to decode base64: %w", err)
}
// Remove the version byte and checksum
if len(data) < 2 {
return "", fmt.Errorf("invalid nsec data length")
}
privkeyBytes := data[1 : len(data)-4]
return hex.EncodeToString(privkeyBytes), nil
}