Blob Blame History Raw
From 4b7951be5b9b05bdc0a577034edf15fe4ebd2940 Mon Sep 17 00:00:00 2001
From: Guido van Rossum <guido@dropbox.com>
Date: Thu, 14 Jan 2016 18:44:05 -0800
Subject: [PATCH 4/4] Find partial types anywhere in the stack. Fixes #1126.

---
 mypy/checker.py                     | 32 +++++++++++++++++++-------------
 mypy/checkexpr.py                   | 14 +++++---------
 mypy/test/data/check-inference.test | 17 ++++++++++++++++-
 3 files changed, 40 insertions(+), 23 deletions(-)

diff --git a/mypy/checker.py b/mypy/checker.py
index 2086275..ae879a2 100644
--- a/mypy/checker.py
+++ b/mypy/checker.py
@@ -464,9 +464,11 @@ class TypeChecker(NodeVisitor[Type]):
                 if isinstance(orig_type, PartialType):
                     if orig_type.type is None:
                         # Ah this is a partial type. Give it the type of the function.
-                        defn.original_def.type = new_type
-                        partial_types = self.partial_types[-1]
-                        del partial_types[defn.original_def]
+                        var = defn.original_def
+                        partial_types = self.find_partial_types(var)
+                        if partial_types is not None:
+                            var.type = new_type
+                            del partial_types[var]
                     else:
                         # Trying to redefine something like partial empty list as function.
                         self.fail(messages.INCOMPATIBLE_REDEFINITION, defn)
@@ -1016,9 +1018,11 @@ class TypeChecker(NodeVisitor[Type]):
                         # None initializers preserve the partial None type.
                         return
                     if is_valid_inferred_type(rvalue_type):
-                        lvalue_type.var.type = rvalue_type
-                        partial_types = self.partial_types[-1]
-                        del partial_types[lvalue_type.var]
+                        var = lvalue_type.var
+                        partial_types = self.find_partial_types(var)
+                        if partial_types is not None:
+                            var.type = rvalue_type
+                            del partial_types[var]
                     # Try to infer a partial type. No need to check the return value, as
                     # an error will be reported elsewhere.
                     self.infer_partial_type(lvalue_type.var, lvalue, rvalue_type)
@@ -1376,14 +1380,10 @@ class TypeChecker(NodeVisitor[Type]):
     def try_infer_partial_type_from_indexed_assignment(
             self, lvalue: IndexExpr, rvalue: Node) -> None:
         # TODO: Should we share some of this with try_infer_partial_type?
-        partial_types = self.partial_types[-1]
-        if not partial_types:
-            # Fast path leave -- no partial types in the current scope.
-            return
         if isinstance(lvalue.base, RefExpr):
-            var = lvalue.base.node
-            if var in partial_types:
-                var = cast(Var, var)
+            var = cast(Var, lvalue.base.node)
+            partial_types = self.find_partial_types(var)
+            if partial_types is not None:
                 typename = cast(Instance, var.type).type.fullname()
                 if typename == 'builtins.dict':
                     # TODO: Don't infer things twice.
@@ -2147,6 +2147,12 @@ class TypeChecker(NodeVisitor[Type]):
             self.msg.fail(messages.NEED_ANNOTATION_FOR_VAR, context)
             var.type = AnyType()
 
+    def find_partial_types(self, var: Var) -> Optional[Dict[Var, Context]]:
+        for partial_types in reversed(self.partial_types):
+            if var in partial_types:
+                return partial_types
+        return None
+
     def is_within_function(self) -> bool:
         """Are we currently type checking within a function?
 
diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py
index 4b0cfac..e70faf3 100644
--- a/mypy/checkexpr.py
+++ b/mypy/checkexpr.py
@@ -80,8 +80,8 @@ class ExpressionChecker:
                     if not lvalue:
                         result = NoneTyp()
                 else:
-                    partial_types = self.chk.partial_types[-1]
-                    if node in partial_types:
+                    partial_types = self.chk.find_partial_types(node)
+                    if partial_types is not None:
                         context = partial_types[node]
                         self.msg.fail(messages.NEED_ANNOTATION_FOR_VAR, context)
                     result = AnyType()
@@ -141,14 +141,10 @@ class ExpressionChecker:
                       }
 
     def try_infer_partial_type(self, e: CallExpr) -> None:
-        partial_types = self.chk.partial_types[-1]
-        if not partial_types:
-            # Fast path leave -- no partial types in the current scope.
-            return
         if isinstance(e.callee, MemberExpr) and isinstance(e.callee.expr, RefExpr):
-            var = e.callee.expr.node
-            if var in partial_types:
-                var = cast(Var, var)
+            var = cast(Var, e.callee.expr.node)
+            partial_types = self.chk.find_partial_types(var)
+            if partial_types is not None:
                 partial_type_type = cast(PartialType, var.type).type
                 if partial_type_type is None:
                     # A partial None type -> can't infer anything.
diff --git a/mypy/test/data/check-inference.test b/mypy/test/data/check-inference.test
index 525fba5..7b6438c 100644
--- a/mypy/test/data/check-inference.test
+++ b/mypy/test/data/check-inference.test
@@ -1183,7 +1183,7 @@ b[{}] = 1
 [out]
 
 [case testInferDictInitializedToEmptyAndUpdatedFromMethod]
-map = {}  # E: Need type annotation for variable
+map = {}
 def add():
     map[1] = 2
 [builtins fixtures/dict.py]
@@ -1315,3 +1315,18 @@ class A:
         self.x()
 [out]
 main: note: In member "f" of class "A":
+
+[case testGlobalInitializedToNoneSetFromFunction]
+a = None
+def f():
+    global a
+    a = 42
+[out]
+
+[case testGlobalInitializedToNoneSetFromMethod]
+a = None
+class C:
+    def m(self):
+        global a
+        a = 42
+[out]
-- 
2.7.0