Completed
Push — master ( 8e6260...e1105a )
by Anas
29s
created

plot_colors()   A

Complexity

Conditions 4

Size

Total Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 4
c 1
b 0
f 0
dl 0
loc 12
rs 9.2
1
#!/usr/bin/python
2
# -*- coding: utf-8 -*-
3
from modules.utils import caption_filter, get_image, send_image, get_param
4
from telegram.ext import CommandHandler, MessageHandler
5
from telegram.ext.dispatcher import run_async
6
from sklearn.cluster import KMeans
7
from telegram import ChatAction
8
from PIL import Image
9
import numpy as np
10
import datetime
11
import cv2
12
13
14
def module_init(gd):
15
    global path
16
    path = gd.config["path"]
17
    commands = gd.config["commands"]
18
    extensions = gd.config["extensions"]
0 ignored issues
show
Unused Code introduced by
The variable extensions seems to be unused.
Loading history...
19
    for command in commands:
20
        gd.dp.add_handler(MessageHandler(caption_filter("/"+command), palette))
21
        gd.dp.add_handler(CommandHandler(command, palette))
22
23
24
@run_async
25
def palette(bot, update):
26
    filename = datetime.datetime.now().strftime("%d%m%y-%H%M%S%f")
27
    name = filename + "-palette"
28
    colors = get_param(update, 4, 1, 10)
29
    if colors is None:
30
        return
31
    try:
32
        extension = get_image(bot, update, path, filename)
33
    except:
34
        update.message.reply_text("I can't get the image! :(")
35
        return
36
    update.message.chat.send_action(ChatAction.UPLOAD_PHOTO)
37
    start_computing(path, filename, extension, colors, "flat")
38
    send_image(update, path, name, extension)
39
    print(datetime.datetime.now(), ">>>", "palette", ">>>", update.message.from_user.username)
40
41
42
def start_computing(path, filename, extension, colors, mode):
43
    open_path = path + filename + extension
44
    number_of_colors = colors
45
    name = filename + "-palette"
46
    save_path = path + name + extension
47
    # Load image here
48
    pil_image = Image.open(open_path).convert('RGB')
49
    original_image = np.float32(pil_image)
50
    original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
51
    width, height = pil_image.size
52
    pil_image.thumbnail((400, 400), Image.ANTIALIAS)
53
    cv_image = np.float32(pil_image)
54
    # Reshape the image to be a list of pixels
55
    cv_image = cv_image.reshape((cv_image.shape[0] * cv_image.shape[1], 3))
56
    # Cluster the pixel intensities
57
    clt = KMeans(n_clusters = number_of_colors, tol=0.001).fit(cv_image)
58
    # Build a histogram of clusters representing the number of pixels labeled to each color
59
    hist = centroid_histogram(clt)
60
    bar = plot_colors(hist, clt.cluster_centers_, width, height, number_of_colors, mode)
61
    bar = cv2.cvtColor(bar, cv2.COLOR_BGR2RGB)
62
    # Separating line for image + palette stacking
63
    separator = np.zeros((15, width, 3), dtype = "uint8")
64
    separator = cv2.rectangle(separator, (0, 0), (width, 15), (30,30,30), -1)
65
    # Cobmine original image, separator and color chart
66
    stacked = np.concatenate((original_image, separator, bar), axis=0)
67
    stacked = cv2.cvtColor(stacked, cv2.COLOR_BGR2RGB)
68
    stacked = Image.fromarray(stacked.astype('uint8'))
69
    stacked.save(save_path)
70
71
72
def centroid_histogram(clt):
73
    numLabels = np.arange(0, len(np.unique(clt.labels_)) + 1)
74
    (hist, _) = np.histogram(clt.labels_, bins = numLabels)
75
    hist = hist.astype("float")
76
    hist /= hist.sum()
77
    return hist
78
79
80
def plot_colors(hist, centroids, width, height, number_of_colors, mode):
81
    bar_height = int(height*0.2)
82
    bar = np.zeros((bar_height, width, 3), dtype = "uint8")
83
    startX = 0
84
    for (percent, color) in zip(hist, centroids):
85
        if mode == "flat":
86
            endX = startX + (width/number_of_colors)
87
        elif mode == "percentage":
88
            endX = startX + (percent * width)
89
        cv2.rectangle(bar, (int(startX), 0), (int(endX), bar_height), color.astype("uint8").tolist(), -1)
90
        startX = endX
91
    return bar