Skip to content

Implementing SSLSocket and SSLSocketFactory #998

@FD-

Description

@FD-

I'm trying to implement an SSLSocketFactory to use within a j2objc project. I decided to use the Secure Transport API. In theory, it seems like the API provides exactly what I was looking for: It offers a TLS layer that, by means of callbacks, can be used with any lower layer channel. I don't need any of the customisability, I just want to be able to secure a Socket over TLS/SSL to connect to an HTTPS server for TCP-level communication.

I created a WrappedSSLSocket class in Java, that just extends from the SSLSocket class and forwards all method calls to an underlying Socket.

WrappedSSLSocket.java

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketException;
import java.nio.channels.SocketChannel;

import javax.net.ssl.HandshakeCompletedListener;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;

public class WrappedSSLSocket extends SSLSocket {
    protected Socket underlyingSocket;

    public WrappedSSLSocket(Socket underlying){
        this.underlyingSocket = underlying;
    }

    public Socket getUnderlyingSocket(){
        return underlyingSocket;
    }

    public void connect(SocketAddress endpoint) throws IOException {
        throw new RuntimeException("Stub!");
    }

    public void connect(SocketAddress endpoint, int timeout) throws IOException {
        underlyingSocket.connect(endpoint, timeout);
    }

    public void bind(SocketAddress bindpoint) throws IOException {
        underlyingSocket.bind(bindpoint);
    }

    public InetAddress getInetAddress() {
        return underlyingSocket.getInetAddress();
    }

    public InetAddress getLocalAddress() {
        return underlyingSocket.getLocalAddress();
    }

    public int getPort() {
        return underlyingSocket.getPort();
    }

    public int getLocalPort() {
        return underlyingSocket.getLocalPort();
    }

    public SocketAddress getRemoteSocketAddress() {
        return underlyingSocket.getRemoteSocketAddress();
    }

    public SocketAddress getLocalSocketAddress() {
        return underlyingSocket.getLocalSocketAddress();
    }

    public SocketChannel getChannel() {
        return underlyingSocket.getChannel();
    }

    public InputStream getInputStream() throws IOException {
        return underlyingSocket.getInputStream();
    }

    public OutputStream getOutputStream() throws IOException {
        return underlyingSocket.getOutputStream();
    }

    public void setTcpNoDelay(boolean on) throws SocketException {
        underlyingSocket.setTcpNoDelay(on);
    }

    public boolean getTcpNoDelay() throws SocketException {
        return underlyingSocket.getTcpNoDelay();
    }

    public void setSoLinger(boolean on, int linger) throws SocketException {
        underlyingSocket.setSoLinger(on, linger);
    }

    public int getSoLinger() throws SocketException {
        return underlyingSocket.getSoLinger();
    }

    public void sendUrgentData(int data) throws IOException {
        underlyingSocket.sendUrgentData(data);
    }

    public void setOOBInline(boolean on) throws SocketException {
        underlyingSocket.setOOBInline(on);
    }

    public boolean getOOBInline() throws SocketException {
        return underlyingSocket.getOOBInline();
    }

    public synchronized void setSoTimeout(int timeout) throws SocketException {
        underlyingSocket.setSoTimeout(timeout);
    }

    public synchronized int getSoTimeout() throws SocketException {
        return underlyingSocket.getSoTimeout();
    }

    public synchronized void setSendBufferSize(int size) throws SocketException {
        underlyingSocket.setSendBufferSize(size);
    }

    public synchronized int getSendBufferSize() throws SocketException {
        return underlyingSocket.getSendBufferSize();
    }

    public synchronized void setReceiveBufferSize(int size) throws SocketException {
        underlyingSocket.setSendBufferSize(size);
    }

    public synchronized int getReceiveBufferSize() throws SocketException {
        return underlyingSocket.getReceiveBufferSize();
    }

