WebRTC + TensorFlow Lite + Android


In this part, I am going to combine WebRTC app with TensorFlow Lite together so as to recognize the object in peer-to-peer(P2P) video communication. Before the implementation, we have to know how to build TensorFlow Lite. We follow the official instructions of TensorFlow website.
Let's take a look at result first:

WebRTC

WebRTC (Web Real-Time Communication) is a free, open-source project that provides web browsers and mobile applications with real-time communication (RTC) via simple application programming interfaces (APIs). It allows audio and video communication to work inside web pages by allowing direct peer-to-peer communication,

TensorFlow Lite

TensorFlow Lite is TensorFlow’s lightweight solution for mobile and embedded devices. It enables on-device machine learning inference with low latency and a small binary size. TensorFlow Lite uses many techniques for achieving low latency such as optimizing the kernels for mobile apps, pre-fused activations, and quantized kernels that allow smaller and faster (fixed-point math) models.

TensorFlow Lite Architecture
image credit from TensorFlow
TensorFlow Lite also supports hardware acceleration with the Android Neural Networks API. TensorFlow website has Developer Guide for developers to convert pre-trained model into TensorFlow mobile/lite. There are also example code of TensorFlow Mobile/Lite provided on GitHub and guides for building the apps.

Getting Started:

Build Environment:


  • Android Studio 3.x or higher
  • Java 8

We are going to use the following examples:
First, we clone the examples:
git clone https://github.com/IhorKlimov/Android-WebRtc.git
git clone https://github.com/tensorflow/tensorflow.git
We will MobileNet+SSD model provided in the example, so we add library of TensorFlow lite into Gradle.
apply plugin: 'com.android.application'
project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets'

assert file(project.ext.ASSET_DIR + "/mobilenet_quant_v1_224.tflite").exists()
assert file(project.ext.ASSET_DIR + "/mobilenet_ssd.tflite").exists()
assert file(project.ext.ASSET_DIR + "/labels.txt").exists()
...
android {
   dataBinding { 
      enabled true
   }
   ...
   aaptOptions {
      noCompress "tflite"
      noCompress "lite"
   }
}
...
dependencies {
   …
   implementation 'pub.devrel:easypermissions:1.1.3'
   implementation 'org.tensorflow:tensorflow-lite:+'
   implementation files('libs/autobanh.jar')
   implementation files('libs/base_java.jar')
   implementation files('libs/libjingle_peerconnection.jar')
}
It depends on what extension of model file you will use in the app. I add noCompress "lite" for an extra option. Next, we copy the pre-trained model and labels under tensorflow/tensorflow/contrib/lite/examples/android/assets/ into src/main/assets as below:

We still need to copy java files from TensorFlow Lite example to src/ folder.

Now we will put WebRTC together with TensorFlow Lite. I add two new files copied directly from WebRTC library: SurfaceViewRenderer and EglRenderer. Then change the package name in the code and layouts to the local path.
<org.appspot.apprtc.webrtc.SurfaceViewRenderer
                android:id="@+id/remote_video_view"
                android:layout_width="wrap_content"
                android:layout_height="wrap_content" />
...
<org.appspot.apprtc.webrtc.SurfaceViewRenderer
                android:id="@+id/local_video_view"
                android:layout_width="wrap_content"
                android:layout_height="wrap_content" />

In ClassifierActivity of TensorFlow Lite example, we can see that the bitmap of camera preview is passed to the model for classification.
@Override
    protected void processImage() {
        L.i(this);
        rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
        final Canvas canvas = new Canvas(croppedBitmap);
        canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);

        runInBackground(
                () -> {
                    // For examining the actual TF input.
                    if (SAVE_PREVIEW_BITMAP) {
                        ImageUtils.saveBitmap(croppedBitmap);
                    }
                    final long startTime = SystemClock.uptimeMillis();
                    final List results = classifier.recognizeImage(croppedBitmap);
                    lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
                    LOGGER.i("Detect: %s", results);
                    cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
                    if (resultsView == null) {
                        resultsView = findViewById(R.id.results);
                    }
                    resultsView.setResults(results);
                    requestRender();
                    readyForNextImage();
                });
    }
