Blob Blame History Raw
diff --git a/dnssec-trigger-script.in b/dnssec-trigger-script.in
index b572dd1..32d7749 100644
--- a/dnssec-trigger-script.in
+++ b/dnssec-trigger-script.in
@@ -6,7 +6,7 @@
 """
 
 from gi.repository import NMClient
-import os, sys, shutil, subprocess
+import os, sys, fcntl, shutil, glob, subprocess
 import logging, logging.handlers
 import socket, struct
 
@@ -15,8 +15,7 @@ DEVNULL = open("/dev/null", "wb")
 log = logging.getLogger()
 log.setLevel(logging.INFO)
 log.addHandler(logging.handlers.SysLogHandler())
-if sys.stderr.isatty():
-    log.addHandler(logging.StreamHandler())
+log.addHandler(logging.StreamHandler())
 
 # NetworkManager reportedly doesn't pass the PATH environment variable.
 os.environ['PATH'] = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
@@ -24,6 +23,24 @@ os.environ['PATH'] = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/b
 class UserError(Exception):
     pass
 
+class Lock:
+    """Lock used to serialize the script"""
+
+    path = "/var/run/dnssec-trigger/lock"
+
+    def __init__(self):
+        # We don't use os.makedirs(..., exist_ok=True) to ensure Python 2 compatibility
+        dirname = os.path.dirname(self.path)
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+        self.lock = open(self.path, "w")
+
+    def __enter__(self):
+        fcntl.lockf(self.lock, fcntl.LOCK_EX)
+
+    def __exit__(self, t, v, tb):
+        fcntl.lockf(self.lock, fcntl.LOCK_UN)
+
 class Config:
     """Global configuration options"""
 
@@ -53,18 +70,17 @@ class ConnectionList:
 
     nm_connections = None
 
-    def __init__(self, only_default=False, skip_wifi=False):
+    def __init__(self, client, only_default=False, skip_wifi=False):
         # Cache the active connection list in the class
+        if not client.get_manager_running():
+            raise UserError("NetworkManager is not running.")
         if self.nm_connections is None:
-            self.__class__.client = NMClient.Client()
-            self.__class__.nm_connections = self.client.get_active_connections()
+            self.__class__.nm_connections = client.get_active_connections()
         self.skip_wifi = skip_wifi
         self.only_default = only_default
         log.debug(self)
 
     def __repr__(self):
-        if not list(self):
-            raise Exception("!!!")
         return "<ConnectionList(only_default={only_default}, skip_wifi={skip_wifi}, connections={})>".format(list(self), **vars(self))
 
     def __iter__(self):
@@ -190,10 +206,10 @@ class UnboundZoneConfig:
                 if fields.pop(0) in ('forward', 'forward:'):
                     fields.pop(0)
                 secure = False
-                if fields[0] == '+i':
+                if fields and fields[0] == '+i':
                     secure = True
                     fields.pop(0)
-                self.cache[name] = set(fields[3:]), secure
+                self.cache[name] = set(fields), secure
         log.debug(self)
 
     def __repr__(self):
@@ -255,7 +271,7 @@ class Store:
                     line = line.strip()
                     if line:
                         self.cache.add(line)
-        except FileNotFoundError:
+        except IOError:
             pass
         log.debug(self)
 
@@ -277,10 +293,16 @@ class Store:
         log.debug(self)
 
     def update(self, zones):
-        """Commit a new zone list."""
+        """Commit a new set of items and return True when it differs"""
 
-        self.cache = set(zones)
-        log.debug(self)
+        zones = set(zones)
+
+        if zones != self.cache:
+            self.cache = set(zones)
+            log.debug(self)
+            return True
+
+        return False
 
     def remove(self, zone):
         """Remove zone from the cache."""
@@ -309,7 +331,7 @@ class GlobalForwarders:
                     line = line.strip()
                     if line:
                         self.cache.add(line)
-        except FileNotFoundError:
+        except IOError:
             pass
 
 class Application:
@@ -328,32 +350,40 @@ class Application:
         except AttributeError:
             self.usage()
         self.config = Config()
+        self.client = NMClient.Client()
+
+        self.resolvconf = "/etc/resolv.conf"
+        self.resolvconf_backup = "/var/run/dnssec-trigger/resolv.conf.bak"
 
     def nm_handles_resolv_conf(self):
-        if subprocess.call(["pidof", "NetworkManager"], stdout=DEVNULL, stderr=DEVNULL) != 0:
+        if not self.client.get_manager_running():
+            log.debug("NetworkManager is not running")
             return False
         try:
             with open("/etc/NetworkManager/NetworkManager.conf") as nm_config_file:
                 for line in nm_config_file:
-                    if line.strip == "dns=none":
+                    if line.strip() in ("dns=none", "dns=unbound"):
+                        log.debug("NetworkManager doesn't handle /etc/resolv.conf")
                         return False
         except IOError:
             pass
+        log.debug("NetworkManager handles /etc/resolv.conf")
         return True
 
     def usage(self):
         raise UserError("Usage: dnssec-trigger-script [--debug] [--async] --prepare|--update|--update-global-forwarders|--update-connection-zones|--cleanup")
 
     def run(self):
-        log.debug("Running: {}".format(self.method.__name__))
-        self.method()
+        with Lock():
+            log.debug("Running: {}".format(self.method.__name__))
+            self.method()
 
     def run_prepare(self):
         """Prepare for dnssec-trigger."""
 
         if not self.nm_handles_resolv_conf():
             log.info("Backing up /etc/resolv.conf")
-            shutil.copy("/etc/resolv.conf", "/var/run/dnssec-trigger/resolv.conf.bak")
+            shutil.copy(self.resolvconf, self.resolvconf_backup)
 
     def run_cleanup(self):
         """Clean up after dnssec-trigger."""
@@ -361,6 +391,18 @@ class Application:
         stored_zones = Store('zones')
         unbound_zones = UnboundZoneConfig()
 
+        # provide upgrade path for previous versions
+        old_zones = glob.glob("/var/run/dnssec-trigger/????????-????-????-????-????????????")
+        if old_zones:
+            log.info("Reading zones from the legacy zone store")
+            with open("/var/run/dnssec-trigger/zones", "a") as target:
+                for filename in old_zones:
+                    with open(filename) as source:
+                        log.debug("Reading zones from {}".format(filename))
+                        for line in source:
+                            stored_zones.add(line.strip())
+                        os.remove(filename)
+
         log.debug("clearing unbound configuration")
         for zone in stored_zones:
             unbound_zones.remove(zone)
@@ -370,11 +412,14 @@ class Application:
         log.debug("recovering /etc/resolv.conf")
         subprocess.check_call(["chattr", "-i", "/etc/resolv.conf"])
         if not self.nm_handles_resolv_conf():
-            shutil.copy("/var/run/dnssec-trigger/resolv.conf.bak", "/etc/resolv.conf")
+            try:
+                shutil.copy(self.resolvconf_backup, self.resolvconf)
+            except IOError as error:
+                log.warning("Cannot restore resolv.conf from {!r}: {}".format(self.resolvconf_backup, error.strerror))
         # NetworkManager currently doesn't support explicit /etc/resolv.conf
         # write out. For now we simply restart the daemon.
         elif os.path.exists("/sys/fs/cgroup/systemd"):
-            subprocess.check_call(["systemctl", "try-restart", "NetworkManager.service"])
+            subprocess.check_call(["systemctl", "--ignore-dependencies", "try-restart", "NetworkManager.service"])
         else:
             subprocess.check_call(["/etc/init.d/NetworkManager", "restart"])
 
@@ -387,7 +432,7 @@ class Application:
 
         subprocess.check_call(["dnssec-trigger-control", "status"], stdout=DEVNULL, stderr=DEVNULL)
 
-        default_connections = ConnectionList(only_default=True)
+        default_connections = ConnectionList(self.client, only_default=True)
         servers = Store('servers')
 
         if servers.update(sum((connection.servers for connection in default_connections), [])):
@@ -399,7 +444,7 @@ class Application:
     def run_update_connection_zones(self):
         """Configures forward zones in the unbound using unbound-control."""
 
-        connections = ConnectionList(skip_wifi=not self.config.add_wifi_provided_zones).get_zone_connection_mapping()
+        connections = ConnectionList(self.client, skip_wifi=not self.config.add_wifi_provided_zones).get_zone_connection_mapping()
         unbound_zones = UnboundZoneConfig()
         stored_zones = Store('zones')