    public void setKeepAlive(boolean on) throws SocketException {
        underlyingSocket.setKeepAlive(on);
    }

    public boolean getKeepAlive() throws SocketException {
        return underlyingSocket.getKeepAlive();
    }

    public void setTrafficClass(int tc) throws SocketException {
        underlyingSocket.setTrafficClass(tc);
    }

    public int getTrafficClass() throws SocketException {
        return underlyingSocket.getTrafficClass();
    }

    public void setReuseAddress(boolean on) throws SocketException {
        underlyingSocket.setReuseAddress(on);
    }

    public boolean getReuseAddress() throws SocketException {
        return underlyingSocket.getReuseAddress();
    }

    public synchronized void close() throws IOException {
        underlyingSocket.close();
    }

    public void shutdownInput() throws IOException {
        underlyingSocket.shutdownInput();
    }

    public void shutdownOutput() throws IOException {
        underlyingSocket.shutdownOutput();
    }

    public String toString() {
        return "WrappedSocket (" + underlyingSocket.toString() + ")";
    }

    public boolean isConnected() {
        return underlyingSocket.isConnected();
    }

    public boolean isBound() {
        return underlyingSocket.isBound();
    }

    public boolean isClosed() {
        return underlyingSocket.isClosed();
    }

    public boolean isInputShutdown() {
        return underlyingSocket.isInputShutdown();
    }

    public boolean isOutputShutdown() {
        return underlyingSocket.isOutputShutdown();
    }

    public void setPerformancePreferences(int connectionTime, int latency, int bandwidth) {
        underlyingSocket.setPerformancePreferences(connectionTime, latency, bandwidth);
    }

    /* SSLSocket implementation */

    @Override
    public String[] getSupportedCipherSuites() {
        // Not implemented
        return new String[0];
    }

    @Override
    public String[] getEnabledCipherSuites() {
        // Not implemented
        return new String[0];
    }

    @Override
    public void setEnabledCipherSuites(String[] strings) {
        // Not implemented
    }

    @Override
    public String[] getSupportedProtocols() {
        // Not implemented
        return new String[0];
    }

    @Override
    public String[] getEnabledProtocols() {
        // Not implemented
        return new String[0];
    }

    @Override
    public void setEnabledProtocols(String[] strings) {
        // Not implemented
    }

    @Override
    public SSLSession getSession() {
        // Not implemented
        return null;
    }

    @Override
    public void addHandshakeCompletedListener(HandshakeCompletedListener handshakeCompletedListener) {
        // Not implemented
    }

    @Override
    public void removeHandshakeCompletedListener(HandshakeCompletedListener handshakeCompletedListener) {
        // Not implemented
    }

    @Override
    public void startHandshake() throws IOException {

    }

    @Override
    public void setUseClientMode(boolean b) {
        // Not implemented
    }

    @Override
    public boolean getUseClientMode() {
        return true;
    }

    @Override
    public void setNeedClientAuth(boolean b) {
        // Not implemented
    }

    @Override
    public boolean getNeedClientAuth() {
        // Not implemented
        return false;
    }

    @Override
    public void setWantClientAuth(boolean b) {
        // Not implemented
    }

    @Override
    public boolean getWantClientAuth() {
        // Not implemented
        return false;
    }

    @Override
    public void setEnableSessionCreation(boolean b) {
        // Not implemented
    }

    @Override
    public boolean getEnableSessionCreation() {
        // Not implemented
        return false;
    }

I then created a Swift class that extends from the j2objc-translated WrappedSSLSocket class. There, I create an SSLContext and set up its IO callbacks to read and write via the underlying Socket. From the doHandshake method, I call SSLHandshake.

iOSSSLSocketFactory.swift

import Foundation
import Security

class iOSSSLInputStream : JavaIoInputStream {
  unowned let sslSocket : iOSSSLSocket
  
  init(sslSocket : iOSSSLSocket) {
    self.sslSocket = sslSocket
    super.init()
  }
  
