diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index ecc82b37c4ccf00fb863ecf6981de19bbccec52a..b3e7a667e03c24ca5c3d1c54c0db5c7b0bdde052 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -137,47 +137,24 @@ extern void mem_cgroup_print_oom_info(struct mem_cgroup *memcg,
 extern void mem_cgroup_replace_page_cache(struct page *oldpage,
 					struct page *newpage);
 
-/**
- * mem_cgroup_toggle_oom - toggle the memcg OOM killer for the current task
- * @new: true to enable, false to disable
- *
- * Toggle whether a failed memcg charge should invoke the OOM killer
- * or just return -ENOMEM.  Returns the previous toggle state.
- *
- * NOTE: Any path that enables the OOM killer before charging must
- *       call mem_cgroup_oom_synchronize() afterward to finalize the
- *       OOM handling and clean up.
- */
-static inline bool mem_cgroup_toggle_oom(bool new)
+static inline void mem_cgroup_oom_enable(void)
 {
-	bool old;
-
-	old = current->memcg_oom.may_oom;
-	current->memcg_oom.may_oom = new;
-
-	return old;
+	WARN_ON(current->memcg_oom.may_oom);
+	current->memcg_oom.may_oom = 1;
 }
 
-static inline void mem_cgroup_enable_oom(void)
+static inline void mem_cgroup_oom_disable(void)
 {
-	bool old = mem_cgroup_toggle_oom(true);
-
-	WARN_ON(old == true);
-}
-
-static inline void mem_cgroup_disable_oom(void)
-{
-	bool old = mem_cgroup_toggle_oom(false);
-
-	WARN_ON(old == false);
+	WARN_ON(!current->memcg_oom.may_oom);
+	current->memcg_oom.may_oom = 0;
 }
 
 static inline bool task_in_memcg_oom(struct task_struct *p)
 {
-	return p->memcg_oom.in_memcg_oom;
+	return p->memcg_oom.memcg;
 }
 
-bool mem_cgroup_oom_synchronize(void);
+bool mem_cgroup_oom_synchronize(bool wait);
 
 #ifdef CONFIG_MEMCG_SWAP
 extern int do_swap_account;
@@ -402,16 +379,11 @@ static inline void mem_cgroup_end_update_page_stat(struct page *page,
 {
 }
 
-static inline bool mem_cgroup_toggle_oom(bool new)
-{
-	return false;
-}
-
-static inline void mem_cgroup_enable_oom(void)
+static inline void mem_cgroup_oom_enable(void)
 {
 }
 
-static inline void mem_cgroup_disable_oom(void)
+static inline void mem_cgroup_oom_disable(void)
 {
 }
 
@@ -420,7 +392,7 @@ static inline bool task_in_memcg_oom(struct task_struct *p)
 	return false;
 }
 
-static inline bool mem_cgroup_oom_synchronize(void)
+static inline bool mem_cgroup_oom_synchronize(bool wait)
 {
 	return false;
 }
diff --git a/include/linux/sched.h b/include/linux/sched.h
index 6682da36b293cfad598a0c5803e9e32038f72aa3..e27baeeda3f470ed99ae899d8de22da062856bf8 100644
--- a/include/linux/sched.h
+++ b/include/linux/sched.h
@@ -1394,11 +1394,10 @@ struct task_struct {
 	} memcg_batch;
 	unsigned int memcg_kmem_skip_account;
 	struct memcg_oom_info {
+		struct mem_cgroup *memcg;
+		gfp_t gfp_mask;
+		int order;
 		unsigned int may_oom:1;
-		unsigned int in_memcg_oom:1;
-		unsigned int oom_locked:1;
-		int wakeups;
-		struct mem_cgroup *wait_on_memcg;
 	} memcg_oom;
 #endif
 #ifdef CONFIG_UPROBES
diff --git a/mm/filemap.c b/mm/filemap.c
index 1e6aec4a2d2ebae29c0fea0405a7e81f422beeb4..ae4846ff48494e9ac5e733a5d5339b61763a49f5 100644
--- a/mm/filemap.c
+++ b/mm/filemap.c
@@ -1616,7 +1616,6 @@ int filemap_fault(struct vm_area_struct *vma, struct vm_fault *vmf)
 	struct inode *inode = mapping->host;
 	pgoff_t offset = vmf->pgoff;
 	struct page *page;
-	bool memcg_oom;
 	pgoff_t size;
 	int ret = 0;
 
@@ -1625,11 +1624,7 @@ int filemap_fault(struct vm_area_struct *vma, struct vm_fault *vmf)
 		return VM_FAULT_SIGBUS;
 
 	/*
-	 * Do we have something in the page cache already?  Either
-	 * way, try readahead, but disable the memcg OOM killer for it
-	 * as readahead is optional and no errors are propagated up
-	 * the fault stack.  The OOM killer is enabled while trying to
-	 * instantiate the faulting page individually below.
+	 * Do we have something in the page cache already?
 	 */
 	page = find_get_page(mapping, offset);
 	if (likely(page) && !(vmf->flags & FAULT_FLAG_TRIED)) {
@@ -1637,14 +1632,10 @@ int filemap_fault(struct vm_area_struct *vma, struct vm_fault *vmf)
 		 * We found the page, so try async readahead before
 		 * waiting for the lock.
 		 */
-		memcg_oom = mem_cgroup_toggle_oom(false);
 		do_async_mmap_readahead(vma, ra, file, page, offset);
-		mem_cgroup_toggle_oom(memcg_oom);
 	} else if (!page) {
 		/* No page in the page cache at all */
-		memcg_oom = mem_cgroup_toggle_oom(false);
 		do_sync_mmap_readahead(vma, ra, file, offset);
-		mem_cgroup_toggle_oom(memcg_oom);
 		count_vm_event(PGMAJFAULT);
 		mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
 		ret = VM_FAULT_MAJOR;
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 5335b2b6be779a2c3ee014c68465b9ad653cc3de..65fc6a4498411332478c92d7ccf473a1cd341e53 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -2161,110 +2161,59 @@ static void memcg_oom_recover(struct mem_cgroup *memcg)
 		memcg_wakeup_oom(memcg);
 }
 
-/*
- * try to call OOM killer
- */
 static void mem_cgroup_oom(struct mem_cgroup *memcg, gfp_t mask, int order)
 {
-	bool locked;
-	int wakeups;
-
 	if (!current->memcg_oom.may_oom)
 		return;
-
-	current->memcg_oom.in_memcg_oom = 1;
-
 	/*
-	 * As with any blocking lock, a contender needs to start
-	 * listening for wakeups before attempting the trylock,
-	 * otherwise it can miss the wakeup from the unlock and sleep
-	 * indefinitely.  This is just open-coded because our locking
-	 * is so particular to memcg hierarchies.
+	 * We are in the middle of the charge context here, so we
+	 * don't want to block when potentially sitting on a callstack
+	 * that holds all kinds of filesystem and mm locks.
+	 *
+	 * Also, the caller may handle a failed allocation gracefully
+	 * (like optional page cache readahead) and so an OOM killer
+	 * invocation might not even be necessary.
+	 *
+	 * That's why we don't do anything here except remember the
+	 * OOM context and then deal with it at the end of the page
+	 * fault when the stack is unwound, the locks are released,
+	 * and when we know whether the fault was overall successful.
 	 */
-	wakeups = atomic_read(&memcg->oom_wakeups);
-	mem_cgroup_mark_under_oom(memcg);
-
-	locked = mem_cgroup_oom_trylock(memcg);
-
-	if (locked)
-		mem_cgroup_oom_notify(memcg);
-
-	if (locked && !memcg->oom_kill_disable) {
-		mem_cgroup_unmark_under_oom(memcg);
-		mem_cgroup_out_of_memory(memcg, mask, order);
-		mem_cgroup_oom_unlock(memcg);
-		/*
-		 * There is no guarantee that an OOM-lock contender
-		 * sees the wakeups triggered by the OOM kill
-		 * uncharges.  Wake any sleepers explicitely.
-		 */
-		memcg_oom_recover(memcg);
-	} else {
-		/*
-		 * A system call can just return -ENOMEM, but if this
-		 * is a page fault and somebody else is handling the
-		 * OOM already, we need to sleep on the OOM waitqueue
-		 * for this memcg until the situation is resolved.
-		 * Which can take some time because it might be
-		 * handled by a userspace task.
-		 *
-		 * However, this is the charge context, which means
-		 * that we may sit on a large call stack and hold
-		 * various filesystem locks, the mmap_sem etc. and we
-		 * don't want the OOM handler to deadlock on them
-		 * while we sit here and wait.  Store the current OOM
-		 * context in the task_struct, then return -ENOMEM.
-		 * At the end of the page fault handler, with the
-		 * stack unwound, pagefault_out_of_memory() will check
-		 * back with us by calling
-		 * mem_cgroup_oom_synchronize(), possibly putting the
-		 * task to sleep.
-		 */
-		current->memcg_oom.oom_locked = locked;
-		current->memcg_oom.wakeups = wakeups;
-		css_get(&memcg->css);
-		current->memcg_oom.wait_on_memcg = memcg;
-	}
+	css_get(&memcg->css);
+	current->memcg_oom.memcg = memcg;
+	current->memcg_oom.gfp_mask = mask;
+	current->memcg_oom.order = order;
 }
 
 /**
  * mem_cgroup_oom_synchronize - complete memcg OOM handling
+ * @handle: actually kill/wait or just clean up the OOM state
  *
- * This has to be called at the end of a page fault if the the memcg
- * OOM handler was enabled and the fault is returning %VM_FAULT_OOM.
+ * This has to be called at the end of a page fault if the memcg OOM
+ * handler was enabled.
  *
- * Memcg supports userspace OOM handling, so failed allocations must
+ * Memcg supports userspace OOM handling where failed allocations must
  * sleep on a waitqueue until the userspace task resolves the
  * situation.  Sleeping directly in the charge context with all kinds
  * of locks held is not a good idea, instead we remember an OOM state
  * in the task and mem_cgroup_oom_synchronize() has to be called at
- * the end of the page fault to put the task to sleep and clean up the
- * OOM state.
+ * the end of the page fault to complete the OOM handling.
  *
  * Returns %true if an ongoing memcg OOM situation was detected and
- * finalized, %false otherwise.
+ * completed, %false otherwise.
  */
-bool mem_cgroup_oom_synchronize(void)
+bool mem_cgroup_oom_synchronize(bool handle)
 {
+	struct mem_cgroup *memcg = current->memcg_oom.memcg;
 	struct oom_wait_info owait;
-	struct mem_cgroup *memcg;
+	bool locked;
 
 	/* OOM is global, do not handle */
-	if (!current->memcg_oom.in_memcg_oom)
-		return false;
-
-	/*
-	 * We invoked the OOM killer but there is a chance that a kill
-	 * did not free up any charges.  Everybody else might already
-	 * be sleeping, so restart the fault and keep the rampage
-	 * going until some charges are released.
-	 */
-	memcg = current->memcg_oom.wait_on_memcg;
 	if (!memcg)
-		goto out;
+		return false;
 
-	if (test_thread_flag(TIF_MEMDIE) || fatal_signal_pending(current))
-		goto out_memcg;
+	if (!handle)
+		goto cleanup;
 
 	owait.memcg = memcg;
 	owait.wait.flags = 0;
@@ -2273,13 +2222,25 @@ bool mem_cgroup_oom_synchronize(void)
 	INIT_LIST_HEAD(&owait.wait.task_list);
 
 	prepare_to_wait(&memcg_oom_waitq, &owait.wait, TASK_KILLABLE);
-	/* Only sleep if we didn't miss any wakeups since OOM */
-	if (atomic_read(&memcg->oom_wakeups) == current->memcg_oom.wakeups)
+	mem_cgroup_mark_under_oom(memcg);
+
+	locked = mem_cgroup_oom_trylock(memcg);
+
+	if (locked)
+		mem_cgroup_oom_notify(memcg);
+
+	if (locked && !memcg->oom_kill_disable) {
+		mem_cgroup_unmark_under_oom(memcg);
+		finish_wait(&memcg_oom_waitq, &owait.wait);
+		mem_cgroup_out_of_memory(memcg, current->memcg_oom.gfp_mask,
+					 current->memcg_oom.order);
+	} else {
 		schedule();
-	finish_wait(&memcg_oom_waitq, &owait.wait);
-out_memcg:
-	mem_cgroup_unmark_under_oom(memcg);
-	if (current->memcg_oom.oom_locked) {
+		mem_cgroup_unmark_under_oom(memcg);
+		finish_wait(&memcg_oom_waitq, &owait.wait);
+	}
+
+	if (locked) {
 		mem_cgroup_oom_unlock(memcg);
 		/*
 		 * There is no guarantee that an OOM-lock contender
@@ -2288,10 +2249,9 @@ bool mem_cgroup_oom_synchronize(void)
 		 */
 		memcg_oom_recover(memcg);
 	}
+cleanup:
+	current->memcg_oom.memcg = NULL;
 	css_put(&memcg->css);
-	current->memcg_oom.wait_on_memcg = NULL;
-out:
-	current->memcg_oom.in_memcg_oom = 0;
 	return true;
 }
 
@@ -2705,6 +2665,9 @@ static int __mem_cgroup_try_charge(struct mm_struct *mm,
 		     || fatal_signal_pending(current)))
 		goto bypass;
 
+	if (unlikely(task_in_memcg_oom(current)))
+		goto bypass;
+
 	/*
 	 * We always charge the cgroup the mm_struct belongs to.
 	 * The mm_struct's mem_cgroup changes on task migration if the
diff --git a/mm/memory.c b/mm/memory.c
index f7b7692c05edee78faf951778bb3909143ac9ac1..1311f26497e6a0f776682ed8a7e23b8620a5b9ad 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -3865,15 +3865,21 @@ int handle_mm_fault(struct mm_struct *mm, struct vm_area_struct *vma,
 	 * space.  Kernel faults are handled more gracefully.
 	 */
 	if (flags & FAULT_FLAG_USER)
-		mem_cgroup_enable_oom();
+		mem_cgroup_oom_enable();
 
 	ret = __handle_mm_fault(mm, vma, address, flags);
 
-	if (flags & FAULT_FLAG_USER)
-		mem_cgroup_disable_oom();
-
-	if (WARN_ON(task_in_memcg_oom(current) && !(ret & VM_FAULT_OOM)))
-		mem_cgroup_oom_synchronize();
+	if (flags & FAULT_FLAG_USER) {
+		mem_cgroup_oom_disable();
+                /*
+                 * The task may have entered a memcg OOM situation but
+                 * if the allocation error was handled gracefully (no
+                 * VM_FAULT_OOM), there is no need to kill anything.
+                 * Just clean up the OOM state peacefully.
+                 */
+                if (task_in_memcg_oom(current) && !(ret & VM_FAULT_OOM))
+                        mem_cgroup_oom_synchronize(false);
+	}
 
 	return ret;
 }
diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index 314e9d2743813ea5e52c08929b8a4cbb4621dfa9..6738c47f1f7280edc5f3fe610b2658195a0a77e0 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -680,7 +680,7 @@ void pagefault_out_of_memory(void)
 {
 	struct zonelist *zonelist;
 
-	if (mem_cgroup_oom_synchronize())
+	if (mem_cgroup_oom_synchronize(true))
 		return;
 
 	zonelist = node_zonelist(first_online_node, GFP_KERNEL);