Passed
Pull Request — master (#8)
by Konstantinos
04:04
created

gui-demo2.update_image_thread()   B

Complexity

Conditions 5

Size

Total Lines 38
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 24
nop 5
dl 0
loc 38
rs 8.8373
c 0
b 0
f 0
1
import tkinter as tk
2
from tkinter import filedialog
3
import threading
4
import matplotlib.pyplot as plt
5
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
6
from matplotlib.figure import Figure
7
import numpy as np
8
from PIL import Image, ImageTk  # You need to install the Python Imaging Library (PIL)
9
10
# from artificial_artwork._demo import create_algo_runner
11
from artificial_artwork._main import create_algo_runner
12
from artificial_artwork.image import convert_to_uint8
13
14
15
# CONSTANTS
16
IMAGE_COMP_ASSETS = {
17
    'content': {
18
        'load_button_text': "Select Content Image",
19
        'label_text': "Content Image:",
20
    },
21
    'style': {
22
        'load_button_text': "Select Style Image",
23
        'label_text': "Style Image:",
24
    },
25
}
26
27
# width x height
28
WINDOW_GEOMETRY: str = '2600x1800'
29
30
# Content and Style Images rendering dimensions
31
INPUT_IMAGE_THUMBNAIL_SIZE = (200, 200)
32
33
# Generated Image rendering dimensions
34
GENERATED_IMAGE_THUMBNAIL_SIZE = (500, 500)
35
36
37
# Helpers Objects
38
39
img_type_2_path = {}
40
41
# Helper Functions
42
def _build_open_image_dialog_callback(image_file_label, image_type: str):
43
    def _open_image_dialog():
44
        file_path = filedialog.askopenfilename()
45
        if file_path:
46
            img_type_2_path[image_type] = file_path
47
            image_file_label.config(text=f'{IMAGE_COMP_ASSETS[image_type]["label_text"]} {file_path}')
48
    return _open_image_dialog
49
50
51 View Code Duplication
def _build_open_image_dialog_callback_v2(x, image_type: str):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
52
    def _open_file_dialog_v2():
53
        file_path = filedialog.askopenfilename()
54
        if file_path:
55
            image_label = x['image_label']
56
            image_pane = x['image_pane']
57
58
            img_type_2_path[image_type] = file_path
59
60
            image = Image.open(file_path)
61
            image.thumbnail(INPUT_IMAGE_THUMBNAIL_SIZE)  # Resize the image to fit in the pane
62
            photo = ImageTk.PhotoImage(image=image)
63
64
            image_pane.config(image=photo)
65
            image_pane.image = photo
66
67
            image_label.config(text=f'{IMAGE_COMP_ASSETS[image_type]["label_text"]} {file_path}')
68
            image_label.update_idletasks()
69
    return _open_file_dialog_v2
70
71
72
# MAIN
73
74
images_components_data = {
75
    'content': dict(
76
        IMAGE_COMP_ASSETS['content'],
77
        image_dialog_from_label=lambda label_obj: _build_open_image_dialog_callback(label_obj, 'content'),
78
        image_dialog=lambda x: _build_open_image_dialog_callback_v2(x, 'content'),
79
    ),
80
    'style': dict(
81
        IMAGE_COMP_ASSETS['style'],
82
        image_dialog_from_label=lambda label_obj: _build_open_image_dialog_callback(label_obj, 'style'),
83
        image_dialog=lambda x: _build_open_image_dialog_callback_v2(x, 'style'),
84
    ),
85
}
86
87
# Create the main window
88
root = tk.Tk()
89
root.title("Neural Style Transfer - Desktop")
90
# width x height
91
root.geometry("2600x1800")  # Larger window size
92
93
# Add a label to describe the purpose of the GUI
94
description_label = tk.Label(root, text="Select a file using the buttons below:")
95
description_label.pack(pady=10)  # Add padding
96
97
98
# CONTENT IMAGE UI/UX
99
100
# BUTTON -> Load Content Image
101
button1 = tk.Button(
102
    root,
103
    text=images_components_data['content']['load_button_text'],
104
    # command=lambda: images_components_data['content']['image_dialog_from_label'](content_image_label)(),
105
    command=lambda: images_components_data['content']['image_dialog']({
106
        'image_label': content_image_label,
107
        'image_pane': content_image_pane,
108
    })(),
109
)
110
button1.pack(pady=5)  # Add padding
111
112
# LABEL -> Show path of loaded Content Image
113
content_image_label = tk.Label(root, text=images_components_data['content']['label_text'])
114
content_image_label.pack()
115
116
# LABEL -> PANE to Render the Content Image
117
content_image_pane = tk.Label(root, width=0, height=0, bg="white")  # Set initial dimensions to 0
118
# content_image_pane = tk.Label(root, width=200, height=200, bg="white")
119
content_image_pane.pack()
120
121
122
# STYLE IMAGE UI/UX
123
124
# BUTTON -> Load Style Image
125
load_style_image_btn = tk.Button(
126
    root,
127
    text=images_components_data['style']['load_button_text'],
128
    # command=lambda: images_components_data['style']['image_dialog_from_label'](style_image_label)()
129
    command=lambda: images_components_data['style']['image_dialog']({
130
        'image_label': style_image_label,
131
        'image_pane': style_image_pane,
132
    })(),
133
)
134
load_style_image_btn.pack(pady=5)  # Add padding
135
136
# LABEL -> Show path of loaded Style Image
137
style_image_label = tk.Label(root, text=images_components_data['style']['label_text'])
138
style_image_label.pack()
139
140
# LABEL -> PANE to Render the Style Image
141
style_image_pane = tk.Label(root, width=0, height=0, bg="white")  # Set initial dimensions to 0
142
# style_image_pane = tk.Label(root, width=200, height=200, bg="white")
143
style_image_pane.pack()
144
145
146
# GENERATED IMAGE UI/UX
147
148
# Helper Update Callback
149
# def update_image_thread(progress, gen_image_pane, _iteration_count_label, fig, combined_subplot):
150
#     t = threading.Thread(
151
#         target=update_image,
152
#         args=(progress, gen_image_pane, _iteration_count_label, fig, combined_subplot)
153
#     )
154
#     t.start()
155
156
#### UPDATE UI based on BACKEND progress ####
157
158
# Function to update the GUI with the result from the backend task
159
# def update_image(progress, gen_image_pane, _iteration_count_label, fig, combined_subplot):
160
def update_image_thread(progress, gen_image_pane, _iteration_count_label, fig, combined_subplot):
161
    numpy_image_array = progress.state.matrix
162
    current_iteration_count: int = progress.state.metrics['iterations']
163
164
    # if we have shape of form (1, Width, Height, Number_of_Color_Channels)
165
    if numpy_image_array.ndim == 4 and numpy_image_array.shape[0] == 1:
166
        # reshape to (Width, Height, Number_of_Color_Channels)
167
        matrix = np.reshape(numpy_image_array, tuple(numpy_image_array.shape[1:]))
168
169
    if str(matrix.dtype) != 'uint8':
0 ignored issues
show
introduced by
The variable matrix does not seem to be defined in case numpy_image_array.ndim == 4 and SubscriptNode == 1 on line 165 is False. Are you sure this can never be the case?
Loading history...
170
        matrix = convert_to_uint8(matrix)
171
172
    image = Image.fromarray(matrix)
173
174
    # Resize the image to fit in the pane
175
    image.thumbnail(GENERATED_IMAGE_THUMBNAIL_SIZE)
176
    # Convert the image to PhotoImage
177
    photo = ImageTk.PhotoImage(image=image)
178
    # Update the image label with the new image
179
    gen_image_pane.config(image=photo)
180
    gen_image_pane.image = photo
181
182
    _iteration_count_label.config(text=f'Iteration Count: {current_iteration_count}')
183
184
    if 'cost' in progress.state.metrics:  # backend has evaluated the costs into scalars (floats)
185
        # Update metrics
186
        total_cost_values.append(progress.state.metrics['cost'])
187
        style_cost_values.append(progress.state.metrics['style-cost-weighted'])
188
        content_cost_values.append(progress.state.metrics['content-cost-weighted'])
189
        iteration_values.append(current_iteration_count)
190
191
        # Update the graph
192
        update_chart(
193
            iteration_values,
194
            total_cost_values,
195
            style_cost_values,
196
            content_cost_values,
197
            combined_subplot
198
        )
199
200
################
201
202
# LABEL -> Text to display above Live Updated Generated Image
203
generated_image_label = tk.Label(root, text="Generated Image:")
204
generated_image_label.pack(pady=10)
205
206
# LABEL -> Live Display of Generated Image ! (this will be updated during the learning loop)
207
generated_image_pane = tk.Label(root, width=0, height=0, bg="white")  # Set initial dimensions to 0
208
generated_image_pane.pack(pady=5)
209
210
# ITERATION COUNT UI/UX
211
# LABEL -> Iteration Count Live Update
212
iteration_count_label = tk.Label(root, text="Iteration Count:")
213
iteration_count_label.pack(pady=5)
214
215
216
# RUN NST ALGORITHM UI/UX
217
218
# Helper Run Functions
219
220
# Run NST Computations in a non-blocking way
221
222 View Code Duplication
def run_nst(fig, combined_subplot):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
223
    # Run tf.compat.v1.reset_default_graph()
224
    # and tf.compat.v1.disable_eager_execution()
225
    # Initialize Session as tf.compat.v1.InteractiveSession()
226
    backend_object = create_algo_runner(
227
        iterations=100,  # NB of Times to pass Image through the Network
228
        output_folder='gui-output-folder',  # Output Folder to store gen img snapshots
229
        noisy_ratio=0.6,
230
    )
231
    observer = type('Observer', (), {
232
        'update': lambda progress: update_image_thread(
233
            progress,
234
            generated_image_pane,
235
            iteration_count_label,
236
            fig,  # Pass the Figure to the update function
237
            combined_subplot,  # Pass the combined subplot to the update function
238
        ),
239
        # 'update': lambda progress: update_image_thread(progress, generated_image_pane, iteration_count_label),
240
    })
241
    backend_object['subscribe'](observer)
242
243
    content_image_path = img_type_2_path['content']
244
    style_image_path = img_type_2_path['style']
245
246
    if content_image_path and style_image_path:
247
        backend_object['run'](
248
            content_image_path,
249
            style_image_path,
250
        )
251
252
# Define Tread to run the NST Algorithm
253
def start_nst_thread():
254
    fig, combined_subplot = initialize_graph(root)
255
    nst_thread = threading.Thread(target=run_nst, args=(fig, combined_subplot))
256
    nst_thread.daemon = True  # Set as a daemon thread to exit when the main program exits
257
    nst_thread.start()
258
259
260
# BUTTON -> Run NST Algorithm on press
261
run_nst_btn = tk.Button(
262
    root,
263
    text="Run NST Algorithm",
264
    command=start_nst_thread,
265
)
266
267
run_nst_btn.pack(pady=5)  # Add padding
268
269
270
# PLOTTING
271
272
total_cost_values = []
273
style_cost_values = []
274
content_cost_values = []
275
iteration_values = []
276
277
278
# Helper Functions
279
# Initialize Matplotlib figure and subplot
280
def initialize_graph(root):
281
    fig, combined_subplot = plt.subplots(figsize=(8, 6))
282
    combined_subplot.set_title('Metrics Over Iterations')
283
    combined_subplot.set_xlabel('Iterations')
284
    combined_subplot.set_ylabel('Metric Values')
285
    
286
    canvas = FigureCanvasTkAgg(fig, master=root)
287
    canvas_widget = canvas.get_tk_widget()
288
    canvas_widget.pack(side=tk.TOP, fill=tk.BOTH, expand=1)
289
    
290
    return fig, combined_subplot
291
292
293
# Update Matplotlib chart with metrics data
294
def update_chart(_iteration_values, _total_cost_values, _style_cost_values, _content_cost_values, _combined_subplot):
295
    _combined_subplot.clear()
296
    _combined_subplot.plot(_iteration_values, _total_cost_values, label='Total Cost', marker='o')
297
    _combined_subplot.plot(_iteration_values, _style_cost_values, label='Weighted Style Cost', marker='s')
298
    _combined_subplot.plot(_iteration_values, _content_cost_values, label='Weighted Content Cost', marker='x')
299
    _combined_subplot.set_title('Metrics Over Iterations')
300
    _combined_subplot.set_xlabel('Iterations')
301
    _combined_subplot.set_ylabel('Metric Values')
302
    _combined_subplot.legend()
303
304
    _combined_subplot.figure.canvas.draw()
305
306
307
# TKINTER MAIN LOOP
308
root.mainloop()
309