So we have to get the bitmap from WebRTC renderers in order to pass to the model. In EglRenderer class, it provides addFrameListener() to get bitmap by assigning callbacks, but it only runs once. So we commented out a few lines of code in notifyCallbacks() function in EglRenderer class because we want to get the bitmaps indefinitely through the callbacks.
private void notifyCallbacks(I420Frame frame, float[] texMatrix) {
        if (!this.frameListeners.isEmpty()) {
            ArrayList tmpList = new ArrayList(this.frameListeners);
//            this.frameListeners.clear();
            float[] bitmapMatrix = RendererCommon.multiplyMatrices(RendererCommon.multiplyMatrices(texMatrix,
                    this.mirror ? RendererCommon.horizontalFlipMatrix() : RendererCommon.identityMatrix()), RendererCommon.verticalFlipMatrix());
            Iterator i$ = tmpList.iterator();

//            while(true) {
            while (i$.hasNext()) {

                EglRenderer.FrameListenerAndParams listenerAndParams = (EglRenderer.FrameListenerAndParams) i$.next();
                int scaledWidth = (int) (listenerAndParams.scale * (float) frame.rotatedWidth());
                int scaledHeight = (int) (listenerAndParams.scale * (float) frame.rotatedHeight());
                if (scaledWidth != 0 && scaledHeight != 0) {
                    ...
                    ...
                    listenerAndParams.listener.onFrame(bitmap);
                } else {
                    listenerAndParams.listener.onFrame(null);
                }
            }
//                return;
//            }
        }
    }

In CallActivity, we add a few of new field variables for processing the bitmaps and reading the parameters of pre-trained model. We are going to use an worker thread by declaring HandlerThread for the model to recognize the objects. You can see how to use it in TensorFlow Lite example.
private static final int INPUT_SIZE = 224;
    private static final String MODEL_FILE = "mobilenet_quant_v1_224.tflite";
    private static final String LABEL_FILE = "labels_mobilenet_quant_v1_224.txt";
    private Classifier classifier;

    private Handler handler;
    private HandlerThread handlerThread;
    private static final boolean SAVE_BITMAP = false;
    private Boolean isProcessing = false;
    private Bitmap croppedBitmap = null;
    private long lastProcessingTimeMs;

Initialize the TensorFlow Lite in onCreate(). We implement the callback and assign it to addFrameListener() in SurfaceViewRenderer. You can rescale the size of the given bitmap by change the value of parameter scale.
    @Override
    public void onCreate(Bundle savedInstanceState) {
        binding.remoteVideoView.addFrameListener(bitmap -> {
            processAndRecognize(bitmap);
        }, 1.f);
        ...
        initTfLite();
    }

    private void initTfLite() {
        if (croppedBitmap == null)
            croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Bitmap.Config.ARGB_8888);

        if (classifier == null)
            classifier = TFLiteImageClassifier.create(getAssets(), MODEL_FILE, LABEL_FILE, INPUT_SIZE);
    }
Initialize the HandlerThread in onResume() and release the memory in onPause().
    @Override
    public synchronized void onResume() {
        super.onResume();

        handlerThread = new HandlerThread("inference");
        handlerThread.start();
        handler = new Handler(handlerThread.getLooper());
        ...
    }

    @Override
    public synchronized void onPause() {
        super.onPause();

        handlerThread.quitSafely();
        try {
            handlerThread.join();
            handlerThread = null;
            handler = null;
        } catch (final InterruptedException e) {
            L.e(this, e.toString());
        }

        ...
    }

    protected synchronized void runInBackground(final Runnable r) {
        if (handler != null) {
            handler.post(r);
        }
    }
Next step, we add a new function called processAndRecognize() for recognition. Once we received the bitmap, we do the center crop along with the bitmap as image below to ensure that the model will see the object.


