Completed
Push — master ( 27dfdc...0fe3ca )
by Michael
03:01
created

lib/nose/proxy/mysql.rb (1 issue)

1
# frozen_string_literal: true
2
3
require 'mysql'
4
5
module NoSE
6
  module Proxy
7
    # A proxy which speaks the MySQL protocol and executes queries
8
    class MysqlProxy < ProxyBase
9
      def initialize(*args)
10
        super
11
12
        # Initialize a hash for the state of sockets
13
        @state = {}
14
      end
15
16
      # Authenticate the client and process queries
17
      def handle_connection(socket)
18
        return authenticate socket if @state[socket].nil?
19
20
        # Retrieve the saved state of the socket
21
        protocol = @state[socket]
22
23
        begin
24
          protocol.process_command(&method(:process_query))
25
        rescue ::Mysql::ClientError::ServerGoneError
26
          # Ensure the socket is closed and remove the state
27
          remove_connection socket
28
          return false
29
        end
30
31
        # Keep this socket around
32
        true
33
      end
34
35
      # Remove the state of the socket
36
      def remove_connection(socket)
37
        socket.close
38
        @state.delete socket
39
      end
40
41
      private
42
43
      # Auth the client and prepare for query processsing
44
      # @return [Boolean]
45
      def authenticate(socket)
46
        protocol = ::Mysql::ServerProtocol.new socket
47
48
        # Try to authenticate
49
        begin
50
          protocol.authenticate
51
        rescue
52
          remove_connection socket
53
          return false
54
        end
55
56
        @state[socket] = protocol
57
58
        true
59
      end
60
61
      # Execute the query on the backend and return the result
62
      def process_query(protocol, query)
63
        begin
64
          @logger.debug { "Got query #{query}" }
65
          result = query_result query
66
          @logger.debug "Executed query with #{result.size} results"
67
        rescue ParseFailed => exc
68
          protocol.error ::Mysql::ServerError::ER_PARSE_ERROR, exc.message
69
        rescue Backend::PlanNotFound => exc
70
          protocol.error ::Mysql::ServerError::ER_UNKNOWN_STMT_HANDLER,
71
                         exc.message
72
        end
73
74
        result
75
      end
76
77
      private
78
79
      # Get the result of the query from the backend
80
      def query_result(query)
81
        query = Statement.parse query, @result.workload.model
82
        @backend.query(query).lazy.map do |row|
83
          Hash[query.select.map { |field| [field.name, row[field.id]] }]
84
        end
85
      end
86
    end
87
  end
88
89
  # Extend the client library with necessary server code
90
  class ::Mysql
1 ignored issue
show
Your coding style requires you to prefer nested module/class definitions instead of compact style.

Instead of combining definitions as much as possible

class Foo::Bar
end

nest the definitions more verbosely

class Foo
  class Bar
  end
end
Loading history...
91
    # Simple class which doesn't do connection setup
92
    class ServerProtocol < Protocol
93
      def initialize(socket)
94
        # We need a much simpler initialization than the default class
95
        @sock = socket
96
      end
97
98
      # Perform authentication
99
      def authenticate
100
        reset
101
        write InitialPacket.serialize
102
        AuthenticationPacket.parse read # TODO: Check auth
103
        write ResultPacket.serialize 0
104
      end
105
106
      # Send an error message with the given number and text
107
      def error(errno, message)
108
        write ErrorPacket.serialize errno, message
109
      end
110
111
      # Process a single incoming command
112
      def process_command(&block)
113
        reset
114
        pkt = read
115
        command = pkt.utiny
116
117
        case command
118
        when COM_QUIT
119
          # Stop processing because the client left
120
          return
121
        when COM_QUERY
122
          process_query pkt.to_s, &block
123
        when COM_PING
124
          write ResultPacket.serialize 0
125
        else
126
          # Return error for invalid commands
127
          protocol.error ::Mysql::ServerError::ER_NOT_SUPPORTED_YET,
128
                         'Command not supported'
129
        end
130
      end
131
132
      private
133
134
      # Handle an individual query
135
      def process_query(query)
136
        # Execute the query on the backend
137
        result = yield self, query
138
        return if result.nil?
139
140
        # Return the list of fields in the result
141
        field_names = result.any? ? result.peek.keys : []