  override func read(with b: IOSByteArray!, with off: jint, with len: jint) -> jint {
    if (sslSocket.isClosed() || (sslSocket.underlyingSocket?.isClosed())!){
      ObjC.throwException(JavaIoIOException(nsString: "Cannot read from closed socket"))
    }
    
    if (sslSocket.isInputShutdown() || (sslSocket.underlyingSocket?.isInputShutdown())!){
      ObjC.throwException(JavaIoIOException(nsString: "Cannot read from shutdown socket"))
    }
    
    let unsafePointer = UnsafeMutableRawPointer(b.byteRef(at: UInt(off)))
    var actuallyRead : Int = 0
    let status = SSLRead(sslSocket.getSSLContext()!, unsafePointer!, Int(len), &actuallyRead)
    if (status == errSecSuccess){
      return jint(actuallyRead)
    } else {
      ObjC.throwException(JavaIoIOException(nsString: "Error reading from SSL: " + String(status)))
    }
    return jint(actuallyRead)
  }
  
  override func read(with b: IOSByteArray!) -> jint {
    return read(with: b, with: 0, with: b.length())
  }
  
  override func read() -> jint {
    let buffer = IOSByteArray.newArray(withLength: 1)
    let actuallyRead = read(with: buffer)
    if (actuallyRead == 1) {
      let resultByte : jbyte = (buffer?.byte(at: 0))!
      return jint(resultByte)
    } else {
      return actuallyRead
    }
  }
  
  override func close() {
    sslSocket.close()
  }
  
  override func available() -> jint {
    // Not supported
    return 0;
  }
  
  override func mark(with readlimit: jint) {
    // Not supported
  }
  
  override func markSupported() -> jboolean {
    return false
  }
  
  override func reset() {
    // Not supported
  }
  
  override func skip(withLong n: jlong) -> jlong {
    // Not supported
    return 0
  }
}

class iOSSSLOutputStream : JavaIoOutputStream {
  unowned let sslSocket : iOSSSLSocket
  
  init(sslSocket : iOSSSLSocket) {
    self.sslSocket = sslSocket
    super.init()
  }
  
  override func write(with b: IOSByteArray!, with off: jint, with len: jint) {
    if (sslSocket.isClosed() || (sslSocket.underlyingSocket?.isClosed())!){
      ObjC.throwException(JavaIoIOException(nsString: "Cannot write to closed socket"))
    }
    
    if (sslSocket.isOutputShutdown() || (sslSocket.underlyingSocket?.isOutputShutdown())!){
      ObjC.throwException(JavaIoIOException(nsString: "Cannot write to shutdown socket"))
    }
    
    var unsafePointer = UnsafeMutableRawPointer(b.byteRef(at: UInt(off)))
    var actuallyWrote : Int = 0
    var remaining : Int = Int(len);
    
    while remaining > 0 {
      let status = SSLWrite(sslSocket.getSSLContext()!, unsafePointer, remaining, &actuallyWrote)
      if (status == noErr){
        unsafePointer = unsafePointer?.advanced(by: actuallyWrote)
        remaining -= actuallyWrote
      } else {
        ObjC.throwException(JavaIoIOException(nsString: "Error writing to SSL: " + String(status)))
      }
    }
  }
  
  override func write(with b: IOSByteArray!) {
    write(with: b, with: 0, with: b.length())
  }
  
  override func write(with b: jint) {
    let buffer = IOSByteArray.newArray(withLength: 1)
    buffer?.replaceByte(at: 0, withByte: jbyte(b))
    write(with: buffer)
  }
  
  override func close() {
    sslSocket.close()
  }
  
  override func flush() {
    sslSocket.getUnderlyingSocket().getOutputStream().flush()
  }
}

class iOSSSLSocket : WrappedSSLSocket{
  var sslContext : SSLContext?
  var inputStream : iOSSSLInputStream?
  var outputStream : iOSSSLOutputStream?
  var underlyingSocket : JavaNetSocket?
  