You can enable SAVE_BITMAP to see the result of bitmap, but it may degrade the performance. You can also scale down the size of bitmap to make the speed of classification much faster.
private void processAndRecognize(Bitmap srcBitmap) {
        if (isProcessing) {
            return;
        }

        isProcessing = true;

        //crop the center of bitmap
        Bitmap dstBitmap;
        if (srcBitmap.getWidth() >= srcBitmap.getHeight()) {
            dstBitmap = Bitmap.createBitmap(srcBitmap, srcBitmap.getWidth()/2 - srcBitmap.getHeight()/2, 0,
                    srcBitmap.getHeight(), srcBitmap.getHeight()
            );

        } else {
            dstBitmap = Bitmap.createBitmap(srcBitmap, 0, srcBitmap.getHeight()/2 - srcBitmap.getWidth()/2,
                    srcBitmap.getWidth(), srcBitmap.getWidth()
            );
        }

        Matrix frameToCropTransform = ImageUtils.getTransformationMatrix(dstBitmap.getWidth(), dstBitmap.getHeight(),
                INPUT_SIZE, INPUT_SIZE, 0, true);

        Matrix cropToFrameTransform = new Matrix();
        frameToCropTransform.invert(cropToFrameTransform);

        final Canvas canvas = new Canvas(croppedBitmap);
        canvas.drawBitmap(dstBitmap, frameToCropTransform, null);

        //enable this for analyzing the bitmaps, but it may degrade the performance
        if (SAVE_BITMAP) {
            ImageUtils.saveBitmap(srcBitmap, "remote_raw.png");
            ImageUtils.saveBitmap(dstBitmap, "remote_rawCrop.png");
            ImageUtils.saveBitmap(croppedBitmap, "remote_crop.png");
        }

        runInBackground(() -> {
            //pass bitmap to TFLite model
            final long startTime = SystemClock.uptimeMillis();
            final List results = classifier.recognizeImage(croppedBitmap);
            lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;

            //update views
            binding.results.setResults(results);
            isProcessing = false;
        });

        // recycle bitmaps
        dstBitmap.recycle();
    }
Remember to release the object you created.
@Override
    protected synchronized void onDestroy() {
        disconnect();
        ...
        croppedBitmap.recycle();
        croppedBitmap = null;

        super.onDestroy();
    }
Now we are good to go. See the demo below.

Add Zooming:

For more features, I'd like to add pinch-to-zoom and doubletap-to-zoom into WebRTC + TensorFlow Lite example, so we will need ScaleDetector and GestureDetector. See Detect Common Gestures for more information.
Let's start with EglRenderer class. Add a new field variable, setter and getter for updating the scale instantly.
private float surfaceViewScale = 1.0f;

    public void setSurfaceViewScale(final float scale) {
        this.surfaceViewScale = scale;
    }

    public float getSurfaceViewScale() {
        return surfaceViewScale;
    }
Second, we can find this.drawer.drawOes() which draws the texture and controls the position and size of surfaceview. We have to do some computation on the size of surfaceview to get proper x/y location. In renderFrameOnRenderThread() function,
float newWidth = eglBase.surfaceWidth() * surfaceViewScale;
float newHeight = eglBase.surfaceHeight() * surfaceViewScale;
int newX = -(int) (Math.abs(newWidth - eglBase.surfaceWidth()) / 2.f);
int newY = -(int) (Math.abs(newHeight - eglBase.surfaceHeight()) / 2.f);
this.drawer.drawOes(frame.textureId, drawMatrix, drawnFrameWidth, drawnFrameHeight, newX, newY, (int)newWidth, (int)newHeight);
Third, let's add gesture detectors in SurfaceViewRenderer class. Add some field variables and initialize them.
private ScaleGestureDetector scaleGestureDetector;
    private GestureDetector gestureDetector;
    private Context context;
    private boolean isSingleTouch;
    private float zoomScale = 1f;
    private float minScale = 1f;
    private float maxScale = 2.5f;
    private float dX, dY;
    private static final boolean MOVE_VIEW = false;
    private ValueAnimator zoomAnimator;
    private boolean isZoomed = false;

    public void init(EglBase.Context sharedContext, RendererEvents rendererEvents, int[] configAttributes, GlDrawer drawer) {
        ...
        scaleGestureDetector = new ScaleGestureDetector(context, new ScaleListener());
        gestureDetector = new GestureDetector(context, new GestureListener());
    }
