1 | # frozen_string_literal: true |
||
2 | |||
3 | require 'mysql' |
||
4 | |||
5 | module NoSE |
||
6 | module Proxy |
||
7 | # A proxy which speaks the MySQL protocol and executes queries |
||
8 | class MysqlProxy < ProxyBase |
||
9 | def initialize(*args) |
||
10 | super |
||
11 | |||
12 | # Initialize a hash for the state of sockets |
||
13 | @state = {} |
||
14 | end |
||
15 | |||
16 | # Authenticate the client and process queries |
||
17 | def handle_connection(socket) |
||
18 | return authenticate socket if @state[socket].nil? |
||
19 | |||
20 | # Retrieve the saved state of the socket |
||
21 | protocol = @state[socket] |
||
22 | |||
23 | begin |
||
24 | protocol.process_command(&method(:process_query)) |
||
25 | rescue ::Mysql::ClientError::ServerGoneError |
||
26 | # Ensure the socket is closed and remove the state |
||
27 | remove_connection socket |
||
28 | return false |
||
29 | end |
||
30 | |||
31 | # Keep this socket around |
||
32 | true |
||
33 | end |
||
34 | |||
35 | # Remove the state of the socket |
||
36 | def remove_connection(socket) |
||
37 | socket.close |
||
38 | @state.delete socket |
||
39 | end |
||
40 | |||
41 | private |
||
42 | |||
43 | # Auth the client and prepare for query processsing |
||
44 | # @return [Boolean] |
||
45 | def authenticate(socket) |
||
46 | protocol = ::Mysql::ServerProtocol.new socket |
||
47 | |||
48 | # Try to authenticate |
||
49 | begin |
||
50 | protocol.authenticate |
||
51 | rescue |
||
52 | remove_connection socket |
||
53 | return false |
||
54 | end |
||
55 | |||
56 | @state[socket] = protocol |
||
57 | |||
58 | true |
||
59 | end |
||
60 | |||
61 | # Execute the query on the backend and return the result |
||
62 | def process_query(protocol, query) |
||
63 | begin |
||
64 | @logger.debug { "Got query #{query}" } |
||
65 | result = query_result query |
||
66 | @logger.debug "Executed query with #{result.size} results" |
||
67 | rescue ParseFailed => exc |
||
68 | protocol.error ::Mysql::ServerError::ER_PARSE_ERROR, exc.message |
||
69 | rescue Backend::PlanNotFound => exc |
||
70 | protocol.error ::Mysql::ServerError::ER_UNKNOWN_STMT_HANDLER, |
||
71 | exc.message |
||
72 | end |
||
73 | |||
74 | result |
||
75 | end |
||
76 | |||
77 | private |
||
78 | |||
79 | # Get the result of the query from the backend |
||
80 | def query_result(query) |
||
81 | query = Statement.parse query, @result.workload.model |
||
82 | @backend.query(query).lazy.map do |row| |
||
83 | Hash[query.select.map { |field| [field.name, row[field.id]] }] |
||
84 | end |
||
85 | end |
||
86 | end |
||
87 | end |
||
88 | |||
89 | # Extend the client library with necessary server code |
||
90 | class ::Mysql |
||
1 ignored issue
–
show
|
|||
91 | # Simple class which doesn't do connection setup |
||
92 | class ServerProtocol < Protocol |
||
93 | def initialize(socket) |
||
94 | # We need a much simpler initialization than the default class |
||
95 | @sock = socket |
||
96 | end |
||
97 | |||
98 | # Perform authentication |
||
99 | def authenticate |
||
100 | reset |
||
101 | write InitialPacket.serialize |
||
102 | AuthenticationPacket.parse read # TODO: Check auth |
||
103 | write ResultPacket.serialize 0 |
||
104 | end |
||
105 | |||
106 | # Send an error message with the given number and text |
||
107 | def error(errno, message) |
||
108 | write ErrorPacket.serialize errno, message |
||
109 | end |
||
110 | |||
111 | # Process a single incoming command |
||
112 | def process_command(&block) |
||
113 | reset |
||
114 | pkt = read |
||
115 | command = pkt.utiny |
||
116 | |||
117 | case command |
||
118 | when COM_QUIT |
||
119 | # Stop processing because the client left |
||
120 | return |
||
121 | when COM_QUERY |
||
122 | process_query pkt.to_s, &block |
||
123 | when COM_PING |
||
124 | write ResultPacket.serialize 0 |
||
125 | else |
||
126 | # Return error for invalid commands |
||
127 | protocol.error ::Mysql::ServerError::ER_NOT_SUPPORTED_YET, |
||
128 | 'Command not supported' |
||
129 | end |
||
130 | end |
||
131 | |||
132 | private |
||
133 | |||
134 | # Handle an individual query |
||
135 | def process_query(query) |
||
136 | # Execute the query on the backend |
||
137 | result = yield self, query |
||
138 | return if result.nil? |
||
139 | |||
140 | # Return the list of fields in the result |
||
141 | field_names = result.any? ? result.peek.keys : [] |
||
142 | write_fields result, field_names |
||
143 | write_rows result, field_names |
||
144 | end |
||
145 | |||
146 | # Write the list of fields for the resulting rows |
||
147 | def write_fields(result, field_names) |
||
148 | write ResultPacket.serialize field_names.count |
||
149 | field_names.each do |field_name| |
||
150 | type, = Protocol.value2net result.first[field_name] |
||
151 | |||
152 | write FieldPacket.serialize '', '', '', field_name, '', 1, type, |
||
153 | Field::NOT_NULL_FLAG, 0, '' |
||
154 | end |
||
155 | write EOFPacket.serialize |
||
156 | end |
||
157 | |||
158 | # Write a packet for each row in the results |
||
159 | def write_rows(result, field_names) |
||
160 | result.each do |row| |
||
161 | values = field_names.map { |field_name| row[field_name] } |
||
162 | write(values.map do |value| |
||
163 | Protocol.value2net(value.to_s).last |
||
164 | end.inject('', &:+)) |
||
165 | end |
||
166 | write EOFPacket.serialize |
||
167 | end |
||
168 | end |
||
169 | |||
170 | # Add serialization of the initial packet |
||
171 | class InitialPacket |
||
172 | # Serialize the initial server hello |
||
173 | # @return [String] |
||
174 | def self.serialize |
||
175 | [ |
||
176 | ::Mysql::Protocol::VERSION, |
||
177 | 'nose', |
||
178 | 0, |
||
179 | 'AAAAAAAA', |
||
180 | 0, |
||
181 | CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION, |
||
182 | 33, # utf8_general_ci |
||
183 | SERVER_STATUS_AUTOCOMMIT, |
||
184 | 'AAAAAAAAAAAA' |
||
185 | ].pack('CZ*Va8CvCvx13Z*') |
||
186 | end |
||
187 | end |
||
188 | |||
189 | # Add serialization of result packets |
||
190 | class ResultPacket |
||
191 | # Serialize a simple OK response |
||
192 | # rubocop:disable Metrics/ParameterLists |
||
193 | # @return [String] |
||
194 | def self.serialize(field_count, affected_rows = 0, insert_id = 0, |
||
195 | server_status = 0, warning_count = 0, message = '') |
||
196 | return Packet.lcb(field_count) unless field_count.zero? |
||
197 | |||
198 | Packet.lcb(field_count) + |
||
199 | Packet.lcb(affected_rows) + |
||
200 | Packet.lcb(insert_id) + |
||
201 | [ |
||
202 | server_status, |
||
203 | warning_count |
||
204 | ].pack('vv') + |
||
205 | Packet.lcs(message) |
||
206 | end |
||
207 | # rubocop:enable Metrics/ParameterLists |
||
208 | end |
||
209 | |||
210 | # Add serialization of field packets |
||
211 | class FieldPacket |
||
212 | # Serialize all the data for a field |
||
213 | # rubocop:disable Metrics/ParameterLists |
||
214 | # @return [String] |
||
215 | def self.serialize(db, table, org_table, name, org_name, length, type, |
||
216 | flags, decimals, default) |
||
217 | Packet.lcs('def') + # catalog |
||
218 | Packet.lcs(db) + |
||
219 | Packet.lcs(table) + |
||
220 | Packet.lcs(org_table) + |
||
221 | Packet.lcs(name) + |
||
222 | Packet.lcs(org_name) + |
||
223 | [ |
||
224 | 0x0c, |
||
225 | 33, # utf8_general_ci |
||
226 | length, |
||
227 | type, |
||
228 | flags, |
||
229 | decimals, |
||
230 | 0 |
||
231 | ].pack('CvVCvCv') + Packet.lcs(default) |
||
232 | end |
||
233 | # rubocop:enable Metrics/ParameterLists |
||
234 | end |
||
235 | |||
236 | # Add parsing of auth packets |
||
237 | class AuthenticationPacket |
||
238 | # Parse the incoming authentication packet |
||
239 | def self.parse(_pkt) |
||
240 | # XXX: Unneeded for now since we don't handle auth |
||
241 | # client_flags = pkt.ulong |
||
242 | # max_packet_size = pkt.ulong |
||
243 | # charset_number = pkt.lcb |
||
244 | # f1 = pkt.read(23) |
||
245 | # username = pkt.string |
||
246 | # scrambled_password = pkt.lcs |
||
247 | # databasename = pkt.string |
||
248 | end |
||
249 | end |
||
250 | |||
251 | # Simple EOF packet |
||
252 | class EOFPacket |
||
253 | # Static string to indicate EOF |
||
254 | # @return [String] |
||
255 | def self.serialize |
||
256 | "\xfe\x00\x00\x00\x00" |
||
257 | end |
||
258 | end |
||
259 | |||
260 | # Serialize an error message |
||
261 | class ErrorPacket |
||
262 | # Generate a packet with a given error number and message |
||
263 | # @return [String] |
||
264 | def self.serialize(errno, message) |
||
265 | [ |
||
266 | 0xff, |
||
267 | errno, |
||
268 | '#', |
||
269 | @sqlstate, |
||
270 | message |
||
271 | ].pack('Cvaa5a*') |
||
272 | end |
||
273 | end |
||
274 | end |
||
275 | end |
||
276 |
Instead of combining definitions as much as possible
nest the definitions more verbosely