  init(underlyingSocket: JavaNetSocket, hostName: String) {
    self.underlyingSocket = underlyingSocket
    self.sslContext = SSLCreateContext(nil, SSLProtocolSide.clientSide, SSLConnectionType.streamType)
  
    super.init(javaNetSocket: underlyingSocket)
    
    self.inputStream = iOSSSLInputStream(sslSocket: self)
    self.outputStream = iOSSSLOutputStream(sslSocket: self)
  
    var status = noErr
    status = SSLSetIOFuncs(sslContext!, sslReadCallback, sslWriteCallback)
    if (status != noErr) {
      ObjC.throwException(JavaIoIOException(nsString: "Error setting IO functions: " + String(status)))
    }
    
    let ref : SSLConnectionRef = UnsafeRawPointer(Unmanaged.passUnretained(self).toOpaque())
    status = SSLSetConnection(sslContext!, ref)
    if (status != noErr) {
      ObjC.throwException(JavaIoIOException(nsString: "Error setting connection data: " + String(status)))
    }
    
    status = SSLSetPeerDomainName(sslContext!, hostName, hostName.lengthOfBytes(using: String.Encoding.utf8))
    if (status != noErr) {
      ObjC.throwException(JavaIoIOException(nsString: "Error setting domain name: " + String(status)))
    }
  }
  
  override func startHandshake() {
    let status = SSLHandshake(sslContext!)
    if (status != noErr) {
      let cfDescription : CFString = SecCopyErrorMessageString(status, nil)!
      let description = cfDescription as String
      ObjC.throwException(JavaIoIOException(nsString: "Handshake error: " + String(status) + " - " + description))
    }
  }
  
  func getSSLContext() -> SSLContext?{
    return sslContext
  }
  
  var sslReadCallback : @convention(c) (SSLConnectionRef, UnsafeMutableRawPointer, UnsafeMutablePointer<Int>) -> OSStatus =
  {(connection: SSLConnectionRef, data: UnsafeMutableRawPointer, size: UnsafeMutablePointer<Int>) -> OSStatus in
    
    let socket = Unmanaged<iOSSSLSocket>.fromOpaque(connection).takeUnretainedValue()
    
    let javaByteBuffer = IOSByteArray.newArray(withLength: UInt(size.pointee))
    var result = noErr
    
    do {
      try ObjC.catchException {
        let actuallyRead = socket.getUnderlyingSocket().getInputStream().read(with: javaByteBuffer)
        if (actuallyRead >= 0) {
          let jbytePointer = data.bindMemory(to: jbyte.self, capacity: size.pointee)
          javaByteBuffer?.getBytes(jbytePointer, length:UInt(actuallyRead))
          size.pointee = Int(actuallyRead)
        } else {
          result = errSSLBadCert
        }
      }
    } catch {
      result = errSSLClosedAbort
    }
    return result
  }
  
  var sslWriteCallback : @convention(c) (SSLConnectionRef, UnsafeRawPointer, UnsafeMutablePointer<Int>) -> OSStatus =
  {(connection: SSLConnectionRef, data: UnsafeRawPointer, size: UnsafeMutablePointer<Int>) -> OSStatus in

    let socket = Unmanaged<iOSSSLSocket>.fromOpaque(connection).takeUnretainedValue()

    let jbytePointer = data.bindMemory(to: jbyte.self, capacity: size.pointee)
    let javaByteBuffer = IOSByteArray.newArray(withBytes:jbytePointer, count: UInt(size.pointee))
    
    do {
      try ObjC.catchException {
        socket.getUnderlyingSocket().getOutputStream().write(with: javaByteBuffer)
        socket.getUnderlyingSocket().getOutputStream().flush()
      }
    } catch {
      return errSSLClosedAbort
    }
    return noErr
  }
  