142
        write_fields result, field_names
143
        write_rows result, field_names
144
      end
145
146
      # Write the list of fields for the resulting rows
147
      def write_fields(result, field_names)
148
        write ResultPacket.serialize field_names.count
149
        field_names.each do |field_name|
150
          type, = Protocol.value2net result.first[field_name]
151
152
          write FieldPacket.serialize '', '', '', field_name, '', 1, type,
153
                                      Field::NOT_NULL_FLAG, 0, ''
154
        end
155
        write EOFPacket.serialize
156
      end
157
158
      # Write a packet for each row in the results
159
      def write_rows(result, field_names)
160
        result.each do |row|
161
          values = field_names.map { |field_name| row[field_name] }
162
          write(values.map do |value|
163
            Protocol.value2net(value.to_s).last
164
          end.inject('', &:+))
165
        end
166
        write EOFPacket.serialize
167
      end
168
    end
169
170
    # Add serialization of the initial packet
171
    class InitialPacket
172
      # Serialize the initial server hello
173
      # @return [String]
174
      def self.serialize
175
        [
176
          ::Mysql::Protocol::VERSION,
177
          'nose',
178
          0,
179
          'AAAAAAAA',
180
          0,
181
          CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION,
182
          33, # utf8_general_ci
183
          SERVER_STATUS_AUTOCOMMIT,
184
          'AAAAAAAAAAAA'
185
        ].pack('CZ*Va8CvCvx13Z*')
186
      end
187
    end
188
189
    # Add serialization of result packets
190
    class ResultPacket
191
      # Serialize a simple OK response
192
      # rubocop:disable Metrics/ParameterLists
193
      # @return [String]
194
      def self.serialize(field_count, affected_rows = 0, insert_id = 0,
195
                         server_status = 0, warning_count = 0, message = '')
196
        return Packet.lcb(field_count) unless field_count.zero?
197
198
        Packet.lcb(field_count) +
199
          Packet.lcb(affected_rows) +
200
          Packet.lcb(insert_id) +
201
          [
202
            server_status,
203
            warning_count
204
          ].pack('vv') +
205
          Packet.lcs(message)
206
      end
207
      # rubocop:enable Metrics/ParameterLists
208
    end
209
210
    # Add serialization of field packets
211
    class FieldPacket
212
      # Serialize all the data for a field
213
      # rubocop:disable Metrics/ParameterLists
214
      # @return [String]
215
      def self.serialize(db, table, org_table, name, org_name, length, type,
216
                         flags, decimals, default)
217
        Packet.lcs('def') + # catalog
218
          Packet.lcs(db) +
219
          Packet.lcs(table) +
220
          Packet.lcs(org_table) +
221
          Packet.lcs(name) +
222
          Packet.lcs(org_name) +
223
          [
224
            0x0c,
225
            33, # utf8_general_ci
226
            length,
227
            type,
228
            flags,
229
            decimals,
230
            0
231
          ].pack('CvVCvCv') + Packet.lcs(default)
232
      end
233
      # rubocop:enable Metrics/ParameterLists
234
    end
235
236
    # Add parsing of auth packets
237
    class AuthenticationPacket
238
      # Parse the incoming authentication packet
239
      def self.parse(_pkt)
240
        # XXX: Unneeded for now since we don't handle auth
241
        # client_flags = pkt.ulong
242
        # max_packet_size = pkt.ulong
243
        # charset_number = pkt.lcb
244
        # f1 = pkt.read(23)
245
        # username = pkt.string
246
        # scrambled_password = pkt.lcs
247
        # databasename = pkt.string
248
      end
249
    end
250
251
    # Simple EOF packet
252
    class EOFPacket
253
      # Static string to indicate EOF
254
      # @return [String]
255
      def self.serialize
256
        "\xfe\x00\x00\x00\x00"
257
      end
258
    end
259
260
    # Serialize an error message
261
    class ErrorPacket
262
      # Generate a packet with a given error number and message
263
      # @return [String]
264
      def self.serialize(errno, message)
265
        [
266
          0xff,
267
          errno,
268
          '#',
269
          @sqlstate,
270
          message
271
        ].pack('Cvaa5a*')
272
      end
273
    end
274
  end
275
end
276