Skip site navigation (1)Skip section navigation (2)



index | | raw e-mail

diff --git a/sys/netinet/in_mcast.c b/sys/netinet/in_mcast.c
index 3206828aacff..b5636c29daeb 100644
--- a/sys/netinet/in_mcast.c
+++ b/sys/netinet/in_mcast.c
@@ -2524,6 +2524,7 @@ inp_set_source_filters(struct inpcb *inp, struct sockopt *sopt)
 {
 	struct epoch_tracker	 et;
 	struct __msfilterreq	 msfr;
+	struct sockaddr_storage	*kss;
 	sockunion_t		*gsa;
 	struct ifnet		*ifp;
 	struct in_mfilter	*imf;
@@ -2536,9 +2537,6 @@ inp_set_source_filters(struct inpcb *inp, struct sockopt *sopt)
 	if (error)
 		return (error);
 
-	if (msfr.msfr_nsrcs > in_mcast_maxsocksrc)
-		return (ENOBUFS);
-
 	if ((msfr.msfr_fmode != MCAST_EXCLUDE &&
 	     msfr.msfr_fmode != MCAST_INCLUDE))
 		return (EINVAL);
@@ -2551,13 +2549,24 @@ inp_set_source_filters(struct inpcb *inp, struct sockopt *sopt)
 	if (!IN_MULTICAST(ntohl(gsa->sin.sin_addr.s_addr)))
 		return (EINVAL);
 
+	if (msfr.msfr_nsrcs > in_mcast_maxsocksrc)
+		return (ENOBUFS);
+	kss = mallocarray(msfr.msfr_nsrcs, sizeof(struct sockaddr_storage),
+	    M_TEMP, M_WAITOK);
+	error = copyin(msfr.msfr_srcs, kss,
+	    sizeof(struct sockaddr_storage) * msfr.msfr_nsrcs);
+	if (error)
+		goto out_inp_unlocked;
+
 	gsa->sin.sin_port = 0;	/* ignore port */
 
 	NET_EPOCH_ENTER(et);
 	ifp = ifnet_byindex(msfr.msfr_ifindex);
 	NET_EPOCH_EXIT(et);	/* XXXGL: unsafe ifp */
-	if (ifp == NULL)
-		return (EADDRNOTAVAIL);
+	if (ifp == NULL) {
+		error = EADDRNOTAVAIL;
+		goto out_inp_unlocked;
+	}
 
 	IN_MULTI_LOCK();
 
@@ -2589,25 +2598,9 @@ inp_set_source_filters(struct inpcb *inp, struct sockopt *sopt)
 	if (msfr.msfr_nsrcs > 0) {
 		struct in_msource	*lims;
 		struct sockaddr_in	*psin;
-		struct sockaddr_storage	*kss, *pkss;
+		struct sockaddr_storage	*pkss;
 		int			 i;
 
-		INP_WUNLOCK(inp);
-
-		CTR2(KTR_IGMPV3, "%s: loading %lu source list entries",
-		    __func__, (unsigned long)msfr.msfr_nsrcs);
-		kss = malloc(sizeof(struct sockaddr_storage) * msfr.msfr_nsrcs,
-		    M_TEMP, M_WAITOK);
-		error = copyin(msfr.msfr_srcs, kss,
-		    sizeof(struct sockaddr_storage) * msfr.msfr_nsrcs);
-		if (error) {
-			IN_MULTI_UNLOCK();
-			free(kss, M_TEMP);
-			return (error);
-		}
-
-		INP_WLOCK(inp);
-
 		/*
 		 * Mark all source filters as UNDEFINED at t1.
 		 * Restore new group filter mode, as imf_leave()
@@ -2642,7 +2635,6 @@ inp_set_source_filters(struct inpcb *inp, struct sockopt *sopt)
 				break;
 			lims->imsl_st[1] = imf->imf_st[1];
 		}
-		free(kss, M_TEMP);
 	}
 
 	if (error)
@@ -2679,6 +2671,8 @@ out_imf_rollback:
 out_inp_locked:
 	INP_WUNLOCK(inp);
 	IN_MULTI_UNLOCK();
+out_inp_unlocked:
+	free(kss, M_TEMP);
 	return (error);
 }
 
diff --git a/sys/netinet6/in6_mcast.c b/sys/netinet6/in6_mcast.c
index a6186568ecb2..4ec9f36cd9ac 100644
--- a/sys/netinet6/in6_mcast.c
+++ b/sys/netinet6/in6_mcast.c
@@ -2489,6 +2489,7 @@ in6p_set_source_filters(struct inpcb *inp, struct sockopt *sopt)
 {
 	struct __msfilterreq	 msfr;
 	struct epoch_tracker	 et;
+	struct sockaddr_storage	*kss;
 	sockunion_t		*gsa;
 	struct ifnet		*ifp;
 	struct in6_mfilter	*imf;
@@ -2501,9 +2502,6 @@ in6p_set_source_filters(struct inpcb *inp, struct sockopt *sopt)
 	if (error)
 		return (error);
 
-	if (msfr.msfr_nsrcs > in6_mcast_maxsocksrc)
-		return (ENOBUFS);
-
 	if (msfr.msfr_fmode != MCAST_EXCLUDE &&
 	    msfr.msfr_fmode != MCAST_INCLUDE)
 		return (EINVAL);
@@ -2516,19 +2514,31 @@ in6p_set_source_filters(struct inpcb *inp, struct sockopt *sopt)
 	if (!IN6_IS_ADDR_MULTICAST(&gsa->sin6.sin6_addr))
 		return (EINVAL);
 
+	if (msfr.msfr_nsrcs > in6_mcast_maxsocksrc)
+		return (ENOBUFS);
+	kss = mallocarray(msfr.msfr_nsrcs, sizeof(struct sockaddr_storage),
+	    M_TEMP, M_WAITOK);
+	error = copyin(msfr.msfr_srcs, kss,
+	    sizeof(struct sockaddr_storage) * msfr.msfr_nsrcs);
+	if (error)
+		goto out_in6p_unlocked;
+
 	gsa->sin6.sin6_port = 0;	/* ignore port */
 
 	NET_EPOCH_ENTER(et);
 	ifp = ifnet_byindex(msfr.msfr_ifindex);
 	NET_EPOCH_EXIT(et);
-	if (ifp == NULL)
-		return (EADDRNOTAVAIL);
+	if (ifp == NULL) {
+		error = EADDRNOTAVAIL;
+		goto out_in6p_unlocked;
+	}
 	(void)in6_setscope(&gsa->sin6.sin6_addr, ifp, NULL);
 
 	/*
 	 * Take the INP write lock.
 	 * Check if this socket is a member of this group.
 	 */
+	IN6_MULTI_LOCK();
 	imo = in6p_findmoptions(inp);
 	imf = im6o_match_group(imo, ifp, &gsa->sa);
 	if (imf == NULL) {
@@ -2553,24 +2563,9 @@ in6p_set_source_filters(struct inpcb *inp, struct sockopt *sopt)
 	if (msfr.msfr_nsrcs > 0) {
 		struct in6_msource	*lims;
 		struct sockaddr_in6	*psin;
-		struct sockaddr_storage	*kss, *pkss;
+		struct sockaddr_storage	*pkss;
 		int			 i;
 
-		INP_WUNLOCK(inp);
-
-		CTR2(KTR_MLD, "%s: loading %lu source list entries",
-		    __func__, (unsigned long)msfr.msfr_nsrcs);
-		kss = malloc(sizeof(struct sockaddr_storage) * msfr.msfr_nsrcs,
-		    M_TEMP, M_WAITOK);
-		error = copyin(msfr.msfr_srcs, kss,
-		    sizeof(struct sockaddr_storage) * msfr.msfr_nsrcs);
-		if (error) {
-			free(kss, M_TEMP);
-			return (error);
-		}
-
-		INP_WLOCK(inp);
-
 		/*
 		 * Mark all source filters as UNDEFINED at t1.
 		 * Restore new group filter mode, as im6f_leave()
@@ -2615,7 +2610,6 @@ in6p_set_source_filters(struct inpcb *inp, struct sockopt *sopt)
 				break;
 			lims->im6sl_st[1] = imf->im6f_st[1];
 		}
-		free(kss, M_TEMP);
 	}
 
 	if (error)
@@ -2650,6 +2644,9 @@ out_im6f_rollback:
 
 out_in6p_locked:
 	INP_WUNLOCK(inp);
+	IN6_MULTI_UNLOCK();
+out_in6p_unlocked:
+	free(kss, M_TEMP);
 	return (error);
 }
 


home | help