| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #ifndef _DNS_DNSTLSSOCKET_H |
| #define _DNS_DNSTLSSOCKET_H |
| |
| #include <openssl/ssl.h> |
| #include <future> |
| #include <mutex> |
| |
| #include <android-base/thread_annotations.h> |
| #include <android-base/unique_fd.h> |
| #include <netdutils/Slice.h> |
| #include <netdutils/Status.h> |
| |
| #include "DnsTlsServer.h" |
| #include "IDnsTlsSocket.h" |
| #include "LockedQueue.h" |
| |
| namespace android { |
| namespace net { |
| |
| class IDnsTlsSocketObserver; |
| class DnsTlsSessionCache; |
| |
| // A class for managing a TLS socket that sends and receives messages in |
| // [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format). |
| // This class is not aware of query-response pairing or anything else about DNS. |
| // For the observer: |
| // This class is not re-entrant: the observer is not permitted to wait for a call to query() |
| // or the destructor in a callback. Doing so will result in deadlocks. |
| // This class may call the observer at any time after initialize(), until the destructor |
| // returns (but not after). |
| // |
| // Calls to IDnsTlsSocketObserver in a DnsTlsSocket life cycle: |
| // |
| // UNINITIALIZED |
| // | |
| // v |
| // INITIALIZED |
| // | |
| // v |
| // +----CONNECTING------+ |
| // Handshake fails | | Handshake succeeds |
| // (onClose() when | | |
| // mAsyncHandshake is set) | v |
| // | +---> CONNECTED --+ |
| // | | | | |
| // | +-----------+ | Idle timeout |
| // | Send/Recv queries | onClose() |
| // | onResponse() | |
| // | | |
| // | | |
| // +--> WAIT_FOR_DELETE <-----+ |
| // |
| // |
| // TODO: Add onHandshakeFinished() for handshake results. |
| class DnsTlsSocket : public IDnsTlsSocket { |
| public: |
| enum class State { |
| UNINITIALIZED, |
| INITIALIZED, |
| CONNECTING, |
| CONNECTED, |
| WAIT_FOR_DELETE, |
| }; |
| |
| DnsTlsSocket(const DnsTlsServer& server, unsigned mark, |
| IDnsTlsSocketObserver* _Nonnull observer, DnsTlsSessionCache* _Nonnull cache) |
| : mMark(mark), mServer(server), mObserver(observer), mCache(cache) {} |
| ~DnsTlsSocket(); |
| |
| // Creates the SSL context for this session. Returns false on failure. |
| // This method should be called after construction and before use of a DnsTlsSocket. |
| // Only call this method once per DnsTlsSocket. |
| bool initialize() EXCLUDES(mLock); |
| |
| // If async handshake is enabled, this function simply signals a handshake request, and the |
| // handshake will be performed in the loop thread; otherwise, if async handshake is disabled, |
| // this function performs the handshake and returns after the handshake finishes. |
| bool startHandshake() EXCLUDES(mLock); |
| |
| // Send a query on the provided SSL socket. |query| contains |
| // the body of a query, not including the ID header. This function will typically return before |
| // the query is actually sent. If this function fails, DnsTlsSocketObserver will be |
| // notified that the socket is closed. |
| // Note that success here indicates successful sending, not receipt of a response. |
| // Thread-safe. |
| bool query(uint16_t id, const netdutils::Slice query) override EXCLUDES(mLock); |
| |
| private: |
| // Lock to be held by the SSL event loop thread. This is not normally in contention. |
| std::mutex mLock; |
| |
| // Forwards queries and receives responses. Blocks until the idle timeout. |
| void loop() EXCLUDES(mLock); |
| std::unique_ptr<std::thread> mLoopThread GUARDED_BY(mLock); |
| |
| // On success, sets mSslFd to a socket connected to mAddr (the |
| // connection will likely be in progress if mProtocol is IPPROTO_TCP). |
| // On error, returns the errno. |
| netdutils::Status tcpConnect() REQUIRES(mLock); |
| |
| bssl::UniquePtr<SSL> prepareForSslConnect(int fd) REQUIRES(mLock); |
| |
| // Connect an SSL session on the provided socket. If connection fails, closing the |
| // socket remains the caller's responsibility. |
| bssl::UniquePtr<SSL> sslConnect(int fd) REQUIRES(mLock); |
| |
| // Connect an SSL session on the provided socket. This is an interruptible version |
| // which allows to terminate connection handshake any time. |
| bssl::UniquePtr<SSL> sslConnectV2(int fd) REQUIRES(mLock); |
| |
| // Disconnect the SSL session and close the socket. |
| void sslDisconnect() REQUIRES(mLock); |
| |
| // Writes a buffer to the socket. |
| bool sslWrite(const netdutils::Slice buffer) REQUIRES(mLock); |
| |
| // Reads exactly the specified number of bytes from the socket, or fails. |
| // Returns SSL_ERROR_NONE on success. |
| // If |wait| is true, then this function always blocks. Otherwise, it |
| // will return SSL_ERROR_WANT_READ if there is no data from the server to read. |
| int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock); |
| |
| bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock); |
| |
| // Read one DNS response. It can potentially block until reading the exact bytes of |
| // the response. |
| bool readResponse() REQUIRES(mLock); |
| |
| // It is only used for DNS-OVER-TLS internal test. |
| bool setTestCaCertificate() REQUIRES(mLock); |
| |
| // Similar to query(), this function uses incrementEventFd to send a message to the |
| // loop thread. However, instead of incrementing the counter by one (indicating a |
| // new query), it wraps the counter to negative, which we use to indicate a shutdown |
| // request. |
| void requestLoopShutdown() EXCLUDES(mLock); |
| |
| // This function sends a message to the loop thread by incrementing mEventFd. |
| bool incrementEventFd(int64_t count) EXCLUDES(mLock); |
| |
| // Transition the state from expected state |from| to new state |to|. |
| void transitionState(State from, State to) REQUIRES(mLock); |
| |
| // Queue of pending queries. query() pushes items onto the queue and notifies |
| // the loop thread by incrementing mEventFd. loop() reads items off the queue. |
| LockedQueue<std::vector<uint8_t>> mQueue; |
| |
| // eventfd socket used for notifying the SSL thread when queries are ready to send. |
| // This socket acts similarly to an atomic counter, incremented by query() and cleared |
| // by loop(). We have to use a socket because the SSL thread needs to wait in poll() |
| // for input from either a remote server or a query thread. Since eventfd does not have |
| // EOF, we indicate a close request by setting the counter to a negative number. |
| // This file descriptor is opened by initialize(), and closed implicitly after |
| // destruction. |
| // Note that: data starts being read from the eventfd when the state is CONNECTED. |
| base::unique_fd mEventFd; |
| |
| // An eventfd used to listen to shutdown requests when the state is CONNECTING. |
| // TODO: let |mEventFd| exclusively handle query requests, and let |mShutdownEvent| exclusively |
| // handle shutdown requests. |
| base::unique_fd mShutdownEvent; |
| |
| // SSL Socket fields. |
| bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock); |
| base::unique_fd mSslFd GUARDED_BY(mLock); |
| bssl::UniquePtr<SSL> mSsl GUARDED_BY(mLock); |
| static constexpr std::chrono::seconds kIdleTimeout = std::chrono::seconds(20); |
| |
| const unsigned mMark; // Socket mark |
| const DnsTlsServer mServer; |
| IDnsTlsSocketObserver* _Nonnull const mObserver; |
| DnsTlsSessionCache* _Nonnull const mCache; |
| State mState GUARDED_BY(mLock) = State::UNINITIALIZED; |
| |
| // If true, defer the handshake to the loop thread; otherwise, run the handshake on caller's |
| // thread (the call to startHandshake()). |
| bool mAsyncHandshake GUARDED_BY(mLock) = false; |
| |
| // The time to wait for the attempt on connecting to the server. |
| // Set the default value 127 seconds to be consistent with TCP connect timeout. |
| // (presume net.ipv4.tcp_syn_retries = 6) |
| static constexpr int kDotConnectTimeoutMs = 127 * 1000; |
| int mConnectTimeoutMs; |
| |
| // For testing. |
| friend class DnsTlsSocketTest; |
| }; |
| |
| } // end of namespace net |
| } // end of namespace android |
| |
| #endif // _DNS_DNSTLSSOCKET_H |