/* * Copyright (C) 2016 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.android.server.net; import static android.app.usage.NetworkStatsManager.MIN_THRESHOLD_BYTES; import android.annotation.NonNull; import android.app.usage.NetworkStatsManager; import android.content.Context; import android.content.pm.PackageManager; import android.net.DataUsageRequest; import android.net.NetworkIdentitySet; import android.net.NetworkStack; import android.net.NetworkStats; import android.net.NetworkStatsAccess; import android.net.NetworkStatsCollection; import android.net.NetworkStatsHistory; import android.net.NetworkTemplate; import android.net.netstats.IUsageCallback; import android.os.Handler; import android.os.HandlerThread; import android.os.IBinder; import android.os.Looper; import android.os.Message; import android.os.Process; import android.os.RemoteException; import android.util.ArrayMap; import android.util.IndentingPrintWriter; import android.util.Log; import android.util.SparseArray; import com.android.internal.annotations.VisibleForTesting; import com.android.net.module.util.PerUidCounter; import java.util.concurrent.atomic.AtomicInteger; /** * Manages observers of {@link NetworkStats}. Allows observers to be notified when * data usage has been reported in {@link NetworkStatsService}. An observer can set * a threshold of how much data it cares about to be notified. */ class NetworkStatsObservers { private static final String TAG = "NetworkStatsObservers"; private static final boolean LOG = true; private static final boolean LOGV = false; private static final int MSG_REGISTER = 1; private static final int MSG_UNREGISTER = 2; private static final int MSG_UPDATE_STATS = 3; private static final int DUMP_USAGE_REQUESTS_COUNT = 200; // The maximum number of request allowed per uid before an exception is thrown. @VisibleForTesting static final int MAX_REQUESTS_PER_UID = 100; // All access to this map must be done from the handler thread. // indexed by DataUsageRequest#requestId private final SparseArray mDataUsageRequests = new SparseArray<>(); // Request counters per uid, this is thread safe. private final PerUidCounter mDataUsageRequestsPerUid = new PerUidCounter(MAX_REQUESTS_PER_UID); // Sequence number of DataUsageRequests private final AtomicInteger mNextDataUsageRequestId = new AtomicInteger(); // Lazily instantiated when an observer is registered. private volatile Handler mHandler; /** * Creates a wrapper that contains the caller context and a normalized request. * The request should be returned to the caller app, and the wrapper should be sent to this * object through #addObserver by the service handler. * *

It will register the observer asynchronously, so it is safe to call from any thread. * * @return the normalized request wrapped within {@link RequestInfo}. */ public DataUsageRequest register(@NonNull Context context, @NonNull DataUsageRequest inputRequest, @NonNull IUsageCallback callback, int callingPid, int callingUid, @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) { DataUsageRequest request = buildRequest(context, inputRequest, callingUid); RequestInfo requestInfo = buildRequestInfo(request, callback, callingPid, callingUid, callingPackage, accessLevel); if (LOG) Log.d(TAG, "Registering observer for " + requestInfo); mDataUsageRequestsPerUid.incrementCountOrThrow(callingUid); getHandler().sendMessage(mHandler.obtainMessage(MSG_REGISTER, requestInfo)); return request; } /** * Unregister a data usage observer. * *

It will unregister the observer asynchronously, so it is safe to call from any thread. */ public void unregister(DataUsageRequest request, int callingUid) { getHandler().sendMessage(mHandler.obtainMessage(MSG_UNREGISTER, callingUid, 0 /* ignore */, request)); } /** * Updates data usage statistics of registered observers and notifies if limits are reached. * *

It will update stats asynchronously, so it is safe to call from any thread. */ public void updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot, ArrayMap activeIfaces, ArrayMap activeUidIfaces, long currentTime) { StatsContext statsContext = new StatsContext(xtSnapshot, uidSnapshot, activeIfaces, activeUidIfaces, currentTime); getHandler().sendMessage(mHandler.obtainMessage(MSG_UPDATE_STATS, statsContext)); } private Handler getHandler() { if (mHandler == null) { synchronized (this) { if (mHandler == null) { if (LOGV) Log.v(TAG, "Creating handler"); mHandler = new Handler(getHandlerLooperLocked(), mHandlerCallback); } } } return mHandler; } @VisibleForTesting protected Looper getHandlerLooperLocked() { HandlerThread handlerThread = new HandlerThread(TAG); handlerThread.start(); return handlerThread.getLooper(); } private Handler.Callback mHandlerCallback = new Handler.Callback() { @Override public boolean handleMessage(Message msg) { switch (msg.what) { case MSG_REGISTER: { handleRegister((RequestInfo) msg.obj); return true; } case MSG_UNREGISTER: { handleUnregister((DataUsageRequest) msg.obj, msg.arg1 /* callingUid */); return true; } case MSG_UPDATE_STATS: { handleUpdateStats((StatsContext) msg.obj); return true; } default: { return false; } } } }; /** * Adds a {@link RequestInfo} as an observer. * Should only be called from the handler thread otherwise there will be a race condition * on mDataUsageRequests. */ private void handleRegister(RequestInfo requestInfo) { mDataUsageRequests.put(requestInfo.mRequest.requestId, requestInfo); } /** * Removes a {@link DataUsageRequest} if the calling uid is authorized. * Should only be called from the handler thread otherwise there will be a race condition * on mDataUsageRequests. */ private void handleUnregister(DataUsageRequest request, int callingUid) { RequestInfo requestInfo; requestInfo = mDataUsageRequests.get(request.requestId); if (requestInfo == null) { if (LOG) Log.d(TAG, "Trying to unregister unknown request " + request); return; } if (Process.SYSTEM_UID != callingUid && requestInfo.mCallingUid != callingUid) { Log.w(TAG, "Caller uid " + callingUid + " is not owner of " + request); return; } if (LOG) Log.d(TAG, "Unregistering " + requestInfo); mDataUsageRequests.remove(request.requestId); mDataUsageRequestsPerUid.decrementCountOrThrow(requestInfo.mCallingUid); requestInfo.unlinkDeathRecipient(); requestInfo.callCallback(NetworkStatsManager.CALLBACK_RELEASED); } private void handleUpdateStats(StatsContext statsContext) { if (mDataUsageRequests.size() == 0) { return; } for (int i = 0; i < mDataUsageRequests.size(); i++) { RequestInfo requestInfo = mDataUsageRequests.valueAt(i); requestInfo.updateStats(statsContext); } } private DataUsageRequest buildRequest(Context context, DataUsageRequest request, int callingUid) { // For non-NETWORK_STACK permission uid, cap the minimum threshold to a safe default to // avoid too many callbacks. final long thresholdInBytes = (context.checkPermission( NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK, Process.myPid(), callingUid) == PackageManager.PERMISSION_GRANTED ? request.thresholdInBytes : Math.max(MIN_THRESHOLD_BYTES, request.thresholdInBytes)); if (thresholdInBytes > request.thresholdInBytes) { Log.w(TAG, "Threshold was too low for " + request + ". Overriding to a safer default of " + thresholdInBytes + " bytes"); } return new DataUsageRequest(mNextDataUsageRequestId.incrementAndGet(), request.template, thresholdInBytes); } private RequestInfo buildRequestInfo(DataUsageRequest request, IUsageCallback callback, int callingPid, int callingUid, @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) { if (accessLevel <= NetworkStatsAccess.Level.USER) { return new UserUsageRequestInfo(this, request, callback, callingPid, callingUid, callingPackage, accessLevel); } else { // Safety check in case a new access level is added and we forgot to update this if (accessLevel < NetworkStatsAccess.Level.DEVICESUMMARY) { throw new IllegalArgumentException( "accessLevel " + accessLevel + " is less than DEVICESUMMARY."); } return new NetworkUsageRequestInfo(this, request, callback, callingPid, callingUid, callingPackage, accessLevel); } } /** * Tracks information relevant to a data usage observer. * It will notice when the calling process dies so we can self-expire. */ private abstract static class RequestInfo implements IBinder.DeathRecipient { private final NetworkStatsObservers mStatsObserver; protected final DataUsageRequest mRequest; private final IUsageCallback mCallback; protected final int mCallingPid; protected final int mCallingUid; protected final String mCallingPackage; protected final @NetworkStatsAccess.Level int mAccessLevel; protected NetworkStatsRecorder mRecorder; protected NetworkStatsCollection mCollection; RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, IUsageCallback callback, int callingPid, int callingUid, @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) { mStatsObserver = statsObserver; mRequest = request; mCallback = callback; mCallingPid = callingPid; mCallingUid = callingUid; mCallingPackage = callingPackage; mAccessLevel = accessLevel; try { mCallback.asBinder().linkToDeath(this, 0); } catch (RemoteException e) { binderDied(); } } @Override public void binderDied() { if (LOGV) { Log.v(TAG, "RequestInfo binderDied(" + mRequest + ", " + mCallback + ")"); } mStatsObserver.unregister(mRequest, Process.SYSTEM_UID); callCallback(NetworkStatsManager.CALLBACK_RELEASED); } @Override public String toString() { return "RequestInfo from pid/uid:" + mCallingPid + "/" + mCallingUid + "(" + mCallingPackage + ")" + " for " + mRequest + " accessLevel:" + mAccessLevel; } private void unlinkDeathRecipient() { mCallback.asBinder().unlinkToDeath(this, 0); } /** * Update stats given the samples and interface to identity mappings. */ private void updateStats(StatsContext statsContext) { if (mRecorder == null) { // First run; establish baseline stats resetRecorder(); recordSample(statsContext); return; } recordSample(statsContext); if (checkStats()) { resetRecorder(); callCallback(NetworkStatsManager.CALLBACK_LIMIT_REACHED); } } private void callCallback(int callbackType) { try { if (LOGV) { Log.v(TAG, "sending notification " + callbackTypeToName(callbackType) + " for " + mRequest); } switch (callbackType) { case NetworkStatsManager.CALLBACK_LIMIT_REACHED: mCallback.onThresholdReached(mRequest); break; case NetworkStatsManager.CALLBACK_RELEASED: mCallback.onCallbackReleased(mRequest); break; } } catch (RemoteException e) { // May occur naturally in the race of binder death. Log.w(TAG, "RemoteException caught trying to send a callback msg for " + mRequest); } } private void resetRecorder() { mRecorder = new NetworkStatsRecorder(); mCollection = mRecorder.getSinceBoot(); } protected abstract boolean checkStats(); protected abstract void recordSample(StatsContext statsContext); private String callbackTypeToName(int callbackType) { switch (callbackType) { case NetworkStatsManager.CALLBACK_LIMIT_REACHED: return "LIMIT_REACHED"; case NetworkStatsManager.CALLBACK_RELEASED: return "RELEASED"; default: return "UNKNOWN"; } } } private static class NetworkUsageRequestInfo extends RequestInfo { NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, IUsageCallback callback, int callingPid, int callingUid, @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) { super(statsObserver, request, callback, callingPid, callingUid, callingPackage, accessLevel); } @Override protected boolean checkStats() { long bytesSoFar = getTotalBytesForNetwork(mRequest.template); if (LOGV) { Log.v(TAG, bytesSoFar + " bytes so far since notification for " + mRequest.template); } if (bytesSoFar > mRequest.thresholdInBytes) { return true; } return false; } @Override protected void recordSample(StatsContext statsContext) { // Recorder does not need to be locked in this context since only the handler // thread will update it. We pass a null VPN array because usage is aggregated by uid // for this snapshot, so VPN traffic can't be reattributed to responsible apps. mRecorder.recordSnapshotLocked(statsContext.mXtSnapshot, statsContext.mActiveIfaces, statsContext.mCurrentTime); } /** * Reads stats matching the given template. {@link NetworkStatsCollection} will aggregate * over all buckets, which in this case should be only one since we built it big enough * that it will outlive the caller. If it doesn't, then there will be multiple buckets. */ private long getTotalBytesForNetwork(NetworkTemplate template) { NetworkStats stats = mCollection.getSummary(template, Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */, mAccessLevel, mCallingUid); return stats.getTotalBytes(); } } private static class UserUsageRequestInfo extends RequestInfo { UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, IUsageCallback callback, int callingPid, int callingUid, @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) { super(statsObserver, request, callback, callingPid, callingUid, callingPackage, accessLevel); } @Override protected boolean checkStats() { int[] uidsToMonitor = mCollection.getRelevantUids(mAccessLevel, mCallingUid); for (int i = 0; i < uidsToMonitor.length; i++) { long bytesSoFar = getTotalBytesForNetworkUid(mRequest.template, uidsToMonitor[i]); if (bytesSoFar > mRequest.thresholdInBytes) { return true; } } return false; } @Override protected void recordSample(StatsContext statsContext) { // Recorder does not need to be locked in this context since only the handler // thread will update it. We pass the VPN info so VPN traffic is reattributed to // responsible apps. mRecorder.recordSnapshotLocked(statsContext.mUidSnapshot, statsContext.mActiveUidIfaces, statsContext.mCurrentTime); } /** * Reads all stats matching the given template and uid. Ther history will likely only * contain one bucket per ident since we build it big enough that it will outlive the * caller lifetime. */ private long getTotalBytesForNetworkUid(NetworkTemplate template, int uid) { try { NetworkStatsHistory history = mCollection.getHistory(template, null, uid, NetworkStats.SET_ALL, NetworkStats.TAG_NONE, NetworkStatsHistory.FIELD_ALL, Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */, mAccessLevel, mCallingUid); return history.getTotalBytes(); } catch (SecurityException e) { if (LOGV) { Log.w(TAG, "CallerUid " + mCallingUid + " may have lost access to uid " + uid); } return 0; } } } private static class StatsContext { NetworkStats mXtSnapshot; NetworkStats mUidSnapshot; ArrayMap mActiveIfaces; ArrayMap mActiveUidIfaces; long mCurrentTime; StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot, ArrayMap activeIfaces, ArrayMap activeUidIfaces, long currentTime) { mXtSnapshot = xtSnapshot; mUidSnapshot = uidSnapshot; mActiveIfaces = activeIfaces; mActiveUidIfaces = activeUidIfaces; mCurrentTime = currentTime; } } public void dump(IndentingPrintWriter pw) { for (int i = 0; i < Math.min(mDataUsageRequests.size(), DUMP_USAGE_REQUESTS_COUNT); i++) { pw.println(mDataUsageRequests.valueAt(i)); } } }