summaryrefslogtreecommitdiff
path: root/tests/common/java/android/net/NetworkProviderTest.kt
blob: 3ceacf8ad6c10c8f3814ac67ce05405540270a89 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
/*
 * Copyright (C) 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package android.net

import android.app.Instrumentation
import android.content.Context
import android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED
import android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED
import android.net.NetworkCapabilities.TRANSPORT_TEST
import android.net.NetworkProviderTest.TestNetworkCallback.CallbackEntry.OnUnavailable
import android.net.NetworkProviderTest.TestNetworkProvider.CallbackEntry.OnNetworkRequestWithdrawn
import android.net.NetworkProviderTest.TestNetworkProvider.CallbackEntry.OnNetworkRequested
import android.os.Build
import android.os.Handler
import android.os.HandlerThread
import android.os.Looper
import android.util.Log
import androidx.test.InstrumentationRegistry
import com.android.net.module.util.ArrayTrackRecord
import com.android.testutils.CompatUtil
import com.android.testutils.ConnectivityModuleTest
import com.android.testutils.DevSdkIgnoreRule
import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter
import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
import com.android.testutils.DevSdkIgnoreRunner
import com.android.testutils.TestableNetworkOfferCallback
import com.android.testutils.isDevSdkInRange
import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mockito.doReturn
import org.mockito.Mockito.mock
import org.mockito.Mockito.verifyNoMoreInteractions
import java.util.UUID
import java.util.concurrent.Executor
import java.util.concurrent.RejectedExecutionException
import kotlin.test.assertEquals
import kotlin.test.assertNotEquals
import kotlin.test.fail

private const val DEFAULT_TIMEOUT_MS = 5000L
private const val DEFAULT_NO_CALLBACK_TIMEOUT_MS = 200L
private val instrumentation: Instrumentation
    get() = InstrumentationRegistry.getInstrumentation()
private val context: Context get() = InstrumentationRegistry.getContext()
private val PROVIDER_NAME = "NetworkProviderTest"

@RunWith(DevSdkIgnoreRunner::class)
@IgnoreUpTo(Build.VERSION_CODES.Q)
@ConnectivityModuleTest
class NetworkProviderTest {
    @Rule @JvmField
    val mIgnoreRule = DevSdkIgnoreRule()
    private val mCm = context.getSystemService(ConnectivityManager::class.java)
    private val mHandlerThread = HandlerThread("${javaClass.simpleName} handler thread")

    @Before
    fun setUp() {
        instrumentation.getUiAutomation().adoptShellPermissionIdentity()
        mHandlerThread.start()
    }

    @After
    fun tearDown() {
        mHandlerThread.quitSafely()
        instrumentation.getUiAutomation().dropShellPermissionIdentity()
    }

    private class TestNetworkProvider(context: Context, looper: Looper) :
            NetworkProvider(context, looper, PROVIDER_NAME) {
        private val TAG = this::class.simpleName
        private val seenEvents = ArrayTrackRecord<CallbackEntry>().newReadHead()

        sealed class CallbackEntry {
            data class OnNetworkRequested(
                val request: NetworkRequest,
                val score: Int,
                val id: Int
            ) : CallbackEntry()
            data class OnNetworkRequestWithdrawn(val request: NetworkRequest) : CallbackEntry()
        }

        override fun onNetworkRequested(request: NetworkRequest, score: Int, id: Int) {
            Log.d(TAG, "onNetworkRequested $request, $score, $id")
            seenEvents.add(OnNetworkRequested(request, score, id))
        }

        override fun onNetworkRequestWithdrawn(request: NetworkRequest) {
            Log.d(TAG, "onNetworkRequestWithdrawn $request")
            seenEvents.add(OnNetworkRequestWithdrawn(request))
        }

        inline fun <reified T : CallbackEntry> eventuallyExpectCallbackThat(
            crossinline predicate: (T) -> Boolean
        ) = seenEvents.poll(DEFAULT_TIMEOUT_MS) { it is T && predicate(it) }
                ?: fail("Did not receive callback after ${DEFAULT_TIMEOUT_MS}ms")

        fun assertNoCallback() {
            val cb = seenEvents.poll(DEFAULT_NO_CALLBACK_TIMEOUT_MS)
            if (null != cb) fail("Expected no callback but got $cb")
        }
    }

    private fun createNetworkProvider(ctx: Context = context): TestNetworkProvider {
        return TestNetworkProvider(ctx, mHandlerThread.looper)
    }

    private fun createAndRegisterNetworkProvider(ctx: Context = context) =
        createNetworkProvider(ctx).also {
            assertEquals(it.getProviderId(), NetworkProvider.ID_NONE)
            mCm.registerNetworkProvider(it)
            assertNotEquals(it.getProviderId(), NetworkProvider.ID_NONE)
        }

    // In S+ framework, do not run this test, since the provider will no longer receive
    // onNetworkRequested for every request. Instead, provider needs to
    // call {@code registerNetworkOffer} with the description of networks they
    // might have ability to setup, and expects {@link NetworkOfferCallback#onNetworkNeeded}.
    @IgnoreAfter(Build.VERSION_CODES.R)
    @Test
    fun testOnNetworkRequested() {
        val provider = createAndRegisterNetworkProvider()

        val specifier = CompatUtil.makeTestNetworkSpecifier(
                UUID.randomUUID().toString())
        // Test network is not allowed to be trusted.
        val nr: NetworkRequest = NetworkRequest.Builder()
                .addTransportType(TRANSPORT_TEST)
                .removeCapability(NET_CAPABILITY_TRUSTED)
                .setNetworkSpecifier(specifier)
                .build()
        val cb = ConnectivityManager.NetworkCallback()
        mCm.requestNetwork(nr, cb)
        provider.eventuallyExpectCallbackThat<OnNetworkRequested>() { callback ->
            callback.request.getNetworkSpecifier() == specifier &&
            callback.request.hasTransport(TRANSPORT_TEST)
        }

        val initialScore = 40
        val updatedScore = 60
        val nc = NetworkCapabilities().apply {
                addTransportType(NetworkCapabilities.TRANSPORT_TEST)
                removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
                removeCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
                addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED)
                addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING)
                addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
                setNetworkSpecifier(specifier)
        }
        val lp = LinkProperties()
        val config = NetworkAgentConfig.Builder().build()
        val agent = object : NetworkAgent(context, mHandlerThread.looper, "TestAgent", nc, lp,
                initialScore, config, provider) {}
        agent.register()
        agent.markConnected()

        provider.eventuallyExpectCallbackThat<OnNetworkRequested>() { callback ->
            callback.request.getNetworkSpecifier() == specifier &&
            callback.score == initialScore &&
            callback.id == agent.providerId
        }

        agent.sendNetworkScore(updatedScore)
        provider.eventuallyExpectCallbackThat<OnNetworkRequested>() { callback ->
            callback.request.getNetworkSpecifier() == specifier &&
            callback.score == updatedScore &&
            callback.id == agent.providerId
        }

        mCm.unregisterNetworkCallback(cb)
        provider.eventuallyExpectCallbackThat<OnNetworkRequestWithdrawn>() { callback ->
            callback.request.getNetworkSpecifier() == specifier &&
            callback.request.hasTransport(TRANSPORT_TEST)
        }
        mCm.unregisterNetworkProvider(provider)
        // Provider id should be ID_NONE after unregister network provider
        assertEquals(provider.getProviderId(), NetworkProvider.ID_NONE)
        // unregisterNetworkProvider should not crash even if it's called on an
        // already unregistered provider.
        mCm.unregisterNetworkProvider(provider)
    }

    // Mainline module can't use internal HandlerExecutor, so add an identical executor here.
    // TODO: Refactor with the one in MultiNetworkPolicyTracker.
    private class HandlerExecutor(private val handler: Handler) : Executor {
        public override fun execute(command: Runnable) {
            if (!handler.post(command)) {
                throw RejectedExecutionException(handler.toString() + " is shutting down")
            }
        }
    }

    @IgnoreUpTo(Build.VERSION_CODES.R)
    @Test
    fun testRegisterNetworkOffer() {
        val provider = createAndRegisterNetworkProvider()
        val provider2 = createAndRegisterNetworkProvider()

        // Prepare the materials which will be used to create different offers.
        val specifier1 = CompatUtil.makeTestNetworkSpecifier("TEST-SPECIFIER-1")
        val specifier2 = CompatUtil.makeTestNetworkSpecifier("TEST-SPECIFIER-2")
        val scoreWeaker = NetworkScore.Builder().build()
        val scoreStronger = NetworkScore.Builder().setTransportPrimary(true).build()
        val ncFilter1 = NetworkCapabilities.Builder().addTransportType(TRANSPORT_TEST)
                .setNetworkSpecifier(specifier1).build()
        val ncFilter2 = NetworkCapabilities.Builder().addTransportType(TRANSPORT_TEST)
                .addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
                .setNetworkSpecifier(specifier1).build()
        val ncFilter3 = NetworkCapabilities.Builder().addTransportType(TRANSPORT_TEST)
                .setNetworkSpecifier(specifier2).build()
        val ncFilter4 = NetworkCapabilities.Builder().addTransportType(TRANSPORT_TEST)
                .setNetworkSpecifier(specifier2).build()

        // Make 4 offers, where 1 doesn't have NOT_VCN, 2 has NOT_VCN, 3 is similar to 1 but with
        // different specifier, and 4 is also similar to 1 but with different provider.
        val offerCallback1 = TestableNetworkOfferCallback(
                DEFAULT_TIMEOUT_MS, DEFAULT_NO_CALLBACK_TIMEOUT_MS)
        val offerCallback2 = TestableNetworkOfferCallback(
                DEFAULT_TIMEOUT_MS, DEFAULT_NO_CALLBACK_TIMEOUT_MS)
        val offerCallback3 = TestableNetworkOfferCallback(
                DEFAULT_TIMEOUT_MS, DEFAULT_NO_CALLBACK_TIMEOUT_MS)
        val offerCallback4 = TestableNetworkOfferCallback(
                DEFAULT_TIMEOUT_MS, DEFAULT_NO_CALLBACK_TIMEOUT_MS)
        provider.registerNetworkOffer(scoreWeaker, ncFilter1,
                HandlerExecutor(mHandlerThread.threadHandler), offerCallback1)
        provider.registerNetworkOffer(scoreStronger, ncFilter2,
                HandlerExecutor(mHandlerThread.threadHandler), offerCallback2)
        provider.registerNetworkOffer(scoreWeaker, ncFilter3,
                HandlerExecutor(mHandlerThread.threadHandler), offerCallback3)
        provider2.registerNetworkOffer(scoreWeaker, ncFilter4,
                HandlerExecutor(mHandlerThread.threadHandler), offerCallback4)
        // Unlike Android R, Android S+ provider will only receive interested requests via offer
        // callback. Verify that the callbacks do not see any existing request such as default
        // requests.
        offerCallback1.assertNoCallback()
        offerCallback2.assertNoCallback()
        offerCallback3.assertNoCallback()
        offerCallback4.assertNoCallback()

        // File a request with specifier but without NOT_VCN, verify network is needed for callback
        // with the same specifier.
        val nrNoNotVcn: NetworkRequest = NetworkRequest.Builder()
                .addTransportType(TRANSPORT_TEST)
                // Test network is not allowed to be trusted.
                .removeCapability(NET_CAPABILITY_TRUSTED)
                .setNetworkSpecifier(specifier1)
                .build()
        val cb1 = ConnectivityManager.NetworkCallback()
        mCm.requestNetwork(nrNoNotVcn, cb1)
        offerCallback1.expectOnNetworkNeeded(ncFilter1)
        offerCallback2.expectOnNetworkNeeded(ncFilter2)
        offerCallback3.assertNoCallback()
        offerCallback4.assertNoCallback()

        mCm.unregisterNetworkCallback(cb1)
        offerCallback1.expectOnNetworkUnneeded(ncFilter1)
        offerCallback2.expectOnNetworkUnneeded(ncFilter2)
        offerCallback3.assertNoCallback()
        offerCallback4.assertNoCallback()

        // File a request without specifier but with NOT_VCN, verify network is needed for offer
        // with NOT_VCN.
        val nrNotVcn: NetworkRequest = NetworkRequest.Builder()
                .addTransportType(TRANSPORT_TEST)
                .addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
                // Test network is not allowed to be trusted.
                .removeCapability(NET_CAPABILITY_TRUSTED)
                .build()
        val cb2 = ConnectivityManager.NetworkCallback()
        mCm.requestNetwork(nrNotVcn, cb2)
        offerCallback1.assertNoCallback()
        offerCallback2.expectOnNetworkNeeded(ncFilter2)
        offerCallback3.assertNoCallback()
        offerCallback4.assertNoCallback()

        // Upgrade offer 3 & 4 to satisfy previous request and then verify they are also needed.
        ncFilter3.addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
        provider.registerNetworkOffer(scoreWeaker, ncFilter3,
                HandlerExecutor(mHandlerThread.threadHandler), offerCallback3)
        ncFilter4.addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
        provider2.registerNetworkOffer(scoreWeaker, ncFilter4,
                HandlerExecutor(mHandlerThread.threadHandler), offerCallback4)
        offerCallback1.assertNoCallback()
        offerCallback2.assertNoCallback()
        offerCallback3.expectOnNetworkNeeded(ncFilter3)
        offerCallback4.expectOnNetworkNeeded(ncFilter4)

        // Connect an agent to fulfill the request, verify offer 4 is not needed since it is not
        // from currently serving provider nor can beat the current satisfier.
        val nc = NetworkCapabilities().apply {
            addTransportType(NetworkCapabilities.TRANSPORT_TEST)
            removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
            addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VCN_MANAGED)
            addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_SUSPENDED)
            addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_ROAMING)
            addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
            setNetworkSpecifier(specifier1)
        }
        val config = NetworkAgentConfig.Builder().build()
        val agent = object : NetworkAgent(context, mHandlerThread.looper, "TestAgent", nc,
                LinkProperties(), scoreWeaker, config, provider) {}
        agent.register()
        agent.markConnected()
        offerCallback1.assertNoCallback()  // Still unneeded.
        offerCallback2.assertNoCallback()  // Still needed.
        offerCallback3.assertNoCallback()  // Still needed.
        offerCallback4.expectOnNetworkUnneeded(ncFilter4)

        // Upgrade the agent, verify no change since the framework will treat the offer as needed
        // if a request is currently satisfied by the network provided by the same provider.
        // TODO: Consider offers with weaker score are unneeded.
        agent.sendNetworkScore(scoreStronger)
        offerCallback1.assertNoCallback()  // Still unneeded.
        offerCallback2.assertNoCallback()  // Still needed.
        offerCallback3.assertNoCallback()  // Still needed.
        offerCallback4.assertNoCallback()  // Still unneeded.

        // Verify that offer callbacks cannot receive any event if offer is unregistered.
        provider2.unregisterNetworkOffer(offerCallback4)
        agent.unregister()
        offerCallback1.assertNoCallback()  // Still unneeded.
        offerCallback2.assertNoCallback()  // Still needed.
        offerCallback3.assertNoCallback()  // Still needed.
        // Since the agent is unregistered, and the offer has chance to satisfy the request,
        // this callback should receive needed if it is not unregistered.
        offerCallback4.assertNoCallback()

        // Verify that offer callbacks cannot receive any event if provider is unregistered.
        mCm.unregisterNetworkProvider(provider)
        mCm.unregisterNetworkCallback(cb2)
        offerCallback1.assertNoCallback()  // No callback since it is still unneeded.
        offerCallback2.assertNoCallback()  // Should be unneeded if not unregistered.
        offerCallback3.assertNoCallback()  // Should be unneeded if not unregistered.
        offerCallback4.assertNoCallback()  // Already unregistered.

        // Clean up and Verify providers did not receive any callback during the entire test.
        mCm.unregisterNetworkProvider(provider2)
        provider.assertNoCallback()
        provider2.assertNoCallback()
    }

    private class TestNetworkCallback : ConnectivityManager.NetworkCallback() {
        private val seenEvents = ArrayTrackRecord<CallbackEntry>().newReadHead()
        sealed class CallbackEntry {
            object OnUnavailable : CallbackEntry()
        }

        override fun onUnavailable() {
            seenEvents.add(OnUnavailable)
        }

        inline fun <reified T : CallbackEntry> expectCallback(
            crossinline predicate: (T) -> Boolean
        ) = seenEvents.poll(DEFAULT_TIMEOUT_MS) { it is T && predicate(it) }
    }

    @Test
    fun testDeclareNetworkRequestUnfulfillable() {
        val mockContext = mock(Context::class.java)
        doReturn(mCm).`when`(mockContext).getSystemService(Context.CONNECTIVITY_SERVICE)
        val provider = createNetworkProvider(mockContext)
        // ConnectivityManager not required at creation time after R
        if (!isDevSdkInRange(0, Build.VERSION_CODES.R)) {
            verifyNoMoreInteractions(mockContext)
        }

        mCm.registerNetworkProvider(provider)

        val specifier = CompatUtil.makeTestNetworkSpecifier(
                UUID.randomUUID().toString())
        val nr: NetworkRequest = NetworkRequest.Builder()
                .addTransportType(TRANSPORT_TEST)
                .setNetworkSpecifier(specifier)
                .build()

        val cb = TestNetworkCallback()
        mCm.requestNetwork(nr, cb)
        provider.declareNetworkRequestUnfulfillable(nr)
        cb.expectCallback<OnUnavailable>() { nr.getNetworkSpecifier() == specifier }
        mCm.unregisterNetworkProvider(provider)
    }
}