  override func close() {
    if (isClosed()) {
      ObjC.throwException(JavaIoIOException(nsString: "Already closed"))
    }
    
    SSLClose(sslContext!)
    
    underlyingSocket?.close()
  }
  
  override func getInputStream() -> JavaIoInputStream! {
    return inputStream
  }
  
  override func getOutputStream() -> JavaIoOutputStream! {
    return outputStream
  }
}

class iOSSSLSocketFactory : JavaxNetSslSSLSocketFactory{
  override func createSocket(with s: JavaNetSocket!, with host: String!, with port: jint, withBoolean autoClose: jboolean) -> JavaNetSocket! {
    return iOSSSLSocket(underlyingSocket: s, hostName: host)
  }
  
  override func createSocket() -> JavaNetSocket! {
    // Not implemented
    return nil
  }
  
  override func createSocket(with host: String!, with port: jint) -> JavaNetSocket! {
    // Not implemented
    return nil
  }
  
  override func getSupportedCipherSuites() -> IOSObjectArray! {
    // Not implemented
    return nil
  }
  
  override func getDefaultCipherSuites() -> IOSObjectArray! {
    // Not implemented
    return nil
  }
}

Now, in theory, I guess this should work. Unfortunately, it only works sporadically: The handshake sometimes succeeds, but more often, SSLHandshake returns an error:

-50 - One or more parameters passed to a function were not valid.

Interestingly, the code seems more likely to succeed if I log from within the callback functions or use breakpoints to stop execution, so my guess is there are some race conditions that I didn't think of. Is there something in the j2objc code that could cause these issues? Is there anything with the handling of the underlying Socket in the callbacks that I've overseen?

For testing the code, I use this simple test:

SSLSocketTest.java

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.net.InetSocketAddress;
import java.net.Socket;

import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;

public class SSLSocketTest {
    private SSLSocketFactory mFactory;
    private String mHost;
    private int mPort;

    // For example: www.google.com, 443
    public SSLSocketTest(SSLSocketFactory factory, String host, int port){
        mFactory = factory;
        mHost = host;
        mPort = port;
    }

    public void run(){
        try {
            Socket socket = new Socket();
            socket.connect(new InetSocketAddress(mHost, mPort));
            SSLSocket sslSocket = (SSLSocket) mFactory.createSocket(socket, mHost, mPort, true);
            sslSocket.startHandshake();

            String request = "GET / HTTP/1.1\nHost:" + mHost + "\n\n";
            sslSocket.getOutputStream().write(request.getBytes());
            sslSocket.getOutputStream().flush();

            StringBuilder response = new StringBuilder();
            BufferedReader br = new BufferedReader(new InputStreamReader(sslSocket.getInputStream()));
            response.append(br.readLine());
            br.close();
        } catch (Exception e) {
            // Removed custom logging
        }
    }
}

For completeness:

ObjC.h

#import <Foundation/Foundation.h>

@interface ObjC : NSObject

+ (BOOL)catchException:(void(^)())tryBlock error:(__autoreleasing NSError **)error;
+ (void)throwException:(id)exception;

@end

ObjC.m

#import "ObjC.h"

@implementation ObjC

+ (BOOL)catchException:(void(^)())tryBlock error:(__autoreleasing NSError **)error {
  @try {
    tryBlock();
    return YES;
  }
  @catch (NSException *exception) {
    *error = [[NSError alloc] initWithDomain:exception.name code:0 userInfo:exception.userInfo];
    return NO;
  }
}

+ (void)throwException:(id)exception{
  @throw exception;
}

@end

Also, I'm wondering if an SSLSocketFactory like this could be added to j2objc? It seems like most of Java's SSLSocket functionality can be implemented with the Secure Transport API. Obviously that leaves things like certificate parsing to iOS (without possibility for any Java-side interventions), but I would think that's fine for most applications.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions