156 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			156 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Go
		
	
	
	
| // Copyright 2017 The Go Authors. All rights reserved.
 | |
| // Use of this source code is governed by a BSD-style
 | |
| // license that can be found in the LICENSE file.
 | |
| 
 | |
| // +build !plan9
 | |
| 
 | |
| package main
 | |
| 
 | |
| import (
 | |
| 	"bufio"
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 	"os"
 | |
| 	"os/user"
 | |
| 	"path/filepath"
 | |
| 	"runtime"
 | |
| 	"strings"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	bashConfig = ".bash_profile"
 | |
| 	zshConfig  = ".zshrc"
 | |
| )
 | |
| 
 | |
| // appendToPATH adds the given path to the PATH environment variable and
 | |
| // persists it for future sessions.
 | |
| func appendToPATH(value string) error {
 | |
| 	if isInPATH(value) {
 | |
| 		return nil
 | |
| 	}
 | |
| 	return persistEnvVar("PATH", pathVar+envSeparator+value)
 | |
| }
 | |
| 
 | |
| func isInPATH(dir string) bool {
 | |
| 	p := os.Getenv("PATH")
 | |
| 
 | |
| 	paths := strings.Split(p, envSeparator)
 | |
| 	for _, d := range paths {
 | |
| 		if d == dir {
 | |
| 			return true
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func getHomeDir() (string, error) {
 | |
| 	home := os.Getenv(homeKey)
 | |
| 	if home != "" {
 | |
| 		return home, nil
 | |
| 	}
 | |
| 
 | |
| 	u, err := user.Current()
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	return u.HomeDir, nil
 | |
| }
 | |
| 
 | |
| func checkStringExistsFile(filename, value string) (bool, error) {
 | |
| 	file, err := os.OpenFile(filename, os.O_RDONLY, 0600)
 | |
| 	if err != nil {
 | |
| 		if os.IsNotExist(err) {
 | |
| 			return false, nil
 | |
| 		}
 | |
| 		return false, err
 | |
| 	}
 | |
| 	defer file.Close()
 | |
| 
 | |
| 	scanner := bufio.NewScanner(file)
 | |
| 	for scanner.Scan() {
 | |
| 		line := scanner.Text()
 | |
| 		if line == value {
 | |
| 			return true, nil
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return false, scanner.Err()
 | |
| }
 | |
| 
 | |
| func appendToFile(filename, value string) error {
 | |
| 	verbosef("Adding %q to %s", value, filename)
 | |
| 
 | |
| 	ok, err := checkStringExistsFile(filename, value)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	if ok {
 | |
| 		// Nothing to do.
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	f, err := os.OpenFile(filename, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	defer f.Close()
 | |
| 
 | |
| 	_, err = f.WriteString(lineEnding + value + lineEnding)
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func isShell(name string) bool {
 | |
| 	return strings.Contains(currentShell(), name)
 | |
| }
 | |
| 
 | |
| // persistEnvVarWindows sets an environment variable in the Windows
 | |
| // registry.
 | |
| func persistEnvVarWindows(name, value string) error {
 | |
| 	_, err := runCommand(context.Background(), "powershell", "-command",
 | |
| 		fmt.Sprintf(`[Environment]::SetEnvironmentVariable("%s", "%s", "User")`, name, value))
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func persistEnvVar(name, value string) error {
 | |
| 	if runtime.GOOS == "windows" {
 | |
| 		if err := persistEnvVarWindows(name, value); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		if isShell("cmd.exe") || isShell("powershell.exe") {
 | |
| 			return os.Setenv(strings.ToUpper(name), value)
 | |
| 		}
 | |
| 		// User is in bash, zsh, etc.
 | |
| 		// Also set the environment variable in their shell config.
 | |
| 	}
 | |
| 
 | |
| 	rc, err := shellConfigFile()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	line := fmt.Sprintf("export %s=%s", strings.ToUpper(name), value)
 | |
| 	if err := appendToFile(rc, line); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return os.Setenv(strings.ToUpper(name), value)
 | |
| }
 | |
| 
 | |
| func shellConfigFile() (string, error) {
 | |
| 	home, err := getHomeDir()
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	switch {
 | |
| 	case isShell("bash"):
 | |
| 		return filepath.Join(home, bashConfig), nil
 | |
| 	case isShell("zsh"):
 | |
| 		return filepath.Join(home, zshConfig), nil
 | |
| 	default:
 | |
| 		return "", fmt.Errorf("%q is not a supported shell", currentShell())
 | |
| 	}
 | |
| }
 |