Fourth, we override the onTouchEvent() method and pass events to the detectors. If you enable MOVE_VIEW, you are able to move the entire SurfaceView.
    @Override
    public boolean onTouchEvent(MotionEvent event) {
        super.onTouchEvent(event);
        scaleGestureDetector.onTouchEvent(event);
        gestureDetector.onTouchEvent(event);
        if (MOVE_VIEW) {
            if (event.getPointerCount() > 1) {
                isSingleTouch = false;
            } else {
                if (event.getAction() == MotionEvent.ACTION_UP) {
                    isSingleTouch = true;
                }
            }
            switch (event.getAction()) {
                case MotionEvent.ACTION_DOWN:
                    dX = this.getX() - event.getRawX();
                    dY = this.getY() - event.getRawY();
                    break;

                case MotionEvent.ACTION_MOVE:
                    if (isSingleTouch) {
                        this.animate().x(event.getRawX() + dX)
                                      .y(event.getRawY() + dY)
                                      .setDuration(0)
                                      .start();
                    }
                    break;
                default:
                    return true;
            }
        }
        return true;
    }
Final, let's implement gesture detectors. We'll see we pass the zoomScale to the setter setSurfaceViewScale() we added in EglRenderer. We animate the zooming by using ValueAnimator to update eglRenderer.setSurfaceViewScale(zoomScale) immediately when we double-tap. Cancel the animation and check min/max values before pinch-to-zooming.
 private class ScaleListener extends ScaleGestureDetector.SimpleOnScaleGestureListener {
        @Override
        public boolean onScale(ScaleGestureDetector detector) {
            if (zoomAnimator != null && zoomAnimator.isRunning())
                zoomAnimator.cancel();
            zoomScale *= detector.getScaleFactor();
            zoomScale = Math.max(minScale, Math.min(zoomScale, maxScale));

            eglRenderer.setSurfaceViewScale(zoomScale);
            isZoomed = zoomScale > minScale;
            return true;
        }
    }

    private class GestureListener extends GestureDetector.SimpleOnGestureListener {
        @Override
        public boolean onSingleTapConfirmed(MotionEvent e){
            return performClick();
        }

        @Override
        public void onLongPress(MotionEvent e) {
            performLongClick();
        }

        @Override
        public boolean onFling(MotionEvent e1, MotionEvent e2, float velocityX, float velocityY) {
            return super.onFling(e1, e2, velocityX, velocityY);
        }

        @Override
        public boolean onDoubleTap(MotionEvent e) {
            if (zoomAnimator != null && zoomAnimator.isRunning())
                zoomAnimator.cancel();

            float start = (isZoomed) ? zoomScale : minScale;
            float end = (isZoomed) ? minScale : Math.max(zoomScale, maxScale);
            zoomAnimator = ValueAnimator.ofFloat(start, end);
            zoomAnimator.addUpdateListener(valueAnimator -> {
                zoomScale = (float) valueAnimator.getAnimatedValue();
                eglRenderer.setSurfaceViewScale(zoomScale);
                isZoomed = zoomScale > minScale;
            });
            zoomAnimator.setDuration(300);
            zoomAnimator.start();

            return super.onDoubleTap(e);
        }

        @Override
        public boolean onDoubleTapEvent(MotionEvent e) {
            return false;
        }
    }

Let's see the effects. You can see we can zoom each SurfaceView properly.




References:

Share:

1 則留言: